File size: 5,643 Bytes
378a074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""Consolidate raw repo JSON into viz-ready data/ files for the blog."""
import json, os

SRC = os.path.dirname(__file__)
OUT = os.path.join(SRC, "..", "data")
os.makedirs(OUT, exist_ok=True)

def L(name): return json.load(open(os.path.join(SRC, name)))
def W(name, obj):
    json.dump(obj, open(os.path.join(OUT, name), "w"), separators=(",", ":"))
    print("wrote", name, os.path.getsize(os.path.join(OUT, name)), "b")

# canonical method order + aliases
ORDER = ["sft", "rft", "dft", "rift", "grpo", "dpo", "online_grpo", "online_dapo"]
ALIAS = {"offline_grpo": "grpo", "grpo": "grpo"}
def mid(m):
    m = m.strip().lower().replace(" ", "_")
    return ALIAS.get(m, m)
def pair_key(a, b):
    a, b = mid(a), mid(b)
    return "__".join(sorted([a, b], key=lambda x: ORDER.index(x)))
def split_pair(key):
    # handles 'a vs b', 'a__b'
    if " vs " in key: a, b = key.split(" vs ")
    else: a, b = key.split("__")
    return a, b

# ---------- cosine matrices ----------
c8 = L("cosine_all8.json")
res = L("results.json")
W("cosine.json", {
    "labels8": c8["labels"],
    "matrix8": c8["cosine"],
    "labels6": [m.upper() for m in res["plot_labels"]],     # SFT RFT DFT RIFT GRPO DPO
    "ids6": [mid(m) for m in res["methods"]],               # sft..dpo (offline_grpo->grpo)
    "matrix6": res["cosine_global"],
})

# ---------- per-layer metrics ----------
layers = res["layers"]
cos_pl, cka_pl, svd_u = {}, {}, {}
for k, v in res["cosine_per_layer"].items():
    a, b = split_pair(k); cos_pl[pair_key(a, b)] = [round(x, 4) for x in v]
for k, v in res["runtime"]["cka_per_layer"].items():
    a, b = split_pair(k); cka_pl[pair_key(a, b)] = [round(x, 4) for x in v]
svd = L("top1_svd_cosine.json")
svd_v_scalar = {}
for k, d in svd.items():
    a, b = split_pair(k); pk = pair_key(a, b)
    plm = d.get("per_layer_u_mean", {})
    svd_u[pk] = [round(plm.get(str(i), plm.get(i, None)), 4) if (str(i) in plm or i in plm) else None for i in layers]
    svd_v_scalar[pk] = round(d.get("mean_top1_v_cosine", 0), 4)
og = L("online_grpo_cosine.json")
ortho = og["online_grpo_orthogonal_fraction_per_layer"]
online_ortho = [round(ortho[str(i)], 4) for i in layers]
W("per_layer.json", {
    "layers": layers,
    "cosine": cos_pl,
    "cka": cka_pl,
    "svd_u": svd_u,
    "svd_v_scalar": svd_v_scalar,
    "online_grpo_ortho_off_sft": online_ortho,
})

# ---------- geometry: norms, rank, principal angles, LMC ----------
frob = {mid(m): round(d["frobenius_norm"], 4) for m, d in res["inventory"]["methods"].items()}
erank = {mid(m): round(v, 3) for m, v in res["effective_rank_avg"].items()}
pa = {}
for k, d in L("principal_angles_summary.json").items():
    a, b = split_pair(k)
    pa[pair_key(a, b)] = {
        "median_top1": round(d["median_top1"], 2),
        "median_top10_max": round(d["median_top10_max"], 2),
        "p25": round(d["p25_top10_max"], 2),
        "p75": round(d["p75_top10_max"], 2),
        "n_modules": d["n_modules"],
    }
lmc = {}            # preserve α-orientation: key order == interpolation direction (a→b)
for k, v in res["runtime"]["mode_connectivity_loss"].items():
    a, b = split_pair(k); lmc[mid(a) + "__" + mid(b)] = [round(x, 4) for x in v]
W("geometry.json", {
    "frobenius": frob,
    "effective_rank": erank,
    "principal_angles": pa,
    "lmc": lmc,
    "alphas": res["runtime"]["alphas"],
    "prediction_depth": {mid(m): round(v, 2) for m, v in res["runtime"]["prediction_depth_mean"].items()},
})

# ---------- accuracy ----------
gsm = {r["method"]: round(r["pass@1_greedy"] * 100, 1) for r in L("gsm8k_accuracy_all.json")}
aime = {r["method"]: round(r["pass_at_1"] * 100, 1) for r in L("aime26_qwen3_accuracy.json")}
# base + online from RESULTS.md headline table
rows = [
    {"method": "base",         "label": "Base instruct", "gsm8k": 94.0, "aime26": 16.7, "group": "base"},
    {"method": "sft",          "label": "SFT",           "gsm8k": gsm.get("sft", 87.6),  "aime26": aime.get("sft", 10.0),  "group": "offline"},
    {"method": "rft",          "label": "RFT",           "gsm8k": gsm.get("rft"),         "aime26": aime.get("rft"),         "group": "offline"},
    {"method": "dft",          "label": "DFT",           "gsm8k": gsm.get("dft"),         "aime26": aime.get("dft"),         "group": "offline"},
    {"method": "rift",         "label": "RIFT",          "gsm8k": gsm.get("rift"),        "aime26": aime.get("rift"),        "group": "offline"},
    {"method": "grpo",         "label": "Offline GRPO",  "gsm8k": gsm.get("offline_grpo"),"aime26": aime.get("offline_grpo"),"group": "offline"},
    {"method": "dpo",          "label": "DPO",           "gsm8k": gsm.get("dpo", 93.5),   "aime26": aime.get("dpo", 13.3),   "group": "offline"},
    {"method": "online_grpo",  "label": "Online GRPO",   "gsm8k": 93.7, "aime26": 20.0, "group": "online"},
    {"method": "online_dapo",  "label": "Online DAPO",   "gsm8k": 93.3, "aime26": 16.7, "group": "online"},
]
grid = L("online_accuracy_grid.json")
W("accuracy.json", {
    "methods": rows,
    "online_grid": {"grpo": grid["grpo"], "dapo": grid["dapo"], "lrs": ["5e-7", "5e-6", "5e-5"], "seeds": ["42", "123"]},
})

# ---------- seed / LR sensitivity ----------
slr = L("seed_lr_sensitivity.json")
dapo = L("dapo_seed_lr.json")
W("seed_lr.json", {
    "seed_cosine": slr["within_config_seed_cosine"],         # method -> lr -> {cos, top1_u, top1_v, med_top8_angle_deg}
    "frobenius_by_lr": slr["frobenius_norm"],
    "dapo_seed_cosine": dapo["within_config_seed_cosine_dapo"],
    "dapo_lr_sensitivity": dapo["lr_sensitivity_dapo_seed42"],
    "lrs": ["5e-7", "5e-6", "5e-5"],
})
print("done")