clarify-rl / scripts /backfill_wandb.py
agarwalanu3103's picture
plots: add training progression + diagnostics, drop W&B links
099bec8 verified
"""Post-hoc W&B import for the 3 trained ClarifyRL runs.
Reads the local TRL ``log_history.json`` files saved during training, replays
every step into a public W&B project so judges can click in and explore the
curves. Eval JSONs are summarized into a final ``eval/`` block per run.
Idempotent: re-runs overwrite the existing W&B run for the same id.
Usage::
WANDB_API_KEY=... python scripts/backfill_wandb.py
"""
from __future__ import annotations
import json
import os
import sys
from pathlib import Path
from typing import Any
import wandb
REPO_ROOT = Path(__file__).resolve().parent.parent
PROJECT = "clarify-rl"
ENTITY = os.environ.get("WANDB_ENTITY") # default: user's default entity
# ---------------------------------------------------------------------------
# Run definitions: where to find the log history + eval, and how to label them
# ---------------------------------------------------------------------------
RUNS = [
{
"id": "run1-0p6b-no-kl",
"name": "Run 1 — 0.6B GRPO (β=0)",
"group": "GRPO ablation",
"tags": ["0.6B", "no-KL", "grpo", "completed"],
"config": {
"model": "Qwen/Qwen3-0.6B",
"size_b": 0.6,
"max_steps": 300,
"beta_kl": 0.0,
"lr": 1e-6,
"num_generations": 8,
"vllm_gpu_mem_util": 0.55,
"max_completion_len": 1024,
"save_steps": 50,
"compute": "HF Jobs a10g-large",
"training_account": "agarwalanu3103",
"training_minutes": 41.0,
"hub_repo": "agarwalanu3103/clarify-rl-grpo-qwen3-0-6b",
},
"log_history": "outputs/run_artifacts/0.6B/log_history.json",
"eval_dir": "outputs/run_artifacts/0.6B/evals",
"summary_path": "outputs/run_artifacts/0.6B/training_summary.json",
"notes": (
"Run 1: baseline GRPO at 0.6B with no KL anchor. Successfully "
"unlocked the event_planning family (0/12 -> 1/12, max 0.382). "
"Aggregate score 0.0 -> 0.0076. Trains the 'GRPO can teach a "
"weak base new behaviors' part of the thesis."
),
},
{
"id": "run2-1p7b-no-kl",
"name": "Run 2 — 1.7B GRPO (β=0)",
"group": "GRPO ablation",
"tags": ["1.7B", "no-KL", "grpo", "completed", "regression"],
"config": {
"model": "Qwen/Qwen3-1.7B",
"size_b": 1.7,
"max_steps": 400,
"beta_kl": 0.0,
"lr": 1e-6,
"num_generations": 8,
"vllm_gpu_mem_util": 0.40,
"max_completion_len": 768,
"save_steps": 50,
"compute": "HF Jobs a100-large",
"training_account": "agarwalanu3103",
"training_minutes": 73.5,
"hub_repo": "agarwalanu3103/clarify-rl-grpo-qwen3-1-7b",
},
"log_history": "outputs/run_artifacts/1.7B/log_history.json",
"eval_dir": "outputs/run_artifacts/1.7B/evals",
"summary_path": "outputs/run_artifacts/1.7B/training_summary.json",
"notes": (
"Run 2: same GRPO recipe as Run 1 scaled to 1.7B. AGGREGATE "
"REGRESSION: 0.0669 -> 0.0286. Catastrophic on event_planning "
"(0.138 mean -> 0.000) but raised meeting_scheduling peak to "
"0.725. This is the regression Run 4 was designed to fix."
),
},
{
"id": "run4-1p7b-kl-anchor",
"name": "Run 4 — 1.7B GRPO (β=0.2 KL anchor)",
"group": "GRPO ablation",
"tags": [
"1.7B",
"kl-anchor",
"grpo",
"completed",
"headline-result",
],
"config": {
"model": "Qwen/Qwen3-1.7B",
"size_b": 1.7,
"max_steps": 300,
"beta_kl": 0.2,
"lr": 5e-7,
"num_generations": 8,
"vllm_gpu_mem_util": 0.40,
"max_completion_len": 768,
"save_steps": 50,
"compute": "HF Jobs a100-large",
"training_account": "2022uec1542",
"training_minutes": 78.2,
"hub_repo": "2022uec1542/clarify-rl-grpo-qwen3-1-7b",
},
"log_history": "outputs/run_artifacts/1.7B-KL/log_history.json",
"eval_dir": "outputs/run_artifacts/1.7B-KL/evals",
"summary_path": "outputs/run_artifacts/1.7B-KL/training_summary.json",
"notes": (
"Run 4: same model + env + steps as Run 2, with TRL beta=0.2 KL "
"anchor and half LR. CENTRAL HACKATHON FINDING: event_planning "
"recovers 0.000 -> 0.175 (beats 1.7B base 0.138), aggregate "
"0.029 -> 0.056. Trade-off: meeting_scheduling peak 0.725 -> "
"0.350. KL stayed bounded 0.005-0.010 throughout training, "
"confirming the anchor was active. Same model. Same env. Same "
"steps. One hyperparameter flips the sign of the regression."
),
},
]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _load_json(path: Path) -> Any:
with path.open() as f:
return json.load(f)
def _eval_summary(eval_dir: Path) -> dict[str, Any]:
"""Pick the most recently modified eval JSON in ``eval_dir`` and pull
aggregate + per-family numbers out of it.
"""
if not eval_dir.exists():
return {}
candidates = sorted(eval_dir.glob("*.json"), key=lambda p: p.stat().st_mtime)
if not candidates:
return {}
eval_path = candidates[-1]
eval_data = _load_json(eval_path)
rows = eval_data.get("rows") or eval_data.get("results") or []
if not rows:
return {"eval/_source": str(eval_path.relative_to(REPO_ROOT))}
avg_score = sum(r.get("score", 0.0) for r in rows) / max(len(rows), 1)
completion = sum(1 for r in rows if r.get("score", 0.0) > 0) / max(len(rows), 1)
fam_means: dict[str, list[float]] = {}
fam_max: dict[str, float] = {}
for r in rows:
fam = r.get("family") or r.get("task_family") or r.get("scenario", "").split("_")[0]
fam_means.setdefault(fam, []).append(r.get("score", 0.0))
fam_max[fam] = max(fam_max.get(fam, 0.0), r.get("score", 0.0))
out: dict[str, Any] = {
"eval/avg_score": avg_score,
"eval/completion_rate": completion,
"eval/n": len(rows),
"eval/_source": str(eval_path.relative_to(REPO_ROOT)),
}
for fam, scores in fam_means.items():
out[f"eval/family/{fam}/mean"] = sum(scores) / len(scores)
out[f"eval/family/{fam}/max"] = fam_max[fam]
return out
def _step_metrics(row: dict) -> dict:
"""Strip the row down to scalar metrics safe to push to W&B for a single
optimization step."""
out = {}
for k, v in row.items():
if k in ("step", "epoch"):
continue
if isinstance(v, (int, float)):
out[k] = v
return out
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> int:
api_key = os.environ.get("WANDB_API_KEY")
if not api_key:
print("ERROR: WANDB_API_KEY env var is required", file=sys.stderr)
return 1
wandb.login(key=api_key, relogin=True)
for spec in RUNS:
log_path = REPO_ROOT / spec["log_history"]
if not log_path.exists():
print(f"[skip] {spec['id']}: missing {log_path}")
continue
rows = _load_json(log_path)
print(f"\n=== {spec['name']} ===")
print(f" rows: {len(rows)}")
print(f" log: {log_path.relative_to(REPO_ROOT)}")
run = wandb.init(
project=PROJECT,
entity=ENTITY,
id=spec["id"],
name=spec["name"],
group=spec["group"],
tags=spec["tags"],
config=spec["config"],
notes=spec["notes"],
resume="allow",
reinit=True,
)
# Replay step-level metrics. TRL's log_history rows that contain
# both training step metrics and a final summary row are heterogeneous;
# we only push numeric scalars. Use row index as the global step so
# curves line up cleanly across the three runs (they all start at 0).
step_count = 0
for i, row in enumerate(rows):
metrics = _step_metrics(row)
if not metrics:
continue
run.log(metrics, step=i)
step_count += 1
print(f" pushed: {step_count} step rows")
# Final eval summary attached to the run summary so it appears in the
# leaderboard view and can be charted against each other across runs.
eval_summary = _eval_summary(REPO_ROOT / spec["eval_dir"])
if eval_summary:
for k, v in eval_summary.items():
run.summary[k] = v
print(f" eval summary keys: {sorted(eval_summary)}")
# Also surface the training_summary.json scalars (reward, loss totals,
# walltime, etc) as run summary for at-a-glance comparison.
summary_path = REPO_ROOT / spec["summary_path"]
if summary_path.exists():
ts = _load_json(summary_path)
for k, v in ts.items():
if isinstance(v, (int, float, str, bool)):
run.summary[f"training/{k}"] = v
run.finish()
print("\nDone. Public dashboard:")
print(f" https://wandb.ai/anurag203/{PROJECT}")
return 0
if __name__ == "__main__":
sys.exit(main())