Nekochu commited on
Commit
32de701
·
1 Parent(s): 4d9a556

cancel, captioning, preprocessing, sidecar upload, elapsed time, GeneratorExit fix

Browse files
Files changed (2) hide show
  1. app.py +162 -126
  2. train_engine.py +13 -8
app.py CHANGED
@@ -19,10 +19,12 @@ from train_engine import (
19
  preprocess_audio,
20
  train_lora_generator,
21
  cancel_training,
 
22
  get_trained_loras as _get_trained_loras_engine,
23
  MAX_TRAINING_TIME,
24
  )
25
 
 
26
  logger = logging.getLogger(__name__)
27
 
28
  # ---------------------------------------------------------------------------
@@ -93,12 +95,14 @@ def _get_props():
93
  return {}
94
 
95
 
96
- def _poll_job(job_id, timeout=600, progress_cb=None):
97
- """Poll a job until done/error/timeout. Returns (status, elapsed)."""
98
  t0 = time.time()
99
  while time.time() - t0 < timeout:
 
 
100
  try:
101
- r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=10)
102
  data = r.json()
103
  status = data.get("status", "unknown")
104
  if progress_cb:
@@ -107,7 +111,7 @@ def _poll_job(job_id, timeout=600, progress_cb=None):
107
  return status, time.time() - t0
108
  except Exception:
109
  pass
110
- time.sleep(2)
111
  return "timeout", time.time() - t0
112
 
113
 
@@ -121,58 +125,41 @@ def _fetch_result(job_id, timeout=60):
121
  return r
122
 
123
 
124
- def _caption_via_understand(audio_path, timeout=120):
125
- """Call ace-server /understand to get a rich caption for an audio file.
126
 
127
- Returns a dict with caption, bpm, key, signature, lyrics on success,
128
- or None on failure (caller should fall back to librosa).
129
- """
130
  fname = os.path.basename(audio_path)
131
  try:
132
  with open(audio_path, "rb") as f:
133
- audio_b64 = base64.b64encode(f.read()).decode("ascii")
134
- except Exception as exc:
135
- logger.warning("[Caption] %s: failed to read file: %s", fname, exc)
136
- return None
137
-
138
- # Submit
139
- try:
140
- r = requests.post(
141
- f"{ACE_SERVER}/understand",
142
- json={"audio": audio_b64},
143
- timeout=30,
144
- )
145
  if r.status_code != 200:
146
- logger.warning("[Caption] %s: /understand returned %d: %s", fname, r.status_code, r.text[:200])
147
  return None
148
  job_id = r.json().get("id")
149
  if not job_id:
150
- logger.warning("[Caption] %s: /understand returned no job id", fname)
151
  return None
152
  except Exception as exc:
153
  logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
154
  return None
155
 
156
- # Poll until done
157
- status, _ = _poll_job(job_id, timeout=timeout)
158
  if status != "done":
159
- logger.warning("[Caption] %s: /understand job %s -> %s", fname, job_id, status)
160
  return None
161
 
162
- # Fetch result
163
  try:
164
  r = _fetch_result(job_id, timeout=30)
165
  if r.status_code != 200:
166
- logger.warning("[Caption] %s: /understand result fetch failed: %d", fname, r.status_code)
167
  return None
168
  data = r.json()
169
- # The result should contain caption, bpm, key, signature, lyrics
170
  if isinstance(data, dict) and data.get("caption"):
171
  return data
172
- logger.warning("[Caption] %s: /understand returned no caption field", fname)
173
  return None
174
- except Exception as exc:
175
- logger.warning("[Caption] %s: /understand result parse failed: %s", fname, exc)
176
  return None
177
 
178
 
@@ -559,7 +546,13 @@ def gradio_main():
559
  train_start = time.time()
560
 
561
  def _log(msg):
562
- _train_log_lines.append(msg)
 
 
 
 
 
 
563
  if len(_train_log_lines) > 2000:
564
  _train_log_lines[:] = _train_log_lines[-1000:]
565
 
@@ -587,7 +580,9 @@ def gradio_main():
587
  work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name)
588
  os.makedirs(work_dir, exist_ok=True)
589
  audio_dir = os.path.join(work_dir, "audio_input")
