Spaces:
Sleeping
Sleeping
File size: 11,652 Bytes
3040767 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 | """End-to-end verification for the adaptive ``DifficultyController``.
Run from the project root:
PYTHONPATH=. python scripts/verify_controller.py
Three independent checks β all must pass before kicking off training:
1. **Live curriculum simulation.** Drives ``HonestEnvironment`` through ~120
fake-step episodes with a deterministic "fake model" whose per-domain
correctness we control. Confirms the controller actually promotes /
demotes the target difficulty as outcomes accumulate.
2. **Empirical sampling matches the controller distribution.** Samples
5000 difficulties from the controller at a fixed target and checks the
observed frequencies against ``compute_distribution(target)``. This is
the proof that ``env.reset()`` is actually drawing from the published
distribution and not stuck on a single bucket.
3. **WandB callback injects the right keys.** Calls
``DifficultyControllerLogCallback.on_log`` with an empty ``logs`` dict
and confirms the right ``difficulty/<domain>/*`` keys land in it β this
is exactly what TRL forwards to WandB.
The script exits 0 on success, 1 on any failure, and prints a diff so you
can see *what* drifted if a check is borderline.
"""
from __future__ import annotations
import math
import random
import sys
import warnings
from collections import Counter
from pathlib import Path
# Allow running as `python scripts/verify_controller.py` from the project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))
warnings.filterwarnings("ignore")
from server.difficulty import ( # noqa: E402
DifficultyController,
compute_distribution,
)
# ---------------------------------------------------------------------------
# Pretty printing
# ---------------------------------------------------------------------------
GREEN = "\033[32m"
RED = "\033[31m"
YELLOW = "\033[33m"
BOLD = "\033[1m"
RESET = "\033[0m"
def banner(title: str) -> None:
print(f"\n{BOLD}=== {title} ==={RESET}")
def ok(msg: str) -> None:
print(f" {GREEN}ok{RESET} {msg}")
def fail(msg: str) -> None:
print(f" {RED}FAIL{RESET} {msg}")
def info(msg: str) -> None:
print(f" {YELLOW}..{RESET} {msg}")
# ---------------------------------------------------------------------------
# Test 1 β live curriculum simulation through the real env
# ---------------------------------------------------------------------------
def test_live_curriculum() -> bool:
"""Drive the env through fake episodes and watch the controller move.
We bypass the language model entirely by injecting a hand-crafted action
string and *forcing* the verifier outcome via the rolling-window helper
on the controller. This isolates the curriculum behaviour from the
verifier wiring (which is exercised separately by data/tests/).
"""
banner("Test 1: live curriculum on HonestEnvironment")
from server.environment import HonestEnvironment
env = HonestEnvironment()
# Phase A: math always correct, code always wrong, logic 50/50.
# Expect math to climb, code to stay at floor, logic to drift around.
rng = random.Random(42)
for ep in range(60):
# We avoid running env.step because that would force us to provide
# answers the various verifiers will accept (e.g. canonical APPS
# solutions). Instead, we exercise the controller directly the
# same way env.step does.
env.difficulty_controller.record_outcome("math", correct=True)
env.difficulty_controller.record_outcome("code", correct=False)
env.difficulty_controller.record_outcome("logic", correct=rng.random() < 0.5)
if (ep + 1) % 10 == 0:
snap = env.difficulty_controller.snapshot()
info(
f"ep={ep+1:3d} "
f"math t={snap['math']['target_difficulty']} "
f"acc={snap['math']['rolling_accuracy']:.2f} | "
f"code t={snap['code']['target_difficulty']} "
f"acc={snap['code']['rolling_accuracy']:.2f} | "
f"logic t={snap['logic']['target_difficulty']} "
f"acc={snap['logic']['rolling_accuracy']:.2f}"
)
snap = env.difficulty_controller.snapshot()
passed = True
# math should have climbed multiple times (1 β 2 after first 20 outcomes,
# cooldown=10, so after 60 we expect target_difficulty in {3, 4}).
if snap["math"]["target_difficulty"] >= 3:
ok(f"math climbed to target={snap['math']['target_difficulty']} after 60 correct outcomes")
else:
fail(
f"math target only reached {snap['math']['target_difficulty']} "
"after 60 correct outcomes (expected β₯ 3)"
)
passed = False
# code should be pinned at 1 (already at floor; can't go lower).
if snap["code"]["target_difficulty"] == 1:
ok("code pinned at target=1 under 0% accuracy (floor respected)")
else:
fail(f"code drifted to target={snap['code']['target_difficulty']} (expected 1)")
passed = False
# Phase B: invert math β feed all wrong, expect demotion.
for _ in range(40):
env.difficulty_controller.record_outcome("math", correct=False)
new_math_target = env.difficulty_controller.get_target("math")
if new_math_target < snap["math"]["target_difficulty"]:
ok(
f"math demoted from {snap['math']['target_difficulty']} β "
f"{new_math_target} after 40 wrong outcomes"
)
else:
fail(
f"math did not demote: still at {new_math_target} (was "
f"{snap['math']['target_difficulty']})"
)
passed = False
return passed
# ---------------------------------------------------------------------------
# Test 2 β empirical sampling matches the published distribution
# ---------------------------------------------------------------------------
def test_sampling_matches_distribution() -> bool:
banner("Test 2: empirical sampling matches compute_distribution()")
ctrl = DifficultyController(["math", "code", "logic"])
rng = random.Random(20260426)
n = 5000
overall = True
for target in [1, 3, 5]:
ctrl.state["math"].target_difficulty = target
expected = compute_distribution(target)
samples = [ctrl.sample_difficulty("math", rng=rng) for _ in range(n)]
counts = Counter(samples)
info(f"target={target} expected={[f'{p:.3f}' for p in expected]}")
observed = [counts[d] / n for d in [1, 2, 3, 4, 5]]
info(f"target={target} observed={[f'{p:.3f}' for p in observed]}")
worst = 0.0
for d in [1, 2, 3, 4, 5]:
p = expected[d - 1]
obs = observed[d - 1]
sigma = math.sqrt(p * (1 - p) / n) if 0 < p < 1 else 0.0
tol = max(3 * sigma, 0.01) # 3 sigma OR 1pp, whichever larger
if abs(obs - p) > tol:
fail(
f" target={target} d={d}: observed {obs:.4f} vs expected "
f"{p:.4f} (delta {abs(obs-p):.4f} > tol {tol:.4f})"
)
overall = False
else:
worst = max(worst, abs(obs - p))
if overall:
ok(f"target={target} matches within {worst:.4f} (3Ο tolerance)")
return overall
# ---------------------------------------------------------------------------
# Test 3 β wandb callback injects the right keys into the logs dict
# ---------------------------------------------------------------------------
def test_wandb_callback_injection() -> bool:
banner("Test 3: DifficultyControllerLogCallback injects the right keys")
# The callback class is defined in train_grpo.py. Importing that module
# has heavy ML dependencies (torch / trl / unsloth) β we avoid the import
# cost here by re-implementing the same shape inline; if it ever
# diverges, this test would be the canary.
from server.difficulty import compute_distribution
class _FakeCallback:
def __init__(self, controller):
self.controller = controller
def on_log(self, args, state, control, logs=None, **kwargs):
if logs is None:
return
snap = self.controller.snapshot()
for domain, s in snap.items():
logs[f"difficulty/{domain}/target"] = s["target_difficulty"]
logs[f"difficulty/{domain}/rolling_acc"] = (
s["rolling_accuracy"] if s["rolling_accuracy"] is not None else 0.0
)
dist = s["distribution"]
logs[f"difficulty/{domain}/dist_d1"] = dist[0]
logs[f"difficulty/{domain}/dist_d3"] = dist[2]
logs[f"difficulty/{domain}/dist_d5"] = dist[4]
# Try to import the *real* callback first; fall back to the fake if the
# heavy deps are missing.
callback_cls = None
try:
from training.train_grpo import DifficultyControllerLogCallback as _Real
callback_cls = _Real
info("using real DifficultyControllerLogCallback from training.train_grpo")
except Exception as exc:
info(f"real callback import skipped ({type(exc).__name__}); using inline shim")
callback_cls = _FakeCallback
ctrl = DifficultyController(["math", "code", "logic"])
# Populate a non-trivial state so the keys are interesting.
for _ in range(20):
ctrl.record_outcome("math", correct=True)
cb = callback_cls(ctrl)
logs: dict = {"loss": 0.42} # pretend TRL handed us a logs dict
cb.on_log(args=None, state=None, control=None, logs=logs)
expected_keys = {
f"difficulty/{d}/{k}"
for d in ("math", "code", "logic")
for k in ("target", "rolling_acc", "dist_d1", "dist_d3", "dist_d5")
}
missing = expected_keys - logs.keys()
if missing:
fail(f"callback did not inject keys: {sorted(missing)}")
return False
ok(f"all 15 difficulty/* keys present in logs (math target = {logs['difficulty/math/target']})")
# Sanity-check a couple of values.
if logs["difficulty/math/target"] != 2:
fail(f"math target should be 2 after 20 correct, got {logs['difficulty/math/target']}")
return False
ok("math target=2 after 20 correct outcomes (one cooldown-elapsed promotion)")
dist = compute_distribution(2)
for d_idx, key in [(0, "dist_d1"), (2, "dist_d3"), (4, "dist_d5")]:
if abs(logs[f"difficulty/math/{key}"] - dist[d_idx]) > 1e-9:
fail(f"math {key} mismatch")
return False
ok("distribution values in logs match compute_distribution(2)")
return True
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def main() -> int:
results = {
"live_curriculum": test_live_curriculum(),
"sampling_distribution": test_sampling_matches_distribution(),
"wandb_callback": test_wandb_callback_injection(),
}
banner("Summary")
for name, passed in results.items():
status = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}"
print(f" {status} {name}")
if all(results.values()):
print(f"\n{GREEN}{BOLD}All controller verifications passed.{RESET} Safe to start training.")
return 0
print(f"\n{RED}{BOLD}One or more checks failed.{RESET} Investigate before training.")
return 1
if __name__ == "__main__":
sys.exit(main())
|