from __future__ import annotations import json import math from dataclasses import replace from pathlib import Path from typing import Any import torch from src.data.anchor_synthetic import load_anchor_synthetic from src.data.openwebmath_bpe import load_openwebmath_bpe from src.data.shakespeare import load_shakespeare from src.data.the_stack_bpe import load_the_stack_bpe from src.data.tinystories_bpe import load_tinystories_bpe from src.data.wikitext_bpe import load_wikitext_bpe from src.model.testformer import TestFormerLM from src.model.testformer_config import TestFormerConfig, build_testformer_config def _default_learning_rate(cfg: TestFormerConfig) -> float: if cfg.d_model <= 384: return 3.0e-4 if cfg.d_model <= 640: return 2.0e-4 return 1.5e-4 def _make_cosine_warmup_scheduler( optimizer: torch.optim.Optimizer, total_steps: int, warmup_fraction: float, ) -> torch.optim.lr_scheduler.LambdaLR: warmup_steps = max(1, int(total_steps * warmup_fraction)) def lr_lambda(current_step: int) -> float: if current_step < warmup_steps: return float(current_step + 1) / float(warmup_steps) progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps) return 0.5 * (1.0 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def load_testformer_dataset( dataset: str, seq_len: int, device: str, data_dir: str = "data_cache", the_stack_repo: str = "bigcode/the-stack-smol-xs", the_stack_lang: str = "python", the_stack_bytes: int = 8_000_000, the_stack_vocab_size: int = 4096, tinystories_repo: str = "roneneldan/TinyStories", tinystories_bytes: int = 16_000_000, tinystories_vocab_size: int = 4096, openwebmath_repo: str = "open-web-math/open-web-math", openwebmath_bytes: int = 200_000, openwebmath_vocab_size: int = 256, wikitext_repo: str = "wikitext", wikitext_config_name: str = "wikitext-2-raw-v1", wikitext_bytes: int = 2_000_000, wikitext_vocab_size: int = 4096, ) -> tuple[Any, Any]: if dataset == "anchor-synthetic": return load_anchor_synthetic(seq_len=24, device=device) if dataset == "shakespeare": return load_shakespeare(seq_len=seq_len, device=device, data_dir=data_dir) if dataset == "the-stack-bpe": return load_the_stack_bpe( seq_len=seq_len, device=device, data_dir=data_dir, repo_id=the_stack_repo, lang=the_stack_lang, target_bytes=the_stack_bytes, vocab_size=the_stack_vocab_size, ) if dataset == "tinystories-bpe": return load_tinystories_bpe( seq_len=seq_len, device=device, data_dir=data_dir, repo_id=tinystories_repo, target_bytes=tinystories_bytes, vocab_size=tinystories_vocab_size, ) if dataset == "openwebmath-bpe": return load_openwebmath_bpe( seq_len=seq_len, device=device, data_dir=data_dir, repo_id=openwebmath_repo, target_bytes=openwebmath_bytes, vocab_size=openwebmath_vocab_size, ) if dataset == "wikitext-bpe": return load_wikitext_bpe( seq_len=seq_len, device=device, data_dir=data_dir, repo_id=wikitext_repo, config_name=wikitext_config_name, target_bytes=wikitext_bytes, vocab_size=wikitext_vocab_size, ) raise ValueError(f"Unknown TestFormer dataset: {dataset}") def evaluate_testformer( model: TestFormerLM, dataset: Any, batch_size: int, device: str, max_batches: int = 5, ) -> dict[str, float]: model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for _ in range(max_batches): x, y = dataset.get_batch(batch_size) x = x.to(device) y = y.to(device) out = model(x, y) total_loss += float(out["loss"].item()) * y.numel() total_tokens += y.numel() mean_loss = total_loss / max(1, total_tokens) return { "loss": mean_loss, "bpb": mean_loss / math.log(2.0), } def train_testformer( cfg: TestFormerConfig, dataset: str = "anchor-synthetic", device: str = "cpu", data_dir: str = "data_cache", steps: int = 100, batch_size: int = 16, eval_every: int = 20, eval_batches: int = 5, learning_rate: float | None = None, weight_decay: float = 0.1, beta1: float = 0.9, beta2: float = 0.95, grad_clip: float = 1.0, warmup_fraction: float = 0.02, the_stack_repo: str = "bigcode/the-stack-smol-xs", the_stack_lang: str = "python", the_stack_bytes: int = 8_000_000, the_stack_vocab_size: int = 4096, tinystories_repo: str = "roneneldan/TinyStories", tinystories_bytes: int = 16_000_000, tinystories_vocab_size: int = 4096, openwebmath_repo: str = "open-web-math/open-web-math", openwebmath_bytes: int = 200_000, openwebmath_vocab_size: int = 256, wikitext_repo: str = "wikitext", wikitext_config_name: str = "wikitext-2-raw-v1", wikitext_bytes: int = 2_000_000, wikitext_vocab_size: int = 4096, ) -> tuple[TestFormerLM, list[dict[str, float]], Any, Any]: train_data, val_data = load_testformer_dataset( dataset=dataset, seq_len=cfg.max_seq_len, device=device, data_dir=data_dir, the_stack_repo=the_stack_repo, the_stack_lang=the_stack_lang, the_stack_bytes=the_stack_bytes, the_stack_vocab_size=the_stack_vocab_size, tinystories_repo=tinystories_repo, tinystories_bytes=tinystories_bytes, tinystories_vocab_size=tinystories_vocab_size, openwebmath_repo=openwebmath_repo, openwebmath_bytes=openwebmath_bytes, openwebmath_vocab_size=openwebmath_vocab_size, wikitext_repo=wikitext_repo, wikitext_config_name=wikitext_config_name, wikitext_bytes=wikitext_bytes, wikitext_vocab_size=wikitext_vocab_size, ) effective_cfg = replace( cfg, vocab_size=train_data.vocab_size, max_seq_len=getattr(train_data, "seq_len", cfg.max_seq_len), ) model = TestFormerLM(effective_cfg).to(device) optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate or _default_learning_rate(effective_cfg), betas=(beta1, beta2), weight_decay=weight_decay, ) scheduler = _make_cosine_warmup_scheduler( optimizer=optimizer, total_steps=max(steps, 1), warmup_fraction=warmup_fraction, ) history: list[dict[str, float]] = [] for step in range(steps): model.train() x, y = train_data.get_batch(batch_size) x = x.to(device) y = y.to(device) out = model(x, y) optimizer.zero_grad() out["loss"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() scheduler.step() if (step + 1) % eval_every == 0 or step == steps - 1: metrics = evaluate_testformer( model=model, dataset=val_data, batch_size=batch_size, device=device, max_batches=eval_batches, ) history.append( { "step": float(step + 1), "train_loss": float(out["loss"].item()), "train_bpb": float(out["loss"].item() / math.log(2.0)), "val_loss": metrics["loss"], "val_bpb": metrics["bpb"], "lr": float(optimizer.param_groups[0]["lr"]), } ) model.training_history = history return model, history, train_data, val_data def summarize_testformer_result( preset_name: str, motif_name: str, dataset_name: str, model: TestFormerLM, history: list[dict[str, float]], ) -> dict[str, Any]: last = history[-1] if history else {} return { "preset": preset_name, "motif": motif_name, "dataset": dataset_name, "parameters": model.parameter_count(), "body_parameters": model.body_parameter_count(), "d_model": model.cfg.d_model, "n_layers": model.cfg.n_layers, "n_heads": model.cfg.n_heads, "d_ff": model.cfg.d_ff, "qk_dim": model.cfg.qk_dim, "v_dim": model.cfg.v_dim, "final_train_loss": last.get("train_loss"), "final_val_loss": last.get("val_loss"), "final_val_bpb": last.get("val_bpb"), "history": history, } def save_testformer_json(payload: Any, path: str | Path) -> None: path = Path(path) if path.parent: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2), encoding="utf-8") def build_runtime_testformer_config( preset_name: str, motif_name: str, vocab_size: int = 32000, seq_len: int | None = None, resid_dropout: float = 0.0, attn_dropout: float = 0.0, emb_dropout: float = 0.0, ) -> TestFormerConfig: return build_testformer_config( preset_name=preset_name, motif_name=motif_name, vocab_size=vocab_size, max_seq_len=seq_len, resid_dropout=resid_dropout, attn_dropout=attn_dropout, emb_dropout=emb_dropout, )