test / behavior_backend /app /utils /device_utils.py
hibatorrahmen's picture
Add backend application and Dockerfile
8ae78b0
import torch
import time
import functools
import logging
import os
import psutil
import gc
try:
from app.utils.logging_utils import setup_logger
except ImportError:
# Try relative imports for running from project root
from behavior_backend.app.utils.logging_utils import setup_logger
# Configure logging
logger = setup_logger(__name__)
def get_system_memory_info():
"""
Get system memory information.
Returns:
dict: Memory information
"""
memory = psutil.virtual_memory()
return {
"total": memory.total / (1024 ** 3), # GB
"available": memory.available / (1024 ** 3), # GB
"percent_used": memory.percent,
"process_usage": psutil.Process(os.getpid()).memory_info().rss / (1024 ** 3) # GB
}
def log_memory_usage(message=""):
"""
Log current memory usage.
Args:
message: Optional message to include in the log
"""
mem_info = get_system_memory_info()
logger.info(f"Memory usage {message}: "
f"Total: {mem_info['total']:.2f}GB, "
f"Available: {mem_info['available']:.2f}GB, "
f"Used: {mem_info['percent_used']}%, "
f"Process: {mem_info['process_usage']:.2f}GB")
def get_available_device():
"""
Determine the best available device with proper error handling.
Returns:
str: 'cuda', 'mps', or 'cpu' depending on availability
"""
logger.info("=== GPU DETECTION ===")
# Check available memory first
mem_info = get_system_memory_info()
if mem_info['available'] < 2.0: # Less than 2GB available
logger.warning(f"Low system memory: {mem_info['available']:.2f}GB available. Forcing CPU usage.")
return "cpu"
# First try CUDA (NVIDIA GPUs)
if torch.cuda.is_available():
try:
# Simplified CUDA test with better error handling
logger.info("CUDA detected - attempting verification")
# Use a smaller and simpler operation
test_tensor = torch.tensor([1.0], device="cuda")
test_tensor = test_tensor + 1.0 # Simple operation
result = test_tensor.item() # Get the value back to validate operation
# If we get here, the CUDA operation worked
test_tensor = test_tensor.cpu() # Move back to CPU to free CUDA memory
torch.cuda.empty_cache() # Clear CUDA cache
logger.info(f" NVIDIA GPU (CUDA) detected and verified working (test result: {result})")
return "cuda"
except Exception as e:
logger.warning(f"CUDA detected but test failed: {e}")
torch.cuda.empty_cache() # Clear CUDA cache
# Then try MPS (Apple Silicon)
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
try:
# Test MPS with a small operation
test_tensor = torch.zeros(1).to('mps')
test_tensor = test_tensor + 1
test_tensor.cpu() # Move back to CPU to free MPS memory
logger.info(" Apple Silicon GPU (MPS) detected and verified working")
return "mps"
except Exception as e:
logger.warning(f" MPS detected but test failed: {e}")
# Fall back to CPU
logger.info(" No GPU detected or all GPU tests failed, using CPU")
return "cpu"
def run_on_device(func):
"""
Decorator to run a function on the best available device.
Args:
func: The function to decorate
Returns:
A wrapped function that runs on the best available device
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Log memory before operation
log_memory_usage(f"before {func.__name__}")
# Force garbage collection before operation
gc.collect()
# Get device if not already specified
device = get_available_device()
# Add device to kwargs if not already present
if 'device' not in kwargs:
kwargs['device'] = device
try:
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
logger.debug(f"Function {func.__name__} ran on {device} in {end_time - start_time:.4f} seconds")
return result
except Exception as e:
# Check if this is the SparseMPS error
if "SparseMPS" in str(e) and device == "mps":
logger.warning(f"MPS error detected: {e}")
logger.warning("Falling back to CPU for this operation")
# Update device to CPU and retry
kwargs['device'] = 'cpu'
# Force garbage collection before retry
gc.collect()
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
logger.debug(f"Function {func.__name__} ran on CPU (fallback) in {end_time - start_time:.4f} seconds")
return result
else:
# Re-raise other exceptions
raise
finally:
# Force garbage collection after operation
gc.collect()
if device == 'cuda':
torch.cuda.empty_cache()
# Log memory after operation
log_memory_usage(f"after {func.__name__}")
return wrapper
# Initialize device once at module level
device = get_available_device()