tidal-api / backend /inference /model_loader.py
the-harsh-vardhan's picture
Deploy vR.P.30.1 with forensic controls
ffb6b18 verified
"""Singleton model loader with checkpoint validation for vR.P.30.1."""
from __future__ import annotations
import hashlib
import json
import logging
import os
import threading
from pathlib import Path
import torch
from .model_architecture import build_vrp301_model
logger = logging.getLogger(__name__)
DEFAULT_MODEL_DIR = Path("models") / "inference" / "vR.P.30.1"
MODEL_DIR = os.environ.get("MODEL_DIR", str(DEFAULT_MODEL_DIR))
MODEL_FILENAME = os.environ.get("MODEL_FILENAME", "best_model.pt")
DEVICE = os.environ.get("DEVICE", "auto")
MANIFEST_FILENAME = "manifest.json"
def _resolve_device():
if DEVICE == "auto":
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return torch.device(DEVICE)
def _compute_sha256(filepath):
sha = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha.update(chunk)
return sha.hexdigest()
def _load_manifest(model_dir: Path) -> dict:
manifest_path = model_dir / MANIFEST_FILENAME
if not manifest_path.exists():
return {}
return json.loads(manifest_path.read_text(encoding="utf-8"))
class ModelLoader:
_instance = None
_lock = threading.Lock()
def __init__(self):
self._model = None
self._device = None
self._checkpoint_hash = None
self._manifest = {}
self._loaded = False
@classmethod
def get_instance(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@property
def is_loaded(self):
return self._loaded
@property
def model(self):
if not self._loaded:
self.load()
return self._model
@property
def device(self):
if self._device is None:
self._device = _resolve_device()
return self._device
@property
def checkpoint_hash(self):
return self._checkpoint_hash
@property
def manifest(self):
return self._manifest
@property
def model_dir(self):
return Path(MODEL_DIR)
@property
def model_filename(self):
return MODEL_FILENAME
def load(self):
path = self.model_dir / MODEL_FILENAME
if not path.exists():
raise FileNotFoundError(f"Model not found at {path}")
logger.info("Loading model from %s", path)
self._manifest = _load_manifest(self.model_dir)
self._checkpoint_hash = _compute_sha256(path)
self._device = _resolve_device()
self._model = build_vrp301_model()
ckpt = torch.load(path, map_location=self._device, weights_only=False)
sd = (
ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
if isinstance(ckpt, dict)
else ckpt
)
self._model.load_state_dict(sd, strict=True)
self._model.to(self._device).eval()
self._loaded = True
logger.info("Model loaded on %s (hash: %s)", self._device, self._checkpoint_hash[:16])
def unload(self):
if self._model:
del self._model
self._model = None
self._manifest = {}
self._loaded = False
if torch.cuda.is_available():
torch.cuda.empty_cache()