Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
Raw
History Blame Contribute Delete
15.3 kB
"""holdout.py — held-out / train set-disjointness enforcer (the #2 safeguard,
second half).
``kill_switch.py`` (the ``HeldOutGuard`` run-level collapse tripwire) is only
sound if the held-out eval it watches is *genuinely disjoint* from the tasks the
generator trains on. If a single held-out task leaks back into the train /
generator pool, the "real" eval drifts WITH the train set and the proxy-real
Hacking-Gap signal becomes meaningless (see the Shumailov / Gao collapse
references in ``kill_switch.py``: the held-out eval must stay anchored to REAL
tasks that are NEVER fed back to the generator). This module enforces that
discipline mechanically rather than leaving it to convention.
``HeldoutSplit`` enforces disjointness two ways, both pure-Python:
- **id-based** — the train/generator ``task_id`` set and the held-out
``task_id`` set must not intersect. This is the cheap, exact check.
- **content-hash-based** (optional, ``check_content=True``) — a sha256 over a
*normalized* view of each task's content. This catches NEAR-DUPLICATES that
slipped through with DIFFERENT ids: the same broken repo + same
``fail_to_pass`` targets re-minted under a fresh ``task_id`` would pass the
id check but is, for collapse purposes, the same eval task leaking into
train. The EvilGenie failure-mode literature (arXiv 2511.21654, cited in
``kill_switch.py``) is explicit that "holdout tests have many surprising
failure modes" — silent re-id'd duplicates are one of them.
The ``split(all_tasks, holdout_frac, seed)`` constructor produces a
GUARANTEED-disjoint (train, holdout) partition deterministically: a fixed seed
yields the same partition every run, so the held-out anchor is reproducible
across the long self-evolving run.
Pure-Python: only ``hashlib`` / ``random`` from the stdlib. No torch, no cloud
deps. Accepts either raw ``task_id`` strings OR ``FeatureDeletionTask`` objects
(anything with a ``task_id`` attribute) on every entry point.
"""
from __future__ import annotations
import hashlib
import random
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from typing import Any
class HeldoutOverlapError(ValueError):
"""Raised when the train/generator pool and the held-out eval pool overlap.
Carries the offending identifiers so the caller can log exactly which tasks
leaked across the boundary (mirroring how ``datagen/monitor.py`` surfaces the
specific suspected hacks rather than a bare boolean).
Attributes:
overlapping_ids: sorted task ids present in BOTH pools (id-based leak).
overlapping_hashes: sorted content hashes present in both pools with
*different* ids (content-based near-duplicate leak); empty unless
content-hashing was enabled.
"""
def __init__(
self,
overlapping_ids: Sequence[str] = (),
overlapping_hashes: Sequence[str] = (),
) -> None:
self.overlapping_ids = tuple(overlapping_ids)
self.overlapping_hashes = tuple(overlapping_hashes)
parts: list[str] = []
if self.overlapping_ids:
parts.append(
f"{len(self.overlapping_ids)} task id(s) appear in BOTH the "
f"train/generator pool and the held-out eval pool: "
f"{list(self.overlapping_ids)}"
)
if self.overlapping_hashes:
parts.append(
f"{len(self.overlapping_hashes)} content hash(es) collide across "
f"the boundary with DIFFERENT ids (re-id'd near-duplicates): "
f"{list(self.overlapping_hashes)}"
)
if not parts: # defensive — should not be raised with nothing overlapping
parts.append("train/held-out overlap detected (no identifiers captured)")
super().__init__(
"held-out eval is NOT disjoint from the train/generator pool — "
"this corrupts the proxy-real collapse signal. " + "; ".join(parts)
)
def task_id_of(task: Any) -> str:
"""Coerce a task (a ``task_id`` string or a ``FeatureDeletionTask``-like
object with a ``.task_id`` attribute) to its id string.
Raises:
TypeError: if ``task`` is neither a string nor has a ``task_id``.
"""
if isinstance(task, str):
return task
tid = getattr(task, "task_id", None)
if isinstance(tid, str):
return tid
raise TypeError(
f"expected a task_id str or an object with a str .task_id attribute, "
f"got {type(task).__name__!r}"
)
def content_hash(task: Any) -> str:
"""sha256 over a NORMALIZED view of a task's content (id-independent).
The hash deliberately EXCLUDES ``task_id`` so two tasks that are identical
apart from their id collide — that collision is exactly the near-duplicate
leak we want ``check_content=True`` to catch.
Normalization (so cosmetic differences do not defeat the check):
- for ``FeatureDeletionTask``-like objects, hash the load-bearing content
fields (repo, base_commit, broken_image, test_command, the SORTED
fail_to_pass / pass_to_pass test sets, granularity, sorted
deleted_symbols) — NOT task_id, and NOT volatile/advisory fields like
difficulty_prior or upstream_license;
- for a bare string, hash the whitespace-collapsed, lower-cased text (a
plain id string is its own content);
- test-set tuples are sorted so reordering the same tests does not change
the hash.
A plain ``task_id`` string therefore hashes to a stable, content-derived
value; passing the same strings to both pools will collide on id FIRST
(the id check fires before the content check), so the string path is mainly
a graceful fallback for callers without structured tasks.
"""
fields = _content_fields(task)
blob = "\x1f".join(fields) # unit-separator join: unambiguous field boundary
return hashlib.sha256(blob.encode("utf-8")).hexdigest()
def _normalize_text(text: str) -> str:
"""Collapse runs of whitespace and lower-case, so cosmetic reformatting of a
command / repo string does not defeat content-hash matching."""
return " ".join(text.split()).lower()
def _content_fields(task: Any) -> list[str]:
"""Ordered, normalized content fields for hashing (id excluded)."""
if isinstance(task, str):
return [_normalize_text(task)]
# FeatureDeletionTask-like: pull the content-defining fields if present.
def norm(attr: str) -> str:
val = getattr(task, attr, None)
return _normalize_text(str(val)) if val is not None else ""
def norm_set(attr: str) -> str:
# Sorted so test-order does not change the hash; each test normalized.
vals = getattr(task, attr, None) or ()
return "\x1e".join(sorted(_normalize_text(str(v)) for v in vals))
if hasattr(task, "task_id"):
return [
norm("repo"),
norm("base_commit"),
norm("broken_image"),
norm("test_command"),
norm_set("fail_to_pass"),
norm_set("pass_to_pass"),
norm("granularity"),
norm_set("deleted_symbols"),
]
# Last resort: a non-string, non-task object — hash its repr (best-effort).
return [_normalize_text(repr(task))]
@dataclass(frozen=True)
class HeldoutSplit:
"""A (train/generator, held-out eval) partition with a disjointness contract.
Construct directly from two iterables of task ids (or
``FeatureDeletionTask`` objects)::
split = HeldoutSplit(train_tasks, holdout_tasks)
split.assert_disjoint() # raises HeldoutOverlapError on a leak
if split.is_disjoint: ...
or deterministically partition one pool::
split = HeldoutSplit.split(all_tasks, holdout_frac=0.2, seed=1234)
Set ``check_content=True`` to also reject re-id'd near-duplicates (same
normalized content under a different ``task_id``). Content-hashing is a
superset check: a content collision with the SAME id is just the id leak and
is reported via ``overlapping_ids``; a collision with DIFFERENT ids is the
near-duplicate leak reported via ``overlapping_content_hashes``.
The instance is frozen; the id/hash sets are computed once at construction.
"""
train_ids: frozenset[str]
holdout_ids: frozenset[str]
check_content: bool = False
# content hash -> set of ids, per pool (only populated when check_content).
_train_hashes: dict[str, frozenset[str]] = field(default_factory=dict, repr=False)
_holdout_hashes: dict[str, frozenset[str]] = field(default_factory=dict, repr=False)
# ------------------------------------------------------------------------
# construction
# ------------------------------------------------------------------------
def __init__(
self,
train: Iterable[Any],
holdout: Iterable[Any],
*,
check_content: bool = False,
) -> None:
train_list = list(train)
holdout_list = list(holdout)
object.__setattr__(self, "train_ids", frozenset(map(task_id_of, train_list)))
object.__setattr__(self, "holdout_ids", frozenset(map(task_id_of, holdout_list)))
object.__setattr__(self, "check_content", bool(check_content))
if check_content:
object.__setattr__(self, "_train_hashes", _hash_index(train_list))
object.__setattr__(self, "_holdout_hashes", _hash_index(holdout_list))
else:
object.__setattr__(self, "_train_hashes", {})
object.__setattr__(self, "_holdout_hashes", {})
# ------------------------------------------------------------------------
# deterministic constructor
# ------------------------------------------------------------------------
@classmethod
def split(
cls,
all_tasks: Iterable[Any],
holdout_frac: float = 0.2,
seed: int = 0,
*,
check_content: bool = False,
) -> HeldoutSplit:
"""Deterministically partition ``all_tasks`` into a disjoint (train,
held-out) split.
The partition is keyed on each task's ``task_id`` so it is reproducible
across runs (same ``all_tasks`` ids + same ``seed`` => same split). Tasks
are de-duplicated by id first (a duplicate id cannot land on both sides),
then shuffled with a SEEDED ``random.Random`` and sliced — guaranteeing a
disjoint result by construction.
Args:
all_tasks: the full pool (ids or ``FeatureDeletionTask`` objects).
holdout_frac: fraction routed to the held-out pool, in [0, 1]. The
held-out size is ``round(n * holdout_frac)``, clamped so that a
non-empty pool with ``0 < holdout_frac < 1`` always leaves at
least one task on EACH side.
seed: PRNG seed for the deterministic shuffle.
check_content: enable content-hash disjointness on the result too.
Returns:
A ``HeldoutSplit`` whose ``is_disjoint`` is True by construction.
Raises:
ValueError: if ``holdout_frac`` is outside [0, 1].
"""
if not (0.0 <= holdout_frac <= 1.0):
raise ValueError(
f"holdout_frac must be in [0, 1], got {holdout_frac!r}"
)
# De-dup by id, preserving first-seen order, keeping the original object
# so content-hashing (if enabled) sees the structured task.
seen: set[str] = set()
unique: list[Any] = []
for t in all_tasks:
tid = task_id_of(t)
if tid not in seen:
seen.add(tid)
unique.append(t)
n = len(unique)
n_holdout = round(n * holdout_frac)
# Clamp so a meaningful frac never collapses one side to empty.
if 0.0 < holdout_frac < 1.0 and n >= 2:
n_holdout = min(max(n_holdout, 1), n - 1)
# Deterministic shuffle on a COPY (does not mutate caller input).
order = list(unique)
random.Random(seed).shuffle(order)
holdout = order[:n_holdout]
train = order[n_holdout:]
return cls(train, holdout, check_content=check_content)
# ------------------------------------------------------------------------
# disjointness checks
# ------------------------------------------------------------------------
def overlapping_ids(self) -> tuple[str, ...]:
"""Sorted task ids present in BOTH pools (the id-based leak set)."""
return tuple(sorted(self.train_ids & self.holdout_ids))
def overlapping_content_hashes(self) -> tuple[str, ...]:
"""Sorted content hashes that collide across pools with DIFFERENT ids.
Empty when ``check_content`` is False. A hash present in both pools whose
only shared ids are already plain id-overlaps is not reported here (that
leak surfaces via ``overlapping_ids``); only collisions that involve at
least one DIFFERENT id on each side count, so the two checks do not
double-report the same leak.
"""
if not self.check_content:
return ()
id_overlap = self.train_ids & self.holdout_ids
bad: list[str] = []
for h, train_ids in self._train_hashes.items():
holdout_ids = self._holdout_hashes.get(h)
if holdout_ids is None:
continue
# Same content on both sides via at least one id that is NOT itself a
# plain id-overlap => a re-id'd near-duplicate leak.
if (holdout_ids - id_overlap) and (train_ids - id_overlap):
bad.append(h)
return tuple(sorted(bad))
@property
def is_disjoint(self) -> bool:
"""True iff the pools share no task id (and, when ``check_content``, no
cross-id near-duplicate content)."""
if self.train_ids & self.holdout_ids:
return False
if self.check_content and self.overlapping_content_hashes():
return False
return True
def validate(self) -> HeldoutSplit:
"""Assert disjointness; return ``self`` so it chains in a constructor.
Raises:
HeldoutOverlapError: listing the overlapping ids (and, when
``check_content``, the near-duplicate content hashes).
"""
id_overlap = self.overlapping_ids()
hash_overlap = self.overlapping_content_hashes()
if id_overlap or hash_overlap:
raise HeldoutOverlapError(id_overlap, hash_overlap)
return self
# Documented alias: the task spec names both `validate()` and
# `assert_disjoint()` — expose both so either calling convention works.
def assert_disjoint(self) -> HeldoutSplit:
"""Alias for ``validate()`` — raise ``HeldoutOverlapError`` on any leak."""
return self.validate()
def _hash_index(tasks: Iterable[Any]) -> dict[str, frozenset[str]]:
"""Map content hash -> frozenset of task ids producing that hash."""
acc: dict[str, set[str]] = {}
for t in tasks:
acc.setdefault(content_hash(t), set()).add(task_id_of(t))
return {h: frozenset(ids) for h, ids in acc.items()}