Baladithya Balamurugan
Wave 3: close the HIGH review findings (kill-switch wiring, HeldoutSplit, EKS entrypoint bug)
bd0c358
Raw
History Blame Contribute Delete
4.87 kB
"""Tests for HeldoutSplit — the train/held-out set-disjointness enforcer.
This is the second half of C1 (the first half is HeldOutGuard in kill_switch.py):
the guard's proxy-real-gap signal is only meaningful if the held-out eval set is
genuinely DISJOINT from the train/generator set. HeldoutSplit enforces that.
Written during Wave-3 integration (the build agent shipped holdout.py without a
test module — same test-gap pattern as the SageMaker/EKS executors).
"""
from __future__ import annotations
import pytest
from composer_replication.safety import HeldoutOverlapError, HeldoutSplit
# A tiny FeatureDeletionTask-like stand-in (HeldoutSplit reads task_id + content
# fields via duck-typing; a plain object with task_id works, and a string is
# treated as its own id).
class _Task:
"""FeatureDeletionTask-like stand-in. content-hashing in holdout.py reads the
real task fields (repo/base_commit/test_command/...), so the content kwarg
populates `repo` (one of the hashed fields), not a generic `content` attr."""
def __init__(self, task_id, content=""):
self.task_id = task_id
self.repo = content # `repo` is one of the fields _content_fields hashes
# ---------------------------------------------------------------------
# id-based disjointness
# ---------------------------------------------------------------------
def test_disjoint_passes():
s = HeldoutSplit(train=["a", "b", "c"], holdout=["d", "e"])
assert s.is_disjoint
assert s.overlapping_ids() == ()
# validate / assert_disjoint return self and do not raise
assert s.validate() is s
assert s.assert_disjoint() is s
def test_overlap_raises_and_lists_ids():
s = HeldoutSplit(train=["a", "b", "shared"], holdout=["shared", "z"])
assert not s.is_disjoint
assert s.overlapping_ids() == ("shared",)
with pytest.raises(HeldoutOverlapError) as exc:
s.validate()
assert "shared" in str(exc.value)
def test_object_tasks_use_task_id():
train = [_Task("t1"), _Task("t2")]
holdout = [_Task("h1")]
assert HeldoutSplit(train, holdout).is_disjoint
# an object sharing an id with train is a leak
leak = HeldoutSplit(train, [_Task("t1")])
assert not leak.is_disjoint
assert leak.overlapping_ids() == ("t1",)
# ---------------------------------------------------------------------
# deterministic split()
# ---------------------------------------------------------------------
def test_split_is_disjoint_by_construction():
pool = [f"task{i}" for i in range(10)]
s = HeldoutSplit.split(pool, holdout_frac=0.3, seed=0)
assert s.is_disjoint
# every id is on exactly one side; union covers the (de-duped) pool
assert s.train_ids.isdisjoint(s.holdout_ids)
assert (s.train_ids | s.holdout_ids) == set(pool)
def test_split_is_deterministic():
pool = [f"task{i}" for i in range(20)]
a = HeldoutSplit.split(pool, holdout_frac=0.25, seed=42)
b = HeldoutSplit.split(pool, holdout_frac=0.25, seed=42)
assert a.holdout_ids == b.holdout_ids
assert a.train_ids == b.train_ids
# a different seed gives a (very likely) different partition
c = HeldoutSplit.split(pool, holdout_frac=0.25, seed=7)
assert c.holdout_ids != a.holdout_ids or c.train_ids != a.train_ids
def test_split_never_collapses_a_side():
s = HeldoutSplit.split([f"t{i}" for i in range(5)], holdout_frac=0.01, seed=0)
assert len(s.holdout_ids) >= 1
assert len(s.train_ids) >= 1
def test_split_rejects_bad_frac():
with pytest.raises(ValueError):
HeldoutSplit.split(["a", "b"], holdout_frac=1.5)
def test_split_dedups_by_id():
# duplicate ids cannot land on both sides
pool = ["a", "a", "b", "c"]
s = HeldoutSplit.split(pool, holdout_frac=0.5, seed=0)
assert s.is_disjoint
assert (s.train_ids | s.holdout_ids) == {"a", "b", "c"}
# ---------------------------------------------------------------------
# content-hash disjointness (catches same-content / different-id near-dups)
# ---------------------------------------------------------------------
def test_content_hash_catches_same_content_different_id():
# different ids, identical content -> id-disjoint but content-leaked
train = [_Task("t1", content="fix the off-by-one in range()")]
holdout = [_Task("h1", content="fix the off-by-one in range()")]
s = HeldoutSplit(train, holdout, check_content=True)
assert s.overlapping_ids() == () # ids are disjoint
assert not s.is_disjoint # but content collides
assert s.overlapping_content_hashes() # non-empty
with pytest.raises(HeldoutOverlapError):
s.validate()
def test_content_hash_disjoint_when_content_differs():
train = [_Task("t1", content="alpha")]
holdout = [_Task("h1", content="beta")]
s = HeldoutSplit(train, holdout, check_content=True)
assert s.is_disjoint