ground-zero / src /engine /whisper_base.py
jefffffff9
Initial commit: Sahel-Agri Voice AI
76db545
Raw
History Blame Contribute Delete
2.55 kB
"""
Loads the Whisper backbone model and processor once.
All other modules receive references to this shared instance.
"""
from __future__ import annotations
import logging
from pathlib import Path
import torch
import yaml
from transformers import WhisperForConditionalGeneration, WhisperProcessor
logger = logging.getLogger(__name__)
class WhisperBackbone:
"""Singleton-style loader for the Whisper base model and processor."""
def __init__(self, config_path: str = "configs/base_config.yaml") -> None:
config_path = Path(config_path)
with open(config_path) as f:
cfg = yaml.safe_load(f)
self._model_id: str = cfg["model"]["id"]
self._model: WhisperForConditionalGeneration | None = None
self._processor: WhisperProcessor | None = None
self._device: str = "cpu"
def load(self, device: str = "cuda", hf_token: str | None = None) -> None:
"""Load model and processor into memory. Call once at startup."""
self._device = device if torch.cuda.is_available() and device == "cuda" else "cpu"
logger.info("Loading %s on %s", self._model_id, self._device)
self._processor = WhisperProcessor.from_pretrained(
self._model_id,
token=hf_token,
)
dtype = torch.float16 if self._device == "cuda" else torch.float32
self._model = WhisperForConditionalGeneration.from_pretrained(
self._model_id,
torch_dtype=dtype,
token=hf_token,
).to(self._device)
self._model.eval()
logger.info("Model loaded successfully (dtype=%s, device=%s)", dtype, self._device)
@property
def model(self) -> WhisperForConditionalGeneration:
if self._model is None:
raise RuntimeError("Call WhisperBackbone.load() before accessing the model.")
return self._model
@property
def processor(self) -> WhisperProcessor:
if self._processor is None:
raise RuntimeError("Call WhisperBackbone.load() before accessing the processor.")
return self._processor
@property
def device(self) -> str:
return self._device
@property
def model_id(self) -> str:
return self._model_id
def free(self) -> None:
"""Release GPU memory."""
del self._model
del self._processor
self._model = None
self._processor = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info("Backbone freed from memory.")