Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,695 Bytes
3ff2f18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
"""GPU resource management utilities for ZeroGPU compatibility.
This module provides utilities for managing GPU resources, including model device
transfers, cache management, and context managers for automatic cleanup.
"""
import logging
import time
from contextlib import contextmanager
from typing import Any, Optional
import torch
from src.config.gpu_config import GPUConfig
logger = logging.getLogger(__name__)
def acquire_gpu(model: torch.nn.Module, device: str = "cuda") -> bool:
"""Move a model to the specified GPU device.
Args:
model: PyTorch model to move to GPU
device: Target device (default: "cuda")
Returns:
bool: True if successful, False otherwise
"""
try:
start_time = time.time()
target_device = torch.device(device)
model.to(target_device)
elapsed = time.time() - start_time
logger.debug(f"Model {model.__class__.__name__} moved to {device} in {elapsed:.3f}s")
return True
except Exception as e:
logger.error(f"Failed to move model to {device}: {e}")
return False
def release_gpu(model: torch.nn.Module, clear_cache: bool = True) -> bool:
"""Move a model back to CPU and optionally clear CUDA cache.
Args:
model: PyTorch model to move to CPU
clear_cache: Whether to clear CUDA cache after moving
Returns:
bool: True if successful, False otherwise
"""
try:
start_time = time.time()
model.to(torch.device("cpu"))
if clear_cache and GPUConfig.ENABLE_CACHE_CLEARING and torch.cuda.is_available():
torch.cuda.empty_cache()
elapsed = time.time() - start_time
if elapsed > GPUConfig.CLEANUP_TIMEOUT:
logger.warning(
f"GPU cleanup took {elapsed:.3f}s, exceeding {GPUConfig.CLEANUP_TIMEOUT}s limit"
)
else:
logger.debug(f"GPU released in {elapsed:.3f}s")
return True
except Exception as e:
logger.error(f"Failed to release GPU: {e}")
return False
@contextmanager
def gpu_context(model: torch.nn.Module, device: str = "cuda"):
"""Context manager for automatic GPU resource management.
Acquires GPU on entry and releases it on exit, even if an exception occurs.
Args:
model: PyTorch model to manage
device: Target GPU device (default: "cuda")
Yields:
torch.nn.Module: The model on the GPU device
Example:
>>> with gpu_context(my_model) as model:
... result = model(input_data)
"""
acquired = False
try:
acquired = acquire_gpu(model, device)
if not acquired:
logger.warning(f"Failed to acquire GPU, model remains on {model.device}")
yield model
finally:
if acquired:
release_gpu(model, clear_cache=True)
def move_to_device(data: Any, device: torch.device) -> Any:
"""Recursively move tensors to the specified device.
Handles nested structures like lists, tuples, and dicts.
Args:
data: Data to move (tensor, list, tuple, dict, or other)
device: Target device
Returns:
Data with all tensors moved to the device
"""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: move_to_device(v, device) for k, v in data.items()}
elif isinstance(data, list):
return [move_to_device(item, device) for item in data]
elif isinstance(data, tuple):
return tuple(move_to_device(item, device) for item in data)
else:
return data
def get_gpu_memory_info() -> Optional[dict]:
"""Get current GPU memory usage information.
Returns:
dict: Memory information with 'allocated' and 'reserved' in GB, or None if CUDA unavailable
"""
if not torch.cuda.is_available():
return None
try:
allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
reserved = torch.cuda.memory_reserved() / 1024**3
return {
"allocated_gb": round(allocated, 2),
"reserved_gb": round(reserved, 2),
}
except Exception as e:
logger.error(f"Failed to get GPU memory info: {e}")
return None
def log_gpu_usage(operation: str):
"""Log current GPU memory usage for a specific operation.
Args:
operation: Description of the operation being performed
"""
memory_info = get_gpu_memory_info()
if memory_info:
logger.info(
f"[{operation}] GPU Memory - Allocated: {memory_info['allocated_gb']}GB, "
f"Reserved: {memory_info['reserved_gb']}GB"
)
|