File size: 4,205 Bytes
4e57c0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import random
from collections import defaultdict
def split_families(input_file, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
"""
Split families into train/val/test (70/15/15), but for val and test,
take only the first unique identifier per family.
Args:
input_file: path to file containing identifiers
train_ratio: proportion of families for training (0.7 = 70%)
val_ratio: proportion of families for validation (0.15 = 15%)
test_ratio: proportion of families for testing (0.15 = 15%)
seed: random seed for reproducibility
"""
# Set seed for reproducibility
random.seed(seed)
# Dictionary to group identifiers by family
families = defaultdict(list)
# Read file and group by family
with open(input_file, 'r') as f:
for line in f:
identifier = line.strip()
if identifier: # Skip empty lines
# Extract family (part before first underscore)
family = identifier.split('_')[0]
families[family].append(identifier)
# Convert to list of families for shuffling
family_list = list(families.keys())
random.shuffle(family_list)
# Calculate split sizes
total_families = len(family_list)
train_size = int(total_families * train_ratio)
val_size = int(total_families * val_ratio)
# test_size will be the remainder
# Split families
train_families = family_list[:train_size]
val_families = family_list[train_size:train_size + val_size]
test_families = family_list[train_size + val_size:]
# Create identifier lists for each split
train_ids = []
val_ids = []
test_ids = []
# Train: take all identifiers from train families
for family in train_families:
train_ids.extend(families[family])
# Val: take only the first identifier from each val family
for family in val_families:
val_ids.append(families[family][0]) # First identifier only
# Test: take only the first identifier from each test family
for family in test_families:
test_ids.append(families[family][0]) # First identifier only
# Save files
with open('train_list.txt', 'w') as f:
for identifier in train_ids:
f.write(identifier + '\n')
with open('val_list.txt', 'w') as f:
for identifier in val_ids:
f.write(identifier + '\n')
with open('test_list.txt', 'w') as f:
for identifier in test_ids:
f.write(identifier + '\n')
# Print statistics
print(f"Total families: {total_families}")
print(f"Total identifiers in input: {sum(len(ids) for ids in families.values())}")
print()
print(f"Train: {len(train_families)} families ({len(train_families)/total_families*100:.1f}%), {len(train_ids)} identifiers")
print(f"Val: {len(val_families)} families ({len(val_families)/total_families*100:.1f}%), {len(val_ids)} identifiers (1 per family)")
print(f"Test: {len(test_families)} families ({len(test_families)/total_families*100:.1f}%), {len(test_ids)} identifiers (1 per family)")
print()
print(f"Total identifiers used: {len(train_ids) + len(val_ids) + len(test_ids)}")
print("Files created: train_list.txt, val_list.txt, test_list.txt")
# Show some example families in each split
print("\nExample families:")
print(f"Train: {train_families[:5]}...")
print(f"Val: {val_families[:3]}...")
print(f"Test: {test_families[:3]}...")
# Show some examples of what goes into val/test
if val_families:
print(f"\nExample val entries (first ID per family):")
for i, family in enumerate(val_families[:3]):
print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)")
if test_families:
print(f"\nExample test entries (first ID per family):")
for i, family in enumerate(test_families[:3]):
print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)")
if __name__ == "__main__":
# Run the script
split_families('full_list.txt') |