File size: 875 Bytes
8006486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
"""Model registry.

To add a new architecture:
  1. Create src/models/yourmodel.py with a class that:
       - accepts __init__(d_input, d_model, d_output, n_layers, **kwargs)
       - forward(x: Tensor[B, T, d_input]) → Tensor[B, T, d_output]
       - defines a static method  extra_kwargs(model_cfg) → dict
         returning any model-specific kwargs from the ModelConfig.
  2. Call  register("name", YourClass)  below.
"""

from src.models.lstm import LSTMModel
from src.models.transformer import TransformerModel
from src.models.transformer_glu import TransformerGLUModel
from src.models.tidar import TiDARModel

REGISTRY: dict = {}


def register(name: str, cls):
    REGISTRY[name] = cls


register("lstm",            LSTMModel)
register("transformer",     TransformerModel)
register("transformer_glu", TransformerGLUModel)
register("tidar",           TiDARModel)