| """ |
| GPU Configuration Module |
| Centralized GPU memory detection and adaptive configuration management |
| |
| Debug Mode: |
| Set environment variable MAX_CUDA_VRAM to simulate different GPU memory sizes. |
| Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU |
| |
| For MPS testing, use MAX_MPS_VRAM to simulate MPS memory. |
| Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS |
| |
| This is useful for testing GPU tier configurations on high-end hardware. |
| """ |
|
|
| import os |
| import sys |
| from dataclasses import dataclass |
| from typing import Optional, List, Dict, Tuple |
| from loguru import logger |
|
|
|
|
| |
| DEBUG_MAX_CUDA_VRAM_ENV = "MAX_CUDA_VRAM" |
| DEBUG_MAX_MPS_VRAM_ENV = "MAX_MPS_VRAM" |
|
|
| |
| |
| VRAM_16GB_TOLERANCE_GB = 0.5 |
| VRAM_16GB_MIN_GB = 16.0 - VRAM_16GB_TOLERANCE_GB |
|
|
| |
| PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121" |
| PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0" |
|
|
|
|
| @dataclass |
| class GPUConfig: |
| """GPU configuration based on available memory""" |
| tier: str |
| gpu_memory_gb: float |
| |
| |
| max_duration_with_lm: int |
| max_duration_without_lm: int |
| |
| |
| max_batch_size_with_lm: int |
| max_batch_size_without_lm: int |
| |
| |
| init_lm_default: bool |
| available_lm_models: List[str] |
| |
| |
| lm_memory_gb: Dict[str, float] |
|
|
|
|
| |
| GPU_TIER_CONFIGS = { |
| "tier1": { |
| "max_duration_with_lm": 180, |
| "max_duration_without_lm": 180, |
| "max_batch_size_with_lm": 1, |
| "max_batch_size_without_lm": 1, |
| "init_lm_default": False, |
| "available_lm_models": [], |
| "lm_memory_gb": {}, |
| }, |
| "tier2": { |
| "max_duration_with_lm": 360, |
| "max_duration_without_lm": 360, |
| "max_batch_size_with_lm": 1, |
| "max_batch_size_without_lm": 1, |
| "init_lm_default": False, |
| "available_lm_models": [], |
| "lm_memory_gb": {}, |
| }, |
| "tier3": { |
| "max_duration_with_lm": 240, |
| "max_duration_without_lm": 360, |
| "max_batch_size_with_lm": 1, |
| "max_batch_size_without_lm": 2, |
| "init_lm_default": False, |
| "available_lm_models": ["acestep-5Hz-lm-0.6B"], |
| "lm_memory_gb": {"0.6B": 3}, |
| }, |
| "tier4": { |
| "max_duration_with_lm": 240, |
| "max_duration_without_lm": 360, |
| "max_batch_size_with_lm": 2, |
| "max_batch_size_without_lm": 4, |
| "init_lm_default": False, |
| "available_lm_models": ["acestep-5Hz-lm-0.6B"], |
| "lm_memory_gb": {"0.6B": 3}, |
| }, |
| "tier5": { |
| "max_duration_with_lm": 240, |
| "max_duration_without_lm": 360, |
| "max_batch_size_with_lm": 2, |
| "max_batch_size_without_lm": 4, |
| "init_lm_default": True, |
| "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B"], |
| "lm_memory_gb": {"0.6B": 3, "1.7B": 8}, |
| }, |
| "tier6": { |
| "max_duration_with_lm": 480, |
| "max_duration_without_lm": 480, |
| "max_batch_size_with_lm": 4, |
| "max_batch_size_without_lm": 8, |
| "init_lm_default": True, |
| "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"], |
| "lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12}, |
| }, |
| "unlimited": { |
| "max_duration_with_lm": 600, |
| "max_duration_without_lm": 600, |
| "max_batch_size_with_lm": 8, |
| "max_batch_size_without_lm": 8, |
| "init_lm_default": True, |
| "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"], |
| "lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12}, |
| }, |
| } |
|
|
|
|
| def get_gpu_memory_gb() -> float: |
| """ |
| Get GPU memory in GB. Returns 0 if no GPU is available. |
| |
| Debug Mode: |
| Set environment variable MAX_CUDA_VRAM to override the detected GPU memory. |
| Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU |
| |
| For MPS testing, set MAX_MPS_VRAM to override MPS memory detection. |
| Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS |
| |
| This allows testing different GPU tier configurations on high-end hardware. |
| """ |
| |
| debug_vram = os.environ.get(DEBUG_MAX_CUDA_VRAM_ENV) |
| if debug_vram is not None: |
| try: |
| simulated_gb = float(debug_vram) |
| logger.warning(f"⚠️ DEBUG MODE: Simulating GPU memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_CUDA_VRAM_ENV} environment variable)") |
| return simulated_gb |
| except ValueError: |
| logger.warning(f"Invalid {DEBUG_MAX_CUDA_VRAM_ENV} value: {debug_vram}, ignoring") |
| debug_mps_vram = os.environ.get(DEBUG_MAX_MPS_VRAM_ENV) |
| if debug_mps_vram is not None: |
| try: |
| simulated_gb = float(debug_mps_vram) |
| logger.warning(f"⚠️ DEBUG MODE: Simulating MPS memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_MPS_VRAM_ENV} environment variable)") |
| return simulated_gb |
| except ValueError: |
| logger.warning(f"Invalid {DEBUG_MAX_MPS_VRAM_ENV} value: {debug_mps_vram}, ignoring") |
| |
| try: |
| import torch |
| if torch.cuda.is_available(): |
| |
| total_memory = torch.cuda.get_device_properties(0).total_memory |
| memory_gb = total_memory / (1024**3) |
| device_name = torch.cuda.get_device_name(0) |
| is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None |
| if is_rocm: |
| logger.info(f"ROCm GPU detected: {device_name} ({memory_gb:.1f} GB, HIP {torch.version.hip})") |
| else: |
| logger.info(f"CUDA GPU detected: {device_name} ({memory_gb:.1f} GB)") |
| return memory_gb |
| elif hasattr(torch, 'xpu') and torch.xpu.is_available(): |
| |
| total_memory = torch.xpu.get_device_properties(0).total_memory |
| memory_gb = total_memory / (1024**3) |
| return memory_gb |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| mps_module = getattr(torch, "mps", None) |
| try: |
| if mps_module is not None and hasattr(mps_module, "recommended_max_memory"): |
| total_memory = mps_module.recommended_max_memory() |
| memory_gb = total_memory / (1024**3) |
| return memory_gb |
| if mps_module is not None and hasattr(mps_module, "get_device_properties"): |
| props = mps_module.get_device_properties(0) |
| total_memory = getattr(props, "total_memory", None) |
| if total_memory: |
| memory_gb = total_memory / (1024**3) |
| return memory_gb |
| except Exception as e: |
| logger.warning(f"Failed to detect MPS memory: {e}") |
|
|
| |
| try: |
| import subprocess |
| result = subprocess.run( |
| ["sysctl", "-n", "hw.memsize"], |
| capture_output=True, text=True, timeout=5 |
| ) |
| total_system_bytes = int(result.stdout.strip()) |
| |
| memory_gb = (total_system_bytes / (1024**3)) * 0.75 |
| return memory_gb |
| except Exception: |
| logger.warning(f"MPS available but total memory not exposed. Set {DEBUG_MAX_MPS_VRAM_ENV} to enable tiering.") |
| |
| return 8.0 |
| else: |
| |
| _log_gpu_diagnostic_info(torch) |
| return 0 |
| except Exception as e: |
| logger.warning(f"Failed to detect GPU memory: {e}") |
| return 0 |
|
|
|
|
| def _log_gpu_diagnostic_info(torch_module): |
| """ |
| Log diagnostic information when GPU is not detected to help users troubleshoot. |
| |
| Args: |
| torch_module: The torch module to inspect for build information |
| """ |
| logger.warning("=" * 80) |
| logger.warning("⚠️ GPU NOT DETECTED - DIAGNOSTIC INFORMATION") |
| logger.warning("=" * 80) |
| |
| |
| is_rocm_build = hasattr(torch_module.version, 'hip') and torch_module.version.hip is not None |
| is_cuda_build = hasattr(torch_module.version, 'cuda') and torch_module.version.cuda is not None |
| |
| if is_rocm_build: |
| logger.warning("✓ PyTorch ROCm build detected") |
| logger.warning(f" HIP version: {torch_module.version.hip}") |
| logger.warning("") |
| logger.warning("❌ torch.cuda.is_available() returned False") |
| logger.warning("") |
| logger.warning("Common causes for AMD/ROCm GPUs:") |
| logger.warning(" 1. ROCm drivers not installed or not properly configured") |
| logger.warning(" 2. GPU not supported by installed ROCm version") |
| logger.warning(" 3. Missing or incorrect HSA_OVERRIDE_GFX_VERSION environment variable") |
| logger.warning(" 4. ROCm runtime libraries not in system path") |
| logger.warning("") |
| |
| |
| hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION') |
| if hsa_override: |
| logger.warning(f" HSA_OVERRIDE_GFX_VERSION is set to: {hsa_override}") |
| else: |
| logger.warning(" ⚠️ HSA_OVERRIDE_GFX_VERSION is not set") |
| logger.warning(" For RDNA3 GPUs (RX 7000 series, RX 9000 series):") |
| logger.warning(" - RX 7900 XT/XTX, RX 9070 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.0") |
| logger.warning(" - RX 7800 XT, RX 7700 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.1") |
| logger.warning(" - RX 7600: set HSA_OVERRIDE_GFX_VERSION=11.0.2") |
| |
| logger.warning("") |
| logger.warning("Troubleshooting steps:") |
| logger.warning(" 1. Verify ROCm installation:") |
| logger.warning(" rocm-smi # Should list your GPU") |
| logger.warning(" 2. Check PyTorch ROCm build:") |
| logger.warning(" python -c \"import torch; print(f'ROCm: {torch.version.hip}')\"") |
| logger.warning(" 3. Set HSA_OVERRIDE_GFX_VERSION for your GPU (see above)") |
| logger.warning(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets required env vars") |
| logger.warning(" 5. See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md for Linux setup") |
| logger.warning(" 6. See requirements-rocm.txt for Windows ROCm setup instructions") |
| |
| elif is_cuda_build: |
| logger.warning("✓ PyTorch CUDA build detected") |
| logger.warning(f" CUDA version: {torch_module.version.cuda}") |
| logger.warning("") |
| logger.warning("❌ torch.cuda.is_available() returned False") |
| logger.warning("") |
| logger.warning("Common causes for NVIDIA GPUs:") |
| logger.warning(" 1. NVIDIA drivers not installed") |
| logger.warning(" 2. CUDA runtime not installed or version mismatch") |
| logger.warning(" 3. GPU not supported by installed CUDA version") |
| logger.warning("") |
| logger.warning("Troubleshooting steps:") |
| logger.warning(" 1. Verify NVIDIA driver installation:") |
| logger.warning(" nvidia-smi # Should list your GPU") |
| logger.warning(" 2. Check CUDA version compatibility") |
| logger.warning(" 3. Reinstall PyTorch with CUDA support:") |
| logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}") |
| |
| else: |
| logger.warning("⚠️ PyTorch build type: CPU-only") |
| logger.warning("") |
| logger.warning("You have installed a CPU-only version of PyTorch!") |
| logger.warning("") |
| logger.warning("For NVIDIA GPUs:") |
| logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}") |
| logger.warning("") |
| logger.warning("For AMD GPUs with ROCm:") |
| logger.warning(" Windows: See requirements-rocm.txt for detailed instructions") |
| logger.warning(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}") |
| logger.warning("") |
| logger.warning("For more information, see README.md section 'AMD / ROCm GPUs'") |
| |
| logger.warning("=" * 80) |
|
|
|
|
| def get_gpu_tier(gpu_memory_gb: float) -> str: |
| """ |
| Determine GPU tier based on available memory. |
| |
| Args: |
| gpu_memory_gb: GPU memory in GB |
| |
| Returns: |
| Tier string: "tier1", "tier2", "tier3", "tier4", "tier5", "tier6", or "unlimited" |
| """ |
| if gpu_memory_gb <= 0: |
| |
| return "tier1" |
| elif gpu_memory_gb <= 4: |
| return "tier1" |
| elif gpu_memory_gb <= 6: |
| return "tier2" |
| elif gpu_memory_gb <= 8: |
| return "tier3" |
| elif gpu_memory_gb <= 12: |
| return "tier4" |
| elif gpu_memory_gb < VRAM_16GB_MIN_GB: |
| return "tier5" |
| elif gpu_memory_gb <= 24: |
| if gpu_memory_gb < 16.0: |
| logger.info(f"Detected {gpu_memory_gb:.2f}GB VRAM — treating as 16GB class GPU") |
| return "tier6" |
| else: |
| return "unlimited" |
|
|
|
|
| def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig: |
| """ |
| Get GPU configuration based on detected or provided GPU memory. |
| |
| Args: |
| gpu_memory_gb: GPU memory in GB. If None, will be auto-detected. |
| |
| Returns: |
| GPUConfig object with all configuration parameters |
| """ |
| if gpu_memory_gb is None: |
| gpu_memory_gb = get_gpu_memory_gb() |
| |
| tier = get_gpu_tier(gpu_memory_gb) |
| config = GPU_TIER_CONFIGS[tier] |
| |
| return GPUConfig( |
| tier=tier, |
| gpu_memory_gb=gpu_memory_gb, |
| max_duration_with_lm=config["max_duration_with_lm"], |
| max_duration_without_lm=config["max_duration_without_lm"], |
| max_batch_size_with_lm=config["max_batch_size_with_lm"], |
| max_batch_size_without_lm=config["max_batch_size_without_lm"], |
| init_lm_default=config["init_lm_default"], |
| available_lm_models=config["available_lm_models"], |
| lm_memory_gb=config["lm_memory_gb"], |
| ) |
|
|
|
|
| def get_lm_model_size(model_path: str) -> str: |
| """ |
| Extract LM model size from model path. |
| |
| Args: |
| model_path: Model path string (e.g., "acestep-5Hz-lm-0.6B") |
| |
| Returns: |
| Model size string: "0.6B", "1.7B", or "4B" |
| """ |
| if "0.6B" in model_path: |
| return "0.6B" |
| elif "1.7B" in model_path: |
| return "1.7B" |
| elif "4B" in model_path: |
| return "4B" |
| else: |
| |
| return "0.6B" |
|
|
|
|
| def get_lm_gpu_memory_ratio(model_path: str, total_gpu_memory_gb: float) -> Tuple[float, float]: |
| """ |
| Calculate GPU memory utilization ratio for LM model. |
| |
| Args: |
| model_path: LM model path (e.g., "acestep-5Hz-lm-0.6B") |
| total_gpu_memory_gb: Total GPU memory in GB |
| |
| Returns: |
| Tuple of (gpu_memory_utilization_ratio, target_memory_gb) |
| """ |
| model_size = get_lm_model_size(model_path) |
| |
| |
| target_memory = { |
| "0.6B": 3.0, |
| "1.7B": 8.0, |
| "4B": 12.0, |
| } |
| |
| target_gb = target_memory.get(model_size, 3.0) |
| |
| |
| if total_gpu_memory_gb >= 24: |
| |
| ratio = min(0.9, max(0.2, target_gb / total_gpu_memory_gb)) |
| else: |
| |
| ratio = min(0.9, max(0.1, target_gb / total_gpu_memory_gb)) |
| |
| return ratio, target_gb |
|
|
|
|
| def check_duration_limit( |
| duration: float, |
| gpu_config: GPUConfig, |
| lm_initialized: bool |
| ) -> Tuple[bool, str]: |
| """ |
| Check if requested duration is within limits for current GPU configuration. |
| |
| Args: |
| duration: Requested duration in seconds |
| gpu_config: Current GPU configuration |
| lm_initialized: Whether LM is initialized |
| |
| Returns: |
| Tuple of (is_valid, warning_message) |
| """ |
| max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm |
| |
| if duration > max_duration: |
| warning_msg = ( |
| f"⚠️ Requested duration ({duration:.0f}s) exceeds the limit for your GPU " |
| f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_duration}s " |
| f"({'with' if lm_initialized else 'without'} LM). " |
| f"Duration will be clamped to {max_duration}s." |
| ) |
| return False, warning_msg |
| |
| return True, "" |
|
|
|
|
| def check_batch_size_limit( |
| batch_size: int, |
| gpu_config: GPUConfig, |
| lm_initialized: bool |
| ) -> Tuple[bool, str]: |
| """ |
| Check if requested batch size is within limits for current GPU configuration. |
| |
| Args: |
| batch_size: Requested batch size |
| gpu_config: Current GPU configuration |
| lm_initialized: Whether LM is initialized |
| |
| Returns: |
| Tuple of (is_valid, warning_message) |
| """ |
| max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm |
| |
| if batch_size > max_batch_size: |
| warning_msg = ( |
| f"⚠️ Requested batch size ({batch_size}) exceeds the limit for your GPU " |
| f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_batch_size} " |
| f"({'with' if lm_initialized else 'without'} LM). " |
| f"Batch size will be clamped to {max_batch_size}." |
| ) |
| return False, warning_msg |
| |
| return True, "" |
|
|
|
|
| def is_lm_model_supported(model_path: str, gpu_config: GPUConfig) -> Tuple[bool, str]: |
| """ |
| Check if the specified LM model is supported for current GPU configuration. |
| |
| Args: |
| model_path: LM model path |
| gpu_config: Current GPU configuration |
| |
| Returns: |
| Tuple of (is_supported, warning_message) |
| """ |
| if not gpu_config.available_lm_models: |
| return False, ( |
| f"⚠️ Your GPU ({gpu_config.gpu_memory_gb:.1f}GB) does not have enough memory " |
| f"to run any LM model. Please disable LM initialization." |
| ) |
| |
| model_size = get_lm_model_size(model_path) |
| |
| |
| for available_model in gpu_config.available_lm_models: |
| if model_size in available_model: |
| return True, "" |
| |
| return False, ( |
| f"⚠️ LM model {model_path} ({model_size}) is not supported for your GPU " |
| f"({gpu_config.gpu_memory_gb:.1f}GB). Available models: {', '.join(gpu_config.available_lm_models)}" |
| ) |
|
|
|
|
| def get_recommended_lm_model(gpu_config: GPUConfig) -> Optional[str]: |
| """ |
| Get recommended LM model for current GPU configuration. |
| |
| Args: |
| gpu_config: Current GPU configuration |
| |
| Returns: |
| Recommended LM model path, or None if LM is not supported |
| """ |
| if not gpu_config.available_lm_models: |
| return None |
| |
| |
| return gpu_config.available_lm_models[-1] |
|
|
|
|
| def print_gpu_config_info(gpu_config: GPUConfig): |
| """Print GPU configuration information for debugging.""" |
| logger.info(f"GPU Configuration:") |
| logger.info(f" - GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB") |
| logger.info(f" - Tier: {gpu_config.tier}") |
| logger.info(f" - Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)") |
| logger.info(f" - Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)") |
| logger.info(f" - Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}") |
| logger.info(f" - Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}") |
| logger.info(f" - Init LM by Default: {gpu_config.init_lm_default}") |
| logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}") |
|
|
|
|
| |
| _global_gpu_config: Optional[GPUConfig] = None |
|
|
|
|
| def get_global_gpu_config() -> GPUConfig: |
| """Get the global GPU configuration, initializing if necessary.""" |
| global _global_gpu_config |
| if _global_gpu_config is None: |
| _global_gpu_config = get_gpu_config() |
| return _global_gpu_config |
|
|
|
|
| def set_global_gpu_config(config: GPUConfig): |
| """Set the global GPU configuration.""" |
| global _global_gpu_config |
| _global_gpu_config = config |
|
|