File size: 6,616 Bytes
788f72f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
#!/usr/bin/env python3
"""
Decisive eval: does ap32_sched fix the >1000-token collapse vs ap32_large?

For each checkpoint: free-run the SP-path (bounded memory, rw=128) to GEN_LEN on held-out
prompts, then teacher-force-score that SAME self-generated trajectory -> per-position
KL(full-KV teacher || SP-path student), binned by position. Lower KL at >1000 = drift fixed.
4-bit frozen LLM (deployment config). Reuses train_sched building blocks.
"""
import sys, json, argparse
sys.path.insert(0, "/workspace")
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.cache_utils import DynamicCache
import train_sched as TS


@torch.no_grad()
def score_traj(pooler, llm, embed, q_ids, q_mask, a_ids, S, rw, C, dev, dt, H, bins):
    """Per-position KL(teacher||student) teacher-forced on a_ids, accumulated into position bins."""
    B, MQ = q_ids.shape; La = a_ids.size(1)
    qlen = q_mask.sum(-1).long()
    qpos = (q_mask.cumsum(-1) - 1).clamp(min=0).long()
    q_last, oa_logits = TS.teacher_full(llm, embed, q_ids, q_mask, a_ids, dev, dt)
    acc = {b: [0.0, 0] for b in bins}            # (sum_kl, n_tokens)
    n_chunks = (La + C - 1) // C
    for j in range(n_chunks):
        c0 = j * C; c1 = min(c0 + C, La); cur = c1 - c0
        if cur < 1: continue
        R = min(c0, rw); dl = c0 - R
        sp = pooler(embed(a_ids[:, :dl]).float()) if dl > 0 else pooler(torch.zeros(B, 0, H, device=dev))
        cache = DynamicCache()
        llm(inputs_embeds=embed(q_ids).to(dt), attention_mask=q_mask, past_key_values=cache,
            position_ids=qpos, use_cache=True, cache_position=torch.arange(MQ, device=dev))
        parts = [sp.to(dt)]
        if R > 0: parts.append(embed(a_ids[:, c0 - R:c0]).to(dt))
        parts.append(embed(a_ids[:, c0:c1]).to(dt))
        block = torch.cat(parts, 1); nb = block.size(1)
        pos = qlen.unsqueeze(1) + torch.arange(nb, device=dev).unsqueeze(0)
        amask = torch.cat([q_mask, torch.ones(B, nb, device=dev)], 1)
        o = llm(inputs_embeds=block, attention_mask=amask, past_key_values=cache,
                position_ids=pos, cache_position=torch.arange(MQ, MQ + nb, device=dev), use_cache=True)
        ps = S - 1 + R
        s_logits = o.logits[:, ps:ps + cur, :].float()
        if c0 == 0:
            t_logits = torch.cat([q_last[:, None, :], oa_logits[:, 0:cur - 1, :]], 1).float()
        else:
            t_logits = oa_logits[:, c0 - 1:c1 - 1, :].float()
        kl = TS.kl_td(t_logits, s_logits)        # (B, cur)
        for lo, hi in bins:
            if lo <= c0 < hi:
                acc[(lo, hi)][0] += kl.sum().item(); acc[(lo, hi)][1] += kl.numel()
    del q_last, oa_logits
    return acc


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--base_model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
    p.add_argument("--ckpts", default="orig=/workspace/ap32_large.pt,sched=/workspace/ap32_sched.pt")
    p.add_argument("--prompts", default="/workspace/prompts.jsonl")
    p.add_argument("--n", type=int, default=24)
    p.add_argument("--B", type=int, default=12)
    p.add_argument("--gen_len", type=int, default=1400)
    p.add_argument("--rw", type=int, default=128)
    p.add_argument("--C", type=int, default=64)
    p.add_argument("--temp", type=float, default=0.7)
    p.add_argument("--max_q_len", type=int, default=160)
    args = p.parse_args()
    dev = torch.device("cuda"); dt = torch.bfloat16
    tok = AutoTokenizer.from_pretrained(args.base_model); tok.padding_side = "left"
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4",
                             bnb_4bit_compute_dtype=dt, bnb_4bit_use_double_quant=True)
    llm = AutoModelForCausalLM.from_pretrained(args.base_model, quantization_config=bnb,
                                               torch_dtype=dt, device_map={"": 0}, attn_implementation="sdpa")
    llm.config.use_cache = True
    for pr in llm.parameters():
        pr.requires_grad_(False)
        if pr.dtype == torch.float32: pr.data = pr.data.to(dt)
    llm.eval()
    H = llm.config.hidden_size; embed = llm.get_input_embeddings()
    with torch.no_grad():
        ids = torch.randint(0, embed.weight.size(0), (512,), device=dev)
        tn = embed(ids).float().norm(dim=-1).mean().item()

    # held-out prompts = LAST n of the pool (least likely used by the ptr sweep start)
    allp = [json.loads(l)["q"] for l in open(args.prompts)]
    prompts = allp[-args.n:]
    # granular position bins out to long-context regime
    edges = [0, 768, 1500, 3000, 6000, 12000, 20000, 99999]
    bins = [(edges[i], edges[i + 1]) for i in range(len(edges) - 1)]

    def load_pooler(path):
        ck = torch.load(path, map_location="cpu", weights_only=False)
        a = ck.get("args", {}) or ck.get("src_args", {})
        S = a.get("n_sp") or 32
        pl = TS.AttnPoolSP(H, S, a.get("heads", 8), a.get("layers", 3), a.get("ffn", 2048), tn).to(dev)
        pl.load_state_dict(ck["pooler"]); pl.eval()
        return pl, S

    results = {}
    for spec in args.ckpts.split(","):
        name, path = spec.split("=")
        pooler, S = load_pooler(path)
        tot = {b: [0.0, 0] for b in bins}
        for s in range(0, len(prompts), args.B):
            batch = prompts[s:s + args.B]
            q_ids, q_mask = TS.pad_queries(tok, batch, dev)
            gen = TS.rollout_sp(pooler, llm, embed, q_ids, q_mask, args.gen_len, S, args.rw, args.C, dev, dt, H, args.temp)
            acc = score_traj(pooler, llm, embed, q_ids, q_mask, gen, S, args.rw, args.C, dev, dt, H, bins)
            for b in bins:
                tot[b][0] += acc[b][0]; tot[b][1] += acc[b][1]
            torch.cuda.empty_cache()
        results[name] = {b: (tot[b][0] / tot[b][1] if tot[b][1] else float("nan")) for b in bins}
        print(f"[{name}] " + "  ".join(f"{lo}-{hi if hi<99999 else 'inf'}:{results[name][(lo,hi)]:.4f}"
                                       for lo, hi in bins), flush=True)

    print("\n=== free-running SP-path KL(full-KV || SP) by position (lower=less drift) ===", flush=True)
    print(f"{'bin':>12} | " + " | ".join(f"{n:>8}" for n in results) + " | delta", flush=True)
    for lo, hi in bins:
        row = [results[n][(lo, hi)] for n in results]
        lab = f"{lo}-{hi if hi<99999 else 'inf'}"
        d = (row[1] - row[0]) if len(row) == 2 else 0.0
        print(f"{lab:>12} | " + " | ".join(f"{v:8.4f}" for v in row) + f" | {d:+.4f}", flush=True)
    print("DONE_EVAL", flush=True)


if __name__ == "__main__":
    main()