|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import functools
|
|
|
import inspect
|
|
|
import logging
|
|
|
import operator
|
|
|
import traceback
|
|
|
import typing
|
|
|
import typing_extensions
|
|
|
import weakref
|
|
|
from collections import defaultdict, OrderedDict
|
|
|
from collections.abc import Generator, Mapping, Sequence
|
|
|
from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext
|
|
|
from dataclasses import dataclass
|
|
|
from typing import (
|
|
|
Any,
|
|
|
Callable,
|
|
|
Optional,
|
|
|
overload,
|
|
|
Protocol,
|
|
|
TYPE_CHECKING,
|
|
|
TypeVar,
|
|
|
Union,
|
|
|
)
|
|
|
from typing_extensions import Concatenate, ParamSpec, Self, TypeVarTuple, Unpack
|
|
|
from weakref import WeakKeyDictionary
|
|
|
|
|
|
import torch
|
|
|
import torch._ops
|
|
|
import torch.fx as fx
|
|
|
import torch.fx.traceback as fx_traceback
|
|
|
import torch.utils._pytree as pytree
|
|
|
from torch import SymBool, SymInt, Tensor
|
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
|
from torch._library.fake_class_registry import FakeScriptObject
|
|
|
from torch._logging import trace_structured
|
|
|
from torch._subclasses.fake_impls import fast_detach
|
|
|
from torch._subclasses.fake_tensor import (
|
|
|
FakeTensor,
|
|
|
FakeTensorMode,
|
|
|
is_fake,
|
|
|
unset_fake_temporarily,
|
|
|
)
|
|
|
from torch._subclasses.meta_utils import is_sparse_any
|
|
|
from torch.fx import GraphModule, Proxy, Tracer
|
|
|
from torch.fx.graph_module import _assign_attr
|
|
|
from torch.fx.node import (
|
|
|
_side_effectful_need_to_be_preserved_pre_dispatch,
|
|
|
Argument,
|
|
|
Target,
|
|
|
)
|
|
|
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
|
|
from torch.nn import Module
|
|
|
from torch.overrides import TorchFunctionMode
|
|
|
from torch.utils._python_dispatch import (
|
|
|
_disable_infra_mode,
|
|
|
_push_mode,
|
|
|
_unset_infra_mode,
|
|
|
TorchDispatchMode,
|
|
|
)
|
|
|
from torch.utils._stats import count
|
|
|
from torch.utils._thunk import Thunk
|
|
|
from torch.utils._traceback import CapturedTraceback
|
|
|
from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary
|
|
|
|
|
|
from ._backward_state import BackwardState
|
|
|
from .sym_node import SymNode
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
import types
|
|
|
from collections.abc import MutableMapping
|
|
|
|
|
|
import sympy
|
|
|
|
|
|
from torch._ops import OpOverload
|
|
|
from torch.fx._symbolic_trace import PHBase
|
|
|
from torch.types import IntLikeType
|
|
|
|
|
|
__all__ = [
|
|
|
"PythonKeyTracer",
|
|
|
"dispatch_trace",
|
|
|
"make_fx",
|
|
|
"DecompositionInterpreter",
|
|
|
"py_sym_types",
|
|
|
"get_innermost_proxy_mode",
|
|
|
"get_proxy_mode",
|
|
|
"handle_sym_dispatch",
|
|
|
"maybe_enable_thunkify",
|
|
|
"maybe_disable_thunkify",
|
|
|
]
|
|
|
|
|
|
_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"]
|
|
|
|
|
|
_AnyScriptObject = (torch.ScriptObject, FakeScriptObject)
|
|
|
_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject]
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
prim = torch.ops.prim
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented")
|
|
|
|
|
|
CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {}
|
|
|
|
|
|
CONSTANT_NUMEL_LIMIT = 1
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
U = TypeVar("U")
|
|
|
_P = ParamSpec("_P")
|
|
|
R = TypeVar("R")
|
|
|
_Ts = TypeVarTuple("_Ts")
|
|
|
|
|
|
null_ctx_type = type(nullcontext)
|
|
|
|
|
|
|
|
|
pytree.register_pytree_node(
|
|
|
torch.Size,
|
|
|
lambda xs: (list(xs), None),
|
|
|
lambda xs, _: tuple(xs),
|
|
|
flatten_with_keys_fn=lambda xs: (
|
|
|
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
|
|
|
None,
|
|
|
),
|
|
|
serialized_type_name="torch.Size",
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_pytree_subclasses_that_lose_info = {torch.Size: tuple}
|
|
|
|
|
|
|
|
|
def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]:
|
|
|
"""FX gets confused by varargs, de-confuse it"""
|
|
|
argnames = ",".join(f"arg{i}" for i in range(nargs))
|
|
|
return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn})
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def decompose(
|
|
|
decomposition_table: Optional[Mapping[OpOverload, Callable]],
|
|
|
) -> Generator[Mapping[OpOverload, Callable], None, None]:
|
|
|
global CURRENT_DECOMPOSITION_TABLE
|
|
|
old_decomposition_table = CURRENT_DECOMPOSITION_TABLE
|
|
|
CURRENT_DECOMPOSITION_TABLE = decomposition_table or {}
|
|
|
try:
|
|
|
yield CURRENT_DECOMPOSITION_TABLE
|
|
|
finally:
|
|
|
CURRENT_DECOMPOSITION_TABLE = old_decomposition_table
|
|
|
|
|
|
|
|
|
|
|
|
proxy_slot = object()
|
|
|
|
|
|
|
|
|
class _NoDefault:
|
|
|
pass
|
|
|
|
|
|
|
|
|
no_default = _NoDefault()
|
|
|
|
|
|
from torch.types import py_sym_types, PySymType
|
|
|
|
|
|
|
|
|
class _HasMeta(Protocol):
|
|
|
meta: dict[str, PySymType]
|
|
|
|
|
|
|
|
|
def is_sym_node(node: _HasMeta) -> bool:
|
|
|
assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta"
|
|
|
return "val" in node.meta and isinstance(node.meta["val"], py_sym_types)
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def set_proxy_slot(
|
|
|
obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy
|
|
|
) -> None: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def set_proxy_slot(
|
|
|
obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType
|
|
|
) -> None: ...
|
|
|
|
|
|
|
|
|
def set_proxy_slot(
|
|
|
obj: Union[PySymType, _AnyScriptObjectType, Tensor],
|
|
|
tracer: _ProxyTracer,
|
|
|
proxy: object,
|
|
|
) -> None:
|
|
|
log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy)
|
|
|
if isinstance(obj, Tensor):
|
|
|
|
|
|
|
|
|
assert isinstance(proxy, _ProxyTensor)
|
|
|
tracer.tensor_tracker[obj] = proxy
|
|
|
elif isinstance(obj, (_AnyScriptObject)):
|
|
|
|
|
|
assert isinstance(proxy, Proxy)
|
|
|
tracer.script_object_tracker[obj] = proxy
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(obj, py_sym_types), type(obj)
|
|
|
if obj not in tracer.symnode_tracker:
|
|
|
tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sympy
|
|
|
|
|
|
if isinstance(obj.node.expr, sympy.Symbol):
|
|
|
tracer.sympy_expr_tracker[obj.node.expr] = proxy
|
|
|
|
|
|
|
|
|
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
|
|
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
|
|
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
|
|
|
|
|
|
|
|
|
_PySymProxyType = Thunk[Proxy]
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: Tensor,
|
|
|
tracer: _ProxyTracer,
|
|
|
) -> _ProxyTensor: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: Tensor,
|
|
|
tracer: _ProxyTracer,
|
|
|
default: U,
|
|
|
) -> Union[_ProxyTensor, U]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: Tensor,
|
|
|
tracer: _ProxyTracer,
|
|
|
default: U,
|
|
|
transform: Callable[[_ProxyTensor], R],
|
|
|
) -> Union[R, U]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: _AnyScriptObjectType,
|
|
|
tracer: _ProxyTracer,
|
|
|
) -> Proxy: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: _AnyScriptObjectType,
|
|
|
tracer: _ProxyTracer,
|
|
|
default: U,
|
|
|
) -> Union[Proxy, U]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: _AnyScriptObjectType,
|
|
|
tracer: _ProxyTracer,
|
|
|
default: U,
|
|
|
transform: Callable[[Proxy], R],
|
|
|
) -> Union[R, U]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: PySymType,
|
|
|
tracer: _ProxyTracer,
|
|
|
) -> _PySymProxyType: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: PySymType,
|
|
|
tracer: _ProxyTracer,
|
|
|
default: T,
|
|
|
) -> Union[T, _PySymProxyType]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def get_proxy_slot(
|
|
|
obj: PySymType,
|
|
|
tracer: _ProxyTracer,
|
|
|
default: U,
|
|
|
transform: Callable[[_PySymProxyType], R],
|
|
|
) -> Union[R, U]: ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_proxy_slot(
|
|
|
obj: Union[Tensor, _AnyScriptObjectType, PySymType],
|
|
|
tracer: _ProxyTracer,
|
|
|
default: object = no_default,
|
|
|
transform: Callable = lambda x: x,
|
|
|
) -> object:
|
|
|
tracker: Any
|
|
|
if isinstance(obj, Tensor):
|
|
|
tracker = tracer.tensor_tracker
|
|
|
elif isinstance(obj, _AnyScriptObject):
|
|
|
tracker = tracer.script_object_tracker
|
|
|
else:
|
|
|
assert isinstance(obj, py_sym_types), type(obj)
|
|
|
tracker = tracer.symnode_tracker
|
|
|
|
|
|
if obj not in tracker:
|
|
|
|
|
|
if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker:
|
|
|
value = tracer.sympy_expr_tracker[obj.node.expr]
|
|
|
else:
|
|
|
if isinstance(default, _NoDefault):
|
|
|
raise RuntimeError(
|
|
|
f"{obj} ({id(obj)})is not tracked with proxy for {tracer}"
|
|
|
)
|
|
|
return default
|
|
|
else:
|
|
|
value = tracker[obj]
|
|
|
res = transform(value)
|
|
|
return res
|
|
|
|
|
|
|
|
|
def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]:
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(val, FakeTensor):
|
|
|
return fast_detach(val.fake_mode, val, include_real)
|
|
|
else:
|
|
|
return val.detach()
|
|
|
|
|
|
|
|
|
_ExtractValType = Optional[
|
|
|
Union[
|
|
|
PySymType,
|
|
|
_AnyScriptObjectType,
|
|
|
BackwardState,
|
|
|
list["_ExtractValType"],
|
|
|
tuple["_ExtractValType", ...],
|
|
|
dict[str, "_ExtractValType"],
|
|
|
Tensor,
|
|
|
int,
|
|
|
float,
|
|
|
bool,
|
|
|
]
|
|
|
]
|
|
|
|
|
|
|
|
|
def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType:
|
|
|
if is_fake(val):
|
|
|
return snapshot_fake(val, include_real=include_real)
|
|
|
elif isinstance(val, py_sym_types):
|
|
|
return val
|
|
|
elif isinstance(val, _AnyScriptObject):
|
|
|
return val
|
|
|
elif isinstance(val, BackwardState):
|
|
|
return val
|
|
|
elif isinstance(val, (list, tuple)):
|
|
|
return val.__class__([extract_val(x) for x in val])
|
|
|
elif isinstance(val, dict):
|
|
|
return {k: extract_val(v) for k, v in val.items()}
|
|
|
elif isinstance(val, Tensor):
|
|
|
if not val.is_sparse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch._guards import detect_fake_mode
|
|
|
|
|
|
fake_tensor_mode = detect_fake_mode(val)
|
|
|
if not fake_tensor_mode:
|
|
|
fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True)
|
|
|
with fake_tensor_mode:
|
|
|
return torch.empty_strided(
|
|
|
val.shape, val.stride(), device=val.device, dtype=val.dtype
|
|
|
)
|
|
|
else:
|
|
|
return None
|
|
|
elif isinstance(val, (int, float, bool)):
|
|
|
return val
|
|
|
elif val is None:
|
|
|
return None
|
|
|
|
|
|
typing_extensions.assert_never(val)
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def _enable_thunkify(
|
|
|
tracer: _ProxyTracer, *, enable: bool = True
|
|
|
) -> Generator[None, None, None]:
|
|
|
"""
|
|
|
Enable thunkification inside the context manager. Thunkification prevents
|
|
|
SymNode computation from directly being traced into an FX graph; instead,
|
|
|
the compute is only added to the graph if it is actually used. This helps
|
|
|
us track SymNode compute when it is computed (since we need /something/
|
|
|
to put in the tracker) even if it is unlikely to be used.
|
|
|
"""
|
|
|
old = tracer.enable_thunkify
|
|
|
tracer.enable_thunkify = enable
|
|
|
try:
|
|
|
yield
|
|
|
finally:
|
|
|
tracer.enable_thunkify = old
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def maybe_disable_thunkify() -> Generator[None, None, None]:
|
|
|
"""Within a context, disable thunkification. See :func:`maybe_enable_thunkify`
|
|
|
for more details. This is helpful if you have a wrapper function which
|
|
|
you want to enable thunkification on, but in some segment on the inside (say,
|
|
|
the original user function), you want to disable thunkification as you know
|
|
|
it is not needed there.
|
|
|
"""
|
|
|
proxy_mode = get_proxy_mode()
|
|
|
if proxy_mode is not None:
|
|
|
with _enable_thunkify(proxy_mode.tracer, enable=False):
|
|
|
yield
|
|
|
else:
|
|
|
yield
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def maybe_enable_thunkify() -> Generator[None, None, None]:
|
|
|
"""Within this context manager, if you are doing make_fx tracing, we will thunkify
|
|
|
all SymNode compute and avoid tracing it into the graph unless it is actually needed.
|
|
|
You should prefer to avoid using this as much as possible, as lazy evaluation of
|
|
|
SymNode tracing can lead to long chains of thunks which will stack overflow
|
|
|
if you evaluate them. However, this is currently sometimes necessary as there
|
|
|
are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error
|
|
|
due to insufficient tracing of SymNode computation.
|
|
|
"""
|
|
|
proxy_mode = get_proxy_mode()
|
|
|
if proxy_mode is not None:
|
|
|
with _enable_thunkify(proxy_mode.tracer):
|
|
|
yield
|
|
|
else:
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
|
|
|
proxy.node.meta["val"] = extract_val(
|
|
|
val, include_real=(proxy.node.op == "placeholder")
|
|
|
)
|
|
|
|
|
|
with _enable_thunkify(proxy.tracer):
|
|
|
|
|
|
if is_fake(val):
|
|
|
proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
|
|
|
elif isinstance(val, Tensor) and not val.is_sparse:
|
|
|
proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val)
|
|
|
return proxy
|
|
|
|
|
|
|
|
|
def thunkify(
|
|
|
tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs
|
|
|
) -> Thunk[R]:
|
|
|
"""
|
|
|
Delays computation of f until it's called again
|
|
|
Also caches the result
|
|
|
"""
|
|
|
if tracer.enable_thunkify:
|
|
|
return Thunk(functools.partial(f, *args, **kwargs))
|
|
|
else:
|
|
|
r = f(*args, **kwargs)
|
|
|
return Thunk(lambda: r)
|
|
|
|
|
|
|
|
|
def track_tensor(
|
|
|
tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer
|
|
|
) -> None:
|
|
|
def try_set_proxy_slot(
|
|
|
outer_s: IntLikeType,
|
|
|
proxy_callable: Callable[Concatenate[PySymType, _P], Proxy],
|
|
|
*args: _P.args,
|
|
|
**kwargs: _P.kwargs,
|
|
|
) -> None:
|
|
|
assert callable(proxy_callable)
|
|
|
if isinstance(outer_s, SymInt):
|
|
|
with _enable_thunkify(tracer):
|
|
|
set_proxy_slot(
|
|
|
outer_s,
|
|
|
tracer,
|
|
|
thunkify(tracer, proxy_callable, outer_s, *args, **kwargs),
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, s in enumerate(tensor.shape):
|
|
|
try_set_proxy_slot(
|
|
|
s,
|
|
|
lambda x, i: set_meta(
|
|
|
tracer.create_proxy(
|
|
|
"call_function", torch.ops.aten.sym_size.int, (proxy, i), {}
|
|
|
),
|
|
|
x,
|
|
|
),
|
|
|
i,
|
|
|
)
|
|
|
|
|
|
if not is_sparse_any(tensor):
|
|
|
for i, s in enumerate(tensor.stride()):
|
|
|
try_set_proxy_slot(
|
|
|
s,
|
|
|
lambda x, i: set_meta(
|
|
|
tracer.create_proxy(
|
|
|
"call_function", torch.ops.aten.sym_stride.int, (proxy, i), {}
|
|
|
),
|
|
|
x,
|
|
|
),
|
|
|
i,
|
|
|
)
|
|
|
|
|
|
try_set_proxy_slot(
|
|
|
tensor.numel(),
|
|
|
lambda x: set_meta(
|
|
|
tracer.create_proxy(
|
|
|
"call_function", torch.ops.aten.sym_numel.default, (proxy,), {}
|
|
|
),
|
|
|
x,
|
|
|
),
|
|
|
)
|
|
|
if not is_sparse_any(tensor):
|
|
|
try_set_proxy_slot(
|
|
|
tensor.storage_offset(),
|
|
|
lambda x: set_meta(
|
|
|
tracer.create_proxy(
|
|
|
"call_function",
|
|
|
torch.ops.aten.sym_storage_offset.default,
|
|
|
(proxy,),
|
|
|
{},
|
|
|
),
|
|
|
x,
|
|
|
),
|
|
|
)
|
|
|
set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant))
|
|
|
|
|
|
|
|
|
_NestedProxys = Union[
|
|
|
Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"]
|
|
|
]
|
|
|
_NestedTensors = Union[
|
|
|
Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"]
|
|
|
]
|
|
|
|
|
|
|
|
|
def track_tensor_tree(
|
|
|
inner_res: T,
|
|
|
proxy_res: _NestedProxys,
|
|
|
*,
|
|
|
constant: Optional[_NestedTensors],
|
|
|
tracer: _ProxyTracer,
|
|
|
) -> T:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_set_unbacked_bindings(inner_res, proxy_res)
|
|
|
|
|
|
def wrap_with_proxy(
|
|
|
e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
|
|
|
) -> None:
|
|
|
if isinstance(e, Tensor):
|
|
|
assert isinstance(proxy, Proxy)
|
|
|
assert constant is None or isinstance(constant, Tensor)
|
|
|
track_tensor(e, proxy, tracer=tracer, constant=constant)
|
|
|
set_meta(proxy, e)
|
|
|
elif isinstance(e, py_sym_types):
|
|
|
assert isinstance(proxy, Proxy)
|
|
|
|
|
|
set_meta(proxy, e)
|
|
|
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
|
|
|
elif isinstance(e, _AnyScriptObject):
|
|
|
assert isinstance(proxy, Proxy)
|
|
|
set_proxy_slot(e, tracer, proxy)
|
|
|
set_meta(proxy, e)
|
|
|
elif isinstance(e, (tuple, list)):
|
|
|
|
|
|
if isinstance(proxy, fx.Proxy):
|
|
|
set_meta(proxy, e)
|
|
|
|
|
|
def get_constant(
|
|
|
c: Optional[_NestedTensors], idx: int
|
|
|
) -> Optional[_NestedTensors]:
|
|
|
if c is None:
|
|
|
return None
|
|
|
else:
|
|
|
assert isinstance(c, (list, tuple))
|
|
|
return c[idx]
|
|
|
|
|
|
for idx, ee in enumerate(e):
|
|
|
|
|
|
|
|
|
wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx))
|
|
|
|
|
|
elif isinstance(e, dict):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert constant is None
|
|
|
|
|
|
if isinstance(proxy, fx.Proxy):
|
|
|
set_meta(proxy, e)
|
|
|
|
|
|
for key, val in e.items():
|
|
|
wrap_with_proxy(val, proxy[key], None)
|
|
|
|
|
|
elif isinstance(e, BackwardState):
|
|
|
assert isinstance(proxy, Proxy)
|
|
|
set_meta(proxy, e)
|
|
|
e.proxy = proxy
|
|
|
else:
|
|
|
|
|
|
pass
|
|
|
|
|
|
wrap_with_proxy(inner_res, proxy_res, constant)
|
|
|
|
|
|
return inner_res
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class _ProxyTensor:
|
|
|
proxy: Proxy
|
|
|
constant: Optional[Tensor]
|
|
|
|
|
|
|
|
|
def fetch_sym_proxy(
|
|
|
tracer: _ProxyTracer,
|
|
|
) -> Callable[[PySymType], Union[bool, int, float, Proxy]]:
|
|
|
def inner(e: PySymType) -> Union[int, bool, float, Proxy]:
|
|
|
n = e.node
|
|
|
if n.constant is not None:
|
|
|
return n.constant
|
|
|
if e.node.expr.is_number:
|
|
|
if isinstance(e, SymBool):
|
|
|
return bool(e.node.expr)
|
|
|
elif isinstance(e, SymInt):
|
|
|
return int(e.node.expr)
|
|
|
return float(e.node.expr)
|
|
|
else:
|
|
|
assert isinstance(e, py_sym_types)
|
|
|
|
|
|
return get_proxy_slot(e, tracer).force()
|
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def fetch_object_proxy(
|
|
|
tracer: _ProxyTracer, t: Tensor
|
|
|
) -> Union[_ProxyTensor, Tensor]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def fetch_object_proxy(
|
|
|
tracer: _ProxyTracer, t: _AnyScriptObjectType
|
|
|
) -> Union[Proxy, _AnyScriptObjectType]: ...
|
|
|
|
|
|
|
|
|
@overload
|
|
|
def fetch_object_proxy(
|
|
|
tracer: _ProxyTracer, t: PySymType
|
|
|
) -> Union[_PySymProxyType, PySymType]: ...
|
|
|
|
|
|
|
|
|
def fetch_object_proxy(
|
|
|
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
|
|
|
) -> object:
|
|
|
return get_proxy_slot(t, tracer, t)
|
|
|
|
|
|
|
|
|
HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor)
|
|
|
|
|
|
|
|
|
def _maybe_record_pointwise_barrier(
|
|
|
func: object, proxy_mode: ProxyTorchDispatchMode
|
|
|
) -> None:
|
|
|
"""
|
|
|
Records pointwise operators in user program (non decomposed) that were output in fp16/bf16
|
|
|
"""
|
|
|
if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts:
|
|
|
return
|
|
|
|
|
|
if (
|
|
|
not isinstance(func, torch._ops.OpOverload)
|
|
|
or torch.Tag.pointwise not in func.tags
|
|
|
):
|
|
|
return
|
|
|
|
|
|
last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes)))
|
|
|
t = last_node.meta.get("val")
|
|
|
if not isinstance(t, torch.Tensor) or t.dtype not in (
|
|
|
torch.bfloat16,
|
|
|
torch.float16,
|
|
|
):
|
|
|
return
|
|
|
|
|
|
last_node.meta["low_precision_pointwise_barrier"] = True
|
|
|
|
|
|
|
|
|
def proxy_call(
|
|
|
proxy_mode: ProxyTorchDispatchMode,
|
|
|
func: OpOverload,
|
|
|
pre_dispatch: bool,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> object:
|
|
|
unrecognized_types: list[type] = []
|
|
|
flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs))
|
|
|
|
|
|
def can_handle_tensor(x: Tensor) -> bool:
|
|
|
r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer)
|
|
|
if proxy_mode._allow_fake_constant:
|
|
|
r = r or type(x) in (torch._subclasses.FakeTensor,)
|
|
|
if not r:
|
|
|
unrecognized_types.append(type(x))
|
|
|
return r
|
|
|
|
|
|
|
|
|
|
|
|
if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)):
|
|
|
not_implemented_log.debug(
|
|
|
"ProxyTensorMode tensors without proxy had unrecognized subclasses: %s",
|
|
|
unrecognized_types,
|
|
|
)
|
|
|
return NotImplemented
|
|
|
|
|
|
r = maybe_handle_decomp(proxy_mode, func, args, kwargs)
|
|
|
if r is not NotImplemented:
|
|
|
_maybe_record_pointwise_barrier(func, proxy_mode)
|
|
|
return r
|
|
|
|
|
|
|
|
|
if not pre_dispatch and func not in [
|
|
|
torch.ops.aten.size.default,
|
|
|
torch.ops.aten.stride.default,
|
|
|
torch.ops.aten.storage_offset.default,
|
|
|
]:
|
|
|
with proxy_mode:
|
|
|
r = func.decompose(*args, **kwargs)
|
|
|
if r is not NotImplemented:
|
|
|
return r
|
|
|
|
|
|
if func is torch.ops.aten.is_nonzero.default:
|
|
|
with proxy_mode:
|
|
|
torch._check(
|
|
|
args[0].numel() == 1,
|
|
|
lambda: "Boolean value of Tensor with more than one value is ambiguous",
|
|
|
)
|
|
|
return (args[0] != 0).item()
|
|
|
|
|
|
tracer = proxy_mode.tracer
|
|
|
f_flat_args_kwargs = [
|
|
|
(
|
|
|
fetch_object_proxy(tracer, x)
|
|
|
if isinstance(x, (Tensor, _AnyScriptObject))
|
|
|
else x
|
|
|
)
|
|
|
for x in flat_args_kwargs
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_constant = (
|
|
|
not any(
|
|
|
t.constant is None
|
|
|
for t in f_flat_args_kwargs
|
|
|
if isinstance(t, _ProxyTensor)
|
|
|
)
|
|
|
|
|
|
|
|
|
and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs)
|
|
|
)
|
|
|
|
|
|
if torch.Tag.data_dependent_output in func.tags:
|
|
|
|
|
|
if all_constant:
|
|
|
const_flat_args_kwargs = [
|
|
|
t.constant if isinstance(t, _ProxyTensor) else t
|
|
|
for t in f_flat_args_kwargs
|
|
|
]
|
|
|
const_args, const_kwargs = pytree.tree_unflatten(
|
|
|
const_flat_args_kwargs, spec
|
|
|
)
|
|
|
with unset_fake_temporarily():
|
|
|
return func(*const_args, **const_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only(
|
|
|
Tensor, lambda t: not is_fake(t), (args, kwargs)
|
|
|
):
|
|
|
raise RuntimeError(
|
|
|
f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
|
|
|
"It's likely that this is caused by data-dependent control flow or similar. "
|
|
|
"It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' "
|
|
|
"in your make_fx call."
|
|
|
)
|
|
|
|
|
|
proxy_flat_args_kwargs = [
|
|
|
e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs
|
|
|
]
|
|
|
proxy_flat_args_kwargs = [
|
|
|
(fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e)
|
|
|
for e in proxy_flat_args_kwargs
|
|
|
]
|
|
|
proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func is torch.ops.aten.lift_fresh.default:
|
|
|
func = torch.ops.aten.lift_fresh_copy.default
|
|
|
|
|
|
proxy_out = proxy_mode.tracer.create_proxy(
|
|
|
"call_function",
|
|
|
func,
|
|
|
proxy_args,
|
|
|
proxy_kwargs,
|
|
|
name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__),
|
|
|
)
|
|
|
|
|
|
with _enable_thunkify(proxy_mode.tracer):
|
|
|
out = func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
any_constant = any(
|
|
|
t.constant is not None
|
|
|
for t in f_flat_args_kwargs
|
|
|
if isinstance(t, _ProxyTensor)
|
|
|
)
|
|
|
|
|
|
constant = None
|
|
|
|
|
|
def tensor_numel_in_limit(t: Tensor) -> bool:
|
|
|
return t.numel() <= CONSTANT_NUMEL_LIMIT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
func is torch.ops.aten.lift_fresh_copy.default
|
|
|
and out.numel() <= CONSTANT_NUMEL_LIMIT
|
|
|
):
|
|
|
with unset_fake_temporarily():
|
|
|
assert isinstance(args[0], (Proxy, Tensor)), type(args[0])
|
|
|
constant = args[0].clone()
|
|
|
elif (
|
|
|
torch.Tag.nondeterministic_seeded not in func.tags
|
|
|
and all_constant
|
|
|
and any_constant
|
|
|
and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out)
|
|
|
):
|
|
|
|
|
|
with unset_fake_temporarily():
|
|
|
const_flat_args_kwargs = [
|
|
|
t.constant if isinstance(t, _ProxyTensor) else t
|
|
|
for t in f_flat_args_kwargs
|
|
|
]
|
|
|
const_args, const_kwargs = pytree.tree_unflatten(
|
|
|
const_flat_args_kwargs, spec
|
|
|
)
|
|
|
constant = func(*const_args, **const_kwargs)
|
|
|
else:
|
|
|
constant = None
|
|
|
|
|
|
track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer)
|
|
|
_maybe_record_pointwise_barrier(func, proxy_mode)
|
|
|
return out
|
|
|
|
|
|
|
|
|
class _SymNodeDict:
|
|
|
"""
|
|
|
Wrapper around a dictionary that will hash SymInts with their nodes
|
|
|
"""
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
self.sym_node_dict: dict[PySymType, _PySymProxyType] = {}
|
|
|
|
|
|
def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None:
|
|
|
self.sym_node_dict[key.node] = value
|
|
|
|
|
|
def __getitem__(self, key: PySymType) -> _PySymProxyType:
|
|
|
return self.sym_node_dict[key.node]
|
|
|
|
|
|
def __contains__(self, key: PySymType) -> bool:
|
|
|
return key.node in self.sym_node_dict
|
|
|
|
|
|
def get(
|
|
|
self, key: PySymType, default: Optional[_PySymProxyType] = None
|
|
|
) -> _PySymProxyType:
|
|
|
|
|
|
|
|
|
return self.sym_node_dict.get(key.node, default)
|
|
|
|
|
|
def __iter__(self) -> Any:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
return len(self.sym_node_dict)
|
|
|
|
|
|
|
|
|
class PythonKeyTracer(Tracer):
|
|
|
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
|
|
symnode_tracker: _SymNodeDict
|
|
|
sympy_expr_tracker: dict[sympy.Symbol, object]
|
|
|
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
|
|
torch_fn_counts: dict[OpOverload, int]
|
|
|
enable_thunkify: bool = False
|
|
|
stack_trace: bool = False
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
super().__init__(autowrap_modules=())
|
|
|
self.tensor_tracker = WeakTensorKeyDictionary()
|
|
|
self.symnode_tracker = _SymNodeDict()
|
|
|
self.script_object_tracker = WeakIdKeyDictionary(
|
|
|
dict=None, ref_type=_WeakHashRef
|
|
|
)
|
|
|
self.sympy_expr_tracker = dict()
|
|
|
|
|
|
|
|
|
self.torch_fn_metadata = None
|
|
|
|
|
|
|
|
|
self.torch_fn_counts = {}
|
|
|
self.enable_thunkify = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_module(
|
|
|
self,
|
|
|
m: Module,
|
|
|
forward: Callable[..., Any],
|
|
|
args: tuple[Any, ...],
|
|
|
kwargs: dict[str, Any],
|
|
|
) -> Any:
|
|
|
return forward(*args, **kwargs)
|
|
|
|
|
|
|
|
|
def getattr(
|
|
|
self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
|
|
|
) -> object:
|
|
|
return attr_val
|
|
|
|
|
|
def create_arg(self, a: object) -> fx.node.Node:
|
|
|
if isinstance(a, torch.nn.Parameter):
|
|
|
for n, p in self.root.named_parameters():
|
|
|
if a is p:
|
|
|
return self.create_node("get_attr", n, (), {})
|
|
|
|
|
|
qualname = self.get_fresh_qualname("_param_constant")
|
|
|
setattr(self.root, qualname, a)
|
|
|
|
|
|
return self.create_node("get_attr", qualname, (), {})
|
|
|
elif isinstance(a, py_sym_types):
|
|
|
assert a.node.constant is not None
|
|
|
return a.node.constant
|
|
|
return super().create_arg(a)
|
|
|
|
|
|
@overload
|
|
|
def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ...
|
|
|
|
|
|
@overload
|
|
|
def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ...
|
|
|
|
|
|
@overload
|
|
|
def unwrap_proxy(
|
|
|
self, e: _AnyScriptObjectType
|
|
|
) -> Union[Proxy, _AnyScriptObjectType]: ...
|
|
|
|
|
|
def unwrap_proxy(self, e: T) -> object:
|
|
|
if isinstance(e, Tensor):
|
|
|
return get_proxy_slot(e, self, e, lambda x: x.proxy)
|
|
|
elif isinstance(e, py_sym_types):
|
|
|
return get_proxy_slot(e, self, e, lambda e: e.force())
|
|
|
elif isinstance(e, _AnyScriptObject):
|
|
|
return get_proxy_slot(e, self, e)
|
|
|
else:
|
|
|
return e
|
|
|
|
|
|
def create_node(
|
|
|
self,
|
|
|
kind: str,
|
|
|
target: Target,
|
|
|
args: tuple[Argument, ...],
|
|
|
kwargs: dict[str, Argument],
|
|
|
name: Optional[str] = None,
|
|
|
type_expr: Optional[Any] = None,
|
|
|
) -> torch.fx.Node:
|
|
|
node = super().create_node(kind, target, args, kwargs, name, type_expr)
|
|
|
|
|
|
|
|
|
if (
|
|
|
self.stack_trace
|
|
|
and "stack_trace" not in node.meta
|
|
|
and node.op not in ["placeholder", "output"]
|
|
|
):
|
|
|
user_frame_summary = CapturedTraceback.extract().summary()
|
|
|
if user_frame_summary:
|
|
|
|
|
|
|
|
|
stack_trace = [
|
|
|
frame
|
|
|
for frame in user_frame_summary
|
|
|
if (
|
|
|
frame.name == "forward"
|
|
|
or frame.filename.endswith("torch/__init__.py")
|
|
|
)
|
|
|
]
|
|
|
|
|
|
|
|
|
stack_trace = [
|
|
|
frame
|
|
|
for frame in stack_trace
|
|
|
if not frame.filename.endswith(
|
|
|
("fx/_symbolic_trace.py", "export/_trace.py")
|
|
|
)
|
|
|
]
|
|
|
if (
|
|
|
stack_trace
|
|
|
):
|
|
|
stack_trace = traceback.StackSummary.from_list(stack_trace)
|
|
|
node.meta["stack_trace"] = "".join(stack_trace.format()).strip()
|
|
|
|
|
|
if kind == "get_attr":
|
|
|
assert isinstance(target, str)
|
|
|
attr = getattr(self.root, target)
|
|
|
if isinstance(attr, torch.Tensor):
|
|
|
with disable_proxy_modes_tracing():
|
|
|
node.meta["val"] = extract_val(attr)
|
|
|
|
|
|
def map_fn(v: Any) -> Optional[_ExtractValType]:
|
|
|
if not isinstance(v, torch.fx.Node) or "val" not in v.meta:
|
|
|
return None
|
|
|
val = v.meta["val"]
|
|
|
|
|
|
|
|
|
if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor):
|
|
|
return None
|
|
|
return extract_val(v.meta["val"])
|
|
|
|
|
|
if _should_save_eager_input_vals(target, (args, kwargs)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn)
|
|
|
node.meta["eager_input_vals"] = (arg_inp, kwarg_inp)
|
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
def _should_save_eager_input_vals(
|
|
|
target: Any,
|
|
|
args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None,
|
|
|
) -> bool:
|
|
|
from torch._higher_order_ops.invoke_subgraph import InvokeSubgraphHOP
|
|
|
|
|
|
if not callable(target):
|
|
|
return False
|
|
|
if isinstance(
|
|
|
target,
|
|
|
(
|
|
|
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional,
|
|
|
torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
|
|
|
InvokeSubgraphHOP,
|
|
|
),
|
|
|
):
|
|
|
return True
|
|
|
if args_kwargs is not None and (
|
|
|
target is torch.ops.higher_order.auto_functionalized
|
|
|
or target is torch.ops.higher_order.auto_functionalized_v2
|
|
|
):
|
|
|
args = args_kwargs[0]
|
|
|
assert isinstance(
|
|
|
args[0], (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
|
|
|
)
|
|
|
return _should_save_eager_input_vals(args[0], None)
|
|
|
if target is torch.ops.higher_order.with_effects:
|
|
|
|
|
|
|
|
|
return False
|
|
|
if isinstance(target, torch._ops.HigherOrderOperator):
|
|
|
if pytree.tree_any(_should_save_eager_input_vals, args_kwargs):
|
|
|
raise RuntimeError(
|
|
|
f"NYI: The HOP {target} has an input that is an OpOverload that "
|
|
|
f"needs exact strides. We probably need special logic to "
|
|
|
f"propagate the FakeTensor vals. Please file an issue."
|
|
|
)
|
|
|
if isinstance(target, torch._ops.OpOverload):
|
|
|
from torch._library.utils import get_layout_constraint_tag
|
|
|
|
|
|
return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides
|
|
|
return False
|
|
|
|
|
|
|
|
|
def _make_temp_remove_mode_context_manager(
|
|
|
mode_ty: type[TorchFunctionMode],
|
|
|
) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]:
|
|
|
@contextmanager
|
|
|
def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]:
|
|
|
from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode
|
|
|
|
|
|
temp_elements = []
|
|
|
removed_mode = None
|
|
|
|
|
|
while _len_torch_function_stack() > 0:
|
|
|
mode = _pop_mode()
|
|
|
if isinstance(mode, mode_ty):
|
|
|
removed_mode = mode
|
|
|
break
|
|
|
else:
|
|
|
temp_elements.append(mode)
|
|
|
|
|
|
for mode in reversed(temp_elements):
|
|
|
_push_mode(mode)
|
|
|
|
|
|
try:
|
|
|
yield removed_mode
|
|
|
|
|
|
finally:
|
|
|
if removed_mode is not None:
|
|
|
count = len(temp_elements)
|
|
|
while count > 0:
|
|
|
mode = _pop_mode()
|
|
|
count -= 1
|
|
|
|
|
|
temp_elements.append(removed_mode)
|
|
|
|
|
|
for mode in reversed(temp_elements):
|
|
|
_push_mode(mode)
|
|
|
|
|
|
return context_manager_fn
|
|
|
|
|
|
|
|
|
@torch._disable_dynamo
|
|
|
def dispatch_trace(
|
|
|
root: Union[Module, Callable],
|
|
|
tracer: Tracer,
|
|
|
concrete_args: Optional[tuple[Any, ...]] = None,
|
|
|
) -> GraphModule:
|
|
|
graph = tracer.trace(root, concrete_args)
|
|
|
|
|
|
|
|
|
def impure_pred(n: fx.Node) -> bool:
|
|
|
from .symbolic_shapes import is_accessor_node
|
|
|
|
|
|
|
|
|
if n.is_impure():
|
|
|
return True
|
|
|
|
|
|
|
|
|
if is_accessor_node(n):
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
isinstance(n.meta.get("val"), py_sym_types)
|
|
|
and
|
|
|
|
|
|
all(
|
|
|
isinstance(a.meta.get("val"), py_sym_types)
|
|
|
for a in n.args
|
|
|
if isinstance(a, fx.Node)
|
|
|
)
|
|
|
):
|
|
|
return False
|
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
graph.eliminate_dead_code(impure_pred)
|
|
|
from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints
|
|
|
|
|
|
dedupe_symints(graph)
|
|
|
name = root.__class__.__name__ if isinstance(root, Module) else root.__name__
|
|
|
return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name)
|
|
|
|
|
|
|
|
|
def wrap_key(
|
|
|
f: Callable[[Unpack[_Ts]], R],
|
|
|
tensors: tuple[Unpack[_Ts]],
|
|
|
tracer: _ProxyTracer,
|
|
|
pre_dispatch: bool,
|
|
|
) -> Callable[_P, R]:
|
|
|
flat_tensors, _tensors_spec = pytree.tree_flatten(tensors)
|
|
|
|
|
|
@functools.wraps(f)
|
|
|
def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R:
|
|
|
flat_proxies, _proxies_spec = pytree.tree_flatten(proxies)
|
|
|
assert len(flat_proxies) == len(flat_tensors)
|
|
|
with disable_proxy_modes_tracing() as m:
|
|
|
assert isinstance(m, ProxyTorchDispatchMode)
|
|
|
track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer)
|
|
|
|
|
|
def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]:
|
|
|
return get_proxy_slot(t, tracer, t, lambda x: x.proxy)
|
|
|
|
|
|
out = f(*tensors)
|
|
|
out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out)
|
|
|
out = pytree.tree_map_only(
|
|
|
_AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out
|
|
|
)
|
|
|
|
|
|
def get_sym_proxy_slot(t: PySymType) -> Proxy:
|
|
|
return get_proxy_slot(t, tracer).force()
|
|
|
|
|
|
out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out)
|
|
|
return out
|
|
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
|
|
|
|
ORIGINAL_ATEN: Optional[object] = None
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]:
|
|
|
global ORIGINAL_ATEN
|
|
|
if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta():
|
|
|
ORIGINAL_ATEN = func
|
|
|
fx_traceback.current_meta["original_aten"] = func
|
|
|
try:
|
|
|
yield
|
|
|
finally:
|
|
|
ORIGINAL_ATEN = None
|
|
|
fx_traceback.current_meta["original_aten"] = None
|
|
|
else:
|
|
|
yield
|
|
|
|
|
|
|
|
|
class TorchFunctionMetadataMode(TorchFunctionMode):
|
|
|
def __init__(self, tracer: _ProxyTracer) -> None:
|
|
|
self.tracer = tracer
|
|
|
|
|
|
def __torch_function__(
|
|
|
self,
|
|
|
func: OpOverload,
|
|
|
types: tuple[torch._C._TensorMeta, ...],
|
|
|
args: tuple[object, ...] = (),
|
|
|
kwargs: Optional[dict[str, object]] = None,
|
|
|
) -> object:
|
|
|
kwargs = kwargs or {}
|
|
|
self.tracer.torch_fn_metadata = func
|
|
|
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager(
|
|
|
TorchFunctionMetadataMode
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|
|
def __init__(self, tracer: _ProxyTracer) -> None:
|
|
|
self.tracer = tracer
|
|
|
|
|
|
|
|
|
|
|
|
self.enter_autocast_nodes: list[torch.fx.Node] = []
|
|
|
|
|
|
def __torch_function__(
|
|
|
self,
|
|
|
func: Union[OpOverload, Callable],
|
|
|
types: tuple[torch._C._TensorMeta, ...],
|
|
|
args: tuple[object, ...] = (),
|
|
|
kwargs: Optional[dict[str, object]] = None,
|
|
|
) -> object:
|
|
|
kwargs = kwargs or {}
|
|
|
if func in _side_effectful_need_to_be_preserved_pre_dispatch:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func == torch.amp.autocast_mode._exit_autocast:
|
|
|
enter_node = self.enter_autocast_nodes.pop()
|
|
|
args = (enter_node,)
|
|
|
node = self.tracer.create_node("call_function", func, args, {})
|
|
|
if func == torch.amp.autocast_mode._enter_autocast:
|
|
|
self.enter_autocast_nodes.append(node)
|
|
|
if func in [
|
|
|
torch._C._set_grad_enabled,
|
|
|
torch.amp.autocast_mode._enter_autocast,
|
|
|
torch.amp.autocast_mode._exit_autocast,
|
|
|
]:
|
|
|
node.meta["val"] = None
|
|
|
return node
|
|
|
|
|
|
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager(
|
|
|
PreDispatchTorchFunctionMode
|
|
|
)
|
|
|
|
|
|
|
|
|
class ProxyTorchDispatchMode(TorchDispatchMode):
|
|
|
|
|
|
@property
|
|
|
def enable_tracing(self) -> bool:
|
|
|
return True
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
tracer: _ProxyTracer,
|
|
|
tracing_mode: str,
|
|
|
pre_dispatch: bool = False,
|
|
|
_allow_fake_constant: bool = False,
|
|
|
_error_on_data_dependent_ops: bool = True,
|
|
|
) -> None:
|
|
|
dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None
|
|
|
super().__init__(dk)
|
|
|
self.tracer = tracer
|
|
|
self.tracing_mode = tracing_mode
|
|
|
self.pre_dispatch = pre_dispatch
|
|
|
self._allow_fake_constant = _allow_fake_constant
|
|
|
self._error_on_data_dependent_ops = _error_on_data_dependent_ops
|
|
|
|
|
|
|
|
|
self._mode_key = torch._C._TorchDispatchModeKey.PROXY
|
|
|
|
|
|
|
|
|
|
|
|
self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
|
|
|
self.decomp_layers: int = 0
|
|
|
from torch._inductor import config
|
|
|
|
|
|
self.emulate_precision_casts: bool = config.emulate_precision_casts
|
|
|
|
|
|
@count
|
|
|
def __torch_dispatch__(
|
|
|
self,
|
|
|
func: OpOverload,
|
|
|
types: tuple[torch._C._TensorMeta, ...],
|
|
|
args: tuple[object, ...] = (),
|
|
|
kwargs: Optional[dict[str, object]] = None,
|
|
|
) -> object:
|
|
|
with set_original_aten_op(func):
|
|
|
kwargs = kwargs or {}
|
|
|
|
|
|
if func in (prim.device.default,):
|
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
|
|
|
|
|
|
def __enter__(self) -> Self:
|
|
|
|
|
|
maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
|
|
|
self.enter_stack.append(maybe_prev_proxy_mode)
|
|
|
return super().__enter__()
|
|
|
|
|
|
def __exit__(
|
|
|
self,
|
|
|
exc_type: Optional[type[BaseException]],
|
|
|
exc_value: Optional[BaseException],
|
|
|
traceback: Optional[types.TracebackType],
|
|
|
) -> Optional[bool]:
|
|
|
b = super().__exit__(exc_type, exc_value, traceback)
|
|
|
|
|
|
|
|
|
mb_previous_proxy_mode = self.enter_stack.pop()
|
|
|
if mb_previous_proxy_mode is not None:
|
|
|
_push_mode(mb_previous_proxy_mode)
|
|
|
|
|
|
return b
|
|
|
|
|
|
@classmethod
|
|
|
def is_infra_mode(cls) -> bool:
|
|
|
return True
|
|
|
|
|
|
def _compute_proxy(
|
|
|
self, func: OpOverload, args: tuple[object, ...], out: PySymType
|
|
|
) -> Proxy:
|
|
|
|
|
|
n_args: tuple[object, ...]
|
|
|
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
|
|
n_args = (
|
|
|
tuple(
|
|
|
(
|
|
|
get_proxy_slot(a, self.tracer).force().node
|
|
|
if isinstance(a, py_sym_types)
|
|
|
else a
|
|
|
)
|
|
|
for a in args[0]
|
|
|
),
|
|
|
)
|
|
|
else:
|
|
|
n_args = tuple(
|
|
|
(
|
|
|
get_proxy_slot(a, self.tracer).force().node
|
|
|
if isinstance(a, py_sym_types)
|
|
|
else a
|
|
|
)
|
|
|
for a in args
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
n_out = self.tracer.create_node("call_function", func, n_args, {})
|
|
|
p_out = fx.Proxy(n_out, self.tracer)
|
|
|
set_meta(p_out, out)
|
|
|
return p_out
|
|
|
|
|
|
def __sym_dispatch__(
|
|
|
self,
|
|
|
func: OpOverload,
|
|
|
types: tuple[torch._C._TensorMeta, ...],
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> object:
|
|
|
|
|
|
|
|
|
if func == operator.mul:
|
|
|
if isinstance(args[1], int) and args[1] == 1:
|
|
|
return args[0]
|
|
|
elif isinstance(args[0], int) and args[0] == 1:
|
|
|
return args[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert not kwargs
|
|
|
out = func(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(out, py_sym_types):
|
|
|
p_out_thunk = thunkify(
|
|
|
self.tracer, self._compute_proxy, func=func, args=args, out=out
|
|
|
)
|
|
|
set_proxy_slot(out, self.tracer, p_out_thunk)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer):
|
|
|
script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy]
|
|
|
symnode_tracker: MutableMapping[PySymType, _PySymProxyType]
|
|
|
tensor_tracker: MutableMapping[Tensor, _ProxyTensor]
|
|
|
sympy_expr_tracker: dict[sympy.Symbol, object]
|
|
|
torch_fn_metadata: Optional[OpOverload]
|
|
|
torch_fn_counts: dict[OpOverload, int]
|
|
|
enable_thunkify: bool = False
|
|
|
|
|
|
def __init__(self, graph: fx.graph.Graph) -> None:
|
|
|
super().__init__(graph)
|
|
|
self.symnode_tracker = weakref.WeakKeyDictionary()
|
|
|
self.tensor_tracker = WeakTensorKeyDictionary()
|
|
|
self.sympy_expr_tracker = {}
|
|
|
self.script_object_tracker = WeakIdKeyDictionary(
|
|
|
dict=None, ref_type=_WeakHashRef
|
|
|
)
|
|
|
|
|
|
self.torch_fn_metadata = None
|
|
|
|
|
|
|
|
|
self.torch_fn_counts = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DecompositionInterpreter(fx.Interpreter):
|
|
|
def __init__(
|
|
|
self,
|
|
|
module: fx.GraphModule,
|
|
|
new_graph: fx.Graph,
|
|
|
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
|
|
**kwargs: object,
|
|
|
) -> None:
|
|
|
super().__init__(module, **kwargs)
|
|
|
self.new_graph = new_graph
|
|
|
self.tracer = _GraphAppendingTracerEx(self.new_graph)
|
|
|
|
|
|
self.decomposition_table = decomposition_table or {}
|
|
|
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
|
|
|
|
|
def placeholder(
|
|
|
self,
|
|
|
target: str,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> object:
|
|
|
out = super().placeholder(target, args, kwargs)
|
|
|
proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer)
|
|
|
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
|
|
|
|
|
return out
|
|
|
|
|
|
def get_attr(
|
|
|
self,
|
|
|
target: str,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> object:
|
|
|
out = super().get_attr(target, args, kwargs)
|
|
|
proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer)
|
|
|
track_tensor_tree(out, proxy, constant=None, tracer=self.tracer)
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def output(
|
|
|
self,
|
|
|
target: str,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> object:
|
|
|
out = super().output(target, args, kwargs)
|
|
|
|
|
|
def get_proxy_node(x: _ProxyTensor) -> fx.node.Node:
|
|
|
return x.proxy.node
|
|
|
|
|
|
def unwrap(e: Tensor) -> Union[Tensor, fx.Node]:
|
|
|
return get_proxy_slot(e, self.tracer, e, get_proxy_node)
|
|
|
|
|
|
self.new_graph.output(pytree.tree_map(unwrap, out))
|
|
|
return out
|
|
|
|
|
|
def run(self, *args: object, **kwargs: object) -> object:
|
|
|
|
|
|
|
|
|
with decompose(self.decomposition_table), self.mode:
|
|
|
return super().run(*args, **kwargs)
|
|
|
|
|
|
|
|
|
def wrapper_and_args_for_make_fx(
|
|
|
func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object]
|
|
|
) -> tuple[Callable[[list[object]], R], list[object]]:
|
|
|
|
|
|
|
|
|
flat_args, spec = pytree.tree_flatten((args, kwargs))
|
|
|
|
|
|
def wrapped(flat_args: list[object]) -> R:
|
|
|
fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec)
|
|
|
return func(*fn_args, **fn_kwargs)
|
|
|
|
|
|
return wrapped, flat_args
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def disable_autocast_cache() -> Generator[None, None, None]:
|
|
|
old_value = torch.is_autocast_cache_enabled()
|
|
|
torch.set_autocast_cache_enabled(False)
|
|
|
try:
|
|
|
yield
|
|
|
finally:
|
|
|
torch.set_autocast_cache_enabled(old_value)
|
|
|
|
|
|
|
|
|
class _ModuleNotInstalledAsSubmoduleError(NameError):
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class _AttrProxy:
|
|
|
def reset_proxy_mapping(self, base: Module, path: str) -> None:
|
|
|
pass
|
|
|
|
|
|
|
|
|
class _ModuleStackTracer(PythonKeyTracer):
|
|
|
r"""Customized version of PythonKeyTracer that retains module stack
|
|
|
information in node.meta["nn_module_stack"].
|
|
|
|
|
|
FX symbolic trace actually does this already, but it relies on `self.root`
|
|
|
being the actual module being traced. Since make_fx traces a lambda of our
|
|
|
creation, things don't work properly.
|
|
|
|
|
|
So for this version we hold onto a reference to the original module
|
|
|
(scope_root) and use that to match the path. Also when we see,
|
|
|
A
|
|
|
/ \
|
|
|
B C
|
|
|
\ /
|
|
|
D
|
|
|
we want to record the path as A.B.D by recording only one path.
|
|
|
See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605
|
|
|
"""
|
|
|
|
|
|
def __init__(self, scope_root: GraphModule) -> None:
|
|
|
super().__init__()
|
|
|
self.stack_trace = True
|
|
|
self.scope_root = scope_root
|
|
|
self.enable_attr_proxy = False
|
|
|
self.submodule_paths = {}
|
|
|
for name, m in self.scope_root.named_modules(remove_duplicate=False):
|
|
|
if m in self.submodule_paths:
|
|
|
log.info(
|
|
|
"Shared module found between %s and %s, AttrProxy is enabled.",
|
|
|
self.submodule_paths[m],
|
|
|
name,
|
|
|
)
|
|
|
self.enable_attr_proxy = True
|
|
|
else:
|
|
|
self.submodule_paths[m] = name
|
|
|
|
|
|
self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary()
|
|
|
self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary()
|
|
|
self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary()
|
|
|
self.counter = 0
|
|
|
|
|
|
self.module_id_cache = defaultdict(list)
|
|
|
for name, mod in self.scope_root.named_modules(remove_duplicate=False):
|
|
|
self.module_id_cache[id(mod)].append(name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracer = self
|
|
|
|
|
|
class AttrProxy(_AttrProxy):
|
|
|
def __init__(self, base: Union[Module, _AttrProxy], path: str) -> None:
|
|
|
if isinstance(base, _AttrProxy):
|
|
|
base = base.get_base()
|
|
|
|
|
|
assert isinstance(base, Module)
|
|
|
|
|
|
|
|
|
|
|
|
self.__class__ = type(
|
|
|
base.__class__.__name__,
|
|
|
(self.__class__, base.__class__),
|
|
|
{},
|
|
|
)
|
|
|
self.__dict__ = base.__dict__
|
|
|
self.__class__.__module__ = base.__class__.__module__
|
|
|
self.__class__.__qualname__ = base.__class__.__qualname__
|
|
|
|
|
|
|
|
|
tracer.proxy_paths[self] = path
|
|
|
tracer.proxy_modules[self] = base
|
|
|
|
|
|
def __getattr__(self, name: str) -> AttrProxy:
|
|
|
assert isinstance(self, Module)
|
|
|
|
|
|
|
|
|
|
|
|
attr_val = super().__getattr__(name)
|
|
|
if not isinstance(attr_val, Module):
|
|
|
return attr_val
|
|
|
|
|
|
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
|
|
|
|
|
|
def get_base(self) -> Module:
|
|
|
return tracer.proxy_modules[self]
|
|
|
|
|
|
def __getitem__(self, idx: Union[int, slice]) -> AttrProxy:
|
|
|
if isinstance(idx, slice):
|
|
|
if isinstance(self, torch.nn.Sequential):
|
|
|
|
|
|
res = torch.nn.Sequential(
|
|
|
OrderedDict(list(self._modules.items())[idx])
|
|
|
)
|
|
|
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
|
|
elif isinstance(self, torch.nn.ModuleList):
|
|
|
|
|
|
res = torch.nn.ModuleList(list(self._modules.values())[idx])
|
|
|
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
|
|
|
|
|
return super().__getitem__(idx)
|
|
|
|
|
|
@property
|
|
|
def _modules(self) -> dict[str, AttrProxy]:
|
|
|
assert "_modules" in self.__dict__
|
|
|
submodules = self.__dict__["_modules"]
|
|
|
assert isinstance(submodules, dict)
|
|
|
return {
|
|
|
key: (
|
|
|
AttrProxy(value, tracer.proxy_paths[self] + "." + str(key))
|
|
|
if value is not None
|
|
|
else value
|
|
|
)
|
|
|
for key, value in submodules.items()
|
|
|
}
|
|
|
|
|
|
self.proxy_type = AttrProxy
|
|
|
|
|
|
def path_of_module(self, mod: Module) -> str:
|
|
|
"""
|
|
|
Use tracked access path during tracing instead of the default BFS behavior.
|
|
|
Still use all the possible module paths to verify the result.
|
|
|
"""
|
|
|
if mod is self.scope_root:
|
|
|
return ""
|
|
|
|
|
|
if isinstance(mod, _AttrProxy):
|
|
|
return self.proxy_paths[mod]
|
|
|
|
|
|
try:
|
|
|
return Tracer.path_of_module(self, mod)
|
|
|
except NameError as e:
|
|
|
raise _ModuleNotInstalledAsSubmoduleError from e
|
|
|
|
|
|
def getattr(
|
|
|
self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy]
|
|
|
) -> object:
|
|
|
if (
|
|
|
not isinstance(attr_val, Module)
|
|
|
or isinstance(attr_val, fx.GraphModule)
|
|
|
or not self.enable_attr_proxy
|
|
|
):
|
|
|
return super().getattr(attr, attr_val, parameter_proxy_cache)
|
|
|
if isinstance(attr_val, _AttrProxy):
|
|
|
return attr_val
|
|
|
|
|
|
|
|
|
if attr_val not in self.attr_proxy_map:
|
|
|
self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr)
|
|
|
else:
|
|
|
self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr)
|
|
|
return self.attr_proxy_map[attr_val]
|
|
|
|
|
|
def trace(
|
|
|
self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]]
|
|
|
) -> fx.Graph:
|
|
|
res = super().trace(root, concrete_args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = []
|
|
|
for name, module in self.root.named_modules():
|
|
|
if module in self.proxy_modules:
|
|
|
proxy_module_names_to_be_replaced.append((name, module))
|
|
|
|
|
|
def _delete_proxy_attr(obj: Module, target: str) -> bool:
|
|
|
|
|
|
|
|
|
atoms = target.split(".")
|
|
|
path, target_submod = atoms[:-1], atoms[-1]
|
|
|
assert isinstance(obj, Module)
|
|
|
mod = obj
|
|
|
|
|
|
|
|
|
for item in path:
|
|
|
if not hasattr(mod, item):
|
|
|
return False
|
|
|
|
|
|
mod = getattr(mod, item)
|
|
|
|
|
|
if not isinstance(mod, (_AttrProxy, Module)):
|
|
|
return False
|
|
|
|
|
|
if not hasattr(mod, target_submod):
|
|
|
return False
|
|
|
|
|
|
|
|
|
if not isinstance(getattr(mod, target_submod), _AttrProxy):
|
|
|
return False
|
|
|
|
|
|
delattr(mod, target_submod)
|
|
|
return True
|
|
|
|
|
|
for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced:
|
|
|
_delete_proxy_attr(self.root, proxy_module_name)
|
|
|
actual_module = self.proxy_modules[proxy_module]
|
|
|
_assign_attr(actual_module, self.root, proxy_module_name)
|
|
|
|
|
|
return res
|
|
|
|
|
|
def call_module(
|
|
|
self,
|
|
|
m: Module,
|
|
|
forward: Callable,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> None:
|
|
|
"""PythonKeyTracer overrides call_module to avoid the scope handling,
|
|
|
but we actually want it.
|
|
|
"""
|
|
|
from torch._dynamo import OptimizedModule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(m, (OptimizedModule, GraphModule)):
|
|
|
return forward(*args, **kwargs)
|
|
|
|
|
|
try:
|
|
|
return Tracer.call_module(self, m, forward, args, kwargs)
|
|
|
except _ModuleNotInstalledAsSubmoduleError:
|
|
|
log.debug(
|
|
|
"Unable to find the path of the module %s. "
|
|
|
"This might be because the module was not properly registered "
|
|
|
"as a submodule, which is not good practice. We will trace "
|
|
|
"through the module without recording stack information.",
|
|
|
str(m),
|
|
|
)
|
|
|
return forward(*args, **kwargs)
|
|
|
|
|
|
def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool:
|
|
|
return False
|
|
|
|
|
|
def create_node(self, *args: object, **kwargs: object) -> fx.node.Node:
|
|
|
"""
|
|
|
Create node and add on metadata.
|
|
|
Add nn_module_stack here instead of TracerBase,
|
|
|
since calls to make_fx() might not want to record module stack metadata.
|
|
|
Add torch_fn by looking at torch_fn_metadata and torch_fn_counts.
|
|
|
Add stack_trace by filtering out forward() stack frames.
|
|
|
"""
|
|
|
node = super().create_node(*args, **kwargs)
|
|
|
|
|
|
|
|
|
if node.op not in ["placeholder", "output"]:
|
|
|
if "nn_module_stack" not in node.meta:
|
|
|
node.meta["nn_module_stack"] = self.module_stack
|
|
|
|
|
|
for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items():
|
|
|
if isinstance(mod_cls, type):
|
|
|
node.meta["nn_module_stack"][key] = (
|
|
|
fqn,
|
|
|
mod_cls.__module__ + "." + mod_cls.__qualname__,
|
|
|
)
|
|
|
|
|
|
|
|
|
if (
|
|
|
node.op == "call_function"
|
|
|
and self.torch_fn_metadata is not None
|
|
|
and "torch_fn" not in node.meta
|
|
|
):
|
|
|
node.meta["torch_fn"] = (
|
|
|
f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}",
|
|
|
f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}",
|
|
|
)
|
|
|
|
|
|
return node
|
|
|
|
|
|
|
|
|
class _MakefxTracer:
|
|
|
def __init__(
|
|
|
self,
|
|
|
decomposition_table: Optional[Mapping[OpOverload, Callable]],
|
|
|
tracing_mode: str,
|
|
|
_allow_non_fake_inputs: bool,
|
|
|
pre_dispatch: bool,
|
|
|
record_module_stack: bool,
|
|
|
_allow_fake_constant: bool,
|
|
|
_error_on_data_dependent_ops: bool,
|
|
|
stack_trace: bool = False,
|
|
|
) -> None:
|
|
|
|
|
|
|
|
|
self.decomposition_table: dict[OpOverload, Callable] = dict(
|
|
|
decomposition_table or {}
|
|
|
)
|
|
|
self.decomposition_table.setdefault(
|
|
|
torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel
|
|
|
)
|
|
|
self.tracing_mode: str = tracing_mode
|
|
|
self._allow_non_fake_inputs: bool = _allow_non_fake_inputs
|
|
|
self.pre_dispatch: bool = pre_dispatch
|
|
|
self.record_module_stack: bool = record_module_stack
|
|
|
self._allow_fake_constant: bool = _allow_fake_constant
|
|
|
self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.fake_tensor_mode: Optional[FakeTensorMode] = None
|
|
|
self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext()
|
|
|
self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = (
|
|
|
nullcontext()
|
|
|
)
|
|
|
self.fx_tracer: Optional[PythonKeyTracer] = None
|
|
|
self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext()
|
|
|
self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = (
|
|
|
nullcontext()
|
|
|
)
|
|
|
self.stack_trace = stack_trace
|
|
|
|
|
|
def _checkpoint_modes(self) -> list[Any]:
|
|
|
return [
|
|
|
self.fake_tensor_mode,
|
|
|
self.proxy_mode,
|
|
|
self.proxy_function_mode,
|
|
|
self.fx_tracer,
|
|
|
self.python_dispatcher_mode,
|
|
|
self.torch_fn_metadata_mode,
|
|
|
]
|
|
|
|
|
|
def _restore_modes(
|
|
|
self,
|
|
|
prev_fake_tensor_mode: Optional[FakeTensorMode],
|
|
|
prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode],
|
|
|
prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode],
|
|
|
prev_fx_tracer: Optional[PythonKeyTracer],
|
|
|
prev_python_dispatcher_mode: Union[nullcontext, Any],
|
|
|
prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode],
|
|
|
) -> None:
|
|
|
self.fake_tensor_mode = prev_fake_tensor_mode
|
|
|
self.proxy_mode = prev_proxy_mode
|
|
|
self.proxy_function_mode = prev_proxy_function_mode
|
|
|
self.fx_tracer = prev_fx_tracer
|
|
|
self.python_dispatcher_mode = prev_python_dispatcher_mode
|
|
|
self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode
|
|
|
|
|
|
@contextmanager
|
|
|
def _init_modes_from_inputs(
|
|
|
self, f: Callable, args: tuple[object, ...]
|
|
|
) -> Generator[None, None, None]:
|
|
|
prev_modes = self._checkpoint_modes()
|
|
|
try:
|
|
|
|
|
|
from .symbolic_shapes import ShapeEnv
|
|
|
|
|
|
if hasattr(f, "_orig_mod") and self.record_module_stack:
|
|
|
scope_root = f._orig_mod
|
|
|
|
|
|
self.fx_tracer = _ModuleStackTracer(scope_root)
|
|
|
else:
|
|
|
self.fx_tracer = PythonKeyTracer()
|
|
|
self.fx_tracer.stack_trace = self.stack_trace
|
|
|
|
|
|
if self.tracing_mode == "fake":
|
|
|
import torch._dynamo
|
|
|
|
|
|
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
|
|
|
if fake_tensor_mode is None:
|
|
|
import torch._functorch.config as _config
|
|
|
|
|
|
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
|
|
fake_tensor_mode = FakeTensorMode(
|
|
|
allow_fallback_kernels=True,
|
|
|
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
|
|
shape_env=ShapeEnv(),
|
|
|
static_shapes=True,
|
|
|
)
|
|
|
self.fake_tensor_mode = fake_tensor_mode
|
|
|
elif self.tracing_mode == "symbolic":
|
|
|
import torch._dynamo
|
|
|
|
|
|
fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args)
|
|
|
if fake_tensor_mode is None:
|
|
|
shape_env = ShapeEnv()
|
|
|
import torch._functorch.config as _config
|
|
|
|
|
|
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
|
|
|
fake_tensor_mode = FakeTensorMode(
|
|
|
allow_fallback_kernels=False,
|
|
|
allow_non_fake_inputs=self._allow_non_fake_inputs,
|
|
|
shape_env=shape_env,
|
|
|
)
|
|
|
assert fake_tensor_mode.shape_env is not None, (
|
|
|
"shape_env should be set if tracing with 'symbolic'"
|
|
|
)
|
|
|
self.fake_tensor_mode = fake_tensor_mode
|
|
|
else:
|
|
|
if not self.tracing_mode == "real":
|
|
|
raise AssertionError(
|
|
|
f"Unexpected tracing type: {self.tracing_mode}"
|
|
|
)
|
|
|
|
|
|
self._construct_modes_with_fx_tracer(self.fx_tracer)
|
|
|
yield
|
|
|
finally:
|
|
|
self._restore_modes(*prev_modes)
|
|
|
|
|
|
def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None:
|
|
|
self.proxy_mode = ProxyTorchDispatchMode(
|
|
|
fx_tracer,
|
|
|
self.tracing_mode,
|
|
|
pre_dispatch=self.pre_dispatch,
|
|
|
_allow_fake_constant=self._allow_fake_constant,
|
|
|
_error_on_data_dependent_ops=self._error_on_data_dependent_ops,
|
|
|
)
|
|
|
|
|
|
if self.pre_dispatch:
|
|
|
self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer)
|
|
|
|
|
|
|
|
|
|
|
|
if self.tracing_mode == "symbolic" or self.pre_dispatch:
|
|
|
self.python_dispatcher_mode = enable_python_dispatcher()
|
|
|
|
|
|
self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer)
|
|
|
|
|
|
@contextmanager
|
|
|
def _init_modes_from_parent(
|
|
|
self, parent_tracer: _MakefxTracer
|
|
|
) -> Generator[None, None, None]:
|
|
|
|
|
|
|
|
|
|
|
|
prev_modes = self._checkpoint_modes()
|
|
|
try:
|
|
|
self.fake_tensor_mode = parent_tracer.fake_tensor_mode
|
|
|
|
|
|
def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer:
|
|
|
if type(parent_tracer) == PythonKeyTracer:
|
|
|
return PythonKeyTracer()
|
|
|
elif type(parent_tracer) == _ModuleStackTracer:
|
|
|
return _ModuleStackTracer(parent_tracer.scope_root)
|
|
|
else:
|
|
|
raise RuntimeError(
|
|
|
f"Unexpected tracer type: {type(parent_tracer)}."
|
|
|
)
|
|
|
|
|
|
assert parent_tracer.fx_tracer is not None
|
|
|
self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer)
|
|
|
self._construct_modes_with_fx_tracer(self.fx_tracer)
|
|
|
yield
|
|
|
finally:
|
|
|
self._restore_modes(*prev_modes)
|
|
|
|
|
|
def _trace_inner(self, f: Callable, *args: object) -> GraphModule:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch._dynamo
|
|
|
|
|
|
phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args)
|
|
|
|
|
|
def _wrap_fake(args: T) -> T:
|
|
|
arg_count = 0
|
|
|
|
|
|
def inner_wrap_fake(x: object) -> object:
|
|
|
nonlocal arg_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
|
|
assert self.fake_tensor_mode is not None
|
|
|
source = ConstantSource(f"input{arg_count}")
|
|
|
if isinstance(x, Tensor):
|
|
|
arg_count += 1
|
|
|
return self.fake_tensor_mode.from_tensor(x, source=source)
|
|
|
|
|
|
elif type(x) is int and self.tracing_mode == "symbolic":
|
|
|
assert self.fake_tensor_mode.shape_env is not None, (
|
|
|
"shape_env should be set if tracing with 'symbolic'"
|
|
|
)
|
|
|
return self.fake_tensor_mode.shape_env.create_symintnode(
|
|
|
self.fake_tensor_mode.shape_env.create_symbol(
|
|
|
x, source, positive=None
|
|
|
),
|
|
|
hint=x,
|
|
|
source=source,
|
|
|
)
|
|
|
elif isinstance(x, torch.ScriptObject):
|
|
|
return torch._library.fake_class_registry.maybe_to_fake_obj(
|
|
|
self.fake_tensor_mode, x
|
|
|
)
|
|
|
|
|
|
assert not isinstance(x, FakeScriptObject), (
|
|
|
f"ScriptObject {x} has been fakified. Cannot wrap_fake it again."
|
|
|
)
|
|
|
return x
|
|
|
|
|
|
wrap_fn_map = {
|
|
|
"real": lambda x: x,
|
|
|
"fake": inner_wrap_fake,
|
|
|
"symbolic": inner_wrap_fake,
|
|
|
}
|
|
|
return pytree.tree_map(wrap_fn_map[self.tracing_mode], args)
|
|
|
|
|
|
def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]:
|
|
|
if (
|
|
|
not hasattr(inspect.unwrap(f), "__code__")
|
|
|
or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS
|
|
|
):
|
|
|
|
|
|
|
|
|
return fake_signature(f, len(phs))
|
|
|
return f
|
|
|
|
|
|
args = _wrap_fake(args)
|
|
|
func = _wrap_func(f, phs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
proxy_mode: ProxyTorchDispatchMode = typing.cast(
|
|
|
ProxyTorchDispatchMode, self.proxy_mode
|
|
|
)
|
|
|
with ExitStack() as stack:
|
|
|
stack.enter_context(decompose(self.decomposition_table))
|
|
|
if self.fake_tensor_mode:
|
|
|
stack.enter_context(self.fake_tensor_mode)
|
|
|
stack.enter_context(self.python_dispatcher_mode)
|
|
|
stack.enter_context(self.proxy_function_mode)
|
|
|
stack.enter_context(self.torch_fn_metadata_mode)
|
|
|
stack.enter_context(proxy_mode)
|
|
|
stack.enter_context(disable_autocast_cache())
|
|
|
stack.enter_context(_set_make_fx_tracer(self))
|
|
|
|
|
|
assert self.fx_tracer is not None
|
|
|
try:
|
|
|
t = dispatch_trace(
|
|
|
wrap_key(func, args, self.fx_tracer, self.pre_dispatch),
|
|
|
tracer=self.fx_tracer,
|
|
|
concrete_args=tuple(phs),
|
|
|
)
|
|
|
except Exception:
|
|
|
trace_structured(
|
|
|
"artifact",
|
|
|
metadata_fn=lambda: {
|
|
|
"name": "make_fx_fail_partial",
|
|
|
"encoding": "string",
|
|
|
},
|
|
|
payload_fn=lambda: self.fx_tracer.graph.python_code(
|
|
|
root_module="self",
|
|
|
verbose=True,
|
|
|
include_stride=True,
|
|
|
include_device=True,
|
|
|
).src,
|
|
|
)
|
|
|
raise
|
|
|
|
|
|
|
|
|
if self.tracing_mode == "symbolic":
|
|
|
assert self.fake_tensor_mode is not None
|
|
|
t.shape_env = self.fake_tensor_mode.shape_env
|
|
|
return t
|
|
|
|
|
|
def trace(self, f: Callable, *args: object) -> fx.GraphModule:
|
|
|
with self._init_modes_from_inputs(f, args):
|
|
|
return self._trace_inner(f, *args)
|
|
|
|
|
|
def trace_subgraph(self, f: Callable, *args: object) -> GraphModule:
|
|
|
|
|
|
sub_tracer = _MakefxTracer(
|
|
|
self.decomposition_table,
|
|
|
"real",
|
|
|
self._allow_non_fake_inputs,
|
|
|
self.pre_dispatch,
|
|
|
self.record_module_stack,
|
|
|
self._allow_fake_constant,
|
|
|
self._error_on_data_dependent_ops,
|
|
|
)
|
|
|
with sub_tracer._init_modes_from_parent(self):
|
|
|
return sub_tracer._trace_inner(f, *args)
|
|
|
|
|
|
|
|
|
_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]:
|
|
|
global _CURRENT_MAKE_FX_TRACER
|
|
|
prev_tracer = _CURRENT_MAKE_FX_TRACER
|
|
|
try:
|
|
|
_CURRENT_MAKE_FX_TRACER = tracer
|
|
|
yield
|
|
|
finally:
|
|
|
_CURRENT_MAKE_FX_TRACER = prev_tracer
|
|
|
|
|
|
|
|
|
def make_fx(
|
|
|
f: Callable,
|
|
|
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
|
|
tracing_mode: str = "real",
|
|
|
_allow_non_fake_inputs: bool = False,
|
|
|
*,
|
|
|
pre_dispatch: bool = False,
|
|
|
record_module_stack: bool = False,
|
|
|
_allow_fake_constant: bool = False,
|
|
|
_error_on_data_dependent_ops: bool = True,
|
|
|
stack_trace: bool = False,
|
|
|
) -> Callable[..., GraphModule]:
|
|
|
"""
|
|
|
Given a function f, return a new function which when executed with valid
|
|
|
arguments to f, returns an FX GraphModule representing the set of operations that
|
|
|
were executed during the course of execution.
|
|
|
|
|
|
If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"]
|
|
|
"""
|
|
|
|
|
|
assert tracing_mode in ["real", "fake", "symbolic"]
|
|
|
|
|
|
from torch._inductor import config
|
|
|
|
|
|
make_fx_tracer = _MakefxTracer(
|
|
|
decomposition_table,
|
|
|
tracing_mode,
|
|
|
_allow_non_fake_inputs,
|
|
|
pre_dispatch,
|
|
|
record_module_stack,
|
|
|
_allow_fake_constant,
|
|
|
_error_on_data_dependent_ops,
|
|
|
stack_trace=stack_trace or config.trace.enabled,
|
|
|
)
|
|
|
|
|
|
@functools.wraps(f)
|
|
|
def wrapped(*args: object) -> GraphModule:
|
|
|
return make_fx_tracer.trace(f, *args)
|
|
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
|
def get_torch_dispatch_modes() -> list[TorchDispatchMode]:
|
|
|
return torch.utils._python_dispatch._get_current_dispatch_mode_stack()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
|
|
return get_proxy_mode()
|
|
|
|
|
|
|
|
|
def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]:
|
|
|
"""
|
|
|
Current the currently active proxy tracing mode, or None if
|
|
|
we are not currently tracing. This includes pre-dispatch proxy
|
|
|
tracing.
|
|
|
"""
|
|
|
pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch(
|
|
|
torch._C._TorchDispatchModeKey.PROXY
|
|
|
)
|
|
|
mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY)
|
|
|
assert pre_dispatch_mode is None or mode is None, (
|
|
|
f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}"
|
|
|
)
|
|
|
return pre_dispatch_mode or mode
|
|
|
|
|
|
|
|
|
def handle_sym_dispatch(
|
|
|
func: Callable[_P, R],
|
|
|
args: _P.args,
|
|
|
kwargs: _P.kwargs,
|
|
|
) -> R:
|
|
|
"""
|
|
|
Call into the currently active proxy tracing mode to do a
|
|
|
SymInt/SymFloat/SymBool dispatch trace on a function that operates on
|
|
|
these arguments.
|
|
|
"""
|
|
|
mode = get_proxy_mode()
|
|
|
assert mode
|
|
|
|
|
|
|
|
|
with disable_proxy_modes_tracing():
|
|
|
|
|
|
types: list[type] = []
|
|
|
return mode.__sym_dispatch__(func, types, args, kwargs)
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]:
|
|
|
return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
|
|
|
|
|
|
|
|
|
def maybe_handle_decomp(
|
|
|
proxy_mode: ProxyTorchDispatchMode,
|
|
|
op: OpOverload,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
) -> object:
|
|
|
from torch._inductor.compiler_bisector import CompilerBisector
|
|
|
|
|
|
if op in CURRENT_DECOMPOSITION_TABLE:
|
|
|
if CompilerBisector.disable_subsystem(
|
|
|
"aot_eager_decomp_partition", "decomposition", lambda: repr(op)
|
|
|
):
|
|
|
return NotImplemented
|
|
|
|
|
|
with proxy_mode:
|
|
|
proxy_mode.decomp_layers += 1
|
|
|
out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs)
|
|
|
proxy_mode.decomp_layers -= 1
|
|
|
return out
|
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
def get_isolated_graphmodule(
|
|
|
func: Callable,
|
|
|
args: tuple[object, ...],
|
|
|
kwargs: dict[str, object],
|
|
|
tracing_mode: str = "real",
|
|
|
decomposition_table: Optional[Mapping[OpOverload, Callable]] = None,
|
|
|
) -> GraphModule:
|
|
|
"""A helper function used to get the GraphModule for the given func.
|
|
|
|
|
|
It's expected to be used in the ProxyTensor tracing context.
|
|
|
It detaches the args and kwargs from the current tracer so that the trace of
|
|
|
the current graph module can be created without any side-effects.
|
|
|
"""
|
|
|
wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs)
|
|
|
|
|
|
with disable_proxy_modes_tracing():
|
|
|
gm = make_fx(
|
|
|
wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode
|
|
|
)(all_args)
|
|
|
return gm
|
|
|
|
|
|
|
|
|
def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
|
|
|
"""A helper function for setting up unbacked_bindings on the destination FX graph."""
|
|
|
from .symbolic_shapes import compute_unbacked_bindings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
|
|
|
if fake_mode and fake_mode.shape_env:
|
|
|
if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
|
|
|
assert isinstance(out_proxy, Proxy), out_proxy
|
|
|
out_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
|
|
|