voice-tools / src /config /gpu_config.py
jcudit's picture
jcudit HF Staff
git commit -m "feat: add HuggingFace ZeroGPU compatibility for
ffe9fdb
"""GPU configuration for HuggingFace ZeroGPU compatibility.
This module provides configuration constants and utilities for managing GPU resources
in both local and HuggingFace Spaces ZeroGPU environments.
"""
import os
import torch
class GPUConfig:
"""GPU configuration constants and environment detection."""
# Environment detection
IS_ZEROGPU: bool = os.environ.get("SPACES_ZERO_GPU") is not None
IS_SPACES: bool = os.environ.get("SPACE_ID") is not None
# Device configuration
GPU_AVAILABLE: bool = torch.cuda.is_available()
DEFAULT_DEVICE: torch.device = torch.device(
"cuda" if GPU_AVAILABLE and not IS_ZEROGPU else "cpu"
)
# Duration limits for @spaces.GPU decorator (seconds)
# These values are based on typical processing times per workflow
SEPARATION_DURATION: int = 90 # Speaker separation (longest operation)
EXTRACTION_DURATION: int = 60 # Speaker extraction
DENOISING_DURATION: int = 45 # Voice denoising (fastest operation)
MAX_DURATION: int = 120 # Maximum allowed by ZeroGPU
# Resource management
CLEANUP_TIMEOUT: float = 2.0 # Maximum time for GPU cleanup (SC-004)
ENABLE_CACHE_CLEARING: bool = True # Clear CUDA cache after operations
@classmethod
def get_device(cls) -> torch.device:
"""Get the appropriate device for model operations.
Returns:
torch.device: CUDA device if available and not in ZeroGPU mode, else CPU
"""
return cls.DEFAULT_DEVICE
@classmethod
def get_environment_type(cls) -> str:
"""Get a string describing the current execution environment.
Returns:
str: One of "zerogpu", "local_gpu", "spaces_cpu", or "local_cpu"
"""
if cls.IS_ZEROGPU:
return "zerogpu"
elif cls.IS_SPACES:
return "spaces_cpu"
elif cls.GPU_AVAILABLE:
return "local_gpu"
else:
return "local_cpu"
@classmethod
def validate_duration(cls, duration: int, max_duration: int = None) -> int:
"""Validate and clamp duration to acceptable limits.
Args:
duration: Requested duration in seconds
max_duration: Maximum allowed duration (defaults to MAX_DURATION)
Returns:
int: Clamped duration value
Raises:
ValueError: If duration is less than 1 second
"""
if duration < 1:
raise ValueError(f"Duration must be at least 1 second, got {duration}")
max_limit = max_duration if max_duration is not None else cls.MAX_DURATION
if duration > max_limit:
return max_limit
return duration
@classmethod
def info(cls) -> dict:
"""Get a dictionary of current GPU configuration.
Returns:
dict: Configuration information for debugging and logging
"""
return {
"environment_type": cls.get_environment_type(),
"is_zerogpu": cls.IS_ZEROGPU,
"is_spaces": cls.IS_SPACES,
"gpu_available": cls.GPU_AVAILABLE,
"default_device": str(cls.DEFAULT_DEVICE),
"separation_duration": cls.SEPARATION_DURATION,
"extraction_duration": cls.EXTRACTION_DURATION,
"denoising_duration": cls.DENOISING_DURATION,
}