Transformers
jespark commited on
Commit
37fcdfe
·
verified ·
1 Parent(s): 21900f4

Upload dataloader.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataloader.py +572 -0
dataloader.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Dataloader for fake/real image classification
2
+
3
+ import torch
4
+ import pandas as pd
5
+ import numpy as np
6
+ import os
7
+ import PIL.Image
8
+ import random
9
+ import custom_transforms as ctrans
10
+ import math
11
+ import utils as ut
12
+
13
+ from torchvision import transforms
14
+ #from torchvision.transforms import v2 as transforms
15
+ from torch.utils.data.distributed import DistributedSampler
16
+ from custom_sampler import DistributedEvalSampler
17
+ from functools import partial
18
+ import datasets as ds
19
+ import io
20
+ import logging
21
+
22
+ class dataset_huggingface(torch.utils.data.Dataset):
23
+ """
24
+ Dataset for Community Forensics
25
+ """
26
+ def __init__(
27
+ self,
28
+ args,
29
+ repo_id='OwensLab/CommunityForensics',
30
+ split='Systematic+Manual',
31
+ mode='train',
32
+ cache_dir='',
33
+ dtype=torch.float32,
34
+ ):
35
+ """
36
+ args: Namespace of argument parser
37
+ split: split of the dataset to use
38
+ mode: 'train' or 'eval'
39
+ cache_dir: directory to cache the dataset
40
+ dtype: data type
41
+ """
42
+ super(dataset_huggingface).__init__()
43
+ self.args = args
44
+ self.repo_id = repo_id
45
+ self.split = split
46
+ self.mode = mode
47
+ self.cache_dir = cache_dir
48
+ self.dtype = dtype
49
+ self.dataset = self.get_hf_dataset()
50
+
51
+ def __getitem__(self, index):
52
+ """
53
+ Returns the image and label for the given index.
54
+ """
55
+ data = self.dataset[index]
56
+ image_bytes = data['image_data']
57
+ label = int(data['label'])
58
+ generator_name = data['model_name']
59
+
60
+ img = PIL.Image.open(io.BytesIO(image_bytes)).convert("RGB")
61
+
62
+ return img, label, generator_name
63
+
64
+ def get_hf_dataset(self):
65
+ """
66
+ Returns the huggingface dataset object
67
+ """
68
+ hf_repo_id = self.repo_id
69
+ if self.mode == 'train':
70
+ shuffle=True
71
+ shuffle_batch_size=3000
72
+ elif self.mode == 'eval':
73
+ shuffle=False
74
+
75
+ #### TEST TOKEN PART ####
76
+ #### TEST TOKEN PART ####
77
+ #### TEST TOKEN PART ####
78
+ token_df = pd.read_csv("/nfs/turbo/coe-ahowens/jespark/tokens.csv")
79
+ HF_TOKEN = token_df.loc[token_df['label'] == 'huggingface_write_token', 'token'].values[0]
80
+ #### TEST TOKEN PART ####
81
+ #### TEST TOKEN PART ####
82
+ #### TEST TOKEN PART ####
83
+
84
+ hf_dataset = ds.load_dataset(hf_repo_id, split=self.split, cache_dir=self.cache_dir, token=HF_TOKEN)
85
+ if shuffle:
86
+ hf_dataset = hf_dataset.shuffle(seed=self.args.seed, writer_batch_size=shuffle_batch_size)
87
+
88
+ return hf_dataset
89
+
90
+ def __len__(self):
91
+ """
92
+ Returns the length of the dataset.
93
+ """
94
+ return len(self.dataset)
95
+
96
+ class dataset_folder_based(torch.utils.data.Dataset):
97
+ """
98
+ Dataset for sourcing images from a directory; designed to be used with the huggingface datasets library.
99
+ """
100
+ def __init__(
101
+ self,
102
+ args,
103
+ dir,
104
+ labels="real:0,fake:1",
105
+ logger: logging.Logger = None,
106
+ dtype=torch.float32,
107
+ ):
108
+ """
109
+ args: Namespace of argument parser
110
+ dir: directory to index
111
+ labels: labels for the dataset. Default: "real:0,fake:1" -- assigns integer label 0 to images under "real" and 1 to images under "fake".
112
+ dtype: data type
113
+
114
+ The directory must be formatted as follows:
115
+ - <generator_or_dataset_name>
116
+ ∟ <label -- "real" or "fake">
117
+ ∟ <image_name>.{jpg,png,...}
118
+ `dir` should point to the parent directory of the `generator_or_dataset_name` folders.
119
+ """
120
+ super(dataset_folder_based).__init__()
121
+ self.args = args
122
+ self.dir = dir
123
+ self.labels = self.parse_labels(labels)
124
+ assert len(self.labels) == 2, f"Labels must be in the format 'label1:int,label2:int'. It only supports two labels. Instead, it is: {labels}."
125
+
126
+ self.logger = logger
127
+ if self.logger is None:
128
+ self.logger = ut.logger
129
+ self.dtype = dtype
130
+ self.df = self.get_index(dir)
131
+
132
+ def __getitem__(self, index):
133
+ """
134
+ Returns the image and label for the given index.
135
+ """
136
+ img_path = self.df.iloc[index]['ImagePath']
137
+ label = int(self.df.iloc[index]['Label'])
138
+ generator_name = self.df.iloc[index]['GeneratorName']
139
+
140
+ img = PIL.Image.open(img_path).convert("RGB")
141
+
142
+ return img, label, generator_name
143
+
144
+ def __len__(self):
145
+ """
146
+ Returns the length of the dataset.
147
+ """
148
+ return len(self.df)
149
+
150
+ def parse_labels(self, labels):
151
+ """
152
+ Parses the labels string and returns a dictionary of labels.
153
+ """
154
+ labels_dict = {}
155
+ for label in labels.split(','):
156
+ label_name, label_value = label.split(':')
157
+ labels_dict[label_name] = int(label_value)
158
+
159
+ return labels_dict
160
+
161
+ def get_label_int(self, label):
162
+ """
163
+ Returns the integer label for the given label name.
164
+ """
165
+ if label in self.labels:
166
+ return self.labels[label]
167
+ else:
168
+ raise ValueError(f"Label {label} not found in labels: {self.labels}. Please check the labels.")
169
+
170
+ def get_index(self, dir):
171
+ """
172
+ Check the `dir` for the index file. If it exists, load it. If not, index the directory and save the index file.
173
+ """
174
+ index_path = os.path.join(dir, 'index.csv')
175
+ if os.path.exists(index_path):
176
+ df = pd.read_csv(index_path)
177
+ if self.args.rank == 0:
178
+ self.logger.info(f"Loaded index file from {index_path}")
179
+ else:
180
+ if self.args.rank == 0:
181
+ self.logger.info(f"Index file not found. Indexing the directory {dir}. This may take a while...")
182
+ df = self.index_directory(dir)
183
+ return df
184
+
185
+ def index_directory(self, dir, report_every=1000):
186
+ """
187
+ Indexes the given directory and returns a dataframe with the image paths, labels, and generator names.
188
+ The directory must be formatted as follows:
189
+ - <generator_or_dataset_name>
190
+ ∟ <label -- "real" or "fake">
191
+ ∟ <image_name>.{jpg,png,...}
192
+ `dir` should point to the parent directory of the `generator_or_dataset_name` folders.
193
+ """
194
+ df = pd.DataFrame(columns=['ImagePath', 'Label', 'GeneratorName'])
195
+ temp_dfs=[]
196
+ for root, dirs, files in os.walk(dir):
197
+ for file in files:
198
+ if file.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif')):
199
+ # get the generator name and label from the directory structure
200
+ generator_name=os.path.basename(os.path.dirname(root))
201
+ label=os.path.basename(root) # should be "real" or "fake"
202
+ label_int=self.get_label_int(label)
203
+ # get the image path
204
+ image_path=os.path.join(root, file)
205
+ # append the image path, label, and generator name to the list
206
+ temp_dfs.append(pd.DataFrame([[image_path, label_int, generator_name]], columns=['ImagePath', 'Label', 'GeneratorName']))
207
+ if len(temp_dfs) % report_every == 0 and self.args.rank == 0:
208
+ print(f"\rIndexed {len(temp_dfs)} images... ", end='', flush=True)
209
+ df = pd.concat(temp_dfs, ignore_index=True)
210
+ print("") # print a new line after the progress bar
211
+ # sort the dataframe by generator name, label, and image name
212
+ df = df.sort_values(by=['GeneratorName', 'Label', 'ImagePath'])
213
+ df = df.reset_index(drop=True)
214
+ # save the dataframe
215
+ df.to_csv(os.path.join(dir, 'index.csv'), index=False)
216
+
217
+ self.logger.info(f"Indexed the directory {dir} and saved the index file to {os.path.join(dir, 'index.csv')}")
218
+ return df
219
+
220
+ def limit_real_data(self, df, num_max_images):
221
+ """
222
+ Limits the real data to contain `num_max_images` total images by preserving the smallest datasets first.
223
+ """
224
+ new_df=pd.DataFrame()
225
+ # get the number of images per dataset name
226
+ real_df = df[df['Label'] == 0]
227
+ fake_df = df[df['Label'] == 1]
228
+
229
+ if len(real_df) <= num_max_images:
230
+ self.logger.info(f"The size of real data: {len(real_df)} is less than or equal to the target size: {num_max_images}. No need to limit the real data. Note that the original model is trained with near 50/50 real/fake to avoid bias -- too much deviation from this may lead to unwanted detection bias.")
231
+ return df
232
+
233
+ dataset_counts = real_df['GeneratorName'].value_counts()
234
+ # sort the dataset counts in descending order
235
+ dataset_counts = dataset_counts.sort_values(ascending=True)
236
+ smallest_sum=0
237
+ smallest_idx=0
238
+ num_not_appended_datasets=len(dataset_counts)
239
+
240
+ while True:
241
+ perModelLen = dataset_counts.iloc[smallest_idx]
242
+ if (perModelLen * num_not_appended_datasets + smallest_sum) >= num_max_images: # reached data target
243
+ perModelLen = math.ceil((num_max_images - smallest_sum) / num_not_appended_datasets) # number of images to sample from datasets not yet fully appended
244
+ break
245
+ elif smallest_idx == len(dataset_counts)-1:
246
+ break # ran out of datasets; this is when size of all real data is less than num_max_images
247
+ else: # continuously grow perModelLen with the next smallest dataset
248
+ smallest_sum += dataset_counts.iloc[smallest_idx] # fully append the smallest dataset
249
+ smallest_idx+=1
250
+ num_not_appended_datasets-=1
251
+
252
+ # sample the datasets
253
+ for dataset_name in dataset_counts.index[smallest_idx:]:
254
+ dataset_df = real_df[real_df['GeneratorName'] == dataset_name]
255
+ if len(dataset_df) > perModelLen:
256
+ dataset_df = dataset_df.sample(n=perModelLen, random_state=self.args.seed)
257
+ new_df = pd.concat([new_df, dataset_df], ignore_index=True)
258
+
259
+ # append the remaining datasets
260
+ for dataset_name in dataset_counts.index[:smallest_idx]:
261
+ dataset_df = real_df[real_df['GeneratorName'] == dataset_name]
262
+ new_df = pd.concat([new_df, dataset_df], ignore_index=True)
263
+
264
+ # report the proportions per dataset
265
+ if self.args.rank == 0:
266
+ pd.options.display.float_format = '{:.2f} %'.format
267
+ self.logger.info(f"Max images per dataset limited to {perModelLen}. Affected datasets: {dataset_counts.index[smallest_idx:]}")
268
+ # Update the dataset counts for reporting proportions
269
+ dataset_counts = new_df['GeneratorName'].value_counts()
270
+ dataset_counts = dataset_counts / dataset_counts.sum() * 100 # composition percentage
271
+ self.logger.info(f"Dataset composition: \n{dataset_counts}")
272
+
273
+ # append the fake data
274
+ new_df = pd.concat([new_df, fake_df], ignore_index=True)
275
+
276
+ return new_df
277
+
278
+ def determine_resize_crop_sizes(args):
279
+ """
280
+ Determine resize and crop sizes based on input size.
281
+ """
282
+ if args.input_size==224:
283
+ resize_size=256
284
+ crop_size=224
285
+ elif args.input_size==384:
286
+ resize_size=440
287
+ crop_size=384
288
+ return resize_size, crop_size
289
+
290
+ def get_transform(args, mode="train", dtype=torch.float32):
291
+ norm_mean = [0.485, 0.456, 0.406] #imagenet norm
292
+ norm_std = [0.229, 0.224, 0.225]
293
+ resize_size, crop_size = determine_resize_crop_sizes(args)
294
+ augment_list = []
295
+
296
+ if mode=="train":
297
+ augment_list.append(transforms.Resize(resize_size))
298
+
299
+ # RandomStateAugmentation
300
+ if args.rsa_ops != '':
301
+ # parse rsa_ops and their num_ops
302
+ # Default "rsa_ops" is "JPEGinMemory,RandomResizeWithRandomIntpl,RandomCrop,RandomHorizontalFlip,RandomVerticalFlip,RRCWithRandomIntpl,RandomRotation,RandomTranslate,RandomShear,RandomPadding,RandomCutout"
303
+ augment_list.append(ctrans.RandomStateAugmentation(resize_size=resize_size, crop_size=crop_size, auglist=args.rsa_ops, min_augs=args.rsa_min_num_ops, max_augs=args.rsa_max_num_ops))
304
+
305
+ augment_list.append(transforms.RandomCrop(crop_size))
306
+
307
+ # basic augmentations
308
+ augment_list.extend([
309
+ ctrans.ToTensor_range(val_min=0, val_max=1),
310
+ transforms.Normalize(mean=norm_mean, std=norm_std),
311
+ transforms.ConvertImageDtype(dtype)
312
+ ])
313
+ elif mode=="val" or mode=="test":
314
+ augment_list.append(transforms.Resize(resize_size))
315
+ augment_list.extend([
316
+ transforms.CenterCrop(crop_size),
317
+ ctrans.ToTensor_range(val_min=0, val_max=1),
318
+ transforms.Normalize(mean=norm_mean, std=norm_std),
319
+ transforms.ConvertImageDtype(dtype),
320
+ ])
321
+ transform = transforms.Compose(augment_list)
322
+ return transform
323
+
324
+ class SubsetWithTransform(torch.utils.data.Dataset):
325
+ """
326
+ Custom subset class which allows to customize transform for each subsets got from random_split()
327
+ """
328
+ def __init__(self, subset, transform=None):
329
+ self.subset = subset
330
+ self.subset_len = len(subset)
331
+ self.transform = transform
332
+
333
+ def __getitem__(self, index):
334
+ img, lab, generator_name = self.subset[index]
335
+ if self.transform:
336
+ img = self.transform(img)
337
+ return img, lab, generator_name
338
+
339
+ def __len__(self):
340
+ return self.subset_len
341
+
342
+ def set_seeds_for_data(seed=11997733):
343
+ """
344
+ Set seeds for Python, numpy, and pytorch. Used to split the dataset consistantly across DDP instances.
345
+ """
346
+ torch.manual_seed(seed)
347
+ random.seed(seed)
348
+ np.random.seed(seed)
349
+
350
+ def set_seeds_for_worker(seed=11997733, id=0):
351
+ """
352
+ Set seeds for python, and numpy. Default seed=11997733.
353
+ PyTorch seeding is handled by torch.Generator passed into the DataLoader
354
+ """
355
+ seed = seed % (2**31)
356
+ random.seed(seed+id)
357
+ np.random.seed(seed+id)
358
+
359
+ def worker_seed_reporter(id=None):
360
+ """
361
+ Debug: reports worker seeds
362
+ """
363
+ workerseed = torch.utils.data.get_worker_info().seed
364
+ numwkr = torch.utils.data.get_worker_info().num_workers
365
+ baseseed = torch.initial_seed()
366
+ print(f"Worker id: {id+1}/{numwkr}, worker seed: {workerseed}, baseseed: {baseseed}, workerseed % (2**31): {workerseed % (2**31)}")
367
+
368
+ def set_seeds_and_report(report=True, id=0):
369
+ """
370
+ Debug: set seeds and report
371
+ """
372
+ workerseed = torch.utils.data.get_worker_info().seed
373
+ set_seeds_for_worker(workerseed, id)
374
+ if report:
375
+ worker_seed_reporter(id)
376
+
377
+ def get_seedftn_and_generator(args, seed=None):
378
+ """
379
+ Get the seed function and generator for the dataloader.
380
+ Args:
381
+ args: Namespace of argument parser
382
+ seed: seed for random number generation
383
+ """
384
+ rank = args.rank
385
+ if seed is not None:
386
+ seedftn = partial(set_seeds_and_report, False)
387
+ seed_generator = torch.Generator(device='cpu')
388
+ seed_generator.manual_seed(seed+rank)
389
+ else:
390
+ seedftn = None
391
+ seed_generator = None
392
+ seed = random.randint(0, 1000000000)
393
+
394
+ return seedftn, seed_generator, seed
395
+
396
+ def get_train_dataloaders(
397
+ args,
398
+ huggingface_repo_id='',
399
+ huggingface_split='Systematic+Manual',
400
+ additional_data_path='',
401
+ additional_data_label_format='real:0,fake:1',
402
+ batch_size=128,
403
+ num_workers=4,
404
+ val_frac=0.01,
405
+ logger: logging.Logger = None,
406
+ seed=None,
407
+ ):
408
+ """
409
+ Get train and validation dataloaders for the dataset.
410
+ Args:
411
+ args: Namespace of argument parser
412
+ huggingface_repo_id: huggingface repo id for the dataset
413
+ huggingface_split: split of the dataset to use
414
+ additional_data_path: path to the folder containing the dataset
415
+ batch_size: size of batch
416
+ num_workers: number of subprocesses to spawn
417
+ val_frac: fraction of data to use for validation (default: 0.01)
418
+ seed: seed for random number generation
419
+ """
420
+ rank = args.rank
421
+ world_size = args.world_size
422
+
423
+ seedftn, seed_generator, seed = get_seedftn_and_generator(args, seed)
424
+ if logger is None:
425
+ logger = ut.logger
426
+
427
+ hf_dataset=None
428
+ if huggingface_repo_id != '':
429
+ hf_dataset=dataset_huggingface(args, huggingface_repo_id, split=huggingface_split, mode='train', cache_dir=args.cache_dir, dtype=torch.float32)
430
+
431
+ folder_dataset=None
432
+ if additional_data_path != '':
433
+ folder_dataset=dataset_folder_based(args, additional_data_path, additional_data_label_format, logger=logger, dtype=torch.float32)
434
+ num_fake_images = len(folder_dataset.df[folder_dataset.df['Label'] == 1])
435
+ if hf_dataset is not None and not args.dont_limit_real_data_to_fake: # limit real data to the length of fake data
436
+ num_hf_fake_images = len(hf_dataset.dataset.filter(lambda x: x['label'] == 1, num_proc=num_workers))
437
+ num_hf_real_images = len(hf_dataset.dataset) - num_hf_fake_images
438
+ num_fake_images = num_fake_images + num_hf_fake_images
439
+ #num_real_images = num_hf_real_images + len(folder_dataset.df[folder_dataset.df['Label'] == 0])
440
+
441
+ folder_based_real_limit = num_fake_images - num_hf_real_images
442
+ if folder_based_real_limit < 0:
443
+ folder_based_real_limit = 0
444
+ else:
445
+ if rank == 0:
446
+ logger.info(f"Limiting folder-based real data to {folder_based_real_limit} images to match the number of fake images.")
447
+ folder_dataset.df = folder_dataset.limit_real_data(folder_dataset.df, folder_based_real_limit)
448
+
449
+ # merge two datasets
450
+ if hf_dataset is not None and folder_dataset is not None:
451
+ dataset_object = torch.utils.data.ConcatDataset([hf_dataset, folder_dataset])
452
+ elif hf_dataset is not None:
453
+ dataset_object = hf_dataset
454
+ elif folder_dataset is not None:
455
+ dataset_object = folder_dataset
456
+ else:
457
+ raise ValueError("No dataset provided. Please provide a huggingface repo id or a folder path.")
458
+
459
+ set_seeds_for_data(seed) # Set same seeds for dataset split
460
+
461
+ # Split the dataset into train and validation sets
462
+ train_frac = 1 - val_frac
463
+ if val_frac > 0:
464
+ traindata_split, valdata_split = torch.utils.data.random_split(dataset_object, (train_frac, val_frac))
465
+ else:
466
+ traindata_split = dataset_object
467
+ valdata_split = []
468
+
469
+ set_seeds_for_data(seed+rank) # after dataset is split, use different seeds for augmentations and shuffling.
470
+
471
+ # Get dataloaders
472
+ traindata_split = SubsetWithTransform(traindata_split, transform=get_transform(args, mode='train', dtype=torch.float32))
473
+ train_sampler = DistributedSampler(
474
+ traindata_split, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False,
475
+ )
476
+ trainloader = torch.utils.data.DataLoader(
477
+ traindata_split, batch_size=batch_size, pin_memory=True,
478
+ shuffle=False, num_workers=num_workers, sampler=train_sampler,
479
+ worker_init_fn=seedftn, generator=seed_generator
480
+ )
481
+
482
+ if len(valdata_split) > 0:
483
+ valdata_split = SubsetWithTransform(valdata_split, transform=get_transform(args, mode='val', dtype=torch.float32))
484
+ val_sampler = DistributedEvalSampler(
485
+ valdata_split, num_replicas=world_size, rank=rank, shuffle=False,
486
+ )
487
+ valloader = torch.utils.data.DataLoader(
488
+ valdata_split, batch_size=batch_size, pin_memory=True,
489
+ shuffle=False, num_workers=num_workers, sampler=val_sampler,
490
+ worker_init_fn=seedftn, generator=seed_generator
491
+ )
492
+ else:
493
+ valloader = None
494
+
495
+ if rank == 0:
496
+ if huggingface_repo_id != '':
497
+ logger.info(f"Loaded huggingface dataset from {huggingface_repo_id}. Split: {huggingface_split}.")
498
+ if additional_data_path != '':
499
+ logger.info(f"Loaded folder dataset from {additional_data_path}.")
500
+ logger.info(f"Train/Val split: num_total: {len(dataset_object)}, num_train: {len(traindata_split)}, num_val: {len(valdata_split)} ")
501
+
502
+ return trainloader, valloader
503
+
504
+ def get_test_dataloader(
505
+ args,
506
+ huggingface_repo_id='',
507
+ huggingface_split='PublicEval',
508
+ additional_data_path='',
509
+ additional_data_label_format='real:0,fake:1',
510
+ batch_size=128,
511
+ num_workers=4,
512
+ logger: logging.Logger = None,
513
+ seed=None,
514
+ ):
515
+ """
516
+ Get test dataloader for the dataset.
517
+ Args:
518
+ args: Namespace of argument parser
519
+ huggingface_repo_id: huggingface repo id for the dataset
520
+ huggingface_split: split of the dataset to use
521
+ additional_data_path: path to the folder containing the dataset
522
+ batch_size: size of batch
523
+ num_workers: number of subprocesses to spawn
524
+ seed: seed for random number generation
525
+ """
526
+ rank = args.rank
527
+ world_size = args.world_size
528
+
529
+ if logger is None:
530
+ logger = ut.logger
531
+
532
+ seedftn, seed_generator, seed = get_seedftn_and_generator(args, seed)
533
+
534
+ hf_dataset=None
535
+ if huggingface_repo_id != '':
536
+ hf_dataset=dataset_huggingface(args, huggingface_repo_id, split=huggingface_split, mode='eval', cache_dir=args.cache_dir, dtype=torch.float32)
537
+
538
+ folder_dataset=None
539
+ if additional_data_path != '':
540
+ folder_dataset=dataset_folder_based(args, additional_data_path, additional_data_label_format, logger=logger, dtype=torch.float32)
541
+
542
+ # merge two datasets
543
+ if hf_dataset is not None and folder_dataset is not None:
544
+ dataset_object = torch.utils.data.ConcatDataset([hf_dataset, folder_dataset])
545
+ elif hf_dataset is not None:
546
+ dataset_object = hf_dataset
547
+ elif folder_dataset is not None:
548
+ dataset_object = folder_dataset
549
+ else:
550
+ raise ValueError("No dataset provided. Please provide a huggingface repo id or a folder path.")
551
+
552
+ set_seeds_for_data(seed+rank)
553
+
554
+ # Create dataset subset with eval transform
555
+ dataset_object = SubsetWithTransform(dataset_object, transform=get_transform(args, mode='val', dtype=torch.float32))
556
+
557
+ # Get dataloaders
558
+ test_sampler = DistributedEvalSampler(
559
+ dataset_object, num_replicas=world_size, rank=rank, shuffle=True,
560
+ )
561
+ testloader = torch.utils.data.DataLoader(
562
+ dataset_object, batch_size=batch_size, pin_memory=True,
563
+ shuffle=False, num_workers=num_workers, sampler=test_sampler,
564
+ worker_init_fn=seedftn, generator=seed_generator
565
+ )
566
+ if rank == 0:
567
+ if huggingface_repo_id != '':
568
+ logger.info(f"Loaded huggingface dataset from {huggingface_repo_id}. Split: {huggingface_split}.")
569
+ if additional_data_path != '':
570
+ logger.info(f"Loaded folder dataset from {additional_data_path}.")
571
+ logger.info(f"Test set size: {len(dataset_object)} ")
572
+ return testloader