cranky-coder08's picture
Add files using upload-large-folder tool
f4cade0 verified
import warnings
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]:
"""
Recursively extracts untyped storages from a tensor or its subclasses.
Args:
t (torch.Tensor): The tensor to extract storages from.
Returns:
Set[torch.UntypedStorage]: A set of untyped storages.
"""
unflattened_tensors = [t]
flattened_tensor_storages = set()
while len(unflattened_tensors) > 0:
obj = unflattened_tensors.pop()
if is_traceable_wrapper_subclass(obj):
attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined]
unflattened_tensors.extend([getattr(obj, attr) for attr in attrs])
else:
if not hasattr(obj, "untyped_storage"):
warnings.warn(
f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}",
category=UserWarning,
stacklevel=2,
)
else:
flattened_tensor_storages.add(obj.untyped_storage())
return flattened_tensor_storages