""" transformer.py Full Transformer implementation for English → Bengali translation with complete calculation tracking at every step. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math from typing import Optional, Tuple, Dict, List # ───────────────────────────────────────────── # Calculation Logger # ───────────────────────────────────────────── class CalcLog: """Captures every intermediate tensor for visualization.""" def __init__(self): self.steps: List[Dict] = [] def log(self, name: str, data, formula: str = "", note: str = ""): entry = { "name": name, "formula": formula, "note": note, "shape": None, "value": None, } if isinstance(data, torch.Tensor): entry["shape"] = list(data.shape) entry["value"] = data.detach().cpu().numpy().tolist() elif isinstance(data, np.ndarray): entry["shape"] = list(data.shape) entry["value"] = data.tolist() else: entry["value"] = data self.steps.append(entry) return data def clear(self): self.steps = [] def to_dict(self): return self.steps # ───────────────────────────────────────────── # Positional Encoding # ───────────────────────────────────────────── class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): super().__init__() self.d_model = d_model self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1).float() div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model) def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None) -> torch.Tensor: seq_len = x.size(1) pe_slice = self.pe[:, :seq_len, :] if log: log.log("PE_matrix", pe_slice[0, :seq_len, :8], formula="PE(pos,2i)=sin(pos/10000^(2i/d)), PE(pos,2i+1)=cos(...)", note=f"Showing first 8 dims for {seq_len} positions") log.log("Embedding_before_PE", x[0, :, :8], note="Token embeddings (first 8 dims)") x = x + pe_slice if log: log.log("Embedding_after_PE", x[0, :, :8], formula="X = Embedding + PE", note="After adding positional encoding") return self.dropout(x) # ───────────────────────────────────────────── # Scaled Dot-Product Attention # ───────────────────────────────────────────── def scaled_dot_product_attention( Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, mask: Optional[torch.Tensor] = None, log: Optional[CalcLog] = None, head_idx: int = 0, layer_idx: int = 0, attn_type: str = "self", ) -> Tuple[torch.Tensor, torch.Tensor]: d_k = Q.size(-1) prefix = f"L{layer_idx}_H{head_idx}_{attn_type}" # Raw scores scores = torch.matmul(Q, K.transpose(-2, -1)) if log: log.log(f"{prefix}_Q", Q[0], formula="Q = X · Wq", note=f"Query matrix head {head_idx}") log.log(f"{prefix}_K", K[0], formula="K = X · Wk", note=f"Key matrix head {head_idx}") log.log(f"{prefix}_V", V[0], formula="V = X · Wv", note=f"Value matrix head {head_idx}") log.log(f"{prefix}_QKt", scores[0], formula="scores = Q · Kᵀ", note=f"Raw attention scores (before scaling)") # Scale scale = math.sqrt(d_k) scores = scores / scale if log: log.log(f"{prefix}_QKt_scaled", scores[0], formula=f"scores = Q·Kᵀ / √{d_k} = Q·Kᵀ / {scale:.3f}", note="Scaled scores — prevents vanishing gradients") # Mask # masks arrive as (B,1,1,T) or (B,1,T,T) from make_src/tgt_mask; # scores here are 3-D (B,T_q,T_k) because we loop per-head, # so squeeze the head dim to avoid (B,B,...) broadcasting. if mask is not None: if mask.dim() == 4: mask = mask.squeeze(1) # (B,1,T,T) or (B,1,1,T) → (B,T,T) or (B,1,T) scores = scores.masked_fill(mask == 0, float("-inf")) if log: log.log(f"{prefix}_mask", mask[0].float(), formula="mask[i,j]=0 → score=-inf (future token blocked)", note="Causal mask (training decoder) or padding mask") log.log(f"{prefix}_scores_masked", scores[0], note="Scores after masking (-inf will become 0 after softmax)") # Softmax attn_weights = F.softmax(scores, dim=-1) # replace nan from -inf rows with 0 (edge case) attn_weights = torch.nan_to_num(attn_weights, nan=0.0) if log: log.log(f"{prefix}_softmax", attn_weights[0], formula="α = softmax(scores, dim=-1)", note="Attention weights — each row sums to 1.0") # Weighted sum output = torch.matmul(attn_weights, V) if log: log.log(f"{prefix}_output", output[0], formula="Attention = α · V", note="Weighted sum of values") return output, attn_weights # ───────────────────────────────────────────── # Multi-Head Attention # ───────────────────────────────────────────── class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, num_heads: int): super().__init__() assert d_model % num_heads == 0 self.d_model = d_model self.num_heads = num_heads self.d_k = d_model // num_heads self.W_q = nn.Linear(d_model, d_model, bias=False) self.W_k = nn.Linear(d_model, d_model, bias=False) self.W_v = nn.Linear(d_model, d_model, bias=False) self.W_o = nn.Linear(d_model, d_model, bias=False) def split_heads(self, x: torch.Tensor) -> torch.Tensor: B, T, D = x.shape return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2) # → (B, num_heads, T, d_k) def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None, log: Optional[CalcLog] = None, layer_idx: int = 0, attn_type: str = "self", ) -> Tuple[torch.Tensor, torch.Tensor]: B = query.size(0) prefix = f"L{layer_idx}_{attn_type}_MHA" # Linear projections Q = self.W_q(query) K = self.W_k(key) V = self.W_v(value) if log: log.log(f"{prefix}_Wq", self.W_q.weight[:4, :4], formula="Wq shape: (d_model, d_model)", note=f"Query weight matrix (first 4×4 shown)") log.log(f"{prefix}_Q_full", Q[0, :, :8], formula="Q = input · Wq", note=f"Full Q projection (first 8 dims shown)") # Split into heads Q = self.split_heads(Q) # (B, h, T, d_k) K = self.split_heads(K) V = self.split_heads(V) if log: log.log(f"{prefix}_Q_head0", Q[0, 0, :, :], formula=f"Split: (B,T,D) → (B,{self.num_heads},T,{self.d_k})", note=f"Head 0 queries — d_k={self.d_k}") # Per-head attention (log only first 2 heads to avoid bloat) all_attn = [] all_weights = [] for h in range(self.num_heads): h_log = log if h < 2 else None out_h, w_h = scaled_dot_product_attention( Q[:, h], K[:, h], V[:, h], mask=mask, log=h_log, head_idx=h, layer_idx=layer_idx, attn_type=attn_type, ) all_attn.append(out_h) all_weights.append(w_h) # Concat heads concat = torch.stack(all_attn, dim=1) # (B, h, T, d_k) concat = concat.transpose(1, 2).contiguous() # (B, T, h, d_k) concat = concat.view(B, -1, self.d_model) # (B, T, D) if log: log.log(f"{prefix}_concat", concat[0, :, :8], formula="concat = [head_1; head_2; ...; head_h]", note=f"Concatenated heads (first 8 dims)") # Final projection output = self.W_o(concat) if log: log.log(f"{prefix}_output", output[0, :, :8], formula="MHA_out = concat · Wo", note="Final multi-head attention output") # Stack all attention weights: (B, h, T_q, T_k) attn_weights = torch.stack(all_weights, dim=1) return output, attn_weights # ───────────────────────────────────────────── # Feed-Forward Network # ───────────────────────────────────────────── class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None, layer_idx: int = 0, loc: str = "enc") -> torch.Tensor: prefix = f"L{layer_idx}_{loc}_FFN" h = self.linear1(x) if log: log.log(f"{prefix}_linear1", h[0, :, :8], formula="h = X · W1 + b1", note=f"First linear (d_model→d_ff), showing first 8 dims") h = F.relu(h) if log: log.log(f"{prefix}_relu", h[0, :, :8], formula="h = ReLU(h) = max(0, h)", note="Negative values zeroed out") h = self.dropout(h) out = self.linear2(h) if log: log.log(f"{prefix}_linear2", out[0, :, :8], formula="out = h · W2 + b2", note=f"Second linear (d_ff→d_model)") return out # ───────────────────────────────────────────── # Layer Norm + Residual # ───────────────────────────────────────────── class AddNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-6): super().__init__() self.norm = nn.LayerNorm(d_model, eps=eps) def forward(self, x: torch.Tensor, sublayer_out: torch.Tensor, log: Optional[CalcLog] = None, tag: str = "") -> torch.Tensor: residual = x + sublayer_out out = self.norm(residual) if log: log.log(f"{tag}_residual", residual[0, :, :8], formula="residual = x + sublayer(x)", note="Residual (skip) connection") log.log(f"{tag}_layernorm", out[0, :, :8], formula="LayerNorm(x) = γ·(x−μ)/σ + β", note="Layer normalization output") return out # ───────────────────────────────────────────── # Encoder Layer # ───────────────────────────────────────────── class EncoderLayer(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.ffn = FeedForward(d_model, d_ff, dropout) self.add_norm1 = AddNorm(d_model) self.add_norm2 = AddNorm(d_model) def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None, log: Optional[CalcLog] = None, layer_idx: int = 0): attn_out, attn_w = self.self_attn( x, x, x, mask=src_mask, log=log, layer_idx=layer_idx, attn_type="enc_self" ) x = self.add_norm1(x, attn_out, log=log, tag=f"L{layer_idx}_enc_self") ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="enc") x = self.add_norm2(x, ffn_out, log=log, tag=f"L{layer_idx}_enc_ffn") return x, attn_w # ───────────────────────────────────────────── # Decoder Layer # ───────────────────────────────────────────── class DecoderLayer(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.masked_self_attn = MultiHeadAttention(d_model, num_heads) self.cross_attn = MultiHeadAttention(d_model, num_heads) self.ffn = FeedForward(d_model, d_ff, dropout) self.add_norm1 = AddNorm(d_model) self.add_norm2 = AddNorm(d_model) self.add_norm3 = AddNorm(d_model) def forward( self, x: torch.Tensor, enc_out: torch.Tensor, tgt_mask: Optional[torch.Tensor] = None, src_mask: Optional[torch.Tensor] = None, log: Optional[CalcLog] = None, layer_idx: int = 0, ): # 1. Masked self-attention m_attn_out, m_attn_w = self.masked_self_attn( x, x, x, mask=tgt_mask, log=log, layer_idx=layer_idx, attn_type="dec_masked" ) x = self.add_norm1(x, m_attn_out, log=log, tag=f"L{layer_idx}_dec_masked") # 2. Cross-attention: Q from decoder, K/V from encoder if log: log.log(f"L{layer_idx}_cross_Q_source", x[0, :, :8], note="Cross-attn Q comes from DECODER (Bengali context)") log.log(f"L{layer_idx}_cross_KV_source", enc_out[0, :, :8], note="Cross-attn K,V come from ENCODER (English context)") c_attn_out, c_attn_w = self.cross_attn( query=x, key=enc_out, value=enc_out, mask=src_mask, log=log, layer_idx=layer_idx, attn_type="dec_cross" ) x = self.add_norm2(x, c_attn_out, log=log, tag=f"L{layer_idx}_dec_cross") # 3. FFN ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="dec") x = self.add_norm3(x, ffn_out, log=log, tag=f"L{layer_idx}_dec_ffn") return x, m_attn_w, c_attn_w # ───────────────────────────────────────────── # Full Transformer # ───────────────────────────────────────────── class Transformer(nn.Module): def __init__( self, src_vocab_size: int, tgt_vocab_size: int, d_model: int = 128, num_heads: int = 4, num_layers: int = 2, d_ff: int = 256, max_len: int = 64, dropout: float = 0.1, pad_idx: int = 0, ): super().__init__() self.d_model = d_model self.pad_idx = pad_idx self.num_layers = num_layers self.src_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx) self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx) self.src_pe = PositionalEncoding(d_model, max_len, dropout) self.tgt_pe = PositionalEncoding(d_model, max_len, dropout) self.encoder_layers = nn.ModuleList( [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)] ) self.decoder_layers = nn.ModuleList( [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)] ) self.output_linear = nn.Linear(d_model, tgt_vocab_size) self._init_weights() def _init_weights(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) # ── mask helpers ────────────────────────── def make_src_mask(self, src: torch.Tensor) -> torch.Tensor: # (B, 1, 1, T_src) — 1 where not pad return (src != self.pad_idx).unsqueeze(1).unsqueeze(2) def make_tgt_mask(self, tgt: torch.Tensor) -> torch.Tensor: T = tgt.size(1) pad_mask = (tgt != self.pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,T) causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() # (T,T) return pad_mask & causal # (B,1,T,T) # ── forward ─────────────────────────────── def forward( self, src: torch.Tensor, tgt: torch.Tensor, log: Optional[CalcLog] = None, ) -> Tuple[torch.Tensor, Dict]: src_mask = self.make_src_mask(src) tgt_mask = self.make_tgt_mask(tgt) # ── Encoder ────────────────────────── src_emb = self.src_embed(src) * math.sqrt(self.d_model) if log: log.log("SRC_tokens", src[0], note="Source token IDs (English)") log.log("SRC_embedding_raw", src_emb[0, :, :8], formula=f"emb = Embedding(token_id) × √{self.d_model}", note="Token embeddings (first 8 dims)") enc_x = self.src_pe(src_emb, log=log) enc_attn_weights = [] for i, layer in enumerate(self.encoder_layers): enc_x, ew = layer(enc_x, src_mask=src_mask, log=log, layer_idx=i) enc_attn_weights.append(ew.detach().cpu().numpy()) if log: log.log("ENCODER_output", enc_x[0, :, :8], note="Final encoder output — passed as K,V to every decoder cross-attention") # ── Decoder ────────────────────────── tgt_emb = self.tgt_embed(tgt) * math.sqrt(self.d_model) if log: log.log("TGT_tokens", tgt[0], note="Target token IDs (Bengali, teacher-forced in training)") log.log("TGT_embedding_raw", tgt_emb[0, :, :8], formula=f"emb = Embedding(token_id) × √{self.d_model}", note="Bengali token embeddings") dec_x = self.tgt_pe(tgt_emb, log=log) dec_self_attn_w = [] dec_cross_attn_w = [] for i, layer in enumerate(self.decoder_layers): dec_x, mw, cw = layer( dec_x, enc_x, tgt_mask=tgt_mask, src_mask=src_mask, log=log, layer_idx=i, ) dec_self_attn_w.append(mw.detach().cpu().numpy()) dec_cross_attn_w.append(cw.detach().cpu().numpy()) # ── Output projection ───────────────── logits = self.output_linear(dec_x) # (B, T, vocab) if log: log.log("LOGITS", logits[0, :, :16], formula="logits = dec_out · W_out (first 16 vocab entries shown)", note=f"Raw scores over vocab of {logits.size(-1)} Bengali tokens") probs = F.softmax(logits[0], dim=-1) log.log("SOFTMAX_probs", probs[:, :16], formula="P(token) = exp(logit) / Σ exp(logits)", note="Probability distribution over Bengali vocabulary") meta = { "enc_attn": enc_attn_weights, "dec_self_attn": dec_self_attn_w, "dec_cross_attn": dec_cross_attn_w, "src_mask": src_mask.cpu().numpy(), "tgt_mask": tgt_mask.cpu().numpy(), } return logits, meta