| from torch.utils import data | |
| import os | |
| from PIL import Image, ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| class EvalDataset(data.Dataset): | |
| def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False): | |
| self.return_predpath = return_predpath | |
| self.return_gtpath = return_gtpath | |
| pred_dirs = os.listdir(pred_root) | |
| label_dirs = os.listdir(label_root) | |
| dir_name_list = [] | |
| for idir in pred_dirs: | |
| if idir in label_dirs: | |
| pred_names = os.listdir(os.path.join(pred_root, idir)) | |
| label_names = os.listdir(os.path.join(label_root, idir)) | |
| for iname in pred_names: | |
| if iname in label_names: | |
| dir_name_list.append(os.path.join(idir, iname)) | |
| self.image_path = list( | |
| map(lambda x: os.path.join(pred_root, x), dir_name_list)) | |
| self.label_path = list( | |
| map(lambda x: os.path.join(label_root, x), dir_name_list)) | |
| self.labels = [] | |
| for p in self.label_path: | |
| self.labels.append(Image.open(p).convert('L')) | |
| def __getitem__(self, item): | |
| predpath = self.image_path[item] | |
| gtpath = self.label_path[item] | |
| pred = Image.open(predpath).convert('L') | |
| gt = self.labels[item] | |
| if pred.size != gt.size: | |
| pred = pred.resize(gt.size, Image.BILINEAR) | |
| returns = [pred, gt] | |
| if self.return_predpath: | |
| returns.append(predpath) | |
| if self.return_gtpath: | |
| returns.append(gtpath) | |
| return returns | |
| def __len__(self): | |
| return len(self.image_path) | |