File size: 4,205 Bytes
4e57c0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import random
from collections import defaultdict

def split_families(input_file, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
    """
    Split families into train/val/test (70/15/15), but for val and test,
    take only the first unique identifier per family.
    
    Args:
        input_file: path to file containing identifiers
        train_ratio: proportion of families for training (0.7 = 70%)
        val_ratio: proportion of families for validation (0.15 = 15%)
        test_ratio: proportion of families for testing (0.15 = 15%)
        seed: random seed for reproducibility
    """
    
    # Set seed for reproducibility
    random.seed(seed)
    
    # Dictionary to group identifiers by family
    families = defaultdict(list)
    
    # Read file and group by family
    with open(input_file, 'r') as f:
        for line in f:
            identifier = line.strip()
            if identifier:  # Skip empty lines
                # Extract family (part before first underscore)
                family = identifier.split('_')[0]
                families[family].append(identifier)
    
    # Convert to list of families for shuffling
    family_list = list(families.keys())
    random.shuffle(family_list)
    
    # Calculate split sizes
    total_families = len(family_list)
    train_size = int(total_families * train_ratio)
    val_size = int(total_families * val_ratio)
    # test_size will be the remainder
    
    # Split families
    train_families = family_list[:train_size]
    val_families = family_list[train_size:train_size + val_size]
    test_families = family_list[train_size + val_size:]
    
    # Create identifier lists for each split
    train_ids = []
    val_ids = []
    test_ids = []
    
    # Train: take all identifiers from train families
    for family in train_families:
        train_ids.extend(families[family])
    
    # Val: take only the first identifier from each val family
    for family in val_families:
        val_ids.append(families[family][0])  # First identifier only
        
    # Test: take only the first identifier from each test family
    for family in test_families:
        test_ids.append(families[family][0])  # First identifier only
    
    # Save files
    with open('train_list.txt', 'w') as f:
        for identifier in train_ids:
            f.write(identifier + '\n')
    
    with open('val_list.txt', 'w') as f:
        for identifier in val_ids:
            f.write(identifier + '\n')
    
    with open('test_list.txt', 'w') as f:
        for identifier in test_ids:
            f.write(identifier + '\n')
    
    # Print statistics
    print(f"Total families: {total_families}")
    print(f"Total identifiers in input: {sum(len(ids) for ids in families.values())}")
    print()
    print(f"Train: {len(train_families)} families ({len(train_families)/total_families*100:.1f}%), {len(train_ids)} identifiers")
    print(f"Val:   {len(val_families)} families ({len(val_families)/total_families*100:.1f}%), {len(val_ids)} identifiers (1 per family)")
    print(f"Test:  {len(test_families)} families ({len(test_families)/total_families*100:.1f}%), {len(test_ids)} identifiers (1 per family)")
    print()
    print(f"Total identifiers used: {len(train_ids) + len(val_ids) + len(test_ids)}")
    print("Files created: train_list.txt, val_list.txt, test_list.txt")
    
    # Show some example families in each split
    print("\nExample families:")
    print(f"Train: {train_families[:5]}...")
    print(f"Val: {val_families[:3]}...")
    print(f"Test: {test_families[:3]}...")
    
    # Show some examples of what goes into val/test
    if val_families:
        print(f"\nExample val entries (first ID per family):")
        for i, family in enumerate(val_families[:3]):
            print(f"  Family {family}: {families[family][0]} (from {len(families[family])} available)")
    
    if test_families:
        print(f"\nExample test entries (first ID per family):")
        for i, family in enumerate(test_families[:3]):
            print(f"  Family {family}: {families[family][0]} (from {len(families[family])} available)")

if __name__ == "__main__":
    # Run the script
    split_families('full_list.txt')