File size: 4,162 Bytes
e039734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Improvement 1: per-tensor anchor-basis ridge mapping.
Each tensor (e.g. layer.5.q_proj.lora_B) gets its own α_3, instead of one global α.
"""
import json, shutil
from pathlib import Path
import torch
from safetensors.torch import load_file, save_file

OUT = Path("/app/out")
TASKS = ["A","B","C","D"]

def load_sd(p): return {k: v.float().cpu() for k,v in load_file(str(p/"adapter_model.safetensors")).items()}

X = {t: load_sd(OUT/"X"/f"X_{t}") for t in TASKS}
Y = {t: load_sd(OUT/"Y"/f"Y_{t}") for t in TASKS}

keys_X = sorted(X["A"].keys())
keys_Y = sorted(Y["A"].keys())

# Map per-tensor: same key index in X and Y because target_modules and layer count differ -> can't.
# Hidden sizes & layer counts differ between X(24L) and Y(16L). So we need a per-tensor map keyed by
# (module_kind, AB, layer_idx_within_each_model).
# Strategy: align by (module, AB, normalized_layer_position rounded to nearest in Y).
import re, collections

def parse(k):
    m = re.search(r'layers\.(\d+)\.self_attn\.(\w+)_proj\.lora_(A|B)', k)
    return int(m.group(1)), m.group(2), m.group(3)

def group_by_layer(sd):
    g = collections.defaultdict(dict)  # (module, AB) -> {layer: tensor}
    for k,v in sd.items():
        L,mod,AB = parse(k)
        g[(mod,AB)][L] = (k, v)
    return g

# For per-tensor mapping we operate on each Y tensor independently using the matched X tensor at the
# closest normalized layer position.
ridge = 1e-3

def predict_per_tensor(X_anchors_sds, Y_anchors_sds, X_target_sd, ridge=1e-3):
    """For each Y tensor, find aligned X tensor (same module/AB at nearest normalized layer)
    in EACH adapter, build (Y_target_tensor_flat, X_aligned_tensor_flat) anchor pairs of length 3,
    solve 3x3 ridge → α, predict Y_hat tensor."""
    Yg_a = [group_by_layer(s) for s in Y_anchors_sds]
    Xg_a = [group_by_layer(s) for s in X_anchors_sds]
    Xg_t = group_by_layer(X_target_sd)
    nLY = max(Yg_a[0][("q","A")].keys()) + 1
    nLX = max(Xg_a[0][("q","A")].keys()) + 1
    # pre-build mapping Y-layer -> nearest X-layer
    def nearest_x(L_y):
        # normalize
        return round(L_y * (nLX-1) / max(1, nLY-1))
    pred = {}
    for (mod,AB), layers in Yg_a[0].items():
        for L_y, (key, _) in layers.items():
            L_x = nearest_x(L_y)
            # collect anchors
            Y_vecs = []; X_vecs = []
            for Yg, Xg in zip(Yg_a, Xg_a):
                _, yt = Yg[(mod,AB)][L_y]
                _, xt = Xg[(mod,AB)][L_x]
                Y_vecs.append(yt.reshape(-1))
                X_vecs.append(xt.reshape(-1))
            _, x_target_t = Xg_t[(mod,AB)][L_x]
            x_target = x_target_t.reshape(-1)
            Yc = torch.stack(Y_vecs); Xc = torch.stack(X_vecs)
            Ym = Yc.mean(0); Xm = Xc.mean(0)
            Yco = Yc - Ym; Xco = Xc - Xm
            G = Xco @ Xco.T
            rhs = Xco @ (x_target - Xm)
            alpha = torch.linalg.solve(G + ridge*torch.eye(3), rhs)
            y_hat = Ym + alpha @ Yco
            pred[key] = y_hat.reshape(yt.shape)
    return pred

X_anc = [X["A"], X["B"], X["C"]]
Y_anc = [Y["A"], Y["B"], Y["C"]]
pred_sd = predict_per_tensor(X_anc, Y_anc, X["D"], ridge=ridge)

# Diagnostics
def cos(a,b):
    a = a.flatten(); b = b.flatten()
    return torch.nn.functional.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()

# overall cosine
flat_pred = torch.cat([pred_sd[k].flatten() for k in keys_Y])
flat_oracle = torch.cat([Y["D"][k].flatten() for k in keys_Y])
flat_mean = torch.cat([torch.stack([Y[t][k] for t in ["A","B","C"]]).mean(0).flatten() for k in keys_Y])
print("PER-TENSOR cos(Y_hat,Y_D):", cos(flat_pred, flat_oracle))
print("PER-TENSOR cos(Y_mean,Y_D):", cos(flat_mean, flat_oracle))

# save
out_dir = OUT/"Y"/"Y_pred_D_pertensor"
out_dir.mkdir(parents=True, exist_ok=True)
shutil.copy(OUT/"Y"/"Y_A"/"adapter_config.json", out_dir/"adapter_config.json")
for f in ["tokenizer.json","tokenizer_config.json","special_tokens_map.json"]:
    src = OUT/"Y"/"Y_A"/f
    if src.exists(): shutil.copy(src, out_dir/f)
save_file({k: v.to(torch.bfloat16) for k,v in pred_sd.items()}, str(out_dir/"adapter_model.safetensors"))
print("Saved", out_dir)