File size: 11,652 Bytes
3040767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""End-to-end verification for the adaptive ``DifficultyController``.

Run from the project root:
    PYTHONPATH=. python scripts/verify_controller.py

Three independent checks β€” all must pass before kicking off training:

1. **Live curriculum simulation.**  Drives ``HonestEnvironment`` through ~120
   fake-step episodes with a deterministic "fake model" whose per-domain
   correctness we control.  Confirms the controller actually promotes /
   demotes the target difficulty as outcomes accumulate.

2. **Empirical sampling matches the controller distribution.**  Samples
   5000 difficulties from the controller at a fixed target and checks the
   observed frequencies against ``compute_distribution(target)``.  This is
   the proof that ``env.reset()`` is actually drawing from the published
   distribution and not stuck on a single bucket.

3. **WandB callback injects the right keys.**  Calls
   ``DifficultyControllerLogCallback.on_log`` with an empty ``logs`` dict
   and confirms the right ``difficulty/<domain>/*`` keys land in it β€” this
   is exactly what TRL forwards to WandB.

The script exits 0 on success, 1 on any failure, and prints a diff so you
can see *what* drifted if a check is borderline.
"""

from __future__ import annotations

import math
import random
import sys
import warnings
from collections import Counter
from pathlib import Path

# Allow running as `python scripts/verify_controller.py` from the project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

warnings.filterwarnings("ignore")

from server.difficulty import (  # noqa: E402
    DifficultyController,
    compute_distribution,
)


# ---------------------------------------------------------------------------
# Pretty printing
# ---------------------------------------------------------------------------

GREEN = "\033[32m"
RED = "\033[31m"
YELLOW = "\033[33m"
BOLD = "\033[1m"
RESET = "\033[0m"


def banner(title: str) -> None:
    print(f"\n{BOLD}=== {title} ==={RESET}")


def ok(msg: str) -> None:
    print(f"  {GREEN}ok{RESET}    {msg}")


def fail(msg: str) -> None:
    print(f"  {RED}FAIL{RESET}  {msg}")


def info(msg: str) -> None:
    print(f"  {YELLOW}..{RESET}    {msg}")


# ---------------------------------------------------------------------------
# Test 1 β€” live curriculum simulation through the real env
# ---------------------------------------------------------------------------


def test_live_curriculum() -> bool:
    """Drive the env through fake episodes and watch the controller move.

    We bypass the language model entirely by injecting a hand-crafted action
    string and *forcing* the verifier outcome via the rolling-window helper
    on the controller.  This isolates the curriculum behaviour from the
    verifier wiring (which is exercised separately by data/tests/).
    """
    banner("Test 1: live curriculum on HonestEnvironment")

    from server.environment import HonestEnvironment

    env = HonestEnvironment()

    # Phase A: math always correct, code always wrong, logic 50/50.
    # Expect math to climb, code to stay at floor, logic to drift around.
    rng = random.Random(42)
    for ep in range(60):
        # We avoid running env.step because that would force us to provide
        # answers the various verifiers will accept (e.g. canonical APPS
        # solutions).  Instead, we exercise the controller directly the
        # same way env.step does.
        env.difficulty_controller.record_outcome("math", correct=True)
        env.difficulty_controller.record_outcome("code", correct=False)
        env.difficulty_controller.record_outcome("logic", correct=rng.random() < 0.5)

        if (ep + 1) % 10 == 0:
            snap = env.difficulty_controller.snapshot()
            info(
                f"ep={ep+1:3d}  "
                f"math t={snap['math']['target_difficulty']} "
                f"acc={snap['math']['rolling_accuracy']:.2f}  |  "
                f"code t={snap['code']['target_difficulty']} "
                f"acc={snap['code']['rolling_accuracy']:.2f}  |  "
                f"logic t={snap['logic']['target_difficulty']} "
                f"acc={snap['logic']['rolling_accuracy']:.2f}"
            )

    snap = env.difficulty_controller.snapshot()
    passed = True

    # math should have climbed multiple times (1 β†’ 2 after first 20 outcomes,
    # cooldown=10, so after 60 we expect target_difficulty in {3, 4}).
    if snap["math"]["target_difficulty"] >= 3:
        ok(f"math climbed to target={snap['math']['target_difficulty']} after 60 correct outcomes")
    else:
        fail(
            f"math target only reached {snap['math']['target_difficulty']} "
            "after 60 correct outcomes (expected β‰₯ 3)"
        )
        passed = False

    # code should be pinned at 1 (already at floor; can't go lower).
    if snap["code"]["target_difficulty"] == 1:
        ok("code pinned at target=1 under 0% accuracy (floor respected)")
    else:
        fail(f"code drifted to target={snap['code']['target_difficulty']} (expected 1)")
        passed = False

    # Phase B: invert math β€” feed all wrong, expect demotion.
    for _ in range(40):
        env.difficulty_controller.record_outcome("math", correct=False)

    new_math_target = env.difficulty_controller.get_target("math")
    if new_math_target < snap["math"]["target_difficulty"]:
        ok(
            f"math demoted from {snap['math']['target_difficulty']} β†’ "
            f"{new_math_target} after 40 wrong outcomes"
        )
    else:
        fail(
            f"math did not demote: still at {new_math_target} (was "
            f"{snap['math']['target_difficulty']})"
        )
        passed = False

    return passed


# ---------------------------------------------------------------------------
# Test 2 β€” empirical sampling matches the published distribution
# ---------------------------------------------------------------------------


def test_sampling_matches_distribution() -> bool:
    banner("Test 2: empirical sampling matches compute_distribution()")

    ctrl = DifficultyController(["math", "code", "logic"])
    rng = random.Random(20260426)
    n = 5000
    overall = True

    for target in [1, 3, 5]:
        ctrl.state["math"].target_difficulty = target
        expected = compute_distribution(target)
        samples = [ctrl.sample_difficulty("math", rng=rng) for _ in range(n)]
        counts = Counter(samples)

        info(f"target={target}  expected={[f'{p:.3f}' for p in expected]}")
        observed = [counts[d] / n for d in [1, 2, 3, 4, 5]]
        info(f"target={target}  observed={[f'{p:.3f}' for p in observed]}")

        worst = 0.0
        for d in [1, 2, 3, 4, 5]:
            p = expected[d - 1]
            obs = observed[d - 1]
            sigma = math.sqrt(p * (1 - p) / n) if 0 < p < 1 else 0.0
            tol = max(3 * sigma, 0.01)  # 3 sigma OR 1pp, whichever larger
            if abs(obs - p) > tol:
                fail(
                    f"  target={target} d={d}: observed {obs:.4f} vs expected "
                    f"{p:.4f} (delta {abs(obs-p):.4f} > tol {tol:.4f})"
                )
                overall = False
            else:
                worst = max(worst, abs(obs - p))

        if overall:
            ok(f"target={target} matches within {worst:.4f} (3Οƒ tolerance)")

    return overall


# ---------------------------------------------------------------------------
# Test 3 β€” wandb callback injects the right keys into the logs dict
# ---------------------------------------------------------------------------


def test_wandb_callback_injection() -> bool:
    banner("Test 3: DifficultyControllerLogCallback injects the right keys")

    # The callback class is defined in train_grpo.py.  Importing that module
    # has heavy ML dependencies (torch / trl / unsloth) β€” we avoid the import
    # cost here by re-implementing the same shape inline; if it ever
    # diverges, this test would be the canary.
    from server.difficulty import compute_distribution

    class _FakeCallback:
        def __init__(self, controller):
            self.controller = controller

        def on_log(self, args, state, control, logs=None, **kwargs):
            if logs is None:
                return
            snap = self.controller.snapshot()
            for domain, s in snap.items():
                logs[f"difficulty/{domain}/target"] = s["target_difficulty"]
                logs[f"difficulty/{domain}/rolling_acc"] = (
                    s["rolling_accuracy"] if s["rolling_accuracy"] is not None else 0.0
                )
                dist = s["distribution"]
                logs[f"difficulty/{domain}/dist_d1"] = dist[0]
                logs[f"difficulty/{domain}/dist_d3"] = dist[2]
                logs[f"difficulty/{domain}/dist_d5"] = dist[4]

    # Try to import the *real* callback first; fall back to the fake if the
    # heavy deps are missing.
    callback_cls = None
    try:
        from training.train_grpo import DifficultyControllerLogCallback as _Real

        callback_cls = _Real
        info("using real DifficultyControllerLogCallback from training.train_grpo")
    except Exception as exc:
        info(f"real callback import skipped ({type(exc).__name__}); using inline shim")
        callback_cls = _FakeCallback

    ctrl = DifficultyController(["math", "code", "logic"])
    # Populate a non-trivial state so the keys are interesting.
    for _ in range(20):
        ctrl.record_outcome("math", correct=True)
    cb = callback_cls(ctrl)

    logs: dict = {"loss": 0.42}  # pretend TRL handed us a logs dict
    cb.on_log(args=None, state=None, control=None, logs=logs)

    expected_keys = {
        f"difficulty/{d}/{k}"
        for d in ("math", "code", "logic")
        for k in ("target", "rolling_acc", "dist_d1", "dist_d3", "dist_d5")
    }
    missing = expected_keys - logs.keys()
    if missing:
        fail(f"callback did not inject keys: {sorted(missing)}")
        return False
    ok(f"all 15 difficulty/* keys present in logs (math target = {logs['difficulty/math/target']})")

    # Sanity-check a couple of values.
    if logs["difficulty/math/target"] != 2:
        fail(f"math target should be 2 after 20 correct, got {logs['difficulty/math/target']}")
        return False
    ok("math target=2 after 20 correct outcomes (one cooldown-elapsed promotion)")

    dist = compute_distribution(2)
    for d_idx, key in [(0, "dist_d1"), (2, "dist_d3"), (4, "dist_d5")]:
        if abs(logs[f"difficulty/math/{key}"] - dist[d_idx]) > 1e-9:
            fail(f"math {key} mismatch")
            return False
    ok("distribution values in logs match compute_distribution(2)")

    return True


# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------


def main() -> int:
    results = {
        "live_curriculum": test_live_curriculum(),
        "sampling_distribution": test_sampling_matches_distribution(),
        "wandb_callback": test_wandb_callback_injection(),
    }

    banner("Summary")
    for name, passed in results.items():
        status = f"{GREEN}PASS{RESET}" if passed else f"{RED}FAIL{RESET}"
        print(f"  {status}  {name}")

    if all(results.values()):
        print(f"\n{GREEN}{BOLD}All controller verifications passed.{RESET} Safe to start training.")
        return 0
    print(f"\n{RED}{BOLD}One or more checks failed.{RESET} Investigate before training.")
    return 1


if __name__ == "__main__":
    sys.exit(main())