Spaces:
Running
Running
| """ACE-Step 1.5 XL (CPU) - Gradio frontend + CLI for ace-server GGUF inference""" | |
| import os | |
| import sys | |
| import time | |
| import json | |
| import argparse | |
| import tempfile | |
| import requests | |
| ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085") | |
| OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs") | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints") | |
| ACE_SOURCE_DIR = "/app/ace-step-source" | |
| ACE_HF_MODEL = "ACE-Step/Ace-Step1.5" | |
| ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters") | |
| # --------------------------------------------------------------------------- | |
| # ace-server helpers | |
| # --------------------------------------------------------------------------- | |
| def _server_ok(): | |
| try: | |
| return requests.get(f"{ACE_SERVER}/health", timeout=5).status_code == 200 | |
| except Exception: | |
| return False | |
| def _get_props(): | |
| """Fetch server properties (models, adapters).""" | |
| try: | |
| r = requests.get(f"{ACE_SERVER}/props", timeout=10) | |
| if r.status_code == 200: | |
| return r.json() | |
| except Exception: | |
| pass | |
| return {} | |
| def _poll_job(job_id, timeout=600, progress_cb=None): | |
| """Poll a job until done/error/timeout. Returns (status, elapsed).""" | |
| t0 = time.time() | |
| while time.time() - t0 < timeout: | |
| try: | |
| r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=10) | |
| data = r.json() | |
| status = data.get("status", "unknown") | |
| if progress_cb: | |
| progress_cb(status, data) | |
| if status in ("done", "error"): | |
| return status, time.time() - t0 | |
| except Exception: | |
| pass | |
| time.sleep(2) | |
| return "timeout", time.time() - t0 | |
| def _fetch_result(job_id, timeout=60): | |
| """Fetch result bytes/json for a completed job.""" | |
| r = requests.get( | |
| f"{ACE_SERVER}/job", | |
| params={"id": job_id, "result": 1}, | |
| timeout=timeout, | |
| ) | |
| return r | |
| def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format, | |
| adapter=None, progress_cb=None): | |
| """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises.""" | |
| t0 = time.time() | |
| # -- Build LM request -- | |
| req = {"caption": caption or "upbeat electronic dance music"} | |
| req["lyrics"] = lyrics if lyrics and lyrics.strip() else "[Instrumental]" | |
| if bpm and int(bpm) > 0: | |
| req["bpm"] = int(bpm) | |
| if duration and float(duration) > 0: | |
| req["duration"] = min(float(duration), 300) | |
| if seed is not None and int(seed) >= 0: | |
| req["seed"] = int(seed) | |
| if steps and int(steps) > 0: | |
| req["inference_steps"] = int(steps) | |
| if adapter: | |
| req["adapter"] = adapter | |
| fmt = output_format if output_format in ("wav", "mp3") else "wav" | |
| synth_fmt = "wav16" if fmt == "wav" else "mp3" | |
| suffix = f".{fmt}" | |
| # -- LM phase -- | |
| if progress_cb: | |
| progress_cb("lm_submit", None) | |
| r = requests.post(f"{ACE_SERVER}/lm", json=req, timeout=30) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"LM submit failed: {r.status_code} {r.text}") | |
| lm_job_id = r.json().get("id") | |
| if progress_cb: | |
| progress_cb("lm_poll", {"job_id": lm_job_id}) | |
| lm_status, lm_elapsed = _poll_job(lm_job_id, timeout=300) | |
| if lm_status != "done": | |
| raise RuntimeError(f"LM {lm_status} after {lm_elapsed:.0f}s") | |
| # Fetch LM result | |
| r = _fetch_result(lm_job_id) | |
| lm_results = r.json() | |
| if not isinstance(lm_results, list) or len(lm_results) == 0: | |
| raise RuntimeError(f"LM returned no results: {lm_results}") | |
| synth_request = lm_results[0] | |
| # -- Synth phase -- | |
| synth_request["output_format"] = synth_fmt | |
| if progress_cb: | |
| progress_cb("synth_submit", None) | |
| r = requests.post(f"{ACE_SERVER}/synth", json=synth_request, timeout=30) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"Synth submit failed: {r.status_code} {r.text}") | |
| synth_job_id = r.json().get("id") | |
| if progress_cb: | |
| progress_cb("synth_poll", {"job_id": synth_job_id}) | |
| synth_status, synth_elapsed = _poll_job(synth_job_id, timeout=600) | |
| if synth_status != "done": | |
| raise RuntimeError(f"Synth {synth_status} after {synth_elapsed:.0f}s") | |
| # Fetch audio | |
| if progress_cb: | |
| progress_cb("fetch", None) | |
| r = _fetch_result(synth_job_id, timeout=60) | |
| if r.status_code != 200: | |
| raise RuntimeError(f"Audio fetch failed: {r.status_code}") | |
| tmp = tempfile.NamedTemporaryFile(suffix=suffix, dir=OUTPUT_DIR, delete=False) | |
| tmp.write(r.content) | |
| tmp.close() | |
| elapsed = time.time() - t0 | |
| msg = f"Done in {elapsed:.0f}s | {duration}s audio, {steps} steps, {fmt}" | |
| return tmp.name, msg | |
| # --------------------------------------------------------------------------- | |
| # CLI mode | |
| # --------------------------------------------------------------------------- | |
| def cli_main(): | |
| parser = argparse.ArgumentParser( | |
| description="ACE-Step 1.5 XL (CPU) - CLI inference via ace-server", | |
| ) | |
| parser.add_argument("caption", nargs="?", default="upbeat electronic dance music", | |
| help="Music description / caption") | |
| parser.add_argument("--lyrics", "-l", default="[Instrumental]", | |
| help="Lyrics text (use '[Instrumental]' for no vocals)") | |
| parser.add_argument("--bpm", type=int, default=120, help="Beats per minute") | |
| parser.add_argument("--duration", "-d", type=float, default=10, | |
| help="Duration in seconds (max 300)") | |
| parser.add_argument("--steps", "-s", type=int, default=8, | |
| help="Inference steps (1-32)") | |
| parser.add_argument("--seed", type=int, default=-1, | |
| help="Random seed (-1 for random)") | |
| parser.add_argument("--format", "-f", choices=["wav", "mp3"], default="wav", | |
| help="Output audio format") | |
| parser.add_argument("--adapter", "-a", default=None, | |
| help="LoRA adapter name") | |
| parser.add_argument("-o", "--output", default=None, | |
| help="Output file path (default: auto in outputs dir)") | |
| parser.add_argument("--server", default=None, | |
| help="ace-server URL (default: http://127.0.0.1:8085)") | |
| args = parser.parse_args() | |
| if args.server: | |
| global ACE_SERVER | |
| ACE_SERVER = args.server | |
| if not _server_ok(): | |
| print(f"ERROR: ace-server not reachable at {ACE_SERVER}", file=sys.stderr) | |
| sys.exit(1) | |
| seed = args.seed if args.seed >= 0 else None | |
| def cli_progress(phase, data): | |
| phases = { | |
| "lm_submit": "Submitting LM job...", | |
| "lm_poll": f"LM generating (job {data['job_id']})..." if data else "LM generating...", | |
| "synth_submit": "Submitting synth job...", | |
| "synth_poll": f"Synthesizing (job {data['job_id']})..." if data else "Synthesizing...", | |
| "fetch": "Fetching audio...", | |
| } | |
| msg = phases.get(phase, phase) | |
| print(f" [{phase}] {msg}") | |
| print(f"ACE-Step CLI | caption: {args.caption}") | |
| print(f" lyrics: {args.lyrics} | bpm: {args.bpm} | duration: {args.duration}s " | |
| f"| steps: {args.steps} | seed: {args.seed} | format: {args.format}") | |
| try: | |
| audio_path, status = _run_pipeline( | |
| caption=args.caption, | |
| lyrics=args.lyrics, | |
| bpm=args.bpm, | |
| duration=args.duration, | |
| seed=seed, | |
| steps=args.steps, | |
| output_format=args.format, | |
| adapter=args.adapter, | |
| progress_cb=cli_progress, | |
| ) | |
| except RuntimeError as e: | |
| print(f"ERROR: {e}", file=sys.stderr) | |
| sys.exit(1) | |
| # Move to requested output path if specified | |
| if args.output: | |
| import shutil | |
| out_dir = os.path.dirname(os.path.abspath(args.output)) | |
| os.makedirs(out_dir, exist_ok=True) | |
| shutil.move(audio_path, args.output) | |
| audio_path = args.output | |
| print(f" {status}") | |
| print(f" Output: {audio_path}") | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI mode | |
| # --------------------------------------------------------------------------- | |
| def gradio_main(): | |
| import gradio as gr | |
| # -- Generate tab handler -- | |
| def generate_music(caption, lyrics, instrumental, bpm, duration, seed, | |
| steps, output_format, progress=gr.Progress(track_tqdm=True)): | |
| if not _server_ok(): | |
| return None, "ace-server not running. Check logs." | |
| if instrumental or not lyrics or lyrics.strip() == "": | |
| lyrics = "[Instrumental]" | |
| actual_seed = None if seed is None or int(seed) < 0 else int(seed) | |
| progress_map = { | |
| "lm_submit": (0.05, "Submitting LM job..."), | |
| "lm_poll": (0.10, "LM generating..."), | |
| "synth_submit": (0.40, "Submitting synth job..."), | |
| "synth_poll": (0.50, "Synthesizing audio..."), | |
| "fetch": (0.90, "Fetching audio..."), | |
| } | |
| def gr_progress(phase, data): | |
| pct, desc = progress_map.get(phase, (0.5, phase)) | |
| if data and "job_id" in data: | |
| desc += f" (job {data['job_id']})" | |
| progress(pct, desc=desc) | |
| try: | |
| audio_path, status = _run_pipeline( | |
| caption=caption, | |
| lyrics=lyrics, | |
| bpm=bpm, | |
| duration=duration, | |
| seed=actual_seed, | |
| steps=steps, | |
| output_format=output_format, | |
| progress_cb=gr_progress, | |
| ) | |
| return audio_path, status | |
| except RuntimeError as e: | |
| return None, str(e) | |
| except Exception as e: | |
| return None, f"Unexpected error: {e}" | |
| # -- Server info helper -- | |
| def get_server_status(): | |
| if not _server_ok(): | |
| return "ace-server: OFFLINE" | |
| props = _get_props() | |
| lines = ["ace-server: ONLINE"] | |
| if props: | |
| lines.append(json.dumps(props, indent=2)) | |
| return "\n".join(lines) | |
| # -- Training -- | |
| def train_lora(audio_files, lora_name, epochs, lr, rank, | |
| progress=gr.Progress(track_tqdm=True)): | |
| import shutil | |
| import gc | |
| if not audio_files: | |
| return "No audio files uploaded." | |
| lora_name = (lora_name or "").strip() or "my-lora" | |
| epochs = max(1, min(int(epochs), 10)) | |
| lr = float(lr) | |
| rank = max(1, min(int(rank), 64)) | |
| output_dir = os.path.join(ADAPTER_DIR, lora_name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| audio_dir = os.path.join(output_dir, "audio_input") | |
| os.makedirs(audio_dir, exist_ok=True) | |
| for f in audio_files: | |
| src = f.name if hasattr(f, "name") else str(f) | |
| shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src))) | |
| log_lines = [ | |
| f"LoRA Training: '{lora_name}'", | |
| f"Audio files: {len(audio_files)}", | |
| f"Epochs: {epochs}, LR: {lr}, Rank: {rank}", | |
| f"Output: {output_dir}", | |
| "", | |
| ] | |
| try: | |
| ckpt_files = os.listdir(ACE_CHECKPOINT_DIR) if os.path.isdir(ACE_CHECKPOINT_DIR) else [] | |
| if len(ckpt_files) < 3: | |
| log_lines.append("[Step 0] Downloading model checkpoints...") | |
| progress(0.02, desc="Downloading checkpoints...") | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| ACE_HF_MODEL, | |
| local_dir=ACE_CHECKPOINT_DIR, | |
| ignore_patterns=["*.md", "*.txt", ".gitattributes"], | |
| ) | |
| log_lines.append(" Checkpoints downloaded.") | |
| if ACE_SOURCE_DIR not in sys.path: | |
| sys.path.insert(0, ACE_SOURCE_DIR) | |
| import torchaudio | |
| _orig_load = torchaudio.load | |
| def _load_soundfile(filepath, *args, **kwargs): | |
| kwargs.setdefault('backend', 'soundfile') | |
| return _orig_load(filepath, *args, **kwargs) | |
| torchaudio.load = _load_soundfile | |
| log_lines.append("[Step 1/2] Preprocessing audio files...") | |
| progress(0.10, desc="Preprocessing audio...") | |
| tensor_dir = os.path.join(output_dir, "preprocessed_tensors") | |
| os.makedirs(tensor_dir, exist_ok=True) | |
| from acestep.training_v2.preprocess import preprocess_audio_files | |
| result = preprocess_audio_files( | |
| audio_dir=audio_dir, | |
| output_dir=tensor_dir, | |
| checkpoint_dir=ACE_CHECKPOINT_DIR, | |
| variant="turbo", | |
| max_duration=60.0, | |
| device="cpu", | |
| precision="float32", | |
| ) | |
| processed = result.get("processed", 0) | |
| total_files = result.get("total", 0) | |
| failed = result.get("failed", 0) | |
| log_lines.append(f" Preprocessed: {processed}/{total_files} (failed: {failed})") | |
| if processed == 0: | |
| log_lines.append("ERROR: No files preprocessed successfully.") | |
| return "\n".join(log_lines) | |
| log_lines.append("[Step 2/2] Training LoRA adapter (CPU, this will be slow)...") | |
| progress(0.30, desc="Loading model for training...") | |
| from acestep.training_v2.model_loader import load_decoder_for_training | |
| from acestep.training_v2.trainer_fixed import FixedLoRATrainer | |
| from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2 | |
| model = load_decoder_for_training( | |
| checkpoint_dir=ACE_CHECKPOINT_DIR, | |
| variant="turbo", | |
| device="cpu", | |
| precision="float32", | |
| ) | |
| model = model.float() | |
| adapter_cfg = LoRAConfigV2(r=rank, alpha=rank, dropout=0.0) | |
| train_cfg = TrainingConfigV2( | |
| checkpoint_dir=ACE_CHECKPOINT_DIR, | |
| model_variant="turbo", | |
| dataset_dir=tensor_dir, | |
| output_dir=output_dir, | |
| max_epochs=epochs, | |
| batch_size=1, | |
| learning_rate=lr, | |
| device="cpu", | |
| precision="float32", | |
| seed=42, | |
| num_workers=0, | |
| pin_memory=False, | |
| ) | |
| trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg) | |
| step_count = 0 | |
| last_loss = 0.0 | |
| for update in trainer.train(): | |
| if hasattr(update, "step"): | |
| step_count = update.step | |
| last_loss = update.loss | |
| elif isinstance(update, tuple) and len(update) >= 2: | |
| step_count = update[0] | |
| last_loss = update[1] | |
| if step_count % 5 == 0: | |
| log_lines.append(f" Step {step_count}: loss={last_loss:.4f}") | |
| pct = 0.30 + 0.65 * min(step_count / max(epochs * processed, 1), 1.0) | |
| progress(pct, desc=f"Step {step_count}, loss={last_loss:.4f}") | |
| log_lines.append(f"Training complete! Final: step {step_count}, loss={last_loss:.4f}") | |
| log_lines.append(f"LoRA saved to: {output_dir}") | |
| del model, trainer | |
| gc.collect() | |
| except ImportError as e: | |
| log_lines.append(f"Import error: {e}") | |
| log_lines.append(f"Check ACE-Step source at {ACE_SOURCE_DIR}") | |
| import traceback | |
| log_lines.append(traceback.format_exc()) | |
| except Exception as e: | |
| import traceback | |
| log_lines.append(f"ERROR: {e}") | |
| log_lines.append(traceback.format_exc()) | |
| return "\n".join(log_lines) | |
| # -- Build UI -- | |
| CSS = """ | |
| .compact-row { gap: 8px !important; } | |
| .status-box textarea { font-family: monospace; font-size: 13px; } | |
| """ | |
| with gr.Blocks(title="ACE-Step 1.5 XL (CPU)", css=CSS) as demo: | |
| with gr.Tabs(): | |
| # ============================================================ | |
| # Tab 1: Generate Music | |
| # ============================================================ | |
| with gr.Tab("Generate Music"): | |
| gr.Markdown( | |
| "**[ACE-Step 1.5 XL (CPU)](https://github.com/ace-step/ACE-Step-1.5)** " | |
| "GGUF Q4_K_M via " | |
| "[acestep.cpp](https://github.com/ServeurpersoCom/acestep.cpp)" | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| with gr.Column(scale=2): | |
| caption = gr.Textbox( | |
| label="Music Description", | |
| lines=2, | |
| value="upbeat electronic dance music, energetic synth leads", | |
| ) | |
| lyrics = gr.Textbox( | |
| label="Lyrics", | |
| lines=3, | |
| value="[Instrumental]", | |
| placeholder="Enter lyrics or [Instrumental] for no vocals", | |
| ) | |
| with gr.Column(scale=1): | |
| audio_out = gr.Audio(label="Output", type="filepath") | |
| status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=2, | |
| elem_classes="status-box", | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| instrumental = gr.Checkbox(label="Instrumental", value=True, scale=1) | |
| bpm = gr.Number(label="BPM", value=120, minimum=0, maximum=300, scale=1) | |
| duration = gr.Slider( | |
| label="Duration (s)", minimum=10, maximum=120, | |
| value=10, step=5, scale=1, | |
| ) | |
| steps = gr.Slider( | |
| label="Steps", minimum=1, maximum=32, | |
| value=8, step=1, scale=1, | |
| ) | |
| seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1) | |
| output_format = gr.Radio( | |
| label="Format", choices=["wav", "mp3"], | |
| value="wav", scale=1, | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| gen_btn = gr.Button("Generate Music", variant="primary", scale=2) | |
| status_btn = gr.Button("Server Status", scale=1) | |
| gen_btn.click( | |
| fn=generate_music, | |
| inputs=[caption, lyrics, instrumental, bpm, duration, | |
| seed, steps, output_format], | |
| outputs=[audio_out, status], | |
| api_name="generate", | |
| ) | |
| status_btn.click( | |
| fn=get_server_status, | |
| inputs=[], | |
| outputs=[status], | |
| api_name="server_status", | |
| ) | |
| # ============================================================ | |
| # Tab 2: Train LoRA | |
| # ============================================================ | |
| with gr.Tab("Train LoRA"): | |
| gr.Markdown( | |
| "### LoRA Training\n" | |
| "Fine-tune ACE-Step on your own audio data. " | |
| "CPU training is very slow. Checkpoints downloaded on first run (~10GB)." | |
| ) | |
| with gr.Row(elem_classes="compact-row"): | |
| with gr.Column(scale=2): | |
| train_audio = gr.File( | |
| label="Training Audio Files", | |
| file_count="multiple", | |
| file_types=["audio"], | |
| ) | |
| with gr.Column(scale=1): | |
| lora_name = gr.Textbox(label="LoRA Name", value="my-lora") | |
| epochs = gr.Number(label="Epochs", value=5, minimum=1, maximum=10) | |
| lr = gr.Number(label="Learning Rate", value=1e-4) | |
| rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64) | |
| train_btn = gr.Button("Train", variant="primary") | |
| train_log = gr.Textbox( | |
| label="Training Log", | |
| interactive=False, | |
| lines=10, | |
| elem_classes="status-box", | |
| ) | |
| train_btn.click( | |
| fn=train_lora, | |
| inputs=[train_audio, lora_name, epochs, lr, rank], | |
| outputs=[train_log], | |
| api_name="train_lora", | |
| ) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| mcp_server=True, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # If any CLI arguments besides the script name, run CLI mode | |
| # (Gradio sets no extra args; start.sh calls `python3 /app/app.py`) | |
| if len(sys.argv) > 1: | |
| cli_main() | |
| else: | |
| gradio_main() | |