Update app.py
Browse files
app.py
CHANGED
|
@@ -4,25 +4,72 @@ import datetime
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import whisperx
|
| 7 |
-
from whisperx.diarize import DiarizationPipeline
|
| 8 |
import gradio as gr
|
| 9 |
|
| 10 |
-
# Выб
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
|
| 13 |
-
# Загру
|
| 14 |
model = whisperx.load_model("small", device=device, compute_type="float32")
|
| 15 |
-
|
| 16 |
-
# Загружаем пайплайн диаризации
|
| 17 |
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN", None)
|
| 18 |
diarize_pipeline = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
| 19 |
|
|
|
|
| 20 |
def transcribe_with_diarization(audio_path):
|
| 21 |
-
# 1) Транскрипция
|
| 22 |
result = model.transcribe(audio_path)
|
| 23 |
|
| 24 |
-
# 2) Выравнивание точных времён слов
|
| 25 |
align_model, metadata = whisperx.load_align_model(
|
| 26 |
language_code=result["language"], device=device
|
| 27 |
)
|
| 28 |
-
result = whisperx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import whisperx
|
| 7 |
+
from whisperx.diarize import DiarizationPipeline
|
| 8 |
import gradio as gr
|
| 9 |
|
| 10 |
+
# 1) Выбор устройства
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
|
| 13 |
+
# 2) Загрузка моделей
|
| 14 |
model = whisperx.load_model("small", device=device, compute_type="float32")
|
|
|
|
|
|
|
| 15 |
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN", None)
|
| 16 |
diarize_pipeline = DiarizationPipeline(use_auth_token=hf_token, device=device)
|
| 17 |
|
| 18 |
+
# 3) Основная функция транскрибации + диаризации
|
| 19 |
def transcribe_with_diarization(audio_path):
|
|
|
|
| 20 |
result = model.transcribe(audio_path)
|
| 21 |
|
|
|
|
| 22 |
align_model, metadata = whisperx.load_align_model(
|
| 23 |
language_code=result["language"], device=device
|
| 24 |
)
|
| 25 |
+
result = whisperx.align(
|
| 26 |
+
segments=result["segments"],
|
| 27 |
+
align_model=align_model,
|
| 28 |
+
metadata=metadata,
|
| 29 |
+
audio_path=audio_path,
|
| 30 |
+
device=device
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
diarize_result = diarize_pipeline(audio_path)
|
| 34 |
+
|
| 35 |
+
merged = whisperx.merge_text_with_diarization(
|
| 36 |
+
result["segments"], diarize_result["segments"]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
lines = []
|
| 40 |
+
for seg in merged:
|
| 41 |
+
spk = seg.get("speaker", "Unknown")
|
| 42 |
+
txt = seg.get("text", "").strip()
|
| 43 |
+
lines.append(f"[{spk}] {txt}")
|
| 44 |
+
return "\n".join(lines)
|
| 45 |
+
|
| 46 |
+
# 4) Экспорт в .txt
|
| 47 |
+
def export_to_txt(text):
|
| 48 |
+
fname = f"transcript_{datetime.datetime.now():%Y%m%d_%H%M%S}.txt"
|
| 49 |
+
path = os.path.join(tempfile.gettempdir(), fname)
|
| 50 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 51 |
+
f.write(text)
|
| 52 |
+
return path
|
| 53 |
+
|
| 54 |
+
# 5) Создаём интерфейс в переменной app
|
| 55 |
+
app = gr.Blocks(title="🎤 Транскрибация и диаризация")
|
| 56 |
+
|
| 57 |
+
with app:
|
| 58 |
+
gr.Markdown(
|
| 59 |
+
"## 🎙️ Audio → Text с разделением спикеров\n"
|
| 60 |
+
"Загрузите аудио, нажмите **Transcribe**, отредактируйте имена спикеров при необходимости и "
|
| 61 |
+
"скачайте результат в `.txt`."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
audio_input = gr.Audio(type="filepath", label="Загрузить аудио")
|
| 65 |
+
btn_trans = gr.Button("▶️ Transcribe")
|
| 66 |
+
txt_out = gr.Textbox(lines=20, label="Транскрипция + Спикеры")
|
| 67 |
+
btn_save = gr.Button("💾 Скачать .txt")
|
| 68 |
+
file_out = gr.File(label="Файл для скачивания")
|
| 69 |
+
|
| 70 |
+
btn_trans.click(fn=transcribe_with_diarization, inputs=audio_input, outputs=txt_out)
|
| 71 |
+
btn_save.click(fn=export_to_txt, inputs=txt_out, outputs=file_out)
|
| 72 |
+
|
| 73 |
+
# 6) Запуск
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
app.launch()
|