Spaces:
Runtime error
Runtime error
| import torch,os | |
| from torch.utils.data.dataset import Dataset | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| import re | |
| from easydict import EasyDict as edict | |
| def data_list(img_root,mode): | |
| data_list=[] | |
| if mode=='train': | |
| split_file=os.path.join(img_root, 'splits/train-19zl.csv') | |
| with open(split_file) as f: | |
| list = f.readlines() | |
| for i in list: | |
| aerial_name=re.split(r',', re.split('\n', i)[0])[0] | |
| panorama_name = re.split(r',', re.split('\n', i)[0])[1] | |
| data_list.append([aerial_name, panorama_name]) | |
| else: | |
| split_file=os.path.join(img_root+'splits/val-19zl.csv') | |
| with open(split_file) as f: | |
| list = f.readlines() | |
| for i in list: | |
| aerial_name=re.split(r',', re.split('\n', i)[0])[0] | |
| panorama_name = re.split(r',', re.split('\n', i)[0])[1] | |
| data_list.append([aerial_name, panorama_name]) | |
| print('length of dataset is: ', len(data_list)) | |
| return [os.path.join(img_root, i[1]) for i in data_list] | |
| def img_read(img,size=None,datatype='RGB'): | |
| img = Image.open(img).convert('RGB' if datatype=='RGB' else "L") | |
| if size: | |
| if type(size) is int: | |
| size = (size,size) | |
| img = img.resize(size = size,resample=Image.BICUBIC if datatype=='RGB' else Image.NEAREST) | |
| img = transforms.ToTensor()(img) | |
| return img | |
| class Dataset(Dataset): | |
| def __init__(self, opt,split='train',sub=None,sty_img=None): | |
| self.pano_list = data_list(img_root=opt.data.root,mode=split) | |
| if sub: | |
| self.pano_list = self.pano_list[:sub] | |
| if opt.task == 'test_vid': | |
| demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.demo_img) | |
| self.pano_list = [demo_img_path] | |
| if sty_img: | |
| assert opt.sty_img.split('.')[-1] == 'jpg' | |
| demo_img_path = os.path.join(opt.data.root, 'streetview/panos', opt.sty_img) | |
| self.pano_list = [demo_img_path] | |
| self.opt = opt | |
| def __len__(self): | |
| return len(self.pano_list) | |
| def __getitem__(self, index): | |
| pano = self.pano_list[index] | |
| aer = pano.replace('streetview/panos', 'bingmap/19') | |
| if self.opt.data.sky_mask: | |
| sky = pano.replace('streetview/panos','sky_mask').replace('jpg', 'png') | |
| name = pano | |
| aer = img_read(aer, size = self.opt.data.sat_size) | |
| pano = img_read(pano,size = self.opt.data.pano_size) | |
| if self.opt.data.sky_mask: | |
| sky = img_read(sky,size=self.opt.data.pano_size,datatype='L') | |
| input = {} | |
| input['sat']=aer | |
| input['pano']=pano | |
| input['paths']=name | |
| if self.opt.data.sky_mask: | |
| input['sky_mask']=sky | |
| black_ground = torch.zeros_like(pano) | |
| if self.opt.data.histo_mode =='grey': | |
| input['sky_histc'] = (pano*sky+black_ground*(1-sky)).histc()[10:] | |
| elif self.opt.data.histo_mode in ['rgb','RGB']: | |
| input_a = (pano*sky+black_ground*(1-sky)) | |
| for idx in range(len(input_a)): | |
| if idx == 0: | |
| sky_histc = input_a[idx].histc()[10:] | |
| else: | |
| sky_histc = torch.cat([input_a[idx].histc()[10:],sky_histc],dim=0) | |
| input['sky_histc'] = sky_histc | |
| return input | |