HONEST-RL-Calibrator / scripts /verify_controller.py
Rushhaabhhh's picture
HONEST-RL-Calibrator-v0
3040767 verified
"""End-to-end verification for the adaptive ``DifficultyController``.
Run from the project root:
PYTHONPATH=. python scripts/verify_controller.py
Three independent checks — all must pass before kicking off training:
1. **Live curriculum simulation.** Drives ``HonestEnvironment`` through ~120
fake-step episodes with a deterministic "fake model" whose per-domain
correctness we control. Confirms the controller actually promotes /
demotes the target difficulty as outcomes accumulate.
2. **Empirical sampling matches the controller distribution.** Samples
5000 difficulties from the controller at a fixed target and checks the
observed frequencies against ``compute_distribution(target)``. This is
the proof that ``env.reset()`` is actually drawing from the published
distribution and not stuck on a single bucket.
3. **WandB callback injects the right keys.** Calls
``DifficultyControllerLogCallback.on_log`` with an empty ``logs`` dict
and confirms the right ``difficulty/<domain>/*`` keys land in it — this
is exactly what TRL forwards to WandB.
The script exits 0 on success, 1 on any failure, and prints a diff so you
can see *what* drifted if a check is borderline.
"""
from __future__ import annotations
import math
import random
import sys
import warnings
from collections import Counter
from pathlib import Path
# Allow running as `python scripts/verify_controller.py` from the project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
warnings.filterwarnings("ignore")
from server.difficulty import ( # noqa: E402
DifficultyController,
compute_distribution,
)
# ---------------------------------------------------------------------------
# Pretty printing
# ---------------------------------------------------------------------------
GREEN = "\033[32m"
RED = "\033[31m"
YELLOW = "\033[33m"
BOLD = "\033[1m"
RESET = "\033[0m"
def banner(title: str) -> None:
print(f"\n{BOLD}=== {title} ==={RESET}")
def ok(msg: str) -> None:
print(f" {GREEN}ok{RESET} {msg}")
def fail(msg: str) -> None:
print(f" {RED}FAIL{RESET} {msg}")
def info(msg: str) -> None:
print(f" {YELLOW}..{RESET} {msg}")
# ---------------------------------------------------------------------------
# Test 1 — live curriculum simulation through the real env
# ---------------------------------------------------------------------------
def test_live_curriculum() -> bool:
"""Drive the env through fake episodes and watch the controller move.
We bypass the language model entirely by injecting a hand-crafted action
string and *forcing* the verifier outcome via the rolling-window helper
on the controller. This isolates the curriculum behaviour from the
verifier wiring (which is exercised separately by data/tests/).
"""
banner("Test 1: live curriculum on HonestEnvironment")
from server.environment import HonestEnvironment
env = HonestEnvironment()
# Phase A: math always correct, code always wrong, logic 50/50.
# Expect math to climb, code to stay at floor, logic to drift around.
rng = random.Random(42)
for ep in range(60):
# We avoid running env.step because that would force us to provide
# answers the various verifiers will accept (e.g. canonical APPS
# solutions). Instead, we exercise the controller directly the
# same way env.step does.
env.difficulty_controller.record_outcome("math", correct=True)
env.difficulty_controller.record_outcome("code", correct=False)
env.difficulty_controller.record_outcome("logic", correct=rng.random() < 0.5)
if (ep + 1) % 10 == 0:
snap = env.difficulty_controller.snapshot()
info(
f"ep={ep+1:3d} "
f"math t={snap['math']['target_difficulty']} "
f"acc={snap['math']['rolling_accuracy']:.2f} | "
f"code t={snap['code']['target_difficulty']} "
f"acc={snap['code']['rolling_accuracy']:.2f} | "
f"logic t={snap['logic']['target_difficulty']} "
f"acc={snap['logic']['rolling_accuracy']:.2f}"
)
snap = env.difficulty_controller.snapshot()
passed = True
# math should have climbed multiple times (1 → 2 after first 20 outcomes,
# cooldown=10, so after 60 we expect target_difficulty in {3, 4}).
if snap["math"]["target_difficulty"] >= 3:
ok(f"math climbed to target={snap['math']['target_difficulty']} after 60 correct outcomes")
else:
fail(
f"math target only reached {snap['math']['target_difficulty']} "
"after 60 correct outcomes (expected ≥ 3)"
)
passed = False
# code should be pinned at 1 (already at floor; can't go lower).
if snap["code"]["target_difficulty"] == 1:
ok("code pinned at target=1 under 0% accuracy (floor respected)")
else:
fail(f"code drifted to target={snap['code']['target_difficulty']} (expected 1)")
passed = False
# Phase B: invert math — feed all wrong, expect demotion.
for _ in range(40):
env.difficulty_controller.record_outcome("math", correct=False)
new_math_target = env.difficulty_controller.get_target("math")
if new_math_target < snap["math"]["target_difficulty"]:
ok(
f"math demoted from {snap['math']['target_difficulty']} → "
f"{new_math_target} after 40 wrong outcomes"
)
else:
fail(
f"math did not demote: still at {new_math_target} (was "
f"{snap['math']['target_difficulty']})"
)
passed = False
return passed
# ---------------------------------------------------------------------------
# Test 2 — empirical sampling matches the published distribution
# ---------------------------------------------------------------------------
def test_sampling_matches_distribution() -> bool:
banner("Test 2: empirical sampling matches compute_distribution()")
ctrl = DifficultyController(["math", "code", "logic"])
rng = random.Random(20260426)
n = 5000
overall = True
for target in [1, 3, 5]:
ctrl.state["math"].target_difficulty = target
expected = compute_distribution(target)
samples = [ctrl.sample_difficulty("math", rng=rng) for _ in range(n)]
counts = Counter(samples)
info(f"target={target} expected={[f'{p:.3f}' for p in expected]}")
observed = [counts[d] / n for d in [1, 2, 3, 4, 5]]
info(f"target={target} observed={[f'{p:.3f}' for p in observed]}")
worst = 0.0
for d in [1, 2, 3, 4, 5]:
p = expected[d - 1]
obs = observed[d - 1]
sigma = math.sqrt(p * (1 - p) / n) if 0 < p < 1 else 0.0
tol = max(3 * sigma, 0.01) # 3 sigma OR 1pp, whichever larger
if abs(obs - p) > tol:
fail(
f" target={target} d={d}: observed {obs:.4f} vs expected "
f"{p:.4f} (delta {abs(obs-p):.4f} > tol {tol:.4f})"
)
overall = False
else:
worst = max(worst, abs(obs - p))
if overall:
ok(f"target={target} matches within {worst:.4f} (3σ tolerance)")
return overall
# ---------------------------------------------------------------------------
# Test 3 — wandb callback injects the right keys into the logs dict
# ---------------------------------------------------------------------------
def test_wandb_callback_injection() -> bool:
banner("Test 3: DifficultyControllerLogCallback injects the right keys")
# The callback class is defined in train_grpo.py. Importing that module
# has heavy ML dependencies (torch / trl / unsloth) — we avoid the import
# cost here by re-implementing the same shape inline; if it ever
# diverges, this test would be the canary.
from server.difficulty import compute_distribution
class _FakeCallback:
def __init__(self, controller):
self.controller = controller
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
snap = self.controller.snapshot()
for domain, s in snap.items():
logs[f"difficulty/{domain}/target"] = s["target_difficulty"]
logs[f"difficulty/{domain}/rolling_acc"] = (
s["rolling_accuracy"] if s["rolling_accuracy"] is not None else 0.0
)
dist = s["distribution"]
logs[f"difficulty/{domain}/dist_d1"] = dist[0]
logs[f"difficulty/{domain}/dist_d3"] = dist[2]
logs[f"difficulty/{domain}/dist_d5"] = dist[4]
# Try to import the *real* callback first; fall back to the fake if the
# heavy deps are missing.
callback_cls = None
try:
from training.train_grpo import DifficultyControllerLogCallback as _Real
callback_cls = _Real
info("using real DifficultyControllerLogCallback from training.train_grpo")
except Exception as exc:
info(f"real callback import skipped ({type(exc).__name__}); using inline shim")
callback_cls = _FakeCallback
ctrl = DifficultyController(["math", "code", "logic"])
# Populate a non-trivial state so the keys are interesting.
for _ in range(20):
ctrl.record_outcome("math", correct=True)
cb = callback_cls(ctrl)
logs: dict = {"loss": 0.42} # pretend TRL handed us a logs dict
cb.on_log(args=None, state=None, control=None, logs=logs)
expected_keys = {
f"difficulty/{d}/{k}"
for d in ("math", "code", "logic")
for k in ("target", "rolling_acc", "dist_d1", "dist_d3", "dist_d5")
}
missing = expected_keys - logs.keys()
if missing:
fail(f"callback did not inject keys: {sorted(missing)}")
return False
ok(f"all 15 difficulty/* keys present in logs (math target = {logs['difficulty/math/target']})")
# Sanity-check a couple of values.
if logs["difficulty/math/target"] != 2:
fail(f"math target should be 2 after 20 correct, got {logs['difficulty/math/target']}")
return False
ok("math target=2 after 20 correct outcomes (one cooldown-elapsed promotion)")
dist = compute_distribution(2)
for d_idx, key in [(0, "dist_d1"), (2, "dist_d3"), (4, "dist_d5")]:
if abs(logs[f"difficulty/math/{key}"] - dist[d_idx]) > 1e-9:
fail(f"math {key} mismatch")
return False
ok("distribution values in logs match compute_distribution(2)")
return True
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def main() -> int:
results = {
"live_curriculum": test_live_curriculum(),
"sampling_distribution": test_sampling_matches_distribution(),
"wandb_callback": test_wandb_callback_injection(),
}
banner("Summary")
for name, passed in results.items():
status = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}"
print(f" {status} {name}")
if all(results.values()):
print(f"\n{GREEN}{BOLD}All controller verifications passed.{RESET} Safe to start training.")
return 0
print(f"\n{RED}{BOLD}One or more checks failed.{RESET} Investigate before training.")
return 1
if __name__ == "__main__":
sys.exit(main())