Upload improve_pertensor.py with huggingface_hub
Browse files- improve_pertensor.py +102 -0
improve_pertensor.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Improvement 1: per-tensor anchor-basis ridge mapping.
|
| 3 |
+
Each tensor (e.g. layer.5.q_proj.lora_B) gets its own α_3, instead of one global α.
|
| 4 |
+
"""
|
| 5 |
+
import json, shutil
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import torch
|
| 8 |
+
from safetensors.torch import load_file, save_file
|
| 9 |
+
|
| 10 |
+
OUT = Path("/app/out")
|
| 11 |
+
TASKS = ["A","B","C","D"]
|
| 12 |
+
|
| 13 |
+
def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()}
|
| 14 |
+
|
| 15 |
+
X = {t: load_sd(OUT/"X"/f"X_{t}") for t in TASKS}
|
| 16 |
+
Y = {t: load_sd(OUT/"Y"/f"Y_{t}") for t in TASKS}
|
| 17 |
+
|
| 18 |
+
keys_X = sorted(X["A"].keys())
|
| 19 |
+
keys_Y = sorted(Y["A"].keys())
|
| 20 |
+
|
| 21 |
+
# Map per-tensor: same key index in X and Y because target_modules and layer count differ -> can't.
|
| 22 |
+
# Hidden sizes & layer counts differ between X(24L) and Y(16L). So we need a per-tensor map keyed by
|
| 23 |
+
# (module_kind, AB, layer_idx_within_each_model).
|
| 24 |
+
# Strategy: align by (module, AB, normalized_layer_position rounded to nearest in Y).
|
| 25 |
+
import re, collections
|
| 26 |
+
|
| 27 |
+
def parse(k):
|
| 28 |
+
m = re.search(r'layers\.(\d+)\.self_attn\.(\w+)_proj\.lora_(A|B)', k)
|
| 29 |
+
return int(m.group(1)), m.group(2), m.group(3)
|
| 30 |
+
|
| 31 |
+
def group_by_layer(sd):
|
| 32 |
+
g = collections.defaultdict(dict) # (module, AB) -> {layer: tensor}
|
| 33 |
+
for k,v in sd.items():
|
| 34 |
+
L,mod,AB = parse(k)
|
| 35 |
+
g[(mod,AB)][L] = (k, v)
|
| 36 |
+
return g
|
| 37 |
+
|
| 38 |
+
# For per-tensor mapping we operate on each Y tensor independently using the matched X tensor at the
|
| 39 |
+
# closest normalized layer position.
|
| 40 |
+
ridge = 1e-3
|
| 41 |
+
|
| 42 |
+
def predict_per_tensor(X_anchors_sds, Y_anchors_sds, X_target_sd, ridge=1e-3):
|
| 43 |
+
"""For each Y tensor, find aligned X tensor (same module/AB at nearest normalized layer)
|
| 44 |
+
in EACH adapter, build (Y_target_tensor_flat, X_aligned_tensor_flat) anchor pairs of length 3,
|
| 45 |
+
solve 3x3 ridge → α, predict Y_hat tensor."""
|
| 46 |
+
Yg_a = [group_by_layer(s) for s in Y_anchors_sds]
|
| 47 |
+
Xg_a = [group_by_layer(s) for s in X_anchors_sds]
|
| 48 |
+
Xg_t = group_by_layer(X_target_sd)
|
| 49 |
+
nLY = max(Yg_a[0][("q","A")].keys()) + 1
|
| 50 |
+
nLX = max(Xg_a[0][("q","A")].keys()) + 1
|
| 51 |
+
# pre-build mapping Y-layer -> nearest X-layer
|
| 52 |
+
def nearest_x(L_y):
|
| 53 |
+
# normalize
|
| 54 |
+
return round(L_y * (nLX-1) / max(1, nLY-1))
|
| 55 |
+
pred = {}
|
| 56 |
+
for (mod,AB), layers in Yg_a[0].items():
|
| 57 |
+
for L_y, (key, _) in layers.items():
|
| 58 |
+
L_x = nearest_x(L_y)
|
| 59 |
+
# collect anchors
|
| 60 |
+
Y_vecs = []; X_vecs = []
|
| 61 |
+
for Yg, Xg in zip(Yg_a, Xg_a):
|
| 62 |
+
_, yt = Yg[(mod,AB)][L_y]
|
| 63 |
+
_, xt = Xg[(mod,AB)][L_x]
|
| 64 |
+
Y_vecs.append(yt.reshape(-1))
|
| 65 |
+
X_vecs.append(xt.reshape(-1))
|
| 66 |
+
_, x_target_t = Xg_t[(mod,AB)][L_x]
|
| 67 |
+
x_target = x_target_t.reshape(-1)
|
| 68 |
+
Yc = torch.stack(Y_vecs); Xc = torch.stack(X_vecs)
|
| 69 |
+
Ym = Yc.mean(0); Xm = Xc.mean(0)
|
| 70 |
+
Yco = Yc - Ym; Xco = Xc - Xm
|
| 71 |
+
G = Xco @ Xco.T
|
| 72 |
+
rhs = Xco @ (x_target - Xm)
|
| 73 |
+
alpha = torch.linalg.solve(G + ridge*torch.eye(3), rhs)
|
| 74 |
+
y_hat = Ym + alpha @ Yco
|
| 75 |
+
pred[key] = y_hat.reshape(yt.shape)
|
| 76 |
+
return pred
|
| 77 |
+
|
| 78 |
+
X_anc = [X["A"], X["B"], X["C"]]
|
| 79 |
+
Y_anc = [Y["A"], Y["B"], Y["C"]]
|
| 80 |
+
pred_sd = predict_per_tensor(X_anc, Y_anc, X["D"], ridge=ridge)
|
| 81 |
+
|
| 82 |
+
# Diagnostics
|
| 83 |
+
def cos(a,b):
|
| 84 |
+
a = a.flatten(); b = b.flatten()
|
| 85 |
+
return torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
|
| 86 |
+
|
| 87 |
+
# overall cosine
|
| 88 |
+
flat_pred = torch.cat([pred_sd[k].flatten() for k in keys_Y])
|
| 89 |
+
flat_oracle = torch.cat([Y["D"][k].flatten() for k in keys_Y])
|
| 90 |
+
flat_mean = torch.cat([torch.stack([Y[t][k] for t in ["A","B","C"]]).mean(0).flatten() for k in keys_Y])
|
| 91 |
+
print("PER-TENSOR cos(Y_hat,Y_D):", cos(flat_pred, flat_oracle))
|
| 92 |
+
print("PER-TENSOR cos(Y_mean,Y_D):", cos(flat_mean, flat_oracle))
|
| 93 |
+
|
| 94 |
+
# save
|
| 95 |
+
out_dir = OUT/"Y"/"Y_pred_D_pertensor"
|
| 96 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 97 |
+
shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", out_dir/"adapter_config.json")
|
| 98 |
+
for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]:
|
| 99 |
+
src = OUT/"Y"/"Y_A"/f
|
| 100 |
+
if src.exists(): shutil.copy(src, out_dir/f)
|
| 101 |
+
save_file({k: v.to(torch.bfloat16) for k,v in pred_sd.items()}, str(out_dir/"adapter_model.safetensors"))
|
| 102 |
+
print("Saved", out_dir)
|