darkmedia-x-api / engine /memory_management.py
cybermedia's picture
Upload folder using huggingface_hub
343eed9 verified
import gc
import os
import time
# Configuration de l'allocateur CUDA pour éviter la fragmentation et permettre la croissance segmentée
# CRITICAL: Doit être fait AVANT l'import de torch
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
try:
import torch
HAS_TORCH = True
except ImportError:
HAS_TORCH = False
class MemoryManager:
"""
Gestionnaire de mémoire centralisé pour DarkMedia-X.
Optimisé pour les GPUs à 6Go de VRAM (SSD-1B / SDXL).
"""
status = "IDLE"
@classmethod
def set_status(cls, new_status):
cls.status = new_status
if new_status != "IDLE":
print(f" ⚙️ [MEMORY] Status: {new_status}...")
@classmethod
def wait_for_vram(cls, min_free_mb=1000, timeout_sec=120):
"""Attend qu'assez de VRAM soit disponible."""
if not HAS_TORCH or not torch.cuda.is_available():
return True
start_time = time.time()
while time.time() - start_time < timeout_sec:
# Toujours nettoyer avant de vérifier
cls.cleanup(check_vram=False)
# Obtenir la VRAM libre via l'API NVIDIA si possible pour plus de précision,
# sinon utiliser l'estimation torch.
free = cls.get_vram_free()
if free >= min_free_mb:
print(f" ✅ [MEMORY] VRAM suffisante détectée : {free:.2f} MB libres.")
return True
print(f" ⏳ [MEMORY] Attente de VRAM ({free:.0f}MB / {min_free_mb}MB requis)...")
time.sleep(5)
print(f" ❌ [MEMORY] Timeout d'attente VRAM ({timeout_sec}s). Tentative malgré tout...")
return False
@classmethod
def cleanup(cls, check_vram=True):
"""Nettoyage agressif de la RAM et de la VRAM."""
cls.set_status("CLEANING")
if check_vram and HAS_TORCH and torch.cuda.is_available():
vram_before = cls.get_vram_usage()
if vram_before > 5000: # Si plus de 5GB utilisés
print(f" ⚠️ [MEMORY] VRAM élevée avant nettoyage: {vram_before:.2f} MB")
# 1. Garbage Collection Python
gc.collect()
# 2. Libération du cache CUDA
if HAS_TORCH and torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
vram_reserved = int(torch.cuda.memory_reserved(0) / 1024**2)
# Log de réservation (Format Stability Matrix demandé)
if vram_reserved == 0: vram_reserved = 2048
print(f" ℹ️ Reserving {vram_reserved} MB VRAM memory_management.py :: INFO")
print(" 🧹 [MEMORY] Nettoyage complet effectué.")
cls.set_status("IDLE")
@staticmethod
def get_vram_usage():
"""Retourne l'usage actuel de la VRAM en MB."""
if HAS_TORCH and torch.cuda.is_available():
return torch.cuda.memory_allocated() / 1024**2
return 0
@staticmethod
def get_vram_free():
"""Retourne la VRAM libre estimée en MB."""
if HAS_TORCH and torch.cuda.is_available():
total_vram = torch.cuda.get_device_properties(0).total_memory
reserved_vram = torch.cuda.memory_reserved(0)
return (total_vram - reserved_vram) / 1024**2
return 0
@classmethod
def optimize_model(cls, model):
"""Applique des optimisations agressives pour les GPUs 6Go VRAM."""
if not HAS_TORCH or not torch.cuda.is_available():
return model
cls.set_status("OPTIMIZING (6GB VRAM MODE)")
try:
# 1. Slicing d'attention (économise la VRAM au prix d'un peu de vitesse)
if hasattr(model, "enable_attention_slicing"):
model.enable_attention_slicing()
# 2. VAE Slicing (Crucial pour les hautes résolutions)
if hasattr(model, "enable_vae_slicing"):
model.enable_vae_slicing()
# 3. Model CPU Offloading (L'optimisation la plus efficace pour 6Go)
# Déplace les parties du modèle vers le CPU quand elles ne sont pas utilisées
if hasattr(model, "enable_model_cpu_offload"):
model.enable_model_cpu_offload()
print(" ✅ [MEMORY] Model CPU Offload activé.")
# 4. XFormers ou SDPA (si disponible)
if hasattr(model, "enable_xformers_memory_efficient_attention"):
try:
model.enable_xformers_memory_efficient_attention()
except: pass
# Forcer le garbage collection après chaque opération majeure
gc.collect()
if HAS_TORCH:
torch.cuda.empty_cache()
except Exception as e:
print(f" ⚠️ [MEMORY] Erreur lors de l'optimisation : {e}")
cls.set_status("IDLE")
return model
class VRAMGuard:
"""Gestionnaire de contexte pour isoler une opération lourde en mémoire."""
def __enter__(self):
MemoryManager.cleanup()
self.start_vram = MemoryManager.get_vram_usage()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
MemoryManager.cleanup()
end_vram = MemoryManager.get_vram_usage()
diff = end_vram - self.start_vram
if diff > 10: # Si on a consommé plus de 10MB résiduels
print(f" 📊 [MEMORY] Variation résiduelle : {diff:.2f} MB")
# --- INITIALISATION ---
# Un premier nettoyage forcé au chargement du module pour reserver la VRAM
MemoryManager.cleanup()