|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Created in September 2022 |
|
|
@author: fabrizio.guillaro |
|
|
""" |
|
|
|
|
|
from torch.utils.data import Dataset |
|
|
import random |
|
|
|
|
|
from dataset.dataset_FantasticReality import FantasticReality |
|
|
from dataset.dataset_IMD2020 import IMD2020 |
|
|
from dataset.dataset_CASIA import CASIA |
|
|
from dataset.dataset_TampCOCO import tampCOCO |
|
|
from dataset.dataset_CompRAISE import compRAISE |
|
|
|
|
|
|
|
|
class myDataset(Dataset): |
|
|
def __init__(self, config, crop_size, grid_crop, mode="train", max_dim=None, aug=None): |
|
|
self.dataset_list = [] |
|
|
training_set = config.DATASET.TRAIN |
|
|
valid_set = config.DATASET.VALID |
|
|
|
|
|
if mode == "train": |
|
|
if 'FR' in training_set: |
|
|
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_train_list.txt", aug=aug)) |
|
|
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_train_list.txt", is_auth_list=True, aug=aug)) |
|
|
|
|
|
if 'IMD' in training_set: |
|
|
self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_train_list.txt", aug=aug)) |
|
|
|
|
|
if 'CA' in training_set: |
|
|
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_train_list.txt", aug=aug)) |
|
|
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_train_list.txt", aug=aug)) |
|
|
|
|
|
if 'COCO' in training_set: |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_train_list.txt", aug=aug)) |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_train_list.txt", aug=aug)) |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_train_list.txt", aug=aug)) |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_train_list.txt", aug=aug)) |
|
|
|
|
|
if 'RAISE' in training_set: |
|
|
self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_train.txt", aug=aug)) |
|
|
|
|
|
|
|
|
elif mode == "valid": |
|
|
if 'FR' in valid_set: |
|
|
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_valid_list.txt", is_auth_list=True, max_dim=max_dim, aug=aug)) |
|
|
|
|
|
if 'IMD' in valid_set: |
|
|
self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
|
|
|
if 'CA' in valid_set: |
|
|
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
|
|
|
if 'COCO' in valid_set: |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_valid_list.txt", max_dim=max_dim, aug=aug)) |
|
|
|
|
|
if 'RAISE' in valid_set: |
|
|
self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_valid.txt", max_dim=max_dim, aug=aug)) |
|
|
|
|
|
else: |
|
|
raise KeyError("Invalid mode: " + mode) |
|
|
|
|
|
self.crop_size = crop_size |
|
|
self.grid_crop = grid_crop |
|
|
self.mode = mode |
|
|
lengths = [len(ds) for ds in self.dataset_list] |
|
|
self.smallest = min(lengths) |
|
|
if config.TRAIN.NUM_SAMPLES > 0 and config.TRAIN.NUM_SAMPLES < self.smallest: |
|
|
self.smallest = config.TRAIN.NUM_SAMPLES |
|
|
|
|
|
|
|
|
def shuffle(self): |
|
|
for dataset in self.dataset_list: |
|
|
random.shuffle(dataset.img_list) |
|
|
|
|
|
|
|
|
def get_filename(self, index): |
|
|
it = 0 |
|
|
while True: |
|
|
if index >= len(self.dataset_list[it]): |
|
|
index -= len(self.dataset_list[it]) |
|
|
it += 1 |
|
|
continue |
|
|
return self.dataset_list[it].get_img_name(index) |
|
|
|
|
|
|
|
|
def __len__(self): |
|
|
if self.mode == 'train': |
|
|
|
|
|
return self.smallest * len(self.dataset_list) |
|
|
else: |
|
|
return sum([len(lst) for lst in self.dataset_list]) |
|
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
|
if self.mode == 'train': |
|
|
|
|
|
if index < self.smallest * len(self.dataset_list): |
|
|
return self.dataset_list[index//self.smallest].get_img(index % self.smallest) |
|
|
else: |
|
|
raise ValueError("Something wrong.") |
|
|
else: |
|
|
it = 0 |
|
|
while True: |
|
|
if index >= len(self.dataset_list[it]): |
|
|
index -= len(self.dataset_list[it]) |
|
|
it += 1 |
|
|
continue |
|
|
return self.dataset_list[it].get_img(index) |
|
|
|
|
|
|
|
|
def get_info(self): |
|
|
s = '' |
|
|
for ds in self.dataset_list: |
|
|
s += f'{ds.__class__.__name__}: \t{len(ds)} images \n' |
|
|
s += f'Smallest: {self.smallest}' |
|
|
return s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|