BrainFM / Generator /interpol /jitfields.py
peirong26's picture
Upload 187 files
2571f24 verified
try:
import jitfields
available = True
except (ImportError, ModuleNotFoundError):
jitfields = None
available = False
from .utils import make_list
import torch
def first2last(input, ndim):
insert = input.dim() <= ndim
if insert:
input = input.unsqueeze(-1)
else:
input = torch.movedim(input, -ndim-1, -1)
return input, insert
def last2first(input, ndim, inserted, grad=False):
if inserted:
input = input.squeeze(-1 - grad)
else:
input = torch.movedim(input, -1 - grad, -ndim-1 - grad)
return input
def grid_pull(input, grid, interpolation='linear', bound='zero',
extrapolate=False, prefilter=False):
ndim = grid.shape[-1]
input, inserted = first2last(input, ndim)
input = jitfields.pull(input, grid, order=interpolation, bound=bound,
extrapolate=extrapolate, prefilter=prefilter)
input = last2first(input, ndim, inserted)
return input
def grid_push(input, grid, shape=None, interpolation='linear', bound='zero',
extrapolate=False, prefilter=False):
ndim = grid.shape[-1]
input, inserted = first2last(input, ndim)
input = jitfields.push(input, grid, shape, order=interpolation, bound=bound,
extrapolate=extrapolate, prefilter=prefilter)
input = last2first(input, ndim, inserted)
return input
def grid_count(grid, shape=None, interpolation='linear', bound='zero',
extrapolate=False):
return jitfields.count(grid, shape, order=interpolation, bound=bound,
extrapolate=extrapolate)
def grid_grad(input, grid, interpolation='linear', bound='zero',
extrapolate=False, prefilter=False):
ndim = grid.shape[-1]
input, inserted = first2last(input, ndim)
input = jitfields.grad(input, grid, order=interpolation, bound=bound,
extrapolate=extrapolate, prefilter=prefilter)
input = last2first(input, ndim, inserted, True)
return input
def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1,
inplace=False):
func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff
return func(input, interpolation, bound=bound, dim=dim)
def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None,
inplace=False):
func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd
return func(input, interpolation, bound=bound, ndim=dim)
def resize(image, factor=None, shape=None, anchor='c',
interpolation=1, prefilter=True, **kwargs):
kwargs.setdefault('bound', 'nearest')
ndim = max(len(make_list(factor or [])),
len(make_list(shape or [])),
len(make_list(anchor or []))) or (image.dim() - 2)
return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim,
anchor=anchor, order=interpolation,
bound=kwargs['bound'], prefilter=prefilter)
def restrict(image, factor=None, shape=None, anchor='c',
interpolation=1, reduce_sum=False, **kwargs):
kwargs.setdefault('bound', 'nearest')
ndim = max(len(make_list(factor or [])),
len(make_list(shape or [])),
len(make_list(anchor or []))) or (image.dim() - 2)
return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim,
anchor=anchor, order=interpolation,
bound=kwargs['bound'], reduce_sum=reduce_sum)