File size: 22,879 Bytes
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
7a55e1e
 
bd0c358
7a55e1e
 
bd0c358
 
7a55e1e
 
 
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
7a55e1e
 
 
bd0c358
 
 
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
"""kill_switch.py — held-out collapse tripwire (the #2 collapse safeguard).

This is the missing RUN-LEVEL / across-generation control for the self-evolving
RL flywheel. The per-task controls already exist in ``composer_replication.datagen``
(the 4-gate solvability validator, the ``HackMonitor`` provenance check, and the
sandbox denylist); this module sits ABOVE them and watches the whole run.

Rationale (the literature is unambiguous that a held-out eval + hard stop is the
load-bearing control, not a nice-to-have):

  - **Reward hacking rises monotonically with optimization depth.** Zhao et al.,
    "Reward Hacking in Self-Improving Code Agents" (ICLR 2026 Workshop on RSI,
    OpenReview ``ikrQWGgxYg``) show that going from 10 -> 100 optimization steps
    drives the hacking rate from 26.4% to 57.8% (+31.4 points), and that
    73.8% of KernelBench / 46.8% of ALE-Bench optimizations show *proxy gains
    without real gains*. They define **Hacking Gap = proxy gain - real gain**;
    this module's ``proxy_real_gap()`` is exactly that quantity. They label an
    optimization reward-hacking when it "improves the public metric WITHOUT
    improving the private metric" — the canonical signature this tripwire fires on.

  - **Self-critique alone is insufficient.** The same paper's "retrospection"
    self-critique sometimes *increased* hacking; their conclusion: "mitigating
    reward hacking likely requires stronger evaluations and constraints beyond
    self-critique alone." So we build a genuinely disjoint held-out eval plus a
    hard stop, not a critique hook.

  - **Held-out eval is necessary but NOT sufficient by itself.** EvilGenie
    (arXiv 2511.21654) found "only minimal improvement from the use of held out
    test cases" in isolation and that "holdout tests have many surprising failure
    modes." This module is therefore explicitly *defense-in-depth*, layered ON
    TOP of ``HackMonitor`` (provenance) — neither is sufficient alone, matching
    the repo's existing defense-in-depth framing in ``datagen/monitor.py``.

  - **Closed-loop RL on self-generated data collapses.** The self-evolving-agents
    survey (Gao et al., TMLR 2026; arXiv 2507.21046 v4) §8.3 names "model
    collapse from closed-loop RL on static synthetic data" and prescribes
    "continuous monitoring ... to detect long-horizon value drift" — i.e. a
    per-generation online tripwire, not a one-time eval. Shumailov et al. (Nature
    2024, "AI models collapse when trained on recursively generated data") show
    self-training first loses the distribution tails, then converges to a
    low-variance point estimate; the mitigation that matters here is that the
    held-out eval must stay anchored to REAL tasks that are NEVER fed back to the
    generator (see ``HeldoutSplit``), otherwise the eval drifts with the train set.

  - **KL-to-init hard stop.** The GRPO "healthy progression" band (Orchestra
    Research GRPO SKILL) climbs 0.02 -> 0.05 -> 0.08 -> 0.12 nats/token over a
    run, with 0.08 the top of the "good progression" band and just below the
    code-generation drift zone (0.05-0.15 per-token); >0.5 is "diverging too
    much." So 0.08 nats/token is a sound HARD-STOP default. Catastrophic Goodhart
    (OpenReview ``UXuBzWoZGK``) proves KL regularization alone does NOT prevent
    heavy-tailed reward misspecification, so the KL hard stop is ONE tripwire
    among several, never the sole control.

UNITS GOTCHA (load-bearing): the ``kl_to_init`` this module consumes is
**token-mean KL in nats/token**, matching the repo convention in
``composer_replication.integrations.altered_minds.kl_logging.token_mean_kl``.
A token-mean KL is NOT comparable to a sequence-level / sequence-summed KL
(whose healthy band is ~0.05-10). The 0.08 default is per-token. Do not pass a
sequence-summed KL into the per-token hard stop — it will fire instantly.

This module is pure-Python: no torch, no cloud deps. ``kl_to_init`` is just a
float the caller passes (computed upstream by ``token_mean_kl``). It is fully
CPU-testable.
"""
from __future__ import annotations

