peirong26's picture
Upload 187 files
2571f24 verified
import collections
import torch
from ShapeID.DiffEqs.solvers import AdaptiveStepsizeODESolver
from ShapeID.DiffEqs.misc import (
_handle_unused_kwargs, _select_initial_step, _convert_to_tensor, _scaled_dot_product, _is_iterable,
_optimal_step_size, _compute_error_ratio
)
_MIN_ORDER = 1
_MAX_ORDER = 12
gamma_star = [
1, -1 / 2, -1 / 12, -1 / 24, -19 / 720, -3 / 160, -863 / 60480, -275 / 24192, -33953 / 3628800, -0.00789255,
-0.00678585, -0.00592406, -0.00523669, -0.0046775, -0.00421495, -0.0038269
]
class _VCABMState(collections.namedtuple('_VCABMState', 'y_n, prev_f, prev_t, next_t, phi, order')):
"""Saved state of the variable step size Adams-Bashforth-Moulton solver as described in
Solving Ordinary Differential Equations I - Nonstiff Problems III.5
by Ernst Hairer, Gerhard Wanner, and Syvert P Norsett.
"""
def g_and_explicit_phi(prev_t, next_t, implicit_phi, k):
curr_t = prev_t[0]
dt = next_t - prev_t[0]
g = torch.empty(k + 1).to(prev_t[0])
explicit_phi = collections.deque(maxlen=k)
beta = torch.tensor(1).to(prev_t[0])
g[0] = 1
c = 1 / torch.arange(1, k + 2).to(prev_t[0])
explicit_phi.append(implicit_phi[0])
for j in range(1, k):
beta = (next_t - prev_t[j - 1]) / (curr_t - prev_t[j]) * beta
beat_cast = beta.to(implicit_phi[j][0])
explicit_phi.append(tuple(iphi_ * beat_cast for iphi_ in implicit_phi[j]))
c = c[:-1] - c[1:] if j == 1 else c[:-1] - c[1:] * dt / (next_t - prev_t[j - 1])
g[j] = c[0]
c = c[:-1] - c[1:] * dt / (next_t - prev_t[k - 1])
g[k] = c[0]
return g, explicit_phi
def compute_implicit_phi(explicit_phi, f_n, k):
k = min(len(explicit_phi) + 1, k)
implicit_phi = collections.deque(maxlen=k)
implicit_phi.append(f_n)
for j in range(1, k):
implicit_phi.append(tuple(iphi_ - ephi_ for iphi_, ephi_ in zip(implicit_phi[j - 1], explicit_phi[j - 1])))
return implicit_phi
class VariableCoefficientAdamsBashforth(AdaptiveStepsizeODESolver):
def __init__(
self, func, y0, rtol, atol, implicit=True, max_order=_MAX_ORDER, safety=0.9, ifactor=10.0, dfactor=0.2,
**unused_kwargs
):
_handle_unused_kwargs(self, unused_kwargs)
del unused_kwargs
self.func = func
self.y0 = y0
self.rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
self.atol = atol if _is_iterable(atol) else [atol] * len(y0)
self.implicit = implicit
self.max_order = int(max(_MIN_ORDER, min(max_order, _MAX_ORDER)))
self.safety = _convert_to_tensor(safety, dtype=torch.float64, device=y0[0].device)
self.ifactor = _convert_to_tensor(ifactor, dtype=torch.float64, device=y0[0].device)
self.dfactor = _convert_to_tensor(dfactor, dtype=torch.float64, device=y0[0].device)
def before_integrate(self, t):
prev_f = collections.deque(maxlen=self.max_order + 1)
prev_t = collections.deque(maxlen=self.max_order + 1)
phi = collections.deque(maxlen=self.max_order)
t0 = t[0]
f0 = self.func(t0.type_as(self.y0[0]), self.y0)
prev_t.appendleft(t0)
prev_f.appendleft(f0)
phi.appendleft(f0)
first_step = _select_initial_step(self.func, t[0], self.y0, 2, self.rtol[0], self.atol[0], f0=f0).to(t)
self.vcabm_state = _VCABMState(self.y0, prev_f, prev_t, next_t=t[0] + first_step, phi=phi, order=1)
def advance(self, final_t):
final_t = _convert_to_tensor(final_t).to(self.vcabm_state.prev_t[0])
while final_t > self.vcabm_state.prev_t[0]:
self.vcabm_state = self._adaptive_adams_step(self.vcabm_state, final_t)
assert final_t == self.vcabm_state.prev_t[0]
return self.vcabm_state.y_n
def _adaptive_adams_step(self, vcabm_state, final_t):
y0, prev_f, prev_t, next_t, prev_phi, order = vcabm_state
if next_t > final_t:
next_t = final_t
dt = (next_t - prev_t[0])
dt_cast = dt.to(y0[0])
# Explicit predictor step.
g, phi = g_and_explicit_phi(prev_t, next_t, prev_phi, order)
g = g.to(y0[0])
p_next = tuple(
y0_ + _scaled_dot_product(dt_cast, g[:max(1, order - 1)], phi_[:max(1, order - 1)])
for y0_, phi_ in zip(y0, tuple(zip(*phi)))
)
# Update phi to implicit.
next_f0 = self.func(next_t.to(p_next[0]), p_next)
implicit_phi_p = compute_implicit_phi(phi, next_f0, order + 1)
# Implicit corrector step.
y_next = tuple(
p_next_ + dt_cast * g[order - 1] * iphi_ for p_next_, iphi_ in zip(p_next, implicit_phi_p[order - 1])
)
# Error estimation.
tolerance = tuple(
atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
for atol_, rtol_, y0_, y1_ in zip(self.atol, self.rtol, y0, y_next)
)
local_error = tuple(dt_cast * (g[order] - g[order - 1]) * iphi_ for iphi_ in implicit_phi_p[order])
error_k = _compute_error_ratio(local_error, tolerance)
accept_step = (torch.tensor(error_k) <= 1).all()
if not accept_step:
# Retry with adjusted step size if step is rejected.
dt_next = _optimal_step_size(dt, error_k, self.safety, self.ifactor, self.dfactor, order=order)
return _VCABMState(y0, prev_f, prev_t, prev_t[0] + dt_next, prev_phi, order=order)
# We accept the step. Evaluate f and update phi.
next_f0 = self.func(next_t.to(p_next[0]), y_next)
implicit_phi = compute_implicit_phi(phi, next_f0, order + 2)
next_order = order
if len(prev_t) <= 4 or order < 3:
next_order = min(order + 1, 3, self.max_order)
else:
error_km1 = _compute_error_ratio(
tuple(dt_cast * (g[order - 1] - g[order - 2]) * iphi_ for iphi_ in implicit_phi_p[order - 1]), tolerance
)
error_km2 = _compute_error_ratio(
tuple(dt_cast * (g[order - 2] - g[order - 3]) * iphi_ for iphi_ in implicit_phi_p[order - 2]), tolerance
)
if min(error_km1 + error_km2) < max(error_k):
next_order = order - 1
elif order < self.max_order:
error_kp1 = _compute_error_ratio(
tuple(dt_cast * gamma_star[order] * iphi_ for iphi_ in implicit_phi_p[order]), tolerance
)
if max(error_kp1) < max(error_k):
next_order = order + 1
# Keep step size constant if increasing order. Else use adaptive step size.
dt_next = dt if next_order > order else _optimal_step_size(
dt, error_k, self.safety, self.ifactor, self.dfactor, order=order + 1
)
prev_f.appendleft(next_f0)
prev_t.appendleft(next_t)
return _VCABMState(p_next, prev_f, prev_t, next_t + dt_next, implicit_phi, order=next_order)