File size: 15,291 Bytes
bd0c358
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
"""holdout.py — held-out / train set-disjointness enforcer (the #2 safeguard,
   second half).

``kill_switch.py`` (the ``HeldOutGuard`` run-level collapse tripwire) is only
sound if the held-out eval it watches is *genuinely disjoint* from the tasks the
generator trains on. If a single held-out task leaks back into the train /
generator pool, the "real" eval drifts WITH the train set and the proxy-real
Hacking-Gap signal becomes meaningless (see the Shumailov / Gao collapse
references in ``kill_switch.py``: the held-out eval must stay anchored to REAL
tasks that are NEVER fed back to the generator). This module enforces that
discipline mechanically rather than leaving it to convention.

``HeldoutSplit`` enforces disjointness two ways, both pure-Python:

  - **id-based** — the train/generator ``task_id`` set and the held-out
    ``task_id`` set must not intersect. This is the cheap, exact check.

  - **content-hash-based** (optional, ``check_content=True``) — a sha256 over a
    *normalized* view of each task's content. This catches NEAR-DUPLICATES that
    slipped through with DIFFERENT ids: the same broken repo + same
    ``fail_to_pass`` targets re-minted under a fresh ``task_id`` would pass the
    id check but is, for collapse purposes, the same eval task leaking into
    train. The EvilGenie failure-mode literature (arXiv 2511.21654, cited in
    ``kill_switch.py``) is explicit that "holdout tests have many surprising
    failure modes" — silent re-id'd duplicates are one of them.

The ``split(all_tasks, holdout_frac, seed)`` constructor produces a
GUARANTEED-disjoint (train, holdout) partition deterministically: a fixed seed
yields the same partition every run, so the held-out anchor is reproducible
across the long self-evolving run.

Pure-Python: only ``hashlib`` / ``random`` from the stdlib. No torch, no cloud
deps. Accepts either raw ``task_id`` strings OR ``FeatureDeletionTask`` objects
(anything with a ``task_id`` attribute) on every entry point.
"""
from __future__ import annotations

import hashlib
import random
from collections.abc import Iterable, Sequence
from dataclasses import dataclass, field
from typing import Any


class HeldoutOverlapError(ValueError):
    """Raised when the train/generator pool and the held-out eval pool overlap.

    Carries the offending identifiers so the caller can log exactly which tasks
    leaked across the boundary (mirroring how ``datagen/monitor.py`` surfaces the
    specific suspected hacks rather than a bare boolean).

    Attributes:
        overlapping_ids: sorted task ids present in BOTH pools (id-based leak).
        overlapping_hashes: sorted content hashes present in both pools with
            *different* ids (content-based near-duplicate leak); empty unless
            content-hashing was enabled.
    """

    def __init__(
        self,
        overlapping_ids: Sequence[str] = (),
        overlapping_hashes: Sequence[str] = (),
    ) -> None:
        self.overlapping_ids = tuple(overlapping_ids)
        self.overlapping_hashes = tuple(overlapping_hashes)
        parts: list[str] = []
        if self.overlapping_ids:
            parts.append(
                f"{len(self.overlapping_ids)} task id(s) appear in BOTH the "
                f"train/generator pool and the held-out eval pool: "
                f"{list(self.overlapping_ids)}"
            )
        if self.overlapping_hashes:
            parts.append(
                f"{len(self.overlapping_hashes)} content hash(es) collide across "
                f"the boundary with DIFFERENT ids (re-id'd near-duplicates): "
                f"{list(self.overlapping_hashes)}"
            )
        if not parts:  # defensive — should not be raised with nothing overlapping
            parts.append("train/held-out overlap detected (no identifiers captured)")
        super().__init__(
            "held-out eval is NOT disjoint from the train/generator pool — "
            "this corrupts the proxy-real collapse signal. " + "; ".join(parts)
        )


def task_id_of(task: Any) -> str:
    """Coerce a task (a ``task_id`` string or a ``FeatureDeletionTask``-like
    object with a ``.task_id`` attribute) to its id string.

    Raises:
        TypeError: if ``task`` is neither a string nor has a ``task_id``.
    """
    if isinstance(task, str):
        return task
    tid = getattr(task, "task_id", None)
    if isinstance(tid, str):
        return tid
    raise TypeError(
        f"expected a task_id str or an object with a str .task_id attribute, "
        f"got {type(task).__name__!r}"
    )


