| """ |
| Improvement 1: per-tensor anchor-basis ridge mapping. |
| Each tensor (e.g. layer.5.q_proj.lora_B) gets its own α_3, instead of one global α. |
| """ |
| import json, shutil |
| from pathlib import Path |
| import torch |
| from safetensors.torch import load_file, save_file |
|
|
| OUT = Path("/app/out") |
| TASKS = ["A","B","C","D"] |
|
|
| def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()} |
|
|
| X = {t: load_sd(OUT/"X"/f"X_{t}") for t in TASKS} |
| Y = {t: load_sd(OUT/"Y"/f"Y_{t}") for t in TASKS} |
|
|
| keys_X = sorted(X["A"].keys()) |
| keys_Y = sorted(Y["A"].keys()) |
|
|
| |
| |
| |
| |
| import re, collections |
|
|
| def parse(k): |
| m = re.search(r'layers\.(\d+)\.self_attn\.(\w+)_proj\.lora_(A|B)', k) |
| return int(m.group(1)), m.group(2), m.group(3) |
|
|
| def group_by_layer(sd): |
| g = collections.defaultdict(dict) |
| for k,v in sd.items(): |
| L,mod,AB = parse(k) |
| g[(mod,AB)][L] = (k, v) |
| return g |
|
|
| |
| |
| ridge = 1e-3 |
|
|
| def predict_per_tensor(X_anchors_sds, Y_anchors_sds, X_target_sd, ridge=1e-3): |
| """For each Y tensor, find aligned X tensor (same module/AB at nearest normalized layer) |
| in EACH adapter, build (Y_target_tensor_flat, X_aligned_tensor_flat) anchor pairs of length 3, |
| solve 3x3 ridge → α, predict Y_hat tensor.""" |
| Yg_a = [group_by_layer(s) for s in Y_anchors_sds] |
| Xg_a = [group_by_layer(s) for s in X_anchors_sds] |
| Xg_t = group_by_layer(X_target_sd) |
| nLY = max(Yg_a[0][("q","A")].keys()) + 1 |
| nLX = max(Xg_a[0][("q","A")].keys()) + 1 |
| |
| def nearest_x(L_y): |
| |
| return round(L_y * (nLX-1) / max(1, nLY-1)) |
| pred = {} |
| for (mod,AB), layers in Yg_a[0].items(): |
| for L_y, (key, _) in layers.items(): |
| L_x = nearest_x(L_y) |
| |
| Y_vecs = []; X_vecs = [] |
| for Yg, Xg in zip(Yg_a, Xg_a): |
| _, yt = Yg[(mod,AB)][L_y] |
| _, xt = Xg[(mod,AB)][L_x] |
| Y_vecs.append(yt.reshape(-1)) |
| X_vecs.append(xt.reshape(-1)) |
| _, x_target_t = Xg_t[(mod,AB)][L_x] |
| x_target = x_target_t.reshape(-1) |
| Yc = torch.stack(Y_vecs); Xc = torch.stack(X_vecs) |
| Ym = Yc.mean(0); Xm = Xc.mean(0) |
| Yco = Yc - Ym; Xco = Xc - Xm |
| G = Xco @ Xco.T |
| rhs = Xco @ (x_target - Xm) |
| alpha = torch.linalg.solve(G + ridge*torch.eye(3), rhs) |
| y_hat = Ym + alpha @ Yco |
| pred[key] = y_hat.reshape(yt.shape) |
| return pred |
|
|
| X_anc = [X["A"], X["B"], X["C"]] |
| Y_anc = [Y["A"], Y["B"], Y["C"]] |
| pred_sd = predict_per_tensor(X_anc, Y_anc, X["D"], ridge=ridge) |
|
|
| |
| def cos(a,b): |
| a = a.flatten(); b = b.flatten() |
| return torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() |
|
|
| |
| flat_pred = torch.cat([pred_sd[k].flatten() for k in keys_Y]) |
| flat_oracle = torch.cat([Y["D"][k].flatten() for k in keys_Y]) |
| flat_mean = torch.cat([torch.stack([Y[t][k] for t in ["A","B","C"]]).mean(0).flatten() for k in keys_Y]) |
| print("PER-TENSOR cos(Y_hat,Y_D):", cos(flat_pred, flat_oracle)) |
| print("PER-TENSOR cos(Y_mean,Y_D):", cos(flat_mean, flat_oracle)) |
|
|
| |
| out_dir = OUT/"Y"/"Y_pred_D_pertensor" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", out_dir/"adapter_config.json") |
| for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]: |
| src = OUT/"Y"/"Y_A"/f |
| if src.exists(): shutil.copy(src, out_dir/f) |
| save_file({k: v.to(torch.bfloat16) for k,v in pred_sd.items()}, str(out_dir/"adapter_model.safetensors")) |
| print("Saved", out_dir) |
|
|