from dataclasses import dataclass, field from typing import Callable, Dict, Any, Optional @dataclass class RegisteredModel: """Metadata + lazy loader for a single model.""" id: str display_name: str loader: Callable[[], Any] _instance: Optional[Any] = field(default=None, init=False, repr=False) def get(self) -> Any: """Instantiate on first call, then cache.""" if self._instance is None: self._instance = self.loader() return self._instance def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]: """ Central place to register all models. Returns a dict: model_id -> RegisteredModel. """ # -------- LR on raw pixels -------- def make_lr_raw(): from src.inference.lr_model import LRModel return LRModel( model_path="checkpoints/lr_model.joblib", labels_path="configs/labels.json", ) # -------- SVM on raw pixels -------- def make_svm_raw(): from src.inference.svm_model import SVMModel return SVMModel( ckpt_path="checkpoints/svm_model.joblib", labels_path="configs/labels.json", ) # -------- ResNet (PT) + LR head -------- def make_resnet_pt_lr(): from src.inference.resnet_pt_lr_model import ResNetPTLRModel return ResNetPTLRModel( ckpt_path="checkpoints/resnet_pt_lr_head.joblib", labels_path="configs/labels.json", device=device, ) # -------- ResNet (PT) + SVM head -------- def make_resnet_pt_svm(): from src.inference.resnet_pt_svm_model import ResNetPTSVMModel return ResNetPTSVMModel( ckpt_path="checkpoints/resnet_pt_svm_head.joblib", labels_path="configs/labels.json", device=device, ) return { "lr_raw": RegisteredModel( id="lr_raw", display_name="LR (raw 64×64 grayscale)", loader=make_lr_raw, ), "svm_raw": RegisteredModel( id="svm_raw", display_name="SVM (raw 64×64 grayscale)", loader=make_svm_raw, ), "resnet_pt_lr": RegisteredModel( id="resnet_pt_lr", display_name="ResNet (pretrained) + LR", loader=make_resnet_pt_lr, ), "resnet_pt_svm": RegisteredModel( id="resnet_pt_svm", display_name="ResNet (pretrained) + SVM", loader=make_resnet_pt_svm, ), } # Build registry once (models load lazily) _REGISTRY: Dict[str, RegisteredModel] = _build_registry(device="cpu") def get_registry() -> Dict[str, RegisteredModel]: """Return the full registry (id -> RegisteredModel).""" return _REGISTRY def get_model(model_id: str) -> Any: """Get a single model instance by id.""" if model_id not in _REGISTRY: raise KeyError(f"Unknown model_id: {model_id}") return _REGISTRY[model_id].get() def get_models() -> Dict[str, Any]: """Eagerly load all models (optional).""" return {mid: entry.get() for mid, entry in _REGISTRY.items()} def get_model_display_names() -> Dict[str, str]: """Mapping id -> display name (for UI dropdowns).""" return {mid: entry.display_name for mid, entry in _REGISTRY.items()}