DaniilOr's picture
Upload folder using huggingface_hub
5f0437a verified
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
from torch.utils.data import Dataset
import random
from dataset.dataset_FantasticReality import FantasticReality
from dataset.dataset_IMD2020 import IMD2020
from dataset.dataset_CASIA import CASIA
from dataset.dataset_TampCOCO import tampCOCO
from dataset.dataset_CompRAISE import compRAISE
class myDataset(Dataset):
def __init__(self, config, crop_size, grid_crop, mode="train", max_dim=None, aug=None):
self.dataset_list = []
training_set = config.DATASET.TRAIN
valid_set = config.DATASET.VALID
if mode == "train":
if 'FR' in training_set:
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_train_list.txt", aug=aug))
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_train_list.txt", is_auth_list=True, aug=aug))
if 'IMD' in training_set:
self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_train_list.txt", aug=aug))
if 'CA' in training_set:
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_train_list.txt", aug=aug))
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_train_list.txt", aug=aug))
if 'COCO' in training_set:
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_train_list.txt", aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_train_list.txt", aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_train_list.txt", aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_train_list.txt", aug=aug))
if 'RAISE' in training_set:
self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_train.txt", aug=aug))
elif mode == "valid":
if 'FR' in valid_set:
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_valid_list.txt", is_auth_list=True, max_dim=max_dim, aug=aug))
if 'IMD' in valid_set:
self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_valid_list.txt", max_dim=max_dim, aug=aug))
if 'CA' in valid_set:
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_valid_list.txt", max_dim=max_dim, aug=aug))
if 'COCO' in valid_set:
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
if 'RAISE' in valid_set:
self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_valid.txt", max_dim=max_dim, aug=aug))
else:
raise KeyError("Invalid mode: " + mode)
self.crop_size = crop_size
self.grid_crop = grid_crop
self.mode = mode
lengths = [len(ds) for ds in self.dataset_list]
self.smallest = min(lengths)
if config.TRAIN.NUM_SAMPLES > 0 and config.TRAIN.NUM_SAMPLES < self.smallest:
self.smallest = config.TRAIN.NUM_SAMPLES
def shuffle(self):
for dataset in self.dataset_list:
random.shuffle(dataset.img_list)
def get_filename(self, index):
it = 0
while True:
if index >= len(self.dataset_list[it]):
index -= len(self.dataset_list[it])
it += 1
continue
return self.dataset_list[it].get_img_name(index)
def __len__(self):
if self.mode == 'train':
# class-balanced sampling
return self.smallest * len(self.dataset_list)
else:
return sum([len(lst) for lst in self.dataset_list])
def __getitem__(self, index):
if self.mode == 'train':
# class-balanced sampling
if index < self.smallest * len(self.dataset_list):
return self.dataset_list[index//self.smallest].get_img(index % self.smallest)
else:
raise ValueError("Something wrong.")
else:
it = 0
while True:
if index >= len(self.dataset_list[it]):
index -= len(self.dataset_list[it])
it += 1
continue
return self.dataset_list[it].get_img(index)
def get_info(self):
s = ''
for ds in self.dataset_list:
s += f'{ds.__class__.__name__}: \t{len(ds)} images \n'
s += f'Smallest: {self.smallest}'
return s