Spaces:
Running
Running
cancel, captioning, preprocessing, sidecar upload, elapsed time, GeneratorExit fix
Browse files- app.py +162 -126
- train_engine.py +13 -8
app.py
CHANGED
|
@@ -19,10 +19,12 @@ from train_engine import (
|
|
| 19 |
preprocess_audio,
|
| 20 |
train_lora_generator,
|
| 21 |
cancel_training,
|
|
|
|
| 22 |
get_trained_loras as _get_trained_loras_engine,
|
| 23 |
MAX_TRAINING_TIME,
|
| 24 |
)
|
| 25 |
|
|
|
|
| 26 |
logger = logging.getLogger(__name__)
|
| 27 |
|
| 28 |
# ---------------------------------------------------------------------------
|
|
@@ -93,12 +95,14 @@ def _get_props():
|
|
| 93 |
return {}
|
| 94 |
|
| 95 |
|
| 96 |
-
def _poll_job(job_id, timeout=600, progress_cb=None):
|
| 97 |
-
"""Poll a job until done/error/timeout. Returns (status, elapsed)."""
|
| 98 |
t0 = time.time()
|
| 99 |
while time.time() - t0 < timeout:
|
|
|
|
|
|
|
| 100 |
try:
|
| 101 |
-
r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=
|
| 102 |
data = r.json()
|
| 103 |
status = data.get("status", "unknown")
|
| 104 |
if progress_cb:
|
|
@@ -107,7 +111,7 @@ def _poll_job(job_id, timeout=600, progress_cb=None):
|
|
| 107 |
return status, time.time() - t0
|
| 108 |
except Exception:
|
| 109 |
pass
|
| 110 |
-
time.sleep(
|
| 111 |
return "timeout", time.time() - t0
|
| 112 |
|
| 113 |
|
|
@@ -121,58 +125,41 @@ def _fetch_result(job_id, timeout=60):
|
|
| 121 |
return r
|
| 122 |
|
| 123 |
|
| 124 |
-
def _caption_via_understand(audio_path, timeout=120):
|
| 125 |
-
"""Call ace-server /understand to get a rich caption for an audio file.
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
"""
|
| 130 |
fname = os.path.basename(audio_path)
|
| 131 |
try:
|
| 132 |
with open(audio_path, "rb") as f:
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# Submit
|
| 139 |
-
try:
|
| 140 |
-
r = requests.post(
|
| 141 |
-
f"{ACE_SERVER}/understand",
|
| 142 |
-
json={"audio": audio_b64},
|
| 143 |
-
timeout=30,
|
| 144 |
-
)
|
| 145 |
if r.status_code != 200:
|
| 146 |
-
logger.warning("[Caption] %s: /understand
|
| 147 |
return None
|
| 148 |
job_id = r.json().get("id")
|
| 149 |
if not job_id:
|
| 150 |
-
logger.warning("[Caption] %s: /understand returned no job id", fname)
|
| 151 |
return None
|
| 152 |
except Exception as exc:
|
| 153 |
logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
|
| 154 |
return None
|
| 155 |
|
| 156 |
-
|
| 157 |
-
status, _ = _poll_job(job_id, timeout=timeout)
|
| 158 |
if status != "done":
|
| 159 |
-
logger.warning("[Caption] %s: /understand
|
| 160 |
return None
|
| 161 |
|
| 162 |
-
# Fetch result
|
| 163 |
try:
|
| 164 |
r = _fetch_result(job_id, timeout=30)
|
| 165 |
if r.status_code != 200:
|
| 166 |
-
logger.warning("[Caption] %s: /understand result fetch failed: %d", fname, r.status_code)
|
| 167 |
return None
|
| 168 |
data = r.json()
|
| 169 |
-
# The result should contain caption, bpm, key, signature, lyrics
|
| 170 |
if isinstance(data, dict) and data.get("caption"):
|
| 171 |
return data
|
| 172 |
-
logger.warning("[Caption] %s: /understand returned no caption field", fname)
|
| 173 |
return None
|
| 174 |
-
except Exception
|
| 175 |
-
logger.warning("[Caption] %s: /understand result parse failed: %s", fname, exc)
|
| 176 |
return None
|
| 177 |
|
| 178 |
|
|
@@ -559,7 +546,13 @@ def gradio_main():
|
|
| 559 |
train_start = time.time()
|
| 560 |
|
| 561 |
def _log(msg):
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
if len(_train_log_lines) > 2000:
|
| 564 |
_train_log_lines[:] = _train_log_lines[-1000:]
|
| 565 |
|
|
@@ -587,7 +580,9 @@ def gradio_main():
|
|
| 587 |
work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name)
|
| 588 |
os.makedirs(work_dir, exist_ok=True)
|
| 589 |
audio_dir = os.path.join(work_dir, "audio_input")
|
| 590 |
-
os.
|
|
|
|
|
|
|
| 591 |
adapter_out = os.path.join(ADAPTER_DIR, lora_name)
|
| 592 |
os.makedirs(adapter_out, exist_ok=True)
|
| 593 |
|
|
@@ -603,6 +598,10 @@ def gradio_main():
|
|
| 603 |
for f in audio_files:
|
| 604 |
src = f.name if hasattr(f, "name") else str(f)
|
| 605 |
fname = os.path.basename(src)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 606 |
try:
|
| 607 |
dur = _lr.get_duration(path=src)
|
| 608 |
except Exception:
|
|
@@ -643,37 +642,61 @@ def gradio_main():
|
|
| 643 |
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 644 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 645 |
|
| 646 |
-
# Caption
|
| 647 |
-
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
tempo, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
|
| 666 |
-
bpm_val = float(tempo) if hasattr(tempo, '__float__') else float(tempo[0])
|
| 667 |
-
fallback = {"caption": "", "bpm": round(bpm_val), "key": "", "signature": "", "lyrics": ""}
|
| 668 |
-
with open(caption_json_path, "w") as cj:
|
| 669 |
-
json.dump(fallback, cj)
|
| 670 |
-
except Exception as cap_exc:
|
| 671 |
-
_log(f"[Caption] {audio_fname}: librosa fallback also failed: {cap_exc}")
|
| 672 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 673 |
-
|
| 674 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 675 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 676 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
# Stop ace-server before training (frees memory)
|
| 678 |
_training_lock.acquire()
|
| 679 |
_log("[INFO] Stopping ace-server for training...")
|
|
@@ -681,28 +704,54 @@ def gradio_main():
|
|
| 681 |
_stop_ace_server()
|
| 682 |
_gc.collect()
|
| 683 |
|
|
|
|
| 684 |
try:
|
| 685 |
-
# -- Phase 1: Preprocessing --
|
| 686 |
-
_log("[Step 1/2] Preprocessing audio...")
|
| 687 |
-
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 688 |
-
|
| 689 |
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
|
|
|
|
| 690 |
|
| 691 |
def preprocess_progress(current, total, desc):
|
| 692 |
_log(f" {desc} ({current}/{total})")
|
| 693 |
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 704 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 705 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 706 |
processed = result.get("processed", 0)
|
| 707 |
failed = result.get("failed", 0)
|
| 708 |
total = result.get("total", 0)
|
|
@@ -740,7 +789,6 @@ def gradio_main():
|
|
| 740 |
device="cpu",
|
| 741 |
log_every=5,
|
| 742 |
):
|
| 743 |
-
# Timeout check
|
| 744 |
elapsed = time.time() - train_start
|
| 745 |
if elapsed > MAX_TRAINING_TIME:
|
| 746 |
_log(f"[WARN] Training timed out after {int(elapsed)}s")
|
|
@@ -756,6 +804,16 @@ def gradio_main():
|
|
| 756 |
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 757 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 758 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
except Exception as exc:
|
| 760 |
_log(f"[FAIL] Training error: {exc}")
|
| 761 |
import traceback
|
|
@@ -763,50 +821,36 @@ def gradio_main():
|
|
| 763 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 764 |
|
| 765 |
finally:
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
|
| 782 |
-
|
| 783 |
-
|
| 784 |
-
|
| 785 |
-
|
| 786 |
-
|
| 787 |
-
|
| 788 |
-
|
| 789 |
-
)
|
| 790 |
-
tmp_out.close()
|
| 791 |
-
shutil.copy2(adapter_safetensors, tmp_out.name)
|
| 792 |
-
_log(f"[OK] LoRA saved: {lora_name}")
|
| 793 |
-
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
|
| 794 |
-
else:
|
| 795 |
-
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 796 |
-
# Clean up training workspace (preprocessed tensors, temp audio, etc.)
|
| 797 |
-
shutil.rmtree(work_dir, ignore_errors=True)
|
| 798 |
|
| 799 |
# -- Cancel handler --
|
| 800 |
def _on_cancel():
|
| 801 |
cancel_training()
|
| 802 |
logger.info("Cancel requested by user")
|
| 803 |
-
return "Cancelling
|
| 804 |
-
|
| 805 |
-
# -- Check log handler --
|
| 806 |
-
def _check_log():
|
| 807 |
-
if _train_log_lines:
|
| 808 |
-
return "\n".join(_train_log_lines)
|
| 809 |
-
return "No training log available."
|
| 810 |
|
| 811 |
# -- Build LM model choices --
|
| 812 |
def _lm_model_choices():
|
|
@@ -909,9 +953,9 @@ def gradio_main():
|
|
| 909 |
with gr.Row(elem_classes="compact-row"):
|
| 910 |
with gr.Column(scale=2):
|
| 911 |
train_audio = gr.File(
|
| 912 |
-
label="Training Audio Files",
|
| 913 |
file_count="multiple",
|
| 914 |
-
file_types=["audio"],
|
| 915 |
)
|
| 916 |
with gr.Column(scale=1):
|
| 917 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
|
@@ -928,7 +972,6 @@ def gradio_main():
|
|
| 928 |
with gr.Row(elem_classes="compact-row"):
|
| 929 |
train_btn = gr.Button("Train", variant="primary", scale=2)
|
| 930 |
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
|
| 931 |
-
log_btn = gr.Button("Check Log", scale=1)
|
| 932 |
|
| 933 |
train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
|
| 934 |
train_log = gr.Textbox(
|
|
@@ -975,13 +1018,6 @@ def gradio_main():
|
|
| 975 |
outputs=[train_log],
|
| 976 |
)
|
| 977 |
|
| 978 |
-
# Check log: show last training output
|
| 979 |
-
log_btn.click(
|
| 980 |
-
_check_log,
|
| 981 |
-
outputs=[train_log],
|
| 982 |
-
api_name="check_log",
|
| 983 |
-
)
|
| 984 |
-
|
| 985 |
demo.launch(
|
| 986 |
server_name="0.0.0.0",
|
| 987 |
server_port=7860,
|
|
|
|
| 19 |
preprocess_audio,
|
| 20 |
train_lora_generator,
|
| 21 |
cancel_training,
|
| 22 |
+
_training_cancel,
|
| 23 |
get_trained_loras as _get_trained_loras_engine,
|
| 24 |
MAX_TRAINING_TIME,
|
| 25 |
)
|
| 26 |
|
| 27 |
+
logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stdout)
|
| 28 |
logger = logging.getLogger(__name__)
|
| 29 |
|
| 30 |
# ---------------------------------------------------------------------------
|
|
|
|
| 95 |
return {}
|
| 96 |
|
| 97 |
|
| 98 |
+
def _poll_job(job_id, timeout=600, progress_cb=None, cancel_check=None):
|
| 99 |
+
"""Poll a job until done/error/timeout/cancelled. Returns (status, elapsed)."""
|
| 100 |
t0 = time.time()
|
| 101 |
while time.time() - t0 < timeout:
|
| 102 |
+
if cancel_check and cancel_check():
|
| 103 |
+
return "cancelled", time.time() - t0
|
| 104 |
try:
|
| 105 |
+
r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=5)
|
| 106 |
data = r.json()
|
| 107 |
status = data.get("status", "unknown")
|
| 108 |
if progress_cb:
|
|
|
|
| 111 |
return status, time.time() - t0
|
| 112 |
except Exception:
|
| 113 |
pass
|
| 114 |
+
time.sleep(1)
|
| 115 |
return "timeout", time.time() - t0
|
| 116 |
|
| 117 |
|
|
|
|
| 125 |
return r
|
| 126 |
|
| 127 |
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
def _caption_via_understand(audio_path, timeout=600, cancel_check=None):
|
| 130 |
+
"""Call ace-server /understand for a rich caption. Returns dict or None."""
|
|
|
|
| 131 |
fname = os.path.basename(audio_path)
|
| 132 |
try:
|
| 133 |
with open(audio_path, "rb") as f:
|
| 134 |
+
r = requests.post(
|
| 135 |
+
f"{ACE_SERVER}/understand",
|
| 136 |
+
files={"audio": (fname, f, "audio/mpeg")},
|
| 137 |
+
timeout=30,
|
| 138 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
if r.status_code != 200:
|
| 140 |
+
logger.warning("[Caption] %s: /understand %d: %s", fname, r.status_code, r.text[:200])
|
| 141 |
return None
|
| 142 |
job_id = r.json().get("id")
|
| 143 |
if not job_id:
|
|
|
|
| 144 |
return None
|
| 145 |
except Exception as exc:
|
| 146 |
logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
|
| 147 |
return None
|
| 148 |
|
| 149 |
+
status, elapsed = _poll_job(job_id, timeout=timeout, cancel_check=cancel_check)
|
|
|
|
| 150 |
if status != "done":
|
| 151 |
+
logger.warning("[Caption] %s: /understand -> %s (%.0fs)", fname, status, elapsed)
|
| 152 |
return None
|
| 153 |
|
|
|
|
| 154 |
try:
|
| 155 |
r = _fetch_result(job_id, timeout=30)
|
| 156 |
if r.status_code != 200:
|
|
|
|
| 157 |
return None
|
| 158 |
data = r.json()
|
|
|
|
| 159 |
if isinstance(data, dict) and data.get("caption"):
|
| 160 |
return data
|
|
|
|
| 161 |
return None
|
| 162 |
+
except Exception:
|
|
|
|
| 163 |
return None
|
| 164 |
|
| 165 |
|
|
|
|
| 546 |
train_start = time.time()
|
| 547 |
|
| 548 |
def _log(msg):
|
| 549 |
+
elapsed = int(time.time() - train_start)
|
| 550 |
+
m, s = divmod(elapsed, 60)
|
| 551 |
+
h, m = divmod(m, 60)
|
| 552 |
+
ts = f"+{h}:{m:02d}:{s:02d}" if h else f"+{m:02d}:{s:02d}"
|
| 553 |
+
line = f"[{ts}] {msg}"
|
| 554 |
+
_train_log_lines.append(line)
|
| 555 |
+
logger.info(msg)
|
| 556 |
if len(_train_log_lines) > 2000:
|
| 557 |
_train_log_lines[:] = _train_log_lines[-1000:]
|
| 558 |
|
|
|
|
| 580 |
work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name)
|
| 581 |
os.makedirs(work_dir, exist_ok=True)
|
| 582 |
audio_dir = os.path.join(work_dir, "audio_input")
|
| 583 |
+
if os.path.exists(audio_dir):
|
| 584 |
+
shutil.rmtree(audio_dir)
|
| 585 |
+
os.makedirs(audio_dir)
|
| 586 |
adapter_out = os.path.join(ADAPTER_DIR, lora_name)
|
| 587 |
os.makedirs(adapter_out, exist_ok=True)
|
| 588 |
|
|
|
|
| 598 |
for f in audio_files:
|
| 599 |
src = f.name if hasattr(f, "name") else str(f)
|
| 600 |
fname = os.path.basename(src)
|
| 601 |
+
# .txt/.json sidecars: copy as caption files, skip duration check
|
| 602 |
+
if fname.lower().endswith((".txt", ".json")):
|
| 603 |
+
shutil.copy2(src, os.path.join(audio_dir, fname))
|
| 604 |
+
continue
|
| 605 |
try:
|
| 606 |
dur = _lr.get_duration(path=src)
|
| 607 |
except Exception:
|
|
|
|
| 642 |
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 643 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 644 |
|
| 645 |
+
# Caption audio files: GGUF LM if ace-server running, else librosa
|
| 646 |
+
use_understand = _server_ok()
|
| 647 |
+
method = "GGUF LM (BPM, key, mood, lyrics)" if use_understand else "librosa (BPM only)"
|
| 648 |
+
_log(f"[INFO] Auto-captioning via {method}...")
|
| 649 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 650 |
+
for audio_fname in sorted(os.listdir(audio_dir)):
|
| 651 |
+
if _training_cancel.is_set():
|
| 652 |
+
break
|
| 653 |
+
full_path = os.path.join(audio_dir, audio_fname)
|
| 654 |
+
if not os.path.isfile(full_path):
|
| 655 |
+
continue
|
| 656 |
+
ext = audio_fname.lower().rsplit(".", 1)[-1] if "." in audio_fname else ""
|
| 657 |
+
if ext in ("json", "txt"):
|
| 658 |
+
continue
|
| 659 |
+
stem = audio_fname.rsplit(".", 1)[0] if "." in audio_fname else audio_fname
|
| 660 |
+
sidecar_json = os.path.join(audio_dir, stem + ".json")
|
| 661 |
+
sidecar_txt = os.path.join(audio_dir, stem + ".txt")
|
| 662 |
+
if os.path.isfile(sidecar_json) or os.path.isfile(sidecar_txt):
|
| 663 |
+
_log(f" {audio_fname}: using caption file")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 665 |
+
continue
|
| 666 |
+
caption_data = None
|
| 667 |
+
if use_understand:
|
| 668 |
+
_log(f" {audio_fname}: GGUF LM captioning...")
|
| 669 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 670 |
+
caption_data = _caption_via_understand(
|
| 671 |
+
full_path, timeout=600,
|
| 672 |
+
cancel_check=lambda: _training_cancel.is_set(),
|
| 673 |
+
)
|
| 674 |
+
if caption_data:
|
| 675 |
+
bpm_s = caption_data.get("bpm", "?")
|
| 676 |
+
key_s = caption_data.get("keyscale", caption_data.get("key", "?"))
|
| 677 |
+
_log(f" {audio_fname}: OK (BPM={bpm_s}, key={key_s})")
|
| 678 |
+
with open(sidecar_json, "w") as cj:
|
| 679 |
+
json.dump(caption_data, cj)
|
| 680 |
+
else:
|
| 681 |
+
try:
|
| 682 |
+
y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
|
| 683 |
+
tempo_arr, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
|
| 684 |
+
bpm_val = int(round(float(tempo_arr.item() if hasattr(tempo_arr, 'item') else tempo_arr)))
|
| 685 |
+
fallback = {"caption": "", "bpm": bpm_val, "key": "", "signature": "", "lyrics": ""}
|
| 686 |
+
with open(sidecar_json, "w") as cj:
|
| 687 |
+
json.dump(fallback, cj)
|
| 688 |
+
_log(f" {audio_fname}: librosa fallback BPM={bpm_val}")
|
| 689 |
+
except Exception as cap_exc:
|
| 690 |
+
_log(f" {audio_fname}: caption failed: {cap_exc}")
|
| 691 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 692 |
|
| 693 |
+
if _training_cancel.is_set():
|
| 694 |
+
_training_cancel.clear()
|
| 695 |
+
_log("[CANCELLED] Stopped")
|
| 696 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 697 |
+
shutil.rmtree(work_dir, ignore_errors=True)
|
| 698 |
+
return
|
| 699 |
+
|
| 700 |
# Stop ace-server before training (frees memory)
|
| 701 |
_training_lock.acquire()
|
| 702 |
_log("[INFO] Stopping ace-server for training...")
|
|
|
|
| 704 |
_stop_ace_server()
|
| 705 |
_gc.collect()
|
| 706 |
|
| 707 |
+
_cleanup_done = False
|
| 708 |
try:
|
| 709 |
+
# -- Phase 1: Preprocessing (runs in thread for live progress) --
|
|
|
|
|
|
|
|
|
|
| 710 |
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
|
| 711 |
+
_preprocess_log_len = len(_train_log_lines)
|
| 712 |
|
| 713 |
def preprocess_progress(current, total, desc):
|
| 714 |
_log(f" {desc} ({current}/{total})")
|
| 715 |
|
| 716 |
+
_preprocess_result = [None]
|
| 717 |
+
_preprocess_error = [None]
|
| 718 |
+
|
| 719 |
+
def _run_preprocess():
|
| 720 |
+
try:
|
| 721 |
+
_preprocess_result[0] = preprocess_audio(
|
| 722 |
+
audio_dir=audio_dir,
|
| 723 |
+
output_dir=preprocessed_dir,
|
| 724 |
+
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 725 |
+
device="cpu",
|
| 726 |
+
variant="turbo",
|
| 727 |
+
max_duration=float(MAX_TOTAL_AUDIO),
|
| 728 |
+
progress_callback=preprocess_progress,
|
| 729 |
+
cancel_check=lambda: _training_cancel.is_set(),
|
| 730 |
+
)
|
| 731 |
+
except Exception as exc:
|
| 732 |
+
_preprocess_error[0] = exc
|
| 733 |
+
|
| 734 |
+
_log("[Step 1/2] Encoding audio → training data (VAE + text encoder)...")
|
| 735 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 736 |
|
| 737 |
+
t = threading.Thread(target=_run_preprocess, daemon=True)
|
| 738 |
+
t.start()
|
| 739 |
+
while t.is_alive():
|
| 740 |
+
t.join(timeout=3)
|
| 741 |
+
if len(_train_log_lines) > _preprocess_log_len:
|
| 742 |
+
_preprocess_log_len = len(_train_log_lines)
|
| 743 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 744 |
+
|
| 745 |
+
if _preprocess_error[0]:
|
| 746 |
+
raise _preprocess_error[0]
|
| 747 |
+
result = _preprocess_result[0]
|
| 748 |
+
|
| 749 |
+
if _training_cancel.is_set():
|
| 750 |
+
_training_cancel.clear()
|
| 751 |
+
_log("[CANCELLED] Stopped during preprocessing")
|
| 752 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 753 |
+
return
|
| 754 |
+
|
| 755 |
processed = result.get("processed", 0)
|
| 756 |
failed = result.get("failed", 0)
|
| 757 |
total = result.get("total", 0)
|
|
|
|
| 789 |
device="cpu",
|
| 790 |
log_every=5,
|
| 791 |
):
|
|
|
|
| 792 |
elapsed = time.time() - train_start
|
| 793 |
if elapsed > MAX_TRAINING_TIME:
|
| 794 |
_log(f"[WARN] Training timed out after {int(elapsed)}s")
|
|
|
|
| 804 |
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 805 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 806 |
|
| 807 |
+
except GeneratorExit:
|
| 808 |
+
_training_cancel.set()
|
| 809 |
+
logger.info("Generator closed by Gradio, cleaning up")
|
| 810 |
+
_cleanup_done = True
|
| 811 |
+
_training_lock.release()
|
| 812 |
+
_gc.collect()
|
| 813 |
+
_start_ace_server()
|
| 814 |
+
shutil.rmtree(work_dir, ignore_errors=True)
|
| 815 |
+
return
|
| 816 |
+
|
| 817 |
except Exception as exc:
|
| 818 |
_log(f"[FAIL] Training error: {exc}")
|
| 819 |
import traceback
|
|
|
|
| 821 |
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 822 |
|
| 823 |
finally:
|
| 824 |
+
if not _cleanup_done:
|
| 825 |
+
_training_lock.release()
|
| 826 |
+
_log("[INFO] Restarting ace-server...")
|
| 827 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 828 |
+
_gc.collect()
|
| 829 |
+
ok = _start_ace_server()
|
| 830 |
+
if ok:
|
| 831 |
+
_log("[OK] ace-server restarted successfully")
|
| 832 |
+
else:
|
| 833 |
+
_log("[WARN] ace-server may not have restarted -- check logs")
|
| 834 |
+
adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
|
| 835 |
+
if os.path.isfile(adapter_safetensors):
|
| 836 |
+
tmp_out = tempfile.NamedTemporaryFile(
|
| 837 |
+
suffix=".safetensors",
|
| 838 |
+
prefix=f"{lora_name}_",
|
| 839 |
+
delete=False,
|
| 840 |
+
)
|
| 841 |
+
tmp_out.close()
|
| 842 |
+
shutil.copy2(adapter_safetensors, tmp_out.name)
|
| 843 |
+
_log(f"[OK] LoRA saved: {lora_name}")
|
| 844 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
|
| 845 |
+
else:
|
| 846 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 847 |
+
shutil.rmtree(work_dir, ignore_errors=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 848 |
|
| 849 |
# -- Cancel handler --
|
| 850 |
def _on_cancel():
|
| 851 |
cancel_training()
|
| 852 |
logger.info("Cancel requested by user")
|
| 853 |
+
return "Cancelling..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 854 |
|
| 855 |
# -- Build LM model choices --
|
| 856 |
def _lm_model_choices():
|
|
|
|
| 953 |
with gr.Row(elem_classes="compact-row"):
|
| 954 |
with gr.Column(scale=2):
|
| 955 |
train_audio = gr.File(
|
| 956 |
+
label="Training Audio + Caption Files",
|
| 957 |
file_count="multiple",
|
| 958 |
+
file_types=["audio", ".txt", ".json"],
|
| 959 |
)
|
| 960 |
with gr.Column(scale=1):
|
| 961 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
|
|
|
| 972 |
with gr.Row(elem_classes="compact-row"):
|
| 973 |
train_btn = gr.Button("Train", variant="primary", scale=2)
|
| 974 |
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
|
|
|
|
| 975 |
|
| 976 |
train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
|
| 977 |
train_log = gr.Textbox(
|
|
|
|
| 1018 |
outputs=[train_log],
|
| 1019 |
)
|
| 1020 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1021 |
demo.launch(
|
| 1022 |
server_name="0.0.0.0",
|
| 1023 |
server_port=7860,
|
train_engine.py
CHANGED
|
@@ -2153,13 +2153,16 @@ def preprocess_audio(
|
|
| 2153 |
|
| 2154 |
# Auto-caption: read existing sidecar or analyze
|
| 2155 |
sidecar = _read_caption_sidecar(af)
|
| 2156 |
-
if sidecar
|
| 2157 |
-
caption = sidecar
|
| 2158 |
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 2159 |
logger.info("[Caption] %s: using existing sidecar", af.name)
|
| 2160 |
else:
|
| 2161 |
# Auto-select analysis mode based on dataset size
|
| 2162 |
-
|
|
|
|
|
|
|
|
|
|
| 2163 |
analysis_mode = "sas"
|
| 2164 |
elif total <= 100:
|
| 2165 |
analysis_mode = "mid"
|
|
@@ -2535,11 +2538,13 @@ def train_lora_generator(
|
|
| 2535 |
# Cancel check
|
| 2536 |
if _training_cancel.is_set():
|
| 2537 |
_training_cancel.clear()
|
| 2538 |
-
|
| 2539 |
-
|
| 2540 |
-
|
| 2541 |
-
|
| 2542 |
-
|
|
|
|
|
|
|
| 2543 |
yield "[DONE]"
|
| 2544 |
_cuda_sync(device)
|
| 2545 |
unload_models(model)
|
|
|
|
| 2153 |
|
| 2154 |
# Auto-caption: read existing sidecar or analyze
|
| 2155 |
sidecar = _read_caption_sidecar(af)
|
| 2156 |
+
if sidecar is not None:
|
| 2157 |
+
caption = sidecar.get("caption", "") or af.stem
|
| 2158 |
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 2159 |
logger.info("[Caption] %s: using existing sidecar", af.name)
|
| 2160 |
else:
|
| 2161 |
# Auto-select analysis mode based on dataset size
|
| 2162 |
+
# mid/sas use Demucs stem separation — GPU only
|
| 2163 |
+
if device == "cpu":
|
| 2164 |
+
analysis_mode = "faf"
|
| 2165 |
+
elif total <= 20:
|
| 2166 |
analysis_mode = "sas"
|
| 2167 |
elif total <= 100:
|
| 2168 |
analysis_mode = "mid"
|
|
|
|
| 2538 |
# Cancel check
|
| 2539 |
if _training_cancel.is_set():
|
| 2540 |
_training_cancel.clear()
|
| 2541 |
+
if epoch > start_epoch:
|
| 2542 |
+
early_path = str(out_path / "early_exit")
|
| 2543 |
+
model.decoder.eval()
|
| 2544 |
+
save_lora_adapter(model, early_path)
|
| 2545 |
+
yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
|
| 2546 |
+
else:
|
| 2547 |
+
yield f"[CANCELLED] Stopped before any epoch completed"
|
| 2548 |
yield "[DONE]"
|
| 2549 |
_cuda_sync(device)
|
| 2550 |
unload_models(model)
|