liuyang commited on
Commit
d36869b
·
1 Parent(s): 62ed41c

switch to whisperX

Browse files
Files changed (1) hide show
  1. app.py +147 -126
app.py CHANGED
@@ -84,6 +84,7 @@ import tempfile
84
  import spaces
85
  from faster_whisper import WhisperModel, BatchedInferencePipeline
86
  from faster_whisper.vad import VadOptions
 
87
  import requests
88
  import base64
89
  from pyannote.audio import Pipeline, Inference, Model
@@ -154,39 +155,19 @@ from huggingface_hub import snapshot_download
154
  # -----------------------------------------------------------------------------
155
  MODELS = {
156
  "large-v3-turbo": {
157
- "repo_id": "deepdml/faster-whisper-large-v3-turbo-ct2",
158
- "local_dir": f"{CACHE_ROOT}/whisper_turbo_v3"
159
  },
160
  "large-v3": {
161
- "repo_id": "Systran/faster-whisper-large-v3",
162
- "local_dir": f"{CACHE_ROOT}/whisper_large_v3"
163
  },
164
  "large-v2": {
165
- "repo_id": "Systran/faster-whisper-large-v2",
166
- "local_dir": f"{CACHE_ROOT}/whisper_large_v2"
167
  },
168
  }
169
  DEFAULT_MODEL = "large-v3-turbo"
170
 
171
- def _download_model(model_name: str):
172
- """Downloads a model from the hub if not already present."""
173
- if model_name not in MODELS:
174
- raise ValueError(f"Model '{model_name}' not found in MODELS registry.")
175
-
176
- model_info = MODELS[model_name]
177
- if not os.path.exists(model_info["local_dir"]):
178
- print(f"Downloading model '{model_name}' from {model_info['repo_id']}...")
179
- snapshot_download(
180
- repo_id=model_info["repo_id"],
181
- local_dir=model_info["local_dir"],
182
- local_dir_use_symlinks=True,
183
- resume_download=True
184
- )
185
- return model_info["local_dir"]
186
-
187
- # Download the default model on startup
188
- for model in MODELS:
189
- _download_model(model)
190
 
191
 
192
  # -----------------------------------------------------------------------------
@@ -412,60 +393,88 @@ def _process_single_chunk(task: dict, out_dir: str) -> dict:
412
  # model_cache_path = LOCAL_DIR # <‑‑ this is what we pass to WhisperModel
413
 
414
  # Lazy global holder ----------------------------------------------------------
415
- _whisper_models = {}
416
- _batched_whisper_models = {}
417
  _whipser_x_transcribe_models = {}
418
  _whipser_x_align_models = {}
419
 
420
  _diarizer = None
421
  _embedder = None
422
 
423
- # Create global diarization pipeline
424
- try:
425
- print("Loading diarization model...")
426
- torch.backends.cuda.matmul.allow_tf32 = True
427
- torch.backends.cudnn.allow_tf32 = True
428
- torch.set_float32_matmul_precision('high')
429
-
430
- _diarizer = Pipeline.from_pretrained(
431
- "pyannote/speaker-diarization-3.1",
432
- use_auth_token=os.getenv("HF_TOKEN"),
433
- ).to(torch.device("cuda"))
434
 
435
- print("Diarization model loaded successfully")
436
- except Exception as e:
437
- import traceback
438
- traceback.print_exc()
439
- print(f"Could not load diarization model: {e}")
440
- _diarizer = None
441
-
442
- @spaces.GPU # GPU is guaranteed to exist *inside* this function
443
- def _load_models(model_name: str = DEFAULT_MODEL):
444
- global _whisper_models, _batched_whisper_models, _diarizer
 
 
 
 
 
 
 
 
445
 
446
- if model_name not in _whisper_models:
447
- print(f"Loading Whisper model '{model_name}'...")
448
-
449
- model_cache_path = _download_model(model_name)
450
-
451
- model = WhisperModel(
452
- model_cache_path,
453
- device="cuda",
454
- compute_type="float16",
455
- )
456
-
457
- # Create batched inference pipeline for improved performance
458
- batched_model = BatchedInferencePipeline(model=model)
459
-
460
- _whisper_models[model_name] = model
461
- _batched_whisper_models[model_name] = batched_model
462
-
463
- print(f"Whisper model '{model_name}' and batched pipeline loaded successfully")
464
 
465
- whisper = _whisper_models[model_name]
466
- batched_whisper = _batched_whisper_models[model_name]
 
 
 
 
467
 
