File size: 3,326 Bytes
627d37a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from dataclasses import dataclass, field
from typing import Callable, Dict, Any, Optional


@dataclass
class RegisteredModel:
    """Metadata + lazy loader for a single model."""
    id: str
    display_name: str
    loader: Callable[[], Any]
    _instance: Optional[Any] = field(default=None, init=False, repr=False)

    def get(self) -> Any:
        """Instantiate on first call, then cache."""
        if self._instance is None:
            self._instance = self.loader()
        return self._instance


def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]:
    """
    Central place to register all models.
    Returns a dict: model_id -> RegisteredModel.
    """

    # -------- LR on raw pixels --------
    def make_lr_raw():
        from src.inference.lr_model import LRModel
        return LRModel(
            model_path="checkpoints/lr_model.joblib",
            labels_path="configs/labels.json",
        )

    # -------- SVM on raw pixels --------
    def make_svm_raw():
        from src.inference.svm_model import SVMModel
        return SVMModel(
            ckpt_path="checkpoints/svm_model.joblib",
            labels_path="configs/labels.json",
        )

    # -------- ResNet (PT) + LR head --------
    def make_resnet_pt_lr():
        from src.inference.resnet_pt_lr_model import ResNetPTLRModel
        return ResNetPTLRModel(
            ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
            labels_path="configs/labels.json",
            device=device,
        )

    # -------- ResNet (PT) + SVM head --------
    def make_resnet_pt_svm():
        from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
        return ResNetPTSVMModel(
            ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
            labels_path="configs/labels.json",
            device=device,
        )

    return {
        "lr_raw": RegisteredModel(
            id="lr_raw",
            display_name="LR (raw 64×64 grayscale)",
            loader=make_lr_raw,
        ),
        "svm_raw": RegisteredModel(
            id="svm_raw",
            display_name="SVM (raw 64×64 grayscale)",
            loader=make_svm_raw,
        ),
        "resnet_pt_lr": RegisteredModel(
            id="resnet_pt_lr",
            display_name="ResNet (pretrained) + LR",
            loader=make_resnet_pt_lr,
        ),
        "resnet_pt_svm": RegisteredModel(
            id="resnet_pt_svm",
            display_name="ResNet (pretrained) + SVM",
            loader=make_resnet_pt_svm,
        ),
    }


# Build registry once (models load lazily)
_REGISTRY: Dict[str, RegisteredModel] = _build_registry(device="cpu")


def get_registry() -> Dict[str, RegisteredModel]:
    """Return the full registry (id -> RegisteredModel)."""
    return _REGISTRY


def get_model(model_id: str) -> Any:
    """Get a single model instance by id."""
    if model_id not in _REGISTRY:
        raise KeyError(f"Unknown model_id: {model_id}")
    return _REGISTRY[model_id].get()


def get_models() -> Dict[str, Any]:
    """Eagerly load all models (optional)."""
    return {mid: entry.get() for mid, entry in _REGISTRY.items()}


def get_model_display_names() -> Dict[str, str]:
    """Mapping id -> display name (for UI dropdowns)."""
    return {mid: entry.display_name for mid, entry in _REGISTRY.items()}