abpt / scripts /run_testformer_wikitext_combo_remote.py
Search
feat: add param-matched testformer mode
6536cc7
from __future__ import annotations
import argparse
import json
import math
import sys
import traceback
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import torch
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.data.wikitext_bpe import load_wikitext_bpe
from src.model.testformer import TestFormerLM
from src.model.testformer_combined import TestFormerCombinedLM
from src.model.testformer_combined_config import build_testformer_combined_config
from src.model.testformer_config import TESTFORMER_MOTIFS, TestFormerConfig, build_testformer_config
ARCHIVE_DIR = ROOT / "archive"
ARCHIVE_DIR.mkdir(exist_ok=True)
DEFAULT_MOTIFS = ("Uniform-Baseline", "Narrow-Compare", "Wide-Memory")
_PARAM_MATCH_CACHE: dict[tuple[str, int, int, int], TestFormerConfig] = {}
def _default_learning_rate(d_model: int) -> float:
if d_model <= 384:
return 3.0e-4
if d_model <= 640:
return 2.0e-4
return 1.5e-4
def _make_cosine_warmup_scheduler(
optimizer: torch.optim.Optimizer,
total_steps: int,
warmup_fraction: float,
) -> torch.optim.lr_scheduler.LambdaLR:
warmup_steps = max(1, int(total_steps * warmup_fraction))
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
return float(current_step + 1) / float(warmup_steps)
progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
def _fit_language_model(
model: torch.nn.Module,
train_data: Any,
val_data: Any,
device: str,
steps: int,
batch_size: int,
eval_every: int,
eval_batches: int,
learning_rate: float,
weight_decay: float,
beta1: float,
beta2: float,
grad_clip: float,
warmup_fraction: float,
) -> list[dict[str, float]]:
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate,
betas=(beta1, beta2),
weight_decay=weight_decay,
)
scheduler = _make_cosine_warmup_scheduler(
optimizer=optimizer,
total_steps=max(steps, 1),
warmup_fraction=warmup_fraction,
)
history: list[dict[str, float]] = []
for step in range(steps):
model.train()
x, y = train_data.get_batch(batch_size)
x = x.to(device)
y = y.to(device)
out = model(x, y)
optimizer.zero_grad()
out["loss"].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
scheduler.step()
if (step + 1) % eval_every == 0 or step == steps - 1:
metrics = _evaluate_model(
model=model,
dataset=val_data,
batch_size=batch_size,
device=device,
max_batches=eval_batches,
)
history.append(
{
"step": float(step + 1),
"train_loss": float(out["loss"].item()),
"train_bpb": float(out["loss"].item() / math.log(2.0)),
"val_loss": metrics["loss"],
"val_bpb": metrics["bpb"],
"lr": float(optimizer.param_groups[0]["lr"]),
}
)
return history
def _evaluate_model(
model: torch.nn.Module,
dataset: Any,
batch_size: int,
device: str,
max_batches: int,
) -> dict[str, float]:
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for _ in range(max_batches):
x, y = dataset.get_batch(batch_size)
x = x.to(device)
y = y.to(device)
out = model(x, y)
total_loss += float(out["loss"].item()) * y.numel()
total_tokens += y.numel()
mean_loss = total_loss / max(1, total_tokens)
return {
"loss": mean_loss,
"bpb": mean_loss / math.log(2.0),
}
def _find_param_matched_single_config(
motif_name: str,
target_params: int,
vocab_size: int,
max_seq_len: int,
) -> TestFormerConfig:
cache_key = (motif_name, target_params, vocab_size, max_seq_len)
if cache_key in _PARAM_MATCH_CACHE:
return _PARAM_MATCH_CACHE[cache_key]
motif = TESTFORMER_MOTIFS[motif_name]
meta_device = torch.device("meta")
best_cfg: TestFormerConfig | None = None
best_diff: int | None = None
for d_model in range(256, 1025, 64):
n_heads = d_model // 64
d_ff = int(round(d_model * motif.r_ff))
for n_layers in range(8, 33):
cfg = TestFormerConfig(
name=f"TestFormer-ParamMatched-{motif_name}",
vocab_size=vocab_size,
d_model=d_model,
n_layers=n_layers,
n_heads=n_heads,
d_ff=d_ff,
max_seq_len=max_seq_len,
alpha_q=motif.alpha_q,
alpha_k=motif.alpha_k,
beta_v=motif.beta_v,
motif_name=motif.name,
)
params = TestFormerLM(cfg, device=meta_device).parameter_count()
diff = abs(params - target_params)
if best_diff is None or diff < best_diff:
best_cfg = cfg
best_diff = diff
if best_cfg is None:
raise RuntimeError(f"Could not find a param-matched config for {motif_name}")
_PARAM_MATCH_CACHE[cache_key] = best_cfg
return best_cfg
def _summarize_single_run(
motif_name: str,
model: TestFormerLM,
history: list[dict[str, float]],
) -> dict[str, Any]:
last = history[-1]
return {
"label": motif_name,
"model_type": "single",
"motif": motif_name,
"parameters": model.parameter_count(),
"body_parameters": model.body_parameter_count(),
"d_model": model.cfg.d_model,
"n_layers": model.cfg.n_layers,
"n_heads": model.cfg.n_heads,
"d_ff": model.cfg.d_ff,
"qk_dim": model.cfg.qk_dim,
"v_dim": model.cfg.v_dim,
"final_train_loss": last["train_loss"],
"final_val_loss": last["val_loss"],
"final_val_bpb": last["val_bpb"],
"history": history,
}
def _summarize_combined_run(
model: TestFormerCombinedLM,
history: list[dict[str, float]],
) -> dict[str, Any]:
last = history[-1]
blend_weights = model.current_blend_weights().cpu()
submodel_parameters = {
motif_name: submodel.parameter_count()
for motif_name, submodel in zip(model.motif_names, model.submodels)
}
return {
"label": "Combined",
"model_type": "combined",
"motifs": list(model.motif_names),
"parameters": model.parameter_count(),
"body_parameters": model.body_parameter_count(),
"blend_weights": {
motif_name: float(weight.item())
for motif_name, weight in zip(model.motif_names, blend_weights)
},
"submodel_parameters": submodel_parameters,
"final_train_loss": last["train_loss"],
"final_val_loss": last["val_loss"],
"final_val_bpb": last["val_bpb"],
"history": history,
}
def run_testformer_wikitext_combo(
preset_name: str,
motif_names: tuple[str, ...],
seq_len: int,
steps: int,
batch_size: int,
eval_every: int,
eval_batches: int,
device: str,
data_dir: str,
wikitext_repo: str,
wikitext_config_name: str,
wikitext_bytes: int,
wikitext_vocab_size: int,
weight_decay: float,
beta1: float,
beta2: float,
grad_clip: float,
warmup_fraction: float,
match_param_budget: bool = False,
target_params: int | None = None,
) -> dict[str, Any]:
train_data, val_data = load_wikitext_bpe(
seq_len=seq_len,
device=device,
data_dir=data_dir,
repo_id=wikitext_repo,
config_name=wikitext_config_name,
target_bytes=wikitext_bytes,
vocab_size=wikitext_vocab_size,
)
actual_seq_len = getattr(train_data, "seq_len", seq_len)
actual_vocab_size = int(train_data.vocab_size)
results: list[dict[str, Any]] = []
combined_cfg = build_testformer_combined_config(
preset_name=preset_name,
motif_names=motif_names,
vocab_size=actual_vocab_size,
max_seq_len=actual_seq_len,
)
combined_reference_params = TestFormerCombinedLM(combined_cfg, device=torch.device("meta")).parameter_count()
resolved_target_params = target_params or combined_reference_params
for motif_name in motif_names:
if match_param_budget:
cfg = _find_param_matched_single_config(
motif_name=motif_name,
target_params=resolved_target_params,
vocab_size=actual_vocab_size,
max_seq_len=actual_seq_len,
)
else:
cfg = build_testformer_config(
preset_name=preset_name,
motif_name=motif_name,
vocab_size=actual_vocab_size,
max_seq_len=actual_seq_len,
)
model = TestFormerLM(cfg).to(device)
history = _fit_language_model(
model=model,
train_data=train_data,
val_data=val_data,
device=device,
steps=steps,
batch_size=batch_size,
eval_every=eval_every,
eval_batches=eval_batches,
learning_rate=_default_learning_rate(cfg.d_model),
weight_decay=weight_decay,
beta1=beta1,
beta2=beta2,
grad_clip=grad_clip,
warmup_fraction=warmup_fraction,
)
results.append(_summarize_single_run(motif_name=motif_name, model=model, history=history))
combined_model = TestFormerCombinedLM(combined_cfg).to(device)
combined_history = _fit_language_model(
model=combined_model,
train_data=train_data,
val_data=val_data,
device=device,
steps=steps,
batch_size=batch_size,
eval_every=eval_every,
eval_batches=eval_batches,
learning_rate=_default_learning_rate(combined_model.submodels[0].cfg.d_model),
weight_decay=weight_decay,
beta1=beta1,
beta2=beta2,
grad_clip=grad_clip,
warmup_fraction=warmup_fraction,
)
results.append(_summarize_combined_run(model=combined_model, history=combined_history))
ranking_by_val_loss = [
{
"label": run["label"],
"model_type": run["model_type"],
"final_val_loss": run["final_val_loss"],
"parameters": run["parameters"],
}
for run in sorted(results, key=lambda run: float(run["final_val_loss"]))
]
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
archive_path = ARCHIVE_DIR / f"testformer_wikitext_combo_{timestamp}.json"
report = {
"status": "success",
"preset": preset_name,
"dataset": "wikitext-bpe",
"motifs": list(motif_names),
"device": device,
"steps": steps,
"batch_size": batch_size,
"eval_every": eval_every,
"eval_batches": eval_batches,
"match_param_budget": match_param_budget,
"target_params": resolved_target_params,
"combined_reference_params": combined_reference_params,
"seq_len": actual_seq_len,
"vocab_size": actual_vocab_size,
"wikitext_repo": wikitext_repo,
"wikitext_config_name": wikitext_config_name,
"wikitext_bytes": wikitext_bytes,
"train_token_count": int(len(train_data)),
"val_token_count": int(len(val_data)),
"runs": results,
"ranking_by_val_loss": ranking_by_val_loss,
"archive_path": str(archive_path),
}
archive_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8")
return report
def _parse_motifs(raw: str) -> tuple[str, ...]:
motifs = tuple(part.strip() for part in raw.split(",") if part.strip())
return motifs or DEFAULT_MOTIFS
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--preset", default="TestFormer-0.25x")
parser.add_argument("--motifs", default=",".join(DEFAULT_MOTIFS))
parser.add_argument("--seq-len", type=int, default=256)
parser.add_argument("--steps", type=int, default=300)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--eval-every", type=int, default=100)
parser.add_argument("--eval-batches", type=int, default=8)
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
parser.add_argument("--data-dir", default="data_cache")
parser.add_argument("--wikitext-repo", default="wikitext")
parser.add_argument("--wikitext-config-name", default="wikitext-2-raw-v1")
parser.add_argument("--wikitext-bytes", type=int, default=1_000_000)
parser.add_argument("--wikitext-vocab-size", type=int, default=2048)
parser.add_argument("--weight-decay", type=float, default=0.1)
parser.add_argument("--beta1", type=float, default=0.9)
parser.add_argument("--beta2", type=float, default=0.95)
parser.add_argument("--grad-clip", type=float, default=1.0)
parser.add_argument("--warmup-fraction", type=float, default=0.02)
parser.add_argument("--match-param-budget", action="store_true")
parser.add_argument("--target-params", type=int, default=None)
args, _ = parser.parse_known_args()
try:
report = run_testformer_wikitext_combo(
preset_name=args.preset,
motif_names=_parse_motifs(args.motifs),
seq_len=args.seq_len,
steps=args.steps,
batch_size=args.batch_size,
eval_every=args.eval_every,
eval_batches=args.eval_batches,
device=args.device,
data_dir=args.data_dir,
wikitext_repo=args.wikitext_repo,
wikitext_config_name=args.wikitext_config_name,
wikitext_bytes=args.wikitext_bytes,
wikitext_vocab_size=args.wikitext_vocab_size,
weight_decay=args.weight_decay,
beta1=args.beta1,
beta2=args.beta2,
grad_clip=args.grad_clip,
warmup_fraction=args.warmup_fraction,
match_param_budget=args.match_param_budget,
target_params=args.target_params,
)
except Exception as exc:
report = {
"status": "error",
"error": str(exc),
"traceback": traceback.format_exc(),
}
print("\n===FINAL_RESULT===")
print(json.dumps(report, indent=2, ensure_ascii=False))
if __name__ == "__main__":
main()