ACE-Step-Custom / src /ace_step_engine.py.backup
ACE-Step Custom
Deploy ACE-Step Custom Edition with bug fixes
a602628
"""
ACE-Step Engine - Core generation module
Handles interaction with ACE-Step model for music generation
"""
import torch
import torchaudio
from pathlib import Path
import logging
from typing import Optional, Dict, Any
import numpy as np
logger = logging.getLogger(__name__)
class ACEStepEngine:
"""Core engine for ACE-Step music generation."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize ACE-Step engine.
Args:
config: Configuration dictionary
"""
self.config = config
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"ACE-Step Engine initialized on {self.device}")
self.model = None
self.text_tokenizer = None # Tokenizer for text encoder (Qwen3-Embedding-0.6B)
self.text_encoder = None # Text encoder model
self.llm_tokenizer = None # Tokenizer for LLM (5Hz Language Model)
self.llm = None # LLM model for planning
self.vae = None # VAE for audio encoding/decoding
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
# Load main model
self.model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto",
low_cpu_mem_usage=True
)
self.model.eval()
logger.info("✅ ACE-Step model loaded successfully")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise
def generate(
self,
prompt: str,
lyrics: Optional[str] = None,
duration: int = 30,
temperature: float = 0.7,
top_p: float = 0.9,
seed: int = -1,
style: str = "auto",
lora_path: Optional[str] = None
) -> str:
"""
Generate music using ACE-Step.
Args:
prompt: Text description of desired music
lyrics: Optional lyrics
duration: Duration in seconds
temperature: Sampling temperature
top_p: Nucleus sampling parameter
seed: Random seed (-1 for random)
style: Music style
lora_path: Path to LoRA model if using
Returns:
Path to generated audio file
"""
try:
# Set seed
if seed >= 0:
torch.manual_seed(seed)
np.random.seed(seed)
# Load LoRA if specified
if lora_path:
self._load_lora(lora_path)
# Prepare input
input_text = self._prepare_input(prompt, lyrics, style, duration)
# Tokenize using text encoder tokenizer
inputs = self.text_tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
# Generate
logger.info(f"Generating {duration}s audio...")
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=duration * 50, # Approximate tokens per second
temperature=temperature,
top_p=top_p,
do_sample=True,
num_return_sequences=1
)
# Decode to audio
audio_tensor = self._decode_to_audio(outputs)
# Save audio
output_path = self._save_audio(audio_tensor, duration)
logger.info(f"✅ Generated audio: {output_path}")
return str(output_path)
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
finally:
# Unload LoRA if it was loaded
if lora_path:
self._unload_lora()
def generate_clip(
self,
prompt: str,
lyrics: str,
duration: int,
context_audio: Optional[np.ndarray] = None,
style: str = "auto",
temperature: float = 0.7,
seed: int = -1
) -> str:
"""
Generate audio clip with context conditioning.
Used for timeline-based generation.
Args:
prompt: Text prompt
lyrics: Lyrics for this clip
duration: Duration in seconds (typically 32)
context_audio: Previous audio for style conditioning
style: Music style
temperature: Sampling temperature
seed: Random seed
Returns:
Path to generated clip
"""
try:
if seed >= 0:
torch.manual_seed(seed)
# Prepare input with context
input_text = self._prepare_input(prompt, lyrics, style, duration)
# If context provided, use it for conditioning
context_embedding = None
if context_audio is not None:
context_embedding = self._encode_audio_context(context_audio)
inputs = self.text_tokenizer(input_text, return_tensors="pt").to(self.device)
# Generate with context conditioning
with torch.no_grad():
if context_embedding is not None:
outputs = self.model.generate(
**inputs,
context_embedding=context_embedding,
max_length=duration * 50,
temperature=temperature,
do_sample=True
)
else:
outputs = self.model.generate(
**inputs,
max_length=duration * 50,
temperature=temperature,
do_sample=True
)
audio_tensor = self._decode_to_audio(outputs)
output_path = self._save_audio(audio_tensor, duration, prefix="clip")
return str(output_path)
except Exception as e:
logger.error(f"Clip generation failed: {e}")
raise
def generate_variation(self, audio_path: str, strength: float = 0.5) -> str:
"""Generate variation of existing audio."""
try:
# Load audio
audio, sr = torchaudio.load(audio_path)
# Encode to latent space
latent = self._encode_audio(audio)
# Add noise for variation
noise = torch.randn_like(latent) * strength
varied_latent = latent + noise
# Decode back to audio
varied_audio = self._decode_from_latent(varied_latent)
# Save
output_path = self._save_audio(varied_audio, audio.shape[-1] / sr, prefix="variation")
return str(output_path)
except Exception as e:
logger.error(f"Variation generation failed: {e}")
raise
def repaint(
self,
audio_path: str,
start_time: float,
end_time: float,
new_prompt: str
) -> str:
"""Repaint specific section of audio."""
try:
# Load original audio
audio, sr = torchaudio.load(audio_path)
# Calculate frame indices
start_frame = int(start_time * sr)
end_frame = int(end_time * sr)
# Encode to latent
latent = self._encode_audio(audio)
# Generate new section
section_duration = end_time - start_time
new_section = self.generate(
prompt=new_prompt,
duration=int(section_duration),
temperature=0.8
)
# Load new section
new_audio, _ = torchaudio.load(new_section)
# Blend sections
result = audio.clone()
result[:, start_frame:end_frame] = new_audio[:, :end_frame-start_frame]
# Smooth transitions
blend_length = int(0.5 * sr) # 0.5s blend
if start_frame > blend_length:
fade_in = torch.linspace(0, 1, blend_length).unsqueeze(0)
result[:, start_frame:start_frame+blend_length] = (
result[:, start_frame:start_frame+blend_length] * fade_in +
audio[:, start_frame:start_frame+blend_length] * (1 - fade_in)
)
if end_frame < audio.shape[-1] - blend_length:
fade_out = torch.linspace(1, 0, blend_length).unsqueeze(0)
result[:, end_frame-blend_length:end_frame] = (
result[:, end_frame-blend_length:end_frame] * fade_out +
audio[:, end_frame-blend_length:end_frame] * (1 - fade_out)
)
# Save
output_path = self._save_audio(result, audio.shape[-1] / sr, prefix="repainted")
return str(output_path)
except Exception as e:
logger.error(f"Repainting failed: {e}")
raise
def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str:
"""Edit lyrics while maintaining music."""
try:
# This is a simplified version - full implementation would:
# 1. Extract musical features (harmony, rhythm, melody)
# 2. Generate new vocals with new lyrics
# 3. Blend new vocals with original instrumental
# For now, regenerate with new lyrics while using audio as reference
audio, sr = torchaudio.load(audio_path)
duration = audio.shape[-1] / sr
# Extract style from original
context = self._encode_audio_context(audio.numpy())
# Generate with new lyrics
result = self.generate(
prompt="Match the style of the reference",
lyrics=new_lyrics,
duration=int(duration),
temperature=0.6
)
return result
except Exception as e:
logger.error(f"Lyric editing failed: {e}")
raise
def _prepare_input(
self,
prompt: str,
lyrics: Optional[str],
style: str,
duration: int
) -> str:
"""Prepare input text for model."""
parts = []
if style and style != "auto":
parts.append(f"[STYLE: {style}]")
parts.append(f"[DURATION: {duration}s]")
parts.append(prompt)
if lyrics:
parts.append(f"[LYRICS]\n{lyrics}")
return " ".join(parts)
def _encode_audio(self, audio: torch.Tensor) -> torch.Tensor:
"""Encode audio to latent space using DCAE."""
# Placeholder - would use actual DCAE encoder
return audio
def _decode_from_latent(self, latent: torch.Tensor) -> torch.Tensor:
"""Decode latent to audio using DCAE."""
# Placeholder - would use actual DCAE decoder
return latent
def _encode_audio_context(self, audio: np.ndarray) -> torch.Tensor:
"""Encode audio context for conditioning."""
# This would extract style/semantic features
# Placeholder implementation
audio_tensor = torch.from_numpy(audio).float().to(self.device)
return audio_tensor
def _decode_to_audio(self, outputs: torch.Tensor) -> torch.Tensor:
"""Decode model outputs to audio tensor."""
# Placeholder - actual implementation would use DCAE decoder
# For now, generate dummy audio
sample_rate = 44100
duration = outputs.shape[1] / 50 # Approximate
samples = int(duration * sample_rate)
# Generate placeholder audio (would be replaced with actual decoding)
audio = torch.randn(2, samples) * 0.1
return audio
def _save_audio(
self,
audio: torch.Tensor,
duration: float,
prefix: str = "generated"
) -> Path:
"""Save audio tensor to file."""
output_dir = Path(self.config.get("output_dir", "outputs"))
output_dir.mkdir(exist_ok=True)
# Generate filename
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{prefix}_{timestamp}.wav"
output_path = output_dir / filename
# Save
torchaudio.save(
str(output_path),
audio,
sample_rate=44100,
encoding="PCM_S",
bits_per_sample=16
)
return output_path
def _load_lora(self, lora_path: str):
"""Load LoRA weights into model."""
try:
from peft import PeftModel
self.model = PeftModel.from_pretrained(self.model, lora_path)
logger.info(f"✅ Loaded LoRA from {lora_path}")
except Exception as e:
logger.warning(f"Failed to load LoRA: {e}")
def _unload_lora(self):
"""Unload LoRA weights."""
try:
if hasattr(self.model, "unload"):
self.model.unload()
except Exception as e:
logger.warning(f"Failed to unload LoRA: {e}")