File size: 1,092 Bytes
59f1501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import contextlib
from collections.abc import Generator
from typing import Any, Union
import torch
from torch._C._functorch import (
get_single_level_autograd_function_allowed,
set_single_level_autograd_function_allowed,
unwrap_if_dead,
)
from torch.utils._exposed_in import exposed_in
__all__ = [
"exposed_in",
"argnums_t",
"enable_single_level_autograd_function",
"unwrap_dead_wrappers",
]
@contextlib.contextmanager
def enable_single_level_autograd_function() -> Generator[None, None, None]:
try:
prev_state = get_single_level_autograd_function_allowed()
set_single_level_autograd_function_allowed(True)
yield
finally:
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args: tuple[Any, ...]) -> tuple[Any, ...]:
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
)
return result
argnums_t = Union[int, tuple[int, ...]]
|