File size: 2,407 Bytes
0e805d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""GPU memory management utilities."""

import gc
import torch
from typing import Optional


class MemoryManager:
    """Manages GPU memory allocation and cleanup."""
    
    @staticmethod
    def setup_cuda_optimizations() -> None:
        """Configure CUDA for optimal performance on L4 GPU."""
        if not torch.cuda.is_available():
            print("[Memory] CUDA not available")
            return
        
        # Enable TF32 for faster inference on Ampere+ GPUs
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        
        print("[Memory] CUDA optimizations enabled:")
        print(f"  - Device: {torch.cuda.get_device_name(0)}")
        print(f"  - Total memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
        print(f"  - TF32 enabled (20-30% faster)")
        print(f"  - cuDNN benchmark enabled")
    
    @staticmethod
    def cleanup_model(model: Optional[object]) -> None:
        """Completely destroy a model and free GPU memory."""
        if model is None:
            return
        
        print("[Memory] Destroying model...")
        
        # Move components to CPU if they exist
        if hasattr(model, 'text_encoder'):
            model.text_encoder.to('cpu')
        if hasattr(model, 'unet'):
            model.unet.to('cpu')
        if hasattr(model, 'vae'):
            model.vae.to('cpu')
        
        # Delete model
        del model
        
        # Nuclear garbage collection
        for _ in range(5):
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
        
        # Report memory status
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated(0) / 1e9
            print(f"[Memory] GPU memory after cleanup: {allocated:.2f} GB")
    
    @staticmethod
    def get_memory_stats() -> dict:
        """Get current GPU memory statistics."""
        if not torch.cuda.is_available():
            return {"available": False}
        
        return {
            "available": True,
            "allocated_gb": torch.cuda.memory_allocated(0) / 1e9,
            "reserved_gb": torch.cuda.memory_reserved(0) / 1e9,
            "max_allocated_gb": torch.cuda.max_memory_allocated(0) / 1e9,
        }