modell-name / src /models /registry.py
RabidUmarell's picture
Add model checkpoint and source
8006486 verified
"""Model registry.
To add a new architecture:
1. Create src/models/yourmodel.py with a class that:
- accepts __init__(d_input, d_model, d_output, n_layers, **kwargs)
- forward(x: Tensor[B, T, d_input]) → Tensor[B, T, d_output]
- defines a static method extra_kwargs(model_cfg) → dict
returning any model-specific kwargs from the ModelConfig.
2. Call register("name", YourClass) below.
"""
from src.models.lstm import LSTMModel
from src.models.transformer import TransformerModel
from src.models.transformer_glu import TransformerGLUModel
from src.models.tidar import TiDARModel
REGISTRY: dict = {}
def register(name: str, cls):
REGISTRY[name] = cls
register("lstm", LSTMModel)
register("transformer", TransformerModel)
register("transformer_glu", TransformerGLUModel)
register("tidar", TiDARModel)