Shashwat98's picture
Update src/registry.py
627d37a verified
from dataclasses import dataclass, field
from typing import Callable, Dict, Any, Optional
@dataclass
class RegisteredModel:
"""Metadata + lazy loader for a single model."""
id: str
display_name: str
loader: Callable[[], Any]
_instance: Optional[Any] = field(default=None, init=False, repr=False)
def get(self) -> Any:
"""Instantiate on first call, then cache."""
if self._instance is None:
self._instance = self.loader()
return self._instance
def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]:
"""
Central place to register all models.
Returns a dict: model_id -> RegisteredModel.
"""
# -------- LR on raw pixels --------
def make_lr_raw():
from src.inference.lr_model import LRModel
return LRModel(
model_path="checkpoints/lr_model.joblib",
labels_path="configs/labels.json",
)
# -------- SVM on raw pixels --------
def make_svm_raw():
from src.inference.svm_model import SVMModel
return SVMModel(
ckpt_path="checkpoints/svm_model.joblib",
labels_path="configs/labels.json",
)
# -------- ResNet (PT) + LR head --------
def make_resnet_pt_lr():
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
return ResNetPTLRModel(
ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
labels_path="configs/labels.json",
device=device,
)
# -------- ResNet (PT) + SVM head --------
def make_resnet_pt_svm():
from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
return ResNetPTSVMModel(
ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
labels_path="configs/labels.json",
device=device,
)
return {
"lr_raw": RegisteredModel(
id="lr_raw",
display_name="LR (raw 64×64 grayscale)",
loader=make_lr_raw,
),
"svm_raw": RegisteredModel(
id="svm_raw",
display_name="SVM (raw 64×64 grayscale)",
loader=make_svm_raw,
),
"resnet_pt_lr": RegisteredModel(
id="resnet_pt_lr",
display_name="ResNet (pretrained) + LR",
loader=make_resnet_pt_lr,
),
"resnet_pt_svm": RegisteredModel(
id="resnet_pt_svm",
display_name="ResNet (pretrained) + SVM",
loader=make_resnet_pt_svm,
),
}
# Build registry once (models load lazily)
_REGISTRY: Dict[str, RegisteredModel] = _build_registry(device="cpu")
def get_registry() -> Dict[str, RegisteredModel]:
"""Return the full registry (id -> RegisteredModel)."""
return _REGISTRY
def get_model(model_id: str) -> Any:
"""Get a single model instance by id."""
if model_id not in _REGISTRY:
raise KeyError(f"Unknown model_id: {model_id}")
return _REGISTRY[model_id].get()
def get_models() -> Dict[str, Any]:
"""Eagerly load all models (optional)."""
return {mid: entry.get() for mid, entry in _REGISTRY.items()}
def get_model_display_names() -> Dict[str, str]:
"""Mapping id -> display name (for UI dropdowns)."""
return {mid: entry.display_name for mid, entry in _REGISTRY.items()}