liuyang commited on
Commit
c18172e
·
1 Parent(s): 02f099e
Files changed (1) hide show
  1. app.py +54 -43
app.py CHANGED
@@ -363,34 +363,12 @@ _whipser_x_align_models = {}
363
  _diarizer = None
364
  _embedder = None
365
 
366
- # Preload all WhisperX transcribe models
367
- print("Preloading all WhisperX transcribe models...")
368
- for model_name in MODELS.keys():
369
- try:
370
- print(f"Loading WhisperX model '{model_name}'...")
371
- whisperx_model_name = MODELS[model_name]["whisperx_name"]
372
- device = "cuda" # Load on CPU initially, will move to GPU when needed
373
- compute_type = "float16"
374
-
375
- model = whisperx.load_model(
376
- whisperx_model_name,
377
- device=device,
378
- compute_type=compute_type,
379
- download_root=CACHE_ROOT
380
- )
381
- _whipser_x_transcribe_models[model_name] = model
382
- print(f"WhisperX model '{model_name}' loaded successfully")
383
- except Exception as e:
384
- import traceback
385
- traceback.print_exc()
386
- print(f"Could not load WhisperX model '{model_name}': {e}")
387
-
388
- # Preload all alignment models for supported languages
389
  print("Preloading all WhisperX alignment models...")
390
  for lang in ALIGN_LANGUAGES:
391
  try:
392
  print(f"Loading alignment model for language '{lang}'...")
393
- device = "cuda" # Load on CPU initially, will move to GPU when needed
394
 