590
- os.makedirs(audio_dir, exist_ok=True)
 
 
591
  adapter_out = os.path.join(ADAPTER_DIR, lora_name)
592
  os.makedirs(adapter_out, exist_ok=True)
593
 
@@ -603,6 +598,10 @@ def gradio_main():
603
  for f in audio_files:
604
  src = f.name if hasattr(f, "name") else str(f)
605
  fname = os.path.basename(src)
 
 
 
 
606
  try:
607
  dur = _lr.get_duration(path=src)
608
  except Exception:
@@ -643,37 +642,61 @@ def gradio_main():
643
  f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
644
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
645
 
646
- # Caption each audio file via ace-server /understand BEFORE stopping it
647
- if _server_ok():
648
- _log("[INFO] Captioning audio via ace-server /understand...")
649
- yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
650
- for audio_fname in sorted(os.listdir(audio_dir)):
651
- full_path = os.path.join(audio_dir, audio_fname)
652
- if not os.path.isfile(full_path) or audio_fname.endswith(".json"):
653
- continue
654
- caption_json_path = full_path + ".json"
655
- caption_data = _caption_via_understand(full_path, timeout=120)
656
- if caption_data:
657
- _log(f"[Caption] {audio_fname}: using ace-server /understand")
658
- with open(caption_json_path, "w") as cj:
659
- json.dump(caption_data, cj)
660
- else:
661
- # Fallback to librosa for basic metadata
662
- _log(f"[Caption] {audio_fname}: fallback to librosa")
663
- try:
664
- y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
665
- tempo, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
666
- bpm_val = float(tempo) if hasattr(tempo, '__float__') else float(tempo[0])
667
- fallback = {"caption": "", "bpm": round(bpm_val), "key": "", "signature": "", "lyrics": ""}
668
- with open(caption_json_path, "w") as cj:
669
- json.dump(fallback, cj)
670
- except Exception as cap_exc:
671
- _log(f"[Caption] {audio_fname}: librosa fallback also failed: {cap_exc}")
672
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
673
- else:
674
- _log("[INFO] ace-server not running, skipping /understand captioning")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
676
 
 
 
 
 
 
 
 
677
  # Stop ace-server before training (frees memory)
678
  _training_lock.acquire()
679
  _log("[INFO] Stopping ace-server for training...")
@@ -681,28 +704,54 @@ def gradio_main():
681
  _stop_ace_server()
682
  _gc.collect()
683
 
 
684
  try:
685
- # -- Phase 1: Preprocessing --
686
- _log("[Step 1/2] Preprocessing audio...")
687
- yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
688
-
689
  preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
 
690
 
691
  def preprocess_progress(current, total, desc):
692
  _log(f" {desc} ({current}/{total})")
693
 
694
- result = preprocess_audio(
695
- audio_dir=audio_dir,
696
- output_dir=preprocessed_dir,
697
- checkpoint_dir=ACE_CHECKPOINT_DIR,
698
- device="cpu",
699
- variant="turbo",
700
- max_duration=float(MAX_TOTAL_AUDIO),
701
- progress_callback=preprocess_progress,
702
- cancel_check=lambda: False,
703
- )
 
 
 
 
 
 
 
 
 
704
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  processed = result.get("processed", 0)
707
  failed = result.get("failed", 0)
708
  total = result.get("total", 0)
@@ -740,7 +789,6 @@ def gradio_main():
740
  device="cpu",
741
  log_every=5,
742
  ):
743
- # Timeout check
744
  elapsed = time.time() - train_start
745
  if elapsed > MAX_TRAINING_TIME:
746
  _log(f"[WARN] Training timed out after {int(elapsed)}s")
@@ -756,6 +804,16 @@ def gradio_main():
756
  _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
757
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
758
 
 
 
 
 
 
 
 
 
 
 
759
  except Exception as exc:
760
  _log(f"[FAIL] Training error: {exc}")
761
  import traceback
@@ -763,50 +821,36 @@ def gradio_main():
763
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
764
 
765
  finally:
