Xernive's picture
Fix Hunyuan3D error handling + enhanced logging
0e805d4
raw
history blame
2.41 kB
"""GPU memory management utilities."""
import gc
import torch
from typing import Optional
class MemoryManager:
"""Manages GPU memory allocation and cleanup."""
@staticmethod
def setup_cuda_optimizations() -> None:
"""Configure CUDA for optimal performance on L4 GPU."""
if not torch.cuda.is_available():
print("[Memory] CUDA not available")
return
# Enable TF32 for faster inference on Ampere+ GPUs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("[Memory] CUDA optimizations enabled:")
print(f" - Device: {torch.cuda.get_device_name(0)}")
print(f" - Total memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
print(f" - TF32 enabled (20-30% faster)")
print(f" - cuDNN benchmark enabled")
@staticmethod
def cleanup_model(model: Optional[object]) -> None:
"""Completely destroy a model and free GPU memory."""
if model is None:
return
print("[Memory] Destroying model...")
# Move components to CPU if they exist
if hasattr(model, 'text_encoder'):
model.text_encoder.to('cpu')
if hasattr(model, 'unet'):
model.unet.to('cpu')
if hasattr(model, 'vae'):
model.vae.to('cpu')
# Delete model
del model
# Nuclear garbage collection
for _ in range(5):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Report memory status
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated(0) / 1e9
print(f"[Memory] GPU memory after cleanup: {allocated:.2f} GB")
@staticmethod
def get_memory_stats() -> dict:
"""Get current GPU memory statistics."""
if not torch.cuda.is_available():
return {"available": False}
return {
"available": True,
"allocated_gb": torch.cuda.memory_allocated(0) / 1e9,
"reserved_gb": torch.cuda.memory_reserved(0) / 1e9,
"max_allocated_gb": torch.cuda.max_memory_allocated(0) / 1e9,
}