| import warnings |
| from math import ceil |
| from . import interp_methods |
|
|
|
|
| class NoneClass: |
| pass |
|
|
| try: |
| import torch |
| from torch import nn |
| nnModuleWrapped = nn.Module |
| except ImportError: |
| warnings.warn('No PyTorch found, will work only with Numpy') |
| torch = None |
| nnModuleWrapped = NoneClass |
|
|
| try: |
| import numpy |
| except ImportError: |
| warnings.warn('No Numpy found, will work only with PyTorch') |
| numpy = None |
|
|
|
|
| if numpy is None and torch is None: |
| raise ImportError("Must have either Numpy or PyTorch but both not found") |
|
|
|
|
| def resize(input, scale_factors=None, out_shape=None, |
| interp_method=interp_methods.cubic, support_sz=None, |
| antialiasing=True): |
| |
| in_shape, n_dims = input.shape, input.ndim |
|
|
| |
| |
| fw = numpy if type(input) is numpy.ndarray else torch |
| eps = fw.finfo(fw.float32).eps |
|
|
| |
| |
| scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, |
| scale_factors, fw) |
|
|
| |
| |
| sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) |
| for dim in sorted(range(n_dims), |
| key=lambda ind: scale_factors[ind]) |
| if scale_factors[dim] != 1.] |
|
|
| |
| |
| if support_sz is None: |
| support_sz = interp_method.support_sz |
|
|
| |
| device = input.device if fw is torch else None |
|
|
| |
| output = input |
|
|
| |
| for dim, scale_factor in sorted_filtered_dims_and_scales: |
|
|
| |
| |
| field_of_view, weights = prepare_weights_and_field_of_view_1d( |
| dim, scale_factor, in_shape[dim], out_shape[dim], interp_method, |
| support_sz, antialiasing, fw, eps, device) |
|
|
| |
| |
| output = apply_weights(output, field_of_view, weights, dim, n_dims, |
| fw) |
| return output |
|
|
|
|
| class ResizeLayer(nnModuleWrapped): |
| def __init__(self, in_shape, scale_factors=None, out_shape=None, |
| interp_method=interp_methods.cubic, support_sz=None, |
| antialiasing=True): |
| super(ResizeLayer, self).__init__() |
|
|
| |
| |
| fw = torch |
| eps = fw.finfo(fw.float32).eps |
|
|
| |
| |
| scale_factors, out_shape = set_scale_and_out_sz(in_shape, out_shape, |
| scale_factors, fw) |
|
|
| |
| |
| if support_sz is None: |
| support_sz = interp_method.support_sz |
|
|
| self.n_dims = len(in_shape) |
|
|
| |
| |
| self.sorted_filtered_dims_and_scales = [(dim, scale_factors[dim]) |
| for dim in |
| sorted(range(self.n_dims), |
| key=lambda ind: |
| scale_factors[ind]) |
| if scale_factors[dim] != 1.] |
|
|
| |
| field_of_view_list = [] |
| weights_list = [] |
| for dim, scale_factor in self.sorted_filtered_dims_and_scales: |
|
|
| |
| |
| field_of_view, weights = prepare_weights_and_field_of_view_1d( |
| dim, scale_factor, in_shape[dim], out_shape[dim], |
| interp_method, support_sz, antialiasing, fw, eps, input.device) |
|
|
| |
| weights_list.append(nn.Parameter(weights, requires_grad=False)) |
| field_of_view_list.append(nn.Parameter(field_of_view, |
| requires_grad=False)) |
|
|
| self.field_of_view = nn.ParameterList(field_of_view_list) |
| self.weights = nn.ParameterList(weights_list) |
| self.in_shape = in_shape |
|
|
| def forward(self, input): |
| |
| output = input |
|
|
| for (dim, scale_factor), field_of_view, weights in zip( |
| self.sorted_filtered_dims_and_scales, |
| self.field_of_view, |
| self.weights): |
| |
| |
| output = apply_weights(output, field_of_view, weights, dim, |
| self.n_dims, torch) |
| return output |
|
|
|
|
| def prepare_weights_and_field_of_view_1d(dim, scale_factor, in_sz, out_sz, |
| interp_method, support_sz, |
| antialiasing, fw, eps, device=None): |
| |
| |
| interp_method, cur_support_sz = apply_antialiasing_if_needed( |
| interp_method, |
| support_sz, |
| scale_factor, |
| antialiasing) |
|
|
| |
| |
| projected_grid = get_projected_grid(in_sz, out_sz, scale_factor, fw, device) |
|
|
| |
| |
| field_of_view = get_field_of_view(projected_grid, cur_support_sz, in_sz, |
| fw, eps, device) |
|
|
| |
| |
| weights = get_weights(interp_method, projected_grid, field_of_view) |
|
|
| return field_of_view, weights |
|
|
|
|
| def apply_weights(input, field_of_view, weights, dim, n_dims, fw): |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| tmp_input = fw_swapaxes(input, dim, 0, fw) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| neighbors = tmp_input[field_of_view] |
|
|
| |
| |
| |
| |
| tmp_weights = fw.reshape(weights, (*weights.shape, * [1] * (n_dims - 1))) |
|
|
| |
| |
| tmp_output = (neighbors * tmp_weights).sum(1) |
|
|
| |
| return fw_swapaxes(tmp_output, 0, dim, fw) |
|
|
|
|
| def set_scale_and_out_sz(in_shape, out_shape, scale_factors, fw): |
| |
| |
| if scale_factors is None and out_shape is None: |
| raise ValueError("either scale_factors or out_shape should be " |
| "provided") |
| if out_shape is not None: |
| |
| |
| |
| |
| |
| out_shape = (list(out_shape) + list(in_shape[-len(out_shape):]) |
| if fw is numpy |
| else list(in_shape[:-len(out_shape)]) + list(out_shape)) |
| if scale_factors is None: |
| |
| |
| scale_factors = [out_sz / in_sz for out_sz, in_sz |
| in zip(out_shape, in_shape)] |
| if scale_factors is not None: |
| |
| |
| scale_factors = (scale_factors |
| if isinstance(scale_factors, (list, tuple)) |
| else [scale_factors, scale_factors]) |
| |
| |
| scale_factors = (list(scale_factors) + [1] * |
| (len(in_shape) - len(scale_factors)) if fw is numpy |
| else [1] * (len(in_shape) - len(scale_factors)) + |
| list(scale_factors)) |
| if out_shape is None: |
| |
| |
| out_shape = [ceil(scale_factor * in_sz) |
| for scale_factor, in_sz in |
| zip(scale_factors, in_shape)] |
| |
| scale_factors = [float(sf) for sf in scale_factors] |
| return scale_factors, out_shape |
|
|
|
|
| def get_projected_grid(in_sz, out_sz, scale_factor, fw, device=None): |
| |
| out_coordinates = fw.arange(out_sz) |
|
|
| |
| out_coordinates = fw_set_device(out_coordinates, device, fw) |
|
|
| |
| |
| |
| |
| return (out_coordinates / scale_factor + |
| (in_sz - 1) / 2 - (out_sz - 1) / (2 * scale_factor)) |
|
|
|
|
| def get_field_of_view(projected_grid, cur_support_sz, in_sz, fw, eps, device): |
| |
| |
| |
| left_boundaries = fw_ceil(projected_grid - cur_support_sz / 2 - eps, fw) |
|
|
| |
| |
| ordinal_numbers = fw.arange(ceil(cur_support_sz - eps)) |
| |
| ordinal_numbers = fw_set_device(ordinal_numbers, device, fw) |
| field_of_view = left_boundaries[:, None] + ordinal_numbers |
|
|
| |
| |
| |
| mirror = fw_cat((fw.arange(in_sz), fw.arange(in_sz - 1, -1, step=-1)), fw) |
| field_of_view = mirror[fw.remainder(field_of_view, mirror.shape[0])] |
| field_of_view = fw_set_device(field_of_view, device, fw) |
| return field_of_view |
|
|
|
|
| def get_weights(interp_method, projected_grid, field_of_view): |
| |
| |
| |
| |
| weights = interp_method(projected_grid[:, None] - field_of_view) |
|
|
| |
| sum_weights = weights.sum(1, keepdims=True) |
| sum_weights[sum_weights == 0] = 1 |
| return weights / sum_weights |
|
|
|
|
| def apply_antialiasing_if_needed(interp_method, support_sz, scale_factor, |
| antialiasing): |
| |
| |
| |
| |
| if scale_factor >= 1.0 or not antialiasing: |
| return interp_method, support_sz |
| cur_interp_method = (lambda arg: scale_factor * |
| interp_method(scale_factor * arg)) |
| cur_support_sz = support_sz / scale_factor |
| return cur_interp_method, cur_support_sz |
|
|
|
|
| def fw_ceil(x, fw): |
| if fw is numpy: |
| return fw.int_(fw.ceil(x)) |
| else: |
| return x.ceil().long() |
|
|
|
|
| def fw_cat(x, fw): |
| if fw is numpy: |
| return fw.concatenate(x) |
| else: |
| return fw.cat(x) |
|
|
|
|
| def fw_swapaxes(x, ax_1, ax_2, fw): |
| if fw is numpy: |
| return fw.swapaxes(x, ax_1, ax_2) |
| else: |
| return x.transpose(ax_1, ax_2) |
|
|
| def fw_set_device(x, device, fw): |
| if fw is numpy: |
| return x |
| else: |
| return x.to(device) |
|
|