liuyang commited on
Commit
25a2b6b
·
1 Parent(s): 57aeeb0

Enhance audio transcription by adding support for 'faster_whisper' engine alongside 'whisperx'. Implement lazy loading for both transcription models and improve handling of transcribe options. Update transcribe_full_audio method to accommodate engine selection and adjust alignment process accordingly.

Browse files
Files changed (1) hide show
  1. app.py +178 -54
app.py CHANGED
@@ -395,6 +395,8 @@ def _process_single_chunk(task: dict, out_dir: str) -> dict:
395
  # Lazy global holder ----------------------------------------------------------
396
  _whipser_x_transcribe_models = {}
397
  _whipser_x_align_models = {}
 
 
398
 
399
  _diarizer = None
400
  _embedder = None
@@ -502,77 +504,198 @@ class WhisperTranscriber:
502
  return meta
503
 
504
  @spaces.GPU # each call gets a GPU slice
505
- def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
506
- """Transcribe the entire audio file using WhisperX with alignment"""
507
- global _whipser_x_transcribe_models, _whipser_x_align_models
508
-
509
- # Load whisperX model lazily on first use (within GPU context)
510
- if model_name not in _whipser_x_transcribe_models:
511
- print(f"Loading WhisperX transcribe model '{model_name}' on GPU...")
512
- if model_name not in MODELS:
513
- raise ValueError(f"Model '{model_name}' not found in MODELS registry. Available: {list(MODELS.keys())}")
514
-
515
- whisperx_model_name = MODELS[model_name]["whisperx_name"]
516
- device = "cuda"
517
- compute_type = "float16"
518
-
519
- whisper_model = whisperx.load_model(
520
- whisperx_model_name,
521
- device=device,
522
- compute_type=compute_type,
523
- download_root=CACHE_ROOT
524
- )
525
- _whipser_x_transcribe_models[model_name] = whisper_model
526
- print(f"WhisperX transcribe model '{model_name}' loaded successfully")
527
- else:
528
- whisper_model = _whipser_x_transcribe_models[model_name]
529
-
530
- print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
531
  start_time = time.time()
532
-
533
- # Load audio with whisperx
534
  audio = whisperx.load_audio(audio_path)
535
  print(audio_path)
536
- # Transcribe with whisperx
537
- result = whisper_model.transcribe(
538
- audio,
539
- language=language,
540
- batch_size=batch_size,
541
- #initial_prompt=prompt,
542
- #task="translate" if translate else "transcribe"
543
- )
544
-
545
- detected_language = result.get("language", language if language else "unknown")
546
- segments = result.get("segments", [])
547
-
548
- print(f"Detected language: {detected_language}, segments: {len(segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
549
- print(segments)
550
- # Align whisper output with alignment model if language is supported
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551
  if detected_language in _whipser_x_align_models:
552
  print(f"Performing WhisperX alignment for language '{detected_language}'...")
553
  align_start = time.time()
554
  try:
555
  align_info = _whipser_x_align_models[detected_language]
556
-
557
- result = whisperx.align(
558
- result["segments"],
559
  align_info["model"],
560
  align_info["metadata"],
561
  audio,
562
  "cuda",
563
  return_char_alignments=False
564
  )
565
- segments = result.get("segments", segments)
566
  print(f"WhisperX alignment completed in {time.time() - align_start:.2f} seconds")
567
  except Exception as e:
568
  print(f"WhisperX alignment failed: {e}, using original timestamps")
569
  else:
570
  print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
571
-
572
  # Process segments into the expected format
573
  results = []
574
  for seg in segments:
575
- # Create result entry with detailed format
576
  words_list = []
577
  if "words" in seg:
578
  for word in seg["words"]:
@@ -581,18 +704,19 @@ class WhisperTranscriber:
581
  "end": float(word.get("end", 0.0)) + float(base_offset_s),
582
  "word": word.get("word", ""),
583
  "probability": word.get("score", 1.0),
584
- "speaker": "SPEAKER_00" # No speaker identification in full transcription
585
  })
