Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Model loading and caching for ASR. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| # Add parent directory to path for imports | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| import torch | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| from config import MODEL_PATHS, INFERENCE_PRECISION | |
| # Module-level cache | |
| _model_cache = { | |
| "bundles": [], # list of dicts: {"path": str, "processor": proc, "model": model} | |
| "loaded": False, | |
| "errors": [] | |
| } | |
| def _get_target_device(): | |
| """Return cuda when available, otherwise cpu. | |
| On HF Spaces with ZeroGPU, always returns CPU to defer CUDA init | |
| until inside a @gpu_decorator function. | |
| """ | |
| from config import IS_HF_SPACE | |
| if IS_HF_SPACE: | |
| return torch.device("cpu") # Defer GPU until inference | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _get_torch_dtype(): | |
| """Get torch dtype based on INFERENCE_PRECISION config and device availability.""" | |
| if not torch.cuda.is_available(): | |
| return torch.float32 # CPU always uses fp32 | |
| precision = INFERENCE_PRECISION.lower() | |
| if precision == "fp16": | |
| return torch.float16 | |
| elif precision == "bf16": | |
| # Check if bf16 is supported (Ampere+ GPUs) | |
| if torch.cuda.is_bf16_supported(): | |
| return torch.bfloat16 | |
| else: | |
| print("⚠️ bf16 not supported on this GPU, falling back to fp16") | |
| return torch.float16 | |
| else: | |
| return torch.float32 # Default to fp32 | |
| def _ensure_device(model): | |
| """ | |
| Move a model to the currently available device if needed. | |
| """ | |
| if model is None: | |
| return None | |
| target_device = _get_target_device() | |
| current_device = next(model.parameters()).device | |
| if current_device != target_device: | |
| model = model.to(target_device) | |
| print(f"Moved ASR model to {target_device}") | |
| return model | |
| def move_models_to_gpu(): | |
| """Move cached ASR models to GPU. | |
| Call this inside @gpu_decorator functions on HF Spaces. | |
| On ZeroGPU, models are loaded on CPU at startup to avoid CUDA init | |
| in the main process. This function moves them to GPU when a GPU | |
| lease is active. | |
| """ | |
| if not torch.cuda.is_available(): | |
| return | |
| device = torch.device("cuda") | |
| dtype = _get_torch_dtype() | |
| for bundle in _model_cache["bundles"]: | |
| model = bundle.get("model") | |
| if model is not None: | |
| current_device = next(model.parameters()).device | |
| if current_device.type != "cuda": | |
| bundle["model"] = model.to(device, dtype=dtype) | |
| print(f"Moved ASR model to {device}") | |
| def load_model(): | |
| """ | |
| Load the ASR model with caching. | |
| Returns: | |
| Tuple of (processor, model, error_message) | |
| """ | |
| models, errors = load_models() | |
| if not models: | |
| return None, None, "; ".join(errors) if errors else "No models loaded" | |
| # Return first model for backward compatibility | |
| primary = models[0] | |
| return primary["processor"], primary["model"], errors[0] if errors else None | |
| def get_model(): | |
| """ | |
| Get the cached model or load it if not loaded. | |
| Returns: | |
| Tuple of (processor, model, error_message) | |
| """ | |
| return load_model() | |
| def load_models(): | |
| """ | |
| Load all ASR models defined in MODEL_PATHS. | |
| Returns: | |
| Tuple (bundles, errors) where bundles is a list of dicts with keys | |
| path, processor, model. errors is a list of error strings (may be empty). | |
| """ | |
| # Return cached models if already loaded (don't move them - use move_models_to_gpu() for that) | |
| if _model_cache["loaded"]: | |
| return _model_cache["bundles"], _model_cache["errors"] | |
| bundles = [] | |
| errors = [] | |
| dtype = _get_torch_dtype() | |
| dtype_name = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"}.get(dtype, str(dtype)) | |
| for idx, path in enumerate(MODEL_PATHS): | |
| try: | |
| print(f"Loading ASR model {idx+1} from {path}...") | |
| processor = Wav2Vec2Processor.from_pretrained(str(path), token=True) | |
| model = Wav2Vec2ForCTC.from_pretrained(str(path), token=True, torch_dtype=dtype) | |
| model = _ensure_device(model) | |
| model.eval() | |
| bundles.append({"path": path, "processor": processor, "model": model}) | |
| errors.append(None) | |
| device = next(model.parameters()).device | |
| print(f"✓ Model {idx+1} loaded on {device} ({dtype_name})") | |
| except Exception as e: | |
| error_msg = f"Failed to load model {idx+1} ({path}): {str(e)}" | |
| errors.append(error_msg) | |
| print(f"✗ {error_msg}") | |
| _model_cache["bundles"] = bundles | |
| _model_cache["errors"] = errors | |
| _model_cache["loaded"] = True | |
| return bundles, errors | |
| def get_models(): | |
| """ | |
| Get all loaded models (or load them if needed). | |
| Returns: | |
| Tuple (bundles, errors) | |
| """ | |
| return load_models() | |