|
|
from __future__ import annotations
|
|
|
|
|
|
import contextlib
|
|
|
import dataclasses
|
|
|
import functools
|
|
|
import threading
|
|
|
import typing
|
|
|
import warnings
|
|
|
import weakref
|
|
|
from abc import abstractmethod
|
|
|
from contextlib import AbstractContextManager, contextmanager
|
|
|
from dataclasses import dataclass
|
|
|
from typing import (
|
|
|
Any,
|
|
|
Callable,
|
|
|
ClassVar,
|
|
|
Generic,
|
|
|
NewType,
|
|
|
Optional,
|
|
|
Protocol,
|
|
|
TYPE_CHECKING,
|
|
|
TypeVar,
|
|
|
Union,
|
|
|
)
|
|
|
from typing_extensions import override, TypedDict, TypeGuard, TypeIs, Unpack
|
|
|
|
|
|
import torch
|
|
|
from torch._C._autograd import CreationMeta
|
|
|
from torch._C._functorch import (
|
|
|
_add_batch_dim,
|
|
|
_unwrap_functional_tensor,
|
|
|
_wrap_functional_tensor,
|
|
|
get_unwrapped,
|
|
|
is_batchedtensor,
|
|
|
is_functorch_wrapped_tensor,
|
|
|
is_gradtrackingtensor,
|
|
|
is_legacy_batchedtensor,
|
|
|
maybe_get_bdim,
|
|
|
maybe_get_level,
|
|
|
peek_interpreter_stack,
|
|
|
)
|
|
|
from torch._dispatch.python import enable_python_dispatcher
|
|
|
from torch._logging import trace_structured
|
|
|
from torch.utils._mode_utils import no_dispatch
|
|
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
|
|
from torch.utils.weak import WeakIdKeyDictionary
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from collections.abc import Generator
|
|
|
|
|
|
from torch._C._functorch import CInterpreter
|
|
|
from torch._guards import Source
|
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv, SymbolicContext
|
|
|
|
|
|
|
|
|
def _is_fake_tensor(t: object) -> TypeIs[FakeTensor]:
|
|
|
from torch._subclasses.fake_tensor import FakeTensor
|
|
|
|
|
|
return isinstance(t, FakeTensor)
|
|
|
|
|
|
|
|
|
DimList = list
|
|
|
_TensorLikeT = TypeVar("_TensorLikeT", "MetaTensorDesc", torch.Tensor)
|
|
|
_T = TypeVar("_T")
|
|
|
_TensorT = TypeVar("_TensorT", bound=torch.Tensor)
|
|
|
_TensorT_cov = TypeVar("_TensorT_cov", bound=torch.Tensor, covariant=True)
|
|
|
|
|
|
|
|
|
def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
|
|
|
try:
|
|
|
return t.is_leaf
|
|
|
except RuntimeError:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
|
|
with warnings.catch_warnings():
|
|
|
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
|
|
|
return t.grad
|
|
|
|
|
|
|
|
|
def _expect_safe_grad(t: _TensorLikeT) -> _TensorLikeT:
|
|
|
grad = safe_grad(t)
|
|
|
assert grad is not None
|
|
|
return grad
|
|
|
|
|
|
|
|
|
def assert_eq(a: _T, b: _T) -> None:
|
|
|
assert a == b, f"{a} != {b}"
|
|
|
|
|
|
|
|
|
tls = threading.local()
|
|
|
|
|
|
|
|
|
|
|
|
tls.disable_inference_mode = False
|
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
def disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
|
|
|
prior = getattr(tls, "disable_inference_mode", False)
|
|
|
tls.disable_inference_mode = True
|
|
|
try:
|
|
|
yield
|
|
|
finally:
|
|
|
tls.disable_inference_mode = prior
|
|
|
|
|
|
|
|
|
def assert_metadata_eq(
|
|
|
assert_eq: Callable[[object, object], None],
|
|
|
m1: Union[MetaTensorDesc, torch.Tensor],
|
|
|
m2: torch.Tensor,
|
|
|
*,
|
|
|
skip_symbolic: bool = False,
|
|
|
skip_leaf: bool = False,
|
|
|
) -> None:
|
|
|
m1 = (
|
|
|
MetaTensorDescriber().describe_tensor(m1)
|
|
|
if isinstance(m1, torch.Tensor)
|
|
|
else m1
|
|
|
)
|
|
|
|
|
|
def go(m1: MetaTensorDesc, m2: torch.Tensor) -> None:
|
|
|
assert_eq(m1.dtype, m2.dtype)
|
|
|
if not skip_symbolic:
|
|
|
assert_eq(m1.shape, m2.shape)
|
|
|
assert_eq(m1.requires_grad, m2.requires_grad)
|
|
|
if not skip_leaf:
|
|
|
assert_eq(m1.is_leaf, m2.is_leaf)
|
|
|
|
|
|
|
|
|
assert_eq(m1.is_sparse, m2.is_sparse)
|
|
|
if not getattr(tls, "disable_inference_mode", False):
|
|
|
assert_eq(m1.is_inference, m2.is_inference())
|
|
|
else:
|
|
|
assert_eq(m1.is_inference, False)
|
|
|
assert_eq(m1.is_conj, m2.is_conj())
|
|
|
assert_eq(m1.is_neg, m2.is_neg())
|
|
|
assert_eq(m1.grad is not None, safe_grad(m2) is not None)
|
|
|
if m1.grad is not None:
|
|
|
go(m1.grad, _expect_safe_grad(m2))
|
|
|
|
|
|
|
|
|
if m1.is_sparse:
|
|
|
assert_eq(m1.layout, m2.layout)
|
|
|
assert_eq(m1.dense_dim, m2.dense_dim())
|
|
|
assert_eq(m1.sparse_dim, m2.sparse_dim())
|
|
|
assert_eq(m1.is_coalesced, m2.is_coalesced())
|
|
|
elif is_sparse_compressed(m1):
|
|
|
assert_eq(m1.layout, m2.layout)
|
|
|
assert_eq(m1.dense_dim, m2.dense_dim())
|
|
|
assert_eq(m1.sparse_dim, m2.sparse_dim())
|
|
|
else:
|
|
|
if not skip_symbolic:
|
|
|
assert_eq(m1.stride, m2.stride())
|
|
|
assert_eq(m1.storage_offset, m2.storage_offset())
|
|
|
assert_eq(m1.is_view, m2._is_view())
|
|
|
if m1.is_view:
|
|
|
assert m1.base is not None
|
|
|
assert m2._base is not None
|
|
|
go(m1.base, m2._base)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return go(m1, m2)
|
|
|
|
|
|
|
|
|
|
|
|
def is_sparse_coo(t: object) -> TypeGuard[torch.Tensor]:
|
|
|
return isinstance(t, torch.Tensor) and t.layout is torch.sparse_coo
|
|
|
|
|
|
|
|
|
def is_sparse_compressed_layout(layout: torch.layout) -> bool:
|
|
|
return layout in {
|
|
|
torch.sparse_csr,
|
|
|
torch.sparse_csc,
|
|
|
torch.sparse_bsr,
|
|
|
torch.sparse_bsc,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def is_sparse_compressed(t: object) -> TypeGuard[torch.Tensor]:
|
|
|
return isinstance(t, torch.Tensor) and is_sparse_compressed_layout(t.layout)
|
|
|
|
|
|
|
|
|
|
|
|
def is_sparse_any(t: object) -> TypeGuard[torch.Tensor]:
|
|
|
return is_sparse_coo(t) or is_sparse_compressed(t)
|
|
|
|
|
|
|
|
|
def _checked_cast(ty: type[_T], obj: object) -> _T:
|
|
|
assert isinstance(obj, ty), f"expected {ty} but got {type(obj)}"
|
|
|
return obj
|
|
|
|
|
|
|
|
|
def _get_real_storage(base: torch.UntypedStorage) -> torch.UntypedStorage:
|
|
|
return base.real_storage
|
|
|
|
|
|
|
|
|
def _set_real_storage(
|
|
|
base: torch.UntypedStorage, real_storage: torch.UntypedStorage
|
|
|
) -> None:
|
|
|
base.real_storage = real_storage
|
|
|
|
|
|
|
|
|
|
|
|
MetaStorageId = NewType("MetaStorageId", int)
|
|
|
MetaTensorId = NewType("MetaTensorId", int)
|
|
|
|
|
|
|
|
|
_DescriberId = NewType("_DescriberId", int)
|
|
|
DESCRIBER_NEXT_ID = _DescriberId(0)
|
|
|
|
|
|
|
|
|
class MetaTensorDescriber:
|
|
|
"""
|
|
|
Given a Tensor/Storage, generate a MetaTensorDesc/MetaStorageDesc
|
|
|
for it, which is enough information to reconstruct a meta tensor/fake tensor
|
|
|
corresponding to a Tensor as faithfully as possible.
|
|
|
|
|
|
This is a stateful conversion object because we keep track of the IDs
|
|
|
of the tensors/storages passed to us, so we can consistently give
|
|
|
the same ID when we see the same tensor/storage.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, *, copy_data: bool = False) -> None:
|
|
|
global DESCRIBER_NEXT_ID
|
|
|
self.id = DESCRIBER_NEXT_ID
|
|
|
DESCRIBER_NEXT_ID = _DescriberId(DESCRIBER_NEXT_ID + 1)
|
|
|
self.next_tensor_id: MetaTensorId = MetaTensorId(0)
|
|
|
self.next_storage_id: MetaStorageId = MetaStorageId(0)
|
|
|
|
|
|
self.lookup_tensor = WeakIdKeyDictionary()
|
|
|
|
|
|
self.lookup_storage = WeakIdKeyDictionary()
|
|
|
self.copy_data = copy_data
|
|
|
self.traced_tensors: set[int] = set()
|
|
|
self.traced_storages: set[int] = set()
|
|
|
|
|
|
def get_tensor_id(self, t: torch.Tensor) -> MetaTensorId:
|
|
|
if t not in self.lookup_tensor:
|
|
|
self.lookup_tensor[t] = self.next_tensor_id
|
|
|
self.next_tensor_id = MetaTensorId(self.next_tensor_id + 1)
|
|
|
return self.lookup_tensor[t]
|
|
|
|
|
|
def get_storage_id(self, s: torch.UntypedStorage) -> MetaStorageId:
|
|
|
if s not in self.lookup_storage:
|
|
|
self.lookup_storage[s] = self.next_storage_id
|
|
|
self.next_storage_id = MetaStorageId(self.next_storage_id + 1)
|
|
|
return self.lookup_storage[s]
|
|
|
|
|
|
def describe_storage(
|
|
|
self, s: torch.UntypedStorage, *, trace: bool = False
|
|
|
) -> MetaStorageDesc:
|
|
|
r = MetaStorageDesc(
|
|
|
id=self.get_storage_id(s),
|
|
|
size=s.size(),
|
|
|
|
|
|
|
|
|
data=s if self.copy_data else None,
|
|
|
)
|
|
|
if trace and r.id not in self.traced_storages:
|
|
|
trace_structured(
|
|
|
"describe_storage",
|
|
|
metadata_fn=lambda: r.as_json(self.id),
|
|
|
)
|
|
|
self.traced_storages.add(r.id)
|
|
|
return r
|
|
|
|
|
|
def describe_tensor(
|
|
|
self, t: torch.Tensor, *, recurse: bool = True, trace: bool = False
|
|
|
) -> MetaTensorDesc:
|
|
|
is_leaf = safe_is_leaf(t)
|
|
|
is_view = t._is_view()
|
|
|
is_sparse = t.is_sparse
|
|
|
layout = t.layout
|
|
|
is_nested = t.is_nested
|
|
|
is_traceable_wrapper_subclass_v = is_traceable_wrapper_subclass(t)
|
|
|
is_functorch_wrapped = is_functorch_wrapped_tensor(t)
|
|
|
is_mkldnn = t.is_mkldnn
|
|
|
is_batchedtensor_v = is_batchedtensor(t)
|
|
|
is_legacy_batchedtensor_v = is_legacy_batchedtensor(t)
|
|
|
is_gradtrackingtensor_v = is_gradtrackingtensor(t)
|
|
|
is_functional = torch._is_functional_tensor(t)
|
|
|
|
|
|
storage = None
|
|
|
|
|
|
|
|
|
|
|
|
storage_offset = 0
|
|
|
if not (
|
|
|
is_sparse
|
|
|
or is_sparse_compressed_layout(layout)
|
|
|
or (is_nested and not is_traceable_wrapper_subclass_v)
|
|
|
or is_mkldnn
|
|
|
|
|
|
|
|
|
or is_functorch_wrapped
|
|
|
or is_legacy_batchedtensor_v
|
|
|
):
|
|
|
|
|
|
|
|
|
storage = self.describe_storage(t.untyped_storage(), trace=trace)
|
|
|
storage_offset = t.storage_offset()
|
|
|
|
|
|
stride = None
|
|
|
if not (
|
|
|
is_sparse
|
|
|
or is_sparse_compressed_layout(layout)
|
|
|
or (is_nested and not is_traceable_wrapper_subclass_v)
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
stride = t.stride()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unwrapped = None
|
|
|
autograd_meta_from = None
|
|
|
current_level = None
|
|
|
if is_batchedtensor_v or is_gradtrackingtensor_v:
|
|
|
unwrapped = self.describe_tensor(get_unwrapped(t), trace=trace)
|
|
|
|
|
|
|
|
|
elif is_functional and t.device.type not in ("xla", "lazy"):
|
|
|
if t._is_view():
|
|
|
raise RuntimeError(
|
|
|
"Cannot safely fakify a view because this process drops the view information right now."
|
|
|
)
|
|
|
if not is_functorch_wrapped:
|
|
|
torch._sync(t)
|
|
|
unwrapped = self.describe_tensor(
|
|
|
torch._from_functional_tensor(t), trace=trace
|
|
|
)
|
|
|
autograd_meta_from = t
|
|
|
else:
|
|
|
reapply_views = torch._C._functionalization_reapply_views_tls()
|
|
|
|
|
|
unwrapped = self.describe_tensor(
|
|
|
_unwrap_functional_tensor(t, reapply_views), trace=trace
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
current_level = torch._C._functorch.current_level()
|
|
|
|
|
|
maybe_functorch_stack = None
|
|
|
if is_functorch_wrapped:
|
|
|
with torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack() as maybe_functorch_stack:
|
|
|
pass
|
|
|
|
|
|
attrs = None
|
|
|
ctx = None
|
|
|
type_v = None
|
|
|
if is_traceable_wrapper_subclass_v:
|
|
|
assert hasattr(t, "__tensor_flatten__")
|
|
|
raw_attrs, ctx = t.__tensor_flatten__()
|
|
|
attrs = {
|
|
|
attr: self.describe_tensor(getattr(t, attr), trace=trace)
|
|
|
for attr in raw_attrs
|
|
|
}
|
|
|
type_v = type(t)
|
|
|
|
|
|
from torch.nested._internal.nested_tensor import _tensor_symint_registry
|
|
|
|
|
|
view_func = ViewFunc.from_tensor(t)
|
|
|
|
|
|
|
|
|
|
|
|
is_inference_mode_disabled = getattr(tls, "disable_inference_mode", False)
|
|
|
r: MetaTensorDesc = MetaTensorDesc(
|
|
|
id=self.get_tensor_id(t),
|
|
|
storage=storage,
|
|
|
is_inference=False if is_inference_mode_disabled else t.is_inference(),
|
|
|
is_leaf=is_leaf,
|
|
|
requires_grad=t.requires_grad,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ndim=t.dim(),
|
|
|
dtype=t.dtype,
|
|
|
is_sparse=is_sparse,
|
|
|
is_mkldnn=is_mkldnn,
|
|
|
is_functorch_wrapped=is_functorch_wrapped,
|
|
|
is_batchedtensor=is_batchedtensor_v,
|
|
|
is_legacy_batchedtensor=is_legacy_batchedtensor_v,
|
|
|
is_gradtrackingtensor=is_gradtrackingtensor_v,
|
|
|
is_view=is_view,
|
|
|
is_conj=t.is_conj(),
|
|
|
is_neg=t.is_neg(),
|
|
|
is_parameter=isinstance(t, torch.nn.Parameter),
|
|
|
is_traceable_wrapper_subclass=is_traceable_wrapper_subclass_v,
|
|
|
is_nested=is_nested,
|
|
|
nested_int=(
|
|
|
_tensor_symint_registry[t].node.nested_int()
|
|
|
if t in _tensor_symint_registry
|
|
|
else None
|
|
|
),
|
|
|
is_functional=is_functional,
|
|
|
layout=layout,
|
|
|
device=t.device,
|
|
|
size=t.size(),
|
|
|
stride=stride,
|
|
|
storage_offset=storage_offset,
|
|
|
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
|
|
|
sparse_dim=(
|
|
|
t.sparse_dim() if t.is_sparse or is_sparse_compressed(t) else None
|
|
|
),
|
|
|
dense_dim=t.dense_dim() if t.is_sparse or is_sparse_compressed(t) else None,
|
|
|
is_coalesced=t.is_coalesced() if t.is_sparse else None,
|
|
|
|
|
|
|
|
|
|
|
|
crow_indices=(
|
|
|
self.describe_tensor(t.crow_indices(), recurse=False, trace=trace)
|
|
|
if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
|
|
|
else None
|
|
|
),
|
|
|
col_indices=(
|
|
|
self.describe_tensor(t.col_indices(), recurse=False, trace=trace)
|
|
|
if recurse and t.layout in {torch.sparse_csr, torch.sparse_bsr}
|
|
|
else None
|
|
|
),
|
|
|
ccol_indices=(
|
|
|
self.describe_tensor(t.ccol_indices(), recurse=False, trace=trace)
|
|
|
if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
|
|
|
else None
|
|
|
),
|
|
|
row_indices=(
|
|
|
self.describe_tensor(t.row_indices(), recurse=False, trace=trace)
|
|
|
if recurse and t.layout in {torch.sparse_csc, torch.sparse_bsc}
|
|
|
else None
|
|
|
),
|
|
|
values=(
|
|
|
self.describe_tensor(t.values(), recurse=False, trace=trace)
|
|
|
if recurse and is_sparse_compressed(t)
|
|
|
else None
|
|
|
),
|
|
|
grad=(
|
|
|
self.describe_tensor(grad, trace=trace)
|
|
|
if (grad := safe_grad(t)) is not None
|
|
|
else None
|
|
|
),
|
|
|
creation_meta=(
|
|
|
torch._C._autograd._get_creation_meta(t) if t._is_view() else None
|
|
|
),
|
|
|
unwrapped=unwrapped,
|
|
|
level=(
|
|
|
maybe_get_level(t)
|
|
|
if is_batchedtensor_v or is_gradtrackingtensor_v
|
|
|
else None
|
|
|
),
|
|
|
bdim=maybe_get_bdim(t) if is_batchedtensor_v else None,
|
|
|
base=(
|
|
|
self.describe_tensor(t._base, trace=trace)
|
|
|
if recurse and t._is_view() and t._base is not None
|
|
|
else None
|
|
|
),
|
|
|
fake_mode=torch._subclasses.fake_tensor.maybe_get_fake_mode(t),
|
|
|
view_func=view_func,
|
|
|
attrs=attrs,
|
|
|
ctx=ctx,
|
|
|
type=type_v,
|
|
|
|
|
|
|
|
|
|
|
|
functorch_stack=maybe_functorch_stack,
|
|
|
autograd_meta_from=autograd_meta_from,
|
|
|
current_level=current_level,
|
|
|
data=t if self.copy_data else None,
|
|
|
)
|
|
|
if trace and r.id not in self.traced_tensors:
|
|
|
trace_structured(
|
|
|
"describe_tensor",
|
|
|
metadata_fn=lambda: r.as_json(self.id),
|
|
|
)
|
|
|
self.traced_tensors.add(r.id)
|
|
|
return r
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class MetaStorageDesc:
|
|
|
id: MetaStorageId
|
|
|
size: int
|
|
|
|
|
|
|
|
|
data: Optional[torch.UntypedStorage]
|
|
|
|
|
|
def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
|
|
|
return {
|
|
|
"id": self.id,
|
|
|
"describer_id": describer_id,
|
|
|
"size": self.size if isinstance(self.size, int) else repr(self.size),
|
|
|
}
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class ViewFunc(Generic[_TensorT]):
|
|
|
@abstractmethod
|
|
|
def apply(
|
|
|
self,
|
|
|
t: _TensorT,
|
|
|
new_base: _TensorT,
|
|
|
symint_visitor_fn: Optional[Callable[[int], int]] = None,
|
|
|
tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None,
|
|
|
) -> _TensorT:
|
|
|
...
|
|
|
|
|
|
@staticmethod
|
|
|
def from_tensor(t: torch.Tensor) -> ViewFunc:
|
|
|
if _is_fake_tensor(t):
|
|
|
return _FakeTensorViewFunc()
|
|
|
else:
|
|
|
return _CustomViewFunc(t._view_func_unsafe)
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
|
|
|
@override
|
|
|
def apply(
|
|
|
self,
|
|
|
t: torch.Tensor,
|
|
|
new_base: torch.Tensor,
|
|
|
symint_visitor_fn: Optional[Callable[[int], int]] = None,
|
|
|
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
|
|
|
) -> FakeTensor:
|
|
|
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
|
|
|
t, new_base, symint_visitor_fn, tensor_visitor_fn
|
|
|
)
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
|
|
|
func: Callable[
|
|
|
[
|
|
|
torch.Tensor,
|
|
|
Optional[Callable[[int], int]],
|
|
|
Optional[Callable[[torch.Tensor], _TensorT]],
|
|
|
],
|
|
|
_TensorT,
|
|
|
]
|
|
|
|
|
|
@override
|
|
|
def apply(
|
|
|
self,
|
|
|
t: torch.Tensor,
|
|
|
new_base: torch.Tensor,
|
|
|
symint_visitor_fn: Optional[Callable[[int], int]] = None,
|
|
|
tensor_visitor_fn: Optional[Callable[[torch.Tensor], _TensorT]] = None,
|
|
|
) -> _TensorT:
|
|
|
|
|
|
return self.func(new_base, symint_visitor_fn, tensor_visitor_fn)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
|
|
|
def __call__(
|
|
|
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
|
|
|
) -> _TensorT_cov:
|
|
|
...
|
|
|
|
|
|
|
|
|
class _MetaTensorCallbackKwargs(TypedDict, total=False):
|
|
|
device: Union[torch.device, str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
|
|
|
def __call__(
|
|
|
self,
|
|
|
arg: Callable[[], torch.Tensor],
|
|
|
/,
|
|
|
**kwargs: Unpack[_MetaTensorCallbackKwargs],
|
|
|
) -> _TensorT_cov:
|
|
|
...
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class MetaTensorDesc(Generic[_TensorT]):
|
|
|
id: MetaTensorId
|
|
|
ndim: int
|
|
|
dtype: torch.dtype
|
|
|
device: torch.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size: tuple[int, ...]
|
|
|
dynamo_dynamic_indices: list[int]
|
|
|
|
|
|
layout: torch.layout = torch.strided
|
|
|
is_inference: bool = False
|
|
|
is_leaf: bool = False
|
|
|
requires_grad: bool = False
|
|
|
is_sparse: bool = False
|
|
|
is_mkldnn: bool = False
|
|
|
is_functorch_wrapped: bool = False
|
|
|
is_batchedtensor: bool = False
|
|
|
is_legacy_batchedtensor: bool = False
|
|
|
is_gradtrackingtensor: bool = False
|
|
|
is_view: bool = False
|
|
|
is_nested: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
nested_int: Optional[int] = None
|
|
|
is_traceable_wrapper_subclass: bool = False
|
|
|
is_functional: bool = False
|
|
|
is_conj: bool = False
|
|
|
is_neg: bool = False
|
|
|
is_parameter: bool = False
|
|
|
stride: Optional[tuple[int, ...]] = None
|
|
|
storage_offset: int = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
storage: Optional[MetaStorageDesc] = None
|
|
|
sparse_dim: Optional[int] = None
|
|
|
dense_dim: Optional[int] = None
|
|
|
is_coalesced: Optional[bool] = None
|
|
|
crow_indices: Optional[MetaTensorDesc] = None
|
|
|
col_indices: Optional[MetaTensorDesc] = None
|
|
|
ccol_indices: Optional[MetaTensorDesc] = None
|
|
|
row_indices: Optional[MetaTensorDesc] = None
|
|
|
values: Optional[MetaTensorDesc] = None
|
|
|
unwrapped: Optional[MetaTensorDesc] = None
|
|
|
bdim: Optional[int] = None
|
|
|
base: Optional[MetaTensorDesc] = None
|
|
|
attrs: Optional[dict[str, MetaTensorDesc]] = None
|
|
|
creation_meta: Optional[CreationMeta] = None
|
|
|
grad: Optional[MetaTensorDesc] = None
|
|
|
|
|
|
|
|
|
|
|
|
_UNSERIALIZABLE: ClassVar[set[str]] = {
|
|
|
"ctx",
|
|
|
"type",
|
|
|
"fake_mode",
|
|
|
|
|
|
"view_func",
|
|
|
"level",
|
|
|
"current_level",
|
|
|
"functorch_stack",
|
|
|
"autograd_meta_from",
|
|
|
"data",
|
|
|
"nested_int",
|
|
|
}
|
|
|
|
|
|
ctx: Optional[object] = None
|
|
|
type: Optional[type] = None
|
|
|
fake_mode: Optional[FakeTensorMode] = None
|
|
|
view_func: Optional[ViewFunc] = None
|
|
|
|
|
|
|
|
|
level: Optional[int] = None
|
|
|
current_level: Optional[int] = None
|
|
|
functorch_stack: Optional[list[CInterpreter]] = None
|
|
|
autograd_meta_from: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data: Optional[torch.Tensor] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def as_json(self, describer_id: _DescriberId) -> dict[str, object]:
|
|
|
def json(k: str, v: object) -> object:
|
|
|
|
|
|
|
|
|
if k in ["data", "autograd_meta_from"]:
|
|
|
return None
|
|
|
if k in MetaTensorDesc._UNSERIALIZABLE:
|
|
|
return repr(v)
|
|
|
if isinstance(v, (torch.device, torch.dtype, torch.layout)):
|
|
|
return repr(v)
|
|
|
if isinstance(v, torch.SymInt):
|
|
|
return repr(v)
|
|
|
if isinstance(v, (tuple, list)):
|
|
|
return [json(k, v1) for v1 in v]
|
|
|
if isinstance(v, (MetaStorageDesc, MetaTensorDesc)):
|
|
|
return v.id
|
|
|
if isinstance(v, CreationMeta):
|
|
|
return str(v)
|
|
|
if k == "attrs" and isinstance(v, dict):
|
|
|
return {k1: v1.id for k1, v1 in v.items()}
|
|
|
return v
|
|
|
|
|
|
r = {
|
|
|
field.name: json(field.name, getattr(self, field.name))
|
|
|
for field in dataclasses.fields(self)
|
|
|
if not (
|
|
|
getattr(self, field.name) is field.default
|
|
|
or (
|
|
|
field.name == "dynamo_dynamic_indices"
|
|
|
and not getattr(self, field.name)
|
|
|
)
|
|
|
)
|
|
|
}
|
|
|
r.update({"describer_id": describer_id})
|
|
|
return r
|
|
|
|
|
|
@property
|
|
|
def shape(self) -> tuple[int, ...]:
|
|
|
return self.size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _safe_copy(dst: torch.Tensor, src: Optional[torch.Tensor]) -> None:
|
|
|
if type(src) is not torch.Tensor:
|
|
|
return
|
|
|
dst.copy_(src)
|
|
|
|
|
|
|
|
|
def _safe_clone(src: torch.Tensor) -> Optional[torch.Tensor]:
|
|
|
if type(src) is not torch.Tensor:
|
|
|
return None
|
|
|
return src.clone()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetaConverter(Generic[_TensorT]):
|
|
|
def __init__(self, *, copy_data: bool = False) -> None:
|
|
|
|
|
|
self.storage_memo: weakref.WeakValueDictionary[
|
|
|
MetaStorageId, torch.UntypedStorage
|
|
|
] = weakref.WeakValueDictionary()
|
|
|
|
|
|
|
|
|
self.tensor_memo: weakref.WeakValueDictionary[
|
|
|
MetaTensorId, _TensorT
|
|
|
] = weakref.WeakValueDictionary()
|
|
|
self.hit = 0
|
|
|
self.miss = 0
|
|
|
self.del_hook = None
|
|
|
self.arg_cnt = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.copy_data = copy_data
|
|
|
self.describer = MetaTensorDescriber(copy_data=copy_data)
|
|
|
|
|
|
def successful(self) -> bool:
|
|
|
return self.hit > 0 and self.miss == 0
|
|
|
|
|
|
def get_tensor_memo(self, t: MetaTensorDesc) -> Optional[torch.Tensor]:
|
|
|
return self.tensor_memo.get(t.id, None)
|
|
|
|
|
|
def _checked_get_tensor_memo(self, t: MetaTensorDesc) -> _TensorT:
|
|
|
r = self.tensor_memo.get(t.id, None)
|
|
|
assert r is not None
|
|
|
return r
|
|
|
|
|
|
def set_tensor_memo(self, t: MetaTensorDesc, v: _TensorT) -> None:
|
|
|
self.tensor_memo[t.id] = v
|
|
|
|
|
|
def get_storage_memo(self, s: MetaStorageDesc) -> Optional[torch.UntypedStorage]:
|
|
|
return self.storage_memo.get(s.id, None)
|
|
|
|
|
|
def set_storage_memo(self, s: MetaStorageDesc, v: torch.UntypedStorage) -> None:
|
|
|
self.storage_memo[s.id] = v
|
|
|
|
|
|
def meta_storage(
|
|
|
self,
|
|
|
s: MetaStorageDesc,
|
|
|
callback: Callable[[Callable[[], torch.Tensor]], _TensorT],
|
|
|
) -> torch.UntypedStorage:
|
|
|
|
|
|
|
|
|
if (memo := self.get_storage_memo(s)) is None:
|
|
|
r_s = callback(
|
|
|
lambda: torch.empty(s.size, dtype=torch.uint8, device="meta"),
|
|
|
).untyped_storage()
|
|
|
if self.copy_data:
|
|
|
|
|
|
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
assert s.data is not None
|
|
|
_set_real_storage(r_s, s.data.clone())
|
|
|
self.set_storage_memo(s, r_s)
|
|
|
return r_s
|
|
|
else:
|
|
|
return memo
|
|
|
|
|
|
@classmethod
|
|
|
def _checked_cast_tensor_t(cls, t: torch.Tensor) -> _TensorT:
|
|
|
|
|
|
return typing.cast(_TensorT, t)
|
|
|
|
|
|
@classmethod
|
|
|
def _identity_callable(
|
|
|
cls,
|
|
|
t: Callable[[], torch.Tensor],
|
|
|
device: Optional[Union[torch.device, str]] = None,
|
|
|
) -> _TensorT:
|
|
|
return cls._checked_cast_tensor_t(t())
|
|
|
|
|
|
@classmethod
|
|
|
def _backward_error(cls, t: _TensorT) -> _TensorT:
|
|
|
errfn = torch._C._functions.DelayedError(
|
|
|
"Internal error: Tried to backward() through example input",
|
|
|
1,
|
|
|
)
|
|
|
err = errfn(t)
|
|
|
return typing.cast(_TensorT, err)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def meta_tensor(
|
|
|
self,
|
|
|
t: MetaTensorDesc,
|
|
|
shape_env: Optional[ShapeEnv],
|
|
|
callback_: _MetaTensorCallback[_TensorT],
|
|
|
source: Optional[Source],
|
|
|
symbolic_context: Optional[SymbolicContext],
|
|
|
) -> _TensorT:
|
|
|
callback: _MetaTensorCallbackOptDevice = functools.partial(
|
|
|
callback_, device=t.device
|
|
|
)
|
|
|
if source is None:
|
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
|
|
|
|
|
source = ConstantSource(
|
|
|
f"__meta_utils_unknown_tensor{len(self.tensor_memo)}"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert not torch._C._dispatch_tls_local_exclude_set().has(
|
|
|
torch._C.DispatchKey.Python
|
|
|
)
|
|
|
self.arg_cnt += 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maybe_suppress: Callable[[], Any] = contextlib.nullcontext
|
|
|
if shape_env is not None:
|
|
|
maybe_suppress = shape_env.suppress_guards
|
|
|
|
|
|
def sym_sizes_strides_storage_offset(
|
|
|
t: MetaTensorDesc,
|
|
|
src: torch._guards.Source,
|
|
|
symbolic_context: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
] = symbolic_context,
|
|
|
) -> tuple[tuple[int, ...], tuple[int, ...], int]:
|
|
|
assert t.stride is not None
|
|
|
if shape_env is not None:
|
|
|
fake_mode = t.fake_mode
|
|
|
if fake_mode is not None and fake_mode.shape_env is shape_env:
|
|
|
|
|
|
|
|
|
return (t.size, t.stride, t.storage_offset)
|
|
|
else:
|
|
|
|
|
|
t_size = tuple(
|
|
|
shape_env._maybe_specialize_sym_int_with_hint(sz)
|
|
|
for sz in t.size
|
|
|
)
|
|
|
t_stride = tuple(
|
|
|
shape_env._maybe_specialize_sym_int_with_hint(sd)
|
|
|
for sd in t.stride
|
|
|
)
|
|
|
t_storage_offset = shape_env._maybe_specialize_sym_int_with_hint(
|
|
|
t.storage_offset
|
|
|
)
|
|
|
return shape_env._create_symbolic_sizes_strides_storage_offset(
|
|
|
t_size,
|
|
|
t_stride,
|
|
|
t_storage_offset,
|
|
|
[d in t.dynamo_dynamic_indices for d in range(t.ndim)],
|
|
|
src,
|
|
|
symbolic_context=symbolic_context,
|
|
|
)
|
|
|
else:
|
|
|
return (t.size, t.stride, t.storage_offset)
|
|
|
|
|
|
def empty_create(
|
|
|
inner_t: MetaTensorDesc,
|
|
|
inner_src: torch._guards.Source,
|
|
|
symbolic_context: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
] = symbolic_context,
|
|
|
) -> torch.Tensor:
|
|
|
(
|
|
|
inner_sizes,
|
|
|
inner_strides,
|
|
|
_inner_storage_offset,
|
|
|
) = sym_sizes_strides_storage_offset(inner_t, inner_src, symbolic_context)
|
|
|
return torch.empty_strided(
|
|
|
inner_sizes,
|
|
|
inner_strides,
|
|
|
dtype=inner_t.dtype,
|
|
|
device="meta",
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def empty_create_subclass(
|
|
|
t: MetaTensorDesc,
|
|
|
outer_size: tuple[int, ...],
|
|
|
outer_stride: tuple[int, ...],
|
|
|
symbolic_context: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
] = symbolic_context,
|
|
|
source: Optional[torch._guards.Source] = source,
|
|
|
) -> _TensorT:
|
|
|
from torch._dynamo.source import AttrSource
|
|
|
from torch.fx.experimental.symbolic_shapes import SubclassSymbolicContext
|
|
|
|
|
|
assert t.attrs is not None
|
|
|
assert t.type is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outer_size = outer_size if outer_size is not None else t.size
|
|
|
outer_stride = outer_stride if outer_stride is not None else t.stride
|
|
|
|
|
|
assert symbolic_context is None or isinstance(
|
|
|
symbolic_context, SubclassSymbolicContext
|
|
|
)
|
|
|
|
|
|
def _empty_create_subclass(
|
|
|
t: MetaTensorDesc,
|
|
|
outer_size: Optional[tuple[int, ...]],
|
|
|
outer_stride: Optional[tuple[int, ...]],
|
|
|
symbolic_context: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
],
|
|
|
callback: _MetaTensorCallbackOptDevice[_TensorT],
|
|
|
source: torch._guards.Source,
|
|
|
) -> _TensorT:
|
|
|
|
|
|
|
|
|
if t.attrs is None:
|
|
|
return self.meta_tensor(
|
|
|
t,
|
|
|
shape_env,
|
|
|
callback,
|
|
|
source,
|
|
|
symbolic_context,
|
|
|
)
|
|
|
|
|
|
inner_tensors = {}
|
|
|
for attr, meta_tensor_desc in t.attrs.items():
|
|
|
current_context = None
|
|
|
if symbolic_context is not None:
|
|
|
assert isinstance(symbolic_context, SubclassSymbolicContext)
|
|
|
if (
|
|
|
current_context_ := symbolic_context.inner_contexts[attr]
|
|
|
) is not None:
|
|
|
current_context = _checked_cast(
|
|
|
torch.fx.experimental.symbolic_shapes.SymbolicContext,
|
|
|
current_context_,
|
|
|
)
|
|
|
|
|
|
current_source = AttrSource(source, attr)
|
|
|
inner_callback = functools.partial(
|
|
|
callback, device=meta_tensor_desc.device
|
|
|
)
|
|
|
new_empty_tensor = _empty_create_subclass(
|
|
|
meta_tensor_desc,
|
|
|
meta_tensor_desc.size,
|
|
|
meta_tensor_desc.stride,
|
|
|
current_context,
|
|
|
inner_callback,
|
|
|
current_source,
|
|
|
)
|
|
|
inner_tensors[attr] = new_empty_tensor
|
|
|
|
|
|
assert t.type is not None
|
|
|
return t.type.__tensor_unflatten__(
|
|
|
inner_tensors, t.ctx, outer_size, outer_stride
|
|
|
)
|
|
|
|
|
|
assert source is not None
|
|
|
sub = _empty_create_subclass(
|
|
|
t, outer_size, outer_stride, symbolic_context, callback, source
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert sub.shape == outer_size, (
|
|
|
f"Expected return value from {t.type}__tensor_unflatten__() to have "
|
|
|
f"shape equal to {outer_size}, but got: {sub.shape}"
|
|
|
)
|
|
|
assert sub.stride() == outer_stride, (
|
|
|
f"Expected return value from {t.type}__tensor_unflatten__() to have "
|
|
|
f"stride equal to {outer_stride}, but got: {sub.stride()}"
|
|
|
)
|
|
|
|
|
|
return sub
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_dynamic_symbolic_context(
|
|
|
t: MetaTensorDesc,
|
|
|
source: torch._guards.Source,
|
|
|
shape_env: Optional[torch.fx.experimental.symbolic_shapes.ShapeEnv],
|
|
|
callback: _MetaTensorCallback[_TensorT],
|
|
|
) -> torch.fx.experimental.symbolic_shapes.SymbolicContext:
|
|
|
from torch._dynamo.source import AttrSource
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
DimDynamic,
|
|
|
StatelessSymbolicContext,
|
|
|
SubclassSymbolicContext,
|
|
|
)
|
|
|
|
|
|
view_base_context: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
] = None
|
|
|
if t.is_view:
|
|
|
assert t.base is not None
|
|
|
view_base_context = all_dynamic_symbolic_context(
|
|
|
t.base, AttrSource(source, "_base"), shape_env, callback
|
|
|
)
|
|
|
|
|
|
t_symbolic_context: torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
t_dynamic_sizes = [DimDynamic.DYNAMIC] * t.ndim
|
|
|
if t.is_traceable_wrapper_subclass:
|
|
|
assert t.attrs is not None
|
|
|
inner_contexts: dict[
|
|
|
str, torch.fx.experimental.symbolic_shapes.SymbolicContext
|
|
|
] = {}
|
|
|
for attr, inner in t.attrs.items():
|
|
|
assert isinstance(attr, str)
|
|
|
inner_contexts[attr] = all_dynamic_symbolic_context(
|
|
|
inner, AttrSource(source, attr), shape_env, callback
|
|
|
)
|
|
|
t_symbolic_context = SubclassSymbolicContext(
|
|
|
dynamic_sizes=t_dynamic_sizes,
|
|
|
constraint_sizes=[None] * t.ndim,
|
|
|
inner_contexts=inner_contexts,
|
|
|
tensor_source=source,
|
|
|
view_base_context=view_base_context,
|
|
|
)
|
|
|
else:
|
|
|
t_symbolic_context = StatelessSymbolicContext(
|
|
|
dynamic_sizes=t_dynamic_sizes,
|
|
|
constraint_sizes=[None] * t.ndim,
|
|
|
view_base_context=view_base_context,
|
|
|
)
|
|
|
|
|
|
return t_symbolic_context
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def view_from_base(
|
|
|
base: _TensorT,
|
|
|
t: MetaTensorDesc,
|
|
|
shape_env: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
|
|
] = shape_env,
|
|
|
) -> _TensorT:
|
|
|
with enable_python_dispatcher():
|
|
|
|
|
|
(sizes, strides, storage_offset) = sym_sizes_strides_storage_offset(
|
|
|
t, source
|
|
|
)
|
|
|
if (
|
|
|
not t.is_traceable_wrapper_subclass
|
|
|
and not is_traceable_wrapper_subclass(base)
|
|
|
):
|
|
|
|
|
|
|
|
|
|
|
|
with maybe_suppress():
|
|
|
return self._checked_cast_tensor_t(
|
|
|
base.as_strided(sizes, strides, storage_offset)
|
|
|
)
|
|
|
|
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
StatelessSymbolicContext,
|
|
|
sym_eq,
|
|
|
)
|
|
|
|
|
|
def symint_visitor_fn(s: int) -> int:
|
|
|
nonlocal symbolic_context
|
|
|
from torch.fx.experimental.symbolic_shapes import DimDynamic
|
|
|
|
|
|
all_static_sizes = (
|
|
|
symbolic_context is not None
|
|
|
and isinstance(symbolic_context, StatelessSymbolicContext)
|
|
|
and all(
|
|
|
x is DimDynamic.STATIC
|
|
|
for x in symbolic_context.dynamic_sizes
|
|
|
)
|
|
|
)
|
|
|
|
|
|
if all_static_sizes or shape_env is None:
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sym_source = EphemeralSource("symint_visitor_fn")
|
|
|
|
|
|
symbol = shape_env.create_symbol(s, sym_source, positive=None)
|
|
|
return shape_env.create_symintnode(
|
|
|
symbol, hint=s, source=sym_source
|
|
|
)
|
|
|
|
|
|
real_to_fake_mapping = {}
|
|
|
if t.is_traceable_wrapper_subclass:
|
|
|
assert t.attrs is not None
|
|
|
|
|
|
|
|
|
assert t.type is not None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fake_t: _TensorT = empty_create_subclass(
|
|
|
t, outer_size=sizes, outer_stride=strides
|
|
|
)
|
|
|
attrs, _ = fake_t.__tensor_flatten__()
|
|
|
for attr in attrs:
|
|
|
real_to_fake_mapping[t.attrs[attr].id] = getattr(fake_t, attr)
|
|
|
|
|
|
def tensor_visitor_fn(
|
|
|
visited_t: torch.Tensor,
|
|
|
|
|
|
|
|
|
shape_env: Optional[
|
|
|
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
|
|
] = shape_env,
|
|
|
callback: _MetaTensorCallbackOptDevice[_TensorT] = callback,
|
|
|
) -> torch.Tensor:
|
|
|
|
|
|
if visited_t is None:
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
visited_id = self.describer.get_tensor_id(visited_t)
|
|
|
fake_visited_t = real_to_fake_mapping.get(visited_id, None)
|
|
|
if fake_visited_t is not None:
|
|
|
return fake_visited_t
|
|
|
|
|
|
visited_desc = self.describer.describe_tensor(visited_t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
temp_source = EphemeralSource("tensor_visitor_fn")
|
|
|
return self.meta_tensor(
|
|
|
visited_desc,
|
|
|
shape_env,
|
|
|
callback,
|
|
|
temp_source,
|
|
|
all_dynamic_symbolic_context(
|
|
|
visited_desc, temp_source, shape_env, callback
|
|
|
),
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
assert t.view_func is not None
|
|
|
|
|
|
|
|
|
fake_t = t.view_func.apply(
|
|
|
t, base, symint_visitor_fn, tensor_visitor_fn
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
torch._check(sym_eq(fake_t.size(), sizes))
|
|
|
torch._check(sym_eq(fake_t.stride(), strides))
|
|
|
torch._check(sym_eq(fake_t.storage_offset(), storage_offset))
|
|
|
return fake_t
|
|
|
|
|
|
if self.get_tensor_memo(t) is None:
|
|
|
GRAD_TENSOR_SENTINEL_VALUE = -2
|
|
|
|
|
|
with torch.inference_mode(t.is_inference):
|
|
|
if t.is_sparse:
|
|
|
is_leaf = t.is_leaf
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = callback(
|
|
|
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
|
|
|
t.sparse_dim,
|
|
|
t.dense_dim,
|
|
|
t.size,
|
|
|
dtype=t.dtype,
|
|
|
layout=torch.sparse_coo,
|
|
|
device="meta",
|
|
|
)
|
|
|
)
|
|
|
if self.copy_data:
|
|
|
|
|
|
assert t.data is not None
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
assert _is_fake_tensor(r)
|
|
|
r.real_tensor = _safe_clone(t.data)
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r._coalesced_(bool(t.is_coalesced))
|
|
|
if t.requires_grad:
|
|
|
r.requires_grad = True
|
|
|
if t.requires_grad and not is_leaf:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = self._checked_cast_tensor_t(r.clone())
|
|
|
with torch.enable_grad():
|
|
|
r._coalesced_(bool(t.is_coalesced))
|
|
|
elif is_sparse_compressed_layout(t.layout):
|
|
|
is_leaf = t.is_leaf
|
|
|
|
|
|
if t.layout in {torch.sparse_bsr, torch.sparse_bsc}:
|
|
|
assert t.sparse_dim is not None
|
|
|
assert t.dense_dim is not None
|
|
|
assert t.values is not None
|
|
|
batch_dim = t.ndim - t.sparse_dim - t.dense_dim
|
|
|
blocksize = t.values.shape[batch_dim + 1 : batch_dim + 3]
|
|
|
else:
|
|
|
blocksize = ()
|
|
|
if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
|
|
assert t.crow_indices is not None
|
|
|
index_dtype = t.crow_indices.dtype
|
|
|
else:
|
|
|
assert t.ccol_indices is not None
|
|
|
index_dtype = t.ccol_indices.dtype
|
|
|
|
|
|
r = callback(
|
|
|
lambda: torch.ops.aten._sparse_compressed_tensor_with_dims(
|
|
|
0,
|
|
|
t.dense_dim,
|
|
|
t.shape,
|
|
|
blocksize,
|
|
|
index_dtype,
|
|
|
layout=t.layout,
|
|
|
dtype=t.dtype,
|
|
|
device="meta",
|
|
|
)
|
|
|
)
|
|
|
if self.copy_data:
|
|
|
|
|
|
assert t.data is not None
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
assert _is_fake_tensor(r)
|
|
|
r.real_tensor = _safe_clone(t.data)
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
|
if t.requires_grad:
|
|
|
r.requires_grad = True
|
|
|
if t.requires_grad and not is_leaf:
|
|
|
r = self._backward_error(r)
|
|
|
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
|
|
|
|
|
|
|
|
|
|
|
from torch._dynamo.exc import unimplemented
|
|
|
|
|
|
unimplemented(
|
|
|
"strided nested tensors are not supported by meta conversion"
|
|
|
)
|
|
|
elif t.is_mkldnn:
|
|
|
is_leaf = t.is_leaf
|
|
|
(
|
|
|
sizes,
|
|
|
strides,
|
|
|
_storage_offset,
|
|
|
) = sym_sizes_strides_storage_offset(t, source)
|
|
|
|
|
|
|
|
|
r = callback(
|
|
|
lambda: torch.empty_strided(
|
|
|
sizes, strides, dtype=t.dtype, device="meta"
|
|
|
)
|
|
|
)
|
|
|
if self.copy_data:
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
assert t.size is not None
|
|
|
assert t.stride is not None
|
|
|
assert _is_fake_tensor(r)
|
|
|
r.real_tensor = torch.empty_strided(
|
|
|
t.size, t.stride, dtype=t.dtype, device=t.device
|
|
|
)
|
|
|
assert t.data is not None
|
|
|
_safe_copy(r.real_tensor, t.data)
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
|
if t.requires_grad:
|
|
|
r.requires_grad = True
|
|
|
if t.requires_grad and not is_leaf:
|
|
|
r = self._backward_error(r)
|
|
|
elif t.is_functorch_wrapped:
|
|
|
if t.is_view:
|
|
|
from torch._dynamo.exc import unimplemented
|
|
|
|
|
|
unimplemented(
|
|
|
"view functorch tensors are not supported by meta conversion"
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def _to_fake_tensor(t: MetaTensorDesc) -> _TensorT:
|
|
|
|
|
|
|
|
|
r: _TensorT
|
|
|
if t.is_batchedtensor:
|
|
|
assert t.unwrapped is not None
|
|
|
assert t.level is not None
|
|
|
assert t.bdim is not None
|
|
|
ft = _to_fake_tensor(t.unwrapped)
|
|
|
lvl = t.level
|
|
|
bdim = t.bdim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
|
|
t.functorch_stack
|
|
|
):
|
|
|
r = self._checked_cast_tensor_t(
|
|
|
_add_batch_dim(ft, bdim, lvl)
|
|
|
)
|
|
|
elif t.is_gradtrackingtensor:
|
|
|
assert t.unwrapped is not None
|
|
|
assert t.level is not None
|
|
|
disable_functorch = torch._C._DisableFuncTorch
|
|
|
with disable_functorch():
|
|
|
ft = _to_fake_tensor(t.unwrapped)
|
|
|
lvl = t.level
|
|
|
if lvl == GRAD_TENSOR_SENTINEL_VALUE:
|
|
|
r = ft
|
|
|
else:
|
|
|
with torch._functorch.pyfunctorch.temporarily_restore_interpreter_stack(
|
|
|
t.functorch_stack
|
|
|
):
|
|
|
r = self._checked_cast_tensor_t(
|
|
|
torch._C._functorch._wrap_for_grad(ft, lvl),
|
|
|
)
|
|
|
|
|
|
is_leaf = t.is_leaf
|
|
|
if t.requires_grad and safe_is_leaf(r):
|
|
|
r.requires_grad = True
|
|
|
elif t.requires_grad and not is_leaf:
|
|
|
r = self._backward_error(r)
|
|
|
elif t.is_functional:
|
|
|
assert t.unwrapped is not None
|
|
|
assert t.current_level is not None
|
|
|
ft = self.meta_tensor(
|
|
|
t.unwrapped,
|
|
|
shape_env,
|
|
|
callback,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
source,
|
|
|
symbolic_context,
|
|
|
)
|
|
|
r = self._checked_cast_tensor_t(
|
|
|
_wrap_functional_tensor(ft, t.current_level),
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
assert t.stride is not None
|
|
|
|
|
|
sizes = t.size
|
|
|
strides = t.stride
|
|
|
r = callback(
|
|
|
lambda: torch.empty_strided(
|
|
|
sizes,
|
|
|
strides,
|
|
|
dtype=t.dtype,
|
|
|
device="meta",
|
|
|
),
|
|
|
|
|
|
)
|
|
|
if self.copy_data:
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
r.real_tensor = torch.empty_strided(
|
|
|
t.size,
|
|
|
t.stride,
|
|
|
dtype=t.dtype,
|
|
|
device=t.device,
|
|
|
)
|
|
|
assert t.data is not None
|
|
|
_safe_copy(r.real_tensor, t.data)
|
|
|
return r
|
|
|
|
|
|
r = _to_fake_tensor(t)
|
|
|
|
|
|
elif t.is_functional and t.device.type not in ["xla", "lazy"]:
|
|
|
assert t.unwrapped is not None
|
|
|
assert not t.is_functorch_wrapped
|
|
|
unwrapped = self.meta_tensor(
|
|
|
t.unwrapped,
|
|
|
shape_env,
|
|
|
callback,
|
|
|
source,
|
|
|
symbolic_context,
|
|
|
)
|
|
|
r = self._checked_cast_tensor_t(
|
|
|
torch._to_functional_tensor(unwrapped)
|
|
|
)
|
|
|
torch._mirror_autograd_meta_to(t.autograd_meta_from, r)
|
|
|
|
|
|
elif t.is_view:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert t.base is not None
|
|
|
|
|
|
base_symbolic_context = None
|
|
|
if shape_env and symbolic_context is not None:
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
StatelessSymbolicContext,
|
|
|
)
|
|
|
|
|
|
assert isinstance(symbolic_context, StatelessSymbolicContext)
|
|
|
|
|
|
|
|
|
|
|
|
if symbolic_context.view_base_context is not None:
|
|
|
base_symbolic_context = symbolic_context.view_base_context
|
|
|
|
|
|
base = self.meta_tensor(
|
|
|
t.base,
|
|
|
shape_env,
|
|
|
callback,
|
|
|
torch._dynamo.source.AttrSource(source, "_base"),
|
|
|
base_symbolic_context,
|
|
|
)
|
|
|
|
|
|
def is_c_of_r(
|
|
|
complex_dtype: torch.dtype, real_dtype: torch.dtype
|
|
|
) -> bool:
|
|
|
return (
|
|
|
utils.is_complex_dtype(complex_dtype)
|
|
|
and utils.corresponding_real_dtype(complex_dtype)
|
|
|
== real_dtype
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
|
|
|
torch._C.DispatchKey.ADInplaceOrView
|
|
|
)
|
|
|
torch._C._dispatch_tls_set_dispatch_key_excluded(
|
|
|
torch._C.DispatchKey.ADInplaceOrView, False
|
|
|
)
|
|
|
try:
|
|
|
if base.dtype == t.dtype:
|
|
|
pass
|
|
|
elif is_c_of_r(base.dtype, t.dtype):
|
|
|
base = self._checked_cast_tensor_t(torch.view_as_real(base))
|
|
|
elif is_c_of_r(t.dtype, base.dtype):
|
|
|
base = self._checked_cast_tensor_t(
|
|
|
torch.view_as_complex(base)
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
base = self._checked_cast_tensor_t(base.view(t.dtype))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if t.is_leaf:
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
r = view_from_base(base, t)
|
|
|
|
|
|
r.requires_grad = t.requires_grad
|
|
|
else:
|
|
|
if t.base.requires_grad == t.requires_grad:
|
|
|
|
|
|
with torch.enable_grad():
|
|
|
r = view_from_base(base, t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
assert t.requires_grad
|
|
|
with torch.no_grad(), enable_python_dispatcher():
|
|
|
mid = self._checked_cast_tensor_t(
|
|
|
base.view(base.shape)
|
|
|
)
|
|
|
mid.requires_grad = t.requires_grad
|
|
|
with torch.enable_grad():
|
|
|
r = view_from_base(mid, t)
|
|
|
|
|
|
|
|
|
|
|
|
assert t.creation_meta is not None
|
|
|
torch._C._autograd._set_creation_meta(r, t.creation_meta)
|
|
|
finally:
|
|
|
torch._C._dispatch_tls_set_dispatch_key_excluded(
|
|
|
torch._C.DispatchKey.ADInplaceOrView, old_exclude
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
is_leaf = t.is_leaf
|
|
|
|
|
|
|
|
|
if (
|
|
|
not (t.is_batchedtensor or t.is_gradtrackingtensor)
|
|
|
and t.is_functorch_wrapped
|
|
|
) or t.is_legacy_batchedtensor:
|
|
|
return NotImplemented
|
|
|
|
|
|
(
|
|
|
sizes,
|
|
|
strides,
|
|
|
storage_offset,
|
|
|
) = sym_sizes_strides_storage_offset(t, source, symbolic_context)
|
|
|
|
|
|
|
|
|
|
|
|
if t.is_traceable_wrapper_subclass:
|
|
|
r = empty_create_subclass(
|
|
|
t, outer_size=sizes, outer_stride=strides
|
|
|
)
|
|
|
else:
|
|
|
r = callback(
|
|
|
lambda: torch.empty_strided(
|
|
|
sizes,
|
|
|
strides,
|
|
|
dtype=t.dtype,
|
|
|
device="meta",
|
|
|
)
|
|
|
)
|
|
|
if self.copy_data:
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
assert t.size is not None
|
|
|
assert t.stride is not None
|
|
|
assert _is_fake_tensor(r)
|
|
|
r.real_tensor = torch.empty_strided(
|
|
|
t.size, t.stride, dtype=t.dtype, device=t.device
|
|
|
)
|
|
|
_safe_copy(r.real_tensor, t.data)
|
|
|
|
|
|
assert safe_is_leaf(r), "the callback you passed in doesn't detach"
|
|
|
if t.requires_grad:
|
|
|
r.requires_grad = t.requires_grad
|
|
|
if not is_leaf:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = self._backward_error(r)
|
|
|
|
|
|
s = t.storage
|
|
|
assert s is not None
|
|
|
if s.id not in self.storage_memo and (
|
|
|
r.is_nested
|
|
|
or (
|
|
|
r.stride() == strides
|
|
|
and r.storage_offset() == storage_offset
|
|
|
)
|
|
|
):
|
|
|
|
|
|
self.set_storage_memo(s, r.untyped_storage())
|
|
|
if self.copy_data:
|
|
|
assert _is_fake_tensor(r)
|
|
|
assert r.real_tensor is not None
|
|
|
_set_real_storage(
|
|
|
r.untyped_storage(), r.real_tensor.untyped_storage()
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r_s = self.meta_storage(s, callback=callback)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
maybe_fake_mgr: AbstractContextManager[
|
|
|
None
|
|
|
] = contextlib.nullcontext()
|
|
|
from torch._subclasses.fake_tensor import (
|
|
|
in_kernel_invocation_manager,
|
|
|
maybe_get_fake_mode,
|
|
|
)
|
|
|
|
|
|
mb_fake_mode = maybe_get_fake_mode(r)
|
|
|
if mb_fake_mode is not None:
|
|
|
maybe_fake_mgr = in_kernel_invocation_manager(mb_fake_mode)
|
|
|
with torch.no_grad(), maybe_suppress():
|
|
|
with maybe_fake_mgr:
|
|
|
r.set_(r_s, storage_offset, sizes, strides)
|
|
|
if self.copy_data:
|
|
|
with torch.no_grad(), no_dispatch():
|
|
|
assert _is_fake_tensor(r)
|
|
|
assert r.real_tensor is not None
|
|
|
assert t.stride is not None
|
|
|
r.real_tensor.set_(
|
|
|
_get_real_storage(r_s),
|
|
|
t.storage_offset,
|
|
|
t.size,
|
|
|
t.stride,
|
|
|
)
|
|
|
|
|
|
if t.grad is not None:
|
|
|
from torch._dynamo.source import AttrSource
|
|
|
|
|
|
|
|
|
|
|
|
r.grad = self.meta_tensor(
|
|
|
t.grad,
|
|
|
shape_env,
|
|
|
callback,
|
|
|
AttrSource(source, "grad"),
|
|
|
symbolic_context,
|
|
|
)
|
|
|
torch._C._set_conj(r, t.is_conj)
|
|
|
torch._C._set_neg(r, t.is_neg)
|
|
|
|
|
|
skip_leaf = (
|
|
|
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
|
|
|
)
|
|
|
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
|
|
|
|
|
|
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
|
|
|
|
if t.storage is not None and guard_or_false(t.storage.size == 0):
|
|
|
r.untyped_storage().resize_(0)
|
|
|
|
|
|
if t.is_parameter:
|
|
|
r._is_param = True
|
|
|
|
|
|
|
|
|
if t.nested_int is not None:
|
|
|
assert _is_fake_tensor(r)
|
|
|
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
|
|
|
nt_tensor_id=t.nested_int
|
|
|
)
|
|
|
|
|
|
self.set_tensor_memo(t, r)
|
|
|
|
|
|
return self._checked_get_tensor_memo(t)
|
|
|
|
|
|
def __call__(
|
|
|
self,
|
|
|
t: torch.Tensor,
|
|
|
shape_env: Optional[ShapeEnv] = None,
|
|
|
*,
|
|
|
callback: Optional[_MetaTensorCallback[_TensorT]] = None,
|
|
|
source: Optional[Source] = None,
|
|
|
symbolic_context: Optional[SymbolicContext] = None,
|
|
|
|
|
|
|
|
|
|
|
|
trace: bool = True,
|
|
|
) -> _TensorT:
|
|
|
callback_: _MetaTensorCallback[_TensorT]
|
|
|
if callback is None:
|
|
|
callback_ = self._identity_callable
|
|
|
else:
|
|
|
callback_ = callback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(t, torch.Tensor):
|
|
|
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
t.device.type == "lazy"
|
|
|
or
|
|
|
|
|
|
t.is_quantized
|
|
|
or
|
|
|
|
|
|
|
|
|
(t._is_view() and t._base is not None and t._base.is_sparse)
|
|
|
):
|
|
|
self.miss += 1
|
|
|
return NotImplemented
|
|
|
else:
|
|
|
self.hit += 1
|
|
|
elif torch.overrides.is_tensor_like(t):
|
|
|
self.miss += 1
|
|
|
return NotImplemented
|
|
|
else:
|
|
|
|
|
|
return t
|
|
|
|
|
|
if source is None:
|
|
|
trace = False
|
|
|
|
|
|
|
|
|
|
|
|
t_desc = self.describer.describe_tensor(t, trace=trace)
|
|
|
|
|
|
if trace:
|
|
|
assert source is not None
|
|
|
trace_structured(
|
|
|
"describe_source",
|
|
|
metadata_fn=lambda: {
|
|
|
"describer_id": self.describer.id,
|
|
|
"id": t_desc.id,
|
|
|
"source": source.name(),
|
|
|
},
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with contextlib.ExitStack() as exit_stack:
|
|
|
exit_stack.enter_context(torch._dispatch.python.suspend_functionalization())
|
|
|
st = peek_interpreter_stack()
|
|
|
if st is not None:
|
|
|
exit_stack.enter_context(
|
|
|
torch._functorch.pyfunctorch.temporarily_clear_interpreter_stack()
|
|
|
)
|
|
|
|
|
|
r = self.meta_tensor(
|
|
|
t_desc,
|
|
|
shape_env,
|
|
|
callback_,
|
|
|
source,
|
|
|
symbolic_context,
|
|
|
)
|
|
|
|
|
|
if type(t) is torch.nn.Parameter:
|
|
|
|
|
|
|
|
|
r._is_param = True
|
|
|
|
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
import torch._prims_common as utils
|
|
|
|