ius / data /dataset.py
pgatoula's picture
Sync from GitHub via hub-sync
99ec8a2 verified
from torch.utils.data import Dataset
from .data_utils import load_image
from .parsers import BaseParser
from .data_utils import LabelTransform
from .perceptual_transforms import PerceptualFeatureMapTransform
class EPUDataset(Dataset):
def __init__(self,
data_parser: BaseParser,
perceptual_transform: PerceptualFeatureMapTransform = None,
label_transform: LabelTransform = None,
**kwargs):
self.data_parser = data_parser
self.transform = perceptual_transform
self.label_transform = label_transform
self.kwargs = kwargs
self.image_paths = self.data_parser.image_filenames
self.labels = self.data_parser.labels
def __len__(self):
assert len(self.image_paths) == len(self.labels), "Mismatch in image paths and labels"
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
label = self.labels[idx]
img = load_image(img_path)
if self.transform is not None:
img = self.transform(img)
if self.label_transform is not None:
label = self.label_transform(label)
return img, label