PacoFYM commited on
Commit
cfd126b
·
verified ·
1 Parent(s): 310e379

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -71
app.py CHANGED
@@ -3,90 +3,126 @@ import torch
3
  import whisperx
4
  from pyannote.audio import Pipeline
5
  import gradio as gr
 
6
 
7
  def create_app():
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- hf_token = os.getenv("HF_TOKEN") or os.getenv("hf_token") or ""
10
-
11
  with gr.Blocks() as app:
12
  gr.Markdown("<h1>Транскрипция и диаризация аудио</h1>")
13
- gr.Markdown("Загрузите аудиофайл и нажмите **Транскрибировать**. После обработки вы сможете прослушать сегменты, отредактировать текст и присвоить имена спикерам.")
14
- audio_input = gr.Audio(label="Аудиофайл", source="upload", type="filepath")
 
 
 
 
 
15
  transcribe_btn = gr.Button("Транскрибировать")
16
-
 
 
17
  save_btn = gr.Button("Сохранить результат")
18
- output_file = gr.File(label="Скачайте результат (.txt)")
19
-
20
- @gr.render(inputs=[audio_input], triggers=[transcribe_btn])
21
- def process(audio_path):
22
- if not audio_path:
23
- return
24
- # 1. WhisperX transcription
25
- model = whisperx.load_model("small", device, compute_type="float32")
26
  audio_array = whisperx.load_audio(audio_path)
