Spaces:
Paused
Paused
| import itertools | |
| from typing import List, Tuple | |
| import torch | |
| from torch import Tensor | |
| def get_parameter_device(parameter: torch.nn.Module) -> torch.device: | |
| """ | |
| Gets the device of a PyTorch module's parameters or buffers. | |
| Args: | |
| parameter (`torch.nn.Module`): The PyTorch module from which to get the device. | |
| Returns: | |
| `torch.device`: The device of the module's parameters or buffers. | |
| """ | |
| try: | |
| parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) | |
| return next(parameters_and_buffers).device | |
| except StopIteration: | |
| # For torch.nn.DataParallel compatibility in PyTorch 1.5 | |
| def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: | |
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |
| return tuples | |
| gen = parameter._named_members(get_members_fn=find_tensor_attributes) | |
| first_tuple = next(gen) | |
| return first_tuple[1].device | |
| def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype: | |
| """ | |
| Gets the data type of a PyTorch module's parameters or buffers. | |
| Args: | |
| parameter (`torch.nn.Module`): The PyTorch module from which to get the data type. | |
| Returns: | |
| `torch.dtype`: The data type of the module's parameters or buffers. | |
| """ | |
| try: | |
| params = tuple(parameter.parameters()) | |
| if len(params) > 0: | |
| return params[0].dtype | |
| buffers = tuple(parameter.buffers()) | |
| if len(buffers) > 0: | |
| return buffers[0].dtype | |
| except StopIteration: | |
| # For torch.nn.DataParallel compatibility in PyTorch 1.5 | |
| def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: | |
| tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |
| return tuples | |
| gen = parameter._named_members(get_members_fn=find_tensor_attributes) | |
| first_tuple = next(gen) | |
| return first_tuple[1].dtype | |
| class ModelMixin(torch.nn.Module): | |
| """ | |
| Provides convenient properties to access the device and data type | |
| of a PyTorch module. | |
| By inheriting from this class, your custom PyTorch modules can access these properties | |
| without manual retrieval of device and data type information. | |
| These properties assume that all module parameters and buffers reside | |
| on the same device and have the same data type, respectively. | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| def device(self) -> torch.device: | |
| """ | |
| `torch.device`: The device on which the module is (assuming that all the module parameters are on the same | |
| device). | |
| """ | |
| return get_parameter_device(self) | |
| def dtype(self) -> torch.dtype: | |
| """ | |
| `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |
| """ | |
| return get_parameter_dtype(self) | |