|
|
import contextlib |
|
|
import functools |
|
|
import itertools |
|
|
import sys |
|
|
import warnings |
|
|
import weakref |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union |
|
|
|
|
|
import torch |
|
|
from torch._ops import OpOverload |
|
|
from torch._subclasses.meta_utils import MetaConverter, WeakTensorRefKey |
|
|
from torch.fx.operator_schemas import normalize_function |
|
|
from torch.multiprocessing.reductions import StorageWeakRef |
|
|
from torch.overrides import TorchFunctionMode |
|
|
from torch.utils._mode_utils import no_dispatch |
|
|
from torch.utils._python_dispatch import TorchDispatchMode |
|
|
|
|
|
from torch.utils._pytree import PyTree, tree_flatten, tree_map |
|
|
|
|
|
pytree = torch.utils._pytree |
|
|
T = TypeVar("T") |
|
|
TensorWeakRef = Any |
|
|
|
|
|
aten = torch.ops.aten |
|
|
|
|
|
CONSTANT_NUMEL_LIMIT = 1 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class UnsupportedFakeTensorException(RuntimeError): |
|
|
reason: str |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DynamicOutputShapeException(RuntimeError): |
|
|
func: OpOverload |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataDependentOutputException(RuntimeError): |
|
|
func: OpOverload |
|
|
|
|
|
|
|
|
_device_not_kwarg_ops = ( |
|
|
aten._resize_output_.default, |
|
|
aten._nested_tensor_from_tensor_list.default, |
|
|
aten._nested_tensor_from_tensor_list.out, |
|
|
aten.pin_memory.default, |
|
|
aten.is_pinned.default, |
|
|
aten.to.device, |
|
|
aten.to.prim_Device, |
|
|
aten._pin_memory.default, |
|
|
aten._pin_memory.out, |
|
|
aten._resize_output.default, |
|
|
aten._resize_output.out, |
|
|
) |
|
|
|
|
|
|
|
|
_non_kwarg_device_constructors = (aten._list_to_tensor,) |
|
|
|
|
|
|
|
|
def contains_tensor_types(type): |
|
|
tensor_type = torch._C.TensorType.get() |
|
|
return type.isSubtypeOf(tensor_type) or any( |
|
|
contains_tensor_types(e) for e in type.containedTypes() |
|
|
) |
|
|
|
|
|
|
|
|
_like_tensor_constructors = ( |
|
|
aten.empty_like.default, |
|
|
aten.empty_like.out, |
|
|
aten.full_like.default, |
|
|
aten.full_like.out, |
|
|
aten.ones_like.default, |
|
|
aten.ones_like.out, |
|
|
aten.rand_like.default, |
|
|
aten.rand_like.out, |
|
|
aten.randn_like.default, |
|
|
aten.randn_like.out, |
|
|
aten.randint_like.default, |
|
|
aten.randint_like.out, |
|
|
aten.randint_like.low_dtype, |
|
|
aten.randint_like.low_dtype_out, |
|
|
aten.zeros_like.default, |
|
|
aten.zeros_like.out, |
|
|
aten.new_empty.default, |
|
|
aten.new_empty.out, |
|
|
aten.new_empty_strided.default, |
|
|
aten.new_empty_strided.out, |
|
|
aten.new_full.default, |
|
|
aten.new_full.out, |
|
|
aten.new_zeros.default, |
|
|
aten.new_zeros.out, |
|
|
aten.new_ones.default, |
|
|
aten.new_ones.out, |
|
|
) |
|
|
|
|
|
|
|
|
@functools.lru_cache(None) |
|
|
def _is_tensor_constructor(func: OpOverload): |
|
|
assert isinstance(func, OpOverload) |
|
|
schema = func._schema |
|
|
if any(contains_tensor_types(arg.type) for arg in schema.arguments): |
|
|
return False |
|
|
|
|
|
return ( |
|
|
len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get() |
|
|
) |
|
|
|
|
|
|
|
|
@functools.lru_cache(None) |
|
|
def get_schema_info(func): |
|
|
return torch._C._SchemaInfo(func._schema) |
|
|
|
|
|
|
|
|
def tree_flatten_only(ty: Type[T], pytree: PyTree): |
|
|
flat_vals, _ = tree_flatten(pytree) |
|
|
return [elem for elem in flat_vals if isinstance(elem, ty)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeTensorConverter(object): |
|
|
tensor_memo: weakref.WeakValueDictionary |
|
|
meta_converter: MetaConverter |
|
|
constant_storage_mapping: Dict[StorageWeakRef, List[TensorWeakRef]] |
|
|
|
|
|
def __init__(self): |
|
|
|
|
|
|
|
|
|
|
|
self.tensor_memo = weakref.WeakValueDictionary() |
|
|
self.meta_converter = MetaConverter() |
|
|
|
|
|
|
|
|
self.constant_storage_mapping = {} |
|
|
|
|
|
def add_constant_storage_mapping(self, fake_tensor): |
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(fake_tensor, FakeTensor) and fake_tensor.constant is not None |
|
|
weak_st = StorageWeakRef(fake_tensor.constant.storage()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if weak_st not in self.constant_storage_mapping: |
|
|
self.constant_storage_mapping[weak_st] = [] |
|
|
self.constant_storage_mapping[weak_st].append(weakref.ref(fake_tensor)) |
|
|
|
|
|
def invalidate_constant_aliases(self, tensor): |
|
|
assert not isinstance(tensor, FakeTensor) |
|
|
|
|
|
weak_st = StorageWeakRef(tensor.storage()) |
|
|
if weak_st not in self.constant_storage_mapping: |
|
|
return |
|
|
|
|
|
for weak_tensor_ref in self.constant_storage_mapping[weak_st]: |
|
|
ten = weak_tensor_ref() |
|
|
if ten is not None: |
|
|
ten._fix_weakref() |
|
|
ten.constant = None |
|
|
|
|
|
del self.constant_storage_mapping[weak_st] |
|
|
|
|
|
def _get_memo(self, t): |
|
|
if WeakTensorRefKey(t) in self.tensor_memo: |
|
|
out = self.tensor_memo[WeakTensorRefKey(t)] |
|
|
out._fix_weakref() |
|
|
return out |
|
|
return None |
|
|
|
|
|
def set_tensor_memo(self, t, v): |
|
|
th = WeakTensorRefKey(t) |
|
|
|
|
|
|
|
|
|
|
|
self_weak_ref = weakref.ref(self) |
|
|
|
|
|
def del_ten(): |
|
|
self_ref = self_weak_ref() |
|
|
if self_ref is None: |
|
|
return |
|
|
|
|
|
self_ref.tensor_memo.pop(th, None) |
|
|
|
|
|
weakref.finalize(t, del_ten) |
|
|
self.tensor_memo[th] = v |
|
|
|
|
|
def from_real_tensor(self, fake_mode, t, make_constant=False, shape_env=None): |
|
|
maybe_memo = self._get_memo(t) |
|
|
if maybe_memo is not None: |
|
|
return maybe_memo |
|
|
existing_device = t.device |
|
|
|
|
|
if t.is_quantized: |
|
|
raise UnsupportedFakeTensorException("quantized nyi in meta tensors") |
|
|
with no_dispatch(): |
|
|
meta_t = self.meta_converter(t, shape_env=shape_env) |
|
|
if meta_t.device.type != "meta": |
|
|
raise UnsupportedFakeTensorException("meta converter nyi") |
|
|
out = FakeTensor( |
|
|
fake_mode, |
|
|
meta_t, |
|
|
existing_device, |
|
|
constant=t if make_constant else None, |
|
|
) |
|
|
out.requires_grad_(t.requires_grad) |
|
|
if make_constant: |
|
|
self.add_constant_storage_mapping(out) |
|
|
if type(t) is torch.nn.Parameter: |
|
|
assert not make_constant |
|
|
out = torch.nn.Parameter(out, requires_grad=out.requires_grad) |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor") |
|
|
grad_not_none = t.grad is not None |
|
|
if grad_not_none: |
|
|
out.grad = self.from_real_tensor(fake_mode, t.grad) |
|
|
self.set_tensor_memo(t, out) |
|
|
return out |
|
|
|
|
|
def from_meta_and_device(self, fake_mode, t, device): |
|
|
maybe_memo = self._get_memo(t) |
|
|
if maybe_memo is not None: |
|
|
return maybe_memo |
|
|
out = FakeTensor(fake_mode, t, device) |
|
|
self.set_tensor_memo(t, out) |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__( |
|
|
self, fake_mode, t, device=None, *, make_constant=False, shape_env=None |
|
|
): |
|
|
if device is None: |
|
|
return self.from_real_tensor( |
|
|
fake_mode, t, make_constant, shape_env=shape_env |
|
|
) |
|
|
else: |
|
|
assert make_constant is False |
|
|
assert t.device.type == "meta" |
|
|
return self.from_meta_and_device(fake_mode, t, device) |
|
|
|
|
|
|
|
|
op_implementations = [] |
|
|
|
|
|
|
|
|
def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]): |
|
|
def impl_decorator(op_impl): |
|
|
global op_implementations |
|
|
if isinstance(run_impl_check, OpOverload): |
|
|
op_implementations.append((lambda func: func == run_impl_check, op_impl)) |
|
|
else: |
|
|
op_implementations.append((run_impl_check, op_impl)) |
|
|
|
|
|
return op_impl |
|
|
|
|
|
return impl_decorator |
|
|
|
|
|
|
|
|
@register_op_impl( |
|
|
lambda func: (_is_tensor_constructor(func) or func in _like_tensor_constructors) |
|
|
) |
|
|
def constructors(fake_mode, func, *args, **kwargs): |
|
|
assert func not in _non_kwarg_device_constructors |
|
|
_, new_kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
if func in _like_tensor_constructors: |
|
|
default_device = new_kwargs["input"].device |
|
|
|
|
|
args = (new_kwargs.pop("input"),) |
|
|
else: |
|
|
|
|
|
default_device = torch.device("cpu") |
|
|
args = () |
|
|
out_device = new_kwargs.pop("device", None) |
|
|
out_device = out_device if out_device is not None else default_device |
|
|
new_kwargs["device"] = torch.device("meta") |
|
|
r = func(*args, **new_kwargs) |
|
|
return FakeTensor(fake_mode, r, out_device) |
|
|
|
|
|
|
|
|
@register_op_impl(lambda func: func in (aten.to.prim_Device, aten.to.device)) |
|
|
def non_kwarg_to(fake_mode, func, *args, **kwargs): |
|
|
_, new_kwargs = normalize_function( |
|
|
func, args, kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
input_device = new_kwargs["device"] |
|
|
out_device = input_device if input_device else new_kwargs["input"].device |
|
|
new_kwargs["device"] = torch.device("meta") |
|
|
r = func(*args, **new_kwargs) |
|
|
return fake_mode.fake_tensor_converter(fake_mode, r, out_device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl(aten.resize_as_.default) |
|
|
def resize_as_(fake_mode, func, *args, **kwargs): |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
|
|
|
@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default) |
|
|
def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs): |
|
|
|
|
|
return constructors(fake_mode, func, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl(aten._to_copy.default) |
|
|
def to_copy(fake_mode, func, *args, **kwargs): |
|
|
_, new_kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
|
|
|
input_device = new_kwargs.pop("device", None) |
|
|
out_device = input_device if input_device else new_kwargs["input"].device |
|
|
with no_dispatch(), in_kernel_invocation_manager(fake_mode): |
|
|
input = new_kwargs.pop("input").to("meta") |
|
|
return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device) |
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl( |
|
|
lambda func: torch.Tag.dynamic_output_shape in func.tags |
|
|
and func != aten.index.Tensor |
|
|
) |
|
|
def dyn_shape(fake_mode, func, *args, **kwargs): |
|
|
raise DynamicOutputShapeException(func) |
|
|
|
|
|
|
|
|
@register_op_impl( |
|
|
lambda func: torch.Tag.data_dependent_output in func.tags |
|
|
) |
|
|
def data_dep(fake_mode, func, *args, **kwargs): |
|
|
if fake_mode.throw_on_data_dependent_ops: |
|
|
raise DataDependentOutputException(func) |
|
|
return NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_no_bool_index_tensors(func, self, indices): |
|
|
for index in indices: |
|
|
if index is not None and index.dtype in (torch.bool, torch.uint8): |
|
|
raise DynamicOutputShapeException(func) |
|
|
|
|
|
|
|
|
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs): |
|
|
_, new_kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
|
|
|
out_device = new_kwargs["input"].device |
|
|
with in_kernel_invocation_manager(fake_mode): |
|
|
out = func(*args, **kwargs) |
|
|
|
|
|
return FakeTensor(fake_mode, out, out_device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl(aten.index.Tensor) |
|
|
def index_tensor(fake_mode, func, *args, **kwargs): |
|
|
|
|
|
check_no_bool_index_tensors(func, *args, **kwargs) |
|
|
|
|
|
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) |
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl(aten.index_put.default) |
|
|
def index_put(fake_mode, func, *args, **kwargs): |
|
|
return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs) |
|
|
|
|
|
|
|
|
|
|
|
@register_op_impl(aten.index_put_.default) |
|
|
def index_put_(fake_mode, func, *args, **kwargs): |
|
|
with in_kernel_invocation_manager(fake_mode): |
|
|
out = func(*args, **kwargs) |
|
|
|
|
|
_, new_kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
|
|
|
return new_kwargs["input"] |
|
|
|
|
|
|
|
|
@register_op_impl(lambda fn: fn in _device_not_kwarg_ops) |
|
|
def nyi(fake_mode, func, *args, **kwargs): |
|
|
assert func not in _device_not_kwarg_ops, f"NYI: {func}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
def in_kernel_invocation_manager(fake_mode): |
|
|
|
|
|
meta_in_tls = torch._C._meta_in_tls_dispatch_include() |
|
|
prev = fake_mode.in_kernel_invocation |
|
|
|
|
|
fake_mode.in_kernel_invocation = True |
|
|
if not meta_in_tls: |
|
|
torch._C._add_meta_to_tls_dispatch_include() |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
fake_mode.in_kernel_invocation = prev |
|
|
if not meta_in_tls: |
|
|
torch._C._remove_meta_from_tls_dispatch_include() |
|
|
|
|
|
|
|
|
class FakeTensor(torch.Tensor): |
|
|
fake_device: torch.device |
|
|
fake_mode: "FakeTensorMode" |
|
|
constant: Optional[torch.Tensor] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def __new__(cls, fake_mode, elem, device, constant=None): |
|
|
return torch.Tensor._make_subclass( |
|
|
cls, |
|
|
elem, |
|
|
elem.requires_grad, |
|
|
dispatch_device=True, |
|
|
device_for_backend_keys=device, |
|
|
) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
fake_mode, |
|
|
elem, |
|
|
device: Union[torch.device, str], |
|
|
constant: Optional[torch.Tensor] = None, |
|
|
): |
|
|
assert elem.device.type == "meta", elem.device.type |
|
|
device = device if isinstance(device, torch.device) else torch.device(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not fake_mode.allow_meta: |
|
|
assert device.type != "meta" |
|
|
|
|
|
if device.type == "cuda" and device.index is None: |
|
|
device = torch.device(f"cuda:{torch.cuda.current_device()}") |
|
|
self.fake_device = device |
|
|
self.fake_mode = fake_mode |
|
|
self.constant = constant |
|
|
|
|
|
@staticmethod |
|
|
def from_tensor(t, fake_mode): |
|
|
return fake_mode.from_tensor(t) |
|
|
|
|
|
|
|
|
def __repr__(self): |
|
|
with in_kernel_invocation_manager(self.fake_mode): |
|
|
self_repr = super().__repr__() |
|
|
return f"FakeTensor({self_repr}, {self.fake_device})" |
|
|
|
|
|
@classmethod |
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
|
|
|
|
|
|
|
|
if func == torch.ops.prim.device.default: |
|
|
assert len(args) == 1 and isinstance(args[0], FakeTensor) |
|
|
if args[0].fake_mode.in_kernel_invocation: |
|
|
return torch.device("meta") |
|
|
else: |
|
|
return args[0].fake_device |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if any(not issubclass(t, FakeTensor) and t is not torch.Tensor for t in types): |
|
|
return NotImplemented |
|
|
|
|
|
fake_mode = None |
|
|
for arg in itertools.chain(tree_flatten(args)[0], tree_flatten(kwargs)[0]): |
|
|
if isinstance(arg, FakeTensor): |
|
|
if fake_mode is None: |
|
|
fake_mode = arg.fake_mode |
|
|
else: |
|
|
assert fake_mode is arg.fake_mode, "Mixing modes NYI" |
|
|
|
|
|
assert fake_mode is not None |
|
|
with fake_mode: |
|
|
return func(*args, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def _find_common_device(func, args, kwargs): |
|
|
|
|
|
|
|
|
|
|
|
common_device = None |
|
|
is_cpu_zero_dim = None |
|
|
|
|
|
def cpu_zero_dim(t): |
|
|
return t.device.type == "cpu" and t.dim() == 0 |
|
|
|
|
|
def merge_devices(t): |
|
|
nonlocal common_device |
|
|
nonlocal is_cpu_zero_dim |
|
|
if not isinstance(t, FakeTensor): |
|
|
return |
|
|
|
|
|
if common_device is None: |
|
|
common_device = t.device |
|
|
is_cpu_zero_dim = cpu_zero_dim(t) |
|
|
return |
|
|
|
|
|
t_is_cpu_zero_dim = cpu_zero_dim(t) |
|
|
if t.device == common_device: |
|
|
if is_cpu_zero_dim: |
|
|
is_cpu_zero_dim = t_is_cpu_zero_dim |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
if t_is_cpu_zero_dim: |
|
|
return |
|
|
|
|
|
|
|
|
if is_cpu_zero_dim: |
|
|
common_device = t.device |
|
|
is_cpu_zero_dim = t_is_cpu_zero_dim |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
raise RuntimeError( |
|
|
f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}" |
|
|
) |
|
|
|
|
|
tree_map(merge_devices, args) |
|
|
tree_map(merge_devices, kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
torch._C._should_allow_numbers_as_tensors( |
|
|
func.name().split("::")[-1].split(".")[0] |
|
|
) |
|
|
and common_device is None |
|
|
): |
|
|
common_device = torch.device("cpu") |
|
|
|
|
|
assert common_device is not None, f"Could not find common device for {func}" |
|
|
|
|
|
return common_device |
|
|
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeTensorMode(TorchDispatchMode): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
allow_fallback_kernels=True, |
|
|
allow_meta=False, |
|
|
throw_on_data_dependent_ops=False, |
|
|
): |
|
|
self.allow_fallback_kernels = allow_fallback_kernels |
|
|
self.fake_tensor_converter = FakeTensorConverter() |
|
|
self.allow_meta = allow_meta |
|
|
|
|
|
|
|
|
self.throw_on_data_dependent_ops = throw_on_data_dependent_ops |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.in_kernel_invocation = False |
|
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
|
kwargs = kwargs if kwargs else {} |
|
|
|
|
|
if func == torch.ops.prim.device.default: |
|
|
assert len(args) == 1 and isinstance(args[0], FakeTensor) |
|
|
if args[0].fake_mode.in_kernel_invocation: |
|
|
return torch.device("meta") |
|
|
else: |
|
|
return args[0].fake_device |
|
|
|
|
|
flat_arg_fake_tensors = tree_flatten_only(FakeTensor, (args, kwargs)) |
|
|
flat_symints = tree_flatten_only(torch.SymIntNode, (args, kwargs)) |
|
|
has_symbolic_sizes = ( |
|
|
any([i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors]) |
|
|
or len(flat_symints) > 0 |
|
|
) |
|
|
|
|
|
converter = self.fake_tensor_converter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func in self.lift_fns: |
|
|
out = func(*args, **kwargs) |
|
|
if self.may_turn_const(out): |
|
|
with no_dispatch(): |
|
|
return converter(self, out.clone(), make_constant=True) |
|
|
|
|
|
with no_dispatch(): |
|
|
flat_arg_tensors = tree_flatten_only(torch.Tensor, (args, kwargs)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.check_for_subclass(flat_arg_tensors): |
|
|
return NotImplemented |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if func in self.lift_fns: |
|
|
assert ( |
|
|
len(kwargs) == 0 |
|
|
and len(args) == 1 |
|
|
and type(args[0]) is torch.Tensor |
|
|
), f"{args} {kwargs}" |
|
|
return converter(self, args[0]) |
|
|
|
|
|
if self.check_for_non_fake(flat_arg_tensors): |
|
|
raise Exception( |
|
|
"Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. " |
|
|
f"Please convert all Tensors to FakeTensors first. Found in {func}(*{args}, **{kwargs})" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_constant = all(e.constant is not None for e in flat_arg_fake_tensors) |
|
|
if ( |
|
|
torch.Tag.nondeterministic_seeded not in func.tags |
|
|
and torch.Tag.inplace_view not in func.tags |
|
|
and all_constant |
|
|
and len(flat_arg_fake_tensors) != 0 |
|
|
and not has_symbolic_sizes |
|
|
): |
|
|
with no_dispatch(): |
|
|
const_args, const_kwargs = pytree.tree_map_only( |
|
|
FakeTensor, lambda t: t.constant, (args, kwargs) |
|
|
) |
|
|
out = func(*const_args, **const_kwargs) |
|
|
|
|
|
all_constant = pytree.tree_all_only( |
|
|
torch.Tensor, lambda t: self.may_turn_const(t), out |
|
|
) |
|
|
|
|
|
if all_constant: |
|
|
return pytree.tree_map_only( |
|
|
torch.Tensor, |
|
|
lambda t: converter(self, t, make_constant=True), |
|
|
out, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
for ten in tree_flatten_only(torch.Tensor, out): |
|
|
converter.invalidate_constant_aliases(ten) |
|
|
|
|
|
|
|
|
|
|
|
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs) |
|
|
|
|
|
|
|
|
if ( |
|
|
has_symbolic_sizes |
|
|
and func not in self.functions_with_cpp_meta_impl_that_support_symint |
|
|
): |
|
|
|
|
|
|
|
|
from torch._decomp import decomposition_table |
|
|
from torch._meta_registrations import meta_table |
|
|
|
|
|
with no_dispatch(): |
|
|
if func == aten.size.default: |
|
|
sys.stderr.write( |
|
|
"Trying to call aten.size on a tensor with symbolic shapes. " |
|
|
"It's likely that this is from calling tensor.shape in C++" |
|
|
) |
|
|
|
|
|
return None |
|
|
|
|
|
with self: |
|
|
if func in meta_table: |
|
|
r = meta_table[func](*args, **kwargs) |
|
|
return r |
|
|
if func in decomposition_table: |
|
|
return decomposition_table[func](*args, **kwargs) |
|
|
|
|
|
|
|
|
r = func.decompose(*args, **kwargs) |
|
|
if r is not NotImplemented: |
|
|
return r |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
|
"prims::" in func._schema.name |
|
|
and len(flat_arg_fake_tensors) != 0 |
|
|
and hasattr(func, "prim_meta_impl") |
|
|
): |
|
|
with self: |
|
|
return func.prim_meta_impl(*args, **kwargs) |
|
|
|
|
|
if has_symbolic_sizes: |
|
|
if func not in self.functions_with_cpp_meta_impl_that_support_symint: |
|
|
raise RuntimeError( |
|
|
f"{func} - couldn't find symbolic meta function/decomposition" |
|
|
) |
|
|
|
|
|
with no_dispatch(): |
|
|
|
|
|
|
|
|
|
|
|
for run_impl_check, op_impl in op_implementations: |
|
|
if run_impl_check(func): |
|
|
op_impl_out = op_impl(self, func, *args, **kwargs) |
|
|
if op_impl_out != NotImplemented: |
|
|
return op_impl_out |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with in_kernel_invocation_manager(self): |
|
|
r = func(*args, **kwargs) |
|
|
except NotImplementedError as not_implemented_error: |
|
|
|
|
|
if not self.allow_fallback_kernels: |
|
|
raise not_implemented_error |
|
|
return run_fallback_kernel( |
|
|
self, func, args, kwargs, not_implemented_error |
|
|
) |
|
|
|
|
|
return self.wrap_meta_outputs_with_default_device_logic( |
|
|
r, func, args, kwargs |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_for_subclass(self, flat_arg_tensors): |
|
|
return any( |
|
|
not isinstance(x, FakeTensor) |
|
|
and type(x) is not torch.Tensor |
|
|
and type(x) is not torch.nn.Parameter |
|
|
for x in flat_arg_tensors |
|
|
) |
|
|
|
|
|
def check_for_non_fake(self, flat_arg_tensors): |
|
|
return any( |
|
|
isinstance(x, torch.Tensor) and not isinstance(x, FakeTensor) |
|
|
for x in flat_arg_tensors |
|
|
) |
|
|
|
|
|
def wrap_meta_outputs_with_default_device_logic(self, r, func, args, kwargs): |
|
|
wrap = self.gen_wrap_fn(func, args, kwargs) |
|
|
|
|
|
|
|
|
if kwargs.get("device", None): |
|
|
return tree_map(partial(wrap, device=kwargs["device"]), r) |
|
|
|
|
|
return tree_map(partial(wrap), r) |
|
|
|
|
|
def gen_wrap_fn(self, func, args, kwargs): |
|
|
converter = self.fake_tensor_converter |
|
|
|
|
|
|
|
|
common_device = None |
|
|
|
|
|
def wrap(e, device=None): |
|
|
nonlocal common_device |
|
|
if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor): |
|
|
if common_device is None: |
|
|
common_device = FakeTensor._find_common_device(func, args, kwargs) |
|
|
return converter(self, e, device or common_device) |
|
|
else: |
|
|
return e |
|
|
|
|
|
return wrap |
|
|
|
|
|
@property |
|
|
def functions_with_cpp_meta_impl_that_support_symint(self): |
|
|
return [ |
|
|
aten.empty_strided.default, |
|
|
aten.as_strided.default, |
|
|
aten.zeros.default, |
|
|
aten.detach.default, |
|
|
] |
|
|
|
|
|
@property |
|
|
def lift_fns(self): |
|
|
return (aten.lift_fresh.default, aten.lift_fresh_copy.default) |
|
|
|
|
|
def may_turn_const(self, t): |
|
|
return ( |
|
|
t.numel() <= CONSTANT_NUMEL_LIMIT |
|
|
and not t.is_sparse |
|
|
and not isinstance(t, FakeTensor) |
|
|
) |
|
|
|
|
|
def invalidate_written_to_constants( |
|
|
self, func, flat_arg_fake_tensors, args, kwargs |
|
|
): |
|
|
any_constant = any(e.constant is not None for e in flat_arg_fake_tensors) |
|
|
if any_constant and get_schema_info(func).is_mutable(): |
|
|
schema_info = get_schema_info(func) |
|
|
_, new_kwargs = normalize_function( |
|
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True |
|
|
) |
|
|
for k, v in new_kwargs.items(): |
|
|
k = k if (k != "input" or schema_info.has_argument(k)) else "self" |
|
|
if ( |
|
|
isinstance(v, FakeTensor) |
|
|
and schema_info.is_mutable(k) |
|
|
and v.constant is not None |
|
|
): |
|
|
self.fake_tensor_converter.invalidate_constant_aliases(v.constant) |
|
|
|
|
|
def from_tensor(self, tensor, shape_env=None): |
|
|
return self.fake_tensor_converter(self, tensor, shape_env=shape_env) |
|
|
|
|
|
|
|
|
|
|
|
def run_fallback_kernel(fake_mode, func, args, kwargs, orig_not_implemented_exception): |
|
|
|
|
|
|
|
|
|
|
|
if torch.Tag.inplace_view in func.tags: |
|
|
raise orig_not_implemented_exception |
|
|
|
|
|
with no_dispatch(): |
|
|
inp_impls = {} |
|
|
|
|
|
def to_real_tensor(e): |
|
|
if isinstance(e, FakeTensor): |
|
|
out = torch.zeros_like(e, device=e.fake_device) |
|
|
if e.is_sparse: |
|
|
out._coalesced_(e.is_coalesced()) |
|
|
inp_impls[id(out)] = e |
|
|
return out |
|
|
return e |
|
|
|
|
|
args = tree_map(to_real_tensor, args) |
|
|
kwargs = tree_map(to_real_tensor, kwargs) |
|
|
|
|
|
r = func(*args, **kwargs) |
|
|
|
|
|
tensor_impls = set() |
|
|
storages = set() |
|
|
|
|
|
for e in tree_flatten((args, kwargs))[0]: |
|
|
if isinstance(e, torch.Tensor): |
|
|
if not e.is_sparse: |
|
|
storages.add(e.storage()._cdata) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for e in tree_flatten(r)[0]: |
|
|
if id(e) not in inp_impls and ( |
|
|
isinstance(e, torch.Tensor) |
|
|
and not e.is_sparse |
|
|
and e.storage()._cdata in storages |
|
|
): |
|
|
raise orig_not_implemented_exception |
|
|
|
|
|
def map_out(e): |
|
|
if isinstance(e, torch.Tensor): |
|
|
if id(e) in inp_impls: |
|
|
return inp_impls[id(e)] |
|
|
else: |
|
|
return fake_mode.fake_tensor_converter(fake_mode, e) |
|
|
else: |
|
|
return e |
|
|
|
|
|
return tree_map(map_out, r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeCopyMode(TorchFunctionMode): |
|
|
def __init__(self, fake_mode): |
|
|
self.fake_mode = fake_mode |
|
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None): |
|
|
kwargs = kwargs if kwargs else {} |
|
|
|
|
|
|
|
|
if func == torch._C._TensorBase.clone: |
|
|
return func(self.fake_mode.from_tensor(args[0]), **kwargs) |
|
|
elif func == torch.Tensor.__deepcopy__: |
|
|
assert len(args) == 2 and len(kwargs) == 0 |
|
|
tensor, memo = args |
|
|
|
|
|
if id(tensor) in memo: |
|
|
return memo[id(tensor)] |
|
|
|
|
|
out = self.fake_mode.from_tensor(tensor) |
|
|
memo[id(tensor)] = out |
|
|
return out |
|
|
else: |
|
|
with torch._C.DisableTorchFunction(): |
|
|
return func(*args, **kwargs) |
|
|
|