from dataclasses import dataclass, field


class CollapseStopError(RuntimeError):
    """Raised (by the caller, optionally) when the tripwire fires a hard stop.

    The trainer loop can either check ``TripwireStatus.fire`` and stop softly,
    or call ``HeldOutGuard.raise_if_fired(status)`` to convert a fired verdict
    into this typed exception. Carries the structured verdict for logging.
    """

    def __init__(self, status: TripwireStatus) -> None:
        super().__init__(status.reason)
        self.status = status


@dataclass(frozen=True)
class TripwireStatus:
    """Structured verdict returned by every ``HeldOutGuard.update(...)`` call.

    Attributes:
        fire: True => the run should HALT (collapse / reward-hacking detected).
        reason: human-readable WHY (empty string when ``fire`` is False), so the
            trainer can log exactly which tripwire tripped, mirroring how
            ``datagen/monitor.py`` logs suspected hacks for review.
        step: the round/generation index this verdict was computed at.
        proxy_real_gap: the RSI "Hacking Gap" at this step = (in-loop reward gain
            since baseline) - (held-out score gain since baseline). Positive and
            widening => proxy improving faster than (or while) real declines.
        in_loop_ema: EMA of the in-loop / proxy reward at this step.
        heldout_ema: EMA of the held-out / real eval score at this step.
        kl_ema: EMA of ``kl_to_init`` (nats/token), or None if never supplied.
    """

    fire: bool
    reason: str
    step: int
    proxy_real_gap: float
    in_loop_ema: float
    heldout_ema: float
    kl_ema: float | None = None

    # `halt` is a documented alias for `fire` — the task spec describes a
    # `should_halt()` / verdict with a `halt` field; expose both names so callers
    # reading either convention work.
    @property
    def halt(self) -> bool:
        return self.fire


