from __future__ import annotations import argparse import json import math import sys import traceback from datetime import datetime, timezone from pathlib import Path from typing import Any import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from src.data.wikitext_bpe import load_wikitext_bpe from src.model.testformer import TestFormerLM from src.model.testformer_combined import TestFormerCombinedLM from src.model.testformer_combined_config import build_testformer_combined_config from src.model.testformer_config import TESTFORMER_MOTIFS, TestFormerConfig, build_testformer_config ARCHIVE_DIR = ROOT / "archive" ARCHIVE_DIR.mkdir(exist_ok=True) DEFAULT_MOTIFS = ("Uniform-Baseline", "Narrow-Compare", "Wide-Memory") _PARAM_MATCH_CACHE: dict[tuple[str, int, int, int], TestFormerConfig] = {} def _default_learning_rate(d_model: int) -> float: if d_model <= 384: return 3.0e-4 if d_model <= 640: return 2.0e-4 return 1.5e-4 def _make_cosine_warmup_scheduler( optimizer: torch.optim.Optimizer, total_steps: int, warmup_fraction: float, ) -> torch.optim.lr_scheduler.LambdaLR: warmup_steps = max(1, int(total_steps * warmup_fraction)) def lr_lambda(current_step: int) -> float: if current_step < warmup_steps: return float(current_step + 1) / float(warmup_steps) progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps) return 0.5 * (1.0 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def _fit_language_model( model: torch.nn.Module, train_data: Any, val_data: Any, device: str, steps: int, batch_size: int, eval_every: int, eval_batches: int, learning_rate: float, weight_decay: float, beta1: float, beta2: float, grad_clip: float, warmup_fraction: float, ) -> list[dict[str, float]]: optimizer = torch.optim.AdamW( model.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay, ) scheduler = _make_cosine_warmup_scheduler( optimizer=optimizer, total_steps=max(steps, 1), warmup_fraction=warmup_fraction, ) history: list[dict[str, float]] = [] for step in range(steps): model.train() x, y = train_data.get_batch(batch_size) x = x.to(device) y = y.to(device) out = model(x, y) optimizer.zero_grad() out["loss"].backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() scheduler.step() if (step + 1) % eval_every == 0 or step == steps - 1: metrics = _evaluate_model( model=model, dataset=val_data, batch_size=batch_size, device=device, max_batches=eval_batches, ) history.append( { "step": float(step + 1), "train_loss": float(out["loss"].item()), "train_bpb": float(out["loss"].item() / math.log(2.0)), "val_loss": metrics["loss"], "val_bpb": metrics["bpb"], "lr": float(optimizer.param_groups[0]["lr"]), } ) return history def _evaluate_model( model: torch.nn.Module, dataset: Any, batch_size: int, device: str, max_batches: int, ) -> dict[str, float]: model.eval() total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for _ in range(max_batches): x, y = dataset.get_batch(batch_size) x = x.to(device) y = y.to(device) out = model(x, y) total_loss += float(out["loss"].item()) * y.numel() total_tokens += y.numel() mean_loss = total_loss / max(1, total_tokens) return { "loss": mean_loss, "bpb": mean_loss / math.log(2.0), } def _find_param_matched_single_config( motif_name: str, target_params: int, vocab_size: int, max_seq_len: int, ) -> TestFormerConfig: cache_key = (motif_name, target_params, vocab_size, max_seq_len) if cache_key in _PARAM_MATCH_CACHE: return _PARAM_MATCH_CACHE[cache_key] motif = TESTFORMER_MOTIFS[motif_name] meta_device = torch.device("meta") best_cfg: TestFormerConfig | None = None best_diff: int | None = None for d_model in range(256, 1025, 64): n_heads = d_model // 64 d_ff = int(round(d_model * motif.r_ff)) for n_layers in range(8, 33): cfg = TestFormerConfig( name=f"TestFormer-ParamMatched-{motif_name}", vocab_size=vocab_size, d_model=d_model, n_layers=n_layers, n_heads=n_heads, d_ff=d_ff, max_seq_len=max_seq_len, alpha_q=motif.alpha_q, alpha_k=motif.alpha_k, beta_v=motif.beta_v, motif_name=motif.name, ) params = TestFormerLM(cfg, device=meta_device).parameter_count() diff = abs(params - target_params) if best_diff is None or diff < best_diff: best_cfg = cfg best_diff = diff if best_cfg is None: raise RuntimeError(f"Could not find a param-matched config for {motif_name}") _PARAM_MATCH_CACHE[cache_key] = best_cfg return best_cfg def _summarize_single_run( motif_name: str, model: TestFormerLM, history: list[dict[str, float]], ) -> dict[str, Any]: last = history[-1] return { "label": motif_name, "model_type": "single", "motif": motif_name, "parameters": model.parameter_count(), "body_parameters": model.body_parameter_count(), "d_model": model.cfg.d_model, "n_layers": model.cfg.n_layers, "n_heads": model.cfg.n_heads, "d_ff": model.cfg.d_ff, "qk_dim": model.cfg.qk_dim, "v_dim": model.cfg.v_dim, "final_train_loss": last["train_loss"], "final_val_loss": last["val_loss"], "final_val_bpb": last["val_bpb"], "history": history, } def _summarize_combined_run( model: TestFormerCombinedLM, history: list[dict[str, float]], ) -> dict[str, Any]: last = history[-1] blend_weights = model.current_blend_weights().cpu() submodel_parameters = { motif_name: submodel.parameter_count() for motif_name, submodel in zip(model.motif_names, model.submodels) } return { "label": "Combined", "model_type": "combined", "motifs": list(model.motif_names), "parameters": model.parameter_count(), "body_parameters": model.body_parameter_count(), "blend_weights": { motif_name: float(weight.item()) for motif_name, weight in zip(model.motif_names, blend_weights) }, "submodel_parameters": submodel_parameters, "final_train_loss": last["train_loss"], "final_val_loss": last["val_loss"], "final_val_bpb": last["val_bpb"], "history": history, } def run_testformer_wikitext_combo( preset_name: str, motif_names: tuple[str, ...], seq_len: int, steps: int, batch_size: int, eval_every: int, eval_batches: int, device: str, data_dir: str, wikitext_repo: str, wikitext_config_name: str, wikitext_bytes: int, wikitext_vocab_size: int, weight_decay: float, beta1: float, beta2: float, grad_clip: float, warmup_fraction: float, match_param_budget: bool = False, target_params: int | None = None, ) -> dict[str, Any]: train_data, val_data = load_wikitext_bpe( seq_len=seq_len, device=device, data_dir=data_dir, repo_id=wikitext_repo, config_name=wikitext_config_name, target_bytes=wikitext_bytes, vocab_size=wikitext_vocab_size, ) actual_seq_len = getattr(train_data, "seq_len", seq_len) actual_vocab_size = int(train_data.vocab_size) results: list[dict[str, Any]] = [] combined_cfg = build_testformer_combined_config( preset_name=preset_name, motif_names=motif_names, vocab_size=actual_vocab_size, max_seq_len=actual_seq_len, ) combined_reference_params = TestFormerCombinedLM(combined_cfg, device=torch.device("meta")).parameter_count() resolved_target_params = target_params or combined_reference_params for motif_name in motif_names: if match_param_budget: cfg = _find_param_matched_single_config( motif_name=motif_name, target_params=resolved_target_params, vocab_size=actual_vocab_size, max_seq_len=actual_seq_len, ) else: cfg = build_testformer_config( preset_name=preset_name, motif_name=motif_name, vocab_size=actual_vocab_size, max_seq_len=actual_seq_len, ) model = TestFormerLM(cfg).to(device) history = _fit_language_model( model=model, train_data=train_data, val_data=val_data, device=device, steps=steps, batch_size=batch_size, eval_every=eval_every, eval_batches=eval_batches, learning_rate=_default_learning_rate(cfg.d_model), weight_decay=weight_decay, beta1=beta1, beta2=beta2, grad_clip=grad_clip, warmup_fraction=warmup_fraction, ) results.append(_summarize_single_run(motif_name=motif_name, model=model, history=history)) combined_model = TestFormerCombinedLM(combined_cfg).to(device) combined_history = _fit_language_model( model=combined_model, train_data=train_data, val_data=val_data, device=device, steps=steps, batch_size=batch_size, eval_every=eval_every, eval_batches=eval_batches, learning_rate=_default_learning_rate(combined_model.submodels[0].cfg.d_model), weight_decay=weight_decay, beta1=beta1, beta2=beta2, grad_clip=grad_clip, warmup_fraction=warmup_fraction, ) results.append(_summarize_combined_run(model=combined_model, history=combined_history)) ranking_by_val_loss = [ { "label": run["label"], "model_type": run["model_type"], "final_val_loss": run["final_val_loss"], "parameters": run["parameters"], } for run in sorted(results, key=lambda run: float(run["final_val_loss"])) ] timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") archive_path = ARCHIVE_DIR / f"testformer_wikitext_combo_{timestamp}.json" report = { "status": "success", "preset": preset_name, "dataset": "wikitext-bpe", "motifs": list(motif_names), "device": device, "steps": steps, "batch_size": batch_size, "eval_every": eval_every, "eval_batches": eval_batches, "match_param_budget": match_param_budget, "target_params": resolved_target_params, "combined_reference_params": combined_reference_params, "seq_len": actual_seq_len, "vocab_size": actual_vocab_size, "wikitext_repo": wikitext_repo, "wikitext_config_name": wikitext_config_name, "wikitext_bytes": wikitext_bytes, "train_token_count": int(len(train_data)), "val_token_count": int(len(val_data)), "runs": results, "ranking_by_val_loss": ranking_by_val_loss, "archive_path": str(archive_path), } archive_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8") return report def _parse_motifs(raw: str) -> tuple[str, ...]: motifs = tuple(part.strip() for part in raw.split(",") if part.strip()) return motifs or DEFAULT_MOTIFS def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--preset", default="TestFormer-0.25x") parser.add_argument("--motifs", default=",".join(DEFAULT_MOTIFS)) parser.add_argument("--seq-len", type=int, default=256) parser.add_argument("--steps", type=int, default=300) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--eval-every", type=int, default=100) parser.add_argument("--eval-batches", type=int, default=8) parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") parser.add_argument("--data-dir", default="data_cache") parser.add_argument("--wikitext-repo", default="wikitext") parser.add_argument("--wikitext-config-name", default="wikitext-2-raw-v1") parser.add_argument("--wikitext-bytes", type=int, default=1_000_000) parser.add_argument("--wikitext-vocab-size", type=int, default=2048) parser.add_argument("--weight-decay", type=float, default=0.1) parser.add_argument("--beta1", type=float, default=0.9) parser.add_argument("--beta2", type=float, default=0.95) parser.add_argument("--grad-clip", type=float, default=1.0) parser.add_argument("--warmup-fraction", type=float, default=0.02) parser.add_argument("--match-param-budget", action="store_true") parser.add_argument("--target-params", type=int, default=None) args, _ = parser.parse_known_args() try: report = run_testformer_wikitext_combo( preset_name=args.preset, motif_names=_parse_motifs(args.motifs), seq_len=args.seq_len, steps=args.steps, batch_size=args.batch_size, eval_every=args.eval_every, eval_batches=args.eval_batches, device=args.device, data_dir=args.data_dir, wikitext_repo=args.wikitext_repo, wikitext_config_name=args.wikitext_config_name, wikitext_bytes=args.wikitext_bytes, wikitext_vocab_size=args.wikitext_vocab_size, weight_decay=args.weight_decay, beta1=args.beta1, beta2=args.beta2, grad_clip=args.grad_clip, warmup_fraction=args.warmup_fraction, match_param_budget=args.match_param_budget, target_params=args.target_params, ) except Exception as exc: report = { "status": "error", "error": str(exc), "traceback": traceback.format_exc(), } print("\n===FINAL_RESULT===") print(json.dumps(report, indent=2, ensure_ascii=False)) if __name__ == "__main__": main()