Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| from typing import Optional, Tuple, Sequence, TypeVar, Union, Mapping, Any, List, Dict | |
| import torch as th | |
| import numpy as np | |
| TensorOrContainer = Union[ | |
| th.Tensor, str, int, Sequence["TensorOrContainer"], Mapping[str, "TensorOrContainer"] | |
| ] | |
| NdarrayOrContainer = Union[ | |
| np.ndarray, | |
| str, | |
| int, | |
| Sequence["NdarrayOrContainer"], | |
| Mapping[str, "NdarrayOrContainer"], | |
| ] | |
| TensorNdarrayOrContainer = Union[ | |
| th.Tensor, | |
| np.ndarray, | |
| str, | |
| int, | |
| Sequence["TensorNdarrayOrContainer"], | |
| Mapping[str, "TensorNdarrayOrContainer"], | |
| ] | |
| TensorNdarrayModuleOrContainer = Union[ | |
| th.Tensor, | |
| np.ndarray, | |
| th.nn.Module, | |
| str, | |
| int, | |
| Sequence["TensorNdarrayModuleOrContainer"], | |
| Mapping[str, "TensorNdarrayModuleOrContainer"], | |
| ] | |
| TTensorOrContainer = TypeVar("TTensorOrContainer", bound=TensorOrContainer) | |
| TNdarrayOrContainer = TypeVar("TNdarrayOrContainer", bound=NdarrayOrContainer) | |
| TTensorNdarrayOrContainer = TypeVar("TTensorNdarrayOrContainer", bound=TensorNdarrayOrContainer) | |
| TTensorNdarrayModuleOrContainer = TypeVar( | |
| "TTensorNdarrayModuleOrContainer", bound=TensorNdarrayModuleOrContainer | |
| ) | |
| import torch as th | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ParamHolder(th.nn.Module): | |
| def __init__( | |
| self, | |
| param_shape: Tuple[int, ...], | |
| key_list: Sequence[str], | |
| init_value: Union[None, bool, float, int, th.Tensor] = None, | |
| ) -> None: | |
| super().__init__() | |
| if isinstance(param_shape, int): | |
| param_shape = (param_shape,) | |
| self.key_list: Sequence[str] = sorted(key_list) | |
| shp = (len(self.key_list),) + param_shape | |
| self.params = th.nn.Parameter(th.zeros(*shp)) | |
| if init_value is not None: | |
| self.params.data[:] = init_value | |
| def state_dict(self, *args: Any, saving: bool = False, **kwargs: Any) -> Dict[str, Any]: | |
| sd = super().state_dict(*args, **kwargs) | |
| if saving: | |
| assert "key_list" not in sd | |
| sd["key_list"] = self.key_list | |
| return sd | |
| # pyre-fixme[14]: `load_state_dict` overrides method defined in `Module` | |
| # inconsistently. | |
| def load_state_dict( | |
| self, state_dict: Mapping[str, Any], strict: bool = True, **kwargs: Any | |
| ) -> th.nn.modules.module._IncompatibleKeys: | |
| # Note: Mapping is immutable while Dict is mutable. According to pyre ErrorCode[14], | |
| # the type of state_dict must be Mapping or supertype of Mapping to keep consistent | |
| # with the overrided function in its superclass. | |
| sd = dict(state_dict) | |
| if "key_list" not in sd: | |
| logger.warning("Missing key list list in state dict, only checking params shape.") | |
| assert sd["params"].shape == self.params.shape | |
| sd["key_list"] = self.key_list | |
| matching_kl = sd["key_list"] == self.key_list | |
| if strict: | |
| logger.warning("Attempting to load from mismatched key lists.") | |
| assert sd["params"].shape[1:] == self.params.shape[1:] | |
| if not matching_kl: | |
| src_kl = sd["key_list"] | |
| new_kl = sorted(set(self.key_list) | set(src_kl)) | |
| new_shp = (len(new_kl),) + tuple(self.params.shape[1:]) | |
| new_params = th.zeros(*new_shp, device=self.params.device) | |
| for f in self.key_list: | |
| new_params[new_kl.index(f)] = self.params[self.key_list.index(f)] | |
| upd = 0 | |
| new = 0 | |
| for f in src_kl: | |
| new_params[new_kl.index(f)] = sd["params"][src_kl.index(f)] | |
| if f in self.key_list: | |
| upd += 1 | |
| else: | |
| new += 1 | |
| logger.info( | |
| f"Updated {upd} keys ({100*upd/len(self.key_list):0.2f}%), added {new} new keys." | |
| ) | |
| self.key_list = new_kl | |
| sd["params"] = new_params | |
| self.params = th.nn.Parameter(new_params) | |
| del sd["key_list"] | |
| return super().load_state_dict(sd, strict=strict, **kwargs) | |
| def to_idx(self, *args: Any) -> th.Tensor: | |
| if len(args) == 1: | |
| keys = args[0] | |
| else: | |
| keys = zip(*args) | |
| return th.tensor( | |
| [self.key_list.index(k) for k in keys], | |
| dtype=th.long, | |
| device=self.params.device, | |
| ) | |
| def from_idx(self, idxs: th.Tensor) -> List[str]: | |
| return [self.key_list[idx] for idx in idxs] | |
| def forward(self, idxs: th.Tensor) -> th.Tensor: | |
| return self.params[idxs] | |
| def to_device( | |
| things: TTensorNdarrayModuleOrContainer, | |
| device: th.device, | |
| cache: Optional[Dict[str, th.Tensor]] = None, | |
| key: Optional[str] = None, | |
| verbose: bool = False, | |
| max_bs: Optional[int] = None, | |
| non_blocking: bool = False, | |
| ) -> TTensorNdarrayModuleOrContainer: | |
| """Sends a potentially nested container of Tensors to the specified | |
| device. Non-tensors are preserved as-is. | |
| Args: | |
| things: Container with tensors or other containers of tensors to send | |
| to a GPU. | |
| device: Device to send the tensors to. | |
| cache: Optional dictionary to use as a cache for CUDAfied tensors. If | |
| passed, use this cache to allocate a tensor once and then resize / | |
| refill it on future calls to to_device() instead of reallocating | |
| it. | |
| key: If using the cache, store the tensor in this key, only for | |
| internal use. | |
| verbose: Print some info when a cached tensor is resized. | |
| max_bs: Maximum batch size allowed for tensors in cache | |
| non_blocking: if True and this copy is between CPU and GPU, the copy | |
| may occur asynchronously with respect to the host. For other cases, | |
| this argument has no effect. | |
| Returns: | |
| collection: The input collection with all tensors transferred to the given device. | |
| """ | |
| device = th.device(device) | |
| pr = print if verbose else lambda *args, **kwargs: None | |
| if isinstance(things, th.Tensor) and things.device != device: | |
| if cache is not None: | |
| assert key is not None | |
| batch_size = things.shape[0] | |
| if key in cache: | |
| assert things.shape[1:] == cache[key].shape[1:] | |
| if batch_size > cache[key].shape[0]: | |
| pr("Resized:", key, "from", cache[key].shape[0], "to", batch_size) | |
| cache[key].resize_as_(things) | |
| else: | |
| buf_shape = list(things.shape) | |
| if max_bs is not None: | |
| assert max_bs >= batch_size | |
| buf_shape[0] = max_bs | |
| cache[key] = th.zeros(*buf_shape, dtype=things.dtype, device=device) | |
| pr("Allocated:", key, buf_shape) | |
| cache[key][:batch_size].copy_(things, non_blocking=non_blocking) | |
| return cache[key][:batch_size] | |
| else: | |
| return things.to(device, non_blocking=non_blocking) | |
| elif isinstance(things, th.nn.Module): | |
| return things.to(device, non_blocking=non_blocking) | |
| elif isinstance(things, dict): | |
| key = key + "." if key is not None else "" | |
| return { | |
| k: to_device(v, device, cache, key + k, verbose, max_bs, non_blocking) | |
| for k, v in things.items() | |
| } | |
| elif isinstance(things, Sequence) and not isinstance(things, str): | |
| key = key if key is not None else "" | |
| out = [ | |
| to_device(v, device, cache, key + f"_{i}", verbose, max_bs, non_blocking) | |
| for i, v in enumerate(things) | |
| ] | |
| if isinstance(things, tuple): | |
| out = tuple(out) | |
| return out | |
| elif isinstance(things, np.ndarray): | |
| return to_device(th.from_numpy(things), device, cache, key, verbose, max_bs, non_blocking) | |
| else: | |
| return things | |