586
-
587
  results.append({
588
  "start": float(seg.get("start", 0.0)) + float(base_offset_s),
589
  "end": float(seg.get("end", 0.0)) + float(base_offset_s),
590
  "text": seg.get("text", ""),
591
- "speaker": "SPEAKER_00", # Single speaker assumption
592
  "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
593
  "words": words_list,
594
  "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
595
  })
 
596
  print(results)
597
  transcription_time = time.time() - start_time
598
  print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
@@ -1043,7 +1167,7 @@ class WhisperTranscriber:
1043
 
1044
  # Step 2: Transcribe full audio once
1045
  transcription_results, detected_language = self.transcribe_full_audio(
1046
- wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None, model_name=model_name, transcribe_options=transcribe_options
1047
  )
1048
 
1049
  # Step 6: Return results
@@ -1094,7 +1218,7 @@ class WhisperTranscriber:
1094
 
1095
  # Step 2: Transcribe full audio once
1096
  transcription_result, detected_language = self.transcribe_full_audio(
1097
- wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, clip_timestamps=None, model_name=model_name, transcribe_options=transcribe_options
1098
  )
1099
 
1100
  # Step 6: Return results
 
395
  # Lazy global holder ----------------------------------------------------------
396
  _whipser_x_transcribe_models = {}
397
  _whipser_x_align_models = {}
398
+ _faster_whisper_transcribe_models = {}
399
+ _faster_whisper_batched_pipelines = {}
400
 
401
  _diarizer = None
402
  _embedder = None
 
504
  return meta
505
 
506
  @spaces.GPU # each call gets a GPU slice
507
+ def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, engine="whisperx", model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
508
+ """Transcribe the entire audio file using selected engine, then align with WhisperX.
509
+
510
+ engine: "whisperx" | "faster_whisper"
511
+ Always uses WhisperX alignment regardless of transcription engine.
512
+ """
513
+ global _whipser_x_transcribe_models, _whipser_x_align_models, _faster_whisper_transcribe_models
514
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  start_time = time.time()
516
+
517
+ # Load audio (float32, 16k) once
518
  audio = whisperx.load_audio(audio_path)
519
  print(audio_path)
