Nekochu commited on
Commit
2d3c27c
·
1 Parent(s): 5b7a56f

add LoRA download button after training (gr.File output, like rvc-beatrice)

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -460,12 +460,12 @@ def gradio_main():
460
  # -- Validation --
461
  if not audio_files:
462
  _log("[FAIL] No audio files uploaded.")
463
- yield _log_text(), gr.update(visible=True), gr.update(visible=False)
464
  return
465
 
466
  if len(audio_files) > MAX_AUDIO_FILES:
467
  _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
468
- yield _log_text(), gr.update(visible=True), gr.update(visible=False)
469
  return
470
 
471
  lora_name = (lora_name or "").strip() or "my-lora"
@@ -485,7 +485,7 @@ def gradio_main():
485
 
486
  # Copy uploaded audio files
487
  _log(f"[INFO] Preparing {len(audio_files)} audio files...")
488
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
489
 
490
  for f in audio_files:
491
  src = f.name if hasattr(f, "name") else str(f)
@@ -493,18 +493,18 @@ def gradio_main():
493
 
494
  _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
495
  f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
496
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
497
 
498
  # Stop ace-server before training (frees memory)
499
  _log("[INFO] Stopping ace-server for training...")
500
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
501
  _stop_ace_server()
502
  _gc.collect()
503
 
504
  try:
505
  # -- Phase 1: Preprocessing --
506
  _log("[Step 1/2] Preprocessing audio...")
507
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
508
 
509
  preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
510
 
@@ -521,24 +521,24 @@ def gradio_main():
521
  progress_callback=preprocess_progress,
522
  cancel_check=lambda: False,
523
  )
524
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
525
 
526
  processed = result.get("processed", 0)
527
  failed = result.get("failed", 0)
528
  total = result.get("total", 0)
529
  _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
530
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
531
 
532
  if processed == 0:
533
  _log("[FAIL] No files preprocessed successfully. Cannot train.")
534
- yield _log_text(), gr.update(visible=True), gr.update(visible=False)
535
  return
536
 
537
  _gc.collect()
538
 
539
  # -- Phase 2: Training --
540
  _log("[Step 2/2] Training LoRA...")
541
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
542
 
543
  for msg in train_lora_generator(
544
  dataset_dir=preprocessed_dir,
@@ -568,31 +568,35 @@ def gradio_main():
568
  break
569
 
570
  _log(msg)
571
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
572
 
573
  if msg.strip() == "[DONE]":
574
  break
575
 
576
  _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
577
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
578
 
579
  except Exception as exc:
580
  _log(f"[FAIL] Training error: {exc}")
581
  import traceback
582
  _log(traceback.format_exc())
583
- yield _log_text(), gr.update(visible=True), gr.update(visible=False)
584
 
585
  finally:
586
  # Always restart ace-server
587
  _log("[INFO] Restarting ace-server...")
588
- yield _log_text(), gr.update(visible=False), gr.update(visible=True)
589
  _gc.collect()
590
  ok = _start_ace_server()
591
  if ok:
592
  _log("[OK] ace-server restarted successfully")
593
  else:
594
  _log("[WARN] ace-server may not have restarted -- check logs")
595
- yield _log_text(), gr.update(visible=True), gr.update(visible=False)
 
 
 
 
596
 
597
  # -- Cancel handler --
598
  def _on_cancel():
@@ -728,18 +732,19 @@ def gradio_main():
728
  cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
729
  log_btn = gr.Button("Check Log", scale=1)
730
 
 
731
  train_log = gr.Textbox(
732
  label="Training Log",
733
  interactive=False,
734
- lines=12,
735
  elem_classes="status-box",
736
  )
737
 
738
- # Training generator -- yields (log, train_btn, cancel_btn)
739
  train_event = train_btn.click(
740
  train_lora_ui,
741
  inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
742
- outputs=[train_log, train_btn, cancel_btn],
743
  api_name="train_lora",
744
  concurrency_limit=1,
745
  )
 
460
  # -- Validation --
461
  if not audio_files:
462
  _log("[FAIL] No audio files uploaded.")
463
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
464
  return
465
 
466
  if len(audio_files) > MAX_AUDIO_FILES:
467
  _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
468
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
469
  return
470
 
471
  lora_name = (lora_name or "").strip() or "my-lora"
 
485
 
486
  # Copy uploaded audio files
487
  _log(f"[INFO] Preparing {len(audio_files)} audio files...")
488
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
489
 
490
  for f in audio_files:
491
  src = f.name if hasattr(f, "name") else str(f)
 
493
 
494
  _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
495
  f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
496
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
497
 
498
  # Stop ace-server before training (frees memory)
499
  _log("[INFO] Stopping ace-server for training...")
500
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
501
  _stop_ace_server()
502
  _gc.collect()
503
 
504
  try:
505
  # -- Phase 1: Preprocessing --
506
  _log("[Step 1/2] Preprocessing audio...")
507
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
508
 
509
  preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
510
 
 
521
  progress_callback=preprocess_progress,
522
  cancel_check=lambda: False,
523
  )
524
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
525
 
526
  processed = result.get("processed", 0)
527
  failed = result.get("failed", 0)
528
  total = result.get("total", 0)
529
  _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
530
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
531
 
532
  if processed == 0:
533
  _log("[FAIL] No files preprocessed successfully. Cannot train.")
534
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
535
  return
536
 
537
  _gc.collect()
538
 
539
  # -- Phase 2: Training --
540
  _log("[Step 2/2] Training LoRA...")
541
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
542
 
543
  for msg in train_lora_generator(
544
  dataset_dir=preprocessed_dir,
 
568
  break
569
 
570
  _log(msg)
571
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
572
 
573
  if msg.strip() == "[DONE]":
574
  break
575
 
576
  _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
577
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
578
 
579
  except Exception as exc:
580
  _log(f"[FAIL] Training error: {exc}")
581
  import traceback
582
  _log(traceback.format_exc())
583
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
584
 
585
  finally:
586
  # Always restart ace-server
587
  _log("[INFO] Restarting ace-server...")
588
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
589
  _gc.collect()
590
  ok = _start_ace_server()
591
  if ok:
592
  _log("[OK] ace-server restarted successfully")
593
  else:
594
  _log("[WARN] ace-server may not have restarted -- check logs")
595
+ adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
596
+ if os.path.isfile(adapter_safetensors):
597
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update(value=adapter_safetensors, visible=True)
598
+ else:
599
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
600
 
601
  # -- Cancel handler --
602
  def _on_cancel():
 
732
  cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
733
  log_btn = gr.Button("Check Log", scale=1)
734
 
735
+ train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
736
  train_log = gr.Textbox(
737
  label="Training Log",
738
  interactive=False,
739
+ lines=10,
740
  elem_classes="status-box",
741
  )
742
 
743
+ # Training generator -- yields (log, train_btn, cancel_btn, output_file)
744
  train_event = train_btn.click(
745
  train_lora_ui,
746
  inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
747
+ outputs=[train_log, train_btn, cancel_btn, train_output_file],
748
  api_name="train_lora",
749
  concurrency_limit=1,
750
  )