baya1116 commited on
Commit
3fcff03
·
verified ·
1 Parent(s): a0650a5

Upload consistency_canon.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. consistency_canon.py +163 -0
consistency_canon.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Is approach (A) [canonicalize oracle SP* then MSE-regress] viable?
4
+ Test: for the SAME context, optimize the SP from TWO different random inits -> sp_a, sp_b
5
+ (both reach low KL, but symmetry makes them differ). Then PROCRUSTES-align sp_b to sp_a.
6
+ raw_dist = ||sp_a - sp_b||^2 / token (how different the two valid SPs are)
7
+ canon_dist= ||sp_a - sp_b@R||^2 / token (after removing an orthogonal rotation)
8
+ If canon_dist << raw_dist => the symmetry is (mostly) a rotation -> canonicalization makes
9
+ the target unique -> MSE-regression viable => GO (A).
10
+ If canon_dist ~ raw_dist => symmetry is deeper than rotation -> MSE won't work => GO (B, KL).
11
+ Also try per-token permutation alignment (the 128 soft tokens may be reorderable).
12
+ """
13
+ import sys
14
+ sys.path.insert(0, "/workspace")
15
+ import argparse, json
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer
19
+ from transformers.cache_utils import DynamicCache
20
+ from train_qwen_distill import (HyperNetwork, Config, extract_qa, CJK_RE, TOOLCALL_RE,
21
+ soft_prompt_stability_loss)
22
+
23
+
24
+ @torch.no_grad()
25
+ def teacher_dist(llm, embed, q, a, mq, ma, dev, dt):
26
+ cache = DynamicCache()
27
+ oq = llm(inputs_embeds=embed(q[:, :mq]).to(dt), attention_mask=torch.ones(1, mq, device=dev),
28
+ past_key_values=cache, use_cache=True, cache_position=torch.arange(mq, device=dev))
29
+ t0 = oq.logits[:, -1, :]
30
+ pos = torch.arange(mq, mq + ma, device=dev)
31
+ oa = llm(inputs_embeds=embed(a[:, :ma]).to(dt), attention_mask=torch.ones(1, mq + ma, device=dev),
32
+ past_key_values=cache, position_ids=pos.unsqueeze(0), use_cache=True, cache_position=pos)
33
+ V = oa.logits.size(-1); T = torch.empty(1, ma, V, dtype=oa.logits.dtype, device=dev)
34
+ T[:, 0] = t0; T[:, 1:] = oa.logits[:, :ma - 1]; return T
35
+
36
+
37
+ @torch.no_grad()
38
+ def prefill(llm, embed, q, mq, dev, dt):
39
+ c = DynamicCache()
40
+ llm(inputs_embeds=embed(q[:, :mq]).to(dt), attention_mask=torch.ones(1, mq, device=dev),
41
+ past_key_values=c, use_cache=True, cache_position=torch.arange(mq, device=dev))
42
+ return c
43
+
44
+
45
+ def opt_sp(llm, embed, q, a, teacher, c0, c1, mq, dev, dt, cfg, mn, rw, steps, lr, S, T, seed):
46
+ cur = c1 - c0; R = min(c0, rw)
47
+ raw = embed(a[:, c0 - R:c0]).to(dt) if R > 0 else None
48
+ chunk = embed(a[:, c0:c1]).to(dt); n = S + R + cur
49
+ cp = torch.arange(mq, mq + n, device=dev)
50
+ tp = F.softmax(teacher[:, c0:c1].float() / T, dim=-1)
51
+ g = torch.Generator(device=dev).manual_seed(seed)
52
+ sp = torch.randn(1, S, cfg.hidden_dim, generator=g, device=dev)
53
+ sp = (sp / sp.norm(dim=-1, keepdim=True).clamp(min=1e-6) * cfg.target_norm).requires_grad_(True)
54
+ opt = torch.optim.Adam([sp], lr=lr); best = float("inf"); bv = None
55
+ for _ in range(steps):
56
+ opt.zero_grad(set_to_none=True)
57
+ nm = sp.norm(dim=-1, keepdim=True).clamp(min=1e-6); sc = torch.where(nm > mn, mn / nm, torch.ones_like(nm))
58
+ spc = (sp * sc).to(dt)
59
+ x = torch.cat([spc, raw, chunk], 1) if R > 0 else torch.cat([spc, chunk], 1)
60
+ cache = prefill(llm, embed, q, mq, dev, dt)
61
+ o = llm(inputs_embeds=x, attention_mask=torch.ones(1, mq + n, device=dev), past_key_values=cache,
62
+ position_ids=cp.unsqueeze(0), use_cache=True, cache_position=cp)
63
+ lp = F.log_softmax(o.logits[:, S - 1 + R:S - 1 + R + cur].float() / T, dim=-1)
64
+ kl = (tp * (tp.clamp_min(1e-9).log() - lp)).sum(-1).mean() * (T * T)
65
+ (kl + soft_prompt_stability_loss(sp, cfg)).backward(); opt.step()
66
+ if kl.item() < best:
67
+ best = kl.item(); bv = (sp.detach() * sc).clone()
68
+ return bv[0], best # (S,H), kl
69
+
70
+
71
+ def procrustes(Bm, Am): # align Bm to Am: return Bm@R, R orthogonal minimizing ||Am - Bm R||
72
+ M = Bm.transpose(-1, -2) @ Am
73
+ U, S, Vh = torch.linalg.svd(M)
74
+ return Bm @ (U @ Vh)
75
+
76
+
77
+ def perm_align(Bm, Am): # align rows of Bm to Am by greedy nearest (permutation of 128 tokens)
78
+ # cost[i,j] = ||Am[i]-Bm[j]||^2 ; greedy match
79
+ d = torch.cdist(Am, Bm) # (S,S)
80
+ used = torch.zeros(Bm.size(0), dtype=torch.bool, device=Bm.device)
81
+ out = torch.empty_like(Bm)
82
+ for i in range(Am.size(0)):
83
+ row = d[i].clone(); row[used] = float("inf"); j = int(row.argmin()); used[j] = True; out[i] = Bm[j]
84
+ return out
85
+
86
+
87
+ def load_samples(path, tok, cfg, n, mc, mal, mt):
88
+ out = []
89
+ with open(path) as f:
90
+ for line in f:
91
+ if len(out) >= n: break
92
+ line = line.strip()
93
+ if not line: continue
94
+ try: row = json.loads(line)
95
+ except Exception: continue
96
+ q, a = extract_qa(row, cfg)
97
+ if not q or not a or len(a) < mc: continue
98
+ if CJK_RE.search(a) or CJK_RE.search(q) or TOOLCALL_RE.search(a) or TOOLCALL_RE.search(q): continue
99
+ qi = tok(q, max_length=cfg.max_query_len, truncation=True, add_special_tokens=True).input_ids
100
+ ai = tok(a, max_length=mal, truncation=True, add_special_tokens=False).input_ids
101
+ if len(ai) < mt: continue
102
+ out.append((qi, ai))
103
+ return out[300:300 + n] if len(out) > 300 + n else out
104
+
105
+
106
+ def main():
107
+ p = argparse.ArgumentParser()
108
+ p.add_argument("--ckpt", default="/workspace/hypernet_qwen/hn_step7750.pt")
109
+ p.add_argument("--base_model", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B")
110
+ p.add_argument("--data", default="/workspace/dolphin_subset.jsonl")
111
+ p.add_argument("--n", type=int, default=12)
112
+ p.add_argument("--depths", default="128,256,384")
113
+ p.add_argument("--chunk_size", type=int, default=64)
114
+ p.add_argument("--raw_window", type=int, default=32)
115
+ p.add_argument("--steps", type=int, default=120)
116
+ p.add_argument("--lr", type=float, default=0.03)
117
+ args = p.parse_args()
118
+ dev = torch.device("cuda"); dt = torch.bfloat16
119
+ cfg = Config(); cfg.base_model = args.base_model
120
+ C = args.chunk_size; S = cfg.num_soft_tokens; T = 1.0
121
+ print("Loading frozen base...", flush=True)
122
+ tok = AutoTokenizer.from_pretrained(cfg.base_model)
123
+ if tok.pad_token is None: tok.pad_token = tok.eos_token
124
+ llm = AutoModelForCausalLM.from_pretrained(cfg.base_model, dtype=dt, device_map="cuda", attn_implementation="sdpa")
125
+ llm.config.use_cache = True
126
+ for prm in llm.parameters(): prm.requires_grad_(False)
127
+ llm.eval(); embed = llm.get_input_embeddings(); cfg.hidden_dim = llm.config.hidden_size
128
+ with torch.no_grad():
129
+ ids = torch.randint(0, embed.weight.size(0), (512,), device=dev)
130
+ cfg.target_norm = embed(ids).float().norm(dim=-1).mean().item()
131
+ mn = cfg.target_norm * 3.0
132
+ hn = HyperNetwork(cfg).to(dtype=torch.float32, device=dev); hn.eval()
133
+ ckd = torch.load(args.ckpt, map_location="cpu", weights_only=False); hn.load_state_dict(ckd["hypernet"], strict=False)
134
+ depths = [int(x) for x in args.depths.split(",")]
135
+ samples = load_samples(args.data, tok, cfg, args.n, 1500, 512, 400)
136
+ raw_l, rot_l, prm_l, kl_l = [], [], [], []
137
+ for si, (qi, ai) in enumerate(samples):
138
+ q = torch.tensor([qi], device=dev); a = torch.tensor([ai], device=dev); mq = q.size(1); ma = a.size(1)
139
+ teacher = teacher_dist(llm, embed, q, a, mq, ma, dev, dt)
140
+ for c0 in depths:
141
+ c1 = min(c0 + C, ma)
142
+ if c1 - c0 < 4 or c0 + 1 >= ma: continue
143
+ spa, ka = opt_sp(llm, embed, q, a, teacher, c0, c1, mq, dev, dt, cfg, mn, args.raw_window, args.steps, args.lr, S, T, 111)
144
+ spb, kb = opt_sp(llm, embed, q, a, teacher, c0, c1, mq, dev, dt, cfg, mn, args.raw_window, args.steps, args.lr, S, T, 999)
145
+ raw = ((spa - spb) ** 2).sum(-1).mean().item()
146
+ rot = ((spa - procrustes(spb, spa)) ** 2).sum(-1).mean().item()
147
+ prm = ((spa - perm_align(spb, spa)) ** 2).sum(-1).mean().item()
148
+ raw_l.append(raw); rot_l.append(rot); prm_l.append(prm); kl_l.append(0.5 * (ka + kb))
149
+ print(f" s{si+1} c0={c0}: KL~{0.5*(ka+kb):.4f} raw={raw:.3f} afterRotation={rot:.3f} afterPerm={prm:.3f}", flush=True)
150
+ del teacher; torch.cuda.empty_cache()
151
+ R = sum(raw_l)/len(raw_l); RO = sum(rot_l)/len(rot_l); PR = sum(prm_l)/len(prm_l)
152
+ print("\n" + "=" * 64)
153
+ print(f"mean KL={sum(kl_l)/len(kl_l):.4f} (both SPs valid)")
154
+ print(f"raw disagreement = {R:.3f}")
155
+ print(f"after rotation align = {RO:.3f} ({100*RO/R:.0f}% of raw)")
156
+ print(f"after permutation align = {PR:.3f} ({100*PR/R:.0f}% of raw)")
157
+ print("=" * 64)
158
+ print("if rotation/perm align << raw (e.g. <40%): symmetry is removable -> (A) canonicalize+MSE VIABLE")
159
+ print("if still ~raw: symmetry deeper than rotation/perm -> go (B) output-space KL")
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()