peirong26's picture
Upload 187 files
2571f24 verified
import torch
from ShapeID.DiffEqs.misc import _convert_to_tensor, _dot_product
def _interp_fit(y0, y1, y_mid, f0, f1, dt):
"""Fit coefficients for 4th order polynomial interpolation.
Args:
y0: function value at the start of the interval.
y1: function value at the end of the interval.
y_mid: function value at the mid-point of the interval.
f0: derivative value at the start of the interval.
f1: derivative value at the end of the interval.
dt: width of the interval.
Returns:
List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial
`p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x`
between 0 (start of interval) and 1 (end of interval).
"""
a = tuple(
_dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_])
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
)
b = tuple(
_dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_])
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
)
c = tuple(
_dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_])
for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid)
)
d = tuple(dt * f0_ for f0_ in f0)
e = y0
return [a, b, c, d, e]
def _interp_evaluate(coefficients, t0, t1, t):
"""Evaluate polynomial interpolation at the given time point.
Args:
coefficients: list of Tensor coefficients as created by `interp_fit`.
t0: scalar float64 Tensor giving the start of the interval.
t1: scalar float64 Tensor giving the end of the interval.
t: scalar float64 Tensor giving the desired interpolation point.
Returns:
Polynomial interpolation of the coefficients at time `t`.
"""
dtype = coefficients[0][0].dtype
device = coefficients[0][0].device
t0 = _convert_to_tensor(t0, dtype=dtype, device=device)
t1 = _convert_to_tensor(t1, dtype=dtype, device=device)
t = _convert_to_tensor(t, dtype=dtype, device=device)
assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1)
x = ((t - t0) / (t1 - t0)).type(dtype).to(device)
xs = [torch.tensor(1).type(dtype).to(device), x]
for _ in range(2, len(coefficients)):
xs.append(xs[-1] * x)
return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients))