banao-tech commited on
Commit
3bc012e
·
verified ·
1 Parent(s): e0d7644

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -14
app.py CHANGED
@@ -11,13 +11,15 @@ from pathlib import Path
11
 
12
  import gradio as gr
13
  import pandas as pd
 
14
  import torch
15
  from faster_whisper import WhisperModel
16
  from pyannote.audio import Pipeline
17
 
18
- DIAR_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
- ASR_DEVICE = "cpu"
20
- ASR_COMPUTE_TYPE = "int8"
 
21
 
22
  BAD_PHRASES = [
23
  "transcribe exactly",
@@ -66,6 +68,13 @@ def to_wav_16k_mono(input_path: Path, output_path: Path, enhance_audio: bool):
66
  run_cmd(cmd)
67
  return output_path
68
 
 
 
 
 
 
 
 
69
  def normalize_spaces(text):
70
  text = (text or "").replace("\n", " ").replace("\r", " ")
71
  text = re.sub(r"\s+", " ", text).strip()
@@ -108,10 +117,12 @@ def format_hhmmss_mmm(seconds):
108
  def preflight(media_file, language, enhance_audio, num_speakers, min_speakers, max_speakers):
109
  lines = [
110
  "=== PREFLIGHT ===",
111
- f"Diarization device: {DIAR_DEVICE}",
112
  f"ASR device: {ASR_DEVICE}",
 
113
  "Diarization model: pyannote/speaker-diarization-community-1",
114
- "ASR model: medium (CPU)",
 
115
  f"Language: {language}",
116
  f"Enhance audio: {enhance_audio}",
117
  f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}",
@@ -132,7 +143,7 @@ def preflight(media_file, language, enhance_audio, num_speakers, min_speakers, m
132
  if dur is not None:
133
  lines.append(f"Estimated duration: {dur:.2f} sec")
134
  if dur > 1800:
135
- lines.append("Warning: long file. Community-1 space uses CPU ASR for stability.")
136
  except Exception as e:
137
  lines.append(f"File inspection failed: {e}")
138
  return "\n".join(lines)
@@ -224,11 +235,11 @@ def process_media(media_file, language, enhance_audio, filter_known_bad, num_spe
224
  progress(0.05, desc="Preparing audio")
225
  to_wav_16k_mono(input_path, wav_path, enhance_audio=enhance_audio)
226
 
227
- progress(0.16, desc="Loading ASR model: medium (CPU)")
228
  asr_model = WhisperModel("medium", device=ASR_DEVICE, compute_type=ASR_COMPUTE_TYPE, cpu_threads=4, num_workers=1)
229
  fw_language = None if language == "auto" else language
230
 
231
- progress(0.28, desc="Transcribing")
232
  segments_iter, info = asr_model.transcribe(
233
  str(wav_path),
234
  language=fw_language,
@@ -282,8 +293,9 @@ def process_media(media_file, language, enhance_audio, filter_known_bad, num_spe
282
  if max_speakers and int(max_speakers) > 0:
283
  diar_kwargs["max_speakers"] = int(max_speakers)
284
 
285
- progress(0.72, desc="Running diarization")
286
- output = pipeline(str(wav_path), **diar_kwargs)
 
287
  if hasattr(output, "exclusive_speaker_diarization"):
288
  diarization = output.exclusive_speaker_diarization
289
  elif hasattr(output, "speaker_diarization"):
@@ -345,6 +357,8 @@ def process_media(media_file, language, enhance_audio, filter_known_bad, num_spe
345
  preview_lines = [
346
  "=== RUN SUMMARY ===",
347
  f"Detected language: {info.language}",
 
 
348
  f"ASR segments kept: {asr_segment_count}",
349
  f"ASR words kept: {len(all_words)}",
350
  f"Raw transcript segments: {len(raw_segments)}",
@@ -367,15 +381,13 @@ with gr.Blocks(title="Diarized Speaker Segments Community-1") as demo:
367
  gr.Markdown(
368
  """
369
  # Diarized Speaker Segments Community-1
370
- Uses **pyannote/speaker-diarization-community-1**.
371
 
372
  Cleanup rule:
373
  - if adjacent speaker segments are the same, merge them
374
  - otherwise do not touch them
375
 
376
- Note:
377
- - ASR runs on CPU for compatibility/stability
378
- - diarization uses GPU if available
379
  """
380
  )
381
  with gr.Row():
 
11
 
12
  import gradio as gr
13
  import pandas as pd
14
+ import soundfile as sf
15
  import torch
16
  from faster_whisper import WhisperModel
17
  from pyannote.audio import Pipeline
18
 
19
+ GPU_AVAILABLE = torch.cuda.is_available()
20
+ ASR_DEVICE = "cuda" if GPU_AVAILABLE else "cpu"
21
+ DIAR_DEVICE = "cuda" if GPU_AVAILABLE else "cpu"
22
+ ASR_COMPUTE_TYPE = "float16" if GPU_AVAILABLE else "int8"
23
 
24
  BAD_PHRASES = [
25
  "transcribe exactly",
 
68
  run_cmd(cmd)
69
  return output_path
70
 
71
+ def load_waveform_for_pyannote(wav_path: Path):
72
+ audio, sample_rate = sf.read(str(wav_path), dtype="float32")
73
+ if audio.ndim > 1:
74
+ audio = audio.mean(axis=1)
75
+ waveform = torch.from_numpy(audio).unsqueeze(0)
76
+ return {"waveform": waveform, "sample_rate": int(sample_rate)}
77
+
78
  def normalize_spaces(text):
79
  text = (text or "").replace("\n", " ").replace("\r", " ")
80
  text = re.sub(r"\s+", " ", text).strip()
 
117
  def preflight(media_file, language, enhance_audio, num_speakers, min_speakers, max_speakers):
118
  lines = [
119
  "=== PREFLIGHT ===",
120
+ f"GPU available: {GPU_AVAILABLE}",
121
  f"ASR device: {ASR_DEVICE}",
122
+ f"Diarization device: {DIAR_DEVICE}",
123
  "Diarization model: pyannote/speaker-diarization-community-1",
124
+ "ASR model: medium",
125
+ f"ASR compute type: {ASR_COMPUTE_TYPE}",
126
  f"Language: {language}",
127
  f"Enhance audio: {enhance_audio}",
128
  f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}",
 
143
  if dur is not None:
144
  lines.append(f"Estimated duration: {dur:.2f} sec")
145
  if dur > 1800:
146
+ lines.append("Warning: long file on T4 small. GPU is used, but medium is still recommended.")
147
  except Exception as e:
148
  lines.append(f"File inspection failed: {e}")
149
  return "\n".join(lines)
 
235
  progress(0.05, desc="Preparing audio")
236
  to_wav_16k_mono(input_path, wav_path, enhance_audio=enhance_audio)
237
 
238
+ progress(0.16, desc="Loading ASR model: medium")
239
  asr_model = WhisperModel("medium", device=ASR_DEVICE, compute_type=ASR_COMPUTE_TYPE, cpu_threads=4, num_workers=1)
240
  fw_language = None if language == "auto" else language
241
 
242
+ progress(0.28, desc="Transcribing on GPU")
243
  segments_iter, info = asr_model.transcribe(
244
  str(wav_path),
245
  language=fw_language,
 
293
  if max_speakers and int(max_speakers) > 0:
294
  diar_kwargs["max_speakers"] = int(max_speakers)
295
 
296
+ progress(0.70, desc="Running diarization on GPU")
297
+ media = load_waveform_for_pyannote(wav_path)
298
+ output = pipeline(media, **diar_kwargs)
299
  if hasattr(output, "exclusive_speaker_diarization"):
300
  diarization = output.exclusive_speaker_diarization
301
  elif hasattr(output, "speaker_diarization"):
 
357
  preview_lines = [
358
  "=== RUN SUMMARY ===",
359
  f"Detected language: {info.language}",
360
+ f"ASR device used: {ASR_DEVICE}",
361
+ f"Diarization device used: {DIAR_DEVICE}",
362
  f"ASR segments kept: {asr_segment_count}",
363
  f"ASR words kept: {len(all_words)}",
364
  f"Raw transcript segments: {len(raw_segments)}",
 
381
  gr.Markdown(
382
  """
383
  # Diarized Speaker Segments Community-1
384
+ Uses **pyannote/speaker-diarization-community-1** and **faster-whisper medium**.
385
 
386
  Cleanup rule:
387
  - if adjacent speaker segments are the same, merge them
388
  - otherwise do not touch them
389
 
390
+ This version uses GPU for both ASR and diarization when a GPU is available.
 
 
391
  """
392
  )
393
  with gr.Row():