#!/usr/bin/env python3 """Unit tests for split_manager module.""" import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from split_manager import assign_splits, split_distribution def test_same_video_stays_in_same_split(): rows = [ {"video_id": "vid_a", "chunk_id": "a_1"}, {"video_id": "vid_a", "chunk_id": "a_2"}, {"video_id": "vid_b", "chunk_id": "b_1"}, ] assign_splits(rows, {"split_seed": "test_seed"}) vid_a_splits = {row["split"] for row in rows if row["video_id"] == "vid_a"} assert len(vid_a_splits) == 1 assert all(row["split_group_key"].startswith("video:") for row in rows) def test_speaker_cluster_takes_priority(): rows = [ {"video_id": "vid_a", "speaker_cluster_id": "spk_1"}, {"video_id": "vid_b", "speaker_cluster_id": "spk_1"}, ] assign_splits(rows, {"split_seed": "test_seed"}) assert rows[0]["split"] == rows[1]["split"] assert rows[0]["split_group_key"] == "speaker:spk_1" def test_split_distribution_counts_rows(): rows = [ {"split": "train"}, {"split": "train"}, {"split": "val"}, ] assert split_distribution(rows) == {"train": 2, "val": 1} if __name__ == "__main__": test_same_video_stays_in_same_split() test_speaker_cluster_takes_priority() test_split_distribution_counts_rows() print("All split_manager tests passed!")