Spaces:
Sleeping
Sleeping
| """End-to-end verification for the adaptive ``DifficultyController``. | |
| Run from the project root: | |
| PYTHONPATH=. python scripts/verify_controller.py | |
| Three independent checks — all must pass before kicking off training: | |
| 1. **Live curriculum simulation.** Drives ``HonestEnvironment`` through ~120 | |
| fake-step episodes with a deterministic "fake model" whose per-domain | |
| correctness we control. Confirms the controller actually promotes / | |
| demotes the target difficulty as outcomes accumulate. | |
| 2. **Empirical sampling matches the controller distribution.** Samples | |
| 5000 difficulties from the controller at a fixed target and checks the | |
| observed frequencies against ``compute_distribution(target)``. This is | |
| the proof that ``env.reset()`` is actually drawing from the published | |
| distribution and not stuck on a single bucket. | |
| 3. **WandB callback injects the right keys.** Calls | |
| ``DifficultyControllerLogCallback.on_log`` with an empty ``logs`` dict | |
| and confirms the right ``difficulty/<domain>/*`` keys land in it — this | |
| is exactly what TRL forwards to WandB. | |
| The script exits 0 on success, 1 on any failure, and prints a diff so you | |
| can see *what* drifted if a check is borderline. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import random | |
| import sys | |
| import warnings | |
| from collections import Counter | |
| from pathlib import Path | |
| # Allow running as `python scripts/verify_controller.py` from the project root. | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| warnings.filterwarnings("ignore") | |
| from server.difficulty import ( # noqa: E402 | |
| DifficultyController, | |
| compute_distribution, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Pretty printing | |
| # --------------------------------------------------------------------------- | |
| GREEN = "\033[32m" | |
| RED = "\033[31m" | |
| YELLOW = "\033[33m" | |
| BOLD = "\033[1m" | |
| RESET = "\033[0m" | |
| def banner(title: str) -> None: | |
| print(f"\n{BOLD}=== {title} ==={RESET}") | |
| def ok(msg: str) -> None: | |
| print(f" {GREEN}ok{RESET} {msg}") | |
| def fail(msg: str) -> None: | |
| print(f" {RED}FAIL{RESET} {msg}") | |
| def info(msg: str) -> None: | |
| print(f" {YELLOW}..{RESET} {msg}") | |
| # --------------------------------------------------------------------------- | |
| # Test 1 — live curriculum simulation through the real env | |
| # --------------------------------------------------------------------------- | |
| def test_live_curriculum() -> bool: | |
| """Drive the env through fake episodes and watch the controller move. | |
| We bypass the language model entirely by injecting a hand-crafted action | |
| string and *forcing* the verifier outcome via the rolling-window helper | |
| on the controller. This isolates the curriculum behaviour from the | |
| verifier wiring (which is exercised separately by data/tests/). | |
| """ | |
| banner("Test 1: live curriculum on HonestEnvironment") | |
| from server.environment import HonestEnvironment | |
| env = HonestEnvironment() | |
| # Phase A: math always correct, code always wrong, logic 50/50. | |
| # Expect math to climb, code to stay at floor, logic to drift around. | |
| rng = random.Random(42) | |
| for ep in range(60): | |
| # We avoid running env.step because that would force us to provide | |
| # answers the various verifiers will accept (e.g. canonical APPS | |
| # solutions). Instead, we exercise the controller directly the | |
| # same way env.step does. | |
| env.difficulty_controller.record_outcome("math", correct=True) | |
| env.difficulty_controller.record_outcome("code", correct=False) | |
| env.difficulty_controller.record_outcome("logic", correct=rng.random() < 0.5) | |
| if (ep + 1) % 10 == 0: | |
| snap = env.difficulty_controller.snapshot() | |
| info( | |
| f"ep={ep+1:3d} " | |
| f"math t={snap['math']['target_difficulty']} " | |
| f"acc={snap['math']['rolling_accuracy']:.2f} | " | |
| f"code t={snap['code']['target_difficulty']} " | |
| f"acc={snap['code']['rolling_accuracy']:.2f} | " | |
| f"logic t={snap['logic']['target_difficulty']} " | |
| f"acc={snap['logic']['rolling_accuracy']:.2f}" | |
| ) | |
| snap = env.difficulty_controller.snapshot() | |
| passed = True | |
| # math should have climbed multiple times (1 → 2 after first 20 outcomes, | |
| # cooldown=10, so after 60 we expect target_difficulty in {3, 4}). | |
| if snap["math"]["target_difficulty"] >= 3: | |
| ok(f"math climbed to target={snap['math']['target_difficulty']} after 60 correct outcomes") | |
| else: | |
| fail( | |
| f"math target only reached {snap['math']['target_difficulty']} " | |
| "after 60 correct outcomes (expected ≥ 3)" | |
| ) | |
| passed = False | |
| # code should be pinned at 1 (already at floor; can't go lower). | |
| if snap["code"]["target_difficulty"] == 1: | |
| ok("code pinned at target=1 under 0% accuracy (floor respected)") | |
| else: | |
| fail(f"code drifted to target={snap['code']['target_difficulty']} (expected 1)") | |
| passed = False | |
| # Phase B: invert math — feed all wrong, expect demotion. | |
| for _ in range(40): | |
| env.difficulty_controller.record_outcome("math", correct=False) | |
| new_math_target = env.difficulty_controller.get_target("math") | |
| if new_math_target < snap["math"]["target_difficulty"]: | |
| ok( | |
| f"math demoted from {snap['math']['target_difficulty']} → " | |
| f"{new_math_target} after 40 wrong outcomes" | |
| ) | |
| else: | |
| fail( | |
| f"math did not demote: still at {new_math_target} (was " | |
| f"{snap['math']['target_difficulty']})" | |
| ) | |
| passed = False | |
| return passed | |
| # --------------------------------------------------------------------------- | |
| # Test 2 — empirical sampling matches the published distribution | |
| # --------------------------------------------------------------------------- | |
| def test_sampling_matches_distribution() -> bool: | |
| banner("Test 2: empirical sampling matches compute_distribution()") | |
| ctrl = DifficultyController(["math", "code", "logic"]) | |
| rng = random.Random(20260426) | |
| n = 5000 | |
| overall = True | |
| for target in [1, 3, 5]: | |
| ctrl.state["math"].target_difficulty = target | |
| expected = compute_distribution(target) | |
| samples = [ctrl.sample_difficulty("math", rng=rng) for _ in range(n)] | |
| counts = Counter(samples) | |
| info(f"target={target} expected={[f'{p:.3f}' for p in expected]}") | |
| observed = [counts[d] / n for d in [1, 2, 3, 4, 5]] | |
| info(f"target={target} observed={[f'{p:.3f}' for p in observed]}") | |
| worst = 0.0 | |
| for d in [1, 2, 3, 4, 5]: | |
| p = expected[d - 1] | |
| obs = observed[d - 1] | |
| sigma = math.sqrt(p * (1 - p) / n) if 0 < p < 1 else 0.0 | |
| tol = max(3 * sigma, 0.01) # 3 sigma OR 1pp, whichever larger | |
| if abs(obs - p) > tol: | |
| fail( | |
| f" target={target} d={d}: observed {obs:.4f} vs expected " | |
| f"{p:.4f} (delta {abs(obs-p):.4f} > tol {tol:.4f})" | |
| ) | |
| overall = False | |
| else: | |
| worst = max(worst, abs(obs - p)) | |
| if overall: | |
| ok(f"target={target} matches within {worst:.4f} (3σ tolerance)") | |
| return overall | |
| # --------------------------------------------------------------------------- | |
| # Test 3 — wandb callback injects the right keys into the logs dict | |
| # --------------------------------------------------------------------------- | |
| def test_wandb_callback_injection() -> bool: | |
| banner("Test 3: DifficultyControllerLogCallback injects the right keys") | |
| # The callback class is defined in train_grpo.py. Importing that module | |
| # has heavy ML dependencies (torch / trl / unsloth) — we avoid the import | |
| # cost here by re-implementing the same shape inline; if it ever | |
| # diverges, this test would be the canary. | |
| from server.difficulty import compute_distribution | |
| class _FakeCallback: | |
| def __init__(self, controller): | |
| self.controller = controller | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs is None: | |
| return | |
| snap = self.controller.snapshot() | |
| for domain, s in snap.items(): | |
| logs[f"difficulty/{domain}/target"] = s["target_difficulty"] | |
| logs[f"difficulty/{domain}/rolling_acc"] = ( | |
| s["rolling_accuracy"] if s["rolling_accuracy"] is not None else 0.0 | |
| ) | |
| dist = s["distribution"] | |
| logs[f"difficulty/{domain}/dist_d1"] = dist[0] | |
| logs[f"difficulty/{domain}/dist_d3"] = dist[2] | |
| logs[f"difficulty/{domain}/dist_d5"] = dist[4] | |
| # Try to import the *real* callback first; fall back to the fake if the | |
| # heavy deps are missing. | |
| callback_cls = None | |
| try: | |
| from training.train_grpo import DifficultyControllerLogCallback as _Real | |
| callback_cls = _Real | |
| info("using real DifficultyControllerLogCallback from training.train_grpo") | |
| except Exception as exc: | |
| info(f"real callback import skipped ({type(exc).__name__}); using inline shim") | |
| callback_cls = _FakeCallback | |
| ctrl = DifficultyController(["math", "code", "logic"]) | |
| # Populate a non-trivial state so the keys are interesting. | |
| for _ in range(20): | |
| ctrl.record_outcome("math", correct=True) | |
| cb = callback_cls(ctrl) | |
| logs: dict = {"loss": 0.42} # pretend TRL handed us a logs dict | |
| cb.on_log(args=None, state=None, control=None, logs=logs) | |
| expected_keys = { | |
| f"difficulty/{d}/{k}" | |
| for d in ("math", "code", "logic") | |
| for k in ("target", "rolling_acc", "dist_d1", "dist_d3", "dist_d5") | |
| } | |
| missing = expected_keys - logs.keys() | |
| if missing: | |
| fail(f"callback did not inject keys: {sorted(missing)}") | |
| return False | |
| ok(f"all 15 difficulty/* keys present in logs (math target = {logs['difficulty/math/target']})") | |
| # Sanity-check a couple of values. | |
| if logs["difficulty/math/target"] != 2: | |
| fail(f"math target should be 2 after 20 correct, got {logs['difficulty/math/target']}") | |
| return False | |
| ok("math target=2 after 20 correct outcomes (one cooldown-elapsed promotion)") | |
| dist = compute_distribution(2) | |
| for d_idx, key in [(0, "dist_d1"), (2, "dist_d3"), (4, "dist_d5")]: | |
| if abs(logs[f"difficulty/math/{key}"] - dist[d_idx]) > 1e-9: | |
| fail(f"math {key} mismatch") | |
| return False | |
| ok("distribution values in logs match compute_distribution(2)") | |
| return True | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| def main() -> int: | |
| results = { | |
| "live_curriculum": test_live_curriculum(), | |
| "sampling_distribution": test_sampling_matches_distribution(), | |
| "wandb_callback": test_wandb_callback_injection(), | |
| } | |
| banner("Summary") | |
| for name, passed in results.items(): | |
| status = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}" | |
| print(f" {status} {name}") | |
| if all(results.values()): | |
| print(f"\n{GREEN}{BOLD}All controller verifications passed.{RESET} Safe to start training.") | |
| return 0 | |
| print(f"\n{RED}{BOLD}One or more checks failed.{RESET} Investigate before training.") | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |