| |
| |
| import os |
| from torch.utils.data import Dataset |
| from PIL import Image |
| import torch |
|
|
| class GTResDataset(Dataset): |
|
|
| def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None): |
| self.pairs = [] |
| for f in os.listdir(root_path): |
| image_path = os.path.join(root_path, f) |
| gt_path = os.path.join(gt_dir, f) |
| if f.endswith(".jpg") or f.endswith(".png"): |
| self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None]) |
| self.transform = transform |
| self.transform_train = transform_train |
|
|
| def __len__(self): |
| return len(self.pairs) |
|
|
| def __getitem__(self, index): |
| from_path, to_path, _ = self.pairs[index] |
| from_im = Image.open(from_path).convert('RGB') |
| to_im = Image.open(to_path).convert('RGB') |
|
|
| if self.transform: |
| to_im = self.transform(to_im) |
| from_im = self.transform(from_im) |
|
|
| return from_im, to_im |
|
|