""" Training loop for the Transformer translator. =============================================== Provides: • ``TranslationDataset`` – a PyTorch Dataset that tokenises and pads source/target sentence pairs. • ``create_dataloaders`` – builds train / validation DataLoaders with an 90/10 split. • ``train_one_epoch`` – one full pass over the training set. • ``evaluate_loss`` – average loss on the validation set. • ``train`` – full training driver with logging, LR scheduling, checkpointing, and early stopping. Design choices: • Label-smoothed cross-entropy (smoothing = 0.1) for better generalisation. • AdamW with a linear-warmup + cosine-decay schedule (stable for small datasets). • Mixed-precision (AMP) with ``torch.amp`` for speed / memory on T4. • Gradient clipping at max_norm = 1.0 to avoid exploding gradients. """ from __future__ import annotations import math import os import time from dataclasses import dataclass, field from pathlib import Path from typing import List, Optional, Tuple import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, random_split from tokenizers import Tokenizer # ────────────────────────────────────────────────────────────────────── # 1. Translation Dataset # ────────────────────────────────────────────────────────────────────── class TranslationDataset(Dataset): """ Wraps a HuggingFace dataset of translation pairs into a PyTorch Dataset that returns padded token-ID tensors. Each ``__getitem__`` returns:: { "src": LongTensor[max_len], # source token IDs (padded) "tgt": LongTensor[max_len], # target input (with [BOS], no final [EOS]) "label": LongTensor[max_len], # target labels (no [BOS], with [EOS]) } The *tgt* / *label* split implements **teacher forcing**: the decoder receives ``[BOS] w1 w2 …`` and must predict ``w1 w2 … [EOS]``. """ def __init__( self, hf_dataset, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, src_lang: str = "en", tgt_lang: str = "ms", max_len: int = 128, pad_id: int = 0, ): self.data = hf_dataset self.src_tok = src_tokenizer self.tgt_tok = tgt_tokenizer self.src_lang = src_lang self.tgt_lang = tgt_lang self.max_len = max_len self.pad_id = pad_id def __len__(self) -> int: return len(self.data) def _pad(self, ids: List[int]) -> List[int]: """Truncate to max_len, then right-pad with pad_id.""" ids = ids[: self.max_len] return ids + [self.pad_id] * (self.max_len - len(ids)) def __getitem__(self, idx: int) -> dict: pair = self.data[idx]["translation"] # Encode (includes [BOS] … [EOS] from post-processor) src_ids = self.src_tok.encode(pair[self.src_lang]).ids tgt_ids = self.tgt_tok.encode(pair[self.tgt_lang]).ids # Teacher-forcing split: # tgt_input = [BOS] w1 w2 … wN (drop last token) # tgt_label = w1 w2 … wN [EOS] (drop first token) tgt_input = tgt_ids[:-1] tgt_label = tgt_ids[1:] return { "src": torch.tensor(self._pad(src_ids), dtype=torch.long), "tgt": torch.tensor(self._pad(tgt_input), dtype=torch.long), "label": torch.tensor(self._pad(tgt_label), dtype=torch.long), } # ────────────────────────────────────────────────────────────────────── # 2. DataLoader factory # ────────────────────────────────────────────────────────────────────── def create_dataloaders( hf_dataset, src_tokenizer: Tokenizer, tgt_tokenizer: Tokenizer, src_lang: str = "en", tgt_lang: str = "ms", max_len: int = 128, batch_size: int = 32, val_ratio: float = 0.1, pad_id: int = 0, seed: int = 42, ) -> Tuple[DataLoader, DataLoader, TranslationDataset]: """ Build training and validation DataLoaders from a HuggingFace dataset. Returns ------- train_loader, val_loader, full_dataset """ full_ds = TranslationDataset( hf_dataset, src_tokenizer, tgt_tokenizer, src_lang, tgt_lang, max_len, pad_id, ) val_size = max(1, int(len(full_ds) * val_ratio)) train_size = len(full_ds) - val_size generator = torch.Generator().manual_seed(seed) train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=generator) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False) print(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}") return train_loader, val_loader, full_ds # ────────────────────────────────────────────────────────────────────── # 3. Training configuration dataclass # ────────────────────────────────────────────────────────────────────── @dataclass class TrainConfig: """All tuneable knobs in one place.""" epochs: int = 50 batch_size: int = 32 max_len: int = 128 lr: float = 5e-4 warmup_steps: int = 200 label_smoothing: float = 0.1 grad_clip: float = 1.0 use_amp: bool = True val_ratio: float = 0.1 checkpoint_dir: str = "training/checkpoints" log_every: int = 10 # print loss every N steps patience: int = 10 # early-stopping patience (epochs) seed: int = 42 # ────────────────────────────────────────────────────────────────────── # 4. LR scheduler with linear warmup + cosine decay # ────────────────────────────────────────────────────────────────────── def _build_scheduler(optimizer, warmup_steps: int, total_steps: int): """Linear warmup for `warmup_steps`, then cosine decay to 0.""" def lr_lambda(step): if step < warmup_steps: return step / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) return 0.5 * (1.0 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # ────────────────────────────────────────────────────────────────────── # 5. Single-epoch training # ────────────────────────────────────────────────────────────────────── def train_one_epoch( model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer, scheduler, criterion: nn.Module, device: torch.device, scaler: Optional[torch.amp.GradScaler], grad_clip: float = 1.0, log_every: int = 10, epoch: int = 0, ) -> float: """Train for one epoch. Returns average loss.""" model.train() total_loss = 0.0 n_tokens = 0 for step, batch in enumerate(loader): src = batch["src"].to(device) tgt = batch["tgt"].to(device) label = batch["label"].to(device) optimizer.zero_grad() amp_enabled = scaler is not None with torch.amp.autocast("cuda", enabled=amp_enabled): logits = model(src, tgt) # (B, T, V) loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1)) if scaler is not None: scaler.scale(loss).backward() scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() scheduler.step() # Accumulate loss (ignore padding contribution) non_pad = (label != model.pad_idx).sum().item() total_loss += loss.item() * non_pad n_tokens += non_pad if (step + 1) % log_every == 0: avg = total_loss / max(n_tokens, 1) lr = scheduler.get_last_lr()[0] print(f" Epoch {epoch+1} | Step {step+1}/{len(loader)} | Loss {avg:.4f} | LR {lr:.2e}") return total_loss / max(n_tokens, 1) # ────────────────────────────────────────────────────────────────────── # 6. Validation loss # ────────────────────────────────────────────────────────────────────── @torch.no_grad() def evaluate_loss( model: nn.Module, loader: DataLoader, criterion: nn.Module, device: torch.device, use_amp: bool = False, ) -> float: """Compute average loss over a validation set (with AMP to match training).""" model.eval() total_loss = 0.0 n_tokens = 0 n_batches = len(loader) for step, batch in enumerate(loader): src = batch["src"].to(device) tgt = batch["tgt"].to(device) label = batch["label"].to(device) with torch.amp.autocast("cuda", enabled=use_amp): logits = model(src, tgt) loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1)) non_pad = (label != model.pad_idx).sum().item() total_loss += loss.item() * non_pad n_tokens += non_pad if (step + 1) % max(1, n_batches // 4) == 0 or (step + 1) == n_batches: print(f" Val {step+1}/{n_batches}", end="\r") return total_loss / max(n_tokens, 1) # ────────────────────────────────────────────────────────────────────── # 7. Full training driver # ────────────────────────────────────────────────────────────────────── def train( model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, cfg: TrainConfig, device: torch.device, trial=None, resume_from: Optional[str] = None, epoch_callback=None, ) -> dict: """ Full training loop with logging, checkpointing, and early stopping. Parameters ---------- trial : optuna.trial.Trial, optional If provided, reports val_loss after each epoch for ASHA pruning. resume_from : str, optional Path to a ``resume_state.pt`` file. If provided, training resumes from the saved epoch with the exact optimizer / scheduler / scaler state, history, and early-stopping counters. epoch_callback : callable, optional Called after every epoch as ``epoch_callback(epoch, history)``. Useful for live plotting in notebooks. Returns ------- history : dict ``{"train_loss": [...], "val_loss": [...], "lr": [...]}`` """ # --- Loss function (label-smoothed CE, ignoring PAD) --------------- criterion = nn.CrossEntropyLoss( ignore_index=model.pad_idx, label_smoothing=cfg.label_smoothing, ) # --- Optimiser ------------------------------------------------------ optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.98), eps=1e-9) # --- LR schedule --------------------------------------------------- total_steps = cfg.epochs * len(train_loader) scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps) # --- AMP scaler ---------------------------------------------------- scaler = torch.amp.GradScaler("cuda") if (cfg.use_amp and device.type == "cuda") else None # --- Checkpoint dir ------------------------------------------------ ckpt_dir = Path(cfg.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) history: dict = {"train_loss": [], "val_loss": [], "lr": []} best_val = float("inf") patience_ctr = 0 start_epoch = 0 # --- Resume from checkpoint ---------------------------------------- if resume_from is not None and os.path.exists(resume_from): print(f"\n🔄 Resuming from {resume_from}") ckpt = torch.load(resume_from, map_location=device, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) if scaler is not None and "scaler_state_dict" in ckpt: scaler.load_state_dict(ckpt["scaler_state_dict"]) start_epoch = ckpt["epoch"] + 1 # resume from *next* epoch best_val = ckpt["best_val_loss"] patience_ctr = ckpt["patience_ctr"] history = ckpt["history"] print(f" Resumed at epoch {start_epoch+1}/{cfg.epochs} | " f"best_val={best_val:.4f} | patience={patience_ctr}/{cfg.patience}") print(f"\n{'='*60}") print(f"Starting training: {cfg.epochs} epochs (from epoch {start_epoch+1}), lr={cfg.lr}, AMP={cfg.use_amp}") print(f"{'='*60}\n") for epoch in range(start_epoch, cfg.epochs): t0 = time.time() train_loss = train_one_epoch( model, train_loader, optimizer, scheduler, criterion, device, scaler, cfg.grad_clip, cfg.log_every, epoch, ) use_amp = cfg.use_amp and device.type == "cuda" val_loss = evaluate_loss(model, val_loader, criterion, device, use_amp=use_amp) lr = scheduler.get_last_lr()[0] elapsed = time.time() - t0 history["train_loss"].append(train_loss) history["val_loss"].append(val_loss) history["lr"].append(lr) print( f"Epoch {epoch+1}/{cfg.epochs} | " f"Train {train_loss:.4f} | Val {val_loss:.4f} | " f"LR {lr:.2e} | {elapsed:.1f}s" ) # --- Optuna ASHA pruning (if trial provided) ------------------ if trial is not None: import optuna trial.report(val_loss, epoch) if trial.should_prune(): print(f"\n✂ Optuna pruned this trial at epoch {epoch+1}.") raise optuna.TrialPruned() # --- Checkpoint best model ------------------------------------ if val_loss < best_val: best_val = val_loss patience_ctr = 0 torch.save(model.state_dict(), ckpt_dir / "best_model.pt") print(f" ↳ New best val loss — checkpoint saved.") else: patience_ctr += 1 if patience_ctr >= cfg.patience: print(f"\n⏹ Early stopping after {cfg.patience} epochs without improvement.") break # --- Save resumable state after every epoch -------------------- resume_state = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "scaler_state_dict": scaler.state_dict() if scaler is not None else None, "best_val_loss": best_val, "patience_ctr": patience_ctr, "history": history, "cfg_epochs": cfg.epochs, } torch.save(resume_state, ckpt_dir / "resume_state.pt") # --- Epoch callback (e.g. live plotting) ---------------------- if epoch_callback is not None: epoch_callback(epoch, history) # Load best checkpoint model.load_state_dict(torch.load(ckpt_dir / "best_model.pt", map_location=device, weights_only=True)) print(f"\n✓ Training complete. Best val loss: {best_val:.4f}") return history