BrainFM / ShapeID /DiffEqs /adjoint.py
peirong26's picture
Upload 187 files
2571f24 verified
import torch
import torch.nn as nn
from ShapeID.DiffEqs.odeint import odeint
from ShapeID.DiffEqs.misc import _flatten, _flatten_convert_none_to_zeros
class OdeintAdjointMethod(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
assert len(args) >= 8, 'Internal error: all arguments required.'
y0, func, t, dt, flat_params, rtol, atol, method, options = \
args[:-8], args[-8], args[-7], args[-6], args[-5], args[-4], args[-3], args[-2], args[-1]
ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options = func, rtol, atol, method, options
with torch.no_grad():
ans = odeint(func, y0, t, dt, rtol=rtol, atol=atol, method=method, options=options)
ctx.save_for_backward(t, flat_params, *ans)
return ans
@staticmethod
def backward(ctx, *grad_output):
t, flat_params, *ans = ctx.saved_tensors
ans = tuple(ans)
func, rtol, atol, method, options = ctx.func, ctx.rtol, ctx.atol, ctx.method, ctx.options
n_tensors = len(ans)
f_params = tuple(func.parameters())
# TODO: use a nn.Module and call odeint_adjoint to implement higher order derivatives.
def augmented_dynamics(t, y_aug):
# Dynamics of the original system augmented with
# the adjoint wrt y, and an integrator wrt t and args.
y, adj_y = y_aug[:n_tensors], y_aug[n_tensors:2 * n_tensors] # Ignore adj_time and adj_params.
with torch.set_grad_enabled(True):
t = t.to(y[0].device).detach().requires_grad_(True)
y = tuple(y_.detach().requires_grad_(True) for y_ in y)
func_eval = func(t, y)
vjp_t, *vjp_y_and_params = torch.autograd.grad(
func_eval, (t,) + y + f_params,
tuple(-adj_y_ for adj_y_ in adj_y), allow_unused=True, retain_graph=True
)
vjp_y = vjp_y_and_params[:n_tensors]
vjp_params = vjp_y_and_params[n_tensors:]
# autograd.grad returns None if no gradient, set to zero.
vjp_t = torch.zeros_like(t) if vjp_t is None else vjp_t
vjp_y = tuple(torch.zeros_like(y_) if vjp_y_ is None else vjp_y_ for vjp_y_, y_ in zip(vjp_y, y))
vjp_params = _flatten_convert_none_to_zeros(vjp_params, f_params)
if len(f_params) == 0:
vjp_params = torch.tensor(0.).to(vjp_y[0])
return (*func_eval, *vjp_y, vjp_t, vjp_params)
T = ans[0].shape[0]
with torch.no_grad():
adj_y = tuple(grad_output_[-1] for grad_output_ in grad_output)
adj_params = torch.zeros_like(flat_params)
adj_time = torch.tensor(0.).to(t)
time_vjps = []
for i in range(T - 1, 0, -1):
ans_i = tuple(ans_[i] for ans_ in ans)
grad_output_i = tuple(grad_output_[i] for grad_output_ in grad_output)
func_i = func(t[i], ans_i)
# Compute the effect of moving the current time measurement point.
dLd_cur_t = sum(
torch.dot(func_i_.reshape(-1), grad_output_i_.reshape(-1)).reshape(1)
for func_i_, grad_output_i_ in zip(func_i, grad_output_i)
)
adj_time = adj_time - dLd_cur_t
time_vjps.append(dLd_cur_t)
# Run the augmented system backwards in time.
if adj_params.numel() == 0:
adj_params = torch.tensor(0.).to(adj_y[0])
aug_y0 = (*ans_i, *adj_y, adj_time, adj_params)
aug_ans = odeint(
augmented_dynamics, aug_y0,
torch.tensor([t[i], t[i - 1]]), rtol=rtol, atol=atol, method=method, options=options
)
# Unpack aug_ans.
adj_y = aug_ans[n_tensors:2 * n_tensors]
adj_time = aug_ans[2 * n_tensors]
adj_params = aug_ans[2 * n_tensors + 1]
adj_y = tuple(adj_y_[1] if len(adj_y_) > 0 else adj_y_ for adj_y_ in adj_y)
if len(adj_time) > 0: adj_time = adj_time[1]
if len(adj_params) > 0: adj_params = adj_params[1]
adj_y = tuple(adj_y_ + grad_output_[i - 1] for adj_y_, grad_output_ in zip(adj_y, grad_output))
del aug_y0, aug_ans
time_vjps.append(adj_time)
time_vjps = torch.cat(time_vjps[::-1])
return (*adj_y, None, time_vjps, adj_params, None, None, None, None, None, None) # Add a None (TODO, futher check)
def odeint_adjoint(func, y0, t, dt, rtol=1e-6, atol=1e-12, method=None, options=None):
# We need this in order to access the variables inside this module,
# since we have no other way of getting variables along the execution path.
if not isinstance(func, nn.Module):
raise ValueError('func is required to be an instance of nn.Module.')
tensor_input = False
if torch.is_tensor(y0):
class TupleFunc(nn.Module):
def __init__(self, base_func):
super(TupleFunc, self).__init__()
self.base_func = base_func
def forward(self, t, y):
return (self.base_func(t, y[0]),)
tensor_input = True
y0 = (y0,)
func = TupleFunc(func)
flat_params = _flatten(func.parameters())
ys = OdeintAdjointMethod.apply(*y0, func, t, dt, flat_params, rtol, atol, method, options)
if tensor_input:
ys = ys[0]
return ys