liuyang commited on
Commit
3de05cb
·
1 Parent(s): d2ef882

switch to whisperX

Browse files
Files changed (2) hide show
  1. app.py +115 -100
  2. requirements.txt +2 -3
app.py CHANGED
@@ -35,8 +35,7 @@ import subprocess
35
  import os
36
  import tempfile
37
  import spaces
38
- from faster_whisper import WhisperModel, BatchedInferencePipeline
39
- from faster_whisper.vad import VadOptions
40
  import requests
41
  import base64
42
  from pyannote.audio import Pipeline, Inference, Model
@@ -118,39 +117,19 @@ from huggingface_hub import snapshot_download
118
  # -----------------------------------------------------------------------------
119
  MODELS = {
120
  "large-v3-turbo": {
121
- "repo_id": "deepdml/faster-whisper-large-v3-turbo-ct2",
122
- "local_dir": f"{CACHE_ROOT}/whisper_turbo_v3"
123
  },
124
  "large-v3": {
125
- "repo_id": "Systran/faster-whisper-large-v3",
126
- "local_dir": f"{CACHE_ROOT}/whisper_large_v3"
127
  },
128
  "large-v2": {
129
- "repo_id": "Systran/faster-whisper-large-v2",
130
- "local_dir": f"{CACHE_ROOT}/whisper_large_v2"
131
  },
132
  }
133
  DEFAULT_MODEL = "large-v3-turbo"
134
 
135
- def _download_model(model_name: str):
136
- """Downloads a model from the hub if not already present."""
137
- if model_name not in MODELS:
138
- raise ValueError(f"Model '{model_name}' not found in MODELS registry.")
139
-
140
- model_info = MODELS[model_name]
141
- if not os.path.exists(model_info["local_dir"]):
142
- print(f"Downloading model '{model_name}' from {model_info['repo_id']}...")
143
- snapshot_download(
144
- repo_id=model_info["repo_id"],
145
- local_dir=model_info["local_dir"],
146
- local_dir_use_symlinks=True,
147
- resume_download=True
148
- )
149
- return model_info["local_dir"]
150
-
151
- # Download the default model on startup
152
- for model in MODELS:
153
- _download_model(model)
154
 
155
 
156
  # -----------------------------------------------------------------------------
@@ -378,9 +357,54 @@ def _process_single_chunk(task: dict, out_dir: str) -> dict:
378
  # Lazy global holder ----------------------------------------------------------
379
  _whisper_models = {}
380
  _batched_whisper_models = {}
 
 
 
381
  _diarizer = None
382
  _embedder = None
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # Create global diarization pipeline
385
  try:
386
  print("Loading diarization model...")
@@ -402,31 +426,22 @@ except Exception as e:
402
 
403
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
404
  def _load_models(model_name: str = DEFAULT_MODEL):
405
- global _whisper_models, _batched_whisper_models, _diarizer
406
 
407
- if model_name not in _whisper_models:
408
- print(f"Loading Whisper model '{model_name}'...")
409
 
410
- model_cache_path = _download_model(model_name)
411
-
412
- model = WhisperModel(
413
- model_cache_path,
414
- device="cuda",
415
- compute_type="float16",
416
- )
417
-
418
- # Create batched inference pipeline for improved performance
419
- batched_model = BatchedInferencePipeline(model=model)
420
-
421
- _whisper_models[model_name] = model
422
- _batched_whisper_models[model_name] = batched_model
423
-
424
- print(f"Whisper model '{model_name}' and batched pipeline loaded successfully")
425
-
426
- whisper = _whisper_models[model_name]
427
- batched_whisper = _batched_whisper_models[model_name]
428
 
429
- return whisper, batched_whisper, _diarizer
 
 
 
 
 
 
 
 
430
 
431
  # -----------------------------------------------------------------------------
432
  class WhisperTranscriber:
@@ -455,76 +470,76 @@ class WhisperTranscriber:
455
 
