|
|
|
|
|
from typing import NamedTuple
|
|
|
|
|
|
import torch
|
|
|
import torch.utils._pytree as pytree
|
|
|
from torch._C._functorch import (
|
|
|
_unwrap_for_grad,
|
|
|
_wrap_for_grad,
|
|
|
current_level,
|
|
|
TransformType,
|
|
|
)
|
|
|
from torch._functorch.apis import vmap
|
|
|
from torch._functorch.utils import enable_single_level_autograd_function
|
|
|
from torch._functorch.vmap import (
|
|
|
_add_batch_dim,
|
|
|
_broadcast_to_and_flatten,
|
|
|
restore_vmap,
|
|
|
unwrap_batched,
|
|
|
wrap_batched,
|
|
|
)
|
|
|
from torch._ops import HigherOrderOperator
|
|
|
from torch.autograd.forward_ad import _set_fwd_grad_enabled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomFunctionHigherOrderOperator(HigherOrderOperator):
|
|
|
def __init__(self) -> None:
|
|
|
super().__init__("custom_function_call")
|
|
|
|
|
|
def __call__(self, autograd_function, *args, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch._C._are_functorch_transforms_active():
|
|
|
return super().__call__(autograd_function, *args, **kwargs)
|
|
|
return autograd_function.apply(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_function_call = CustomFunctionHigherOrderOperator()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@custom_function_call.py_impl(TransformType.Grad)
|
|
|
@custom_function_call.py_impl(TransformType.Jvp)
|
|
|
def custom_function_call_grad(interpreter, autograd_function, *operands):
|
|
|
Generated = generate_single_level_function(interpreter, autograd_function)
|
|
|
with enable_single_level_autograd_function():
|
|
|
flat_out = Generated.apply(*operands)
|
|
|
return flat_out
|
|
|
|
|
|
|
|
|
def generate_single_level_function(interpreter, autograd_function):
|
|
|
level = interpreter.level()
|
|
|
|
|
|
def forward(*operands):
|
|
|
unwrapped_operands = pytree.tree_map_only(
|
|
|
torch.Tensor, lambda x: _unwrap_for_grad(x, level), operands
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
|
|
|
unwrapped_output = custom_function_call(
|
|
|
autograd_function, *unwrapped_operands
|
|
|
)
|
|
|
|
|
|
|
|
|
def wrap_fn(output):
|
|
|
return _wrap_for_grad(output, level)
|
|
|
|
|
|
return wrap_outputs_maintaining_identity(
|
|
|
unwrapped_output, unwrapped_operands, operands, wrap_fn
|
|
|
)
|
|
|
|
|
|
def setup_context(ctx, inputs, output):
|
|
|
return autograd_function.setup_context(ctx, inputs, output)
|
|
|
|
|
|
|
|
|
def backward(ctx, *grads):
|
|
|
result = autograd_function.backward(ctx, *grads)
|
|
|
return result
|
|
|
|
|
|
|
|
|
def jvp(ctx, *tangents):
|
|
|
result = autograd_function.jvp(ctx, *tangents)
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
name = f"{autograd_function.__name__}Generated"
|
|
|
Generated = type(
|
|
|
name,
|
|
|
(torch.autograd.function._SingleLevelFunction,),
|
|
|
{
|
|
|
"forward": staticmethod(forward),
|
|
|
"backward": staticmethod(backward),
|
|
|
"jvp": staticmethod(jvp),
|
|
|
"setup_context": staticmethod(setup_context),
|
|
|
},
|
|
|
)
|
|
|
return Generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NO_OUT_DIMS = "not specified"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_outputs_maintaining_identity(
|
|
|
outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS
|
|
|
):
|
|
|
flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs)
|
|
|
flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs)
|
|
|
|
|
|
unwrapped_input_to_orig_input = {
|
|
|
id(unwrapped): orig
|
|
|
for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
|
|
|
}
|
|
|
|
|
|
flat_outputs, spec = pytree.tree_flatten(outputs)
|
|
|
result = []
|
|
|
|
|
|
out_dims_specified = out_dims != NO_OUT_DIMS
|
|
|
|
|
|
if out_dims_specified:
|
|
|
flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
|
|
|
|
|
|
|
|
|
if flat_out_dims is None:
|
|
|
raise RuntimeError(
|
|
|
f"The autograd.Function's vmap staticmethod returned an "
|
|
|
f"incompatible (output, out_dims) tuple. "
|
|
|
f"Expected out_dims={out_dims} "
|
|
|
f"to be compatible with the structure of `output`. "
|
|
|
f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} "
|
|
|
f"but output has structure {spec}. "
|
|
|
f"For more details, please see "
|
|
|
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
|
|
)
|
|
|
|
|
|
for i, output in enumerate(flat_outputs):
|
|
|
if not isinstance(output, torch.Tensor):
|
|
|
result.append(output)
|
|
|
continue
|
|
|
if id(output) in unwrapped_input_to_orig_input:
|
|
|
result.append(unwrapped_input_to_orig_input[id(output)])
|
|
|
continue
|
|
|
if out_dims_specified:
|
|
|
result.append(wrap_fn(output, flat_out_dims[i]))
|
|
|
else:
|
|
|
result.append(wrap_fn(output))
|
|
|
|
|
|
return pytree.tree_unflatten(result, spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VmapInfo(NamedTuple):
|
|
|
batch_size: int
|
|
|
randomness: str
|
|
|
|
|
|
|
|
|
def has_overridden_vmap_rule(autograd_function):
|
|
|
return autograd_function.vmap is not torch.autograd.Function.vmap
|
|
|
|
|
|
|
|
|
def validate_vmap_returns_tuple_of_two_elements(result):
|
|
|
base_error_msg = (
|
|
|
"Expected the vmap staticmethod to have two returns, an output "
|
|
|
"and out_dims with pytree structure compatible with the output. "
|
|
|
)
|
|
|
if not isinstance(result, tuple):
|
|
|
raise RuntimeError(base_error_msg + f"Got a {type(result)} instead")
|
|
|
if not len(result) == 2:
|
|
|
raise RuntimeError(base_error_msg + f"Got {len(result)} returns instead")
|
|
|
|
|
|
|
|
|
@custom_function_call.py_impl(TransformType.Vmap)
|
|
|
def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs):
|
|
|
if any(
|
|
|
isinstance(val, torch.Tensor)
|
|
|
for val in torch.utils._pytree.tree_flatten(kwargs)[0]
|
|
|
):
|
|
|
raise NotImplementedError(
|
|
|
f"Run vmap on autograd.Function with kwarg-only Tensor args. "
|
|
|
f"Please do not pass kwarg-only Tensors to autograd.Function. "
|
|
|
f"Got: {kwargs}"
|
|
|
)
|
|
|
|
|
|
if autograd_function.generate_vmap_rule:
|
|
|
if has_overridden_vmap_rule(autograd_function):
|
|
|
|
|
|
|
|
|
raise RuntimeError(
|
|
|
f"You tried to vmap over {autograd_function.__name__}, but "
|
|
|
f"it has both generate_vmap_rule=True and an overridden vmap "
|
|
|
f"staticmethod. Please set generate_vmap_rule=False or delete "
|
|
|
f"the overridden vmap staticmethod to avoid ambiguity. "
|
|
|
f"For more details, please see "
|
|
|
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
|
|
)
|
|
|
return custom_function_call_vmap_generate_rule(
|
|
|
interpreter, autograd_function, *operands
|
|
|
)
|
|
|
|
|
|
if not has_overridden_vmap_rule(autograd_function):
|
|
|
|
|
|
|
|
|
raise RuntimeError(
|
|
|
f"You tried to vmap over {autograd_function.__name__}, but "
|
|
|
f"it does not have vmap support. Please override and implement the "
|
|
|
f"vmap staticmethod or set generate_vmap_rule=True. "
|
|
|
f"For more details, please see "
|
|
|
f"https://pytorch.org/docs/main/notes/extending.func.html"
|
|
|
)
|
|
|
|
|
|
return custom_function_call_vmap_helper(
|
|
|
interpreter, autograd_function.vmap, autograd_function, *operands, **kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
def custom_function_call_vmap_helper(
|
|
|
interpreter, vmap_function, op, *operands, **kwargs
|
|
|
):
|
|
|
current_level = interpreter.level()
|
|
|
info = VmapInfo(
|
|
|
batch_size=interpreter.batch_size(),
|
|
|
randomness=interpreter.randomness(),
|
|
|
)
|
|
|
|
|
|
|
|
|
autograd_function_case = isinstance(op, torch.autograd.function.FunctionMeta)
|
|
|
|
|
|
def lower_to_next():
|
|
|
if autograd_function_case:
|
|
|
return interpreter.lower()
|
|
|
else:
|
|
|
return torch._C._ExcludeDispatchKeyGuard(
|
|
|
torch._C.DispatchKeySet(torch._C.DispatchKey.FuncTorchBatched)
|
|
|
)
|
|
|
|
|
|
unwrapped_operands, in_dims = unwrap_batched(operands, current_level)
|
|
|
|
|
|
|
|
|
|
|
|
if pytree.tree_all(lambda dim: dim is None, in_dims):
|
|
|
with lower_to_next():
|
|
|
if autograd_function_case:
|
|
|
return custom_function_call(op, *operands)
|
|
|
else:
|
|
|
return op(*operands, **kwargs)
|
|
|
|
|
|
with lower_to_next():
|
|
|
result = vmap_function(info, in_dims, *unwrapped_operands, **kwargs)
|
|
|
validate_vmap_returns_tuple_of_two_elements(result)
|
|
|
unwrapped_output, out_dims = result
|
|
|
|
|
|
|
|
|
def wrap_fn(output, out_dim):
|
|
|
return (
|
|
|
output
|
|
|
if out_dim is None
|
|
|
else _add_batch_dim(output, out_dim, current_level)
|
|
|
)
|
|
|
|
|
|
return wrap_outputs_maintaining_identity(
|
|
|
unwrapped_output, unwrapped_operands, operands, wrap_fn, out_dims=out_dims
|
|
|
)
|
|
|
|
|
|
|
|
|
def unpack_outputs(outputs):
|
|
|
out_dims = outputs[-1]
|
|
|
if isinstance(out_dims, tuple):
|
|
|
outputs = outputs[:-1]
|
|
|
else:
|
|
|
outputs = outputs[0]
|
|
|
return outputs, out_dims
|
|
|
|
|
|
|
|
|
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):
|
|
|
unwrapped_operands, in_dims = unwrap_batched(operands, interpreter.level())
|
|
|
vmapped_function = vmapify_autograd_function(
|
|
|
autograd_function, in_dims, interpreter.batch_size(), interpreter.randomness()
|
|
|
)
|
|
|
with interpreter.lower():
|
|
|
outputs = custom_function_call(vmapped_function, *unwrapped_operands)
|
|
|
|
|
|
assert isinstance(outputs, tuple)
|
|
|
outputs, out_dims = unpack_outputs(outputs)
|
|
|
return wrap_batched(outputs, out_dims, interpreter.level())
|
|
|
|
|
|
|
|
|
@custom_function_call.py_impl(TransformType.Functionalize)
|
|
|
def custom_function_call_functionalize(
|
|
|
interpreter, autograd_function, generate_vmap_rule, *operands
|
|
|
):
|
|
|
raise RuntimeError("NYI: Functionalize rule for custom_function_call")
|
|
|
|
|
|
|
|
|
def vmapify_autograd_function(autograd_function, in_dims, batch_size, randomness):
|
|
|
def forward(*operands):
|
|
|
outputs, out_dims = restore_vmap(
|
|
|
autograd_function.forward, in_dims, batch_size, randomness
|
|
|
)(*operands)
|
|
|
if isinstance(outputs, torch.Tensor):
|
|
|
return outputs, out_dims
|
|
|
else:
|
|
|
return *outputs, out_dims
|
|
|
|
|
|
def setup_context(ctx, inputs, outputs):
|
|
|
outputs, out_dims = unpack_outputs(outputs)
|
|
|
key = id(Generated)
|
|
|
|
|
|
def inner(inputs, outputs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
wrapped_ctx = CtxCustomSave(ctx, current_level())
|
|
|
autograd_function.setup_context(wrapped_ctx, inputs, outputs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_shapes = tuple(
|
|
|
inp.shape if isinstance(inp, torch.Tensor) else None for inp in inputs
|
|
|
)
|
|
|
if not hasattr(ctx, "_pt_input_shapes"):
|
|
|
ctx._pt_input_shapes = {}
|
|
|
ctx._pt_input_shapes.update({key: input_shapes})
|
|
|
|
|
|
if not hasattr(ctx, "_pt_saved_tensors_bdims_stack"):
|
|
|
ctx._pt_saved_tensors_bdims_stack = {}
|
|
|
ctx._pt_saved_tensors_bdims_stack.update(
|
|
|
{key: (wrapped_ctx._pt_saved_tensors_bdims)}
|
|
|
)
|
|
|
|
|
|
|
|
|
restore_vmap(
|
|
|
inner,
|
|
|
(in_dims, out_dims),
|
|
|
batch_size,
|
|
|
randomness,
|
|
|
)(inputs, outputs)
|
|
|
|
|
|
if not hasattr(ctx, "_pt_out_dims"):
|
|
|
ctx._pt_out_dims = {}
|
|
|
ctx._pt_out_dims.update({key: out_dims})
|
|
|
|
|
|
def jvp(ctx, *tangents):
|
|
|
key = id(Generated)
|
|
|
|
|
|
def jvp_no_context(saved_tensors, tangents):
|
|
|
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
|
|
|
return autograd_function.jvp(wrapped_ctx, *tangents)
|
|
|
|
|
|
tangent_in_dims = get_tangents_in_dims(in_dims, tangents)
|
|
|
out_tangents, out_tangents_dims = restore_vmap(
|
|
|
jvp_no_context,
|
|
|
(ctx._pt_saved_tensors_bdims_stack[key], tangent_in_dims),
|
|
|
batch_size,
|
|
|
randomness,
|
|
|
)(ctx.saved_tensors, tangents)
|
|
|
|
|
|
result = reductify(
|
|
|
out_tangents, out_tangents_dims, ctx._pt_out_dims[key], batch_size
|
|
|
)
|
|
|
if isinstance(result, torch.Tensor):
|
|
|
return result, None
|
|
|
else:
|
|
|
return *result, None
|
|
|
|
|
|
def backward(ctx, *grad_outputs):
|
|
|
key = id(Generated)
|
|
|
grad_outputs_ = grad_outputs[:-1]
|
|
|
grad_outputs_in_dims = ctx._pt_out_dims[key]
|
|
|
|
|
|
if not isinstance(grad_outputs_in_dims, tuple):
|
|
|
grad_outputs_in_dims = (grad_outputs_in_dims,)
|
|
|
|
|
|
grad_outputs_in_dims = tuple(
|
|
|
in_dim if grad_output is not None else None
|
|
|
for grad_output, in_dim in zip(grad_outputs_, grad_outputs_in_dims)
|
|
|
)
|
|
|
|
|
|
def backward_no_context(inputs):
|
|
|
saved_tensors, grad_outputs = inputs
|
|
|
wrapped_ctx = CtxWithSavedTensors(ctx, saved_tensors)
|
|
|
return autograd_function.backward(wrapped_ctx, *grad_outputs)
|
|
|
|
|
|
grad_ins, grad_ins_dims = restore_vmap(
|
|
|
backward_no_context,
|
|
|
((ctx._pt_saved_tensors_bdims_stack[key], grad_outputs_in_dims),),
|
|
|
batch_size,
|
|
|
randomness,
|
|
|
)((ctx.saved_tensors, grad_outputs_))
|
|
|
result = reductify(
|
|
|
grad_ins, grad_ins_dims, in_dims, batch_size, ctx._pt_input_shapes[key]
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
name = f"Vmapped{autograd_function.__name__}"
|
|
|
Generated = type(
|
|
|
name,
|
|
|
(torch.autograd.Function,),
|
|
|
{
|
|
|
"forward": staticmethod(forward),
|
|
|
"backward": staticmethod(backward),
|
|
|
"jvp": staticmethod(jvp),
|
|
|
"setup_context": staticmethod(setup_context),
|
|
|
"generate_vmap_rule": True,
|
|
|
},
|
|
|
)
|
|
|
|
|
|
return Generated
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tangents_in_dims(input_dims, tangents):
|
|
|
flat_in_dims, spec = pytree.tree_flatten(input_dims)
|
|
|
flat_tangents = pytree.arg_tree_leaves(*tangents)
|
|
|
result = [
|
|
|
None if tangent is None else in_dim
|
|
|
for in_dim, tangent in zip(flat_in_dims, flat_tangents)
|
|
|
]
|
|
|
return pytree.tree_unflatten(result, spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WrappedCtx:
|
|
|
_pt_reserved_attrs: tuple[str, ...] = ("_pt_reserved_attrs", "_pt_inner_ctx")
|
|
|
|
|
|
def __init__(self, ctx):
|
|
|
if not isinstance(ctx, WrappedCtx):
|
|
|
reserved_attrs = type(self)._pt_reserved_attrs
|
|
|
for name in reserved_attrs:
|
|
|
if not hasattr(ctx, name):
|
|
|
continue
|
|
|
raise RuntimeError(
|
|
|
f"PyTorch reserves the {reserved_attrs} field on ctx. "
|
|
|
"Please name your fields on ctx something else to avoid name "
|
|
|
"collision."
|
|
|
)
|
|
|
self._pt_inner_ctx = ctx
|
|
|
|
|
|
def __getattr__(self, name):
|
|
|
return getattr(self._pt_inner_ctx, name)
|
|
|
|
|
|
def __setattr__(self, name, value):
|
|
|
if name in type(self)._pt_reserved_attrs:
|
|
|
self.__dict__[name] = value
|
|
|
return
|
|
|
return setattr(self._pt_inner_ctx, name, value)
|
|
|
|
|
|
|
|
|
|
|
|
class CtxWithSavedTensors(WrappedCtx):
|
|
|
_pt_reserved_attrs = ("_pt_new_saved_tensors", *WrappedCtx._pt_reserved_attrs)
|
|
|
|
|
|
def __init__(self, ctx, new_saved_tensors):
|
|
|
super().__init__(ctx)
|
|
|
self._pt_new_saved_tensors = new_saved_tensors
|
|
|
|
|
|
@property
|
|
|
def saved_tensors(self):
|
|
|
return self._pt_new_saved_tensors
|
|
|
|
|
|
|
|
|
class CtxCustomSave(WrappedCtx):
|
|
|
_pt_reserved_attrs = (
|
|
|
"_pt_saved_tensors_bdims",
|
|
|
"_pt_current_level",
|
|
|
*WrappedCtx._pt_reserved_attrs,
|
|
|
)
|
|
|
|
|
|
def __init__(self, ctx, current_level):
|
|
|
super().__init__(ctx)
|
|
|
self._pt_saved_tensors_bdims = ()
|
|
|
self._pt_current_level = current_level
|
|
|
|
|
|
def save_for_backward(self, *tensors):
|
|
|
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
|
|
|
self._pt_inner_ctx.save_for_backward(*unwrapped_tensors)
|
|
|
self._pt_saved_tensors_bdims = bdims
|
|
|
|
|
|
def save_for_forward(self, *tensors):
|
|
|
unwrapped_tensors, bdims = unwrap_batched(tensors, self._pt_current_level)
|
|
|
self._pt_inner_ctx.save_for_forward(*unwrapped_tensors)
|
|
|
self._pt_saved_tensors_bdims = bdims
|
|
|
|
|
|
|
|
|
def reductify(
|
|
|
grad_input,
|
|
|
grad_input_bdim,
|
|
|
input_bdim,
|
|
|
batch_size,
|
|
|
target_shape_without_bdim_to_reduce_to=None,
|
|
|
):
|
|
|
if not isinstance(grad_input, tuple):
|
|
|
grad_input = (grad_input,)
|
|
|
if not isinstance(grad_input_bdim, tuple):
|
|
|
grad_input_bdim = (grad_input_bdim,)
|
|
|
if not isinstance(input_bdim, tuple):
|
|
|
input_bdim = (input_bdim,)
|
|
|
|
|
|
if target_shape_without_bdim_to_reduce_to is None:
|
|
|
target_shape_without_bdim_to_reduce_to = len(grad_input) * (None,)
|
|
|
result = tuple(
|
|
|
reductify_leaf(gi, gi_bdim, i_bdim, batch_size, maybe_ishape)
|
|
|
for gi, gi_bdim, i_bdim, maybe_ishape in zip(
|
|
|
grad_input,
|
|
|
grad_input_bdim,
|
|
|
input_bdim,
|
|
|
target_shape_without_bdim_to_reduce_to,
|
|
|
)
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
|
|
|
def reductify_leaf(
|
|
|
grad_input,
|
|
|
grad_input_bdim,
|
|
|
input_bdim,
|
|
|
batch_size,
|
|
|
target_shape_without_bdim_to_reduce_to=None,
|
|
|
):
|
|
|
if grad_input is None:
|
|
|
return None
|
|
|
|
|
|
if grad_input_bdim is None and input_bdim is None:
|
|
|
return grad_input
|
|
|
|
|
|
if grad_input_bdim is not None and input_bdim is None:
|
|
|
return grad_input.sum(grad_input_bdim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert input_bdim is not None
|
|
|
|
|
|
if grad_input_bdim is None:
|
|
|
grad_input = grad_input.unsqueeze(input_bdim)
|
|
|
new_shape = list(grad_input.shape)
|
|
|
new_shape[input_bdim] = batch_size
|
|
|
grad_input = grad_input.expand(new_shape)
|
|
|
grad_input_bdim = input_bdim
|
|
|
|
|
|
if target_shape_without_bdim_to_reduce_to is not None:
|
|
|
return vmap(
|
|
|
torch.Tensor.sum_to_size,
|
|
|
in_dims=(grad_input_bdim, None),
|
|
|
out_dims=input_bdim,
|
|
|
)(grad_input, target_shape_without_bdim_to_reduce_to)
|
|
|
|
|
|
if input_bdim != grad_input_bdim:
|
|
|
grad_input = grad_input.movedim(grad_input_bdim, input_bdim)
|
|
|
return grad_input
|
|
|
|
|
|
|
|
|
def autograd_function_forward_rewritten(original_forward, original_setup_context):
|
|
|
def new_forward(ctx, *args, **kwargs):
|
|
|
output = original_forward(*args, **kwargs)
|
|
|
original_setup_context(ctx, args, output)
|
|
|
return output
|
|
|
|
|
|
return new_forward
|
|
|
|
|
|
|
|
|
class AutogradFunctionApply(HigherOrderOperator):
|
|
|
def __init__(self) -> None:
|
|
|
super().__init__("autograd_function_apply")
|
|
|
|
|
|
def __call__(self, fwd, bwd, *fwd_args, **fwd_kwargs):
|
|
|
saved_values = None
|
|
|
args_tensor_mask = fwd_kwargs["args_tensor_mask"]
|
|
|
non_differentiable_idx = fwd_kwargs["non_differentiable_idx"]
|
|
|
length_of_tensor_args = sum(args_tensor_mask)
|
|
|
|
|
|
|
|
|
|
|
|
new_fwd_args = fwd_args[:length_of_tensor_args]
|
|
|
|
|
|
class ApplyTemplate(torch.autograd.Function):
|
|
|
@staticmethod
|
|
|
def forward(ctx, *args):
|
|
|
nonlocal saved_values
|
|
|
output, saved_values = fwd(None, *fwd_args)
|
|
|
|
|
|
|
|
|
if len(non_differentiable_idx) > 0:
|
|
|
non_differentiable_output = []
|
|
|
for i, x in enumerate(output):
|
|
|
if i in non_differentiable_idx:
|
|
|
non_differentiable_output.append(x)
|
|
|
ctx.mark_non_differentiable(*non_differentiable_output)
|
|
|
|
|
|
return output
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, *grad):
|
|
|
return bwd(None, *grad, *saved_values)
|
|
|
|
|
|
return ApplyTemplate.apply(*new_fwd_args)
|
|
|
|
|
|
|
|
|
autograd_function_apply = AutogradFunctionApply()
|
|
|
|