HONEST-RL-Calibrator / data /tests /test_difficulty_controller.py
Rushhaabhhh's picture
HONEST-RL-Calibrator-v0
3040767 verified
"""Tests for the adaptive ``DifficultyController`` in ``server/difficulty.py``.
Run from the project root:
PYTHONPATH=. pytest data/tests/test_difficulty_controller.py -v
"""
from __future__ import annotations
import math
import random
from collections import Counter
import pytest
from server.difficulty import (
ADAPTIVE_BUDGET,
DifficultyController,
STATIC_FLOOR,
compute_distribution,
triangular_overlay,
)
DOMAINS = ["math", "code", "logic"]
# ---------------------------------------------------------------------------
# 1. Distribution sanity
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("target", [1, 2, 3, 4, 5])
def test_distribution_sums_to_one(target):
dist = compute_distribution(target)
assert len(dist) == 5
assert math.isclose(sum(dist), 1.0, abs_tol=1e-9)
@pytest.mark.parametrize("target", [1, 2, 3, 4, 5])
def test_distribution_all_non_negative(target):
dist = compute_distribution(target)
assert all(w >= 0.0 for w in dist)
# ---------------------------------------------------------------------------
# 2. Floor preservation (catastrophic-forgetting protection)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("target", [1, 2, 3, 4, 5])
def test_floor_preserves_d1_minimum(target):
"""Difficulty-1 weight must always be >= the static floor for d1 (0.20)."""
dist = compute_distribution(target)
assert dist[0] >= STATIC_FLOOR[0] - 1e-9, (
f"target={target}: d1 weight {dist[0]:.3f} dropped below floor"
)
@pytest.mark.parametrize("target", [1, 2, 3, 4, 5])
def test_floor_preserves_easy_combined_minimum(target):
"""Combined d1+d2 weight must always be >= 0.35 (the easy floor)."""
dist = compute_distribution(target)
easy = dist[0] + dist[1]
assert easy >= 0.35 - 1e-9, (
f"target={target}: easy floor d1+d2 = {easy:.3f} fell below 0.35"
)
def test_overlay_sums_to_budget():
for t in [1, 2, 3, 4, 5]:
overlay = triangular_overlay(t)
assert math.isclose(sum(overlay), ADAPTIVE_BUDGET, abs_tol=1e-9)
# ---------------------------------------------------------------------------
# 3. Cooldown enforcement
# ---------------------------------------------------------------------------
def test_cooldown_blocks_early_changes():
"""5 correct outcomes → not enough to update (cooldown=10 AND window not full)."""
ctrl = DifficultyController(DOMAINS)
initial = ctrl.get_target("math")
for _ in range(5):
ctrl.record_outcome("math", correct=True)
assert ctrl.get_target("math") == initial
# ---------------------------------------------------------------------------
# 4. Hysteresis up
# ---------------------------------------------------------------------------
def test_hysteresis_up_promotes_after_window_fills():
"""20 correct outcomes — window full, cooldown elapsed, accuracy=1.0 ≥ 0.75."""
ctrl = DifficultyController(DOMAINS)
assert ctrl.get_target("math") == 1
for _ in range(20):
ctrl.record_outcome("math", correct=True)
assert ctrl.get_target("math") == 2
# Cooldown reset by the bump
assert ctrl.state["math"].episodes_since_last_update == 0
# ---------------------------------------------------------------------------
# 5. Hysteresis down
# ---------------------------------------------------------------------------
def test_hysteresis_down_demotes_after_window_fills():
ctrl = DifficultyController(DOMAINS, initial_target=3)
for _ in range(20):
ctrl.record_outcome("math", correct=False)
assert ctrl.get_target("math") == 2
# ---------------------------------------------------------------------------
# 6. Hysteresis dead zone
# ---------------------------------------------------------------------------
def test_hysteresis_dead_zone_stays_put():
"""50% accuracy is in (0.25, 0.75) → no change."""
ctrl = DifficultyController(DOMAINS, initial_target=3)
outcomes = ([True, False] * 10) # 20 outcomes, 50% accuracy
for c in outcomes:
ctrl.record_outcome("math", correct=c)
assert ctrl.get_target("math") == 3
# ---------------------------------------------------------------------------
# 7. Bounds (floor / ceiling)
# ---------------------------------------------------------------------------
def test_target_does_not_drop_below_min():
ctrl = DifficultyController(DOMAINS, initial_target=1)
for _ in range(40):
ctrl.record_outcome("math", correct=False)
assert ctrl.get_target("math") == 1
def test_target_does_not_exceed_max():
ctrl = DifficultyController(DOMAINS, initial_target=5)
for _ in range(40):
ctrl.record_outcome("math", correct=True)
assert ctrl.get_target("math") == 5
# ---------------------------------------------------------------------------
# 8. Per-domain independence
# ---------------------------------------------------------------------------
def test_domains_track_independently():
ctrl = DifficultyController(DOMAINS)
for _ in range(20):
ctrl.record_outcome("math", correct=True)
ctrl.record_outcome("code", correct=False)
assert ctrl.get_target("math") == 2
assert ctrl.get_target("code") == 1 # already at floor — can't drop further
# logic was untouched
assert ctrl.get_target("logic") == 1
assert ctrl.get_rolling_accuracy("logic") is None
# ---------------------------------------------------------------------------
# 9. Empirical sampling matches computed distribution
# ---------------------------------------------------------------------------
def test_sampling_matches_distribution():
ctrl = DifficultyController(DOMAINS, initial_target=3)
rng = random.Random(20260426)
n = 10_000
samples = [ctrl.sample_difficulty("math", rng=rng) for _ in range(n)]
counts = Counter(samples)
expected = compute_distribution(3)
for d in [1, 2, 3, 4, 5]:
observed = counts[d] / n
# 2 std dev for a binomial proportion at n=10k is ~ 2 * sqrt(p*(1-p)/n)
sigma = math.sqrt(expected[d - 1] * (1 - expected[d - 1]) / n)
tol = max(2 * sigma, 0.005)
assert abs(observed - expected[d - 1]) <= tol, (
f"d={d}: empirical {observed:.4f} vs expected {expected[d-1]:.4f} "
f"(tolerance {tol:.4f})"
)
# ---------------------------------------------------------------------------
# 10. Abstain / malformed do NOT pollute the rolling window
# ---------------------------------------------------------------------------
def test_controller_only_records_real_outcomes():
"""Caller must not pass None into record_outcome; window length tracks
only the True/False outcomes that are actually fed in."""
ctrl = DifficultyController(DOMAINS)
for _ in range(3):
ctrl.record_outcome("math", correct=True)
# Simulate that abstain/malformed episodes were skipped by the caller —
# the window should reflect only the 3 real outcomes.
s = ctrl.state["math"]
assert len(s.rolling_window) == 3
assert sum(s.rolling_window) == 3
# Cooldown also only ticks on real outcomes
assert s.episodes_since_last_update == 3
# ---------------------------------------------------------------------------
# Bonus: snapshot shape
# ---------------------------------------------------------------------------
def test_snapshot_contains_expected_keys():
ctrl = DifficultyController(DOMAINS)
snap = ctrl.snapshot()
assert set(snap.keys()) == set(DOMAINS)
for s in snap.values():
assert {
"target_difficulty",
"rolling_accuracy",
"episodes_since_update",
"window_full",
"window_size",
"distribution",
} <= s.keys()
assert len(s["distribution"]) == 5