def content_hash(task: Any) -> str:
    """sha256 over a NORMALIZED view of a task's content (id-independent).

    The hash deliberately EXCLUDES ``task_id`` so two tasks that are identical
    apart from their id collide — that collision is exactly the near-duplicate
    leak we want ``check_content=True`` to catch.

    Normalization (so cosmetic differences do not defeat the check):
      - for ``FeatureDeletionTask``-like objects, hash the load-bearing content
        fields (repo, base_commit, broken_image, test_command, the SORTED
        fail_to_pass / pass_to_pass test sets, granularity, sorted
        deleted_symbols) — NOT task_id, and NOT volatile/advisory fields like
        difficulty_prior or upstream_license;
      - for a bare string, hash the whitespace-collapsed, lower-cased text (a
        plain id string is its own content);
      - test-set tuples are sorted so reordering the same tests does not change
        the hash.

    A plain ``task_id`` string therefore hashes to a stable, content-derived
    value; passing the same strings to both pools will collide on id FIRST
    (the id check fires before the content check), so the string path is mainly
    a graceful fallback for callers without structured tasks.
    """
    fields = _content_fields(task)
    blob = "\x1f".join(fields)  # unit-separator join: unambiguous field boundary
    return hashlib.sha256(blob.encode("utf-8")).hexdigest()


def _normalize_text(text: str) -> str:
    """Collapse runs of whitespace and lower-case, so cosmetic reformatting of a
    command / repo string does not defeat content-hash matching."""
    return " ".join(text.split()).lower()


def _content_fields(task: Any) -> list[str]:
    """Ordered, normalized content fields for hashing (id excluded)."""
    if isinstance(task, str):
        return [_normalize_text(task)]

    # FeatureDeletionTask-like: pull the content-defining fields if present.
    def norm(attr: str) -> str:
        val = getattr(task, attr, None)
        return _normalize_text(str(val)) if val is not None else ""

    def norm_set(attr: str) -> str:
        # Sorted so test-order does not change the hash; each test normalized.
        vals = getattr(task, attr, None) or ()
        return "\x1e".join(sorted(_normalize_text(str(v)) for v in vals))

    if hasattr(task, "task_id"):
        return [
            norm("repo"),
            norm("base_commit"),
            norm("broken_image"),
            norm("test_command"),
            norm_set("fail_to_pass"),
            norm_set("pass_to_pass"),
            norm("granularity"),
            norm_set("deleted_symbols"),
        ]

    # Last resort: a non-string, non-task object — hash its repr (best-effort).
    return [_normalize_text(repr(task))]


