| from torch import nn | |
| from GroupedQueryAttention import GroupedQueryAttention | |
| from FeedForward import FeedForward | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.att = GroupedQueryAttention( | |
| d_in=cfg["emb_dim"], | |
| d_out=cfg["emb_dim"], | |
| num_heads=cfg["n_heads"], | |
| num_kv_groups=cfg["n_kv_groups"], | |
| dtype=cfg["dtype"] | |
| ) | |
| self.ff = FeedForward(cfg) | |
| self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) | |
| self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) | |
| def forward(self, x, mask, cos, sin): | |
| # Shortcut connection for attention block | |
| shortcut = x | |
| x = self.norm1(x) | |
| x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size] | |
| x = x + shortcut # Add the original input back | |
| # Shortcut connection for feed-forward block | |
| shortcut = x | |
| x = self.norm2(x) | |
| x = self.ff(x) | |
| x = x + shortcut # Add the original input back | |
| return x |