Samarth0710 commited on
Commit
e039734
·
verified ·
1 Parent(s): d2398b2

Upload improve_pertensor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. improve_pertensor.py +102 -0
improve_pertensor.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Improvement 1: per-tensor anchor-basis ridge mapping.
3
+ Each tensor (e.g. layer.5.q_proj.lora_B) gets its own α_3, instead of one global α.
4
+ """
5
+ import json, shutil
6
+ from pathlib import Path
7
+ import torch
8
+ from safetensors.torch import load_file, save_file
9
+
10
+ OUT = Path("/app/out")
11
+ TASKS = ["A","B","C","D"]
12
+
13
+ def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()}
14
+
15
+ X = {t: load_sd(OUT/"X"/f"X_{t}") for t in TASKS}
16
+ Y = {t: load_sd(OUT/"Y"/f"Y_{t}") for t in TASKS}
17
+
18
+ keys_X = sorted(X["A"].keys())
19
+ keys_Y = sorted(Y["A"].keys())
20
+
21
+ # Map per-tensor: same key index in X and Y because target_modules and layer count differ -> can't.
22
+ # Hidden sizes & layer counts differ between X(24L) and Y(16L). So we need a per-tensor map keyed by
23
+ # (module_kind, AB, layer_idx_within_each_model).
24
+ # Strategy: align by (module, AB, normalized_layer_position rounded to nearest in Y).
25
+ import re, collections
26
+
27
+ def parse(k):
28
+ m = re.search(r'layers\.(\d+)\.self_attn\.(\w+)_proj\.lora_(A|B)', k)
29
+ return int(m.group(1)), m.group(2), m.group(3)
30
+
31
+ def group_by_layer(sd):
32
+ g = collections.defaultdict(dict) # (module, AB) -> {layer: tensor}
33
+ for k,v in sd.items():
34
+ L,mod,AB = parse(k)
35
+ g[(mod,AB)][L] = (k, v)
36
+ return g
37
+
38
+ # For per-tensor mapping we operate on each Y tensor independently using the matched X tensor at the
39
+ # closest normalized layer position.
40
+ ridge = 1e-3
41
+
42
+ def predict_per_tensor(X_anchors_sds, Y_anchors_sds, X_target_sd, ridge=1e-3):
43
+ """For each Y tensor, find aligned X tensor (same module/AB at nearest normalized layer)
44
+ in EACH adapter, build (Y_target_tensor_flat, X_aligned_tensor_flat) anchor pairs of length 3,
45
+ solve 3x3 ridge → α, predict Y_hat tensor."""
46
+ Yg_a = [group_by_layer(s) for s in Y_anchors_sds]
47
+ Xg_a = [group_by_layer(s) for s in X_anchors_sds]
48
+ Xg_t = group_by_layer(X_target_sd)
49
+ nLY = max(Yg_a[0][("q","A")].keys()) + 1
50
+ nLX = max(Xg_a[0][("q","A")].keys()) + 1
51
+ # pre-build mapping Y-layer -> nearest X-layer
52
+ def nearest_x(L_y):
53
+ # normalize
54
+ return round(L_y * (nLX-1) / max(1, nLY-1))
55
+ pred = {}
56
+ for (mod,AB), layers in Yg_a[0].items():
57
+ for L_y, (key, _) in layers.items():
58
+ L_x = nearest_x(L_y)
59
+ # collect anchors
60
+ Y_vecs = []; X_vecs = []
61
+ for Yg, Xg in zip(Yg_a, Xg_a):
62
+ _, yt = Yg[(mod,AB)][L_y]
63
+ _, xt = Xg[(mod,AB)][L_x]
64
+ Y_vecs.append(yt.reshape(-1))
65
+ X_vecs.append(xt.reshape(-1))
66
+ _, x_target_t = Xg_t[(mod,AB)][L_x]
67
+ x_target = x_target_t.reshape(-1)
68
+ Yc = torch.stack(Y_vecs); Xc = torch.stack(X_vecs)
69
+ Ym = Yc.mean(0); Xm = Xc.mean(0)
70
+ Yco = Yc - Ym; Xco = Xc - Xm
71
+ G = Xco @ Xco.T
72
+ rhs = Xco @ (x_target - Xm)
73
+ alpha = torch.linalg.solve(G + ridge*torch.eye(3), rhs)
74
+ y_hat = Ym + alpha @ Yco
75
+ pred[key] = y_hat.reshape(yt.shape)
76
+ return pred
77
+
78
+ X_anc = [X["A"], X["B"], X["C"]]
79
+ Y_anc = [Y["A"], Y["B"], Y["C"]]
80
+ pred_sd = predict_per_tensor(X_anc, Y_anc, X["D"], ridge=ridge)
81
+
82
+ # Diagnostics
83
+ def cos(a,b):
84
+ a = a.flatten(); b = b.flatten()
85
+ return torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
86
+
87
+ # overall cosine
88
+ flat_pred = torch.cat([pred_sd[k].flatten() for k in keys_Y])
89
+ flat_oracle = torch.cat([Y["D"][k].flatten() for k in keys_Y])
90
+ flat_mean = torch.cat([torch.stack([Y[t][k] for t in ["A","B","C"]]).mean(0).flatten() for k in keys_Y])
91
+ print("PER-TENSOR cos(Y_hat,Y_D):", cos(flat_pred, flat_oracle))
92
+ print("PER-TENSOR cos(Y_mean,Y_D):", cos(flat_mean, flat_oracle))
93
+
94
+ # save
95
+ out_dir = OUT/"Y"/"Y_pred_D_pertensor"
96
+ out_dir.mkdir(parents=True, exist_ok=True)
97
+ shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", out_dir/"adapter_config.json")
98
+ for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]:
99
+ src = OUT/"Y"/"Y_A"/f
100
+ if src.exists(): shutil.copy(src, out_dir/f)
101
+ save_file({k: v.to(torch.bfloat16) for k,v in pred_sd.items()}, str(out_dir/"adapter_model.safetensors"))
102
+ print("Saved", out_dir)