| |
| """ |
| Decisive eval: does ap32_sched fix the >1000-token collapse vs ap32_large? |
| |
| For each checkpoint: free-run the SP-path (bounded memory, rw=128) to GEN_LEN on held-out |
| prompts, then teacher-force-score that SAME self-generated trajectory -> per-position |
| KL(full-KV teacher || SP-path student), binned by position. Lower KL at >1000 = drift fixed. |
| 4-bit frozen LLM (deployment config). Reuses train_sched building blocks. |
| """ |
| import sys, json, argparse |
| sys.path.insert(0, "/workspace") |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| from transformers.cache_utils import DynamicCache |
| import train_sched as TS |
|
|
|
|
| @torch.no_grad() |
| def score_traj(pooler, llm, embed, q_ids, q_mask, a_ids, S, rw, C, dev, dt, H, bins): |
| """Per-position KL(teacher||student) teacher-forced on a_ids, accumulated into position bins.""" |
| B, MQ = q_ids.shape; La = a_ids.size(1) |
| qlen = q_mask.sum(-1).long() |
| qpos = (q_mask.cumsum(-1) - 1).clamp(min=0).long() |
| q_last, oa_logits = TS.teacher_full(llm, embed, q_ids, q_mask, a_ids, dev, dt) |
| acc = {b: [0.0, 0] for b in bins} |
| n_chunks = (La + C - 1) // C |
| for j in range(n_chunks): |
| c0 = j * C; c1 = min(c0 + C, La); cur = c1 - c0 |
| if cur < 1: continue |
| R = min(c0, rw); dl = c0 - R |
| sp = pooler(embed(a_ids[:, :dl]).float()) if dl > 0 else pooler(torch.zeros(B, 0, H, device=dev)) |
| cache = DynamicCache() |
| llm(inputs_embeds=embed(q_ids).to(dt), attention_mask=q_mask, past_key_values=cache, |
| position_ids=qpos, use_cache=True, cache_position=torch.arange(MQ, device=dev)) |
| parts = [sp.to(dt)] |
| if R > 0: parts.append(embed(a_ids[:, c0 - R:c0]).to(dt)) |
| parts.append(embed(a_ids[:, c0:c1]).to(dt)) |
| block = torch.cat(parts, 1); nb = block.size(1) |
| pos = qlen.unsqueeze(1) + torch.arange(nb, device=dev).unsqueeze(0) |
| amask = torch.cat([q_mask, torch.ones(B, nb, device=dev)], 1) |
| o = llm(inputs_embeds=block, attention_mask=amask, past_key_values=cache, |
| position_ids=pos, cache_position=torch.arange(MQ, MQ + nb, device=dev), use_cache=True) |
| ps = S - 1 + R |
| s_logits = o.logits[:, ps:ps + cur, :].float() |
| if c0 == 0: |
| t_logits = torch.cat([q_last[:, None, :], oa_logits[:, 0:cur - 1, :]], 1).float() |
| else: |
| t_logits = oa_logits[:, c0 - 1:c1 - 1, :].float() |
| kl = TS.kl_td(t_logits, s_logits) |
| for lo, hi in bins: |
| if lo <= c0 < hi: |
| acc[(lo, hi)][0] += kl.sum().item(); acc[(lo, hi)][1] += kl.numel() |
| del q_last, oa_logits |
| return acc |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--base_model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B") |
| p.add_argument("--ckpts", default="orig=/workspace/ap32_large.pt,sched=/workspace/ap32_sched.pt") |
| p.add_argument("--prompts", default="/workspace/prompts.jsonl") |
| p.add_argument("--n", type=int, default=24) |
| p.add_argument("--B", type=int, default=12) |
| p.add_argument("--gen_len", type=int, default=1400) |
| p.add_argument("--rw", type=int, default=128) |
| p.add_argument("--C", type=int, default=64) |
| p.add_argument("--temp", type=float, default=0.7) |
| p.add_argument("--max_q_len", type=int, default=160) |
| args = p.parse_args() |
| dev = torch.device("cuda"); dt = torch.bfloat16 |
| tok = AutoTokenizer.from_pretrained(args.base_model); tok.padding_side = "left" |
| if tok.pad_token is None: tok.pad_token = tok.eos_token |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", |
| bnb_4bit_compute_dtype=dt, bnb_4bit_use_double_quant=True) |
| llm = AutoModelForCausalLM.from_pretrained(args.base_model, quantization_config=bnb, |
| torch_dtype=dt, device_map={"": 0}, attn_implementation="sdpa") |
| llm.config.use_cache = True |
| for pr in llm.parameters(): |
| pr.requires_grad_(False) |
| if pr.dtype == torch.float32: pr.data = pr.data.to(dt) |
| llm.eval() |
| H = llm.config.hidden_size; embed = llm.get_input_embeddings() |
| with torch.no_grad(): |
| ids = torch.randint(0, embed.weight.size(0), (512,), device=dev) |
| tn = embed(ids).float().norm(dim=-1).mean().item() |
|
|
| |
| allp = [json.loads(l)["q"] for l in open(args.prompts)] |
| prompts = allp[-args.n:] |
| |
| edges = [0, 768, 1500, 3000, 6000, 12000, 20000, 99999] |
| bins = [(edges[i], edges[i + 1]) for i in range(len(edges) - 1)] |
|
|
| def load_pooler(path): |
| ck = torch.load(path, map_location="cpu", weights_only=False) |
| a = ck.get("args", {}) or ck.get("src_args", {}) |
| S = a.get("n_sp") or 32 |
| pl = TS.AttnPoolSP(H, S, a.get("heads", 8), a.get("layers", 3), a.get("ffn", 2048), tn).to(dev) |
| pl.load_state_dict(ck["pooler"]); pl.eval() |
| return pl, S |
|
|
| results = {} |
| for spec in args.ckpts.split(","): |
| name, path = spec.split("=") |
| pooler, S = load_pooler(path) |
| tot = {b: [0.0, 0] for b in bins} |
| for s in range(0, len(prompts), args.B): |
| batch = prompts[s:s + args.B] |
| q_ids, q_mask = TS.pad_queries(tok, batch, dev) |
| gen = TS.rollout_sp(pooler, llm, embed, q_ids, q_mask, args.gen_len, S, args.rw, args.C, dev, dt, H, args.temp) |
| acc = score_traj(pooler, llm, embed, q_ids, q_mask, gen, S, args.rw, args.C, dev, dt, H, bins) |
| for b in bins: |
| tot[b][0] += acc[b][0]; tot[b][1] += acc[b][1] |
| torch.cuda.empty_cache() |
| results[name] = {b: (tot[b][0] / tot[b][1] if tot[b][1] else float("nan")) for b in bins} |
| print(f"[{name}] " + " ".join(f"{lo}-{hi if hi<99999 else 'inf'}:{results[name][(lo,hi)]:.4f}" |
| for lo, hi in bins), flush=True) |
|
|
| print("\n=== free-running SP-path KL(full-KV || SP) by position (lower=less drift) ===", flush=True) |
| print(f"{'bin':>12} | " + " | ".join(f"{n:>8}" for n in results) + " | delta", flush=True) |
| for lo, hi in bins: |
| row = [results[n][(lo, hi)] for n in results] |
| lab = f"{lo}-{hi if hi<99999 else 'inf'}" |
| d = (row[1] - row[0]) if len(row) == 2 else 0.0 |
| print(f"{lab:>12} | " + " | ".join(f"{v:8.4f}" for v in row) + f" | {d:+.4f}", flush=True) |
| print("DONE_EVAL", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|