File size: 5,307 Bytes
3546a09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """
train_pilot.py — toy MQAR (Multi-Query Associative Recall) training script.
Generates synthetic data of the form:
[k1] [v1] [k2] [v2] ... [kN] [vN] | [query=k_i] -> v_i
The model must recall the value associated with a queried key seen earlier
in the same sequence. Length-OOD generalization is tested by training on
short sequences and evaluating on longer ones.
This is a SMOKE-LEVEL pilot, intended to verify the implementation runs and
to establish a baseline-vs-SMW signal for the eventual full pilot. It is
NOT the 350M / 1.3B pre-registered run.
Usage:
python scripts/train_pilot.py --condition C5_smw --steps 1000
python scripts/train_pilot.py --condition C0_baseline --steps 1000
python scripts/train_pilot.py --condition C4_all_independent --steps 1000
"""
import argparse
import math
import time
import torch
import torch.nn.functional as F
from smw import SMWModel, CONDITIONS
# ---- synthetic MQAR data ----
VOCAB_SIZE = 64
KEY_RANGE = (0, 16)
VAL_RANGE = (16, 64)
SEP_TOKEN = 0 # not used as a key/val
QUERY_TOKEN = 1 # special "go retrieve" marker
def make_mqar_batch(batch_size: int, n_pairs: int, device="cpu"):
"""
Returns (idx, targets) where targets are -100 (ignore) except at the
query-answer position.
"""
seq_len = 2 * n_pairs + 2 # pairs + [query_marker, query_key]
idx = torch.zeros(batch_size, seq_len + 1, dtype=torch.long, device=device)
targets = torch.full((batch_size, seq_len + 1), -100, dtype=torch.long, device=device)
for b in range(batch_size):
keys = torch.randperm(KEY_RANGE[1] - KEY_RANGE[0])[:n_pairs] + KEY_RANGE[0]
vals = torch.randint(VAL_RANGE[0], VAL_RANGE[1], (n_pairs,))
# Lay out: k1 v1 k2 v2 ...
for i in range(n_pairs):
idx[b, 2 * i] = keys[i]
idx[b, 2 * i + 1] = vals[i]
# Query marker + a sampled key
q_i = torch.randint(0, n_pairs, (1,)).item()
idx[b, 2 * n_pairs] = QUERY_TOKEN
idx[b, 2 * n_pairs + 1] = keys[q_i]
# Answer slot
idx[b, 2 * n_pairs + 2] = vals[q_i]
targets[b, 2 * n_pairs + 2] = vals[q_i]
return idx, targets
def evaluate(model, n_pairs, n_batches=8, batch_size=16, device="cpu"):
model.eval()
correct, total = 0, 0
with torch.no_grad():
for _ in range(n_batches):
idx, targets = make_mqar_batch(batch_size, n_pairs, device=device)
logits, _, _ = model(idx)
# Answer position is the LAST token in idx (we predict from the prior context)
ans_pos = idx.size(1) - 1
preds = logits[:, ans_pos - 1].argmax(-1)
true = targets[:, ans_pos]
mask = true != -100
correct += (preds[mask] == true[mask]).sum().item()
total += mask.sum().item()
model.train()
return correct / max(total, 1)
# ---- training ----
def main():
p = argparse.ArgumentParser()
p.add_argument("--condition", type=str, default="C5_smw", choices=list(CONDITIONS.keys()))
p.add_argument("--steps", type=int, default=500)
p.add_argument("--batch_size", type=int, default=32)
p.add_argument("--n_pairs_train", type=int, default=4)
p.add_argument("--n_pairs_eval_id", type=int, default=4) # in-distribution
p.add_argument("--n_pairs_eval_ood", type=int, default=8) # length-OOD (P1-style probe)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--seed", type=int, default=0)
args = p.parse_args()
torch.manual_seed(args.seed)
cfg = CONDITIONS[args.condition]
cfg.vocab_size = VOCAB_SIZE
cfg.block_size = max(2 * args.n_pairs_eval_ood + 4, cfg.block_size)
cfg.d_model = 128
cfg.n_layers = 4
cfg.d_w = 16
cfg.k_bottleneck = 2
model = SMWModel(cfg).to(args.device)
n_params = sum(p.numel() for p in model.parameters())
print(f"[{cfg.name}] params: {n_params:,} device: {args.device}")
opt = torch.optim.AdamW(model.parameters(), lr=args.lr)
t0 = time.time()
for step in range(1, args.steps + 1):
idx, targets = make_mqar_batch(args.batch_size, args.n_pairs_train, device=args.device)
logits, ce_loss, reg_loss = model(idx, targets)
loss = ce_loss + (reg_loss if reg_loss is not None else 0.0)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if step % max(args.steps // 10, 1) == 0 or step == args.steps:
id_acc = evaluate(model, args.n_pairs_eval_id, device=args.device)
ood_acc = evaluate(model, args.n_pairs_eval_ood, device=args.device)
slow_band = model.slow_band_mass()
mask_H = model.last_mask_entropy()
print(
f"step {step:4d} ce={float(ce_loss):.3f} reg={float(reg_loss) if reg_loss is not None else 0.0:.3f}"
f" id_acc={id_acc:.2f} ood_acc={ood_acc:.2f}"
f" slow_band={f'{slow_band:.2f}' if slow_band is not None else '-'}"
f" mask_H={f'{mask_H:.2f}' if mask_H is not None else '-'}"
)
print(f"done in {time.time() - t0:.1f}s")
if __name__ == "__main__":
main()
|