| import torch | |
| from ShapeID.DiffEqs.misc import _convert_to_tensor, _dot_product | |
| def _interp_fit(y0, y1, y_mid, f0, f1, dt): | |
| """Fit coefficients for 4th order polynomial interpolation. | |
| Args: | |
| y0: function value at the start of the interval. | |
| y1: function value at the end of the interval. | |
| y_mid: function value at the mid-point of the interval. | |
| f0: derivative value at the start of the interval. | |
| f1: derivative value at the end of the interval. | |
| dt: width of the interval. | |
| Returns: | |
| List of coefficients `[a, b, c, d, e]` for interpolating with the polynomial | |
| `p = a * x ** 4 + b * x ** 3 + c * x ** 2 + d * x + e` for values of `x` | |
| between 0 (start of interval) and 1 (end of interval). | |
| """ | |
| a = tuple( | |
| _dot_product([-2 * dt, 2 * dt, -8, -8, 16], [f0_, f1_, y0_, y1_, y_mid_]) | |
| for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) | |
| ) | |
| b = tuple( | |
| _dot_product([5 * dt, -3 * dt, 18, 14, -32], [f0_, f1_, y0_, y1_, y_mid_]) | |
| for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) | |
| ) | |
| c = tuple( | |
| _dot_product([-4 * dt, dt, -11, -5, 16], [f0_, f1_, y0_, y1_, y_mid_]) | |
| for f0_, f1_, y0_, y1_, y_mid_ in zip(f0, f1, y0, y1, y_mid) | |
| ) | |
| d = tuple(dt * f0_ for f0_ in f0) | |
| e = y0 | |
| return [a, b, c, d, e] | |
| def _interp_evaluate(coefficients, t0, t1, t): | |
| """Evaluate polynomial interpolation at the given time point. | |
| Args: | |
| coefficients: list of Tensor coefficients as created by `interp_fit`. | |
| t0: scalar float64 Tensor giving the start of the interval. | |
| t1: scalar float64 Tensor giving the end of the interval. | |
| t: scalar float64 Tensor giving the desired interpolation point. | |
| Returns: | |
| Polynomial interpolation of the coefficients at time `t`. | |
| """ | |
| dtype = coefficients[0][0].dtype | |
| device = coefficients[0][0].device | |
| t0 = _convert_to_tensor(t0, dtype=dtype, device=device) | |
| t1 = _convert_to_tensor(t1, dtype=dtype, device=device) | |
| t = _convert_to_tensor(t, dtype=dtype, device=device) | |
| assert (t0 <= t) & (t <= t1), 'invalid interpolation, fails `t0 <= t <= t1`: {}, {}, {}'.format(t0, t, t1) | |
| x = ((t - t0) / (t1 - t0)).type(dtype).to(device) | |
| xs = [torch.tensor(1).type(dtype).to(device), x] | |
| for _ in range(2, len(coefficients)): | |
| xs.append(xs[-1] * x) | |
| return tuple(_dot_product(coefficients_, reversed(xs)) for coefficients_ in zip(*coefficients)) | |