BrainFM / utils /interpol /splines.py
peirong26's picture
Upload 187 files
2571f24 verified
"""Weights and derivatives of spline orders 0 to 7."""
import torch
from enum import Enum
from .jit_utils import square, cube, pow4, pow5, pow6, pow7
class InterpolationType(Enum):
nearest = zeroth = 0
linear = first = 1
quadratic = second = 2
cubic = third = 3
fourth = 4
fifth = 5
sixth = 6
seventh = 7
@torch.jit.script
class Spline:
def __init__(self, order: int = 1):
self.order = order
def weight(self, x):
w = self.fastweight(x)
zero = torch.zeros([1], dtype=x.dtype, device=x.device)
w = torch.where(x.abs() >= (self.order + 1)/2, zero, w)
return w
def fastweight(self, x):
if self.order == 0:
return torch.ones(x.shape, dtype=x.dtype, device=x.device)
x = x.abs()
if self.order == 1:
return 1 - x
if self.order == 2:
x_low = 0.75 - square(x)
x_up = 0.5 * square(1.5 - x)
return torch.where(x < 0.5, x_low, x_up)
if self.order == 3:
x_low = (x * x * (x - 2.) * 3. + 4.) / 6.
x_up = cube(2. - x) / 6.
return torch.where(x < 1., x_low, x_up)
if self.order == 4:
x_low = square(x)
x_low = x_low * (x_low * 0.25 - 0.625) + 115. / 192.
x_mid = x * (x * (x * (5. - x) / 6. - 1.25) + 5./24.) + 55./96.
x_up = pow4(x - 2.5) / 24.
return torch.where(x < 0.5, x_low, torch.where(x < 1.5, x_mid, x_up))
if self.order == 5:
x_low = square(x)
x_low = x_low * (x_low * (0.25 - x / 12.) - 0.5) + 0.55
x_mid = x * (x * (x * (x * (x / 24. - 0.375) + 1.25) - 1.75) + 0.625) + 0.425
x_up = pow5(3 - x) / 120.
return torch.where(x < 1., x_low, torch.where(x < 2., x_mid, x_up))
if self.order == 6:
x_low = square(x)
x_low = x_low * (x_low * (7./48. - x_low/36.) - 77./192.) + 5887./11520.
x_mid_low = (x * (x * (x * (x * (x * (x / 48. - 7./48.) + 0.328125)
- 35./288.) - 91./256.) - 7./768.) + 7861./15360.)
x_mid_up = (x * (x * (x * (x * (x * (7./60. - x / 120.) - 0.65625)
+ 133./72.) - 2.5703125) + 1267./960.) + 1379./7680.)
x_up = pow6(x - 3.5) / 720.
return torch.where(x < .5, x_low,
torch.where(x < 1.5, x_mid_low,
torch.where(x < 2.5, x_mid_up, x_up)))
if self.order == 7:
x_low = square(x)
x_low = (x_low * (x_low * (x_low * (x / 144. - 1./36.)
+ 1./9.) - 1./3.) + 151./315.)
x_mid_low = (x * (x * (x * (x * (x * (x * (0.05 - x/240.) - 7./30.)
+ 0.5) - 7./18.) - 0.1) - 7./90.) + 103./210.)
x_mid_up = (x * (x * (x * (x * (x * (x * (x / 720. - 1./36.)
+ 7./30.) - 19./18.) + 49./18.) - 23./6.) + 217./90.)
- 139./630.)
x_up = pow7(4 - x) / 5040.
return torch.where(x < 1., x_low,
torch.where(x < 2., x_mid_low,
torch.where(x < 3., x_mid_up, x_up)))
raise NotImplementedError
def grad(self, x):
if self.order == 0:
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
g = self.fastgrad(x)
zero = torch.zeros([1], dtype=x.dtype, device=x.device)
g = torch.where(x.abs() >= (self.order + 1)/2, zero, g)
return g
def fastgrad(self, x):
if self.order == 0:
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
return self._fastgrad(x.abs()).mul(x.sign())
def _fastgrad(self, x):
if self.order == 1:
return torch.ones(x.shape, dtype=x.dtype, device=x.device)
if self.order == 2:
return torch.where(x < 0.5, -2*x, x - 1.5)
if self.order == 3:
g_low = x * (x * 1.5 - 2)
g_up = -0.5 * square(2 - x)
return torch.where(x < 1, g_low, g_up)
if self.order == 4:
g_low = x * (square(x) - 1.25)
g_mid = x * (x * (x * (-2./3.) + 2.5) - 2.5) + 5./24.
g_up = cube(2. * x - 5.) / 48.
return torch.where(x < 0.5, g_low,
torch.where(x < 1.5, g_mid, g_up))
if self.order == 5:
g_low = x * (x * (x * (x * (-5./12.) + 1.)) - 1.)
g_mid = x * (x * (x * (x * (5./24.) - 1.5) + 3.75) - 3.5) + 0.625
g_up = pow4(x - 3.) / (-24.)
return torch.where(x < 1, g_low,
torch.where(x < 2, g_mid, g_up))
if self.order == 6:
g_low = square(x)
g_low = x * (g_low * (7./12.) - square(g_low) / 6. - 77./96.)
g_mid_low = (x * (x * (x * (x * (x * 0.125 - 35./48.) + 1.3125)
- 35./96.) - 0.7109375) - 7./768.)
g_mid_up = (x * (x * (x * (x * (x / (-20.) + 7./12.) - 2.625)
+ 133./24.) - 5.140625) + 1267./960.)
g_up = pow5(2*x - 7) / 3840.
return torch.where(x < 0.5, g_low,
torch.where(x < 1.5, g_mid_low,
torch.where(x < 2.5, g_mid_up,
g_up)))
if self.order == 7:
g_low = square(x)
g_low = x * (g_low * (g_low * (x * (7./144.) - 1./6.) + 4./9.) - 2./3.)
g_mid_low = (x * (x * (x * (x * (x * (x * (-7./240.) + 3./10.)
- 7./6.) + 2.) - 7./6.) - 1./5.) - 7./90.)
g_mid_up = (x * (x * (x * (x * (x * (x * (7./720.) - 1./6.)
+ 7./6.) - 38./9.) + 49./6.) - 23./3.) + 217./90.)
g_up = pow6(x - 4) / (-720.)
return torch.where(x < 1, g_low,
torch.where(x < 2, g_mid_low,
torch.where(x < 3, g_mid_up, g_up)))
raise NotImplementedError
def hess(self, x):
if self.order == 0:
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
h = self.fasthess(x)
zero = torch.zeros([1], dtype=x.dtype, device=x.device)
h = torch.where(x.abs() >= (self.order + 1)/2, zero, h)
return h
def fasthess(self, x):
if self.order in (0, 1):
return torch.zeros(x.shape, dtype=x.dtype, device=x.device)
x = x.abs()
if self.order == 2:
one = torch.ones([1], dtype=x.dtype, device=x.device)
return torch.where(x < 0.5, -2 * one, one)
if self.order == 3:
return torch.where(x < 1, 3. * x - 2., 2. - x)
if self.order == 4:
return torch.where(x < 0.5, 3. * square(x) - 1.25,
torch.where(x < 1.5, x * (-2. * x + 5.) - 2.5,
square(2. * x - 5.) / 8.))
if self.order == 5:
h_low = square(x)
h_low = - h_low * (x * (5./3.) - 3.) - 1.
h_mid = x * (x * (x * (5./6.) - 9./2.) + 15./2.) - 7./2.
h_up = 9./2. - x * (x * (x/6. - 3./2.) + 9./2.)
return torch.where(x < 1, h_low,
torch.where(x < 2, h_mid, h_up))
if self.order == 6:
h_low = square(x)
h_low = - h_low * (h_low * (5./6) - 7./4.) - 77./96.
h_mid_low = (x * (x * (x * (x * (5./8.) - 35./12.) + 63./16.)
- 35./48.) - 91./128.)
h_mid_up = -(x * (x * (x * (x/4. - 7./3.) + 63./8.) - 133./12.)
+ 329./64.)
h_up = (x * (x * (x * (x/24. - 7./12.) + 49./16.) - 343./48.)
+ 2401./384.)
return torch.where(x < 0.5, h_low,
torch.where(x < 1.5, h_mid_low,
torch.where(x < 2.5, h_mid_up,
h_up)))
if self.order == 7:
h_low = square(x)
h_low = h_low * (h_low*(x * (7./24.) - 5./6.) + 4./3.) - 2./3.
h_mid_low = - (x * (x * (x * (x * (x * (7./40.) - 3./2.) + 14./3.)
- 6.) + 7./3.) + 1./5.)
h_mid_up = (x * (x * (x * (x * (x * (7./120.) - 5./6.) + 14./3.)
- 38./3.) + 49./3.) - 23./3.)
h_up = - (x * (x * (x * (x * (x/120. - 1./6.) + 4./3.) - 16./3.)
+ 32./3.) - 128./15.)
return torch.where(x < 1, h_low,
torch.where(x < 2, h_mid_low,
torch.where(x < 3, h_mid_up,
h_up)))
raise NotImplementedError