Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ACE-Step Engine - Wrapper for ACE-Step 1.5 official architecture | |
| Properly integrates AceStepHandler (DiT) and LLMHandler (5Hz LM) | |
| """ | |
| import torch | |
| from pathlib import Path | |
| import logging | |
| from typing import Optional, Dict, Any, Tuple | |
| import os | |
| logger = logging.getLogger(__name__) | |
| # Import ACE-Step 1.5 official handlers | |
| try: | |
| from acestep.handler import AceStepHandler | |
| from acestep.llm_inference import LLMHandler | |
| from acestep.inference import GenerationParams, GenerationConfig, generate_music | |
| from acestep.model_downloader import ensure_main_model, get_checkpoints_dir, check_main_model_exists | |
| ACE_STEP_AVAILABLE = True | |
| except ImportError as e: | |
| logger.warning(f"ACE-Step 1.5 modules not available: {e}") | |
| ACE_STEP_AVAILABLE = False | |
| class ACEStepEngine: | |
| """Wrapper engine for ACE-Step 1.5 with custom interface.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| """ | |
| Initialize ACE-Step engine. | |
| Args: | |
| config: Configuration dictionary | |
| """ | |
| self.config = config | |
| self._initialized = False | |
| self.dit_handler = None | |
| self.llm_handler = None | |
| logger.info(f"ACE-Step Engine created (GPU will be detected on first use)") | |
| if not ACE_STEP_AVAILABLE: | |
| logger.error("ACE-Step 1.5 modules not available") | |
| logger.error("Please ensure acestep package is installed in your environment") | |
| return | |
| logger.info("✓ ACE-Step Engine created (models will load on first use)") | |
| def _download_checkpoints(self): | |
| """Download model checkpoints from HuggingFace if not present.""" | |
| checkpoints_dir = get_checkpoints_dir(self.config.get("checkpoint_dir")) | |
| # Check if main model already exists | |
| if check_main_model_exists(checkpoints_dir): | |
| logger.info(f"✓ ACE-Step 1.5 models already exist at {checkpoints_dir}") | |
| return | |
| logger.info("Downloading ACE-Step 1.5 models from HuggingFace...") | |
| logger.info("This may take several minutes (models are ~7GB total)...") | |
| try: | |
| # Use the built-in model downloader | |
| success, message = ensure_main_model( | |
| checkpoints_dir=checkpoints_dir, | |
| prefer_source="huggingface" # Use HuggingFace for Spaces | |
| ) | |
| if not success: | |
| raise RuntimeError(f"Failed to download models: {message}") | |
| logger.info(f"✓ {message}") | |
| logger.info("✓ All ACE-Step 1.5 models downloaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to download checkpoints: {e}") | |
| raise | |
| def _load_models(self): | |
| """Initialize and load ACE-Step models.""" | |
| try: | |
| if not ACE_STEP_AVAILABLE: | |
| raise RuntimeError("ACE-Step 1.5 not available") | |
| checkpoint_dir = self.config.get("checkpoint_dir", "./checkpoints") | |
| dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo") | |
| lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B") | |
| # Get checkpoints directory using helper function | |
| checkpoints_dir = get_checkpoints_dir(checkpoint_dir) | |
| # Get project root | |
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| logger.info(f"Initializing DiT handler with model: {dit_model_path}") | |
| # Initialize DiT handler (handles main diffusion model, VAE, text encoder) | |
| # Note: handler auto-detects checkpoints dir as project_root/checkpoints | |
| # config_path should be just the model name, not full path | |
| status_dit, success_dit = self.dit_handler.initialize_service( | |
| project_root=project_root, | |
| config_path=dit_model_path, # Just model name, handler adds checkpoints/ | |
| device="auto", | |
| use_flash_attention=False, | |
| compile_model=False, | |
| offload_to_cpu=False, | |
| ) | |
| if not success_dit: | |
| raise RuntimeError(f"Failed to initialize DiT: {status_dit}") | |
| logger.info(f"✓ DiT initialized: {status_dit}") | |
| # Initialize LLM handler (handles 5Hz Language Model) | |
| logger.info(f"Initializing LLM handler with model: {lm_model_path}") | |
| status_llm, success_llm = self.llm_handler.initialize( | |
| checkpoint_dir=str(checkpoints_dir), | |
| lm_model_path=lm_model_path, | |
| backend="pt", # Use PyTorch backend for compatibility | |
| device="auto", | |
| offload_to_cpu=False, | |
| ) | |
| if not success_llm: | |
| logger.warning(f"LLM initialization failed: {status_llm}") | |
| logger.warning("Continuing without LLM (DiT-only mode)") | |
| else: | |
| logger.info(f" LLM initialized: {status_llm}") | |
| self._initialized = True | |
| logger.info(" ACE-Step engine fully initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize models: {e}") | |
| raise | |
| def _ensure_models_loaded(self): | |
| """Ensure models are loaded (lazy loading for ZeroGPU compatibility).""" | |
| if not self._initialized: | |
| logger.info("Lazy loading models on first use...") | |
| # Detect device now (within GPU context on ZeroGPU) | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {self.device}") | |
| # Create handlers if not already created | |
| if self.dit_handler is None: | |
| self.dit_handler = AceStepHandler() | |
| if self.llm_handler is None: | |
| self.llm_handler = LLMHandler() | |
| try: | |
| # Download and load models | |
| self._download_checkpoints() | |
| self._load_models() | |
| logger.info("✓ Models loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| raise | |
| def generate( | |
| self, | |
| prompt: str, | |
| lyrics: Optional[str] = None, | |
| duration: int = 30, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| seed: int = -1, | |
| style: str = "auto", | |
| lora_path: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Generate music using ACE-Step. | |
| Args: | |
| prompt: Text description of desired music | |
| lyrics: Optional lyrics | |
| duration: Duration in seconds | |
| temperature: Sampling temperature (for LLM) | |
| top_p: Nucleus sampling parameter (for LLM) | |
| seed: Random seed (-1 for random) | |
| style: Music style | |
| lora_path: Path to LoRA model if using | |
| Returns: | |
| Path to generated audio file | |
| """ | |
| # Ensure models are loaded (lazy loading for ZeroGPU) | |
| self._ensure_models_loaded() | |
| try: | |
| # Prepare generation parameters | |
| params = GenerationParams( | |
| task_type="text2music", | |
| caption=prompt, | |
| lyrics=lyrics or "", | |
| duration=duration, | |
| inference_steps=8, # Turbo model default | |
| seed=seed if seed >= 0 else -1, | |
| thinking=True, # Use LLM planning | |
| lm_temperature=temperature, | |
| lm_top_p=top_p, | |
| ) | |
| # Prepare generation config | |
| config = GenerationConfig( | |
| batch_size=1, | |
| use_random_seed=(seed < 0), | |
| audio_format="wav", | |
| ) | |
| # Generate using official inference | |
| output_dir = self.config.get("output_dir", "outputs") | |
| os.makedirs(output_dir, exist_ok=True) | |
| logger.info(f"Generating {duration}s audio: {prompt[:50]}...") | |
| result = generate_music( | |
| dit_handler=self.dit_handler, | |
| llm_handler=self.llm_handler, | |
| params=params, | |
| config=config, | |
| save_dir=output_dir, | |
| ) | |
| if result.audio_paths: | |
| output_path = result.audio_paths[0] | |
| logger.info(f" Generated: {output_path}") | |
| return output_path | |
| else: | |
| raise RuntimeError("No audio generated") | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| raise | |
| def generate_clip( | |
| self, | |
| prompt: str, | |
| lyrics: str, | |
| duration: int, | |
| context_audio: Optional[str] = None, | |
| style: str = "auto", | |
| temperature: float = 0.7, | |
| seed: int = -1 | |
| ) -> str: | |
| """ | |
| Generate audio clip for timeline (with context conditioning). | |
| Args: | |
| prompt: Text prompt | |
| lyrics: Lyrics for this clip | |
| duration: Duration in seconds (typically 32) | |
| context_audio: Path to previous audio for style conditioning | |
| style: Music style | |
| temperature: Sampling temperature | |
| seed: Random seed | |
| Returns: | |
| Path to generated clip | |
| """ | |
| # For timeline clips, use regular generation with extended context | |
| # Context conditioning would require custom implementation | |
| return self.generate( | |
| prompt=prompt, | |
| lyrics=lyrics, | |
| duration=duration, | |
| temperature=temperature, | |
| seed=seed, | |
| style=style | |
| ) | |
| def generate_variation(self, audio_path: str, strength: float = 0.5) -> str: | |
| """Generate variation of existing audio.""" | |
| # Ensure models are loaded (lazy loading for ZeroGPU) | |
| self._ensure_models_loaded() | |
| try: | |
| params = GenerationParams( | |
| task_type="audio_variation", | |
| audio_path=audio_path, | |
| audio_cover_strength=strength, | |
| inference_steps=8, | |
| ) | |
| config = GenerationConfig( | |
| batch_size=1, | |
| audio_format="wav", | |
| ) | |
| output_dir = self.config.get("output_dir", "outputs") | |
| result = generate_music( | |
| self.dit_handler, | |
| self.llm_handler, | |
| params, | |
| config, | |
| save_dir=output_dir, | |
| ) | |
| return result.audio_paths[0] if result.audio_paths else audio_path | |
| except Exception as e: | |
| logger.error(f"Variation generation failed: {e}") | |
| raise | |
| def repaint( | |
| self, | |
| audio_path: str, | |
| start_time: float, | |
| end_time: float, | |
| new_prompt: str | |
| ) -> str: | |
| """Repaint specific section of audio.""" | |
| if not self._initialized: | |
| raise RuntimeError("Engine not initialized") | |
| try: | |
| params = GenerationParams( | |
| task_type="repainting", | |
| audio_path=audio_path, | |
| caption=new_prompt, | |
| repainting_start=start_time, | |
| repainting_end=end_time, | |
| inference_steps=8, | |
| ) | |
| config = GenerationConfig( | |
| batch_size=1, | |
| audio_format="wav", | |
| ) | |
| output_dir = self.config.get("output_dir", "outputs") | |
| result = generate_music( | |
| self.dit_handler, | |
| self.llm_handler, | |
| params, | |
| config, | |
| save_dir=output_dir, | |
| ) | |
| return result.audio_paths[0] if result.audio_paths else audio_path | |
| except Exception as e: | |
| logger.error(f"Repainting failed: {e}") | |
| raise | |
| def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str: | |
| """Edit lyrics while maintaining music.""" | |
| # This would require custom implementation | |
| # For now, regenerate with same style | |
| logger.warning("Lyric editing not fully implemented - regenerating with new lyrics") | |
| return self.generate( | |
| prompt="Match the style of the reference", | |
| lyrics=new_lyrics, | |
| duration=30, | |
| ) | |
| def is_initialized(self) -> bool: | |
| """Check if engine is initialized.""" | |
| return self._initialized | |