Vlmbd commited on
Commit
fa98756
·
unverified ·
1 Parent(s): f1804f1

add split script

Browse files
Files changed (1) hide show
  1. split_script.py +107 -0
split_script.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import defaultdict
3
+
4
+ def split_families(input_file, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
5
+ """
6
+ Split families into train/val/test (70/15/15), but for val and test,
7
+ take only the first unique identifier per family.
8
+
9
+ Args:
10
+ input_file: path to file containing identifiers
11
+ train_ratio: proportion of families for training (0.7 = 70%)
12
+ val_ratio: proportion of families for validation (0.15 = 15%)
13
+ test_ratio: proportion of families for testing (0.15 = 15%)
14
+ seed: random seed for reproducibility
15
+ """
16
+
17
+ # Set seed for reproducibility
18
+ random.seed(seed)
19
+
20
+ # Dictionary to group identifiers by family
21
+ families = defaultdict(list)
22
+
23
+ # Read file and group by family
24
+ with open(input_file, 'r') as f:
25
+ for line in f:
26
+ identifier = line.strip()
27
+ if identifier: # Skip empty lines
28
+ # Extract family (part before first underscore)
29
+ family = identifier.split('_')[0]
30
+ families[family].append(identifier)
31
+
32
+ # Convert to list of families for shuffling
33
+ family_list = list(families.keys())
34
+ random.shuffle(family_list)
35
+
36
+ # Calculate split sizes
37
+ total_families = len(family_list)
38
+ train_size = int(total_families * train_ratio)
39
+ val_size = int(total_families * val_ratio)
40
+ # test_size will be the remainder
41
+
42
+ # Split families
43
+ train_families = family_list[:train_size]
44
+ val_families = family_list[train_size:train_size + val_size]
45
+ test_families = family_list[train_size + val_size:]
46
+
47
+ # Create identifier lists for each split
48
+ train_ids = []
49
+ val_ids = []
50
+ test_ids = []
51
+
52
+ # Train: take all identifiers from train families
53
+ for family in train_families:
54
+ train_ids.extend(families[family])
55
+
56
+ # Val: take only the first identifier from each val family
57
+ for family in val_families:
58
+ val_ids.append(families[family][0]) # First identifier only
59
+
60
+ # Test: take only the first identifier from each test family
61
+ for family in test_families:
62
+ test_ids.append(families[family][0]) # First identifier only
63
+
64
+ # Save files
65
+ with open('train_list.txt', 'w') as f:
66
+ for identifier in train_ids:
67
+ f.write(identifier + '\n')
68
+
69
+ with open('val_list.txt', 'w') as f:
70
+ for identifier in val_ids:
71
+ f.write(identifier + '\n')
72
+
73
+ with open('test_list.txt', 'w') as f:
74
+ for identifier in test_ids:
75
+ f.write(identifier + '\n')
76
+
77
+ # Print statistics
78
+ print(f"Total families: {total_families}")
79
+ print(f"Total identifiers in input: {sum(len(ids) for ids in families.values())}")
80
+ print()
81
+ print(f"Train: {len(train_families)} families ({len(train_families)/total_families*100:.1f}%), {len(train_ids)} identifiers")
82
+ print(f"Val: {len(val_families)} families ({len(val_families)/total_families*100:.1f}%), {len(val_ids)} identifiers (1 per family)")
83
+ print(f"Test: {len(test_families)} families ({len(test_families)/total_families*100:.1f}%), {len(test_ids)} identifiers (1 per family)")
84
+ print()
85
+ print(f"Total identifiers used: {len(train_ids) + len(val_ids) + len(test_ids)}")
86
+ print("Files created: train_list.txt, val_list.txt, test_list.txt")
87
+
88
+ # Show some example families in each split
89
+ print("\nExample families:")
90
+ print(f"Train: {train_families[:5]}...")
91
+ print(f"Val: {val_families[:3]}...")
92
+ print(f"Test: {test_families[:3]}...")
93
+
94
+ # Show some examples of what goes into val/test
95
+ if val_families:
96
+ print(f"\nExample val entries (first ID per family):")
97
+ for i, family in enumerate(val_families[:3]):
98
+ print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)")
99
+
100
+ if test_families:
101
+ print(f"\nExample test entries (first ID per family):")
102
+ for i, family in enumerate(test_families[:3]):
103
+ print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)")
104
+
105
+ if __name__ == "__main__":
106
+ # Run the script
107
+ split_families('full_list.txt')