File size: 6,149 Bytes
5f0437a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt
"""
Created in September 2022
@author: fabrizio.guillaro
"""
from torch.utils.data import Dataset
import random
from dataset.dataset_FantasticReality import FantasticReality
from dataset.dataset_IMD2020 import IMD2020
from dataset.dataset_CASIA import CASIA
from dataset.dataset_TampCOCO import tampCOCO
from dataset.dataset_CompRAISE import compRAISE
class myDataset(Dataset):
def __init__(self, config, crop_size, grid_crop, mode="train", max_dim=None, aug=None):
self.dataset_list = []
training_set = config.DATASET.TRAIN
valid_set = config.DATASET.VALID
if mode == "train":
if 'FR' in training_set:
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_train_list.txt", aug=aug))
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_train_list.txt", is_auth_list=True, aug=aug))
if 'IMD' in training_set:
self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_train_list.txt", aug=aug))
if 'CA' in training_set:
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_train_list.txt", aug=aug))
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_train_list.txt", aug=aug))
if 'COCO' in training_set:
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_train_list.txt", aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_train_list.txt", aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_train_list.txt", aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_train_list.txt", aug=aug))
if 'RAISE' in training_set:
self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_train.txt", aug=aug))
elif mode == "valid":
if 'FR' in valid_set:
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_valid_list.txt", is_auth_list=True, max_dim=max_dim, aug=aug))
if 'IMD' in valid_set:
self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_valid_list.txt", max_dim=max_dim, aug=aug))
if 'CA' in valid_set:
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_valid_list.txt", max_dim=max_dim, aug=aug))
if 'COCO' in valid_set:
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
if 'RAISE' in valid_set:
self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_valid.txt", max_dim=max_dim, aug=aug))
else:
raise KeyError("Invalid mode: " + mode)
self.crop_size = crop_size
self.grid_crop = grid_crop
self.mode = mode
lengths = [len(ds) for ds in self.dataset_list]
self.smallest = min(lengths)
if config.TRAIN.NUM_SAMPLES > 0 and config.TRAIN.NUM_SAMPLES < self.smallest:
self.smallest = config.TRAIN.NUM_SAMPLES
def shuffle(self):
for dataset in self.dataset_list:
random.shuffle(dataset.img_list)
def get_filename(self, index):
it = 0
while True:
if index >= len(self.dataset_list[it]):
index -= len(self.dataset_list[it])
it += 1
continue
return self.dataset_list[it].get_img_name(index)
def __len__(self):
if self.mode == 'train':
# class-balanced sampling
return self.smallest * len(self.dataset_list)
else:
return sum([len(lst) for lst in self.dataset_list])
def __getitem__(self, index):
if self.mode == 'train':
# class-balanced sampling
if index < self.smallest * len(self.dataset_list):
return self.dataset_list[index//self.smallest].get_img(index % self.smallest)
else:
raise ValueError("Something wrong.")
else:
it = 0
while True:
if index >= len(self.dataset_list[it]):
index -= len(self.dataset_list[it])
it += 1
continue
return self.dataset_list[it].get_img(index)
def get_info(self):
s = ''
for ds in self.dataset_list:
s += f'{ds.__class__.__name__}: \t{len(ds)} images \n'
s += f'Smallest: {self.smallest}'
return s
|