"""Post-hoc W&B import for the 3 trained ClarifyRL runs. Reads the local TRL ``log_history.json`` files saved during training, replays every step into a public W&B project so judges can click in and explore the curves. Eval JSONs are summarized into a final ``eval/`` block per run. Idempotent: re-runs overwrite the existing W&B run for the same id. Usage:: WANDB_API_KEY=... python scripts/backfill_wandb.py """ from __future__ import annotations import json import os import sys from pathlib import Path from typing import Any import wandb REPO_ROOT = Path(__file__).resolve().parent.parent PROJECT = "clarify-rl" ENTITY = os.environ.get("WANDB_ENTITY") # default: user's default entity # --------------------------------------------------------------------------- # Run definitions: where to find the log history + eval, and how to label them # --------------------------------------------------------------------------- RUNS = [ { "id": "run1-0p6b-no-kl", "name": "Run 1 — 0.6B GRPO (β=0)", "group": "GRPO ablation", "tags": ["0.6B", "no-KL", "grpo", "completed"], "config": { "model": "Qwen/Qwen3-0.6B", "size_b": 0.6, "max_steps": 300, "beta_kl": 0.0, "lr": 1e-6, "num_generations": 8, "vllm_gpu_mem_util": 0.55, "max_completion_len": 1024, "save_steps": 50, "compute": "HF Jobs a10g-large", "training_account": "agarwalanu3103", "training_minutes": 41.0, "hub_repo": "agarwalanu3103/clarify-rl-grpo-qwen3-0-6b", }, "log_history": "outputs/run_artifacts/0.6B/log_history.json", "eval_dir": "outputs/run_artifacts/0.6B/evals", "summary_path": "outputs/run_artifacts/0.6B/training_summary.json", "notes": ( "Run 1: baseline GRPO at 0.6B with no KL anchor. Successfully " "unlocked the event_planning family (0/12 -> 1/12, max 0.382). " "Aggregate score 0.0 -> 0.0076. Trains the 'GRPO can teach a " "weak base new behaviors' part of the thesis." ), }, { "id": "run2-1p7b-no-kl", "name": "Run 2 — 1.7B GRPO (β=0)", "group": "GRPO ablation", "tags": ["1.7B", "no-KL", "grpo", "completed", "regression"], "config": { "model": "Qwen/Qwen3-1.7B", "size_b": 1.7, "max_steps": 400, "beta_kl": 0.0, "lr": 1e-6, "num_generations": 8, "vllm_gpu_mem_util": 0.40, "max_completion_len": 768, "save_steps": 50, "compute": "HF Jobs a100-large", "training_account": "agarwalanu3103", "training_minutes": 73.5, "hub_repo": "agarwalanu3103/clarify-rl-grpo-qwen3-1-7b", }, "log_history": "outputs/run_artifacts/1.7B/log_history.json", "eval_dir": "outputs/run_artifacts/1.7B/evals", "summary_path": "outputs/run_artifacts/1.7B/training_summary.json", "notes": ( "Run 2: same GRPO recipe as Run 1 scaled to 1.7B. AGGREGATE " "REGRESSION: 0.0669 -> 0.0286. Catastrophic on event_planning " "(0.138 mean -> 0.000) but raised meeting_scheduling peak to " "0.725. This is the regression Run 4 was designed to fix." ), }, { "id": "run4-1p7b-kl-anchor", "name": "Run 4 — 1.7B GRPO (β=0.2 KL anchor)", "group": "GRPO ablation", "tags": [ "1.7B", "kl-anchor", "grpo", "completed", "headline-result", ], "config": { "model": "Qwen/Qwen3-1.7B", "size_b": 1.7, "max_steps": 300, "beta_kl": 0.2, "lr": 5e-7, "num_generations": 8, "vllm_gpu_mem_util": 0.40, "max_completion_len": 768, "save_steps": 50, "compute": "HF Jobs a100-large", "training_account": "2022uec1542", "training_minutes": 78.2, "hub_repo": "2022uec1542/clarify-rl-grpo-qwen3-1-7b", }, "log_history": "outputs/run_artifacts/1.7B-KL/log_history.json", "eval_dir": "outputs/run_artifacts/1.7B-KL/evals", "summary_path": "outputs/run_artifacts/1.7B-KL/training_summary.json", "notes": ( "Run 4: same model + env + steps as Run 2, with TRL beta=0.2 KL " "anchor and half LR. CENTRAL HACKATHON FINDING: event_planning " "recovers 0.000 -> 0.175 (beats 1.7B base 0.138), aggregate " "0.029 -> 0.056. Trade-off: meeting_scheduling peak 0.725 -> " "0.350. KL stayed bounded 0.005-0.010 throughout training, " "confirming the anchor was active. Same model. Same env. Same " "steps. One hyperparameter flips the sign of the regression." ), }, ] # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _load_json(path: Path) -> Any: with path.open() as f: return json.load(f) def _eval_summary(eval_dir: Path) -> dict[str, Any]: """Pick the most recently modified eval JSON in ``eval_dir`` and pull aggregate + per-family numbers out of it. """ if not eval_dir.exists(): return {} candidates = sorted(eval_dir.glob("*.json"), key=lambda p: p.stat().st_mtime) if not candidates: return {} eval_path = candidates[-1] eval_data = _load_json(eval_path) rows = eval_data.get("rows") or eval_data.get("results") or [] if not rows: return {"eval/_source": str(eval_path.relative_to(REPO_ROOT))} avg_score = sum(r.get("score", 0.0) for r in rows) / max(len(rows), 1) completion = sum(1 for r in rows if r.get("score", 0.0) > 0) / max(len(rows), 1) fam_means: dict[str, list[float]] = {} fam_max: dict[str, float] = {} for r in rows: fam = r.get("family") or r.get("task_family") or r.get("scenario", "").split("_")[0] fam_means.setdefault(fam, []).append(r.get("score", 0.0)) fam_max[fam] = max(fam_max.get(fam, 0.0), r.get("score", 0.0)) out: dict[str, Any] = { "eval/avg_score": avg_score, "eval/completion_rate": completion, "eval/n": len(rows), "eval/_source": str(eval_path.relative_to(REPO_ROOT)), } for fam, scores in fam_means.items(): out[f"eval/family/{fam}/mean"] = sum(scores) / len(scores) out[f"eval/family/{fam}/max"] = fam_max[fam] return out def _step_metrics(row: dict) -> dict: """Strip the row down to scalar metrics safe to push to W&B for a single optimization step.""" out = {} for k, v in row.items(): if k in ("step", "epoch"): continue if isinstance(v, (int, float)): out[k] = v return out # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> int: api_key = os.environ.get("WANDB_API_KEY") if not api_key: print("ERROR: WANDB_API_KEY env var is required", file=sys.stderr) return 1 wandb.login(key=api_key, relogin=True) for spec in RUNS: log_path = REPO_ROOT / spec["log_history"] if not log_path.exists(): print(f"[skip] {spec['id']}: missing {log_path}") continue rows = _load_json(log_path) print(f"\n=== {spec['name']} ===") print(f" rows: {len(rows)}") print(f" log: {log_path.relative_to(REPO_ROOT)}") run = wandb.init( project=PROJECT, entity=ENTITY, id=spec["id"], name=spec["name"], group=spec["group"], tags=spec["tags"], config=spec["config"], notes=spec["notes"], resume="allow", reinit=True, ) # Replay step-level metrics. TRL's log_history rows that contain # both training step metrics and a final summary row are heterogeneous; # we only push numeric scalars. Use row index as the global step so # curves line up cleanly across the three runs (they all start at 0). step_count = 0 for i, row in enumerate(rows): metrics = _step_metrics(row) if not metrics: continue run.log(metrics, step=i) step_count += 1 print(f" pushed: {step_count} step rows") # Final eval summary attached to the run summary so it appears in the # leaderboard view and can be charted against each other across runs. eval_summary = _eval_summary(REPO_ROOT / spec["eval_dir"]) if eval_summary: for k, v in eval_summary.items(): run.summary[k] = v print(f" eval summary keys: {sorted(eval_summary)}") # Also surface the training_summary.json scalars (reward, loss totals, # walltime, etc) as run summary for at-a-glance comparison. summary_path = REPO_ROOT / spec["summary_path"] if summary_path.exists(): ts = _load_json(summary_path) for k, v in ts.items(): if isinstance(v, (int, float, str, bool)): run.summary[f"training/{k}"] = v run.finish() print("\nDone. Public dashboard:") print(f" https://wandb.ai/anurag203/{PROJECT}") return 0 if __name__ == "__main__": sys.exit(main())