Transformer-Visualizer / transformer.py
priyadip
Fix: js in gr.Blocks(), event delegation for card clicks, SVG loss curve
dc138e1
"""
transformer.py
Full Transformer implementation for English → Bengali translation
with complete calculation tracking at every step.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, Dict, List
# ─────────────────────────────────────────────
# Calculation Logger
# ─────────────────────────────────────────────
class CalcLog:
"""Captures every intermediate tensor for visualization."""
def __init__(self):
self.steps: List[Dict] = []
def log(self, name: str, data, formula: str = "", note: str = ""):
entry = {
"name": name,
"formula": formula,
"note": note,
"shape": None,
"value": None,
}
if isinstance(data, torch.Tensor):
entry["shape"] = list(data.shape)
entry["value"] = data.detach().cpu().numpy().tolist()
elif isinstance(data, np.ndarray):
entry["shape"] = list(data.shape)
entry["value"] = data.tolist()
else:
entry["value"] = data
self.steps.append(entry)
return data
def clear(self):
self.steps = []
def to_dict(self):
return self.steps
# ─────────────────────────────────────────────
# Positional Encoding
# ─────────────────────────────────────────────
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, max_len: int = 512, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None) -> torch.Tensor:
seq_len = x.size(1)
pe_slice = self.pe[:, :seq_len, :]
if log:
log.log("PE_matrix", pe_slice[0, :seq_len, :8],
formula="PE(pos,2i)=sin(pos/10000^(2i/d)), PE(pos,2i+1)=cos(...)",
note=f"Showing first 8 dims for {seq_len} positions")
log.log("Embedding_before_PE", x[0, :, :8],
note="Token embeddings (first 8 dims)")
x = x + pe_slice
if log:
log.log("Embedding_after_PE", x[0, :, :8],
formula="X = Embedding + PE",
note="After adding positional encoding")
return self.dropout(x)
# ─────────────────────────────────────────────
# Scaled Dot-Product Attention
# ─────────────────────────────────────────────
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: Optional[torch.Tensor] = None,
log: Optional[CalcLog] = None,
head_idx: int = 0,
layer_idx: int = 0,
attn_type: str = "self",
) -> Tuple[torch.Tensor, torch.Tensor]:
d_k = Q.size(-1)
prefix = f"L{layer_idx}_H{head_idx}_{attn_type}"
# Raw scores
scores = torch.matmul(Q, K.transpose(-2, -1))
if log:
log.log(f"{prefix}_Q", Q[0],
formula="Q = X · Wq",
note=f"Query matrix head {head_idx}")
log.log(f"{prefix}_K", K[0],
formula="K = X · Wk",
note=f"Key matrix head {head_idx}")
log.log(f"{prefix}_V", V[0],
formula="V = X · Wv",
note=f"Value matrix head {head_idx}")
log.log(f"{prefix}_QKt", scores[0],
formula="scores = Q · Kᵀ",
note=f"Raw attention scores (before scaling)")
# Scale
scale = math.sqrt(d_k)
scores = scores / scale
if log:
log.log(f"{prefix}_QKt_scaled", scores[0],
formula=f"scores = Q·Kᵀ / √{d_k} = Q·Kᵀ / {scale:.3f}",
note="Scaled scores — prevents vanishing gradients")
# Mask
# masks arrive as (B,1,1,T) or (B,1,T,T) from make_src/tgt_mask;
# scores here are 3-D (B,T_q,T_k) because we loop per-head,
# so squeeze the head dim to avoid (B,B,...) broadcasting.
if mask is not None:
if mask.dim() == 4:
mask = mask.squeeze(1) # (B,1,T,T) or (B,1,1,T) → (B,T,T) or (B,1,T)
scores = scores.masked_fill(mask == 0, float("-inf"))
if log:
log.log(f"{prefix}_mask", mask[0].float(),
formula="mask[i,j]=0 → score=-inf (future token blocked)",
note="Causal mask (training decoder) or padding mask")
log.log(f"{prefix}_scores_masked", scores[0],
note="Scores after masking (-inf will become 0 after softmax)")
# Softmax
attn_weights = F.softmax(scores, dim=-1)
# replace nan from -inf rows with 0 (edge case)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
if log:
log.log(f"{prefix}_softmax", attn_weights[0],
formula="α = softmax(scores, dim=-1)",
note="Attention weights — each row sums to 1.0")
# Weighted sum
output = torch.matmul(attn_weights, V)
if log:
log.log(f"{prefix}_output", output[0],
formula="Attention = α · V",
note="Weighted sum of values")
return output, attn_weights
# ─────────────────────────────────────────────
# Multi-Head Attention
# ─────────────────────────────────────────────
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
# → (B, num_heads, T, d_k)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
log: Optional[CalcLog] = None,
layer_idx: int = 0,
attn_type: str = "self",
) -> Tuple[torch.Tensor, torch.Tensor]:
B = query.size(0)
prefix = f"L{layer_idx}_{attn_type}_MHA"
# Linear projections
Q = self.W_q(query)
K = self.W_k(key)
V = self.W_v(value)
if log:
log.log(f"{prefix}_Wq", self.W_q.weight[:4, :4],
formula="Wq shape: (d_model, d_model)",
note=f"Query weight matrix (first 4×4 shown)")
log.log(f"{prefix}_Q_full", Q[0, :, :8],
formula="Q = input · Wq",
note=f"Full Q projection (first 8 dims shown)")
# Split into heads
Q = self.split_heads(Q) # (B, h, T, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
if log:
log.log(f"{prefix}_Q_head0", Q[0, 0, :, :],
formula=f"Split: (B,T,D) → (B,{self.num_heads},T,{self.d_k})",
note=f"Head 0 queries — d_k={self.d_k}")
# Per-head attention (log only first 2 heads to avoid bloat)
all_attn = []
all_weights = []
for h in range(self.num_heads):
h_log = log if h < 2 else None
out_h, w_h = scaled_dot_product_attention(
Q[:, h], K[:, h], V[:, h],
mask=mask,
log=h_log,
head_idx=h,
layer_idx=layer_idx,
attn_type=attn_type,
)
all_attn.append(out_h)
all_weights.append(w_h)
# Concat heads
concat = torch.stack(all_attn, dim=1) # (B, h, T, d_k)
concat = concat.transpose(1, 2).contiguous() # (B, T, h, d_k)
concat = concat.view(B, -1, self.d_model) # (B, T, D)
if log:
log.log(f"{prefix}_concat", concat[0, :, :8],
formula="concat = [head_1; head_2; ...; head_h]",
note=f"Concatenated heads (first 8 dims)")
# Final projection
output = self.W_o(concat)
if log:
log.log(f"{prefix}_output", output[0, :, :8],
formula="MHA_out = concat · Wo",
note="Final multi-head attention output")
# Stack all attention weights: (B, h, T_q, T_k)
attn_weights = torch.stack(all_weights, dim=1)
return output, attn_weights
# ─────────────────────────────────────────────
# Feed-Forward Network
# ─────────────────────────────────────────────
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, log: Optional[CalcLog] = None,
layer_idx: int = 0, loc: str = "enc") -> torch.Tensor:
prefix = f"L{layer_idx}_{loc}_FFN"
h = self.linear1(x)
if log:
log.log(f"{prefix}_linear1", h[0, :, :8],
formula="h = X · W1 + b1",
note=f"First linear (d_model→d_ff), showing first 8 dims")
h = F.relu(h)
if log:
log.log(f"{prefix}_relu", h[0, :, :8],
formula="h = ReLU(h) = max(0, h)",
note="Negative values zeroed out")
h = self.dropout(h)
out = self.linear2(h)
if log:
log.log(f"{prefix}_linear2", out[0, :, :8],
formula="out = h · W2 + b2",
note=f"Second linear (d_ff→d_model)")
return out
# ─────────────────────────────────────────────
# Layer Norm + Residual
# ─────────────────────────────────────────────
class AddNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(d_model, eps=eps)
def forward(self, x: torch.Tensor, sublayer_out: torch.Tensor,
log: Optional[CalcLog] = None, tag: str = "") -> torch.Tensor:
residual = x + sublayer_out
out = self.norm(residual)
if log:
log.log(f"{tag}_residual", residual[0, :, :8],
formula="residual = x + sublayer(x)",
note="Residual (skip) connection")
log.log(f"{tag}_layernorm", out[0, :, :8],
formula="LayerNorm(x) = γ·(x−μ)/σ + β",
note="Layer normalization output")
return out
# ─────────────────────────────────────────────
# Encoder Layer
# ─────────────────────────────────────────────
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.add_norm1 = AddNorm(d_model)
self.add_norm2 = AddNorm(d_model)
def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None,
log: Optional[CalcLog] = None, layer_idx: int = 0):
attn_out, attn_w = self.self_attn(
x, x, x, mask=src_mask, log=log,
layer_idx=layer_idx, attn_type="enc_self"
)
x = self.add_norm1(x, attn_out, log=log, tag=f"L{layer_idx}_enc_self")
ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="enc")
x = self.add_norm2(x, ffn_out, log=log, tag=f"L{layer_idx}_enc_ffn")
return x, attn_w
# ─────────────────────────────────────────────
# Decoder Layer
# ─────────────────────────────────────────────
class DecoderLayer(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.masked_self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ffn = FeedForward(d_model, d_ff, dropout)
self.add_norm1 = AddNorm(d_model)
self.add_norm2 = AddNorm(d_model)
self.add_norm3 = AddNorm(d_model)
def forward(
self,
x: torch.Tensor,
enc_out: torch.Tensor,
tgt_mask: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
log: Optional[CalcLog] = None,
layer_idx: int = 0,
):
# 1. Masked self-attention
m_attn_out, m_attn_w = self.masked_self_attn(
x, x, x, mask=tgt_mask, log=log,
layer_idx=layer_idx, attn_type="dec_masked"
)
x = self.add_norm1(x, m_attn_out, log=log, tag=f"L{layer_idx}_dec_masked")
# 2. Cross-attention: Q from decoder, K/V from encoder
if log:
log.log(f"L{layer_idx}_cross_Q_source", x[0, :, :8],
note="Cross-attn Q comes from DECODER (Bengali context)")
log.log(f"L{layer_idx}_cross_KV_source", enc_out[0, :, :8],
note="Cross-attn K,V come from ENCODER (English context)")
c_attn_out, c_attn_w = self.cross_attn(
query=x, key=enc_out, value=enc_out,
mask=src_mask, log=log,
layer_idx=layer_idx, attn_type="dec_cross"
)
x = self.add_norm2(x, c_attn_out, log=log, tag=f"L{layer_idx}_dec_cross")
# 3. FFN
ffn_out = self.ffn(x, log=log, layer_idx=layer_idx, loc="dec")
x = self.add_norm3(x, ffn_out, log=log, tag=f"L{layer_idx}_dec_ffn")
return x, m_attn_w, c_attn_w
# ─────────────────────────────────────────────
# Full Transformer
# ─────────────────────────────────────────────
class Transformer(nn.Module):
def __init__(
self,
src_vocab_size: int,
tgt_vocab_size: int,
d_model: int = 128,
num_heads: int = 4,
num_layers: int = 2,
d_ff: int = 256,
max_len: int = 64,
dropout: float = 0.1,
pad_idx: int = 0,
):
super().__init__()
self.d_model = d_model
self.pad_idx = pad_idx
self.num_layers = num_layers
self.src_embed = nn.Embedding(src_vocab_size, d_model, padding_idx=pad_idx)
self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model, padding_idx=pad_idx)
self.src_pe = PositionalEncoding(d_model, max_len, dropout)
self.tgt_pe = PositionalEncoding(d_model, max_len, dropout)
self.encoder_layers = nn.ModuleList(
[EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
)
self.decoder_layers = nn.ModuleList(
[DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)]
)
self.output_linear = nn.Linear(d_model, tgt_vocab_size)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
# ── mask helpers ──────────────────────────
def make_src_mask(self, src: torch.Tensor) -> torch.Tensor:
# (B, 1, 1, T_src) — 1 where not pad
return (src != self.pad_idx).unsqueeze(1).unsqueeze(2)
def make_tgt_mask(self, tgt: torch.Tensor) -> torch.Tensor:
T = tgt.size(1)
pad_mask = (tgt != self.pad_idx).unsqueeze(1).unsqueeze(2) # (B,1,1,T)
causal = torch.tril(torch.ones(T, T, device=tgt.device)).bool() # (T,T)
return pad_mask & causal # (B,1,T,T)
# ── forward ───────────────────────────────
def forward(
self,
src: torch.Tensor,
tgt: torch.Tensor,
log: Optional[CalcLog] = None,
) -> Tuple[torch.Tensor, Dict]:
src_mask = self.make_src_mask(src)
tgt_mask = self.make_tgt_mask(tgt)
# ── Encoder ──────────────────────────
src_emb = self.src_embed(src) * math.sqrt(self.d_model)
if log:
log.log("SRC_tokens", src[0],
note="Source token IDs (English)")
log.log("SRC_embedding_raw", src_emb[0, :, :8],
formula=f"emb = Embedding(token_id) × √{self.d_model}",
note="Token embeddings (first 8 dims)")
enc_x = self.src_pe(src_emb, log=log)
enc_attn_weights = []
for i, layer in enumerate(self.encoder_layers):
enc_x, ew = layer(enc_x, src_mask=src_mask, log=log, layer_idx=i)
enc_attn_weights.append(ew.detach().cpu().numpy())
if log:
log.log("ENCODER_output", enc_x[0, :, :8],
note="Final encoder output — passed as K,V to every decoder cross-attention")
# ── Decoder ──────────────────────────
tgt_emb = self.tgt_embed(tgt) * math.sqrt(self.d_model)
if log:
log.log("TGT_tokens", tgt[0],
note="Target token IDs (Bengali, teacher-forced in training)")
log.log("TGT_embedding_raw", tgt_emb[0, :, :8],
formula=f"emb = Embedding(token_id) × √{self.d_model}",
note="Bengali token embeddings")
dec_x = self.tgt_pe(tgt_emb, log=log)
dec_self_attn_w = []
dec_cross_attn_w = []
for i, layer in enumerate(self.decoder_layers):
dec_x, mw, cw = layer(
dec_x, enc_x,
tgt_mask=tgt_mask, src_mask=src_mask,
log=log, layer_idx=i,
)
dec_self_attn_w.append(mw.detach().cpu().numpy())
dec_cross_attn_w.append(cw.detach().cpu().numpy())
# ── Output projection ─────────────────
logits = self.output_linear(dec_x) # (B, T, vocab)
if log:
log.log("LOGITS", logits[0, :, :16],
formula="logits = dec_out · W_out (first 16 vocab entries shown)",
note=f"Raw scores over vocab of {logits.size(-1)} Bengali tokens")
probs = F.softmax(logits[0], dim=-1)
log.log("SOFTMAX_probs", probs[:, :16],
formula="P(token) = exp(logit) / Σ exp(logits)",
note="Probability distribution over Bengali vocabulary")
meta = {
"enc_attn": enc_attn_weights,
"dec_self_attn": dec_self_attn_w,
"dec_cross_attn": dec_cross_attn_w,
"src_mask": src_mask.cpu().numpy(),
"tgt_mask": tgt_mask.cpu().numpy(),
}
return logits, meta