File size: 15,587 Bytes
7a55e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
"""Tests for the held-out collapse kill-switch (HeldOutGuard).

CPU-only, pure-Python — no torch, no cloud. Mirrors the
``datagen/tests/test_feature_deletion.py`` style (small helpers, behavioral
asserts). Covers:
  - no-halt on a healthy co-rising run (the held-out-twin "within noise" case);
  - HALT on the canonical signature: held-out declines while in-loop rises;
  - HALT on KL-to-init hard-stop breach;
  - HALT on a fast proxy-real Hacking-Gap blowout;
  - window / patience behavior (min_steps warm-up; decline_patience streak);
  - calibration tightens-only;
  - idempotent query + latched-fire edge cases.
"""
from __future__ import annotations

import pytest

from composer_replication.safety import (
    CollapseStopError,
    HeldOutGuard,
    TripwireStatus,
    kl_token_trust_filter,
)


def _guard(**kw) -> HeldOutGuard:
    # Small min_steps keeps tests fast while still exercising the warm-up gate.
    base = dict(min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08)
    base.update(kw)
    return HeldOutGuard(**base)


# --- healthy run: both rise => never halt -----------------------------------

def test_no_halt_when_both_rise():
    """Clean run: in-loop and held-out rise together, KL stays in band. The
    held-out twin scores within noise of the proxy => no fire (the well-behaved
    case the literature says a clean model exhibits)."""
    g = _guard()
    status = None
    for i in range(30):
        status = g.update(
            i,
            in_loop_reward=0.30 + 0.01 * i,
            heldout_score=0.28 + 0.01 * i,  # tracks proxy within noise
            kl_to_init=0.03,
        )
        assert not status.fire, f"fired unexpectedly at step {i}: {status.reason}"
    assert not g.should_halt()
    # Gap stays near zero because both gained equally.
    assert abs(g.proxy_real_gap()) < 0.05


# --- canonical signature: held-out declines while in-loop rises -------------

def test_halt_on_heldout_declines_while_reward_rises():
    g = _guard(max_proxy_real_gap=10.0)  # disable gap-blowout path to isolate (a)
    # Warm up past min_steps with a stable healthy stretch.
    for i in range(6):
        s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
        assert not s.fire
    # Now: proxy reward climbs, held-out eval falls — the reward-hacking
    # fingerprint. Should fire once the decline streak hits decline_patience (3).
    fired_at = None
    for j, i in enumerate(range(6, 12)):
        s = g.update(
            i,
            in_loop_reward=0.40 + 0.05 * (j + 1),   # rising
            heldout_score=0.40 - 0.05 * (j + 1),    # declining
            kl_to_init=0.03,                          # KL stays in band
        )
        if s.fire:
            fired_at = i
            break
    assert fired_at is not None, "tripwire never fired on the collapse signature"
    assert g.should_halt()
    s = g.last_status
    assert "held-out" in s.reason and "consecutive" in s.reason
    assert s.proxy_real_gap > 0.0  # proxy gained while real lost


def test_does_not_fire_before_patience_window():
    """Held-out declining while in-loop rises for FEWER than decline_patience
    checkpoints must NOT fire (window behavior)."""
    g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
    for i in range(6):
        g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
    # Only 2 divergent checkpoints (< patience of 3) => no fire.
    s1 = g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
    s2 = g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
    assert not s1.fire and not s2.fire


def test_decline_streak_resets_on_recovery():
    """A clean checkpoint (held-out recovers) resets the decline streak, so a
    later short divergence does not inherit prior declines."""
    g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
    for i in range(6):
        g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
    # 2 declines...
    g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
    g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
    # ...then held-out recovers (resets streak)...
    s = g.update(8, in_loop_reward=0.50, heldout_score=0.45, kl_to_init=0.03)
    assert not s.fire
    # ...one more decline is only streak=1, still below patience.
    s = g.update(9, in_loop_reward=0.55, heldout_score=0.40, kl_to_init=0.03)
    assert not s.fire


# --- KL hard-stop ------------------------------------------------------------

def test_halt_on_kl_hard_stop_breach():
    g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
    # Healthy KL through the warm-up; both metrics flat so only KL can fire.
    for i in range(5):
        s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
        assert not s.fire
    # KL spikes well above 0.08; EMA climbs across a couple steps then crosses.
    fired = False
    for i in range(5, 12):
        s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.20)
        if s.fire:
            fired = True
            assert "kl_to_init" in s.reason and "hard stop" in s.reason
            break
    assert fired, "KL hard-stop never fired despite KL EMA crossing the ceiling"


def test_kl_none_never_fires_kl_path():
    """If the caller never supplies kl_to_init, the KL path must be inert (and
    kl_ema stays None) — KL is an optional float."""
    g = _guard(max_proxy_real_gap=10.0)
    s = None
    for i in range(20):
        s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=None)
    assert s is not None and not s.fire
    assert s.kl_ema is None


# --- proxy-real gap blowout (fast divergence) -------------------------------

