peirong26's picture
Upload 187 files
2571f24 verified
from ShapeID.DiffEqs.tsit5 import Tsit5Solver
from ShapeID.DiffEqs.dopri5 import Dopri5Solver
from ShapeID.DiffEqs.fixed_grid import Euler, Midpoint, RK4
from ShapeID.DiffEqs.fixed_adams import AdamsBashforth, AdamsBashforthMoulton
from ShapeID.DiffEqs.adams import VariableCoefficientAdamsBashforth
from ShapeID.DiffEqs.misc import _check_inputs
SOLVERS = {
'explicit_adams': AdamsBashforth,
'fixed_adams': AdamsBashforthMoulton,
'adams': VariableCoefficientAdamsBashforth,
'tsit5': Tsit5Solver,
'dopri5': Dopri5Solver,
'euler': Euler,
'midpoint': Midpoint,
'rk4': RK4,
}
def odeint(func, y0, t, dt, step_size = None, rtol = 1e-7, atol = 1e-9, method = None, options = None):
"""Integrate a system of ordinary differential equations.
Solves the initial value problem for a non-stiff system of first order ODEs:
```
dy/dt = func(t, y), y(t[0]) = y0
```
where y is a Tensor of any shape.
Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
Args:
func: Function that maps a Tensor holding the state `y` and a scalar Tensor
`t` into a Tensor of state derivatives with respect to time.
y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
have any floating point or complex dtype.
t: 1-D Tensor holding a sequence of time points for which to solve for
`y`. The initial time point should be the first element of this sequence,
and each time must be larger than the previous time. May have any floating
point dtype. Converted to a Tensor with float64 dtype.
rtol: optional float64 Tensor specifying an upper bound on relative error,
per element of `y`.
atol: optional float64 Tensor specifying an upper bound on absolute error,
per element of `y`.
method: optional string indicating the integration method to use.
options: optional dict of configuring options for the indicated integration
method. Can only be provided if a `method` is explicitly set.
name: Optional name for this operation.
Returns:
y: Tensor, where the first dimension corresponds to different
time points. Contains the solved value of y for each desired time point in
`t`, with the initial value `y0` being the first element along the first
dimension.
Raises:
ValueError: if an invalid `method` is provided.
TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
an invalid dtype.
"""
tensor_input, func, y0, t = _check_inputs(func, y0, t)
if options and method is None:
raise ValueError('cannot supply `options` without specifying `method`')
if method is None:
method = 'dopri5'
#solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, **options)
solver = SOLVERS[method](func, y0, rtol = rtol, atol = atol, dt = dt, options = options)
solution = solver.integrate(t)
if tensor_input:
solution = solution[0]
return solution