File size: 7,080 Bytes
59f1501 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
# This module contains functions that *will be allowed* by dynamo
"""
This module contains utility functions that are explicitly allowed to be called during
TorchDynamo compilation. These functions are carefully vetted to ensure they work
correctly within the TorchDynamo tracing and compilation process.
Key functionality groups:
- Compilation State:
Functions for checking compilation state (is_compiling)
- Function Wrapping:
Utilities for wrapping functions (wrap_inline, wrap_numpy) to work with
TorchDynamo compilation
- Autograd Hooks:
Functions and classes for handling autograd hooks and backward passes
(call_hook, FakeBackwardCFunction, etc.)
- Tensor Operations:
Utility functions for tensor operations and transformations
"""
import functools
import warnings
from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import deprecated, ParamSpec
import torch
import torch.utils._pytree as pytree
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
_P = ParamSpec("_P")
_R = TypeVar("_R")
if TYPE_CHECKING:
# TorchScript does not support `@deprecated`
# This is a workaround to avoid breaking TorchScript
@deprecated(
"`torch._dynamo.external_utils.is_compiling` is deprecated. Use `torch.compiler.is_compiling` instead.",
category=FutureWarning,
)
def is_compiling() -> bool:
return torch.compiler.is_compiling()
else:
def is_compiling() -> bool:
"""
Indicates whether we are tracing/compiling with torch.compile() or torch.export().
"""
# NOTE: With `@torch.compile(backend="eager")`, torch._dynamo.is_compiling() will get traced
# and return true. torch.compiler.is_compiling() is skipped and will return false.
return torch.compiler.is_compiling()
def wrap_inline(fn: Callable[_P, _R]) -> Callable[_P, _R]:
"""
Create an extra frame around fn that is not in skipfiles.
"""
@functools.wraps(fn)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return fn(*args, **kwargs)
return inner
def call_hook(
hook: Callable[..., Optional[torch.Tensor]], *args: Any, **kwargs: Any
) -> torch.Tensor:
"""
Used by compiled autograd to handle hook returning None.
"""
result = hook(*args)
if result is None:
return args[0]
elif kwargs.get("hook_type") == "post_acc_grad_hook":
raise RuntimeError("Tensor post accumulate grad hooks should return None.")
return result
def wrap_numpy(f: Callable[_P, _R]) -> Callable[_P, _R]:
r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function
from ``torch.Tensor``s to ``torch.Tensor``s.
"""
if not np:
return f
@functools.wraps(f)
def wrap(*args: _P.args, **kwargs: _P.kwargs) -> pytree.PyTree:
args, kwargs = pytree.tree_map_only(
torch.Tensor, lambda x: x.numpy(), (args, kwargs)
)
out = f(*args, **kwargs)
return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out)
return wrap
class FakeBackwardCFunction:
def __init__(
self,
real: torch.autograd.function.BackwardCFunction,
saved_tensors: list[torch.Tensor],
) -> None:
self.real = real
self.saved_tensors = saved_tensors
def __getattr__(self, name: str) -> Any:
if name == "saved_variables":
warnings.warn(
"'saved_variables' is deprecated; use 'saved_tensors'",
DeprecationWarning,
)
return self.saved_tensors
return getattr(self.real, name)
def call_backward(
backward_c_function: torch.autograd.function.BackwardCFunction,
saved_tensors: list[torch.Tensor],
*args: Any,
) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
fake = FakeBackwardCFunction(backward_c_function, saved_tensors)
grads = fake._forward_cls.backward(fake, *args) # type: ignore[attr-defined]
if not isinstance(grads, tuple):
grads = (grads,)
return grads
def normalize_as_list(x: Any) -> list[Any]:
if isinstance(x, tuple):
return list(x)
elif isinstance(x, list):
return x
return [x]
def untyped_storage_size(x: torch.Tensor) -> int:
return x.untyped_storage().size()
class FakeCompiledAutogradEngine:
@staticmethod
def queue_callback(
final_callbacks: list[Callable[[], None]], cb: Callable[[], None]
) -> None:
final_callbacks.append(cb)
@staticmethod
def exec_final_callbacks(final_callbacks: list[Callable[[], None]]) -> None:
i = 0
while i < len(final_callbacks):
cb = final_callbacks[i]
cb()
i += 1
final_callbacks.clear()
@staticmethod
def _exec_final_callbacks_stub() -> None:
pass
def call_hook_from_backward_state(
*args: Any, bw_state: Any, hook_name: str, **kwargs: Any
) -> Any:
return getattr(bw_state, hook_name)(*args, **kwargs)
def call_module_hooks_from_backward_state(
_: Any, result: Any, *args: Any, bw_state: Any, hooks_name: str, module_name: str
) -> Any:
module = getattr(bw_state, module_name)
hooks = getattr(bw_state, hooks_name)
for hook in hooks:
new_result = hook(module, result, *args)
if new_result is not None:
result = new_result
return result
# used for torch._dynamo.disable(recursive=False)
def get_nonrecursive_disable_wrapper(fn: Callable[_P, _R]) -> Callable[_P, _R]:
# wrap function to get the right error message
# this function is in external_utils so that convert_frame doesn't skip it.
@functools.wraps(fn)
def nonrecursive_disable_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
return fn(*args, **kwargs)
return nonrecursive_disable_wrapper
def _dynamo_config_patch_proxy_dunder_call(
self: Any, func: Callable[_P, _R]
) -> Callable[_P, _R]:
@functools.wraps(func)
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with self:
return func(*args, **kwargs)
return inner
# Use only on ints marked dynamic via torch.empty(0, integer)
# Currently only way to mark ints as dynamic: https://github.com/pytorch/pytorch/issues/129623
def unwrap_maybe_dynamic_int(x: Union[torch.Tensor, int]) -> int:
if isinstance(x, torch.Tensor):
# x.size() is expected to be [0, dynamic_int]
return x.size(1)
return x
def call_accumulate_grad(
variable: torch.Tensor, grad: torch.Tensor, has_post_hooks: bool
) -> None:
updated_grad = torch._dynamo.compiled_autograd.ops.AccumulateGrad( # type: ignore[attr-defined]
[grad], variable, variable.grad, has_post_hooks
)
variable.grad = updated_grad[0]
|