RustCoSeg / dataloader.py
HirraA's picture
Upload 30 files
168ec29 verified
from torch.utils import data
import os
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class EvalDataset(data.Dataset):
def __init__(self, pred_root, label_root, return_predpath=False, return_gtpath=False):
self.return_predpath = return_predpath
self.return_gtpath = return_gtpath
pred_dirs = os.listdir(pred_root)
label_dirs = os.listdir(label_root)
dir_name_list = []
for idir in pred_dirs:
if idir in label_dirs:
pred_names = os.listdir(os.path.join(pred_root, idir))
label_names = os.listdir(os.path.join(label_root, idir))
for iname in pred_names:
if iname in label_names:
dir_name_list.append(os.path.join(idir, iname))
self.image_path = list(
map(lambda x: os.path.join(pred_root, x), dir_name_list))
self.label_path = list(
map(lambda x: os.path.join(label_root, x), dir_name_list))
self.labels = []
for p in self.label_path:
self.labels.append(Image.open(p).convert('L'))
def __getitem__(self, item):
predpath = self.image_path[item]
gtpath = self.label_path[item]
pred = Image.open(predpath).convert('L')
gt = self.labels[item]
if pred.size != gt.size:
pred = pred.resize(gt.size, Image.BILINEAR)
returns = [pred, gt]
if self.return_predpath:
returns.append(predpath)
if self.return_gtpath:
returns.append(gtpath)
return returns
def __len__(self):
return len(self.image_path)