456
  @spaces.GPU # each call gets a GPU slice
457
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
458
- """Transcribe the entire audio file without speaker diarization using batched inference"""
459
- whisper, batched_whisper, _ = _load_models(model_name) # models live on the GPU
460
 
461
- print(f"Transcribing full audio with '{model_name}' and batch size {batch_size}...")
462
  start_time = time.time()
463
 
464
- # Prepare options for batched inference
465
- options = dict(
 
 
 
 
466
  language=language,
467
- beam_size=5,
468
- word_timestamps=True,
469
  initial_prompt=prompt,
470
- condition_on_previous_text=False, # avoid runaway context
471
- language_detection_segments=1,
472
- task="translate" if translate else "transcribe",
473
  )
474
- if clip_timestamps:
475
- options["vad_filter"] = False
476
- options["clip_timestamps"] = clip_timestamps
477
- else:
478
- vad_options = transcribe_options.get("vad_parameters", None)
479
- options["vad_filter"] = True # VAD is enabled by default for batched transcription
480
- options["vad_parameters"] = VadOptions(**vad_options) if vad_options else VadOptions(
481
- max_speech_duration_s=whisper.feature_extractor.chunk_length,
482
- min_speech_duration_ms=180, # ignore ultra-short blips
483
- min_silence_duration_ms=120, # split on short Mandarin pauses (if supported)
484
- speech_pad_ms=120,
485
- threshold=0.35,
486
- neg_threshold=0.2,
487
- )
488
- if batch_size > 1:
489
- # Use batched inference for better performance
490
- segments, transcript_info = batched_whisper.transcribe(
491
- audio_path,
492
- batch_size=batch_size,
493
- **options
 
 
 
 
494
  )
 
 
495
  else:
496
- segments, transcript_info = whisper.transcribe(
497
- audio_path,
498
- **options
499
- )
500
- segments = list(segments)
501
-
502
- detected_language = transcript_info.language
503
- print("Detected language: ", detected_language, "segments: ", len(segments))
504
 
505
- # Process segments
506
  results = []
507
  for seg in segments:
508
  # Create result entry with detailed format
509
  words_list = []
510
- if seg.words:
511
- for word in seg.words:
512
  words_list.append({
513
- "start": float(word.start) + float(base_offset_s),
514
- "end": float(word.end) + float(base_offset_s),
515
- "word": word.word,
516
- "probability": word.probability,
517
  "speaker": "SPEAKER_00" # No speaker identification in full transcription
518
  })
519
 
520
  results.append({
521
- "start": float(seg.start) + float(base_offset_s),
522
- "end": float(seg.end) + float(base_offset_s),
523
- "text": seg.text,
524
  "speaker": "SPEAKER_00", # Single speaker assumption
525
- "avg_logprob": seg.avg_logprob,
526
  "words": words_list,
527
- "duration": float(seg.end - seg.start)
528
  })
529
 
530
  transcription_time = time.time() - start_time
@@ -549,7 +564,7 @@ class WhisperTranscriber:
549
  try:
550
  embedder = self._load_embedder()
551
  # Provide waveform as (channel, time) and pad if too short
552
- min_embed_duration_sec = 3.0
553
  min_samples = int(min_embed_duration_sec * sample_rate)
554
  if waveform.shape[1] < min_samples:
555
  pad_len = min_samples - waveform.shape[1]
 
35
  import os
36
  import tempfile
37
  import spaces
38
+ import whisperx
 
39
  import requests
40
  import base64
41
  from pyannote.audio import Pipeline, Inference, Model
 
117
  # -----------------------------------------------------------------------------
118
  MODELS = {
119
  "large-v3-turbo": {
120
+ "whisperx_name": "large-v3-turbo",
 
121
  },
122
  "large-v3": {
123
+ "whisperx_name": "large-v3",
 
124
  },
125
  "large-v2": {
126
+ "whisperx_name": "large-v2",
 
127
  },
128
  }
129
  DEFAULT_MODEL = "large-v3-turbo"
