| import gc | |
| import logging | |
| import torch | |
| import torch.distributed as dist | |
| logger = logging.getLogger(__name__) | |
| def clear_memory(clear_host_memory: bool = False): | |
| torch.cuda.synchronize() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| if clear_host_memory: | |
| torch._C._host_emptyCache() | |
| def available_memory(): | |
| device = torch.cuda.current_device() | |
| free, total = torch.cuda.mem_get_info(device) | |
| return { | |
| "gpu": str(device), | |
| "total_GB": _byte_to_gb(total), | |
| "free_GB": _byte_to_gb(free), | |
| "used_GB": _byte_to_gb(total - free), | |
| "allocated_GB": _byte_to_gb(torch.cuda.memory_allocated(device)), | |
| "reserved_GB": _byte_to_gb(torch.cuda.memory_reserved(device)), | |
| } | |
| def _byte_to_gb(n: int): | |
| return round(n / (1024**3), 2) | |
| def print_memory(msg, clear_before_print: bool = False): | |
| if clear_before_print: | |
| clear_memory() | |
| memory_info = available_memory() | |
| # Need to print for all ranks, b/c different rank can have different behaviors | |
| logger.info( | |
| f"[Rank {dist.get_rank()}] Memory-Usage {msg}{' (cleared before print)' if clear_before_print else ''}: {memory_info}" | |
| ) | |
| return memory_info | |