hypernet-sp-distill / eval_collapse.py
baya1116's picture
Upload eval_collapse.py with huggingface_hub
788f72f verified
#!/usr/bin/env python3
"""
Decisive eval: does ap32_sched fix the >1000-token collapse vs ap32_large?
For each checkpoint: free-run the SP-path (bounded memory, rw=128) to GEN_LEN on held-out
prompts, then teacher-force-score that SAME self-generated trajectory -> per-position
KL(full-KV teacher || SP-path student), binned by position. Lower KL at >1000 = drift fixed.
4-bit frozen LLM (deployment config). Reuses train_sched building blocks.
"""
import sys, json, argparse
sys.path.insert(0, "/workspace")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.cache_utils import DynamicCache
import train_sched as TS
@torch.no_grad()
def score_traj(pooler, llm, embed, q_ids, q_mask, a_ids, S, rw, C, dev, dt, H, bins):
"""Per-position KL(teacher||student) teacher-forced on a_ids, accumulated into position bins."""
B, MQ = q_ids.shape; La = a_ids.size(1)
qlen = q_mask.sum(-1).long()
qpos = (q_mask.cumsum(-1) - 1).clamp(min=0).long()
q_last, oa_logits = TS.teacher_full(llm, embed, q_ids, q_mask, a_ids, dev, dt)
acc = {b: [0.0, 0] for b in bins} # (sum_kl, n_tokens)
n_chunks = (La + C - 1) // C
for j in range(n_chunks):
c0 = j * C; c1 = min(c0 + C, La); cur = c1 - c0
if cur < 1: continue
R = min(c0, rw); dl = c0 - R
sp = pooler(embed(a_ids[:, :dl]).float()) if dl > 0 else pooler(torch.zeros(B, 0, H, device=dev))
cache = DynamicCache()
llm(inputs_embeds=embed(q_ids).to(dt), attention_mask=q_mask, past_key_values=cache,
position_ids=qpos, use_cache=True, cache_position=torch.arange(MQ, device=dev))
parts = [sp.to(dt)]
if R > 0: parts.append(embed(a_ids[:, c0 - R:c0]).to(dt))
parts.append(embed(a_ids[:, c0:c1]).to(dt))
block = torch.cat(parts, 1); nb = block.size(1)
pos = qlen.unsqueeze(1) + torch.arange(nb, device=dev).unsqueeze(0)
amask = torch.cat([q_mask, torch.ones(B, nb, device=dev)], 1)
o = llm(inputs_embeds=block, attention_mask=amask, past_key_values=cache,
position_ids=pos, cache_position=torch.arange(MQ, MQ + nb, device=dev), use_cache=True)
ps = S - 1 + R
s_logits = o.logits[:, ps:ps + cur, :].float()
if c0 == 0:
t_logits = torch.cat([q_last[:, None, :], oa_logits[:, 0:cur - 1, :]], 1).float()
else:
t_logits = oa_logits[:, c0 - 1:c1 - 1, :].float()
kl = TS.kl_td(t_logits, s_logits) # (B, cur)
for lo, hi in bins:
if lo <= c0 < hi:
acc[(lo, hi)][0] += kl.sum().item(); acc[(lo, hi)][1] += kl.numel()
del q_last, oa_logits
return acc
def main():
p = argparse.ArgumentParser()
p.add_argument("--base_model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
p.add_argument("--ckpts", default="orig=/workspace/ap32_large.pt,sched=/workspace/ap32_sched.pt")
p.add_argument("--prompts", default="/workspace/prompts.jsonl")
p.add_argument("--n", type=int, default=24)
p.add_argument("--B", type=int, default=12)
p.add_argument("--gen_len", type=int, default=1400)
p.add_argument("--rw", type=int, default=128)
p.add_argument("--C", type=int, default=64)
p.add_argument("--temp", type=float, default=0.7)
p.add_argument("--max_q_len", type=int, default=160)
args = p.parse_args()
dev = torch.device("cuda"); dt = torch.bfloat16
tok = AutoTokenizer.from_pretrained(args.base_model); tok.padding_side = "left"
if tok.pad_token is None: tok.pad_token = tok.eos_token
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dt, bnb_4bit_use_double_quant=True)
llm = AutoModelForCausalLM.from_pretrained(args.base_model, quantization_config=bnb,
torch_dtype=dt, device_map={"": 0}, attn_implementation="sdpa")
llm.config.use_cache = True
for pr in llm.parameters():
pr.requires_grad_(False)
if pr.dtype == torch.float32: pr.data = pr.data.to(dt)
llm.eval()
H = llm.config.hidden_size; embed = llm.get_input_embeddings()
with torch.no_grad():
ids = torch.randint(0, embed.weight.size(0), (512,), device=dev)
tn = embed(ids).float().norm(dim=-1).mean().item()
# held-out prompts = LAST n of the pool (least likely used by the ptr sweep start)
allp = [json.loads(l)["q"] for l in open(args.prompts)]
prompts = allp[-args.n:]
# granular position bins out to long-context regime
edges = [0, 768, 1500, 3000, 6000, 12000, 20000, 99999]
bins = [(edges[i], edges[i + 1]) for i in range(len(edges) - 1)]
def load_pooler(path):
ck = torch.load(path, map_location="cpu", weights_only=False)
a = ck.get("args", {}) or ck.get("src_args", {})
S = a.get("n_sp") or 32
pl = TS.AttnPoolSP(H, S, a.get("heads", 8), a.get("layers", 3), a.get("ffn", 2048), tn).to(dev)
pl.load_state_dict(ck["pooler"]); pl.eval()
return pl, S
results = {}
for spec in args.ckpts.split(","):
name, path = spec.split("=")
pooler, S = load_pooler(path)
tot = {b: [0.0, 0] for b in bins}
for s in range(0, len(prompts), args.B):
batch = prompts[s:s + args.B]
q_ids, q_mask = TS.pad_queries(tok, batch, dev)
gen = TS.rollout_sp(pooler, llm, embed, q_ids, q_mask, args.gen_len, S, args.rw, args.C, dev, dt, H, args.temp)
acc = score_traj(pooler, llm, embed, q_ids, q_mask, gen, S, args.rw, args.C, dev, dt, H, bins)
for b in bins:
tot[b][0] += acc[b][0]; tot[b][1] += acc[b][1]
torch.cuda.empty_cache()
results[name] = {b: (tot[b][0] / tot[b][1] if tot[b][1] else float("nan")) for b in bins}
print(f"[{name}] " + " ".join(f"{lo}-{hi if hi<99999 else 'inf'}:{results[name][(lo,hi)]:.4f}"
for lo, hi in bins), flush=True)
print("\n=== free-running SP-path KL(full-KV || SP) by position (lower=less drift) ===", flush=True)
print(f"{'bin':>12} | " + " | ".join(f"{n:>8}" for n in results) + " | delta", flush=True)
for lo, hi in bins:
row = [results[n][(lo, hi)] for n in results]
lab = f"{lo}-{hi if hi<99999 else 'inf'}"
d = (row[1] - row[0]) if len(row) == 2 else 0.0
print(f"{lab:>12} | " + " | ".join(f"{v:8.4f}" for v in row) + f" | {d:+.4f}", flush=True)
print("DONE_EVAL", flush=True)
if __name__ == "__main__":
main()