| from torch.utils.data import Dataset |
|
|
| from .data_utils import load_image |
|
|
| from .parsers import BaseParser |
| from .data_utils import LabelTransform |
| from .perceptual_transforms import PerceptualFeatureMapTransform |
|
|
|
|
| class EPUDataset(Dataset): |
| def __init__(self, |
| data_parser: BaseParser, |
| perceptual_transform: PerceptualFeatureMapTransform = None, |
| label_transform: LabelTransform = None, |
| **kwargs): |
|
|
| self.data_parser = data_parser |
| self.transform = perceptual_transform |
| self.label_transform = label_transform |
| self.kwargs = kwargs |
|
|
| self.image_paths = self.data_parser.image_filenames |
| self.labels = self.data_parser.labels |
|
|
| def __len__(self): |
| assert len(self.image_paths) == len(self.labels), "Mismatch in image paths and labels" |
| return len(self.image_paths) |
|
|
| def __getitem__(self, idx): |
| img_path = self.image_paths[idx] |
| label = self.labels[idx] |
|
|
| img = load_image(img_path) |
| if self.transform is not None: |
| img = self.transform(img) |
| if self.label_transform is not None: |
| label = self.label_transform(label) |
| return img, label |
|
|