File size: 1,221 Bytes
f4cade0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import warnings
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Recursively extracts untyped storages from a tensor or its subclasses.
Args:
t (torch.Tensor): The tensor to extract storages from.
Returns:
Set[torch.UntypedStorage]: A set of untyped storages.
"""
unflattened_tensors = [t]
flattened_tensor_storages = set()
while len(unflattened_tensors) > 0:
obj = unflattened_tensors.pop()
if is_traceable_wrapper_subclass(obj):
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
else:
if not hasattr(obj, "untyped_storage"):
warnings.warn(
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
category=UserWarning,
stacklevel=2,
)
else:
flattened_tensor_storages.add(obj.untyped_storage())
return flattened_tensor_storages
|