ACE-Step-CPU / app.py
Nekochu's picture
add _is_space flag, block inference during training, understand clone fix
3c15b8b
raw
history blame
37.2 kB
"""ACE-Step 1.5 XL (CPU) - Gradio frontend + CLI for ace-server GGUF inference"""
import os
import sys
import time
import json
import argparse
import base64
import tempfile
import subprocess
import shutil
import string
import random
import requests
import logging
from train_engine import (
preprocess_audio,
train_lora_generator,
cancel_training,
get_trained_loras as _get_trained_loras_engine,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configurable limits (edit here, not buried in code)
# ---------------------------------------------------------------------------
MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min)
MAX_TRAINING_TIME = 28800 # 8 hours hard training timeout (seconds)
MAX_AUDIO_FILES = 50 # max number of training audio files per run
# ---------------------------------------------------------------------------
# Paths & constants
# ---------------------------------------------------------------------------
ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)
ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
ACE_SOURCE_DIR = "/app/ace-step-source"
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
MODELS_DIR = os.environ.get("ACE_MODELS_DIR", "/app/models")
ACE_SERVER_BIN = "/app/ace-server"
# Detect if running on HF Space (ace-server available) vs locally (PyTorch only)
_is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None
_training_in_progress = False
# HF repo for on-demand GGUF downloads
GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
# ---------------------------------------------------------------------------
# ace-server helpers
# ---------------------------------------------------------------------------
def _server_ok():
try:
return requests.get(f"{ACE_SERVER}/health", timeout=5).status_code == 200
except Exception:
return False
def _get_props():
"""Fetch server properties (models, adapters)."""
try:
r = requests.get(f"{ACE_SERVER}/props", timeout=10)
if r.status_code == 200:
return r.json()
except Exception:
pass
return {}
def _poll_job(job_id, timeout=600, progress_cb=None):
"""Poll a job until done/error/timeout. Returns (status, elapsed)."""
t0 = time.time()
while time.time() - t0 < timeout:
try:
r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=10)
data = r.json()
status = data.get("status", "unknown")
if progress_cb:
progress_cb(status, data)
if status in ("done", "error"):
return status, time.time() - t0
except Exception:
pass
time.sleep(2)
return "timeout", time.time() - t0
def _fetch_result(job_id, timeout=60):
"""Fetch result bytes/json for a completed job."""
r = requests.get(
f"{ACE_SERVER}/job",
params={"id": job_id, "result": 1},
timeout=timeout,
)
return r
def _caption_via_understand(audio_path, timeout=120):
"""Call ace-server /understand to get a rich caption for an audio file.
Returns a dict with caption, bpm, key, signature, lyrics on success,
or None on failure (caller should fall back to librosa).
"""
fname = os.path.basename(audio_path)
try:
with open(audio_path, "rb") as f:
audio_b64 = base64.b64encode(f.read()).decode("ascii")
except Exception as exc:
logger.warning("[Caption] %s: failed to read file: %s", fname, exc)
return None
# Submit
try:
r = requests.post(
f"{ACE_SERVER}/understand",
json={"audio": audio_b64},
timeout=30,
)
if r.status_code != 200:
logger.warning("[Caption] %s: /understand returned %d", fname, r.status_code)
return None
job_id = r.json().get("id")
if not job_id:
logger.warning("[Caption] %s: /understand returned no job id", fname)
return None
except Exception as exc:
logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
return None
# Poll until done
status, _ = _poll_job(job_id, timeout=timeout)
if status != "done":
logger.warning("[Caption] %s: /understand job %s -> %s", fname, job_id, status)
return None
# Fetch result
try:
r = _fetch_result(job_id, timeout=30)
if r.status_code != 200:
logger.warning("[Caption] %s: /understand result fetch failed: %d", fname, r.status_code)
return None
data = r.json()
# The result should contain caption, bpm, key, signature, lyrics
if isinstance(data, dict) and data.get("caption"):
return data
logger.warning("[Caption] %s: /understand returned no caption field", fname)
return None
except Exception as exc:
logger.warning("[Caption] %s: /understand result parse failed: %s", fname, exc)
return None
def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
adapter=None, lm_model=None, progress_cb=None):
"""Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
t0 = time.time()
# -- Build LM request --
req = {"caption": caption or "upbeat electronic dance music"}
req["lyrics"] = lyrics if lyrics and lyrics.strip() else "[Instrumental]"
if bpm and int(bpm) > 0:
req["bpm"] = int(bpm)
if duration and float(duration) > 0:
req["duration"] = min(float(duration), 300)
if seed is not None and int(seed) >= 0:
req["seed"] = int(seed)
if steps and int(steps) > 0:
req["inference_steps"] = int(steps)
if adapter:
req["adapter"] = adapter
if lm_model:
req["model"] = lm_model
fmt = output_format if output_format in ("wav", "mp3") else "mp3"
synth_fmt = "wav16" if fmt == "wav" else "mp3"
suffix = f".{fmt}"
# -- LM phase --
if progress_cb:
progress_cb("lm_submit", None)
r = requests.post(f"{ACE_SERVER}/lm", json=req, timeout=30)
if r.status_code != 200:
raise RuntimeError(f"LM submit failed: {r.status_code} {r.text}")
lm_job_id = r.json().get("id")
if progress_cb:
progress_cb("lm_poll", {"job_id": lm_job_id})
lm_status, lm_elapsed = _poll_job(lm_job_id, timeout=900)
if lm_status != "done":
raise RuntimeError(f"LM {lm_status} after {lm_elapsed:.0f}s")
# Fetch LM result
r = _fetch_result(lm_job_id)
lm_results = r.json()
if not isinstance(lm_results, list) or len(lm_results) == 0:
raise RuntimeError(f"LM returned no results: {lm_results}")
synth_request = lm_results[0]
# -- Synth phase --
synth_request["output_format"] = synth_fmt
if adapter:
synth_request["adapter"] = adapter
synth_request["synth_model"] = "acestep-v15-turbo-Q4_K_M.gguf"
if progress_cb:
progress_cb("synth_submit", None)
r = requests.post(f"{ACE_SERVER}/synth", json=synth_request, timeout=30)
if r.status_code != 200:
raise RuntimeError(f"Synth submit failed: {r.status_code} {r.text}")
synth_job_id = r.json().get("id")
if progress_cb:
progress_cb("synth_poll", {"job_id": synth_job_id})
synth_status, synth_elapsed = _poll_job(synth_job_id, timeout=600)
if synth_status != "done":
raise RuntimeError(f"Synth {synth_status} after {synth_elapsed:.0f}s")
# Fetch audio
if progress_cb:
progress_cb("fetch", None)
r = _fetch_result(synth_job_id, timeout=60)
if r.status_code != 200:
raise RuntimeError(f"Audio fetch failed: {r.status_code}")
tmp = tempfile.NamedTemporaryFile(suffix=suffix, dir=OUTPUT_DIR, delete=False)
tmp.write(r.content)
tmp.close()
elapsed = time.time() - t0
msg = f"Done in {elapsed:.0f}s | {duration}s audio, {steps} steps, {fmt}"
return tmp.name, msg
# ---------------------------------------------------------------------------
# LM model scanning & on-demand download
# ---------------------------------------------------------------------------
DEFAULT_LM = "acestep-5Hz-lm-1.7B-Q8_0.gguf"
AVAILABLE_LM_MODELS = [
"acestep-5Hz-lm-1.7B-Q8_0.gguf",
"acestep-5Hz-lm-0.6B-Q8_0.gguf",
"acestep-5Hz-lm-4B-Q5_K_M.gguf",
]
def _scan_lm_models():
"""Return LM model choices. Installed shown as-is, others need download."""
installed = set()
if os.path.isdir(MODELS_DIR):
for f in os.listdir(MODELS_DIR):
if "-lm-" in f and f.endswith(".gguf"):
installed.add(f)
choices = []
for m in AVAILABLE_LM_MODELS:
if m in installed:
choices.append(m)
else:
choices.append(f"{m} [not installed]")
return choices
def _download_lm_model(filename):
"""Download a GGUF LM model from HF if not already present."""
dest = os.path.join(MODELS_DIR, filename)
if os.path.isfile(dest):
return dest
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=GGUF_HF_REPO,
filename=filename,
local_dir=MODELS_DIR,
)
return path
except Exception as exc:
logger.error("Failed to download %s: %s", filename, exc)
return None
# ---------------------------------------------------------------------------
# LoRA listing for UI dropdowns
# ---------------------------------------------------------------------------
def _list_lora_choices():
"""Return list of LoRA choices for dropdown, including 'None'."""
choices = ["None (no LoRA)"]
if os.path.isdir(ADAPTER_DIR):
for d in os.listdir(ADAPTER_DIR):
if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
choices.append(d)
return choices
# ---------------------------------------------------------------------------
# ace-server stop/start helpers
# ---------------------------------------------------------------------------
_ace_proc = None
def _stop_ace_server():
"""Stop ace-server process."""
global _ace_proc
logger.info("[ace-server] Stopping...")
if _ace_proc and _ace_proc.poll() is None:
_ace_proc.terminate()
try:
_ace_proc.wait(timeout=10)
except subprocess.TimeoutExpired:
_ace_proc.kill()
_ace_proc = None
logger.info("[ace-server] Stopped (tracked PID)")
else:
try:
subprocess.run(["pkill", "ace-server"],
stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
timeout=10)
logger.info("[ace-server] Stopped (pkill)")
except Exception:
pass
time.sleep(1)
def _start_ace_server():
"""Start ace-server in background and wait for health."""
global _ace_proc
logger.info("[ace-server] Starting with --adapters %s", ADAPTER_DIR)
try:
_ace_proc = subprocess.Popen(
[ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
"--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
)
except Exception as exc:
logger.error("[ace-server] Failed to start: %s", exc)
return False
for _ in range(30):
if _server_ok():
logger.info("[ace-server] Healthy")
return True
time.sleep(2)
logger.error("[ace-server] Health check timeout")
return False
# ---------------------------------------------------------------------------
# CLI mode
# ---------------------------------------------------------------------------
def cli_main():
parser = argparse.ArgumentParser(
description="ACE-Step 1.5 XL (CPU) - CLI inference via ace-server",
)
parser.add_argument("caption", nargs="?", default="upbeat electronic dance music",
help="Music description / caption")
parser.add_argument("--lyrics", "-l", default="[Instrumental]",
help="Lyrics text (use '[Instrumental]' for no vocals)")
parser.add_argument("--bpm", type=int, default=120, help="Beats per minute")
parser.add_argument("--duration", "-d", type=float, default=10,
help="Duration in seconds (max 300)")
parser.add_argument("--steps", "-s", type=int, default=8,
help="Inference steps (1-32)")
parser.add_argument("--seed", type=int, default=-1,
help="Random seed (-1 for random)")
parser.add_argument("--format", "-f", choices=["wav", "mp3"], default="wav",
help="Output audio format")
parser.add_argument("--adapter", "-a", default=None,
help="LoRA adapter name")
parser.add_argument("-o", "--output", default=None,
help="Output file path (default: auto in outputs dir)")
parser.add_argument("--server", default=None,
help="ace-server URL (default: http://127.0.0.1:8085)")
args = parser.parse_args()
if args.server:
global ACE_SERVER
ACE_SERVER = args.server
if not _server_ok():
print(f"ERROR: ace-server not reachable at {ACE_SERVER}", file=sys.stderr)
sys.exit(1)
seed = args.seed if args.seed >= 0 else None
def cli_progress(phase, data):
phases = {
"lm_submit": "Submitting LM job...",
"lm_poll": f"LM generating (job {data['job_id']})..." if data else "LM generating...",
"synth_submit": "Submitting synth job...",
"synth_poll": f"Synthesizing (job {data['job_id']})..." if data else "Synthesizing...",
"fetch": "Fetching audio...",
}
msg = phases.get(phase, phase)
print(f" [{phase}] {msg}")
print(f"ACE-Step CLI | caption: {args.caption}")
print(f" lyrics: {args.lyrics} | bpm: {args.bpm} | duration: {args.duration}s "
f"| steps: {args.steps} | seed: {args.seed} | format: {args.format}")
try:
audio_path, status = _run_pipeline(
caption=args.caption,
lyrics=args.lyrics,
bpm=args.bpm,
duration=args.duration,
seed=seed,
steps=args.steps,
output_format=args.format,
adapter=args.adapter,
progress_cb=cli_progress,
)
except RuntimeError as e:
print(f"ERROR: {e}", file=sys.stderr)
sys.exit(1)
# Move to requested output path if specified
if args.output:
out_dir = os.path.dirname(os.path.abspath(args.output))
os.makedirs(out_dir, exist_ok=True)
shutil.move(audio_path, args.output)
audio_path = args.output
print(f" {status}")
print(f" Output: {audio_path}")
# ---------------------------------------------------------------------------
# Gradio UI mode
# ---------------------------------------------------------------------------
def gradio_main():
import gradio as gr
import gc
# -- Persistent training log buffer (survives across yields) --
_train_log_lines = []
# -- Generate tab handler --
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
steps, lora_select, lm_model_select,
progress=gr.Progress(track_tqdm=True)):
if _training_in_progress:
return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training."
if not _server_ok():
return None, "ace-server not running. Check logs."
if instrumental or not lyrics or lyrics.strip() == "":
lyrics = "[Instrumental]"
actual_seed = None if seed is None or int(seed) < 0 else int(seed)
adapter = None if lora_select == "None (no LoRA)" else lora_select
lm_model_file = lm_model_select.replace(" [not installed]", "") if lm_model_select else None
if lm_model_file and "[not installed]" in (lm_model_select or ""):
_download_lm_model(lm_model_file)
lm_model = lm_model_file
progress_map = {
"lm_submit": (0.05, "Submitting LM job..."),
"lm_poll": (0.10, "LM generating..."),
"synth_submit": (0.40, "Submitting synth job..."),
"synth_poll": (0.50, "Synthesizing audio..."),
"fetch": (0.90, "Fetching audio..."),
}
def gr_progress(phase, data):
pct, desc = progress_map.get(phase, (0.5, phase))
if data and "job_id" in data:
desc += f" (job {data['job_id']})"
progress(pct, desc=desc)
try:
audio_path, status = _run_pipeline(
caption=caption,
lyrics=lyrics,
bpm=bpm,
duration=duration,
seed=actual_seed,
steps=steps,
output_format="mp3",
adapter=adapter,
lm_model=lm_model,
progress_cb=gr_progress,
)
return audio_path, status
except RuntimeError as e:
return None, str(e)
except Exception as e:
return None, f"Unexpected error: {e}"
# -- Server info helper --
def get_server_status():
if not _server_ok():
return "ace-server: OFFLINE"
props = _get_props()
lines = ["ace-server: ONLINE"]
if props:
lines.append(json.dumps(props, indent=2))
return "\n".join(lines)
# -- Training generator (direct integration, no subprocess) --
def train_lora_ui(audio_files, lora_name, epochs, lr, rank):
"""Generator that yields (train_log, train_btn_update, cancel_btn_update)."""
import gc as _gc
_train_log_lines.clear()
train_start = time.time()
def _log(msg):
_train_log_lines.append(msg)
def _log_text():
return "\n".join(_train_log_lines)
# -- Validation --
if not audio_files:
_log("[FAIL] No audio files uploaded.")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
if len(audio_files) > MAX_AUDIO_FILES:
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
lora_name = (lora_name or "").strip() or "my-lora"
# Sanitize: alphanumeric, dash, underscore only
lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
# Append random suffix to prevent naming collisions between users
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
lora_name = f"{lora_name}-{suffix}"
epochs = max(1, min(int(epochs), 10))
lr = float(lr)
rank = max(1, min(int(rank), 64))
work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name)
os.makedirs(work_dir, exist_ok=True)
audio_dir = os.path.join(work_dir, "audio_input")
os.makedirs(audio_dir, exist_ok=True)
adapter_out = os.path.join(ADAPTER_DIR, lora_name)
os.makedirs(adapter_out, exist_ok=True)
# Copy uploaded audio files + check total duration
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
import librosa as _lr
total_dur = 0.0
accepted = 0
skipped_names = []
truncated_names = []
for f in audio_files:
src = f.name if hasattr(f, "name") else str(f)
fname = os.path.basename(src)
try:
dur = _lr.get_duration(path=src)
except Exception:
dur = 0.0
remaining = MAX_TOTAL_AUDIO - total_dur
if remaining <= 0:
skipped_names.append(fname)
continue
if dur > remaining:
# Truncate this file to fit
import soundfile as _sf
y, sr = _lr.load(src, sr=None, mono=False)
max_samples = int(remaining * sr)
if y.ndim == 1:
y = y[:max_samples]
else:
y = y[:, :max_samples]
dst = os.path.join(audio_dir, fname)
_sf.write(dst, y.T if y.ndim > 1 else y, sr)
truncated_names.append(f"{fname} ({dur:.0f}s -> {remaining:.0f}s)")
total_dur += remaining
accepted += 1
else:
shutil.copy2(src, os.path.join(audio_dir, fname))
total_dur += dur
accepted += 1
if truncated_names:
_log(f"[WARN] Truncated: {', '.join(truncated_names)}")
if skipped_names:
_log(f"[WARN] Skipped (over {MAX_TOTAL_AUDIO/60:.0f} min cap): {', '.join(skipped_names)}")
_log(f"[INFO] Total audio: {total_dur:.0f}s ({total_dur/60:.1f} min), {accepted} files")
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
# Caption each audio file via ace-server /understand BEFORE stopping it
if _server_ok():
_log("[INFO] Captioning audio via ace-server /understand...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
for audio_fname in sorted(os.listdir(audio_dir)):
full_path = os.path.join(audio_dir, audio_fname)
if not os.path.isfile(full_path) or audio_fname.endswith(".json"):
continue
caption_json_path = full_path + ".json"
caption_data = _caption_via_understand(full_path, timeout=120)
if caption_data:
_log(f"[Caption] {audio_fname}: using ace-server /understand")
with open(caption_json_path, "w") as cj:
json.dump(caption_data, cj)
else:
# Fallback to librosa for basic metadata
_log(f"[Caption] {audio_fname}: fallback to librosa")
try:
y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
tempo, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
bpm_val = float(tempo) if hasattr(tempo, '__float__') else float(tempo[0])
fallback = {"caption": "", "bpm": round(bpm_val), "key": "", "signature": "", "lyrics": ""}
with open(caption_json_path, "w") as cj:
json.dump(fallback, cj)
except Exception as cap_exc:
_log(f"[Caption] {audio_fname}: librosa fallback also failed: {cap_exc}")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
else:
_log("[INFO] ace-server not running, skipping /understand captioning")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
# Stop ace-server before training (frees memory)
global _training_in_progress
_training_in_progress = True
_log("[INFO] Stopping ace-server for training...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
_stop_ace_server()
_gc.collect()
try:
# -- Phase 1: Preprocessing --
_log("[Step 1/2] Preprocessing audio...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
def preprocess_progress(current, total, desc):
_log(f" {desc} ({current}/{total})")
result = preprocess_audio(
audio_dir=audio_dir,
output_dir=preprocessed_dir,
checkpoint_dir=ACE_CHECKPOINT_DIR,
device="cpu",
variant="turbo",
max_duration=float(MAX_TOTAL_AUDIO),
progress_callback=preprocess_progress,
cancel_check=lambda: False,
)
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
processed = result.get("processed", 0)
failed = result.get("failed", 0)
total = result.get("total", 0)
_log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if processed == 0:
_log("[FAIL] No files preprocessed successfully. Cannot train.")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
_gc.collect()
# -- Phase 2: Training --
_log("[Step 2/2] Training LoRA...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
for msg in train_lora_generator(
dataset_dir=preprocessed_dir,
output_dir=adapter_out,
checkpoint_dir=ACE_CHECKPOINT_DIR,
epochs=epochs,
lr=lr,
rank=rank,
alpha=rank * 2,
dropout=0.0,
batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=100,
weight_decay=0.01,
max_grad_norm=1.0,
save_every_n_epochs=max(1, epochs // 2),
seed=42,
variant="turbo",
device="cpu",
log_every=5,
):
# Timeout check
elapsed = time.time() - train_start
if elapsed > MAX_TRAINING_TIME:
_log(f"[WARN] Training timed out after {int(elapsed)}s")
cancel_training()
break
_log(msg)
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if msg.strip() == "[DONE]":
break
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
except Exception as exc:
_log(f"[FAIL] Training error: {exc}")
import traceback
_log(traceback.format_exc())
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
finally:
_training_in_progress = False
# Always restart ace-server
_log("[INFO] Restarting ace-server...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
_gc.collect()
ok = _start_ace_server()
if ok:
_log("[OK] ace-server restarted successfully")
else:
_log("[WARN] ace-server may not have restarted -- check logs")
adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
if os.path.isfile(adapter_safetensors):
# Copy to a temp file so Gradio doesn't try to validate /app paths
# (avoids InvalidPathError: "Cannot move /app to the gradio cache dir
# because it was not uploaded by a user")
tmp_out = tempfile.NamedTemporaryFile(
suffix=".safetensors",
prefix=f"{lora_name}_",
delete=False,
)
tmp_out.close()
shutil.copy2(adapter_safetensors, tmp_out.name)
_log(f"[OK] LoRA saved: {lora_name}")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
else:
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
# -- Cancel handler --
def _on_cancel():
cancel_training()
logger.info("Cancel requested by user")
return "Cancelling after current epoch... please wait"
# -- Check log handler --
def _check_log():
if _train_log_lines:
return "\n".join(_train_log_lines)
return "No training log available."
# -- Build LM model choices --
def _lm_model_choices():
return _scan_lm_models()
# -- Build UI --
CSS = """
.compact-row { gap: 8px !important; }
.status-box textarea { font-family: monospace; font-size: 13px; }
"""
with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo:
with gr.Tabs():
# ============================================================
# Tab 1: Generate Music
# ============================================================
with gr.Tab("Generate Music"):
gr.Markdown(
"**[ACE-Step 1.5 XL (CPU)](https://github.com/ace-step/ACE-Step-1.5)** "
"GGUF Q4_K_M via "
"[acestep.cpp](https://github.com/ServeurpersoCom/acestep.cpp)"
)
with gr.Row(elem_classes="compact-row"):
with gr.Column(scale=2):
caption = gr.Textbox(
label="Music Description",
lines=2,
value="upbeat electronic dance music, energetic synth leads",
)
lyrics = gr.Textbox(
label="Lyrics",
lines=3,
value="[Instrumental]",
placeholder="Enter lyrics or [Instrumental] for no vocals",
)
with gr.Column(scale=1):
audio_out = gr.Audio(label="Output", type="filepath")
status = gr.Textbox(
label="Status",
interactive=False,
lines=2,
elem_classes="status-box",
)
with gr.Row(elem_classes="compact-row"):
instrumental = gr.Checkbox(label="Instrumental", value=True, scale=1)
bpm = gr.Number(label="BPM", value=120, minimum=0, maximum=300, scale=1)
duration = gr.Slider(
label="Duration (s)", minimum=10, maximum=120,
value=10, step=5, scale=1,
)
steps = gr.Slider(
label="Steps", minimum=1, maximum=32,
value=8, step=1, scale=1,
)
seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
with gr.Row(elem_classes="compact-row"):
lora_select = gr.Dropdown(
label="LoRA", choices=_list_lora_choices(),
value="None (no LoRA)", scale=1,
allow_custom_value=True,
)
lm_model_select = gr.Dropdown(
label="LM Model", choices=_lm_model_choices(),
value=DEFAULT_LM, scale=1,
)
with gr.Row(elem_classes="compact-row"):
gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
status_btn = gr.Button("Server Status", scale=1)
gen_btn.click(
fn=generate_music,
inputs=[caption, lyrics, instrumental, bpm, duration,
seed, steps, lora_select, lm_model_select],
outputs=[audio_out, status],
api_name="generate",
)
status_btn.click(
fn=get_server_status,
inputs=[],
outputs=[status],
api_name="server_status",
)
# ============================================================
# Tab 2: Train LoRA
# ============================================================
with gr.Tab("Train LoRA"):
gr.Markdown(
"### LoRA Training\n"
"Fine-tune ACE-Step on your audio. "
"CPU training is slow -- ace-server stops during training."
)
with gr.Row(elem_classes="compact-row"):
with gr.Column(scale=2):
train_audio = gr.File(
label="Training Audio Files",
file_count="multiple",
file_types=["audio"],
)
with gr.Column(scale=1):
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
train_epochs = gr.Slider(
label="Epochs", minimum=1, maximum=1000,
value=3, step=1,
)
train_lr = gr.Number(label="Learning Rate", value=3e-4)
train_rank = gr.Slider(
label="Rank (r)", minimum=1, maximum=128,
value=32, step=1,
)
with gr.Row(elem_classes="compact-row"):
train_btn = gr.Button("Train", variant="primary", scale=2)
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
log_btn = gr.Button("Check Log", scale=1)
train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
train_log = gr.Textbox(
label="Training Log",
interactive=False,
lines=10,
elem_classes="status-box",
)
# Button swap on click (separate handler, like rvc-beatrice)
# This fires immediately so user sees Cancel even if training
# queues behind concurrency_limit=1
train_btn.click(
lambda: (gr.Button(visible=False), gr.Button(visible=True)),
outputs=[train_btn, cancel_btn],
)
# Training generator -- yields (log, train_btn, cancel_btn, output_file)
train_event = train_btn.click(
train_lora_ui,
inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
outputs=[train_log, train_btn, cancel_btn, train_output_file],
api_name="train_lora",
concurrency_limit=1,
)
# After training completes, restore buttons and refresh LoRA dropdown
# This ensures cleanup even if the user navigated away
def _post_training():
return (
gr.Button(visible=True),
gr.Button(visible=False),
gr.Dropdown(choices=_list_lora_choices()),
)
train_event.then(
_post_training,
outputs=[train_btn, cancel_btn, lora_select],
)
# Cancel: set the flag, update status
cancel_btn.click(
_on_cancel,
outputs=[train_log],
)
# Check log: show last training output
log_btn.click(
_check_log,
outputs=[train_log],
api_name="check_log",
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True,
css=CSS,
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# If any CLI arguments besides the script name, run CLI mode
# (Gradio sets no extra args; start.sh calls `python3 /app/app.py`)
if len(sys.argv) > 1:
cli_main()
else:
gradio_main()