def test_halt_on_proxy_real_gap_blowout():
    """A large single-generation divergence (proxy jumps, real stays flat) fires
    via the gap-blowout path even before the decline streak reaches patience."""
    g = _guard(max_proxy_real_gap=0.10, decline_patience=100)  # disable (a)
    for i in range(5):
        g.update(i, in_loop_reward=0.30, heldout_score=0.30, kl_to_init=0.03)
    # Proxy blows up; held-out flat. With ema_alpha=0.5 the gap crosses 0.10 fast.
    fired = False
    for i in range(5, 12):
        s = g.update(i, in_loop_reward=0.90, heldout_score=0.30, kl_to_init=0.03)
        if s.fire:
            fired = True
            assert "Hacking Gap" in s.reason
            assert s.proxy_real_gap > 0.10
            break
    assert fired, "gap-blowout tripwire never fired"


# --- warm-up window (min_steps) ---------------------------------------------

def test_respects_min_steps_no_early_fire():
    """Even with every signal tripped, no fire before min_steps (avoids halting
    on early-run noise)."""
    g = _guard(min_steps=10, decline_patience=2, kl_hard_stop=0.08,
               max_proxy_real_gap=0.01)
    # Egregiously bad signals from step 0: KL huge, proxy up, held-out down.
    for i in range(9):  # 9 updates, all < min_steps=10
        s = g.update(i, in_loop_reward=0.10 + 0.1 * i, heldout_score=0.90 - 0.1 * i,
                     kl_to_init=0.9)
        assert not s.fire, f"fired during warm-up at step {i}: {s.reason}"
    # The 10th update (n==10, not < min_steps) is now allowed to fire.
    s = g.update(9, in_loop_reward=1.5, heldout_score=0.0, kl_to_init=0.9)
    assert s.fire


# --- calibration -------------------------------------------------------------

def test_calibrate_kl_threshold_tightens_only():
    g = _guard(kl_hard_stop=0.08)
    # Baseline mean 0.01 => 3x = 0.03 < 0.08 => tightens to 0.03.
    new = g.calibrate_kl_threshold([0.008, 0.010, 0.012], factor=3.0)
    assert new == pytest.approx(0.03, abs=1e-9)
    assert g.kl_hard_stop == pytest.approx(0.03, abs=1e-9)


def test_calibrate_never_loosens_past_band():
    g = _guard(kl_hard_stop=0.08)
    # A drifting baseline (mean 0.05 => 3x = 0.15) must NOT loosen past 0.08.
    new = g.calibrate_kl_threshold([0.05, 0.05, 0.05], factor=3.0)
    assert new == pytest.approx(0.08, abs=1e-9)
    assert g.kl_hard_stop == pytest.approx(0.08, abs=1e-9)


def test_calibrate_empty_raises():
    g = _guard()
    with pytest.raises(ValueError, match="non-empty"):
        g.calibrate_kl_threshold([])


# --- proxy_real_gap definition ----------------------------------------------

def test_proxy_real_gap_is_gain_difference():
    g = _guard(min_steps=100, max_proxy_real_gap=10.0)  # disable firing
    g.update(0, in_loop_reward=0.20, heldout_score=0.20, kl_to_init=0.02)  # baseline
    # With ema_alpha=0.5 the second sample moves each EMA halfway.
    g.update(1, in_loop_reward=0.60, heldout_score=0.30, kl_to_init=0.02)
    # in_loop EMA: 0.5*0.20 + 0.5*0.60 = 0.40; gain = 0.40-0.20 = 0.20
    # heldout EMA: 0.5*0.20 + 0.5*0.30 = 0.25; gain = 0.25-0.20 = 0.05
    # gap = 0.20 - 0.05 = 0.15
    assert g.proxy_real_gap() == pytest.approx(0.15, abs=1e-9)


def test_proxy_real_gap_zero_before_update():
    g = _guard()
    assert g.proxy_real_gap() == 0.0


# --- idempotency / edge cases -----------------------------------------------

def test_should_halt_is_idempotent_query():
    g = _guard(max_proxy_real_gap=10.0)
    for i in range(6):
        g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
    # Querying repeatedly must not advance state or change the verdict.
    snap_gap = g.proxy_real_gap()
    assert g.should_halt() is False
    assert g.should_halt() is False
    assert g.proxy_real_gap() == snap_gap  # unchanged by querying
    assert g.last_status is not None and not g.last_status.fire


def test_fire_is_latched():
    """Once fired, a subsequent recovery cannot silently un-halt the run."""
    g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
    for i in range(5):
        g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
    # Drive a KL breach.
    fired = False
    for i in range(5, 12):
        s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
        if s.fire:
            fired = True
            break
    assert fired
    # Now KL recovers to healthy — verdict must stay fired (latched).
    s = g.update(99, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.01)
    assert s.fire and s.reason.startswith("latched:")
    assert g.should_halt()


def test_raise_if_fired_raises_typed_exception():
    g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
    for i in range(5):
        g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
    status = None
    for i in range(5, 12):
        status = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
        if status.fire:
            break
    assert status is not None and status.fire
    with pytest.raises(CollapseStopError) as exc:
        g.raise_if_fired(status)
    assert exc.value.status is status
    assert isinstance(str(exc.value), str) and str(exc.value)


