import math import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, hidden_dim: int, n_heads: int, dropout: float): super().__init__() assert hidden_dim % n_heads == 0, "hidden_dim must be divisible by n_heads" self.n_heads = n_heads self.head_dim = hidden_dim // n_heads # Linear projection for query, key, value (combined for efficiency) self.qkv = nn.Linear(hidden_dim, hidden_dim * 3) # Linear projection for output self.out_proj = nn.Linear(hidden_dim, hidden_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # x shape: (batch, seq_length, hidden_dim) batch_size, seq_length, hidden_dim = x.size() # Project queries, keys, and values qkv = self.qkv(x) # (batch, seq_length, 3*hidden_dim) # Split into Q, K, V and reshape for multi-head attention qkv = qkv.reshape(batch_size, seq_length, 3, self.n_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, n_heads, seq_length, head_dim) q, k, v = qkv[0], qkv[1], qkv[2] # Each shape: (batch, n_heads, seq_length, head_dim) # Scaled dot-product attention attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Causal mask to hide future positions mask = torch.tril(torch.ones(seq_length, seq_length, device=x.device)) attn_scores = attn_scores.masked_fill(mask == 0, float('-inf')) attn_weights = F.softmax(attn_scores, dim=-1) # (batch, n_heads, seq_length, seq_length) attn_weights = self.dropout(attn_weights) # Weighted sum of values attn_output = torch.matmul(attn_weights, v) # (batch, n_heads, seq_length, head_dim) # Combine heads attn_output = attn_output.permute(0, 2, 1, 4).reshape(batch_size, seq_length, hidden_dim) # Final linear projection and dropout output = self.out_proj(attn_output) output = self.dropout(output) return output class FeedForward(nn.Module): def __init__(self, hidden_dim: int, dropout: float): super().__init__() self.fc1 = nn.Linear(hidden_dim, 4 * hidden_dim) self.fc2 = nn.Linear(4 * hidden_dim, hidden_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): # Two-layer MLP with GELU activation x = F.gelu(self.fc1(x)) x = self.fc2(x) x = self.dropout(x) return x class TransformerBlock(nn.Module): def __init__(self, hidden_dim: int, n_heads: int, dropout: float): super().__init__() self.ln1 = nn.LayerNorm(hidden_dim) self.ln2 = nn.LayerNorm(hidden_dim) self.attn = SelfAttention(hidden_dim, n_heads, dropout) self.ff = FeedForward(hidden_dim, dropout) def forward(self, x): # Apply self-attention with residual connection a = self.ln1(x) x = x + self.attn(a) # Apply feed-forward network with residual connection m = self.ln2(x) x = x + self.ff(m) return x