File size: 6,622 Bytes
2571f24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import warnings
import torch
def _flatten(sequence):
flat = [p.contiguous().view(-1) for p in sequence]
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
def _flatten_convert_none_to_zeros(sequence, like_sequence):
flat = [
p.contiguous().view(-1) if p is not None else torch.zeros_like(q).view(-1)
for p, q in zip(sequence, like_sequence)
]
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
def _possibly_nonzero(x):
return isinstance(x, torch.Tensor) or x != 0
def _scaled_dot_product(scale, xs, ys):
"""Calculate a scaled, vector inner product between lists of Tensors."""
# Using _possibly_nonzero lets us avoid wasted computation.
return sum([(scale * x) * y for x, y in zip(xs, ys) if _possibly_nonzero(x) or _possibly_nonzero(y)])
def _dot_product(xs, ys):
"""Calculate the vector inner product between two lists of Tensors."""
return sum([x * y for x, y in zip(xs, ys)])
def _has_converged(y0, y1, rtol, atol):
"""Checks that each element is within the error tolerance."""
error_tol = tuple(atol + rtol * torch.max(torch.abs(y0_), torch.abs(y1_)) for y0_, y1_ in zip(y0, y1))
error = tuple(torch.abs(y0_ - y1_) for y0_, y1_ in zip(y0, y1))
return all((error_ < error_tol_).all() for error_, error_tol_ in zip(error, error_tol))
def _convert_to_tensor(a, dtype=None, device=None):
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if dtype is not None:
a = a.type(dtype)
if device is not None:
a = a.to(device)
return a
def _is_finite(tensor):
_check = (tensor == float('inf')) + (tensor == float('-inf')) + torch.isnan(tensor)
return not _check.any()
def _decreasing(t):
return (t[1:] < t[:-1]).all()
def _assert_increasing(t):
assert (t[1:] > t[:-1]).all(), 't must be strictly increasing or decrasing'
def _is_iterable(inputs):
try:
iter(inputs)
return True
except TypeError:
return False
def _norm(x):
"""Compute RMS norm."""
if torch.is_tensor(x):
return x.norm() / (x.numel()**0.5)
else:
return torch.sqrt(sum(x_.norm()**2 for x_ in x) / sum(x_.numel() for x_ in x))
def _handle_unused_kwargs(solver, unused_kwargs):
if len(unused_kwargs) > 0:
warnings.warn('{}: Unexpected arguments {}'.format(solver.__class__.__name__, unused_kwargs))
def _select_initial_step(fun, t0, y0, order, rtol, atol, f0=None):
"""Empirically select a good initial step.
The algorithm is described in [1]_.
Parameters
----------
fun : callable
Right-hand side of the system.
t0 : float
Initial value of the independent variable.
y0 : ndarray, shape (n,)
Initial value of the dependent variable.
direction : float
Integration direction.
order : float
Method order.
rtol : float
Desired relative tolerance.
atol : float
Desired absolute tolerance.
Returns
-------
h_abs : float
Absolute value of the suggested initial step.
References
----------
.. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential
Equations I: Nonstiff Problems", Sec. II.4.
"""
t0 = t0.to(y0[0])
if f0 is None:
f0 = fun(t0, y0)
rtol = rtol if _is_iterable(rtol) else [rtol] * len(y0)
atol = atol if _is_iterable(atol) else [atol] * len(y0)
scale = tuple(atol_ + torch.abs(y0_) * rtol_ for y0_, atol_, rtol_ in zip(y0, atol, rtol))
d0 = tuple(_norm(y0_ / scale_) for y0_, scale_ in zip(y0, scale))
d1 = tuple(_norm(f0_ / scale_) for f0_, scale_ in zip(f0, scale))
if max(d0).item() < 1e-5 or max(d1).item() < 1e-5:
h0 = torch.tensor(1e-6).to(t0)
else:
h0 = 0.01 * max(d0_ / d1_ for d0_, d1_ in zip(d0, d1))
y1 = tuple(y0_ + h0 * f0_ for y0_, f0_ in zip(y0, f0))
f1 = fun(t0 + h0, y1)
d2 = tuple(_norm((f1_ - f0_) / scale_) / h0 for f1_, f0_, scale_ in zip(f1, f0, scale))
if max(d1).item() <= 1e-15 and max(d2).item() <= 1e-15:
h1 = torch.max(torch.tensor(1e-6).to(h0), h0 * 1e-3)
else:
h1 = (0.01 / max(d1 + d2))**(1. / float(order + 1))
return torch.min(100 * h0, h1)
def _compute_error_ratio(error_estimate, error_tol=None, rtol=None, atol=None, y0=None, y1=None):
if error_tol is None:
assert rtol is not None and atol is not None and y0 is not None and y1 is not None
rtol if _is_iterable(rtol) else [rtol] * len(y0)
atol if _is_iterable(atol) else [atol] * len(y0)
error_tol = tuple(
atol_ + rtol_ * torch.max(torch.abs(y0_), torch.abs(y1_))
for atol_, rtol_, y0_, y1_ in zip(atol, rtol, y0, y1)
)
error_ratio = tuple(error_estimate_ / error_tol_ for error_estimate_, error_tol_ in zip(error_estimate, error_tol))
mean_sq_error_ratio = tuple(torch.mean(error_ratio_ * error_ratio_) for error_ratio_ in error_ratio)
return mean_sq_error_ratio
def _optimal_step_size(last_step, mean_error_ratio, safety=0.9, ifactor=10.0, dfactor=0.2, order=5):
"""Calculate the optimal size for the next step."""
mean_error_ratio = max(mean_error_ratio) # Compute step size based on highest ratio.
if mean_error_ratio == 0:
return last_step * ifactor
if mean_error_ratio < 1:
dfactor = _convert_to_tensor(1, dtype=torch.float64, device=mean_error_ratio.device)
error_ratio = torch.sqrt(mean_error_ratio).to(last_step)
exponent = torch.tensor(1 / order).to(last_step)
factor = torch.max(1 / ifactor, torch.min(error_ratio**exponent / safety, 1 / dfactor))
return last_step / factor
def _check_inputs(func, y0, t):
tensor_input = False
if torch.is_tensor(y0):
tensor_input = True
y0 = (y0,)
_base_nontuple_func_ = func
func = lambda t, y: (_base_nontuple_func_(t, y[0]),)
assert isinstance(y0, tuple), 'y0 must be either a torch.Tensor or a tuple'
for y0_ in y0:
assert torch.is_tensor(y0_), 'each element must be a torch.Tensor but received {}'.format(type(y0_))
if _decreasing(t):
t = -t
_base_reverse_func = func
func = lambda t, y: tuple(-f_ for f_ in _base_reverse_func(-t, y))
for y0_ in y0:
if not torch.is_floating_point(y0_):
raise TypeError('`y0` must be a floating point Tensor but is a {}'.format(y0_.type()))
if not torch.is_floating_point(t):
raise TypeError('`t` must be a floating point Tensor but is a {}'.format(t.type()))
return tensor_input, func, y0, t
|