import torch from ShapeID.DiffEqs.misc import _scaled_dot_product, _convert_to_tensor, _is_finite, _select_initial_step, _handle_unused_kwargs from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver from ShapeID.DiffEqs.rk_common import _RungeKuttaState, _ButcherTableau, _runge_kutta_step # Parameters from Tsitouras (2011). _TSITOURAS_TABLEAU = _ButcherTableau( alpha=[0.161, 0.327, 0.9, 0.9800255409045097, 1., 1.], beta=[ [0.161], [-0.008480655492357, 0.3354806554923570], [2.897153057105494, -6.359448489975075, 4.362295432869581], [5.32586482843925895, -11.74888356406283, 7.495539342889836, -0.09249506636175525], [5.86145544294642038, -12.92096931784711, 8.159367898576159, -0.071584973281401006, -0.02826905039406838], [0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774], ], c_sol=[0.09646076681806523, 0.01, 0.4798896504144996, 1.379008574103742, -3.290069515436081, 2.324710524099774, 0], c_error=[ 0.09646076681806523 - 0.001780011052226, 0.01 - 0.000816434459657, 0.4798896504144996 - -0.007880878010262, 1.379008574103742 - 0.144711007173263, -3.290069515436081 - -0.582357165452555, 2.324710524099774 - 0.458082105929187, -1 / 66, ], ) def _interp_coeff_tsit5(t0, dt, eval_t): t = float((eval_t - t0) / dt) b1 = -1.0530884977290216 * t * (t - 1.3299890189751412) * (t**2 - 1.4364028541716351 * t + 0.7139816917074209) b2 = 0.1017 * t**2 * (t**2 - 2.1966568338249754 * t + 1.2949852507374631) b3 = 2.490627285651252793 * t**2 * (t**2 - 2.38535645472061657 * t + 1.57803468208092486) b4 = -16.54810288924490272 * (t - 1.21712927295533244) * (t - 0.61620406037800089) * t**2 b5 = 47.37952196281928122 * (t - 1.203071208372362603) * (t - 0.658047292653547382) * t**2 b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2 b7 = 2.5 * (t - 1) * (t - 0.6) * t**2 return [b1, b2, b3, b4, b5, b6, b7] def _interp_eval_tsit5(t0, t1, k, eval_t): dt = t1 - t0 y0 = tuple(k_[0] for k_ in k) interp_coeff = _interp_coeff_tsit5(t0, dt, eval_t) y_t = tuple(y0_ + _scaled_dot_product(dt, interp_coeff, k_) for y0_, k_ in zip(y0, k)) return y_t def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5): """Calculate the optimal size for the next Runge-Kutta step.""" if mean_error_ratio == 0: return last_step * ifactor if mean_error_ratio < 1: dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device) error_ratio = torch.sqrt(mean_error_ratio).type_as(last_step) exponent = torch.tensor(1 / order).type_as(last_step) factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor)) return last_step / factor def _abs_square(x): return torch.mul(x, x) class Tsit5Solver(AdaptiveStepsizeODESolver): def __init__( self, func, y0, rtol, atol, first_step=None, safety=0.9, ifactor=10.0, dfactor=0.2, max_num_steps=2**31 - 1, **unused_kwargs ): _handle_unused_kwargs(self, unused_kwargs) del unused_kwargs self.func = func self.y0 = y0 self.rtol = rtol self.atol = atol self.first_step = first_step self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device) self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device) self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device) self.max_num_steps = _convert_to_tensor(max_num_steps, dtype=torch.int32, device=y0[0].device) def before_integrate(self, t): if self.first_step is None: first_step = _select_initial_step(self.func, t[0], self.y0, 4, self.rtol, self.atol).to(t) else: first_step = _convert_to_tensor(0.01, dtype=t.dtype, device=t.device) self.rk_state = _RungeKuttaState( self.y0, self.func(t[0].type_as(self.y0[0]), self.y0), t[0], t[0], first_step, tuple(map(lambda x: [x] * 7, self.y0)) ) def advance(self, next_t): """Interpolate through the next time point, integrating as necessary.""" n_steps = 0 while next_t > self.rk_state.t1: assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) self.rk_state = self._adaptive_tsit5_step(self.rk_state) n_steps += 1 return _interp_eval_tsit5(self.rk_state.t0, self.rk_state.t1, self.rk_state.interp_coeff, next_t) def _adaptive_tsit5_step(self, rk_state): """Take an adaptive Runge-Kutta step to integrate the DiffEqs.""" y0, f0, _, t0, dt, _ = rk_state ######################################################## # Assertions # ######################################################## assert t0 + dt > t0, 'underflow in dt {}'.format(dt.item()) for y0_ in y0: assert _is_finite(torch.abs(y0_)), 'non-finite values in state `y`: {}'.format(y0_) y1, f1, y1_error, k = _runge_kutta_step(self.func, y0, f0, t0, dt, tableau=_TSITOURAS_TABLEAU) ######################################################## # Error Ratio # ######################################################## error_tol = tuple(self.atol + self.rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1)) tensor_error_ratio = tuple(y1_error_ / error_tol_ for y1_error_, error_tol_ in zip(y1_error, error_tol)) sq_error_ratio = tuple( torch.mul(tensor_error_ratio_, tensor_error_ratio_) for tensor_error_ratio_ in tensor_error_ratio ) mean_error_ratio = ( sum(torch.sum(sq_error_ratio_) for sq_error_ratio_ in sq_error_ratio) / sum(sq_error_ratio_.numel() for sq_error_ratio_ in sq_error_ratio) ) accept_step = mean_error_ratio <= 1 ######################################################## # Update RK State # ######################################################## y_next = y1 if accept_step else y0 f_next = f1 if accept_step else f0 t_next = t0 + dt if accept_step else t0 dt_next = _optimal_step_size(dt, mean_error_ratio, self.safety, self.ifactor, self.dfactor) k_next = k if accept_step else self.rk_state.interp_coeff rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, k_next) return rk_state