Reinforcement Learning
Transformers
English
post-training
distillation
agentic-coding
composer-2.5
cursor
kimi-k2
grpo
dapo
diloco
openenv
trl
verl
research
methodology
Instructions to use Codeseys/composer-replication-framework with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Codeseys/composer-replication-framework with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Codeseys/composer-replication-framework", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 15,587 Bytes
7a55e1e bd0c358 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 | """Tests for the held-out collapse kill-switch (HeldOutGuard).
CPU-only, pure-Python — no torch, no cloud. Mirrors the
``datagen/tests/test_feature_deletion.py`` style (small helpers, behavioral
asserts). Covers:
- no-halt on a healthy co-rising run (the held-out-twin "within noise" case);
- HALT on the canonical signature: held-out declines while in-loop rises;
- HALT on KL-to-init hard-stop breach;
- HALT on a fast proxy-real Hacking-Gap blowout;
- window / patience behavior (min_steps warm-up; decline_patience streak);
- calibration tightens-only;
- idempotent query + latched-fire edge cases.
"""
from __future__ import annotations
import pytest
from composer_replication.safety import (
CollapseStopError,
HeldOutGuard,
TripwireStatus,
kl_token_trust_filter,
)
def _guard(**kw) -> HeldOutGuard:
# Small min_steps keeps tests fast while still exercising the warm-up gate.
base = dict(min_steps=3, decline_patience=3, ema_alpha=0.5, kl_hard_stop=0.08)
base.update(kw)
return HeldOutGuard(**base)
# --- healthy run: both rise => never halt -----------------------------------
def test_no_halt_when_both_rise():
"""Clean run: in-loop and held-out rise together, KL stays in band. The
held-out twin scores within noise of the proxy => no fire (the well-behaved
case the literature says a clean model exhibits)."""
g = _guard()
status = None
for i in range(30):
status = g.update(
i,
in_loop_reward=0.30 + 0.01 * i,
heldout_score=0.28 + 0.01 * i, # tracks proxy within noise
kl_to_init=0.03,
)
assert not status.fire, f"fired unexpectedly at step {i}: {status.reason}"
assert not g.should_halt()
# Gap stays near zero because both gained equally.
assert abs(g.proxy_real_gap()) < 0.05
# --- canonical signature: held-out declines while in-loop rises -------------
def test_halt_on_heldout_declines_while_reward_rises():
g = _guard(max_proxy_real_gap=10.0) # disable gap-blowout path to isolate (a)
# Warm up past min_steps with a stable healthy stretch.
for i in range(6):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
assert not s.fire
# Now: proxy reward climbs, held-out eval falls — the reward-hacking
# fingerprint. Should fire once the decline streak hits decline_patience (3).
fired_at = None
for j, i in enumerate(range(6, 12)):
s = g.update(
i,
in_loop_reward=0.40 + 0.05 * (j + 1), # rising
heldout_score=0.40 - 0.05 * (j + 1), # declining
kl_to_init=0.03, # KL stays in band
)
if s.fire:
fired_at = i
break
assert fired_at is not None, "tripwire never fired on the collapse signature"
assert g.should_halt()
s = g.last_status
assert "held-out" in s.reason and "consecutive" in s.reason
assert s.proxy_real_gap > 0.0 # proxy gained while real lost
def test_does_not_fire_before_patience_window():
"""Held-out declining while in-loop rises for FEWER than decline_patience
checkpoints must NOT fire (window behavior)."""
g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
for i in range(6):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# Only 2 divergent checkpoints (< patience of 3) => no fire.
s1 = g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
s2 = g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
assert not s1.fire and not s2.fire
def test_decline_streak_resets_on_recovery():
"""A clean checkpoint (held-out recovers) resets the decline streak, so a
later short divergence does not inherit prior declines."""
g = _guard(decline_patience=3, max_proxy_real_gap=10.0)
for i in range(6):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# 2 declines...
g.update(6, in_loop_reward=0.45, heldout_score=0.35, kl_to_init=0.03)
g.update(7, in_loop_reward=0.50, heldout_score=0.30, kl_to_init=0.03)
# ...then held-out recovers (resets streak)...
s = g.update(8, in_loop_reward=0.50, heldout_score=0.45, kl_to_init=0.03)
assert not s.fire
# ...one more decline is only streak=1, still below patience.
s = g.update(9, in_loop_reward=0.55, heldout_score=0.40, kl_to_init=0.03)
assert not s.fire
# --- KL hard-stop ------------------------------------------------------------
def test_halt_on_kl_hard_stop_breach():
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
# Healthy KL through the warm-up; both metrics flat so only KL can fire.
for i in range(5):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
assert not s.fire
# KL spikes well above 0.08; EMA climbs across a couple steps then crosses.
fired = False
for i in range(5, 12):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.20)
if s.fire:
fired = True
assert "kl_to_init" in s.reason and "hard stop" in s.reason
break
assert fired, "KL hard-stop never fired despite KL EMA crossing the ceiling"
def test_kl_none_never_fires_kl_path():
"""If the caller never supplies kl_to_init, the KL path must be inert (and
kl_ema stays None) — KL is an optional float."""
g = _guard(max_proxy_real_gap=10.0)
s = None
for i in range(20):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=None)
assert s is not None and not s.fire
assert s.kl_ema is None
# --- proxy-real gap blowout (fast divergence) -------------------------------
def test_halt_on_proxy_real_gap_blowout():
"""A large single-generation divergence (proxy jumps, real stays flat) fires
via the gap-blowout path even before the decline streak reaches patience."""
g = _guard(max_proxy_real_gap=0.10, decline_patience=100) # disable (a)
for i in range(5):
g.update(i, in_loop_reward=0.30, heldout_score=0.30, kl_to_init=0.03)
# Proxy blows up; held-out flat. With ema_alpha=0.5 the gap crosses 0.10 fast.
fired = False
for i in range(5, 12):
s = g.update(i, in_loop_reward=0.90, heldout_score=0.30, kl_to_init=0.03)
if s.fire:
fired = True
assert "Hacking Gap" in s.reason
assert s.proxy_real_gap > 0.10
break
assert fired, "gap-blowout tripwire never fired"
# --- warm-up window (min_steps) ---------------------------------------------
def test_respects_min_steps_no_early_fire():
"""Even with every signal tripped, no fire before min_steps (avoids halting
on early-run noise)."""
g = _guard(min_steps=10, decline_patience=2, kl_hard_stop=0.08,
max_proxy_real_gap=0.01)
# Egregiously bad signals from step 0: KL huge, proxy up, held-out down.
for i in range(9): # 9 updates, all < min_steps=10
s = g.update(i, in_loop_reward=0.10 + 0.1 * i, heldout_score=0.90 - 0.1 * i,
kl_to_init=0.9)
assert not s.fire, f"fired during warm-up at step {i}: {s.reason}"
# The 10th update (n==10, not < min_steps) is now allowed to fire.
s = g.update(9, in_loop_reward=1.5, heldout_score=0.0, kl_to_init=0.9)
assert s.fire
# --- calibration -------------------------------------------------------------
def test_calibrate_kl_threshold_tightens_only():
g = _guard(kl_hard_stop=0.08)
# Baseline mean 0.01 => 3x = 0.03 < 0.08 => tightens to 0.03.
new = g.calibrate_kl_threshold([0.008, 0.010, 0.012], factor=3.0)
assert new == pytest.approx(0.03, abs=1e-9)
assert g.kl_hard_stop == pytest.approx(0.03, abs=1e-9)
def test_calibrate_never_loosens_past_band():
g = _guard(kl_hard_stop=0.08)
# A drifting baseline (mean 0.05 => 3x = 0.15) must NOT loosen past 0.08.
new = g.calibrate_kl_threshold([0.05, 0.05, 0.05], factor=3.0)
assert new == pytest.approx(0.08, abs=1e-9)
assert g.kl_hard_stop == pytest.approx(0.08, abs=1e-9)
def test_calibrate_empty_raises():
g = _guard()
with pytest.raises(ValueError, match="non-empty"):
g.calibrate_kl_threshold([])
# --- proxy_real_gap definition ----------------------------------------------
def test_proxy_real_gap_is_gain_difference():
g = _guard(min_steps=100, max_proxy_real_gap=10.0) # disable firing
g.update(0, in_loop_reward=0.20, heldout_score=0.20, kl_to_init=0.02) # baseline
# With ema_alpha=0.5 the second sample moves each EMA halfway.
g.update(1, in_loop_reward=0.60, heldout_score=0.30, kl_to_init=0.02)
# in_loop EMA: 0.5*0.20 + 0.5*0.60 = 0.40; gain = 0.40-0.20 = 0.20
# heldout EMA: 0.5*0.20 + 0.5*0.30 = 0.25; gain = 0.25-0.20 = 0.05
# gap = 0.20 - 0.05 = 0.15
assert g.proxy_real_gap() == pytest.approx(0.15, abs=1e-9)
def test_proxy_real_gap_zero_before_update():
g = _guard()
assert g.proxy_real_gap() == 0.0
# --- idempotency / edge cases -----------------------------------------------
def test_should_halt_is_idempotent_query():
g = _guard(max_proxy_real_gap=10.0)
for i in range(6):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# Querying repeatedly must not advance state or change the verdict.
snap_gap = g.proxy_real_gap()
assert g.should_halt() is False
assert g.should_halt() is False
assert g.proxy_real_gap() == snap_gap # unchanged by querying
assert g.last_status is not None and not g.last_status.fire
def test_fire_is_latched():
"""Once fired, a subsequent recovery cannot silently un-halt the run."""
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
for i in range(5):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
# Drive a KL breach.
fired = False
for i in range(5, 12):
s = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
if s.fire:
fired = True
break
assert fired
# Now KL recovers to healthy — verdict must stay fired (latched).
s = g.update(99, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.01)
assert s.fire and s.reason.startswith("latched:")
assert g.should_halt()
def test_raise_if_fired_raises_typed_exception():
g = _guard(kl_hard_stop=0.08, max_proxy_real_gap=10.0)
for i in range(5):
g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.04)
status = None
for i in range(5, 12):
status = g.update(i, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.30)
if status.fire:
break
assert status is not None and status.fire
with pytest.raises(CollapseStopError) as exc:
g.raise_if_fired(status)
assert exc.value.status is status
assert isinstance(str(exc.value), str) and str(exc.value)
def test_raise_if_fired_noop_when_clean():
g = _guard(max_proxy_real_gap=10.0)
s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
# No fire => no raise (uses last_status when arg omitted).
g.raise_if_fired(s)
g.raise_if_fired()
def test_status_halt_alias_matches_fire():
g = _guard(max_proxy_real_gap=10.0)
s = g.update(0, in_loop_reward=0.40, heldout_score=0.40, kl_to_init=0.03)
assert s.halt == s.fire is False
assert isinstance(s, TripwireStatus)
def test_non_contiguous_round_idx_uses_internal_counter():
"""min_steps gates on the internal update counter, not round_idx, so a caller
that logs sparse / non-contiguous round indices still warms up correctly."""
g = _guard(min_steps=3, max_proxy_real_gap=0.01, decline_patience=1)
# Pass huge round_idx values; only the 3rd UPDATE clears warm-up.
g.update(1000, in_loop_reward=0.10, heldout_score=0.90, kl_to_init=0.9)
g.update(2000, in_loop_reward=0.50, heldout_score=0.50, kl_to_init=0.9)
s = g.update(3000, in_loop_reward=0.90, heldout_score=0.10, kl_to_init=0.9)
assert s.fire # 3rd update, n==3 not < min_steps
# --- config validation -------------------------------------------------------
def test_bad_ema_alpha_rejected():
with pytest.raises(ValueError, match="ema_alpha"):
HeldOutGuard(ema_alpha=1.0)
with pytest.raises(ValueError, match="ema_alpha"):
HeldOutGuard(ema_alpha=-0.1)
def test_bad_kl_hard_stop_rejected():
with pytest.raises(ValueError, match="kl_hard_stop"):
HeldOutGuard(kl_hard_stop=0.0)
def test_bad_decline_patience_rejected():
with pytest.raises(ValueError, match="decline_patience"):
HeldOutGuard(decline_patience=0)
# --- kl_token_trust_filter helper -------------------------------------------
def test_kl_token_trust_filter_masks_above_threshold():
# 0.5 * logratio^2; mask when it exceeds the per-token KL ceiling.
assert kl_token_trust_filter(0.20, threshold=0.08) is True # too large -> mask
assert kl_token_trust_filter(0.05, threshold=0.08) is False # within trust region
assert kl_token_trust_filter(0.08, threshold=0.08) is False # boundary, not masked
# --- R4: calibrate_kl_threshold input guards (negative factor / baseline) -----
def test_calibrate_rejects_nonpositive_factor():
"""R4: a factor<=0 would make calibrated<=0 and min(<=0, 0.08)<=0, after
which the KL tripwire fires on every healthy step. Reject it loudly."""
g = _guard()
with pytest.raises(ValueError, match="factor must be > 0"):
g.calibrate_kl_threshold([0.01, 0.02], factor=-3.0)
with pytest.raises(ValueError, match="factor must be > 0"):
g.calibrate_kl_threshold([0.01, 0.02], factor=0.0)
def test_calibrate_rejects_negative_baseline_kl():
"""R4: KL is non-negative by definition; a negative baseline is nonsensical
and could invert the ceiling. Reject it."""
g = _guard()
with pytest.raises(ValueError, match="non-negative"):
g.calibrate_kl_threshold([0.01, -0.5, 0.02])
def test_calibrate_never_yields_nonpositive_threshold():
"""R4: even an all-zero baseline (mean 0) must leave a positive ceiling so a
later positive KL doesn't fire spuriously."""
g = _guard()
out = g.calibrate_kl_threshold([0.0, 0.0, 0.0], factor=3.0)
assert out > 0.0
assert g.kl_hard_stop > 0.0
# --- R10: path-(c) gap-blowout is a divergence-RATE gate, not a real-decline --
def test_gap_blowout_fires_even_when_real_still_rising():
"""R10: path (c) fires when the proxy gain outpaces the real gain beyond the
ceiling EVEN WHILE the held-out (real) score is still genuinely RISING. This
is INTENTIONAL — path (c) is a divergence-RATE gate (fast single-generation
hacking), distinct from path (a)'s real-decline streak. Locking the intended
behavior so a future change can't silently turn it into a real-decline gate."""
g = _guard(max_proxy_real_gap=0.1, decline_patience=99) # isolate path (c) from (a)
status = None
for i in range(8):
status = g.update(
i,
in_loop_reward=0.30 + 0.20 * i, # proxy sprints
heldout_score=0.30 + 0.01 * i, # real still rising, but slowly
kl_to_init=0.02,
)
assert status.fire, "path (c) should fire on a fast proxy/real divergence"
assert "gap" in status.reason.lower()
# And the real score WAS rising the whole time (not a decline-driven fire).
assert status.heldout_ema > g._fold(None, 0.30) # type: ignore[attr-defined]
|