File size: 875 Bytes
8006486 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | """Model registry.
To add a new architecture:
1. Create src/models/yourmodel.py with a class that:
- accepts __init__(d_input, d_model, d_output, n_layers, **kwargs)
- forward(x: Tensor[B, T, d_input]) → Tensor[B, T, d_output]
- defines a static method extra_kwargs(model_cfg) → dict
returning any model-specific kwargs from the ModelConfig.
2. Call register("name", YourClass) below.
"""
from src.models.lstm import LSTMModel
from src.models.transformer import TransformerModel
from src.models.transformer_glu import TransformerGLUModel
from src.models.tidar import TiDARModel
REGISTRY: dict = {}
def register(name: str, cls):
REGISTRY[name] = cls
register("lstm", LSTMModel)
register("transformer", TransformerModel)
register("transformer_glu", TransformerGLUModel)
register("tidar", TiDARModel)
|