Spaces:
Running
on
Zero
Running
on
Zero
File size: 12,884 Bytes
a602628 ee19acb a602628 4709141 052ca84 a602628 4709141 a602628 052ca84 a602628 ee19acb a602628 ee19acb a602628 ee19acb a602628 ee19acb fa7f63d ee19acb a602628 ee19acb a602628 ee19acb a602628 78910e3 a602628 78910e3 a602628 78910e3 a602628 6a590ee a602628 6a590ee a602628 78910e3 a602628 78910e3 a602628 052ca84 a602628 052ca84 a602628 052ca84 a602628 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 |
"""
ACE-Step Engine - Wrapper for ACE-Step 1.5 official architecture
Properly integrates AceStepHandler (DiT) and LLMHandler (5Hz LM)
"""
import torch
from pathlib import Path
import logging
from typing import Optional, Dict, Any, Tuple
import os
logger = logging.getLogger(__name__)
# Import ACE-Step 1.5 official handlers
try:
from acestep.handler import AceStepHandler
from acestep.llm_inference import LLMHandler
from acestep.inference import GenerationParams, GenerationConfig, generate_music
from acestep.model_downloader import ensure_main_model, get_checkpoints_dir, check_main_model_exists
ACE_STEP_AVAILABLE = True
except ImportError as e:
logger.warning(f"ACE-Step 1.5 modules not available: {e}")
ACE_STEP_AVAILABLE = False
class ACEStepEngine:
"""Wrapper engine for ACE-Step 1.5 with custom interface."""
def __init__(self, config: Dict[str, Any]):
"""
Initialize ACE-Step engine.
Args:
config: Configuration dictionary
"""
self.config = config
self._initialized = False
self.dit_handler = None
self.llm_handler = None
logger.info(f"ACE-Step Engine created (GPU will be detected on first use)")
if not ACE_STEP_AVAILABLE:
logger.error("ACE-Step 1.5 modules not available")
logger.error("Please ensure acestep package is installed in your environment")
return
logger.info("✓ ACE-Step Engine created (models will load on first use)")
def _download_checkpoints(self):
"""Download model checkpoints from HuggingFace if not present."""
checkpoints_dir = get_checkpoints_dir(self.config.get("checkpoint_dir"))
# Check if main model already exists
if check_main_model_exists(checkpoints_dir):
logger.info(f"✓ ACE-Step 1.5 models already exist at {checkpoints_dir}")
return
logger.info("Downloading ACE-Step 1.5 models from HuggingFace...")
logger.info("This may take several minutes (models are ~7GB total)...")
try:
# Use the built-in model downloader
success, message = ensure_main_model(
checkpoints_dir=checkpoints_dir,
prefer_source="huggingface" # Use HuggingFace for Spaces
)
if not success:
raise RuntimeError(f"Failed to download models: {message}")
logger.info(f"✓ {message}")
logger.info("✓ All ACE-Step 1.5 models downloaded successfully")
except Exception as e:
logger.error(f"Failed to download checkpoints: {e}")
raise
def _load_models(self):
"""Initialize and load ACE-Step models."""
try:
if not ACE_STEP_AVAILABLE:
raise RuntimeError("ACE-Step 1.5 not available")
checkpoint_dir = self.config.get("checkpoint_dir", "./checkpoints")
dit_model_path = self.config.get("dit_model_path", "acestep-v15-turbo")
lm_model_path = self.config.get("lm_model_path", "acestep-5Hz-lm-1.7B")
# Get checkpoints directory using helper function
checkpoints_dir = get_checkpoints_dir(checkpoint_dir)
# Get project root
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
logger.info(f"Initializing DiT handler with model: {dit_model_path}")
# Initialize DiT handler (handles main diffusion model, VAE, text encoder)
# Note: handler auto-detects checkpoints dir as project_root/checkpoints
# config_path should be just the model name, not full path
status_dit, success_dit = self.dit_handler.initialize_service(
project_root=project_root,
config_path=dit_model_path, # Just model name, handler adds checkpoints/
device="auto",
use_flash_attention=False,
compile_model=False,
offload_to_cpu=False,
)
if not success_dit:
raise RuntimeError(f"Failed to initialize DiT: {status_dit}")
logger.info(f"✓ DiT initialized: {status_dit}")
# Initialize LLM handler (handles 5Hz Language Model)
logger.info(f"Initializing LLM handler with model: {lm_model_path}")
status_llm, success_llm = self.llm_handler.initialize(
checkpoint_dir=str(checkpoints_dir),
lm_model_path=lm_model_path,
backend="pt", # Use PyTorch backend for compatibility
device="auto",
offload_to_cpu=False,
)
if not success_llm:
logger.warning(f"LLM initialization failed: {status_llm}")
logger.warning("Continuing without LLM (DiT-only mode)")
else:
logger.info(f" LLM initialized: {status_llm}")
self._initialized = True
logger.info(" ACE-Step engine fully initialized")
except Exception as e:
logger.error(f"Failed to initialize models: {e}")
raise
def _ensure_models_loaded(self):
"""Ensure models are loaded (lazy loading for ZeroGPU compatibility)."""
if not self._initialized:
logger.info("Lazy loading models on first use...")
# Detect device now (within GPU context on ZeroGPU)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Create handlers if not already created
if self.dit_handler is None:
self.dit_handler = AceStepHandler()
if self.llm_handler is None:
self.llm_handler = LLMHandler()
try:
# Download and load models
self._download_checkpoints()
self._load_models()
logger.info("✓ Models loaded successfully")
except Exception as e:
logger.error(f"Failed to load models: {e}")
raise
def generate(
self,
prompt: str,
lyrics: Optional[str] = None,
duration: int = 30,
temperature: float = 0.7,
top_p: float = 0.9,
seed: int = -1,
style: str = "auto",
lora_path: Optional[str] = None
) -> str:
"""
Generate music using ACE-Step.
Args:
prompt: Text description of desired music
lyrics: Optional lyrics
duration: Duration in seconds
temperature: Sampling temperature (for LLM)
top_p: Nucleus sampling parameter (for LLM)
seed: Random seed (-1 for random)
style: Music style
lora_path: Path to LoRA model if using
Returns:
Path to generated audio file
"""
# Ensure models are loaded (lazy loading for ZeroGPU)
self._ensure_models_loaded()
try:
# Prepare generation parameters
params = GenerationParams(
task_type="text2music",
caption=prompt,
lyrics=lyrics or "",
duration=duration,
inference_steps=8, # Turbo model default
seed=seed if seed >= 0 else -1,
thinking=True, # Use LLM planning
lm_temperature=temperature,
lm_top_p=top_p,
)
# Prepare generation config
config = GenerationConfig(
batch_size=1,
use_random_seed=(seed < 0),
audio_format="wav",
)
# Generate using official inference
output_dir = self.config.get("output_dir", "outputs")
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Generating {duration}s audio: {prompt[:50]}...")
result = generate_music(
dit_handler=self.dit_handler,
llm_handler=self.llm_handler,
params=params,
config=config,
save_dir=output_dir,
)
if result.audio_paths:
output_path = result.audio_paths[0]
logger.info(f" Generated: {output_path}")
return output_path
else:
raise RuntimeError("No audio generated")
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
def generate_clip(
self,
prompt: str,
lyrics: str,
duration: int,
context_audio: Optional[str] = None,
style: str = "auto",
temperature: float = 0.7,
seed: int = -1
) -> str:
"""
Generate audio clip for timeline (with context conditioning).
Args:
prompt: Text prompt
lyrics: Lyrics for this clip
duration: Duration in seconds (typically 32)
context_audio: Path to previous audio for style conditioning
style: Music style
temperature: Sampling temperature
seed: Random seed
Returns:
Path to generated clip
"""
# For timeline clips, use regular generation with extended context
# Context conditioning would require custom implementation
return self.generate(
prompt=prompt,
lyrics=lyrics,
duration=duration,
temperature=temperature,
seed=seed,
style=style
)
def generate_variation(self, audio_path: str, strength: float = 0.5) -> str:
"""Generate variation of existing audio."""
# Ensure models are loaded (lazy loading for ZeroGPU)
self._ensure_models_loaded()
try:
params = GenerationParams(
task_type="audio_variation",
audio_path=audio_path,
audio_cover_strength=strength,
inference_steps=8,
)
config = GenerationConfig(
batch_size=1,
audio_format="wav",
)
output_dir = self.config.get("output_dir", "outputs")
result = generate_music(
self.dit_handler,
self.llm_handler,
params,
config,
save_dir=output_dir,
)
return result.audio_paths[0] if result.audio_paths else audio_path
except Exception as e:
logger.error(f"Variation generation failed: {e}")
raise
def repaint(
self,
audio_path: str,
start_time: float,
end_time: float,
new_prompt: str
) -> str:
"""Repaint specific section of audio."""
if not self._initialized:
raise RuntimeError("Engine not initialized")
try:
params = GenerationParams(
task_type="repainting",
audio_path=audio_path,
caption=new_prompt,
repainting_start=start_time,
repainting_end=end_time,
inference_steps=8,
)
config = GenerationConfig(
batch_size=1,
audio_format="wav",
)
output_dir = self.config.get("output_dir", "outputs")
result = generate_music(
self.dit_handler,
self.llm_handler,
params,
config,
save_dir=output_dir,
)
return result.audio_paths[0] if result.audio_paths else audio_path
except Exception as e:
logger.error(f"Repainting failed: {e}")
raise
def edit_lyrics(self, audio_path: str, new_lyrics: str) -> str:
"""Edit lyrics while maintaining music."""
# This would require custom implementation
# For now, regenerate with same style
logger.warning("Lyric editing not fully implemented - regenerating with new lyrics")
return self.generate(
prompt="Match the style of the reference",
lyrics=new_lyrics,
duration=30,
)
def is_initialized(self) -> bool:
"""Check if engine is initialized."""
return self._initialized
|