| import warnings | |
| import torch | |
| from torch.utils._python_dispatch import is_traceable_wrapper_subclass | |
| def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]: | |
| """ | |
| Recursively extracts untyped storages from a tensor or its subclasses. | |
| Args: | |
| t (torch.Tensor): The tensor to extract storages from. | |
| Returns: | |
| Set[torch.UntypedStorage]: A set of untyped storages. | |
| """ | |
| unflattened_tensors = [t] | |
| flattened_tensor_storages = set() | |
| while len(unflattened_tensors) > 0: | |
| obj = unflattened_tensors.pop() | |
| if is_traceable_wrapper_subclass(obj): | |
| attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] | |
| unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) | |
| else: | |
| if not hasattr(obj, "untyped_storage"): | |
| warnings.warn( | |
| f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", | |
| category=UserWarning, | |
| stacklevel=2, | |
| ) | |
| else: | |
| flattened_tensor_storages.add(obj.untyped_storage()) | |
| return flattened_tensor_storages | |