|
|
import gc |
|
|
import threading |
|
|
|
|
|
import psutil |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
def b2mb(x): |
|
|
return int(x / 2**20) |
|
|
|
|
|
|
|
|
|
|
|
class TorchTracemalloc: |
|
|
def __enter__(self): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.reset_max_memory_allocated() |
|
|
self.begin = torch.cuda.memory_allocated() |
|
|
self.process = psutil.Process() |
|
|
|
|
|
self.cpu_begin = self.cpu_mem_used() |
|
|
self.peak_monitoring = True |
|
|
peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) |
|
|
peak_monitor_thread.daemon = True |
|
|
peak_monitor_thread.start() |
|
|
return self |
|
|
|
|
|
def cpu_mem_used(self): |
|
|
"""get resident set size memory for the current process""" |
|
|
return self.process.memory_info().rss |
|
|
|
|
|
def peak_monitor_func(self): |
|
|
self.cpu_peak = -1 |
|
|
|
|
|
while True: |
|
|
self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.peak_monitoring: |
|
|
break |
|
|
|
|
|
def __exit__(self, *exc): |
|
|
self.peak_monitoring = False |
|
|
|
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
self.end = torch.cuda.memory_allocated() |
|
|
self.peak = torch.cuda.max_memory_allocated() |
|
|
self.used = b2mb(self.end - self.begin) |
|
|
self.peaked = b2mb(self.peak - self.begin) |
|
|
|
|
|
self.cpu_end = self.cpu_mem_used() |
|
|
self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) |
|
|
self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) |
|
|
|
|
|
|