Quran-multi-aligner / src /segmenter /segmenter_model.py
hetchyy's picture
Initial commit
20e9692
"""Model lifecycle and device management for the VAD segmenter."""
import torch
from config import SEGMENTER_MODEL, DTYPE, IS_HF_SPACE, TORCH_COMPILE
from ..zero_gpu import ZERO_GPU_AVAILABLE, is_quota_exhausted, is_user_forced_cpu
# =============================================================================
# Model caches
# =============================================================================
_segmenter_cache = {"model": None, "processor": None, "loaded": False, "load_time": 0.0, "device": None}
_env_logged = False
def _log_env_once():
"""Log library and GPU versions once for debugging HF Space mismatches."""
global _env_logged
if _env_logged:
return
_env_logged = True
try:
import importlib.metadata as _im
def _ver(pkg: str) -> str:
try:
return _im.version(pkg)
except Exception:
return "unknown"
print(f"[ENV] torch={torch.__version__} cuda={torch.version.cuda} cudnn={torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else 'none'}")
print(f"[ENV] transformers={_ver('transformers')} recitations_segmenter={_ver('recitations_segmenter')}")
if torch.cuda.is_available():
print(f"[ENV] GPU={torch.cuda.get_device_name(0)}")
except Exception as e:
print(f"[ENV] Failed to log env: {e}")
_TORCH_DTYPE = torch.float16 if DTYPE == "float16" else torch.float32
def _get_device_and_dtype():
"""Get the best available device and dtype."""
if IS_HF_SPACE or ZERO_GPU_AVAILABLE:
return torch.device("cpu"), _TORCH_DTYPE
if torch.cuda.is_available():
return torch.device("cuda"), _TORCH_DTYPE
return torch.device("cpu"), _TORCH_DTYPE
def ensure_models_on_gpu(asr_model_name=None):
"""
Move models to GPU. Call this INSIDE a GPU-decorated function
after ZeroGPU lease is acquired.
Args:
asr_model_name: If provided, move only this ASR model to GPU.
If None, skip ASR model movement (e.g. during VAD-only lease).
Skips if quota exhausted or CUDA unavailable.
Idempotent: checks current device before moving.
Returns:
float: Time in seconds spent moving models to GPU.
"""
import time
from ..alignment.phoneme_asr import move_phoneme_asr_to_gpu
if is_user_forced_cpu() or is_quota_exhausted() or not torch.cuda.is_available():
return 0.0
device = torch.device("cuda")
dtype = _TORCH_DTYPE
move_start = time.time()
# Move segmenter to GPU
if _segmenter_cache["loaded"] and _segmenter_cache["model"] is not None:
model = _segmenter_cache["model"]
if next(model.parameters()).device.type != "cuda":
print("[GPU] Moving segmenter to CUDA...")
model.to(device, dtype=dtype)
_segmenter_cache["model"] = model
_segmenter_cache["device"] = "cuda"
print("[GPU] Segmenter on CUDA")
# Move phoneme ASR to GPU (only the requested model)
if asr_model_name is not None:
move_phoneme_asr_to_gpu(asr_model_name)
return time.time() - move_start
def ensure_models_on_cpu():
"""
Move all models back to CPU. Called when GPU lease fails or quota
is exhausted so that CPU fallback inference can proceed.
Idempotent: checks current device before moving.
"""
from ..alignment.phoneme_asr import move_phoneme_asr_to_cpu
device = torch.device("cpu")
dtype = _TORCH_DTYPE
# Move segmenter to CPU
if _segmenter_cache["loaded"] and _segmenter_cache["model"] is not None:
model = _segmenter_cache["model"]
if next(model.parameters()).device.type != "cpu":
print("[CPU] Moving segmenter to CPU...")
model.to(device, dtype=dtype)
_segmenter_cache["model"] = model
_segmenter_cache["device"] = "cpu"
print("[CPU] Segmenter on CPU")
# Move phoneme ASR to CPU
move_phoneme_asr_to_cpu()
def load_segmenter():
"""Load the VAD segmenter model on CPU. Returns (model, processor, load_time).
Models are loaded once and cached. Use ensure_models_on_gpu()
inside GPU-decorated functions to move to CUDA.
"""
if _segmenter_cache["loaded"]:
return _segmenter_cache["model"], _segmenter_cache["processor"], 0.0
import time
start_time = time.time()
try:
from transformers import AutoModelForAudioFrameClassification, AutoFeatureExtractor
print(f"Loading segmenter: {SEGMENTER_MODEL}")
device, dtype = _get_device_and_dtype()
model = AutoModelForAudioFrameClassification.from_pretrained(SEGMENTER_MODEL)
model.to(device, dtype=dtype)
model.eval()
if TORCH_COMPILE and not (IS_HF_SPACE or ZERO_GPU_AVAILABLE):
model = torch.compile(model, mode="reduce-overhead")
processor = AutoFeatureExtractor.from_pretrained(SEGMENTER_MODEL)
load_time = time.time() - start_time
_segmenter_cache["model"] = model
_segmenter_cache["processor"] = processor
_segmenter_cache["loaded"] = True
_segmenter_cache["load_time"] = load_time
_segmenter_cache["device"] = device.type
print(f"Segmenter loaded on {device} in {load_time:.2f}s")
return model, processor, load_time
except Exception as e:
print(f"Failed to load segmenter: {e}")
return None, None, 0.0