@dataclass(frozen=True)
class HeldoutSplit:
    """A (train/generator, held-out eval) partition with a disjointness contract.

    Construct directly from two iterables of task ids (or
    ``FeatureDeletionTask`` objects)::

        split = HeldoutSplit(train_tasks, holdout_tasks)
        split.assert_disjoint()           # raises HeldoutOverlapError on a leak
        if split.is_disjoint: ...

    or deterministically partition one pool::

        split = HeldoutSplit.split(all_tasks, holdout_frac=0.2, seed=1234)

    Set ``check_content=True`` to also reject re-id'd near-duplicates (same
    normalized content under a different ``task_id``). Content-hashing is a
    superset check: a content collision with the SAME id is just the id leak and
    is reported via ``overlapping_ids``; a collision with DIFFERENT ids is the
    near-duplicate leak reported via ``overlapping_content_hashes``.

    The instance is frozen; the id/hash sets are computed once at construction.
    """

    train_ids: frozenset[str]
    holdout_ids: frozenset[str]
    check_content: bool = False
    # content hash -> set of ids, per pool (only populated when check_content).
    _train_hashes: dict[str, frozenset[str]] = field(default_factory=dict, repr=False)
    _holdout_hashes: dict[str, frozenset[str]] = field(default_factory=dict, repr=False)

    # ------------------------------------------------------------------------
    # construction
    # ------------------------------------------------------------------------
    def __init__(
        self,
        train: Iterable[Any],
        holdout: Iterable[Any],
        *,
        check_content: bool = False,
    ) -> None:
        train_list = list(train)
        holdout_list = list(holdout)

        object.__setattr__(self, "train_ids", frozenset(map(task_id_of, train_list)))
        object.__setattr__(self, "holdout_ids", frozenset(map(task_id_of, holdout_list)))
        object.__setattr__(self, "check_content", bool(check_content))

        if check_content:
            object.__setattr__(self, "_train_hashes", _hash_index(train_list))
            object.__setattr__(self, "_holdout_hashes", _hash_index(holdout_list))
        else:
            object.__setattr__(self, "_train_hashes", {})
            object.__setattr__(self, "_holdout_hashes", {})

    # ------------------------------------------------------------------------
    # deterministic constructor
    # ------------------------------------------------------------------------
    @classmethod
    def split(
        cls,
        all_tasks: Iterable[Any],
        holdout_frac: float = 0.2,
        seed: int = 0,
        *,
        check_content: bool = False,
    ) -> HeldoutSplit:
        """Deterministically partition ``all_tasks`` into a disjoint (train,
        held-out) split.

        The partition is keyed on each task's ``task_id`` so it is reproducible
        across runs (same ``all_tasks`` ids + same ``seed`` => same split). Tasks
        are de-duplicated by id first (a duplicate id cannot land on both sides),
        then shuffled with a SEEDED ``random.Random`` and sliced — guaranteeing a
        disjoint result by construction.

        Args:
            all_tasks: the full pool (ids or ``FeatureDeletionTask`` objects).
            holdout_frac: fraction routed to the held-out pool, in [0, 1]. The
                held-out size is ``round(n * holdout_frac)``, clamped so that a
                non-empty pool with ``0 < holdout_frac < 1`` always leaves at
                least one task on EACH side.
            seed: PRNG seed for the deterministic shuffle.
            check_content: enable content-hash disjointness on the result too.

        Returns:
            A ``HeldoutSplit`` whose ``is_disjoint`` is True by construction.

        Raises:
            ValueError: if ``holdout_frac`` is outside [0, 1].
        """
        if not (0.0 <= holdout_frac <= 1.0):
            raise ValueError(
                f"holdout_frac must be in [0, 1], got {holdout_frac!r}"
            )

        # De-dup by id, preserving first-seen order, keeping the original object
        # so content-hashing (if enabled) sees the structured task.
        seen: set[str] = set()
        unique: list[Any] = []
        for t in all_tasks:
            tid = task_id_of(t)
            if tid not in seen:
                seen.add(tid)
                unique.append(t)

        n = len(unique)
        n_holdout = round(n * holdout_frac)
        # Clamp so a meaningful frac never collapses one side to empty.
        if 0.0 < holdout_frac < 1.0 and n >= 2:
            n_holdout = min(max(n_holdout, 1), n - 1)

        # Deterministic shuffle on a COPY (does not mutate caller input).
        order = list(unique)
        random.Random(seed).shuffle(order)
        holdout = order[:n_holdout]
        train = order[n_holdout:]
        return cls(train, holdout, check_content=check_content)

    # ------------------------------------------------------------------------
    # disjointness checks
    # ------------------------------------------------------------------------
    def overlapping_ids(self) -> tuple[str, ...]:
        """Sorted task ids present in BOTH pools (the id-based leak set)."""
        return tuple(sorted(self.train_ids & self.holdout_ids))

    def overlapping_content_hashes(self) -> tuple[str, ...]:
        """Sorted content hashes that collide across pools with DIFFERENT ids.

        Empty when ``check_content`` is False. A hash present in both pools whose
        only shared ids are already plain id-overlaps is not reported here (that
        leak surfaces via ``overlapping_ids``); only collisions that involve at
        least one DIFFERENT id on each side count, so the two checks do not
        double-report the same leak.
        """
        if not self.check_content:
            return ()
        id_overlap = self.train_ids & self.holdout_ids
        bad: list[str] = []
        for h, train_ids in self._train_hashes.items():
            holdout_ids = self._holdout_hashes.get(h)
            if holdout_ids is None:
                continue
            # Same content on both sides via at least one id that is NOT itself a
            # plain id-overlap => a re-id'd near-duplicate leak.
            if (holdout_ids - id_overlap) and (train_ids - id_overlap):
                bad.append(h)
        return tuple(sorted(bad))

    @property
    def is_disjoint(self) -> bool:
        """True iff the pools share no task id (and, when ``check_content``, no
        cross-id near-duplicate content)."""
        if self.train_ids & self.holdout_ids:
            return False
        if self.check_content and self.overlapping_content_hashes():
            return False
        return True

    def validate(self) -> HeldoutSplit:
        """Assert disjointness; return ``self`` so it chains in a constructor.

        Raises:
            HeldoutOverlapError: listing the overlapping ids (and, when
                ``check_content``, the near-duplicate content hashes).
        """
        id_overlap = self.overlapping_ids()
        hash_overlap = self.overlapping_content_hashes()
        if id_overlap or hash_overlap:
            raise HeldoutOverlapError(id_overlap, hash_overlap)
        return self

    # Documented alias: the task spec names both `validate()` and
    # `assert_disjoint()` — expose both so either calling convention works.
    def assert_disjoint(self) -> HeldoutSplit:
        """Alias for ``validate()`` — raise ``HeldoutOverlapError`` on any leak."""
        return self.validate()


def _hash_index(tasks: Iterable[Any]) -> dict[str, frozenset[str]]:
    """Map content hash -> frozenset of task ids producing that hash."""
    acc: dict[str, set[str]] = {}
    for t in tasks:
        acc.setdefault(content_hash(t), set()).add(task_id_of(t))
    return {h: frozenset(ids) for h, ids in acc.items()}