WCNegentropy's picture
Upload 510 files
3d7f6c5 verified
"""Training infrastructure for standalone WrinkleBrane model.
Provides training loops, evaluation, and model comparison utilities
shared across all three training tasks.
Key components
--------------
``train_step``
Single optimisation step with orthogonality regularisation.
``train_loop``
Multi-step training loop with logging.
``evaluate``
Evaluation on held-out data.
``compare_models``
Side-by-side WrinkleBrane vs transformer training comparison.
"""
from __future__ import annotations
import time
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn, Tensor
from wrinklebrane.standalone_model import WrinkleBraneModel, WrinkleBraneConfig
from wrinklebrane.baseline_transformer import SmallTransformer, SmallTransformerConfig
from wrinklebrane.tasks import compute_accuracy
# ---------------------------------------------------------------------------
# Training step
# ---------------------------------------------------------------------------
def train_step(
model: nn.Module,
input_ids: Tensor,
target_ids: Tensor,
optimizer: torch.optim.Optimizer,
ortho_lambda: float = 0.0,
ignore_index: int = -100,
) -> Dict[str, float]:
"""Single training step.
Parameters
----------
model : nn.Module
WrinkleBraneModel or SmallTransformer.
input_ids : Tensor ``[B, T]``
target_ids : Tensor ``[B, T]``
optimizer : Optimizer
ortho_lambda : float
Orthogonality regularisation weight (0 for transformer).
ignore_index : int
Cross-entropy ignore index.
Returns
-------
dict
``task_loss``, ``ortho_loss``, ``total_loss``, ``accuracy``.
"""
model.train()
optimizer.zero_grad()
logits = model(input_ids) # [B, T, V]
# Cross-entropy loss
B, T, V = logits.shape
task_loss = nn.functional.cross_entropy(
logits.reshape(B * T, V),
target_ids.reshape(B * T),
ignore_index=ignore_index,
)
# Orthogonality regularisation (WrinkleBrane only)
ortho = torch.tensor(0.0, device=task_loss.device)
if ortho_lambda > 0 and hasattr(model, "ortho_loss"):
ortho = model.ortho_loss()
total_loss = task_loss + ortho_lambda * ortho
total_loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
with torch.no_grad():
acc = compute_accuracy(logits.detach(), target_ids, ignore_index)
return {
"task_loss": float(task_loss.detach()),
"ortho_loss": float(ortho.detach()),
"total_loss": float(total_loss.detach()),
"accuracy": acc,
}
# ---------------------------------------------------------------------------
# Training loop
# ---------------------------------------------------------------------------
def train_loop(
model: nn.Module,
task,
*,
n_steps: int = 500,
batch_size: int = 32,
lr: float = 3e-4,
ortho_lambda: float = 0.0,
log_every: int = 50,
device: str = "cpu",
ignore_index: int = -100,
) -> List[Dict[str, float]]:
"""Train a model on a task for ``n_steps``.
Parameters
----------
model : nn.Module
task : SequenceCopyTask, AssociativeRecallTask, or SyntheticGrammarTask
n_steps : int
batch_size : int
lr : float
ortho_lambda : float
log_every : int
device : str
ignore_index : int
Returns
-------
list of dict
Per-step metrics (logged at ``log_every`` intervals).
"""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
# Learning rate schedule: linear warmup + cosine decay
warmup_steps = min(n_steps // 10, 100)
def lr_lambda(step):
if step < warmup_steps:
return (step + 1) / warmup_steps
progress = (step - warmup_steps) / max(1, n_steps - warmup_steps)
return 0.5 * (1.0 + __import__("math").cos(__import__("math").pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
history = []
t0 = time.time()
for step in range(n_steps):
input_ids, target_ids = task.generate_batch(batch_size)
input_ids = input_ids.to(device)
target_ids = target_ids.to(device)
metrics = train_step(
model, input_ids, target_ids, optimizer,
ortho_lambda=ortho_lambda,
ignore_index=ignore_index,
)
metrics["step"] = step
metrics["lr"] = optimizer.param_groups[0]["lr"]
scheduler.step()
if step % log_every == 0 or step == n_steps - 1:
elapsed = time.time() - t0
metrics["elapsed_s"] = elapsed
history.append(metrics)
return history
# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------
@torch.no_grad()
def evaluate(
model: nn.Module,
task,
*,
n_batches: int = 10,
batch_size: int = 32,
device: str = "cpu",
ignore_index: int = -100,
) -> Dict[str, float]:
"""Evaluate a model on a task.
Returns
-------
dict
``loss``, ``accuracy``, ``perplexity``.
"""
model.eval()
model = model.to(device)
total_loss = 0.0
total_correct = 0
total_counted = 0
for _ in range(n_batches):
input_ids, target_ids = task.generate_batch(batch_size)
input_ids = input_ids.to(device)
target_ids = target_ids.to(device)
logits = model(input_ids)
B, T, V = logits.shape
loss = nn.functional.cross_entropy(
logits.reshape(B * T, V),
target_ids.reshape(B * T),
ignore_index=ignore_index,
)
total_loss += float(loss) * B
# Accuracy
preds = logits.argmax(dim=-1)
mask = target_ids != ignore_index
total_correct += int(((preds == target_ids) & mask).sum())
total_counted += int(mask.sum())
avg_loss = total_loss / (n_batches * batch_size)
accuracy = total_correct / max(total_counted, 1)
perplexity = min(__import__("math").exp(avg_loss), 1e6)
return {
"loss": avg_loss,
"accuracy": accuracy,
"perplexity": perplexity,
}
# ---------------------------------------------------------------------------
# Model comparison
# ---------------------------------------------------------------------------
def compare_models(
task,
*,
wb_config: Optional[WrinkleBraneConfig] = None,
tf_config: Optional[SmallTransformerConfig] = None,
n_steps: int = 500,
batch_size: int = 32,
lr: float = 3e-4,
log_every: int = 50,
device: str = "cpu",
ignore_index: int = -100,
) -> Dict[str, object]:
"""Train both models side-by-side on the same task.
Returns
-------
dict
``wb_history``, ``tf_history``, ``wb_eval``, ``tf_eval``,
``wb_params``, ``tf_params``.
"""
if wb_config is None:
wb_config = WrinkleBraneConfig()
if tf_config is None:
tf_config = SmallTransformerConfig(
vocab_size=wb_config.vocab_size,
d_model=wb_config.d_model,
max_seq_len=wb_config.max_seq_len,
n_layers=wb_config.n_layers,
n_heads=wb_config.n_heads,
ffn_expansion=wb_config.ffn_expansion,
dropout=wb_config.dropout,
weight_tying=wb_config.weight_tying,
)
wb_model = WrinkleBraneModel(wb_config)
tf_model = SmallTransformer(tf_config)
wb_params = wb_model.count_parameters()
tf_params = tf_model.count_parameters()
# Train WrinkleBrane
wb_history = train_loop(
wb_model, task,
n_steps=n_steps, batch_size=batch_size, lr=lr,
ortho_lambda=wb_config.ortho_lambda,
log_every=log_every, device=device,
ignore_index=ignore_index,
)
# Train transformer
tf_history = train_loop(
tf_model, task,
n_steps=n_steps, batch_size=batch_size, lr=lr,
ortho_lambda=0.0,
log_every=log_every, device=device,
ignore_index=ignore_index,
)
# Evaluate both
wb_eval = evaluate(
wb_model, task, device=device, ignore_index=ignore_index,
)
tf_eval = evaluate(
tf_model, task, device=device, ignore_index=ignore_index,
)
return {
"wb_history": wb_history,
"tf_history": tf_history,
"wb_eval": wb_eval,
"tf_eval": tf_eval,
"wb_params": wb_params,
"tf_params": tf_params,
}