dboris commited on
Commit
58fda56
·
verified ·
1 Parent(s): 154b6ce

Upload src/registry.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/registry.py +35 -0
src/registry.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model registry — @register decorator pattern.
3
+
4
+ Adding a new model = 1 file in backbones/, decorated with @register("name").
5
+ Training, ensemble, and inference code never import specific models.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch.nn as nn
11
+
12
+ _BACKBONE_REGISTRY: dict[str, type[nn.Module]] = {}
13
+
14
+
15
+ def register(name: str):
16
+ """Decorator to register a backbone class by name."""
17
+ def decorator(cls: type[nn.Module]) -> type[nn.Module]:
18
+ if name in _BACKBONE_REGISTRY:
19
+ raise ValueError(f"Backbone '{name}' already registered")
20
+ _BACKBONE_REGISTRY[name] = cls
21
+ return cls
22
+ return decorator
23
+
24
+
25
+ def get_backbone(name: str, **kwargs) -> nn.Module:
26
+ """Instantiate a registered backbone by name."""
27
+ if name not in _BACKBONE_REGISTRY:
28
+ available = ", ".join(sorted(_BACKBONE_REGISTRY.keys()))
29
+ raise ValueError(f"Unknown backbone '{name}'. Available: {available}")
30
+ return _BACKBONE_REGISTRY[name](**kwargs)
31
+
32
+
33
+ def list_backbones() -> list[str]:
34
+ """Return names of all registered backbones."""
35
+ return sorted(_BACKBONE_REGISTRY.keys())