Spaces:
Running
Running
| import os | |
| import glob | |
| import pandas | |
| import random | |
| from torch.utils.data import DataLoader, random_split | |
| import json | |
| import bisect | |
| dataset_path = os.path.join(os.sep, 'path', 'to', 'dataset') | |
| datasets = [] | |
| for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(dataset_path), topdown=True, followlinks=True): | |
| if len(dataset_dirs): | |
| continue | |
| # if 'Telegram' not in dataset_root: | |
| # continue | |
| id = dataset_root.split('Real/')[-1].split('Fake/')[-1] | |
| #print(id) | |
| shr = dataset_root.replace(dataset_path + os.sep, '').split('/')[0] | |
| #print(shr) | |
| if 'FORLAB' in id or 'FFHQ' in id: | |
| files = sorted([file.replace('.jpg', '').replace('.png', '') for file in dataset_files])[:40000] | |
| else: | |
| files = sorted([file.replace('.jpg', '').replace('.png', '') for file in dataset_files]) | |
| #print(len(files)) | |
| datasets.append({'id': id, 'shared': shr, 'root': dataset_root, 'files': files}) | |
| split = [] | |
| train_set = [] | |
| val_set = [] | |
| test_set = [] | |
| breakpoint() | |
| for dataset in [dataset for dataset in datasets if (dataset['shared'] == 'Telegram')]: | |
| print(dataset['id']) | |
| files_pre = [dataset_com for dataset_com in datasets if dataset_com['shared'] == 'PreSocial' and dataset_com['id'] == dataset['id']][0]['files'] | |
| files_post = dataset['files'] | |
| train_set_post, val_set_post, test_set_post = random_split(files_post, [0.7, 0.15, 0.15]) | |
| residual_pre = [file for file in files_pre if file not in files_post] | |
| residual_pre_neg = [file for file in files_pre if file in files_post] | |
| train_set_pre, val_set_pre, test_set_pre = random_split(residual_pre, [0.7, 0.15, 0.15]) | |
| train_set = train_set + [os.path.join(dataset['id'], file) for file in train_set_post] + [os.path.join(dataset['id'], file) for file in train_set_pre] | |
| val_set = val_set + [os.path.join(dataset['id'], file) for file in val_set_post] + [os.path.join(dataset['id'], file) for file in val_set_pre] | |
| test_set = test_set + [os.path.join(dataset['id'], file) for file in test_set_post] + [os.path.join(dataset['id'], file) for file in test_set_pre] | |
| print(len(train_set_post), len(val_set_post), len(test_set_post), ':', len(train_set_post)+len(val_set_post)+len(test_set_post)) | |
| print(len(train_set_pre), len(val_set_pre), len(test_set_pre), ':', len(train_set_pre)+len(val_set_pre)+len(test_set_pre)) | |
| print(len(train_set_pre)+len(train_set_post), len(val_set_pre)+len(val_set_post), len(test_set_pre)+len(test_set_post), ':', len(train_set_pre)+len(train_set_post)+len(val_set_pre)+len(val_set_post)+len(test_set_pre)+len(test_set_post)) | |
| #print(val_set) | |
| #print(test_set) | |
| #train_set = train_set + [os.path.join(dataset['id'], file) for file in train_set_pre] | |
| #val_set = val_set + [os.path.join(dataset['id'], file) for file in val_set_pre] | |
| #test_set = test_set + [os.path.join(dataset['id'], file) for file in test_set_pre] | |
| print(len(train_set), len(val_set), len(test_set), ':', len(train_set)+len(val_set)+len(test_set)) | |
| #with open("train.json", "w") as f: | |
| # json.dump(train_set, f) | |
| #with open("val.json", "w") as f: | |
| # json.dump(val_set, f) | |
| #with open("test.json", "w") as f: | |
| # json.dump(test_set, f) | |
| with open("split.json", "w") as f: | |
| json.dump({'train': sorted(train_set), 'val': sorted(val_set), 'test': sorted(test_set)}, f) |