Spaces:
Runtime error
Runtime error
| from collections import defaultdict | |
| import torch | |
| class MemUsageMonitor(): | |
| device = None | |
| disabled = False | |
| opts = None | |
| data = None | |
| def __init__(self, name, device): | |
| self.name = name | |
| self.device = device | |
| self.data = defaultdict(int) | |
| if not torch.cuda.is_available(): | |
| self.disabled = True | |
| else: | |
| try: | |
| torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device()) | |
| torch.cuda.memory_stats(self.device) | |
| except Exception: | |
| self.disabled = True | |
| def cuda_mem_get_info(self): # legacy for extensions only | |
| if self.disabled: | |
| return 0, 0 | |
| return torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device()) | |
| def reset(self): | |
| if not self.disabled: | |
| try: | |
| torch.cuda.reset_peak_memory_stats(self.device) | |
| self.data['retries'] = 0 | |
| self.data['oom'] = 0 | |
| # torch.cuda.reset_accumulated_memory_stats(self.device) | |
| # torch.cuda.reset_max_memory_allocated(self.device) | |
| # torch.cuda.reset_max_memory_cached(self.device) | |
| except Exception: | |
| pass | |
| def read(self): | |
| if not self.disabled: | |
| try: | |
| self.data["free"], self.data["total"] = torch.cuda.mem_get_info(self.device.index if self.device.index is not None else torch.cuda.current_device()) | |
| torch_stats = torch.cuda.memory_stats(self.device) | |
| self.data["active"] = torch_stats["active.all.current"] | |
| self.data["active_peak"] = torch_stats["active_bytes.all.peak"] | |
| self.data["reserved"] = torch_stats["reserved_bytes.all.current"] | |
| self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] | |
| self.data['retries'] = torch_stats["num_alloc_retries"] | |
| self.data['oom'] = torch_stats["num_ooms"] | |
| self.data["used"] = self.data["total"] - self.data["free"] | |
| except Exception: | |
| self.disabled = True | |
| return self.data | |