Nekochu commited on
Commit
625132a
·
1 Parent(s): 72e4b69

add LoRA training, fix css kwarg

Browse files
Files changed (2) hide show
  1. Dockerfile +7 -2
  2. app.py +147 -14
Dockerfile CHANGED
@@ -69,8 +69,13 @@ RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/Qwen3-Embedding-0.6B-Q8_0.
69
  RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
70
  "https://huggingface.co/Serveurperso/ACE-Step-1.5-GGUF/resolve/main/vae-BF16.gguf"
71
 
72
- # Install Python deps for Gradio UI
73
- RUN pip3 install --no-cache-dir gradio==5.29.0 requests
 
 
 
 
 
74
 
75
  # Copy application files
76
  COPY app.py /app/app.py
 
69
  RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
70
  "https://huggingface.co/Serveurperso/ACE-Step-1.5-GGUF/resolve/main/vae-BF16.gguf"
71
 
72
+ # Install Python deps for Gradio UI + training
73
+ RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
74
+ gradio==5.29.0 requests torch safetensors \
75
+ transformers>=4.51.0 peft>=0.18.0 accelerate>=1.12.0
76
+
77
+ # Clone ACE-Step repo for training module
78
+ RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
79
 
80
  # Copy application files
81
  COPY app.py /app/app.py
app.py CHANGED
@@ -12,6 +12,11 @@ ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
12
  OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
13
  os.makedirs(OUTPUT_DIR, exist_ok=True)
14
 
 
 
 
 
 
15
  # ---------------------------------------------------------------------------
16
  # ace-server helpers
17
  # ---------------------------------------------------------------------------
@@ -280,14 +285,143 @@ def gradio_main():
280
  lines.append(json.dumps(props, indent=2))
281
  return "\n".join(lines)
282
 
283
- # -- Training placeholder --
284
- def train_lora_placeholder(*args):
285
- return ("Training requires PyTorch and the ACE-Step Python package.\n\n"
286
- "To enable training, install dependencies:\n"
287
- " pip install torch torchaudio safetensors transformers "
288
- "diffusers peft accelerate einops\n\n"
289
- "Then restart the app. Training is not available on the "
290
- "CPU-only HF Space — use a local GPU machine or a GPU Space.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  # -- Build UI --
293
  CSS = """
@@ -295,7 +429,7 @@ def gradio_main():
295
  .status-box textarea { font-family: monospace; font-size: 13px; }
296
  """
297
 
298
- with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo:
299
 
300
  with gr.Tabs():
301
  # ============================================================
@@ -373,7 +507,7 @@ def gradio_main():
373
  gr.Markdown(
374
  "### LoRA Training\n"
375
  "Fine-tune ACE-Step on your own audio data. "
376
- "Requires PyTorch + GPU (not available on CPU Spaces)."
377
  )
378
 
379
  with gr.Row(elem_classes="compact-row"):
@@ -385,9 +519,9 @@ def gradio_main():
385
  )
386
  with gr.Column(scale=1):
387
  lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
388
- epochs = gr.Number(label="Epochs", value=100, minimum=1, maximum=10000)
389
  lr = gr.Number(label="Learning Rate", value=1e-4)
390
- rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=256)
391
 
392
  train_btn = gr.Button("Train", variant="primary")
393
  train_log = gr.Textbox(
@@ -398,7 +532,7 @@ def gradio_main():
398
  )
399
 
400
  train_btn.click(
401
- fn=train_lora_placeholder,
402
  inputs=[train_audio, lora_name, epochs, lr, rank],
403
  outputs=[train_log],
404
  api_name="train_lora",
@@ -408,7 +542,6 @@ def gradio_main():
408
  server_name="0.0.0.0",
409
  server_port=7860,
410
  mcp_server=True,
411
- css=CSS,
412
  )
413
 
414
 
 
12
  OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
13
  os.makedirs(OUTPUT_DIR, exist_ok=True)