766
- _training_lock.release()
767
- # Always restart ace-server
768
- _log("[INFO] Restarting ace-server...")
769
- yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
770
- _gc.collect()
771
- ok = _start_ace_server()
772
- if ok:
773
- _log("[OK] ace-server restarted successfully")
774
- else:
775
- _log("[WARN] ace-server may not have restarted -- check logs")
776
- if os.path.isdir(adapter_out):
777
- logger.info("Adapter dir %s: %s", adapter_out, os.listdir(adapter_out))
778
- else:
779
- logger.warning("Adapter dir %s does not exist", adapter_out)
780
- adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
781
- if os.path.isfile(adapter_safetensors):
782
- # Copy to a temp file so Gradio doesn't try to validate /app paths
783
- # (avoids InvalidPathError: "Cannot move /app to the gradio cache dir
784
- # because it was not uploaded by a user")
785
- tmp_out = tempfile.NamedTemporaryFile(
786
- suffix=".safetensors",
787
- prefix=f"{lora_name}_",
788
- delete=False,
789
- )
790
- tmp_out.close()
791
- shutil.copy2(adapter_safetensors, tmp_out.name)
792
- _log(f"[OK] LoRA saved: {lora_name}")
793
- yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
794
- else:
795
- yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
796
- # Clean up training workspace (preprocessed tensors, temp audio, etc.)
797
- shutil.rmtree(work_dir, ignore_errors=True)
798
 
799
  # -- Cancel handler --
800
  def _on_cancel():
801
  cancel_training()
802
  logger.info("Cancel requested by user")
803
- return "Cancelling after current epoch... please wait"
804
-
805
- # -- Check log handler --
806
- def _check_log():
807
- if _train_log_lines:
808
- return "\n".join(_train_log_lines)
809
- return "No training log available."
810
 
811
  # -- Build LM model choices --
812
  def _lm_model_choices():
@@ -909,9 +953,9 @@ def gradio_main():
909
  with gr.Row(elem_classes="compact-row"):
910
  with gr.Column(scale=2):
911
  train_audio = gr.File(
912
- label="Training Audio Files",
913
  file_count="multiple",
914
- file_types=["audio"],
915
  )
916
  with gr.Column(scale=1):
917
  lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
@@ -928,7 +972,6 @@ def gradio_main():
928
  with gr.Row(elem_classes="compact-row"):
929
  train_btn = gr.Button("Train", variant="primary", scale=2)
930
  cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
931
- log_btn = gr.Button("Check Log", scale=1)
932
 
933
  train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
934
  train_log = gr.Textbox(
@@ -975,13 +1018,6 @@ def gradio_main():
975
  outputs=[train_log],
976
  )
977
 
978
- # Check log: show last training output
979
- log_btn.click(
980
- _check_log,
981
- outputs=[train_log],
982
- api_name="check_log",
983
- )
984
-
985
  demo.launch(
986
  server_name="0.0.0.0",
987
  server_port=7860,
 
19
  preprocess_audio,
20
  train_lora_generator,
21
  cancel_training,
22
+ _training_cancel,
23
  get_trained_loras as _get_trained_loras_engine,
24
  MAX_TRAINING_TIME,
25
  )
26
 
27
+ logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stdout)
28
  logger = logging.getLogger(__name__)
29
 
30
  # ---------------------------------------------------------------------------
 
95
  return {}
96
 
97
 
98
+ def _poll_job(job_id, timeout=600, progress_cb=None, cancel_check=None):
99
+ """Poll a job until done/error/timeout/cancelled. Returns (status, elapsed)."""
100
  t0 = time.time()
101
  while time.time() - t0 < timeout:
102
+ if cancel_check and cancel_check():
103
+ return "cancelled", time.time() - t0
104
  try:
105
+ r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=5)
106
  data = r.json()
107
  status = data.get("status", "unknown")
108
  if progress_cb:
 
111
  return status, time.time() - t0
112
  except Exception:
113
  pass
114
+ time.sleep(1)
115
  return "timeout", time.time() - t0
116
 
117
 
 
125
  return r
126
 
127
 
 
 
128
 
129
+ def _caption_via_understand(audio_path, timeout=600, cancel_check=None):
130
+ """Call ace-server /understand for a rich caption. Returns dict or None."""
 
131
  fname = os.path.basename(audio_path)
132
  try:
133
  with open(audio_path, "rb") as f:
134
+ r = requests.post(
135
+ f"{ACE_SERVER}/understand",
136
+ files={"audio": (fname, f, "audio/mpeg")},
137
+ timeout=30,
138
+ )
 
 
 
 
 
 
 
