Spaces:
Running
Running
add LoRA download button after training (gr.File output, like rvc-beatrice)
Browse files
app.py
CHANGED
|
@@ -460,12 +460,12 @@ def gradio_main():
|
|
| 460 |
# -- Validation --
|
| 461 |
if not audio_files:
|
| 462 |
_log("[FAIL] No audio files uploaded.")
|
| 463 |
-
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 464 |
return
|
| 465 |
|
| 466 |
if len(audio_files) > MAX_AUDIO_FILES:
|
| 467 |
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
|
| 468 |
-
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 469 |
return
|
| 470 |
|
| 471 |
lora_name = (lora_name or "").strip() or "my-lora"
|
|
@@ -485,7 +485,7 @@ def gradio_main():
|
|
| 485 |
|
| 486 |
# Copy uploaded audio files
|
| 487 |
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
|
| 488 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 489 |
|
| 490 |
for f in audio_files:
|
| 491 |
src = f.name if hasattr(f, "name") else str(f)
|
|
@@ -493,18 +493,18 @@ def gradio_main():
|
|
| 493 |
|
| 494 |
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
|
| 495 |
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 496 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 497 |
|
| 498 |
# Stop ace-server before training (frees memory)
|
| 499 |
_log("[INFO] Stopping ace-server for training...")
|
| 500 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 501 |
_stop_ace_server()
|
| 502 |
_gc.collect()
|
| 503 |
|
| 504 |
try:
|
| 505 |
# -- Phase 1: Preprocessing --
|
| 506 |
_log("[Step 1/2] Preprocessing audio...")
|
| 507 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 508 |
|
| 509 |
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
|
| 510 |
|
|
@@ -521,24 +521,24 @@ def gradio_main():
|
|
| 521 |
progress_callback=preprocess_progress,
|
| 522 |
cancel_check=lambda: False,
|
| 523 |
)
|
| 524 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 525 |
|
| 526 |
processed = result.get("processed", 0)
|
| 527 |
failed = result.get("failed", 0)
|
| 528 |
total = result.get("total", 0)
|
| 529 |
_log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
|
| 530 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 531 |
|
| 532 |
if processed == 0:
|
| 533 |
_log("[FAIL] No files preprocessed successfully. Cannot train.")
|
| 534 |
-
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 535 |
return
|
| 536 |
|
| 537 |
_gc.collect()
|
| 538 |
|
| 539 |
# -- Phase 2: Training --
|
| 540 |
_log("[Step 2/2] Training LoRA...")
|
| 541 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 542 |
|
| 543 |
for msg in train_lora_generator(
|
| 544 |
dataset_dir=preprocessed_dir,
|
|
@@ -568,31 +568,35 @@ def gradio_main():
|
|
| 568 |
break
|
| 569 |
|
| 570 |
_log(msg)
|
| 571 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 572 |
|
| 573 |
if msg.strip() == "[DONE]":
|
| 574 |
break
|
| 575 |
|
| 576 |
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 577 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 578 |
|
| 579 |
except Exception as exc:
|
| 580 |
_log(f"[FAIL] Training error: {exc}")
|
| 581 |
import traceback
|
| 582 |
_log(traceback.format_exc())
|
| 583 |
-
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 584 |
|
| 585 |
finally:
|
| 586 |
# Always restart ace-server
|
| 587 |
_log("[INFO] Restarting ace-server...")
|
| 588 |
-
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 589 |
_gc.collect()
|
| 590 |
ok = _start_ace_server()
|
| 591 |
if ok:
|
| 592 |
_log("[OK] ace-server restarted successfully")
|
| 593 |
else:
|
| 594 |
_log("[WARN] ace-server may not have restarted -- check logs")
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
|
| 597 |
# -- Cancel handler --
|
| 598 |
def _on_cancel():
|
|
@@ -728,18 +732,19 @@ def gradio_main():
|
|
| 728 |
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
|
| 729 |
log_btn = gr.Button("Check Log", scale=1)
|
| 730 |
|
|
|
|
| 731 |
train_log = gr.Textbox(
|
| 732 |
label="Training Log",
|
| 733 |
interactive=False,
|
| 734 |
-
lines=
|
| 735 |
elem_classes="status-box",
|
| 736 |
)
|
| 737 |
|
| 738 |
-
# Training generator -- yields (log, train_btn, cancel_btn)
|
| 739 |
train_event = train_btn.click(
|
| 740 |
train_lora_ui,
|
| 741 |
inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
|
| 742 |
-
outputs=[train_log, train_btn, cancel_btn],
|
| 743 |
api_name="train_lora",
|
| 744 |
concurrency_limit=1,
|
| 745 |
)
|
|
|
|
| 460 |
# -- Validation --
|
| 461 |
if not audio_files:
|
| 462 |
_log("[FAIL] No audio files uploaded.")
|
| 463 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
|
| 464 |
return
|
| 465 |
|
| 466 |
if len(audio_files) > MAX_AUDIO_FILES:
|
| 467 |
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
|
| 468 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
|
| 469 |
return
|
| 470 |
|
| 471 |
lora_name = (lora_name or "").strip() or "my-lora"
|
|
|
|
| 485 |
|
| 486 |
# Copy uploaded audio files
|
| 487 |
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
|
| 488 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 489 |
|
| 490 |
for f in audio_files:
|
| 491 |
src = f.name if hasattr(f, "name") else str(f)
|
|
|
|
| 493 |
|
| 494 |
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
|
| 495 |
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 496 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 497 |
|
| 498 |
# Stop ace-server before training (frees memory)
|
| 499 |
_log("[INFO] Stopping ace-server for training...")
|
| 500 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 501 |
_stop_ace_server()
|
| 502 |
_gc.collect()
|
| 503 |
|
| 504 |
try:
|
| 505 |
# -- Phase 1: Preprocessing --
|
| 506 |
_log("[Step 1/2] Preprocessing audio...")
|
| 507 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 508 |
|
| 509 |
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
|
| 510 |
|
|
|
|
| 521 |
progress_callback=preprocess_progress,
|
| 522 |
cancel_check=lambda: False,
|
| 523 |
)
|
| 524 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 525 |
|
| 526 |
processed = result.get("processed", 0)
|
| 527 |
failed = result.get("failed", 0)
|
| 528 |
total = result.get("total", 0)
|
| 529 |
_log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
|
| 530 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 531 |
|
| 532 |
if processed == 0:
|
| 533 |
_log("[FAIL] No files preprocessed successfully. Cannot train.")
|
| 534 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
|
| 535 |
return
|
| 536 |
|
| 537 |
_gc.collect()
|
| 538 |
|
| 539 |
# -- Phase 2: Training --
|
| 540 |
_log("[Step 2/2] Training LoRA...")
|
| 541 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 542 |
|
| 543 |
for msg in train_lora_generator(
|
| 544 |
dataset_dir=preprocessed_dir,
|
|
|
|
| 568 |
break
|
| 569 |
|
| 570 |
_log(msg)
|
| 571 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 572 |
|
| 573 |
if msg.strip() == "[DONE]":
|
| 574 |
break
|
| 575 |
|
| 576 |
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 577 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 578 |
|
| 579 |
except Exception as exc:
|
| 580 |
_log(f"[FAIL] Training error: {exc}")
|
| 581 |
import traceback
|
| 582 |
_log(traceback.format_exc())
|
| 583 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
|
| 584 |
|
| 585 |
finally:
|
| 586 |
# Always restart ace-server
|
| 587 |
_log("[INFO] Restarting ace-server...")
|
| 588 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
|
| 589 |
_gc.collect()
|
| 590 |
ok = _start_ace_server()
|
| 591 |
if ok:
|
| 592 |
_log("[OK] ace-server restarted successfully")
|
| 593 |
else:
|
| 594 |
_log("[WARN] ace-server may not have restarted -- check logs")
|
| 595 |
+
adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
|
| 596 |
+
if os.path.isfile(adapter_safetensors):
|
| 597 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update(value=adapter_safetensors, visible=True)
|
| 598 |
+
else:
|
| 599 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
|
| 600 |
|
| 601 |
# -- Cancel handler --
|
| 602 |
def _on_cancel():
|
|
|
|
| 732 |
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
|
| 733 |
log_btn = gr.Button("Check Log", scale=1)
|
| 734 |
|
| 735 |
+
train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
|
| 736 |
train_log = gr.Textbox(
|
| 737 |
label="Training Log",
|
| 738 |
interactive=False,
|
| 739 |
+
lines=10,
|
| 740 |
elem_classes="status-box",
|
| 741 |
)
|
| 742 |
|
| 743 |
+
# Training generator -- yields (log, train_btn, cancel_btn, output_file)
|
| 744 |
train_event = train_btn.click(
|
| 745 |
train_lora_ui,
|
| 746 |
inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
|
| 747 |
+
outputs=[train_log, train_btn, cancel_btn, train_output_file],
|
| 748 |
api_name="train_lora",
|
| 749 |
concurrency_limit=1,
|
| 750 |
)
|