Spaces:
Sleeping
Sleeping
| """ | |
| transformer.py | |
| Full Transformer implementation for English → Bengali translation | |
| with complete calculation tracking at every step. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import math | |
| from typing import Optional, Tuple, Dict, List | |
| # ───────────────────────────────────────────── | |
| # Calculation Logger | |
| # ───────────────────────────────────────────── | |
| class CalcLog: | |
| """Captures every intermediate tensor for visualization.""" | |
| def __init__(self): | |
| self.steps: List[Dict] = [] | |
| def log(self, name: str, data, formula: str = "", note: str = ""): | |
| entry = { | |
| "name": name, | |
| "formula": formula, | |
| "note": note, | |
| "shape": None, | |
| "value": None, | |
| } | |
| if isinstance(data, torch.Tensor): | |
| entry["shape"] = list(data.shape) | |
| entry["value"] = data.detach().cpu().numpy().tolist() | |
| elif isinstance(data, np.ndarray): | |
| entry["shape"] = list(data.shape) | |
| entry["value"] = data.tolist() | |
| else: | |
| entry["value"] = data | |
| self.steps.append(entry) | |
| return data | |
| def clear(self): | |
| self.steps = [] | |
| def to_dict(self): | |
| return self.steps | |
| # ───────────────────────────────────────────── | |
| # Positional Encoding | |
| # ───────────────────────────────────────────── | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.dropout = nn.Dropout(dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1).float() | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model) | |
| def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None) -> torch.Tensor: | |
| seq_len = x.size(1) | |
| pe_slice = self.pe[:, :seq_len, :] | |
| if log: | |
| log.log("PE_matrix", pe_slice[0, :seq_len, :8], | |
| formula="PE(pos,2i)=sin(pos/10000^(2i/d)), PE(pos,2i+1)=cos(...)", | |
| note=f"Showing first 8 dims for {seq_len} positions") | |
| log.log("Embedding_before_PE", x[0, :, :8], | |
| note="Token embeddings (first 8 dims)") | |
| x = x + pe_slice | |
| if log: | |
| log.log("Embedding_after_PE", x[0, :, :8], | |
| formula="X = Embedding + PE", | |
| note="After adding positional encoding") | |
| return self.dropout(x) | |
| # ───────────────────────────────────────────── | |
| # Scaled Dot-Product Attention | |
| # ───────────────────────────────────────────── | |
| def scaled_dot_product_attention( | |
| Q: torch.Tensor, | |
| K: torch.Tensor, | |
| V: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| log: Optional[CalcLog] = None, | |
| head_idx: int = 0, | |
| layer_idx: int = 0, | |
| attn_type: str = "self", | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| d_k = Q.size(-1) | |
| prefix = f"L{layer_idx}_H{head_idx}_{attn_type}" | |
| # Raw scores | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) | |
| if log: | |
| log.log(f"{prefix}_Q", Q[0], | |
| formula="Q = X · Wq", | |
| note=f"Query matrix head {head_idx}") | |
| log.log(f"{prefix}_K", K[0], | |
| formula="K = X · Wk", | |
| note=f"Key matrix head {head_idx}") | |
| log.log(f"{prefix}_V", V[0], | |
| formula="V = X · Wv", | |
| note=f"Value matrix head {head_idx}") | |
| log.log(f"{prefix}_QKt", scores[0], | |
| formula="scores = Q · Kᵀ", | |
| note=f"Raw attention scores (before scaling)") | |
| # Scale | |
| scale = math.sqrt(d_k) | |
| scores = scores / scale | |
| if log: | |
| log.log(f"{prefix}_QKt_scaled", scores[0], | |
| formula=f"scores = Q·Kᵀ / √{d_k} = Q·Kᵀ / {scale:.3f}", | |
| note="Scaled scores — prevents vanishing gradients") | |
| # Mask | |
| # masks arrive as (B,1,1,T) or (B,1,T,T) from make_src/tgt_mask; | |
| # scores here are 3-D (B,T_q,T_k) because we loop per-head, | |
| # so squeeze the head dim to avoid (B,B,...) broadcasting. | |
| if mask is not None: | |
| if mask.dim() == 4: | |
| mask = mask.squeeze(1) # (B,1,T,T) or (B,1,1,T) → (B,T,T) or (B,1,T) | |
| scores = scores.masked_fill(mask == 0, float("-inf")) | |
| if log: | |
| log.log(f"{prefix}_mask", mask[0].float(), | |
| formula="mask[i,j]=0 → score=-inf (future token blocked)", | |
| note="Causal mask (training decoder) or padding mask") | |
| log.log(f"{prefix}_scores_masked", scores[0], | |
| note="Scores after masking (-inf will become 0 after softmax)") | |
| # Softmax | |
| attn_weights = F.softmax(scores, dim=-1) | |
| # replace nan from -inf rows with 0 (edge case) | |
| attn_weights = torch.nan_to_num(attn_weights, nan=0.0) | |
| if log: | |
| log.log(f"{prefix}_softmax", attn_weights[0], | |
| formula="α = softmax(scores, dim=-1)", | |
| note="Attention weights — each row sums to 1.0") | |
| # Weighted sum | |
| output = torch.matmul(attn_weights, V) | |
| if log: | |
| log.log(f"{prefix}_output", output[0], | |
| formula="Attention = α · V", | |
| note="Weighted sum of values") | |
| return output, attn_weights | |
| # ───────────────────────────────────────────── | |
| # Multi-Head Attention | |
| # ───────────────────────────────────────────── | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int): | |
| super().__init__() | |
| assert d_model % num_heads == 0 | |
| self.d_model = d_model | |
| self.num_heads = num_heads | |
| self.d_k = d_model // num_heads | |
| self.W_q = nn.Linear(d_model, d_model, bias=False) | |
| self.W_k = nn.Linear(d_model, d_model, bias=False) | |
| self.W_v = nn.Linear(d_model, d_model, bias=False) | |
| self.W_o = nn.Linear(d_model, d_model, bias=False) | |
| def split_heads(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, D = x.shape | |
| return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2) | |
| # → (B, num_heads, T, d_k) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| log: Optional[CalcLog] = None, | |
| layer_idx: int = 0, | |
| attn_type: str = "self", | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| B = query.size(0) | |
| prefix = f"L{layer_idx}_{attn_type}_MHA" | |
| # Linear projections | |
| Q = self.W_q(query) | |
| K = self.W_k(key) | |
| V = self.W_v(value) | |
| if log: | |
| log.log(f"{prefix}_Wq", self.W_q.weight[:4, :4], | |
| formula="Wq shape: (d_model, d_model)", | |
| note=f"Query weight matrix (first 4×4 shown)") | |
| log.log(f"{prefix}_Q_full", Q[0, :, :8], | |
| formula="Q = input · Wq", | |
| note=f"Full Q projection (first 8 dims shown)") | |
| # Split into heads | |
| Q = self.split_heads(Q) # (B, h, T, d_k) | |
| K = self.split_heads(K) | |
| V = self.split_heads(V) | |
| if log: | |
| log.log(f"{prefix}_Q_head0", Q[0, 0, :, :], | |
| formula=f"Split: (B,T,D) → (B,{self.num_heads},T,{self.d_k})", | |
| note=f"Head 0 queries — d_k={self.d_k}") | |
| # Per-head attention (log only first 2 heads to avoid bloat) | |
| all_attn = [] | |
| all_weights = [] | |
| for h in range(self.num_heads): | |
| h_log = log if h < 2 else None | |
| out_h, w_h = scaled_dot_product_attention( | |
| Q[:, h], K[:, h], V[:, h], | |
| mask=mask, | |
| log=h_log, | |
| head_idx=h, | |
| layer_idx=layer_idx, | |
| attn_type=attn_type, | |
| ) | |
| all_attn.append(out_h) | |
| all_weights.append(w_h) | |
| # Concat heads | |
| concat = torch.stack(all_attn, dim=1) # (B, h, T, d_k) | |
| concat = concat.transpose(1, 2).contiguous() # (B, T, h, d_k) | |
| concat = concat.view(B, -1, self.d_model) # (B, T, D) | |
| if log: | |
| log.log(f"{prefix}_concat", concat[0, :, :8], | |
| formula="concat = [head_1; head_2; ...; head_h]", | |
| note=f"Concatenated heads (first 8 dims)") | |
| # Final projection | |
| output = self.W_o(concat) | |
| if log: | |
| log.log(f"{prefix}_output", output[0, :, :8], | |
| formula="MHA_out = concat · Wo", | |
| note="Final multi-head attention output") | |
| # Stack all attention weights: (B, h, T_q, T_k) | |
| attn_weights = torch.stack(all_weights, dim=1) | |
| return output, attn_weights | |
| # ───────────────────────────────────────────── | |
| # Feed-Forward Network | |
| # ───────────────────────────────────────────── | |
| class FeedForward(nn.Module): | |
| def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.linear1 = nn.Linear(d_model, d_ff) | |
| self.linear2 = nn.Linear(d_ff, d_model) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None, | |
| layer_idx: int = 0, loc: str = "enc") -> torch.Tensor: | |
| prefix = f"L{layer_idx}_{loc}_FFN" | |
| h = self.linear1(x) | |
| if log: | |
| log.log(f"{prefix}_linear1", h[0, :, :8], | |
| formula="h = X · W1 + b1", | |
| note=f"First linear (d_model→d_ff), showing first 8 dims") | |
| h = F.relu(h) | |
| if log: | |
| log.log(f"{prefix}_relu", h[0, :, :8], | |
| formula="h = ReLU(h) = max(0, h)", | |
| note="Negative values zeroed out") | |
| h = self.dropout(h) | |
| out = self.linear2(h) | |
| if log: | |
| log.log(f"{prefix}_linear2", out[0, :, :8], | |
| formula="out = h · W2 + b2", | |
| note=f"Second linear (d_ff→d_model)") | |
| return out | |
| # ───────────────────────────────────────────── | |
| # Layer Norm + Residual | |
| # ───────────────────────────────────────────── | |
| class AddNorm(nn.Module): | |
| def __init__(self, d_model: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.norm = nn.LayerNorm(d_model, eps=eps) | |
| def forward(self, x: torch.Tensor, sublayer_out: torch.Tensor, | |
| log: Optional[CalcLog] = None, tag: str = "") -> torch.Tensor: | |
| residual = x + sublayer_out | |
| out = self.norm(residual) | |
| if log: | |
| log.log(f"{tag}_residual", residual[0, :, :8], | |
| formula="residual = x + sublayer(x)", | |
| note="Residual (skip) connection") | |
| log.log(f"{tag}_layernorm", out[0, :, :8], | |
| formula="LayerNorm(x) = γ·(x−μ)/σ + β", | |
| note="Layer normalization output") | |
| return out | |
| # ───────────────────────────────────────────── | |
| # Encoder Layer | |
| # ───────────────────────────────────────────── | |
| class EncoderLayer(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, num_heads) | |
| self.ffn = FeedForward(d_model, d_ff, dropout) | |
| self.add_norm1 = AddNorm(d_model) | |
| self.add_norm2 = AddNorm(d_model) | |
| def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None, | |
| log: Optional[CalcLog] = None, layer_idx: int = 0): | |
| attn_out, attn_w = self.self_attn( | |
| x, x, x, mask=src_mask, log=log, | |
| layer_idx=layer_idx, attn_type="enc_self" | |
| ) | |
| x = self.add_norm1(x, attn_out, log=log, tag=f"L{layer_idx}_enc_self") | |
| ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="enc") | |
| x = self.add_norm2(x, ffn_out, log=log, tag=f"L{layer_idx}_enc_ffn") | |
| return x, attn_w | |
| # ───────────────────────────────────────────── | |
| # Decoder Layer | |
| # ───────────────────────────────────────────── | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.masked_self_attn = MultiHeadAttention(d_model, num_heads) | |
| self.cross_attn = MultiHeadAttention(d_model, num_heads) | |
| self.ffn = FeedForward(d_model, d_ff, dropout) | |
| self.add_norm1 = AddNorm(d_model) | |
| self.add_norm2 = AddNorm(d_model) | |
| self.add_norm3 = AddNorm(d_model) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| enc_out: torch.Tensor, | |
| tgt_mask: Optional[torch.Tensor] = None, | |
| src_mask: Optional[torch.Tensor] = None, | |
| log: Optional[CalcLog] = None, | |
| layer_idx: int = 0, | |
| ): | |
| # 1. Masked self-attention | |
| m_attn_out, m_attn_w = self.masked_self_attn( | |
| x, x, x, mask=tgt_mask, log=log, | |
| layer_idx=layer_idx, attn_type="dec_masked" | |
| ) | |
| x = self.add_norm1(x, m_attn_out, log=log, tag=f"L{layer_idx}_dec_masked") | |
| # 2. Cross-attention: Q from decoder, K/V from encoder | |
| if log: | |
| log.log(f"L{layer_idx}_cross_Q_source", x[0, :, :8], | |
| note="Cross-attn Q comes from DECODER (Bengali context)") | |
| log.log(f"L{layer_idx}_cross_KV_source", enc_out[0, :, :8], | |
| note="Cross-attn K,V come from ENCODER (English context)") | |
| c_attn_out, c_attn_w = self.cross_attn( | |
| query=x, key=enc_out, value=enc_out, | |
| mask=src_mask, log=log, | |
| layer_idx=layer_idx, attn_type="dec_cross" | |
| ) | |
| x = self.add_norm2(x, c_attn_out, log=log, tag=f"L{layer_idx}_dec_cross") | |
| # 3. FFN | |
| ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="dec") | |
| x = self.add_norm3(x, ffn_out, log=log, tag=f"L{layer_idx}_dec_ffn") | |
| return x, m_attn_w, c_attn_w | |
| # ───────────────────────────────────────────── | |
| # Full Transformer | |
| # ───────────────────────────────────────────── | |
| class Transformer(nn.Module): | |
| def __init__( | |
| self, | |
| src_vocab_size: int, | |
| tgt_vocab_size: int, | |
| d_model: int = 128, | |
| num_heads: int = 4, | |
| num_layers: int = 2, | |
| d_ff: int = 256, | |
| max_len: int = 64, | |
| dropout: float = 0.1, | |
| pad_idx: int = 0, | |
| ): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.pad_idx = pad_idx | |
| self.num_layers = num_layers | |
| self.src_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx) | |
| self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx) | |
| self.src_pe = PositionalEncoding(d_model, max_len, dropout) | |
| self.tgt_pe = PositionalEncoding(d_model, max_len, dropout) | |
| self.encoder_layers = nn.ModuleList( | |
| [EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)] | |
| ) | |
| self.decoder_layers = nn.ModuleList( | |
| [DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)] | |
| ) | |
| self.output_linear = nn.Linear(d_model, tgt_vocab_size) | |
| self._init_weights() | |
| def _init_weights(self): | |
| for p in self.parameters(): | |
| if p.dim() > 1: | |
| nn.init.xavier_uniform_(p) | |
| # ── mask helpers ────────────────────────── | |
| def make_src_mask(self, src: torch.Tensor) -> torch.Tensor: | |
| # (B, 1, 1, T_src) — 1 where not pad | |
| return (src != self.pad_idx).unsqueeze(1).unsqueeze(2) | |
| def make_tgt_mask(self, tgt: torch.Tensor) -> torch.Tensor: | |
| T = tgt.size(1) | |
| pad_mask = (tgt != self.pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,T) | |
| causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() # (T,T) | |
| return pad_mask & causal # (B,1,T,T) | |
| # ── forward ─────────────────────────────── | |
| def forward( | |
| self, | |
| src: torch.Tensor, | |
| tgt: torch.Tensor, | |
| log: Optional[CalcLog] = None, | |
| ) -> Tuple[torch.Tensor, Dict]: | |
| src_mask = self.make_src_mask(src) | |
| tgt_mask = self.make_tgt_mask(tgt) | |
| # ── Encoder ────────────────────────── | |
| src_emb = self.src_embed(src) * math.sqrt(self.d_model) | |
| if log: | |
| log.log("SRC_tokens", src[0], | |
| note="Source token IDs (English)") | |
| log.log("SRC_embedding_raw", src_emb[0, :, :8], | |
| formula=f"emb = Embedding(token_id) × √{self.d_model}", | |
| note="Token embeddings (first 8 dims)") | |
| enc_x = self.src_pe(src_emb, log=log) | |
| enc_attn_weights = [] | |
| for i, layer in enumerate(self.encoder_layers): | |
| enc_x, ew = layer(enc_x, src_mask=src_mask, log=log, layer_idx=i) | |
| enc_attn_weights.append(ew.detach().cpu().numpy()) | |
| if log: | |
| log.log("ENCODER_output", enc_x[0, :, :8], | |
| note="Final encoder output — passed as K,V to every decoder cross-attention") | |
| # ── Decoder ────────────────────────── | |
| tgt_emb = self.tgt_embed(tgt) * math.sqrt(self.d_model) | |
| if log: | |
| log.log("TGT_tokens", tgt[0], | |
| note="Target token IDs (Bengali, teacher-forced in training)") | |
| log.log("TGT_embedding_raw", tgt_emb[0, :, :8], | |
| formula=f"emb = Embedding(token_id) × √{self.d_model}", | |
| note="Bengali token embeddings") | |
| dec_x = self.tgt_pe(tgt_emb, log=log) | |
| dec_self_attn_w = [] | |
| dec_cross_attn_w = [] | |
| for i, layer in enumerate(self.decoder_layers): | |
| dec_x, mw, cw = layer( | |
| dec_x, enc_x, | |
| tgt_mask=tgt_mask, src_mask=src_mask, | |
| log=log, layer_idx=i, | |
| ) | |
| dec_self_attn_w.append(mw.detach().cpu().numpy()) | |
| dec_cross_attn_w.append(cw.detach().cpu().numpy()) | |
| # ── Output projection ───────────────── | |
| logits = self.output_linear(dec_x) # (B, T, vocab) | |
| if log: | |
| log.log("LOGITS", logits[0, :, :16], | |
| formula="logits = dec_out · W_out (first 16 vocab entries shown)", | |
| note=f"Raw scores over vocab of {logits.size(-1)} Bengali tokens") | |
| probs = F.softmax(logits[0], dim=-1) | |
| log.log("SOFTMAX_probs", probs[:, :16], | |
| formula="P(token) = exp(logit) / Σ exp(logits)", | |
| note="Probability distribution over Bengali vocabulary") | |
| meta = { | |
| "enc_attn": enc_attn_weights, | |
| "dec_self_attn": dec_self_attn_w, | |
| "dec_cross_attn": dec_cross_attn_w, | |
| "src_mask": src_mask.cpu().numpy(), | |
| "tgt_mask": tgt_mask.cpu().numpy(), | |
| } | |
| return logits, meta | |