|
|
import dataclasses |
|
|
import traceback |
|
|
from collections import OrderedDict |
|
|
from typing import Any, Callable, Dict, List, Set, Tuple, Union |
|
|
|
|
|
import torch |
|
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
from torch.nn.parallel.scatter_gather import ( |
|
|
_is_namedtuple, |
|
|
) |
|
|
from torch.nn.utils.rnn import PackedSequence |
|
|
|
|
|
"""Useful functions to deal with tensor types with other python container types.""" |
|
|
|
|
|
__all__ = ["p_assert"] |
|
|
|
|
|
def _contains_batchnorm(module): |
|
|
return any( |
|
|
isinstance(mod, _BatchNorm) for mod in module.modules() |
|
|
) |
|
|
|
|
|
|
|
|
def _override_batchnorm_mixed_precision(module): |
|
|
for mod in module.modules(): |
|
|
if isinstance(mod, _BatchNorm): |
|
|
mod._wrap_overrides = {"mixed_precision": None} |
|
|
|
|
|
|
|
|
def _apply_to_tensors( |
|
|
fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence] |
|
|
) -> Any: |
|
|
"""Recursively apply to all tensor in different kinds of container types.""" |
|
|
|
|
|
def apply(x: Union[torch.Tensor, Dict, List, Tuple, Set, OrderedDict, PackedSequence]) -> Any: |
|
|
if torch.is_tensor(x): |
|
|
return fn(x) |
|
|
elif hasattr(x, "__dataclass_fields__"): |
|
|
dc = dataclasses.replace(x) |
|
|
for f in dataclasses.fields(dc): |
|
|
name = f.name |
|
|
setattr(dc, name, apply(getattr(dc, name))) |
|
|
return dc |
|
|
elif isinstance(x, OrderedDict): |
|
|
od = x.__class__() |
|
|
for key, value in x.items(): |
|
|
od[key] = apply(value) |
|
|
return od |
|
|
elif isinstance(x, PackedSequence): |
|
|
apply(x.data) |
|
|
return x |
|
|
elif isinstance(x, dict): |
|
|
return {key: apply(value) for key, value in x.items()} |
|
|
elif _is_namedtuple(x): |
|
|
res = (apply(el) for el in x) |
|
|
return type(x)(*res) |
|
|
elif isinstance(x, (list, tuple, set)): |
|
|
return type(x)(apply(el) for el in x) |
|
|
else: |
|
|
return x |
|
|
|
|
|
return apply(container) |
|
|
|
|
|
|
|
|
def _apply_to_modules( |
|
|
root_module: torch.nn.Module, |
|
|
module_fn: Callable, |
|
|
return_fn: Callable, |
|
|
*args, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Performs a pre-order traversal of the modules in the hierarchy rooted at |
|
|
``root_module``, applying ``module_fn`` at each module and finally |
|
|
returning a value using ``return_fn``. The traversal constructs the full |
|
|
module prefix name (e.g. "module.submodule." just like in model state dict) |
|
|
and makes that available to ``module_fn``. |
|
|
""" |
|
|
def f(module: torch.nn.Module, prefix: str, *args, **kwargs): |
|
|
|
|
|
module_fn(module, prefix, *args, **kwargs) |
|
|
for submodule_name, submodule in module.named_children(): |
|
|
if submodule is not None: |
|
|
new_prefix = prefix + submodule_name + "." |
|
|
f(submodule, new_prefix, *args, **kwargs) |
|
|
|
|
|
f(root_module, "", *args, **kwargs) |
|
|
return return_fn(*args, **kwargs) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> bool: |
|
|
""" |
|
|
Allocate storage for ``tensor`` with the given size. |
|
|
|
|
|
Returns: |
|
|
bool: ``True`` if this method allocated storage and ``False`` if the |
|
|
storage was already allocated. |
|
|
""" |
|
|
already_allocated = tensor.storage().size() == size.numel() |
|
|
if not already_allocated: |
|
|
tensor_storage_size = tensor.storage().size() |
|
|
p_assert( |
|
|
tensor_storage_size == 0, |
|
|
f"Tensor storage should have been resized to be 0 but got {tensor_storage_size}", |
|
|
) |
|
|
tensor.storage().resize_(size.numel()) |
|
|
return not already_allocated |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _free_storage(tensor: torch.Tensor) -> bool: |
|
|
""" |
|
|
Frees the underlying storage of ``tensor``. |
|
|
|
|
|
Returns: |
|
|
bool: ``True`` if the method freed the storage and ``False`` if the |
|
|
storage was already freed. |
|
|
""" |
|
|
already_freed = tensor.storage().size() == 0 |
|
|
if not already_freed: |
|
|
p_assert( |
|
|
tensor.storage_offset() == 0, |
|
|
"Freeing a tensor's storage is unsafe when it is not the sole occupant", |
|
|
) |
|
|
tensor.storage().resize_(0) |
|
|
return not already_freed |
|
|
|
|
|
|
|
|
def p_assert(cond: Any, s: Any, raise_assertion_error: bool = True) -> None: |
|
|
"""This is used as an alternate to ``assert`` when in the backward context |
|
|
to print the error message ``s`` since otherwise, it is swallowed.""" |
|
|
if not cond: |
|
|
print(s) |
|
|
traceback.print_stack() |
|
|
if raise_assertion_error: |
|
|
raise AssertionError |
|
|
|