| | """ |
| | Simple, flexible core models. |
| | """ |
| |
|
| | import torch |
| |
|
| | from models.components.layers.transformer_blocks import GenericTransformerBlock |
| |
|
| |
|
| | class GenericTransformer(torch.nn.Module): |
| | """ |
| | Generic Transformer Class intended to be used for as |
| | broad a range of transformer models as possible. |
| | """ |
| |
|
| | def __init__(self, model_cfg): |
| | super().__init__() |
| |
|
| | |
| | self.transformer = torch.nn.ModuleDict( |
| | { |
| | "drop": torch.nn.Dropout(), |
| | "h": torch.nn.ModuleList( |
| | [ |
| | GenericTransformerBlock( |
| | hidden_dim=model_cfg["hidden_dim"], |
| | context_window=model_cfg["context_window"], |
| | use_rope=model_cfg["positional_encoding_type"] == "rope", |
| | ffn_cfg=model_cfg["core_model"]["ffn"], |
| | attn_cfg=model_cfg["core_model"]["attn"], |
| | ) |
| | for _ in range(model_cfg["core_model"]["num_layers"]) |
| | ] |
| | ), |
| | } |
| | ) |
| |
|
| | def forward(self, x, attention_mask=None): |
| | """ |
| | Pass an input through the model |
| | Args: |
| | x: torch.tensor(B, S, H) |
| | Returns: |
| | x: torch.tensor(B, S, H) |
| | """ |
| |
|
| | |
| | x = self.transformer.drop(x) |
| |
|
| | |
| | for block in self.transformer.h: |
| | x = block(x, attention_mask=attention_mask) |
| |
|
| | return x |
| |
|
| |
|