| | import gc |
| | import threading |
| |
|
| | import psutil |
| | import torch |
| |
|
| |
|
| | |
| | def b2mb(x): |
| | return int(x / 2**20) |
| |
|
| |
|
| | |
| | class TorchTracemalloc: |
| | def __enter__(self): |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | torch.cuda.reset_max_memory_allocated() |
| | self.begin = torch.cuda.memory_allocated() |
| | self.process = psutil.Process() |
| |
|
| | self.cpu_begin = self.cpu_mem_used() |
| | self.peak_monitoring = True |
| | peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) |
| | peak_monitor_thread.daemon = True |
| | peak_monitor_thread.start() |
| | return self |
| |
|
| | def cpu_mem_used(self): |
| | """get resident set size memory for the current process""" |
| | return self.process.memory_info().rss |
| |
|
| | def peak_monitor_func(self): |
| | self.cpu_peak = -1 |
| |
|
| | while True: |
| | self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) |
| |
|
| | |
| | |
| |
|
| | if not self.peak_monitoring: |
| | break |
| |
|
| | def __exit__(self, *exc): |
| | self.peak_monitoring = False |
| |
|
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | self.end = torch.cuda.memory_allocated() |
| | self.peak = torch.cuda.max_memory_allocated() |
| | self.used = b2mb(self.end - self.begin) |
| | self.peaked = b2mb(self.peak - self.begin) |
| |
|
| | self.cpu_end = self.cpu_mem_used() |
| | self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) |
| | self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) |
| | |
| |
|