520
+
521
+ # Resolve engine (allow override from transcribe_options)
522
+ if transcribe_options and isinstance(transcribe_options, dict) and transcribe_options.get("engine"):
523
+ engine = str(transcribe_options.get("engine")).strip().lower()
524
+
525
+ # Transcribe using the selected engine
526
+ initial_segments = []
527
+ detected_language = language if language else "unknown"
528
+
529
+ if engine == "whisperx":
530
+ # Lazy-load WhisperX model on first use
531
+ if model_name not in _whipser_x_transcribe_models:
532
+ print(f"Loading WhisperX transcribe model '{model_name}' on GPU...")
533
+ if model_name not in MODELS:
534
+ raise ValueError(f"Model '{model_name}' not found in MODELS registry. Available: {list(MODELS.keys())}")
535
+ whisperx_model_name = MODELS[model_name]["whisperx_name"]
536
+ device = "cuda"
537
+ compute_type = "float16"
538
+ whisper_model = whisperx.load_model(
539
+ whisperx_model_name,
540
+ device=device,
541
+ compute_type=compute_type,
542
+ download_root=CACHE_ROOT,
543
+ asr_options=transcribe_options
544
+ )
545
+ _whipser_x_transcribe_models[model_name] = whisper_model
546
+ print(f"WhisperX transcribe model '{model_name}' loaded successfully")
547
+ else:
548
+ whisper_model = _whipser_x_transcribe_models[model_name]
549
+
550
+ print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
551
+ result = whisper_model.transcribe(
552
+ audio,
553
+ language=language,
554
+ batch_size=batch_size,
555
+ #initial_prompt=prompt,
556
+ #task="translate" if translate else "transcribe"
557
+ )
558
+ detected_language = result.get("language", detected_language)
559
+ initial_segments = result.get("segments", [])
560
+
561
+ elif engine == "faster_whisper":
562
+ # Lazy-load Faster-Whisper model on first use
563
+ if model_name not in _faster_whisper_transcribe_models:
564
+ print(f"Loading Faster-Whisper transcribe model '{model_name}' on GPU...")
565
+ # Use the same name by default; extend MODELS with specific mapping if needed
566
+ faster_name = MODELS.get(model_name, {}).get("whisperx_name", model_name)
567
+ fw_model = WhisperModel(
568
+ faster_name,
569
+ device="cuda",
570
+ compute_type="float16",
571
+ download_root=CACHE_ROOT,
572
+ )
573
+ _faster_whisper_transcribe_models[model_name] = fw_model
574
+ print(f"Faster-Whisper transcribe model '{model_name}' loaded successfully")
575
+ else:
576
+ fw_model = _faster_whisper_transcribe_models[model_name]
577
+
578
+ print(f"Transcribing full audio with Faster-Whisper model '{model_name}' and batch size {batch_size}...")
579
+ task = "translate" if translate else "transcribe"
580
+
581
+ # Build kwargs from transcribe_options for Faster-Whisper's transcribe API
582
+ fw_kwargs = {}
583
+ if isinstance(transcribe_options, dict):
584
+ allowed = {
585
+ "log_progress",
586
+ "beam_size",
587
+ "best_of",
588
+ "patience",
589
+ "length_penalty",
590
+ "repetition_penalty",
591
+ "no_repeat_ngram_size",
592
+ "temperature",
593
+ "compression_ratio_threshold",
594
+ "log_prob_threshold",
595
+ "no_speech_threshold",
596
+ "condition_on_previous_text",
597
+ "prompt_reset_on_temperature",
598
+ "initial_prompt",
599
+ "prefix",
600
+ "suppress_blank",
601
+ "suppress_tokens",
602
+ "without_timestamps",
603
+ "max_initial_timestamp",
604
+ #"word_timestamps",
605
+ #"prepend_punctuations",
606
+ #"append_punctuations",
607
+ "multilingual",
608
+ "vad_filter",
609
+ "vad_parameters",
610
+ "max_new_tokens",
611
+ "chunk_length",
612
+ "clip_timestamps",
613
+ "hallucination_silence_threshold",
614
+ "batch_size",
615
+ "hotwords",
616
+ "language_detection_threshold",
617
+ "language_detection_segments",
618
+ }
619
+ for k in allowed:
620
+ if k in transcribe_options and transcribe_options[k] is not None:
621
+ fw_kwargs[k] = transcribe_options[k]
622
+
623
+ # Ensure sensible defaults and avoid duplicates
624
+ if "initial_prompt" not in fw_kwargs and prompt is not None:
625
+ fw_kwargs["initial_prompt"] = prompt
626
+ if "batch_size" not in fw_kwargs and batch_size is not None:
627
+ fw_kwargs["batch_size"] = batch_size
628
+ if "vad_filter" not in fw_kwargs:
629
+ fw_kwargs["vad_filter"] = False # preserve boundaries for alignment
630
+
631
+ # language and task are passed explicitly; do not include in fw_kwargs
632
+ fw_kwargs.pop("language", None)
633
+ fw_kwargs.pop("task", None)
634
+ fw_kwargs["prepend_punctuations"] = "\"'“¿([{-"
635
+ fw_kwargs["append_punctuations"] = "\"'.。,,!!??::”)]}、"
636
+ fw_kwargs["without_timestamps"] = True
637
+ fw_kwargs["max_initial_timestamp"] = 0.0
638
+ fw_kwargs["word_timestamps"] = False
639
+
640
+ # Choose between single and batched transcription per docs
641
+ effective_bs = int(fw_kwargs.get("batch_size", batch_size if batch_size is not None else 8))
642
+ use_batched = effective_bs > 1
643
+
644
+ # Note: pass numpy audio
645
+ if use_batched:
646
+ if model_name not in _faster_whisper_batched_pipelines:
647
+ _faster_whisper_batched_pipelines[model_name] = BatchedInferencePipeline(model=fw_model)
648
+ batched_model = _faster_whisper_batched_pipelines[model_name]
649
+ segments_iter, info = batched_model.transcribe(
650
+ audio,
651
+ language=language,
652
+ task=task,
653
+ **fw_kwargs,
654
+ )
655
+ else:
656
+ segments_iter, info = fw_model.transcribe(
657
+ audio,
658
+ language=language,
659
+ task=task,
660
+ **fw_kwargs,
661
+ )
662
+ detected_language = getattr(info, "language", detected_language)
663
+ # Convert to WhisperX-like segment dicts
664
+ initial_segments = [{
665
+ "start": float(s.start),
666
+ "end": float(s.end),
667
+ "text": s.text or "",
668
+ } for s in segments_iter]
669
+
670
+ else:
671
+ raise ValueError(f"Unknown engine '{engine}'. Supported: 'whisperx', 'faster_whisper'")
672
+
673
+ print(f"Detected language: {detected_language}, segments: {len(initial_segments)}, transcribing done in {time.time() - start_time:.2f} seconds")
674
+ # Align with WhisperX if supported for detected language (always attempt when available)
675
+ segments = initial_segments
676
  if detected_language in _whipser_x_align_models:
