File size: 1,199 Bytes
d7b3a74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 | import gc
import logging
import torch
import torch.distributed as dist
logger = logging.getLogger(__name__)
def clear_memory(clear_host_memory: bool = False):
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
if clear_host_memory:
torch._C._host_emptyCache()
def available_memory():
device = torch.cuda.current_device()
free, total = torch.cuda.mem_get_info(device)
return {
"gpu": str(device),
"total_GB": _byte_to_gb(total),
"free_GB": _byte_to_gb(free),
"used_GB": _byte_to_gb(total - free),
"allocated_GB": _byte_to_gb(torch.cuda.memory_allocated(device)),
"reserved_GB": _byte_to_gb(torch.cuda.memory_reserved(device)),
}
def _byte_to_gb(n: int):
return round(n / (1024**3), 2)
def print_memory(msg, clear_before_print: bool = False):
if clear_before_print:
clear_memory()
memory_info = available_memory()
# Need to print for all ranks, b/c different rank can have different behaviors
logger.info(
f"[Rank {dist.get_rank()}] Memory-Usage {msg}{' (cleared before print)' if clear_before_print else ''}: {memory_info}"
)
return memory_info
|