| import random |
| from collections import defaultdict |
|
|
| def split_families(input_file, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, seed=42): |
| """ |
| Split families into train/val/test (70/15/15), but for val and test, |
| take only the first unique identifier per family. |
| |
| Args: |
| input_file: path to file containing identifiers |
| train_ratio: proportion of families for training (0.7 = 70%) |
| val_ratio: proportion of families for validation (0.15 = 15%) |
| test_ratio: proportion of families for testing (0.15 = 15%) |
| seed: random seed for reproducibility |
| """ |
| |
| |
| random.seed(seed) |
| |
| |
| families = defaultdict(list) |
| |
| |
| with open(input_file, 'r') as f: |
| for line in f: |
| identifier = line.strip() |
| if identifier: |
| |
| family = identifier.split('_')[0] |
| families[family].append(identifier) |
| |
| |
| family_list = list(families.keys()) |
| random.shuffle(family_list) |
| |
| |
| total_families = len(family_list) |
| train_size = int(total_families * train_ratio) |
| val_size = int(total_families * val_ratio) |
| |
| |
| |
| train_families = family_list[:train_size] |
| val_families = family_list[train_size:train_size + val_size] |
| test_families = family_list[train_size + val_size:] |
| |
| |
| train_ids = [] |
| val_ids = [] |
| test_ids = [] |
| |
| |
| for family in train_families: |
| train_ids.extend(families[family]) |
| |
| |
| for family in val_families: |
| val_ids.append(families[family][0]) |
| |
| |
| for family in test_families: |
| test_ids.append(families[family][0]) |
| |
| |
| with open('train_list.txt', 'w') as f: |
| for identifier in train_ids: |
| f.write(identifier + '\n') |
| |
| with open('val_list.txt', 'w') as f: |
| for identifier in val_ids: |
| f.write(identifier + '\n') |
| |
| with open('test_list.txt', 'w') as f: |
| for identifier in test_ids: |
| f.write(identifier + '\n') |
| |
| |
| print(f"Total families: {total_families}") |
| print(f"Total identifiers in input: {sum(len(ids) for ids in families.values())}") |
| print() |
| print(f"Train: {len(train_families)} families ({len(train_families)/total_families*100:.1f}%), {len(train_ids)} identifiers") |
| print(f"Val: {len(val_families)} families ({len(val_families)/total_families*100:.1f}%), {len(val_ids)} identifiers (1 per family)") |
| print(f"Test: {len(test_families)} families ({len(test_families)/total_families*100:.1f}%), {len(test_ids)} identifiers (1 per family)") |
| print() |
| print(f"Total identifiers used: {len(train_ids) + len(val_ids) + len(test_ids)}") |
| print("Files created: train_list.txt, val_list.txt, test_list.txt") |
| |
| |
| print("\nExample families:") |
| print(f"Train: {train_families[:5]}...") |
| print(f"Val: {val_families[:3]}...") |
| print(f"Test: {test_families[:3]}...") |
| |
| |
| if val_families: |
| print(f"\nExample val entries (first ID per family):") |
| for i, family in enumerate(val_families[:3]): |
| print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)") |
| |
| if test_families: |
| print(f"\nExample test entries (first ID per family):") |
| for i, family in enumerate(test_families[:3]): |
| print(f" Family {family}: {families[family][0]} (from {len(families[family])} available)") |
|
|
| if __name__ == "__main__": |
| |
| split_families('full_list.txt') |