File size: 2,973 Bytes
99ec8a2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | # from torch.utils.data import Dataset
from typing import Union
from data.data_utils import LabelTransform
from utils.omega_parser import DataParams
from data.parsers import set_parser_class
from data.perceptual_transforms import PerceptualFeatureMapTransform
from data.dataset import EPUDataset
from utils.config_utils import data_cfg_to_dataparser
from ius.ius_eval_parser import IUSEvalParser
# Creates an EPUDataset from IUS config.
# Used in EPUCNN training, classification performance eval & calculation of cb vectors
class EPUDatasetFromConfig:
def __init__(self, dataconfig: DataParams, **kwargs):
self.dataset_path = dataconfig.dataset_path
self.images_extension = dataconfig.images_extension
self.data_preprocessing = dataconfig.data_preprocessing
# self.data_loading = dataconfig.data_loading
self.group_by = kwargs.get('group_by')
def get_dataset(self, dataset_mode: str) -> EPUDataset:
assert dataset_mode in ["train", "validation", "test"], "Dataset mode must be either train or val or test."
# Create parser & transforms for Dataset
parser = set_parser_class(
name=self.data_preprocessing.data_parser)(
**data_cfg_to_dataparser(
dataset_path=self.dataset_path,
images_extension=self.images_extension,
data_mode=dataset_mode,
preprocessing_cfg=self.data_preprocessing,
group_by=self.group_by,
)
)
perceptual_transform = PerceptualFeatureMapTransform(
resize_dims=self.data_preprocessing.resize_dims,
resize_mode="bicubic",
data_mode=self.data_preprocessing.data_mode
)
label_transform = LabelTransform(
mapping_dict=self.data_preprocessing.label_mapping
)
# Create Dataset
dataset = EPUDataset(
data_parser=parser,
perceptual_transform=perceptual_transform,
label_transform=label_transform
)
return dataset
# used during IUS evaluation
class IUSEvalDataset:
def __init__(self, dataconfig: DataParams, **kwargs):
self.dataset_path = dataconfig.dataset_path
self.images_extension = dataconfig.images_extension
self.data_preprocessing = dataconfig.data_preprocessing
# self.data_loading = dataconfig.data_loading
self.group_by = kwargs.get('group_by')
def get_dataset(self, parser: Union[IUSEvalParser]) -> EPUDataset:
perceptual_transform = PerceptualFeatureMapTransform(
resize_dims=self.data_preprocessing.resize_dims,
resize_mode="bicubic",
data_mode=self.data_preprocessing.data_mode
)
# Create Dataset
dataset = EPUDataset(
data_parser=parser,
perceptual_transform=perceptual_transform,
label_transform=None
)
return dataset
|