| import gc | |
| import os | |
| import psutil | |
| import torch | |
| def print_memory_usage(): | |
| process = psutil.Process(os.getpid()) | |
| print(f"Memory usage: {process.memory_info().rss / 1024 ** 2:2f} MB") | |
| print(f"GPU usage: {torch.cuda.memory_allocated() / 1024 ** 2:2f} MB") | |
| def clear_cuda_and_gc(): | |
| print_memory_usage() | |
| print("Clearing cuda and gc") | |
| clear_gc() | |
| clear_cuda() | |
| print_memory_usage() | |
| def clear_cuda(): | |
| with torch.no_grad(): | |
| torch.cuda.empty_cache() | |
| def clear_gc(): | |
| gc.collect() | |
| def auto_clear_cuda_and_gc(controlnet): | |
| def auto_clear_cuda_and_gc_wrapper(func): | |
| def wrapper(*args, **kwargs): | |
| try: | |
| return func(*args, **kwargs) | |
| except Exception as e: | |
| controlnet.cleanup() | |
| clear_cuda_and_gc() | |
| raise e | |
| return wrapper | |
| return auto_clear_cuda_and_gc_wrapper | |