| try: | |
| import jitfields | |
| available = True | |
| except (ImportError, ModuleNotFoundError): | |
| jitfields = None | |
| available = False | |
| from .utils import make_list | |
| import torch | |
| def first2last(input, ndim): | |
| insert = input.dim() <= ndim | |
| if insert: | |
| input = input.unsqueeze(-1) | |
| else: | |
| input = torch.movedim(input, -ndim-1, -1) | |
| return input, insert | |
| def last2first(input, ndim, inserted, grad=False): | |
| if inserted: | |
| input = input.squeeze(-1 - grad) | |
| else: | |
| input = torch.movedim(input, -1 - grad, -ndim-1 - grad) | |
| return input | |
| def grid_pull(input, grid, interpolation='linear', bound='zero', | |
| extrapolate=False, prefilter=False): | |
| ndim = grid.shape[-1] | |
| input, inserted = first2last(input, ndim) | |
| input = jitfields.pull(input, grid, order=interpolation, bound=bound, | |
| extrapolate=extrapolate, prefilter=prefilter) | |
| input = last2first(input, ndim, inserted) | |
| return input | |
| def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', | |
| extrapolate=False, prefilter=False): | |
| ndim = grid.shape[-1] | |
| input, inserted = first2last(input, ndim) | |
| input = jitfields.push(input, grid, shape, order=interpolation, bound=bound, | |
| extrapolate=extrapolate, prefilter=prefilter) | |
| input = last2first(input, ndim, inserted) | |
| return input | |
| def grid_count(grid, shape=None, interpolation='linear', bound='zero', | |
| extrapolate=False): | |
| return jitfields.count(grid, shape, order=interpolation, bound=bound, | |
| extrapolate=extrapolate) | |
| def grid_grad(input, grid, interpolation='linear', bound='zero', | |
| extrapolate=False, prefilter=False): | |
| ndim = grid.shape[-1] | |
| input, inserted = first2last(input, ndim) | |
| input = jitfields.grad(input, grid, order=interpolation, bound=bound, | |
| extrapolate=extrapolate, prefilter=prefilter) | |
| input = last2first(input, ndim, inserted, True) | |
| return input | |
| def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, | |
| inplace=False): | |
| func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff | |
| return func(input, interpolation, bound=bound, dim=dim) | |
| def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, | |
| inplace=False): | |
| func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd | |
| return func(input, interpolation, bound=bound, ndim=dim) | |
| def resize(image, factor=None, shape=None, anchor='c', | |
| interpolation=1, prefilter=True, **kwargs): | |
| kwargs.setdefault('bound', 'nearest') | |
| ndim = max(len(make_list(factor or [])), | |
| len(make_list(shape or [])), | |
| len(make_list(anchor or []))) or (image.dim() - 2) | |
| return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim, | |
| anchor=anchor, order=interpolation, | |
| bound=kwargs['bound'], prefilter=prefilter) | |
| def restrict(image, factor=None, shape=None, anchor='c', | |
| interpolation=1, reduce_sum=False, **kwargs): | |
| kwargs.setdefault('bound', 'nearest') | |
| ndim = max(len(make_list(factor or [])), | |
| len(make_list(shape or [])), | |
| len(make_list(anchor or []))) or (image.dim() - 2) | |
| return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim, | |
| anchor=anchor, order=interpolation, | |
| bound=kwargs['bound'], reduce_sum=reduce_sum) | |