ACE-Step-Custom / acestep /gpu_config.py
ACE-Step Custom
Deploy ACE-Step Custom Edition with bug fixes
a602628
"""
GPU Configuration Module
Centralized GPU memory detection and adaptive configuration management
Debug Mode:
Set environment variable MAX_CUDA_VRAM to simulate different GPU memory sizes.
Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
For MPS testing, use MAX_MPS_VRAM to simulate MPS memory.
Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
This is useful for testing GPU tier configurations on high-end hardware.
"""
import os
import sys
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple
from loguru import logger
# Environment variable for debugging/testing different GPU memory configurations
DEBUG_MAX_CUDA_VRAM_ENV = "MAX_CUDA_VRAM"
DEBUG_MAX_MPS_VRAM_ENV = "MAX_MPS_VRAM"
# Tolerance for 16GB detection: reported VRAM like 15.5GB is effectively 16GB hardware
# Real-world 16GB GPUs often report 15.7-15.9GB due to system/driver reservations
VRAM_16GB_TOLERANCE_GB = 0.5
VRAM_16GB_MIN_GB = 16.0 - VRAM_16GB_TOLERANCE_GB # treat as 16GB class if >= this
# PyTorch installation URLs for diagnostics
PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
@dataclass
class GPUConfig:
"""GPU configuration based on available memory"""
tier: str # "tier1", "tier2", etc. or "unlimited"
gpu_memory_gb: float
# Duration limits (in seconds)
max_duration_with_lm: int # When LM is initialized
max_duration_without_lm: int # When LM is not initialized
# Batch size limits
max_batch_size_with_lm: int
max_batch_size_without_lm: int
# LM configuration
init_lm_default: bool # Whether to initialize LM by default
available_lm_models: List[str] # Available LM models for this tier
# LM memory allocation (GB) for each model size
lm_memory_gb: Dict[str, float] # e.g., {"0.6B": 3, "1.7B": 8, "4B": 12}
# GPU tier configurations
GPU_TIER_CONFIGS = {
"tier1": { # <= 4GB
"max_duration_with_lm": 180, # 3 minutes
"max_duration_without_lm": 180, # 3 minutes
"max_batch_size_with_lm": 1,
"max_batch_size_without_lm": 1,
"init_lm_default": False,
"available_lm_models": [],
"lm_memory_gb": {},
},
"tier2": { # 4-6GB
"max_duration_with_lm": 360, # 6 minutes
"max_duration_without_lm": 360, # 6 minutes
"max_batch_size_with_lm": 1,
"max_batch_size_without_lm": 1,
"init_lm_default": False,
"available_lm_models": [],
"lm_memory_gb": {},
},
"tier3": { # 6-8GB
"max_duration_with_lm": 240, # 4 minutes with LM
"max_duration_without_lm": 360, # 6 minutes without LM
"max_batch_size_with_lm": 1,
"max_batch_size_without_lm": 2,
"init_lm_default": False, # Don't init by default due to limited memory
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
"lm_memory_gb": {"0.6B": 3},
},
"tier4": { # 8-12GB
"max_duration_with_lm": 240, # 4 minutes with LM
"max_duration_without_lm": 360, # 6 minutes without LM
"max_batch_size_with_lm": 2,
"max_batch_size_without_lm": 4,
"init_lm_default": False, # Don't init by default
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
"lm_memory_gb": {"0.6B": 3},
},
"tier5": { # 12-16GB
"max_duration_with_lm": 240, # 4 minutes with LM
"max_duration_without_lm": 360, # 6 minutes without LM
"max_batch_size_with_lm": 2,
"max_batch_size_without_lm": 4,
"init_lm_default": True,
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B"],
"lm_memory_gb": {"0.6B": 3, "1.7B": 8},
},
"tier6": { # 16-24GB
"max_duration_with_lm": 480, # 8 minutes
"max_duration_without_lm": 480, # 8 minutes
"max_batch_size_with_lm": 4,
"max_batch_size_without_lm": 8,
"init_lm_default": True,
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
"lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
},
"unlimited": { # >= 24GB
"max_duration_with_lm": 600, # 10 minutes (max supported)
"max_duration_without_lm": 600, # 10 minutes
"max_batch_size_with_lm": 8,
"max_batch_size_without_lm": 8,
"init_lm_default": True,
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
"lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
},
}
def get_gpu_memory_gb() -> float:
"""
Get GPU memory in GB. Returns 0 if no GPU is available.
Debug Mode:
Set environment variable MAX_CUDA_VRAM to override the detected GPU memory.
Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
For MPS testing, set MAX_MPS_VRAM to override MPS memory detection.
Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
This allows testing different GPU tier configurations on high-end hardware.
"""
# Check for debug override first
debug_vram = os.environ.get(DEBUG_MAX_CUDA_VRAM_ENV)
if debug_vram is not None:
try:
simulated_gb = float(debug_vram)
logger.warning(f"⚠️ DEBUG MODE: Simulating GPU memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_CUDA_VRAM_ENV} environment variable)")
return simulated_gb
except ValueError:
logger.warning(f"Invalid {DEBUG_MAX_CUDA_VRAM_ENV} value: {debug_vram}, ignoring")
debug_mps_vram = os.environ.get(DEBUG_MAX_MPS_VRAM_ENV)
if debug_mps_vram is not None:
try:
simulated_gb = float(debug_mps_vram)
logger.warning(f"⚠️ DEBUG MODE: Simulating MPS memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_MPS_VRAM_ENV} environment variable)")
return simulated_gb
except ValueError:
logger.warning(f"Invalid {DEBUG_MAX_MPS_VRAM_ENV} value: {debug_mps_vram}, ignoring")
try:
import torch
if torch.cuda.is_available():
# Get total memory of the first GPU in GB
total_memory = torch.cuda.get_device_properties(0).total_memory
memory_gb = total_memory / (1024**3) # Convert bytes to GB
device_name = torch.cuda.get_device_name(0)
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
if is_rocm:
logger.info(f"ROCm GPU detected: {device_name} ({memory_gb:.1f} GB, HIP {torch.version.hip})")
else:
logger.info(f"CUDA GPU detected: {device_name} ({memory_gb:.1f} GB)")
return memory_gb
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
# Get total memory of the first XPU in GB
total_memory = torch.xpu.get_device_properties(0).total_memory
memory_gb = total_memory / (1024**3) # Convert bytes to GB
return memory_gb
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
mps_module = getattr(torch, "mps", None)
try:
if mps_module is not None and hasattr(mps_module, "recommended_max_memory"):
total_memory = mps_module.recommended_max_memory()
memory_gb = total_memory / (1024**3) # Convert bytes to GB
return memory_gb
if mps_module is not None and hasattr(mps_module, "get_device_properties"):
props = mps_module.get_device_properties(0)
total_memory = getattr(props, "total_memory", None)
if total_memory:
memory_gb = total_memory / (1024**3)
return memory_gb
except Exception as e:
logger.warning(f"Failed to detect MPS memory: {e}")
# Fallback: estimate from system unified memory (Apple Silicon shares CPU/GPU RAM)
try:
import subprocess
result = subprocess.run(
["sysctl", "-n", "hw.memsize"],
capture_output=True, text=True, timeout=5
)
total_system_bytes = int(result.stdout.strip())
# MPS can use up to ~75% of unified memory for GPU workloads
memory_gb = (total_system_bytes / (1024**3)) * 0.75
return memory_gb
except Exception:
logger.warning(f"MPS available but total memory not exposed. Set {DEBUG_MAX_MPS_VRAM_ENV} to enable tiering.")
# Conservative fallback for M1/M2
return 8.0
else:
# No GPU detected - provide diagnostic information
_log_gpu_diagnostic_info(torch)
return 0
except Exception as e:
logger.warning(f"Failed to detect GPU memory: {e}")
return 0
def _log_gpu_diagnostic_info(torch_module):
"""
Log diagnostic information when GPU is not detected to help users troubleshoot.
Args:
torch_module: The torch module to inspect for build information
"""
logger.warning("=" * 80)
logger.warning("⚠️ GPU NOT DETECTED - DIAGNOSTIC INFORMATION")
logger.warning("=" * 80)
# Check PyTorch build type
is_rocm_build = hasattr(torch_module.version, 'hip') and torch_module.version.hip is not None
is_cuda_build = hasattr(torch_module.version, 'cuda') and torch_module.version.cuda is not None
if is_rocm_build:
logger.warning("✓ PyTorch ROCm build detected")
logger.warning(f" HIP version: {torch_module.version.hip}")
logger.warning("")
logger.warning("❌ torch.cuda.is_available() returned False")
logger.warning("")
logger.warning("Common causes for AMD/ROCm GPUs:")
logger.warning(" 1. ROCm drivers not installed or not properly configured")
logger.warning(" 2. GPU not supported by installed ROCm version")
logger.warning(" 3. Missing or incorrect HSA_OVERRIDE_GFX_VERSION environment variable")
logger.warning(" 4. ROCm runtime libraries not in system path")
logger.warning("")
# Check for common environment variables
hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
if hsa_override:
logger.warning(f" HSA_OVERRIDE_GFX_VERSION is set to: {hsa_override}")
else:
logger.warning(" ⚠️ HSA_OVERRIDE_GFX_VERSION is not set")
logger.warning(" For RDNA3 GPUs (RX 7000 series, RX 9000 series):")
logger.warning(" - RX 7900 XT/XTX, RX 9070 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
logger.warning(" - RX 7800 XT, RX 7700 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.1")
logger.warning(" - RX 7600: set HSA_OVERRIDE_GFX_VERSION=11.0.2")
logger.warning("")
logger.warning("Troubleshooting steps:")
logger.warning(" 1. Verify ROCm installation:")
logger.warning(" rocm-smi # Should list your GPU")
logger.warning(" 2. Check PyTorch ROCm build:")
logger.warning(" python -c \"import torch; print(f'ROCm: {torch.version.hip}')\"")
logger.warning(" 3. Set HSA_OVERRIDE_GFX_VERSION for your GPU (see above)")
logger.warning(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets required env vars")
logger.warning(" 5. See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md for Linux setup")
logger.warning(" 6. See requirements-rocm.txt for Windows ROCm setup instructions")
elif is_cuda_build:
logger.warning("✓ PyTorch CUDA build detected")
logger.warning(f" CUDA version: {torch_module.version.cuda}")
logger.warning("")
logger.warning("❌ torch.cuda.is_available() returned False")
logger.warning("")
logger.warning("Common causes for NVIDIA GPUs:")
logger.warning(" 1. NVIDIA drivers not installed")
logger.warning(" 2. CUDA runtime not installed or version mismatch")
logger.warning(" 3. GPU not supported by installed CUDA version")
logger.warning("")
logger.warning("Troubleshooting steps:")
logger.warning(" 1. Verify NVIDIA driver installation:")
logger.warning(" nvidia-smi # Should list your GPU")
logger.warning(" 2. Check CUDA version compatibility")
logger.warning(" 3. Reinstall PyTorch with CUDA support:")
logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
else:
logger.warning("⚠️ PyTorch build type: CPU-only")
logger.warning("")
logger.warning("You have installed a CPU-only version of PyTorch!")
logger.warning("")
logger.warning("For NVIDIA GPUs:")
logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
logger.warning("")
logger.warning("For AMD GPUs with ROCm:")
logger.warning(" Windows: See requirements-rocm.txt for detailed instructions")
logger.warning(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
logger.warning("")
logger.warning("For more information, see README.md section 'AMD / ROCm GPUs'")
logger.warning("=" * 80)
def get_gpu_tier(gpu_memory_gb: float) -> str:
"""
Determine GPU tier based on available memory.
Args:
gpu_memory_gb: GPU memory in GB
Returns:
Tier string: "tier1", "tier2", "tier3", "tier4", "tier5", "tier6", or "unlimited"
"""
if gpu_memory_gb <= 0:
# CPU mode - use tier1 limits
return "tier1"
elif gpu_memory_gb <= 4:
return "tier1"
elif gpu_memory_gb <= 6:
return "tier2"
elif gpu_memory_gb <= 8:
return "tier3"
elif gpu_memory_gb <= 12:
return "tier4"
elif gpu_memory_gb < VRAM_16GB_MIN_GB:
return "tier5"
elif gpu_memory_gb <= 24:
if gpu_memory_gb < 16.0:
logger.info(f"Detected {gpu_memory_gb:.2f}GB VRAM — treating as 16GB class GPU")
return "tier6"
else:
return "unlimited"
def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
"""
Get GPU configuration based on detected or provided GPU memory.
Args:
gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
Returns:
GPUConfig object with all configuration parameters
"""
if gpu_memory_gb is None:
gpu_memory_gb = get_gpu_memory_gb()
tier = get_gpu_tier(gpu_memory_gb)
config = GPU_TIER_CONFIGS[tier]
return GPUConfig(
tier=tier,
gpu_memory_gb=gpu_memory_gb,
max_duration_with_lm=config["max_duration_with_lm"],
max_duration_without_lm=config["max_duration_without_lm"],
max_batch_size_with_lm=config["max_batch_size_with_lm"],
max_batch_size_without_lm=config["max_batch_size_without_lm"],
init_lm_default=config["init_lm_default"],
available_lm_models=config["available_lm_models"],
lm_memory_gb=config["lm_memory_gb"],
)
def get_lm_model_size(model_path: str) -> str:
"""
Extract LM model size from model path.
Args:
model_path: Model path string (e.g., "acestep-5Hz-lm-0.6B")
Returns:
Model size string: "0.6B", "1.7B", or "4B"
"""
if "0.6B" in model_path:
return "0.6B"
elif "1.7B" in model_path:
return "1.7B"
elif "4B" in model_path:
return "4B"
else:
# Default to smallest model assumption
return "0.6B"
def get_lm_gpu_memory_ratio(model_path: str, total_gpu_memory_gb: float) -> Tuple[float, float]:
"""
Calculate GPU memory utilization ratio for LM model.
Args:
model_path: LM model path (e.g., "acestep-5Hz-lm-0.6B")
total_gpu_memory_gb: Total GPU memory in GB
Returns:
Tuple of (gpu_memory_utilization_ratio, target_memory_gb)
"""
model_size = get_lm_model_size(model_path)
# Target memory allocation for each model size
target_memory = {
"0.6B": 3.0,
"1.7B": 8.0,
"4B": 12.0,
}
target_gb = target_memory.get(model_size, 3.0)
# For large GPUs (>=24GB), don't restrict memory too much
if total_gpu_memory_gb >= 24:
# Use a reasonable ratio that allows the model to run efficiently
ratio = min(0.9, max(0.2, target_gb / total_gpu_memory_gb))
else:
# For smaller GPUs, strictly limit memory usage
ratio = min(0.9, max(0.1, target_gb / total_gpu_memory_gb))
return ratio, target_gb
def check_duration_limit(
duration: float,
gpu_config: GPUConfig,
lm_initialized: bool
) -> Tuple[bool, str]:
"""
Check if requested duration is within limits for current GPU configuration.
Args:
duration: Requested duration in seconds
gpu_config: Current GPU configuration
lm_initialized: Whether LM is initialized
Returns:
Tuple of (is_valid, warning_message)
"""
max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
if duration > max_duration:
warning_msg = (
f"⚠️ Requested duration ({duration:.0f}s) exceeds the limit for your GPU "
f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_duration}s "
f"({'with' if lm_initialized else 'without'} LM). "
f"Duration will be clamped to {max_duration}s."
)
return False, warning_msg
return True, ""
def check_batch_size_limit(
batch_size: int,
gpu_config: GPUConfig,
lm_initialized: bool
) -> Tuple[bool, str]:
"""
Check if requested batch size is within limits for current GPU configuration.
Args:
batch_size: Requested batch size
gpu_config: Current GPU configuration
lm_initialized: Whether LM is initialized
Returns:
Tuple of (is_valid, warning_message)
"""
max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
if batch_size > max_batch_size:
warning_msg = (
f"⚠️ Requested batch size ({batch_size}) exceeds the limit for your GPU "
f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_batch_size} "
f"({'with' if lm_initialized else 'without'} LM). "
f"Batch size will be clamped to {max_batch_size}."
)
return False, warning_msg
return True, ""
def is_lm_model_supported(model_path: str, gpu_config: GPUConfig) -> Tuple[bool, str]:
"""
Check if the specified LM model is supported for current GPU configuration.
Args:
model_path: LM model path
gpu_config: Current GPU configuration
Returns:
Tuple of (is_supported, warning_message)
"""
if not gpu_config.available_lm_models:
return False, (
f"⚠️ Your GPU ({gpu_config.gpu_memory_gb:.1f}GB) does not have enough memory "
f"to run any LM model. Please disable LM initialization."
)
model_size = get_lm_model_size(model_path)
# Check if model size is in available models
for available_model in gpu_config.available_lm_models:
if model_size in available_model:
return True, ""
return False, (
f"⚠️ LM model {model_path} ({model_size}) is not supported for your GPU "
f"({gpu_config.gpu_memory_gb:.1f}GB). Available models: {', '.join(gpu_config.available_lm_models)}"
)
def get_recommended_lm_model(gpu_config: GPUConfig) -> Optional[str]:
"""
Get recommended LM model for current GPU configuration.
Args:
gpu_config: Current GPU configuration
Returns:
Recommended LM model path, or None if LM is not supported
"""
if not gpu_config.available_lm_models:
return None
# Return the largest available model (last in the list)
return gpu_config.available_lm_models[-1]
def print_gpu_config_info(gpu_config: GPUConfig):
"""Print GPU configuration information for debugging."""
logger.info(f"GPU Configuration:")
logger.info(f" - GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB")
logger.info(f" - Tier: {gpu_config.tier}")
logger.info(f" - Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
logger.info(f" - Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
logger.info(f" - Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
logger.info(f" - Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
logger.info(f" - Init LM by Default: {gpu_config.init_lm_default}")
logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
# Global GPU config instance (initialized lazily)
_global_gpu_config: Optional[GPUConfig] = None
def get_global_gpu_config() -> GPUConfig:
"""Get the global GPU configuration, initializing if necessary."""
global _global_gpu_config
if _global_gpu_config is None:
_global_gpu_config = get_gpu_config()
return _global_gpu_config
def set_global_gpu_config(config: GPUConfig):
"""Set the global GPU configuration."""
global _global_gpu_config
_global_gpu_config = config