|
|
""" |
|
|
This module contains utility functions that are explicitly allowed to be called during |
|
|
TorchDynamo compilation. These functions are carefully vetted to ensure they work |
|
|
correctly within the TorchDynamo tracing and compilation process. |
|
|
|
|
|
Key functionality groups: |
|
|
|
|
|
- Compilation State: |
|
|
Functions for checking compilation state (is_compiling) |
|
|
|
|
|
- Function Wrapping: |
|
|
Utilities for wrapping functions (wrap_inline, wrap_numpy) to work with |
|
|
TorchDynamo compilation |
|
|
|
|
|
- Autograd Hooks: |
|
|
Functions and classes for handling autograd hooks and backward passes |
|
|
(call_hook, FakeBackwardCFunction, etc.) |
|
|
|
|
|
- Tensor Operations: |
|
|
Utility functions for tensor operations and transformations |
|
|
""" |
|
|
|
|
|
import functools |
|
|
import warnings |
|
|
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union |
|
|
from typing_extensions import deprecated, ParamSpec |
|
|
|
|
|
import torch |
|
|
import torch.utils._pytree as pytree |
|
|
|
|
|
|
|
|
try: |
|
|
import numpy as np |
|
|
except ModuleNotFoundError: |
|
|
np = None |
|
|
|
|
|
_P = ParamSpec("_P") |
|
|
_R = TypeVar("_R") |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
|
|
|
|
|
@deprecated( |
|
|
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.", |
|
|
category=FutureWarning, |
|
|
) |
|
|
def is_compiling() -> bool: |
|
|
return torch.compiler.is_compiling() |
|
|
|
|
|
else: |
|
|
|
|
|
def is_compiling() -> bool: |
|
|
""" |
|
|
Indicates whether we are tracing/compiling with torch.compile() or torch.export(). |
|
|
""" |
|
|
|
|
|
|
|
|
return torch.compiler.is_compiling() |
|
|
|
|
|
|
|
|
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]: |
|
|
""" |
|
|
Create an extra frame around fn that is not in skipfiles. |
|
|
""" |
|
|
|
|
|
@functools.wraps(fn) |
|
|
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: |
|
|
return fn(*args, **kwargs) |
|
|
|
|
|
return inner |
|
|
|
|
|
|
|
|
def call_hook( |
|
|
hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Used by compiled autograd to handle hook returning None. |
|
|
""" |
|
|
result = hook(*args) |
|
|
if result is None: |
|
|
return args[0] |
|
|
elif kwargs.get("hook_type") == "post_acc_grad_hook": |
|
|
raise RuntimeError("Tensor post accumulate grad hooks should return None.") |
|
|
return result |
|
|
|
|
|
|
|
|
def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]: |
|
|
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function |
|
|
from ``torch.Tensor``s to ``torch.Tensor``s. |
|
|
""" |
|
|
if not np: |
|
|
return f |
|
|
|
|
|
@functools.wraps(f) |
|
|
def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree: |
|
|
args, kwargs = pytree.tree_map_only( |
|
|
torch.Tensor, lambda x: x.numpy(), (args, kwargs) |
|
|
) |
|
|
out = f(*args, **kwargs) |
|
|
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) |
|
|
|
|
|
return wrap |
|
|
|
|
|
|
|
|
class FakeBackwardCFunction: |
|
|
def __init__( |
|
|
self, |
|
|
real: torch.autograd.function.BackwardCFunction, |
|
|
saved_tensors: list[torch.Tensor], |
|
|
) -> None: |
|
|
self.real = real |
|
|
self.saved_tensors = saved_tensors |
|
|
|
|
|
def __getattr__(self, name: str) -> Any: |
|
|
if name == "saved_variables": |
|
|
warnings.warn( |
|
|
"'saved_variables' is deprecated; use 'saved_tensors'", |
|
|
DeprecationWarning, |
|
|
) |
|
|
return self.saved_tensors |
|
|
|
|
|
return getattr(self.real, name) |
|
|
|
|
|
|
|
|
def call_backward( |
|
|
backward_c_function: torch.autograd.function.BackwardCFunction, |
|
|
saved_tensors: list[torch.Tensor], |
|
|
*args: Any, |
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: |
|
|
fake = FakeBackwardCFunction(backward_c_function, saved_tensors) |
|
|
grads = fake._forward_cls.backward(fake, *args) |
|
|
|
|
|
if not isinstance(grads, tuple): |
|
|
grads = (grads,) |
|
|
|
|
|
return grads |
|
|
|
|
|
|
|
|
def normalize_as_list(x: Any) -> list[Any]: |
|
|
if isinstance(x, tuple): |
|
|
return list(x) |
|
|
elif isinstance(x, list): |
|
|
return x |
|
|
return [x] |
|
|
|
|
|
|
|
|
def untyped_storage_size(x: torch.Tensor) -> int: |
|
|
return x.untyped_storage().size() |
|
|
|
|
|
|
|
|
class FakeCompiledAutogradEngine: |
|
|
@staticmethod |
|
|
def queue_callback( |
|
|
final_callbacks: list[Callable[[], None]], cb: Callable[[], None] |
|
|
) -> None: |
|
|
final_callbacks.append(cb) |
|
|
|
|
|
@staticmethod |
|
|
def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None: |
|
|
i = 0 |
|
|
while i < len(final_callbacks): |
|
|
cb = final_callbacks[i] |
|
|
cb() |
|
|
i += 1 |
|
|
final_callbacks.clear() |
|
|
|
|
|
@staticmethod |
|
|
def _exec_final_callbacks_stub() -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
def call_hook_from_backward_state( |
|
|
*args: Any, bw_state: Any, hook_name: str, **kwargs: Any |
|
|
) -> Any: |
|
|
return getattr(bw_state, hook_name)(*args, **kwargs) |
|
|
|
|
|
|
|
|
def call_module_hooks_from_backward_state( |
|
|
_: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str |
|
|
) -> Any: |
|
|
module = getattr(bw_state, module_name) |
|
|
hooks = getattr(bw_state, hooks_name) |
|
|
for hook in hooks: |
|
|
new_result = hook(module, result, *args) |
|
|
if new_result is not None: |
|
|
result = new_result |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]: |
|
|
|
|
|
|
|
|
@functools.wraps(fn) |
|
|
def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: |
|
|
return fn(*args, **kwargs) |
|
|
|
|
|
return nonrecursive_disable_wrapper |
|
|
|
|
|
|
|
|
def wrap_dunder_call_ctx_manager(self: Any, func: Callable[_P, _R]) -> Callable[_P, _R]: |
|
|
""" |
|
|
Apply self as a ctx manager around a call to func |
|
|
""" |
|
|
|
|
|
|
|
|
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: |
|
|
with self: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
return inner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unwrap_maybe_dynamic_int(x: Union[torch.Tensor, int]) -> int: |
|
|
if isinstance(x, torch.Tensor): |
|
|
|
|
|
return x.size(1) |
|
|
return x |
|
|
|
|
|
|
|
|
def call_accumulate_grad( |
|
|
variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool |
|
|
) -> None: |
|
|
updated_grad = torch._dynamo.compiled_autograd.ops.AccumulateGrad( |
|
|
[grad], variable, variable.grad, has_post_hooks |
|
|
) |
|
|
variable.grad = updated_grad[0] |
|
|
|
|
|
|
|
|
def wrap_inline_with_error_on_graph_break( |
|
|
fn: Callable[_P, _R], error_on_graph_break: bool |
|
|
) -> Callable[_P, _R]: |
|
|
|
|
|
|
|
|
|
|
|
if error_on_graph_break: |
|
|
|
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: |
|
|
with torch._dynamo.error_on_graph_break(True): |
|
|
return fn(*args, **kwargs) |
|
|
|
|
|
else: |
|
|
|
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: |
|
|
with torch._dynamo.error_on_graph_break(False): |
|
|
return fn(*args, **kwargs) |
|
|
|
|
|
return wrapper |
|
|
|
|
|
|
|
|
def filter_out_const_values(tup: tuple[Any, ...], masks: list[bool]) -> tuple[Any, ...]: |
|
|
""" |
|
|
masks is a list of bools, where True means the corresponding element in tup |
|
|
is a const value. Filter out the const values. |
|
|
""" |
|
|
out = [] |
|
|
for mask_idx, mask in enumerate(masks): |
|
|
if not mask: |
|
|
out.append(tup[mask_idx]) |
|
|
return tuple(out) |
|
|
|
|
|
|
|
|
def insert_const_values_with_mask( |
|
|
tup: tuple[Any, ...], masks: list[bool], values: tuple[Any, ...] |
|
|
) -> tuple[Any, ...]: |
|
|
""" |
|
|
masks and values are of same length. For indices where the mask is True, use |
|
|
the const_values to fill in. |
|
|
""" |
|
|
out = [] |
|
|
idx = 0 |
|
|
for mask_idx, mask in enumerate(masks): |
|
|
if mask: |
|
|
out.append(values[mask_idx]) |
|
|
else: |
|
|
out.append(tup[idx]) |
|
|
idx += 1 |
|
|
return tuple(out) |
|
|
|