PacoFYM commited on
Commit
310e379
·
verified ·
1 Parent(s): 80c9ce8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -89
app.py CHANGED
@@ -1,95 +1,92 @@
1
  import os
2
- import tempfile
3
- import whisperx
4
  import torch
 
 
5
  import gradio as gr
6
 
7
- # 1. Устройство
8
- device = "cuda" if torch.cuda.is_available() else "cpu"
9
-
10
- # 2. Загрузка моделей
11
- asr_model = whisperx.load_model("small", device)
12
- hf_token = os.getenv("HF_TOKEN", None)
13
- diarize_pipeline = whisperx.DiarizationPipeline(use_auth_token=hf_token, device=device)
14
-
15
- def transcribe_and_prepare(audio_path):
16
- # ASR (жёстко русский)
17
- result = asr_model.transcribe(audio_path, language="ru")
18
-
19
- # Alignment
20
- aligned = whisperx.align(
21
- result["segments"], asr_model, audio_path, device=device
22
- )
23
-
24
- # Diarization
25
- diarization = diarize_pipeline(audio_path)
26
- segments = whisperx.diarize(aligned, diarization)
27
-
28
- # Подготовка для UI: возвращаем список dict-ов
29
- ui_data = []
30
- for i, seg in enumerate(segments):
31
- ui_data.append({
32
- "index": i,
33
- "speaker": seg["speaker"],
34
- "start": f"{seg['start']:.2f}",
35
- "end": f"{seg['end']:.2f}",
36
- "text": seg["text"]
37
- })
38
- return ui_data
39
-
40
- def generate_download(ui_data):
41
- # Формируем итоговый TXT
42
- lines = []
43
- for row in ui_data:
44
- lines.append(f"[{row['speaker']}] ({row['start']}-{row['end']}): {row['text']}")
45
- txt = "\n".join(lines)
46
- path = os.path.join(tempfile.gettempdir(), "transcript.txt")
47
- with open(path, "w", encoding="utf-8") as f:
48
- f.write(txt)
49
- return path
50
-
51
- # 3. Интерфейс
52
- with gr.Blocks(css="""
53
- .gradio-container { max-width: 900px; margin: auto; }
54
- @media (max-width: 600px) {
55
- .gradio-container { padding: 0 10px; }
56
- }
57
- """) as demo:
58
-
59
- gr.Markdown("## 🎤 Транскрибация и диаризация аудио (русский)")
60
- audio_in = gr.Audio(label="Загрузите аудио", type="filepath")
61
- btn = gr.Button("Запустить транскрибацию")
62
-
63
- # Таблица сегментов для ручной правки
64
- table = gr.Dataframe(
65
- headers=["index","speaker","start","end","text"],
66
- datatype=["number","text","text","text","text"],
67
- interactive=True,
68
- row_count=(1, None),
69
- col_count=5,
70
- wrap=True,
71
- label="Сегменты (можно править спикера и текст)"
72
- )
73
-
74
- download_btn = gr.Button("Скачать итоговый TXT")
75
- download_txt = gr.File(label="Итоговый файл")
76
-
77
- # Связываем
78
- btn.click(fn=transcribe_and_prepare, inputs=[audio_in], outputs=[table])
79
- download_btn.click(fn=generate_download, inputs=[table], outputs=[download_txt])
80
-
81
- # Плейер для выбранного сегмента
82
- with gr.Row():
83
- idx_in = gr.Number(value=0, label="Номер сегмента для прослушивания")
84
- play_btn = gr.Button("▶️ Прослушать сегмент")
85
- player = gr.Audio(label="Плеер сегмента")
86
-
87
- def play_segment(audio_path, ui_data, idx):
88
- seg = ui_data[int(idx)]
89
- start, end = float(seg["start"]), float(seg["end"])
90
- return {"filepath": audio_path, "start_time": start, "end_time": end}
91
-
92
- play_btn.click(fn=play_segment, inputs=[audio_in, table, idx_in], outputs=[player])
93
 
94
  if __name__ == "__main__":
95
- demo.launch()
 
1
  import os
 
 
2
  import torch
3
+ import whisperx
4
+ from pyannote.audio import Pipeline
5
  import gradio as gr
6
 
