import torch import time import functools import logging import os import psutil import gc try: from app.utils.logging_utils import setup_logger except ImportError: # Try relative imports for running from project root from behavior_backend.app.utils.logging_utils import setup_logger # Configure logging logger = setup_logger(__name__) def get_system_memory_info(): """ Get system memory information. Returns: dict: Memory information """ memory = psutil.virtual_memory() return { "total": memory.total / (1024 ** 3), # GB "available": memory.available / (1024 ** 3), # GB "percent_used": memory.percent, "process_usage": psutil.Process(os.getpid()).memory_info().rss / (1024 ** 3) # GB } def log_memory_usage(message=""): """ Log current memory usage. Args: message: Optional message to include in the log """ mem_info = get_system_memory_info() logger.info(f"Memory usage {message}: " f"Total: {mem_info['total']:.2f}GB, " f"Available: {mem_info['available']:.2f}GB, " f"Used: {mem_info['percent_used']}%, " f"Process: {mem_info['process_usage']:.2f}GB") def get_available_device(): """ Determine the best available device with proper error handling. Returns: str: 'cuda', 'mps', or 'cpu' depending on availability """ logger.info("=== GPU DETECTION ===") # Check available memory first mem_info = get_system_memory_info() if mem_info['available'] < 2.0: # Less than 2GB available logger.warning(f"Low system memory: {mem_info['available']:.2f}GB available. Forcing CPU usage.") return "cpu" # First try CUDA (NVIDIA GPUs) if torch.cuda.is_available(): try: # Simplified CUDA test with better error handling logger.info("CUDA detected - attempting verification") # Use a smaller and simpler operation test_tensor = torch.tensor([1.0], device="cuda") test_tensor = test_tensor + 1.0 # Simple operation result = test_tensor.item() # Get the value back to validate operation # If we get here, the CUDA operation worked test_tensor = test_tensor.cpu() # Move back to CPU to free CUDA memory torch.cuda.empty_cache() # Clear CUDA cache logger.info(f" NVIDIA GPU (CUDA) detected and verified working (test result: {result})") return "cuda" except Exception as e: logger.warning(f"CUDA detected but test failed: {e}") torch.cuda.empty_cache() # Clear CUDA cache # Then try MPS (Apple Silicon) if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): try: # Test MPS with a small operation test_tensor = torch.zeros(1).to('mps') test_tensor = test_tensor + 1 test_tensor.cpu() # Move back to CPU to free MPS memory logger.info(" Apple Silicon GPU (MPS) detected and verified working") return "mps" except Exception as e: logger.warning(f" MPS detected but test failed: {e}") # Fall back to CPU logger.info(" No GPU detected or all GPU tests failed, using CPU") return "cpu" def run_on_device(func): """ Decorator to run a function on the best available device. Args: func: The function to decorate Returns: A wrapped function that runs on the best available device """ @functools.wraps(func) def wrapper(*args, **kwargs): # Log memory before operation log_memory_usage(f"before {func.__name__}") # Force garbage collection before operation gc.collect() # Get device if not already specified device = get_available_device() # Add device to kwargs if not already present if 'device' not in kwargs: kwargs['device'] = device try: start_time = time.time() result = func(*args, **kwargs) end_time = time.time() logger.debug(f"Function {func.__name__} ran on {device} in {end_time - start_time:.4f} seconds") return result except Exception as e: # Check if this is the SparseMPS error if "SparseMPS" in str(e) and device == "mps": logger.warning(f"MPS error detected: {e}") logger.warning("Falling back to CPU for this operation") # Update device to CPU and retry kwargs['device'] = 'cpu' # Force garbage collection before retry gc.collect() start_time = time.time() result = func(*args, **kwargs) end_time = time.time() logger.debug(f"Function {func.__name__} ran on CPU (fallback) in {end_time - start_time:.4f} seconds") return result else: # Re-raise other exceptions raise finally: # Force garbage collection after operation gc.collect() if device == 'cuda': torch.cuda.empty_cache() # Log memory after operation log_memory_usage(f"after {func.__name__}") return wrapper # Initialize device once at module level device = get_available_device()