AlexWortega's picture
Interactive blog: Same Data, Different Losses, Same Circuits?
378a074 verified
Raw
History Blame Contribute Delete
5.64 kB
#!/usr/bin/env python3
"""Consolidate raw repo JSON into viz-ready data/ files for the blog."""
import json, os
SRC = os.path.dirname(__file__)
OUT = os.path.join(SRC, "..", "data")
os.makedirs(OUT, exist_ok=True)
def L(name): return json.load(open(os.path.join(SRC, name)))
def W(name, obj):
json.dump(obj, open(os.path.join(OUT, name), "w"), separators=(",", ":"))
print("wrote", name, os.path.getsize(os.path.join(OUT, name)), "b")
# canonical method order + aliases
ORDER = ["sft", "rft", "dft", "rift", "grpo", "dpo", "online_grpo", "online_dapo"]
ALIAS = {"offline_grpo": "grpo", "grpo": "grpo"}
def mid(m):
m = m.strip().lower().replace(" ", "_")
return ALIAS.get(m, m)
def pair_key(a, b):
a, b = mid(a), mid(b)
return "__".join(sorted([a, b], key=lambda x: ORDER.index(x)))
def split_pair(key):
# handles 'a vs b', 'a__b'
if " vs " in key: a, b = key.split(" vs ")
else: a, b = key.split("__")
return a, b
# ---------- cosine matrices ----------
c8 = L("cosine_all8.json")
res = L("results.json")
W("cosine.json", {
"labels8": c8["labels"],
"matrix8": c8["cosine"],
"labels6": [m.upper() for m in res["plot_labels"]], # SFT RFT DFT RIFT GRPO DPO
"ids6": [mid(m) for m in res["methods"]], # sft..dpo (offline_grpo->grpo)
"matrix6": res["cosine_global"],
})
# ---------- per-layer metrics ----------
layers = res["layers"]
cos_pl, cka_pl, svd_u = {}, {}, {}
for k, v in res["cosine_per_layer"].items():
a, b = split_pair(k); cos_pl[pair_key(a, b)] = [round(x, 4) for x in v]
for k, v in res["runtime"]["cka_per_layer"].items():
a, b = split_pair(k); cka_pl[pair_key(a, b)] = [round(x, 4) for x in v]
svd = L("top1_svd_cosine.json")
svd_v_scalar = {}
for k, d in svd.items():
a, b = split_pair(k); pk = pair_key(a, b)
plm = d.get("per_layer_u_mean", {})
svd_u[pk] = [round(plm.get(str(i), plm.get(i, None)), 4) if (str(i) in plm or i in plm) else None for i in layers]
svd_v_scalar[pk] = round(d.get("mean_top1_v_cosine", 0), 4)
og = L("online_grpo_cosine.json")
ortho = og["online_grpo_orthogonal_fraction_per_layer"]
online_ortho = [round(ortho[str(i)], 4) for i in layers]
W("per_layer.json", {
"layers": layers,
"cosine": cos_pl,
"cka": cka_pl,
"svd_u": svd_u,
"svd_v_scalar": svd_v_scalar,
"online_grpo_ortho_off_sft": online_ortho,
})
# ---------- geometry: norms, rank, principal angles, LMC ----------
frob = {mid(m): round(d["frobenius_norm"], 4) for m, d in res["inventory"]["methods"].items()}
erank = {mid(m): round(v, 3) for m, v in res["effective_rank_avg"].items()}
pa = {}
for k, d in L("principal_angles_summary.json").items():
a, b = split_pair(k)
pa[pair_key(a, b)] = {
"median_top1": round(d["median_top1"], 2),
"median_top10_max": round(d["median_top10_max"], 2),
"p25": round(d["p25_top10_max"], 2),
"p75": round(d["p75_top10_max"], 2),
"n_modules": d["n_modules"],
}
lmc = {} # preserve α-orientation: key order == interpolation direction (a→b)
for k, v in res["runtime"]["mode_connectivity_loss"].items():
a, b = split_pair(k); lmc[mid(a) + "__" + mid(b)] = [round(x, 4) for x in v]
W("geometry.json", {
"frobenius": frob,
"effective_rank": erank,
"principal_angles": pa,
"lmc": lmc,
"alphas": res["runtime"]["alphas"],
"prediction_depth": {mid(m): round(v, 2) for m, v in res["runtime"]["prediction_depth_mean"].items()},
})
# ---------- accuracy ----------
gsm = {r["method"]: round(r["pass@1_greedy"] * 100, 1) for r in L("gsm8k_accuracy_all.json")}
aime = {r["method"]: round(r["pass_at_1"] * 100, 1) for r in L("aime26_qwen3_accuracy.json")}
# base + online from RESULTS.md headline table
rows = [
{"method": "base", "label": "Base instruct", "gsm8k": 94.0, "aime26": 16.7, "group": "base"},
{"method": "sft", "label": "SFT", "gsm8k": gsm.get("sft", 87.6), "aime26": aime.get("sft", 10.0), "group": "offline"},
{"method": "rft", "label": "RFT", "gsm8k": gsm.get("rft"), "aime26": aime.get("rft"), "group": "offline"},
{"method": "dft", "label": "DFT", "gsm8k": gsm.get("dft"), "aime26": aime.get("dft"), "group": "offline"},
{"method": "rift", "label": "RIFT", "gsm8k": gsm.get("rift"), "aime26": aime.get("rift"), "group": "offline"},
{"method": "grpo", "label": "Offline GRPO", "gsm8k": gsm.get("offline_grpo"),"aime26": aime.get("offline_grpo"),"group": "offline"},
{"method": "dpo", "label": "DPO", "gsm8k": gsm.get("dpo", 93.5), "aime26": aime.get("dpo", 13.3), "group": "offline"},
{"method": "online_grpo", "label": "Online GRPO", "gsm8k": 93.7, "aime26": 20.0, "group": "online"},
{"method": "online_dapo", "label": "Online DAPO", "gsm8k": 93.3, "aime26": 16.7, "group": "online"},
]
grid = L("online_accuracy_grid.json")
W("accuracy.json", {
"methods": rows,
"online_grid": {"grpo": grid["grpo"], "dapo": grid["dapo"], "lrs": ["5e-7", "5e-6", "5e-5"], "seeds": ["42", "123"]},
})
# ---------- seed / LR sensitivity ----------
slr = L("seed_lr_sensitivity.json")
dapo = L("dapo_seed_lr.json")
W("seed_lr.json", {
"seed_cosine": slr["within_config_seed_cosine"], # method -> lr -> {cos, top1_u, top1_v, med_top8_angle_deg}
"frobenius_by_lr": slr["frobenius_norm"],
"dapo_seed_cosine": dapo["within_config_seed_cosine_dapo"],
"dapo_lr_sensitivity": dapo["lr_sensitivity_dapo_seed42"],
"lrs": ["5e-7", "5e-6", "5e-5"],
})
print("done")