abpt / src /utils /testformer_runner.py
Search
auto: sync run_testformer_wikitext_combo_remote.py
f37be5a
from __future__ import annotations
import json
import math
from dataclasses import replace
from pathlib import Path
from typing import Any
import torch
from src.data.anchor_synthetic import load_anchor_synthetic
from src.data.openwebmath_bpe import load_openwebmath_bpe
from src.data.shakespeare import load_shakespeare
from src.data.the_stack_bpe import load_the_stack_bpe
from src.data.tinystories_bpe import load_tinystories_bpe
from src.data.wikitext_bpe import load_wikitext_bpe
from src.model.testformer import TestFormerLM
from src.model.testformer_config import TestFormerConfig, build_testformer_config
def _default_learning_rate(cfg: TestFormerConfig) -> float:
if cfg.d_model <= 384:
return 3.0e-4
if cfg.d_model <= 640:
return 2.0e-4
return 1.5e-4
def _make_cosine_warmup_scheduler(
optimizer: torch.optim.Optimizer,
total_steps: int,
warmup_fraction: float,
) -> torch.optim.lr_scheduler.LambdaLR:
warmup_steps = max(1, int(total_steps * warmup_fraction))
def lr_lambda(current_step: int) -> float:
if current_step < warmup_steps:
return float(current_step + 1) / float(warmup_steps)
progress = (current_step - warmup_steps) / max(1, total_steps - warmup_steps)
return 0.5 * (1.0 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
def load_testformer_dataset(
dataset: str,
seq_len: int,
device: str,
data_dir: str = "data_cache",
the_stack_repo: str = "bigcode/the-stack-smol-xs",
the_stack_lang: str = "python",
the_stack_bytes: int = 8_000_000,
the_stack_vocab_size: int = 4096,
tinystories_repo: str = "roneneldan/TinyStories",
tinystories_bytes: int = 16_000_000,
tinystories_vocab_size: int = 4096,
openwebmath_repo: str = "open-web-math/open-web-math",
openwebmath_bytes: int = 200_000,
openwebmath_vocab_size: int = 256,
wikitext_repo: str = "wikitext",
wikitext_config_name: str = "wikitext-2-raw-v1",
wikitext_bytes: int = 2_000_000,
wikitext_vocab_size: int = 4096,
) -> tuple[Any, Any]:
if dataset == "anchor-synthetic":
return load_anchor_synthetic(seq_len=24, device=device)
if dataset == "shakespeare":
return load_shakespeare(seq_len=seq_len, device=device, data_dir=data_dir)
if dataset == "the-stack-bpe":
return load_the_stack_bpe(
seq_len=seq_len,
device=device,
data_dir=data_dir,
repo_id=the_stack_repo,
lang=the_stack_lang,
target_bytes=the_stack_bytes,
vocab_size=the_stack_vocab_size,
)
if dataset == "tinystories-bpe":
return load_tinystories_bpe(
seq_len=seq_len,
device=device,
data_dir=data_dir,
repo_id=tinystories_repo,
target_bytes=tinystories_bytes,
vocab_size=tinystories_vocab_size,
)
if dataset == "openwebmath-bpe":
return load_openwebmath_bpe(
seq_len=seq_len,
device=device,
data_dir=data_dir,
repo_id=openwebmath_repo,
target_bytes=openwebmath_bytes,
vocab_size=openwebmath_vocab_size,
)
if dataset == "wikitext-bpe":
return load_wikitext_bpe(
seq_len=seq_len,
device=device,
data_dir=data_dir,
repo_id=wikitext_repo,
config_name=wikitext_config_name,
target_bytes=wikitext_bytes,
vocab_size=wikitext_vocab_size,
)
raise ValueError(f"Unknown TestFormer dataset: {dataset}")
def evaluate_testformer(
model: TestFormerLM,
dataset: Any,
batch_size: int,
device: str,
max_batches: int = 5,
) -> dict[str, float]:
model.eval()
total_loss = 0.0
total_tokens = 0
with torch.no_grad():
for _ in range(max_batches):
x, y = dataset.get_batch(batch_size)
x = x.to(device)
y = y.to(device)
out = model(x, y)
total_loss += float(out["loss"].item()) * y.numel()
total_tokens += y.numel()
mean_loss = total_loss / max(1, total_tokens)
return {
"loss": mean_loss,
"bpb": mean_loss / math.log(2.0),
}
def train_testformer(
cfg: TestFormerConfig,
dataset: str = "anchor-synthetic",
device: str = "cpu",
data_dir: str = "data_cache",
steps: int = 100,
batch_size: int = 16,
eval_every: int = 20,
eval_batches: int = 5,
learning_rate: float | None = None,
weight_decay: float = 0.1,
beta1: float = 0.9,
beta2: float = 0.95,
grad_clip: float = 1.0,
warmup_fraction: float = 0.02,
the_stack_repo: str = "bigcode/the-stack-smol-xs",
the_stack_lang: str = "python",
the_stack_bytes: int = 8_000_000,
the_stack_vocab_size: int = 4096,
tinystories_repo: str = "roneneldan/TinyStories",
tinystories_bytes: int = 16_000_000,
tinystories_vocab_size: int = 4096,
openwebmath_repo: str = "open-web-math/open-web-math",
openwebmath_bytes: int = 200_000,
openwebmath_vocab_size: int = 256,
wikitext_repo: str = "wikitext",
wikitext_config_name: str = "wikitext-2-raw-v1",
wikitext_bytes: int = 2_000_000,
wikitext_vocab_size: int = 4096,
) -> tuple[TestFormerLM, list[dict[str, float]], Any, Any]:
train_data, val_data = load_testformer_dataset(
dataset=dataset,
seq_len=cfg.max_seq_len,
device=device,
data_dir=data_dir,
the_stack_repo=the_stack_repo,
the_stack_lang=the_stack_lang,
the_stack_bytes=the_stack_bytes,
the_stack_vocab_size=the_stack_vocab_size,
tinystories_repo=tinystories_repo,
tinystories_bytes=tinystories_bytes,
tinystories_vocab_size=tinystories_vocab_size,
openwebmath_repo=openwebmath_repo,
openwebmath_bytes=openwebmath_bytes,
openwebmath_vocab_size=openwebmath_vocab_size,
wikitext_repo=wikitext_repo,
wikitext_config_name=wikitext_config_name,
wikitext_bytes=wikitext_bytes,
wikitext_vocab_size=wikitext_vocab_size,
)
effective_cfg = replace(
cfg,
vocab_size=train_data.vocab_size,
max_seq_len=getattr(train_data, "seq_len", cfg.max_seq_len),
)
model = TestFormerLM(effective_cfg).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=learning_rate or _default_learning_rate(effective_cfg),
betas=(beta1, beta2),
weight_decay=weight_decay,
)
scheduler = _make_cosine_warmup_scheduler(
optimizer=optimizer,
total_steps=max(steps, 1),
warmup_fraction=warmup_fraction,
)
history: list[dict[str, float]] = []
for step in range(steps):
model.train()
x, y = train_data.get_batch(batch_size)
x = x.to(device)
y = y.to(device)
out = model(x, y)
optimizer.zero_grad()
out["loss"].backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
scheduler.step()
if (step + 1) % eval_every == 0 or step == steps - 1:
metrics = evaluate_testformer(
model=model,
dataset=val_data,
batch_size=batch_size,
device=device,
max_batches=eval_batches,
)
history.append(
{
"step": float(step + 1),
"train_loss": float(out["loss"].item()),
"train_bpb": float(out["loss"].item() / math.log(2.0)),
"val_loss": metrics["loss"],
"val_bpb": metrics["bpb"],
"lr": float(optimizer.param_groups[0]["lr"]),
}
)
model.training_history = history
return model, history, train_data, val_data
def summarize_testformer_result(
preset_name: str,
motif_name: str,
dataset_name: str,
model: TestFormerLM,
history: list[dict[str, float]],
) -> dict[str, Any]:
last = history[-1] if history else {}
return {
"preset": preset_name,
"motif": motif_name,
"dataset": dataset_name,
"parameters": model.parameter_count(),
"body_parameters": model.body_parameter_count(),
"d_model": model.cfg.d_model,
"n_layers": model.cfg.n_layers,
"n_heads": model.cfg.n_heads,
"d_ff": model.cfg.d_ff,
"qk_dim": model.cfg.qk_dim,
"v_dim": model.cfg.v_dim,
"final_train_loss": last.get("train_loss"),
"final_val_loss": last.get("val_loss"),
"final_val_bpb": last.get("val_bpb"),
"history": history,
}
def save_testformer_json(payload: Any, path: str | Path) -> None:
path = Path(path)
if path.parent:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
def build_runtime_testformer_config(
preset_name: str,
motif_name: str,
vocab_size: int = 32000,
seq_len: int | None = None,
resid_dropout: float = 0.0,
attn_dropout: float = 0.0,
emb_dropout: float = 0.0,
) -> TestFormerConfig:
return build_testformer_config(
preset_name=preset_name,
motif_name=motif_name,
vocab_size=vocab_size,
max_seq_len=seq_len,
resid_dropout=resid_dropout,
attn_dropout=attn_dropout,
emb_dropout=emb_dropout,
)