"""End-to-end verification for the adaptive ``DifficultyController``. Run from the project root: PYTHONPATH=. python scripts/verify_controller.py Three independent checks — all must pass before kicking off training: 1. **Live curriculum simulation.** Drives ``HonestEnvironment`` through ~120 fake-step episodes with a deterministic "fake model" whose per-domain correctness we control. Confirms the controller actually promotes / demotes the target difficulty as outcomes accumulate. 2. **Empirical sampling matches the controller distribution.** Samples 5000 difficulties from the controller at a fixed target and checks the observed frequencies against ``compute_distribution(target)``. This is the proof that ``env.reset()`` is actually drawing from the published distribution and not stuck on a single bucket. 3. **WandB callback injects the right keys.** Calls ``DifficultyControllerLogCallback.on_log`` with an empty ``logs`` dict and confirms the right ``difficulty//*`` keys land in it — this is exactly what TRL forwards to WandB. The script exits 0 on success, 1 on any failure, and prints a diff so you can see *what* drifted if a check is borderline. """ from __future__ import annotations import math import random import sys import warnings from collections import Counter from pathlib import Path # Allow running as `python scripts/verify_controller.py` from the project root. PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) warnings.filterwarnings("ignore") from server.difficulty import ( # noqa: E402 DifficultyController, compute_distribution, ) # --------------------------------------------------------------------------- # Pretty printing # --------------------------------------------------------------------------- GREEN = "\033[32m" RED = "\033[31m" YELLOW = "\033[33m" BOLD = "\033[1m" RESET = "\033[0m" def banner(title: str) -> None: print(f"\n{BOLD}=== {title} ==={RESET}") def ok(msg: str) -> None: print(f" {GREEN}ok{RESET} {msg}") def fail(msg: str) -> None: print(f" {RED}FAIL{RESET} {msg}") def info(msg: str) -> None: print(f" {YELLOW}..{RESET} {msg}") # --------------------------------------------------------------------------- # Test 1 — live curriculum simulation through the real env # --------------------------------------------------------------------------- def test_live_curriculum() -> bool: """Drive the env through fake episodes and watch the controller move. We bypass the language model entirely by injecting a hand-crafted action string and *forcing* the verifier outcome via the rolling-window helper on the controller. This isolates the curriculum behaviour from the verifier wiring (which is exercised separately by data/tests/). """ banner("Test 1: live curriculum on HonestEnvironment") from server.environment import HonestEnvironment env = HonestEnvironment() # Phase A: math always correct, code always wrong, logic 50/50. # Expect math to climb, code to stay at floor, logic to drift around. rng = random.Random(42) for ep in range(60): # We avoid running env.step because that would force us to provide # answers the various verifiers will accept (e.g. canonical APPS # solutions). Instead, we exercise the controller directly the # same way env.step does. env.difficulty_controller.record_outcome("math", correct=True) env.difficulty_controller.record_outcome("code", correct=False) env.difficulty_controller.record_outcome("logic", correct=rng.random() < 0.5) if (ep + 1) % 10 == 0: snap = env.difficulty_controller.snapshot() info( f"ep={ep+1:3d} " f"math t={snap['math']['target_difficulty']} " f"acc={snap['math']['rolling_accuracy']:.2f} | " f"code t={snap['code']['target_difficulty']} " f"acc={snap['code']['rolling_accuracy']:.2f} | " f"logic t={snap['logic']['target_difficulty']} " f"acc={snap['logic']['rolling_accuracy']:.2f}" ) snap = env.difficulty_controller.snapshot() passed = True # math should have climbed multiple times (1 → 2 after first 20 outcomes, # cooldown=10, so after 60 we expect target_difficulty in {3, 4}). if snap["math"]["target_difficulty"] >= 3: ok(f"math climbed to target={snap['math']['target_difficulty']} after 60 correct outcomes") else: fail( f"math target only reached {snap['math']['target_difficulty']} " "after 60 correct outcomes (expected ≥ 3)" ) passed = False # code should be pinned at 1 (already at floor; can't go lower). if snap["code"]["target_difficulty"] == 1: ok("code pinned at target=1 under 0% accuracy (floor respected)") else: fail(f"code drifted to target={snap['code']['target_difficulty']} (expected 1)") passed = False # Phase B: invert math — feed all wrong, expect demotion. for _ in range(40): env.difficulty_controller.record_outcome("math", correct=False) new_math_target = env.difficulty_controller.get_target("math") if new_math_target < snap["math"]["target_difficulty"]: ok( f"math demoted from {snap['math']['target_difficulty']} → " f"{new_math_target} after 40 wrong outcomes" ) else: fail( f"math did not demote: still at {new_math_target} (was " f"{snap['math']['target_difficulty']})" ) passed = False return passed # --------------------------------------------------------------------------- # Test 2 — empirical sampling matches the published distribution # --------------------------------------------------------------------------- def test_sampling_matches_distribution() -> bool: banner("Test 2: empirical sampling matches compute_distribution()") ctrl = DifficultyController(["math", "code", "logic"]) rng = random.Random(20260426) n = 5000 overall = True for target in [1, 3, 5]: ctrl.state["math"].target_difficulty = target expected = compute_distribution(target) samples = [ctrl.sample_difficulty("math", rng=rng) for _ in range(n)] counts = Counter(samples) info(f"target={target} expected={[f'{p:.3f}' for p in expected]}") observed = [counts[d] / n for d in [1, 2, 3, 4, 5]] info(f"target={target} observed={[f'{p:.3f}' for p in observed]}") worst = 0.0 for d in [1, 2, 3, 4, 5]: p = expected[d - 1] obs = observed[d - 1] sigma = math.sqrt(p * (1 - p) / n) if 0 < p < 1 else 0.0 tol = max(3 * sigma, 0.01) # 3 sigma OR 1pp, whichever larger if abs(obs - p) > tol: fail( f" target={target} d={d}: observed {obs:.4f} vs expected " f"{p:.4f} (delta {abs(obs-p):.4f} > tol {tol:.4f})" ) overall = False else: worst = max(worst, abs(obs - p)) if overall: ok(f"target={target} matches within {worst:.4f} (3σ tolerance)") return overall # --------------------------------------------------------------------------- # Test 3 — wandb callback injects the right keys into the logs dict # --------------------------------------------------------------------------- def test_wandb_callback_injection() -> bool: banner("Test 3: DifficultyControllerLogCallback injects the right keys") # The callback class is defined in train_grpo.py. Importing that module # has heavy ML dependencies (torch / trl / unsloth) — we avoid the import # cost here by re-implementing the same shape inline; if it ever # diverges, this test would be the canary. from server.difficulty import compute_distribution class _FakeCallback: def __init__(self, controller): self.controller = controller def on_log(self, args, state, control, logs=None, **kwargs): if logs is None: return snap = self.controller.snapshot() for domain, s in snap.items(): logs[f"difficulty/{domain}/target"] = s["target_difficulty"] logs[f"difficulty/{domain}/rolling_acc"] = ( s["rolling_accuracy"] if s["rolling_accuracy"] is not None else 0.0 ) dist = s["distribution"] logs[f"difficulty/{domain}/dist_d1"] = dist[0] logs[f"difficulty/{domain}/dist_d3"] = dist[2] logs[f"difficulty/{domain}/dist_d5"] = dist[4] # Try to import the *real* callback first; fall back to the fake if the # heavy deps are missing. callback_cls = None try: from training.train_grpo import DifficultyControllerLogCallback as _Real callback_cls = _Real info("using real DifficultyControllerLogCallback from training.train_grpo") except Exception as exc: info(f"real callback import skipped ({type(exc).__name__}); using inline shim") callback_cls = _FakeCallback ctrl = DifficultyController(["math", "code", "logic"]) # Populate a non-trivial state so the keys are interesting. for _ in range(20): ctrl.record_outcome("math", correct=True) cb = callback_cls(ctrl) logs: dict = {"loss": 0.42} # pretend TRL handed us a logs dict cb.on_log(args=None, state=None, control=None, logs=logs) expected_keys = { f"difficulty/{d}/{k}" for d in ("math", "code", "logic") for k in ("target", "rolling_acc", "dist_d1", "dist_d3", "dist_d5") } missing = expected_keys - logs.keys() if missing: fail(f"callback did not inject keys: {sorted(missing)}") return False ok(f"all 15 difficulty/* keys present in logs (math target = {logs['difficulty/math/target']})") # Sanity-check a couple of values. if logs["difficulty/math/target"] != 2: fail(f"math target should be 2 after 20 correct, got {logs['difficulty/math/target']}") return False ok("math target=2 after 20 correct outcomes (one cooldown-elapsed promotion)") dist = compute_distribution(2) for d_idx, key in [(0, "dist_d1"), (2, "dist_d3"), (4, "dist_d5")]: if abs(logs[f"difficulty/math/{key}"] - dist[d_idx]) > 1e-9: fail(f"math {key} mismatch") return False ok("distribution values in logs match compute_distribution(2)") return True # --------------------------------------------------------------------------- # Runner # --------------------------------------------------------------------------- def main() -> int: results = { "live_curriculum": test_live_curriculum(), "sampling_distribution": test_sampling_matches_distribution(), "wandb_callback": test_wandb_callback_injection(), } banner("Summary") for name, passed in results.items(): status = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}" print(f" {status} {name}") if all(results.values()): print(f"\n{GREEN}{BOLD}All controller verifications passed.{RESET} Safe to start training.") return 0 print(f"\n{RED}{BOLD}One or more checks failed.{RESET} Investigate before training.") return 1 if __name__ == "__main__": sys.exit(main())