"""ACE-Step 1.5 XL (CPU) - Gradio frontend + CLI for ace-server GGUF inference""" import os import sys import time import json import argparse import tempfile import subprocess import shutil import requests import logging from train_engine import ( preprocess_audio, train_lora_generator, cancel_training, get_trained_loras as _get_trained_loras_engine, ) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Configurable limits (edit here, not buried in code) # --------------------------------------------------------------------------- MAX_AUDIO_DURATION = 240 # seconds, cap per audio file for training MAX_TRAINING_TIME = 28800 # 8 hours hard training timeout (seconds) MAX_AUDIO_FILES = 50 # max number of training audio files per run # --------------------------------------------------------------------------- # Paths & constants # --------------------------------------------------------------------------- ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085") OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs") os.makedirs(OUTPUT_DIR, exist_ok=True) ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints") ACE_SOURCE_DIR = "/app/ace-step-source" ACE_HF_MODEL = "ACE-Step/Ace-Step1.5" ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters") MODELS_DIR = os.environ.get("ACE_MODELS_DIR", "/app/models") ACE_SERVER_BIN = "/app/ace-server" # HF repo for on-demand GGUF downloads GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF" # --------------------------------------------------------------------------- # ace-server helpers # --------------------------------------------------------------------------- def _server_ok(): try: return requests.get(f"{ACE_SERVER}/health", timeout=5).status_code == 200 except Exception: return False def _get_props(): """Fetch server properties (models, adapters).""" try: r = requests.get(f"{ACE_SERVER}/props", timeout=10) if r.status_code == 200: return r.json() except Exception: pass return {} def _poll_job(job_id, timeout=600, progress_cb=None): """Poll a job until done/error/timeout. Returns (status, elapsed).""" t0 = time.time() while time.time() - t0 < timeout: try: r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=10) data = r.json() status = data.get("status", "unknown") if progress_cb: progress_cb(status, data) if status in ("done", "error"): return status, time.time() - t0 except Exception: pass time.sleep(2) return "timeout", time.time() - t0 def _fetch_result(job_id, timeout=60): """Fetch result bytes/json for a completed job.""" r = requests.get( f"{ACE_SERVER}/job", params={"id": job_id, "result": 1}, timeout=timeout, ) return r def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format, adapter=None, lm_model=None, progress_cb=None): """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises.""" t0 = time.time() # -- Build LM request -- req = {"caption": caption or "upbeat electronic dance music"} req["lyrics"] = lyrics if lyrics and lyrics.strip() else "[Instrumental]" if bpm and int(bpm) > 0: req["bpm"] = int(bpm) if duration and float(duration) > 0: req["duration"] = min(float(duration), 300) if seed is not None and int(seed) >= 0: req["seed"] = int(seed) if steps and int(steps) > 0: req["inference_steps"] = int(steps) if adapter: req["adapter"] = adapter if lm_model: req["model"] = lm_model fmt = output_format if output_format in ("wav", "mp3") else "mp3" synth_fmt = "wav16" if fmt == "wav" else "mp3" suffix = f".{fmt}" # -- LM phase -- if progress_cb: progress_cb("lm_submit", None) r = requests.post(f"{ACE_SERVER}/lm", json=req, timeout=30) if r.status_code != 200: raise RuntimeError(f"LM submit failed: {r.status_code} {r.text}") lm_job_id = r.json().get("id") if progress_cb: progress_cb("lm_poll", {"job_id": lm_job_id}) lm_status, lm_elapsed = _poll_job(lm_job_id, timeout=900) if lm_status != "done": raise RuntimeError(f"LM {lm_status} after {lm_elapsed:.0f}s") # Fetch LM result r = _fetch_result(lm_job_id) lm_results = r.json() if not isinstance(lm_results, list) or len(lm_results) == 0: raise RuntimeError(f"LM returned no results: {lm_results}") synth_request = lm_results[0] # -- Synth phase -- synth_request["output_format"] = synth_fmt if adapter: synth_request["adapter"] = adapter synth_request["synth_model"] = "acestep-v15-turbo-Q4_K_M.gguf" if progress_cb: progress_cb("synth_submit", None) r = requests.post(f"{ACE_SERVER}/synth", json=synth_request, timeout=30) if r.status_code != 200: raise RuntimeError(f"Synth submit failed: {r.status_code} {r.text}") synth_job_id = r.json().get("id") if progress_cb: progress_cb("synth_poll", {"job_id": synth_job_id}) synth_status, synth_elapsed = _poll_job(synth_job_id, timeout=600) if synth_status != "done": raise RuntimeError(f"Synth {synth_status} after {synth_elapsed:.0f}s") # Fetch audio if progress_cb: progress_cb("fetch", None) r = _fetch_result(synth_job_id, timeout=60) if r.status_code != 200: raise RuntimeError(f"Audio fetch failed: {r.status_code}") tmp = tempfile.NamedTemporaryFile(suffix=suffix, dir=OUTPUT_DIR, delete=False) tmp.write(r.content) tmp.close() elapsed = time.time() - t0 msg = f"Done in {elapsed:.0f}s | {duration}s audio, {steps} steps, {fmt}" return tmp.name, msg # --------------------------------------------------------------------------- # LM model scanning & on-demand download # --------------------------------------------------------------------------- DEFAULT_LM = "acestep-5Hz-lm-1.7B-Q8_0.gguf" AVAILABLE_LM_MODELS = [ "acestep-5Hz-lm-1.7B-Q8_0.gguf", "acestep-5Hz-lm-0.6B-Q8_0.gguf", "acestep-5Hz-lm-4B-Q5_K_M.gguf", ] def _scan_lm_models(): """Return LM model choices. Installed shown as-is, others need download.""" installed = set() if os.path.isdir(MODELS_DIR): for f in os.listdir(MODELS_DIR): if "-lm-" in f and f.endswith(".gguf"): installed.add(f) choices = [] for m in AVAILABLE_LM_MODELS: if m in installed: choices.append(m) else: choices.append(f"{m} [not installed]") return choices def _download_lm_model(filename): """Download a GGUF LM model from HF if not already present.""" dest = os.path.join(MODELS_DIR, filename) if os.path.isfile(dest): return dest try: from huggingface_hub import hf_hub_download path = hf_hub_download( repo_id=GGUF_HF_REPO, filename=filename, local_dir=MODELS_DIR, ) return path except Exception as exc: logger.error("Failed to download %s: %s", filename, exc) return None # --------------------------------------------------------------------------- # LoRA listing for UI dropdowns # --------------------------------------------------------------------------- def _list_lora_choices(): """Return list of LoRA choices for dropdown, including 'None'.""" choices = ["None (no LoRA)"] if os.path.isdir(ADAPTER_DIR): for d in os.listdir(ADAPTER_DIR): if os.path.isdir(os.path.join(ADAPTER_DIR, d)): choices.append(d) return choices # --------------------------------------------------------------------------- # ace-server stop/start helpers # --------------------------------------------------------------------------- _ace_proc = None def _stop_ace_server(): """Stop ace-server process.""" global _ace_proc logger.info("[ace-server] Stopping...") if _ace_proc and _ace_proc.poll() is None: _ace_proc.terminate() try: _ace_proc.wait(timeout=10) except subprocess.TimeoutExpired: _ace_proc.kill() _ace_proc = None logger.info("[ace-server] Stopped (tracked PID)") else: try: subprocess.run(["pkill", "ace-server"], stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL, timeout=10) logger.info("[ace-server] Stopped (pkill)") except Exception: pass time.sleep(1) def _start_ace_server(): """Start ace-server in background and wait for health.""" global _ace_proc logger.info("[ace-server] Starting with --adapters %s", ADAPTER_DIR) try: _ace_proc = subprocess.Popen( [ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085", "--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"], ) except Exception as exc: logger.error("[ace-server] Failed to start: %s", exc) return False for _ in range(30): if _server_ok(): logger.info("[ace-server] Healthy") return True time.sleep(2) logger.error("[ace-server] Health check timeout") return False # --------------------------------------------------------------------------- # CLI mode # --------------------------------------------------------------------------- def cli_main(): parser = argparse.ArgumentParser( description="ACE-Step 1.5 XL (CPU) - CLI inference via ace-server", ) parser.add_argument("caption", nargs="?", default="upbeat electronic dance music", help="Music description / caption") parser.add_argument("--lyrics", "-l", default="[Instrumental]", help="Lyrics text (use '[Instrumental]' for no vocals)") parser.add_argument("--bpm", type=int, default=120, help="Beats per minute") parser.add_argument("--duration", "-d", type=float, default=10, help="Duration in seconds (max 300)") parser.add_argument("--steps", "-s", type=int, default=8, help="Inference steps (1-32)") parser.add_argument("--seed", type=int, default=-1, help="Random seed (-1 for random)") parser.add_argument("--format", "-f", choices=["wav", "mp3"], default="wav", help="Output audio format") parser.add_argument("--adapter", "-a", default=None, help="LoRA adapter name") parser.add_argument("-o", "--output", default=None, help="Output file path (default: auto in outputs dir)") parser.add_argument("--server", default=None, help="ace-server URL (default: http://127.0.0.1:8085)") args = parser.parse_args() if args.server: global ACE_SERVER ACE_SERVER = args.server if not _server_ok(): print(f"ERROR: ace-server not reachable at {ACE_SERVER}", file=sys.stderr) sys.exit(1) seed = args.seed if args.seed >= 0 else None def cli_progress(phase, data): phases = { "lm_submit": "Submitting LM job...", "lm_poll": f"LM generating (job {data['job_id']})..." if data else "LM generating...", "synth_submit": "Submitting synth job...", "synth_poll": f"Synthesizing (job {data['job_id']})..." if data else "Synthesizing...", "fetch": "Fetching audio...", } msg = phases.get(phase, phase) print(f" [{phase}] {msg}") print(f"ACE-Step CLI | caption: {args.caption}") print(f" lyrics: {args.lyrics} | bpm: {args.bpm} | duration: {args.duration}s " f"| steps: {args.steps} | seed: {args.seed} | format: {args.format}") try: audio_path, status = _run_pipeline( caption=args.caption, lyrics=args.lyrics, bpm=args.bpm, duration=args.duration, seed=seed, steps=args.steps, output_format=args.format, adapter=args.adapter, progress_cb=cli_progress, ) except RuntimeError as e: print(f"ERROR: {e}", file=sys.stderr) sys.exit(1) # Move to requested output path if specified if args.output: out_dir = os.path.dirname(os.path.abspath(args.output)) os.makedirs(out_dir, exist_ok=True) shutil.move(audio_path, args.output) audio_path = args.output print(f" {status}") print(f" Output: {audio_path}") # --------------------------------------------------------------------------- # Gradio UI mode # --------------------------------------------------------------------------- def gradio_main(): import gradio as gr import gc # -- Persistent training log buffer (survives across yields) -- _train_log_lines = [] # -- Generate tab handler -- def generate_music(caption, lyrics, instrumental, bpm, duration, seed, steps, lora_select, lm_model_select, progress=gr.Progress(track_tqdm=True)): if not _server_ok(): return None, "ace-server not running. Check logs." if instrumental or not lyrics or lyrics.strip() == "": lyrics = "[Instrumental]" actual_seed = None if seed is None or int(seed) < 0 else int(seed) adapter = None if lora_select == "None (no LoRA)" else lora_select lm_model_file = lm_model_select.replace(" [not installed]", "") if lm_model_select else None if lm_model_file and "[not installed]" in (lm_model_select or ""): _download_lm_model(lm_model_file) lm_model = lm_model_file progress_map = { "lm_submit": (0.05, "Submitting LM job..."), "lm_poll": (0.10, "LM generating..."), "synth_submit": (0.40, "Submitting synth job..."), "synth_poll": (0.50, "Synthesizing audio..."), "fetch": (0.90, "Fetching audio..."), } def gr_progress(phase, data): pct, desc = progress_map.get(phase, (0.5, phase)) if data and "job_id" in data: desc += f" (job {data['job_id']})" progress(pct, desc=desc) try: audio_path, status = _run_pipeline( caption=caption, lyrics=lyrics, bpm=bpm, duration=duration, seed=actual_seed, steps=steps, output_format="mp3", adapter=adapter, lm_model=lm_model, progress_cb=gr_progress, ) return audio_path, status except RuntimeError as e: return None, str(e) except Exception as e: return None, f"Unexpected error: {e}" # -- Server info helper -- def get_server_status(): if not _server_ok(): return "ace-server: OFFLINE" props = _get_props() lines = ["ace-server: ONLINE"] if props: lines.append(json.dumps(props, indent=2)) return "\n".join(lines) # -- Training generator (direct integration, no subprocess) -- def train_lora_ui(audio_files, lora_name, epochs, lr, rank): """Generator that yields (train_log, train_btn_update, cancel_btn_update).""" import gc as _gc _train_log_lines.clear() train_start = time.time() def _log(msg): _train_log_lines.append(msg) def _log_text(): return "\n".join(_train_log_lines) # -- Validation -- if not audio_files: _log("[FAIL] No audio files uploaded.") yield _log_text(), gr.update(visible=True), gr.update(visible=False) return if len(audio_files) > MAX_AUDIO_FILES: _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}") yield _log_text(), gr.update(visible=True), gr.update(visible=False) return lora_name = (lora_name or "").strip() or "my-lora" # Sanitize: alphanumeric, dash, underscore only lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name) epochs = max(1, min(int(epochs), 10)) lr = float(lr) rank = max(1, min(int(rank), 64)) work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name) os.makedirs(work_dir, exist_ok=True) audio_dir = os.path.join(work_dir, "audio_input") os.makedirs(audio_dir, exist_ok=True) adapter_out = os.path.join(ADAPTER_DIR, lora_name) os.makedirs(adapter_out, exist_ok=True) # Copy uploaded audio files _log(f"[INFO] Preparing {len(audio_files)} audio files...") yield _log_text(), gr.update(visible=False), gr.update(visible=True) for f in audio_files: src = f.name if hasattr(f, "name") else str(f) shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src))) _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | " f"Epochs: {epochs} | LR: {lr} | Rank: {rank}") yield _log_text(), gr.update(visible=False), gr.update(visible=True) # Stop ace-server before training (frees memory) _log("[INFO] Stopping ace-server for training...") yield _log_text(), gr.update(visible=False), gr.update(visible=True) _stop_ace_server() _gc.collect() try: # -- Phase 1: Preprocessing -- _log("[Step 1/2] Preprocessing audio...") yield _log_text(), gr.update(visible=False), gr.update(visible=True) preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors") def preprocess_progress(current, total, desc): _log(f" {desc} ({current}/{total})") result = preprocess_audio( audio_dir=audio_dir, output_dir=preprocessed_dir, checkpoint_dir=ACE_CHECKPOINT_DIR, device="cpu", variant="turbo", max_duration=float(MAX_AUDIO_DURATION), progress_callback=preprocess_progress, cancel_check=lambda: False, ) yield _log_text(), gr.update(visible=False), gr.update(visible=True) processed = result.get("processed", 0) failed = result.get("failed", 0) total = result.get("total", 0) _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})") yield _log_text(), gr.update(visible=False), gr.update(visible=True) if processed == 0: _log("[FAIL] No files preprocessed successfully. Cannot train.") yield _log_text(), gr.update(visible=True), gr.update(visible=False) return _gc.collect() # -- Phase 2: Training -- _log("[Step 2/2] Training LoRA...") yield _log_text(), gr.update(visible=False), gr.update(visible=True) for msg in train_lora_generator( dataset_dir=preprocessed_dir, output_dir=adapter_out, checkpoint_dir=ACE_CHECKPOINT_DIR, epochs=epochs, lr=lr, rank=rank, alpha=rank * 2, dropout=0.0, batch_size=1, gradient_accumulation_steps=4, warmup_steps=100, weight_decay=0.01, max_grad_norm=1.0, save_every_n_epochs=max(1, epochs // 2), seed=42, variant="turbo", device="cpu", log_every=5, ): # Timeout check elapsed = time.time() - train_start if elapsed > MAX_TRAINING_TIME: _log(f"[WARN] Training timed out after {int(elapsed)}s") cancel_training() break _log(msg) yield _log_text(), gr.update(visible=False), gr.update(visible=True) if msg.strip() == "[DONE]": break _log(f"[INFO] Total time: {time.time() - train_start:.0f}s") yield _log_text(), gr.update(visible=False), gr.update(visible=True) except Exception as exc: _log(f"[FAIL] Training error: {exc}") import traceback _log(traceback.format_exc()) yield _log_text(), gr.update(visible=True), gr.update(visible=False) finally: # Always restart ace-server _log("[INFO] Restarting ace-server...") yield _log_text(), gr.update(visible=False), gr.update(visible=True) _gc.collect() ok = _start_ace_server() if ok: _log("[OK] ace-server restarted successfully") else: _log("[WARN] ace-server may not have restarted -- check logs") yield _log_text(), gr.update(visible=True), gr.update(visible=False) # -- Cancel handler -- def _on_cancel(): cancel_training() logger.info("Cancel requested by user") return "Cancelling after current epoch... please wait" # -- Check log handler -- def _check_log(): if _train_log_lines: return "\n".join(_train_log_lines) return "No training log available." # -- Build LM model choices -- def _lm_model_choices(): return _scan_lm_models() # -- Build UI -- CSS = """ .compact-row { gap: 8px !important; } .status-box textarea { font-family: monospace; font-size: 13px; } """ with gr.Blocks(title="ACE-Step 1.5 XL (CPU)", css=CSS) as demo: with gr.Tabs(): # ============================================================ # Tab 1: Generate Music # ============================================================ with gr.Tab("Generate Music"): gr.Markdown( "**[ACE-Step 1.5 XL (CPU)](https://github.com/ace-step/ACE-Step-1.5)** " "GGUF Q4_K_M via " "[acestep.cpp](https://github.com/ServeurpersoCom/acestep.cpp)" ) with gr.Row(elem_classes="compact-row"): with gr.Column(scale=2): caption = gr.Textbox( label="Music Description", lines=2, value="upbeat electronic dance music, energetic synth leads", ) lyrics = gr.Textbox( label="Lyrics", lines=3, value="[Instrumental]", placeholder="Enter lyrics or [Instrumental] for no vocals", ) with gr.Column(scale=1): audio_out = gr.Audio(label="Output", type="filepath") status = gr.Textbox( label="Status", interactive=False, lines=2, elem_classes="status-box", ) with gr.Row(elem_classes="compact-row"): instrumental = gr.Checkbox(label="Instrumental", value=True, scale=1) bpm = gr.Number(label="BPM", value=120, minimum=0, maximum=300, scale=1) duration = gr.Slider( label="Duration (s)", minimum=10, maximum=120, value=10, step=5, scale=1, ) steps = gr.Slider( label="Steps", minimum=1, maximum=32, value=8, step=1, scale=1, ) seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1) with gr.Row(elem_classes="compact-row"): lora_select = gr.Dropdown( label="LoRA", choices=_list_lora_choices(), value="None (no LoRA)", scale=1, allow_custom_value=True, ) lm_model_select = gr.Dropdown( label="LM Model", choices=_lm_model_choices(), value=DEFAULT_LM, scale=1, ) with gr.Row(elem_classes="compact-row"): gen_btn = gr.Button("Generate Music", variant="primary", scale=2) status_btn = gr.Button("Server Status", scale=1) gen_btn.click( fn=generate_music, inputs=[caption, lyrics, instrumental, bpm, duration, seed, steps, lora_select, lm_model_select], outputs=[audio_out, status], api_name="generate", ) status_btn.click( fn=get_server_status, inputs=[], outputs=[status], api_name="server_status", ) # ============================================================ # Tab 2: Train LoRA # ============================================================ with gr.Tab("Train LoRA"): gr.Markdown( "### LoRA Training\n" "Fine-tune ACE-Step on your audio. " "CPU training is slow -- ace-server stops during training." ) with gr.Row(elem_classes="compact-row"): with gr.Column(scale=2): train_audio = gr.File( label="Training Audio Files", file_count="multiple", file_types=["audio"], ) with gr.Column(scale=1): lora_name = gr.Textbox(label="LoRA Name", value="my-lora") train_epochs = gr.Slider( label="Epochs", minimum=1, maximum=1000, value=3, step=1, ) train_lr = gr.Number(label="Learning Rate", value=3e-4) train_rank = gr.Slider( label="Rank (r)", minimum=1, maximum=128, value=32, step=1, ) with gr.Row(elem_classes="compact-row"): train_btn = gr.Button("Train", variant="primary", scale=2) cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1) log_btn = gr.Button("Check Log", scale=1) train_log = gr.Textbox( label="Training Log", interactive=False, lines=12, elem_classes="status-box", ) # Training generator -- yields (log, train_btn, cancel_btn) train_event = train_btn.click( train_lora_ui, inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank], outputs=[train_log, train_btn, cancel_btn], api_name="train_lora", concurrency_limit=1, ) # After training completes, restore buttons and refresh LoRA dropdown def _post_training(): return ( gr.update(visible=True), gr.update(visible=False), gr.update(choices=_list_lora_choices()), ) train_event.then( _post_training, outputs=[train_btn, cancel_btn, lora_select], ) # Cancel: set the flag, update status cancel_btn.click( _on_cancel, outputs=[train_log], ) # Check log: show last training output log_btn.click( _check_log, outputs=[train_log], api_name="check_log", ) demo.launch( server_name="0.0.0.0", server_port=7860, mcp_server=True, ) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": # If any CLI arguments besides the script name, run CLI mode # (Gradio sets no extra args; start.sh calls `python3 /app/app.py`) if len(sys.argv) > 1: cli_main() else: gradio_main()