Spaces:
Sleeping
Sleeping
| """Train and compare baseline vs motif-aware transformer on algorithmic tasks.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import time | |
| from pathlib import Path | |
| import torch | |
| from src.fog.config import ( | |
| FOGConfig, | |
| BASELINE_SMALL, MOTIF_SMALL, | |
| BASELINE_TINY, MOTIF_TINY, UNIFORM_TINY, | |
| BASELINE_MICRO, MOTIF_MICRO, UNIFORM_MICRO, | |
| BASELINE_MED, MOTIF_MED, UNIFORM_MED, | |
| ) | |
| from src.fog.model_baseline import BaselineTransformer | |
| from src.fog.model_motif import MotifTransformer | |
| from src.fog.data import ( | |
| CopyTask, ReverseTask, SelectiveRetrieval, | |
| DistractorRetrieval, NoisyRetrieval, MultiQueryRetrieval, | |
| ChainedRetrieval, | |
| ConditionalRetrieval, SetIntersection, ComposeArithmetic, MultiHopChained, | |
| prebatch_dataset, TensorBatchIterator, | |
| ) | |
| def count_params(model: torch.nn.Module) -> int: | |
| return sum(p.numel() for p in model.parameters()) | |
| def train_epoch( | |
| model: torch.nn.Module, | |
| loader: TensorBatchIterator, | |
| optimizer: torch.optim.Optimizer, | |
| device: torch.device, | |
| ) -> float: | |
| model.train() | |
| total_loss = 0.0 | |
| n_batches = 0 | |
| for batch in loader: | |
| input_ids = batch["input_ids"].to(device) | |
| targets = batch["targets"].to(device) | |
| loss_mask = batch["loss_mask"].to(device) | |
| out = model(input_ids, targets, loss_mask=loss_mask) | |
| loss = out["loss"] | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| n_batches += 1 | |
| return total_loss / max(n_batches, 1) | |
| def eval_accuracy( | |
| model: torch.nn.Module, | |
| loader: TensorBatchIterator, | |
| device: torch.device, | |
| ) -> dict[str, float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| correct = 0 | |
| total = 0 | |
| seq_correct = 0 | |
| seq_total = 0 | |
| n_batches = 0 | |
| for batch in loader: | |
| input_ids = batch["input_ids"].to(device) | |
| targets = batch["targets"].to(device) | |
| loss_mask = batch["loss_mask"].to(device) | |
| out = model(input_ids, targets, loss_mask=loss_mask) | |
| total_loss += out["loss"].item() | |
| n_batches += 1 | |
| preds = out["logits"].argmax(dim=-1) | |
| m = loss_mask.bool() | |
| correct += (preds[m] == targets[m]).sum().item() | |
| total += m.sum().item() | |
| for b in range(preds.size(0)): | |
| mb = m[b] | |
| if mb.any(): | |
| seq_total += 1 | |
| if torch.equal(preds[b][mb], targets[b][mb]): | |
| seq_correct += 1 | |
| return { | |
| "loss": total_loss / max(n_batches, 1), | |
| "accuracy": correct / max(total, 1), | |
| "exact_match": seq_correct / max(seq_total, 1), | |
| "total_tokens": total, | |
| } | |
| TASK_MAP = { | |
| "copy": CopyTask, | |
| "reverse": ReverseTask, | |
| "retrieval": SelectiveRetrieval, | |
| "distractor": DistractorRetrieval, | |
| "noisy": NoisyRetrieval, | |
| "multiquery": MultiQueryRetrieval, | |
| "chained": ChainedRetrieval, | |
| "conditional": ConditionalRetrieval, | |
| "intersection": SetIntersection, | |
| "compose_add": ComposeArithmetic, | |
| "multihop": MultiHopChained, | |
| } | |
| def run_experiment( | |
| task_name: str, | |
| cfg: FOGConfig, | |
| model_type: str, | |
| n_epochs: int, | |
| batch_size: int, | |
| lr: float, | |
| device: torch.device, | |
| seed: int = 42, | |
| n_train: int = 2000, | |
| n_eval: int = 500, | |
| ) -> dict: | |
| torch.manual_seed(seed) | |
| if task_name not in TASK_MAP: | |
| raise ValueError(f"Unknown task: {task_name}. Choose from {list(TASK_MAP.keys())}") | |
| task_cls = TASK_MAP[task_name] | |
| # Use n_pairs=6 for chained (needs enough pairs for chains to form) | |
| extra_kwargs = {} | |
| if task_name == "chained": | |
| extra_kwargs["n_pairs"] = 6 | |
| elif task_name == "multihop": | |
| extra_kwargs["n_pairs"] = 10 | |
| elif task_name == "conditional": | |
| extra_kwargs["n_pairs"] = 6 | |
| elif task_name == "intersection": | |
| extra_kwargs["set_size"] = 8 | |
| extra_kwargs["overlap"] = 3 | |
| elif task_name == "compose_add": | |
| extra_kwargs["n_pairs"] = 6 | |
| elif task_name in ("distractor", "noisy", "multiquery", "retrieval"): | |
| extra_kwargs["n_pairs"] = 4 | |
| train_ds = task_cls(cfg.vocab_size, cfg.max_seq_len, n_train, seed=0, **extra_kwargs) | |
| eval_ds = task_cls(cfg.vocab_size, cfg.max_seq_len, n_eval, seed=99, **extra_kwargs) | |
| # Pre-batch into contiguous tensors for speed | |
| train_data = prebatch_dataset(train_ds, cfg.max_seq_len) | |
| eval_data = prebatch_dataset(eval_ds, cfg.max_seq_len) | |
| train_loader = TensorBatchIterator(train_data, batch_size, shuffle=True) | |
| eval_loader = TensorBatchIterator(eval_data, batch_size, shuffle=False) | |
| if model_type in ("baseline", "uniform_small"): | |
| model = BaselineTransformer(cfg).to(device) | |
| elif model_type == "motif": | |
| model = MotifTransformer(cfg).to(device) | |
| else: | |
| raise ValueError(f"Unknown model: {model_type}") | |
| n_params = count_params(model) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01) | |
| history: list[dict] = [] | |
| t0 = time.time() | |
| for epoch in range(1, n_epochs + 1): | |
| train_loss = train_epoch(model, train_loader, optimizer, device) | |
| metrics = eval_accuracy(model, eval_loader, device) | |
| history.append({ | |
| "epoch": epoch, | |
| "train_loss": round(train_loss, 4), | |
| "eval_loss": round(metrics["loss"], 4), | |
| "eval_accuracy": round(metrics["accuracy"], 4), | |
| "eval_exact_match": round(metrics["exact_match"], 4), | |
| }) | |
| if epoch % 10 == 0 or epoch == 1: | |
| print(f" [{model_type}/{task_name}] epoch {epoch:>3d} " | |
| f"train={train_loss:.4f} eval={metrics['loss']:.4f} " | |
| f"acc={metrics['accuracy']:.4f} em={metrics['exact_match']:.4f}") | |
| elapsed = time.time() - t0 | |
| final = history[-1] if history else {} | |
| return { | |
| "model_type": model_type, | |
| "task": task_name, | |
| "seed": seed, | |
| "n_params": n_params, | |
| "n_epochs": n_epochs, | |
| "elapsed_s": round(elapsed, 1), | |
| "final_train_loss": final.get("train_loss"), | |
| "final_eval_loss": final.get("eval_loss"), | |
| "final_accuracy": final.get("eval_accuracy"), | |
| "final_exact_match": final.get("eval_exact_match"), | |
| "history": history, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="FOG Ablation: baseline vs motif-aware") | |
| parser.add_argument("--tasks", nargs="+", default=["copy", "reverse", "retrieval"]) | |
| parser.add_argument("--epochs", type=int, default=30) | |
| parser.add_argument("--batch_size", type=int, default=128) | |
| parser.add_argument("--lr", type=float, default=3e-4) | |
| parser.add_argument("--device", type=str, default="cpu") | |
| parser.add_argument("--size", type=str, default="med", | |
| choices=["micro", "tiny", "med", "small"]) | |
| parser.add_argument("--seeds", type=int, nargs="+", default=[42]) | |
| parser.add_argument("--n_train", type=int, default=2000) | |
| parser.add_argument("--n_eval", type=int, default=500) | |
| parser.add_argument("--output", type=str, default="archive/fog_ablation.json") | |
| args = parser.parse_args() | |
| device = torch.device(args.device) | |
| if args.size == "micro": | |
| configs = [ | |
| ("baseline", BASELINE_MICRO), | |
| ("uniform_small", UNIFORM_MICRO), | |
| ("motif", MOTIF_MICRO), | |
| ] | |
| elif args.size == "tiny": | |
| configs = [ | |
| ("baseline", BASELINE_TINY), | |
| ("uniform_small", UNIFORM_TINY), | |
| ("motif", MOTIF_TINY), | |
| ] | |
| elif args.size == "med": | |
| configs = [ | |
| ("baseline", BASELINE_MED), | |
| ("uniform_small", UNIFORM_MED), | |
| ("motif", MOTIF_MED), | |
| ] | |
| else: | |
| configs = [("baseline", BASELINE_SMALL), ("motif", MOTIF_SMALL)] | |
| results = [] | |
| for task in args.tasks: | |
| for seed in args.seeds: | |
| print(f"\n{'='*60}") | |
| print(f" Task: {task} (size={args.size}, seed={seed})") | |
| print(f"{'='*60}") | |
| for model_type, cfg in configs: | |
| result = run_experiment( | |
| task_name=task, | |
| cfg=cfg, | |
| model_type=model_type, | |
| n_epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lr=args.lr, | |
| device=device, | |
| seed=seed, | |
| n_train=args.n_train, | |
| n_eval=args.n_eval, | |
| ) | |
| results.append(result) | |
| print(f" -> {model_type}: params={result['n_params']:,} " | |
| f"acc={result['final_accuracy']:.4f} " | |
| f"em={result['final_exact_match']:.4f} " | |
| f"time={result['elapsed_s']}s") | |
| # Summary | |
| print(f"\n{'='*60}") | |
| print(f" SUMMARY") | |
| print(f"{'='*60}") | |
| print(f"{'Task':<12} {'Model':<15} {'Params':>8} {'Loss':>8} {'Acc':>8} {'EM':>8} {'Time':>6}") | |
| print("-" * 70) | |
| for r in results: | |
| em = r.get('final_exact_match', 0) or 0 | |
| print(f"{r['task']:<12} {r['model_type']:<15} {r['n_params']:>8,} " | |
| f"{r['final_eval_loss']:>8.4f} {r['final_accuracy']:>8.4f} " | |
| f"{em:>8.4f} {r['elapsed_s']:>5.0f}s") | |
| out_path = Path(args.output) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text(json.dumps(results, indent=2), encoding="utf-8") | |
| print(f"\nSaved: {out_path}") | |
| if __name__ == "__main__": | |
| main() | |