| from torchvision import transforms | |
| import torchvision.transforms.functional as TF | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| import torch | |
| import os, cv2 | |
| from utils import * | |
| import json | |
| import random | |
| from pycocotools.coco import COCO | |
| class SaliconDataset(DataLoader): | |
| def __init__(self, img_dir, gt_dir, fix_dir, img_ids, exten='.png'): | |
| self.img_dir = img_dir | |
| self.gt_dir = gt_dir | |
| self.fix_dir = fix_dir | |
| self.img_ids = img_ids | |
| self.exten = exten | |
| self.img_transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], | |
| [0.5, 0.5, 0.5]) | |
| ]) | |
| def __getitem__(self, idx): | |
| img_id = self.img_ids[idx] | |
| img_path = os.path.join(self.img_dir, img_id + '.jpg') | |
| gt_path = os.path.join(self.gt_dir, img_id + self.exten) | |
| fix_path = os.path.join(self.fix_dir, img_id + self.exten) | |
| img = Image.open(img_path).convert('RGB') | |
| img = self.img_transform(img) | |
| gt = np.array(Image.open(gt_path).convert('L')) | |
| gt = gt.astype('float') | |
| gt = cv2.resize(gt, (256,256)) | |
| if np.max(gt) > 1.0: | |
| gt = gt / 255.0 | |
| fixations = np.array(Image.open(fix_path).convert('L')) | |
| fixations = fixations.astype('float') | |
| fixations = (fixations > 0.5).astype('float') | |
| assert np.min(gt)>=0.0 and np.max(gt)<=1.0 | |
| assert np.min(fixations)==0.0 and np.max(fixations)==1.0 | |
| return img, torch.FloatTensor(gt), torch.FloatTensor(fixations) | |
| def __len__(self): | |
| return len(self.img_ids) | |