File size: 1,668 Bytes
168ec29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from torch.utils import data
import os
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class EvalDataset(data.Dataset):
def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False):
self.return_predpath = return_predpath
self.return_gtpath = return_gtpath
pred_dirs = os.listdir(pred_root)
label_dirs = os.listdir(label_root)
dir_name_list = []
for idir in pred_dirs:
if idir in label_dirs:
pred_names = os.listdir(os.path.join(pred_root, idir))
label_names = os.listdir(os.path.join(label_root, idir))
for iname in pred_names:
if iname in label_names:
dir_name_list.append(os.path.join(idir, iname))
self.image_path = list(
map(lambda x: os.path.join(pred_root, x), dir_name_list))
self.label_path = list(
map(lambda x: os.path.join(label_root, x), dir_name_list))
self.labels = []
for p in self.label_path:
self.labels.append(Image.open(p).convert('L'))
def __getitem__(self, item):
predpath = self.image_path[item]
gtpath = self.label_path[item]
pred = Image.open(predpath).convert('L')
gt = self.labels[item]
if pred.size != gt.size:
pred = pred.resize(gt.size, Image.BILINEAR)
returns = [pred, gt]
if self.return_predpath:
returns.append(predpath)
if self.return_gtpath:
returns.append(gtpath)
return returns
def __len__(self):
return len(self.image_path)
|