|
|
|
|
|
import collections |
|
|
from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor |
|
|
|
|
|
_ButcherTableau = collections.namedtuple('_ButcherTableau', 'alpha beta c_sol c_error') |
|
|
|
|
|
|
|
|
class _RungeKuttaState(collections.namedtuple('_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')): |
|
|
"""Saved state of the Runge Kutta solver. |
|
|
|
|
|
Attributes: |
|
|
y1: Tensor giving the function value at the end of the last time step. |
|
|
f1: Tensor giving derivative at the end of the last time step. |
|
|
t0: scalar float64 Tensor giving start of the last time step. |
|
|
t1: scalar float64 Tensor giving end of the last time step. |
|
|
dt: scalar float64 Tensor giving the size for the next time step. |
|
|
interp_coef: list of Tensors giving coefficients for polynomial |
|
|
interpolation between `t0` and `t1`. |
|
|
""" |
|
|
|
|
|
|
|
|
def _runge_kutta_step(func, y0, f0, t0, dt, tableau): |
|
|
"""Take an arbitrary Runge-Kutta step and estimate error. |
|
|
|
|
|
Args: |
|
|
func: Function to evaluate like `func(t, y)` to compute the time derivative |
|
|
of `y`. |
|
|
y0: Tensor initial value for the state. |
|
|
f0: Tensor initial value for the derivative, computed from `func(t0, y0)`. |
|
|
t0: float64 scalar Tensor giving the initial time. |
|
|
dt: float64 scalar Tensor giving the size of the desired time step. |
|
|
tableau: optional _ButcherTableau describing how to take the Runge-Kutta |
|
|
step. |
|
|
name: optional name for the operation. |
|
|
|
|
|
Returns: |
|
|
Tuple `(y1, f1, y1_error, k)` giving the estimated function value after |
|
|
the Runge-Kutta step at `t1 = t0 + dt`, the derivative of the state at `t1`, |
|
|
estimated error at `t1`, and a list of Runge-Kutta coefficients `k` used for |
|
|
calculating these terms. |
|
|
""" |
|
|
dtype = y0[0].dtype |
|
|
device = y0[0].device |
|
|
|
|
|
t0 = _convert_to_tensor(t0, dtype=dtype, device=device) |
|
|
dt = _convert_to_tensor(dt, dtype=dtype, device=device) |
|
|
|
|
|
k = tuple(map(lambda x: [x], f0)) |
|
|
for alpha_i, beta_i in zip(tableau.alpha, tableau.beta): |
|
|
ti = t0 + alpha_i * dt |
|
|
yi = tuple(y0_ + _scaled_dot_product(dt, beta_i, k_) for y0_, k_ in zip(y0, k)) |
|
|
tuple(k_.append(f_) for k_, f_ in zip(k, func(ti, yi))) |
|
|
|
|
|
if not (tableau.c_sol[-1] == 0 and tableau.c_sol[:-1] == tableau.beta[-1]): |
|
|
|
|
|
yi = tuple(y0_ + _scaled_dot_product(dt, tableau.c_sol, k_) for y0_, k_ in zip(y0, k)) |
|
|
|
|
|
y1 = yi |
|
|
f1 = tuple(k_[-1] for k_ in k) |
|
|
y1_error = tuple(_scaled_dot_product(dt, tableau.c_error, k_) for k_ in k) |
|
|
return (y1, f1, y1_error, k) |
|
|
|
|
|
|
|
|
def rk4_step_func(func, t, dt, y, k1=None): |
|
|
if k1 is None: k1 = func(t, y) |
|
|
k2 = func(t + dt / 2, tuple(y_ + dt * k1_ / 2 for y_, k1_ in zip(y, k1))) |
|
|
k3 = func(t + dt / 2, tuple(y_ + dt * k2_ / 2 for y_, k2_ in zip(y, k2))) |
|
|
k4 = func(t + dt, tuple(y_ + dt * k3_ for y_, k3_ in zip(y, k3))) |
|
|
return tuple((k1_ + 2 * k2_ + 2 * k3_ + k4_) * (dt / 6) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) |
|
|
|
|
|
|
|
|
def rk4_alt_step_func(func, t, dt, y, k1=None): |
|
|
"""Smaller error with slightly more compute.""" |
|
|
if k1 is None: k1 = func(t, y) |
|
|
k2 = func(t + dt / 3, tuple(y_ + dt * k1_ / 3 for y_, k1_ in zip(y, k1))) |
|
|
k3 = func(t + dt * 2 / 3, tuple(y_ + dt * (k1_ / -3 + k2_) for y_, k1_, k2_ in zip(y, k1, k2))) |
|
|
k4 = func(t + dt, tuple(y_ + dt * (k1_ - k2_ + k3_) for y_, k1_, k2_, k3_ in zip(y, k1, k2, k3))) |
|
|
return tuple((k1_ + 3 * k2_ + 3 * k3_ + k4_) * (dt / 8) for k1_, k2_, k3_, k4_ in zip(k1, k2, k3, k4)) |
|
|
|