File size: 5,151 Bytes
0de2901
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python3
"""
Supervised Cortex adapter tuning.

This trains only Cortex module parameters against the same multiple-choice
log-likelihood objective used by the benchmark runner. It is intended as a
small, explicit tuning step before expecting Cortex to outperform the base
model.
"""

import argparse
import os
import random
import sys
import time

import torch

# Ensure parent directory is on path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from benchmark.runner import BenchmarkRunner
from benchmark.tasks import TASK_REGISTRY
from benchmark.tuning import cortex_auxiliary_loss, multiple_choice_loss


def load_examples(task_names, n_per_task, seed):
    examples = []
    for task_name in task_names:
        task_cls = TASK_REGISTRY[task_name]
        task = task_cls() if callable(task_cls) else task_cls
        task_examples = task.load_examples(n=n_per_task, seed=seed)
        examples.extend((task_name, ex) for ex in task_examples)
        print(f"Loaded {len(task_examples)} examples for {task_name}")
    return examples


def main():
    parser = argparse.ArgumentParser(description="Train Cortex modules on benchmark-style MC data")
    parser.add_argument(
        "--model", type=str, default="HuggingFaceTB/SmolLM2-135M",
        help="HuggingFace model ID to tune",
    )
    parser.add_argument(
        "--tasks", nargs="+", default=["hellaswag", "piqa", "arc-easy", "winogrande"],
        help="Tasks to train on",
    )
    parser.add_argument(
        "--n-train", type=int, default=8,
        help="Examples per task for tuning",
    )
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--weight-decay", type=float, default=0.01)
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--device", type=str, default="auto",
        help="Device: cuda, mps, cpu, or auto",
    )
    parser.add_argument(
        "--dtype", type=str, default="float32",
        choices=["float32", "float16", "bfloat16"],
    )
    parser.add_argument(
        "--init-cortex-weights", type=str, default=None,
        help="Optional Cortex weights to resume from",
    )
    parser.add_argument(
        "--output", type=str, default="cortex_tuned.pt",
        help="Path to save tuned Cortex weights",
    )
    parser.add_argument("--log-every", type=int, default=4)
    args = parser.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)

    runner = BenchmarkRunner(
        model_name=args.model,
        device=args.device,
        dtype=args.dtype,
        cortex_weights=args.init_cortex_weights,
    )
    runner.inject_cortex()

    model = runner.model
    tokenizer = runner.tokenizer
    surgeon = runner._surgeon
    model.train()

    examples = load_examples(args.tasks, args.n_train, args.seed)
    if not examples:
        raise RuntimeError("No training examples loaded")

    trainable_params = list(surgeon.get_trainable_parameters())
    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=args.lr,
        weight_decay=args.weight_decay,
    )

    print(f"Training on {len(examples)} examples for {args.epochs} epoch(s)")
    start = time.time()

    for epoch in range(args.epochs):
        rng = random.Random(args.seed + epoch)
        rng.shuffle(examples)

        total_loss = 0.0
        correct = 0
        seen = 0
        skipped = 0

        for step, (task_name, example) in enumerate(examples, start=1):
            optimizer.zero_grad(set_to_none=True)

            loss, pred = multiple_choice_loss(model, tokenizer, example, runner.device)
            if loss is None:
                skipped += 1
                continue

            aux_loss = cortex_auxiliary_loss(model)
            train_loss = loss + aux_loss
            train_loss.backward()

            if args.max_grad_norm > 0:
                torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)

            optimizer.step()

            seen += 1
            total_loss += float(train_loss.detach().cpu())
            correct += int(pred == example["gold_idx"])

            if step % args.log_every == 0 or step == len(examples):
                avg_loss = total_loss / max(seen, 1)
                acc = correct / max(seen, 1)
                print(
                    f"epoch={epoch + 1} step={step}/{len(examples)} "
                    f"task={task_name} loss={avg_loss:.4f} acc={acc:.3f}"
                )

        avg_loss = total_loss / max(seen, 1)
        acc = correct / max(seen, 1)
        print(
            f"Epoch {epoch + 1} done: loss={avg_loss:.4f} "
            f"acc={acc:.3f} skipped={skipped}"
        )

    output_dir = os.path.dirname(args.output)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    surgeon.save_cortex_modules(args.output)
    elapsed = time.time() - start
    print(f"Saved Cortex weights to {args.output} [{elapsed:.1f}s]")


if __name__ == "__main__":
    main()