ACE-Step-Custom / src /ace_step_engine.py
ACE-Step Custom
Fix: Implement lazy model loading for ZeroGPU compatibility
052ca84
"""
ACE-Step Engine - Wrapper for ACE-Step 1.5 official architecture
Properly integrates AceStepHandler (DiT) and LLMHandler (5Hz LM)
"""
import torch
from pathlib import Path
import logging
from typing import Optional, Dict, Any, Tuple
import os
logger = logging.getLogger(__name__)
# Import ACE-Step 1.5 official handlers
try:
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.inference import GenerationParams, GenerationConfig, generate_music
from acestep.model_downloader import ensure_main_model, get_checkpoints_dir, check_main_model_exists
ACE_STEP_AVAILABLE = True
except ImportError as e:
logger.warning(f"ACE-Step 1.5 modules not available: {e}")
ACE_STEP_AVAILABLE = False
class ACEStepEngine:
"""Wrapper engine for ACE-Step 1.5 with custom interface."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize ACE-Step engine.
Args:
config: Configuration dictionary
"""
self.config = config
self._initialized = False
self.dit_handler = None
self.llm_handler = None
logger.info(f"ACE-Step Engine created (GPU will be detected on first use)")
if not ACE_STEP_AVAILABLE:
logger.error("ACE-Step 1.5 modules not available")
logger.error("Please ensure acestep package is installed in your environment")
return
logger.info("✓ ACE-Step Engine created (models will load on first use)")
def _download_checkpoints(self):
"""Download model checkpoints from HuggingFace if not present."""
checkpoints_dir = get_checkpoints_dir(self.config.get("checkpoint_dir"))
# Check if main model already exists
if check_main_model_exists(checkpoints_dir):
logger.info(f"✓ ACE-Step 1.5 models already exist at {checkpoints_dir}")
return
logger.info("Downloading ACE-Step 1.5 models from HuggingFace...")
logger.info("This may take several minutes (models are ~7GB total)...")
try:
# Use the built-in model downloader
success, message = ensure_main_model(
checkpoints_dir=checkpoints_dir,
prefer_source="huggingface" # Use HuggingFace for Spaces
)
if not success:
raise RuntimeError(f"Failed to download models: {message}")
logger.info(f"✓ {message}")
logger.info("✓ All ACE-Step 1.5 models downloaded successfully")
except Exception as e:
logger.error(f"Failed to download checkpoints: {e}")
raise
def _load_models(self):
"""Initialize and load ACE-Step models."""
try:
if not ACE_STEP_AVAILABLE:
raise RuntimeError("ACE-Step 1.5 not available")
checkpoint_dir = self.config.get("checkpoint_dir", "./checkpoints")
dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
# Get checkpoints directory using helper function
checkpoints_dir = get_checkpoints_dir(checkpoint_dir)
# Get project root
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
logger.info(f"Initializing DiT handler with model: {dit_model_path}")
# Initialize DiT handler (handles main diffusion model, VAE, text encoder)
# Note: handler auto-detects checkpoints dir as project_root/checkpoints
# config_path should be just the model name, not full path
status_dit, success_dit = self.dit_handler.initialize_service(
project_root=project_root,
config_path=dit_model_path, # Just model name, handler adds checkpoints/
device="auto",
use_flash_attention=False,
compile_model=False,
offload_to_cpu=False,
)
if not success_dit:
raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
logger.info(f"✓ DiT initialized: {status_dit}")
# Initialize LLM handler (handles 5Hz Language Model)
logger.info(f"Initializing LLM handler with model: {lm_model_path}")
status_llm, success_llm = self.llm_handler.initialize(
checkpoint_dir=str(checkpoints_dir),
lm_model_path=lm_model_path,
backend="pt", # Use PyTorch backend for compatibility
device="auto",
offload_to_cpu=False,
)
if not success_llm:
logger.warning(f"LLM initialization failed: {status_llm}")
logger.warning("Continuing without LLM (DiT-only mode)")
else:
logger.info(f" LLM initialized: {status_llm}")
self._initialized = True
logger.info(" ACE-Step engine fully initialized")
except Exception as e:
logger.error(f"Failed to initialize models: {e}")
raise
def _ensure_models_loaded(self):
"""Ensure models are loaded (lazy loading for ZeroGPU compatibility)."""
if not self._initialized:
logger.info("Lazy loading models on first use...")
# Detect device now (within GPU context on ZeroGPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Create handlers if not already created
if self.dit_handler is None:
self.dit_handler = AceStepHandler()
if self.llm_handler is None:
self.llm_handler = LLMHandler()
try:
# Download and load models
self._download_checkpoints()
self._load_models()
logger.info("✓ Models loaded successfully")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise
def generate(
self,
prompt: str,
lyrics: Optional[str] = None,
duration: int = 30,
temperature: float = 0.7,
top_p: float = 0.9,
seed: int = -1,
style: str = "auto",
lora_path: Optional[str] = None
) -> str:
"""
Generate music using ACE-Step.
Args:
prompt: Text description of desired music
lyrics: Optional lyrics
duration: Duration in seconds
temperature: Sampling temperature (for LLM)
top_p: Nucleus sampling parameter (for LLM)
seed: Random seed (-1 for random)
style: Music style
lora_path: Path to LoRA model if using
Returns:
Path to generated audio file
"""
# Ensure models are loaded (lazy loading for ZeroGPU)
self._ensure_models_loaded()
try:
# Prepare generation parameters
params = GenerationParams(
task_type="text2music",
caption=prompt,
lyrics=lyrics or "",
duration=duration,
inference_steps=8, # Turbo model default
seed=seed if seed >= 0 else -1,
thinking=True, # Use LLM planning
lm_temperature=temperature,
lm_top_p=top_p,
)
# Prepare generation config
config = GenerationConfig(
batch_size=1,
use_random_seed=(seed < 0),
audio_format="wav",
)
# Generate using official inference
output_dir = self.config.get("output_dir", "outputs")
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Generating {duration}s audio: {prompt[:50]}...")
result = generate_music(
dit_handler=self.dit_handler,
llm_handler=self.llm_handler,
params=params,
config=config,
save_dir=output_dir,
)
if result.audio_paths:
output_path = result.audio_paths[0]
logger.info(f" Generated: {output_path}")
return output_path
else:
raise RuntimeError("No audio generated")
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
def generate_clip(
self,
prompt: str,
lyrics: str,
duration: int,
context_audio: Optional[str] = None,
style: str = "auto",
temperature: float = 0.7,
seed: int = -1
) -> str:
"""
Generate audio clip for timeline (with context conditioning).
Args:
prompt: Text prompt
lyrics: Lyrics for this clip
duration: Duration in seconds (typically 32)
context_audio: Path to previous audio for style conditioning
style: Music style
temperature: Sampling temperature
seed: Random seed
Returns:
Path to generated clip
"""
# For timeline clips, use regular generation with extended context
# Context conditioning would require custom implementation
return self.generate(
prompt=prompt,
lyrics=lyrics,
duration=duration,
temperature=temperature,
seed=seed,
style=style
)
def generate_variation(self, audio_path: str, strength: float = 0.5) -> str:
"""Generate variation of existing audio."""
# Ensure models are loaded (lazy loading for ZeroGPU)
self._ensure_models_loaded()
try:
params = GenerationParams(
task_type="audio_variation",
audio_path=audio_path,
audio_cover_strength=strength,
inference_steps=8,
)
config = GenerationConfig(
batch_size=1,
audio_format="wav",
)
output_dir = self.config.get("output_dir", "outputs")
result = generate_music(
self.dit_handler,
self.llm_handler,
params,
config,
save_dir=output_dir,
)
return result.audio_paths[0] if result.audio_paths else audio_path
except Exception as e:
logger.error(f"Variation generation failed: {e}")
raise
def repaint(
self,
audio_path: str,
start_time: float,
end_time: float,
new_prompt: str
) -> str:
"""Repaint specific section of audio."""
if not self._initialized:
raise RuntimeError("Engine not initialized")
try:
params = GenerationParams(
task_type="repainting",
audio_path=audio_path,
caption=new_prompt,
repainting_start=start_time,
repainting_end=end_time,
inference_steps=8,
)
config = GenerationConfig(
batch_size=1,
audio_format="wav",
)
output_dir = self.config.get("output_dir", "outputs")
result = generate_music(
self.dit_handler,
self.llm_handler,
params,
config,
save_dir=output_dir,
)
return result.audio_paths[0] if result.audio_paths else audio_path
except Exception as e:
logger.error(f"Repainting failed: {e}")
raise
def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str:
"""Edit lyrics while maintaining music."""
# This would require custom implementation
# For now, regenerate with same style
logger.warning("Lyric editing not fully implemented - regenerating with new lyrics")
return self.generate(
prompt="Match the style of the reference",
lyrics=new_lyrics,
duration=30,
)
def is_initialized(self) -> bool:
"""Check if engine is initialized."""
return self._initialized