voice-tools / src /lib /gpu_utils.py
jcudit's picture
jcudit HF Staff
fix: also correct lib/ in gitignore to only exclude root-level, add src/lib package
3ff2f18
"""GPU resource management utilities for ZeroGPU compatibility.
This module provides utilities for managing GPU resources, including model device
transfers, cache management, and context managers for automatic cleanup.
"""
import logging
import time
from contextlib import contextmanager
from typing import Any, Optional
import torch
from src.config.gpu_config import GPUConfig
logger = logging.getLogger(__name__)
def acquire_gpu(model: torch.nn.Module, device: str = "cuda") -> bool:
"""Move a model to the specified GPU device.
Args:
model: PyTorch model to move to GPU
device: Target device (default: "cuda")
Returns:
bool: True if successful, False otherwise
"""
try:
start_time = time.time()
target_device = torch.device(device)
model.to(target_device)
elapsed = time.time() - start_time
logger.debug(f"Model {model.__class__.__name__} moved to {device} in {elapsed:.3f}s")
return True
except Exception as e:
logger.error(f"Failed to move model to {device}: {e}")
return False
def release_gpu(model: torch.nn.Module, clear_cache: bool = True) -> bool:
"""Move a model back to CPU and optionally clear CUDA cache.
Args:
model: PyTorch model to move to CPU
clear_cache: Whether to clear CUDA cache after moving
Returns:
bool: True if successful, False otherwise
"""
try:
start_time = time.time()
model.to(torch.device("cpu"))
if clear_cache and GPUConfig.ENABLE_CACHE_CLEARING and torch.cuda.is_available():
torch.cuda.empty_cache()
elapsed = time.time() - start_time
if elapsed > GPUConfig.CLEANUP_TIMEOUT:
logger.warning(
f"GPU cleanup took {elapsed:.3f}s, exceeding {GPUConfig.CLEANUP_TIMEOUT}s limit"
)
else:
logger.debug(f"GPU released in {elapsed:.3f}s")
return True
except Exception as e:
logger.error(f"Failed to release GPU: {e}")
return False
@contextmanager
def gpu_context(model: torch.nn.Module, device: str = "cuda"):
"""Context manager for automatic GPU resource management.
Acquires GPU on entry and releases it on exit, even if an exception occurs.
Args:
model: PyTorch model to manage
device: Target GPU device (default: "cuda")
Yields:
torch.nn.Module: The model on the GPU device
Example:
>>> with gpu_context(my_model) as model:
... result = model(input_data)
"""
acquired = False
try:
acquired = acquire_gpu(model, device)
if not acquired:
logger.warning(f"Failed to acquire GPU, model remains on {model.device}")
yield model
finally:
if acquired:
release_gpu(model, clear_cache=True)
def move_to_device(data: Any, device: torch.device) -> Any:
"""Recursively move tensors to the specified device.
Handles nested structures like lists, tuples, and dicts.
Args:
data: Data to move (tensor, list, tuple, dict, or other)
device: Target device
Returns:
Data with all tensors moved to the device
"""
if isinstance(data, torch.Tensor):
return data.to(device)
elif isinstance(data, dict):
return {k: move_to_device(v, device) for k, v in data.items()}
elif isinstance(data, list):
return [move_to_device(item, device) for item in data]
elif isinstance(data, tuple):
return tuple(move_to_device(item, device) for item in data)
else:
return data
def get_gpu_memory_info() -> Optional[dict]:
"""Get current GPU memory usage information.
Returns:
dict: Memory information with 'allocated' and 'reserved' in GB, or None if CUDA unavailable
"""
if not torch.cuda.is_available():
return None
try:
allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
reserved = torch.cuda.memory_reserved() / 1024**3
return {
"allocated_gb": round(allocated, 2),
"reserved_gb": round(reserved, 2),
}
except Exception as e:
logger.error(f"Failed to get GPU memory info: {e}")
return None
def log_gpu_usage(operation: str):
"""Log current GPU memory usage for a specific operation.
Args:
operation: Description of the operation being performed
"""
memory_info = get_gpu_memory_info()
if memory_info:
logger.info(
f"[{operation}] GPU Memory - Allocated: {memory_info['allocated_gb']}GB, "
f"Reserved: {memory_info['reserved_gb']}GB"
)