Spaces:
Sleeping
Sleeping
| """ | |
| training.py | |
| Training loop for English→Bengali transformer with full calculation capture. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import numpy as np | |
| import math | |
| from typing import Dict, List, Tuple, Optional | |
| from transformer import Transformer, CalcLog | |
| from vocab import get_vocabs, PARALLEL_DATA, PAD_IDX, BOS_IDX, EOS_IDX | |
| # ───────────────────────────────────────────── | |
| # Data helpers | |
| # ───────────────────────────────────────────── | |
| def collate_batch(pairs: List[Tuple[str, str]], src_v, tgt_v, device: str = "cpu"): | |
| src_seqs, tgt_seqs = [], [] | |
| for en, bn in pairs: | |
| src_seqs.append(src_v.encode(en)) | |
| tgt_seqs.append(tgt_v.encode(bn)) | |
| def pad(seqs): | |
| max_len = max(len(s) for s in seqs) | |
| padded = [s + [PAD_IDX] * (max_len - len(s)) for s in seqs] | |
| return torch.tensor(padded, dtype=torch.long, device=device) | |
| return pad(src_seqs), pad(tgt_seqs) | |
| # ───────────────────────────────────────────── | |
| # Label-smoothed cross-entropy | |
| # ───────────────────────────────────────────── | |
| class LabelSmoothingLoss(nn.Module): | |
| def __init__(self, vocab_size: int, pad_idx: int, smoothing: float = 0.1): | |
| super().__init__() | |
| self.vocab_size = vocab_size | |
| self.pad_idx = pad_idx | |
| self.smoothing = smoothing | |
| self.confidence = 1.0 - smoothing | |
| def forward(self, logits: torch.Tensor, target: torch.Tensor, | |
| log: Optional[CalcLog] = None) -> torch.Tensor: | |
| B, T, V = logits.shape | |
| logits_flat = logits.reshape(-1, V) | |
| target_flat = target.reshape(-1) | |
| log_probs = torch.log_softmax(logits_flat, dim=-1) | |
| with torch.no_grad(): | |
| smooth_dist = torch.full_like(log_probs, self.smoothing / (V - 2)) | |
| smooth_dist.scatter_(1, target_flat.unsqueeze(1), self.confidence) | |
| smooth_dist[:, self.pad_idx] = 0 | |
| mask = (target_flat == self.pad_idx) | |
| smooth_dist[mask] = 0 | |
| loss = -(smooth_dist * log_probs).sum(dim=-1) | |
| non_pad = (~mask).sum() | |
| loss = loss.sum() / non_pad.clamp(min=1) | |
| if log: | |
| probs_sample = torch.exp(log_probs[:4]) | |
| log.log("LOSS_log_probs_sample", probs_sample, | |
| formula="log P(token) = log_softmax(logits)", | |
| note="Softmax probabilities for first 4 target positions") | |
| log.log("LOSS_smooth_dist_sample", smooth_dist[:4], | |
| formula=f"smooth: correct={self.confidence:.2f}, others={self.smoothing/(V-2):.5f}", | |
| note="Label-smoothed target distribution") | |
| log.log("LOSS_value", loss.item(), | |
| formula="L = -Σ smooth_dist · log_probs / n_tokens", | |
| note=f"Label-smoothed cross-entropy loss = {loss.item():.4f}") | |
| return loss | |
| # ───────────────────────────────────────────── | |
| # Build model | |
| # ───────────────────────────────────────────── | |
| def build_model(src_vocab_size: int, tgt_vocab_size: int, | |
| device: str = "cpu") -> Transformer: | |
| model = Transformer( | |
| src_vocab_size=src_vocab_size, | |
| tgt_vocab_size=tgt_vocab_size, | |
| d_model=64, | |
| num_heads=4, | |
| num_layers=2, | |
| d_ff=128, | |
| max_len=32, | |
| dropout=0.1, | |
| pad_idx=PAD_IDX, | |
| ).to(device) | |
| return model | |
| # ───────────────────────────────────────────── | |
| # Single training step (with full logging) | |
| # ───────────────────────────────────────────── | |
| def training_step( | |
| model: Transformer, | |
| src: torch.Tensor, | |
| tgt: torch.Tensor, | |
| criterion: LabelSmoothingLoss, | |
| optimizer: optim.Optimizer, | |
| log: CalcLog, | |
| step_num: int = 0, | |
| ) -> Dict: | |
| model.train() | |
| log.clear() | |
| # Teacher forcing: decoder input = [BOS, token_1, ..., token_{T-1}] | |
| tgt_input = tgt[:, :-1] | |
| tgt_target = tgt[:, 1:] | |
| log.log("TRAINING_SETUP", { | |
| "mode": "TRAINING", | |
| "step": step_num, | |
| "src_shape": list(src.shape), | |
| "tgt_input_shape": list(tgt_input.shape), | |
| "tgt_target_shape": list(tgt_target.shape), | |
| }, formula="Teacher Forcing: feed ground-truth Bengali tokens as decoder input", | |
| note="During training, decoder sees actual Bengali tokens (not its own predictions)") | |
| log.log("SRC_sentence_ids", src[0].tolist(), | |
| note="Source (English) token IDs fed to encoder") | |
| log.log("TGT_input_ids", tgt_input[0].tolist(), | |
| note="Target input to decoder (shifted right — starts with <BOS>)") | |
| log.log("TGT_target_ids", tgt_target[0].tolist(), | |
| note="What decoder must predict (shifted left — ends with <EOS>)") | |
| # Forward | |
| logits, meta = model(src, tgt_input, log=log) | |
| # Loss | |
| loss = criterion(logits, tgt_target, log=log) | |
| log.log("LOSS_final", loss.item(), | |
| formula="Total loss = label-smoothed cross-entropy averaged over tokens", | |
| note=f"Loss = {loss.item():.4f} (lower = better prediction)") | |
| # Backward | |
| optimizer.zero_grad() | |
| loss.backward() | |
| # Gradient stats | |
| grad_norms = {} | |
| for name, param in model.named_parameters(): | |
| if param.grad is not None: | |
| gn = param.grad.norm().item() | |
| grad_norms[name] = round(gn, 6) | |
| log.log("GRADIENTS_norm_sample", dict(list(grad_norms.items())[:8]), | |
| formula="∂L/∂W via backpropagation (chain rule)", | |
| note="Gradient norms for first 8 parameter tensors") | |
| # Gradient clipping | |
| nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| log.log("OPTIMIZER_step", { | |
| "algorithm": "Adam", | |
| "lr": optimizer.param_groups[0]["lr"], | |
| "note": "W = W - lr × (m̂ / (√v̂ + ε)) (Adam update rule)", | |
| }, formula="Adam: adaptive learning rate with momentum", | |
| note="Weights updated — model slightly improved") | |
| return { | |
| "loss": loss.item(), | |
| "calc_log": log.to_dict(), | |
| "meta": {k: v.tolist() if hasattr(v, "tolist") else v for k, v in meta.items() | |
| if k != "enc_attn"}, | |
| } | |
| # ───────────────────────────────────────────── | |
| # Full training run (quick demo) | |
| # ───────────────────────────────────────────── | |
| def run_training( | |
| epochs: int = 30, | |
| device: str = "cpu", | |
| progress_cb=None, | |
| ) -> Tuple[Transformer, List[float]]: | |
| src_v, tgt_v = get_vocabs() | |
| model = build_model(len(src_v), len(tgt_v), device) | |
| criterion = LabelSmoothingLoss(len(tgt_v), PAD_IDX, smoothing=0.1) | |
| optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.98), eps=1e-9) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5) | |
| losses = [] | |
| src_batch, tgt_batch = collate_batch(PARALLEL_DATA, src_v, tgt_v, device) | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| tgt_input = tgt_batch[:, :-1] | |
| tgt_target = tgt_batch[:, 1:] | |
| logits, _ = model(src_batch, tgt_input, log=None) | |
| loss = criterion(logits, tgt_target) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| scheduler.step(loss.item()) | |
| losses.append(loss.item()) | |
| if progress_cb: | |
| progress_cb(epoch, epochs, loss.item()) | |
| return model, losses | |
| # ───────────────────────────────────────────── | |
| # Single-sample step for visualization | |
| # ───────────────────────────────────────────── | |
| def visualize_training_step( | |
| model: Transformer, | |
| en_sentence: str, | |
| bn_sentence: str, | |
| device: str = "cpu", | |
| ) -> Dict: | |
| src_v, tgt_v = get_vocabs() | |
| log = CalcLog() | |
| src_ids = src_v.encode(en_sentence) | |
| tgt_ids = tgt_v.encode(bn_sentence) | |
| log.log("TOKENIZATION_EN", { | |
| "sentence": en_sentence, | |
| "tokens": en_sentence.lower().split(), | |
| "ids": src_ids, | |
| "vocab_size": len(src_v), | |
| }, formula="token_id = vocab[word]", | |
| note="English → token IDs (BOS prepended, EOS appended)") | |
| log.log("TOKENIZATION_BN", { | |
| "sentence": bn_sentence, | |
| "tokens": bn_sentence.split(), | |
| "ids": tgt_ids, | |
| "vocab_size": len(tgt_v), | |
| }, note="Bengali → token IDs (teacher-forced during training)") | |
| src = torch.tensor([src_ids], dtype=torch.long, device=device) | |
| tgt = torch.tensor([tgt_ids], dtype=torch.long, device=device) | |
| criterion = LabelSmoothingLoss(len(tgt_v), PAD_IDX) | |
| optimizer = optim.Adam(model.parameters(), lr=1e-3) | |
| result = training_step(model, src, tgt, criterion, optimizer, log) | |
| src_v_obj, tgt_v_obj = get_vocabs() | |
| result["src_tokens"] = src_v_obj.tokens(src_ids) | |
| result["tgt_tokens"] = tgt_v_obj.tokens(tgt_ids) | |
| result["en_sentence"] = en_sentence | |
| result["bn_sentence"] = bn_sentence | |
| return result | |