""" ACE-Step Engine - Core generation module Handles interaction with ACE-Step model for music generation """ import torch import torchaudio from pathlib import Path import logging from typing import Optional, Dict, Any import numpy as np logger = logging.getLogger(__name__) class ACEStepEngine: """Core engine for ACE-Step music generation.""" def __init__(self, config: Dict[str, Any]): """ Initialize ACE-Step engine. Args: config: Configuration dictionary """ self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"ACE-Step Engine initialized on {self.device}") self.model = None self.text_tokenizer = None # Tokenizer for text encoder (Qwen3-Embedding-0.6B) self.text_encoder = None # Text encoder model self.llm_tokenizer = None # Tokenizer for LLM (5Hz Language Model) self.llm = None # LLM model for planning self.vae = None # VAE for audio encoding/decoding # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_path) # Load main model self.model = AutoModel.from_pretrained( model_path, torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, device_map="auto", low_cpu_mem_usage=True ) self.model.eval() logger.info("✅ ACE-Step model loaded successfully") except Exception as e: logger.error(f"Failed to load models: {e}") raise def generate( self, prompt: str, lyrics: Optional[str] = None, duration: int = 30, temperature: float = 0.7, top_p: float = 0.9, seed: int = -1, style: str = "auto", lora_path: Optional[str] = None ) -> str: """ Generate music using ACE-Step. Args: prompt: Text description of desired music lyrics: Optional lyrics duration: Duration in seconds temperature: Sampling temperature top_p: Nucleus sampling parameter seed: Random seed (-1 for random) style: Music style lora_path: Path to LoRA model if using Returns: Path to generated audio file """ try: # Set seed if seed >= 0: torch.manual_seed(seed) np.random.seed(seed) # Load LoRA if specified if lora_path: self._load_lora(lora_path) # Prepare input input_text = self._prepare_input(prompt, lyrics, style, duration) # Tokenize using text encoder tokenizer inputs = self.text_tokenizer( input_text, return_tensors="pt", padding=True, truncation=True ).to(self.device) # Generate logger.info(f"Generating {duration}s audio...") with torch.no_grad(): outputs = self.model.generate( **inputs, max_length=duration * 50, # Approximate tokens per second temperature=temperature, top_p=top_p, do_sample=True, num_return_sequences=1 ) # Decode to audio audio_tensor = self._decode_to_audio(outputs) # Save audio output_path = self._save_audio(audio_tensor, duration) logger.info(f"✅ Generated audio: {output_path}") return str(output_path) except Exception as e: logger.error(f"Generation failed: {e}") raise finally: # Unload LoRA if it was loaded if lora_path: self._unload_lora() def generate_clip( self, prompt: str, lyrics: str, duration: int, context_audio: Optional[np.ndarray] = None, style: str = "auto", temperature: float = 0.7, seed: int = -1 ) -> str: """ Generate audio clip with context conditioning. Used for timeline-based generation. Args: prompt: Text prompt lyrics: Lyrics for this clip duration: Duration in seconds (typically 32) context_audio: Previous audio for style conditioning style: Music style temperature: Sampling temperature seed: Random seed Returns: Path to generated clip """ try: if seed >= 0: torch.manual_seed(seed) # Prepare input with context input_text = self._prepare_input(prompt, lyrics, style, duration) # If context provided, use it for conditioning context_embedding = None if context_audio is not None: context_embedding = self._encode_audio_context(context_audio) inputs = self.text_tokenizer(input_text, return_tensors="pt").to(self.device) # Generate with context conditioning with torch.no_grad(): if context_embedding is not None: outputs = self.model.generate( **inputs, context_embedding=context_embedding, max_length=duration * 50, temperature=temperature, do_sample=True ) else: outputs = self.model.generate( **inputs, max_length=duration * 50, temperature=temperature, do_sample=True ) audio_tensor = self._decode_to_audio(outputs) output_path = self._save_audio(audio_tensor, duration, prefix="clip") return str(output_path) except Exception as e: logger.error(f"Clip generation failed: {e}") raise def generate_variation(self, audio_path: str, strength: float = 0.5) -> str: """Generate variation of existing audio.""" try: # Load audio audio, sr = torchaudio.load(audio_path) # Encode to latent space latent = self._encode_audio(audio) # Add noise for variation noise = torch.randn_like(latent) * strength varied_latent = latent + noise # Decode back to audio varied_audio = self._decode_from_latent(varied_latent) # Save output_path = self._save_audio(varied_audio, audio.shape[-1] / sr, prefix="variation") return str(output_path) except Exception as e: logger.error(f"Variation generation failed: {e}") raise def repaint( self, audio_path: str, start_time: float, end_time: float, new_prompt: str ) -> str: """Repaint specific section of audio.""" try: # Load original audio audio, sr = torchaudio.load(audio_path) # Calculate frame indices start_frame = int(start_time * sr) end_frame = int(end_time * sr) # Encode to latent latent = self._encode_audio(audio) # Generate new section section_duration = end_time - start_time new_section = self.generate( prompt=new_prompt, duration=int(section_duration), temperature=0.8 ) # Load new section new_audio, _ = torchaudio.load(new_section) # Blend sections result = audio.clone() result[:, start_frame:end_frame] = new_audio[:, :end_frame-start_frame] # Smooth transitions blend_length = int(0.5 * sr) # 0.5s blend if start_frame > blend_length: fade_in = torch.linspace(0, 1, blend_length).unsqueeze(0) result[:, start_frame:start_frame+blend_length] = ( result[:, start_frame:start_frame+blend_length] * fade_in + audio[:, start_frame:start_frame+blend_length] * (1 - fade_in) ) if end_frame < audio.shape[-1] - blend_length: fade_out = torch.linspace(1, 0, blend_length).unsqueeze(0) result[:, end_frame-blend_length:end_frame] = ( result[:, end_frame-blend_length:end_frame] * fade_out + audio[:, end_frame-blend_length:end_frame] * (1 - fade_out) ) # Save output_path = self._save_audio(result, audio.shape[-1] / sr, prefix="repainted") return str(output_path) except Exception as e: logger.error(f"Repainting failed: {e}") raise def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str: """Edit lyrics while maintaining music.""" try: # This is a simplified version - full implementation would: # 1. Extract musical features (harmony, rhythm, melody) # 2. Generate new vocals with new lyrics # 3. Blend new vocals with original instrumental # For now, regenerate with new lyrics while using audio as reference audio, sr = torchaudio.load(audio_path) duration = audio.shape[-1] / sr # Extract style from original context = self._encode_audio_context(audio.numpy()) # Generate with new lyrics result = self.generate( prompt="Match the style of the reference", lyrics=new_lyrics, duration=int(duration), temperature=0.6 ) return result except Exception as e: logger.error(f"Lyric editing failed: {e}") raise def _prepare_input( self, prompt: str, lyrics: Optional[str], style: str, duration: int ) -> str: """Prepare input text for model.""" parts = [] if style and style != "auto": parts.append(f"[STYLE: {style}]") parts.append(f"[DURATION: {duration}s]") parts.append(prompt) if lyrics: parts.append(f"[LYRICS]\n{lyrics}") return " ".join(parts) def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor: """Encode audio to latent space using DCAE.""" # Placeholder - would use actual DCAE encoder return audio def _decode_from_latent(self, latent: torch.Tensor) -> torch.Tensor: """Decode latent to audio using DCAE.""" # Placeholder - would use actual DCAE decoder return latent def _encode_audio_context(self, audio: np.ndarray) -> torch.Tensor: """Encode audio context for conditioning.""" # This would extract style/semantic features # Placeholder implementation audio_tensor = torch.from_numpy(audio).float().to(self.device) return audio_tensor def _decode_to_audio(self, outputs: torch.Tensor) -> torch.Tensor: """Decode model outputs to audio tensor.""" # Placeholder - actual implementation would use DCAE decoder # For now, generate dummy audio sample_rate = 44100 duration = outputs.shape[1] / 50 # Approximate samples = int(duration * sample_rate) # Generate placeholder audio (would be replaced with actual decoding) audio = torch.randn(2, samples) * 0.1 return audio def _save_audio( self, audio: torch.Tensor, duration: float, prefix: str = "generated" ) -> Path: """Save audio tensor to file.""" output_dir = Path(self.config.get("output_dir", "outputs")) output_dir.mkdir(exist_ok=True) # Generate filename from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{prefix}_{timestamp}.wav" output_path = output_dir / filename # Save torchaudio.save( str(output_path), audio, sample_rate=44100, encoding="PCM_S", bits_per_sample=16 ) return output_path def _load_lora(self, lora_path: str): """Load LoRA weights into model.""" try: from peft import PeftModel self.model = PeftModel.from_pretrained(self.model, lora_path) logger.info(f"✅ Loaded LoRA from {lora_path}") except Exception as e: logger.warning(f"Failed to load LoRA: {e}") def _unload_lora(self): """Unload LoRA weights.""" try: if hasattr(self.model, "unload"): self.model.unload() except Exception as e: logger.warning(f"Failed to unload LoRA: {e}")