File size: 1,118 Bytes
58fda56 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | """
Model registry — @register decorator pattern.
Adding a new model = 1 file in backbones/, decorated with @register("name").
Training, ensemble, and inference code never import specific models.
"""
from __future__ import annotations
import torch.nn as nn
_BACKBONE_REGISTRY: dict[str, type[nn.Module]] = {}
def register(name: str):
"""Decorator to register a backbone class by name."""
def decorator(cls: type[nn.Module]) -> type[nn.Module]:
if name in _BACKBONE_REGISTRY:
raise ValueError(f"Backbone '{name}' already registered")
_BACKBONE_REGISTRY[name] = cls
return cls
return decorator
def get_backbone(name: str, **kwargs) -> nn.Module:
"""Instantiate a registered backbone by name."""
if name not in _BACKBONE_REGISTRY:
available = ", ".join(sorted(_BACKBONE_REGISTRY.keys()))
raise ValueError(f"Unknown backbone '{name}'. Available: {available}")
return _BACKBONE_REGISTRY[name](**kwargs)
def list_backbones() -> list[str]:
"""Return names of all registered backbones."""
return sorted(_BACKBONE_REGISTRY.keys())
|