import torch import torch.nn as nn import torch.nn.functional as F from multiheadattention import MultiHeadAttention class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, ff_dim): super().__init__() self.attn = MultiHeadAttention(embed_dim, num_heads) self.ln1 = nn.LayerNorm(embed_dim) self.ff = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.GELU(), nn.Linear(ff_dim, embed_dim) ) self.ln2 = nn.LayerNorm(embed_dim) def forward(self, x, mask=None): x = x + self.attn(self.ln1(x), mask = mask) x = x + self.ff(self.ln2(x)) return x