| """Model registry. |
| |
| To add a new architecture: |
| 1. Create src/models/yourmodel.py with a class that: |
| - accepts __init__(d_input, d_model, d_output, n_layers, **kwargs) |
| - forward(x: Tensor[B, T, d_input]) → Tensor[B, T, d_output] |
| - defines a static method extra_kwargs(model_cfg) → dict |
| returning any model-specific kwargs from the ModelConfig. |
| 2. Call register("name", YourClass) below. |
| """ |
|
|
| from src.models.lstm import LSTMModel |
| from src.models.transformer import TransformerModel |
| from src.models.transformer_glu import TransformerGLUModel |
| from src.models.tidar import TiDARModel |
|
|
| REGISTRY: dict = {} |
|
|
|
|
| def register(name: str, cls): |
| REGISTRY[name] = cls |
|
|
|
|
| register("lstm", LSTMModel) |
| register("transformer", TransformerModel) |
| register("transformer_glu", TransformerGLUModel) |
| register("tidar", TiDARModel) |
|
|