en-ms-transformer / src /training.py
AstralPotato's picture
v2: 2M training, dropout 0.1, full-corpus tokenizer — chrF 48.93 (was 45.62)
d7fa769 verified
"""
Training loop for the Transformer translator.
===============================================
Provides:
• ``TranslationDataset`` – a PyTorch Dataset that tokenises and pads
source/target sentence pairs.
• ``create_dataloaders`` – builds train / validation DataLoaders with
an 90/10 split.
• ``train_one_epoch`` – one full pass over the training set.
• ``evaluate_loss`` – average loss on the validation set.
• ``train`` – full training driver with logging, LR
scheduling, checkpointing, and early stopping.
Design choices:
• Label-smoothed cross-entropy (smoothing = 0.1) for better
generalisation.
• AdamW with a linear-warmup + cosine-decay schedule (stable for
small datasets).
• Mixed-precision (AMP) with ``torch.amp`` for speed / memory on T4.
• Gradient clipping at max_norm = 1.0 to avoid exploding gradients.
"""
from __future__ import annotations
import math
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from tokenizers import Tokenizer
# ──────────────────────────────────────────────────────────────────────
# 1. Translation Dataset
# ──────────────────────────────────────────────────────────────────────
class TranslationDataset(Dataset):
"""
Wraps a HuggingFace dataset of translation pairs into a PyTorch
Dataset that returns padded token-ID tensors.
Each ``__getitem__`` returns::
{
"src": LongTensor[max_len], # source token IDs (padded)
"tgt": LongTensor[max_len], # target input (with [BOS], no final [EOS])
"label": LongTensor[max_len], # target labels (no [BOS], with [EOS])
}
The *tgt* / *label* split implements **teacher forcing**: the decoder
receives ``[BOS] w1 w2 …`` and must predict ``w1 w2 … [EOS]``.
"""
def __init__(
self,
hf_dataset,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
src_lang: str = "en",
tgt_lang: str = "ms",
max_len: int = 128,
pad_id: int = 0,
):
self.data = hf_dataset
self.src_tok = src_tokenizer
self.tgt_tok = tgt_tokenizer
self.src_lang = src_lang
self.tgt_lang = tgt_lang
self.max_len = max_len
self.pad_id = pad_id
def __len__(self) -> int:
return len(self.data)
def _pad(self, ids: List[int]) -> List[int]:
"""Truncate to max_len, then right-pad with pad_id."""
ids = ids[: self.max_len]
return ids + [self.pad_id] * (self.max_len - len(ids))
def __getitem__(self, idx: int) -> dict:
pair = self.data[idx]["translation"]
# Encode (includes [BOS] … [EOS] from post-processor)
src_ids = self.src_tok.encode(pair[self.src_lang]).ids
tgt_ids = self.tgt_tok.encode(pair[self.tgt_lang]).ids
# Teacher-forcing split:
# tgt_input = [BOS] w1 w2 … wN (drop last token)
# tgt_label = w1 w2 … wN [EOS] (drop first token)
tgt_input = tgt_ids[:-1]
tgt_label = tgt_ids[1:]
return {
"src": torch.tensor(self._pad(src_ids), dtype=torch.long),
"tgt": torch.tensor(self._pad(tgt_input), dtype=torch.long),
"label": torch.tensor(self._pad(tgt_label), dtype=torch.long),
}
# ──────────────────────────────────────────────────────────────────────
# 2. DataLoader factory
# ──────────────────────────────────────────────────────────────────────
def create_dataloaders(
hf_dataset,
src_tokenizer: Tokenizer,
tgt_tokenizer: Tokenizer,
src_lang: str = "en",
tgt_lang: str = "ms",
max_len: int = 128,
batch_size: int = 32,
val_ratio: float = 0.1,
pad_id: int = 0,
seed: int = 42,
) -> Tuple[DataLoader, DataLoader, TranslationDataset]:
"""
Build training and validation DataLoaders from a HuggingFace dataset.
Returns
-------
train_loader, val_loader, full_dataset
"""
full_ds = TranslationDataset(
hf_dataset, src_tokenizer, tgt_tokenizer,
src_lang, tgt_lang, max_len, pad_id,
)
val_size = max(1, int(len(full_ds) * val_ratio))
train_size = len(full_ds) - val_size
generator = torch.Generator().manual_seed(seed)
train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=generator)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)
print(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}")
return train_loader, val_loader, full_ds
# ──────────────────────────────────────────────────────────────────────
# 3. Training configuration dataclass
# ──────────────────────────────────────────────────────────────────────
@dataclass
class TrainConfig:
"""All tuneable knobs in one place."""
epochs: int = 50
batch_size: int = 32
max_len: int = 128
lr: float = 5e-4
warmup_steps: int = 200
label_smoothing: float = 0.1
grad_clip: float = 1.0
use_amp: bool = True
val_ratio: float = 0.1
checkpoint_dir: str = "training/checkpoints"
log_every: int = 10 # print loss every N steps
patience: int = 10 # early-stopping patience (epochs)
seed: int = 42
# ──────────────────────────────────────────────────────────────────────
# 4. LR scheduler with linear warmup + cosine decay
# ──────────────────────────────────────────────────────────────────────
def _build_scheduler(optimizer, warmup_steps: int, total_steps: int):
"""Linear warmup for `warmup_steps`, then cosine decay to 0."""
def lr_lambda(step):
if step < warmup_steps:
return step / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# ──────────────────────────────────────────────────────────────────────
# 5. Single-epoch training
# ──────────────────────────────────────────────────────────────────────
def train_one_epoch(
model: nn.Module,
loader: DataLoader,
optimizer: torch.optim.Optimizer,
scheduler,
criterion: nn.Module,
device: torch.device,
scaler: Optional[torch.amp.GradScaler],
grad_clip: float = 1.0,
log_every: int = 10,
epoch: int = 0,
) -> float:
"""Train for one epoch. Returns average loss."""
model.train()
total_loss = 0.0
n_tokens = 0
for step, batch in enumerate(loader):
src = batch["src"].to(device)
tgt = batch["tgt"].to(device)
label = batch["label"].to(device)
optimizer.zero_grad()
amp_enabled = scaler is not None
with torch.amp.autocast("cuda", enabled=amp_enabled):
logits = model(src, tgt) # (B, T, V)
loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
if scaler is not None:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
scheduler.step()
# Accumulate loss (ignore padding contribution)
non_pad = (label != model.pad_idx).sum().item()
total_loss += loss.item() * non_pad
n_tokens += non_pad
if (step + 1) % log_every == 0:
avg = total_loss / max(n_tokens, 1)
lr = scheduler.get_last_lr()[0]
print(f" Epoch {epoch+1} | Step {step+1}/{len(loader)} | Loss {avg:.4f} | LR {lr:.2e}")
return total_loss / max(n_tokens, 1)
# ──────────────────────────────────────────────────────────────────────
# 6. Validation loss
# ──────────────────────────────────────────────────────────────────────
@torch.no_grad()
def evaluate_loss(
model: nn.Module,
loader: DataLoader,
criterion: nn.Module,
device: torch.device,
use_amp: bool = False,
) -> float:
"""Compute average loss over a validation set (with AMP to match training)."""
model.eval()
total_loss = 0.0
n_tokens = 0
n_batches = len(loader)
for step, batch in enumerate(loader):
src = batch["src"].to(device)
tgt = batch["tgt"].to(device)
label = batch["label"].to(device)
with torch.amp.autocast("cuda", enabled=use_amp):
logits = model(src, tgt)
loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
non_pad = (label != model.pad_idx).sum().item()
total_loss += loss.item() * non_pad
n_tokens += non_pad
if (step + 1) % max(1, n_batches // 4) == 0 or (step + 1) == n_batches:
print(f" Val {step+1}/{n_batches}", end="\r")
return total_loss / max(n_tokens, 1)
# ──────────────────────────────────────────────────────────────────────
# 7. Full training driver
# ──────────────────────────────────────────────────────────────────────
def train(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
cfg: TrainConfig,
device: torch.device,
trial=None,
resume_from: Optional[str] = None,
epoch_callback=None,
) -> dict:
"""
Full training loop with logging, checkpointing, and early stopping.
Parameters
----------
trial : optuna.trial.Trial, optional
If provided, reports val_loss after each epoch for ASHA pruning.
resume_from : str, optional
Path to a ``resume_state.pt`` file. If provided, training resumes
from the saved epoch with the exact optimizer / scheduler / scaler
state, history, and early-stopping counters.
epoch_callback : callable, optional
Called after every epoch as ``epoch_callback(epoch, history)``.
Useful for live plotting in notebooks.
Returns
-------
history : dict
``{"train_loss": [...], "val_loss": [...], "lr": [...]}``
"""
# --- Loss function (label-smoothed CE, ignoring PAD) ---------------
criterion = nn.CrossEntropyLoss(
ignore_index=model.pad_idx,
label_smoothing=cfg.label_smoothing,
)
# --- Optimiser ------------------------------------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.98), eps=1e-9)
# --- LR schedule ---------------------------------------------------
total_steps = cfg.epochs * len(train_loader)
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
# --- AMP scaler ----------------------------------------------------
scaler = torch.amp.GradScaler("cuda") if (cfg.use_amp and device.type == "cuda") else None
# --- Checkpoint dir ------------------------------------------------
ckpt_dir = Path(cfg.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
history: dict = {"train_loss": [], "val_loss": [], "lr": []}
best_val = float("inf")
patience_ctr = 0
start_epoch = 0
# --- Resume from checkpoint ----------------------------------------
if resume_from is not None and os.path.exists(resume_from):
print(f"\n🔄 Resuming from {resume_from}")
ckpt = torch.load(resume_from, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
if scaler is not None and "scaler_state_dict" in ckpt:
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1 # resume from *next* epoch
best_val = ckpt["best_val_loss"]
patience_ctr = ckpt["patience_ctr"]
history = ckpt["history"]
print(f" Resumed at epoch {start_epoch+1}/{cfg.epochs} | "
f"best_val={best_val:.4f} | patience={patience_ctr}/{cfg.patience}")
print(f"\n{'='*60}")
print(f"Starting training: {cfg.epochs} epochs (from epoch {start_epoch+1}), lr={cfg.lr}, AMP={cfg.use_amp}")
print(f"{'='*60}\n")
for epoch in range(start_epoch, cfg.epochs):
t0 = time.time()
train_loss = train_one_epoch(
model, train_loader, optimizer, scheduler, criterion,
device, scaler, cfg.grad_clip, cfg.log_every, epoch,
)
use_amp = cfg.use_amp and device.type == "cuda"
val_loss = evaluate_loss(model, val_loader, criterion, device, use_amp=use_amp)
lr = scheduler.get_last_lr()[0]
elapsed = time.time() - t0
history["train_loss"].append(train_loss)
history["val_loss"].append(val_loss)
history["lr"].append(lr)
print(
f"Epoch {epoch+1}/{cfg.epochs} | "
f"Train {train_loss:.4f} | Val {val_loss:.4f} | "
f"LR {lr:.2e} | {elapsed:.1f}s"
)
# --- Optuna ASHA pruning (if trial provided) ------------------
if trial is not None:
import optuna
trial.report(val_loss, epoch)
if trial.should_prune():
print(f"\n✂ Optuna pruned this trial at epoch {epoch+1}.")
raise optuna.TrialPruned()
# --- Checkpoint best model ------------------------------------
if val_loss < best_val:
best_val = val_loss
patience_ctr = 0
torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
print(f" ↳ New best val loss — checkpoint saved.")
else:
patience_ctr += 1
if patience_ctr >= cfg.patience:
print(f"\n⏹ Early stopping after {cfg.patience} epochs without improvement.")
break
# --- Save resumable state after every epoch --------------------
resume_state = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict() if scaler is not None else None,
"best_val_loss": best_val,
"patience_ctr": patience_ctr,
"history": history,
"cfg_epochs": cfg.epochs,
}
torch.save(resume_state, ckpt_dir / "resume_state.pt")
# --- Epoch callback (e.g. live plotting) ----------------------
if epoch_callback is not None:
epoch_callback(epoch, history)
# Load best checkpoint
model.load_state_dict(torch.load(ckpt_dir / "best_model.pt", map_location=device, weights_only=True))
print(f"\n✓ Training complete. Best val loss: {best_val:.4f}")
return history