14
 
15
+ ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
16
+ ACE_SOURCE_DIR = "/app/ace-step-source"
17
+ ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
18
+ ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
19
+
20
  # ---------------------------------------------------------------------------
21
  # ace-server helpers
22
  # ---------------------------------------------------------------------------
 
285
  lines.append(json.dumps(props, indent=2))
286
  return "\n".join(lines)
287
 
288
+ # -- Training --
289
+ def train_lora(audio_files, lora_name, epochs, lr, rank,
290
+ progress=gr.Progress(track_tqdm=True)):
291
+ import shutil
292
+ import gc
293
+
294
+ if not audio_files:
295
+ return "No audio files uploaded."
296
+
297
+ lora_name = (lora_name or "").strip() or "my-lora"
298
+ epochs = max(1, min(int(epochs), 10))
299
+ lr = float(lr)
300
+ rank = max(1, min(int(rank), 64))
301
+
302
+ output_dir = os.path.join(ADAPTER_DIR, lora_name)
303
+ os.makedirs(output_dir, exist_ok=True)
304
+
305
+ audio_dir = os.path.join(output_dir, "audio_input")
306
+ os.makedirs(audio_dir, exist_ok=True)
307
+ for f in audio_files:
308
+ src = f.name if hasattr(f, "name") else str(f)
309
+ shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
310
+
311
+ log_lines = [
312
+ f"LoRA Training: '{lora_name}'",
313
+ f"Audio files: {len(audio_files)}",
314
+ f"Epochs: {epochs}, LR: {lr}, Rank: {rank}",
315
+ f"Output: {output_dir}",
316
+ "",
317
+ ]
318
+
319
+ try:
320
+ ckpt_files = os.listdir(ACE_CHECKPOINT_DIR) if os.path.isdir(ACE_CHECKPOINT_DIR) else []
321
+ if len(ckpt_files) < 3:
322
+ log_lines.append("[Step 0] Downloading model checkpoints...")
323
+ progress(0.02, desc="Downloading checkpoints...")
324
+ from huggingface_hub import snapshot_download
325
+ snapshot_download(
326
+ ACE_HF_MODEL,
327
+ local_dir=ACE_CHECKPOINT_DIR,
328
+ ignore_patterns=["*.md", "*.txt", ".gitattributes"],
329
+ )
330
+ log_lines.append(" Checkpoints downloaded.")
331
+
332
+ if ACE_SOURCE_DIR not in sys.path:
333
+ sys.path.insert(0, ACE_SOURCE_DIR)
334
+
335
+ log_lines.append("[Step 1/2] Preprocessing audio files...")
336
+ progress(0.10, desc="Preprocessing audio...")
337
+
338
+ tensor_dir = os.path.join(output_dir, "preprocessed_tensors")
339
+ os.makedirs(tensor_dir, exist_ok=True)
340
+
341
+ from acestep.training_v2.preprocess import preprocess_audio_files
342
+ result = preprocess_audio_files(
343
+ audio_dir=audio_dir,
344
+ output_dir=tensor_dir,
345
+ checkpoint_dir=ACE_CHECKPOINT_DIR,
346
+ variant="turbo",
347
+ max_duration=60.0,
348
+ device="cpu",
349
+ precision="float32",
350
+ )
351
+
352
+ processed = result.get("processed", 0)
353
+ total_files = result.get("total", 0)
354
+ failed = result.get("failed", 0)
355
+ log_lines.append(f" Preprocessed: {processed}/{total_files} (failed: {failed})")
356
+
357
+ if processed == 0:
358
+ log_lines.append("ERROR: No files preprocessed successfully.")
359
+ return "\n".join(log_lines)
360
+
361
+ log_lines.append("[Step 2/2] Training LoRA adapter (CPU, this will be slow)...")
362
+ progress(0.30, desc="Loading model for training...")
363
+
364
+ from acestep.training_v2.model_loader import load_decoder_for_training
365
+ from acestep.training_v2.trainer_fixed import FixedLoRATrainer
366
+ from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
367
+
368
+ model = load_decoder_for_training(
369
+ checkpoint_dir=ACE_CHECKPOINT_DIR,
370
+ variant="turbo",
371
+ device="cpu",
372
+ precision="float32",
373
+ )
374
+ model = model.float()
375
+
376
+ adapter_cfg = LoRAConfigV2(r=rank, alpha=rank, dropout=0.0)
377
+ train_cfg = TrainingConfigV2(
378
+ checkpoint_dir=ACE_CHECKPOINT_DIR,
379
+ model_variant="turbo",
380
+ dataset_dir=tensor_dir,
381
+ output_dir=output_dir,
382
+ max_epochs=epochs,
383
+ batch_size=1,
384
+ learning_rate=lr,
385
+ device="cpu",
386
+ precision="float32",
387
+ seed=42,
388
+ num_workers=0,
389
+ pin_memory=False,
390
+ )
391
+
392
+ trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg)
393
+
394
+ step_count = 0
395
+ last_loss = 0.0
396
+ for update in trainer.train():
397
+ if hasattr(update, "step"):
398
+ step_count = update.step
399
+ last_loss = update.loss
400
+ elif isinstance(update, tuple) and len(update) >= 2:
401
+ step_count = update[0]
402
+ last_loss = update[1]
403
+ if step_count % 5 == 0:
404
+ log_lines.append(f" Step {step_count}: loss={last_loss:.4f}")
405
+ pct = 0.30 + 0.65 * min(step_count / max(epochs * processed, 1), 1.0)
406
+ progress(pct, desc=f"Step {step_count}, loss={last_loss:.4f}")
407
+
408
+ log_lines.append(f"Training complete! Final: step {step_count}, loss={last_loss:.4f}")
409
+ log_lines.append(f"LoRA saved to: {output_dir}")
410
+
411
+ del model, trainer
412
+ gc.collect()
413
+
414
+ except ImportError as e:
415
+ log_lines.append(f"Import error: {e}")
416
+ log_lines.append(f"Check ACE-Step source at {ACE_SOURCE_DIR}")
417
+ import traceback
418
+ log_lines.append(traceback.format_exc())
419
+ except Exception as e:
420
+ import traceback
421
+ log_lines.append(f"ERROR: {e}")
422
+ log_lines.append(traceback.format_exc())
423
+
424
+ return "\n".join(log_lines)
425
 
