File size: 3,581 Bytes
2571f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
try:
    import jitfields
    available = True
except (ImportError, ModuleNotFoundError):
    jitfields = None
    available = False
from .utils import make_list
import torch


def first2last(input, ndim):
    insert = input.dim() <= ndim
    if insert:
        input = input.unsqueeze(-1)
    else:
        input = torch.movedim(input, -ndim-1, -1)
    return input, insert


def last2first(input, ndim, inserted, grad=False):
    if inserted:
        input = input.squeeze(-1 - grad)
    else:
        input = torch.movedim(input, -1 - grad, -ndim-1 - grad)
    return input


def grid_pull(input, grid, interpolation='linear', bound='zero',
              extrapolate=False, prefilter=False):
    ndim = grid.shape[-1]
    input, inserted = first2last(input, ndim)
    input = jitfields.pull(input, grid, order=interpolation, bound=bound,
                           extrapolate=extrapolate, prefilter=prefilter)
    input = last2first(input, ndim, inserted)
    return input


def grid_push(input, grid, shape=None, interpolation='linear', bound='zero',
              extrapolate=False, prefilter=False):
    ndim = grid.shape[-1]
    input, inserted = first2last(input, ndim)
    input = jitfields.push(input, grid, shape, order=interpolation, bound=bound,
                           extrapolate=extrapolate, prefilter=prefilter)
    input = last2first(input, ndim, inserted)
    return input


def grid_count(grid, shape=None, interpolation='linear', bound='zero',
               extrapolate=False):
    return jitfields.count(grid, shape, order=interpolation, bound=bound,
                           extrapolate=extrapolate)


def grid_grad(input, grid, interpolation='linear', bound='zero',
              extrapolate=False, prefilter=False):
    ndim = grid.shape[-1]
    input, inserted = first2last(input, ndim)
    input = jitfields.grad(input, grid, order=interpolation, bound=bound,
                           extrapolate=extrapolate, prefilter=prefilter)
    input = last2first(input, ndim, inserted, True)
    return input


def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1,
                 inplace=False):
    func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff
    return func(input, interpolation, bound=bound, dim=dim)


def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None,
                    inplace=False):
    func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd
    return func(input, interpolation, bound=bound, ndim=dim)


def resize(image, factor=None, shape=None, anchor='c',
           interpolation=1, prefilter=True, **kwargs):
    kwargs.setdefault('bound', 'nearest')
    ndim = max(len(make_list(factor or [])),
               len(make_list(shape or [])),
               len(make_list(anchor or []))) or (image.dim() - 2)
    return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim,
                            anchor=anchor, order=interpolation,
                            bound=kwargs['bound'], prefilter=prefilter)


def restrict(image, factor=None, shape=None, anchor='c',
             interpolation=1, reduce_sum=False, **kwargs):
    kwargs.setdefault('bound', 'nearest')
    ndim = max(len(make_list(factor or [])),
               len(make_list(shape or [])),
               len(make_list(anchor or []))) or (image.dim() - 2)
    return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim,
                              anchor=anchor, order=interpolation,
                              bound=kwargs['bound'], reduce_sum=reduce_sum)