Spaces:
Sleeping
Sleeping
| """Singleton model loader with checkpoint validation for vR.P.30.1.""" | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import logging | |
| import os | |
| import threading | |
| from pathlib import Path | |
| import torch | |
| from .model_architecture import build_vrp301_model | |
| logger = logging.getLogger(__name__) | |
| DEFAULT_MODEL_DIR = Path("models") / "inference" / "vR.P.30.1" | |
| MODEL_DIR = os.environ.get("MODEL_DIR", str(DEFAULT_MODEL_DIR)) | |
| MODEL_FILENAME = os.environ.get("MODEL_FILENAME", "best_model.pt") | |
| DEVICE = os.environ.get("DEVICE", "auto") | |
| MANIFEST_FILENAME = "manifest.json" | |
| def _resolve_device(): | |
| if DEVICE == "auto": | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| return torch.device(DEVICE) | |
| def _compute_sha256(filepath): | |
| sha = hashlib.sha256() | |
| with open(filepath, "rb") as f: | |
| for chunk in iter(lambda: f.read(8192), b""): | |
| sha.update(chunk) | |
| return sha.hexdigest() | |
| def _load_manifest(model_dir: Path) -> dict: | |
| manifest_path = model_dir / MANIFEST_FILENAME | |
| if not manifest_path.exists(): | |
| return {} | |
| return json.loads(manifest_path.read_text(encoding="utf-8")) | |
| class ModelLoader: | |
| _instance = None | |
| _lock = threading.Lock() | |
| def __init__(self): | |
| self._model = None | |
| self._device = None | |
| self._checkpoint_hash = None | |
| self._manifest = {} | |
| self._loaded = False | |
| def get_instance(cls): | |
| if cls._instance is None: | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = cls() | |
| return cls._instance | |
| def is_loaded(self): | |
| return self._loaded | |
| def model(self): | |
| if not self._loaded: | |
| self.load() | |
| return self._model | |
| def device(self): | |
| if self._device is None: | |
| self._device = _resolve_device() | |
| return self._device | |
| def checkpoint_hash(self): | |
| return self._checkpoint_hash | |
| def manifest(self): | |
| return self._manifest | |
| def model_dir(self): | |
| return Path(MODEL_DIR) | |
| def model_filename(self): | |
| return MODEL_FILENAME | |
| def load(self): | |
| path = self.model_dir / MODEL_FILENAME | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Model not found at {path}") | |
| logger.info("Loading model from %s", path) | |
| self._manifest = _load_manifest(self.model_dir) | |
| self._checkpoint_hash = _compute_sha256(path) | |
| self._device = _resolve_device() | |
| self._model = build_vrp301_model() | |
| ckpt = torch.load(path, map_location=self._device, weights_only=False) | |
| sd = ( | |
| ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt)) | |
| if isinstance(ckpt, dict) | |
| else ckpt | |
| ) | |
| self._model.load_state_dict(sd, strict=True) | |
| self._model.to(self._device).eval() | |
| self._loaded = True | |
| logger.info("Model loaded on %s (hash: %s)", self._device, self._checkpoint_hash[:16]) | |
| def unload(self): | |
| if self._model: | |
| del self._model | |
| self._model = None | |
| self._manifest = {} | |
| self._loaded = False | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |