tempsal / src /dataloader_clean.py
baharay's picture
Upload 31 files
a3fc39c verified
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import DataLoader
import numpy as np
import torch
import os, cv2
from utils import *
import json
import random
from pycocotools.coco import COCO
class SaliconDataset(DataLoader):
def __init__(self, img_dir, gt_dir, fix_dir, img_ids, exten='.png'):
self.img_dir = img_dir
self.gt_dir = gt_dir
self.fix_dir = fix_dir
self.img_ids = img_ids
self.exten = exten
self.img_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],
[0.5, 0.5, 0.5])
])
def __getitem__(self, idx):
img_id = self.img_ids[idx]
img_path = os.path.join(self.img_dir, img_id + '.jpg')
gt_path = os.path.join(self.gt_dir, img_id + self.exten)
fix_path = os.path.join(self.fix_dir, img_id + self.exten)
img = Image.open(img_path).convert('RGB')
img = self.img_transform(img)
gt = np.array(Image.open(gt_path).convert('L'))
gt = gt.astype('float')
gt = cv2.resize(gt, (256,256))
if np.max(gt) > 1.0:
gt = gt / 255.0
fixations = np.array(Image.open(fix_path).convert('L'))
fixations = fixations.astype('float')
fixations = (fixations > 0.5).astype('float')
assert np.min(gt)>=0.0 and np.max(gt)<=1.0
assert np.min(fixations)==0.0 and np.max(fixations)==1.0
return img, torch.FloatTensor(gt), torch.FloatTensor(fixations)
def __len__(self):
return len(self.img_ids)