| | |
| | |
| |
|
| | import os |
| | from torch.utils.data import Dataset |
| | from PIL import Image |
| |
|
| | from datasets.transforms import get_pair_transforms |
| |
|
| | def load_image(impath): |
| | return Image.open(impath) |
| |
|
| | def load_pairs_from_cache_file(fname, root=''): |
| | assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) |
| | with open(fname, 'r') as fid: |
| | lines = fid.read().strip().splitlines() |
| | pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines] |
| | return pairs |
| | |
| | def load_pairs_from_list_file(fname, root=''): |
| | assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname) |
| | with open(fname, 'r') as fid: |
| | lines = fid.read().strip().splitlines() |
| | pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')] |
| | return pairs |
| | |
| | |
| | def write_cache_file(fname, pairs, root=''): |
| | if len(root)>0: |
| | if not root.endswith('/'): root+='/' |
| | assert os.path.isdir(root) |
| | s = '' |
| | for im1, im2 in pairs: |
| | if len(root)>0: |
| | assert im1.startswith(root), im1 |
| | assert im2.startswith(root), im2 |
| | s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):]) |
| | with open(fname, 'w') as fid: |
| | fid.write(s[:-1]) |
| | |
| | def parse_and_cache_all_pairs(dname, data_dir='./data/'): |
| | if dname=='habitat_release': |
| | dirname = os.path.join(data_dir, 'habitat_release') |
| | assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname |
| | cache_file = os.path.join(dirname, 'pairs.txt') |
| | assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file |
| | |
| | print('Parsing pairs for dataset: '+dname) |
| | pairs = [] |
| | for root, dirs, files in os.walk(dirname): |
| | if 'val' in root: continue |
| | dirs.sort() |
| | pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')] |
| | print('Found {:,} pairs'.format(len(pairs))) |
| | print('Writing cache to: '+cache_file) |
| | write_cache_file(cache_file, pairs, root=dirname) |
| |
|
| | else: |
| | raise NotImplementedError('Unknown dataset: '+dname) |
| | |
| | def dnames_to_image_pairs(dnames, data_dir='./data/'): |
| | """ |
| | dnames: list of datasets with image pairs, separated by + |
| | """ |
| | all_pairs = [] |
| | for dname in dnames.split('+'): |
| | if dname=='habitat_release': |
| | dirname = os.path.join(data_dir, 'habitat_release') |
| | assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname |
| | cache_file = os.path.join(dirname, 'pairs.txt') |
| | assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file |
| | pairs = load_pairs_from_cache_file(cache_file, root=dirname) |
| | elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']: |
| | dirname = os.path.join(data_dir, dname+'_crops') |
| | assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname) |
| | list_file = os.path.join(dirname, 'listing.txt') |
| | assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file) |
| | pairs = load_pairs_from_list_file(list_file, root=dirname) |
| | print(' {:s}: {:,} pairs'.format(dname, len(pairs))) |
| | all_pairs += pairs |
| | if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs))) |
| | return all_pairs |
| |
|
| |
|
| | class PairsDataset(Dataset): |
| |
|
| | def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'): |
| | super().__init__() |
| | self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir) |
| | self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize) |
| | |
| | def __len__(self): |
| | return len(self.image_pairs) |
| | |
| | def __getitem__(self, index): |
| | im1path, im2path = self.image_pairs[index] |
| | im1 = load_image(im1path) |
| | im2 = load_image(im2path) |
| | if self.transforms is not None: im1, im2 = self.transforms(im1, im2) |
| | return im1, im2 |
| |
|
| | |
| | if __name__=="__main__": |
| | import argparse |
| | parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset") |
| | parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored") |
| | parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset") |
| | args = parser.parse_args() |
| | parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir) |
| |
|