677
  print(f"Performing WhisperX alignment for language '{detected_language}'...")
678
  align_start = time.time()
679
  try:
680
  align_info = _whipser_x_align_models[detected_language]
681
+ align_result = whisperx.align(
682
+ initial_segments,
 
683
  align_info["model"],
684
  align_info["metadata"],
685
  audio,
686
  "cuda",
687
  return_char_alignments=False
688
  )
689
+ segments = align_result.get("segments", segments)
690
  print(f"WhisperX alignment completed in {time.time() - align_start:.2f} seconds")
691
  except Exception as e:
692
  print(f"WhisperX alignment failed: {e}, using original timestamps")
693
  else:
694
  print(f"No WhisperX alignment model available for language '{detected_language}', using original timestamps")
695
+
696
  # Process segments into the expected format
697
  results = []
698
  for seg in segments:
 
699
  words_list = []
700
  if "words" in seg:
701
  for word in seg["words"]:
 
704
  "end": float(word.get("end", 0.0)) + float(base_offset_s),
705
  "word": word.get("word", ""),
706
  "probability": word.get("score", 1.0),
707
+ "speaker": "SPEAKER_00"
708
  })
709
+
710
  results.append({
711
  "start": float(seg.get("start", 0.0)) + float(base_offset_s),
712
  "end": float(seg.get("end", 0.0)) + float(base_offset_s),
713
  "text": seg.get("text", ""),
714
+ "speaker": "SPEAKER_00",
715
  "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
716
  "words": words_list,
717
  "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
718
  })
719
+
720
  print(results)
721
  transcription_time = time.time() - start_time
722
  print(f"Full audio transcribed and aligned in {transcription_time:.2f} seconds using batch size {batch_size}")
 
1167
 
1168
  # Step 2: Transcribe full audio once
1169
  transcription_results, detected_language = self.transcribe_full_audio(
1170
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, engine=transcribe_options.get("engine", "whisperx"), model_name=model_name, transcribe_options=transcribe_options
1171
  )
1172
 
1173
  # Step 6: Return results
 
1218
 
1219
  # Step 2: Transcribe full audio once
1220
  transcription_result, detected_language = self.transcribe_full_audio(
1221
+ wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s, engine=transcribe_options.get("engine", "faster_whisper"), model_name=model_name, transcribe_options=transcribe_options
1222
  )
1223
 
1224
  # Step 6: Return results