|
|
import weakref |
|
|
|
|
|
import torch |
|
|
from torch.multiprocessing.reductions import StorageWeakRef |
|
|
from torch.utils._mode_utils import no_dispatch |
|
|
|
|
|
|
|
|
def safe_is_leaf(t): |
|
|
try: |
|
|
return t.is_leaf |
|
|
except RuntimeError: |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WeakTensorRefKey(object): |
|
|
def __init__(self, ten): |
|
|
self.ten = weakref.ref(ten) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.id = id(self.ten()) |
|
|
|
|
|
def __hash__(self): |
|
|
return self.id |
|
|
|
|
|
def __eq__(self, other): |
|
|
if id(self) == id(other): |
|
|
return True |
|
|
return self.id == other.id |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MetaConverter: |
|
|
def __init__(self): |
|
|
self.storage_memo = {} |
|
|
self.tensor_memo = {} |
|
|
self.maybe_storages_to_delete = [] |
|
|
self.check_expired_frequency = 128 |
|
|
self.check_expired_count = 0 |
|
|
self.hit = 0 |
|
|
self.miss = 0 |
|
|
self.del_hook = None |
|
|
self.arg_cnt = 0 |
|
|
|
|
|
def successful(self): |
|
|
return self.hit > 0 and self.miss == 0 |
|
|
|
|
|
def check_for_expired_weak_storages(self): |
|
|
new_li = [] |
|
|
stor_to_delete = [] |
|
|
for obj in self.maybe_storages_to_delete: |
|
|
if not obj.expired(): |
|
|
new_li.append(obj) |
|
|
else: |
|
|
stor_to_delete.append(obj) |
|
|
for obj in stor_to_delete: |
|
|
self.storage_memo.pop(obj, None) |
|
|
self.maybe_storages_to_delete = new_li |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.check_expired_frequency = max( |
|
|
self.check_expired_frequency, len(self.maybe_storages_to_delete) |
|
|
) |
|
|
|
|
|
def get_tensor_memo(self, t): |
|
|
return self.tensor_memo.get(WeakTensorRefKey(t), None) |
|
|
|
|
|
def set_tensor_memo(self, t, v): |
|
|
|
|
|
|
|
|
self_weak_ref = weakref.ref(self) |
|
|
if t.is_sparse: |
|
|
weak_st = None |
|
|
else: |
|
|
weak_st = StorageWeakRef(t.storage()) |
|
|
tensor_ref_key = WeakTensorRefKey(t) |
|
|
|
|
|
def del_ten(): |
|
|
|
|
|
self_ref = self_weak_ref() |
|
|
if self_ref is None: |
|
|
return |
|
|
|
|
|
self_ref.tensor_memo.pop(tensor_ref_key, None) |
|
|
if weak_st and weak_st.expired(): |
|
|
self_ref.storage_memo.pop(weak_st, None) |
|
|
elif weak_st is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_ref.maybe_storages_to_delete.append(weak_st) |
|
|
|
|
|
weakref.finalize(t, del_ten) |
|
|
self.tensor_memo[tensor_ref_key] = v |
|
|
|
|
|
|
|
|
|
|
|
def meta_storage(self, s): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
swr = StorageWeakRef(s) |
|
|
if swr not in self.storage_memo: |
|
|
self.storage_memo[swr] = torch.empty(s.size(), dtype=s.dtype, device="meta") |
|
|
return self.storage_memo[swr] |
|
|
|
|
|
|
|
|
def meta_tensor(self, t, shape_env=None): |
|
|
arg_cnt = self.arg_cnt |
|
|
self.arg_cnt += 1 |
|
|
|
|
|
|
|
|
|
|
|
make_symbolic = shape_env is not None and not isinstance(t, torch.nn.Parameter) |
|
|
|
|
|
def sym(name, x): |
|
|
if make_symbolic: |
|
|
return shape_env.create_symint(f"t{arg_cnt}.{name}()", x) |
|
|
else: |
|
|
return x |
|
|
|
|
|
def sym_list(name, xs): |
|
|
if make_symbolic: |
|
|
return [ |
|
|
shape_env.create_symint(f"t{arg_cnt}.{name}({i})", x) |
|
|
for i, x in enumerate(xs) |
|
|
] |
|
|
else: |
|
|
return xs |
|
|
|
|
|
def sym_size(t): |
|
|
return sym_list("size", t.size()) |
|
|
|
|
|
def sym_stride(t): |
|
|
return sym_list("stride", t.stride()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sym_storage_offset(t): |
|
|
return sym("storage_offset", t.storage_offset()) |
|
|
|
|
|
|
|
|
self.check_expired_count += 1 |
|
|
if self.check_expired_count >= self.check_expired_frequency: |
|
|
self.check_for_expired_weak_storages() |
|
|
self.check_expired_count = 0 |
|
|
|
|
|
if self.get_tensor_memo(t) is None: |
|
|
with torch.inference_mode(t.is_inference()): |
|
|
if t.is_sparse: |
|
|
assert shape_env is None, "symbolic on sparse NYI" |
|
|
is_leaf = safe_is_leaf(t) |
|
|
r = torch.ops.aten._sparse_coo_tensor_with_dims( |
|
|
t.sparse_dim(), |
|
|
t.dense_dim(), |
|
|
t.shape, |
|
|
dtype=t.dtype, |
|
|
layout=torch.sparse_coo, |
|
|
device="meta", |
|
|
) |
|
|
r._coalesced_(t.is_coalesced()) |
|
|
if t.requires_grad: |
|
|
r.requires_grad = True |
|
|
if t.requires_grad and not is_leaf: |
|
|
with torch.enable_grad(): |
|
|
r = r.clone() |
|
|
r._coalesced_(t.is_coalesced()) |
|
|
|
|
|
elif t._is_view(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert t._is_view() |
|
|
base = self.meta_tensor(t._base) |
|
|
|
|
|
def is_c_of_r(complex_dtype, real_dtype): |
|
|
return ( |
|
|
utils.is_complex_dtype(complex_dtype) |
|
|
and utils.corresponding_real_dtype(complex_dtype) |
|
|
== real_dtype |
|
|
) |
|
|
|
|
|
if base.dtype == t.dtype: |
|
|
pass |
|
|
elif is_c_of_r(base.dtype, t.dtype): |
|
|
base = torch.view_as_real(base) |
|
|
elif is_c_of_r(t.dtype, base.dtype): |
|
|
base = torch.view_as_complex(base) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
base = base.view(t.dtype) |
|
|
|
|
|
with torch.enable_grad(): |
|
|
r = base.as_strided( |
|
|
sym_size(t), sym_stride(t), sym_storage_offset(t) |
|
|
) |
|
|
else: |
|
|
is_leaf = safe_is_leaf(t) |
|
|
|
|
|
if t.requires_grad: |
|
|
r = torch.empty( |
|
|
(0,), dtype=t.dtype, device="meta", requires_grad=True |
|
|
) |
|
|
if not is_leaf: |
|
|
with torch.enable_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r = r.clone() |
|
|
else: |
|
|
r = torch.empty((0,), dtype=t.dtype, device="meta") |
|
|
|
|
|
|
|
|
|
|
|
s = self.meta_storage(t.storage()) |
|
|
with no_dispatch(): |
|
|
with torch.no_grad(): |
|
|
r.set_(s, sym_storage_offset(t), sym_size(t), sym_stride(t)) |
|
|
|
|
|
torch._C._set_conj(r, t.is_conj()) |
|
|
torch._C._set_neg(r, t.is_neg()) |
|
|
self.set_tensor_memo(t, r) |
|
|
|
|
|
return self.get_tensor_memo(t) |
|
|
|
|
|
def __call__(self, t, shape_env=None): |
|
|
|
|
|
|
|
|
from torch._subclasses.fake_tensor import FakeTensor |
|
|
|
|
|
if ( |
|
|
type(t) is torch.Tensor |
|
|
or type(t) is torch.nn.Parameter |
|
|
or isinstance(t, FakeTensor) |
|
|
): |
|
|
if any( |
|
|
[ |
|
|
t.is_sparse_csr, |
|
|
t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc], |
|
|
t.is_mkldnn, |
|
|
t.is_quantized, |
|
|
t.is_nested, |
|
|
t._is_view() and t._base is not None and t._base.is_sparse, |
|
|
torch._is_functional_tensor(t), |
|
|
|
|
|
|
|
|
t.is_neg(), |
|
|
t.is_conj(), |
|
|
t.device.type in ("lazy", "meta"), |
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.miss += 1 |
|
|
return t |
|
|
else: |
|
|
self.hit += 1 |
|
|
r = self.meta_tensor(t, shape_env=shape_env) |
|
|
if type(t) is torch.nn.Parameter: |
|
|
r = torch.nn.Parameter(r, requires_grad=r.requires_grad) |
|
|
return r |
|
|
elif torch.overrides.is_tensor_like(t): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.miss += 1 |
|
|
return t |
|
|
else: |
|
|
|
|
|
return t |
|
|
|
|
|
|
|
|
import torch._prims_common as utils |
|
|
|