| from torch.utils import data |
| import os |
| from PIL import Image, ImageFile |
|
|
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| class EvalDataset(data.Dataset): |
| def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False): |
| self.return_predpath = return_predpath |
| self.return_gtpath = return_gtpath |
| pred_dirs = os.listdir(pred_root) |
| label_dirs = os.listdir(label_root) |
|
|
| dir_name_list = [] |
| for idir in pred_dirs: |
| if idir in label_dirs: |
| pred_names = os.listdir(os.path.join(pred_root, idir)) |
| label_names = os.listdir(os.path.join(label_root, idir)) |
| for iname in pred_names: |
| if iname in label_names: |
| dir_name_list.append(os.path.join(idir, iname)) |
|
|
| self.image_path = list( |
| map(lambda x: os.path.join(pred_root, x), dir_name_list)) |
| self.label_path = list( |
| map(lambda x: os.path.join(label_root, x), dir_name_list)) |
|
|
| self.labels = [] |
| for p in self.label_path: |
| self.labels.append(Image.open(p).convert('L')) |
|
|
|
|
| def __getitem__(self, item): |
| predpath = self.image_path[item] |
| gtpath = self.label_path[item] |
| pred = Image.open(predpath).convert('L') |
| gt = self.labels[item] |
| if pred.size != gt.size: |
| pred = pred.resize(gt.size, Image.BILINEAR) |
| returns = [pred, gt] |
| if self.return_predpath: |
| returns.append(predpath) |
| if self.return_gtpath: |
| returns.append(gtpath) |
| return returns |
|
|
| def __len__(self): |
| return len(self.image_path) |
|
|