139
  if r.status_code != 200:
140
+ logger.warning("[Caption] %s: /understand %d: %s", fname, r.status_code, r.text[:200])
141
  return None
142
  job_id = r.json().get("id")
143
  if not job_id:
 
144
  return None
145
  except Exception as exc:
146
  logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
147
  return None
148
 
149
+ status, elapsed = _poll_job(job_id, timeout=timeout, cancel_check=cancel_check)
 
150
  if status != "done":
151
+ logger.warning("[Caption] %s: /understand -> %s (%.0fs)", fname, status, elapsed)
152
  return None
153
 
 
154
  try:
155
  r = _fetch_result(job_id, timeout=30)
156
  if r.status_code != 200:
 
157
  return None
158
  data = r.json()
 
159
  if isinstance(data, dict) and data.get("caption"):
160
  return data
 
161
  return None
162
+ except Exception:
 
163
  return None
164
 
165
 
 
546
  train_start = time.time()
547
 
548
  def _log(msg):
549
+ elapsed = int(time.time() - train_start)
550
+ m, s = divmod(elapsed, 60)
551
+ h, m = divmod(m, 60)
552
+ ts = f"+{h}:{m:02d}:{s:02d}" if h else f"+{m:02d}:{s:02d}"
553
+ line = f"[{ts}] {msg}"
554
+ _train_log_lines.append(line)
555
+ logger.info(msg)
556
  if len(_train_log_lines) > 2000:
557
  _train_log_lines[:] = _train_log_lines[-1000:]
558
 
 
580
  work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name)
581
  os.makedirs(work_dir, exist_ok=True)
582
  audio_dir = os.path.join(work_dir, "audio_input")
583
+ if os.path.exists(audio_dir):
584
+ shutil.rmtree(audio_dir)
585
+ os.makedirs(audio_dir)
586
  adapter_out = os.path.join(ADAPTER_DIR, lora_name)
587
  os.makedirs(adapter_out, exist_ok=True)
588
 
 
598
  for f in audio_files:
599
  src = f.name if hasattr(f, "name") else str(f)
600
  fname = os.path.basename(src)
601
+ # .txt/.json sidecars: copy as caption files, skip duration check
602
+ if fname.lower().endswith((".txt", ".json")):
603
+ shutil.copy2(src, os.path.join(audio_dir, fname))
604
+ continue
605
  try:
606
  dur = _lr.get_duration(path=src)
607
  except Exception:
 
642
  f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
643
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
644
 
645
+ # Caption audio files: GGUF LM if ace-server running, else librosa
646
+ use_understand = _server_ok()
647
+ method = "GGUF LM (BPM, key, mood, lyrics)" if use_understand else "librosa (BPM only)"
648
+ _log(f"[INFO] Auto-captioning via {method}...")
649
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
650
+ for audio_fname in sorted(os.listdir(audio_dir)):
651
+ if _training_cancel.is_set():
652
+ break
653
+ full_path = os.path.join(audio_dir, audio_fname)
654
+ if not os.path.isfile(full_path):
655
+ continue
656
+ ext = audio_fname.lower().rsplit(".", 1)[-1] if "." in audio_fname else ""
657
+ if ext in ("json", "txt"):
658
+ continue
659
+ stem = audio_fname.rsplit(".", 1)[0] if "." in audio_fname else audio_fname
660
+ sidecar_json = os.path.join(audio_dir, stem + ".json")
661
+ sidecar_txt = os.path.join(audio_dir, stem + ".txt")
662
+ if os.path.isfile(sidecar_json) or os.path.isfile(sidecar_txt):
663
+ _log(f" {audio_fname}: using caption file")
 
 
 
 
 
 
 