395
  align_model, align_metadata = whisperx.load_align_model(
396
  language_code=lang,
@@ -405,7 +383,7 @@ for lang in ALIGN_LANGUAGES:
405
  except Exception as e:
406
  print(f"Could not load alignment model for '{lang}': {e}")
407
 
408
- # Create global diarization pipeline
409
  try:
410
  print("Loading diarization model...")
411
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -424,25 +402,61 @@ except Exception as e:
424
  print(f"Could not load diarization model: {e}")
425
  _diarizer = None
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
428
  def _load_models(model_name: str = DEFAULT_MODEL):
429
  global _whipser_x_transcribe_models, _whipser_x_align_models, _diarizer
430
 
431
- if model_name not in _whipser_x_transcribe_models:
432
- raise ValueError(f"Model '{model_name}' not preloaded. Available models: {list(_whipser_x_transcribe_models.keys())}")
433
-
434
- whisper_model = _whipser_x_transcribe_models[model_name]
435
-
436
- # Move model to GPU if not already
437
- if hasattr(whisper_model, 'model') and hasattr(whisper_model.model, 'device'):
438
- current_device = str(whisper_model.model.device)
439
- if 'cpu' in current_device:
440
- print(f"Moving WhisperX model '{model_name}' to GPU...")
441
- whisper_model = whisper_model.to("cuda")
442
- _whipser_x_transcribe_models[model_name] = whisper_model
443
 
444
  return whisper_model, _diarizer
445
 
 
 
 
 
 
 
 
 
 
 
 
 
446
  # -----------------------------------------------------------------------------
447
  class WhisperTranscriber:
448
  def __init__(self):
@@ -498,12 +512,6 @@ class WhisperTranscriber:
498
  print(f"Performing alignment for language '{detected_language}'...")
499
  align_info = _whipser_x_align_models[detected_language]
500
 
501
- # Move alignment model to GPU if needed
502
- align_model = align_info["model"]
503
- if hasattr(align_model, 'to'):
504
- align_model = align_model.to("cuda")
505
- _whipser_x_align_models[detected_language]["model"] = align_model
506
-
507
  result = whisperx.align(
508
  result["segments"],
509
  align_info["model"],
@@ -1501,6 +1509,9 @@ with demo:
1501
  - Languages: Supports 100+ languages with auto-detection
1502
  - Vocabulary: Add names and technical terms in the prompt for better accuracy
1503
  """)
 
 
 
1504
 
1505
  if __name__ == "__main__":
1506
  demo.launch(debug=True)
 
363
  _diarizer = None
364
  _embedder = None
365
 
366
+ # Preload alignment and diarization models at startup (no GPU decorator needed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  print("Preloading all WhisperX alignment models...")
368
  for lang in ALIGN_LANGUAGES:
369
  try:
370
  print(f"Loading alignment model for language '{lang}'...")
371
+ device = "cuda"
372
 
373
  align_model, align_metadata = whisperx.load_align_model(
374
  language_code=lang,
 
383
  except Exception as e:
384
  print(f"Could not load alignment model for '{lang}': {e}")
385
 
386
+ # Create global diarization pipeline at startup
387
  try:
388
  print("Loading diarization model...")
389
  torch.backends.cuda.matmul.allow_tf32 = True
 
402
  print(f"Could not load diarization model: {e}")
403
  _diarizer = None
404
 
405
+ print("Alignment and diarization models preloaded successfully!")
406
+
407
+ @spaces.GPU # GPU is guaranteed to exist *inside* this function
408
+ def _load_whisper_model(model_name: str):
409
+ """Load a specific WhisperX transcribe model on GPU (lazy loading)"""
410
+ global _whipser_x_transcribe_models
411
+
412
+ if model_name in _whipser_x_transcribe_models:
413
+ print(f"WhisperX model '{model_name}' already loaded")
414
+ return _whipser_x_transcribe_models[model_name]
415
+
416
+ if model_name not in MODELS:
417
+ raise ValueError(f"Model '{model_name}' not found in MODELS registry. Available: {list(MODELS.keys())}")
418
+
419
+ print(f"Loading WhisperX model '{model_name}' on GPU...")
420
+ whisperx_model_name = MODELS[model_name]["whisperx_name"]
421
+ device = "cuda"
422
+ compute_type = "float16"
423
+
424
+ try:
425
+ model = whisperx.load_model(
426
+ whisperx_model_name,
427
+ device=device,
428
+ compute_type=compute_type,
429
+ download_root=CACHE_ROOT
430
+ )
431
+ _whipser_x_transcribe_models[model_name] = model
432
+ print(f"WhisperX model '{model_name}' loaded successfully")
433
+ return model
434
+ except Exception as e:
435
+ import traceback
436
+ traceback.print_exc()
437
+ raise RuntimeError(f"Could not load WhisperX model '{model_name}': {e}")
438
+
439
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
440
  def _load_models(model_name: str = DEFAULT_MODEL):
441
  global _whipser_x_transcribe_models, _whipser_x_align_models, _diarizer
442
 
443
+ # Load the specific whisper model (lazy loading)
444
+ whisper_model = _load_whisper_model(model_name)
 
 
 
 
 
 
 
 
 
 
445
 
446
  return whisper_model, _diarizer
447
 
448
+ # Optional: Preload all whisper models explicitly
449
+ @spaces.GPU
450
+ def preload_all_whisper_models():
451
+ """Preload all WhisperX transcribe models - optional, for faster first-time use"""
452
+ print("Preloading all WhisperX transcribe models...")
453
+ for model_name in MODELS.keys():
454
+ try:
455
+ _load_whisper_model(model_name)
456
+ except Exception as e:
457
+ print(f"Failed to preload model '{model_name}': {e}")
458
+ print("All WhisperX transcribe models preloaded!")
459
+
460
  # -----------------------------------------------------------------------------
461
  class WhisperTranscriber:
462
  def __init__(self):
 
512
  print(f"Performing alignment for language '{detected_language}'...")
513
  align_info = _whipser_x_align_models[detected_language]
514
 
 
 
 
 
 
 
515
  result = whisperx.align(
516
  result["segments"],
517
  align_info["model"],
 
1509
  - Languages: Supports 100+ languages with auto-detection
1510
  - Vocabulary: Add names and technical terms in the prompt for better accuracy
1511
  """)
1512
+
1513
+ # Preload all whisper models on startup
1514
+ demo.load(fn=preload_all_whisper_models, inputs=None, outputs=None)
1515
 
1516
  if __name__ == "__main__":
1517
  demo.launch(debug=True)