File size: 5,265 Bytes
6de0612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
Memory management utilities for Pixagram AI Pixel Art Generator
Provides efficient GPU memory management and model offloading
"""
import torch
import gc
import psutil
import os


class MemoryManager:
    """Manages GPU and CPU memory efficiently for model offloading"""
    
    def __init__(self, device='cuda', dtype=torch.float16, verbose=True):
        self.device = device
        self.dtype = dtype
        self.verbose = verbose
        self.models_on_gpu = set()
        
    def offload_to_cpu(self, model, model_name="model"):
        """Move model to CPU and free GPU memory"""
        if model is None:
            return model
            
        try:
            model = model.to("cpu")
            self.models_on_gpu.discard(model_name)
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            
            if self.verbose:
                print(f"[MEMORY] Offloaded {model_name} to CPU")
                self.print_memory_status()
                
            return model
        except Exception as e:
            print(f"[MEMORY] Error offloading {model_name}: {e}")
            return model
    
    def load_to_gpu(self, model, model_name="model"):
        """Move model to GPU temporarily"""
        if model is None:
            return model
            
        try:
            model = model.to(self.device)
            self.models_on_gpu.add(model_name)
            
            if self.verbose:
                print(f"[MEMORY] Loaded {model_name} to GPU")
                self.print_memory_status()
                
            return model
        except Exception as e:
            print(f"[MEMORY] Error loading {model_name} to GPU: {e}")
            return model
    
    def cleanup_memory(self, aggressive=True):
        """Perform memory cleanup"""
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        
        if aggressive:
            # Multiple GC passes for thorough cleanup
            for _ in range(3):
                gc.collect()
        else:
            gc.collect()
            
        if self.verbose:
            self.print_memory_status()
    
    def print_memory_status(self):
        """Print current memory usage"""
        if torch.cuda.is_available():
            allocated_gb = torch.cuda.memory_allocated() / 1024**3
            reserved_gb = torch.cuda.memory_reserved() / 1024**3
            print(f"  GPU: {allocated_gb:.2f}GB allocated, {reserved_gb:.2f}GB reserved")
        
        # CPU memory status
        process = psutil.Process(os.getpid())
        cpu_mb = process.memory_info().rss / 1024**2
        print(f"  CPU: {cpu_mb:.0f}MB used")
        
    def get_available_gpu_memory(self):
        """Get available GPU memory in GB"""
        if not torch.cuda.is_available():
            return 0
            
        return (torch.cuda.get_device_properties(0).total_memory - 
                torch.cuda.memory_reserved()) / 1024**3
    
    def can_fit_on_gpu(self, estimated_gb):
        """Check if model of estimated size can fit on GPU"""
        available = self.get_available_gpu_memory()
        # Leave 1GB buffer for safety
        return available > (estimated_gb + 1.0)


class ModelOffloader:
    """Context manager for temporary GPU loading"""
    
    def __init__(self, model, memory_manager, model_name="model"):
        self.model = model
        self.memory_manager = memory_manager
        self.model_name = model_name
        self.was_on_gpu = False
        
    def __enter__(self):
        """Move model to GPU"""
        if self.model is not None and hasattr(self.model, 'device'):
            self.was_on_gpu = (self.model.device.type == 'cuda')
            if not self.was_on_gpu:
                self.model = self.memory_manager.load_to_gpu(self.model, self.model_name)
        return self.model
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        """Move model back to CPU if it wasn't on GPU before"""
        if self.model is not None and not self.was_on_gpu:
            self.model = self.memory_manager.offload_to_cpu(self.model, self.model_name)


def optimize_for_zero_gpu(pipe):
    """
    Optimize pipeline for Hugging Face Spaces Zero GPU
    This ensures models stay on CPU until @spaces.GPU decorator activates
    """
    if hasattr(pipe, 'enable_model_cpu_offload'):
        pipe.enable_model_cpu_offload()
        print("[MEMORY] Enabled model CPU offloading for Zero GPU")
    
    if hasattr(pipe, 'enable_vae_slicing'):
        pipe.enable_vae_slicing()
        print("[MEMORY] Enabled VAE slicing for memory efficiency")
    
    if hasattr(pipe, 'enable_vae_tiling'):
        pipe.enable_vae_tiling()
        print("[MEMORY] Enabled VAE tiling for memory efficiency")
    
    return pipe


def estimate_model_size(model):
    """Estimate model size in GB"""
    if model is None:
        return 0
        
    total_params = 0
    for param in model.parameters():
        total_params += param.numel()
    
    # Assuming float16 (2 bytes per param)
    size_gb = (total_params * 2) / 1024**3
    return size_gb


print("[OK] Memory management utilities loaded")