"""holdout.py — held-out / train set-disjointness enforcer (the #2 safeguard, second half). ``kill_switch.py`` (the ``HeldOutGuard`` run-level collapse tripwire) is only sound if the held-out eval it watches is *genuinely disjoint* from the tasks the generator trains on. If a single held-out task leaks back into the train / generator pool, the "real" eval drifts WITH the train set and the proxy-real Hacking-Gap signal becomes meaningless (see the Shumailov / Gao collapse references in ``kill_switch.py``: the held-out eval must stay anchored to REAL tasks that are NEVER fed back to the generator). This module enforces that discipline mechanically rather than leaving it to convention. ``HeldoutSplit`` enforces disjointness two ways, both pure-Python: - **id-based** — the train/generator ``task_id`` set and the held-out ``task_id`` set must not intersect. This is the cheap, exact check. - **content-hash-based** (optional, ``check_content=True``) — a sha256 over a *normalized* view of each task's content. This catches NEAR-DUPLICATES that slipped through with DIFFERENT ids: the same broken repo + same ``fail_to_pass`` targets re-minted under a fresh ``task_id`` would pass the id check but is, for collapse purposes, the same eval task leaking into train. The EvilGenie failure-mode literature (arXiv 2511.21654, cited in ``kill_switch.py``) is explicit that "holdout tests have many surprising failure modes" — silent re-id'd duplicates are one of them. The ``split(all_tasks, holdout_frac, seed)`` constructor produces a GUARANTEED-disjoint (train, holdout) partition deterministically: a fixed seed yields the same partition every run, so the held-out anchor is reproducible across the long self-evolving run. Pure-Python: only ``hashlib`` / ``random`` from the stdlib. No torch, no cloud deps. Accepts either raw ``task_id`` strings OR ``FeatureDeletionTask`` objects (anything with a ``task_id`` attribute) on every entry point. """ from __future__ import annotations import hashlib import random from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from typing import Any class HeldoutOverlapError(ValueError): """Raised when the train/generator pool and the held-out eval pool overlap. Carries the offending identifiers so the caller can log exactly which tasks leaked across the boundary (mirroring how ``datagen/monitor.py`` surfaces the specific suspected hacks rather than a bare boolean). Attributes: overlapping_ids: sorted task ids present in BOTH pools (id-based leak). overlapping_hashes: sorted content hashes present in both pools with *different* ids (content-based near-duplicate leak); empty unless content-hashing was enabled. """ def __init__( self, overlapping_ids: Sequence[str] = (), overlapping_hashes: Sequence[str] = (), ) -> None: self.overlapping_ids = tuple(overlapping_ids) self.overlapping_hashes = tuple(overlapping_hashes) parts: list[str] = [] if self.overlapping_ids: parts.append( f"{len(self.overlapping_ids)} task id(s) appear in BOTH the " f"train/generator pool and the held-out eval pool: " f"{list(self.overlapping_ids)}" ) if self.overlapping_hashes: parts.append( f"{len(self.overlapping_hashes)} content hash(es) collide across " f"the boundary with DIFFERENT ids (re-id'd near-duplicates): " f"{list(self.overlapping_hashes)}" ) if not parts: # defensive — should not be raised with nothing overlapping parts.append("train/held-out overlap detected (no identifiers captured)") super().__init__( "held-out eval is NOT disjoint from the train/generator pool — " "this corrupts the proxy-real collapse signal. " + "; ".join(parts) ) def task_id_of(task: Any) -> str: """Coerce a task (a ``task_id`` string or a ``FeatureDeletionTask``-like object with a ``.task_id`` attribute) to its id string. Raises: TypeError: if ``task`` is neither a string nor has a ``task_id``. """ if isinstance(task, str): return task tid = getattr(task, "task_id", None) if isinstance(tid, str): return tid raise TypeError( f"expected a task_id str or an object with a str .task_id attribute, " f"got {type(task).__name__!r}" ) def content_hash(task: Any) -> str: """sha256 over a NORMALIZED view of a task's content (id-independent). The hash deliberately EXCLUDES ``task_id`` so two tasks that are identical apart from their id collide — that collision is exactly the near-duplicate leak we want ``check_content=True`` to catch. Normalization (so cosmetic differences do not defeat the check): - for ``FeatureDeletionTask``-like objects, hash the load-bearing content fields (repo, base_commit, broken_image, test_command, the SORTED fail_to_pass / pass_to_pass test sets, granularity, sorted deleted_symbols) — NOT task_id, and NOT volatile/advisory fields like difficulty_prior or upstream_license; - for a bare string, hash the whitespace-collapsed, lower-cased text (a plain id string is its own content); - test-set tuples are sorted so reordering the same tests does not change the hash. A plain ``task_id`` string therefore hashes to a stable, content-derived value; passing the same strings to both pools will collide on id FIRST (the id check fires before the content check), so the string path is mainly a graceful fallback for callers without structured tasks. """ fields = _content_fields(task) blob = "\x1f".join(fields) # unit-separator join: unambiguous field boundary return hashlib.sha256(blob.encode("utf-8")).hexdigest() def _normalize_text(text: str) -> str: """Collapse runs of whitespace and lower-case, so cosmetic reformatting of a command / repo string does not defeat content-hash matching.""" return " ".join(text.split()).lower() def _content_fields(task: Any) -> list[str]: """Ordered, normalized content fields for hashing (id excluded).""" if isinstance(task, str): return [_normalize_text(task)] # FeatureDeletionTask-like: pull the content-defining fields if present. def norm(attr: str) -> str: val = getattr(task, attr, None) return _normalize_text(str(val)) if val is not None else "" def norm_set(attr: str) -> str: # Sorted so test-order does not change the hash; each test normalized. vals = getattr(task, attr, None) or () return "\x1e".join(sorted(_normalize_text(str(v)) for v in vals)) if hasattr(task, "task_id"): return [ norm("repo"), norm("base_commit"), norm("broken_image"), norm("test_command"), norm_set("fail_to_pass"), norm_set("pass_to_pass"), norm("granularity"), norm_set("deleted_symbols"), ] # Last resort: a non-string, non-task object — hash its repr (best-effort). return [_normalize_text(repr(task))] @dataclass(frozen=True) class HeldoutSplit: """A (train/generator, held-out eval) partition with a disjointness contract. Construct directly from two iterables of task ids (or ``FeatureDeletionTask`` objects):: split = HeldoutSplit(train_tasks, holdout_tasks) split.assert_disjoint() # raises HeldoutOverlapError on a leak if split.is_disjoint: ... or deterministically partition one pool:: split = HeldoutSplit.split(all_tasks, holdout_frac=0.2, seed=1234) Set ``check_content=True`` to also reject re-id'd near-duplicates (same normalized content under a different ``task_id``). Content-hashing is a superset check: a content collision with the SAME id is just the id leak and is reported via ``overlapping_ids``; a collision with DIFFERENT ids is the near-duplicate leak reported via ``overlapping_content_hashes``. The instance is frozen; the id/hash sets are computed once at construction. """ train_ids: frozenset[str] holdout_ids: frozenset[str] check_content: bool = False # content hash -> set of ids, per pool (only populated when check_content). _train_hashes: dict[str, frozenset[str]] = field(default_factory=dict, repr=False) _holdout_hashes: dict[str, frozenset[str]] = field(default_factory=dict, repr=False) # ------------------------------------------------------------------------ # construction # ------------------------------------------------------------------------ def __init__( self, train: Iterable[Any], holdout: Iterable[Any], *, check_content: bool = False, ) -> None: train_list = list(train) holdout_list = list(holdout) object.__setattr__(self, "train_ids", frozenset(map(task_id_of, train_list))) object.__setattr__(self, "holdout_ids", frozenset(map(task_id_of, holdout_list))) object.__setattr__(self, "check_content", bool(check_content)) if check_content: object.__setattr__(self, "_train_hashes", _hash_index(train_list)) object.__setattr__(self, "_holdout_hashes", _hash_index(holdout_list)) else: object.__setattr__(self, "_train_hashes", {}) object.__setattr__(self, "_holdout_hashes", {}) # ------------------------------------------------------------------------ # deterministic constructor # ------------------------------------------------------------------------ @classmethod def split( cls, all_tasks: Iterable[Any], holdout_frac: float = 0.2, seed: int = 0, *, check_content: bool = False, ) -> HeldoutSplit: """Deterministically partition ``all_tasks`` into a disjoint (train, held-out) split. The partition is keyed on each task's ``task_id`` so it is reproducible across runs (same ``all_tasks`` ids + same ``seed`` => same split). Tasks are de-duplicated by id first (a duplicate id cannot land on both sides), then shuffled with a SEEDED ``random.Random`` and sliced — guaranteeing a disjoint result by construction. Args: all_tasks: the full pool (ids or ``FeatureDeletionTask`` objects). holdout_frac: fraction routed to the held-out pool, in [0, 1]. The held-out size is ``round(n * holdout_frac)``, clamped so that a non-empty pool with ``0 < holdout_frac < 1`` always leaves at least one task on EACH side. seed: PRNG seed for the deterministic shuffle. check_content: enable content-hash disjointness on the result too. Returns: A ``HeldoutSplit`` whose ``is_disjoint`` is True by construction. Raises: ValueError: if ``holdout_frac`` is outside [0, 1]. """ if not (0.0 <= holdout_frac <= 1.0): raise ValueError( f"holdout_frac must be in [0, 1], got {holdout_frac!r}" ) # De-dup by id, preserving first-seen order, keeping the original object # so content-hashing (if enabled) sees the structured task. seen: set[str] = set() unique: list[Any] = [] for t in all_tasks: tid = task_id_of(t) if tid not in seen: seen.add(tid) unique.append(t) n = len(unique) n_holdout = round(n * holdout_frac) # Clamp so a meaningful frac never collapses one side to empty. if 0.0 < holdout_frac < 1.0 and n >= 2: n_holdout = min(max(n_holdout, 1), n - 1) # Deterministic shuffle on a COPY (does not mutate caller input). order = list(unique) random.Random(seed).shuffle(order) holdout = order[:n_holdout] train = order[n_holdout:] return cls(train, holdout, check_content=check_content) # ------------------------------------------------------------------------ # disjointness checks # ------------------------------------------------------------------------ def overlapping_ids(self) -> tuple[str, ...]: """Sorted task ids present in BOTH pools (the id-based leak set).""" return tuple(sorted(self.train_ids & self.holdout_ids)) def overlapping_content_hashes(self) -> tuple[str, ...]: """Sorted content hashes that collide across pools with DIFFERENT ids. Empty when ``check_content`` is False. A hash present in both pools whose only shared ids are already plain id-overlaps is not reported here (that leak surfaces via ``overlapping_ids``); only collisions that involve at least one DIFFERENT id on each side count, so the two checks do not double-report the same leak. """ if not self.check_content: return () id_overlap = self.train_ids & self.holdout_ids bad: list[str] = [] for h, train_ids in self._train_hashes.items(): holdout_ids = self._holdout_hashes.get(h) if holdout_ids is None: continue # Same content on both sides via at least one id that is NOT itself a # plain id-overlap => a re-id'd near-duplicate leak. if (holdout_ids - id_overlap) and (train_ids - id_overlap): bad.append(h) return tuple(sorted(bad)) @property def is_disjoint(self) -> bool: """True iff the pools share no task id (and, when ``check_content``, no cross-id near-duplicate content).""" if self.train_ids & self.holdout_ids: return False if self.check_content and self.overlapping_content_hashes(): return False return True def validate(self) -> HeldoutSplit: """Assert disjointness; return ``self`` so it chains in a constructor. Raises: HeldoutOverlapError: listing the overlapping ids (and, when ``check_content``, the near-duplicate content hashes). """ id_overlap = self.overlapping_ids() hash_overlap = self.overlapping_content_hashes() if id_overlap or hash_overlap: raise HeldoutOverlapError(id_overlap, hash_overlap) return self # Documented alias: the task spec names both `validate()` and # `assert_disjoint()` — expose both so either calling convention works. def assert_disjoint(self) -> HeldoutSplit: """Alias for ``validate()`` — raise ``HeldoutOverlapError`` on any leak.""" return self.validate() def _hash_index(tasks: Iterable[Any]) -> dict[str, frozenset[str]]: """Map content hash -> frozenset of task ids producing that hash.""" acc: dict[str, set[str]] = {} for t in tasks: acc.setdefault(content_hash(t), set()).add(task_id_of(t)) return {h: frozenset(ids) for h, ids in acc.items()}