Petimot / split_script.py
Vlmbd
add split script
4e57c0d unverified
import random
from collections import defaultdict
def split_families(input_file, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42):
"""
Split families into train/val/test (70/15/15), but for val and test,
take only the first unique identifier per family.
Args:
input_file: path to file containing identifiers
train_ratio: proportion of families for training (0.7 = 70%)
val_ratio: proportion of families for validation (0.15 = 15%)
test_ratio: proportion of families for testing (0.15 = 15%)
seed: random seed for reproducibility
"""
# Set seed for reproducibility
random.seed(seed)
# Dictionary to group identifiers by family
families = defaultdict(list)
# Read file and group by family
with open(input_file, 'r') as f:
for line in f:
identifier = line.strip()
if identifier: # Skip empty lines
# Extract family (part before first underscore)
family = identifier.split('_')[0]
families[family].append(identifier)
# Convert to list of families for shuffling
family_list = list(families.keys())
random.shuffle(family_list)
# Calculate split sizes
total_families = len(family_list)
train_size = int(total_families * train_ratio)
val_size = int(total_families * val_ratio)
# test_size will be the remainder
# Split families
train_families = family_list[:train_size]
val_families = family_list[train_size:train_size + val_size]
test_families = family_list[train_size + val_size:]
# Create identifier lists for each split
train_ids = []
val_ids = []
test_ids = []
# Train: take all identifiers from train families
for family in train_families:
train_ids.extend(families[family])
# Val: take only the first identifier from each val family
for family in val_families:
val_ids.append(families[family][0]) # First identifier only
# Test: take only the first identifier from each test family
for family in test_families:
test_ids.append(families[family][0]) # First identifier only
# Save files
with open('train_list.txt', 'w') as f:
for identifier in train_ids:
f.write(identifier + '\n')
with open('val_list.txt', 'w') as f:
for identifier in val_ids:
f.write(identifier + '\n')
with open('test_list.txt', 'w') as f:
for identifier in test_ids:
f.write(identifier + '\n')
# Print statistics
print(f"Total families: {total_families}")
print(f"Total identifiers in input: {sum(len(ids) for ids in families.values())}")
print()
print(f"Train: {len(train_families)} families ({len(train_families)/total_families*100:.1f}%), {len(train_ids)} identifiers")
print(f"Val: {len(val_families)} families ({len(val_families)/total_families*100:.1f}%), {len(val_ids)} identifiers (1 per family)")
print(f"Test: {len(test_families)} families ({len(test_families)/total_families*100:.1f}%), {len(test_ids)} identifiers (1 per family)")
print()
print(f"Total identifiers used: {len(train_ids) + len(val_ids) + len(test_ids)}")
print("Files created: train_list.txt, val_list.txt, test_list.txt")
# Show some example families in each split
print("\nExample families:")
print(f"Train: {train_families[:5]}...")
print(f"Val: {val_families[:3]}...")
print(f"Test: {test_families[:3]}...")
# Show some examples of what goes into val/test
if val_families:
print(f"\nExample val entries (first ID per family):")
for i, family in enumerate(val_families[:3]):
print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)")
if test_families:
print(f"\nExample test entries (first ID per family):")
for i, family in enumerate(test_families[:3]):
print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)")
if __name__ == "__main__":
# Run the script
split_families('full_list.txt')