File size: 1,431 Bytes
ea99d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python3
"""Unit tests for split_manager module."""

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from split_manager import assign_splits, split_distribution


def test_same_video_stays_in_same_split():
    rows = [
        {"video_id": "vid_a", "chunk_id": "a_1"},
        {"video_id": "vid_a", "chunk_id": "a_2"},
        {"video_id": "vid_b", "chunk_id": "b_1"},
    ]
    assign_splits(rows, {"split_seed": "test_seed"})

    vid_a_splits = {row["split"] for row in rows if row["video_id"] == "vid_a"}
    assert len(vid_a_splits) == 1
    assert all(row["split_group_key"].startswith("video:") for row in rows)


def test_speaker_cluster_takes_priority():
    rows = [
        {"video_id": "vid_a", "speaker_cluster_id": "spk_1"},
        {"video_id": "vid_b", "speaker_cluster_id": "spk_1"},
    ]
    assign_splits(rows, {"split_seed": "test_seed"})
    assert rows[0]["split"] == rows[1]["split"]
    assert rows[0]["split_group_key"] == "speaker:spk_1"


def test_split_distribution_counts_rows():
    rows = [
        {"split": "train"},
        {"split": "train"},
        {"split": "val"},
    ]
    assert split_distribution(rows) == {"train": 2, "val": 1}


if __name__ == "__main__":
    test_same_video_stays_in_same_split()
    test_speaker_cluster_takes_priority()
    test_split_distribution_counts_rows()
    print("All split_manager tests passed!")