File size: 1,431 Bytes
ea99d10 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | #!/usr/bin/env python3
"""Unit tests for split_manager module."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from split_manager import assign_splits, split_distribution
def test_same_video_stays_in_same_split():
rows = [
{"video_id": "vid_a", "chunk_id": "a_1"},
{"video_id": "vid_a", "chunk_id": "a_2"},
{"video_id": "vid_b", "chunk_id": "b_1"},
]
assign_splits(rows, {"split_seed": "test_seed"})
vid_a_splits = {row["split"] for row in rows if row["video_id"] == "vid_a"}
assert len(vid_a_splits) == 1
assert all(row["split_group_key"].startswith("video:") for row in rows)
def test_speaker_cluster_takes_priority():
rows = [
{"video_id": "vid_a", "speaker_cluster_id": "spk_1"},
{"video_id": "vid_b", "speaker_cluster_id": "spk_1"},
]
assign_splits(rows, {"split_seed": "test_seed"})
assert rows[0]["split"] == rows[1]["split"]
assert rows[0]["split_group_key"] == "speaker:spk_1"
def test_split_distribution_counts_rows():
rows = [
{"split": "train"},
{"split": "train"},
{"split": "val"},
]
assert split_distribution(rows) == {"train": 2, "val": 1}
if __name__ == "__main__":
test_same_video_stays_in_same_split()
test_speaker_cluster_takes_priority()
test_split_distribution_counts_rows()
print("All split_manager tests passed!")
|