| |
| """Unit tests for split_manager module.""" |
|
|
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from split_manager import assign_splits, split_distribution |
|
|
|
|
| def test_same_video_stays_in_same_split(): |
| rows = [ |
| {"video_id": "vid_a", "chunk_id": "a_1"}, |
| {"video_id": "vid_a", "chunk_id": "a_2"}, |
| {"video_id": "vid_b", "chunk_id": "b_1"}, |
| ] |
| assign_splits(rows, {"split_seed": "test_seed"}) |
|
|
| vid_a_splits = {row["split"] for row in rows if row["video_id"] == "vid_a"} |
| assert len(vid_a_splits) == 1 |
| assert all(row["split_group_key"].startswith("video:") for row in rows) |
|
|
|
|
| def test_speaker_cluster_takes_priority(): |
| rows = [ |
| {"video_id": "vid_a", "speaker_cluster_id": "spk_1"}, |
| {"video_id": "vid_b", "speaker_cluster_id": "spk_1"}, |
| ] |
| assign_splits(rows, {"split_seed": "test_seed"}) |
| assert rows[0]["split"] == rows[1]["split"] |
| assert rows[0]["split_group_key"] == "speaker:spk_1" |
|
|
|
|
| def test_split_distribution_counts_rows(): |
| rows = [ |
| {"split": "train"}, |
| {"split": "train"}, |
| {"split": "val"}, |
| ] |
| assert split_distribution(rows) == {"train": 2, "val": 1} |
|
|
|
|
| if __name__ == "__main__": |
| test_same_video_stays_in_same_split() |
| test_speaker_cluster_takes_priority() |
| test_split_distribution_counts_rows() |
| print("All split_manager tests passed!") |
|
|