File size: 687 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from torch._functorch.apis import grad, grad_and_value, vmap
from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_
from torch._functorch.eager_transforms import (
debug_unwrap,
functionalize,
hessian,
jacfwd,
jacrev,
jvp,
linearize,
vjp,
)
from torch._functorch.functional_call import functional_call, stack_module_state
__all__ = [
"grad",
"grad_and_value",
"vmap",
"replace_all_batch_norm_modules_",
"functionalize",
"hessian",
"jacfwd",
"jacrev",
"jvp",
"linearize",
"vjp",
"functional_call",
"stack_module_state",
"debug_unwrap",
]
|