468
- return whisper, batched_whisper, _diarizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  # -----------------------------------------------------------------------------
471
  class WhisperTranscriber:
@@ -494,81 +503,84 @@ class WhisperTranscriber:
494
 
495
  @spaces.GPU # each call gets a GPU slice
496
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
497
- """Transcribe the entire audio file without speaker diarization using batched inference"""
498
- whisper, batched_whisper, _ = _load_models(model_name) # models live on the GPU
 
 
 
 
499
 
500
- print(f"Transcribing full audio with '{model_name}' and batch size {batch_size}...")
 
 
501
  start_time = time.time()
502
 
503
- # Prepare options for batched inference
504
- options = dict(
 
 
 
 
505
  language=language,
506
- beam_size=5,
507
- word_timestamps=True,
508
- initial_prompt=prompt,
509
- condition_on_previous_text=False, # avoid runaway context
510
- language_detection_segments=1,
511
- task="translate" if translate else "transcribe",
512
  )
513
- if clip_timestamps:
514
- options["vad_filter"] = False
515
- options["clip_timestamps"] = clip_timestamps
516
- else:
517
- vad_options = transcribe_options.get("vad_parameters", None)
518
- options["vad_filter"] = True # VAD is enabled by default for batched transcription
519
- options["vad_parameters"] = VadOptions(**vad_options) if vad_options else VadOptions(
520
- max_speech_duration_s=whisper.feature_extractor.chunk_length,
521
- min_speech_duration_ms=180, # ignore ultra-short blips
522
- min_silence_duration_ms=120, # split on short Mandarin pauses (if supported)
523
- speech_pad_ms=120,
524
- threshold=0.35,
525
- neg_threshold=0.2,
526
- )
527
- if batch_size > 1:
528
- # Use batched inference for better performance
529
- segments, transcript_info = batched_whisper.transcribe(
530
- audio_path,
531
- batch_size=batch_size,
532
- **options
533
- )
534
- else:
535
- segments, transcript_info = whisper.transcribe(
536
- audio_path,
537
- **options
538
- )
539
- segments = list(segments)
540
 
541
- detected_language = transcript_info.language
542
- print("Detected language: ", detected_language, "segments: ", len(segments))
 
 
543
 
544
- # Process segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  results = []
546
  for seg in segments:
547
  # Create result entry with detailed format
548
  words_list = []
549
- if seg.words:
550
- for word in seg.words:
551
  words_list.append({
552
- "start": float(word.start) + float(base_offset_s),
553
- "end": float(word.end) + float(base_offset_s),
554
- "word": word.word,
555
- "probability": word.probability,
556
  "speaker": "SPEAKER_00" # No speaker identification in full transcription
557
  })
558
 
559
  results.append({
560
- "start": float(seg.start) + float(base_offset_s),
561
- "end": float(seg.end) + float(base_offset_s),
562
- "text": seg.text,
563
  "speaker": "SPEAKER_00", # Single speaker assumption
564
- "avg_logprob": seg.avg_logprob,
565
  "words": words_list,
566
- "duration": float(seg.end - seg.start)
567
  })
568
 
569
  transcription_time = time.time() - start_time
570
- print(f"Full audio transcribed in {transcription_time:.2f} seconds using batch size {batch_size}")
571
- print(results)
572
  return results, detected_language
573
 
574
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
@@ -576,9 +588,9 @@ class WhisperTranscriber:
576
  @spaces.GPU # each call gets a GPU slice
577
  def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
578
  """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
579
- _, _, diarizer = _load_models() # models live on the GPU
580
 
581
- if diarizer is None:
582
  print("Diarization model not available, creating single speaker segment")
583
  # Load audio to get duration
584
  waveform, sample_rate = torchaudio.load(audio_path)
@@ -611,7 +623,7 @@ class WhisperTranscriber:
611
  waveform, sample_rate = torchaudio.load(audio_path)
612
 
613
  # Perform diarization
614
- diarization = diarizer(
615
  {"waveform": waveform, "sample_rate": sample_rate},
616
  num_speakers=num_speakers,
617
  )
@@ -1526,5 +1538,14 @@ with demo:
1526
  - Vocabulary: Add names and technical terms in the prompt for better accuracy
1527
  """)
1528
 
 
 
 
 
 
 
 
 
 
1529
  if __name__ == "__main__":
1530
  demo.launch(debug=True)
 
84
  import spaces
85
  from faster_whisper import WhisperModel, BatchedInferencePipeline
86
  from faster_whisper.vad import VadOptions
87
+ import whisperx
88
  import requests
