|
|
""" |
|
|
Simple, flexible core models. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
|
|
|
from models.components.layers.transformer_blocks import GenericTransformerBlock |
|
|
|
|
|
|
|
|
class GenericTransformer(torch.nn.Module): |
|
|
""" |
|
|
Generic Transformer Class intended to be used for as |
|
|
broad a range of transformer models as possible. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_cfg): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.transformer = torch.nn.ModuleDict( |
|
|
{ |
|
|
"drop": torch.nn.Dropout(), |
|
|
"h": torch.nn.ModuleList( |
|
|
[ |
|
|
GenericTransformerBlock( |
|
|
hidden_dim=model_cfg["hidden_dim"], |
|
|
context_window=model_cfg["context_window"], |
|
|
use_rope=model_cfg["positional_encoding_type"] == "rope", |
|
|
ffn_cfg=model_cfg["core_model"]["ffn"], |
|
|
attn_cfg=model_cfg["core_model"]["attn"], |
|
|
) |
|
|
for _ in range(model_cfg["core_model"]["num_layers"]) |
|
|
] |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
def forward(self, x, attention_mask=None): |
|
|
""" |
|
|
Pass an input through the model |
|
|
Args: |
|
|
x: torch.tensor(B, S, H) |
|
|
Returns: |
|
|
x: torch.tensor(B, S, H) |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.transformer.drop(x) |
|
|
|
|
|
|
|
|
for block in self.transformer.h: |
|
|
x = block(x, attention_mask=attention_mask) |
|
|
|
|
|
return x |
|
|
|
|
|
|