File size: 5,577 Bytes
bc7101b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | """Quick test: does our trained model still get GSM8K answers right when forced
to emit ONLY the final answer (no verbal CoT in y)?
This is the *latent-reasoning value proposition* test: if the K=16 latent
vectors actually carry the reasoning, the model should be able to produce
"#### N" directly without writing out steps.
We don't retrain — we just generate with max_new_tokens small enough that
no verbal CoT can fit. Two flavors:
short-32: max_new_tokens=32, just enough for "#### NUMBER" + a few extras
short-8: max_new_tokens=8, basically forces direct emission
The model was trained to emit a full ~150-token GSM8K answer, so we expect
this is way out-of-distribution and accuracy will be low. The point is to
calibrate: is there ANY chance z carries the reasoning, or is the model
hopelessly dependent on writing verbal CoT?
"""
from __future__ import annotations
import argparse
import json
import re
import time
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from ..data import GSM8KDataset, collate_batch
from ..model import BLTConfig, LatentProjector, build_base, generate_with_latent
GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")
ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?")
def parse_pred(text: str):
m = GSM8K_NUM.search(text)
if m: return m.group(1)
nums = ANY_NUM.findall(text)
return nums[-1] if nums else None
def correct(pred, gold):
if pred is None: return False
try: return abs(float(pred) - float(gold)) < 1e-4
except ValueError: return False
def main():
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True)
p.add_argument("--config", required=True)
p.add_argument("--n", type=int, default=100)
p.add_argument("--K", type=int, default=16)
p.add_argument("--no_block_y_to_x", action="store_true")
p.add_argument("--max_new_tokens", type=int, default=32)
p.add_argument("--out", default=None)
args = p.parse_args()
with open(args.config) as f:
cfg = json.load(f)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ckpt = Path(args.ckpt)
bcfg = BLTConfig(
base_model=cfg["base_model"], use_lora=False,
lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"],
lora_dropout=cfg["lora_dropout"],
lora_target_modules=tuple(cfg["lora_target_modules"]),
K_latents=args.K, block_y_to_x=cfg["block_y_to_x"],
proj_init_scale=cfg["proj_init_scale"],
dtype=cfg["dtype"], attn_impl=cfg["attn_impl"],
gradient_checkpointing=False,
)
base_model, tokenizer = build_base(bcfg)
from peft import PeftModel
adapter_dir = ckpt / "model"
if (adapter_dir / "adapter_config.json").exists():
model = PeftModel.from_pretrained(base_model, str(adapter_dir))
print(f"[load] adapter from {adapter_dir}")
else:
model = base_model
model.to(device).eval()
inner = model.get_base_model() if hasattr(model, "get_base_model") else model
d_model = inner.config.hidden_size
projector = LatentProjector(
d_model, init_scale=cfg["proj_init_scale"],
use_mlp=cfg.get("proj_mlp", False),
hidden_mult=cfg.get("proj_hidden_mult", 4),
).to(device).to(next(model.parameters()).dtype)
projector.load_state_dict(torch.load(ckpt / "projector.pt", map_location=device))
projector.eval()
val_ds = GSM8KDataset(split="test", max_examples=args.n)
loader = DataLoader(
val_ds, batch_size=8, shuffle=False,
collate_fn=lambda b: collate_batch(b, tokenizer,
max_prompt_len=cfg["max_prompt_len"],
max_answer_len=cfg["max_answer_len"]),
)
block_y_to_x = not args.no_block_y_to_x
print(f"[mode] block_y_to_x={block_y_to_x} max_new_tokens={args.max_new_tokens} K={args.K}")
correct_n = 0
total = 0
examples = []
t0 = time.time()
for batch in loader:
x_ids = batch.x_ids.to(device); x_attn = batch.x_attn.to(device)
B = x_ids.size(0)
gen = generate_with_latent(
model, tokenizer, projector,
x_ids=x_ids, x_attn=x_attn, K=args.K,
block_y_to_x=block_y_to_x,
max_new_tokens=args.max_new_tokens,
temperature=0.0, eos_token_id=tokenizer.eos_token_id,
)
for b in range(B):
text = tokenizer.decode(gen[b], skip_special_tokens=True)
pred = parse_pred(text)
gold = batch.final_strs[b].replace("#### ", "").strip()
ok = correct(pred, gold)
correct_n += int(ok)
total += 1
if len(examples) < 8:
examples.append({"text": text[:120], "pred": pred, "gold": gold, "ok": ok})
summary = {
"ckpt": str(ckpt), "n": args.n, "K": args.K,
"block_y_to_x": block_y_to_x,
"max_new_tokens": args.max_new_tokens,
"acc": correct_n / max(total, 1),
"correct": correct_n,
"total": total,
"elapsed_s": time.time() - t0,
"examples": examples,
}
out = args.out or str(ckpt / f"short_y_eval_M{args.max_new_tokens}_block{block_y_to_x}.json")
Path(out).write_text(json.dumps(summary, indent=2))
print(f"[done] acc={summary['acc']:.4f} ({correct_n}/{total}) elapsed={summary['elapsed_s']:.0f}s")
print(f"[written] {out}")
for e in examples[:5]:
print(f" text={e['text']!r} pred={e['pred']} gold={e['gold']} ok={e['ok']}")
if __name__ == "__main__":
main()
|