File size: 6,149 Bytes
5f0437a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Copyright (c) 2023 Image Processing Research Group of University Federico II of Naples ('GRIP-UNINA').
#
# All rights reserved.
# This work should only be used for nonprofit purposes.
#
# By downloading and/or using any of these files, you implicitly agree to all the
# terms of the license, as specified in the document LICENSE.txt
# (included in this package) and online at
# http://www.grip.unina.it/download/LICENSE_OPEN.txt

"""
Created in September 2022
@author: fabrizio.guillaro
"""

from torch.utils.data import Dataset
import random

from dataset.dataset_FantasticReality import FantasticReality
from dataset.dataset_IMD2020 import IMD2020
from dataset.dataset_CASIA import CASIA
from dataset.dataset_TampCOCO import tampCOCO
from dataset.dataset_CompRAISE import compRAISE


class myDataset(Dataset):
    def __init__(self, config, crop_size, grid_crop, mode="train", max_dim=None, aug=None):
        self.dataset_list = []
        training_set = config.DATASET.TRAIN
        valid_set    = config.DATASET.VALID
        
        if mode == "train":
            if 'FR' in training_set:
                self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_train_list.txt", aug=aug))
                self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_train_list.txt", is_auth_list=True, aug=aug))
                
            if 'IMD' in training_set:
                self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_train_list.txt", aug=aug))
                
            if 'CA' in training_set:
                self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_train_list.txt", aug=aug))
                self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_train_list.txt", aug=aug))

            if 'COCO' in training_set:
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_train_list.txt",   aug=aug))
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_train_list.txt",   aug=aug))
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_train_list.txt",  aug=aug))
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_train_list.txt", aug=aug))
            
            if 'RAISE' in training_set:
                self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_train.txt", aug=aug))


        elif mode == "valid":
            if 'FR' in valid_set:
                self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_valid_list.txt", max_dim=max_dim, aug=aug))
                self.dataset_list.append(FantasticReality(crop_size, grid_crop, "dataset/data/FR_auth_valid_list.txt", is_auth_list=True, max_dim=max_dim, aug=aug))
                
            if 'IMD' in valid_set:
                self.dataset_list.append(IMD2020(crop_size, grid_crop, "dataset/data/IMD_valid_list.txt", max_dim=max_dim, aug=aug))
            
            if 'CA' in valid_set:
                self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_valid_list.txt", max_dim=max_dim, aug=aug))
                self.dataset_list.append(CASIA(crop_size, grid_crop, "dataset/data/CASIA_v2_auth_valid_list.txt", max_dim=max_dim, aug=aug))
            
            if 'COCO' in valid_set:
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/cm_COCO_valid_list.txt",   max_dim=max_dim, aug=aug))
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/sp_COCO_valid_list.txt",   max_dim=max_dim, aug=aug))
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcm_COCO_valid_list.txt",  max_dim=max_dim, aug=aug))
                self.dataset_list.append(tampCOCO(crop_size, grid_crop, "dataset/data/bcmc_COCO_valid_list.txt", max_dim=max_dim, aug=aug))
            
            if 'RAISE' in valid_set:
                self.dataset_list.append(compRAISE(crop_size, grid_crop, "dataset/data/compRAISE_valid.txt", max_dim=max_dim, aug=aug))

        else:
            raise KeyError("Invalid mode: " + mode)

        self.crop_size = crop_size
        self.grid_crop = grid_crop
        self.mode = mode
        lengths = [len(ds) for ds in self.dataset_list]
        self.smallest = min(lengths)
        if config.TRAIN.NUM_SAMPLES > 0 and config.TRAIN.NUM_SAMPLES < self.smallest:
            self.smallest = config.TRAIN.NUM_SAMPLES


    def shuffle(self):
        for dataset in self.dataset_list:
            random.shuffle(dataset.img_list)


    def get_filename(self, index):
        it = 0
        while True:
            if index >= len(self.dataset_list[it]):
                index -= len(self.dataset_list[it])
                it += 1
                continue
            return self.dataset_list[it].get_img_name(index)


    def __len__(self):
        if self.mode == 'train':
            # class-balanced sampling
            return self.smallest * len(self.dataset_list)
        else:
            return sum([len(lst) for lst in self.dataset_list])


    def __getitem__(self, index):
        if self.mode == 'train':
            # class-balanced sampling
            if index < self.smallest * len(self.dataset_list):
                return self.dataset_list[index//self.smallest].get_img(index % self.smallest)
            else:
                raise ValueError("Something wrong.")
        else:
            it = 0
            while True:
                if index >= len(self.dataset_list[it]):
                    index -= len(self.dataset_list[it])
                    it += 1
                    continue
                return self.dataset_list[it].get_img(index)


    def get_info(self):
        s = ''
        for ds in self.dataset_list:
            s += f'{ds.__class__.__name__}: \t{len(ds)} images \n'
        s += f'Smallest: {self.smallest}'
        return s