#!/usr/bin/env python3 """Consolidate raw repo JSON into viz-ready data/ files for the blog.""" import json, os SRC = os.path.dirname(__file__) OUT = os.path.join(SRC, "..", "data") os.makedirs(OUT, exist_ok=True) def L(name): return json.load(open(os.path.join(SRC, name))) def W(name, obj): json.dump(obj, open(os.path.join(OUT, name), "w"), separators=(",", ":")) print("wrote", name, os.path.getsize(os.path.join(OUT, name)), "b") # canonical method order + aliases ORDER = ["sft", "rft", "dft", "rift", "grpo", "dpo", "online_grpo", "online_dapo"] ALIAS = {"offline_grpo": "grpo", "grpo": "grpo"} def mid(m): m = m.strip().lower().replace(" ", "_") return ALIAS.get(m, m) def pair_key(a, b): a, b = mid(a), mid(b) return "__".join(sorted([a, b], key=lambda x: ORDER.index(x))) def split_pair(key): # handles 'a vs b', 'a__b' if " vs " in key: a, b = key.split(" vs ") else: a, b = key.split("__") return a, b # ---------- cosine matrices ---------- c8 = L("cosine_all8.json") res = L("results.json") W("cosine.json", { "labels8": c8["labels"], "matrix8": c8["cosine"], "labels6": [m.upper() for m in res["plot_labels"]], # SFT RFT DFT RIFT GRPO DPO "ids6": [mid(m) for m in res["methods"]], # sft..dpo (offline_grpo->grpo) "matrix6": res["cosine_global"], }) # ---------- per-layer metrics ---------- layers = res["layers"] cos_pl, cka_pl, svd_u = {}, {}, {} for k, v in res["cosine_per_layer"].items(): a, b = split_pair(k); cos_pl[pair_key(a, b)] = [round(x, 4) for x in v] for k, v in res["runtime"]["cka_per_layer"].items(): a, b = split_pair(k); cka_pl[pair_key(a, b)] = [round(x, 4) for x in v] svd = L("top1_svd_cosine.json") svd_v_scalar = {} for k, d in svd.items(): a, b = split_pair(k); pk = pair_key(a, b) plm = d.get("per_layer_u_mean", {}) svd_u[pk] = [round(plm.get(str(i), plm.get(i, None)), 4) if (str(i) in plm or i in plm) else None for i in layers] svd_v_scalar[pk] = round(d.get("mean_top1_v_cosine", 0), 4) og = L("online_grpo_cosine.json") ortho = og["online_grpo_orthogonal_fraction_per_layer"] online_ortho = [round(ortho[str(i)], 4) for i in layers] W("per_layer.json", { "layers": layers, "cosine": cos_pl, "cka": cka_pl, "svd_u": svd_u, "svd_v_scalar": svd_v_scalar, "online_grpo_ortho_off_sft": online_ortho, }) # ---------- geometry: norms, rank, principal angles, LMC ---------- frob = {mid(m): round(d["frobenius_norm"], 4) for m, d in res["inventory"]["methods"].items()} erank = {mid(m): round(v, 3) for m, v in res["effective_rank_avg"].items()} pa = {} for k, d in L("principal_angles_summary.json").items(): a, b = split_pair(k) pa[pair_key(a, b)] = { "median_top1": round(d["median_top1"], 2), "median_top10_max": round(d["median_top10_max"], 2), "p25": round(d["p25_top10_max"], 2), "p75": round(d["p75_top10_max"], 2), "n_modules": d["n_modules"], } lmc = {} # preserve α-orientation: key order == interpolation direction (a→b) for k, v in res["runtime"]["mode_connectivity_loss"].items(): a, b = split_pair(k); lmc[mid(a) + "__" + mid(b)] = [round(x, 4) for x in v] W("geometry.json", { "frobenius": frob, "effective_rank": erank, "principal_angles": pa, "lmc": lmc, "alphas": res["runtime"]["alphas"], "prediction_depth": {mid(m): round(v, 2) for m, v in res["runtime"]["prediction_depth_mean"].items()}, }) # ---------- accuracy ---------- gsm = {r["method"]: round(r["pass@1_greedy"] * 100, 1) for r in L("gsm8k_accuracy_all.json")} aime = {r["method"]: round(r["pass_at_1"] * 100, 1) for r in L("aime26_qwen3_accuracy.json")} # base + online from RESULTS.md headline table rows = [ {"method": "base", "label": "Base instruct", "gsm8k": 94.0, "aime26": 16.7, "group": "base"}, {"method": "sft", "label": "SFT", "gsm8k": gsm.get("sft", 87.6), "aime26": aime.get("sft", 10.0), "group": "offline"}, {"method": "rft", "label": "RFT", "gsm8k": gsm.get("rft"), "aime26": aime.get("rft"), "group": "offline"}, {"method": "dft", "label": "DFT", "gsm8k": gsm.get("dft"), "aime26": aime.get("dft"), "group": "offline"}, {"method": "rift", "label": "RIFT", "gsm8k": gsm.get("rift"), "aime26": aime.get("rift"), "group": "offline"}, {"method": "grpo", "label": "Offline GRPO", "gsm8k": gsm.get("offline_grpo"),"aime26": aime.get("offline_grpo"),"group": "offline"}, {"method": "dpo", "label": "DPO", "gsm8k": gsm.get("dpo", 93.5), "aime26": aime.get("dpo", 13.3), "group": "offline"}, {"method": "online_grpo", "label": "Online GRPO", "gsm8k": 93.7, "aime26": 20.0, "group": "online"}, {"method": "online_dapo", "label": "Online DAPO", "gsm8k": 93.3, "aime26": 16.7, "group": "online"}, ] grid = L("online_accuracy_grid.json") W("accuracy.json", { "methods": rows, "online_grid": {"grpo": grid["grpo"], "dapo": grid["dapo"], "lrs": ["5e-7", "5e-6", "5e-5"], "seeds": ["42", "123"]}, }) # ---------- seed / LR sensitivity ---------- slr = L("seed_lr_sensitivity.json") dapo = L("dapo_seed_lr.json") W("seed_lr.json", { "seed_cosine": slr["within_config_seed_cosine"], # method -> lr -> {cos, top1_u, top1_v, med_top8_angle_deg} "frobenius_by_lr": slr["frobenius_norm"], "dapo_seed_cosine": dapo["within_config_seed_cosine_dapo"], "dapo_lr_sensitivity": dapo["lr_sensitivity_dapo_seed42"], "lrs": ["5e-7", "5e-6", "5e-5"], }) print("done")