File size: 2,973 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# from torch.utils.data import Dataset
from typing import Union

from data.data_utils import LabelTransform
from utils.omega_parser import DataParams
from data.parsers import set_parser_class
from data.perceptual_transforms import PerceptualFeatureMapTransform
from data.dataset import EPUDataset
from utils.config_utils import data_cfg_to_dataparser
from ius.ius_eval_parser import IUSEvalParser


# Creates an EPUDataset from IUS config.
# Used in EPUCNN training, classification performance eval & calculation of cb vectors
class EPUDatasetFromConfig:
    def __init__(self, dataconfig: DataParams, **kwargs):

        self.dataset_path = dataconfig.dataset_path
        self.images_extension = dataconfig.images_extension

        self.data_preprocessing = dataconfig.data_preprocessing
        # self.data_loading = dataconfig.data_loading

        self.group_by = kwargs.get('group_by')

    def get_dataset(self, dataset_mode: str) -> EPUDataset:
        assert dataset_mode in ["train", "validation", "test"], "Dataset mode must be either train or val or test."

        # Create parser & transforms for Dataset
        parser = set_parser_class(
            name=self.data_preprocessing.data_parser)(
            **data_cfg_to_dataparser(
                dataset_path=self.dataset_path,
                images_extension=self.images_extension,
                data_mode=dataset_mode,
                preprocessing_cfg=self.data_preprocessing,
                group_by=self.group_by,
            )
        )

        perceptual_transform = PerceptualFeatureMapTransform(
            resize_dims=self.data_preprocessing.resize_dims,
            resize_mode="bicubic",
            data_mode=self.data_preprocessing.data_mode
        )

        label_transform = LabelTransform(
            mapping_dict=self.data_preprocessing.label_mapping
        )

        # Create Dataset
        dataset = EPUDataset(
            data_parser=parser,
            perceptual_transform=perceptual_transform,
            label_transform=label_transform
        )
        return dataset


# used during IUS evaluation
class IUSEvalDataset:
    def __init__(self, dataconfig: DataParams, **kwargs):

        self.dataset_path = dataconfig.dataset_path
        self.images_extension = dataconfig.images_extension

        self.data_preprocessing = dataconfig.data_preprocessing
        # self.data_loading = dataconfig.data_loading

        self.group_by = kwargs.get('group_by')

    def get_dataset(self, parser: Union[IUSEvalParser]) -> EPUDataset:

        perceptual_transform = PerceptualFeatureMapTransform(
            resize_dims=self.data_preprocessing.resize_dims,
            resize_mode="bicubic",
            data_mode=self.data_preprocessing.data_mode
        )

        # Create Dataset
        dataset = EPUDataset(
            data_parser=parser,
            perceptual_transform=perceptual_transform,
            label_transform=None
        )
        return dataset