""" ACE-Step Engine - Wrapper for ACE-Step 1.5 official architecture Properly integrates AceStepHandler (DiT) and LLMHandler (5Hz LM) """ import torch from pathlib import Path import logging from typing import Optional, Dict, Any, Tuple import os logger = logging.getLogger(__name__) # Import ACE-Step 1.5 official handlers try: from acestep.handler import AceStepHandler from acestep.llm_inference import LLMHandler from acestep.inference import GenerationParams, GenerationConfig, generate_music from acestep.model_downloader import ensure_main_model, get_checkpoints_dir, check_main_model_exists ACE_STEP_AVAILABLE = True except ImportError as e: logger.warning(f"ACE-Step 1.5 modules not available: {e}") ACE_STEP_AVAILABLE = False class ACEStepEngine: """Wrapper engine for ACE-Step 1.5 with custom interface.""" def __init__(self, config: Dict[str, Any]): """ Initialize ACE-Step engine. Args: config: Configuration dictionary """ self.config = config self._initialized = False self.dit_handler = None self.llm_handler = None logger.info(f"ACE-Step Engine created (GPU will be detected on first use)") if not ACE_STEP_AVAILABLE: logger.error("ACE-Step 1.5 modules not available") logger.error("Please ensure acestep package is installed in your environment") return logger.info("✓ ACE-Step Engine created (models will load on first use)") def _download_checkpoints(self): """Download model checkpoints from HuggingFace if not present.""" checkpoints_dir = get_checkpoints_dir(self.config.get("checkpoint_dir")) # Check if main model already exists if check_main_model_exists(checkpoints_dir): logger.info(f"✓ ACE-Step 1.5 models already exist at {checkpoints_dir}") return logger.info("Downloading ACE-Step 1.5 models from HuggingFace...") logger.info("This may take several minutes (models are ~7GB total)...") try: # Use the built-in model downloader success, message = ensure_main_model( checkpoints_dir=checkpoints_dir, prefer_source="huggingface" # Use HuggingFace for Spaces ) if not success: raise RuntimeError(f"Failed to download models: {message}") logger.info(f"✓ {message}") logger.info("✓ All ACE-Step 1.5 models downloaded successfully") except Exception as e: logger.error(f"Failed to download checkpoints: {e}") raise def _load_models(self): """Initialize and load ACE-Step models.""" try: if not ACE_STEP_AVAILABLE: raise RuntimeError("ACE-Step 1.5 not available") checkpoint_dir = self.config.get("checkpoint_dir", "./checkpoints") dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo") lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B") # Get checkpoints directory using helper function checkpoints_dir = get_checkpoints_dir(checkpoint_dir) # Get project root project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) logger.info(f"Initializing DiT handler with model: {dit_model_path}") # Initialize DiT handler (handles main diffusion model, VAE, text encoder) # Note: handler auto-detects checkpoints dir as project_root/checkpoints # config_path should be just the model name, not full path status_dit, success_dit = self.dit_handler.initialize_service( project_root=project_root, config_path=dit_model_path, # Just model name, handler adds checkpoints/ device="auto", use_flash_attention=False, compile_model=False, offload_to_cpu=False, ) if not success_dit: raise RuntimeError(f"Failed to initialize DiT: {status_dit}") logger.info(f"✓ DiT initialized: {status_dit}") # Initialize LLM handler (handles 5Hz Language Model) logger.info(f"Initializing LLM handler with model: {lm_model_path}") status_llm, success_llm = self.llm_handler.initialize( checkpoint_dir=str(checkpoints_dir), lm_model_path=lm_model_path, backend="pt", # Use PyTorch backend for compatibility device="auto", offload_to_cpu=False, ) if not success_llm: logger.warning(f"LLM initialization failed: {status_llm}") logger.warning("Continuing without LLM (DiT-only mode)") else: logger.info(f" LLM initialized: {status_llm}") self._initialized = True logger.info(" ACE-Step engine fully initialized") except Exception as e: logger.error(f"Failed to initialize models: {e}") raise def _ensure_models_loaded(self): """Ensure models are loaded (lazy loading for ZeroGPU compatibility).""" if not self._initialized: logger.info("Lazy loading models on first use...") # Detect device now (within GPU context on ZeroGPU) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {self.device}") # Create handlers if not already created if self.dit_handler is None: self.dit_handler = AceStepHandler() if self.llm_handler is None: self.llm_handler = LLMHandler() try: # Download and load models self._download_checkpoints() self._load_models() logger.info("✓ Models loaded successfully") except Exception as e: logger.error(f"Failed to load models: {e}") raise def generate( self, prompt: str, lyrics: Optional[str] = None, duration: int = 30, temperature: float = 0.7, top_p: float = 0.9, seed: int = -1, style: str = "auto", lora_path: Optional[str] = None ) -> str: """ Generate music using ACE-Step. Args: prompt: Text description of desired music lyrics: Optional lyrics duration: Duration in seconds temperature: Sampling temperature (for LLM) top_p: Nucleus sampling parameter (for LLM) seed: Random seed (-1 for random) style: Music style lora_path: Path to LoRA model if using Returns: Path to generated audio file """ # Ensure models are loaded (lazy loading for ZeroGPU) self._ensure_models_loaded() try: # Prepare generation parameters params = GenerationParams( task_type="text2music", caption=prompt, lyrics=lyrics or "", duration=duration, inference_steps=8, # Turbo model default seed=seed if seed >= 0 else -1, thinking=True, # Use LLM planning lm_temperature=temperature, lm_top_p=top_p, ) # Prepare generation config config = GenerationConfig( batch_size=1, use_random_seed=(seed < 0), audio_format="wav", ) # Generate using official inference output_dir = self.config.get("output_dir", "outputs") os.makedirs(output_dir, exist_ok=True) logger.info(f"Generating {duration}s audio: {prompt[:50]}...") result = generate_music( dit_handler=self.dit_handler, llm_handler=self.llm_handler, params=params, config=config, save_dir=output_dir, ) if result.audio_paths: output_path = result.audio_paths[0] logger.info(f" Generated: {output_path}") return output_path else: raise RuntimeError("No audio generated") except Exception as e: logger.error(f"Generation failed: {e}") raise def generate_clip( self, prompt: str, lyrics: str, duration: int, context_audio: Optional[str] = None, style: str = "auto", temperature: float = 0.7, seed: int = -1 ) -> str: """ Generate audio clip for timeline (with context conditioning). Args: prompt: Text prompt lyrics: Lyrics for this clip duration: Duration in seconds (typically 32) context_audio: Path to previous audio for style conditioning style: Music style temperature: Sampling temperature seed: Random seed Returns: Path to generated clip """ # For timeline clips, use regular generation with extended context # Context conditioning would require custom implementation return self.generate( prompt=prompt, lyrics=lyrics, duration=duration, temperature=temperature, seed=seed, style=style ) def generate_variation(self, audio_path: str, strength: float = 0.5) -> str: """Generate variation of existing audio.""" # Ensure models are loaded (lazy loading for ZeroGPU) self._ensure_models_loaded() try: params = GenerationParams( task_type="audio_variation", audio_path=audio_path, audio_cover_strength=strength, inference_steps=8, ) config = GenerationConfig( batch_size=1, audio_format="wav", ) output_dir = self.config.get("output_dir", "outputs") result = generate_music( self.dit_handler, self.llm_handler, params, config, save_dir=output_dir, ) return result.audio_paths[0] if result.audio_paths else audio_path except Exception as e: logger.error(f"Variation generation failed: {e}") raise def repaint( self, audio_path: str, start_time: float, end_time: float, new_prompt: str ) -> str: """Repaint specific section of audio.""" if not self._initialized: raise RuntimeError("Engine not initialized") try: params = GenerationParams( task_type="repainting", audio_path=audio_path, caption=new_prompt, repainting_start=start_time, repainting_end=end_time, inference_steps=8, ) config = GenerationConfig( batch_size=1, audio_format="wav", ) output_dir = self.config.get("output_dir", "outputs") result = generate_music( self.dit_handler, self.llm_handler, params, config, save_dir=output_dir, ) return result.audio_paths[0] if result.audio_paths else audio_path except Exception as e: logger.error(f"Repainting failed: {e}") raise def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str: """Edit lyrics while maintaining music.""" # This would require custom implementation # For now, regenerate with same style logger.warning("Lyric editing not fully implemented - regenerating with new lyrics") return self.generate( prompt="Match the style of the reference", lyrics=new_lyrics, duration=30, ) def is_initialized(self) -> bool: """Check if engine is initialized.""" return self._initialized