| """ |
| ACE-Step 1.5 Custom Edition - Main Application |
| A comprehensive music generation system with three main interfaces: |
| 1. Standard ACE-Step GUI |
| 2. Custom Timeline-based Workflow |
| 3. LoRA Training Studio |
| """ |
|
|
| import gradio as gr |
| import torch |
| import numpy as np |
| from pathlib import Path |
| import json |
| from typing import Optional, List, Tuple |
| import spaces |
|
|
| from src.ace_step_engine import ACEStepEngine |
| from src.timeline_manager import TimelineManager |
| from src.lora_trainer import LoRATrainer |
| from src.audio_processor import AudioProcessor |
| from src.utils import setup_logging, load_config |
|
|
| |
| logger = setup_logging() |
| config = load_config() |
|
|
| |
| ace_engine = None |
| timeline_manager = None |
| lora_trainer = None |
| audio_processor = None |
|
|
| def get_ace_engine(): |
| """Lazy-load ACE-Step engine.""" |
| global ace_engine |
| if ace_engine is None: |
| ace_engine = ACEStepEngine(config) |
| return ace_engine |
|
|
| def get_timeline_manager(): |
| """Lazy-load timeline manager.""" |
| global timeline_manager |
| if timeline_manager is None: |
| timeline_manager = TimelineManager(config) |
| return timeline_manager |
|
|
| def get_lora_trainer(): |
| """Lazy-load LoRA trainer.""" |
| global lora_trainer |
| if lora_trainer is None: |
| lora_trainer = LoRATrainer(config) |
| return lora_trainer |
|
|
| def get_audio_processor(): |
| """Lazy-load audio processor.""" |
| global audio_processor |
| if audio_processor is None: |
| audio_processor = AudioProcessor(config) |
| return audio_processor |
|
|
|
|
| |
|
|
| @spaces.GPU(duration=300) |
| def standard_generate( |
| prompt: str, |
| lyrics: str, |
| duration: int, |
| temperature: float, |
| top_p: float, |
| seed: int, |
| style: str, |
| use_lora: bool, |
| lora_path: Optional[str] = None |
| ) -> Tuple[str, str]: |
| """Standard ACE-Step generation with all original features.""" |
| try: |
| logger.info(f"Standard generation: {prompt[:50]}...") |
| |
| |
| engine = get_ace_engine() |
| |
| |
| audio_path = engine.generate( |
| prompt=prompt, |
| lyrics=lyrics, |
| duration=duration, |
| temperature=temperature, |
| top_p=top_p, |
| seed=seed, |
| style=style, |
| lora_path=lora_path if use_lora else None |
| ) |
| |
| info = f"✅ Generated {duration}s audio successfully" |
| return audio_path, info |
| |
| except Exception as e: |
| logger.error(f"Standard generation failed: {e}") |
| return None, f"❌ Error: {str(e)}" |
|
|
|
|
| @spaces.GPU(duration=180) |
| def standard_variation(audio_path: str, variation_strength: float) -> Tuple[str, str]: |
| """Generate variation of existing audio.""" |
| try: |
| result = get_ace_engine().generate_variation(audio_path, variation_strength) |
| return result, "✅ Variation generated" |
| except Exception as e: |
| return None, f"❌ Error: {str(e)}" |
|
|
|
|
| @spaces.GPU(duration=180) |
| def standard_repaint( |
| audio_path: str, |
| start_time: float, |
| end_time: float, |
| new_prompt: str |
| ) -> Tuple[str, str]: |
| """Repaint specific section of audio.""" |
| try: |
| result = get_ace_engine().repaint(audio_path, start_time, end_time, new_prompt) |
| return result, f"✅ Repainted {start_time}s-{end_time}s" |
| except Exception as e: |
| return None, f"❌ Error: {str(e)}" |
|
|
| @spaces.GPU(duration=180) |
|
|
| def standard_lyric_edit( |
| audio_path: str, |
| new_lyrics: str |
| ) -> Tuple[str, str]: |
| """Edit lyrics while maintaining music.""" |
| try: |
| result = get_ace_engine().edit_lyrics(audio_path, new_lyrics) |
| return result, "✅ Lyrics edited" |
| except Exception as e: |
| return None, f"❌ Error: {str(e)}" |
|
|
|
|
| |
| @spaces.GPU(duration=300) |
|
|
| def timeline_generate( |
| prompt: str, |
| lyrics: str, |
| context_length: int, |
| style: str, |
| temperature: float, |
| seed: int, |
| session_state: dict |
| ) -> Tuple[str, str, str, dict]: |
| """ |
| Generate 32-second clip with 2s lead-in, 28s main, 2s lead-out. |
| Blends with previous clips based on context_length. |
| """ |
| try: |
| |
| if session_state is None: |
| session_state = {"timeline_id": None, "total_clips": 0} |
| |
| logger.info(f"Timeline generation with {context_length}s context") |
| |
| |
| tm = get_timeline_manager() |
| engine = get_ace_engine() |
| ap = get_audio_processor() |
| |
| |
| context_audio = tm.get_context( |
| session_state.get("timeline_id"), |
| context_length |
| ) |
| |
| |
| clip = engine.generate_clip( |
| prompt=prompt, |
| lyrics=lyrics, |
| duration=32, |
| context_audio=context_audio, |
| style=style, |
| temperature=temperature, |
| seed=seed |
| ) |
| |
| |
| blended_clip = ap.blend_clip( |
| clip, |
| tm.get_last_clip(session_state.get("timeline_id")), |
| lead_in=2.0, |
| lead_out=2.0 |
| ) |
| |
| |
| timeline_id = tm.add_clip( |
| session_state.get("timeline_id"), |
| blended_clip, |
| metadata={ |
| "prompt": prompt, |
| "lyrics": lyrics, |
| "context_length": context_length |
| } |
| ) |
| |
| |
| session_state["timeline_id"] = timeline_id |
| session_state["total_clips"] = session_state.get("total_clips", 0) + 1 |
| |
| |
| full_audio = tm.export_timeline(timeline_id) |
| |
| |
| timeline_viz = tm.visualize_timeline(timeline_id) |
| |
| info = f"✅ Clip {session_state['total_clips']} added • Total: {tm.get_duration(timeline_id):.1f}s" |
| |
| return blended_clip, full_audio, timeline_viz, session_state, info |
| |
| except Exception as e: |
| logger.error(f"Timeline generation failed: {e}") |
| return None, None, None, session_state, f"❌ Error: {str(e)}" |
|
|
|
|
| def timeline_extend( |
| prompt: str, |
| lyrics: str, |
| context_length: int, |
| session_state: dict |
| ) -> Tuple[str, str, str, dict]: |
| """Extend current timeline with new generation.""" |
| return timeline_generate( |
| prompt, lyrics, context_length, "auto", 0.7, -1, session_state |
| ) |
|
|
| @spaces.GPU(duration=240) |
|
|
| def timeline_inpaint( |
| start_time: float, |
| end_time: float, |
| new_prompt: str, |
| session_state: dict |
| ) -> Tuple[str, str, dict]: |
| """Inpaint specific region in timeline.""" |
| try: |
| |
| if session_state is None: |
| session_state = {"timeline_id": None, "total_clips": 0} |
| |
| tm = get_timeline_manager() |
| timeline_id = session_state.get("timeline_id") |
| result = tm.inpaint_region( |
| timeline_id, |
| start_time, |
| end_time, |
| new_prompt |
| ) |
| |
| full_audio = tm.export_timeline(timeline_id) |
| timeline_viz = tm.visualize_timeline(timeline_id) |
| |
| info = f"✅ Inpainted {start_time:.1f}s-{end_time:.1f}s" |
| return full_audio, timeline_viz, session_state, info |
| |
| except Exception as e: |
| return None, None, session_state, f"❌ Error: {str(e)}" |
|
|
|
|
| def timeline_reset(session_state: dict) -> Tuple[None, None, str, dict]: |
| """Reset timeline to start fresh.""" |
| |
| if session_state is None: |
| session_state = {"timeline_id": None, "total_clips": 0} |
| elif session_state.get("timeline_id"): |
| get_timeline_manager().delete_timeline(session_state["timeline_id"]) |
| |
| session_state = {"timeline_id": None, "total_clips": 0} |
| return None, None, "Timeline cleared", session_state |
|
|
|
|
| |
|
|
| def lora_upload_files(files: List[str]) -> str: |
| """Upload and prepare audio files for LoRA training.""" |
| try: |
| prepared_files = get_lora_trainer().prepare_dataset(files) |
| return f"✅ Prepared {len(prepared_files)} files for training" |
| except Exception as e: |
| return f"❌ Error: {str(e)}" |
|
|
| @spaces.GPU(duration=300) |
|
|
| def lora_train( |
| dataset_path: str, |
| model_name: str, |
| learning_rate: float, |
| batch_size: int, |
| num_epochs: int, |
| rank: int, |
| alpha: int, |
| use_existing_lora: bool, |
| existing_lora_path: Optional[str] = None, |
| progress=gr.Progress() |
| ) -> Tuple[str, str]: |
| """Train LoRA model on uploaded dataset.""" |
| try: |
| logger.info(f"Starting LoRA training: {model_name}") |
| |
| |
| if use_existing_lora and existing_lora_path: |
| lora_trainer.load_lora(existing_lora_path) |
| else: |
| lora_trainer.initialize_lora(rank=rank, alpha=alpha) |
| |
| |
| def progress_callback(step, total_steps, loss): |
| progress((step, total_steps), desc=f"Training (loss: {loss:.4f})") |
| |
| result_path = lora_trainer.train( |
| dataset_path=dataset_path, |
| model_name=model_name, |
| learning_rate=learning_rate, |
| batch_size=batch_size, |
| num_epochs=num_epochs, |
| progress_callback=progress_callback |
| ) |
| |
| info = f"✅ Training complete! Model saved to {result_path}" |
| return result_path, info |
| |
| except Exception as e: |
| logger.error(f"LoRA training failed: {e}") |
| return None, f"❌ Error: {str(e)}" |
|
|
|
|
| def lora_download(lora_path: str) -> str: |
| """Provide LoRA model for download.""" |
| return lora_path if Path(lora_path).exists() else None |
|
|
|
|
| |
|
|
| def create_ui(): |
| """Create the three-tab Gradio interface.""" |
| |
| with gr.Blocks(title="ACE-Step 1.5 Custom Edition", theme=gr.themes.Soft()) as app: |
| |
| gr.Markdown(""" |
| # 🎵 ACE-Step 1.5 Custom Edition |
| |
| **Three powerful interfaces for music generation and training** |
| |
| Models will download automatically on first use (~7GB from HuggingFace) |
| """) |
| |
| with gr.Tabs(): |
| |
| |
| with gr.Tab("🎼 Standard ACE-Step"): |
| gr.Markdown("### Full-featured standard ACE-Step 1.5 interface") |
| |
| with gr.Row(): |
| with gr.Column(): |
| std_prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Describe the music style, mood, instruments...", |
| lines=3 |
| ) |
| std_lyrics = gr.Textbox( |
| label="Lyrics (optional)", |
| placeholder="Enter lyrics here...", |
| lines=5 |
| ) |
| |
| with gr.Row(): |
| std_duration = gr.Slider( |
| minimum=10, maximum=240, value=30, step=10, |
| label="Duration (seconds)" |
| ) |
| std_style = gr.Dropdown( |
| choices=["auto", "pop", "rock", "jazz", "classical", "electronic", "hip-hop"], |
| value="auto", |
| label="Style" |
| ) |
| |
| with gr.Row(): |
| std_temperature = gr.Slider( |
| minimum=0.1, maximum=1.5, value=0.7, step=0.1, |
| label="Temperature" |
| ) |
| std_top_p = gr.Slider( |
| minimum=0.1, maximum=1.0, value=0.9, step=0.05, |
| label="Top P" |
| ) |
| |
| std_seed = gr.Number(label="Seed (-1 for random)", value=-1) |
| |
| with gr.Row(): |
| std_use_lora = gr.Checkbox(label="Use LoRA", value=False) |
| std_lora_path = gr.Textbox( |
| label="LoRA Path", |
| placeholder="Path to LoRA model (if using)" |
| ) |
| |
| std_generate_btn = gr.Button("🎵 Generate", variant="primary", size="lg") |
| |
| with gr.Column(): |
| gr.Markdown("### Audio Input (Optional)") |
| gr.Markdown("*Upload audio file or record to use as style guidance*") |
| std_audio_input = gr.Audio( |
| label="Style Reference Audio", |
| type="filepath" |
| ) |
| |
| gr.Markdown("### Generated Output") |
| std_audio_out = gr.Audio(label="Generated Audio") |
| std_info = gr.Textbox(label="Status", lines=2) |
| |
| gr.Markdown("### Advanced Controls") |
| |
| with gr.Accordion("🔄 Generate Variation", open=False): |
| std_var_strength = gr.Slider(0.1, 1.0, 0.5, label="Variation Strength") |
| std_var_btn = gr.Button("Generate Variation") |
| |
| with gr.Accordion("🎨 Repaint Section", open=False): |
| std_repaint_start = gr.Number(label="Start Time (s)", value=0) |
| std_repaint_end = gr.Number(label="End Time (s)", value=10) |
| std_repaint_prompt = gr.Textbox(label="New Prompt", lines=2) |
| std_repaint_btn = gr.Button("Repaint") |
| |
| with gr.Accordion("✏️ Edit Lyrics", open=False): |
| std_edit_lyrics = gr.Textbox(label="New Lyrics", lines=4) |
| std_edit_btn = gr.Button("Edit Lyrics") |
| |
| |
| std_generate_btn.click( |
| fn=standard_generate, |
| inputs=[std_prompt, std_lyrics, std_duration, std_temperature, |
| std_top_p, std_seed, std_style, std_use_lora, std_lora_path], |
| outputs=[std_audio_out, std_info] |
| ) |
| |
| std_var_btn.click( |
| fn=standard_variation, |
| inputs=[std_audio_out, std_var_strength], |
| outputs=[std_audio_out, std_info] |
| ) |
| |
| std_repaint_btn.click( |
| fn=standard_repaint, |
| inputs=[std_audio_out, std_repaint_start, std_repaint_end, std_repaint_prompt], |
| outputs=[std_audio_out, std_info] |
| ) |
| |
| std_edit_btn.click( |
| fn=standard_lyric_edit, |
| inputs=[std_audio_out, std_edit_lyrics], |
| outputs=[std_audio_out, std_info] |
| ) |
| |
| |
| with gr.Tab("⏱️ Timeline Workflow"): |
| gr.Markdown(""" |
| ### Custom Timeline-based Generation |
| Generate 32-second clips that seamlessly blend together on a master timeline. |
| """) |
| |
| |
| timeline_state = gr.State(value=None) |
| |
| with gr.Row(): |
| with gr.Column(): |
| tl_prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Describe this section...", |
| lines=3 |
| ) |
| tl_lyrics = gr.Textbox( |
| label="Lyrics for this clip", |
| placeholder="Enter lyrics for this 32s section...", |
| lines=4 |
| ) |
| |
| gr.Markdown("*How far back to reference for style guidance*") |
| tl_context_length = gr.Slider( |
| minimum=0, maximum=120, value=30, step=10, |
| label="Context Length (seconds)" |
| ) |
| |
| with gr.Row(): |
| tl_style = gr.Dropdown( |
| choices=["auto", "pop", "rock", "jazz", "electronic"], |
| value="auto", |
| label="Style" |
| ) |
| tl_temperature = gr.Slider( |
| minimum=0.5, maximum=1.0, value=0.7, step=0.05, |
| label="Temperature" |
| ) |
| |
| tl_seed = gr.Number(label="Seed (-1 for random)", value=-1) |
| |
| with gr.Row(): |
| tl_generate_btn = gr.Button("🎵 Generate Clip", variant="primary", size="lg") |
| tl_extend_btn = gr.Button("➕ Extend", size="lg") |
| tl_reset_btn = gr.Button("🔄 Reset Timeline", variant="secondary") |
| |
| tl_info = gr.Textbox(label="Status", lines=2) |
| |
| with gr.Column(): |
| tl_clip_audio = gr.Audio(label="Latest Clip") |
| tl_full_audio = gr.Audio(label="Full Timeline") |
| tl_timeline_viz = gr.Image(label="Timeline Visualization") |
| |
| with gr.Accordion("🎨 Inpaint Timeline Region", open=False): |
| tl_inpaint_start = gr.Number(label="Start Time (s)", value=0) |
| tl_inpaint_end = gr.Number(label="End Time (s)", value=10) |
| tl_inpaint_prompt = gr.Textbox(label="New Prompt", lines=2) |
| tl_inpaint_btn = gr.Button("Inpaint Region") |
| |
| |
| tl_generate_btn.click( |
| fn=timeline_generate, |
| inputs=[tl_prompt, tl_lyrics, tl_context_length, tl_style, |
| tl_temperature, tl_seed, timeline_state], |
| outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info] |
| ) |
| |
| tl_extend_btn.click( |
| fn=timeline_extend, |
| inputs=[tl_prompt, tl_lyrics, tl_context_length, timeline_state], |
| outputs=[tl_clip_audio, tl_full_audio, tl_timeline_viz, timeline_state, tl_info] |
| ) |
| |
| tl_reset_btn.click( |
| fn=timeline_reset, |
| inputs=[timeline_state], |
| outputs=[tl_clip_audio, tl_full_audio, tl_info, timeline_state] |
| ) |
| |
| tl_inpaint_btn.click( |
| fn=timeline_inpaint, |
| inputs=[tl_inpaint_start, tl_inpaint_end, tl_inpaint_prompt, timeline_state], |
| outputs=[tl_full_audio, tl_timeline_viz, timeline_state, tl_info] |
| ) |
| |
| |
| with gr.Tab("🎓 LoRA Training Studio"): |
| gr.Markdown(""" |
| ### Train Custom LoRA Models |
| Upload audio files to train specialized models for voice cloning, style adaptation, etc. |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("#### 1. Upload Training Data") |
| lora_files = gr.File( |
| label="Audio Files", |
| file_count="multiple", |
| file_types=["audio"] |
| ) |
| lora_upload_btn = gr.Button("📤 Upload & Prepare Dataset") |
| lora_upload_status = gr.Textbox(label="Upload Status", lines=2) |
| |
| gr.Markdown("#### 2. Training Configuration") |
| lora_dataset_path = gr.Textbox( |
| label="Dataset Path", |
| placeholder="Path to prepared dataset" |
| ) |
| lora_model_name = gr.Textbox( |
| label="Model Name", |
| placeholder="my_custom_lora" |
| ) |
| |
| with gr.Row(): |
| lora_learning_rate = gr.Number( |
| label="Learning Rate", |
| value=1e-4 |
| ) |
| lora_batch_size = gr.Slider( |
| minimum=1, maximum=16, value=4, step=1, |
| label="Batch Size" |
| ) |
| |
| with gr.Row(): |
| lora_num_epochs = gr.Slider( |
| minimum=1, maximum=100, value=10, step=1, |
| label="Epochs" |
| ) |
| lora_rank = gr.Slider( |
| minimum=4, maximum=128, value=16, step=4, |
| label="LoRA Rank" |
| ) |
| lora_alpha = gr.Slider( |
| minimum=4, maximum=128, value=32, step=4, |
| label="LoRA Alpha" |
| ) |
| |
| lora_use_existing = gr.Checkbox( |
| label="Continue training from existing LoRA", |
| value=False |
| ) |
| lora_existing_path = gr.Textbox( |
| label="Existing LoRA Path", |
| placeholder="Path to existing LoRA model" |
| ) |
| |
| lora_train_btn = gr.Button("🚀 Start Training", variant="primary", size="lg") |
| |
| with gr.Column(): |
| lora_train_status = gr.Textbox(label="Training Status", lines=3) |
| lora_model_path = gr.Textbox(label="Trained Model Path", lines=1) |
| lora_download_btn = gr.Button("💾 Download Model") |
| lora_download_file = gr.File(label="Download") |
| |
| gr.Markdown(""" |
| #### Training Tips |
| - Upload 10+ audio samples for best results |
| - Keep samples consistent in style/quality |
| - Higher rank = more capacity but slower training |
| - Start with 10-20 epochs and adjust |
| - Use existing LoRA to continue training |
| """) |
| |
| |
| lora_upload_btn.click( |
| fn=lora_upload_files, |
| inputs=[lora_files], |
| outputs=[lora_upload_status] |
| ) |
| |
| lora_train_btn.click( |
| fn=lora_train, |
| inputs=[lora_dataset_path, lora_model_name, lora_learning_rate, |
| lora_batch_size, lora_num_epochs, lora_rank, lora_alpha, |
| lora_use_existing, lora_existing_path], |
| outputs=[lora_model_path, lora_train_status] |
| ) |
| |
| lora_download_btn.click( |
| fn=lora_download, |
| inputs=[lora_model_path], |
| outputs=[lora_download_file] |
| ) |
| |
| gr.Markdown(""" |
| --- |
| ### About |
| ACE-Step 1.5 Custom Edition by Gamahea | Based on [ACE-Step](https://ace-step.github.io/) |
| """) |
| |
| return app |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| logger.info("Starting ACE-Step 1.5 Custom Edition...") |
| |
| try: |
| |
| app = create_ui() |
| |
| |
| original_get_api_info = app.get_api_info |
| |
| def safe_get_api_info(*args, **kwargs): |
| """Patched get_api_info that returns minimal info to avoid schema errors""" |
| try: |
| return original_get_api_info(*args, **kwargs) |
| except (TypeError, AttributeError, KeyError) as e: |
| logger.warning(f"API info generation failed, returning minimal info: {e}") |
| return { |
| "named_endpoints": {}, |
| "unnamed_endpoints": {} |
| } |
| |
| app.get_api_info = safe_get_api_info |
| logger.info("✓ Patched get_api_info method") |
| |
| |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True |
| ) |
| except Exception as e: |
| logger.error(f"Failed to launch app: {e}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|