| import os | |
| import random | |
| import torch.utils.data as data | |
| import torchvision.transforms as T | |
| from PIL import Image | |
| class MEFDataset(data.Dataset): | |
| def __init__(self, root): | |
| self.img_root = root | |
| self.numbers = list(sorted(os.listdir(self.img_root))) | |
| print(len(self.numbers)) | |
| self.preproc = T.Compose( | |
| [T.ToTensor()] | |
| ) | |
| def __getitem__(self, idx): | |
| number = self.numbers[idx] | |
| im_dir = os.path.join(self.img_root, number) | |
| fn1, fn2 = tuple(random.sample(os.listdir(im_dir), k=2)) | |
| fp1 = os.path.join(im_dir, fn1) | |
| fp2 = os.path.join(im_dir, fn2) | |
| img1 = Image.open(fp1).convert("RGB") | |
| img2 = Image.open(fp2).convert("RGB") | |
| img1 = self.preproc(img1) | |
| img2 = self.preproc(img2) | |
| fn1 = f'{number}_{fn1}' | |
| fn2 = f'{number}_{fn2}' | |
| return img1, img2, fn1, fn2 | |
| def __len__(self): | |
| return len(self.numbers) | |