| |
| |
|
|
| |
| |
| |
|
|
| import sys, os |
| import os.path as osp |
| import pickle |
| import numpy as np |
| from PIL import Image |
| import json |
| import h5py |
| from glob import glob |
| import cv2 |
|
|
| import torch |
| from torch.utils import data |
|
|
| from .augmentor import StereoAugmentor |
|
|
|
|
|
|
| dataset_to_root = { |
| 'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/', |
| 'SceneFlow': './data/stereoflow//SceneFlow/', |
| 'ETH3DLowRes': './data/stereoflow/eth3d_lowres/', |
| 'Booster': './data/stereoflow/booster_gt/', |
| 'Middlebury2021': './data/stereoflow/middlebury/2021/data/', |
| 'Middlebury2014': './data/stereoflow/middlebury/2014/', |
| 'Middlebury2006': './data/stereoflow/middlebury/2006/', |
| 'Middlebury2005': './data/stereoflow/middlebury/2005/train/', |
| 'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/', |
| 'Spring': './data/stereoflow/spring/', |
| 'Kitti15': './data/stereoflow/kitti-stereo-2015/', |
| 'Kitti12': './data/stereoflow/kitti-stereo-2012/', |
| } |
| cache_dir = "./data/stereoflow/datasets_stereo_cache/" |
|
|
|
|
| in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) |
| in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) |
| def img_to_tensor(img): |
| img = torch.from_numpy(img).permute(2, 0, 1).float() / 255. |
| img = (img-in1k_mean)/in1k_std |
| return img |
| def disp_to_tensor(disp): |
| return torch.from_numpy(disp)[None,:,:] |
|
|
| class StereoDataset(data.Dataset): |
| |
| def __init__(self, split, augmentor=False, crop_size=None, totensor=True): |
| self.split = split |
| if not augmentor: assert crop_size is None |
| if crop_size: assert augmentor |
| self.crop_size = crop_size |
| self.augmentor_str = augmentor |
| self.augmentor = StereoAugmentor(crop_size) if augmentor else None |
| self.totensor = totensor |
| self.rmul = 1 |
| self.has_constant_resolution = True |
| self._prepare_data() |
| self._load_or_build_cache() |
| |
| def prepare_data(self): |
| """ |
| to be defined for each dataset |
| """ |
| raise NotImplementedError |
| |
| def __len__(self): |
| return len(self.pairnames) |
| |
| def __getitem__(self, index): |
| pairname = self.pairnames[index] |
| |
| |
| Limgname = self.pairname_to_Limgname(pairname) |
| Rimgname = self.pairname_to_Rimgname(pairname) |
| Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None |
| |
| |
| Limg = _read_img(Limgname) |
| Rimg = _read_img(Rimgname) |
| disp = self.load_disparity(Ldispname) if Ldispname is not None else None |
| |
| |
| if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname) |
| |
| |
| if self.augmentor is not None: |
| Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) |
| |
| if self.totensor: |
| Limg = img_to_tensor(Limg) |
| Rimg = img_to_tensor(Rimg) |
| if disp is None: |
| disp = torch.tensor([]) |
| else: |
| disp = disp_to_tensor(disp) |
| |
| return Limg, Rimg, disp, str(pairname) |
| |
| def __rmul__(self, v): |
| self.rmul *= v |
| self.pairnames = v * self.pairnames |
| return self |
| |
| def __str__(self): |
| return f'{self.__class__.__name__}_{self.split}' |
| |
| def __repr__(self): |
| s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' |
| if self.rmul==1: |
| s+=f'\n\tnum pairs: {len(self.pairnames)}' |
| else: |
| s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' |
| return s |
|
|
| def _set_root(self): |
| self.root = dataset_to_root[self.name] |
| assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" |
|
|
| def _load_or_build_cache(self): |
| cache_file = osp.join(cache_dir, self.name+'.pkl') |
| if osp.isfile(cache_file): |
| with open(cache_file, 'rb') as fid: |
| self.pairnames = pickle.load(fid)[self.split] |
| else: |
| tosave = self._build_cache() |
| os.makedirs(cache_dir, exist_ok=True) |
| with open(cache_file, 'wb') as fid: |
| pickle.dump(tosave, fid) |
| self.pairnames = tosave[self.split] |
| |
| class CREStereoDataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = 'CREStereo' |
| self._set_root() |
| assert self.split in ['train'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg') |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png') |
| self.pairname_to_str = lambda pairname: pairname |
| self.load_disparity = _read_crestereo_disp |
| |
| |
| def _build_cache(self): |
| allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')] |
| assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo" |
| tosave = {'train': allpairs} |
| return tosave |
| |
| class SceneFlowDataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "SceneFlow" |
| self._set_root() |
| assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/') |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm' |
| self.pairname_to_str = lambda pairname: pairname[:-4] |
| self.load_disparity = _read_sceneflow_disp |
| |
| def _build_cache(self): |
| trainpairs = [] |
| |
| pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png')) |
| pairs = list(map(lambda x: x[len(self.root):], pairs)) |
| assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" |
| trainpairs += pairs |
| |
| pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png')) |
| pairs = list(map(lambda x: x[len(self.root):], pairs)) |
| assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" |
| trainpairs += pairs |
| |
| pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png')) |
| pairs = list(map(lambda x: x[len(self.root):], pairs)) |
| assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" |
| trainpairs += pairs |
| assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" |
| testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png')) |
| testpairs = list(map(lambda x: x[len(self.root):], testpairs)) |
| assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" |
| test1of100pairs = testpairs[::100] |
| assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" |
| |
| tosave = {'train_finalpass': trainpairs, |
| 'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)), |
| 'test_finalpass': testpairs, |
| 'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)), |
| 'test1of100_finalpass': test1of100pairs, |
| 'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)), |
| } |
| tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass'] |
| tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass'] |
| return tosave |
| |
| class Md21Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Middlebury2021" |
| self._set_root() |
| assert self.split in ['train','subtrain','subval'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1')) |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm') |
| self.pairname_to_str = lambda pairname: pairname[:-4] |
| self.load_disparity = _read_middlebury_disp |
| |
| def _build_cache(self): |
| seqs = sorted(os.listdir(self.root)) |
| trainpairs = [] |
| for s in seqs: |
| |
| trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')] |
| assert len(trainpairs)==355 |
| subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])] |
| subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])] |
| assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021" |
| tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
| return tosave |
|
|
| class Md14Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Middlebury2014" |
| self._set_root() |
| assert self.split in ['train','subtrain','subval'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm') |
| self.pairname_to_str = lambda pairname: pairname[:-4] |
| self.load_disparity = _read_middlebury_disp |
| self.has_constant_resolution = False |
| |
| def _build_cache(self): |
| seqs = sorted(os.listdir(self.root)) |
| trainpairs = [] |
| for s in seqs: |
| trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png'] |
| assert len(trainpairs)==138 |
| valseqs = ['Umbrella-imperfect','Vintage-perfect'] |
| assert all(s in seqs for s in valseqs) |
| subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] |
| subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] |
| assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014" |
| tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
| return tosave |
|
|
| class Md06Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Middlebury2006" |
| self._set_root() |
| assert self.split in ['train','subtrain','subval'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') |
| self.load_disparity = _read_middlebury20052006_disp |
| self.has_constant_resolution = False |
| |
| def _build_cache(self): |
| seqs = sorted(os.listdir(self.root)) |
| trainpairs = [] |
| for s in seqs: |
| for i in ['Illum1','Illum2','Illum3']: |
| for e in ['Exp0','Exp1','Exp2']: |
| trainpairs.append(osp.join(s,i,e,'view1.png')) |
| assert len(trainpairs)==189 |
| valseqs = ['Rocks1','Wood2'] |
| assert all(s in seqs for s in valseqs) |
| subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] |
| subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] |
| assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006" |
| tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
| return tosave |
|
|
| class Md05Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Middlebury2005" |
| self._set_root() |
| assert self.split in ['train','subtrain','subval'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') |
| self.pairname_to_str = lambda pairname: pairname[:-4] |
| self.load_disparity = _read_middlebury20052006_disp |
| |
| def _build_cache(self): |
| seqs = sorted(os.listdir(self.root)) |
| trainpairs = [] |
| for s in seqs: |
| for i in ['Illum1','Illum2','Illum3']: |
| for e in ['Exp0','Exp1','Exp2']: |
| trainpairs.append(osp.join(s,i,e,'view1.png')) |
| assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005" |
| valseqs = ['Reindeer'] |
| assert all(s in seqs for s in valseqs) |
| subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] |
| subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] |
| assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005" |
| tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
| return tosave |
| |
| class MdEval3Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "MiddleburyEval3" |
| self._set_root() |
| assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']] |
| if self.split.endswith('_full'): |
| self.root = self.root.replace('/MiddEval3','/MiddEval3_F') |
| elif self.split.endswith('_half'): |
| self.root = self.root.replace('/MiddEval3','/MiddEval3_H') |
| else: |
| assert self.split.endswith('_quarter') |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') |
| self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm') |
| self.pairname_to_str = lambda pairname: pairname |
| self.load_disparity = _read_middlebury_disp |
| |
| self.submission_methodname = "CroCo-Stereo" |
| self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q') |
| |
| def _build_cache(self): |
| trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))] |
| testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))] |
| subvalpairs = trainpairs[-1:] |
| subtrainpairs = trainpairs[:-1] |
| allpairs = trainpairs+testpairs |
| assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3" |
| tosave = {} |
| for r in ['full','half','quarter']: |
| tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs}) |
| return tosave |
| |
| def submission_save_pairname(self, pairname, prediction, outdir, time): |
| assert prediction.ndim==2 |
| assert prediction.dtype==np.float32 |
| outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm') |
| os.makedirs( os.path.dirname(outfile), exist_ok=True) |
| writePFM(outfile, prediction) |
| timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt') |
| with open(timefile, 'w') as fid: |
| fid.write(str(time)) |
|
|
| def finalize_submission(self, outdir): |
| cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' |
| print(cmd) |
| os.system(cmd) |
| print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip') |
|
|
| class ETH3DLowResDataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "ETH3DLowRes" |
| self._set_root() |
| assert self.split in ['train','test','subtrain','subval','all'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') |
| self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm') |
| self.pairname_to_str = lambda pairname: pairname |
| self.load_disparity = _read_eth3d_disp |
| self.has_constant_resolution = False |
| |
| def _build_cache(self): |
| trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))] |
| testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))] |
| assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res" |
| subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l'] |
| assert all(p in trainpairs for p in subvalpairs) |
| subtrainpairs = [p for p in trainpairs if not p in subvalpairs] |
| assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res" |
| tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs} |
| return tosave |
|
|
| def submission_save_pairname(self, pairname, prediction, outdir, time): |
| assert prediction.ndim==2 |
| assert prediction.dtype==np.float32 |
| outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm') |
| os.makedirs( os.path.dirname(outfile), exist_ok=True) |
| writePFM(outfile, prediction) |
| timefile = outfile[:-4]+'.txt' |
| with open(timefile, 'w') as fid: |
| fid.write('runtime '+str(time)) |
|
|
| def finalize_submission(self, outdir): |
| cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' |
| print(cmd) |
| os.system(cmd) |
| print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip') |
|
|
| class BoosterDataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Booster" |
| self._set_root() |
| assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/') |
| self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') |
| self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/') |
| self.load_disparity = _read_booster_disp |
| |
| |
| def _build_cache(self): |
| trainseqs = sorted(os.listdir(self.root+'train/balanced')) |
| trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))] |
| testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))] |
| assert len(trainpairs) == 228 and len(testpairs) == 191 |
| subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] |
| subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] |
| |
| tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,} |
| return tosave |
| |
| class SpringDataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Spring" |
| self._set_root() |
| assert self.split in ['train', 'test', 'subtrain', 'subval'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','<frame_right>').replace('frame_left','frame_right').replace('<frame_right>','frame_left') |
| self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') |
| self.pairname_to_str = lambda pairname: pairname |
| self.load_disparity = _read_hdf5_disp |
| |
| def _build_cache(self): |
| trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) |
| trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))] |
| testseqs = sorted(os.listdir( osp.join(self.root,'test'))) |
| testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))] |
| testpairs += [p.replace('frame_left','frame_right') for p in testpairs] |
| """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" |
| subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041'] |
| subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041'] |
| assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring" |
| tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
| return tosave |
| |
| def submission_save_pairname(self, pairname, prediction, outdir, time): |
| assert prediction.ndim==2 |
| assert prediction.dtype==np.float32 |
| outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') |
| os.makedirs( os.path.dirname(outfile), exist_ok=True) |
| writeDsp5File(prediction, outfile) |
| |
| def finalize_submission(self, outdir): |
| assert self.split=='test' |
| exe = "{self.root}/disp1_subsampling" |
| if os.path.isfile(exe): |
| cmd = f'cd "{outdir}/test"; {exe} .' |
| print(cmd) |
| os.system(cmd) |
| else: |
| print('Could not find disp1_subsampling executable for submission.') |
| print('Please download it and run:') |
| print(f'cd "{outdir}/test"; <disp1_subsampling_exe> .') |
|
|
| class Kitti12Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Kitti12" |
| self._set_root() |
| assert self.split in ['train','test'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png') |
| self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png') |
| self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') |
| self.load_disparity = _read_kitti_disp |
| |
| def _build_cache(self): |
| trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] |
| testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] |
| assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" |
| tosave = {'train': trainseqs, 'test': testseqs} |
| return tosave |
|
|
| def submission_save_pairname(self, pairname, prediction, outdir, time): |
| assert prediction.ndim==2 |
| assert prediction.dtype==np.float32 |
| outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') |
| os.makedirs( os.path.dirname(outfile), exist_ok=True) |
| img = (prediction * 256).astype('uint16') |
| Image.fromarray(img).save(outfile) |
|
|
| def finalize_submission(self, outdir): |
| assert self.split=='test' |
| cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' |
| print(cmd) |
| os.system(cmd) |
| print(f'Done. Submission file at {outdir}/kitti12_results.zip') |
|
|
| class Kitti15Dataset(StereoDataset): |
|
|
| def _prepare_data(self): |
| self.name = "Kitti15" |
| self._set_root() |
| assert self.split in ['train','subtrain','subval','test'] |
| self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') |
| self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png') |
| self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png') |
| self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') |
| self.load_disparity = _read_kitti_disp |
| |
| def _build_cache(self): |
| trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] |
| subtrainseqs = trainseqs[:-5] |
| subvalseqs = trainseqs[-5:] |
| testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] |
| assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" |
| tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} |
| return tosave |
|
|
| def submission_save_pairname(self, pairname, prediction, outdir, time): |
| assert prediction.ndim==2 |
| assert prediction.dtype==np.float32 |
| outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png') |
| os.makedirs( os.path.dirname(outfile), exist_ok=True) |
| img = (prediction * 256).astype('uint16') |
| Image.fromarray(img).save(outfile) |
|
|
| def finalize_submission(self, outdir): |
| assert self.split=='test' |
| cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' |
| print(cmd) |
| os.system(cmd) |
| print(f'Done. Submission file at {outdir}/kitti15_results.zip') |
|
|
|
|
| |
|
|
| def _read_img(filename): |
| |
| img = np.asarray(Image.open(filename).convert('RGB')) |
| return img |
|
|
| def _read_booster_disp(filename): |
| disp = np.load(filename) |
| disp[disp==0.0] = np.inf |
| return disp |
|
|
| def _read_png_disp(filename, coef=1.0): |
| disp = np.asarray(Image.open(filename)) |
| disp = disp.astype(np.float32) / coef |
| disp[disp==0.0] = np.inf |
| return disp |
|
|
| def _read_pfm_disp(filename): |
| disp = np.ascontiguousarray(_read_pfm(filename)[0]) |
| disp[disp<=0] = np.inf |
| return disp |
|
|
| def _read_npy_disp(filename): |
| return np.load(filename) |
|
|
| def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0) |
| def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0) |
| def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0) |
| _read_sceneflow_disp = _read_pfm_disp |
| _read_eth3d_disp = _read_pfm_disp |
| _read_middlebury_disp = _read_pfm_disp |
| _read_carla_disp = _read_pfm_disp |
| _read_tartanair_disp = _read_npy_disp |
| |
| def _read_hdf5_disp(filename): |
| disp = np.asarray(h5py.File(filename)['disparity']) |
| disp[np.isnan(disp)] = np.inf |
| |
| return disp.astype(np.float32) |
| |
| import re |
| def _read_pfm(file): |
| file = open(file, 'rb') |
|
|
| color = None |
| width = None |
| height = None |
| scale = None |
| endian = None |
|
|
| header = file.readline().rstrip() |
| if header.decode("ascii") == 'PF': |
| color = True |
| elif header.decode("ascii") == 'Pf': |
| color = False |
| else: |
| raise Exception('Not a PFM file.') |
|
|
| dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) |
| if dim_match: |
| width, height = list(map(int, dim_match.groups())) |
| else: |
| raise Exception('Malformed PFM header.') |
|
|
| scale = float(file.readline().decode("ascii").rstrip()) |
| if scale < 0: |
| endian = '<' |
| scale = -scale |
| else: |
| endian = '>' |
|
|
| data = np.fromfile(file, endian + 'f') |
| shape = (height, width, 3) if color else (height, width) |
|
|
| data = np.reshape(data, shape) |
| data = np.flipud(data) |
| return data, scale |
|
|
| def writePFM(file, image, scale=1): |
| file = open(file, 'wb') |
|
|
| color = None |
|
|
| if image.dtype.name != 'float32': |
| raise Exception('Image dtype must be float32.') |
|
|
| image = np.flipud(image) |
|
|
| if len(image.shape) == 3 and image.shape[2] == 3: |
| color = True |
| elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: |
| color = False |
| else: |
| raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') |
|
|
| file.write('PF\n' if color else 'Pf\n'.encode()) |
| file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) |
|
|
| endian = image.dtype.byteorder |
|
|
| if endian == '<' or endian == '=' and sys.byteorder == 'little': |
| scale = -scale |
|
|
| file.write('%f\n'.encode() % scale) |
|
|
| image.tofile(file) |
|
|
| def writeDsp5File(disp, filename): |
| with h5py.File(filename, "w") as f: |
| f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) |
|
|
|
|
| |
|
|
| def vis_disparity(disp, m=None, M=None): |
| if m is None: m = disp.min() |
| if M is None: M = disp.max() |
| disp_vis = (disp - m) / (M-m) * 255.0 |
| disp_vis = disp_vis.astype("uint8") |
| disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) |
| return disp_vis |
|
|
| |
| |
| def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): |
| dataset_str = dataset_str.replace('(','Dataset(') |
| if augmentor: |
| dataset_str = dataset_str.replace(')',', augmentor=True)') |
| if crop_size is not None: |
| dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) |
| return eval(dataset_str) |
| |
| def get_test_datasets_stereo(dataset_str): |
| dataset_str = dataset_str.replace('(','Dataset(') |
| return [eval(s) for s in dataset_str.split('+')] |