File size: 3,346 Bytes
9c4b1c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import glob
import pandas
import random
from torch.utils.data import DataLoader, random_split
import json
import bisect

dataset_path = os.path.join(os.sep, 'path', 'to', 'dataset') 

datasets = []
for dataset_root, dataset_dirs, dataset_files in os.walk(os.path.join(dataset_path), topdown=True, followlinks=True):
    if len(dataset_dirs):
        continue

    # if 'Telegram' not in dataset_root:
    #     continue

    id = dataset_root.split('Real/')[-1].split('Fake/')[-1]
    #print(id)

    shr = dataset_root.replace(dataset_path + os.sep, '').split('/')[0]
    #print(shr)
    if 'FORLAB' in id or 'FFHQ' in id:
        files = sorted([file.replace('.jpg', '').replace('.png', '') for file in dataset_files])[:40000]
    else:
        files = sorted([file.replace('.jpg', '').replace('.png', '') for file in dataset_files])
    #print(len(files))
    datasets.append({'id': id, 'shared': shr,  'root': dataset_root, 'files': files})


split = []

train_set = []
val_set = []
test_set = []
breakpoint()
for dataset in [dataset for dataset in datasets if (dataset['shared'] == 'Telegram')]:
    print(dataset['id'])
    files_pre = [dataset_com for dataset_com in datasets if dataset_com['shared'] == 'PreSocial' and dataset_com['id'] == dataset['id']][0]['files']
    files_post = dataset['files']

    train_set_post, val_set_post, test_set_post = random_split(files_post, [0.7, 0.15, 0.15])
    
    residual_pre = [file for file in files_pre if file not in files_post]
    residual_pre_neg = [file for file in files_pre if file in files_post]

    train_set_pre, val_set_pre, test_set_pre = random_split(residual_pre, [0.7, 0.15, 0.15])

    train_set = train_set + [os.path.join(dataset['id'], file) for file in train_set_post] + [os.path.join(dataset['id'], file) for file in train_set_pre]
    val_set = val_set + [os.path.join(dataset['id'], file) for file in val_set_post] + [os.path.join(dataset['id'], file) for file in val_set_pre]
    test_set = test_set + [os.path.join(dataset['id'], file) for file in test_set_post] + [os.path.join(dataset['id'], file) for file in test_set_pre]

    print(len(train_set_post), len(val_set_post), len(test_set_post), ':', len(train_set_post)+len(val_set_post)+len(test_set_post))
    print(len(train_set_pre), len(val_set_pre), len(test_set_pre), ':', len(train_set_pre)+len(val_set_pre)+len(test_set_pre))
    print(len(train_set_pre)+len(train_set_post), len(val_set_pre)+len(val_set_post), len(test_set_pre)+len(test_set_post), ':', len(train_set_pre)+len(train_set_post)+len(val_set_pre)+len(val_set_post)+len(test_set_pre)+len(test_set_post))
    #print(val_set)
    #print(test_set)
    #train_set = train_set + [os.path.join(dataset['id'], file) for file in train_set_pre]
    #val_set = val_set + [os.path.join(dataset['id'], file) for file in val_set_pre]
    #test_set = test_set + [os.path.join(dataset['id'], file) for file in test_set_pre]

print(len(train_set), len(val_set), len(test_set), ':', len(train_set)+len(val_set)+len(test_set))

#with open("train.json", "w") as f:
#    json.dump(train_set, f)

#with open("val.json", "w") as f:
#    json.dump(val_set, f)

#with open("test.json", "w") as f:
#    json.dump(test_set, f)

with open("split.json", "w") as f:
    json.dump({'train': sorted(train_set), 'val': sorted(val_set), 'test': sorted(test_set)}, f)