File size: 1,230 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from torch.utils.data import Dataset

from .data_utils import load_image

from .parsers import BaseParser
from .data_utils import LabelTransform
from .perceptual_transforms import PerceptualFeatureMapTransform


class EPUDataset(Dataset):
    def __init__(self,
                 data_parser: BaseParser,
                 perceptual_transform: PerceptualFeatureMapTransform = None,
                 label_transform: LabelTransform = None,
                 **kwargs):

        self.data_parser = data_parser
        self.transform = perceptual_transform
        self.label_transform = label_transform
        self.kwargs = kwargs

        self.image_paths = self.data_parser.image_filenames
        self.labels = self.data_parser.labels

    def __len__(self):
        assert len(self.image_paths) == len(self.labels), "Mismatch in image paths and labels"
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        img = load_image(img_path)
        if self.transform is not None:
            img = self.transform(img)
        if self.label_transform is not None:
            label = self.label_transform(label)
        return img, label