| from typing import Any, Iterable, Optional, Union |
|
|
| import torch |
|
|
|
|
| def B_to_GiB(bytes: Union[int, float]) -> float: |
| return bytes / 2**30 |
|
|
|
|
| def get_tensor_bytes(tensor: torch.Tensor) -> int: |
| """ |
| Returns the bytes of storage a given tensor takes up. If `tensor` is a view of a larger tensor, |
| this function only returns the bytes associated with the view. |
| """ |
| tensor_bytes = tensor.numel() * tensor.element_size() |
| return tensor_bytes |
|
|
|
|
| class AllocatedMemContext: |
| """ |
| Context manager which captures the allocated GPU memory at context exit and the change between |
| enter and exit. |
| |
| Only includes `allocated_bytes.all.`-prefixed keys in `memory_stats` with all readings converted |
| to GiB. |
| |
| Example: |
| |
| ```python |
| |
| ``` |
| """ |
|
|
| def __init__(self) -> None: |
| |
| torch.cuda.current_blas_handle() |
|
|
| self.before: dict[str, int] = {} |
| self.after: dict[str, int] = {} |
| self.delta: dict[str, int] = {} |
|
|
| self._mem_key_prefix = "allocated_bytes.all." |
|
|
| def _get_mem_dict(self) -> dict[str, int]: |
| return { |
| k.replace(self._mem_key_prefix, ""): v |
| for k, v in torch.cuda.memory_stats().items() |
| if self._mem_key_prefix in k |
| } |
|
|
| def __enter__(self) -> "AllocatedMemContext": |
| self.before = self._get_mem_dict() |
| return self |
|
|
| def __exit__(self, *args: Any, **kwargs: Any) -> None: |
| self.after = self._get_mem_dict() |
| self.delta = {k: v - self.before[k] for k, v in self.after.items()} |
|
|
|
|
| class SavedTensorContext: |
| """ |
| Context manager which captures all tensors which are registered as being saved for backwards |
| within the context window. Does not work with `meta`-device tensors. |
| |
| All saved tensors are stored in the `saved_tensor_dict` attr, which is an instance of torch's |
| WeakTensorKeyDictionary with tensor/data_ptr key/value pairs. Some of these tensors may be |
| views of the same underlying storage. The total memory of all saved tensors in bytes, accounting |
| for redundant views, can be accessed through `saved_tensor_mem`. |
| |
| Use: |
| ``` |
| model = ... |
| with SavedTensorContext(ignored_tensors=model.parameters()) as saved: |
| # Do some computation with `model` and capture saved tensors which are not model weights |
| |
| ``` |
| saved.saved_tensor_dict # WeakTensorKeyDictionary of all saved tensors. |
| saved.saved_tensor_mem # bytes from all saved tensors (activation memory). |
| """ |
|
|
| def __init__( |
| self, |
| ignored_tensors: Optional[Iterable[torch.Tensor]] = None, |
| ) -> None: |
| |
| |
| self._ignored_data_ptrs = ( |
| set() |
| if ignored_tensors is None |
| else {t.untyped_storage().data_ptr() for t in ignored_tensors} |
| ) |
|
|
| |
| |
| self.saved_tensor_dict = torch.utils.weak.WeakTensorKeyDictionary() |
|
|
| def pack_hook(saved_tensor: torch.Tensor) -> torch.Tensor: |
| data_ptr = saved_tensor.untyped_storage().data_ptr() |
| if data_ptr not in self._ignored_data_ptrs: |
| self.saved_tensor_dict[saved_tensor] = data_ptr |
| return saved_tensor |
|
|
| def unpack_hook(saved_tensor: torch.Tensor) -> torch.Tensor: |
| return saved_tensor |
|
|
| self._saved_tensors_hook = torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook) |
|
|
| def __enter__(self) -> "SavedTensorContext": |
| self._saved_tensors_hook.__enter__() |
| return self |
|
|
| def __exit__(self, *args: Any, **kwargs: Any) -> None: |
| self._saved_tensors_hook.__exit__(*args, **kwargs) |
|
|
| @property |
| def saved_tensor_mem(self) -> int: |
| """ |
| The memory in bytes of all saved tensors, accounting for views into the same storage. |
| """ |
| accounted_for = self._ignored_data_ptrs.copy() |
| total_bytes = 0 |
| for t in self.saved_tensor_dict: |
| data_ptr = t.untyped_storage().data_ptr() |
| if data_ptr not in accounted_for: |
| print(f"Tensor ptr: {t.untyped_storage().data_ptr()}, " |
| f"shape: {t.shape}, " |
| f"dtype: {t.dtype}, " |
| f"device: {t.device}" |
| ) |
| total_bytes += t.untyped_storage().nbytes() |
| accounted_for.add(data_ptr) |
| return total_bytes |
|
|