Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass, field | |
| from typing import Callable, Dict, Any, Optional | |
| class RegisteredModel: | |
| """Metadata + lazy loader for a single model.""" | |
| id: str | |
| display_name: str | |
| loader: Callable[[], Any] | |
| _instance: Optional[Any] = field(default=None, init=False, repr=False) | |
| def get(self) -> Any: | |
| """Instantiate on first call, then cache.""" | |
| if self._instance is None: | |
| self._instance = self.loader() | |
| return self._instance | |
| def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]: | |
| """ | |
| Central place to register all models. | |
| Returns a dict: model_id -> RegisteredModel. | |
| """ | |
| # -------- LR on raw pixels -------- | |
| def make_lr_raw(): | |
| from src.inference.lr_model import LRModel | |
| return LRModel( | |
| model_path="checkpoints/lr_model.joblib", | |
| labels_path="configs/labels.json", | |
| ) | |
| # -------- SVM on raw pixels -------- | |
| def make_svm_raw(): | |
| from src.inference.svm_model import SVMModel | |
| return SVMModel( | |
| ckpt_path="checkpoints/svm_model.joblib", | |
| labels_path="configs/labels.json", | |
| ) | |
| # -------- ResNet (PT) + LR head -------- | |
| def make_resnet_pt_lr(): | |
| from src.inference.resnet_pt_lr_model import ResNetPTLRModel | |
| return ResNetPTLRModel( | |
| ckpt_path="checkpoints/resnet_pt_lr_head.joblib", | |
| labels_path="configs/labels.json", | |
| device=device, | |
| ) | |
| # -------- ResNet (PT) + SVM head -------- | |
| def make_resnet_pt_svm(): | |
| from src.inference.resnet_pt_svm_model import ResNetPTSVMModel | |
| return ResNetPTSVMModel( | |
| ckpt_path="checkpoints/resnet_pt_svm_head.joblib", | |
| labels_path="configs/labels.json", | |
| device=device, | |
| ) | |
| return { | |
| "lr_raw": RegisteredModel( | |
| id="lr_raw", | |
| display_name="LR (raw 64×64 grayscale)", | |
| loader=make_lr_raw, | |
| ), | |
| "svm_raw": RegisteredModel( | |
| id="svm_raw", | |
| display_name="SVM (raw 64×64 grayscale)", | |
| loader=make_svm_raw, | |
| ), | |
| "resnet_pt_lr": RegisteredModel( | |
| id="resnet_pt_lr", | |
| display_name="ResNet (pretrained) + LR", | |
| loader=make_resnet_pt_lr, | |
| ), | |
| "resnet_pt_svm": RegisteredModel( | |
| id="resnet_pt_svm", | |
| display_name="ResNet (pretrained) + SVM", | |
| loader=make_resnet_pt_svm, | |
| ), | |
| } | |
| # Build registry once (models load lazily) | |
| _REGISTRY: Dict[str, RegisteredModel] = _build_registry(device="cpu") | |
| def get_registry() -> Dict[str, RegisteredModel]: | |
| """Return the full registry (id -> RegisteredModel).""" | |
| return _REGISTRY | |
| def get_model(model_id: str) -> Any: | |
| """Get a single model instance by id.""" | |
| if model_id not in _REGISTRY: | |
| raise KeyError(f"Unknown model_id: {model_id}") | |
| return _REGISTRY[model_id].get() | |
| def get_models() -> Dict[str, Any]: | |
| """Eagerly load all models (optional).""" | |
| return {mid: entry.get() for mid, entry in _REGISTRY.items()} | |
| def get_model_display_names() -> Dict[str, str]: | |
| """Mapping id -> display name (for UI dropdowns).""" | |
| return {mid: entry.display_name for mid, entry in _REGISTRY.items()} | |