abpt / src /fog /train.py
kharki's picture
feat: FOG stress ablation with new tasks (conditional, intersection, compose_add, multihop)
a4b762c verified
"""Train and compare baseline vs motif-aware transformer on algorithmic tasks."""
from __future__ import annotations
import argparse
import json
import time
from pathlib import Path
import torch
from src.fog.config import (
FOGConfig,
BASELINE_SMALL, MOTIF_SMALL,
BASELINE_TINY, MOTIF_TINY, UNIFORM_TINY,
BASELINE_MICRO, MOTIF_MICRO, UNIFORM_MICRO,
BASELINE_MED, MOTIF_MED, UNIFORM_MED,
)
from src.fog.model_baseline import BaselineTransformer
from src.fog.model_motif import MotifTransformer
from src.fog.data import (
CopyTask, ReverseTask, SelectiveRetrieval,
DistractorRetrieval, NoisyRetrieval, MultiQueryRetrieval,
ChainedRetrieval,
ConditionalRetrieval, SetIntersection, ComposeArithmetic, MultiHopChained,
prebatch_dataset, TensorBatchIterator,
)
def count_params(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def train_epoch(
model: torch.nn.Module,
loader: TensorBatchIterator,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> float:
model.train()
total_loss = 0.0
n_batches = 0
for batch in loader:
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
loss_mask = batch["loss_mask"].to(device)
out = model(input_ids, targets, loss_mask=loss_mask)
loss = out["loss"]
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += loss.item()
n_batches += 1
return total_loss / max(n_batches, 1)
@torch.no_grad()
def eval_accuracy(
model: torch.nn.Module,
loader: TensorBatchIterator,
device: torch.device,
) -> dict[str, float]:
model.eval()
total_loss = 0.0
correct = 0
total = 0
seq_correct = 0
seq_total = 0
n_batches = 0
for batch in loader:
input_ids = batch["input_ids"].to(device)
targets = batch["targets"].to(device)
loss_mask = batch["loss_mask"].to(device)
out = model(input_ids, targets, loss_mask=loss_mask)
total_loss += out["loss"].item()
n_batches += 1
preds = out["logits"].argmax(dim=-1)
m = loss_mask.bool()
correct += (preds[m] == targets[m]).sum().item()
total += m.sum().item()
for b in range(preds.size(0)):
mb = m[b]
if mb.any():
seq_total += 1
if torch.equal(preds[b][mb], targets[b][mb]):
seq_correct += 1
return {
"loss": total_loss / max(n_batches, 1),
"accuracy": correct / max(total, 1),
"exact_match": seq_correct / max(seq_total, 1),
"total_tokens": total,
}
TASK_MAP = {
"copy": CopyTask,
"reverse": ReverseTask,
"retrieval": SelectiveRetrieval,
"distractor": DistractorRetrieval,
"noisy": NoisyRetrieval,
"multiquery": MultiQueryRetrieval,
"chained": ChainedRetrieval,
"conditional": ConditionalRetrieval,
"intersection": SetIntersection,
"compose_add": ComposeArithmetic,
"multihop": MultiHopChained,
}
def run_experiment(
task_name: str,
cfg: FOGConfig,
model_type: str,
n_epochs: int,
batch_size: int,
lr: float,
device: torch.device,
seed: int = 42,
n_train: int = 2000,
n_eval: int = 500,
) -> dict:
torch.manual_seed(seed)
if task_name not in TASK_MAP:
raise ValueError(f"Unknown task: {task_name}. Choose from {list(TASK_MAP.keys())}")
task_cls = TASK_MAP[task_name]
# Use n_pairs=6 for chained (needs enough pairs for chains to form)
extra_kwargs = {}
if task_name == "chained":
extra_kwargs["n_pairs"] = 6
elif task_name == "multihop":
extra_kwargs["n_pairs"] = 10
elif task_name == "conditional":
extra_kwargs["n_pairs"] = 6
elif task_name == "intersection":
extra_kwargs["set_size"] = 8
extra_kwargs["overlap"] = 3
elif task_name == "compose_add":
extra_kwargs["n_pairs"] = 6
elif task_name in ("distractor", "noisy", "multiquery", "retrieval"):
extra_kwargs["n_pairs"] = 4
train_ds = task_cls(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0, **extra_kwargs)
eval_ds = task_cls(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99, **extra_kwargs)
# Pre-batch into contiguous tensors for speed
train_data = prebatch_dataset(train_ds, cfg.max_seq_len)
eval_data = prebatch_dataset(eval_ds, cfg.max_seq_len)
train_loader = TensorBatchIterator(train_data, batch_size, shuffle=True)
eval_loader = TensorBatchIterator(eval_data, batch_size, shuffle=False)
if model_type in ("baseline", "uniform_small"):
model = BaselineTransformer(cfg).to(device)
elif model_type == "motif":
model = MotifTransformer(cfg).to(device)
else:
raise ValueError(f"Unknown model: {model_type}")
n_params = count_params(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
history: list[dict] = []
t0 = time.time()
for epoch in range(1, n_epochs + 1):
train_loss = train_epoch(model, train_loader, optimizer, device)
metrics = eval_accuracy(model, eval_loader, device)
history.append({
"epoch": epoch,
"train_loss": round(train_loss, 4),
"eval_loss": round(metrics["loss"], 4),
"eval_accuracy": round(metrics["accuracy"], 4),
"eval_exact_match": round(metrics["exact_match"], 4),
})
if epoch % 10 == 0 or epoch == 1:
print(f" [{model_type}/{task_name}] epoch {epoch:>3d} "
f"train={train_loss:.4f} eval={metrics['loss']:.4f} "
f"acc={metrics['accuracy']:.4f} em={metrics['exact_match']:.4f}")
elapsed = time.time() - t0
final = history[-1] if history else {}
return {
"model_type": model_type,
"task": task_name,
"seed": seed,
"n_params": n_params,
"n_epochs": n_epochs,
"elapsed_s": round(elapsed, 1),
"final_train_loss": final.get("train_loss"),
"final_eval_loss": final.get("eval_loss"),
"final_accuracy": final.get("eval_accuracy"),
"final_exact_match": final.get("eval_exact_match"),
"history": history,
}
def main() -> None:
parser = argparse.ArgumentParser(description="FOG Ablation: baseline vs motif-aware")
parser.add_argument("--tasks", nargs="+", default=["copy", "reverse", "retrieval"])
parser.add_argument("--epochs", type=int, default=30)
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--size", type=str, default="med",
choices=["micro", "tiny", "med", "small"])
parser.add_argument("--seeds", type=int, nargs="+", default=[42])
parser.add_argument("--n_train", type=int, default=2000)
parser.add_argument("--n_eval", type=int, default=500)
parser.add_argument("--output", type=str, default="archive/fog_ablation.json")
args = parser.parse_args()
device = torch.device(args.device)
if args.size == "micro":
configs = [
("baseline", BASELINE_MICRO),
("uniform_small", UNIFORM_MICRO),
("motif", MOTIF_MICRO),
]
elif args.size == "tiny":
configs = [
("baseline", BASELINE_TINY),
("uniform_small", UNIFORM_TINY),
("motif", MOTIF_TINY),
]
elif args.size == "med":
configs = [
("baseline", BASELINE_MED),
("uniform_small", UNIFORM_MED),
("motif", MOTIF_MED),
]
else:
configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)]
results = []
for task in args.tasks:
for seed in args.seeds:
print(f"\n{'='*60}")
print(f" Task: {task} (size={args.size}, seed={seed})")
print(f"{'='*60}")
for model_type, cfg in configs:
result = run_experiment(
task_name=task,
cfg=cfg,
model_type=model_type,
n_epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
device=device,
seed=seed,
n_train=args.n_train,
n_eval=args.n_eval,
)
results.append(result)
print(f" -> {model_type}: params={result['n_params']:,} "
f"acc={result['final_accuracy']:.4f} "
f"em={result['final_exact_match']:.4f} "
f"time={result['elapsed_s']}s")
# Summary
print(f"\n{'='*60}")
print(f" SUMMARY")
print(f"{'='*60}")
print(f"{'Task':<12} {'Model':<15} {'Params':>8} {'Loss':>8} {'Acc':>8} {'EM':>8} {'Time':>6}")
print("-" * 70)
for r in results:
em = r.get('final_exact_match', 0) or 0
print(f"{r['task']:<12} {r['model_type']:<15} {r['n_params']:>8,} "
f"{r['final_eval_loss']:>8.4f} {r['final_accuracy']:>8.4f} "
f"{em:>8.4f} {r['elapsed_s']:>5.0f}s")
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(results, indent=2), encoding="utf-8")
print(f"\nSaved: {out_path}")
if __name__ == "__main__":
main()