#!/usr/bin/env python3 """ Decisive eval: does ap32_sched fix the >1000-token collapse vs ap32_large? For each checkpoint: free-run the SP-path (bounded memory, rw=128) to GEN_LEN on held-out prompts, then teacher-force-score that SAME self-generated trajectory -> per-position KL(full-KV teacher || SP-path student), binned by position. Lower KL at >1000 = drift fixed. 4-bit frozen LLM (deployment config). Reuses train_sched building blocks. """ import sys, json, argparse sys.path.insert(0, "/workspace") import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from transformers.cache_utils import DynamicCache import train_sched as TS @torch.no_grad() def score_traj(pooler, llm, embed, q_ids, q_mask, a_ids, S, rw, C, dev, dt, H, bins): """Per-position KL(teacher||student) teacher-forced on a_ids, accumulated into position bins.""" B, MQ = q_ids.shape; La = a_ids.size(1) qlen = q_mask.sum(-1).long() qpos = (q_mask.cumsum(-1) - 1).clamp(min=0).long() q_last, oa_logits = TS.teacher_full(llm, embed, q_ids, q_mask, a_ids, dev, dt) acc = {b: [0.0, 0] for b in bins} # (sum_kl, n_tokens) n_chunks = (La + C - 1) // C for j in range(n_chunks): c0 = j * C; c1 = min(c0 + C, La); cur = c1 - c0 if cur < 1: continue R = min(c0, rw); dl = c0 - R sp = pooler(embed(a_ids[:, :dl]).float()) if dl > 0 else pooler(torch.zeros(B, 0, H, device=dev)) cache = DynamicCache() llm(inputs_embeds=embed(q_ids).to(dt), attention_mask=q_mask, past_key_values=cache, position_ids=qpos, use_cache=True, cache_position=torch.arange(MQ, device=dev)) parts = [sp.to(dt)] if R > 0: parts.append(embed(a_ids[:, c0 - R:c0]).to(dt)) parts.append(embed(a_ids[:, c0:c1]).to(dt)) block = torch.cat(parts, 1); nb = block.size(1) pos = qlen.unsqueeze(1) + torch.arange(nb, device=dev).unsqueeze(0) amask = torch.cat([q_mask, torch.ones(B, nb, device=dev)], 1) o = llm(inputs_embeds=block, attention_mask=amask, past_key_values=cache, position_ids=pos, cache_position=torch.arange(MQ, MQ + nb, device=dev), use_cache=True) ps = S - 1 + R s_logits = o.logits[:, ps:ps + cur, :].float() if c0 == 0: t_logits = torch.cat([q_last[:, None, :], oa_logits[:, 0:cur - 1, :]], 1).float() else: t_logits = oa_logits[:, c0 - 1:c1 - 1, :].float() kl = TS.kl_td(t_logits, s_logits) # (B, cur) for lo, hi in bins: if lo <= c0 < hi: acc[(lo, hi)][0] += kl.sum().item(); acc[(lo, hi)][1] += kl.numel() del q_last, oa_logits return acc def main(): p = argparse.ArgumentParser() p.add_argument("--base_model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") p.add_argument("--ckpts", default="orig=/workspace/ap32_large.pt,sched=/workspace/ap32_sched.pt") p.add_argument("--prompts", default="/workspace/prompts.jsonl") p.add_argument("--n", type=int, default=24) p.add_argument("--B", type=int, default=12) p.add_argument("--gen_len", type=int, default=1400) p.add_argument("--rw", type=int, default=128) p.add_argument("--C", type=int, default=64) p.add_argument("--temp", type=float, default=0.7) p.add_argument("--max_q_len", type=int, default=160) args = p.parse_args() dev = torch.device("cuda"); dt = torch.bfloat16 tok = AutoTokenizer.from_pretrained(args.base_model); tok.padding_side = "left" if tok.pad_token is None: tok.pad_token = tok.eos_token bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dt, bnb_4bit_use_double_quant=True) llm = AutoModelForCausalLM.from_pretrained(args.base_model, quantization_config=bnb, torch_dtype=dt, device_map={"": 0}, attn_implementation="sdpa") llm.config.use_cache = True for pr in llm.parameters(): pr.requires_grad_(False) if pr.dtype == torch.float32: pr.data = pr.data.to(dt) llm.eval() H = llm.config.hidden_size; embed = llm.get_input_embeddings() with torch.no_grad(): ids = torch.randint(0, embed.weight.size(0), (512,), device=dev) tn = embed(ids).float().norm(dim=-1).mean().item() # held-out prompts = LAST n of the pool (least likely used by the ptr sweep start) allp = [json.loads(l)["q"] for l in open(args.prompts)] prompts = allp[-args.n:] # granular position bins out to long-context regime edges = [0, 768, 1500, 3000, 6000, 12000, 20000, 99999] bins = [(edges[i], edges[i + 1]) for i in range(len(edges) - 1)] def load_pooler(path): ck = torch.load(path, map_location="cpu", weights_only=False) a = ck.get("args", {}) or ck.get("src_args", {}) S = a.get("n_sp") or 32 pl = TS.AttnPoolSP(H, S, a.get("heads", 8), a.get("layers", 3), a.get("ffn", 2048), tn).to(dev) pl.load_state_dict(ck["pooler"]); pl.eval() return pl, S results = {} for spec in args.ckpts.split(","): name, path = spec.split("=") pooler, S = load_pooler(path) tot = {b: [0.0, 0] for b in bins} for s in range(0, len(prompts), args.B): batch = prompts[s:s + args.B] q_ids, q_mask = TS.pad_queries(tok, batch, dev) gen = TS.rollout_sp(pooler, llm, embed, q_ids, q_mask, args.gen_len, S, args.rw, args.C, dev, dt, H, args.temp) acc = score_traj(pooler, llm, embed, q_ids, q_mask, gen, S, args.rw, args.C, dev, dt, H, bins) for b in bins: tot[b][0] += acc[b][0]; tot[b][1] += acc[b][1] torch.cuda.empty_cache() results[name] = {b: (tot[b][0] / tot[b][1] if tot[b][1] else float("nan")) for b in bins} print(f"[{name}] " + " ".join(f"{lo}-{hi if hi<99999 else 'inf'}:{results[name][(lo,hi)]:.4f}" for lo, hi in bins), flush=True) print("\n=== free-running SP-path KL(full-KV || SP) by position (lower=less drift) ===", flush=True) print(f"{'bin':>12} | " + " | ".join(f"{n:>8}" for n in results) + " | delta", flush=True) for lo, hi in bins: row = [results[n][(lo, hi)] for n in results] lab = f"{lo}-{hi if hi<99999 else 'inf'}" d = (row[1] - row[0]) if len(row) == 2 else 0.0 print(f"{lab:>12} | " + " | ".join(f"{v:8.4f}" for v in row) + f" | {d:+.4f}", flush=True) print("DONE_EVAL", flush=True) if __name__ == "__main__": main()