cross-model-lora-prediction / improve_pertensor.py
Samarth0710's picture
Upload improve_pertensor.py with huggingface_hub
e039734 verified
"""
Improvement 1: per-tensor anchor-basis ridge mapping.
Each tensor (e.g. layer.5.q_proj.lora_B) gets its own α_3, instead of one global α.
"""
import json, shutil
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file
OUT = Path("/app/out")
TASKS = ["A","B","C","D"]
def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()}
X = {t: load_sd(OUT/"X"/f"X_{t}") for t in TASKS}
Y = {t: load_sd(OUT/"Y"/f"Y_{t}") for t in TASKS}
keys_X = sorted(X["A"].keys())
keys_Y = sorted(Y["A"].keys())
# Map per-tensor: same key index in X and Y because target_modules and layer count differ -> can't.
# Hidden sizes & layer counts differ between X(24L) and Y(16L). So we need a per-tensor map keyed by
# (module_kind, AB, layer_idx_within_each_model).
# Strategy: align by (module, AB, normalized_layer_position rounded to nearest in Y).
import re, collections
def parse(k):
m = re.search(r'layers\.(\d+)\.self_attn\.(\w+)_proj\.lora_(A|B)', k)
return int(m.group(1)), m.group(2), m.group(3)
def group_by_layer(sd):
g = collections.defaultdict(dict) # (module, AB) -> {layer: tensor}
for k,v in sd.items():
L,mod,AB = parse(k)
g[(mod,AB)][L] = (k, v)
return g
# For per-tensor mapping we operate on each Y tensor independently using the matched X tensor at the
# closest normalized layer position.
ridge = 1e-3
def predict_per_tensor(X_anchors_sds, Y_anchors_sds, X_target_sd, ridge=1e-3):
"""For each Y tensor, find aligned X tensor (same module/AB at nearest normalized layer)
in EACH adapter, build (Y_target_tensor_flat, X_aligned_tensor_flat) anchor pairs of length 3,
solve 3x3 ridge → α, predict Y_hat tensor."""
Yg_a = [group_by_layer(s) for s in Y_anchors_sds]
Xg_a = [group_by_layer(s) for s in X_anchors_sds]
Xg_t = group_by_layer(X_target_sd)
nLY = max(Yg_a[0][("q","A")].keys()) + 1
nLX = max(Xg_a[0][("q","A")].keys()) + 1
# pre-build mapping Y-layer -> nearest X-layer
def nearest_x(L_y):
# normalize
return round(L_y * (nLX-1) / max(1, nLY-1))
pred = {}
for (mod,AB), layers in Yg_a[0].items():
for L_y, (key, _) in layers.items():
L_x = nearest_x(L_y)
# collect anchors
Y_vecs = []; X_vecs = []
for Yg, Xg in zip(Yg_a, Xg_a):
_, yt = Yg[(mod,AB)][L_y]
_, xt = Xg[(mod,AB)][L_x]
Y_vecs.append(yt.reshape(-1))
X_vecs.append(xt.reshape(-1))
_, x_target_t = Xg_t[(mod,AB)][L_x]
x_target = x_target_t.reshape(-1)
Yc = torch.stack(Y_vecs); Xc = torch.stack(X_vecs)
Ym = Yc.mean(0); Xm = Xc.mean(0)
Yco = Yc - Ym; Xco = Xc - Xm
G = Xco @ Xco.T
rhs = Xco @ (x_target - Xm)
alpha = torch.linalg.solve(G + ridge*torch.eye(3), rhs)
y_hat = Ym + alpha @ Yco
pred[key] = y_hat.reshape(yt.shape)
return pred
X_anc = [X["A"], X["B"], X["C"]]
Y_anc = [Y["A"], Y["B"], Y["C"]]
pred_sd = predict_per_tensor(X_anc, Y_anc, X["D"], ridge=ridge)
# Diagnostics
def cos(a,b):
a = a.flatten(); b = b.flatten()
return torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()
# overall cosine
flat_pred = torch.cat([pred_sd[k].flatten() for k in keys_Y])
flat_oracle = torch.cat([Y["D"][k].flatten() for k in keys_Y])
flat_mean = torch.cat([torch.stack([Y[t][k] for t in ["A","B","C"]]).mean(0).flatten() for k in keys_Y])
print("PER-TENSOR cos(Y_hat,Y_D):", cos(flat_pred, flat_oracle))
print("PER-TENSOR cos(Y_mean,Y_D):", cos(flat_mean, flat_oracle))
# save
out_dir = OUT/"Y"/"Y_pred_D_pertensor"
out_dir.mkdir(parents=True, exist_ok=True)
shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", out_dir/"adapter_config.json")
for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]:
src = OUT/"Y"/"Y_A"/f
if src.exists(): shutil.copy(src, out_dir/f)
save_file({k: v.to(torch.bfloat16) for k,v in pred_sd.items()}, str(out_dir/"adapter_model.safetensors"))
print("Saved", out_dir)