Spaces:
Running
Running
add LoRA training, fix css kwarg
Browse files- Dockerfile +7 -2
- app.py +147 -14
Dockerfile
CHANGED
|
@@ -69,8 +69,13 @@ RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/Qwen3-Embedding-0.6B-Q8_0.
|
|
| 69 |
RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
|
| 70 |
"https://huggingface.co/Serveurperso/ACE-Step-1.5-GGUF/resolve/main/vae-BF16.gguf"
|
| 71 |
|
| 72 |
-
# Install Python deps for Gradio UI
|
| 73 |
-
RUN pip3 install --no-cache-dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Copy application files
|
| 76 |
COPY app.py /app/app.py
|
|
|
|
| 69 |
RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
|
| 70 |
"https://huggingface.co/Serveurperso/ACE-Step-1.5-GGUF/resolve/main/vae-BF16.gguf"
|
| 71 |
|
| 72 |
+
# Install Python deps for Gradio UI + training
|
| 73 |
+
RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
|
| 74 |
+
gradio==5.29.0 requests torch safetensors \
|
| 75 |
+
transformers>=4.51.0 peft>=0.18.0 accelerate>=1.12.0
|
| 76 |
+
|
| 77 |
+
# Clone ACE-Step repo for training module
|
| 78 |
+
RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
|
| 79 |
|
| 80 |
# Copy application files
|
| 81 |
COPY app.py /app/app.py
|
app.py
CHANGED
|
@@ -12,6 +12,11 @@ ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
|
|
| 12 |
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
|
| 13 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# ---------------------------------------------------------------------------
|
| 16 |
# ace-server helpers
|
| 17 |
# ---------------------------------------------------------------------------
|
|
@@ -280,14 +285,143 @@ def gradio_main():
|
|
| 280 |
lines.append(json.dumps(props, indent=2))
|
| 281 |
return "\n".join(lines)
|
| 282 |
|
| 283 |
-
# -- Training
|
| 284 |
-
def
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
# -- Build UI --
|
| 293 |
CSS = """
|
|
@@ -295,7 +429,7 @@ def gradio_main():
|
|
| 295 |
.status-box textarea { font-family: monospace; font-size: 13px; }
|
| 296 |
"""
|
| 297 |
|
| 298 |
-
with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo:
|
| 299 |
|
| 300 |
with gr.Tabs():
|
| 301 |
# ============================================================
|
|
@@ -373,7 +507,7 @@ def gradio_main():
|
|
| 373 |
gr.Markdown(
|
| 374 |
"### LoRA Training\n"
|
| 375 |
"Fine-tune ACE-Step on your own audio data. "
|
| 376 |
-
"
|
| 377 |
)
|
| 378 |
|
| 379 |
with gr.Row(elem_classes="compact-row"):
|
|
@@ -385,9 +519,9 @@ def gradio_main():
|
|
| 385 |
)
|
| 386 |
with gr.Column(scale=1):
|
| 387 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
| 388 |
-
epochs = gr.Number(label="Epochs", value=
|
| 389 |
lr = gr.Number(label="Learning Rate", value=1e-4)
|
| 390 |
-
rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=
|
| 391 |
|
| 392 |
train_btn = gr.Button("Train", variant="primary")
|
| 393 |
train_log = gr.Textbox(
|
|
@@ -398,7 +532,7 @@ def gradio_main():
|
|
| 398 |
)
|
| 399 |
|
| 400 |
train_btn.click(
|
| 401 |
-
fn=
|
| 402 |
inputs=[train_audio, lora_name, epochs, lr, rank],
|
| 403 |
outputs=[train_log],
|
| 404 |
api_name="train_lora",
|
|
@@ -408,7 +542,6 @@ def gradio_main():
|
|
| 408 |
server_name="0.0.0.0",
|
| 409 |
server_port=7860,
|
| 410 |
mcp_server=True,
|
| 411 |
-
css=CSS,
|
| 412 |
)
|
| 413 |
|
| 414 |
|
|
|
|
| 12 |
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
|
| 13 |
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 14 |
|
| 15 |
+
ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
|
| 16 |
+
ACE_SOURCE_DIR = "/app/ace-step-source"
|
| 17 |
+
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
|
| 18 |
+
ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
|
| 19 |
+
|
| 20 |
# ---------------------------------------------------------------------------
|
| 21 |
# ace-server helpers
|
| 22 |
# ---------------------------------------------------------------------------
|
|
|
|
| 285 |
lines.append(json.dumps(props, indent=2))
|
| 286 |
return "\n".join(lines)
|
| 287 |
|
| 288 |
+
# -- Training --
|
| 289 |
+
def train_lora(audio_files, lora_name, epochs, lr, rank,
|
| 290 |
+
progress=gr.Progress(track_tqdm=True)):
|
| 291 |
+
import shutil
|
| 292 |
+
import gc
|
| 293 |
+
|
| 294 |
+
if not audio_files:
|
| 295 |
+
return "No audio files uploaded."
|
| 296 |
+
|
| 297 |
+
lora_name = (lora_name or "").strip() or "my-lora"
|
| 298 |
+
epochs = max(1, min(int(epochs), 10))
|
| 299 |
+
lr = float(lr)
|
| 300 |
+
rank = max(1, min(int(rank), 64))
|
| 301 |
+
|
| 302 |
+
output_dir = os.path.join(ADAPTER_DIR, lora_name)
|
| 303 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 304 |
+
|
| 305 |
+
audio_dir = os.path.join(output_dir, "audio_input")
|
| 306 |
+
os.makedirs(audio_dir, exist_ok=True)
|
| 307 |
+
for f in audio_files:
|
| 308 |
+
src = f.name if hasattr(f, "name") else str(f)
|
| 309 |
+
shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
|
| 310 |
+
|
| 311 |
+
log_lines = [
|
| 312 |
+
f"LoRA Training: '{lora_name}'",
|
| 313 |
+
f"Audio files: {len(audio_files)}",
|
| 314 |
+
f"Epochs: {epochs}, LR: {lr}, Rank: {rank}",
|
| 315 |
+
f"Output: {output_dir}",
|
| 316 |
+
"",
|
| 317 |
+
]
|
| 318 |
+
|
| 319 |
+
try:
|
| 320 |
+
ckpt_files = os.listdir(ACE_CHECKPOINT_DIR) if os.path.isdir(ACE_CHECKPOINT_DIR) else []
|
| 321 |
+
if len(ckpt_files) < 3:
|
| 322 |
+
log_lines.append("[Step 0] Downloading model checkpoints...")
|
| 323 |
+
progress(0.02, desc="Downloading checkpoints...")
|
| 324 |
+
from huggingface_hub import snapshot_download
|
| 325 |
+
snapshot_download(
|
| 326 |
+
ACE_HF_MODEL,
|
| 327 |
+
local_dir=ACE_CHECKPOINT_DIR,
|
| 328 |
+
ignore_patterns=["*.md", "*.txt", ".gitattributes"],
|
| 329 |
+
)
|
| 330 |
+
log_lines.append(" Checkpoints downloaded.")
|
| 331 |
+
|
| 332 |
+
if ACE_SOURCE_DIR not in sys.path:
|
| 333 |
+
sys.path.insert(0, ACE_SOURCE_DIR)
|
| 334 |
+
|
| 335 |
+
log_lines.append("[Step 1/2] Preprocessing audio files...")
|
| 336 |
+
progress(0.10, desc="Preprocessing audio...")
|
| 337 |
+
|
| 338 |
+
tensor_dir = os.path.join(output_dir, "preprocessed_tensors")
|
| 339 |
+
os.makedirs(tensor_dir, exist_ok=True)
|
| 340 |
+
|
| 341 |
+
from acestep.training_v2.preprocess import preprocess_audio_files
|
| 342 |
+
result = preprocess_audio_files(
|
| 343 |
+
audio_dir=audio_dir,
|
| 344 |
+
output_dir=tensor_dir,
|
| 345 |
+
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 346 |
+
variant="turbo",
|
| 347 |
+
max_duration=60.0,
|
| 348 |
+
device="cpu",
|
| 349 |
+
precision="float32",
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
processed = result.get("processed", 0)
|
| 353 |
+
total_files = result.get("total", 0)
|
| 354 |
+
failed = result.get("failed", 0)
|
| 355 |
+
log_lines.append(f" Preprocessed: {processed}/{total_files} (failed: {failed})")
|
| 356 |
+
|
| 357 |
+
if processed == 0:
|
| 358 |
+
log_lines.append("ERROR: No files preprocessed successfully.")
|
| 359 |
+
return "\n".join(log_lines)
|
| 360 |
+
|
| 361 |
+
log_lines.append("[Step 2/2] Training LoRA adapter (CPU, this will be slow)...")
|
| 362 |
+
progress(0.30, desc="Loading model for training...")
|
| 363 |
+
|
| 364 |
+
from acestep.training_v2.model_loader import load_decoder_for_training
|
| 365 |
+
from acestep.training_v2.trainer_fixed import FixedLoRATrainer
|
| 366 |
+
from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
|
| 367 |
+
|
| 368 |
+
model = load_decoder_for_training(
|
| 369 |
+
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 370 |
+
variant="turbo",
|
| 371 |
+
device="cpu",
|
| 372 |
+
precision="float32",
|
| 373 |
+
)
|
| 374 |
+
model = model.float()
|
| 375 |
+
|
| 376 |
+
adapter_cfg = LoRAConfigV2(r=rank, alpha=rank, dropout=0.0)
|
| 377 |
+
train_cfg = TrainingConfigV2(
|
| 378 |
+
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 379 |
+
model_variant="turbo",
|
| 380 |
+
dataset_dir=tensor_dir,
|
| 381 |
+
output_dir=output_dir,
|
| 382 |
+
max_epochs=epochs,
|
| 383 |
+
batch_size=1,
|
| 384 |
+
learning_rate=lr,
|
| 385 |
+
device="cpu",
|
| 386 |
+
precision="float32",
|
| 387 |
+
seed=42,
|
| 388 |
+
num_workers=0,
|
| 389 |
+
pin_memory=False,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg)
|
| 393 |
+
|
| 394 |
+
step_count = 0
|
| 395 |
+
last_loss = 0.0
|
| 396 |
+
for update in trainer.train():
|
| 397 |
+
if hasattr(update, "step"):
|
| 398 |
+
step_count = update.step
|
| 399 |
+
last_loss = update.loss
|
| 400 |
+
elif isinstance(update, tuple) and len(update) >= 2:
|
| 401 |
+
step_count = update[0]
|
| 402 |
+
last_loss = update[1]
|
| 403 |
+
if step_count % 5 == 0:
|
| 404 |
+
log_lines.append(f" Step {step_count}: loss={last_loss:.4f}")
|
| 405 |
+
pct = 0.30 + 0.65 * min(step_count / max(epochs * processed, 1), 1.0)
|
| 406 |
+
progress(pct, desc=f"Step {step_count}, loss={last_loss:.4f}")
|
| 407 |
+
|
| 408 |
+
log_lines.append(f"Training complete! Final: step {step_count}, loss={last_loss:.4f}")
|
| 409 |
+
log_lines.append(f"LoRA saved to: {output_dir}")
|
| 410 |
+
|
| 411 |
+
del model, trainer
|
| 412 |
+
gc.collect()
|
| 413 |
+
|
| 414 |
+
except ImportError as e:
|
| 415 |
+
log_lines.append(f"Import error: {e}")
|
| 416 |
+
log_lines.append(f"Check ACE-Step source at {ACE_SOURCE_DIR}")
|
| 417 |
+
import traceback
|
| 418 |
+
log_lines.append(traceback.format_exc())
|
| 419 |
+
except Exception as e:
|
| 420 |
+
import traceback
|
| 421 |
+
log_lines.append(f"ERROR: {e}")
|
| 422 |
+
log_lines.append(traceback.format_exc())
|
| 423 |
+
|
| 424 |
+
return "\n".join(log_lines)
|
| 425 |
|
| 426 |
# -- Build UI --
|
| 427 |
CSS = """
|
|
|
|
| 429 |
.status-box textarea { font-family: monospace; font-size: 13px; }
|
| 430 |
"""
|
| 431 |
|
| 432 |
+
with gr.Blocks(title="ACE-Step 1.5 XL (CPU)", css=CSS) as demo:
|
| 433 |
|
| 434 |
with gr.Tabs():
|
| 435 |
# ============================================================
|
|
|
|
| 507 |
gr.Markdown(
|
| 508 |
"### LoRA Training\n"
|
| 509 |
"Fine-tune ACE-Step on your own audio data. "
|
| 510 |
+
"CPU training is very slow. Checkpoints downloaded on first run (~10GB)."
|
| 511 |
)
|
| 512 |
|
| 513 |
with gr.Row(elem_classes="compact-row"):
|
|
|
|
| 519 |
)
|
| 520 |
with gr.Column(scale=1):
|
| 521 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
| 522 |
+
epochs = gr.Number(label="Epochs", value=5, minimum=1, maximum=10)
|
| 523 |
lr = gr.Number(label="Learning Rate", value=1e-4)
|
| 524 |
+
rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
|
| 525 |
|
| 526 |
train_btn = gr.Button("Train", variant="primary")
|
| 527 |
train_log = gr.Textbox(
|
|
|
|
| 532 |
)
|
| 533 |
|
| 534 |
train_btn.click(
|
| 535 |
+
fn=train_lora,
|
| 536 |
inputs=[train_audio, lora_name, epochs, lr, rank],
|
| 537 |
outputs=[train_log],
|
| 538 |
api_name="train_lora",
|
|
|
|
| 542 |
server_name="0.0.0.0",
|
| 543 |
server_port=7860,
|
| 544 |
mcp_server=True,
|
|
|
|
| 545 |
)
|
| 546 |
|
| 547 |
|