Shashwat98 commited on
Commit
627d37a
·
verified ·
1 Parent(s): 52dd1ca

Update src/registry.py

Browse files
Files changed (1) hide show
  1. src/registry.py +107 -108
src/registry.py CHANGED
@@ -1,108 +1,107 @@
1
- # src/registry.py
2
-
3
- from dataclasses import dataclass, field
4
- from typing import Callable, Dict, Any, Optional
5
-
6
-
7
- @dataclass
8
- class RegisteredModel:
9
- """Metadata + lazy loader for a single model."""
10
- id: str
11
- display_name: str
12
- loader: Callable[[], Any]
13
- _instance: Optional[Any] = field(default=None, init=False, repr=False)
14
-
15
- def get(self) -> Any:
16
- """Instantiate on first call, then cache."""
17
- if self._instance is None:
18
- self._instance = self.loader()
19
- return self._instance
20
-
21
-
22
- def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]:
23
- """
24
- Central place to register all models.
25
- Returns a dict: model_id -> RegisteredModel.
26
- """
27
-
28
- def make_lr_raw():
29
- from src.inference.lr_model import LRModel
30
- return LRModel(
31
- ckpt_path="checkpoints/lr_model.joblib",
32
- labels_path="configs/labels.json",
33
- device=device,
34
- )
35
-
36
- def make_svm_raw():
37
- from src.inference.svm_model import SVMModel
38
- return SVMModel(
39
- ckpt_path="checkpoints/svm_model.joblib",
40
- labels_path="configs/labels.json",
41
- device=device,
42
- )
43
-
44
- def make_resnet_pt_lr():
45
- from src.inference.resnet_pt_lr_model import ResNetPTLRModel
46
- return ResNetPTLRModel(
47
- ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
48
- labels_path="configs/labels.json",
49
- device=device,
50
- )
51
-
52
- def make_resnet_pt_svm():
53
- from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
54
- return ResNetPTSVMModel(
55
- ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
56
- labels_path="configs/labels.json",
57
- device=device,
58
- )
59
-
60
- return {
61
- "lr_raw": RegisteredModel(
62
- id="lr_raw",
63
- display_name="LR (raw 64×64 grayscale)",
64
- loader=make_lr_raw,
65
- ),
66
- "svm_raw": RegisteredModel(
67
- id="svm_raw",
68
- display_name="SVM (raw 64×64 grayscale)",
69
- loader=make_svm_raw,
70
- ),
71
- "resnet_pt_lr": RegisteredModel(
72
- id="resnet_pt_lr",
73
- display_name="ResNet(PT) + LR",
74
- loader=make_resnet_pt_lr,
75
- ),
76
- "resnet_pt_svm": RegisteredModel(
77
- id="resnet_pt_svm",
78
- display_name="ResNet(PT) + SVM",
79
- loader=make_resnet_pt_svm,
80
- ),
81
- }
82
-
83
-
84
- # Build once at import; models themselves are loaded lazily.
85
- _REGISTRY: Dict[str, RegisteredModel] = _build_registry()
86
-
87
-
88
- def get_registry() -> Dict[str, RegisteredModel]:
89
- """Return the full registry (id -> RegisteredModel)."""
90
- return _REGISTRY
91
-
92
-
93
- def get_models() -> Dict[str, Any]:
94
- """
95
- Eagerly instantiate all models and return id -> model_instance.
96
- Useful for simple scripts or for initializing everything at UI startup.
97
- """
98
- return {mid: entry.get() for mid, entry in _REGISTRY.items()}
99
-
100
-
101
- def get_model(model_id: str) -> Any:
102
- """Get a single model instance by id (instantiates on first use)."""
103
- return _REGISTRY[model_id].get()
104
-
105
-
106
- def get_model_display_names() -> Dict[str, str]:
107
- """Return mapping id -> human-readable name (for dropdown choices)."""
108
- return {mid: entry.display_name for mid, entry in _REGISTRY.items()}
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Callable, Dict, Any, Optional
3
+
4
+
5
+ @dataclass
6
+ class RegisteredModel:
7
+ """Metadata + lazy loader for a single model."""
8
+ id: str
9
+ display_name: str
10
+ loader: Callable[[], Any]
11
+ _instance: Optional[Any] = field(default=None, init=False, repr=False)
12
+
13
+ def get(self) -> Any:
14
+ """Instantiate on first call, then cache."""
15
+ if self._instance is None:
16
+ self._instance = self.loader()
17
+ return self._instance
18
+
19
+
20
+ def _build_registry(device: str = "cpu") -> Dict[str, RegisteredModel]:
21
+ """
22
+ Central place to register all models.
23
+ Returns a dict: model_id -> RegisteredModel.
24
+ """
25
+
26
+ # -------- LR on raw pixels --------
27
+ def make_lr_raw():
28
+ from src.inference.lr_model import LRModel
29
+ return LRModel(
30
+ model_path="checkpoints/lr_model.joblib",
31
+ labels_path="configs/labels.json",
32
+ )
33
+
34
+ # -------- SVM on raw pixels --------
35
+ def make_svm_raw():
36
+ from src.inference.svm_model import SVMModel
37
+ return SVMModel(
38
+ ckpt_path="checkpoints/svm_model.joblib",
39
+ labels_path="configs/labels.json",
40
+ )
41
+
42
+ # -------- ResNet (PT) + LR head --------
43
+ def make_resnet_pt_lr():
44
+ from src.inference.resnet_pt_lr_model import ResNetPTLRModel
45
+ return ResNetPTLRModel(
46
+ ckpt_path="checkpoints/resnet_pt_lr_head.joblib",
47
+ labels_path="configs/labels.json",
48
+ device=device,
49
+ )
50
+
51
+ # -------- ResNet (PT) + SVM head --------
52
+ def make_resnet_pt_svm():
53
+ from src.inference.resnet_pt_svm_model import ResNetPTSVMModel
54
+ return ResNetPTSVMModel(
55
+ ckpt_path="checkpoints/resnet_pt_svm_head.joblib",
56
+ labels_path="configs/labels.json",
57
+ device=device,
58
+ )
59
+
60
+ return {
61
+ "lr_raw": RegisteredModel(
62
+ id="lr_raw",
63
+ display_name="LR (raw 64×64 grayscale)",
64
+ loader=make_lr_raw,
65
+ ),
66
+ "svm_raw": RegisteredModel(
67
+ id="svm_raw",
68
+ display_name="SVM (raw 64×64 grayscale)",
69
+ loader=make_svm_raw,
70
+ ),
71
+ "resnet_pt_lr": RegisteredModel(
72
+ id="resnet_pt_lr",
73
+ display_name="ResNet (pretrained) + LR",
74
+ loader=make_resnet_pt_lr,
75
+ ),
76
+ "resnet_pt_svm": RegisteredModel(
77
+ id="resnet_pt_svm",
78
+ display_name="ResNet (pretrained) + SVM",
79
+ loader=make_resnet_pt_svm,
80
+ ),
81
+ }
82
+
83
+
84
+ # Build registry once (models load lazily)
85
+ _REGISTRY: Dict[str, RegisteredModel] = _build_registry(device="cpu")
86
+
87
+
88
+ def get_registry() -> Dict[str, RegisteredModel]:
89
+ """Return the full registry (id -> RegisteredModel)."""
90
+ return _REGISTRY
91
+
92
+
93
+ def get_model(model_id: str) -> Any:
94
+ """Get a single model instance by id."""
95
+ if model_id not in _REGISTRY:
96
+ raise KeyError(f"Unknown model_id: {model_id}")
97
+ return _REGISTRY[model_id].get()
98
+
99
+
100
+ def get_models() -> Dict[str, Any]:
101
+ """Eagerly load all models (optional)."""
102
+ return {mid: entry.get() for mid, entry in _REGISTRY.items()}
103
+
104
+
105
+ def get_model_display_names() -> Dict[str, str]:
106
+ """Mapping id -> display name (for UI dropdowns)."""
107
+ return {mid: entry.display_name for mid, entry in _REGISTRY.items()}