File size: 6,616 Bytes
788f72f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | #!/usr/bin/env python3
"""
Decisive eval: does ap32_sched fix the >1000-token collapse vs ap32_large?
For each checkpoint: free-run the SP-path (bounded memory, rw=128) to GEN_LEN on held-out
prompts, then teacher-force-score that SAME self-generated trajectory -> per-position
KL(full-KV teacher || SP-path student), binned by position. Lower KL at >1000 = drift fixed.
4-bit frozen LLM (deployment config). Reuses train_sched building blocks.
"""
import sys, json, argparse
sys.path.insert(0, "/workspace")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.cache_utils import DynamicCache
import train_sched as TS
@torch.no_grad()
def score_traj(pooler, llm, embed, q_ids, q_mask, a_ids, S, rw, C, dev, dt, H, bins):
"""Per-position KL(teacher||student) teacher-forced on a_ids, accumulated into position bins."""
B, MQ = q_ids.shape; La = a_ids.size(1)
qlen = q_mask.sum(-1).long()
qpos = (q_mask.cumsum(-1) - 1).clamp(min=0).long()
q_last, oa_logits = TS.teacher_full(llm, embed, q_ids, q_mask, a_ids, dev, dt)
acc = {b: [0.0, 0] for b in bins} # (sum_kl, n_tokens)
n_chunks = (La + C - 1) // C
for j in range(n_chunks):
c0 = j * C; c1 = min(c0 + C, La); cur = c1 - c0
if cur < 1: continue
R = min(c0, rw); dl = c0 - R
sp = pooler(embed(a_ids[:, :dl]).float()) if dl > 0 else pooler(torch.zeros(B, 0, H, device=dev))
cache = DynamicCache()
llm(inputs_embeds=embed(q_ids).to(dt), attention_mask=q_mask, past_key_values=cache,
position_ids=qpos, use_cache=True, cache_position=torch.arange(MQ, device=dev))
parts = [sp.to(dt)]
if R > 0: parts.append(embed(a_ids[:, c0 - R:c0]).to(dt))
parts.append(embed(a_ids[:, c0:c1]).to(dt))
block = torch.cat(parts, 1); nb = block.size(1)
pos = qlen.unsqueeze(1) + torch.arange(nb, device=dev).unsqueeze(0)
amask = torch.cat([q_mask, torch.ones(B, nb, device=dev)], 1)
o = llm(inputs_embeds=block, attention_mask=amask, past_key_values=cache,
position_ids=pos, cache_position=torch.arange(MQ, MQ + nb, device=dev), use_cache=True)
ps = S - 1 + R
s_logits = o.logits[:, ps:ps + cur, :].float()
if c0 == 0:
t_logits = torch.cat([q_last[:, None, :], oa_logits[:, 0:cur - 1, :]], 1).float()
else:
t_logits = oa_logits[:, c0 - 1:c1 - 1, :].float()
kl = TS.kl_td(t_logits, s_logits) # (B, cur)
for lo, hi in bins:
if lo <= c0 < hi:
acc[(lo, hi)][0] += kl.sum().item(); acc[(lo, hi)][1] += kl.numel()
del q_last, oa_logits
return acc
def main():
p = argparse.ArgumentParser()
p.add_argument("--base_model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
p.add_argument("--ckpts", default="orig=/workspace/ap32_large.pt,sched=/workspace/ap32_sched.pt")
p.add_argument("--prompts", default="/workspace/prompts.jsonl")
p.add_argument("--n", type=int, default=24)
p.add_argument("--B", type=int, default=12)
p.add_argument("--gen_len", type=int, default=1400)
p.add_argument("--rw", type=int, default=128)
p.add_argument("--C", type=int, default=64)
p.add_argument("--temp", type=float, default=0.7)
p.add_argument("--max_q_len", type=int, default=160)
args = p.parse_args()
dev = torch.device("cuda"); dt = torch.bfloat16
tok = AutoTokenizer.from_pretrained(args.base_model); tok.padding_side = "left"
if tok.pad_token is None: tok.pad_token = tok.eos_token
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dt, bnb_4bit_use_double_quant=True)
llm = AutoModelForCausalLM.from_pretrained(args.base_model, quantization_config=bnb,
torch_dtype=dt, device_map={"": 0}, attn_implementation="sdpa")
llm.config.use_cache = True
for pr in llm.parameters():
pr.requires_grad_(False)
if pr.dtype == torch.float32: pr.data = pr.data.to(dt)
llm.eval()
H = llm.config.hidden_size; embed = llm.get_input_embeddings()
with torch.no_grad():
ids = torch.randint(0, embed.weight.size(0), (512,), device=dev)
tn = embed(ids).float().norm(dim=-1).mean().item()
# held-out prompts = LAST n of the pool (least likely used by the ptr sweep start)
allp = [json.loads(l)["q"] for l in open(args.prompts)]
prompts = allp[-args.n:]
# granular position bins out to long-context regime
edges = [0, 768, 1500, 3000, 6000, 12000, 20000, 99999]
bins = [(edges[i], edges[i + 1]) for i in range(len(edges) - 1)]
def load_pooler(path):
ck = torch.load(path, map_location="cpu", weights_only=False)
a = ck.get("args", {}) or ck.get("src_args", {})
S = a.get("n_sp") or 32
pl = TS.AttnPoolSP(H, S, a.get("heads", 8), a.get("layers", 3), a.get("ffn", 2048), tn).to(dev)
pl.load_state_dict(ck["pooler"]); pl.eval()
return pl, S
results = {}
for spec in args.ckpts.split(","):
name, path = spec.split("=")
pooler, S = load_pooler(path)
tot = {b: [0.0, 0] for b in bins}
for s in range(0, len(prompts), args.B):
batch = prompts[s:s + args.B]
q_ids, q_mask = TS.pad_queries(tok, batch, dev)
gen = TS.rollout_sp(pooler, llm, embed, q_ids, q_mask, args.gen_len, S, args.rw, args.C, dev, dt, H, args.temp)
acc = score_traj(pooler, llm, embed, q_ids, q_mask, gen, S, args.rw, args.C, dev, dt, H, bins)
for b in bins:
tot[b][0] += acc[b][0]; tot[b][1] += acc[b][1]
torch.cuda.empty_cache()
results[name] = {b: (tot[b][0] / tot[b][1] if tot[b][1] else float("nan")) for b in bins}
print(f"[{name}] " + " ".join(f"{lo}-{hi if hi<99999 else 'inf'}:{results[name][(lo,hi)]:.4f}"
for lo, hi in bins), flush=True)
print("\n=== free-running SP-path KL(full-KV || SP) by position (lower=less drift) ===", flush=True)
print(f"{'bin':>12} | " + " | ".join(f"{n:>8}" for n in results) + " | delta", flush=True)
for lo, hi in bins:
row = [results[n][(lo, hi)] for n in results]
lab = f"{lo}-{hi if hi<99999 else 'inf'}"
d = (row[1] - row[0]) if len(row) == 2 else 0.0
print(f"{lab:>12} | " + " | ".join(f"{v:8.4f}" for v in row) + f" | {d:+.4f}", flush=True)
print("DONE_EVAL", flush=True)
if __name__ == "__main__":
main()
|