def test_raise_if_fired_noop_when_clean():
    g = _guard(max_proxy_real_gap=10.0)
    s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
    # No fire => no raise (uses last_status when arg omitted).
    g.raise_if_fired(s)
    g.raise_if_fired()


def test_status_halt_alias_matches_fire():
    g = _guard(max_proxy_real_gap=10.0)
    s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
    assert s.halt == s.fire is False
    assert isinstance(s, TripwireStatus)


def test_non_contiguous_round_idx_uses_internal_counter():
    """min_steps gates on the internal update counter, not round_idx, so a caller
    that logs sparse / non-contiguous round indices still warms up correctly."""
    g = _guard(min_steps=3, max_proxy_real_gap=0.01, decline_patience=1)
    # Pass huge round_idx values; only the 3rd UPDATE clears warm-up.
    g.update(1000, in_loop_reward=0.10, heldout_score=0.90, kl_to_init=0.9)
    g.update(2000, in_loop_reward=0.50, heldout_score=0.50, kl_to_init=0.9)
    s = g.update(3000, in_loop_reward=0.90, heldout_score=0.10, kl_to_init=0.9)
    assert s.fire  # 3rd update, n==3 not < min_steps


# --- config validation -------------------------------------------------------

def test_bad_ema_alpha_rejected():
    with pytest.raises(ValueError, match="ema_alpha"):
        HeldOutGuard(ema_alpha=1.0)
    with pytest.raises(ValueError, match="ema_alpha"):
        HeldOutGuard(ema_alpha=-0.1)


def test_bad_kl_hard_stop_rejected():
    with pytest.raises(ValueError, match="kl_hard_stop"):
        HeldOutGuard(kl_hard_stop=0.0)


def test_bad_decline_patience_rejected():
    with pytest.raises(ValueError, match="decline_patience"):
        HeldOutGuard(decline_patience=0)


# --- kl_token_trust_filter helper -------------------------------------------

def test_kl_token_trust_filter_masks_above_threshold():
    # 0.5 * logratio^2; mask when it exceeds the per-token KL ceiling.
    assert kl_token_trust_filter(0.20, threshold=0.08) is True   # too large -> mask
    assert kl_token_trust_filter(0.05, threshold=0.08) is False  # within trust region
    assert kl_token_trust_filter(0.08, threshold=0.08) is False  # boundary, not masked


# --- R4: calibrate_kl_threshold input guards (negative factor / baseline) -----

def test_calibrate_rejects_nonpositive_factor():
    """R4: a factor<=0 would make calibrated<=0 and min(<=0, 0.08)<=0, after
    which the KL tripwire fires on every healthy step. Reject it loudly."""
    g = _guard()
    with pytest.raises(ValueError, match="factor must be > 0"):
        g.calibrate_kl_threshold([0.01, 0.02], factor=-3.0)
    with pytest.raises(ValueError, match="factor must be > 0"):
        g.calibrate_kl_threshold([0.01, 0.02], factor=0.0)


def test_calibrate_rejects_negative_baseline_kl():
    """R4: KL is non-negative by definition; a negative baseline is nonsensical
    and could invert the ceiling. Reject it."""
    g = _guard()
    with pytest.raises(ValueError, match="non-negative"):
        g.calibrate_kl_threshold([0.01, -0.5, 0.02])


def test_calibrate_never_yields_nonpositive_threshold():
    """R4: even an all-zero baseline (mean 0) must leave a positive ceiling so a
    later positive KL doesn't fire spuriously."""
    g = _guard()
    out = g.calibrate_kl_threshold([0.0, 0.0, 0.0], factor=3.0)
    assert out > 0.0
    assert g.kl_hard_stop > 0.0


# --- R10: path-(c) gap-blowout is a divergence-RATE gate, not a real-decline --

def test_gap_blowout_fires_even_when_real_still_rising():
    """R10: path (c) fires when the proxy gain outpaces the real gain beyond the
    ceiling EVEN WHILE the held-out (real) score is still genuinely RISING. This
    is INTENTIONAL — path (c) is a divergence-RATE gate (fast single-generation
    hacking), distinct from path (a)'s real-decline streak. Locking the intended
    behavior so a future change can't silently turn it into a real-decline gate."""
    g = _guard(max_proxy_real_gap=0.1, decline_patience=99)  # isolate path (c) from (a)
    status = None
    for i in range(8):
        status = g.update(
            i,
            in_loop_reward=0.30 + 0.20 * i,   # proxy sprints
            heldout_score=0.30 + 0.01 * i,    # real still rising, but slowly
            kl_to_init=0.02,
        )
    assert status.fire, "path (c) should fire on a fast proxy/real divergence"
    assert "gap" in status.reason.lower()
    # And the real score WAS rising the whole time (not a decline-driven fire).
    assert status.heldout_ema > g._fold(None, 0.30)  # type: ignore[attr-defined]