File size: 11,968 Bytes
3040767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
"""Adaptive difficulty management for the HONEST environment.

The rolling accuracy window looks at records in ``state.episode_history``
that have a ``"domain"`` key and a ``"correct"`` key.  Those records are
owned and written **exclusively by** ``HonestEnvironment.step()`` β€” this
module is a pure analyser and mutates only ``state.domain_difficulties``.

``update_difficulty`` is now a **pure side-effect-free function** w.r.t.
history: it reads history, optionally updates the difficulty scalar, and
returns ``(new_difficulty, changed)``.  The caller (environment) is
responsible for flagging the relevant history record with
``"difficulty_changed": True`` if it wishes to track transitions.
"""

import random
from collections import deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

from models.models import HonestState

# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

WINDOW: int = 20              # rolling accuracy window (episodes per domain)
HIGH_THRESHOLD: float = 0.70  # accuracy above this β†’ increase difficulty
LOW_THRESHOLD:  float = 0.30  # accuracy below this β†’ decrease difficulty
MIN_DIFFICULTY: int = 1
MAX_DIFFICULTY: int = 5
HYSTERESIS_EPISODES: int = 10  # min episodes between consecutive changes


# ---------------------------------------------------------------------------
# Adaptive sampling distribution: static floor + triangular overlay
# ---------------------------------------------------------------------------

# Always-on weight per difficulty 1..5 β€” protects against catastrophic
# forgetting of easy-problem competence as the curriculum advances.
STATIC_FLOOR: List[float] = [0.20, 0.15, 0.10, 0.05, 0.00]  # sums to 0.50

# Remaining weight (0.50) is distributed by a triangular kernel around
# the controller's current target_difficulty.
ADAPTIVE_BUDGET: float = 0.50


def triangular_overlay(target: int, total_weight: float = ADAPTIVE_BUDGET) -> List[float]:
    """Triangular distribution centered at ``target``, summing to ``total_weight``.

    Difficulties are 1-indexed; returns a 5-element list.
    Kernel: ``max(0, 3 - |target - d|)`` over d in [1..5], then renormalised
    to ``total_weight``.  At the edges (target=1 or target=5) the kernel is
    clipped, so less mass lands on phantom out-of-range difficulties β€” but
    the surviving mass is still renormalised so the overlay always sums to
    ``total_weight``.
    """
    raw = [max(0, 3 - abs(target - d)) for d in range(1, 6)]
    s = sum(raw)
    if s == 0:
        return [0.0] * 5
    return [r * total_weight / s for r in raw]


def compute_distribution(target_difficulty: int) -> List[float]:
    """Static floor + adaptive overlay.  Returns weights for difficulties 1..5.

    The result is renormalised to sum to exactly 1.0 to absorb floating-point
    drift, so callers can pass it directly to ``random.choices`` weights.
    """
    overlay = triangular_overlay(target_difficulty)
    distribution = [STATIC_FLOOR[i] + overlay[i] for i in range(5)]
    total = sum(distribution)
    return [d / total for d in distribution]


# ---------------------------------------------------------------------------
# DomainState + DifficultyController
# ---------------------------------------------------------------------------


@dataclass
class DomainState:
    """Per-domain state held by ``DifficultyController``."""
    target_difficulty: int = 1
    rolling_window: deque = field(default_factory=lambda: deque(maxlen=20))
    episodes_since_last_update: int = 0


class DifficultyController:
    """Adaptive difficulty controller with a static floor.

    Per-domain state: rolling 20-episode accuracy window, a target difficulty
    scalar, and a cooldown counter.  Hysteresis thresholds 75 / 25, cooldown
    of 10 outcomes per domain.

    Lifetime: one instance per ``HonestEnvironment`` (or one per training
    process for the local-rollout path).  The controller persists across
    episode boundaries β€” its state is *not* reset by ``env.reset()``.
    """

    UPDATE_THRESHOLD_UP: float = 0.75
    UPDATE_THRESHOLD_DOWN: float = 0.25
    COOLDOWN_EPISODES: int = 10
    WINDOW_SIZE: int = 20
    DIFFICULTY_MIN: int = MIN_DIFFICULTY
    DIFFICULTY_MAX: int = MAX_DIFFICULTY

    def __init__(self, domains: List[str], initial_target: int = 1) -> None:
        self.domains = list(domains)
        self.state: Dict[str, DomainState] = {
            d: DomainState(target_difficulty=initial_target) for d in self.domains
        }

    # --- sampling ----------------------------------------------------------

    def sample_difficulty(
        self,
        domain: str,
        rng: Optional[random.Random] = None,
    ) -> int:
        """Sample a difficulty 1..5 for ``domain`` using the current distribution."""
        target = self.state[domain].target_difficulty
        weights = compute_distribution(target)
        chooser = rng if rng is not None else random
        return chooser.choices([1, 2, 3, 4, 5], weights=weights, k=1)[0]

    # --- outcome tracking --------------------------------------------------

    def record_outcome(self, domain: str, correct: bool) -> Tuple[int, bool]:
        """Record an episode outcome.

        Returns ``(new_target_difficulty, did_update)``.  Should be called
        only with True/False β€” abstain / malformed (correct=None) episodes
        must NOT enter the rolling window.
        """
        s = self.state[domain]
        s.rolling_window.append(1 if correct else 0)
        s.episodes_since_last_update += 1

        did_update = False
        if (
            s.episodes_since_last_update >= self.COOLDOWN_EPISODES
            and len(s.rolling_window) >= self.WINDOW_SIZE
        ):
            accuracy = sum(s.rolling_window) / len(s.rolling_window)

            if (
                accuracy >= self.UPDATE_THRESHOLD_UP
                and s.target_difficulty < self.DIFFICULTY_MAX
            ):
                s.target_difficulty += 1
                did_update = True
            elif (
                accuracy <= self.UPDATE_THRESHOLD_DOWN
                and s.target_difficulty > self.DIFFICULTY_MIN
            ):
                s.target_difficulty -= 1
                did_update = True

            if did_update:
                # Reset the cooldown but keep the rolling window β€” the new
                # target's accuracy estimate phases in as fresh outcomes flow.
                s.episodes_since_last_update = 0

        return s.target_difficulty, did_update

    # --- introspection -----------------------------------------------------

    def get_distribution(self, domain: str) -> List[float]:
        return compute_distribution(self.state[domain].target_difficulty)

    def get_target(self, domain: str) -> int:
        return self.state[domain].target_difficulty

    def get_rolling_accuracy(self, domain: str) -> Optional[float]:
        s = self.state[domain]
        if len(s.rolling_window) == 0:
            return None
        return sum(s.rolling_window) / len(s.rolling_window)

    def snapshot(self) -> Dict[str, dict]:
        """Return a JSON-serialisable snapshot for logging / debugging."""
        return {
            d: {
                "target_difficulty": s.target_difficulty,
                "rolling_accuracy": self.get_rolling_accuracy(d),
                "episodes_since_update": s.episodes_since_last_update,
                "window_full": len(s.rolling_window) == self.WINDOW_SIZE,
                "window_size": len(s.rolling_window),
                "distribution": self.get_distribution(d),
            }
            for d, s in self.state.items()
        }


# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

def _domain_records(state: HonestState, domain: str) -> list[dict]:
    """Return the last WINDOW episode records for *domain* from history.

    Only records that contain both 'domain' and 'correct' keys are
    considered (i.e. the rich records written by the environment, not
    any stale auxiliary records).
    """
    records = [
        r for r in state.episode_history
        if r.get("domain") == domain and "correct" in r
    ]
    return records[-WINDOW:]


def _last_change_episode(state: HonestState, domain: str) -> int:
    """Return the global episode index of the most recent difficulty change
    for *domain*.

    Scans ``episode_history`` backwards for a record flagged with
    ``"difficulty_changed": True`` for the given domain.
    Returns 0 if no change has ever occurred (safe to change from episode 0).
    """
    for idx in range(len(state.episode_history) - 1, -1, -1):
        r = state.episode_history[idx]
        if r.get("domain") == domain and r.get("difficulty_changed"):
            return idx
    return 0


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def get_rolling_accuracy(state: HonestState, domain: str) -> float:
    """Return the rolling accuracy (0.0–1.0) for *domain* over the last
    WINDOW answered episodes.

    Episodes where ``correct`` is ``None`` (abstain / malformed) are treated
    as incorrect.  Returns 0.5 (neutral) when there are no episodes yet.
    """
    records = _domain_records(state, domain)
    if not records:
        return 0.5  # neutral default β€” no change triggered
    correct_count = sum(1 for r in records if r.get("correct") is True)
    return correct_count / len(records)


def update_difficulty(
    state: HonestState,
    last_correctness: Optional[bool],
    domain: Optional[str] = None,
) -> Tuple[int, bool]:
    """Evaluate rolling accuracy and adjust ``state.domain_difficulties``
    in-place if a threshold is crossed.

    Parameters
    ----------
    state:
        The current environment state (mutated in-place for the difficulty
        scalar only β€” history is **not** touched here).
    last_correctness:
        Whether the most recent answer was correct.  ``None`` for
        abstain / malformed answers (counted as incorrect).
    domain:
        Override for the active domain.  Defaults to ``state.current_domain``.

    Returns
    -------
    (new_difficulty, changed) where ``changed`` is True iff the difficulty
    scalar was actually modified this call.  The caller can use this to
    stamp ``"difficulty_changed": True`` onto its own history record.
    """
    if domain is None:
        domain = state.current_domain

    current_difficulty = state.domain_difficulties.get(domain, 1)

    # We need to know how many history entries exist *before* the caller
    # appends the current step.  The caller must append its rich record
    # *before* calling this function so rolling accuracy includes the
    # current result.
    global_episode_idx = len(state.episode_history) - 1  # 0-indexed last item

    # --- compute rolling accuracy (includes the just-appended record) ---
    accuracy = get_rolling_accuracy(state, domain)

    # --- hysteresis guard ---
    last_change = _last_change_episode(state, domain)
    episodes_since_change = global_episode_idx - last_change
    if episodes_since_change < HYSTERESIS_EPISODES:
        return current_difficulty, False  # too soon to change

    # --- apply threshold rules ---
    new_difficulty = current_difficulty
    if accuracy > HIGH_THRESHOLD:
        new_difficulty = min(current_difficulty + 1, MAX_DIFFICULTY)
    elif accuracy < LOW_THRESHOLD:
        new_difficulty = max(current_difficulty - 1, MIN_DIFFICULTY)

    changed = new_difficulty != current_difficulty
    if changed:
        state.domain_difficulties[domain] = new_difficulty

    return new_difficulty, changed