ayjays132's picture
Upload 478 files
101858b verified
"""
Memory Cleanup Module
Adaptive memory cleanup and optimization utilities.
"""
import torch
import gc
import logging
from typing import Optional
logger = logging.getLogger(__name__)
class MemoryCleanup:
"""
Memory cleanup and optimization utilities.
"""
def __init__(self, memory_threshold: float = 0.85, cleanup_threshold: float = 0.75):
self.memory_threshold = memory_threshold
self.cleanup_threshold = cleanup_threshold
self.memory_pressure_level = 0
logger.debug("MemoryCleanup initialized")
def check_memory_pressure(self) -> bool:
"""
Check if memory usage is above threshold.
Returns:
True if memory pressure is high
"""
if not torch.cuda.is_available():
return False
try:
memory_allocated = torch.cuda.memory_allocated()
max_memory = torch.cuda.max_memory_allocated()
# Avoid division by zero
if max_memory == 0:
return False
memory_ratio = memory_allocated / max_memory
return memory_ratio > self.memory_threshold
except Exception:
return False
def adaptive_cleanup(self, tensor_pool=None) -> None:
"""
Perform adaptive memory cleanup based on usage patterns.
Args:
tensor_pool: Optional tensor pool to clean
"""
if not torch.cuda.is_available():
return
# Clear unused tensor pools
if tensor_pool is not None:
tensor_pool.clear_pool(keep_ratio=0.5)
# Clear cache if memory pressure is high
if self.check_memory_pressure():
torch.cuda.empty_cache()
gc.collect()
logger.debug("[CLEANUP] Adaptive cleanup performed")
def emergency_cleanup(self, tensor_pool=None) -> None:
"""
Perform emergency memory cleanup.
Args:
tensor_pool: Optional tensor pool to clear
"""
logger.warning("[CLEANUP] Performing emergency memory cleanup")
# Clear tensor pools
if tensor_pool is not None:
tensor_pool.clear_all()
# Clear PyTorch cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Force garbage collection
gc.collect()
logger.info("[CLEANUP] Emergency cleanup completed")
def get_memory_stats(self) -> dict:
"""Get current memory statistics."""
stats = {
'memory_pressure_level': self.memory_pressure_level,
'memory_threshold': self.memory_threshold
}
if torch.cuda.is_available():
stats.update({
'cuda_allocated': torch.cuda.memory_allocated(),
'cuda_reserved': torch.cuda.memory_reserved(),
'cuda_max_allocated': torch.cuda.max_memory_allocated(),
'cuda_max_reserved': torch.cuda.max_memory_reserved()
})
return stats