Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import sys | |
| import traceback | |
| from datetime import datetime, timezone | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from src.data.wikitext_bpe import load_wikitext_bpe | |
| from src.model.testformer import TestFormerLM | |
| from src.model.testformer_combined import TestFormerCombinedLM | |
| from src.model.testformer_combined_config import build_testformer_combined_config | |
| from src.model.testformer_config import TESTFORMER_MOTIFS, TestFormerConfig, build_testformer_config | |
| ARCHIVE_DIR = ROOT / "archive" | |
| ARCHIVE_DIR.mkdir(exist_ok=True) | |
| DEFAULT_MOTIFS = ("Uniform-Baseline", "Narrow-Compare", "Wide-Memory") | |
| _PARAM_MATCH_CACHE: dict[tuple[str, int, int, int], TestFormerConfig] = {} | |
| def _default_learning_rate(d_model: int) -> float: | |
| if d_model <= 384: | |
| return 3.0e-4 | |
| if d_model <= 640: | |
| return 2.0e-4 | |
| return 1.5e-4 | |
| def _make_cosine_warmup_scheduler( | |
| optimizer: torch.optim.Optimizer, | |
| total_steps: int, | |
| warmup_fraction: float, | |
| ) -> torch.optim.lr_scheduler.LambdaLR: | |
| warmup_steps = max(1, int(total_steps * warmup_fraction)) | |
| def lr_lambda(current_step: int) -> float: | |
| if current_step < warmup_steps: | |
| return float(current_step + 1) / float(warmup_steps) | |
| progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| return 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) | |
| def _fit_language_model( | |
| model: torch.nn.Module, | |
| train_data: Any, | |
| val_data: Any, | |
| device: str, | |
| steps: int, | |
| batch_size: int, | |
| eval_every: int, | |
| eval_batches: int, | |
| learning_rate: float, | |
| weight_decay: float, | |
| beta1: float, | |
| beta2: float, | |
| grad_clip: float, | |
| warmup_fraction: float, | |
| ) -> list[dict[str, float]]: | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=learning_rate, | |
| betas=(beta1, beta2), | |
| weight_decay=weight_decay, | |
| ) | |
| scheduler = _make_cosine_warmup_scheduler( | |
| optimizer=optimizer, | |
| total_steps=max(steps, 1), | |
| warmup_fraction=warmup_fraction, | |
| ) | |
| history: list[dict[str, float]] = [] | |
| for step in range(steps): | |
| model.train() | |
| x, y = train_data.get_batch(batch_size) | |
| x = x.to(device) | |
| y = y.to(device) | |
| out = model(x, y) | |
| optimizer.zero_grad() | |
| out["loss"].backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| scheduler.step() | |
| if (step + 1) % eval_every == 0 or step == steps - 1: | |
| metrics = _evaluate_model( | |
| model=model, | |
| dataset=val_data, | |
| batch_size=batch_size, | |
| device=device, | |
| max_batches=eval_batches, | |
| ) | |
| history.append( | |
| { | |
| "step": float(step + 1), | |
| "train_loss": float(out["loss"].item()), | |
| "train_bpb": float(out["loss"].item() / math.log(2.0)), | |
| "val_loss": metrics["loss"], | |
| "val_bpb": metrics["bpb"], | |
| "lr": float(optimizer.param_groups[0]["lr"]), | |
| } | |
| ) | |
| return history | |
| def _evaluate_model( | |
| model: torch.nn.Module, | |
| dataset: Any, | |
| batch_size: int, | |
| device: str, | |
| max_batches: int, | |
| ) -> dict[str, float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| total_tokens = 0 | |
| with torch.no_grad(): | |
| for _ in range(max_batches): | |
| x, y = dataset.get_batch(batch_size) | |
| x = x.to(device) | |
| y = y.to(device) | |
| out = model(x, y) | |
| total_loss += float(out["loss"].item()) * y.numel() | |
| total_tokens += y.numel() | |
| mean_loss = total_loss / max(1, total_tokens) | |
| return { | |
| "loss": mean_loss, | |
| "bpb": mean_loss / math.log(2.0), | |
| } | |
| def _find_param_matched_single_config( | |
| motif_name: str, | |
| target_params: int, | |
| vocab_size: int, | |
| max_seq_len: int, | |
| ) -> TestFormerConfig: | |
| cache_key = (motif_name, target_params, vocab_size, max_seq_len) | |
| if cache_key in _PARAM_MATCH_CACHE: | |
| return _PARAM_MATCH_CACHE[cache_key] | |
| motif = TESTFORMER_MOTIFS[motif_name] | |
| meta_device = torch.device("meta") | |
| best_cfg: TestFormerConfig | None = None | |
| best_diff: int | None = None | |
| for d_model in range(256, 1025, 64): | |
| n_heads = d_model // 64 | |
| d_ff = int(round(d_model * motif.r_ff)) | |
| for n_layers in range(8, 33): | |
| cfg = TestFormerConfig( | |
| name=f"TestFormer-ParamMatched-{motif_name}", | |
| vocab_size=vocab_size, | |
| d_model=d_model, | |
| n_layers=n_layers, | |
| n_heads=n_heads, | |
| d_ff=d_ff, | |
| max_seq_len=max_seq_len, | |
| alpha_q=motif.alpha_q, | |
| alpha_k=motif.alpha_k, | |
| beta_v=motif.beta_v, | |
| motif_name=motif.name, | |
| ) | |
| params = TestFormerLM(cfg, device=meta_device).parameter_count() | |
| diff = abs(params - target_params) | |
| if best_diff is None or diff < best_diff: | |
| best_cfg = cfg | |
| best_diff = diff | |
| if best_cfg is None: | |
| raise RuntimeError(f"Could not find a param-matched config for {motif_name}") | |
| _PARAM_MATCH_CACHE[cache_key] = best_cfg | |
| return best_cfg | |
| def _summarize_single_run( | |
| motif_name: str, | |
| model: TestFormerLM, | |
| history: list[dict[str, float]], | |
| ) -> dict[str, Any]: | |
| last = history[-1] | |
| return { | |
| "label": motif_name, | |
| "model_type": "single", | |
| "motif": motif_name, | |
| "parameters": model.parameter_count(), | |
| "body_parameters": model.body_parameter_count(), | |
| "d_model": model.cfg.d_model, | |
| "n_layers": model.cfg.n_layers, | |
| "n_heads": model.cfg.n_heads, | |
| "d_ff": model.cfg.d_ff, | |
| "qk_dim": model.cfg.qk_dim, | |
| "v_dim": model.cfg.v_dim, | |
| "final_train_loss": last["train_loss"], | |
| "final_val_loss": last["val_loss"], | |
| "final_val_bpb": last["val_bpb"], | |
| "history": history, | |
| } | |
| def _summarize_combined_run( | |
| model: TestFormerCombinedLM, | |
| history: list[dict[str, float]], | |
| ) -> dict[str, Any]: | |
| last = history[-1] | |
| blend_weights = model.current_blend_weights().cpu() | |
| submodel_parameters = { | |
| motif_name: submodel.parameter_count() | |
| for motif_name, submodel in zip(model.motif_names, model.submodels) | |
| } | |
| return { | |
| "label": "Combined", | |
| "model_type": "combined", | |
| "motifs": list(model.motif_names), | |
| "parameters": model.parameter_count(), | |
| "body_parameters": model.body_parameter_count(), | |
| "blend_weights": { | |
| motif_name: float(weight.item()) | |
| for motif_name, weight in zip(model.motif_names, blend_weights) | |
| }, | |
| "submodel_parameters": submodel_parameters, | |
| "final_train_loss": last["train_loss"], | |
| "final_val_loss": last["val_loss"], | |
| "final_val_bpb": last["val_bpb"], | |
| "history": history, | |
| } | |
| def run_testformer_wikitext_combo( | |
| preset_name: str, | |
| motif_names: tuple[str, ...], | |
| seq_len: int, | |
| steps: int, | |
| batch_size: int, | |
| eval_every: int, | |
| eval_batches: int, | |
| device: str, | |
| data_dir: str, | |
| wikitext_repo: str, | |
| wikitext_config_name: str, | |
| wikitext_bytes: int, | |
| wikitext_vocab_size: int, | |
| weight_decay: float, | |
| beta1: float, | |
| beta2: float, | |
| grad_clip: float, | |
| warmup_fraction: float, | |
| match_param_budget: bool = False, | |
| target_params: int | None = None, | |
| ) -> dict[str, Any]: | |
| train_data, val_data = load_wikitext_bpe( | |
| seq_len=seq_len, | |
| device=device, | |
| data_dir=data_dir, | |
| repo_id=wikitext_repo, | |
| config_name=wikitext_config_name, | |
| target_bytes=wikitext_bytes, | |
| vocab_size=wikitext_vocab_size, | |
| ) | |
| actual_seq_len = getattr(train_data, "seq_len", seq_len) | |
| actual_vocab_size = int(train_data.vocab_size) | |
| results: list[dict[str, Any]] = [] | |
| combined_cfg = build_testformer_combined_config( | |
| preset_name=preset_name, | |
| motif_names=motif_names, | |
| vocab_size=actual_vocab_size, | |
| max_seq_len=actual_seq_len, | |
| ) | |
| combined_reference_params = TestFormerCombinedLM(combined_cfg, device=torch.device("meta")).parameter_count() | |
| resolved_target_params = target_params or combined_reference_params | |
| for motif_name in motif_names: | |
| if match_param_budget: | |
| cfg = _find_param_matched_single_config( | |
| motif_name=motif_name, | |
| target_params=resolved_target_params, | |
| vocab_size=actual_vocab_size, | |
| max_seq_len=actual_seq_len, | |
| ) | |
| else: | |
| cfg = build_testformer_config( | |
| preset_name=preset_name, | |
| motif_name=motif_name, | |
| vocab_size=actual_vocab_size, | |
| max_seq_len=actual_seq_len, | |
| ) | |
| model = TestFormerLM(cfg).to(device) | |
| history = _fit_language_model( | |
| model=model, | |
| train_data=train_data, | |
| val_data=val_data, | |
| device=device, | |
| steps=steps, | |
| batch_size=batch_size, | |
| eval_every=eval_every, | |
| eval_batches=eval_batches, | |
| learning_rate=_default_learning_rate(cfg.d_model), | |
| weight_decay=weight_decay, | |
| beta1=beta1, | |
| beta2=beta2, | |
| grad_clip=grad_clip, | |
| warmup_fraction=warmup_fraction, | |
| ) | |
| results.append(_summarize_single_run(motif_name=motif_name, model=model, history=history)) | |
| combined_model = TestFormerCombinedLM(combined_cfg).to(device) | |
| combined_history = _fit_language_model( | |
| model=combined_model, | |
| train_data=train_data, | |
| val_data=val_data, | |
| device=device, | |
| steps=steps, | |
| batch_size=batch_size, | |
| eval_every=eval_every, | |
| eval_batches=eval_batches, | |
| learning_rate=_default_learning_rate(combined_model.submodels[0].cfg.d_model), | |
| weight_decay=weight_decay, | |
| beta1=beta1, | |
| beta2=beta2, | |
| grad_clip=grad_clip, | |
| warmup_fraction=warmup_fraction, | |
| ) | |
| results.append(_summarize_combined_run(model=combined_model, history=combined_history)) | |
| ranking_by_val_loss = [ | |
| { | |
| "label": run["label"], | |
| "model_type": run["model_type"], | |
| "final_val_loss": run["final_val_loss"], | |
| "parameters": run["parameters"], | |
| } | |
| for run in sorted(results, key=lambda run: float(run["final_val_loss"])) | |
| ] | |
| timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") | |
| archive_path = ARCHIVE_DIR / f"testformer_wikitext_combo_{timestamp}.json" | |
| report = { | |
| "status": "success", | |
| "preset": preset_name, | |
| "dataset": "wikitext-bpe", | |
| "motifs": list(motif_names), | |
| "device": device, | |
| "steps": steps, | |
| "batch_size": batch_size, | |
| "eval_every": eval_every, | |
| "eval_batches": eval_batches, | |
| "match_param_budget": match_param_budget, | |
| "target_params": resolved_target_params, | |
| "combined_reference_params": combined_reference_params, | |
| "seq_len": actual_seq_len, | |
| "vocab_size": actual_vocab_size, | |
| "wikitext_repo": wikitext_repo, | |
| "wikitext_config_name": wikitext_config_name, | |
| "wikitext_bytes": wikitext_bytes, | |
| "train_token_count": int(len(train_data)), | |
| "val_token_count": int(len(val_data)), | |
| "runs": results, | |
| "ranking_by_val_loss": ranking_by_val_loss, | |
| "archive_path": str(archive_path), | |
| } | |
| archive_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8") | |
| return report | |
| def _parse_motifs(raw: str) -> tuple[str, ...]: | |
| motifs = tuple(part.strip() for part in raw.split(",") if part.strip()) | |
| return motifs or DEFAULT_MOTIFS | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--preset", default="TestFormer-0.25x") | |
| parser.add_argument("--motifs", default=",".join(DEFAULT_MOTIFS)) | |
| parser.add_argument("--seq-len", type=int, default=256) | |
| parser.add_argument("--steps", type=int, default=300) | |
| parser.add_argument("--batch-size", type=int, default=16) | |
| parser.add_argument("--eval-every", type=int, default=100) | |
| parser.add_argument("--eval-batches", type=int, default=8) | |
| parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
| parser.add_argument("--data-dir", default="data_cache") | |
| parser.add_argument("--wikitext-repo", default="wikitext") | |
| parser.add_argument("--wikitext-config-name", default="wikitext-2-raw-v1") | |
| parser.add_argument("--wikitext-bytes", type=int, default=1_000_000) | |
| parser.add_argument("--wikitext-vocab-size", type=int, default=2048) | |
| parser.add_argument("--weight-decay", type=float, default=0.1) | |
| parser.add_argument("--beta1", type=float, default=0.9) | |
| parser.add_argument("--beta2", type=float, default=0.95) | |
| parser.add_argument("--grad-clip", type=float, default=1.0) | |
| parser.add_argument("--warmup-fraction", type=float, default=0.02) | |
| parser.add_argument("--match-param-budget", action="store_true") | |
| parser.add_argument("--target-params", type=int, default=None) | |
| args, _ = parser.parse_known_args() | |
| try: | |
| report = run_testformer_wikitext_combo( | |
| preset_name=args.preset, | |
| motif_names=_parse_motifs(args.motifs), | |
| seq_len=args.seq_len, | |
| steps=args.steps, | |
| batch_size=args.batch_size, | |
| eval_every=args.eval_every, | |
| eval_batches=args.eval_batches, | |
| device=args.device, | |
| data_dir=args.data_dir, | |
| wikitext_repo=args.wikitext_repo, | |
| wikitext_config_name=args.wikitext_config_name, | |
| wikitext_bytes=args.wikitext_bytes, | |
| wikitext_vocab_size=args.wikitext_vocab_size, | |
| weight_decay=args.weight_decay, | |
| beta1=args.beta1, | |
| beta2=args.beta2, | |
| grad_clip=args.grad_clip, | |
| warmup_fraction=args.warmup_fraction, | |
| match_param_budget=args.match_param_budget, | |
| target_params=args.target_params, | |
| ) | |
| except Exception as exc: | |
| report = { | |
| "status": "error", | |
| "error": str(exc), | |
| "traceback": traceback.format_exc(), | |
| } | |
| print("\n===FINAL_RESULT===") | |
| print(json.dumps(report, indent=2, ensure_ascii=False)) | |
| if __name__ == "__main__": | |
| main() | |