664
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
665
+ continue
666
+ caption_data = None
667
+ if use_understand:
668
+ _log(f" {audio_fname}: GGUF LM captioning...")
669
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
670
+ caption_data = _caption_via_understand(
671
+ full_path, timeout=600,
672
+ cancel_check=lambda: _training_cancel.is_set(),
673
+ )
674
+ if caption_data:
675
+ bpm_s = caption_data.get("bpm", "?")
676
+ key_s = caption_data.get("keyscale", caption_data.get("key", "?"))
677
+ _log(f" {audio_fname}: OK (BPM={bpm_s}, key={key_s})")
678
+ with open(sidecar_json, "w") as cj:
679
+ json.dump(caption_data, cj)
680
+ else:
681
+ try:
682
+ y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
683
+ tempo_arr, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
684
+ bpm_val = int(round(float(tempo_arr.item() if hasattr(tempo_arr, 'item') else tempo_arr)))
685
+ fallback = {"caption": "", "bpm": bpm_val, "key": "", "signature": "", "lyrics": ""}
686
+ with open(sidecar_json, "w") as cj:
687
+ json.dump(fallback, cj)
688
+ _log(f" {audio_fname}: librosa fallback BPM={bpm_val}")
689
+ except Exception as cap_exc:
690
+ _log(f" {audio_fname}: caption failed: {cap_exc}")
691
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
692
 
693
+ if _training_cancel.is_set():
694
+ _training_cancel.clear()
695
+ _log("[CANCELLED] Stopped")
696
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
697
+ shutil.rmtree(work_dir, ignore_errors=True)
698
+ return
699
+
700
  # Stop ace-server before training (frees memory)
701
  _training_lock.acquire()
702
  _log("[INFO] Stopping ace-server for training...")
 
704
  _stop_ace_server()
705
  _gc.collect()
706
 
707
+ _cleanup_done = False
708
  try:
709
+ # -- Phase 1: Preprocessing (runs in thread for live progress) --
 
 
 
710
  preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
711
+ _preprocess_log_len = len(_train_log_lines)
712
 
713
  def preprocess_progress(current, total, desc):
714
  _log(f" {desc} ({current}/{total})")
715
 
716
+ _preprocess_result = [None]
717
+ _preprocess_error = [None]
718
+
719
+ def _run_preprocess():
720
+ try:
721
+ _preprocess_result[0] = preprocess_audio(
722
+ audio_dir=audio_dir,
723
+ output_dir=preprocessed_dir,
724
+ checkpoint_dir=ACE_CHECKPOINT_DIR,
725
+ device="cpu",
726
+ variant="turbo",
727
+ max_duration=float(MAX_TOTAL_AUDIO),
728
+ progress_callback=preprocess_progress,
729
+ cancel_check=lambda: _training_cancel.is_set(),
730
+ )
731
+ except Exception as exc:
732
+ _preprocess_error[0] = exc
733
+
734
+ _log("[Step 1/2] Encoding audio → training data (VAE + text encoder)...")
735
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
736
 
737
+ t = threading.Thread(target=_run_preprocess, daemon=True)
738
+ t.start()
739
+ while t.is_alive():
740
+ t.join(timeout=3)
741
+ if len(_train_log_lines) > _preprocess_log_len:
742
+ _preprocess_log_len = len(_train_log_lines)
743
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
744
+
745
+ if _preprocess_error[0]:
746
+ raise _preprocess_error[0]
747
+ result = _preprocess_result[0]
748
+
749
+ if _training_cancel.is_set():
750
+ _training_cancel.clear()
751
+ _log("[CANCELLED] Stopped during preprocessing")
752
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
753
+ return
754
+
755
  processed = result.get("processed", 0)
756
  failed = result.get("failed", 0)
757
  total = result.get("total", 0)
 
789
  device="cpu",
790
  log_every=5,
791
  ):
 
792
  elapsed = time.time() - train_start
793
  if elapsed > MAX_TRAINING_TIME:
794
  _log(f"[WARN] Training timed out after {int(elapsed)}s")
 
804
  _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
805
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
806
 
807
+ except GeneratorExit:
808
+ _training_cancel.set()
809
+ logger.info("Generator closed by Gradio, cleaning up")
810
+ _cleanup_done = True
811
+ _training_lock.release()
812
+ _gc.collect()
813
+ _start_ace_server()
814
+ shutil.rmtree(work_dir, ignore_errors=True)
815
+ return
816
+
817
  except Exception as exc:
818
  _log(f"[FAIL] Training error: {exc}")
819
  import traceback
 
821
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
822
 
823
  finally:
