#!/usr/bin/env python3 """ Supervised Cortex adapter tuning. This trains only Cortex module parameters against the same multiple-choice log-likelihood objective used by the benchmark runner. It is intended as a small, explicit tuning step before expecting Cortex to outperform the base model. """ import argparse import os import random import sys import time import torch # Ensure parent directory is on path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from benchmark.runner import BenchmarkRunner from benchmark.tasks import TASK_REGISTRY from benchmark.tuning import cortex_auxiliary_loss, multiple_choice_loss def load_examples(task_names, n_per_task, seed): examples = [] for task_name in task_names: task_cls = TASK_REGISTRY[task_name] task = task_cls() if callable(task_cls) else task_cls task_examples = task.load_examples(n=n_per_task, seed=seed) examples.extend((task_name, ex) for ex in task_examples) print(f"Loaded {len(task_examples)} examples for {task_name}") return examples def main(): parser = argparse.ArgumentParser(description="Train Cortex modules on benchmark-style MC data") parser.add_argument( "--model", type=str, default="HuggingFaceTB/SmolLM2-135M", help="HuggingFace model ID to tune", ) parser.add_argument( "--tasks", nargs="+", default=["hellaswag", "piqa", "arc-easy", "winogrande"], help="Tasks to train on", ) parser.add_argument( "--n-train", type=int, default=8, help="Examples per task for tuning", ) parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=0.01) parser.add_argument("--max-grad-norm", type=float, default=1.0) parser.add_argument("--seed", type=int, default=42) parser.add_argument( "--device", type=str, default="auto", help="Device: cuda, mps, cpu, or auto", ) parser.add_argument( "--dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"], ) parser.add_argument( "--init-cortex-weights", type=str, default=None, help="Optional Cortex weights to resume from", ) parser.add_argument( "--output", type=str, default="cortex_tuned.pt", help="Path to save tuned Cortex weights", ) parser.add_argument("--log-every", type=int, default=4) args = parser.parse_args() random.seed(args.seed) torch.manual_seed(args.seed) runner = BenchmarkRunner( model_name=args.model, device=args.device, dtype=args.dtype, cortex_weights=args.init_cortex_weights, ) runner.inject_cortex() model = runner.model tokenizer = runner.tokenizer surgeon = runner._surgeon model.train() examples = load_examples(args.tasks, args.n_train, args.seed) if not examples: raise RuntimeError("No training examples loaded") trainable_params = list(surgeon.get_trainable_parameters()) optimizer = torch.optim.AdamW( trainable_params, lr=args.lr, weight_decay=args.weight_decay, ) print(f"Training on {len(examples)} examples for {args.epochs} epoch(s)") start = time.time() for epoch in range(args.epochs): rng = random.Random(args.seed + epoch) rng.shuffle(examples) total_loss = 0.0 correct = 0 seen = 0 skipped = 0 for step, (task_name, example) in enumerate(examples, start=1): optimizer.zero_grad(set_to_none=True) loss, pred = multiple_choice_loss(model, tokenizer, example, runner.device) if loss is None: skipped += 1 continue aux_loss = cortex_auxiliary_loss(model) train_loss = loss + aux_loss train_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm) optimizer.step() seen += 1 total_loss += float(train_loss.detach().cpu()) correct += int(pred == example["gold_idx"]) if step % args.log_every == 0 or step == len(examples): avg_loss = total_loss / max(seen, 1) acc = correct / max(seen, 1) print( f"epoch={epoch + 1} step={step}/{len(examples)} " f"task={task_name} loss={avg_loss:.4f} acc={acc:.3f}" ) avg_loss = total_loss / max(seen, 1) acc = correct / max(seen, 1) print( f"Epoch {epoch + 1} done: loss={avg_loss:.4f} " f"acc={acc:.3f} skipped={skipped}" ) output_dir = os.path.dirname(args.output) if output_dir: os.makedirs(output_dir, exist_ok=True) surgeon.save_cortex_modules(args.output) elapsed = time.time() - start print(f"Saved Cortex weights to {args.output} [{elapsed:.1f}s]") if __name__ == "__main__": main()