"""Tests for HeldoutSplit — the train/held-out set-disjointness enforcer. This is the second half of C1 (the first half is HeldOutGuard in kill_switch.py): the guard's proxy-real-gap signal is only meaningful if the held-out eval set is genuinely DISJOINT from the train/generator set. HeldoutSplit enforces that. Written during Wave-3 integration (the build agent shipped holdout.py without a test module — same test-gap pattern as the SageMaker/EKS executors). """ from __future__ import annotations import pytest from composer_replication.safety import HeldoutOverlapError, HeldoutSplit # A tiny FeatureDeletionTask-like stand-in (HeldoutSplit reads task_id + content # fields via duck-typing; a plain object with task_id works, and a string is # treated as its own id). class _Task: """FeatureDeletionTask-like stand-in. content-hashing in holdout.py reads the real task fields (repo/base_commit/test_command/...), so the content kwarg populates `repo` (one of the hashed fields), not a generic `content` attr.""" def __init__(self, task_id, content=""): self.task_id = task_id self.repo = content # `repo` is one of the fields _content_fields hashes # --------------------------------------------------------------------- # id-based disjointness # --------------------------------------------------------------------- def test_disjoint_passes(): s = HeldoutSplit(train=["a", "b", "c"], holdout=["d", "e"]) assert s.is_disjoint assert s.overlapping_ids() == () # validate / assert_disjoint return self and do not raise assert s.validate() is s assert s.assert_disjoint() is s def test_overlap_raises_and_lists_ids(): s = HeldoutSplit(train=["a", "b", "shared"], holdout=["shared", "z"]) assert not s.is_disjoint assert s.overlapping_ids() == ("shared",) with pytest.raises(HeldoutOverlapError) as exc: s.validate() assert "shared" in str(exc.value) def test_object_tasks_use_task_id(): train = [_Task("t1"), _Task("t2")] holdout = [_Task("h1")] assert HeldoutSplit(train, holdout).is_disjoint # an object sharing an id with train is a leak leak = HeldoutSplit(train, [_Task("t1")]) assert not leak.is_disjoint assert leak.overlapping_ids() == ("t1",) # --------------------------------------------------------------------- # deterministic split() # --------------------------------------------------------------------- def test_split_is_disjoint_by_construction(): pool = [f"task{i}" for i in range(10)] s = HeldoutSplit.split(pool, holdout_frac=0.3, seed=0) assert s.is_disjoint # every id is on exactly one side; union covers the (de-duped) pool assert s.train_ids.isdisjoint(s.holdout_ids) assert (s.train_ids | s.holdout_ids) == set(pool) def test_split_is_deterministic(): pool = [f"task{i}" for i in range(20)] a = HeldoutSplit.split(pool, holdout_frac=0.25, seed=42) b = HeldoutSplit.split(pool, holdout_frac=0.25, seed=42) assert a.holdout_ids == b.holdout_ids assert a.train_ids == b.train_ids # a different seed gives a (very likely) different partition c = HeldoutSplit.split(pool, holdout_frac=0.25, seed=7) assert c.holdout_ids != a.holdout_ids or c.train_ids != a.train_ids def test_split_never_collapses_a_side(): s = HeldoutSplit.split([f"t{i}" for i in range(5)], holdout_frac=0.01, seed=0) assert len(s.holdout_ids) >= 1 assert len(s.train_ids) >= 1 def test_split_rejects_bad_frac(): with pytest.raises(ValueError): HeldoutSplit.split(["a", "b"], holdout_frac=1.5) def test_split_dedups_by_id(): # duplicate ids cannot land on both sides pool = ["a", "a", "b", "c"] s = HeldoutSplit.split(pool, holdout_frac=0.5, seed=0) assert s.is_disjoint assert (s.train_ids | s.holdout_ids) == {"a", "b", "c"} # --------------------------------------------------------------------- # content-hash disjointness (catches same-content / different-id near-dups) # --------------------------------------------------------------------- def test_content_hash_catches_same_content_different_id(): # different ids, identical content -> id-disjoint but content-leaked train = [_Task("t1", content="fix the off-by-one in range()")] holdout = [_Task("h1", content="fix the off-by-one in range()")] s = HeldoutSplit(train, holdout, check_content=True) assert s.overlapping_ids() == () # ids are disjoint assert not s.is_disjoint # but content collides assert s.overlapping_content_hashes() # non-empty with pytest.raises(HeldoutOverlapError): s.validate() def test_content_hash_disjoint_when_content_differs(): train = [_Task("t1", content="alpha")] holdout = [_Task("h1", content="beta")] s = HeldoutSplit(train, holdout, check_content=True) assert s.is_disjoint