824
+ if not _cleanup_done:
825
+ _training_lock.release()
826
+ _log("[INFO] Restarting ace-server...")
827
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
828
+ _gc.collect()
829
+ ok = _start_ace_server()
830
+ if ok:
831
+ _log("[OK] ace-server restarted successfully")
832
+ else:
833
+ _log("[WARN] ace-server may not have restarted -- check logs")
834
+ adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
835
+ if os.path.isfile(adapter_safetensors):
836
+ tmp_out = tempfile.NamedTemporaryFile(
837
+ suffix=".safetensors",
838
+ prefix=f"{lora_name}_",
839
+ delete=False,
840
+ )
841
+ tmp_out.close()
842
+ shutil.copy2(adapter_safetensors, tmp_out.name)
843
+ _log(f"[OK] LoRA saved: {lora_name}")
844
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
845
+ else:
846
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
847
+ shutil.rmtree(work_dir, ignore_errors=True)
 
 
 
 
 
 
 
 
848
 
849
  # -- Cancel handler --
850
  def _on_cancel():
851
  cancel_training()
852
  logger.info("Cancel requested by user")
853
+ return "Cancelling..."
 
 
 
 
 
 
854
 
855
  # -- Build LM model choices --
856
  def _lm_model_choices():
 
953
  with gr.Row(elem_classes="compact-row"):
954
  with gr.Column(scale=2):
955
  train_audio = gr.File(
956
+ label="Training Audio + Caption Files",
957
  file_count="multiple",
958
+ file_types=["audio", ".txt", ".json"],
959
  )
960
  with gr.Column(scale=1):
961
  lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
 
972
  with gr.Row(elem_classes="compact-row"):
973
  train_btn = gr.Button("Train", variant="primary", scale=2)
974
  cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
 
975
 
976
  train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
977
  train_log = gr.Textbox(
 
1018
  outputs=[train_log],
1019
  )
1020
 
 
 
 
 
 
 
 
1021
  demo.launch(
1022
  server_name="0.0.0.0",
1023
  server_port=7860,
train_engine.py CHANGED
@@ -2153,13 +2153,16 @@ def preprocess_audio(
2153
 
2154
  # Auto-caption: read existing sidecar or analyze
2155
  sidecar = _read_caption_sidecar(af)
2156
- if sidecar and sidecar.get("caption"):
2157
- caption = sidecar["caption"]
2158
  lyrics = sidecar.get("lyrics", "[Instrumental]")
2159
  logger.info("[Caption] %s: using existing sidecar", af.name)
2160
  else:
2161
  # Auto-select analysis mode based on dataset size
2162
- if total <= 20:
 
 
 
2163
  analysis_mode = "sas"
2164
  elif total <= 100:
2165
  analysis_mode = "mid"
@@ -2535,11 +2538,13 @@ def train_lora_generator(
2535
  # Cancel check
2536
  if _training_cancel.is_set():
2537
  _training_cancel.clear()
2538
- early_path = str(out_path / "early_exit")
2539
- model.decoder.eval()
2540
- save_lora_adapter(model, early_path)
2541
- model.decoder.train()
2542
- yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
 
 
2543
  yield "[DONE]"
2544
  _cuda_sync(device)
2545
  unload_models(model)
 
2153
 
2154
  # Auto-caption: read existing sidecar or analyze
2155
  sidecar = _read_caption_sidecar(af)
2156
+ if sidecar is not None:
2157
+ caption = sidecar.get("caption", "") or af.stem
2158
  lyrics = sidecar.get("lyrics", "[Instrumental]")
2159
  logger.info("[Caption] %s: using existing sidecar", af.name)
2160
  else:
2161
  # Auto-select analysis mode based on dataset size
2162
+ # mid/sas use Demucs stem separation — GPU only
2163
+ if device == "cpu":
2164
+ analysis_mode = "faf"
2165
+ elif total <= 20:
2166
  analysis_mode = "sas"
2167
  elif total <= 100:
2168
  analysis_mode = "mid"
 
2538
  # Cancel check
2539
  if _training_cancel.is_set():
2540
  _training_cancel.clear()
2541
+ if epoch > start_epoch:
2542
+ early_path = str(out_path / "early_exit")
2543
+ model.decoder.eval()
2544
+ save_lora_adapter(model, early_path)
2545
+ yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
2546
+ else:
2547
+ yield f"[CANCELLED] Stopped before any epoch completed"
2548
  yield "[DONE]"
2549
  _cuda_sync(device)
2550
  unload_models(model)