File size: 6,815 Bytes
2571f24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import torch
from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs
from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver
from ShapeID.DiffEqs.rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step
# Parameters from Tsitouras (2011).
_TSITOURAS_TABLEAU = _ButcherTableau(
alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.],
beta=[
[0.161],
[-0.008480655492357, 0.3354806554923570],
[2.897153057105494, -6.359448489975075, 4.362295432869581],
[5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525],
[5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838],
[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774],
],
c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0],
c_error=[
0.09646076681806523 - 0.001780011052226,
0.01 - 0.000816434459657,
0.4798896504144996 - -0.007880878010262,
1.379008574103742 - 0.144711007173263,
-3.290069515436081 - -0.582357165452555,
2.324710524099774 - 0.458082105929187,
-1 / 66,
],
)
def _interp_coeff_tsit5(t0, dt, eval_t):
t = float((eval_t - t0) / dt)
b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209)
b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631)
b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486)
b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2
b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2
b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2
b7 = 2.5 * (t - 1) * (t - 0.6) * t**2
return [b1, b2, b3, b4, b5, b6, b7]
def _interp_eval_tsit5(t0, t1, k, eval_t):
dt = t1 - t0
y0 = tuple(k_[0] for k_ in k)
interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t)
y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k))
return y_t
def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5):
"""Calculate the optimal size for the next Runge-Kutta step."""
if mean_error_ratio == 0:
return last_step * ifactor
if mean_error_ratio < 1:
dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device)
error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step)
exponent = torch.tensor(1 / order).type_as(last_step)
factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
return last_step / factor
def _abs_square(x):
return torch.mul(x, x)
class Tsit5Solver(AdaptiveStepsizeODESolver):
def __init__(
self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1,
**unused_kwargs
):
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs
self.func = func
self.y0 = y0
self.rtol = rtol
self.atol = atol
self.first_step = first_step
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device)
def before_integrate(self, t):
if self.first_step is None:
first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t)
else:
first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device)
self.rk_state = _RungeKuttaState(
self.y0,
self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step,
tuple(map(lambda x: [x] * 7, self.y0))
)
def advance(self, next_t):
"""Interpolate through the next time point, integrating as necessary."""
n_steps = 0
while next_t > self.rk_state.t1:
assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
self.rk_state = self._adaptive_tsit5_step(self.rk_state)
n_steps += 1
return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t)
def _adaptive_tsit5_step(self, rk_state):
"""Take an adaptive Runge-Kutta step to integrate the DiffEqs."""
y0, f0, _, t0, dt, _ = rk_state
########################################################
# Assertions #
########################################################
assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item())
for y0_ in y0:
assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_)
y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU)
########################################################
# Error Ratio #
########################################################
error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1))
tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol))
sq_error_ratio = tuple(
torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio
)
mean_error_ratio = (
sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) /
sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio)
)
accept_step = mean_error_ratio <= 1
########################################################
# Update RK State #
########################################################
y_next = y1 if accept_step else y0
f_next = f1 if accept_step else f0
t_next = t0 + dt if accept_step else t0
dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor)
k_next = k if accept_step else self.rk_state.interp_coeff
rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next)
return rk_state
|