| import os | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| class SuperviselyPersonDataset(Dataset): | |
| def __init__(self, imgdir, segdir, transform=None): | |
| self.img_dir = imgdir | |
| self.img_files = sorted(os.listdir(imgdir)) | |
| self.seg_dir = segdir | |
| self.seg_files = sorted(os.listdir(segdir)) | |
| assert len(self.img_files) == len(self.seg_files) | |
| self.transform = transform | |
| def __len__(self): | |
| return len(self.img_files) | |
| def __getitem__(self, idx): | |
| with Image.open(os.path.join(self.img_dir, self.img_files[idx])) as img, \ | |
| Image.open(os.path.join(self.seg_dir, self.seg_files[idx])) as seg: | |
| img = img.convert('RGB') | |
| seg = seg.convert('L') | |
| if self.transform is not None: | |
| img, seg = self.transform(img, seg) | |
| return img, seg | |