"""GPU resource management utilities for ZeroGPU compatibility. This module provides utilities for managing GPU resources, including model device transfers, cache management, and context managers for automatic cleanup. """ import logging import time from contextlib import contextmanager from typing import Any, Optional import torch from src.config.gpu_config import GPUConfig logger = logging.getLogger(__name__) def acquire_gpu(model: torch.nn.Module, device: str = "cuda") -> bool: """Move a model to the specified GPU device. Args: model: PyTorch model to move to GPU device: Target device (default: "cuda") Returns: bool: True if successful, False otherwise """ try: start_time = time.time() target_device = torch.device(device) model.to(target_device) elapsed = time.time() - start_time logger.debug(f"Model {model.__class__.__name__} moved to {device} in {elapsed:.3f}s") return True except Exception as e: logger.error(f"Failed to move model to {device}: {e}") return False def release_gpu(model: torch.nn.Module, clear_cache: bool = True) -> bool: """Move a model back to CPU and optionally clear CUDA cache. Args: model: PyTorch model to move to CPU clear_cache: Whether to clear CUDA cache after moving Returns: bool: True if successful, False otherwise """ try: start_time = time.time() model.to(torch.device("cpu")) if clear_cache and GPUConfig.ENABLE_CACHE_CLEARING and torch.cuda.is_available(): torch.cuda.empty_cache() elapsed = time.time() - start_time if elapsed > GPUConfig.CLEANUP_TIMEOUT: logger.warning( f"GPU cleanup took {elapsed:.3f}s, exceeding {GPUConfig.CLEANUP_TIMEOUT}s limit" ) else: logger.debug(f"GPU released in {elapsed:.3f}s") return True except Exception as e: logger.error(f"Failed to release GPU: {e}") return False @contextmanager def gpu_context(model: torch.nn.Module, device: str = "cuda"): """Context manager for automatic GPU resource management. Acquires GPU on entry and releases it on exit, even if an exception occurs. Args: model: PyTorch model to manage device: Target GPU device (default: "cuda") Yields: torch.nn.Module: The model on the GPU device Example: >>> with gpu_context(my_model) as model: ... result = model(input_data) """ acquired = False try: acquired = acquire_gpu(model, device) if not acquired: logger.warning(f"Failed to acquire GPU, model remains on {model.device}") yield model finally: if acquired: release_gpu(model, clear_cache=True) def move_to_device(data: Any, device: torch.device) -> Any: """Recursively move tensors to the specified device. Handles nested structures like lists, tuples, and dicts. Args: data: Data to move (tensor, list, tuple, dict, or other) device: Target device Returns: Data with all tensors moved to the device """ if isinstance(data, torch.Tensor): return data.to(device) elif isinstance(data, dict): return {k: move_to_device(v, device) for k, v in data.items()} elif isinstance(data, list): return [move_to_device(item, device) for item in data] elif isinstance(data, tuple): return tuple(move_to_device(item, device) for item in data) else: return data def get_gpu_memory_info() -> Optional[dict]: """Get current GPU memory usage information. Returns: dict: Memory information with 'allocated' and 'reserved' in GB, or None if CUDA unavailable """ if not torch.cuda.is_available(): return None try: allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB reserved = torch.cuda.memory_reserved() / 1024**3 return { "allocated_gb": round(allocated, 2), "reserved_gb": round(reserved, 2), } except Exception as e: logger.error(f"Failed to get GPU memory info: {e}") return None def log_gpu_usage(operation: str): """Log current GPU memory usage for a specific operation. Args: operation: Description of the operation being performed """ memory_info = get_gpu_memory_info() if memory_info: logger.info( f"[{operation}] GPU Memory - Allocated: {memory_info['allocated_gb']}GB, " f"Reserved: {memory_info['reserved_gb']}GB" )