Tajweed-AI / recitation_engine /model_loader.py
hetchyy's picture
Add i8n
cee53d8
"""
Model loading and caching for ASR.
"""
import sys
from pathlib import Path
# Add parent directory to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from config import MODEL_PATHS, INFERENCE_PRECISION
# Module-level cache
_model_cache = {
"bundles": [], # list of dicts: {"path": str, "processor": proc, "model": model}
"loaded": False,
"errors": []
}
def _get_target_device():
"""Return cuda when available, otherwise cpu.
On HF Spaces with ZeroGPU, always returns CPU to defer CUDA init
until inside a @gpu_decorator function.
"""
from config import IS_HF_SPACE
if IS_HF_SPACE:
return torch.device("cpu") # Defer GPU until inference
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _get_torch_dtype():
"""Get torch dtype based on INFERENCE_PRECISION config and device availability."""
if not torch.cuda.is_available():
return torch.float32 # CPU always uses fp32
precision = INFERENCE_PRECISION.lower()
if precision == "fp16":
return torch.float16
elif precision == "bf16":
# Check if bf16 is supported (Ampere+ GPUs)
if torch.cuda.is_bf16_supported():
return torch.bfloat16
else:
print("⚠️ bf16 not supported on this GPU, falling back to fp16")
return torch.float16
else:
return torch.float32 # Default to fp32
def _ensure_device(model):
"""
Move a model to the currently available device if needed.
"""
if model is None:
return None
target_device = _get_target_device()
current_device = next(model.parameters()).device
if current_device != target_device:
model = model.to(target_device)
print(f"Moved ASR model to {target_device}")
return model
def move_models_to_gpu():
"""Move cached ASR models to GPU.
Call this inside @gpu_decorator functions on HF Spaces.
On ZeroGPU, models are loaded on CPU at startup to avoid CUDA init
in the main process. This function moves them to GPU when a GPU
lease is active.
"""
if not torch.cuda.is_available():
return
device = torch.device("cuda")
dtype = _get_torch_dtype()
for bundle in _model_cache["bundles"]:
model = bundle.get("model")
if model is not None:
current_device = next(model.parameters()).device
if current_device.type != "cuda":
bundle["model"] = model.to(device, dtype=dtype)
print(f"Moved ASR model to {device}")
def load_model():
"""
Load the ASR model with caching.
Returns:
Tuple of (processor, model, error_message)
"""
models, errors = load_models()
if not models:
return None, None, "; ".join(errors) if errors else "No models loaded"
# Return first model for backward compatibility
primary = models[0]
return primary["processor"], primary["model"], errors[0] if errors else None
def get_model():
"""
Get the cached model or load it if not loaded.
Returns:
Tuple of (processor, model, error_message)
"""
return load_model()
def load_models():
"""
Load all ASR models defined in MODEL_PATHS.
Returns:
Tuple (bundles, errors) where bundles is a list of dicts with keys
path, processor, model. errors is a list of error strings (may be empty).
"""
# Return cached models if already loaded (don't move them - use move_models_to_gpu() for that)
if _model_cache["loaded"]:
return _model_cache["bundles"], _model_cache["errors"]
bundles = []
errors = []
dtype = _get_torch_dtype()
dtype_name = {torch.float32: "fp32", torch.float16: "fp16", torch.bfloat16: "bf16"}.get(dtype, str(dtype))
for idx, path in enumerate(MODEL_PATHS):
try:
print(f"Loading ASR model {idx+1} from {path}...")
processor = Wav2Vec2Processor.from_pretrained(str(path), token=True)
model = Wav2Vec2ForCTC.from_pretrained(str(path), token=True, torch_dtype=dtype)
model = _ensure_device(model)
model.eval()
bundles.append({"path": path, "processor": processor, "model": model})
errors.append(None)
device = next(model.parameters()).device
print(f"✓ Model {idx+1} loaded on {device} ({dtype_name})")
except Exception as e:
error_msg = f"Failed to load model {idx+1} ({path}): {str(e)}"
errors.append(error_msg)
print(f"✗ {error_msg}")
_model_cache["bundles"] = bundles
_model_cache["errors"] = errors
_model_cache["loaded"] = True
return bundles, errors
def get_models():
"""
Get all loaded models (or load them if needed).
Returns:
Tuple (bundles, errors)
"""
return load_models()