| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from multiheadattention import MultiHeadAttention | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, embed_dim, num_heads, ff_dim): | |
| super().__init__() | |
| self.attn = MultiHeadAttention(embed_dim, num_heads) | |
| self.ln1 = nn.LayerNorm(embed_dim) | |
| self.ff = nn.Sequential( | |
| nn.Linear(embed_dim, ff_dim), | |
| nn.GELU(), | |
| nn.Linear(ff_dim, embed_dim) | |
| ) | |
| self.ln2 = nn.LayerNorm(embed_dim) | |
| def forward(self, x, mask=None): | |
| x = x + self.attn(self.ln1(x), mask = mask) | |
| x = x + self.ff(self.ln2(x)) | |
| return x | |