|
|
import gc |
|
|
import torch |
|
|
|
|
|
|
|
|
__all__ = ['print_memory_size', 'garbage_collection_cuda'] |
|
|
|
|
|
|
|
|
def print_memory_size(a): |
|
|
assert isinstance(a, torch.Tensor) |
|
|
memory = a.element_size() * a.nelement() |
|
|
if memory > 1024 * 1024 * 1024: |
|
|
print(f'Memory: {memory / (1024 * 1024 * 1024):0.3f} Gb') |
|
|
return |
|
|
if memory > 1024 * 1024: |
|
|
print(f'Memory: {memory / (1024 * 1024):0.3f} Mb') |
|
|
return |
|
|
if memory > 1024: |
|
|
print(f'Memory: {memory / 1024:0.3f} Kb') |
|
|
return |
|
|
print(f'Memory: {memory:0.3f} bytes') |
|
|
|
|
|
|
|
|
def is_oom_error(exception: BaseException) -> bool: |
|
|
return is_cuda_out_of_memory(exception) or is_cudnn_snafu(exception) or is_out_of_cpu_memory(exception) |
|
|
|
|
|
|
|
|
|
|
|
def is_cuda_out_of_memory(exception: BaseException) -> bool: |
|
|
return ( |
|
|
isinstance(exception, RuntimeError) |
|
|
and len(exception.args) == 1 |
|
|
and "CUDA" in exception.args[0] |
|
|
and "out of memory" in exception.args[0] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def is_cudnn_snafu(exception: BaseException) -> bool: |
|
|
|
|
|
return ( |
|
|
isinstance(exception, RuntimeError) |
|
|
and len(exception.args) == 1 |
|
|
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def is_out_of_cpu_memory(exception: BaseException) -> bool: |
|
|
return ( |
|
|
isinstance(exception, RuntimeError) |
|
|
and len(exception.args) == 1 |
|
|
and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def garbage_collection_cuda() -> None: |
|
|
"""Garbage collection Torch (CUDA) memory.""" |
|
|
gc.collect() |
|
|
try: |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
except RuntimeError as exception: |
|
|
if not is_oom_error(exception): |
|
|
|
|
|
raise |
|
|
|