| | import collections |
| | import os |
| | import tarfile |
| | import urllib |
| | import zipfile |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| | from taming.data.helper_types import Annotation |
| | from torch._six import string_classes |
| | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format |
| | from tqdm import tqdm |
| |
|
| |
|
| | def unpack(path): |
| | if path.endswith("tar.gz"): |
| | with tarfile.open(path, "r:gz") as tar: |
| | tar.extractall(path=os.path.split(path)[0]) |
| | elif path.endswith("tar"): |
| | with tarfile.open(path, "r:") as tar: |
| | tar.extractall(path=os.path.split(path)[0]) |
| | elif path.endswith("zip"): |
| | with zipfile.ZipFile(path, "r") as f: |
| | f.extractall(path=os.path.split(path)[0]) |
| | else: |
| | raise NotImplementedError( |
| | "Unknown file extension: {}".format(os.path.splitext(path)[1]) |
| | ) |
| |
|
| |
|
| | def reporthook(bar): |
| | """tqdm progress bar for downloads.""" |
| |
|
| | def hook(b=1, bsize=1, tsize=None): |
| | if tsize is not None: |
| | bar.total = tsize |
| | bar.update(b * bsize - bar.n) |
| |
|
| | return hook |
| |
|
| |
|
| | def get_root(name): |
| | base = "data/" |
| | root = os.path.join(base, name) |
| | os.makedirs(root, exist_ok=True) |
| | return root |
| |
|
| |
|
| | def is_prepared(root): |
| | return Path(root).joinpath(".ready").exists() |
| |
|
| |
|
| | def mark_prepared(root): |
| | Path(root).joinpath(".ready").touch() |
| |
|
| |
|
| | def prompt_download(file_, source, target_dir, content_dir=None): |
| | targetpath = os.path.join(target_dir, file_) |
| | while not os.path.exists(targetpath): |
| | if content_dir is not None and os.path.exists( |
| | os.path.join(target_dir, content_dir) |
| | ): |
| | break |
| | print( |
| | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) |
| | ) |
| | if content_dir is not None: |
| | print( |
| | "Or place its content into '{}'.".format( |
| | os.path.join(target_dir, content_dir) |
| | ) |
| | ) |
| | input("Press Enter when done...") |
| | return targetpath |
| |
|
| |
|
| | def download_url(file_, url, target_dir): |
| | targetpath = os.path.join(target_dir, file_) |
| | os.makedirs(target_dir, exist_ok=True) |
| | with tqdm( |
| | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ |
| | ) as bar: |
| | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) |
| | return targetpath |
| |
|
| |
|
| | def download_urls(urls, target_dir): |
| | paths = dict() |
| | for fname, url in urls.items(): |
| | outpath = download_url(fname, url, target_dir) |
| | paths[fname] = outpath |
| | return paths |
| |
|
| |
|
| | def quadratic_crop(x, bbox, alpha=1.0): |
| | """bbox is xmin, ymin, xmax, ymax""" |
| | im_h, im_w = x.shape[:2] |
| | bbox = np.array(bbox, dtype=np.float32) |
| | bbox = np.clip(bbox, 0, max(im_h, im_w)) |
| | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) |
| | w = bbox[2] - bbox[0] |
| | h = bbox[3] - bbox[1] |
| | l = int(alpha * max(w, h)) |
| | l = max(l, 2) |
| |
|
| | required_padding = -1 * min( |
| | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) |
| | ) |
| | required_padding = int(np.ceil(required_padding)) |
| | if required_padding > 0: |
| | padding = [ |
| | [required_padding, required_padding], |
| | [required_padding, required_padding], |
| | ] |
| | padding += [[0, 0]] * (len(x.shape) - 2) |
| | x = np.pad(x, padding, "reflect") |
| | center = center[0] + required_padding, center[1] + required_padding |
| | xmin = int(center[0] - l / 2) |
| | ymin = int(center[1] - l / 2) |
| | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) |
| |
|
| |
|
| | def custom_collate(batch): |
| | r"""source: pytorch 1.9.0, only one modification to original code """ |
| |
|
| | elem = batch[0] |
| | elem_type = type(elem) |
| | if isinstance(elem, torch.Tensor): |
| | out = None |
| | if torch.utils.data.get_worker_info() is not None: |
| | |
| | |
| | numel = sum([x.numel() for x in batch]) |
| | storage = elem.storage()._new_shared(numel) |
| | out = elem.new(storage) |
| | return torch.stack(batch, 0, out=out) |
| | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
| | and elem_type.__name__ != 'string_': |
| | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': |
| | |
| | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
| | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) |
| |
|
| | return custom_collate([torch.as_tensor(b) for b in batch]) |
| | elif elem.shape == (): |
| | return torch.as_tensor(batch) |
| | elif isinstance(elem, float): |
| | return torch.tensor(batch, dtype=torch.float64) |
| | elif isinstance(elem, int): |
| | return torch.tensor(batch) |
| | elif isinstance(elem, string_classes): |
| | return batch |
| | elif isinstance(elem, collections.abc.Mapping): |
| | return {key: custom_collate([d[key] for d in batch]) for key in elem} |
| | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
| | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) |
| | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): |
| | return batch |
| | elif isinstance(elem, collections.abc.Sequence): |
| | |
| | it = iter(batch) |
| | elem_size = len(next(it)) |
| | if not all(len(elem) == elem_size for elem in it): |
| | raise RuntimeError('each element in list of batch should be of equal size') |
| | transposed = zip(*batch) |
| | return [custom_collate(samples) for samples in transposed] |
| |
|
| | raise TypeError(default_collate_err_msg_format.format(elem_type)) |
| |
|