Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ACE-Step Engine - Core generation module | |
| Handles interaction with ACE-Step model for music generation | |
| """ | |
| import torch | |
| import torchaudio | |
| from pathlib import Path | |
| import logging | |
| from typing import Optional, Dict, Any | |
| import numpy as np | |
| logger = logging.getLogger(__name__) | |
| class ACEStepEngine: | |
| """Core engine for ACE-Step music generation.""" | |
| def __init__(self, config: Dict[str, Any]): | |
| """ | |
| Initialize ACE-Step engine. | |
| Args: | |
| config: Configuration dictionary | |
| """ | |
| self.config = config | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"ACE-Step Engine initialized on {self.device}") | |
| self.model = None | |
| self.text_tokenizer = None # Tokenizer for text encoder (Qwen3-Embedding-0.6B) | |
| self.text_encoder = None # Text encoder model | |
| self.llm_tokenizer = None # Tokenizer for LLM (5Hz Language Model) | |
| self.llm = None # LLM model for planning | |
| self.vae = None # VAE for audio encoding/decoding | |
| # Load tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| # Load main model | |
| self.model = AutoModel.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, | |
| device_map="auto", | |
| low_cpu_mem_usage=True | |
| ) | |
| self.model.eval() | |
| logger.info("✅ ACE-Step model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to load models: {e}") | |
| raise | |
| def generate( | |
| self, | |
| prompt: str, | |
| lyrics: Optional[str] = None, | |
| duration: int = 30, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| seed: int = -1, | |
| style: str = "auto", | |
| lora_path: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Generate music using ACE-Step. | |
| Args: | |
| prompt: Text description of desired music | |
| lyrics: Optional lyrics | |
| duration: Duration in seconds | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| seed: Random seed (-1 for random) | |
| style: Music style | |
| lora_path: Path to LoRA model if using | |
| Returns: | |
| Path to generated audio file | |
| """ | |
| try: | |
| # Set seed | |
| if seed >= 0: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # Load LoRA if specified | |
| if lora_path: | |
| self._load_lora(lora_path) | |
| # Prepare input | |
| input_text = self._prepare_input(prompt, lyrics, style, duration) | |
| # Tokenize using text encoder tokenizer | |
| inputs = self.text_tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True | |
| ).to(self.device) | |
| # Generate | |
| logger.info(f"Generating {duration}s audio...") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=duration * 50, # Approximate tokens per second | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| num_return_sequences=1 | |
| ) | |
| # Decode to audio | |
| audio_tensor = self._decode_to_audio(outputs) | |
| # Save audio | |
| output_path = self._save_audio(audio_tensor, duration) | |
| logger.info(f"✅ Generated audio: {output_path}") | |
| return str(output_path) | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| raise | |
| finally: | |
| # Unload LoRA if it was loaded | |
| if lora_path: | |
| self._unload_lora() | |
| def generate_clip( | |
| self, | |
| prompt: str, | |
| lyrics: str, | |
| duration: int, | |
| context_audio: Optional[np.ndarray] = None, | |
| style: str = "auto", | |
| temperature: float = 0.7, | |
| seed: int = -1 | |
| ) -> str: | |
| """ | |
| Generate audio clip with context conditioning. | |
| Used for timeline-based generation. | |
| Args: | |
| prompt: Text prompt | |
| lyrics: Lyrics for this clip | |
| duration: Duration in seconds (typically 32) | |
| context_audio: Previous audio for style conditioning | |
| style: Music style | |
| temperature: Sampling temperature | |
| seed: Random seed | |
| Returns: | |
| Path to generated clip | |
| """ | |
| try: | |
| if seed >= 0: | |
| torch.manual_seed(seed) | |
| # Prepare input with context | |
| input_text = self._prepare_input(prompt, lyrics, style, duration) | |
| # If context provided, use it for conditioning | |
| context_embedding = None | |
| if context_audio is not None: | |
| context_embedding = self._encode_audio_context(context_audio) | |
| inputs = self.text_tokenizer(input_text, return_tensors="pt").to(self.device) | |
| # Generate with context conditioning | |
| with torch.no_grad(): | |
| if context_embedding is not None: | |
| outputs = self.model.generate( | |
| **inputs, | |
| context_embedding=context_embedding, | |
| max_length=duration * 50, | |
| temperature=temperature, | |
| do_sample=True | |
| ) | |
| else: | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_length=duration * 50, | |
| temperature=temperature, | |
| do_sample=True | |
| ) | |
| audio_tensor = self._decode_to_audio(outputs) | |
| output_path = self._save_audio(audio_tensor, duration, prefix="clip") | |
| return str(output_path) | |
| except Exception as e: | |
| logger.error(f"Clip generation failed: {e}") | |
| raise | |
| def generate_variation(self, audio_path: str, strength: float = 0.5) -> str: | |
| """Generate variation of existing audio.""" | |
| try: | |
| # Load audio | |
| audio, sr = torchaudio.load(audio_path) | |
| # Encode to latent space | |
| latent = self._encode_audio(audio) | |
| # Add noise for variation | |
| noise = torch.randn_like(latent) * strength | |
| varied_latent = latent + noise | |
| # Decode back to audio | |
| varied_audio = self._decode_from_latent(varied_latent) | |
| # Save | |
| output_path = self._save_audio(varied_audio, audio.shape[-1] / sr, prefix="variation") | |
| return str(output_path) | |
| except Exception as e: | |
| logger.error(f"Variation generation failed: {e}") | |
| raise | |
| def repaint( | |
| self, | |
| audio_path: str, | |
| start_time: float, | |
| end_time: float, | |
| new_prompt: str | |
| ) -> str: | |
| """Repaint specific section of audio.""" | |
| try: | |
| # Load original audio | |
| audio, sr = torchaudio.load(audio_path) | |
| # Calculate frame indices | |
| start_frame = int(start_time * sr) | |
| end_frame = int(end_time * sr) | |
| # Encode to latent | |
| latent = self._encode_audio(audio) | |
| # Generate new section | |
| section_duration = end_time - start_time | |
| new_section = self.generate( | |
| prompt=new_prompt, | |
| duration=int(section_duration), | |
| temperature=0.8 | |
| ) | |
| # Load new section | |
| new_audio, _ = torchaudio.load(new_section) | |
| # Blend sections | |
| result = audio.clone() | |
| result[:, start_frame:end_frame] = new_audio[:, :end_frame-start_frame] | |
| # Smooth transitions | |
| blend_length = int(0.5 * sr) # 0.5s blend | |
| if start_frame > blend_length: | |
| fade_in = torch.linspace(0, 1, blend_length).unsqueeze(0) | |
| result[:, start_frame:start_frame+blend_length] = ( | |
| result[:, start_frame:start_frame+blend_length] * fade_in + | |
| audio[:, start_frame:start_frame+blend_length] * (1 - fade_in) | |
| ) | |
| if end_frame < audio.shape[-1] - blend_length: | |
| fade_out = torch.linspace(1, 0, blend_length).unsqueeze(0) | |
| result[:, end_frame-blend_length:end_frame] = ( | |
| result[:, end_frame-blend_length:end_frame] * fade_out + | |
| audio[:, end_frame-blend_length:end_frame] * (1 - fade_out) | |
| ) | |
| # Save | |
| output_path = self._save_audio(result, audio.shape[-1] / sr, prefix="repainted") | |
| return str(output_path) | |
| except Exception as e: | |
| logger.error(f"Repainting failed: {e}") | |
| raise | |
| def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str: | |
| """Edit lyrics while maintaining music.""" | |
| try: | |
| # This is a simplified version - full implementation would: | |
| # 1. Extract musical features (harmony, rhythm, melody) | |
| # 2. Generate new vocals with new lyrics | |
| # 3. Blend new vocals with original instrumental | |
| # For now, regenerate with new lyrics while using audio as reference | |
| audio, sr = torchaudio.load(audio_path) | |
| duration = audio.shape[-1] / sr | |
| # Extract style from original | |
| context = self._encode_audio_context(audio.numpy()) | |
| # Generate with new lyrics | |
| result = self.generate( | |
| prompt="Match the style of the reference", | |
| lyrics=new_lyrics, | |
| duration=int(duration), | |
| temperature=0.6 | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Lyric editing failed: {e}") | |
| raise | |
| def _prepare_input( | |
| self, | |
| prompt: str, | |
| lyrics: Optional[str], | |
| style: str, | |
| duration: int | |
| ) -> str: | |
| """Prepare input text for model.""" | |
| parts = [] | |
| if style and style != "auto": | |
| parts.append(f"[STYLE: {style}]") | |
| parts.append(f"[DURATION: {duration}s]") | |
| parts.append(prompt) | |
| if lyrics: | |
| parts.append(f"[LYRICS]\n{lyrics}") | |
| return " ".join(parts) | |
| def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor: | |
| """Encode audio to latent space using DCAE.""" | |
| # Placeholder - would use actual DCAE encoder | |
| return audio | |
| def _decode_from_latent(self, latent: torch.Tensor) -> torch.Tensor: | |
| """Decode latent to audio using DCAE.""" | |
| # Placeholder - would use actual DCAE decoder | |
| return latent | |
| def _encode_audio_context(self, audio: np.ndarray) -> torch.Tensor: | |
| """Encode audio context for conditioning.""" | |
| # This would extract style/semantic features | |
| # Placeholder implementation | |
| audio_tensor = torch.from_numpy(audio).float().to(self.device) | |
| return audio_tensor | |
| def _decode_to_audio(self, outputs: torch.Tensor) -> torch.Tensor: | |
| """Decode model outputs to audio tensor.""" | |
| # Placeholder - actual implementation would use DCAE decoder | |
| # For now, generate dummy audio | |
| sample_rate = 44100 | |
| duration = outputs.shape[1] / 50 # Approximate | |
| samples = int(duration * sample_rate) | |
| # Generate placeholder audio (would be replaced with actual decoding) | |
| audio = torch.randn(2, samples) * 0.1 | |
| return audio | |
| def _save_audio( | |
| self, | |
| audio: torch.Tensor, | |
| duration: float, | |
| prefix: str = "generated" | |
| ) -> Path: | |
| """Save audio tensor to file.""" | |
| output_dir = Path(self.config.get("output_dir", "outputs")) | |
| output_dir.mkdir(exist_ok=True) | |
| # Generate filename | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = f"{prefix}_{timestamp}.wav" | |
| output_path = output_dir / filename | |
| # Save | |
| torchaudio.save( | |
| str(output_path), | |
| audio, | |
| sample_rate=44100, | |
| encoding="PCM_S", | |
| bits_per_sample=16 | |
| ) | |
| return output_path | |
| def _load_lora(self, lora_path: str): | |
| """Load LoRA weights into model.""" | |
| try: | |
| from peft import PeftModel | |
| self.model = PeftModel.from_pretrained(self.model, lora_path) | |
| logger.info(f"✅ Loaded LoRA from {lora_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load LoRA: {e}") | |
| def _unload_lora(self): | |
| """Unload LoRA weights.""" | |
| try: | |
| if hasattr(self.model, "unload"): | |
| self.model.unload() | |
| except Exception as e: | |
| logger.warning(f"Failed to unload LoRA: {e}") | |