|
|
"""
|
|
|
Training loop for the Transformer translator.
|
|
|
===============================================
|
|
|
Provides:
|
|
|
• ``TranslationDataset`` – a PyTorch Dataset that tokenises and pads
|
|
|
source/target sentence pairs.
|
|
|
• ``create_dataloaders`` – builds train / validation DataLoaders with
|
|
|
an 90/10 split.
|
|
|
• ``train_one_epoch`` – one full pass over the training set.
|
|
|
• ``evaluate_loss`` – average loss on the validation set.
|
|
|
• ``train`` – full training driver with logging, LR
|
|
|
scheduling, checkpointing, and early stopping.
|
|
|
|
|
|
Design choices:
|
|
|
• Label-smoothed cross-entropy (smoothing = 0.1) for better
|
|
|
generalisation.
|
|
|
• AdamW with a linear-warmup + cosine-decay schedule (stable for
|
|
|
small datasets).
|
|
|
• Mixed-precision (AMP) with ``torch.amp`` for speed / memory on T4.
|
|
|
• Gradient clipping at max_norm = 1.0 to avoid exploding gradients.
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import math
|
|
|
import os
|
|
|
import time
|
|
|
from dataclasses import dataclass, field
|
|
|
from pathlib import Path
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import Dataset, DataLoader, random_split
|
|
|
from tokenizers import Tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranslationDataset(Dataset):
|
|
|
"""
|
|
|
Wraps a HuggingFace dataset of translation pairs into a PyTorch
|
|
|
Dataset that returns padded token-ID tensors.
|
|
|
|
|
|
Each ``__getitem__`` returns::
|
|
|
{
|
|
|
"src": LongTensor[max_len], # source token IDs (padded)
|
|
|
"tgt": LongTensor[max_len], # target input (with [BOS], no final [EOS])
|
|
|
"label": LongTensor[max_len], # target labels (no [BOS], with [EOS])
|
|
|
}
|
|
|
|
|
|
The *tgt* / *label* split implements **teacher forcing**: the decoder
|
|
|
receives ``[BOS] w1 w2 …`` and must predict ``w1 w2 … [EOS]``.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
hf_dataset,
|
|
|
src_tokenizer: Tokenizer,
|
|
|
tgt_tokenizer: Tokenizer,
|
|
|
src_lang: str = "en",
|
|
|
tgt_lang: str = "ms",
|
|
|
max_len: int = 128,
|
|
|
pad_id: int = 0,
|
|
|
):
|
|
|
self.data = hf_dataset
|
|
|
self.src_tok = src_tokenizer
|
|
|
self.tgt_tok = tgt_tokenizer
|
|
|
self.src_lang = src_lang
|
|
|
self.tgt_lang = tgt_lang
|
|
|
self.max_len = max_len
|
|
|
self.pad_id = pad_id
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
return len(self.data)
|
|
|
|
|
|
def _pad(self, ids: List[int]) -> List[int]:
|
|
|
"""Truncate to max_len, then right-pad with pad_id."""
|
|
|
ids = ids[: self.max_len]
|
|
|
return ids + [self.pad_id] * (self.max_len - len(ids))
|
|
|
|
|
|
def __getitem__(self, idx: int) -> dict:
|
|
|
pair = self.data[idx]["translation"]
|
|
|
|
|
|
|
|
|
src_ids = self.src_tok.encode(pair[self.src_lang]).ids
|
|
|
tgt_ids = self.tgt_tok.encode(pair[self.tgt_lang]).ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tgt_input = tgt_ids[:-1]
|
|
|
tgt_label = tgt_ids[1:]
|
|
|
|
|
|
return {
|
|
|
"src": torch.tensor(self._pad(src_ids), dtype=torch.long),
|
|
|
"tgt": torch.tensor(self._pad(tgt_input), dtype=torch.long),
|
|
|
"label": torch.tensor(self._pad(tgt_label), dtype=torch.long),
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_dataloaders(
|
|
|
hf_dataset,
|
|
|
src_tokenizer: Tokenizer,
|
|
|
tgt_tokenizer: Tokenizer,
|
|
|
src_lang: str = "en",
|
|
|
tgt_lang: str = "ms",
|
|
|
max_len: int = 128,
|
|
|
batch_size: int = 32,
|
|
|
val_ratio: float = 0.1,
|
|
|
pad_id: int = 0,
|
|
|
seed: int = 42,
|
|
|
) -> Tuple[DataLoader, DataLoader, TranslationDataset]:
|
|
|
"""
|
|
|
Build training and validation DataLoaders from a HuggingFace dataset.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
train_loader, val_loader, full_dataset
|
|
|
"""
|
|
|
full_ds = TranslationDataset(
|
|
|
hf_dataset, src_tokenizer, tgt_tokenizer,
|
|
|
src_lang, tgt_lang, max_len, pad_id,
|
|
|
)
|
|
|
|
|
|
val_size = max(1, int(len(full_ds) * val_ratio))
|
|
|
train_size = len(full_ds) - val_size
|
|
|
|
|
|
generator = torch.Generator().manual_seed(seed)
|
|
|
train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=generator)
|
|
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=False)
|
|
|
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=False)
|
|
|
|
|
|
print(f"Train: {train_size} | Val: {val_size} | Batch size: {batch_size}")
|
|
|
return train_loader, val_loader, full_ds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TrainConfig:
|
|
|
"""All tuneable knobs in one place."""
|
|
|
epochs: int = 50
|
|
|
batch_size: int = 32
|
|
|
max_len: int = 128
|
|
|
lr: float = 5e-4
|
|
|
warmup_steps: int = 200
|
|
|
label_smoothing: float = 0.1
|
|
|
grad_clip: float = 1.0
|
|
|
use_amp: bool = True
|
|
|
val_ratio: float = 0.1
|
|
|
checkpoint_dir: str = "training/checkpoints"
|
|
|
log_every: int = 10
|
|
|
patience: int = 10
|
|
|
seed: int = 42
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _build_scheduler(optimizer, warmup_steps: int, total_steps: int):
|
|
|
"""Linear warmup for `warmup_steps`, then cosine decay to 0."""
|
|
|
|
|
|
def lr_lambda(step):
|
|
|
if step < warmup_steps:
|
|
|
return step / max(1, warmup_steps)
|
|
|
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
|
|
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(
|
|
|
model: nn.Module,
|
|
|
loader: DataLoader,
|
|
|
optimizer: torch.optim.Optimizer,
|
|
|
scheduler,
|
|
|
criterion: nn.Module,
|
|
|
device: torch.device,
|
|
|
scaler: Optional[torch.amp.GradScaler],
|
|
|
grad_clip: float = 1.0,
|
|
|
log_every: int = 10,
|
|
|
epoch: int = 0,
|
|
|
) -> float:
|
|
|
"""Train for one epoch. Returns average loss."""
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
n_tokens = 0
|
|
|
|
|
|
for step, batch in enumerate(loader):
|
|
|
src = batch["src"].to(device)
|
|
|
tgt = batch["tgt"].to(device)
|
|
|
label = batch["label"].to(device)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
amp_enabled = scaler is not None
|
|
|
with torch.amp.autocast("cuda", enabled=amp_enabled):
|
|
|
logits = model(src, tgt)
|
|
|
loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
|
|
|
|
|
|
if scaler is not None:
|
|
|
scaler.scale(loss).backward()
|
|
|
scaler.unscale_(optimizer)
|
|
|
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
|
|
optimizer.step()
|
|
|
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
|
non_pad = (label != model.pad_idx).sum().item()
|
|
|
total_loss += loss.item() * non_pad
|
|
|
n_tokens += non_pad
|
|
|
|
|
|
if (step + 1) % log_every == 0:
|
|
|
avg = total_loss / max(n_tokens, 1)
|
|
|
lr = scheduler.get_last_lr()[0]
|
|
|
print(f" Epoch {epoch+1} | Step {step+1}/{len(loader)} | Loss {avg:.4f} | LR {lr:.2e}")
|
|
|
|
|
|
return total_loss / max(n_tokens, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def evaluate_loss(
|
|
|
model: nn.Module,
|
|
|
loader: DataLoader,
|
|
|
criterion: nn.Module,
|
|
|
device: torch.device,
|
|
|
use_amp: bool = False,
|
|
|
) -> float:
|
|
|
"""Compute average loss over a validation set (with AMP to match training)."""
|
|
|
model.eval()
|
|
|
total_loss = 0.0
|
|
|
n_tokens = 0
|
|
|
n_batches = len(loader)
|
|
|
|
|
|
for step, batch in enumerate(loader):
|
|
|
src = batch["src"].to(device)
|
|
|
tgt = batch["tgt"].to(device)
|
|
|
label = batch["label"].to(device)
|
|
|
|
|
|
with torch.amp.autocast("cuda", enabled=use_amp):
|
|
|
logits = model(src, tgt)
|
|
|
loss = criterion(logits.reshape(-1, logits.size(-1)), label.reshape(-1))
|
|
|
|
|
|
non_pad = (label != model.pad_idx).sum().item()
|
|
|
total_loss += loss.item() * non_pad
|
|
|
n_tokens += non_pad
|
|
|
|
|
|
if (step + 1) % max(1, n_batches // 4) == 0 or (step + 1) == n_batches:
|
|
|
print(f" Val {step+1}/{n_batches}", end="\r")
|
|
|
|
|
|
return total_loss / max(n_tokens, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(
|
|
|
model: nn.Module,
|
|
|
train_loader: DataLoader,
|
|
|
val_loader: DataLoader,
|
|
|
cfg: TrainConfig,
|
|
|
device: torch.device,
|
|
|
trial=None,
|
|
|
resume_from: Optional[str] = None,
|
|
|
epoch_callback=None,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Full training loop with logging, checkpointing, and early stopping.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
trial : optuna.trial.Trial, optional
|
|
|
If provided, reports val_loss after each epoch for ASHA pruning.
|
|
|
resume_from : str, optional
|
|
|
Path to a ``resume_state.pt`` file. If provided, training resumes
|
|
|
from the saved epoch with the exact optimizer / scheduler / scaler
|
|
|
state, history, and early-stopping counters.
|
|
|
epoch_callback : callable, optional
|
|
|
Called after every epoch as ``epoch_callback(epoch, history)``.
|
|
|
Useful for live plotting in notebooks.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
history : dict
|
|
|
``{"train_loss": [...], "val_loss": [...], "lr": [...]}``
|
|
|
"""
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss(
|
|
|
ignore_index=model.pad_idx,
|
|
|
label_smoothing=cfg.label_smoothing,
|
|
|
)
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, betas=(0.9, 0.98), eps=1e-9)
|
|
|
|
|
|
|
|
|
total_steps = cfg.epochs * len(train_loader)
|
|
|
scheduler = _build_scheduler(optimizer, cfg.warmup_steps, total_steps)
|
|
|
|
|
|
|
|
|
scaler = torch.amp.GradScaler("cuda") if (cfg.use_amp and device.type == "cuda") else None
|
|
|
|
|
|
|
|
|
ckpt_dir = Path(cfg.checkpoint_dir)
|
|
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
history: dict = {"train_loss": [], "val_loss": [], "lr": []}
|
|
|
best_val = float("inf")
|
|
|
patience_ctr = 0
|
|
|
start_epoch = 0
|
|
|
|
|
|
|
|
|
if resume_from is not None and os.path.exists(resume_from):
|
|
|
print(f"\n🔄 Resuming from {resume_from}")
|
|
|
ckpt = torch.load(resume_from, map_location=device, weights_only=False)
|
|
|
model.load_state_dict(ckpt["model_state_dict"])
|
|
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
|
|
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
|
|
if scaler is not None and "scaler_state_dict" in ckpt:
|
|
|
scaler.load_state_dict(ckpt["scaler_state_dict"])
|
|
|
start_epoch = ckpt["epoch"] + 1
|
|
|
best_val = ckpt["best_val_loss"]
|
|
|
patience_ctr = ckpt["patience_ctr"]
|
|
|
history = ckpt["history"]
|
|
|
print(f" Resumed at epoch {start_epoch+1}/{cfg.epochs} | "
|
|
|
f"best_val={best_val:.4f} | patience={patience_ctr}/{cfg.patience}")
|
|
|
|
|
|
print(f"\n{'='*60}")
|
|
|
print(f"Starting training: {cfg.epochs} epochs (from epoch {start_epoch+1}), lr={cfg.lr}, AMP={cfg.use_amp}")
|
|
|
print(f"{'='*60}\n")
|
|
|
|
|
|
for epoch in range(start_epoch, cfg.epochs):
|
|
|
t0 = time.time()
|
|
|
|
|
|
train_loss = train_one_epoch(
|
|
|
model, train_loader, optimizer, scheduler, criterion,
|
|
|
device, scaler, cfg.grad_clip, cfg.log_every, epoch,
|
|
|
)
|
|
|
use_amp = cfg.use_amp and device.type == "cuda"
|
|
|
val_loss = evaluate_loss(model, val_loader, criterion, device, use_amp=use_amp)
|
|
|
lr = scheduler.get_last_lr()[0]
|
|
|
|
|
|
elapsed = time.time() - t0
|
|
|
history["train_loss"].append(train_loss)
|
|
|
history["val_loss"].append(val_loss)
|
|
|
history["lr"].append(lr)
|
|
|
|
|
|
print(
|
|
|
f"Epoch {epoch+1}/{cfg.epochs} | "
|
|
|
f"Train {train_loss:.4f} | Val {val_loss:.4f} | "
|
|
|
f"LR {lr:.2e} | {elapsed:.1f}s"
|
|
|
)
|
|
|
|
|
|
|
|
|
if trial is not None:
|
|
|
import optuna
|
|
|
trial.report(val_loss, epoch)
|
|
|
if trial.should_prune():
|
|
|
print(f"\n✂ Optuna pruned this trial at epoch {epoch+1}.")
|
|
|
raise optuna.TrialPruned()
|
|
|
|
|
|
|
|
|
if val_loss < best_val:
|
|
|
best_val = val_loss
|
|
|
patience_ctr = 0
|
|
|
torch.save(model.state_dict(), ckpt_dir / "best_model.pt")
|
|
|
print(f" ↳ New best val loss — checkpoint saved.")
|
|
|
else:
|
|
|
patience_ctr += 1
|
|
|
if patience_ctr >= cfg.patience:
|
|
|
print(f"\n⏹ Early stopping after {cfg.patience} epochs without improvement.")
|
|
|
break
|
|
|
|
|
|
|
|
|
resume_state = {
|
|
|
"epoch": epoch,
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
|
"scaler_state_dict": scaler.state_dict() if scaler is not None else None,
|
|
|
"best_val_loss": best_val,
|
|
|
"patience_ctr": patience_ctr,
|
|
|
"history": history,
|
|
|
"cfg_epochs": cfg.epochs,
|
|
|
}
|
|
|
torch.save(resume_state, ckpt_dir / "resume_state.pt")
|
|
|
|
|
|
|
|
|
if epoch_callback is not None:
|
|
|
epoch_callback(epoch, history)
|
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load(ckpt_dir / "best_model.pt", map_location=device, weights_only=True))
|
|
|
print(f"\n✓ Training complete. Best val loss: {best_val:.4f}")
|
|
|
return history
|
|
|
|