ace-step15-endpoint / lora_ui.py
Andrew
Add LoRA upload flow in Space UI and update README
048d6f4
"""
ACE-Step 1.5 LoRA Training and Evaluation UI.
Gradio interface with four tabs:
1. Model Setup: initialize base DiT, VAE, and text encoder
2. Dataset: scan folder or drop files, then edit/save sidecars
3. Training: configure hyperparameters and run LoRA training
4. Evaluation: load adapters and run deterministic A/B generation
"""
import os
import sys
import json
import math
import random
import threading
import tempfile
import time
import shutil
import zipfile
from pathlib import Path
from typing import List, Optional
import gradio as gr
# On Hugging Face Spaces Zero, `spaces` must be imported before CUDA-related modules.
if os.getenv("SPACE_ID"):
try:
import spaces # noqa: F401
except Exception:
pass
import torch
from loguru import logger
# ---------------------------------------------------------------------------
# Ensure project root is on sys.path so `acestep` imports work
# ---------------------------------------------------------------------------
PROJECT_ROOT = str(Path(__file__).resolve().parent)
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from acestep.handler import AceStepHandler
from acestep.audio_utils import AudioSaver
from acestep.llm_inference import LLMHandler
from acestep.inference import understand_music
from lora_train import (
LoRATrainConfig,
LoRATrainer,
TrackEntry,
scan_dataset_folder,
scan_uploaded_files,
)
# ---------------------------------------------------------------------------
# Globals (shared across Gradio callbacks)
# ---------------------------------------------------------------------------
handler = AceStepHandler()
llm_handler = LLMHandler()
trainer: Optional[LoRATrainer] = None
dataset_entries: List[TrackEntry] = []
_training_thread: Optional[threading.Thread] = None
_training_log: List[str] = []
_training_status: str = "idle" # idle | running | stopped | done
_training_started_at: Optional[float] = None
_model_init_ok: bool = False
_model_init_status: str = ""
_last_model_init_args: Optional[dict] = None
_lm_init_ok: bool = False
_last_lm_init_args: Optional[dict] = None
_auto_label_cursor: int = 0
audio_saver = AudioSaver(default_format="wav")
IS_SPACE = bool(os.getenv("SPACE_ID"))
DEFAULT_OUTPUT_DIR = "/data/lora_output" if IS_SPACE else "lora_output"
DEFAULT_UPLOADED_ADAPTER_SUBDIR = "uploaded_adapters"
if IS_SPACE:
try:
import spaces as _hf_spaces
_gpu_callback = _hf_spaces.GPU(duration=300)
except Exception:
_gpu_callback = lambda fn: fn
else:
_gpu_callback = lambda fn: fn
def _rows_from_entries(entries: List[TrackEntry]):
rows = []
for e in entries:
rows.append([
Path(e.audio_path).name,
f"{e.duration:.1f}s" if e.duration else "?",
e.caption or "(none)",
e.lyrics[:60] + "..." if len(e.lyrics) > 60 else (e.lyrics or "(none)"),
e.vocal_language,
])
return rows
# ===========================================================================
# Tab 1 - Model Setup
# ===========================================================================
def get_available_models():
models = handler.get_available_acestep_v15_models()
return models if models else ["acestep-v15-base"]
def init_model(
model_name: str,
device: str,
offload_cpu: bool,
offload_dit_cpu: bool,
):
global _model_init_ok, _model_init_status, _last_model_init_args
_last_model_init_args = dict(
project_root=PROJECT_ROOT,
config_path=model_name,
device=device,
use_flash_attention=False,
compile_model=False,
offload_to_cpu=offload_cpu,
offload_dit_to_cpu=offload_dit_cpu,
)
status, ok = _init_model_gpu(**_last_model_init_args)
_model_init_ok = bool(ok)
_model_init_status = status or ""
return status
@_gpu_callback
def _init_model_gpu(**kwargs):
return _init_model_impl(**kwargs)
def _init_model_impl(**kwargs):
return handler.initialize_service(**kwargs)
# ===========================================================================
# Tab 2 - Dataset
# ===========================================================================
def scan_folder(folder_path: str):
global dataset_entries, _auto_label_cursor
if not folder_path or not os.path.isdir(folder_path):
return "Provide a valid folder path.", []
dataset_entries = scan_dataset_folder(folder_path)
_auto_label_cursor = 0
rows = _rows_from_entries(dataset_entries)
msg = f"Found {len(dataset_entries)} audio files."
return msg, rows
def load_uploaded(file_paths: List[str]):
global dataset_entries, _auto_label_cursor
if not file_paths:
return "Drop audio files (and optional .json sidecars) first.", []
sidecar_count = sum(
1 for p in file_paths if isinstance(p, str) and Path(p).suffix.lower() == ".json"
)
dataset_entries = scan_uploaded_files(file_paths)
_auto_label_cursor = 0
rows = _rows_from_entries(dataset_entries)
msg = (
f"Loaded {len(dataset_entries)} dropped audio files."
+ (f" Matched {sidecar_count} uploaded sidecar JSON file(s)." if sidecar_count else "")
)
return msg, rows
def save_sidecar(index: int, caption: str, lyrics: str, bpm: str, keyscale: str, lang: str):
"""Save metadata edits back to a JSON sidecar and update in-memory entry."""
global dataset_entries
if index < 0 or index >= len(dataset_entries):
return "Invalid track index."
entry = dataset_entries[index]
entry.caption = caption
entry.lyrics = lyrics
if bpm.strip():
try:
entry.bpm = int(float(bpm))
except ValueError:
return "Invalid BPM value. Use an integer or leave empty."
else:
entry.bpm = None
entry.keyscale = keyscale
entry.vocal_language = lang
sidecar_path = Path(entry.audio_path).with_suffix(".json")
meta = {
"caption": entry.caption,
"lyrics": entry.lyrics,
"bpm": entry.bpm,
"keyscale": entry.keyscale,
"timesignature": entry.timesignature,
"vocal_language": entry.vocal_language,
"duration": entry.duration,
}
sidecar_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
return f"Saved sidecar for {Path(entry.audio_path).name}"
def init_auto_label_lm(lm_model_path: str, lm_backend: str, lm_device: str):
global _lm_init_ok, _last_lm_init_args
_last_lm_init_args = dict(
lm_model_path=lm_model_path,
lm_backend=lm_backend,
lm_device=lm_device,
)
status = _init_auto_label_lm_gpu(**_last_lm_init_args)
_lm_init_ok = not str(status).startswith("LM init failed:") and not str(status).startswith("LM init exception:")
return status
@_gpu_callback
def _init_auto_label_lm_gpu(lm_model_path: str, lm_backend: str, lm_device: str):
return _init_auto_label_lm_impl(lm_model_path, lm_backend, lm_device)
def _init_auto_label_lm_impl(lm_model_path: str, lm_backend: str, lm_device: str):
"""Initialize LLM for dataset auto-labeling."""
checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")
full_lm_path = os.path.join(checkpoint_dir, lm_model_path)
try:
if not os.path.exists(full_lm_path):
from pathlib import Path as _Path
from acestep.model_downloader import ensure_main_model, ensure_lm_model
if lm_model_path == "acestep-5Hz-lm-1.7B":
ok, msg = ensure_main_model(
checkpoints_dir=_Path(checkpoint_dir),
prefer_source="huggingface",
)
else:
ok, msg = ensure_lm_model(
model_name=lm_model_path,
checkpoints_dir=_Path(checkpoint_dir),
prefer_source="huggingface",
)
if not ok:
return f"Failed to download LM model: {msg}"
status, ok = llm_handler.initialize(
checkpoint_dir=checkpoint_dir,
lm_model_path=lm_model_path,
backend=lm_backend,
device=lm_device,
offload_to_cpu=False,
)
return status if ok else f"LM init failed:\n{status}"
except Exception as exc:
logger.exception("LM init failed for auto-label")
return f"LM init exception: {exc}"
def _write_entry_sidecar(entry: TrackEntry):
sidecar_path = Path(entry.audio_path).with_suffix(".json")
meta = {
"caption": entry.caption,
"lyrics": entry.lyrics,
"bpm": entry.bpm,
"keyscale": entry.keyscale,
"timesignature": entry.timesignature,
"vocal_language": entry.vocal_language,
"duration": entry.duration,
}
sidecar_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
@_gpu_callback
def auto_label_all(overwrite_existing: bool, caption_only: bool, max_files_per_run: int = 6, reset_cursor: bool = False):
"""Auto-label all loaded tracks using ACE audio understanding (audio->codes->metadata)."""
global dataset_entries, _auto_label_cursor
if handler.model is None:
if _model_init_ok and _last_model_init_args:
status, ok = _init_model_impl(**_last_model_init_args)
if not ok:
return f"Model reload failed before auto-label:\n{status}", [], "Auto-label skipped."
else:
return "Initialize model first in Step 1.", [], "Auto-label skipped."
if not dataset_entries:
return "Load dataset first in Step 2.", [], "Auto-label skipped."
if not llm_handler.llm_initialized:
if _lm_init_ok and _last_lm_init_args:
status = _init_auto_label_lm_impl(**_last_lm_init_args)
if not llm_handler.llm_initialized:
return (
f"Auto-label LM reload failed:\n{status}",
_rows_from_entries(dataset_entries),
"Auto-label skipped.",
)
else:
return "Initialize Auto-Label LM first.", _rows_from_entries(dataset_entries), "Auto-label skipped."
if max_files_per_run <= 0:
max_files_per_run = 6
if reset_cursor:
_auto_label_cursor = 0
if _auto_label_cursor < 0 or _auto_label_cursor >= len(dataset_entries):
_auto_label_cursor = 0
start_idx = _auto_label_cursor
end_idx = min(len(dataset_entries), start_idx + int(max_files_per_run))
updated = 0
skipped = 0
failed = 0
logs: List[str] = []
for idx in range(start_idx, end_idx):
entry = dataset_entries[idx]
try:
missing_fields = []
if not (entry.caption or "").strip():
missing_fields.append("caption")
if (not caption_only) and (not (entry.lyrics or "").strip()):
missing_fields.append("lyrics")
if entry.bpm is None:
missing_fields.append("bpm")
if not (entry.keyscale or "").strip():
missing_fields.append("keyscale")
if entry.duration is None:
missing_fields.append("duration")
# Skip only when every core field is already available.
if (not overwrite_existing) and (len(missing_fields) == 0):
skipped += 1
logs.append(f"[{idx}] Skipped (already fully labeled): {Path(entry.audio_path).name}")
continue
codes = handler.convert_src_audio_to_codes(entry.audio_path)
if not codes or codes.startswith("❌"):
failed += 1
logs.append(f"[{idx}] Failed to convert audio to codes: {Path(entry.audio_path).name}")
continue
result = understand_music(
llm_handler=llm_handler,
audio_codes=codes,
temperature=0.85,
use_constrained_decoding=True,
constrained_decoding_debug=False,
)
if not result.success:
failed += 1
logs.append(f"[{idx}] Failed to label: {Path(entry.audio_path).name} ({result.error or result.status_message})")
continue
# Update fields. If overwrite is false, fill only missing values.
if overwrite_existing or not (entry.caption or "").strip():
entry.caption = (result.caption or entry.caption or "").strip()
if not caption_only:
if overwrite_existing or not (entry.lyrics or "").strip():
entry.lyrics = (result.lyrics or entry.lyrics or "").strip()
if entry.bpm is None and result.bpm is not None:
entry.bpm = int(result.bpm)
if (not entry.keyscale) and result.keyscale:
entry.keyscale = result.keyscale
if (not entry.timesignature) and result.timesignature:
entry.timesignature = result.timesignature
if (not entry.vocal_language) and result.language:
entry.vocal_language = result.language
if entry.duration is None and result.duration is not None:
entry.duration = float(result.duration)
_write_entry_sidecar(entry)
updated += 1
logs.append(f"[{idx}] Labeled: {Path(entry.audio_path).name}")
except Exception as exc:
failed += 1
logs.append(f"[{idx}] Exception: {Path(entry.audio_path).name} ({exc})")
_auto_label_cursor = 0 if end_idx >= len(dataset_entries) else end_idx
mode = "caption-only" if caption_only else "caption+lyrics"
progress_msg = (
f"Processed batch {start_idx + 1}-{end_idx} of {len(dataset_entries)}. "
if len(dataset_entries) > 0 else ""
)
if _auto_label_cursor == 0 and len(dataset_entries) > 0:
progress_msg += "Reached end of dataset."
else:
progress_msg += f"Next start index: {_auto_label_cursor}."
summary = (
f"Auto-label ({mode}) complete. Updated={updated}, Skipped={skipped}, Failed={failed}. "
f"{progress_msg}"
)
detail = "\n".join(logs[-40:]) if logs else "No logs."
return summary, _rows_from_entries(dataset_entries), detail
# ===========================================================================
# Tab 3 - Training
# ===========================================================================
def _run_training(config_dict: dict):
"""Target for the background training thread."""
global trainer, _training_status, _training_log, _training_started_at
_training_status = "running"
_training_log.clear()
_training_started_at = time.time()
try:
cfg = LoRATrainConfig(**config_dict)
trainer = LoRATrainer(handler, cfg)
trainer.prepare()
_training_log.append(f"Training device: {handler.device}")
def _cb(step, total, loss, epoch):
elapsed = 0.0 if _training_started_at is None else max(0.0, time.time() - _training_started_at)
rate = (step / elapsed) if elapsed > 0 else 0.0
remaining = max(0, total - step)
eta_sec = (remaining / rate) if rate > 0 else -1.0
eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown"
msg = (
f"Step {step}/{total} Epoch {epoch+1} Loss {loss:.6f} "
f"Elapsed {elapsed/60:.1f}m ETA {eta_msg}"
)
_training_log.append(msg)
result = trainer.train(dataset_entries, progress_callback=_cb)
_training_log.append(result)
_training_status = "done"
except Exception as exc:
_training_log.append(f"ERROR: {exc}")
_training_status = "stopped"
logger.exception("Training failed")
def start_training(
lora_rank, lora_alpha, lora_dropout,
lr, weight_decay, optimizer_name,
max_grad_norm, warmup_ratio, scheduler_name,
num_epochs, batch_size, grad_accum,
save_every, log_every, shift,
max_duration, output_dir, resume_from,
):
global _training_thread, _training_status
if handler.model is None:
return "Model not initialised. Go to Model Setup first."
if not dataset_entries:
return "No dataset loaded. Go to Dataset tab first."
if _training_status == "running":
return "Training already in progress."
config_dict = dict(
lora_rank=int(lora_rank),
lora_alpha=int(lora_alpha),
lora_dropout=float(lora_dropout),
learning_rate=float(lr),
weight_decay=float(weight_decay),
optimizer=optimizer_name,
max_grad_norm=float(max_grad_norm),
warmup_ratio=float(warmup_ratio),
scheduler=scheduler_name,
num_epochs=int(num_epochs),
batch_size=int(batch_size),
gradient_accumulation_steps=int(grad_accum),
save_every_n_epochs=int(save_every),
log_every_n_steps=int(log_every),
shift=float(shift),
max_duration_sec=float(max_duration),
output_dir=output_dir,
resume_from=(resume_from.strip() if isinstance(resume_from, str) and resume_from.strip() else None),
device=str(handler.device),
)
steps_per_epoch = math.ceil(len(dataset_entries) / int(batch_size))
total_steps = steps_per_epoch * int(num_epochs)
total_optim_steps = math.ceil(total_steps / int(grad_accum))
_training_thread = threading.Thread(target=_run_training, args=(config_dict,), daemon=True)
_training_thread.start()
return (
f"Training started on {handler.device}. "
f"Estimated optimiser steps: {total_optim_steps}."
)
def stop_training():
global trainer, _training_status
if trainer:
trainer.request_stop()
_training_status = "stopped"
return "Stop requested - will finish current step."
return "No training in progress."
def poll_training():
"""Return current log + loss chart data."""
log_text = "\n".join(_training_log[-50:]) if _training_log else "(no output yet)"
# Build loss curve data
chart_data = []
if trainer and trainer.loss_history:
chart_data = [[h["step"], h["loss"]] for h in trainer.loss_history]
status = _training_status
device_line = f"Device: {handler.device}"
if torch.cuda.is_available() and str(handler.device).startswith("cuda"):
try:
idx = torch.cuda.current_device()
name = torch.cuda.get_device_name(idx)
allocated = torch.cuda.memory_allocated(idx) / (1024 ** 3)
reserved = torch.cuda.memory_reserved(idx) / (1024 ** 3)
device_line = (
f"Device: {handler.device} ({name}) | "
f"VRAM allocated={allocated:.2f}GB reserved={reserved:.2f}GB"
)
except Exception:
pass
return f"Status: {status}\n{device_line}\n\n{log_text}", chart_data
# ===========================================================================
# Tab 4 - Evaluation / A-B Test
# ===========================================================================
def list_adapters(output_dir: str):
adapters = LoRATrainer.list_adapters(output_dir)
return adapters if adapters else ["(none found)"]
def _safe_adapter_name(name: str) -> str:
name = (name or "").strip()
if not name:
return f"adapter_{int(time.time())}"
out = []
for ch in name:
if ch.isalnum() or ch in ("-", "_", "."):
out.append(ch)
else:
out.append("_")
cleaned = "".join(out).strip("._")
return cleaned or f"adapter_{int(time.time())}"
def _safe_extract_zip(zip_path: str, target_dir: Path) -> int:
extracted = 0
target_resolved = target_dir.resolve()
with zipfile.ZipFile(zip_path, "r") as zf:
for member in zf.infolist():
member_path = (target_dir / member.filename).resolve()
if not str(member_path).startswith(str(target_resolved)):
raise RuntimeError(f"Unsafe archive path detected: {member.filename}")
zf.extractall(target_dir)
extracted = len(zf.namelist())
return extracted
def upload_adapter_files(uploaded_files: List[str], adapter_dir: str, adapter_name: str):
"""Upload LoRA adapter files/zip and make them available in adapter dropdown."""
if not uploaded_files:
adapters = list_adapters(adapter_dir)
return "Please upload .zip or adapter files first.", gr.update(choices=adapters, value=adapters[0] if adapters else None)
root_dir = Path(adapter_dir or DEFAULT_OUTPUT_DIR)
target_root = root_dir / DEFAULT_UPLOADED_ADAPTER_SUBDIR
target_root.mkdir(parents=True, exist_ok=True)
target_dir = target_root / _safe_adapter_name(adapter_name)
target_dir.mkdir(parents=True, exist_ok=True)
copied = 0
extracted = 0
try:
# If a single zip is uploaded, extract it; otherwise copy files directly.
if len(uploaded_files) == 1 and str(uploaded_files[0]).lower().endswith(".zip"):
zip_path = uploaded_files[0]
extracted = _safe_extract_zip(zip_path, target_dir)
else:
for src in uploaded_files:
src_path = Path(src)
if not src_path.exists():
continue
dst = target_dir / src_path.name
shutil.copy2(src_path, dst)
copied += 1
found = sorted({str(p.parent) for p in target_dir.rglob("adapter_config.json")})
if not found:
adapters = list_adapters(str(root_dir))
return (
f"Uploaded to {target_dir}, but no adapter_config.json found. "
"Upload a valid LoRA adapter folder or zip.",
gr.update(choices=adapters, value=adapters[0] if adapters else None),
)
adapters = list_adapters(str(root_dir))
primary = found[0]
msg = (
f"Adapter upload complete. Copied {copied} file(s), extracted {extracted} archive entries. "
f"Detected {len(found)} adapter path(s). Primary: {primary}"
)
return msg, gr.update(choices=adapters, value=primary)
except Exception as exc:
logger.exception("Adapter upload failed")
adapters = list_adapters(str(root_dir))
return f"Adapter upload failed: {exc}", gr.update(choices=adapters, value=adapters[0] if adapters else None)
@_gpu_callback
def load_adapter(adapter_path: str):
if not adapter_path or adapter_path == "(none found)":
return "Select a valid adapter path."
return handler.load_lora(adapter_path)
@_gpu_callback
def unload_adapter():
return handler.unload_lora()
def set_lora_scale(scale: float):
return handler.set_lora_scale(scale)
@_gpu_callback
def generate_sample(
prompt: str,
lyrics: str,
duration: float,
bpm: int,
steps: int,
guidance: float,
seed: int,
use_lora: bool,
lora_scale: float,
):
"""Generate a single audio sample for evaluation."""
if handler.model is None:
return None, "Model not initialised."
# Toggle LoRA if loaded
if handler.lora_loaded:
handler.set_use_lora(use_lora)
if use_lora:
handler.set_lora_scale(lora_scale)
actual_seed = int(seed) if seed >= 0 else random.randint(0, 2**32 - 1)
result = handler.generate_music(
captions=prompt,
lyrics=lyrics,
bpm=bpm if bpm > 0 else None,
inference_steps=steps,
guidance_scale=guidance,
use_random_seed=False,
seed=actual_seed,
audio_duration=duration,
batch_size=1,
)
if not result.get("success", False):
return None, result.get("error", "Generation failed.")
audios = result.get("audios", [])
if not audios:
return None, "No audio produced."
# Save to temp file
audio_data = audios[0]
wav_tensor = audio_data.get("tensor")
sr = audio_data.get("sample_rate", 48000)
if wav_tensor is None:
path = audio_data.get("path")
if path and os.path.exists(path):
return path, f"Generated (from file), seed={actual_seed}."
return None, "No audio tensor."
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
audio_saver.save_audio(wav_tensor, tmp.name, sample_rate=sr)
return tmp.name, f"Generated successfully, seed={actual_seed}."
@_gpu_callback
def ab_test(
prompt, lyrics, duration, bpm, steps, guidance, seed,
lora_scale_b,
):
"""Generate two samples: A = base, B = LoRA at given scale."""
resolved_seed = int(seed) if seed >= 0 else random.randint(0, 2**32 - 1)
results = {}
for label, use, scale in [("A (base)", False, 0.0), ("B (LoRA)", True, lora_scale_b)]:
path, msg = generate_sample(
prompt, lyrics, duration, bpm, steps, guidance, resolved_seed,
use_lora=use, lora_scale=scale,
)
results[label] = (path, msg)
return (
results["A (base)"][0],
results["A (base)"][1],
results["B (LoRA)"][0],
results["B (LoRA)"][1],
)
# ===========================================================================
# Build the Gradio App
# ===========================================================================
def get_workflow_status():
model_is_ready = (handler.model is not None) or _model_init_ok
model_ready = "Ready" if model_is_ready else "Not initialized"
tracks = len(dataset_entries)
training_state = _training_status
lora_status = handler.get_lora_status() if handler.model is not None else {"loaded": False, "active": False, "scale": 1.0}
init_note = ""
if IS_SPACE and _model_init_ok and handler.model is None:
init_note = " (Zero GPU callback context)"
return (
f"Model: {model_ready}{init_note}\n"
f"Tracks Loaded: {tracks}\n"
f"Training: {training_state}\n"
f"LoRA Loaded: {lora_status.get('loaded', False)}\n"
f"LoRA Active: {lora_status.get('active', False)}\n"
f"LoRA Scale: {lora_status.get('scale', 1.0)}"
)
def init_model_and_status(
model_name: str,
device: str,
offload_cpu: bool,
offload_dit_cpu: bool,
):
status = init_model(model_name, device, offload_cpu, offload_dit_cpu)
return status, get_workflow_status()
def build_ui():
available_models = get_available_models()
with gr.Blocks(title="ACE-Step 1.5 LoRA Studio", theme=gr.themes.Soft()) as app:
gr.Markdown(
"# ACE-Step 1.5 LoRA Studio\n"
"Use this guided workflow from left to right.\n\n"
"**Step 1:** Initialize model \n"
"**Step 2:** Load dataset \n"
"**Step 3:** Start training \n"
"**Step 4:** Evaluate adapter"
)
with gr.Row():
workflow_status = gr.Textbox(label="Workflow Status", value=get_workflow_status(), lines=6, interactive=False)
refresh_status_btn = gr.Button("Refresh Status")
refresh_status_btn.click(get_workflow_status, outputs=workflow_status, api_name="workflow_status")
# ---- Step 1 ----
with gr.Tab("Step 1 - Initialize Model"):
gr.Markdown(
"### Instructions\n"
"1. Pick a model (`acestep-v15-base` recommended for LoRA).\n"
"2. Keep device on `auto` unless you need manual override.\n"
"3. Click **Initialize Model** and confirm status is success."
)
with gr.Row():
model_dd = gr.Dropdown(
choices=available_models,
value=available_models[0] if available_models else None,
label="DiT Model",
)
device_dd = gr.Dropdown(
choices=["auto", "cuda", "mps", "cpu"],
value="auto",
label="Device",
)
with gr.Row():
offload_cb = gr.Checkbox(label="Offload To CPU (optional)", value=False)
offload_dit_cb = gr.Checkbox(label="Offload DiT To CPU (optional)", value=False)
init_btn = gr.Button("Initialize Model", variant="primary")
init_out = gr.Textbox(label="Initialization Output", lines=8, interactive=False)
init_btn.click(
init_model_and_status,
[model_dd, device_dd, offload_cb, offload_dit_cb],
[init_out, workflow_status],
api_name="init_model",
)
# ---- Step 2 ----
with gr.Tab("Step 2 - Load Dataset"):
gr.Markdown(
"### Instructions\n"
"1. Either scan a folder or drag/drop audio files (+ optional .json sidecars).\n"
"2. Confirm tracks appear in the table.\n"
"3. Optional: run Auto-Label All to fill caption/lyrics/metas.\n"
"4. Optional: edit metadata manually and save sidecar JSON."
)
with gr.Row():
folder_input = gr.Textbox(label="Dataset Folder Path", placeholder="e.g. ./dataset_inbox")
scan_btn = gr.Button("Scan Folder")
with gr.Row():
upload_files = gr.Files(
label="Drag/Drop Audio Files (+ Optional JSON Sidecars)",
file_count="multiple",
file_types=["audio", ".json"],
type="filepath",
)
upload_btn = gr.Button("Load Dropped Files")
scan_msg = gr.Textbox(label="Dataset Result", interactive=False)
dataset_table = gr.Dataframe(
headers=["File", "Duration", "Caption", "Lyrics", "Language"],
datatype=["str", "str", "str", "str", "str"],
label="Tracks",
interactive=False,
)
scan_btn.click(
scan_folder,
folder_input,
[scan_msg, dataset_table],
api_name="scan_folder",
)
upload_btn.click(
load_uploaded,
upload_files,
[scan_msg, dataset_table],
api_name="load_uploaded",
)
with gr.Accordion("Auto-Label (ACE audio understanding)", open=False):
gr.Markdown(
"Auto-label uses ACE: audio -> semantic codes -> metadata/lyrics.\n"
"Initialize LM first, then run Auto-Label All.\n"
"Use Caption-Only if your dataset has no lyrics.\n"
"On Zero GPU, process in smaller batches and click Auto-Label All repeatedly."
)
with gr.Row():
lm_model_dd = gr.Dropdown(
choices=["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
value="acestep-5Hz-lm-0.6B",
label="Auto-Label LM Model",
)
lm_backend_dd = gr.Dropdown(
choices=["pt", "vllm", "mlx"],
value="pt",
label="LM Backend",
)
lm_device_dd = gr.Dropdown(
choices=["auto", "cuda", "mps", "xpu", "cpu"],
value="auto",
label="LM Device",
)
with gr.Row():
init_lm_btn = gr.Button("Initialize Auto-Label LM")
overwrite_cb = gr.Checkbox(label="Overwrite Existing Caption/Lyrics", value=False)
caption_only_cb = gr.Checkbox(label="Caption-Only (Skip Lyrics)", value=True)
auto_label_btn = gr.Button("Auto-Label All", variant="primary")
with gr.Row():
max_files_per_run = gr.Slider(1, 25, value=6, step=1, label="Files Per Run (Zero GPU Safe)")
reset_cursor_cb = gr.Checkbox(label="Restart From First Track", value=False)
lm_init_status = gr.Textbox(label="Auto-Label LM Status", lines=5, interactive=False)
auto_label_status = gr.Textbox(label="Auto-Label Summary", interactive=False)
auto_label_log = gr.Textbox(label="Auto-Label Log", lines=8, interactive=False)
init_lm_btn.click(
init_auto_label_lm,
[lm_model_dd, lm_backend_dd, lm_device_dd],
lm_init_status,
api_name="init_auto_label_lm",
)
auto_label_btn.click(
auto_label_all,
[overwrite_cb, caption_only_cb, max_files_per_run, reset_cursor_cb],
[auto_label_status, dataset_table, auto_label_log],
api_name="auto_label_all",
)
with gr.Accordion("Optional: Edit Metadata Sidecar", open=False):
with gr.Row():
edit_idx = gr.Number(label="Track Index (0-based)", value=0, precision=0)
edit_caption = gr.Textbox(label="Caption")
edit_lyrics = gr.Textbox(label="Lyrics", lines=3)
with gr.Row():
edit_bpm = gr.Textbox(label="BPM", placeholder="e.g. 120")
edit_key = gr.Textbox(label="Key/Scale", placeholder="e.g. Am")
edit_lang = gr.Textbox(label="Language", value="en")
save_btn = gr.Button("Save Sidecar")
save_msg = gr.Textbox(label="Save Result", interactive=False)
save_btn.click(
save_sidecar,
[edit_idx, edit_caption, edit_lyrics, edit_bpm, edit_key, edit_lang],
save_msg,
api_name="save_sidecar",
)
# ---- Step 3 ----
with gr.Tab("Step 3 - Train LoRA"):
gr.Markdown(
"### Instructions\n"
"1. Keep default settings for first run.\n"
"2. Set output directory (defaults are good).\n"
"3. Click **Start Training** and monitor logs/loss.\n"
"4. Use **Stop Training** for graceful stop."
)
with gr.Row():
t_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs")
t_bs = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
t_accum = gr.Slider(1, 16, value=1, step=1, label="Grad Accumulation")
with gr.Row():
t_outdir = gr.Textbox(label="Output Directory", value=DEFAULT_OUTPUT_DIR)
t_resume = gr.Textbox(label="Resume From Adapter Directory (optional)", value="")
with gr.Accordion("Advanced Training Settings (optional)", open=False):
with gr.Row():
t_rank = gr.Slider(4, 256, value=64, step=4, label="LoRA Rank")
t_alpha = gr.Slider(4, 256, value=64, step=4, label="LoRA Alpha")
t_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="LoRA Dropout")
with gr.Row():
t_lr = gr.Number(label="Learning Rate", value=1e-4)
t_wd = gr.Number(label="Weight Decay", value=0.01)
t_optim = gr.Dropdown(["adamw", "adamw_8bit"], value="adamw_8bit", label="Optimizer")
with gr.Row():
t_grad_norm = gr.Number(label="Max Grad Norm", value=1.0)
t_warmup = gr.Number(label="Warmup Ratio", value=0.03)
t_sched = gr.Dropdown(
["constant_with_warmup", "linear", "cosine"],
value="constant_with_warmup",
label="Scheduler",
)
with gr.Row():
t_save = gr.Slider(1, 100, value=10, step=1, label="Save Every N Epochs")
t_log = gr.Slider(1, 100, value=5, step=1, label="Log Every N Steps")
t_shift = gr.Number(label="Timestep Shift", value=3.0)
t_maxdur = gr.Number(label="Max Audio Duration (s)", value=240)
with gr.Row():
train_btn = gr.Button("Start Training", variant="primary")
stop_btn = gr.Button("Stop Training", variant="stop")
poll_btn = gr.Button("Refresh Log")
train_status = gr.Textbox(label="Training Log", lines=12, interactive=False)
loss_chart = gr.LinePlot(
x="Step",
y="Loss",
title="Training Loss",
x_title="Step",
y_title="Loss",
)
train_btn.click(
start_training,
[
t_rank, t_alpha, t_dropout,
t_lr, t_wd, t_optim,
t_grad_norm, t_warmup, t_sched,
t_epochs, t_bs, t_accum,
t_save, t_log, t_shift,
t_maxdur, t_outdir, t_resume,
],
train_status,
api_name="start_training",
)
stop_btn.click(stop_training, outputs=train_status, api_name="stop_training")
def _poll_and_format():
log_text, chart_data = poll_training()
if chart_data:
import pandas as pd
df = pd.DataFrame(chart_data, columns=["Step", "Loss"])
else:
import pandas as pd
df = pd.DataFrame({"Step": [], "Loss": []})
return log_text, df
poll_btn.click(_poll_and_format, outputs=[train_status, loss_chart], api_name="poll_training")
# ---- Step 4 ----
with gr.Tab("Step 4 - Evaluate"):
gr.Markdown(
"### Instructions\n"
"1. Refresh adapter list and load a trained adapter.\n"
"2. Run single generation or A/B test.\n"
"3. Use same seed for fair comparison."
)
with gr.Accordion("Adapter Management", open=True):
with gr.Row():
adapter_dir = gr.Textbox(label="Adapters Directory", value=DEFAULT_OUTPUT_DIR)
refresh_btn = gr.Button("Refresh List")
adapter_dd = gr.Dropdown(label="Select Adapter", choices=[])
with gr.Row():
upload_adapter_files_input = gr.Files(
label="Upload LoRA Adapter (.zip or adapter files)",
file_count="multiple",
file_types=[".zip", ".json", ".safetensors", ".bin", ".pt", ".pth"],
type="filepath",
)
upload_adapter_name = gr.Textbox(
label="Uploaded Adapter Name (optional)",
placeholder="my-lora-adapter",
)
upload_adapter_btn = gr.Button("Upload Adapter")
with gr.Row():
load_btn = gr.Button("Load Adapter", variant="primary")
unload_btn = gr.Button("Unload Adapter")
adapter_status = gr.Textbox(label="Adapter Status", interactive=False)
def _refresh(d):
adapters = list_adapters(d)
return gr.update(choices=adapters, value=adapters[0] if adapters else None)
refresh_btn.click(_refresh, adapter_dir, adapter_dd, api_name="list_adapters")
upload_adapter_btn.click(
upload_adapter_files,
[upload_adapter_files_input, adapter_dir, upload_adapter_name],
[adapter_status, adapter_dd],
api_name="upload_adapter_files",
)
load_btn.click(load_adapter, adapter_dd, adapter_status, api_name="load_adapter")
unload_btn.click(unload_adapter, outputs=adapter_status, api_name="unload_adapter")
with gr.Accordion("Generation Settings", open=True):
with gr.Row():
eval_prompt = gr.Textbox(label="Prompt / Caption", lines=2, placeholder="upbeat pop rock with electric guitar")
eval_lyrics = gr.Textbox(label="Lyrics", lines=3, placeholder="[Instrumental]")
with gr.Row():
eval_dur = gr.Slider(10, 300, value=30, step=5, label="Duration (s)")
eval_bpm = gr.Number(label="BPM (0 = auto)", value=0)
eval_steps = gr.Slider(1, 100, value=8, step=1, label="Inference Steps")
with gr.Row():
eval_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.5, label="Guidance Scale")
eval_seed = gr.Number(label="Seed (-1 = random)", value=-1)
with gr.Row():
sg_use_lora = gr.Checkbox(label="Use LoRA", value=True)
sg_scale = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="LoRA Scale")
sg_btn = gr.Button("Generate", variant="primary")
sg_audio = gr.Audio(label="Single Output", type="filepath")
sg_msg = gr.Textbox(label="Generation Status", interactive=False)
sg_btn.click(
generate_sample,
[eval_prompt, eval_lyrics, eval_dur, eval_bpm, eval_steps, eval_guidance, eval_seed, sg_use_lora, sg_scale],
[sg_audio, sg_msg],
api_name="generate_sample",
)
gr.Markdown("#### A/B Test (Base vs LoRA)")
with gr.Row():
ab_scale = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="LoRA Scale for B")
ab_btn = gr.Button("Run A/B Test")
with gr.Row():
ab_audio_a = gr.Audio(label="A - Base", type="filepath")
ab_audio_b = gr.Audio(label="B - Base + LoRA", type="filepath")
with gr.Row():
ab_msg_a = gr.Textbox(label="Status A", interactive=False)
ab_msg_b = gr.Textbox(label="Status B", interactive=False)
ab_btn.click(
ab_test,
[eval_prompt, eval_lyrics, eval_dur, eval_bpm, eval_steps, eval_guidance, eval_seed, ab_scale],
[ab_audio_a, ab_msg_a, ab_audio_b, ab_msg_b],
api_name="ab_test",
)
app.queue(default_concurrency_limit=1)
return app
# ===========================================================================
# Entry point
# ===========================================================================
if __name__ == "__main__":
app = build_ui()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
)