7
+ def create_app():
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ hf_token = os.getenv("HF_TOKEN") or os.getenv("hf_token") or ""
10
+
11
+ with gr.Blocks() as app:
12
+ gr.Markdown("<h1>Транскрипция и диаризация аудио</h1>")
13
+ gr.Markdown("Загрузите аудиофайл и нажмите **Транскрибировать**. После обработки вы сможете прослушать сегменты, отредактировать текст и присвоить имена спикерам.")
14
+ audio_input = gr.Audio(label="Аудиофайл", source="upload", type="filepath")
15
+ transcribe_btn = gr.Button("Транскрибировать")
16
+
17
+ save_btn = gr.Button("Сохранить результат")
18
+ output_file = gr.File(label="Скачайте результат (.txt)")
19
+
20
+ @gr.render(inputs=[audio_input], triggers=[transcribe_btn])
21
+ def process(audio_path):
22
+ if not audio_path:
23
+ return
24
+ # 1. WhisperX transcription
25
+ model = whisperx.load_model("small", device, compute_type="float32")
26
+ audio_array = whisperx.load_audio(audio_path)
27
+ result = model.transcribe(audio_array, batch_size=16, language="ru")
28
+ model_a, metadata = whisperx.load_align_model(language_code="ru", device=device)
29
+ result = whisperx.align(result["segments"], model_a, metadata, audio_array, device=device, return_char_alignments=False)
30
+ # 2. Speaker diarization
31
+ pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
32
+ pipeline.to(device)
33
+ diarization = pipeline(audio_path)
34
+ result = whisperx.assign_word_speakers(diarization, result)
35
+ segments = result["segments"]
36
+ # Unique speakers
37
+ speakers = sorted({seg["speaker"] for seg in segments})
38
+
39
+ # Input fields for speaker names
40
+ if speakers:
41
+ gr.Markdown("**Имена спикеров:**")
42
+ speaker_name_inputs = []
43
+ for spk in speakers:
44
+ tb = gr.Textbox(label=f"Спикер {spk}", value=f"Спикер {spk}", interactive=True, key=f"name_{spk}")
45
+ speaker_name_inputs.append(tb)
46
+
47
+ # Load audio for slicing segments
48
+ try:
49
+ import torchaudio
50
+ waveform, sample_rate = torchaudio.load(audio_path)
51
+ except Exception:
52
+ waveform, sample_rate = None, None
53
+
54
+ transcripts = []
55
+ transcript_text_inputs = []
56
+ # Render each segment
57
+ for i, seg in enumerate(segments):
58
+ speaker = seg["speaker"]
59
+ text = seg["text"]
60
+ transcripts.append((speaker, text))
61
+ start, end = seg["start"], seg["end"]
62
+ seg_audio_path = audio_path
63
+ if waveform is not None and sample_rate is not None:
64
+ start_idx = int(start * sample_rate)
65
+ end_idx = int(end * sample_rate)
66
+ segment_waveform = waveform[:, start_idx:end_idx]
67
+ seg_audio_path = f"segment_{i}.wav"
68
+ torchaudio.save(seg_audio_path, segment_waveform, sample_rate)
69
+ with gr.Row():
70
+ gr.Audio(value=seg_audio_path, format="audio/wav", show_label=False, interactive=False, key=f"audio_{i}")
71
+ tb_seg = gr.Textbox(value=text, lines=2, label=f"Спикер {speaker}", key=f"text_{i}", interactive=True)
72
+ transcript_text_inputs.append(tb_seg)
73
+
74
+ # Define save function
75
+ def save_func(*args):
76
+ names = list(args[:len(speakers)])
77
+ texts = list(args[len(speakers):])
78
+ name_map = {speakers[j]: names[j] for j in range(len(speakers))}
79
+ with open("result.txt", "w", encoding="utf-8") as f:
80
+ for idx, (speaker, _) in enumerate(transcripts):
81
+ name = name_map.get(speaker, f"Спикер {speaker}")
82
+ text = texts[idx]
83
+ f.write(f"{name}: {text}\n")
84
+ return "result.txt"
85
+
86
+ if speakers or transcripts:
87
+ save_btn.click(save_func, inputs=speaker_name_inputs + transcript_text_inputs, outputs=output_file)
88
+
89
+ app.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
90
 
91
  if __name__ == "__main__":
92
+ create_app()