| import torch |
| import time |
| import functools |
| import logging |
| import os |
| import psutil |
| import gc |
| try: |
| from app.utils.logging_utils import setup_logger |
| except ImportError: |
| |
| from behavior_backend.app.utils.logging_utils import setup_logger |
|
|
| |
| logger = setup_logger(__name__) |
|
|
| def get_system_memory_info(): |
| """ |
| Get system memory information. |
| |
| Returns: |
| dict: Memory information |
| """ |
| memory = psutil.virtual_memory() |
| return { |
| "total": memory.total / (1024 ** 3), |
| "available": memory.available / (1024 ** 3), |
| "percent_used": memory.percent, |
| "process_usage": psutil.Process(os.getpid()).memory_info().rss / (1024 ** 3) |
| } |
|
|
| def log_memory_usage(message=""): |
| """ |
| Log current memory usage. |
| |
| Args: |
| message: Optional message to include in the log |
| """ |
| mem_info = get_system_memory_info() |
| logger.info(f"Memory usage {message}: " |
| f"Total: {mem_info['total']:.2f}GB, " |
| f"Available: {mem_info['available']:.2f}GB, " |
| f"Used: {mem_info['percent_used']}%, " |
| f"Process: {mem_info['process_usage']:.2f}GB") |
|
|
| def get_available_device(): |
| """ |
| Determine the best available device with proper error handling. |
| |
| Returns: |
| str: 'cuda', 'mps', or 'cpu' depending on availability |
| """ |
| logger.info("=== GPU DETECTION ===") |
| |
| |
| mem_info = get_system_memory_info() |
| if mem_info['available'] < 2.0: |
| logger.warning(f"Low system memory: {mem_info['available']:.2f}GB available. Forcing CPU usage.") |
| return "cpu" |
| |
| |
| if torch.cuda.is_available(): |
| try: |
| |
| logger.info("CUDA detected - attempting verification") |
| |
| test_tensor = torch.tensor([1.0], device="cuda") |
| test_tensor = test_tensor + 1.0 |
| result = test_tensor.item() |
| |
| |
| test_tensor = test_tensor.cpu() |
| torch.cuda.empty_cache() |
| logger.info(f" NVIDIA GPU (CUDA) detected and verified working (test result: {result})") |
| return "cuda" |
| except Exception as e: |
| logger.warning(f"CUDA detected but test failed: {e}") |
| torch.cuda.empty_cache() |
| |
| |
| if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): |
| try: |
| |
| test_tensor = torch.zeros(1).to('mps') |
| test_tensor = test_tensor + 1 |
| test_tensor.cpu() |
| logger.info(" Apple Silicon GPU (MPS) detected and verified working") |
| return "mps" |
| except Exception as e: |
| logger.warning(f" MPS detected but test failed: {e}") |
| |
| |
| logger.info(" No GPU detected or all GPU tests failed, using CPU") |
| return "cpu" |
|
|
| def run_on_device(func): |
| """ |
| Decorator to run a function on the best available device. |
| |
| Args: |
| func: The function to decorate |
| |
| Returns: |
| A wrapped function that runs on the best available device |
| """ |
| @functools.wraps(func) |
| def wrapper(*args, **kwargs): |
| |
| log_memory_usage(f"before {func.__name__}") |
| |
| |
| gc.collect() |
| |
| |
| device = get_available_device() |
| |
| |
| if 'device' not in kwargs: |
| kwargs['device'] = device |
| |
| try: |
| start_time = time.time() |
| result = func(*args, **kwargs) |
| end_time = time.time() |
| |
| logger.debug(f"Function {func.__name__} ran on {device} in {end_time - start_time:.4f} seconds") |
| return result |
| except Exception as e: |
| |
| if "SparseMPS" in str(e) and device == "mps": |
| logger.warning(f"MPS error detected: {e}") |
| logger.warning("Falling back to CPU for this operation") |
| |
| |
| kwargs['device'] = 'cpu' |
| |
| |
| gc.collect() |
| |
| start_time = time.time() |
| result = func(*args, **kwargs) |
| end_time = time.time() |
| |
| logger.debug(f"Function {func.__name__} ran on CPU (fallback) in {end_time - start_time:.4f} seconds") |
| return result |
| else: |
| |
| raise |
| finally: |
| |
| gc.collect() |
| if device == 'cuda': |
| torch.cuda.empty_cache() |
| |
| |
| log_memory_usage(f"after {func.__name__}") |
| |
| return wrapper |
|
|
| |
| device = get_available_device() |