MDS_demonstrator / support /json_compile.py
AMontiB
Your original commit message (now includes LFS pointer)
9c4b1c4
import os
import glob
import pandas
import random
from torch.utils.data import DataLoader, random_split
import json
import bisect
dataset_path = os.path.join(os.sep, 'path', 'to', 'dataset')
datasets = []
for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(dataset_path), topdown=True, followlinks=True):
if len(dataset_dirs):
continue
# if 'Telegram' not in dataset_root:
# continue
id = dataset_root.split('Real/')[-1].split('Fake/')[-1]
#print(id)
shr = dataset_root.replace(dataset_path + os.sep, '').split('/')[0]
#print(shr)
if 'FORLAB' in id or 'FFHQ' in id:
files = sorted([file.replace('.jpg', '').replace('.png', '') for file in dataset_files])[:40000]
else:
files = sorted([file.replace('.jpg', '').replace('.png', '') for file in dataset_files])
#print(len(files))
datasets.append({'id': id, 'shared': shr, 'root': dataset_root, 'files': files})
split = []
train_set = []
val_set = []
test_set = []
breakpoint()
for dataset in [dataset for dataset in datasets if (dataset['shared'] == 'Telegram')]:
print(dataset['id'])
files_pre = [dataset_com for dataset_com in datasets if dataset_com['shared'] == 'PreSocial' and dataset_com['id'] == dataset['id']][0]['files']
files_post = dataset['files']
train_set_post, val_set_post, test_set_post = random_split(files_post, [0.7, 0.15, 0.15])
residual_pre = [file for file in files_pre if file not in files_post]
residual_pre_neg = [file for file in files_pre if file in files_post]
train_set_pre, val_set_pre, test_set_pre = random_split(residual_pre, [0.7, 0.15, 0.15])
train_set = train_set + [os.path.join(dataset['id'], file) for file in train_set_post] + [os.path.join(dataset['id'], file) for file in train_set_pre]
val_set = val_set + [os.path.join(dataset['id'], file) for file in val_set_post] + [os.path.join(dataset['id'], file) for file in val_set_pre]
test_set = test_set + [os.path.join(dataset['id'], file) for file in test_set_post] + [os.path.join(dataset['id'], file) for file in test_set_pre]
print(len(train_set_post), len(val_set_post), len(test_set_post), ':', len(train_set_post)+len(val_set_post)+len(test_set_post))
print(len(train_set_pre), len(val_set_pre), len(test_set_pre), ':', len(train_set_pre)+len(val_set_pre)+len(test_set_pre))
print(len(train_set_pre)+len(train_set_post), len(val_set_pre)+len(val_set_post), len(test_set_pre)+len(test_set_post), ':', len(train_set_pre)+len(train_set_post)+len(val_set_pre)+len(val_set_post)+len(test_set_pre)+len(test_set_post))
#print(val_set)
#print(test_set)
#train_set = train_set + [os.path.join(dataset['id'], file) for file in train_set_pre]
#val_set = val_set + [os.path.join(dataset['id'], file) for file in val_set_pre]
#test_set = test_set + [os.path.join(dataset['id'], file) for file in test_set_pre]
print(len(train_set), len(val_set), len(test_set), ':', len(train_set)+len(val_set)+len(test_set))
#with open("train.json", "w") as f:
# json.dump(train_set, f)
#with open("val.json", "w") as f:
# json.dump(val_set, f)
#with open("test.json", "w") as f:
# json.dump(test_set, f)
with open("split.json", "w") as f:
json.dump({'train': sorted(train_set), 'val': sorted(val_set), 'test': sorted(test_set)}, f)