| from torchvision import transforms |
| import torchvision.transforms.functional as TF |
| from PIL import Image |
| from torch.utils.data import DataLoader |
| import numpy as np |
| import torch |
| import os, cv2 |
| from utils import * |
| import json |
| import random |
| from pycocotools.coco import COCO |
|
|
| |
| class SaliconDataset(DataLoader): |
| def __init__(self, img_dir, gt_dir, fix_dir, img_ids, exten='.png'): |
| self.img_dir = img_dir |
| self.gt_dir = gt_dir |
| self.fix_dir = fix_dir |
| self.img_ids = img_ids |
| self.exten = exten |
| self.img_transform = transforms.Compose([ |
| transforms.Resize((256, 256)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], |
| [0.5, 0.5, 0.5]) |
| ]) |
|
|
| def __getitem__(self, idx): |
| img_id = self.img_ids[idx] |
| img_path = os.path.join(self.img_dir, img_id + '.jpg') |
| gt_path = os.path.join(self.gt_dir, img_id + self.exten) |
| fix_path = os.path.join(self.fix_dir, img_id + self.exten) |
|
|
| img = Image.open(img_path).convert('RGB') |
| img = self.img_transform(img) |
|
|
| gt = np.array(Image.open(gt_path).convert('L')) |
| gt = gt.astype('float') |
| gt = cv2.resize(gt, (256,256)) |
| if np.max(gt) > 1.0: |
| gt = gt / 255.0 |
|
|
| fixations = np.array(Image.open(fix_path).convert('L')) |
| fixations = fixations.astype('float') |
| fixations = (fixations > 0.5).astype('float') |
|
|
| assert np.min(gt)>=0.0 and np.max(gt)<=1.0 |
| assert np.min(fixations)==0.0 and np.max(fixations)==1.0 |
| return img, torch.FloatTensor(gt), torch.FloatTensor(fixations) |
|
|
| def __len__(self): |
| return len(self.img_ids) |
|
|