| |
| """Consolidate raw repo JSON into viz-ready data/ files for the blog.""" |
| import json, os |
|
|
| SRC = os.path.dirname(__file__) |
| OUT = os.path.join(SRC, "..", "data") |
| os.makedirs(OUT, exist_ok=True) |
|
|
| def L(name): return json.load(open(os.path.join(SRC, name))) |
| def W(name, obj): |
| json.dump(obj, open(os.path.join(OUT, name), "w"), separators=(",", ":")) |
| print("wrote", name, os.path.getsize(os.path.join(OUT, name)), "b") |
|
|
| |
| ORDER = ["sft", "rft", "dft", "rift", "grpo", "dpo", "online_grpo", "online_dapo"] |
| ALIAS = {"offline_grpo": "grpo", "grpo": "grpo"} |
| def mid(m): |
| m = m.strip().lower().replace(" ", "_") |
| return ALIAS.get(m, m) |
| def pair_key(a, b): |
| a, b = mid(a), mid(b) |
| return "__".join(sorted([a, b], key=lambda x: ORDER.index(x))) |
| def split_pair(key): |
| |
| if " vs " in key: a, b = key.split(" vs ") |
| else: a, b = key.split("__") |
| return a, b |
|
|
| |
| c8 = L("cosine_all8.json") |
| res = L("results.json") |
| W("cosine.json", { |
| "labels8": c8["labels"], |
| "matrix8": c8["cosine"], |
| "labels6": [m.upper() for m in res["plot_labels"]], |
| "ids6": [mid(m) for m in res["methods"]], |
| "matrix6": res["cosine_global"], |
| }) |
|
|
| |
| layers = res["layers"] |
| cos_pl, cka_pl, svd_u = {}, {}, {} |
| for k, v in res["cosine_per_layer"].items(): |
| a, b = split_pair(k); cos_pl[pair_key(a, b)] = [round(x, 4) for x in v] |
| for k, v in res["runtime"]["cka_per_layer"].items(): |
| a, b = split_pair(k); cka_pl[pair_key(a, b)] = [round(x, 4) for x in v] |
| svd = L("top1_svd_cosine.json") |
| svd_v_scalar = {} |
| for k, d in svd.items(): |
| a, b = split_pair(k); pk = pair_key(a, b) |
| plm = d.get("per_layer_u_mean", {}) |
| svd_u[pk] = [round(plm.get(str(i), plm.get(i, None)), 4) if (str(i) in plm or i in plm) else None for i in layers] |
| svd_v_scalar[pk] = round(d.get("mean_top1_v_cosine", 0), 4) |
| og = L("online_grpo_cosine.json") |
| ortho = og["online_grpo_orthogonal_fraction_per_layer"] |
| online_ortho = [round(ortho[str(i)], 4) for i in layers] |
| W("per_layer.json", { |
| "layers": layers, |
| "cosine": cos_pl, |
| "cka": cka_pl, |
| "svd_u": svd_u, |
| "svd_v_scalar": svd_v_scalar, |
| "online_grpo_ortho_off_sft": online_ortho, |
| }) |
|
|
| |
| frob = {mid(m): round(d["frobenius_norm"], 4) for m, d in res["inventory"]["methods"].items()} |
| erank = {mid(m): round(v, 3) for m, v in res["effective_rank_avg"].items()} |
| pa = {} |
| for k, d in L("principal_angles_summary.json").items(): |
| a, b = split_pair(k) |
| pa[pair_key(a, b)] = { |
| "median_top1": round(d["median_top1"], 2), |
| "median_top10_max": round(d["median_top10_max"], 2), |
| "p25": round(d["p25_top10_max"], 2), |
| "p75": round(d["p75_top10_max"], 2), |
| "n_modules": d["n_modules"], |
| } |
| lmc = {} |
| for k, v in res["runtime"]["mode_connectivity_loss"].items(): |
| a, b = split_pair(k); lmc[mid(a) + "__" + mid(b)] = [round(x, 4) for x in v] |
| W("geometry.json", { |
| "frobenius": frob, |
| "effective_rank": erank, |
| "principal_angles": pa, |
| "lmc": lmc, |
| "alphas": res["runtime"]["alphas"], |
| "prediction_depth": {mid(m): round(v, 2) for m, v in res["runtime"]["prediction_depth_mean"].items()}, |
| }) |
|
|
| |
| gsm = {r["method"]: round(r["pass@1_greedy"] * 100, 1) for r in L("gsm8k_accuracy_all.json")} |
| aime = {r["method"]: round(r["pass_at_1"] * 100, 1) for r in L("aime26_qwen3_accuracy.json")} |
| |
| rows = [ |
| {"method": "base", "label": "Base instruct", "gsm8k": 94.0, "aime26": 16.7, "group": "base"}, |
| {"method": "sft", "label": "SFT", "gsm8k": gsm.get("sft", 87.6), "aime26": aime.get("sft", 10.0), "group": "offline"}, |
| {"method": "rft", "label": "RFT", "gsm8k": gsm.get("rft"), "aime26": aime.get("rft"), "group": "offline"}, |
| {"method": "dft", "label": "DFT", "gsm8k": gsm.get("dft"), "aime26": aime.get("dft"), "group": "offline"}, |
| {"method": "rift", "label": "RIFT", "gsm8k": gsm.get("rift"), "aime26": aime.get("rift"), "group": "offline"}, |
| {"method": "grpo", "label": "Offline GRPO", "gsm8k": gsm.get("offline_grpo"),"aime26": aime.get("offline_grpo"),"group": "offline"}, |
| {"method": "dpo", "label": "DPO", "gsm8k": gsm.get("dpo", 93.5), "aime26": aime.get("dpo", 13.3), "group": "offline"}, |
| {"method": "online_grpo", "label": "Online GRPO", "gsm8k": 93.7, "aime26": 20.0, "group": "online"}, |
| {"method": "online_dapo", "label": "Online DAPO", "gsm8k": 93.3, "aime26": 16.7, "group": "online"}, |
| ] |
| grid = L("online_accuracy_grid.json") |
| W("accuracy.json", { |
| "methods": rows, |
| "online_grid": {"grpo": grid["grpo"], "dapo": grid["dapo"], "lrs": ["5e-7", "5e-6", "5e-5"], "seeds": ["42", "123"]}, |
| }) |
|
|
| |
| slr = L("seed_lr_sensitivity.json") |
| dapo = L("dapo_seed_lr.json") |
| W("seed_lr.json", { |
| "seed_cosine": slr["within_config_seed_cosine"], |
| "frobenius_by_lr": slr["frobenius_norm"], |
| "dapo_seed_cosine": dapo["within_config_seed_cosine_dapo"], |
| "dapo_lr_sensitivity": dapo["lr_sensitivity_dapo_seed42"], |
| "lrs": ["5e-7", "5e-6", "5e-5"], |
| }) |
| print("done") |
|
|