Spaces:
Sleeping
Sleeping
ACE-Step Custom commited on
Commit ·
4709141
1
Parent(s): a602628
Fix engine initialization and audio I/O components
Browse files- app.py +70 -25
- src/ace_step_engine.py +25 -12
- src/utils.py +7 -1
app.py
CHANGED
|
@@ -24,11 +24,39 @@ from src.utils import setup_logging, load_config
|
|
| 24 |
logger = setup_logging()
|
| 25 |
config = load_config()
|
| 26 |
|
| 27 |
-
#
|
| 28 |
-
ace_engine =
|
| 29 |
-
timeline_manager =
|
| 30 |
-
lora_trainer =
|
| 31 |
-
audio_processor =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
|
| 34 |
# ==================== TAB 1: STANDARD ACE-STEP GUI ====================
|
|
@@ -49,8 +77,11 @@ def standard_generate(
|
|
| 49 |
try:
|
| 50 |
logger.info(f"Standard generation: {prompt[:50]}...")
|
| 51 |
|
|
|
|
|
|
|
|
|
|
| 52 |
# Generate audio
|
| 53 |
-
audio_path =
|
| 54 |
prompt=prompt,
|
| 55 |
lyrics=lyrics,
|
| 56 |
duration=duration,
|
|
@@ -73,7 +104,7 @@ def standard_generate(
|
|
| 73 |
def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]:
|
| 74 |
"""Generate variation of existing audio."""
|
| 75 |
try:
|
| 76 |
-
result =
|
| 77 |
return result, "✅ Variation generated"
|
| 78 |
except Exception as e:
|
| 79 |
return None, f"❌ Error: {str(e)}"
|
|
@@ -88,7 +119,7 @@ def standard_repaint(
|
|
| 88 |
) -> Tuple[str, str]:
|
| 89 |
"""Repaint specific section of audio."""
|
| 90 |
try:
|
| 91 |
-
result =
|
| 92 |
return result, f"✅ Repainted {start_time}s-{end_time}s"
|
| 93 |
except Exception as e:
|
| 94 |
return None, f"❌ Error: {str(e)}"
|
|
@@ -101,7 +132,7 @@ def standard_lyric_edit(
|
|
| 101 |
) -> Tuple[str, str]:
|
| 102 |
"""Edit lyrics while maintaining music."""
|
| 103 |
try:
|
| 104 |
-
result =
|
| 105 |
return result, "✅ Lyrics edited"
|
| 106 |
except Exception as e:
|
| 107 |
return None, f"❌ Error: {str(e)}"
|
|
@@ -130,14 +161,19 @@ def timeline_generate(
|
|
| 130 |
|
| 131 |
logger.info(f"Timeline generation with {context_length}s context")
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
# Get context from timeline
|
| 134 |
-
context_audio =
|
| 135 |
session_state.get("timeline_id"),
|
| 136 |
context_length
|
| 137 |
)
|
| 138 |
|
| 139 |
# Generate 32s clip
|
| 140 |
-
clip =
|
| 141 |
prompt=prompt,
|
| 142 |
lyrics=lyrics,
|
| 143 |
duration=32,
|
|
@@ -148,15 +184,15 @@ def timeline_generate(
|
|
| 148 |
)
|
| 149 |
|
| 150 |
# Blend with timeline (2s lead-in and lead-out)
|
| 151 |
-
blended_clip =
|
| 152 |
clip,
|
| 153 |
-
|
| 154 |
lead_in=2.0,
|
| 155 |
lead_out=2.0
|
| 156 |
)
|
| 157 |
|
| 158 |
# Add to timeline
|
| 159 |
-
timeline_id =
|
| 160 |
session_state.get("timeline_id"),
|
| 161 |
blended_clip,
|
| 162 |
metadata={
|
|
@@ -171,12 +207,12 @@ def timeline_generate(
|
|
| 171 |
session_state["total_clips"] = session_state.get("total_clips", 0) + 1
|
| 172 |
|
| 173 |
# Get full timeline audio
|
| 174 |
-
full_audio =
|
| 175 |
|
| 176 |
# Get timeline visualization
|
| 177 |
-
timeline_viz =
|
| 178 |
|
| 179 |
-
info = f"✅ Clip {session_state['total_clips']} added • Total: {
|
| 180 |
|
| 181 |
return blended_clip, full_audio, timeline_viz, session_state, info
|
| 182 |
|
|
@@ -210,16 +246,17 @@ def timeline_inpaint(
|
|
| 210 |
if session_state is None:
|
| 211 |
session_state = {"timeline_id": None, "total_clips": 0}
|
| 212 |
|
|
|
|
| 213 |
timeline_id = session_state.get("timeline_id")
|
| 214 |
-
result =
|
| 215 |
timeline_id,
|
| 216 |
start_time,
|
| 217 |
end_time,
|
| 218 |
new_prompt
|
| 219 |
)
|
| 220 |
|
| 221 |
-
full_audio =
|
| 222 |
-
timeline_viz =
|
| 223 |
|
| 224 |
info = f"✅ Inpainted {start_time:.1f}s-{end_time:.1f}s"
|
| 225 |
return full_audio, timeline_viz, session_state, info
|
|
@@ -234,7 +271,7 @@ def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]:
|
|
| 234 |
if session_state is None:
|
| 235 |
session_state = {"timeline_id": None, "total_clips": 0}
|
| 236 |
elif session_state.get("timeline_id"):
|
| 237 |
-
|
| 238 |
|
| 239 |
session_state = {"timeline_id": None, "total_clips": 0}
|
| 240 |
return None, None, "Timeline cleared", session_state
|
|
@@ -245,7 +282,7 @@ def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]:
|
|
| 245 |
def lora_upload_files(files: List[str]) -> str:
|
| 246 |
"""Upload and prepare audio files for LoRA training."""
|
| 247 |
try:
|
| 248 |
-
prepared_files =
|
| 249 |
return f"✅ Prepared {len(prepared_files)} files for training"
|
| 250 |
except Exception as e:
|
| 251 |
return f"❌ Error: {str(e)}"
|
|
@@ -365,7 +402,15 @@ def create_ui():
|
|
| 365 |
std_generate_btn = gr.Button("🎵 Generate", variant="primary", size="lg")
|
| 366 |
|
| 367 |
with gr.Column():
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
std_info = gr.Textbox(label="Status", lines=2)
|
| 370 |
|
| 371 |
gr.Markdown("### Advanced Controls")
|
|
@@ -460,8 +505,8 @@ def create_ui():
|
|
| 460 |
tl_info = gr.Textbox(label="Status", lines=2)
|
| 461 |
|
| 462 |
with gr.Column():
|
| 463 |
-
tl_clip_audio = gr.Audio(label="Latest Clip"
|
| 464 |
-
tl_full_audio = gr.Audio(label="Full Timeline"
|
| 465 |
tl_timeline_viz = gr.Image(label="Timeline Visualization")
|
| 466 |
|
| 467 |
with gr.Accordion("🎨 Inpaint Timeline Region", open=False):
|
|
|
|
| 24 |
logger = setup_logging()
|
| 25 |
config = load_config()
|
| 26 |
|
| 27 |
+
# Lazy initialize components (will be initialized on first use)
|
| 28 |
+
ace_engine = None
|
| 29 |
+
timeline_manager = None
|
| 30 |
+
lora_trainer = None
|
| 31 |
+
audio_processor = None
|
| 32 |
+
|
| 33 |
+
def get_ace_engine():
|
| 34 |
+
"""Lazy-load ACE-Step engine."""
|
| 35 |
+
global ace_engine
|
| 36 |
+
if ace_engine is None:
|
| 37 |
+
ace_engine = ACEStepEngine(config)
|
| 38 |
+
return ace_engine
|
| 39 |
+
|
| 40 |
+
def get_timeline_manager():
|
| 41 |
+
"""Lazy-load timeline manager."""
|
| 42 |
+
global timeline_manager
|
| 43 |
+
if timeline_manager is None:
|
| 44 |
+
timeline_manager = TimelineManager(config)
|
| 45 |
+
return timeline_manager
|
| 46 |
+
|
| 47 |
+
def get_lora_trainer():
|
| 48 |
+
"""Lazy-load LoRA trainer."""
|
| 49 |
+
global lora_trainer
|
| 50 |
+
if lora_trainer is None:
|
| 51 |
+
lora_trainer = LoRATrainer(config)
|
| 52 |
+
return lora_trainer
|
| 53 |
+
|
| 54 |
+
def get_audio_processor():
|
| 55 |
+
"""Lazy-load audio processor."""
|
| 56 |
+
global audio_processor
|
| 57 |
+
if audio_processor is None:
|
| 58 |
+
audio_processor = AudioProcessor(config)
|
| 59 |
+
return audio_processor
|
| 60 |
|
| 61 |
|
| 62 |
# ==================== TAB 1: STANDARD ACE-STEP GUI ====================
|
|
|
|
| 77 |
try:
|
| 78 |
logger.info(f"Standard generation: {prompt[:50]}...")
|
| 79 |
|
| 80 |
+
# Get engine instance
|
| 81 |
+
engine = get_ace_engine()
|
| 82 |
+
|
| 83 |
# Generate audio
|
| 84 |
+
audio_path = engine.generate(
|
| 85 |
prompt=prompt,
|
| 86 |
lyrics=lyrics,
|
| 87 |
duration=duration,
|
|
|
|
| 104 |
def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]:
|
| 105 |
"""Generate variation of existing audio."""
|
| 106 |
try:
|
| 107 |
+
result = get_ace_engine().generate_variation(audio_path, variation_strength)
|
| 108 |
return result, "✅ Variation generated"
|
| 109 |
except Exception as e:
|
| 110 |
return None, f"❌ Error: {str(e)}"
|
|
|
|
| 119 |
) -> Tuple[str, str]:
|
| 120 |
"""Repaint specific section of audio."""
|
| 121 |
try:
|
| 122 |
+
result = get_ace_engine().repaint(audio_path, start_time, end_time, new_prompt)
|
| 123 |
return result, f"✅ Repainted {start_time}s-{end_time}s"
|
| 124 |
except Exception as e:
|
| 125 |
return None, f"❌ Error: {str(e)}"
|
|
|
|
| 132 |
) -> Tuple[str, str]:
|
| 133 |
"""Edit lyrics while maintaining music."""
|
| 134 |
try:
|
| 135 |
+
result = get_ace_engine().edit_lyrics(audio_path, new_lyrics)
|
| 136 |
return result, "✅ Lyrics edited"
|
| 137 |
except Exception as e:
|
| 138 |
return None, f"❌ Error: {str(e)}"
|
|
|
|
| 161 |
|
| 162 |
logger.info(f"Timeline generation with {context_length}s context")
|
| 163 |
|
| 164 |
+
# Get managers
|
| 165 |
+
tm = get_timeline_manager()
|
| 166 |
+
engine = get_ace_engine()
|
| 167 |
+
ap = get_audio_processor()
|
| 168 |
+
|
| 169 |
# Get context from timeline
|
| 170 |
+
context_audio = tm.get_context(
|
| 171 |
session_state.get("timeline_id"),
|
| 172 |
context_length
|
| 173 |
)
|
| 174 |
|
| 175 |
# Generate 32s clip
|
| 176 |
+
clip = engine.generate_clip(
|
| 177 |
prompt=prompt,
|
| 178 |
lyrics=lyrics,
|
| 179 |
duration=32,
|
|
|
|
| 184 |
)
|
| 185 |
|
| 186 |
# Blend with timeline (2s lead-in and lead-out)
|
| 187 |
+
blended_clip = ap.blend_clip(
|
| 188 |
clip,
|
| 189 |
+
tm.get_last_clip(session_state.get("timeline_id")),
|
| 190 |
lead_in=2.0,
|
| 191 |
lead_out=2.0
|
| 192 |
)
|
| 193 |
|
| 194 |
# Add to timeline
|
| 195 |
+
timeline_id = tm.add_clip(
|
| 196 |
session_state.get("timeline_id"),
|
| 197 |
blended_clip,
|
| 198 |
metadata={
|
|
|
|
| 207 |
session_state["total_clips"] = session_state.get("total_clips", 0) + 1
|
| 208 |
|
| 209 |
# Get full timeline audio
|
| 210 |
+
full_audio = tm.export_timeline(timeline_id)
|
| 211 |
|
| 212 |
# Get timeline visualization
|
| 213 |
+
timeline_viz = tm.visualize_timeline(timeline_id)
|
| 214 |
|
| 215 |
+
info = f"✅ Clip {session_state['total_clips']} added • Total: {tm.get_duration(timeline_id):.1f}s"
|
| 216 |
|
| 217 |
return blended_clip, full_audio, timeline_viz, session_state, info
|
| 218 |
|
|
|
|
| 246 |
if session_state is None:
|
| 247 |
session_state = {"timeline_id": None, "total_clips": 0}
|
| 248 |
|
| 249 |
+
tm = get_timeline_manager()
|
| 250 |
timeline_id = session_state.get("timeline_id")
|
| 251 |
+
result = tm.inpaint_region(
|
| 252 |
timeline_id,
|
| 253 |
start_time,
|
| 254 |
end_time,
|
| 255 |
new_prompt
|
| 256 |
)
|
| 257 |
|
| 258 |
+
full_audio = tm.export_timeline(timeline_id)
|
| 259 |
+
timeline_viz = tm.visualize_timeline(timeline_id)
|
| 260 |
|
| 261 |
info = f"✅ Inpainted {start_time:.1f}s-{end_time:.1f}s"
|
| 262 |
return full_audio, timeline_viz, session_state, info
|
|
|
|
| 271 |
if session_state is None:
|
| 272 |
session_state = {"timeline_id": None, "total_clips": 0}
|
| 273 |
elif session_state.get("timeline_id"):
|
| 274 |
+
get_timeline_manager().delete_timeline(session_state["timeline_id"])
|
| 275 |
|
| 276 |
session_state = {"timeline_id": None, "total_clips": 0}
|
| 277 |
return None, None, "Timeline cleared", session_state
|
|
|
|
| 282 |
def lora_upload_files(files: List[str]) -> str:
|
| 283 |
"""Upload and prepare audio files for LoRA training."""
|
| 284 |
try:
|
| 285 |
+
prepared_files = get_lora_trainer().prepare_dataset(files)
|
| 286 |
return f"✅ Prepared {len(prepared_files)} files for training"
|
| 287 |
except Exception as e:
|
| 288 |
return f"❌ Error: {str(e)}"
|
|
|
|
| 402 |
std_generate_btn = gr.Button("🎵 Generate", variant="primary", size="lg")
|
| 403 |
|
| 404 |
with gr.Column():
|
| 405 |
+
gr.Markdown("### Audio Input (Optional)")
|
| 406 |
+
std_audio_input = gr.Audio(
|
| 407 |
+
label="Style Reference Audio",
|
| 408 |
+
type="filepath",
|
| 409 |
+
info="Upload audio file or record to use as style guidance"
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
gr.Markdown("### Generated Output")
|
| 413 |
+
std_audio_out = gr.Audio(label="Generated Audio")
|
| 414 |
std_info = gr.Textbox(label="Status", lines=2)
|
| 415 |
|
| 416 |
gr.Markdown("### Advanced Controls")
|
|
|
|
| 505 |
tl_info = gr.Textbox(label="Status", lines=2)
|
| 506 |
|
| 507 |
with gr.Column():
|
| 508 |
+
tl_clip_audio = gr.Audio(label="Latest Clip")
|
| 509 |
+
tl_full_audio = gr.Audio(label="Full Timeline")
|
| 510 |
tl_timeline_viz = gr.Image(label="Timeline Visualization")
|
| 511 |
|
| 512 |
with gr.Accordion("🎨 Inpaint Timeline Region", open=False):
|
src/ace_step_engine.py
CHANGED
|
@@ -35,21 +35,32 @@ class ACEStepEngine:
|
|
| 35 |
"""
|
| 36 |
self.config = config
|
| 37 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
-
self._initialized = False
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
if not ACE_STEP_AVAILABLE:
|
| 42 |
-
logger.error("ACE-Step 1.5 not
|
| 43 |
-
|
| 44 |
-
self.llm_handler = None
|
| 45 |
return
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
def _download_checkpoints(self):
|
| 55 |
"""Download model checkpoints from HuggingFace if not present."""
|
|
@@ -184,7 +195,9 @@ class ACEStepEngine:
|
|
| 184 |
Path to generated audio file
|
| 185 |
"""
|
| 186 |
if not self._initialized:
|
| 187 |
-
|
|
|
|
|
|
|
| 188 |
|
| 189 |
try:
|
| 190 |
# Prepare generation parameters
|
|
|
|
| 35 |
"""
|
| 36 |
self.config = config
|
| 37 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
self._initialized = False
|
| 39 |
+
self.dit_handler = None
|
| 40 |
+
self.llm_handler = None
|
| 41 |
+
|
| 42 |
+
logger.info(f"ACE-Step Engine initializing on {self.device}")
|
| 43 |
|
| 44 |
if not ACE_STEP_AVAILABLE:
|
| 45 |
+
logger.error("ACE-Step 1.5 modules not available")
|
| 46 |
+
logger.error("Please ensure acestep package is installed in your environment")
|
|
|
|
| 47 |
return
|
| 48 |
|
| 49 |
+
try:
|
| 50 |
+
# Initialize official handlers
|
| 51 |
+
self.dit_handler = AceStepHandler()
|
| 52 |
+
self.llm_handler = LLMHandler()
|
| 53 |
+
|
| 54 |
+
# Download and load models
|
| 55 |
+
self._download_checkpoints()
|
| 56 |
+
self._load_models()
|
| 57 |
+
|
| 58 |
+
logger.info("✓ ACE-Step Engine fully initialized")
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.error(f"Failed to initialize ACE-Step Engine: {e}")
|
| 61 |
+
logger.error("Engine will not be available for generation")
|
| 62 |
+
import traceback
|
| 63 |
+
traceback.print_exc()
|
| 64 |
|
| 65 |
def _download_checkpoints(self):
|
| 66 |
"""Download model checkpoints from HuggingFace if not present."""
|
|
|
|
| 195 |
Path to generated audio file
|
| 196 |
"""
|
| 197 |
if not self._initialized:
|
| 198 |
+
error_msg = "❌ Engine not initialized - ACE-Step 1.5 may not be installed or models are not loaded"
|
| 199 |
+
logger.error(error_msg)
|
| 200 |
+
raise RuntimeError(error_msg)
|
| 201 |
|
| 202 |
try:
|
| 203 |
# Prepare generation parameters
|
src/utils.py
CHANGED
|
@@ -57,13 +57,19 @@ def load_config(config_path: str = "config.yaml") -> Dict[str, Any]:
|
|
| 57 |
else:
|
| 58 |
# Default configuration
|
| 59 |
config = {
|
|
|
|
|
|
|
|
|
|
| 60 |
"model_path": "ACE-Step/ACE-Step-v1-3.5B",
|
| 61 |
"sample_rate": 44100,
|
| 62 |
"output_dir": "outputs",
|
| 63 |
"timeline_dir": "timelines",
|
| 64 |
"training_dir": "lora_training",
|
| 65 |
"chunk_duration": 30,
|
| 66 |
-
"force_mono": False
|
|
|
|
|
|
|
|
|
|
| 67 |
}
|
| 68 |
|
| 69 |
# Save default config
|
|
|
|
| 57 |
else:
|
| 58 |
# Default configuration
|
| 59 |
config = {
|
| 60 |
+
"checkpoint_dir": "./checkpoints",
|
| 61 |
+
"dit_model_path": "acestep-v15-turbo",
|
| 62 |
+
"lm_model_path": "acestep-5Hz-lm-1.7B",
|
| 63 |
"model_path": "ACE-Step/ACE-Step-v1-3.5B",
|
| 64 |
"sample_rate": 44100,
|
| 65 |
"output_dir": "outputs",
|
| 66 |
"timeline_dir": "timelines",
|
| 67 |
"training_dir": "lora_training",
|
| 68 |
"chunk_duration": 30,
|
| 69 |
+
"force_mono": False,
|
| 70 |
+
"device": "auto",
|
| 71 |
+
"use_flash_attention": False,
|
| 72 |
+
"offload_to_cpu": False
|
| 73 |
}
|
| 74 |
|
| 75 |
# Save default config
|