"""Training infrastructure for standalone WrinkleBrane model. Provides training loops, evaluation, and model comparison utilities shared across all three training tasks. Key components -------------- ``train_step`` Single optimisation step with orthogonality regularisation. ``train_loop`` Multi-step training loop with logging. ``evaluate`` Evaluation on held-out data. ``compare_models`` Side-by-side WrinkleBrane vs transformer training comparison. """ from __future__ import annotations import time from typing import Dict, List, Optional, Tuple import torch from torch import nn, Tensor from wrinklebrane.standalone_model import WrinkleBraneModel, WrinkleBraneConfig from wrinklebrane.baseline_transformer import SmallTransformer, SmallTransformerConfig from wrinklebrane.tasks import compute_accuracy # --------------------------------------------------------------------------- # Training step # --------------------------------------------------------------------------- def train_step( model: nn.Module, input_ids: Tensor, target_ids: Tensor, optimizer: torch.optim.Optimizer, ortho_lambda: float = 0.0, ignore_index: int = -100, ) -> Dict[str, float]: """Single training step. Parameters ---------- model : nn.Module WrinkleBraneModel or SmallTransformer. input_ids : Tensor ``[B, T]`` target_ids : Tensor ``[B, T]`` optimizer : Optimizer ortho_lambda : float Orthogonality regularisation weight (0 for transformer). ignore_index : int Cross-entropy ignore index. Returns ------- dict ``task_loss``, ``ortho_loss``, ``total_loss``, ``accuracy``. """ model.train() optimizer.zero_grad() logits = model(input_ids) # [B, T, V] # Cross-entropy loss B, T, V = logits.shape task_loss = nn.functional.cross_entropy( logits.reshape(B * T, V), target_ids.reshape(B * T), ignore_index=ignore_index, ) # Orthogonality regularisation (WrinkleBrane only) ortho = torch.tensor(0.0, device=task_loss.device) if ortho_lambda > 0 and hasattr(model, "ortho_loss"): ortho = model.ortho_loss() total_loss = task_loss + ortho_lambda * ortho total_loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() with torch.no_grad(): acc = compute_accuracy(logits.detach(), target_ids, ignore_index) return { "task_loss": float(task_loss.detach()), "ortho_loss": float(ortho.detach()), "total_loss": float(total_loss.detach()), "accuracy": acc, } # --------------------------------------------------------------------------- # Training loop # --------------------------------------------------------------------------- def train_loop( model: nn.Module, task, *, n_steps: int = 500, batch_size: int = 32, lr: float = 3e-4, ortho_lambda: float = 0.0, log_every: int = 50, device: str = "cpu", ignore_index: int = -100, ) -> List[Dict[str, float]]: """Train a model on a task for ``n_steps``. Parameters ---------- model : nn.Module task : SequenceCopyTask, AssociativeRecallTask, or SyntheticGrammarTask n_steps : int batch_size : int lr : float ortho_lambda : float log_every : int device : str ignore_index : int Returns ------- list of dict Per-step metrics (logged at ``log_every`` intervals). """ model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) # Learning rate schedule: linear warmup + cosine decay warmup_steps = min(n_steps // 10, 100) def lr_lambda(step): if step < warmup_steps: return (step + 1) / warmup_steps progress = (step - warmup_steps) / max(1, n_steps - warmup_steps) return 0.5 * (1.0 + __import__("math").cos(__import__("math").pi * progress)) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) history = [] t0 = time.time() for step in range(n_steps): input_ids, target_ids = task.generate_batch(batch_size) input_ids = input_ids.to(device) target_ids = target_ids.to(device) metrics = train_step( model, input_ids, target_ids, optimizer, ortho_lambda=ortho_lambda, ignore_index=ignore_index, ) metrics["step"] = step metrics["lr"] = optimizer.param_groups[0]["lr"] scheduler.step() if step % log_every == 0 or step == n_steps - 1: elapsed = time.time() - t0 metrics["elapsed_s"] = elapsed history.append(metrics) return history # --------------------------------------------------------------------------- # Evaluation # --------------------------------------------------------------------------- @torch.no_grad() def evaluate( model: nn.Module, task, *, n_batches: int = 10, batch_size: int = 32, device: str = "cpu", ignore_index: int = -100, ) -> Dict[str, float]: """Evaluate a model on a task. Returns ------- dict ``loss``, ``accuracy``, ``perplexity``. """ model.eval() model = model.to(device) total_loss = 0.0 total_correct = 0 total_counted = 0 for _ in range(n_batches): input_ids, target_ids = task.generate_batch(batch_size) input_ids = input_ids.to(device) target_ids = target_ids.to(device) logits = model(input_ids) B, T, V = logits.shape loss = nn.functional.cross_entropy( logits.reshape(B * T, V), target_ids.reshape(B * T), ignore_index=ignore_index, ) total_loss += float(loss) * B # Accuracy preds = logits.argmax(dim=-1) mask = target_ids != ignore_index total_correct += int(((preds == target_ids) & mask).sum()) total_counted += int(mask.sum()) avg_loss = total_loss / (n_batches * batch_size) accuracy = total_correct / max(total_counted, 1) perplexity = min(__import__("math").exp(avg_loss), 1e6) return { "loss": avg_loss, "accuracy": accuracy, "perplexity": perplexity, } # --------------------------------------------------------------------------- # Model comparison # --------------------------------------------------------------------------- def compare_models( task, *, wb_config: Optional[WrinkleBraneConfig] = None, tf_config: Optional[SmallTransformerConfig] = None, n_steps: int = 500, batch_size: int = 32, lr: float = 3e-4, log_every: int = 50, device: str = "cpu", ignore_index: int = -100, ) -> Dict[str, object]: """Train both models side-by-side on the same task. Returns ------- dict ``wb_history``, ``tf_history``, ``wb_eval``, ``tf_eval``, ``wb_params``, ``tf_params``. """ if wb_config is None: wb_config = WrinkleBraneConfig() if tf_config is None: tf_config = SmallTransformerConfig( vocab_size=wb_config.vocab_size, d_model=wb_config.d_model, max_seq_len=wb_config.max_seq_len, n_layers=wb_config.n_layers, n_heads=wb_config.n_heads, ffn_expansion=wb_config.ffn_expansion, dropout=wb_config.dropout, weight_tying=wb_config.weight_tying, ) wb_model = WrinkleBraneModel(wb_config) tf_model = SmallTransformer(tf_config) wb_params = wb_model.count_parameters() tf_params = tf_model.count_parameters() # Train WrinkleBrane wb_history = train_loop( wb_model, task, n_steps=n_steps, batch_size=batch_size, lr=lr, ortho_lambda=wb_config.ortho_lambda, log_every=log_every, device=device, ignore_index=ignore_index, ) # Train transformer tf_history = train_loop( tf_model, task, n_steps=n_steps, batch_size=batch_size, lr=lr, ortho_lambda=0.0, log_every=log_every, device=device, ignore_index=ignore_index, ) # Evaluate both wb_eval = evaluate( wb_model, task, device=device, ignore_index=ignore_index, ) tf_eval = evaluate( tf_model, task, device=device, ignore_index=ignore_index, ) return { "wb_history": wb_history, "tf_history": tf_history, "wb_eval": wb_eval, "tf_eval": tf_eval, "wb_params": wb_params, "tf_params": tf_params, }