| import torch |
| import gc |
|
|
|
|
| def clear_gradients(*args): |
| for arg in args: |
| if isinstance(arg, torch.Tensor) and arg.grad is not None: |
| arg.grad = None |
|
|
|
|
| def clear_memory(device): |
| torch._C._cuda_clearCublasWorkspaces() |
| torch._dynamo.reset() |
| gc.collect() |
| torch.cuda.empty_cache() |
| torch.cuda.reset_peak_memory_stats(device) |
|
|
|
|
| def peak_memory(f, *args, device): |
| for _ in range(3): |
| |
| clear_memory(device) |
| clear_gradients(*args) |
|
|
| |
| f(*args) |
|
|
| |
| torch.cuda.synchronize() |
| memory = torch.cuda.max_memory_allocated(device) |
|
|
| return memory |
|
|
|
|
| def current_memory(device): |
| return torch.cuda.memory_allocated(device) / (1024**3) |
|
|
|
|
| def memory_measure(f, device, num_iters=3): |
| |
| clear_memory(device) |
|
|
| |
| print("Current memory: ", current_memory(device)) |
| memory = peak_memory(f, device=device) |
|
|
| print("Peak memory: ", memory / (1024**3)) |
| return memory / (1024**3) |
|
|
|
|
| def memory_measure_simple(f, device, *args, **kwargs): |
| |
| clear_memory(device) |
| clear_gradients(*args) |
|
|
| current = current_memory(device) |
|
|
| |
| out = f(*args, **kwargs) |
|
|
| |
| torch.cuda.synchronize() |
| memory = torch.cuda.max_memory_allocated(device) |
| memory = memory / (1024**3) |
| memory = memory - current |
|
|
| return out, memory |
|
|