Update app.py
Browse files
app.py
CHANGED
|
@@ -11,13 +11,15 @@ from pathlib import Path
|
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import pandas as pd
|
|
|
|
| 14 |
import torch
|
| 15 |
from faster_whisper import WhisperModel
|
| 16 |
from pyannote.audio import Pipeline
|
| 17 |
|
| 18 |
-
|
| 19 |
-
ASR_DEVICE = "cpu"
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
BAD_PHRASES = [
|
| 23 |
"transcribe exactly",
|
|
@@ -66,6 +68,13 @@ def to_wav_16k_mono(input_path: Path, output_path: Path, enhance_audio: bool):
|
|
| 66 |
run_cmd(cmd)
|
| 67 |
return output_path
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
def normalize_spaces(text):
|
| 70 |
text = (text or "").replace("\n", " ").replace("\r", " ")
|
| 71 |
text = re.sub(r"\s+", " ", text).strip()
|
|
@@ -108,10 +117,12 @@ def format_hhmmss_mmm(seconds):
|
|
| 108 |
def preflight(media_file, language, enhance_audio, num_speakers, min_speakers, max_speakers):
|
| 109 |
lines = [
|
| 110 |
"=== PREFLIGHT ===",
|
| 111 |
-
f"
|
| 112 |
f"ASR device: {ASR_DEVICE}",
|
|
|
|
| 113 |
"Diarization model: pyannote/speaker-diarization-community-1",
|
| 114 |
-
"ASR model: medium
|
|
|
|
| 115 |
f"Language: {language}",
|
| 116 |
f"Enhance audio: {enhance_audio}",
|
| 117 |
f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}",
|
|
@@ -132,7 +143,7 @@ def preflight(media_file, language, enhance_audio, num_speakers, min_speakers, m
|
|
| 132 |
if dur is not None:
|
| 133 |
lines.append(f"Estimated duration: {dur:.2f} sec")
|
| 134 |
if dur > 1800:
|
| 135 |
-
lines.append("Warning: long file.
|
| 136 |
except Exception as e:
|
| 137 |
lines.append(f"File inspection failed: {e}")
|
| 138 |
return "\n".join(lines)
|
|
@@ -224,11 +235,11 @@ def process_media(media_file, language, enhance_audio, filter_known_bad, num_spe
|
|
| 224 |
progress(0.05, desc="Preparing audio")
|
| 225 |
to_wav_16k_mono(input_path, wav_path, enhance_audio=enhance_audio)
|
| 226 |
|
| 227 |
-
progress(0.16, desc="Loading ASR model: medium
|
| 228 |
asr_model = WhisperModel("medium", device=ASR_DEVICE, compute_type=ASR_COMPUTE_TYPE, cpu_threads=4, num_workers=1)
|
| 229 |
fw_language = None if language == "auto" else language
|
| 230 |
|
| 231 |
-
progress(0.28, desc="Transcribing")
|
| 232 |
segments_iter, info = asr_model.transcribe(
|
| 233 |
str(wav_path),
|
| 234 |
language=fw_language,
|
|
@@ -282,8 +293,9 @@ def process_media(media_file, language, enhance_audio, filter_known_bad, num_spe
|
|
| 282 |
if max_speakers and int(max_speakers) > 0:
|
| 283 |
diar_kwargs["max_speakers"] = int(max_speakers)
|
| 284 |
|
| 285 |
-
progress(0.
|
| 286 |
-
|
|
|
|
| 287 |
if hasattr(output, "exclusive_speaker_diarization"):
|
| 288 |
diarization = output.exclusive_speaker_diarization
|
| 289 |
elif hasattr(output, "speaker_diarization"):
|
|
@@ -345,6 +357,8 @@ def process_media(media_file, language, enhance_audio, filter_known_bad, num_spe
|
|
| 345 |
preview_lines = [
|
| 346 |
"=== RUN SUMMARY ===",
|
| 347 |
f"Detected language: {info.language}",
|
|
|
|
|
|
|
| 348 |
f"ASR segments kept: {asr_segment_count}",
|
| 349 |
f"ASR words kept: {len(all_words)}",
|
| 350 |
f"Raw transcript segments: {len(raw_segments)}",
|
|
@@ -367,15 +381,13 @@ with gr.Blocks(title="Diarized Speaker Segments Community-1") as demo:
|
|
| 367 |
gr.Markdown(
|
| 368 |
"""
|
| 369 |
# Diarized Speaker Segments Community-1
|
| 370 |
-
Uses **pyannote/speaker-diarization-community-1**.
|
| 371 |
|
| 372 |
Cleanup rule:
|
| 373 |
- if adjacent speaker segments are the same, merge them
|
| 374 |
- otherwise do not touch them
|
| 375 |
|
| 376 |
-
|
| 377 |
-
- ASR runs on CPU for compatibility/stability
|
| 378 |
-
- diarization uses GPU if available
|
| 379 |
"""
|
| 380 |
)
|
| 381 |
with gr.Row():
|
|
|
|
| 11 |
|
| 12 |
import gradio as gr
|
| 13 |
import pandas as pd
|
| 14 |
+
import soundfile as sf
|
| 15 |
import torch
|
| 16 |
from faster_whisper import WhisperModel
|
| 17 |
from pyannote.audio import Pipeline
|
| 18 |
|
| 19 |
+
GPU_AVAILABLE = torch.cuda.is_available()
|
| 20 |
+
ASR_DEVICE = "cuda" if GPU_AVAILABLE else "cpu"
|
| 21 |
+
DIAR_DEVICE = "cuda" if GPU_AVAILABLE else "cpu"
|
| 22 |
+
ASR_COMPUTE_TYPE = "float16" if GPU_AVAILABLE else "int8"
|
| 23 |
|
| 24 |
BAD_PHRASES = [
|
| 25 |
"transcribe exactly",
|
|
|
|
| 68 |
run_cmd(cmd)
|
| 69 |
return output_path
|
| 70 |
|
| 71 |
+
def load_waveform_for_pyannote(wav_path: Path):
|
| 72 |
+
audio, sample_rate = sf.read(str(wav_path), dtype="float32")
|
| 73 |
+
if audio.ndim > 1:
|
| 74 |
+
audio = audio.mean(axis=1)
|
| 75 |
+
waveform = torch.from_numpy(audio).unsqueeze(0)
|
| 76 |
+
return {"waveform": waveform, "sample_rate": int(sample_rate)}
|
| 77 |
+
|
| 78 |
def normalize_spaces(text):
|
| 79 |
text = (text or "").replace("\n", " ").replace("\r", " ")
|
| 80 |
text = re.sub(r"\s+", " ", text).strip()
|
|
|
|
| 117 |
def preflight(media_file, language, enhance_audio, num_speakers, min_speakers, max_speakers):
|
| 118 |
lines = [
|
| 119 |
"=== PREFLIGHT ===",
|
| 120 |
+
f"GPU available: {GPU_AVAILABLE}",
|
| 121 |
f"ASR device: {ASR_DEVICE}",
|
| 122 |
+
f"Diarization device: {DIAR_DEVICE}",
|
| 123 |
"Diarization model: pyannote/speaker-diarization-community-1",
|
| 124 |
+
"ASR model: medium",
|
| 125 |
+
f"ASR compute type: {ASR_COMPUTE_TYPE}",
|
| 126 |
f"Language: {language}",
|
| 127 |
f"Enhance audio: {enhance_audio}",
|
| 128 |
f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}",
|
|
|
|
| 143 |
if dur is not None:
|
| 144 |
lines.append(f"Estimated duration: {dur:.2f} sec")
|
| 145 |
if dur > 1800:
|
| 146 |
+
lines.append("Warning: long file on T4 small. GPU is used, but medium is still recommended.")
|
| 147 |
except Exception as e:
|
| 148 |
lines.append(f"File inspection failed: {e}")
|
| 149 |
return "\n".join(lines)
|
|
|
|
| 235 |
progress(0.05, desc="Preparing audio")
|
| 236 |
to_wav_16k_mono(input_path, wav_path, enhance_audio=enhance_audio)
|
| 237 |
|
| 238 |
+
progress(0.16, desc="Loading ASR model: medium")
|
| 239 |
asr_model = WhisperModel("medium", device=ASR_DEVICE, compute_type=ASR_COMPUTE_TYPE, cpu_threads=4, num_workers=1)
|
| 240 |
fw_language = None if language == "auto" else language
|
| 241 |
|
| 242 |
+
progress(0.28, desc="Transcribing on GPU")
|
| 243 |
segments_iter, info = asr_model.transcribe(
|
| 244 |
str(wav_path),
|
| 245 |
language=fw_language,
|
|
|
|
| 293 |
if max_speakers and int(max_speakers) > 0:
|
| 294 |
diar_kwargs["max_speakers"] = int(max_speakers)
|
| 295 |
|
| 296 |
+
progress(0.70, desc="Running diarization on GPU")
|
| 297 |
+
media = load_waveform_for_pyannote(wav_path)
|
| 298 |
+
output = pipeline(media, **diar_kwargs)
|
| 299 |
if hasattr(output, "exclusive_speaker_diarization"):
|
| 300 |
diarization = output.exclusive_speaker_diarization
|
| 301 |
elif hasattr(output, "speaker_diarization"):
|
|
|
|
| 357 |
preview_lines = [
|
| 358 |
"=== RUN SUMMARY ===",
|
| 359 |
f"Detected language: {info.language}",
|
| 360 |
+
f"ASR device used: {ASR_DEVICE}",
|
| 361 |
+
f"Diarization device used: {DIAR_DEVICE}",
|
| 362 |
f"ASR segments kept: {asr_segment_count}",
|
| 363 |
f"ASR words kept: {len(all_words)}",
|
| 364 |
f"Raw transcript segments: {len(raw_segments)}",
|
|
|
|
| 381 |
gr.Markdown(
|
| 382 |
"""
|
| 383 |
# Diarized Speaker Segments Community-1
|
| 384 |
+
Uses **pyannote/speaker-diarization-community-1** and **faster-whisper medium**.
|
| 385 |
|
| 386 |
Cleanup rule:
|
| 387 |
- if adjacent speaker segments are the same, merge them
|
| 388 |
- otherwise do not touch them
|
| 389 |
|
| 390 |
+
This version uses GPU for both ASR and diarization when a GPU is available.
|
|
|
|
|
|
|
| 391 |
"""
|
| 392 |
)
|
| 393 |
with gr.Row():
|