Spaces:
Running
on
Zero
Running
on
Zero
| """GPU resource management utilities for ZeroGPU compatibility. | |
| This module provides utilities for managing GPU resources, including model device | |
| transfers, cache management, and context managers for automatic cleanup. | |
| """ | |
| import logging | |
| import time | |
| from contextlib import contextmanager | |
| from typing import Any, Optional | |
| import torch | |
| from src.config.gpu_config import GPUConfig | |
| logger = logging.getLogger(__name__) | |
| def acquire_gpu(model: torch.nn.Module, device: str = "cuda") -> bool: | |
| """Move a model to the specified GPU device. | |
| Args: | |
| model: PyTorch model to move to GPU | |
| device: Target device (default: "cuda") | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| start_time = time.time() | |
| target_device = torch.device(device) | |
| model.to(target_device) | |
| elapsed = time.time() - start_time | |
| logger.debug(f"Model {model.__class__.__name__} moved to {device} in {elapsed:.3f}s") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to move model to {device}: {e}") | |
| return False | |
| def release_gpu(model: torch.nn.Module, clear_cache: bool = True) -> bool: | |
| """Move a model back to CPU and optionally clear CUDA cache. | |
| Args: | |
| model: PyTorch model to move to CPU | |
| clear_cache: Whether to clear CUDA cache after moving | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| start_time = time.time() | |
| model.to(torch.device("cpu")) | |
| if clear_cache and GPUConfig.ENABLE_CACHE_CLEARING and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elapsed = time.time() - start_time | |
| if elapsed > GPUConfig.CLEANUP_TIMEOUT: | |
| logger.warning( | |
| f"GPU cleanup took {elapsed:.3f}s, exceeding {GPUConfig.CLEANUP_TIMEOUT}s limit" | |
| ) | |
| else: | |
| logger.debug(f"GPU released in {elapsed:.3f}s") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to release GPU: {e}") | |
| return False | |
| def gpu_context(model: torch.nn.Module, device: str = "cuda"): | |
| """Context manager for automatic GPU resource management. | |
| Acquires GPU on entry and releases it on exit, even if an exception occurs. | |
| Args: | |
| model: PyTorch model to manage | |
| device: Target GPU device (default: "cuda") | |
| Yields: | |
| torch.nn.Module: The model on the GPU device | |
| Example: | |
| >>> with gpu_context(my_model) as model: | |
| ... result = model(input_data) | |
| """ | |
| acquired = False | |
| try: | |
| acquired = acquire_gpu(model, device) | |
| if not acquired: | |
| logger.warning(f"Failed to acquire GPU, model remains on {model.device}") | |
| yield model | |
| finally: | |
| if acquired: | |
| release_gpu(model, clear_cache=True) | |
| def move_to_device(data: Any, device: torch.device) -> Any: | |
| """Recursively move tensors to the specified device. | |
| Handles nested structures like lists, tuples, and dicts. | |
| Args: | |
| data: Data to move (tensor, list, tuple, dict, or other) | |
| device: Target device | |
| Returns: | |
| Data with all tensors moved to the device | |
| """ | |
| if isinstance(data, torch.Tensor): | |
| return data.to(device) | |
| elif isinstance(data, dict): | |
| return {k: move_to_device(v, device) for k, v in data.items()} | |
| elif isinstance(data, list): | |
| return [move_to_device(item, device) for item in data] | |
| elif isinstance(data, tuple): | |
| return tuple(move_to_device(item, device) for item in data) | |
| else: | |
| return data | |
| def get_gpu_memory_info() -> Optional[dict]: | |
| """Get current GPU memory usage information. | |
| Returns: | |
| dict: Memory information with 'allocated' and 'reserved' in GB, or None if CUDA unavailable | |
| """ | |
| if not torch.cuda.is_available(): | |
| return None | |
| try: | |
| allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB | |
| reserved = torch.cuda.memory_reserved() / 1024**3 | |
| return { | |
| "allocated_gb": round(allocated, 2), | |
| "reserved_gb": round(reserved, 2), | |
| } | |
| except Exception as e: | |
| logger.error(f"Failed to get GPU memory info: {e}") | |
| return None | |
| def log_gpu_usage(operation: str): | |
| """Log current GPU memory usage for a specific operation. | |
| Args: | |
| operation: Description of the operation being performed | |
| """ | |
| memory_info = get_gpu_memory_info() | |
| if memory_info: | |
| logger.info( | |
| f"[{operation}] GPU Memory - Allocated: {memory_info['allocated_gb']}GB, " | |
| f"Reserved: {memory_info['reserved_gb']}GB" | |
| ) | |