"""GPU memory management utilities.""" import gc import torch from typing import Optional class MemoryManager: """Manages GPU memory allocation and cleanup.""" @staticmethod def setup_cuda_optimizations() -> None: """Configure CUDA for optimal performance on L4 GPU.""" if not torch.cuda.is_available(): print("[Memory] CUDA not available") return # Enable TF32 for faster inference on Ampere+ GPUs torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True print("[Memory] CUDA optimizations enabled:") print(f" - Device: {torch.cuda.get_device_name(0)}") print(f" - Total memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") print(f" - TF32 enabled (20-30% faster)") print(f" - cuDNN benchmark enabled") @staticmethod def cleanup_model(model: Optional[object]) -> None: """Completely destroy a model and free GPU memory.""" if model is None: return print("[Memory] Destroying model...") # Move components to CPU if they exist if hasattr(model, 'text_encoder'): model.text_encoder.to('cpu') if hasattr(model, 'unet'): model.unet.to('cpu') if hasattr(model, 'vae'): model.vae.to('cpu') # Delete model del model # Nuclear garbage collection for _ in range(5): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() # Report memory status if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated(0) / 1e9 print(f"[Memory] GPU memory after cleanup: {allocated:.2f} GB") @staticmethod def get_memory_stats() -> dict: """Get current GPU memory statistics.""" if not torch.cuda.is_available(): return {"available": False} return { "available": True, "allocated_gb": torch.cuda.memory_allocated(0) / 1e9, "reserved_gb": torch.cuda.memory_reserved(0) / 1e9, "max_allocated_gb": torch.cuda.max_memory_allocated(0) / 1e9, }