426
  # -- Build UI --
427
  CSS = """
 
429
  .status-box textarea { font-family: monospace; font-size: 13px; }
430
  """
431
 
432
+ with gr.Blocks(title="ACE-Step 1.5 XL (CPU)", css=CSS) as demo:
433
 
434
  with gr.Tabs():
435
  # ============================================================
 
507
  gr.Markdown(
508
  "### LoRA Training\n"
509
  "Fine-tune ACE-Step on your own audio data. "
510
+ "CPU training is very slow. Checkpoints downloaded on first run (~10GB)."
511
  )
512
 
513
  with gr.Row(elem_classes="compact-row"):
 
519
  )
520
  with gr.Column(scale=1):
521
  lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
522
+ epochs = gr.Number(label="Epochs", value=5, minimum=1, maximum=10)
523
  lr = gr.Number(label="Learning Rate", value=1e-4)
524
+ rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
525
 
526
  train_btn = gr.Button("Train", variant="primary")
527
  train_log = gr.Textbox(
 
532
  )
533
 
534
  train_btn.click(
535
+ fn=train_lora,
536
  inputs=[train_audio, lora_name, epochs, lr, rank],
537
  outputs=[train_log],
538
  api_name="train_lora",
 
542
  server_name="0.0.0.0",
543
  server_port=7860,
544
  mcp_server=True,
 
545
  )
546
 
547