pixagram-neo-backup / memory_utils.py
primerz's picture
Rename memory_utils to memory_utils.py
ff8e481 verified
raw
history blame
5.27 kB
"""
Memory management utilities for Pixagram AI Pixel Art Generator
Provides efficient GPU memory management and model offloading
"""
import torch
import gc
import psutil
import os
class MemoryManager:
"""Manages GPU and CPU memory efficiently for model offloading"""
def __init__(self, device='cuda', dtype=torch.float16, verbose=True):
self.device = device
self.dtype = dtype
self.verbose = verbose
self.models_on_gpu = set()
def offload_to_cpu(self, model, model_name="model"):
"""Move model to CPU and free GPU memory"""
if model is None:
return model
try:
model = model.to("cpu")
self.models_on_gpu.discard(model_name)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if self.verbose:
print(f"[MEMORY] Offloaded {model_name} to CPU")
self.print_memory_status()
return model
except Exception as e:
print(f"[MEMORY] Error offloading {model_name}: {e}")
return model
def load_to_gpu(self, model, model_name="model"):
"""Move model to GPU temporarily"""
if model is None:
return model
try:
model = model.to(self.device)
self.models_on_gpu.add(model_name)
if self.verbose:
print(f"[MEMORY] Loaded {model_name} to GPU")
self.print_memory_status()
return model
except Exception as e:
print(f"[MEMORY] Error loading {model_name} to GPU: {e}")
return model
def cleanup_memory(self, aggressive=True):
"""Perform memory cleanup"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
if aggressive:
# Multiple GC passes for thorough cleanup
for _ in range(3):
gc.collect()
else:
gc.collect()
if self.verbose:
self.print_memory_status()
def print_memory_status(self):
"""Print current memory usage"""
if torch.cuda.is_available():
allocated_gb = torch.cuda.memory_allocated() / 1024**3
reserved_gb = torch.cuda.memory_reserved() / 1024**3
print(f" GPU: {allocated_gb:.2f}GB allocated, {reserved_gb:.2f}GB reserved")
# CPU memory status
process = psutil.Process(os.getpid())
cpu_mb = process.memory_info().rss / 1024**2
print(f" CPU: {cpu_mb:.0f}MB used")
def get_available_gpu_memory(self):
"""Get available GPU memory in GB"""
if not torch.cuda.is_available():
return 0
return (torch.cuda.get_device_properties(0).total_memory -
torch.cuda.memory_reserved()) / 1024**3
def can_fit_on_gpu(self, estimated_gb):
"""Check if model of estimated size can fit on GPU"""
available = self.get_available_gpu_memory()
# Leave 1GB buffer for safety
return available > (estimated_gb + 1.0)
class ModelOffloader:
"""Context manager for temporary GPU loading"""
def __init__(self, model, memory_manager, model_name="model"):
self.model = model
self.memory_manager = memory_manager
self.model_name = model_name
self.was_on_gpu = False
def __enter__(self):
"""Move model to GPU"""
if self.model is not None and hasattr(self.model, 'device'):
self.was_on_gpu = (self.model.device.type == 'cuda')
if not self.was_on_gpu:
self.model = self.memory_manager.load_to_gpu(self.model, self.model_name)
return self.model
def __exit__(self, exc_type, exc_val, exc_tb):
"""Move model back to CPU if it wasn't on GPU before"""
if self.model is not None and not self.was_on_gpu:
self.model = self.memory_manager.offload_to_cpu(self.model, self.model_name)
def optimize_for_zero_gpu(pipe):
"""
Optimize pipeline for Hugging Face Spaces Zero GPU
This ensures models stay on CPU until @spaces.GPU decorator activates
"""
if hasattr(pipe, 'enable_model_cpu_offload'):
pipe.enable_model_cpu_offload()
print("[MEMORY] Enabled model CPU offloading for Zero GPU")
if hasattr(pipe, 'enable_vae_slicing'):
pipe.enable_vae_slicing()
print("[MEMORY] Enabled VAE slicing for memory efficiency")
if hasattr(pipe, 'enable_vae_tiling'):
pipe.enable_vae_tiling()
print("[MEMORY] Enabled VAE tiling for memory efficiency")
return pipe
def estimate_model_size(model):
"""Estimate model size in GB"""
if model is None:
return 0
total_params = 0
for param in model.parameters():
total_params += param.numel()
# Assuming float16 (2 bytes per param)
size_gb = (total_params * 2) / 1024**3
return size_gb
print("[OK] Memory management utilities loaded")