| |
| import os |
| import time |
| import numpy as np |
| import warnings |
| import random |
| from omegaconf.listconfig import ListConfig |
| from webdataset import pipelinefilter |
| import torch |
| import torchvision.transforms.functional as TVF |
| from torchvision.transforms import InterpolationMode |
| from torchvision.transforms.transforms import _interpolation_modes_from_int |
| from typing import Sequence |
|
|
| from michelangelo.utils import instantiate_from_config |
|
|
|
|
| def _uid_buffer_pick(buf_dict, rng): |
| uid_keys = list(buf_dict.keys()) |
| selected_uid = rng.choice(uid_keys) |
| buf = buf_dict[selected_uid] |
|
|
| k = rng.randint(0, len(buf) - 1) |
| sample = buf[k] |
| buf[k] = buf[-1] |
| buf.pop() |
|
|
| if len(buf) == 0: |
| del buf_dict[selected_uid] |
|
|
| return sample |
|
|
|
|
| def _add_to_buf_dict(buf_dict, sample): |
| key = sample["__key__"] |
| uid, uid_sample_id = key.split("_") |
| if uid not in buf_dict: |
| buf_dict[uid] = [] |
| buf_dict[uid].append(sample) |
|
|
| return buf_dict |
|
|
|
|
| def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): |
| """Shuffle the data in the stream. |
| |
| This uses a buffer of size `bufsize`. Shuffling at |
| startup is less random; this is traded off against |
| yielding samples quickly. |
| |
| data: iterator |
| bufsize: buffer size for shuffling |
| returns: iterator |
| rng: either random module or random.Random instance |
| |
| """ |
| if rng is None: |
| rng = random.Random(int((os.getpid() + time.time()) * 1e9)) |
| initial = min(initial, bufsize) |
| buf_dict = dict() |
| current_samples = 0 |
| for sample in data: |
| _add_to_buf_dict(buf_dict, sample) |
| current_samples += 1 |
|
|
| if current_samples < bufsize: |
| try: |
| _add_to_buf_dict(buf_dict, next(data)) |
| current_samples += 1 |
| except StopIteration: |
| pass |
|
|
| if current_samples >= initial: |
| current_samples -= 1 |
| yield _uid_buffer_pick(buf_dict, rng) |
|
|
| while current_samples > 0: |
| current_samples -= 1 |
| yield _uid_buffer_pick(buf_dict, rng) |
|
|
|
|
| uid_shuffle = pipelinefilter(_uid_shuffle) |
|
|
|
|
| class RandomSample(object): |
| def __init__(self, |
| num_volume_samples: int = 1024, |
| num_near_samples: int = 1024): |
|
|
| super().__init__() |
|
|
| self.num_volume_samples = num_volume_samples |
| self.num_near_samples = num_near_samples |
|
|
| def __call__(self, sample): |
| rng = np.random.default_rng() |
|
|
| |
| total_surface = sample["surface"] |
| ind = rng.choice(total_surface.shape[0], replace=False) |
| surface = total_surface[ind] |
|
|
| |
| vol_points = sample["vol_points"] |
| vol_label = sample["vol_label"] |
| near_points = sample["near_points"] |
| near_label = sample["near_label"] |
|
|
| ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) |
| vol_points = vol_points[ind] |
| vol_label = vol_label[ind] |
| vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) |
|
|
| ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) |
| near_points = near_points[ind] |
| near_label = near_label[ind] |
| near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) |
|
|
| |
| geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) |
|
|
| sample = { |
| "surface": surface, |
| "geo_points": geo_points |
| } |
|
|
| return sample |
|
|
|
|
| class SplitRandomSample(object): |
| def __init__(self, |
| use_surface_sample: bool = False, |
| num_surface_samples: int = 4096, |
| num_volume_samples: int = 1024, |
| num_near_samples: int = 1024): |
|
|
| super().__init__() |
|
|
| self.use_surface_sample = use_surface_sample |
| self.num_surface_samples = num_surface_samples |
| self.num_volume_samples = num_volume_samples |
| self.num_near_samples = num_near_samples |
|
|
| def __call__(self, sample): |
|
|
| rng = np.random.default_rng() |
|
|
| |
| surface = sample["surface"] |
|
|
| if self.use_surface_sample: |
| replace = surface.shape[0] < self.num_surface_samples |
| ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) |
| surface = surface[ind] |
|
|
| |
| vol_points = sample["vol_points"] |
| vol_label = sample["vol_label"] |
| near_points = sample["near_points"] |
| near_label = sample["near_label"] |
|
|
| ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) |
| vol_points = vol_points[ind] |
| vol_label = vol_label[ind] |
| vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) |
|
|
| ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) |
| near_points = near_points[ind] |
| near_label = near_label[ind] |
| near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) |
|
|
| |
| geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) |
|
|
| sample = { |
| "surface": surface, |
| "geo_points": geo_points |
| } |
|
|
| return sample |
|
|
|
|
| class FeatureSelection(object): |
|
|
| VALID_SURFACE_FEATURE_DIMS = { |
| "none": [0, 1, 2], |
| "watertight_normal": [0, 1, 2, 3, 4, 5], |
| "normal": [0, 1, 2, 6, 7, 8] |
| } |
|
|
| def __init__(self, surface_feature_type: str): |
|
|
| self.surface_feature_type = surface_feature_type |
| self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] |
|
|
| def __call__(self, sample): |
| sample["surface"] = sample["surface"][:, self.surface_dims] |
| return sample |
|
|
|
|
| class AxisScaleTransform(object): |
| def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): |
| assert isinstance(interval, (tuple, list, ListConfig)) |
| self.interval = interval |
| self.min_val = interval[0] |
| self.max_val = interval[1] |
| self.inter_size = interval[1] - interval[0] |
| self.jitter = jitter |
| self.jitter_scale = jitter_scale |
|
|
| def __call__(self, sample): |
|
|
| surface = sample["surface"][..., 0:3] |
| geo_points = sample["geo_points"][..., 0:3] |
|
|
| scaling = torch.rand(1, 3) * self.inter_size + self.min_val |
| |
| surface = surface * scaling |
| geo_points = geo_points * scaling |
|
|
| scale = (1 / torch.abs(surface).max().item()) * 0.999999 |
| surface *= scale |
| geo_points *= scale |
|
|
| if self.jitter: |
| surface += self.jitter_scale * torch.randn_like(surface) |
| surface.clamp_(min=-1.015, max=1.015) |
|
|
| sample["surface"][..., 0:3] = surface |
| sample["geo_points"][..., 0:3] = geo_points |
|
|
| return sample |
|
|
|
|
| class ToTensor(object): |
|
|
| def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): |
| self.tensor_keys = tensor_keys |
|
|
| def __call__(self, sample): |
| for key in self.tensor_keys: |
| if key not in sample: |
| continue |
|
|
| sample[key] = torch.tensor(sample[key], dtype=torch.float32) |
|
|
| return sample |
|
|
|
|
| class AxisScale(object): |
| def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): |
| assert isinstance(interval, (tuple, list, ListConfig)) |
| self.interval = interval |
| self.jitter = jitter |
| self.jitter_scale = jitter_scale |
|
|
| def __call__(self, surface, *args): |
| scaling = torch.rand(1, 3) * 0.5 + 0.75 |
| |
| surface = surface * scaling |
| scale = (1 / torch.abs(surface).max().item()) * 0.999999 |
| surface *= scale |
|
|
| args_outputs = [] |
| for _arg in args: |
| _arg = _arg * scaling * scale |
| args_outputs.append(_arg) |
|
|
| if self.jitter: |
| surface += self.jitter_scale * torch.randn_like(surface) |
| surface.clamp_(min=-1, max=1) |
|
|
| if len(args) == 0: |
| return surface |
| else: |
| return surface, *args_outputs |
|
|
|
|
| class RandomResize(torch.nn.Module): |
| """Apply randomly Resize with a given probability.""" |
|
|
| def __init__( |
| self, |
| size, |
| resize_radio=(0.5, 1), |
| allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), |
| interpolation=InterpolationMode.BICUBIC, |
| max_size=None, |
| antialias=None, |
| ): |
| super().__init__() |
| if not isinstance(size, (int, Sequence)): |
| raise TypeError(f"Size should be int or sequence. Got {type(size)}") |
| if isinstance(size, Sequence) and len(size) not in (1, 2): |
| raise ValueError("If size is a sequence, it should have 1 or 2 values") |
|
|
| self.size = size |
| self.max_size = max_size |
| |
| if isinstance(interpolation, int): |
| warnings.warn( |
| "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " |
| "Please use InterpolationMode enum." |
| ) |
| interpolation = _interpolation_modes_from_int(interpolation) |
|
|
| self.interpolation = interpolation |
| self.antialias = antialias |
|
|
| self.resize_radio = resize_radio |
| self.allow_resize_interpolations = allow_resize_interpolations |
|
|
| def random_resize_params(self): |
| radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] |
|
|
| if isinstance(self.size, int): |
| size = int(self.size * radio) |
| elif isinstance(self.size, Sequence): |
| size = list(self.size) |
| size = (int(size[0] * radio), int(size[1] * radio)) |
| else: |
| raise RuntimeError() |
|
|
| interpolation = self.allow_resize_interpolations[ |
| torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) |
| ] |
| return size, interpolation |
|
|
| def forward(self, img): |
| size, interpolation = self.random_resize_params() |
| img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) |
| img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) |
| return img |
|
|
| def __repr__(self) -> str: |
| detail = f"(size={self.size}, interpolation={self.interpolation.value}," |
| detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" |
| return f"{self.__class__.__name__}{detail}" |
|
|
|
|
| class Compose(object): |
| """Composes several transforms together. This transform does not support torchscript. |
| Please, see the note below. |
| |
| Args: |
| transforms (list of ``Transform`` objects): list of transforms to compose. |
| |
| Example: |
| >>> transforms.Compose([ |
| >>> transforms.CenterCrop(10), |
| >>> transforms.ToTensor(), |
| >>> ]) |
| |
| .. note:: |
| In order to script the transformations, please use ``torch.nn.Sequential`` as below. |
| |
| >>> transforms = torch.nn.Sequential( |
| >>> transforms.CenterCrop(10), |
| >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| >>> ) |
| >>> scripted_transforms = torch.jit.script(transforms) |
| |
| Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require |
| `lambda` functions or ``PIL.Image``. |
| |
| """ |
|
|
| def __init__(self, transforms): |
| self.transforms = transforms |
|
|
| def __call__(self, *args): |
| for t in self.transforms: |
| args = t(*args) |
| return args |
|
|
| def __repr__(self): |
| format_string = self.__class__.__name__ + '(' |
| for t in self.transforms: |
| format_string += '\n' |
| format_string += ' {0}'.format(t) |
| format_string += '\n)' |
| return format_string |
|
|
|
|
| def identity(*args, **kwargs): |
| if len(args) == 1: |
| return args[0] |
| else: |
| return args |
|
|
|
|
| def build_transforms(cfg): |
|
|
| if cfg is None: |
| return identity |
|
|
| transforms = [] |
|
|
| for transform_name, cfg_instance in cfg.items(): |
| transform_instance = instantiate_from_config(cfg_instance) |
| transforms.append(transform_instance) |
| print(f"Build transform: {transform_instance}") |
|
|
| transforms = Compose(transforms) |
|
|
| return transforms |
|
|
|
|