Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,060 Bytes
b56342d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
from copy import deepcopy
from functools import partial
from glob import glob
from hashlib import sha1
from typing import Callable, Iterable, Optional, Tuple
import cv2
import numpy as np
from glog import logger
from joblib import Parallel, cpu_count, delayed
from skimage.io import imread
from torch.utils.data import Dataset
from tqdm import tqdm
import aug
def subsample(data: Iterable, bounds: Tuple[float, float], hash_fn: Callable, n_buckets=100, salt='', verbose=True):
data = list(data)
buckets = split_into_buckets(data, n_buckets=n_buckets, salt=salt, hash_fn=hash_fn)
lower_bound, upper_bound = [x * n_buckets for x in bounds]
msg = f'Subsampling buckets from {lower_bound} to {upper_bound}, total buckets number is {n_buckets}'
if salt:
msg += f'; salt is {salt}'
if verbose:
logger.info(msg)
return np.array([sample for bucket, sample in zip(buckets, data) if lower_bound <= bucket < upper_bound])
def hash_from_paths(x: Tuple[str, str], salt: str = '') -> str:
path_a, path_b = x
names = ''.join(map(os.path.basename, (path_a, path_b)))
return sha1(f'{names}_{salt}'.encode()).hexdigest()
def split_into_buckets(data: Iterable, n_buckets: int, hash_fn: Callable, salt=''):
hashes = map(partial(hash_fn, salt=salt), data)
return np.array([int(x, 16) % n_buckets for x in hashes])
def _read_img(x: str):
img = cv2.imread(x)
if img is None:
logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')
img = imread(x)
return img
class PairedDataset(Dataset):
def __init__(self,
files_a: Tuple[str],
files_b: Tuple[str],
transform_fn: Callable,
normalize_fn: Callable,
corrupt_fn: Optional[Callable] = None,
preload: bool = True,
preload_size: Optional[int] = 0,
verbose=True):
assert len(files_a) == len(files_b)
self.preload = preload
self.data_a = files_a
self.data_b = files_b
self.verbose = verbose
self.corrupt_fn = corrupt_fn
self.transform_fn = transform_fn
self.normalize_fn = normalize_fn
logger.info(f'Dataset has been created with {len(self.data_a)} samples')
if preload:
preload_fn = partial(self._bulk_preload, preload_size=preload_size)
if files_a == files_b:
self.data_a = self.data_b = preload_fn(self.data_a)
else:
self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))
self.preload = True
def _bulk_preload(self, data: Iterable[str], preload_size: int):
jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]
jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)
return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)
@staticmethod
def _preload(x: str, preload_size: int):
img = _read_img(x)
if preload_size:
h, w, *_ = img.shape
h_scale = preload_size / h
w_scale = preload_size / w
scale = max(h_scale, w_scale)
img = cv2.resize(img, fx=scale, fy=scale, dsize=None)
assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'
return img
def _preprocess(self, img, res):
def transpose(x):
return np.transpose(x, (2, 0, 1))
return map(transpose, self.normalize_fn(img, res))
def __len__(self):
return len(self.data_a)
def __getitem__(self, idx):
a, b = self.data_a[idx], self.data_b[idx]
if not self.preload:
a, b = map(_read_img, (a, b))
a, b = self.transform_fn(a, b)
if self.corrupt_fn is not None:
a = self.corrupt_fn(a)
a, b = self._preprocess(a, b)
return {'a': a, 'b': b}
@staticmethod
def from_config(config):
config = deepcopy(config)
files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))
transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])
normalize_fn = aug.get_normalize()
hash_fn = hash_from_paths
# ToDo: add more hash functions
verbose = config.get('verbose', True)
data = subsample(data=zip(files_a, files_b),
bounds=config.get('bounds', (0, 1)),
hash_fn=hash_fn,
verbose=verbose)
files_a, files_b = map(list, zip(*data))
return PairedDataset(files_a=files_a,
files_b=files_b,
preload=config['preload'],
preload_size=config['preload_size'],
normalize_fn=normalize_fn,
transform_fn=transform_fn,
verbose=verbose) |