130
 
131
+ # Supported languages for alignment models
132
+ ALIGN_LANGUAGES = ["en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh", "ar", "nl", "tr", "pl", "cs", "sv", "da", "fi", "no", "uk"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
 
135
  # -----------------------------------------------------------------------------
 
357
  # Lazy global holder ----------------------------------------------------------
358
  _whisper_models = {}
359
  _batched_whisper_models = {}
360
+ _whipser_x_transcribe_models = {}
361
+ _whipser_x_align_models = {}
362
+
363
  _diarizer = None
364
  _embedder = None
365
 
366
+ # Preload all WhisperX transcribe models
367
+ print("Preloading all WhisperX transcribe models...")
368
+ for model_name in MODELS.keys():
369
+ try:
370
+ print(f"Loading WhisperX model '{model_name}'...")
371
+ whisperx_model_name = MODELS[model_name]["whisperx_name"]
372
+ device = "cpu" # Load on CPU initially, will move to GPU when needed
373
+ compute_type = "float16"
374
+
375
+ model = whisperx.load_model(
376
+ whisperx_model_name,
377
+ device=device,
378
+ compute_type=compute_type,
379
+ download_root=CACHE_ROOT
380
+ )
381
+ _whipser_x_transcribe_models[model_name] = model
382
+ print(f"WhisperX model '{model_name}' loaded successfully")
383
+ except Exception as e:
384
+ import traceback
385
+ traceback.print_exc()
386
+ print(f"Could not load WhisperX model '{model_name}': {e}")
387
+
388
+ # Preload all alignment models for supported languages
389
+ print("Preloading all WhisperX alignment models...")
390
+ for lang in ALIGN_LANGUAGES:
391
+ try:
392
+ print(f"Loading alignment model for language '{lang}'...")
393
+ device = "cpu" # Load on CPU initially, will move to GPU when needed
394
+
395
+ align_model, align_metadata = whisperx.load_align_model(
396
+ language_code=lang,
397
+ device=device,
398
+ model_dir=CACHE_ROOT
399
+ )
400
+ _whipser_x_align_models[lang] = {
401
+ "model": align_model,
402
+ "metadata": align_metadata
403
+ }
404
+ print(f"Alignment model for '{lang}' loaded successfully")
405
+ except Exception as e:
406
+ print(f"Could not load alignment model for '{lang}': {e}")
407
+
408
  # Create global diarization pipeline
409
  try:
410
  print("Loading diarization model...")
 
426
 
427
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
428
  def _load_models(model_name: str = DEFAULT_MODEL):
429
+ global _whipser_x_transcribe_models, _whipser_x_align_models, _diarizer
430
 
431
+ if model_name not in _whipser_x_transcribe_models:
432
+ raise ValueError(f"Model '{model_name}' not preloaded. Available models: {list(_whipser_x_transcribe_models.keys())}")
433
 
434
+ whisper_model = _whipser_x_transcribe_models[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
+ # Move model to GPU if not already
437
+ if hasattr(whisper_model, 'model') and hasattr(whisper_model.model, 'device'):
438
+ current_device = str(whisper_model.model.device)
439
+ if 'cpu' in current_device:
440
+ print(f"Moving WhisperX model '{model_name}' to GPU...")
441
+ whisper_model = whisper_model.to("cuda")
442
+ _whipser_x_transcribe_models[model_name] = whisper_model
443
+
444
+ return whisper_model, _diarizer
445
 
446
  # -----------------------------------------------------------------------------
447
  class WhisperTranscriber:
 
470
 
471
  @spaces.GPU # each call gets a GPU slice
472
  def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0, clip_timestamps=None, model_name: str = DEFAULT_MODEL, transcribe_options: dict = None):
473
+ """Transcribe the entire audio file without speaker diarization using WhisperX"""
474
+ whisper_model, _ = _load_models(model_name) # models live on the GPU
475
 
476
+ print(f"Transcribing full audio with WhisperX model '{model_name}' and batch size {batch_size}...")
477
  start_time = time.time()
478
 
479
+ # Load audio with whisperx
480
+ audio = whisperx.load_audio(audio_path)
481
+
482
+ # Transcribe with whisperx
483
+ result = whisper_model.transcribe(
484
+ audio,
485
  language=language,
486
+ batch_size=batch_size,
 
487
  initial_prompt=prompt,
488
+ task="translate" if translate else "transcribe"
 
 
489
  )
490
+
491
+ detected_language = result.get("language", language if language else "unknown")
492
+ segments = result.get("segments", [])
493
+
494
+ print(f"Detected language: {detected_language}, segments: {len(segments)}")
495
+
496
+ # Align whisper output with alignment model if language is supported
497
+ if detected_language in _whipser_x_align_models:
498
+ print(f"Performing alignment for language '{detected_language}'...")
499
+ align_info = _whipser_x_align_models[detected_language]
500
+
501
+ # Move alignment model to GPU if needed
502
+ align_model = align_info["model"]
503
+ if hasattr(align_model, 'to'):
504
+ align_model = align_model.to("cuda")
505
+ _whipser_x_align_models[detected_language]["model"] = align_model
506
+
507
+ result = whisperx.align(
508
+ result["segments"],
509
+ align_info["model"],
510
+ align_info["metadata"],
511
+ audio,
512
+ "cuda",
513
+ return_char_alignments=False
514
  )
515
+ segments = result.get("segments", segments)
516
+ print(f"Alignment completed")
517
  else:
518
+ print(f"No alignment model available for language '{detected_language}', using original timestamps")
 
 
 
 
 
 
 
519
 
520
+ # Process segments into the expected format
521
  results = []
522
  for seg in segments:
523
  # Create result entry with detailed format
524
  words_list = []
525
+ if "words" in seg:
526
+ for word in seg["words"]:
527
  words_list.append({
528
+ "start": float(word.get("start", 0.0)) + float(base_offset_s),
529
+ "end": float(word.get("end", 0.0)) + float(base_offset_s),
530
+ "word": word.get("word", ""),
531
+ "probability": word.get("score", 1.0),
532
  "speaker": "SPEAKER_00" # No speaker identification in full transcription
533
  })
534
 
535
  results.append({
536
+ "start": float(seg.get("start", 0.0)) + float(base_offset_s),
537
+ "end": float(seg.get("end", 0.0)) + float(base_offset_s),
538
+ "text": seg.get("text", ""),
539
  "speaker": "SPEAKER_00", # Single speaker assumption
540
+ "avg_logprob": seg.get("avg_logprob", 0.0) if "avg_logprob" in seg else 0.0,
541
  "words": words_list,
542
+ "duration": float(seg.get("end", 0.0)) - float(seg.get("start", 0.0))
543
  })
544
 
545
  transcription_time = time.time() - start_time
 
564
  try:
565
  embedder = self._load_embedder()
566
  # Provide waveform as (channel, time) and pad if too short
567
+ min_embed_duration_sec = 1.0
568
  min_samples = int(min_embed_duration_sec * sample_rate)
569
  if waveform.shape[1] < min_samples:
570
  pad_len = min_samples - waveform.shape[1]
requirements.txt CHANGED
@@ -4,9 +4,8 @@ transformers==4.48.0
4
  # https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.4-cp310-cp310-linux_x86_64.whl
5
  pydantic==2.10.6
6
 
7
- # 2. Main whisper model
8
- faster-whisper==1.1.1
9
- ctranslate2==4.5.0
10
  torch
11
 
12
  # 3. Extra libs your app really needs
 
4
  # https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.4-cp310-cp310-linux_x86_64.whl
5
  pydantic==2.10.6
6
 
7
+ # 2. Main whisper model - using whisperx instead of faster-whisper
8
+ whisperx
 
9
  torch
10
 
11
  # 3. Extra libs your app really needs