27
- result = model.transcribe(audio_array, batch_size=16, language="ru")
28
- model_a, metadata = whisperx.load_align_model(language_code="ru", device=device)
29
- result = whisperx.align(result["segments"], model_a, metadata, audio_array, device=device, return_char_alignments=False)
30
- # 2. Speaker diarization
31
- pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", use_auth_token=hf_token)
32
- pipeline.to(device)
33
- diarization = pipeline(audio_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  result = whisperx.assign_word_speakers(diarization, result)
 
 
35
  segments = result["segments"]
36
- # Unique speakers
37
  speakers = sorted({seg["speaker"] for seg in segments})
38
-
39
- # Input fields for speaker names
40
- if speakers:
41
- gr.Markdown("**Имена спикеров:**")
42
- speaker_name_inputs = []
43
- for spk in speakers:
44
- tb = gr.Textbox(label=f"Спикер {spk}", value=f"Спикер {spk}", interactive=True, key=f"name_{spk}")
45
- speaker_name_inputs.append(tb)
46
-
47
- # Load audio for slicing segments
48
- try:
49
- import torchaudio
50
- waveform, sample_rate = torchaudio.load(audio_path)
51
- except Exception:
52
- waveform, sample_rate = None, None
53
-
54
- transcripts = []
55
- transcript_text_inputs = []
56
- # Render each segment
57
- for i, seg in enumerate(segments):
58
- speaker = seg["speaker"]
59
- text = seg["text"]
60
- transcripts.append((speaker, text))
61
- start, end = seg["start"], seg["end"]
62
- seg_audio_path = audio_path
63
- if waveform is not None and sample_rate is not None:
64
- start_idx = int(start * sample_rate)
65
- end_idx = int(end * sample_rate)
66
- segment_waveform = waveform[:, start_idx:end_idx]
67
- seg_audio_path = f"segment_{i}.wav"
68
- torchaudio.save(seg_audio_path, segment_waveform, sample_rate)
69
- with gr.Row():
70
- gr.Audio(value=seg_audio_path, format="audio/wav", show_label=False, interactive=False, key=f"audio_{i}")
71
- tb_seg = gr.Textbox(value=text, lines=2, label=f"Спикер {speaker}", key=f"text_{i}", interactive=True)
72
- transcript_text_inputs.append(tb_seg)
73
-
74
- # Define save function
75
- def save_func(*args):
76
- names = list(args[:len(speakers)])
77
- texts = list(args[len(speakers):])
78
- name_map = {speakers[j]: names[j] for j in range(len(speakers))}
 
79
  with open("result.txt", "w", encoding="utf-8") as f:
80
- for idx, (speaker, _) in enumerate(transcripts):
81
- name = name_map.get(speaker, f"Спикер {speaker}")
82
- text = texts[idx]
83
- f.write(f"{name}: {text}\n")
84
  return "result.txt"
85
-
86
- if speakers or transcripts:
87
- save_btn.click(save_func, inputs=speaker_name_inputs + transcript_text_inputs, outputs=output_file)
88
-
89
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  if __name__ == "__main__":
92
  create_app()
 
3
  import whisperx
4
  from pyannote.audio import Pipeline
5
  import gradio as gr
6
+ import torchaudio
7
 
8
  def create_app():
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ hf_token = os.getenv("HF_TOKEN", "")
11
+
12
  with gr.Blocks() as app:
13
  gr.Markdown("<h1>Транскрипция и диаризация аудио</h1>")
14
+ gr.Markdown(
15
+ "Загрузите аудиофайл (формат WAV/MP3), нажмите **Транскрибировать**, "
16
+ "отредактируйте результат и сохраните его."
17
+ )
18
+
19
+ # Убираем `source="upload"` — по умолчанию Audio позволяет загрузку
20
+ audio_input = gr.Audio(label="Аудиофайл", type="filepath")
21
  transcribe_btn = gr.Button("Транскрибировать")
22
+
23
+ # Здесь будут динамически добавляться поля для редактирования
24
+ segment_container = gr.Column()
25
  save_btn = gr.Button("Сохранить результат")
26
+ output_file = gr.File(label="Скачать .txt")
27
+
28
+ def transcribe_with_diarization(audio_path):
29
+ # 1) Транскрипция WhisperX с фиксированным языком "ru"
30
+ asr_model = whisperx.load_model("small", device, compute_type="float32")
 
 
 
31
  audio_array = whisperx.load_audio(audio_path)
32
+ result = asr_model.transcribe(
33
+ audio_array,
34
+ batch_size=16,
35
+ language="ru"
36
+ )
37
+ align_model, metadata = whisperx.load_align_model(
38
+ language_code="ru", device=device
39
+ )
40
+ result = whisperx.align(
41
+ result["segments"],
42
+ align_model,
43
+ metadata,
44
+ audio_array,
45
+ device=device,
46
+ return_char_alignments=False
47
+ )
48
+
49
+ # 2) Диаризация Pyannote
50
+ diar_pipeline = Pipeline.from_pretrained(
51
+ "pyannote/speaker-diarization-3.1",
52
+ use_auth_token=hf_token
53
+ ).to(device)
54
+ diarization = diar_pipeline(audio_path)
55
  result = whisperx.assign_word_speakers(diarization, result)
56
+
57
+ # 3) Подготовка UI сегментов
58
  segments = result["segments"]
 
59
  speakers = sorted({seg["speaker"] for seg in segments})
60
+
61
+ # Очищаем контейнер и добавляем новые поля
62
+ segment_container.clear()
63
+
64
+ # Поля для переименования спикеров
65
+ name_inputs = {}
66
+ with segment_container:
67
+ gr.Markdown("**Укажите имена спикеров:**")
68
+ for spk in speakers:
69
+ name_inputs[spk] = gr.Textbox(
70
+ label=f"Спикер {spk}",
71
+ value=f"Спикер {spk}"
72
+ )
73
+
74
+ gr.Markdown("---")
75
+ gr.Markdown("**Отредактируйте текст сегментов:**")
76
+ text_inputs = []
77
+ for i, seg in enumerate(segments):
78
+ start, end = seg["start"], seg["end"]
79
+ speaker = seg["speaker"]
80
+ txt = seg["text"]
81
+ # Срез аудио для сегмента
82
+ seg_path = f"seg_{i}.wav"
83
+ wave, sr = torchaudio.load(audio_path)
84
+ torchaudio.save(
85
+ seg_path,
86
+ wave[:, int(start*sr):int(end*sr)],
87
+ sr
88
+ )
89
+ with gr.Row():
90
+ gr.Audio(value=seg_path, format="wav", label=None)
91
+ ti = gr.Textbox(
92
+ value=txt,
93
+ label=f"{name_inputs[speaker].value}: {start:.1f}-{end:.1f}s",
94
+ lines=2
95
+ )
96
+ text_inputs.append((speaker, ti))
97
+
98
+ # Функция сохранения
99
+ def save_result(**kwargs):
100
+ # kwargs содержит сначала name_inputs, потом text_inputs
101
+ names = {spk: kwargs[f"Спикер {spk}"] for spk in speakers}
102
  with open("result.txt", "w", encoding="utf-8") as f:
103
+ for spk, ti in text_inputs:
104
+ text = kwargs[ti.label]
105
+ f.write(f"{names[spk]}: {text}\n")
 
106
  return "result.txt"
107
+
108
+ # Создаем привязку кнопки сохранения
109
+ save_btn.click(
110
+ fn=save_result,
111
+ inputs=list(name_inputs.values()) + [ti for _, ti in text_inputs],
112
+ outputs=output_file
113
+ )
114
+
115
+ transcribe_btn.click(
116
+ fn=transcribe_with_diarization,
117
+ inputs=audio_input,
118
+ outputs=[]
119
+ )
120
+
121
+ app.launch(
122
+ server_name="0.0.0.0",
123
+ server_port=7860,
124
+ show_api=False
125
+ )
126
 
127
  if __name__ == "__main__":
128
  create_app()