""" Model loading and caching for ASR. """ import sys from pathlib import Path # Add parent directory to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) import torch from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor from config import MODEL_PATHS, INFERENCE_PRECISION # Module-level cache _model_cache = { "bundles": [], # list of dicts: {"path": str, "processor": proc, "model": model} "loaded": False, "errors": [] } def _get_target_device(): """Return cuda when available, otherwise cpu. On HF Spaces with ZeroGPU, always returns CPU to defer CUDA init until inside a @gpu_decorator function. """ from config import IS_HF_SPACE if IS_HF_SPACE: return torch.device("cpu") # Defer GPU until inference return torch.device("cuda" if torch.cuda.is_available() else "cpu") def _get_torch_dtype(): """Get torch dtype based on INFERENCE_PRECISION config and device availability.""" if not torch.cuda.is_available(): return torch.float32 # CPU always uses fp32 precision = INFERENCE_PRECISION.lower() if precision == "fp16": return torch.float16 elif precision == "bf16": # Check if bf16 is supported (Ampere+ GPUs) if torch.cuda.is_bf16_supported(): return torch.bfloat16 else: print("⚠️ bf16 not supported on this GPU, falling back to fp16") return torch.float16 else: return torch.float32 # Default to fp32 def _ensure_device(model): """ Move a model to the currently available device if needed. """ if model is None: return None target_device = _get_target_device() current_device = next(model.parameters()).device if current_device != target_device: model = model.to(target_device) print(f"Moved ASR model to {target_device}") return model def move_models_to_gpu(): """Move cached ASR models to GPU. Call this inside @gpu_decorator functions on HF Spaces. On ZeroGPU, models are loaded on CPU at startup to avoid CUDA init in the main process. This function moves them to GPU when a GPU lease is active. """ if not torch.cuda.is_available(): return device = torch.device("cuda") dtype = _get_torch_dtype() for bundle in _model_cache["bundles"]: model = bundle.get("model") if model is not None: current_device = next(model.parameters()).device if current_device.type != "cuda": bundle["model"] = model.to(device, dtype=dtype) print(f"Moved ASR model to {device}") def load_model(): """ Load the ASR model with caching. Returns: Tuple of (processor, model, error_message) """ models, errors = load_models() if not models: return None, None, "; ".join(errors) if errors else "No models loaded" # Return first model for backward compatibility primary = models[0] return primary["processor"], primary["model"], errors[0] if errors else None def get_model(): """ Get the cached model or load it if not loaded. Returns: Tuple of (processor, model, error_message) """ return load_model() def load_models(): """ Load all ASR models defined in MODEL_PATHS. Returns: Tuple (bundles, errors) where bundles is a list of dicts with keys path, processor, model. errors is a list of error strings (may be empty). """ # Return cached models if already loaded (don't move them - use move_models_to_gpu() for that) if _model_cache["loaded"]: return _model_cache["bundles"], _model_cache["errors"] bundles = [] errors = [] dtype = _get_torch_dtype() dtype_name = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"}.get(dtype, str(dtype)) for idx, path in enumerate(MODEL_PATHS): try: print(f"Loading ASR model {idx+1} from {path}...") processor = Wav2Vec2Processor.from_pretrained(str(path), token=True) model = Wav2Vec2ForCTC.from_pretrained(str(path), token=True, torch_dtype=dtype) model = _ensure_device(model) model.eval() bundles.append({"path": path, "processor": processor, "model": model}) errors.append(None) device = next(model.parameters()).device print(f"✓ Model {idx+1} loaded on {device} ({dtype_name})") except Exception as e: error_msg = f"Failed to load model {idx+1} ({path}): {str(e)}" errors.append(error_msg) print(f"✗ {error_msg}") _model_cache["bundles"] = bundles _model_cache["errors"] = errors _model_cache["loaded"] = True return bundles, errors def get_models(): """ Get all loaded models (or load them if needed). Returns: Tuple (bundles, errors) """ return load_models()