""" Improvement 1: per-tensor anchor-basis ridge mapping. Each tensor (e.g. layer.5.q_proj.lora_B) gets its own α_3, instead of one global α. """ import json, shutil from pathlib import Path import torch from safetensors.torch import load_file, save_file OUT = Path("/app/out") TASKS = ["A","B","C","D"] def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()} X = {t: load_sd(OUT/"X"/f"X_{t}") for t in TASKS} Y = {t: load_sd(OUT/"Y"/f"Y_{t}") for t in TASKS} keys_X = sorted(X["A"].keys()) keys_Y = sorted(Y["A"].keys()) # Map per-tensor: same key index in X and Y because target_modules and layer count differ -> can't. # Hidden sizes & layer counts differ between X(24L) and Y(16L). So we need a per-tensor map keyed by # (module_kind, AB, layer_idx_within_each_model). # Strategy: align by (module, AB, normalized_layer_position rounded to nearest in Y). import re, collections def parse(k): m = re.search(r'layers\.(\d+)\.self_attn\.(\w+)_proj\.lora_(A|B)', k) return int(m.group(1)), m.group(2), m.group(3) def group_by_layer(sd): g = collections.defaultdict(dict) # (module, AB) -> {layer: tensor} for k,v in sd.items(): L,mod,AB = parse(k) g[(mod,AB)][L] = (k, v) return g # For per-tensor mapping we operate on each Y tensor independently using the matched X tensor at the # closest normalized layer position. ridge = 1e-3 def predict_per_tensor(X_anchors_sds, Y_anchors_sds, X_target_sd, ridge=1e-3): """For each Y tensor, find aligned X tensor (same module/AB at nearest normalized layer) in EACH adapter, build (Y_target_tensor_flat, X_aligned_tensor_flat) anchor pairs of length 3, solve 3x3 ridge → α, predict Y_hat tensor.""" Yg_a = [group_by_layer(s) for s in Y_anchors_sds] Xg_a = [group_by_layer(s) for s in X_anchors_sds] Xg_t = group_by_layer(X_target_sd) nLY = max(Yg_a[0][("q","A")].keys()) + 1 nLX = max(Xg_a[0][("q","A")].keys()) + 1 # pre-build mapping Y-layer -> nearest X-layer def nearest_x(L_y): # normalize return round(L_y * (nLX-1) / max(1, nLY-1)) pred = {} for (mod,AB), layers in Yg_a[0].items(): for L_y, (key, _) in layers.items(): L_x = nearest_x(L_y) # collect anchors Y_vecs = []; X_vecs = [] for Yg, Xg in zip(Yg_a, Xg_a): _, yt = Yg[(mod,AB)][L_y] _, xt = Xg[(mod,AB)][L_x] Y_vecs.append(yt.reshape(-1)) X_vecs.append(xt.reshape(-1)) _, x_target_t = Xg_t[(mod,AB)][L_x] x_target = x_target_t.reshape(-1) Yc = torch.stack(Y_vecs); Xc = torch.stack(X_vecs) Ym = Yc.mean(0); Xm = Xc.mean(0) Yco = Yc - Ym; Xco = Xc - Xm G = Xco @ Xco.T rhs = Xco @ (x_target - Xm) alpha = torch.linalg.solve(G + ridge*torch.eye(3), rhs) y_hat = Ym + alpha @ Yco pred[key] = y_hat.reshape(yt.shape) return pred X_anc = [X["A"], X["B"], X["C"]] Y_anc = [Y["A"], Y["B"], Y["C"]] pred_sd = predict_per_tensor(X_anc, Y_anc, X["D"], ridge=ridge) # Diagnostics def cos(a,b): a = a.flatten(); b = b.flatten() return torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item() # overall cosine flat_pred = torch.cat([pred_sd[k].flatten() for k in keys_Y]) flat_oracle = torch.cat([Y["D"][k].flatten() for k in keys_Y]) flat_mean = torch.cat([torch.stack([Y[t][k] for t in ["A","B","C"]]).mean(0).flatten() for k in keys_Y]) print("PER-TENSOR cos(Y_hat,Y_D):", cos(flat_pred, flat_oracle)) print("PER-TENSOR cos(Y_mean,Y_D):", cos(flat_mean, flat_oracle)) # save out_dir = OUT/"Y"/"Y_pred_D_pertensor" out_dir.mkdir(parents=True, exist_ok=True) shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", out_dir/"adapter_config.json") for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]: src = OUT/"Y"/"Y_A"/f if src.exists(): shutil.copy(src, out_dir/f) save_file({k: v.to(torch.bfloat16) for k,v in pred_sd.items()}, str(out_dir/"adapter_model.safetensors")) print("Saved", out_dir)