""" Memory Cleanup Module Adaptive memory cleanup and optimization utilities. """ import torch import gc import logging from typing import Optional logger = logging.getLogger(__name__) class MemoryCleanup: """ Memory cleanup and optimization utilities. """ def __init__(self, memory_threshold: float = 0.85, cleanup_threshold: float = 0.75): self.memory_threshold = memory_threshold self.cleanup_threshold = cleanup_threshold self.memory_pressure_level = 0 logger.debug("MemoryCleanup initialized") def check_memory_pressure(self) -> bool: """ Check if memory usage is above threshold. Returns: True if memory pressure is high """ if not torch.cuda.is_available(): return False try: memory_allocated = torch.cuda.memory_allocated() max_memory = torch.cuda.max_memory_allocated() # Avoid division by zero if max_memory == 0: return False memory_ratio = memory_allocated / max_memory return memory_ratio > self.memory_threshold except Exception: return False def adaptive_cleanup(self, tensor_pool=None) -> None: """ Perform adaptive memory cleanup based on usage patterns. Args: tensor_pool: Optional tensor pool to clean """ if not torch.cuda.is_available(): return # Clear unused tensor pools if tensor_pool is not None: tensor_pool.clear_pool(keep_ratio=0.5) # Clear cache if memory pressure is high if self.check_memory_pressure(): torch.cuda.empty_cache() gc.collect() logger.debug("[CLEANUP] Adaptive cleanup performed") def emergency_cleanup(self, tensor_pool=None) -> None: """ Perform emergency memory cleanup. Args: tensor_pool: Optional tensor pool to clear """ logger.warning("[CLEANUP] Performing emergency memory cleanup") # Clear tensor pools if tensor_pool is not None: tensor_pool.clear_all() # Clear PyTorch cache if torch.cuda.is_available(): torch.cuda.empty_cache() # Force garbage collection gc.collect() logger.info("[CLEANUP] Emergency cleanup completed") def get_memory_stats(self) -> dict: """Get current memory statistics.""" stats = { 'memory_pressure_level': self.memory_pressure_level, 'memory_threshold': self.memory_threshold } if torch.cuda.is_available(): stats.update({ 'cuda_allocated': torch.cuda.memory_allocated(), 'cuda_reserved': torch.cuda.memory_reserved(), 'cuda_max_allocated': torch.cuda.max_memory_allocated(), 'cuda_max_reserved': torch.cuda.max_memory_reserved() }) return stats