89
  import base64
90
  from pyannote.audio import Pipeline, Inference, Model
 
155
  # -----------------------------------------------------------------------------
156
  MODELS = {
157
  "large-v3-turbo": {
158
+ "whisperx_name": "large-v3-turbo",
 
159
  },
160
  "large-v3": {
161
+ "whisperx_name": "large-v3",
 
162
  },
163
  "large-v2": {
164
+ "whisperx_name": "large-v2",
 
165
  },
166
  }
167
  DEFAULT_MODEL = "large-v3-turbo"
168
 
169
+ # Supported languages for alignment models (whisperX)
170
+ ALIGN_LANGUAGES = ["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh", "ar", "nl", "tr", "pl", "cs", "sv", "da", "fi", "no", "uk"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
 
173
  # -----------------------------------------------------------------------------
 
393
  # model_cache_path = LOCAL_DIR # <‑‑ this is what we pass to WhisperModel
394
 
395
  # Lazy global holder ----------------------------------------------------------
 
 
396
  _whipser_x_transcribe_models = {}
397
  _whipser_x_align_models = {}
398
 
399
  _diarizer = None
400
  _embedder = None
401
 
402
+ # Preload alignment and diarization models at startup (no GPU decorator)
403
+ def _preload_alignment_and_diarization_models():
404
+ """Preload WhisperX alignment and diarization models on CUDA device"""
405
+ global _whipser_x_align_models, _diarizer
 
 
 
 
 
 
 
406
 
407
+ print("Preloading all WhisperX alignment models...")
408
+ for lang in ALIGN_LANGUAGES:
409
+ try:
410
+ print(f"Loading alignment model for language '{lang}'...")
411
+ device = "cuda"
412
+
413
+ align_model, align_metadata = whisperx.load_align_model(
414
+ language_code=lang,
415
+ device=device,
416
+ model_dir=CACHE_ROOT
417
+ )
418
+ _whipser_x_align_models[lang] = {
419
+ "model": align_model,
420
+ "metadata": align_metadata
421
+ }
422
+ print(f"Alignment model for '{lang}' loaded successfully")
423
+ except Exception as e:
424
+ print(f"Could not load alignment model for '{lang}': {e}")
425
 
426
+ # Create global diarization pipeline
427
+ try:
428
+ print("Loading diarization model...")
429
+ torch.backends.cuda.matmul.allow_tf32 = True
430
+ torch.backends.cudnn.allow_tf32 = True
431
+ torch.set_float32_matmul_precision('high')
432
+
433
+ _diarizer = Pipeline.from_pretrained(
434
+ "pyannote/speaker-diarization-3.1",
435
+ use_auth_token=os.getenv("HF_TOKEN"),
436
+ ).to(torch.device("cuda"))
 
 
 
 
 
 
 
437
 
438
+ print("Diarization model loaded successfully")
439
+ except Exception as e:
440
+ import traceback
441
+ traceback.print_exc()
442
+ print(f"Could not load diarization model: {e}")
443
+ _diarizer = None
444
 
445
+ print("WhisperX alignment and diarization models preloaded successfully!")
446
+
447
+ # Call preload function at startup
448
+ _preload_alignment_and_diarization_models()
449
+
450
+ # Preload WhisperX transcribe models with GPU decorator
451
+ @spaces.GPU
452
+ def _preload_whisperx_transcribe_models():
453
+ """Preload all WhisperX transcribe models on GPU"""
454
+ global _whipser_x_transcribe_models
455
+
456
+ print("Preloading all WhisperX transcribe models on GPU...")
457
+ for model_name in MODELS.keys():
458
+ try:
459
+ print(f"Loading WhisperX transcribe model '{model_name}'...")
460
+ whisperx_model_name = MODELS[model_name]["whisperx_name"]
461
+ device = "cuda"
462
+ compute_type = "float16"
463
+
464
+ model = whisperx.load_model(
465
+ whisperx_model_name,
466
+ device=device,
467
+ compute_type=compute_type,
468
+ download_root=CACHE_ROOT
469
+ )
470
+ _whipser_x_transcribe_models[model_name] = model
471
+ print(f"WhisperX transcribe model '{model_name}' loaded successfully")
472
+ except Exception as e:
473
+ import traceback
474
+ traceback.print_exc()
475
+ print(f"Could not load WhisperX transcribe model '{model_name}': {e}")
476
+
477
+ print("All WhisperX transcribe models preloaded successfully!")
478
 
479
  # -----------------------------------------------------------------------------
480
  class WhisperTranscriber:
 
503
 
504
  @spaces.GPU # each call gets a GPU slice
505
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
506
+ """Transcribe the entire audio file using WhisperX with alignment"""
507
+ global _whipser_x_transcribe_models, _whipser_x_align_models
508
+
509
+ # Get preloaded whisperX model
510
+ if model_name not in _whipser_x_transcribe_models:
511
+ raise ValueError(f"WhisperX model '{model_name}' not preloaded. Available models: {list(_whipser_x_transcribe_models.keys())}")
512
 
513
+ whisper_model = _whipser_x_transcribe_models[model_name]
514
+
515
+ print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
516
  start_time = time.time()
517
 
518
+ # Load audio with whisperx
519
+ audio = whisperx.load_audio(audio_path)
520
+
521
+ # Transcribe with whisperx
522
+ result = whisper_model.transcribe(
523
+ audio,
524
  language=language,
525
+ batch_size=batch_size,
526
+ #initial_prompt=prompt,
527
+ #task="translate" if translate else "transcribe"
 
 
 
528
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
 
530
+ detected_language = result.get("language", language if language else "unknown")
531
+ segments = result.get("segments", [])
532
+
533
+ print(f"Detected language: {detected_language}, segments: {len(segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
534
 
535
+ # Align whisper output with alignment model if language is supported
536
+ if detected_language in _whipser_x_align_models:
537
+ print(f"Performing WhisperX alignment for language '{detected_language}'...")
538
+ align_start = time.time()
539
+ try:
540
+ align_info = _whipser_x_align_models[detected_language]
541
+
542
+ result = whisperx.align(
543
+ result["segments"],
544
+ align_info["model"],
545
+ align_info["metadata"],
546
+ audio,
547
+ "cuda",
548
+ return_char_alignments=False
549
+ )
550
+ segments = result.get("segments", segments)
551
+ print(f"WhisperX alignment completed in {time.time() - align_start:.2f} seconds")
552
+ except Exception as e:
553
+ print(f"WhisperX alignment failed: {e}, using original timestamps")
554
+ else:
555
+ print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
556
+
557
+ # Process segments into the expected format
558
  results = []
559
  for seg in segments:
560
  # Create result entry with detailed format
561
  words_list = []
562
+ if "words" in seg:
563
+ for word in seg["words"]:
564
  words_list.append({
565
+ "start": float(word.get("start", 0.0)) + float(base_offset_s),
566
+ "end": float(word.get("end", 0.0)) + float(base_offset_s),
567
+ "word": word.get("word", ""),
568
+ "probability": word.get("score", 1.0),
569
  "speaker": "SPEAKER_00" # No speaker identification in full transcription
570
  })
571
 
572
  results.append({
573
+ "start": float(seg.get("start", 0.0)) + float(base_offset_s),
574
+ "end": float(seg.get("end", 0.0)) + float(base_offset_s),
575
+ "text": seg.get("text", ""),
576
  "speaker": "SPEAKER_00", # Single speaker assumption
577
+ "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
578
  "words": words_list,
579
+ "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
580
  })
581
 
582
  transcription_time = time.time() - start_time
583
+ print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
 
584
  return results, detected_language
585
 
586
  # Removed audio cutting; transcription is done once on the full (preprocessed) audio
 
588
  @spaces.GPU # each call gets a GPU slice
589
  def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0):
590
  """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings."""
591
+ global _diarizer
592
 
593
+ if _diarizer is None:
594
  print("Diarization model not available, creating single speaker segment")
595
  # Load audio to get duration
596
  waveform, sample_rate = torchaudio.load(audio_path)
 
623
  waveform, sample_rate = torchaudio.load(audio_path)
624
 
625
  # Perform diarization
626
+ diarization = _diarizer(
627
  {"waveform": waveform, "sample_rate": sample_rate},
628
  num_speakers=num_speakers,
629
  )
 
1538
  - Vocabulary: Add names and technical terms in the prompt for better accuracy
1539
  """)
1540
 
1541
+ # Preload all WhisperX transcribe models once at service initialization
1542
+ print("Preloading all WhisperX transcribe models at startup...")
1543
+ try:
1544
+ _preload_whisperx_transcribe_models()
1545
+ print("All WhisperX transcribe models preloaded at startup!")
1546
+ except Exception as e:
1547
+ print(f"Warning: Could not preload WhisperX transcribe models at startup: {e}")
1548
+ print("Models will be loaded on first use instead.")
1549
+
1550
  if __name__ == "__main__":
1551
  demo.launch(debug=True)