FiberGate / scripts /resplit.py
AzizMiladi's picture
chore: git mv scripts, UI, dev tools, docs into folders
70c46cc
Raw
History Blame
1.53 kB
import json, random
from collections import defaultdict
random.seed(42)
with open('data2/combined_annotations.json', encoding='utf-8') as f:
all_records = json.load(f)
# Group pages by source PDF
pdf_groups = defaultdict(list)
for r in all_records:
pdf_id = r['image_file'].rsplit('_p', 1)[0]
pdf_groups[pdf_id].append(r)
pdfs = list(pdf_groups.keys())
random.shuffle(pdfs)
# 70/15/15 split at the PDF level
n = len(pdfs)
train_pdfs = pdfs[:int(n * 0.70)]
val_pdfs = pdfs[int(n * 0.70):int(n * 0.85)]
test_pdfs = pdfs[int(n * 0.85):]
def flatten(pdf_list):
return [r for p in pdf_list for r in pdf_groups[p]]
train = flatten(train_pdfs)
val = flatten(val_pdfs)
test = flatten(test_pdfs)
json.dump(train, open('data_combined/combined_train_v2.json', 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
json.dump(val, open('data_combined/combined_val_v2.json', 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
json.dump(test, open('data_combined/combined_test_v2.json', 'w', encoding='utf-8'), ensure_ascii=False, indent=2)
print(f"Train: {len(train)} records | Val: {len(val)} | Test: {len(test)}")
# Verify no contamination
train_pdfs_set = set(train_pdfs)
val_pdfs_set = set(val_pdfs)
test_pdfs_set = set(test_pdfs)
print(f"train∩val overlap: {len(train_pdfs_set & val_pdfs_set)} PDFs (should be 0)")
print(f"train∩test overlap: {len(train_pdfs_set & test_pdfs_set)} PDFs (should be 0)")
print(f"val∩test overlap: {len(val_pdfs_set & test_pdfs_set)} PDFs (should be 0)")