Spaces:
Runtime error
Runtime error
liuyang
commited on
Commit
·
c18172e
1
Parent(s):
02f099e
preload
Browse files
app.py
CHANGED
|
@@ -363,34 +363,12 @@ _whipser_x_align_models = {}
|
|
| 363 |
_diarizer = None
|
| 364 |
_embedder = None
|
| 365 |
|
| 366 |
-
# Preload
|
| 367 |
-
print("Preloading all WhisperX transcribe models...")
|
| 368 |
-
for model_name in MODELS.keys():
|
| 369 |
-
try:
|
| 370 |
-
print(f"Loading WhisperX model '{model_name}'...")
|
| 371 |
-
whisperx_model_name = MODELS[model_name]["whisperx_name"]
|
| 372 |
-
device = "cuda" # Load on CPU initially, will move to GPU when needed
|
| 373 |
-
compute_type = "float16"
|
| 374 |
-
|
| 375 |
-
model = whisperx.load_model(
|
| 376 |
-
whisperx_model_name,
|
| 377 |
-
device=device,
|
| 378 |
-
compute_type=compute_type,
|
| 379 |
-
download_root=CACHE_ROOT
|
| 380 |
-
)
|
| 381 |
-
_whipser_x_transcribe_models[model_name] = model
|
| 382 |
-
print(f"WhisperX model '{model_name}' loaded successfully")
|
| 383 |
-
except Exception as e:
|
| 384 |
-
import traceback
|
| 385 |
-
traceback.print_exc()
|
| 386 |
-
print(f"Could not load WhisperX model '{model_name}': {e}")
|
| 387 |
-
|
| 388 |
-
# Preload all alignment models for supported languages
|
| 389 |
print("Preloading all WhisperX alignment models...")
|
| 390 |
for lang in ALIGN_LANGUAGES:
|
| 391 |
try:
|
| 392 |
print(f"Loading alignment model for language '{lang}'...")
|
| 393 |
-
device = "cuda"
|
| 394 |
|
| 395 |
align_model, align_metadata = whisperx.load_align_model(
|
| 396 |
language_code=lang,
|
|
@@ -405,7 +383,7 @@ for lang in ALIGN_LANGUAGES:
|
|
| 405 |
except Exception as e:
|
| 406 |
print(f"Could not load alignment model for '{lang}': {e}")
|
| 407 |
|
| 408 |
-
# Create global diarization pipeline
|
| 409 |
try:
|
| 410 |
print("Loading diarization model...")
|
| 411 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
@@ -424,25 +402,61 @@ except Exception as e:
|
|
| 424 |
print(f"Could not load diarization model: {e}")
|
| 425 |
_diarizer = None
|
| 426 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
@spaces.GPU # GPU is guaranteed to exist *inside* this function
|
| 428 |
def _load_models(model_name: str = DEFAULT_MODEL):
|
| 429 |
global _whipser_x_transcribe_models, _whipser_x_align_models, _diarizer
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
whisper_model = _whipser_x_transcribe_models[model_name]
|
| 435 |
-
|
| 436 |
-
# Move model to GPU if not already
|
| 437 |
-
if hasattr(whisper_model, 'model') and hasattr(whisper_model.model, 'device'):
|
| 438 |
-
current_device = str(whisper_model.model.device)
|
| 439 |
-
if 'cpu' in current_device:
|
| 440 |
-
print(f"Moving WhisperX model '{model_name}' to GPU...")
|
| 441 |
-
whisper_model = whisper_model.to("cuda")
|
| 442 |
-
_whipser_x_transcribe_models[model_name] = whisper_model
|
| 443 |
|
| 444 |
return whisper_model, _diarizer
|
| 445 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
# -----------------------------------------------------------------------------
|
| 447 |
class WhisperTranscriber:
|
| 448 |
def __init__(self):
|
|
@@ -498,12 +512,6 @@ class WhisperTranscriber:
|
|
| 498 |
print(f"Performing alignment for language '{detected_language}'...")
|
| 499 |
align_info = _whipser_x_align_models[detected_language]
|
| 500 |
|
| 501 |
-
# Move alignment model to GPU if needed
|
| 502 |
-
align_model = align_info["model"]
|
| 503 |
-
if hasattr(align_model, 'to'):
|
| 504 |
-
align_model = align_model.to("cuda")
|
| 505 |
-
_whipser_x_align_models[detected_language]["model"] = align_model
|
| 506 |
-
|
| 507 |
result = whisperx.align(
|
| 508 |
result["segments"],
|
| 509 |
align_info["model"],
|
|
@@ -1501,6 +1509,9 @@ with demo:
|
|
| 1501 |
- Languages: Supports 100+ languages with auto-detection
|
| 1502 |
- Vocabulary: Add names and technical terms in the prompt for better accuracy
|
| 1503 |
""")
|
|
|
|
|
|
|
|
|
|
| 1504 |
|
| 1505 |
if __name__ == "__main__":
|
| 1506 |
demo.launch(debug=True)
|
|
|
|
| 363 |
_diarizer = None
|
| 364 |
_embedder = None
|
| 365 |
|
| 366 |
+
# Preload alignment and diarization models at startup (no GPU decorator needed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
print("Preloading all WhisperX alignment models...")
|
| 368 |
for lang in ALIGN_LANGUAGES:
|
| 369 |
try:
|
| 370 |
print(f"Loading alignment model for language '{lang}'...")
|
| 371 |
+
device = "cuda"
|
| 372 |
|
| 373 |
align_model, align_metadata = whisperx.load_align_model(
|
| 374 |
language_code=lang,
|
|
|
|
| 383 |
except Exception as e:
|
| 384 |
print(f"Could not load alignment model for '{lang}': {e}")
|
| 385 |
|
| 386 |
+
# Create global diarization pipeline at startup
|
| 387 |
try:
|
| 388 |
print("Loading diarization model...")
|
| 389 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
| 402 |
print(f"Could not load diarization model: {e}")
|
| 403 |
_diarizer = None
|
| 404 |
|
| 405 |
+
print("Alignment and diarization models preloaded successfully!")
|
| 406 |
+
|
| 407 |
+
@spaces.GPU # GPU is guaranteed to exist *inside* this function
|
| 408 |
+
def _load_whisper_model(model_name: str):
|
| 409 |
+
"""Load a specific WhisperX transcribe model on GPU (lazy loading)"""
|
| 410 |
+
global _whipser_x_transcribe_models
|
| 411 |
+
|
| 412 |
+
if model_name in _whipser_x_transcribe_models:
|
| 413 |
+
print(f"WhisperX model '{model_name}' already loaded")
|
| 414 |
+
return _whipser_x_transcribe_models[model_name]
|
| 415 |
+
|
| 416 |
+
if model_name not in MODELS:
|
| 417 |
+
raise ValueError(f"Model '{model_name}' not found in MODELS registry. Available: {list(MODELS.keys())}")
|
| 418 |
+
|
| 419 |
+
print(f"Loading WhisperX model '{model_name}' on GPU...")
|
| 420 |
+
whisperx_model_name = MODELS[model_name]["whisperx_name"]
|
| 421 |
+
device = "cuda"
|
| 422 |
+
compute_type = "float16"
|
| 423 |
+
|
| 424 |
+
try:
|
| 425 |
+
model = whisperx.load_model(
|
| 426 |
+
whisperx_model_name,
|
| 427 |
+
device=device,
|
| 428 |
+
compute_type=compute_type,
|
| 429 |
+
download_root=CACHE_ROOT
|
| 430 |
+
)
|
| 431 |
+
_whipser_x_transcribe_models[model_name] = model
|
| 432 |
+
print(f"WhisperX model '{model_name}' loaded successfully")
|
| 433 |
+
return model
|
| 434 |
+
except Exception as e:
|
| 435 |
+
import traceback
|
| 436 |
+
traceback.print_exc()
|
| 437 |
+
raise RuntimeError(f"Could not load WhisperX model '{model_name}': {e}")
|
| 438 |
+
|
| 439 |
@spaces.GPU # GPU is guaranteed to exist *inside* this function
|
| 440 |
def _load_models(model_name: str = DEFAULT_MODEL):
|
| 441 |
global _whipser_x_transcribe_models, _whipser_x_align_models, _diarizer
|
| 442 |
|
| 443 |
+
# Load the specific whisper model (lazy loading)
|
| 444 |
+
whisper_model = _load_whisper_model(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
|
| 446 |
return whisper_model, _diarizer
|
| 447 |
|
| 448 |
+
# Optional: Preload all whisper models explicitly
|
| 449 |
+
@spaces.GPU
|
| 450 |
+
def preload_all_whisper_models():
|
| 451 |
+
"""Preload all WhisperX transcribe models - optional, for faster first-time use"""
|
| 452 |
+
print("Preloading all WhisperX transcribe models...")
|
| 453 |
+
for model_name in MODELS.keys():
|
| 454 |
+
try:
|
| 455 |
+
_load_whisper_model(model_name)
|
| 456 |
+
except Exception as e:
|
| 457 |
+
print(f"Failed to preload model '{model_name}': {e}")
|
| 458 |
+
print("All WhisperX transcribe models preloaded!")
|
| 459 |
+
|
| 460 |
# -----------------------------------------------------------------------------
|
| 461 |
class WhisperTranscriber:
|
| 462 |
def __init__(self):
|
|
|
|
| 512 |
print(f"Performing alignment for language '{detected_language}'...")
|
| 513 |
align_info = _whipser_x_align_models[detected_language]
|
| 514 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
result = whisperx.align(
|
| 516 |
result["segments"],
|
| 517 |
align_info["model"],
|
|
|
|
| 1509 |
- Languages: Supports 100+ languages with auto-detection
|
| 1510 |
- Vocabulary: Add names and technical terms in the prompt for better accuracy
|
| 1511 |
""")
|
| 1512 |
+
|
| 1513 |
+
# Preload all whisper models on startup
|
| 1514 |
+
demo.load(fn=preload_all_whisper_models, inputs=None, outputs=None)
|
| 1515 |
|
| 1516 |
if __name__ == "__main__":
|
| 1517 |
demo.launch(debug=True)
|