| """Training infrastructure for standalone WrinkleBrane model. |
| |
| Provides training loops, evaluation, and model comparison utilities |
| shared across all three training tasks. |
| |
| Key components |
| -------------- |
| ``train_step`` |
| Single optimisation step with orthogonality regularisation. |
| |
| ``train_loop`` |
| Multi-step training loop with logging. |
| |
| ``evaluate`` |
| Evaluation on held-out data. |
| |
| ``compare_models`` |
| Side-by-side WrinkleBrane vs transformer training comparison. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import time |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| from torch import nn, Tensor |
|
|
| from wrinklebrane.standalone_model import WrinkleBraneModel, WrinkleBraneConfig |
| from wrinklebrane.baseline_transformer import SmallTransformer, SmallTransformerConfig |
| from wrinklebrane.tasks import compute_accuracy |
|
|
|
|
| |
| |
| |
|
|
| def train_step( |
| model: nn.Module, |
| input_ids: Tensor, |
| target_ids: Tensor, |
| optimizer: torch.optim.Optimizer, |
| ortho_lambda: float = 0.0, |
| ignore_index: int = -100, |
| ) -> Dict[str, float]: |
| """Single training step. |
| |
| Parameters |
| ---------- |
| model : nn.Module |
| WrinkleBraneModel or SmallTransformer. |
| input_ids : Tensor ``[B, T]`` |
| target_ids : Tensor ``[B, T]`` |
| optimizer : Optimizer |
| ortho_lambda : float |
| Orthogonality regularisation weight (0 for transformer). |
| ignore_index : int |
| Cross-entropy ignore index. |
| |
| Returns |
| ------- |
| dict |
| ``task_loss``, ``ortho_loss``, ``total_loss``, ``accuracy``. |
| """ |
| model.train() |
| optimizer.zero_grad() |
|
|
| logits = model(input_ids) |
|
|
| |
| B, T, V = logits.shape |
| task_loss = nn.functional.cross_entropy( |
| logits.reshape(B * T, V), |
| target_ids.reshape(B * T), |
| ignore_index=ignore_index, |
| ) |
|
|
| |
| ortho = torch.tensor(0.0, device=task_loss.device) |
| if ortho_lambda > 0 and hasattr(model, "ortho_loss"): |
| ortho = model.ortho_loss() |
|
|
| total_loss = task_loss + ortho_lambda * ortho |
| total_loss.backward() |
|
|
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
| optimizer.step() |
|
|
| with torch.no_grad(): |
| acc = compute_accuracy(logits.detach(), target_ids, ignore_index) |
|
|
| return { |
| "task_loss": float(task_loss.detach()), |
| "ortho_loss": float(ortho.detach()), |
| "total_loss": float(total_loss.detach()), |
| "accuracy": acc, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def train_loop( |
| model: nn.Module, |
| task, |
| *, |
| n_steps: int = 500, |
| batch_size: int = 32, |
| lr: float = 3e-4, |
| ortho_lambda: float = 0.0, |
| log_every: int = 50, |
| device: str = "cpu", |
| ignore_index: int = -100, |
| ) -> List[Dict[str, float]]: |
| """Train a model on a task for ``n_steps``. |
| |
| Parameters |
| ---------- |
| model : nn.Module |
| task : SequenceCopyTask, AssociativeRecallTask, or SyntheticGrammarTask |
| n_steps : int |
| batch_size : int |
| lr : float |
| ortho_lambda : float |
| log_every : int |
| device : str |
| ignore_index : int |
| |
| Returns |
| ------- |
| list of dict |
| Per-step metrics (logged at ``log_every`` intervals). |
| """ |
| model = model.to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) |
|
|
| |
| warmup_steps = min(n_steps // 10, 100) |
|
|
| def lr_lambda(step): |
| if step < warmup_steps: |
| return (step + 1) / warmup_steps |
| progress = (step - warmup_steps) / max(1, n_steps - warmup_steps) |
| return 0.5 * (1.0 + __import__("math").cos(__import__("math").pi * progress)) |
|
|
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
| history = [] |
| t0 = time.time() |
|
|
| for step in range(n_steps): |
| input_ids, target_ids = task.generate_batch(batch_size) |
| input_ids = input_ids.to(device) |
| target_ids = target_ids.to(device) |
|
|
| metrics = train_step( |
| model, input_ids, target_ids, optimizer, |
| ortho_lambda=ortho_lambda, |
| ignore_index=ignore_index, |
| ) |
| metrics["step"] = step |
| metrics["lr"] = optimizer.param_groups[0]["lr"] |
| scheduler.step() |
|
|
| if step % log_every == 0 or step == n_steps - 1: |
| elapsed = time.time() - t0 |
| metrics["elapsed_s"] = elapsed |
| history.append(metrics) |
|
|
| return history |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate( |
| model: nn.Module, |
| task, |
| *, |
| n_batches: int = 10, |
| batch_size: int = 32, |
| device: str = "cpu", |
| ignore_index: int = -100, |
| ) -> Dict[str, float]: |
| """Evaluate a model on a task. |
| |
| Returns |
| ------- |
| dict |
| ``loss``, ``accuracy``, ``perplexity``. |
| """ |
| model.eval() |
| model = model.to(device) |
|
|
| total_loss = 0.0 |
| total_correct = 0 |
| total_counted = 0 |
|
|
| for _ in range(n_batches): |
| input_ids, target_ids = task.generate_batch(batch_size) |
| input_ids = input_ids.to(device) |
| target_ids = target_ids.to(device) |
|
|
| logits = model(input_ids) |
| B, T, V = logits.shape |
|
|
| loss = nn.functional.cross_entropy( |
| logits.reshape(B * T, V), |
| target_ids.reshape(B * T), |
| ignore_index=ignore_index, |
| ) |
| total_loss += float(loss) * B |
|
|
| |
| preds = logits.argmax(dim=-1) |
| mask = target_ids != ignore_index |
| total_correct += int(((preds == target_ids) & mask).sum()) |
| total_counted += int(mask.sum()) |
|
|
| avg_loss = total_loss / (n_batches * batch_size) |
| accuracy = total_correct / max(total_counted, 1) |
| perplexity = min(__import__("math").exp(avg_loss), 1e6) |
|
|
| return { |
| "loss": avg_loss, |
| "accuracy": accuracy, |
| "perplexity": perplexity, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def compare_models( |
| task, |
| *, |
| wb_config: Optional[WrinkleBraneConfig] = None, |
| tf_config: Optional[SmallTransformerConfig] = None, |
| n_steps: int = 500, |
| batch_size: int = 32, |
| lr: float = 3e-4, |
| log_every: int = 50, |
| device: str = "cpu", |
| ignore_index: int = -100, |
| ) -> Dict[str, object]: |
| """Train both models side-by-side on the same task. |
| |
| Returns |
| ------- |
| dict |
| ``wb_history``, ``tf_history``, ``wb_eval``, ``tf_eval``, |
| ``wb_params``, ``tf_params``. |
| """ |
| if wb_config is None: |
| wb_config = WrinkleBraneConfig() |
| if tf_config is None: |
| tf_config = SmallTransformerConfig( |
| vocab_size=wb_config.vocab_size, |
| d_model=wb_config.d_model, |
| max_seq_len=wb_config.max_seq_len, |
| n_layers=wb_config.n_layers, |
| n_heads=wb_config.n_heads, |
| ffn_expansion=wb_config.ffn_expansion, |
| dropout=wb_config.dropout, |
| weight_tying=wb_config.weight_tying, |
| ) |
|
|
| wb_model = WrinkleBraneModel(wb_config) |
| tf_model = SmallTransformer(tf_config) |
|
|
| wb_params = wb_model.count_parameters() |
| tf_params = tf_model.count_parameters() |
|
|
| |
| wb_history = train_loop( |
| wb_model, task, |
| n_steps=n_steps, batch_size=batch_size, lr=lr, |
| ortho_lambda=wb_config.ortho_lambda, |
| log_every=log_every, device=device, |
| ignore_index=ignore_index, |
| ) |
|
|
| |
| tf_history = train_loop( |
| tf_model, task, |
| n_steps=n_steps, batch_size=batch_size, lr=lr, |
| ortho_lambda=0.0, |
| log_every=log_every, device=device, |
| ignore_index=ignore_index, |
| ) |
|
|
| |
| wb_eval = evaluate( |
| wb_model, task, device=device, ignore_index=ignore_index, |
| ) |
| tf_eval = evaluate( |
| tf_model, task, device=device, ignore_index=ignore_index, |
| ) |
|
|
| return { |
| "wb_history": wb_history, |
| "tf_history": tf_history, |
| "wb_eval": wb_eval, |
| "tf_eval": tf_eval, |
| "wb_params": wb_params, |
| "tf_params": tf_params, |
| } |
|
|