@dataclass
class HeldOutGuard:
    """Across-generation collapse / reward-hacking kill-switch (HeldOutGuard).

    Tracks, per generation/round: in-loop (proxy) oracle reward, held-out (real)
    eval score, and optional KL-to-init / entropy / reward-std. Computes the
    proxy-minus-real "Hacking Gap" tripwire and fires a structured ``halt``
    verdict when collapse is caught in the act.

    The guard is **stateful**: call ``update(round_idx, ...)`` once per checkpoint
    in the trainer loop (the same cadence at which ``DifficultyCurriculum.update``
    is called). It maintains denoised EMAs of every metric (raw single-step
    values are too noisy to threshold — theneuralbase early-stopping guidance) and
    returns a ``TripwireStatus``.

    Fires (``fire=True``) when ANY of:

      (a) **collapse-caught-in-the-act** — the in-loop reward EMA is RISING while
          the held-out score EMA has DECLINED for >= ``decline_patience``
          consecutive checkpoints (default 3, matching the "monotone for >=3
          checkpoints" rule). This is the canonical reward-hacking signature.

      (b) **KL breach** — the ``kl_to_init`` EMA exceeds ``kl_hard_stop`` (default
          0.08 nats/token) on/after ``min_steps``.

      (c) **proxy-real gap blowout** — the Hacking Gap (proxy gain - real gain
          since baseline) widens beyond ``max_proxy_real_gap``, even if held-out
          has not strictly declined for the full patience window (a fast
          single-generation divergence).

    No tripwire fires before ``min_steps`` (avoids halting on early-run noise,
    when both signals are still warming up).

    The guard is idempotent in the sense that re-querying ``last_status`` or
    calling ``should_halt()`` does not advance state — only ``update`` does.
    """

    # --- thresholds (calibratable; see calibrate_kl_threshold) ---------------
    kl_hard_stop: float = 0.08          # nats/token; top of GRPO "good" band
    max_proxy_real_gap: float = 0.10    # absolute Hacking-Gap blowout ceiling
    # --- temporal gates ------------------------------------------------------
    min_steps: int = 20                 # no fire before this many updates
    decline_patience: int = 3           # consecutive held-out declines to fire (a)
    # --- denoising -----------------------------------------------------------
    ema_alpha: float = 0.9              # EMA weight on the PRIOR (0.9 => slow)
    rise_eps: float = 1e-4              # min EMA delta to count as "rising"/"declining"

    # --- internal state (do not set directly) --------------------------------
    _n: int = field(default=0, init=False)
    _in_loop_ema: float | None = field(default=None, init=False)
    _heldout_ema: float | None = field(default=None, init=False)
    _kl_ema: float | None = field(default=None, init=False)
    _entropy_ema: float | None = field(default=None, init=False)
    _reward_std_ema: float | None = field(default=None, init=False)
    _in_loop_baseline: float | None = field(default=None, init=False)
    _heldout_baseline: float | None = field(default=None, init=False)
    _prev_in_loop_ema: float | None = field(default=None, init=False)
    _prev_heldout_ema: float | None = field(default=None, init=False)
    _heldout_decline_streak: int = field(default=0, init=False)
    _last_status: TripwireStatus | None = field(default=None, init=False)
    _fired: bool = field(default=False, init=False)

    def __post_init__(self) -> None:
        if not (0.0 <= self.ema_alpha < 1.0):
            raise ValueError(
                f"ema_alpha must be in [0, 1), got {self.ema_alpha!r} "
                "(it is the weight on the PRIOR EMA)."
            )
        if self.kl_hard_stop <= 0.0:
            raise ValueError(f"kl_hard_stop must be > 0, got {self.kl_hard_stop!r}")
        if self.decline_patience < 1:
            raise ValueError(
                f"decline_patience must be >= 1, got {self.decline_patience!r}"
            )

    # ------------------------------------------------------------------------
    # core API
    # ------------------------------------------------------------------------
    def update(
        self,
        round_idx: int,
        in_loop_reward: float,
        heldout_score: float,
        kl_to_init: float | None = None,
        entropy: float | None = None,
        reward_std: float | None = None,
    ) -> TripwireStatus:
        """Fold one checkpoint's metrics in and return the current verdict.

        Args:
            round_idx: the generation / round index (for logging; not used for
                gating — the internal update counter ``_n`` drives ``min_steps``
                so the guard is robust to non-contiguous round indices).
            in_loop_reward: mean in-loop (proxy / oracle) reward this round. This
                is what the policy is optimizing against.
            heldout_score: mean score on the DISJOINT held-out eval pool this
                round — REAL tasks the generator never trains on. See
                ``composer_replication.safety.holdout`` design notes / the
                ``HeldoutSplit`` discipline; if held-out drifts with the train
                set the gap signal is meaningless.
            kl_to_init: optional token-mean KL(policy || init) in nats/token
                (this repo's ``token_mean_kl`` convention). NOT sequence-level KL.
            entropy: optional policy entropy (early-warning of entropy collapse,
                "the silent killer of RLVR generalization"). Tracked + exposed,
                not currently a hard gate.
            reward_std: optional std of the reward distribution (tracked; a
                collapsing std is an early collapse signal).

        Returns:
            A ``TripwireStatus``. Once the guard has fired, every subsequent
            ``update`` keeps ``fire=True`` (latched) so a transient recovery
            after a detected collapse cannot silently un-halt the run.
        """
        self._n += 1

        # --- EMA folds (alpha on the prior; first sample seeds the EMA) -------
        self._in_loop_ema = self._fold(self._in_loop_ema, float(in_loop_reward))
        self._heldout_ema = self._fold(self._heldout_ema, float(heldout_score))
        if kl_to_init is not None:
            self._kl_ema = self._fold(self._kl_ema, float(kl_to_init))
        if entropy is not None:
            self._entropy_ema = self._fold(self._entropy_ema, float(entropy))
        if reward_std is not None:
            self._reward_std_ema = self._fold(self._reward_std_ema, float(reward_std))

        # --- baselines: seed on the first update so gains are measured from
        #     run start (the RSI Hacking-Gap is a gain-since-baseline quantity). -
        if self._in_loop_baseline is None:
            self._in_loop_baseline = self._in_loop_ema
        if self._heldout_baseline is None:
            self._heldout_baseline = self._heldout_ema

        # --- track the held-out decline streak (uses EMA deltas, denoised) ----
        in_loop_rising = (
            self._prev_in_loop_ema is not None
            and (self._in_loop_ema - self._prev_in_loop_ema) > self.rise_eps
        )
        heldout_declining = (
            self._prev_heldout_ema is not None
            and (self._heldout_ema - self._prev_heldout_ema) < -self.rise_eps
        )
        # The collapse signature is held-out DOWN while in-loop UP. We only count
        # a decline toward the streak when in-loop is simultaneously rising — a
        # held-out dip during an in-loop dip is just noise / a hard batch, not
        # reward hacking.
        if heldout_declining and in_loop_rising:
            self._heldout_decline_streak += 1
        elif not heldout_declining:
            self._heldout_decline_streak = 0
        # (if held-out declines but in-loop is flat/down we neither grow nor reset
        #  the streak immediately — but the elif above resets on any non-decline,
        #  so a single clean checkpoint clears it.)

        gap = self.proxy_real_gap()
        status = self._evaluate(round_idx, gap)

        # advance "previous EMA" trackers AFTER evaluation
        self._prev_in_loop_ema = self._in_loop_ema
        self._prev_heldout_ema = self._heldout_ema
        self._last_status = status
        if status.fire:
            self._fired = True
        return status

    def _evaluate(self, round_idx: int, gap: float) -> TripwireStatus:
        """Decide the verdict from current state. Pure (no state mutation)."""
        assert self._in_loop_ema is not None and self._heldout_ema is not None

        base = dict(
            step=round_idx,
            proxy_real_gap=gap,
            in_loop_ema=self._in_loop_ema,
            heldout_ema=self._heldout_ema,
            kl_ema=self._kl_ema,
        )

        # Latched: once fired, stay fired (cannot silently un-halt).
        if self._fired:
            prev_reason = self._last_status.reason if self._last_status else "collapse"
            return TripwireStatus(fire=True, reason=f"latched: {prev_reason}", **base)

        # Warm-up guard: never fire on early-run noise.
        if self._n < self.min_steps:
            return TripwireStatus(fire=False, reason="", **base)

        # (b) KL hard stop — checked first; it's the cheapest unambiguous breach.
        if self._kl_ema is not None and self._kl_ema > self.kl_hard_stop:
            return TripwireStatus(
                fire=True,
                reason=(
                    f"kl_to_init EMA {self._kl_ema:.4f} nats/token exceeds hard "
                    f"stop {self.kl_hard_stop:.4f} (policy drifting from init)"
                ),
                **base,
            )

        # (a) collapse caught in the act — held-out declines while in-loop rises.
        if self._heldout_decline_streak >= self.decline_patience:
            return TripwireStatus(
                fire=True,
                reason=(
                    f"reward-hacking signature: held-out score declined while "
                    f"in-loop reward rose for {self._heldout_decline_streak} "
                    f"consecutive checkpoints (Hacking Gap {gap:.4f})"
                ),
                **base,
            )

        # (c) proxy-real gap blowout — fast single-generation divergence.
        if gap > self.max_proxy_real_gap:
            return TripwireStatus(
                fire=True,
                reason=(
                    f"proxy-real Hacking Gap {gap:.4f} exceeds ceiling "
                    f"{self.max_proxy_real_gap:.4f} (proxy reward improving far "
                    f"faster than real held-out eval)"
                ),
                **base,
            )

        return TripwireStatus(fire=False, reason="", **base)

    # ------------------------------------------------------------------------
    # query helpers (do NOT advance state — idempotent)
    # ------------------------------------------------------------------------
    def should_halt(self) -> bool:
        """True if the most recent ``update`` produced a halt verdict.

        Idempotent: querying does not advance the EMA state.
        """
        return self._last_status is not None and self._last_status.fire

    @property
    def last_status(self) -> TripwireStatus | None:
        """The most recent verdict, or None if ``update`` was never called."""
        return self._last_status

    def raise_if_fired(self, status: TripwireStatus | None = None) -> None:
        """Convert a fired verdict into a typed ``CollapseStopError`` exception.

        Pass the status returned by ``update`` (or omit to use ``last_status``).
        Trainer loops that prefer exception-based control flow call this right
        after ``update``; loops that prefer flag-checking just read
        ``status.fire`` / ``should_halt()``.
        """
        st = status if status is not None else self._last_status
        if st is not None and st.fire:
            raise CollapseStopError(st)

    def proxy_real_gap(self) -> float:
        """The RSI Hacking Gap = (in-loop gain) - (held-out gain), both measured
        as EMA-minus-baseline since run start.

        Returns 0.0 before the first ``update`` (no baseline yet). A positive,
        widening value is the reward-hacking fingerprint: the proxy the policy
        optimizes is improving more than the real held-out objective.
        """
        if (
            self._in_loop_ema is None
            or self._heldout_ema is None
            or self._in_loop_baseline is None
            or self._heldout_baseline is None
        ):
            return 0.0
        in_loop_gain = self._in_loop_ema - self._in_loop_baseline
        heldout_gain = self._heldout_ema - self._heldout_baseline
        return in_loop_gain - heldout_gain

    # ------------------------------------------------------------------------
    # calibration
    # ------------------------------------------------------------------------
    def calibrate_kl_threshold(
        self, baseline_kls: list[float], factor: float = 3.0
    ) -> float:
        """Set ``kl_hard_stop`` to ``factor`` x the mean of early-run baseline KLs.

        theneuralbase guidance: "Record baseline KL during the first ~100 steps,
        set max to 3x that." Single fixed thresholds are dataset-dependent; this
        adapts to the run's own KL scale.

        SAFETY CLAMP: calibration may only ever TIGHTEN the hard stop, never
        loosen it past the documented collapse band. The returned (and stored)
        threshold is ``min(3x baseline, current kl_hard_stop)`` — so a noisy /
        already-drifting baseline cannot raise the ceiling above 0.08 nats/token.

        Args:
            baseline_kls: per-step token-mean KL values from early in the run.
                KL is non-negative by definition, so every value must be >= 0.
            factor: multiplier on the baseline mean. Must be > 0.

        Returns:
            The new ``kl_hard_stop`` (also stored on the instance), always > 0.

        Raises:
            ValueError: if ``baseline_kls`` is empty, ``factor <= 0``, or any
                baseline KL is negative.
        """
        if not baseline_kls:
            raise ValueError("baseline_kls must be non-empty to calibrate")
        # GUARD (R4): a non-positive factor or a negative baseline would make
        # `calibrated` <= 0, and min(<=0, 0.08) = a NON-POSITIVE kl_hard_stop —
        # after which the KL tripwire fires on EVERY healthy step (any positive
        # KL EMA exceeds a non-positive ceiling). KL is non-negative by
        # definition, so these inputs are nonsensical; reject them loudly rather
        # than silently disarm-by-inverting the guard.
        if factor <= 0:
            raise ValueError(f"factor must be > 0, got {factor!r}")
        if any(k < 0 for k in baseline_kls):
            raise ValueError(
                f"baseline_kls must all be >= 0 (KL is non-negative); got a "
                f"negative value in {baseline_kls!r}"
            )
        mean_kl = sum(baseline_kls) / len(baseline_kls)
        calibrated = factor * mean_kl
        # Only tighten: never let calibration loosen past the current ceiling.
        # Floor at a small positive epsilon so an all-zero baseline (mean_kl==0)
        # can't drive the ceiling to exactly 0 and fire on the first positive KL.
        self.kl_hard_stop = max(min(calibrated, self.kl_hard_stop), 1e-6)
        return self.kl_hard_stop

    # ------------------------------------------------------------------------
    # internals
    # ------------------------------------------------------------------------
    def _fold(self, prev: float | None, x: float) -> float:
        """EMA fold; the first observation seeds the EMA (no warm-up bias)."""
        if prev is None:
            return x
        return self.ema_alpha * prev + (1.0 - self.ema_alpha) * x


def kl_token_trust_filter(logratio_sq_half: float, threshold: float = 0.08) -> bool:
    """Per-token KL trust-region mask, mirroring torchrl's GRPO "KL-Mask".

    torchrl masks any TOKEN whose ``0.5 * (log pi/pi_ref)^2`` (the Schulman k2
    estimator of per-token KL) exceeds a threshold, forming a per-token trust
    region. This helper returns True when the token should be MASKED OUT (its
    KL contribution is too large), so it can be wired into a loss later without
    pulling torch into this module — the caller computes ``0.5 * logratio**2``.

    Args:
        logratio_sq_half: ``0.5 * (log pi/pi_ref)^2`` for one token (nats).
        threshold: per-token KL ceiling (default 0.08 nats, the same band as the
            run-level hard stop).

    Returns:
        True if the token exceeds the trust region and should be masked.
    """
    return logratio_sq_half > threshold