Andrew commited on
Commit
bd37cca
·
1 Parent(s): a3ab20b

github push

Browse files
.env.example ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ HF_TOKEN=hf_xxx_your_token_here
2
+ HF_ENDPOINT_URL=https://your-endpoint-url.endpoints.huggingface.cloud
3
+
4
+ # Optional defaults used by scripts/hf_clone.py
5
+ HF_USERNAME=your-hf-username
.gitignore CHANGED
@@ -1,4 +1,66 @@
 
1
  .env
2
- *.bat
3
- *.ps1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  *.wav
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment
2
  .env
3
+ .env.*
4
+ !.env.example
5
+
6
+ # Python cache/build artifacts
7
+ __pycache__/
8
+ *.py[cod]
9
+ *.pyo
10
+ *.pyd
11
+ .pytest_cache/
12
+ .mypy_cache/
13
+ .ruff_cache/
14
+ .ipynb_checkpoints/
15
+ .coverage
16
+ htmlcov/
17
+ build/
18
+ dist/
19
+ *.egg-info/
20
+
21
+ # Virtual environments
22
+ .venv/
23
+ venv/
24
+ env/
25
+
26
+ # Tool and local runtime caches
27
+ .cache/
28
+ .huggingface/
29
+ .gradio/
30
+
31
+ # Logs/temp
32
+ *.log
33
+ *.tmp
34
+ *.temp
35
+ *.bak
36
+
37
+ # Model/data/runtime artifacts
38
+ checkpoints/
39
+ lora_output/
40
+ outputs/
41
+ artifacts/
42
+ models/
43
+ datasets/
44
+ /data/
45
  *.wav
46
+ *.flac
47
+ *.mp3
48
+ *.ogg
49
+ *.opus
50
+ *.m4a
51
+ *.aac
52
+ *.pt
53
+ *.bin
54
+ *.safetensors
55
+ *.ckpt
56
+ *.onnx
57
+
58
+ # OS/editor
59
+ .DS_Store
60
+ Thumbs.db
61
+ .idea/
62
+ .vscode/
63
+
64
+ # Optional local working copies
65
+ Lora-ace-step/
66
+ song_summaries_llm*.md
CONTRIBUTING.md ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing
2
+
3
+ ## Development Setup
4
+
5
+ ```bash
6
+ python -m pip install --upgrade pip
7
+ python -m pip install -r requirements.txt
8
+ python app.py
9
+ ```
10
+
11
+ ## Before Opening A PR
12
+
13
+ 1. Keep secrets out of git (`HF_TOKEN`, endpoint URLs, `.env`).
14
+ 2. Do not commit local artifacts (`checkpoints/`, `lora_output/`, generated audio).
15
+ 3. Run quick CLI sanity checks:
16
+ - `python lora_train.py --help`
17
+ - `python scripts/hf_clone.py --help`
18
+ - `python scripts/endpoint/generate_interactive.py --help`
19
+ 4. Update docs (`README.md`, `docs/deploy/*`) if behavior or workflows changed.
20
+
21
+ ## Scope Guidelines
22
+
23
+ - UI + training workflow changes belong in `lora_ui.py` / `lora_train.py`.
24
+ - Inference endpoint changes belong in `handler.py`.
25
+ - Shared ACE-Step runtime logic belongs in `acestep/`.
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 ACE-Step LoRA Studio contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ACE-Step 1.5 LoRA Studio
3
+ emoji: music
4
+ colorFrom: blue
5
+ colorTo: teal
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+
11
+ # ACE-Step 1.5 LoRA Studio
12
+
13
+ Train ACE-Step 1.5 LoRA adapters, deploy your own Hugging Face Space, and run production-style inference through a Dedicated Endpoint.
14
+
15
+ [![Create HF Space](https://img.shields.io/badge/Create-HF%20Space-FFD21E?logo=huggingface&logoColor=black)](https://huggingface.co/new-space)
16
+ [![Create HF Endpoint Repo](https://img.shields.io/badge/Create-HF%20Endpoint%20Repo-FFB000?logo=huggingface&logoColor=black)](https://huggingface.co/new-model)
17
+ [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE)
18
+
19
+ ## What you get
20
+
21
+ - LoRA training UI and workflow: `app.py`, `lora_ui.py`
22
+ - CLI LoRA trainer for local/HF datasets: `lora_train.py`
23
+ - Custom endpoint runtime: `handler.py`, `acestep/`
24
+ - Bootstrap automation for cloning into your HF account: `scripts/hf_clone.py`
25
+ - Endpoint test clients and HF job launcher: `scripts/endpoint/`, `scripts/jobs/`
26
+
27
+ ## Quick start (local)
28
+
29
+ ```bash
30
+ python -m pip install --upgrade pip
31
+ python -m pip install -r requirements.txt
32
+ python app.py
33
+ ```
34
+
35
+ Open `http://localhost:7860`.
36
+
37
+ ## Clone to your HF account
38
+
39
+ Use the two buttons near the top of this README to create target repos in your HF account, then run:
40
+
41
+ Set token once:
42
+
43
+ ```bash
44
+ # Linux/macOS
45
+ export HF_TOKEN=hf_xxx
46
+
47
+ # Windows PowerShell
48
+ $env:HF_TOKEN="hf_xxx"
49
+ ```
50
+
51
+ Clone your own Space:
52
+
53
+ ```bash
54
+ python scripts/hf_clone.py space --repo-id YOUR_USERNAME/YOUR_SPACE_NAME
55
+ ```
56
+
57
+ Clone your own Endpoint repo:
58
+
59
+ ```bash
60
+ python scripts/hf_clone.py endpoint --repo-id YOUR_USERNAME/YOUR_ENDPOINT_REPO
61
+ ```
62
+
63
+ Clone both in one run:
64
+
65
+ ```bash
66
+ python scripts/hf_clone.py all \
67
+ --space-repo-id YOUR_USERNAME/YOUR_SPACE_NAME \
68
+ --endpoint-repo-id YOUR_USERNAME/YOUR_ENDPOINT_REPO
69
+ ```
70
+
71
+ ## Project layout
72
+
73
+ ```text
74
+ .
75
+ |- app.py
76
+ |- lora_ui.py
77
+ |- lora_train.py
78
+ |- handler.py
79
+ |- acestep/
80
+ |- scripts/
81
+ | |- hf_clone.py
82
+ | |- endpoint/
83
+ | | |- generate_interactive.py
84
+ | | |- test.ps1
85
+ | | |- test.bat
86
+ | | |- test_rnb.bat
87
+ | | `- test_rnb_2min.bat
88
+ | `- jobs/
89
+ | `- submit_hf_lora_job.ps1
90
+ |- docs/
91
+ | |- deploy/
92
+ | `- guides/
93
+ |- summaries/
94
+ | `- findings.md
95
+ `- templates/hf-endpoint/
96
+ ```
97
+
98
+ ## Dataset format
99
+
100
+ Supported audio:
101
+
102
+ - `.wav`, `.flac`, `.mp3`, `.ogg`, `.opus`, `.m4a`, `.aac`
103
+
104
+ Optional sidecar metadata per track:
105
+
106
+ - `song_001.wav`
107
+ - `song_001.json`
108
+
109
+ ```json
110
+ {
111
+ "caption": "melodic emotional rnb pop with warm pads",
112
+ "lyrics": "[Verse]\\n...",
113
+ "bpm": 92,
114
+ "keyscale": "Am",
115
+ "timesignature": "4/4",
116
+ "vocal_language": "en",
117
+ "duration": 120
118
+ }
119
+ ```
120
+
121
+ ## Endpoint testing
122
+
123
+ ```bash
124
+ python scripts/endpoint/generate_interactive.py
125
+ ```
126
+
127
+ Or run scripted tests:
128
+
129
+ - `scripts/endpoint/test.ps1`
130
+ - `scripts/endpoint/test.bat`
131
+
132
+ ## Findings and notes
133
+
134
+ Current baseline analysis and improvement ideas are tracked in `summaries/findings.md`.
135
+
136
+ ## Docs
137
+
138
+ - Space deployment: `docs/deploy/SPACE.md`
139
+ - Endpoint deployment: `docs/deploy/ENDPOINT.md`
140
+ - Additional guides: `docs/guides/qwen2-audio-train.md`
141
+
142
+ ## Open-source readiness checklist
143
+
144
+ - Secrets are env-driven (`HF_TOKEN`, `HF_ENDPOINT_URL`, `.env`).
145
+ - Local artifacts are ignored via `.gitignore`.
146
+ - MIT license included.
147
+ - Reproducible clone/deploy paths documented.
acestep/handler.py CHANGED
@@ -24,6 +24,7 @@ from typing import Optional, Dict, Any, Tuple, List, Union
24
  import torch
25
  import torchaudio
26
  import soundfile as sf
 
27
  import time
28
  from tqdm import tqdm
29
  from loguru import logger
@@ -1655,7 +1656,7 @@ class AceStepHandler:
1655
 
1656
  try:
1657
  # Load audio file
1658
- audio, sr = torchaudio.load(audio_file)
1659
 
1660
  logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
1661
  logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
@@ -1710,7 +1711,7 @@ class AceStepHandler:
1710
 
1711
  try:
1712
  # Load audio file
1713
- audio, sr = torchaudio.load(audio_file)
1714
 
1715
  # Normalize to stereo 48kHz
1716
  audio = self._normalize_audio_to_stereo_48k(audio, sr)
@@ -1720,6 +1721,44 @@ class AceStepHandler:
1720
  except Exception as e:
1721
  logger.exception("[process_src_audio] Error processing source audio")
1722
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1723
 
1724
  def convert_src_audio_to_codes(self, audio_file) -> str:
1725
  """
 
24
  import torch
25
  import torchaudio
26
  import soundfile as sf
27
+ import numpy as np
28
  import time
29
  from tqdm import tqdm
30
  from loguru import logger
 
1656
 
1657
  try:
1658
  # Load audio file
1659
+ audio, sr = self._load_audio_any_backend(audio_file)
1660
 
1661
  logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
1662
  logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
 
1711
 
1712
  try:
1713
  # Load audio file
1714
+ audio, sr = self._load_audio_any_backend(audio_file)
1715
 
1716
  # Normalize to stereo 48kHz
1717
  audio = self._normalize_audio_to_stereo_48k(audio, sr)
 
1721
  except Exception as e:
1722
  logger.exception("[process_src_audio] Error processing source audio")
1723
  return None
1724
+
1725
+ def _load_audio_any_backend(self, audio_file):
1726
+ """Load audio with torchaudio first, then soundfile fallback."""
1727
+ def _coerce_audio_tensor(audio_obj):
1728
+ if isinstance(audio_obj, list):
1729
+ audio_obj = np.asarray(audio_obj, dtype=np.float32)
1730
+ if isinstance(audio_obj, np.ndarray):
1731
+ audio_obj = torch.from_numpy(audio_obj)
1732
+ if not torch.is_tensor(audio_obj):
1733
+ raise TypeError(f"Unsupported audio type: {type(audio_obj)}")
1734
+
1735
+ if not torch.is_floating_point(audio_obj):
1736
+ audio_obj = audio_obj.float()
1737
+
1738
+ # Normalize to [C, T]
1739
+ if audio_obj.dim() == 1:
1740
+ audio_obj = audio_obj.unsqueeze(0)
1741
+ elif audio_obj.dim() == 2:
1742
+ if audio_obj.shape[0] > audio_obj.shape[1] and audio_obj.shape[1] <= 8:
1743
+ audio_obj = audio_obj.transpose(0, 1)
1744
+ elif audio_obj.dim() == 3:
1745
+ audio_obj = audio_obj[0]
1746
+ else:
1747
+ raise ValueError(f"Unexpected audio dims: {tuple(audio_obj.shape)}")
1748
+ return audio_obj.contiguous()
1749
+
1750
+ try:
1751
+ audio, sr = torchaudio.load(audio_file)
1752
+ return _coerce_audio_tensor(audio), sr
1753
+ except Exception as torchaudio_exc:
1754
+ try:
1755
+ audio_np, sr = sf.read(audio_file, dtype="float32", always_2d=True)
1756
+ return _coerce_audio_tensor(audio_np.T), sr
1757
+ except Exception as sf_exc:
1758
+ raise RuntimeError(
1759
+ f"Audio decode failed for '{audio_file}' with torchaudio ({torchaudio_exc}) "
1760
+ f"and soundfile ({sf_exc})."
1761
+ ) from sf_exc
1762
 
1763
  def convert_src_audio_to_codes(self, audio_file) -> str:
1764
  """
acestep/llm_inference.py CHANGED
@@ -457,7 +457,7 @@ class LLMHandler:
457
 
458
  # If lm_model_path is None, use default
459
  if lm_model_path is None:
460
- lm_model_path = "acestep-5Hz-lm-1.7B"
461
  logger.info(f"[initialize] lm_model_path is None, using default: {lm_model_path}")
462
 
463
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
 
457
 
458
  # If lm_model_path is None, use default
459
  if lm_model_path is None:
460
+ lm_model_path = "acestep-5Hz-lm-4B"
461
  logger.info(f"[initialize] lm_model_path is None, using default: {lm_model_path}")
462
 
463
  full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
app.py CHANGED
@@ -1,5 +1,12 @@
1
  import os
2
 
 
 
 
 
 
 
 
3
  from lora_ui import build_ui
4
 
5
  app = build_ui()
 
1
  import os
2
 
3
+ # On Hugging Face Spaces Zero, `spaces` must be imported before CUDA-related modules.
4
+ if os.getenv("SPACE_ID"):
5
+ try:
6
+ import spaces # noqa: F401
7
+ except Exception:
8
+ pass
9
+
10
  from lora_ui import build_ui
11
 
12
  app = build_ui()
docs/ACE-Step-1.5-LoRA-HF-Consolidated.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ACE-Step 1.5 LoRA Pipeline (Simple + HF Spaces)
2
+
3
+ Last updated: 2026-02-12
4
+
5
+ ## 1. What is already implemented in this repo
6
+ - Drag/drop dataset loading and folder scan.
7
+ - Optional per-track sidecar JSON (`song.wav` + `song.json`).
8
+ - New **Auto-Label All** option in `lora_ui.py`:
9
+ - Uses ACE audio understanding (`audio -> semantic codes -> caption/lyrics/metadata`).
10
+ - Writes/updates sidecar JSON for each track.
11
+ - LoRA training with ACE flow-matching defaults and adapter checkpoints.
12
+ - Training log now shows device plus elapsed time and ETA.
13
+
14
+ ## 2. Direct answers to your core questions
15
+
16
+ ### Is LoRA using HF GPU?
17
+ Yes, if the Space hardware is GPU and model device is `auto`/`cuda`, training runs on that Space GPU.
18
+
19
+ ### Do we get time estimates?
20
+ Yes. The training status now shows elapsed time and ETA in the log.
21
+
22
+ ### How are metadata and lyrics paired per song?
23
+ By basename in the same folder:
24
+ - `track01.wav`
25
+ - `track01.json`
26
+
27
+ ### Do you need all metadata?
28
+ No. In this pipeline, metadata is optional.
29
+ - Required minimum: audio files.
30
+ - Strongly recommended: `caption` and/or `lyrics` for better conditioning quality.
31
+ - Optional but helpful: `bpm`, `keyscale`, `timesignature`, `vocal_language`, `duration`.
32
+
33
+ ### Where are trained adapters saved?
34
+ - Local run: `lora_output/...` by default.
35
+ - HF Space run: `/data/lora_output/...` by default (as configured in UI code).
36
+ - Final adapter checkpoint is saved under a `final` subfolder.
37
+
38
+ ### Cloud GPU + local files?
39
+ - Training on Spaces uses cloud GPU and writes artifacts to the Space filesystem.
40
+ - To keep results outside the Space, download them or upload to a Hub model repo.
41
+
42
+ ### Can HF Endpoint GPU train this?
43
+ Not the right product. Inference Endpoints are for model serving/inference; use Spaces (interactive) or Jobs (batch) for training.
44
+
45
+ ## 3. Minimal dataset format
46
+
47
+ Drop files into one folder:
48
+
49
+ ```text
50
+ dataset_inbox/
51
+ song_a.wav
52
+ song_b.flac
53
+ song_c.mp3
54
+ ```
55
+
56
+ Optional sidecar for tighter control:
57
+
58
+ ```text
59
+ dataset_inbox/
60
+ song_a.wav
61
+ song_a.json
62
+ ```
63
+
64
+ Example `song_a.json`:
65
+
66
+ ```json
67
+ {
68
+ "caption": "emotional indie pop with airy female vocal and warm pads",
69
+ "lyrics": "[Verse]\n...",
70
+ "bpm": 96,
71
+ "keyscale": "Am",
72
+ "timesignature": "4/4",
73
+ "vocal_language": "en",
74
+ "duration": 120
75
+ }
76
+ ```
77
+
78
+ ## 4. Super simple training flow (UI)
79
+ 1. Start UI:
80
+ - Local: `python app.py`
81
+ - Space: app starts automatically from `app.py`.
82
+ 2. Step 1 tab: initialize `acestep-v15-base` (best LoRA plasticity).
83
+ 3. Step 2 tab: scan folder or drag/drop files.
84
+ 4. Optional: initialize auto-label LM and click **Auto-Label All**.
85
+ 5. Step 3 tab: keep defaults for first run, click **Start Training**.
86
+ 6. Click **Refresh Log** to monitor status/loss/ETA.
87
+ 7. Step 4 tab: load adapter from output folder and A/B test against base.
88
+
89
+ ## 5. HF Spaces setup (step by step)
90
+ 1. Create a new Hugging Face **Space** with SDK = `Gradio`.
91
+ 2. Push this repo to that Space repo.
92
+ 3. Ensure Space metadata/front matter includes:
93
+ - `sdk: gradio`
94
+ - `app_file: app.py`
95
+ 4. In Space `Settings -> Hardware`, select a GPU tier.
96
+ 5. In Space `Settings -> Variables and secrets`, add any needed tokens as secrets (never hardcode).
97
+ 6. Open the Space and run the 4-step UI flow.
98
+
99
+ ## 6. GPU association and cost control
100
+
101
+ ### Pick hardware for your stage
102
+ - Fast/cheap iteration: start with T4 or A10G.
103
+ - Heavier runs or bigger LM usage: A100/L40S/H100 class.
104
+
105
+ ### Keep spend under control
106
+ 1. Use smaller auto-label LM (`0.6B`) unless you need higher quality labels.
107
+ 2. Train with `acestep-v15-base` only for final-quality runs; iterate on turbo variants if needed.
108
+ 3. Pause or downgrade hardware immediately when idle.
109
+ 4. Export/upload adapters right after training so you can shut hardware down.
110
+
111
+ ### Current billing behavior to remember
112
+ HF Spaces docs indicate upgraded hardware is billed by minute while the Space is running, and you should pause/stop upgraded hardware when not in use.
113
+
114
+ ## 7. Suggested first-run defaults
115
+ - Model: `acestep-v15-base`
116
+ - LoRA rank/alpha/dropout: `64 / 64 / 0.1`
117
+ - Optimizer: `adamw_8bit`
118
+ - LR: `1e-4`
119
+ - Warmup: `0.03`
120
+ - Scheduler: `constant_with_warmup`
121
+ - Shift: `3.0`
122
+ - Max grad norm: `1.0`
123
+
124
+ ## 8. Source links (official)
125
+ - ACE-Step Gradio guide: https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5/blob/main/docs/en/GRADIO_GUIDE.md
126
+ - ACE-Step README: https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5/blob/main/README.md
127
+ - ACE-Step LoRA model card note (DiT-only LoRA): https://huggingface.co/ACE-Step/Ace-Step-v1.5-lo-ra-new-year
128
+ - HF Spaces overview: https://huggingface.co/docs/hub/en/spaces-overview
129
+ - HF Spaces GPU/hardware docs: https://huggingface.co/docs/hub/en/spaces-gpus
130
+ - HF Spaces config reference: https://huggingface.co/docs/hub/en/spaces-config-reference
131
+ - HF Inference Endpoints overview: https://huggingface.co/docs/inference-endpoints/en/index
docs/deploy/ENDPOINT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploy Inference To Your Own HF Dedicated Endpoint
2
+
3
+ This guide deploys the custom `handler.py` inference runtime to a Hugging Face Dedicated Inference Endpoint.
4
+
5
+ ## Prerequisites
6
+
7
+ - Hugging Face account
8
+ - `HF_TOKEN` with repo write access
9
+ - Dedicated Endpoint access on your HF plan
10
+
11
+ ## 1) Create/Update Your Endpoint Repo
12
+
13
+ ```bash
14
+ python scripts/hf_clone.py endpoint --repo-id YOUR_USERNAME/YOUR_ENDPOINT_REPO
15
+ ```
16
+
17
+ This uploads:
18
+
19
+ - `handler.py`
20
+ - `acestep/`
21
+ - `requirements.txt`
22
+ - `packages.txt`
23
+ - endpoint-specific README template
24
+
25
+ ## 2) Create Endpoint In HF UI
26
+
27
+ 1. Go to **Inference Endpoints** -> **New endpoint**.
28
+ 2. Select your custom model repo: `YOUR_USERNAME/YOUR_ENDPOINT_REPO`.
29
+ 3. Choose GPU hardware.
30
+ 4. Deploy.
31
+
32
+ ## 3) Recommended Endpoint Environment Variables
33
+
34
+ - `ACE_CONFIG_PATH` (default: `acestep-v15-sft`)
35
+ - `ACE_LM_MODEL_PATH` (default: `acestep-5Hz-lm-4B`)
36
+ - `ACE_LM_BACKEND` (default: `pt`)
37
+ - `ACE_DOWNLOAD_SOURCE` (`huggingface` or `modelscope`)
38
+ - `ACE_ENABLE_FALLBACK` (`false` recommended for strict failure visibility)
39
+
40
+ ## 4) Test The Endpoint
41
+
42
+ Set credentials:
43
+
44
+ ```bash
45
+ # Linux/macOS
46
+ export HF_TOKEN=hf_xxx
47
+ export HF_ENDPOINT_URL=https://your-endpoint-url.endpoints.huggingface.cloud
48
+
49
+ # Windows PowerShell
50
+ $env:HF_TOKEN="hf_xxx"
51
+ $env:HF_ENDPOINT_URL="https://your-endpoint-url.endpoints.huggingface.cloud"
52
+ ```
53
+
54
+ Test with:
55
+
56
+ - `python scripts/endpoint/generate_interactive.py`
57
+ - `scripts/endpoint/test.ps1`
58
+
59
+ ## Request Contract
60
+
61
+ ```json
62
+ {
63
+ "inputs": {
64
+ "prompt": "upbeat pop rap with emotional guitar",
65
+ "lyrics": "[Verse] city lights and midnight rain",
66
+ "duration_sec": 12,
67
+ "sample_rate": 44100,
68
+ "seed": 42,
69
+ "guidance_scale": 7.0,
70
+ "steps": 50,
71
+ "use_lm": true
72
+ }
73
+ }
74
+ ```
75
+
76
+ ## Cost Control
77
+
78
+ - Use scale-to-zero for idle periods.
79
+ - Pause endpoint for immediate spend stop.
80
+ - Expect cold starts when scaled to zero.
docs/deploy/SPACE.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deploy LoRA Studio To Your Own HF Space
2
+
3
+ This guide deploys the full LoRA Studio UI to your own Hugging Face Space.
4
+
5
+ ## Prerequisites
6
+
7
+ - Hugging Face account
8
+ - `HF_TOKEN` with repo write access
9
+ - Python environment with `requirements.txt` installed
10
+
11
+ ## Fast Path (Recommended)
12
+
13
+ ```bash
14
+ python scripts/hf_clone.py space --repo-id YOUR_USERNAME/YOUR_SPACE_NAME
15
+ ```
16
+
17
+ Optional private Space:
18
+
19
+ ```bash
20
+ python scripts/hf_clone.py space --repo-id YOUR_USERNAME/YOUR_SPACE_NAME --private
21
+ ```
22
+
23
+ ## Manual Path
24
+
25
+ 1. Create a new Space on Hugging Face:
26
+ - SDK: `Gradio`
27
+ 2. Push this repo content (excluding local artifacts) to that Space repo.
28
+ 3. Ensure README front matter has:
29
+ - `sdk: gradio`
30
+ - `app_file: app.py`
31
+ 4. In Space settings:
32
+ - select GPU hardware (A10G/A100/etc.) if needed
33
+ - add secrets (`HF_TOKEN`) if your flow requires private Hub access
34
+
35
+ ## Runtime Notes
36
+
37
+ - Space output defaults to `/data/lora_output` on Hugging Face Spaces.
38
+ - Enable persistent storage if you need checkpoint retention across restarts.
39
+ - For long-running non-interactive training, HF Jobs may be more cost-efficient than keeping a Space running.
40
+
docs/guides/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Guides
2
+
3
+ Additional step-by-step guides that are useful but not required for the core LoRA Studio flow.
4
+
5
+ - `qwen2-audio-train.md`
docs/guides/qwen2-audio-train.md ADDED
File without changes
handler.py CHANGED
@@ -3,6 +3,7 @@ import base64
3
  import io
4
  import os
5
  import traceback
 
6
  from typing import Any, Dict, Optional, Tuple
7
 
8
  import numpy as np
@@ -27,7 +28,7 @@ class EndpointHandler:
27
  "sample_rate": 44100,
28
  "seed": 42,
29
  "guidance_scale": 7.0,
30
- "steps": 8,
31
  "use_lm": true,
32
  "simple_prompt": false,
33
  "instrumental": false,
@@ -50,8 +51,10 @@ class EndpointHandler:
50
  self.project_root = os.path.dirname(os.path.abspath(__file__))
51
 
52
  self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
53
- self.config_path = os.getenv("ACE_CONFIG_PATH", "acestep-v15-turbo")
54
- self.lm_model_path = os.getenv("ACE_LM_MODEL_PATH", "acestep-5Hz-lm-1.7B")
 
 
55
  self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
56
  self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
57
 
@@ -233,6 +236,31 @@ class EndpointHandler:
233
 
234
  try:
235
  checkpoint_dir = os.path.join(self.project_root, "checkpoints")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  status, ok = self.llm_handler.initialize(
237
  checkpoint_dir=checkpoint_dir,
238
  lm_model_path=self.lm_model_path,
@@ -352,8 +380,14 @@ class EndpointHandler:
352
 
353
  seed = self._to_int(raw_inputs.get("seed", 42), 42)
354
  guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
355
- steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 8)), 8)
356
  steps = max(1, min(steps, 200))
 
 
 
 
 
 
357
  use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
358
  allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
359
 
@@ -365,6 +399,7 @@ class EndpointHandler:
365
  "seed": seed,
366
  "guidance_scale": guidance_scale,
367
  "steps": steps,
 
368
  "use_lm": use_lm,
369
  "instrumental": instrumental,
370
  "simple_prompt": simple_prompt,
@@ -383,7 +418,7 @@ class EndpointHandler:
383
  "simple_expansion_error": None,
384
  }
385
 
386
- bpm = None
387
  keyscale = ""
388
  timesignature = ""
389
  vocal_language = "unknown"
@@ -399,7 +434,9 @@ class EndpointHandler:
399
  if getattr(sample, "success", False):
400
  caption = getattr(sample, "caption", "") or caption
401
  lyrics = getattr(sample, "lyrics", "") or lyrics
402
- bpm = getattr(sample, "bpm", None)
 
 
403
  keyscale = getattr(sample, "keyscale", "") or ""
404
  timesignature = getattr(sample, "timesignature", "") or ""
405
  vocal_language = getattr(sample, "language", "") or "unknown"
@@ -526,6 +563,7 @@ class EndpointHandler:
526
  "seed": req["seed"],
527
  "guidance_scale": req["guidance_scale"],
528
  "steps": req["steps"],
 
529
  "use_lm": req["use_lm"],
530
  "simple_prompt": req["simple_prompt"],
531
  "instrumental": req["instrumental"],
 
3
  import io
4
  import os
5
  import traceback
6
+ from pathlib import Path
7
  from typing import Any, Dict, Optional, Tuple
8
 
9
  import numpy as np
 
28
  "sample_rate": 44100,
29
  "seed": 42,
30
  "guidance_scale": 7.0,
31
+ "steps": 50,
32
  "use_lm": true,
33
  "simple_prompt": false,
34
  "instrumental": false,
 
51
  self.project_root = os.path.dirname(os.path.abspath(__file__))
52
 
53
  self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
54
+ # Default to the larger quality-oriented setup.
55
+ # Override via ACE_CONFIG_PATH / ACE_LM_MODEL_PATH when needed.
56
+ self.config_path = os.getenv("ACE_CONFIG_PATH", "acestep-v15-sft")
57
+ self.lm_model_path = os.getenv("ACE_LM_MODEL_PATH", "acestep-5Hz-lm-4B")
58
  self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
59
  self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
60
 
 
236
 
237
  try:
238
  checkpoint_dir = os.path.join(self.project_root, "checkpoints")
239
+ full_lm_model_path = os.path.join(checkpoint_dir, self.lm_model_path)
240
+ if not os.path.exists(full_lm_model_path):
241
+ try:
242
+ from acestep.model_downloader import ensure_lm_model, ensure_main_model
243
+ except Exception as e:
244
+ self.llm_error = f"LM download helper import failed: {type(e).__name__}: {e}"
245
+ return False
246
+
247
+ # 1.7B ships with main; 0.6B/4B are standalone submodels.
248
+ if self.lm_model_path == "acestep-5Hz-lm-1.7B":
249
+ dl_ok, dl_msg = ensure_main_model(
250
+ checkpoints_dir=Path(checkpoint_dir),
251
+ prefer_source=self.download_source,
252
+ )
253
+ else:
254
+ dl_ok, dl_msg = ensure_lm_model(
255
+ model_name=self.lm_model_path,
256
+ checkpoints_dir=Path(checkpoint_dir),
257
+ prefer_source=self.download_source,
258
+ )
259
+ self.init_details["llm_download"] = dl_msg
260
+ if not dl_ok:
261
+ self.llm_error = f"LM download failed: {dl_msg}"
262
+ return False
263
+
264
  status, ok = self.llm_handler.initialize(
265
  checkpoint_dir=checkpoint_dir,
266
  lm_model_path=self.lm_model_path,
 
380
 
381
  seed = self._to_int(raw_inputs.get("seed", 42), 42)
382
  guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
383
+ steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 50)), 50)
384
  steps = max(1, min(steps, 200))
385
+ bpm_raw = raw_inputs.get("bpm")
386
+ bpm = None
387
+ if bpm_raw is not None and str(bpm_raw).strip() != "":
388
+ bpm = self._to_int(bpm_raw, 0)
389
+ if bpm <= 0:
390
+ bpm = None
391
  use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
392
  allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
393
 
 
399
  "seed": seed,
400
  "guidance_scale": guidance_scale,
401
  "steps": steps,
402
+ "bpm": bpm,
403
  "use_lm": use_lm,
404
  "instrumental": instrumental,
405
  "simple_prompt": simple_prompt,
 
418
  "simple_expansion_error": None,
419
  }
420
 
421
+ bpm = req.get("bpm")
422
  keyscale = ""
423
  timesignature = ""
424
  vocal_language = "unknown"
 
434
  if getattr(sample, "success", False):
435
  caption = getattr(sample, "caption", "") or caption
436
  lyrics = getattr(sample, "lyrics", "") or lyrics
437
+ sample_bpm = getattr(sample, "bpm", None)
438
+ if bpm is None:
439
+ bpm = sample_bpm
440
  keyscale = getattr(sample, "keyscale", "") or ""
441
  timesignature = getattr(sample, "timesignature", "") or ""
442
  vocal_language = getattr(sample, "language", "") or "unknown"
 
563
  "seed": req["seed"],
564
  "guidance_scale": req["guidance_scale"],
565
  "steps": req["steps"],
566
+ "bpm": req.get("bpm"),
567
  "use_lm": req["use_lm"],
568
  "simple_prompt": req["simple_prompt"],
569
  "instrumental": req["instrumental"],
lora_train.py ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step 1.5 LoRA Training Engine
3
+
4
+ Handles dataset building, VAE encoding, and flow-matching LoRA training
5
+ of the DiT decoder. Designed to work with the existing AceStepHandler.
6
+ """
7
+
8
+ import os
9
+ import sys
10
+ import json
11
+ import math
12
+ import time
13
+ import random
14
+ import hashlib
15
+ import argparse
16
+ import tempfile
17
+ from pathlib import Path
18
+ from dataclasses import dataclass, field, asdict
19
+ from typing import Optional, List, Dict, Any, Tuple
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import torchaudio
24
+ import soundfile as sf
25
+ import numpy as np
26
+ from loguru import logger
27
+ from tqdm import tqdm
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Dataset helpers
31
+ # ---------------------------------------------------------------------------
32
+
33
+ AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aac"}
34
+
35
+
36
+ @dataclass
37
+ class TrackEntry:
38
+ """One audio file + its metadata."""
39
+
40
+ audio_path: str
41
+ caption: str = ""
42
+ lyrics: str = ""
43
+ bpm: Optional[int] = None
44
+ keyscale: str = ""
45
+ timesignature: str = "4/4"
46
+ vocal_language: str = "en"
47
+ duration: Optional[float] = None # seconds (measured at scan time)
48
+
49
+
50
+ def _load_track_entry(audio_path: Path) -> TrackEntry:
51
+ """Load one track + optional sidecar metadata."""
52
+ sidecar = audio_path.with_suffix(".json")
53
+ meta: Dict[str, Any] = {}
54
+ if sidecar.exists():
55
+ try:
56
+ meta = json.loads(sidecar.read_text(encoding="utf-8"))
57
+ except Exception as exc:
58
+ logger.warning(f"Bad sidecar {sidecar}: {exc}")
59
+
60
+ try:
61
+ info = torchaudio.info(str(audio_path))
62
+ duration = info.num_frames / info.sample_rate
63
+ except Exception:
64
+ duration = meta.get("duration")
65
+
66
+ return TrackEntry(
67
+ audio_path=str(audio_path),
68
+ caption=meta.get("caption", ""),
69
+ lyrics=meta.get("lyrics", ""),
70
+ bpm=meta.get("bpm"),
71
+ keyscale=meta.get("keyscale", ""),
72
+ timesignature=meta.get("timesignature", "4/4"),
73
+ vocal_language=meta.get("vocal_language", "en"),
74
+ duration=duration,
75
+ )
76
+
77
+
78
+ def scan_dataset_folder(folder: str) -> List[TrackEntry]:
79
+ """Scan *folder* for audio files and optional JSON sidecars.
80
+
81
+ For every ``track.wav`` found, if ``track.json`` exists next to it the
82
+ metadata fields are loaded from the sidecar. Missing sidecars are fine –
83
+ the entry will have empty metadata that can be filled later.
84
+ """
85
+ folder = Path(folder)
86
+ if not folder.is_dir():
87
+ raise FileNotFoundError(f"Dataset folder not found: {folder}")
88
+
89
+ entries: List[TrackEntry] = []
90
+ for audio_path in sorted(folder.rglob("*")):
91
+ if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
92
+ continue
93
+ entries.append(_load_track_entry(audio_path))
94
+
95
+ logger.info(f"Scanned {len(entries)} audio files in {folder}")
96
+ return entries
97
+
98
+
99
+ def scan_uploaded_files(file_paths: List[str]) -> List[TrackEntry]:
100
+ """Build entries from dropped/uploaded files.
101
+
102
+ Supports uploading audio files together with optional ``.json`` sidecars.
103
+ Sidecars are matched by basename stem (``song.mp3`` <-> ``song.json``).
104
+ """
105
+ meta_by_stem: Dict[str, Dict[str, Any]] = {}
106
+ for path in file_paths:
107
+ p = Path(path)
108
+ if not p.exists() or p.suffix.lower() != ".json":
109
+ continue
110
+ try:
111
+ meta_by_stem[p.stem] = json.loads(p.read_text(encoding="utf-8"))
112
+ except Exception as exc:
113
+ logger.warning(f"Bad uploaded sidecar {p}: {exc}")
114
+
115
+ entries: List[TrackEntry] = []
116
+ for path in file_paths:
117
+ p = Path(path)
118
+ if not p.exists() or p.suffix.lower() not in AUDIO_EXTENSIONS:
119
+ continue
120
+
121
+ uploaded_meta = meta_by_stem.get(p.stem)
122
+ if uploaded_meta is None:
123
+ entries.append(_load_track_entry(p))
124
+ continue
125
+
126
+ try:
127
+ info = torchaudio.info(str(p))
128
+ duration = info.num_frames / info.sample_rate
129
+ except Exception:
130
+ duration = uploaded_meta.get("duration")
131
+
132
+ bpm_val = uploaded_meta.get("bpm")
133
+ if isinstance(bpm_val, str) and bpm_val.strip():
134
+ try:
135
+ bpm_val = int(float(bpm_val))
136
+ except Exception:
137
+ bpm_val = None
138
+
139
+ entries.append(
140
+ TrackEntry(
141
+ audio_path=str(p),
142
+ caption=uploaded_meta.get("caption", "") or "",
143
+ lyrics=uploaded_meta.get("lyrics", "") or "",
144
+ bpm=bpm_val if isinstance(bpm_val, int) else None,
145
+ keyscale=uploaded_meta.get("keyscale", "") or "",
146
+ timesignature=uploaded_meta.get("timesignature", "4/4") or "4/4",
147
+ vocal_language=uploaded_meta.get("vocal_language", uploaded_meta.get("language", "en")) or "en",
148
+ duration=duration,
149
+ )
150
+ )
151
+
152
+ logger.info(
153
+ "Loaded {} uploaded audio files ({} uploaded sidecars detected)".format(
154
+ len(entries), len(meta_by_stem)
155
+ )
156
+ )
157
+ return entries
158
+
159
+
160
+ # ---------------------------------------------------------------------------
161
+ # Training hyper-parameters
162
+ # ---------------------------------------------------------------------------
163
+
164
+
165
+ @dataclass
166
+ class LoRATrainConfig:
167
+ """All tuneable knobs for a LoRA run."""
168
+
169
+ # LoRA architecture
170
+ lora_rank: int = 64
171
+ lora_alpha: int = 64
172
+ lora_dropout: float = 0.1
173
+ lora_target_modules: List[str] = field(
174
+ default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
175
+ )
176
+
177
+ # Optimiser
178
+ learning_rate: float = 1e-4
179
+ weight_decay: float = 0.01
180
+ optimizer: str = "adamw_8bit" # "adamw" | "adamw_8bit"
181
+ max_grad_norm: float = 1.0
182
+
183
+ # Schedule
184
+ warmup_ratio: float = 0.03
185
+ scheduler: str = "constant_with_warmup"
186
+
187
+ # Training loop
188
+ num_epochs: int = 50
189
+ batch_size: int = 1
190
+ gradient_accumulation_steps: int = 1
191
+ save_every_n_epochs: int = 10
192
+ log_every_n_steps: int = 5
193
+
194
+ # Flow matching
195
+ shift: float = 3.0 # timestep shift factor
196
+
197
+ # Audio pre-processing
198
+ max_duration_sec: float = 240.0 # clamp audio to this length
199
+ sample_rate: int = 48000
200
+
201
+ # Paths
202
+ output_dir: str = "lora_output"
203
+ resume_from: Optional[str] = None
204
+
205
+ # Device
206
+ device: str = "auto"
207
+ dtype: str = "bf16" # "bf16" | "fp16" | "fp32"
208
+ mixed_precision: bool = True
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Core trainer
213
+ # ---------------------------------------------------------------------------
214
+
215
+
216
+ class LoRATrainer:
217
+ """Thin training loop that wraps the existing AceStepHandler."""
218
+
219
+ def __init__(self, handler, config: LoRATrainConfig):
220
+ """
221
+ Args:
222
+ handler: Initialised ``AceStepHandler`` (model, vae, text_encoder loaded).
223
+ config: Training hyper-parameters.
224
+ """
225
+ self.handler = handler
226
+ self.cfg = config
227
+
228
+ self.device = handler.device
229
+ self.dtype = handler.dtype
230
+
231
+ # Will be set during prepare()
232
+ self.peft_model = None
233
+ self.optimizer = None
234
+ self.scheduler = None
235
+ self.global_step = 0
236
+ self.current_epoch = 0
237
+
238
+ # Loss history for UI
239
+ self.loss_history: List[Dict[str, Any]] = []
240
+ self._stop_requested = False
241
+
242
+ # ------------------------------------------------------------------
243
+ # Setup
244
+ # ------------------------------------------------------------------
245
+
246
+ @staticmethod
247
+ def _resolve_lora_target_modules(model, requested_targets: Optional[List[str]]) -> List[str]:
248
+ """Resolve LoRA target module suffixes against the actual decoder module names."""
249
+ linear_module_names = [
250
+ name for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)
251
+ ]
252
+
253
+ def _exists_as_suffix(target: str) -> bool:
254
+ return any(name.endswith(target) for name in linear_module_names)
255
+
256
+ requested_targets = requested_targets or []
257
+ resolved = [target for target in requested_targets if _exists_as_suffix(target)]
258
+ if resolved:
259
+ return resolved
260
+
261
+ fallback_groups = [
262
+ ["q_proj", "k_proj", "v_proj", "o_proj"],
263
+ ["to_q", "to_k", "to_v", "to_out.0"],
264
+ ["query", "key", "value", "out_proj"],
265
+ ["wq", "wk", "wv", "wo"],
266
+ ["qkv", "proj_out"],
267
+ ]
268
+ for group in fallback_groups:
269
+ group_resolved = [target for target in group if _exists_as_suffix(target)]
270
+ if len(group_resolved) >= 2:
271
+ return group_resolved
272
+
273
+ sample = ", ".join(linear_module_names[:30])
274
+ raise ValueError(
275
+ "Could not find LoRA target modules in decoder. "
276
+ f"Requested={requested_targets}. "
277
+ f"Sample linear modules: {sample}"
278
+ )
279
+
280
+ def prepare(self):
281
+ """Attach LoRA adapters to the decoder and build the optimiser."""
282
+ import copy
283
+ from peft import LoraConfig, PeftModel, TaskType, get_peft_model
284
+
285
+ # Keep a backup of the plain base decoder so load/unload logic remains valid.
286
+ if self.handler._base_decoder is None:
287
+ self.handler._base_decoder = copy.deepcopy(self.handler.model.decoder)
288
+ else:
289
+ self.handler.model.decoder = copy.deepcopy(self.handler._base_decoder)
290
+ self.handler.model.decoder = self.handler.model.decoder.to(self.device).to(self.dtype)
291
+ self.handler.model.decoder.eval()
292
+
293
+ resume_adapter = None
294
+ if self.cfg.resume_from:
295
+ adapter_cfg = os.path.join(self.cfg.resume_from, "adapter_config.json")
296
+ if os.path.isfile(adapter_cfg):
297
+ resume_adapter = self.cfg.resume_from
298
+
299
+ if resume_adapter:
300
+ logger.info(f"Loading existing LoRA adapter for resume: {resume_adapter}")
301
+ self.peft_model = PeftModel.from_pretrained(
302
+ self.handler.model.decoder,
303
+ resume_adapter,
304
+ is_trainable=True,
305
+ )
306
+ else:
307
+ resolved_targets = self._resolve_lora_target_modules(
308
+ self.handler.model.decoder,
309
+ self.cfg.lora_target_modules,
310
+ )
311
+ logger.info(f"Using LoRA target modules: {resolved_targets}")
312
+ peft_cfg = LoraConfig(
313
+ r=self.cfg.lora_rank,
314
+ lora_alpha=self.cfg.lora_alpha,
315
+ lora_dropout=self.cfg.lora_dropout,
316
+ target_modules=resolved_targets,
317
+ bias="none",
318
+ task_type=TaskType.FEATURE_EXTRACTION,
319
+ )
320
+ self.peft_model = get_peft_model(self.handler.model.decoder, peft_cfg)
321
+
322
+ self.peft_model.print_trainable_parameters()
323
+ self.handler.model.decoder = self.peft_model
324
+ self.handler.model.decoder.to(self.device).to(self.dtype)
325
+ self.handler.model.decoder.train()
326
+ self.handler.lora_loaded = True
327
+ self.handler.use_lora = True
328
+
329
+ # Build optimiser (only LoRA params are trainable)
330
+ trainable_params = [p for p in self.peft_model.parameters() if p.requires_grad]
331
+ if self.cfg.optimizer == "adamw_8bit":
332
+ try:
333
+ import bitsandbytes as bnb
334
+ self.optimizer = bnb.optim.AdamW8bit(
335
+ trainable_params,
336
+ lr=self.cfg.learning_rate,
337
+ weight_decay=self.cfg.weight_decay,
338
+ )
339
+ except ImportError:
340
+ logger.warning("bitsandbytes not found – falling back to standard AdamW")
341
+ self.optimizer = torch.optim.AdamW(
342
+ trainable_params,
343
+ lr=self.cfg.learning_rate,
344
+ weight_decay=self.cfg.weight_decay,
345
+ )
346
+ else:
347
+ self.optimizer = torch.optim.AdamW(
348
+ trainable_params,
349
+ lr=self.cfg.learning_rate,
350
+ weight_decay=self.cfg.weight_decay,
351
+ )
352
+
353
+ # Resume checkpoint state (after model/adapter restore).
354
+ if self.cfg.resume_from and os.path.isfile(
355
+ os.path.join(self.cfg.resume_from, "training_state.pt")
356
+ ):
357
+ state = torch.load(
358
+ os.path.join(self.cfg.resume_from, "training_state.pt"),
359
+ weights_only=False,
360
+ )
361
+ try:
362
+ self.optimizer.load_state_dict(state["optimizer"])
363
+ except Exception as exc:
364
+ logger.warning(f"Could not restore optimizer state, continuing fresh optimizer: {exc}")
365
+ self.global_step = int(state.get("global_step", 0))
366
+ # Saved epoch is completed epoch index; continue from next epoch.
367
+ self.current_epoch = int(state.get("epoch", -1)) + 1
368
+ loss_path = os.path.join(self.cfg.resume_from, "loss_history.json")
369
+ if os.path.isfile(loss_path):
370
+ try:
371
+ with open(loss_path, "r", encoding="utf-8") as f:
372
+ self.loss_history = json.load(f)
373
+ except Exception:
374
+ pass
375
+ logger.info(
376
+ f"Resumed from {self.cfg.resume_from} "
377
+ f"(epoch {self.current_epoch}, step {self.global_step})"
378
+ )
379
+
380
+ # ------------------------------------------------------------------
381
+ # Data loading
382
+ # ------------------------------------------------------------------
383
+
384
+ @staticmethod
385
+ def _coerce_audio_tensor(audio: Any) -> torch.Tensor:
386
+ """Coerce decoded audio into torch.Tensor with shape [C, T]."""
387
+ if isinstance(audio, list):
388
+ audio = np.asarray(audio, dtype=np.float32)
389
+ if isinstance(audio, np.ndarray):
390
+ audio = torch.from_numpy(audio)
391
+ if not torch.is_tensor(audio):
392
+ raise TypeError(f"Unsupported audio type: {type(audio)}")
393
+
394
+ # Ensure floating point for downstream resample/vae encode.
395
+ if not torch.is_floating_point(audio):
396
+ audio = audio.float()
397
+
398
+ # Normalize dimensions to [C, T].
399
+ if audio.dim() == 1:
400
+ audio = audio.unsqueeze(0)
401
+ elif audio.dim() == 2:
402
+ # Accept either [T, C] or [C, T]; transpose only when clearly [T, C].
403
+ if audio.shape[0] > audio.shape[1] and audio.shape[1] <= 8:
404
+ audio = audio.transpose(0, 1)
405
+ elif audio.dim() == 3:
406
+ # If batched, take first item.
407
+ audio = audio[0]
408
+ else:
409
+ raise ValueError(f"Unexpected audio dims: {tuple(audio.shape)}")
410
+
411
+ return audio.contiguous()
412
+
413
+ def _load_audio(self, path: str) -> torch.Tensor:
414
+ """Load audio, resample to 48 kHz stereo, clamp to max_duration."""
415
+ try:
416
+ wav, sr = torchaudio.load(path)
417
+ except Exception as torchaudio_exc:
418
+ # torchaudio on some Space images requires torchcodec for decode.
419
+ # Fallback to soundfile so training can proceed without torchcodec.
420
+ try:
421
+ audio_np, sr = sf.read(path, dtype="float32", always_2d=True)
422
+ wav = torch.from_numpy(audio_np.T)
423
+ except Exception as sf_exc:
424
+ raise RuntimeError(
425
+ f"Failed to decode audio '{path}' with torchaudio ({torchaudio_exc}) "
426
+ f"and soundfile ({sf_exc})."
427
+ ) from sf_exc
428
+
429
+ wav = self._coerce_audio_tensor(wav)
430
+
431
+ # Resample if needed
432
+ if sr != self.cfg.sample_rate:
433
+ wav = torchaudio.functional.resample(wav, sr, self.cfg.sample_rate)
434
+
435
+ # Convert mono → stereo
436
+ if wav.shape[0] == 1:
437
+ wav = wav.repeat(2, 1)
438
+ elif wav.shape[0] > 2:
439
+ wav = wav[:2]
440
+
441
+ # Clamp length
442
+ max_samples = int(self.cfg.max_duration_sec * self.cfg.sample_rate)
443
+ if wav.shape[1] > max_samples:
444
+ wav = wav[:, :max_samples]
445
+
446
+ return wav # [2, T]
447
+
448
+ def _encode_audio(self, wav: torch.Tensor) -> torch.Tensor:
449
+ """Encode raw waveform → VAE latent on device."""
450
+ with torch.no_grad():
451
+ latent = self.handler._encode_audio_to_latents(wav)
452
+ if latent.dim() == 2:
453
+ latent = latent.unsqueeze(0)
454
+ latent = latent.to(self.dtype)
455
+ return latent
456
+
457
+ def _build_text_embeddings(self, caption: str, lyrics: str):
458
+ """Compute text & lyric embeddings using the text encoder."""
459
+ tokenizer = self.handler.text_tokenizer
460
+ text_encoder = self.handler.text_encoder
461
+
462
+ # Caption embedding
463
+ text_tokens = tokenizer(
464
+ caption or "",
465
+ return_tensors="pt",
466
+ padding="max_length",
467
+ truncation=True,
468
+ max_length=512,
469
+ ).to(self.device)
470
+
471
+ with torch.no_grad():
472
+ text_hidden = text_encoder(
473
+ input_ids=text_tokens["input_ids"]
474
+ ).last_hidden_state.to(self.dtype)
475
+ text_mask = text_tokens["attention_mask"].to(self.dtype)
476
+
477
+ # Lyric embedding (token-level via embed_tokens)
478
+ lyric_tokens = tokenizer(
479
+ lyrics or "",
480
+ return_tensors="pt",
481
+ padding="max_length",
482
+ truncation=True,
483
+ max_length=512,
484
+ ).to(self.device)
485
+
486
+ with torch.no_grad():
487
+ lyric_hidden = text_encoder.embed_tokens(
488
+ lyric_tokens["input_ids"]
489
+ ).to(self.dtype)
490
+ lyric_mask = lyric_tokens["attention_mask"].to(self.dtype)
491
+
492
+ return text_hidden, text_mask, lyric_hidden, lyric_mask
493
+
494
+ # ------------------------------------------------------------------
495
+ # Flow matching loss
496
+ # ------------------------------------------------------------------
497
+
498
+ def _flow_matching_loss(
499
+ self,
500
+ x1: torch.Tensor,
501
+ encoder_hidden_states: torch.Tensor,
502
+ encoder_attention_mask: torch.Tensor,
503
+ context_latents: torch.Tensor,
504
+ ) -> torch.Tensor:
505
+ """Compute rectified-flow MSE loss for one sample.
506
+
507
+ Notation follows ACE-Step convention:
508
+ x0 = noise, x1 = clean latent
509
+ xt = t * x0 + (1 - t) * x1
510
+ target velocity = x0 - x1
511
+ """
512
+ bsz = x1.shape[0]
513
+
514
+ # Sample random timestep per element
515
+ t = torch.rand(bsz, device=self.device, dtype=self.dtype)
516
+
517
+ # Apply timestep shift: t_shifted = shift * t / (1 + (shift - 1) * t)
518
+ if self.cfg.shift != 1.0:
519
+ t = self.cfg.shift * t / (1.0 + (self.cfg.shift - 1.0) * t)
520
+
521
+ t = t.clamp(1e-5, 1.0 - 1e-5)
522
+
523
+ # Noise
524
+ x0 = torch.randn_like(x1)
525
+
526
+ # Interpolate
527
+ t_expand = t.view(bsz, 1, 1)
528
+ xt = t_expand * x0 + (1.0 - t_expand) * x1
529
+
530
+ # Target velocity
531
+ velocity_target = x0 - x1
532
+
533
+ # Attention mask
534
+ attention_mask = torch.ones(
535
+ bsz, x1.shape[1], device=self.device, dtype=self.dtype
536
+ )
537
+
538
+ # Forward through decoder
539
+ decoder_out = self.handler.model.decoder(
540
+ hidden_states=xt,
541
+ timestep=t,
542
+ timestep_r=t,
543
+ attention_mask=attention_mask,
544
+ encoder_hidden_states=encoder_hidden_states,
545
+ encoder_attention_mask=encoder_attention_mask,
546
+ context_latents=context_latents,
547
+ use_cache=False,
548
+ output_attentions=False,
549
+ )
550
+
551
+ velocity_pred = decoder_out[0] # first element is the predicted output
552
+ loss = F.mse_loss(velocity_pred, velocity_target)
553
+ return loss
554
+
555
+ @staticmethod
556
+ def _pad_and_stack(tensors: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor:
557
+ """Pad variable-length tensors on dimension 0 and stack as batch."""
558
+ normalized = []
559
+ for t in tensors:
560
+ if t.dim() >= 2 and t.shape[0] == 1:
561
+ normalized.append(t.squeeze(0))
562
+ else:
563
+ normalized.append(t)
564
+
565
+ max_len = max(t.shape[0] for t in normalized)
566
+ template = normalized[0]
567
+ out_shape = (len(normalized), max_len, *template.shape[1:])
568
+ out = template.new_full(out_shape, pad_value)
569
+ for i, t in enumerate(normalized):
570
+ out[i, : t.shape[0]] = t
571
+ return out
572
+
573
+ # ------------------------------------------------------------------
574
+ # Main training loop
575
+ # ------------------------------------------------------------------
576
+
577
+ def request_stop(self):
578
+ """Ask the training loop to stop after the current step."""
579
+ self._stop_requested = True
580
+
581
+ def train(
582
+ self,
583
+ entries: List[TrackEntry],
584
+ progress_callback=None,
585
+ ) -> str:
586
+ """Run the full LoRA training.
587
+
588
+ Args:
589
+ entries: List of scanned TrackEntry objects.
590
+ progress_callback: ``fn(step, total_steps, loss, epoch)`` for UI updates.
591
+
592
+ Returns:
593
+ Status message.
594
+ """
595
+ self._stop_requested = False
596
+ self.loss_history.clear()
597
+ os.makedirs(self.cfg.output_dir, exist_ok=True)
598
+
599
+ if not entries:
600
+ return "No training data provided."
601
+
602
+ num_entries = len(entries)
603
+ total_steps = (
604
+ math.ceil(num_entries / self.cfg.batch_size)
605
+ * self.cfg.num_epochs
606
+ )
607
+
608
+ # ---- Pre-encode all audio & text (fits in CPU RAM) ----
609
+ logger.info("Pre-encoding dataset through VAE & text encoder ...")
610
+ dataset: List[Dict[str, Any]] = []
611
+ failed_encode: List[str] = []
612
+
613
+ # Freeze VAE and text encoder (they are not trained)
614
+ self.handler.vae.eval()
615
+ self.handler.text_encoder.eval()
616
+
617
+ # Reuse silence reference latent (same as handler's internal fallback path).
618
+ ref_latent = self.handler.silence_latent[:, :750, :].to(self.device).to(self.dtype)
619
+ ref_order_mask = torch.zeros(1, device=self.device, dtype=torch.long)
620
+
621
+ for idx, entry in enumerate(tqdm(entries, desc="Encoding dataset")):
622
+ try:
623
+ wav = self._load_audio(entry.audio_path)
624
+ latent = self._encode_audio(wav)
625
+ text_h, text_m, lyric_h, lyric_m = self._build_text_embeddings(
626
+ entry.caption, entry.lyrics
627
+ )
628
+
629
+ # Prepare condition using the model's own prepare_condition
630
+ with torch.no_grad():
631
+ enc_hs, enc_mask, ctx_lat = self.handler.model.prepare_condition(
632
+ text_hidden_states=text_h,
633
+ text_attention_mask=text_m,
634
+ lyric_hidden_states=lyric_h,
635
+ lyric_attention_mask=lyric_m,
636
+ refer_audio_acoustic_hidden_states_packed=ref_latent,
637
+ refer_audio_order_mask=ref_order_mask,
638
+ hidden_states=latent,
639
+ attention_mask=torch.ones(
640
+ 1, latent.shape[1],
641
+ device=self.device, dtype=self.dtype,
642
+ ),
643
+ silence_latent=self.handler.silence_latent,
644
+ src_latents=latent,
645
+ chunk_masks=torch.ones_like(latent),
646
+ is_covers=[False],
647
+ )
648
+
649
+ dataset.append(
650
+ {
651
+ "latent": latent.cpu(),
652
+ "enc_hs": enc_hs.cpu(),
653
+ "enc_mask": enc_mask.cpu(),
654
+ "ctx_lat": ctx_lat.cpu(),
655
+ "name": Path(entry.audio_path).stem,
656
+ }
657
+ )
658
+ except Exception as exc:
659
+ reason = f"{Path(entry.audio_path).name}: {exc}"
660
+ failed_encode.append(reason)
661
+ logger.warning(f"Skipping {entry.audio_path}: {exc}")
662
+
663
+ if not dataset:
664
+ preview = "\n".join(f"- {msg}" for msg in failed_encode[:8]) or "- (no detailed errors captured)"
665
+ return (
666
+ "All tracks failed to encode. Check audio files.\n"
667
+ "First errors:\n"
668
+ f"{preview}\n"
669
+ "Tip: try WAV/FLAC files and dataset folder scan instead of temporary uploads."
670
+ )
671
+
672
+ logger.info(f"Encoded {len(dataset)}/{num_entries} tracks.")
673
+
674
+ # ---- Warmup scheduler ----
675
+ total_optim_steps = math.ceil(
676
+ total_steps / self.cfg.gradient_accumulation_steps
677
+ )
678
+ warmup_steps = int(total_optim_steps * self.cfg.warmup_ratio)
679
+
680
+ if self.cfg.scheduler in {"constant_with_warmup", "linear", "cosine"}:
681
+ try:
682
+ from transformers import get_scheduler
683
+ self.scheduler = get_scheduler(
684
+ name=self.cfg.scheduler,
685
+ optimizer=self.optimizer,
686
+ num_warmup_steps=warmup_steps,
687
+ num_training_steps=total_optim_steps,
688
+ )
689
+ except Exception as exc:
690
+ logger.warning(f"Could not create scheduler '{self.cfg.scheduler}', disabling scheduler: {exc}")
691
+ self.scheduler = None
692
+ else:
693
+ self.scheduler = None
694
+
695
+ # ---- Training loop ----
696
+ logger.info(
697
+ f"Starting LoRA training: {self.cfg.num_epochs} epochs, "
698
+ f"{len(dataset)} samples, {total_optim_steps} optimiser steps"
699
+ )
700
+
701
+ self.peft_model.train()
702
+ accum_loss = 0.0
703
+ step_in_accum = 0
704
+
705
+ for epoch in range(self.current_epoch, self.cfg.num_epochs):
706
+ if self._stop_requested:
707
+ break
708
+
709
+ self.current_epoch = epoch
710
+ indices = list(range(len(dataset)))
711
+ random.shuffle(indices)
712
+
713
+ epoch_loss = 0.0
714
+ epoch_steps = 0
715
+
716
+ for i in range(0, len(indices), self.cfg.batch_size):
717
+ if self._stop_requested:
718
+ break
719
+
720
+ batch_indices = indices[i : i + self.cfg.batch_size]
721
+ batch_items = [dataset[j] for j in batch_indices]
722
+
723
+ # Move batch to device
724
+ latents = self._pad_and_stack([it["latent"] for it in batch_items]).to(self.device, self.dtype)
725
+ enc_hs = self._pad_and_stack([it["enc_hs"] for it in batch_items]).to(self.device, self.dtype)
726
+ enc_mask = self._pad_and_stack([it["enc_mask"] for it in batch_items], pad_value=0.0).to(self.device)
727
+ if enc_mask.dtype != self.dtype:
728
+ enc_mask = enc_mask.to(self.dtype)
729
+ ctx_lat = self._pad_and_stack([it["ctx_lat"] for it in batch_items]).to(self.device, self.dtype)
730
+
731
+ # Forward + loss
732
+ loss = self._flow_matching_loss(latents, enc_hs, enc_mask, ctx_lat)
733
+ loss = loss / self.cfg.gradient_accumulation_steps
734
+ loss.backward()
735
+
736
+ accum_loss += loss.item()
737
+ step_in_accum += 1
738
+
739
+ if step_in_accum >= self.cfg.gradient_accumulation_steps:
740
+ torch.nn.utils.clip_grad_norm_(
741
+ self.peft_model.parameters(), self.cfg.max_grad_norm
742
+ )
743
+ self.optimizer.step()
744
+ if self.scheduler is not None:
745
+ self.scheduler.step()
746
+ self.optimizer.zero_grad()
747
+
748
+ self.global_step += 1
749
+ avg_loss = accum_loss
750
+ accum_loss = 0.0
751
+ step_in_accum = 0
752
+
753
+ self.loss_history.append(
754
+ {
755
+ "step": self.global_step,
756
+ "epoch": epoch,
757
+ "loss": avg_loss,
758
+ "lr": self.optimizer.param_groups[0]["lr"],
759
+ }
760
+ )
761
+
762
+ if self.global_step % self.cfg.log_every_n_steps == 0:
763
+ logger.info(
764
+ f"Epoch {epoch+1}/{self.cfg.num_epochs} "
765
+ f"Step {self.global_step}/{total_optim_steps} "
766
+ f"Loss {avg_loss:.6f} "
767
+ f"LR {self.optimizer.param_groups[0]['lr']:.2e}"
768
+ )
769
+
770
+ if progress_callback:
771
+ progress_callback(
772
+ self.global_step, total_optim_steps, avg_loss, epoch
773
+ )
774
+
775
+ epoch_loss += loss.item() * self.cfg.gradient_accumulation_steps
776
+ epoch_steps += 1
777
+
778
+ # Flush remaining micro-batches when len(dataset) is not divisible by grad accumulation.
779
+ if step_in_accum > 0:
780
+ torch.nn.utils.clip_grad_norm_(self.peft_model.parameters(), self.cfg.max_grad_norm)
781
+ self.optimizer.step()
782
+ if self.scheduler is not None:
783
+ self.scheduler.step()
784
+ self.optimizer.zero_grad()
785
+ self.global_step += 1
786
+ avg_loss = accum_loss
787
+ accum_loss = 0.0
788
+ step_in_accum = 0
789
+ self.loss_history.append(
790
+ {
791
+ "step": self.global_step,
792
+ "epoch": epoch,
793
+ "loss": avg_loss,
794
+ "lr": self.optimizer.param_groups[0]["lr"],
795
+ }
796
+ )
797
+
798
+ # End of epoch – checkpoint?
799
+ if (
800
+ (epoch + 1) % self.cfg.save_every_n_epochs == 0
801
+ or epoch == self.cfg.num_epochs - 1
802
+ or self._stop_requested
803
+ ):
804
+ self._save_checkpoint(epoch)
805
+
806
+ if epoch_steps > 0:
807
+ avg_epoch_loss = epoch_loss / epoch_steps
808
+ logger.info(
809
+ f"Epoch {epoch+1} complete – avg loss {avg_epoch_loss:.6f}"
810
+ )
811
+
812
+ # Final save
813
+ final_dir = self._save_checkpoint(self.current_epoch, final=True)
814
+ status = (
815
+ "Training stopped early." if self._stop_requested else "Training complete!"
816
+ )
817
+ return f"{status} Adapter saved to {final_dir}"
818
+
819
+ # ------------------------------------------------------------------
820
+ # Checkpointing
821
+ # ------------------------------------------------------------------
822
+
823
+ def _save_checkpoint(self, epoch: int, final: bool = False) -> str:
824
+ tag = "final" if final else f"epoch-{epoch+1}"
825
+ save_dir = os.path.join(self.cfg.output_dir, tag)
826
+ os.makedirs(save_dir, exist_ok=True)
827
+
828
+ # Save PEFT adapter
829
+ self.peft_model.save_pretrained(save_dir)
830
+
831
+ # Save training state
832
+ torch.save(
833
+ {
834
+ "optimizer": self.optimizer.state_dict(),
835
+ "global_step": self.global_step,
836
+ "epoch": epoch,
837
+ },
838
+ os.path.join(save_dir, "training_state.pt"),
839
+ )
840
+
841
+ # Save loss curve
842
+ loss_path = os.path.join(save_dir, "loss_history.json")
843
+ with open(loss_path, "w") as f:
844
+ json.dump(self.loss_history, f)
845
+
846
+ # Save config
847
+ cfg_path = os.path.join(save_dir, "train_config.json")
848
+ with open(cfg_path, "w") as f:
849
+ json.dump(asdict(self.cfg), f, indent=2)
850
+
851
+ logger.info(f"Checkpoint saved → {save_dir}")
852
+ return save_dir
853
+
854
+ # ------------------------------------------------------------------
855
+ # Adapter listing
856
+ # ------------------------------------------------------------------
857
+
858
+ @staticmethod
859
+ def list_adapters(output_dir: str = "lora_output") -> List[str]:
860
+ """Return adapter directories inside *output_dir* (recursive)."""
861
+ results = []
862
+ root = Path(output_dir)
863
+ if not root.is_dir():
864
+ return results
865
+ for cfg in sorted(root.rglob("adapter_config.json")):
866
+ d = cfg.parent
867
+ if d.is_dir():
868
+ results.append(str(d))
869
+ return results
870
+
871
+
872
+ def _build_arg_parser() -> argparse.ArgumentParser:
873
+ parser = argparse.ArgumentParser(description="ACE-Step 1.5 LoRA trainer (CLI)")
874
+
875
+ # Dataset
876
+ parser.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder path")
877
+ parser.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id (optional)")
878
+ parser.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision")
879
+ parser.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside downloaded dataset")
880
+
881
+ # Model init
882
+ parser.add_argument("--model-config", type=str, default="acestep-v15-base", help="DiT config name")
883
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"])
884
+ parser.add_argument("--offload-to-cpu", action="store_true")
885
+ parser.add_argument("--offload-dit-to-cpu", action="store_true")
886
+ parser.add_argument("--prefer-source", type=str, default="huggingface", choices=["huggingface", "modelscope"])
887
+
888
+ # Train config
889
+ parser.add_argument("--output-dir", type=str, default="lora_output")
890
+ parser.add_argument("--resume-from", type=str, default="")
891
+ parser.add_argument("--num-epochs", type=int, default=50)
892
+ parser.add_argument("--batch-size", type=int, default=1)
893
+ parser.add_argument("--grad-accum", type=int, default=1)
894
+ parser.add_argument("--save-every", type=int, default=10)
895
+ parser.add_argument("--log-every", type=int, default=5)
896
+ parser.add_argument("--max-duration-sec", type=float, default=240.0)
897
+
898
+ parser.add_argument("--lora-rank", type=int, default=64)
899
+ parser.add_argument("--lora-alpha", type=int, default=64)
900
+ parser.add_argument("--lora-dropout", type=float, default=0.1)
901
+
902
+ parser.add_argument("--learning-rate", type=float, default=1e-4)
903
+ parser.add_argument("--weight-decay", type=float, default=0.01)
904
+ parser.add_argument("--optimizer", type=str, default="adamw_8bit", choices=["adamw", "adamw_8bit"])
905
+ parser.add_argument("--max-grad-norm", type=float, default=1.0)
906
+ parser.add_argument("--warmup-ratio", type=float, default=0.03)
907
+ parser.add_argument("--scheduler", type=str, default="constant_with_warmup", choices=["constant_with_warmup", "linear", "cosine"])
908
+ parser.add_argument("--shift", type=float, default=3.0)
909
+
910
+ # Optional upload
911
+ parser.add_argument("--upload-repo", type=str, default="", help="HF model repo to upload final adapter")
912
+ parser.add_argument("--upload-path", type=str, default="", help="Path inside upload repo (optional)")
913
+ parser.add_argument("--upload-private", action="store_true")
914
+ parser.add_argument("--hf-token-env", type=str, default="HF_TOKEN", help="Environment variable name for HF token")
915
+
916
+ return parser
917
+
918
+
919
+ def _resolve_dataset_dir(args) -> str:
920
+ if args.dataset_dir:
921
+ return args.dataset_dir
922
+
923
+ if not args.dataset_repo:
924
+ raise ValueError("Provide --dataset-dir or --dataset-repo.")
925
+
926
+ from huggingface_hub import snapshot_download
927
+
928
+ token = os.getenv(args.hf_token_env)
929
+ temp_root = tempfile.mkdtemp(prefix="acestep_lora_dataset_")
930
+ local_dir = os.path.join(temp_root, "dataset")
931
+ logger.info(f"Downloading dataset repo {args.dataset_repo}@{args.dataset_revision} to {local_dir}")
932
+ snapshot_download(
933
+ repo_id=args.dataset_repo,
934
+ repo_type="dataset",
935
+ revision=args.dataset_revision,
936
+ local_dir=local_dir,
937
+ local_dir_use_symlinks=False,
938
+ token=token,
939
+ )
940
+ if args.dataset_subdir:
941
+ sub = os.path.join(local_dir, args.dataset_subdir)
942
+ if not os.path.isdir(sub):
943
+ raise FileNotFoundError(f"Dataset subdir not found: {sub}")
944
+ return sub
945
+ return local_dir
946
+
947
+
948
+ def _upload_adapter_if_requested(args, final_dir: str):
949
+ if not args.upload_repo:
950
+ return
951
+
952
+ from huggingface_hub import HfApi
953
+
954
+ token = os.getenv(args.hf_token_env)
955
+ if not token:
956
+ raise RuntimeError(
957
+ f"{args.hf_token_env} is not set. Needed for upload to {args.upload_repo}."
958
+ )
959
+
960
+ api = HfApi(token=token)
961
+ api.create_repo(
962
+ repo_id=args.upload_repo,
963
+ repo_type="model",
964
+ exist_ok=True,
965
+ private=bool(args.upload_private),
966
+ )
967
+
968
+ path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else ""
969
+ commit_message = f"Upload ACE-Step LoRA adapter from {Path(final_dir).name}"
970
+ logger.info(f"Uploading adapter from {final_dir} to {args.upload_repo}/{path_in_repo}")
971
+ api.upload_folder(
972
+ repo_id=args.upload_repo,
973
+ repo_type="model",
974
+ folder_path=final_dir,
975
+ path_in_repo=path_in_repo,
976
+ commit_message=commit_message,
977
+ )
978
+ logger.info("Upload complete")
979
+
980
+
981
+ def main():
982
+ args = _build_arg_parser().parse_args()
983
+
984
+ dataset_dir = _resolve_dataset_dir(args)
985
+ entries = scan_dataset_folder(dataset_dir)
986
+ if not entries:
987
+ raise RuntimeError(f"No audio files found in dataset: {dataset_dir}")
988
+
989
+ from acestep.handler import AceStepHandler
990
+
991
+ project_root = str(Path(__file__).resolve().parent)
992
+ handler = AceStepHandler()
993
+ status, ok = handler.initialize_service(
994
+ project_root=project_root,
995
+ config_path=args.model_config,
996
+ device=args.device,
997
+ use_flash_attention=False,
998
+ compile_model=False,
999
+ offload_to_cpu=bool(args.offload_to_cpu),
1000
+ offload_dit_to_cpu=bool(args.offload_dit_to_cpu),
1001
+ prefer_source=args.prefer_source,
1002
+ )
1003
+ print(status)
1004
+ if not ok:
1005
+ raise RuntimeError("Model initialization failed")
1006
+
1007
+ cfg = LoRATrainConfig(
1008
+ lora_rank=args.lora_rank,
1009
+ lora_alpha=args.lora_alpha,
1010
+ lora_dropout=args.lora_dropout,
1011
+ learning_rate=args.learning_rate,
1012
+ weight_decay=args.weight_decay,
1013
+ optimizer=args.optimizer,
1014
+ max_grad_norm=args.max_grad_norm,
1015
+ warmup_ratio=args.warmup_ratio,
1016
+ scheduler=args.scheduler,
1017
+ num_epochs=args.num_epochs,
1018
+ batch_size=args.batch_size,
1019
+ gradient_accumulation_steps=args.grad_accum,
1020
+ save_every_n_epochs=args.save_every,
1021
+ log_every_n_steps=args.log_every,
1022
+ shift=args.shift,
1023
+ max_duration_sec=args.max_duration_sec,
1024
+ output_dir=args.output_dir,
1025
+ resume_from=(args.resume_from.strip() if args.resume_from else None),
1026
+ device=args.device,
1027
+ )
1028
+
1029
+ trainer = LoRATrainer(handler, cfg)
1030
+ trainer.prepare()
1031
+
1032
+ start = time.time()
1033
+
1034
+ def _progress(step, total, loss, epoch):
1035
+ elapsed = time.time() - start
1036
+ rate = step / elapsed if elapsed > 0 else 0.0
1037
+ remaining = max(0.0, total - step)
1038
+ eta_sec = remaining / rate if rate > 0 else -1.0
1039
+ eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown"
1040
+ logger.info(
1041
+ f"[progress] step={step}/{total} epoch={epoch+1} loss={loss:.6f} elapsed={elapsed/60:.1f}m eta={eta_msg}"
1042
+ )
1043
+
1044
+ msg = trainer.train(entries, progress_callback=_progress)
1045
+ print(msg)
1046
+
1047
+ final_dir = os.path.join(cfg.output_dir, "final")
1048
+ if os.path.isdir(final_dir):
1049
+ _upload_adapter_if_requested(args, final_dir)
1050
+ print(f"Final adapter directory: {final_dir}")
1051
+ else:
1052
+ print(f"Warning: final adapter directory not found at {final_dir}")
1053
+
1054
+
1055
+ if __name__ == "__main__":
1056
+ main()
lora_ui.py ADDED
@@ -0,0 +1,973 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step 1.5 LoRA Training and Evaluation UI.
3
+
4
+ Gradio interface with four tabs:
5
+ 1. Model Setup: initialize base DiT, VAE, and text encoder
6
+ 2. Dataset: scan folder or drop files, then edit/save sidecars
7
+ 3. Training: configure hyperparameters and run LoRA training
8
+ 4. Evaluation: load adapters and run deterministic A/B generation
9
+ """
10
+ import os
11
+ import sys
12
+ import json
13
+ import math
14
+ import random
15
+ import threading
16
+ import tempfile
17
+ import time
18
+ from pathlib import Path
19
+ from typing import List, Optional
20
+
21
+ import gradio as gr
22
+ # On Hugging Face Spaces Zero, `spaces` must be imported before CUDA-related modules.
23
+ if os.getenv("SPACE_ID"):
24
+ try:
25
+ import spaces # noqa: F401
26
+ except Exception:
27
+ pass
28
+ import torch
29
+ from loguru import logger
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Ensure project root is on sys.path so `acestep` imports work
33
+ # ---------------------------------------------------------------------------
34
+ PROJECT_ROOT = str(Path(__file__).resolve().parent)
35
+ if PROJECT_ROOT not in sys.path:
36
+ sys.path.insert(0, PROJECT_ROOT)
37
+
38
+ from acestep.handler import AceStepHandler
39
+ from acestep.audio_utils import AudioSaver
40
+ from acestep.llm_inference import LLMHandler
41
+ from acestep.inference import understand_music
42
+ from lora_train import (
43
+ LoRATrainConfig,
44
+ LoRATrainer,
45
+ TrackEntry,
46
+ scan_dataset_folder,
47
+ scan_uploaded_files,
48
+ )
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Globals (shared across Gradio callbacks)
52
+ # ---------------------------------------------------------------------------
53
+ handler = AceStepHandler()
54
+ llm_handler = LLMHandler()
55
+ trainer: Optional[LoRATrainer] = None
56
+ dataset_entries: List[TrackEntry] = []
57
+ _training_thread: Optional[threading.Thread] = None
58
+ _training_log: List[str] = []
59
+ _training_status: str = "idle" # idle | running | stopped | done
60
+ _training_started_at: Optional[float] = None
61
+ _model_init_ok: bool = False
62
+ _model_init_status: str = ""
63
+ _last_model_init_args: Optional[dict] = None
64
+ _lm_init_ok: bool = False
65
+ _last_lm_init_args: Optional[dict] = None
66
+ _auto_label_cursor: int = 0
67
+
68
+ audio_saver = AudioSaver(default_format="wav")
69
+ IS_SPACE = bool(os.getenv("SPACE_ID"))
70
+ DEFAULT_OUTPUT_DIR = "/data/lora_output" if IS_SPACE else "lora_output"
71
+
72
+ if IS_SPACE:
73
+ try:
74
+ import spaces as _hf_spaces
75
+ _gpu_callback = _hf_spaces.GPU(duration=300)
76
+ except Exception:
77
+ _gpu_callback = lambda fn: fn
78
+ else:
79
+ _gpu_callback = lambda fn: fn
80
+
81
+
82
+ def _rows_from_entries(entries: List[TrackEntry]):
83
+ rows = []
84
+ for e in entries:
85
+ rows.append([
86
+ Path(e.audio_path).name,
87
+ f"{e.duration:.1f}s" if e.duration else "?",
88
+ e.caption or "(none)",
89
+ e.lyrics[:60] + "..." if len(e.lyrics) > 60 else (e.lyrics or "(none)"),
90
+ e.vocal_language,
91
+ ])
92
+ return rows
93
+
94
+
95
+ # ===========================================================================
96
+ # Tab 1 - Model Setup
97
+ # ===========================================================================
98
+
99
+ def get_available_models():
100
+ models = handler.get_available_acestep_v15_models()
101
+ return models if models else ["acestep-v15-base"]
102
+
103
+
104
+ def init_model(
105
+ model_name: str,
106
+ device: str,
107
+ offload_cpu: bool,
108
+ offload_dit_cpu: bool,
109
+ ):
110
+ global _model_init_ok, _model_init_status, _last_model_init_args
111
+ _last_model_init_args = dict(
112
+ project_root=PROJECT_ROOT,
113
+ config_path=model_name,
114
+ device=device,
115
+ use_flash_attention=False,
116
+ compile_model=False,
117
+ offload_to_cpu=offload_cpu,
118
+ offload_dit_to_cpu=offload_dit_cpu,
119
+ )
120
+ status, ok = _init_model_gpu(**_last_model_init_args)
121
+ _model_init_ok = bool(ok)
122
+ _model_init_status = status or ""
123
+ return status
124
+
125
+
126
+ @_gpu_callback
127
+ def _init_model_gpu(**kwargs):
128
+ return _init_model_impl(**kwargs)
129
+
130
+
131
+ def _init_model_impl(**kwargs):
132
+ return handler.initialize_service(**kwargs)
133
+
134
+
135
+ # ===========================================================================
136
+ # Tab 2 - Dataset
137
+ # ===========================================================================
138
+
139
+ def scan_folder(folder_path: str):
140
+ global dataset_entries, _auto_label_cursor
141
+ if not folder_path or not os.path.isdir(folder_path):
142
+ return "Provide a valid folder path.", []
143
+ dataset_entries = scan_dataset_folder(folder_path)
144
+ _auto_label_cursor = 0
145
+ rows = _rows_from_entries(dataset_entries)
146
+ msg = f"Found {len(dataset_entries)} audio files."
147
+ return msg, rows
148
+
149
+
150
+ def load_uploaded(file_paths: List[str]):
151
+ global dataset_entries, _auto_label_cursor
152
+ if not file_paths:
153
+ return "Drop audio files (and optional .json sidecars) first.", []
154
+ sidecar_count = sum(
155
+ 1 for p in file_paths if isinstance(p, str) and Path(p).suffix.lower() == ".json"
156
+ )
157
+ dataset_entries = scan_uploaded_files(file_paths)
158
+ _auto_label_cursor = 0
159
+ rows = _rows_from_entries(dataset_entries)
160
+ msg = (
161
+ f"Loaded {len(dataset_entries)} dropped audio files."
162
+ + (f" Matched {sidecar_count} uploaded sidecar JSON file(s)." if sidecar_count else "")
163
+ )
164
+ return msg, rows
165
+
166
+
167
+ def save_sidecar(index: int, caption: str, lyrics: str, bpm: str, keyscale: str, lang: str):
168
+ """Save metadata edits back to a JSON sidecar and update in-memory entry."""
169
+ global dataset_entries
170
+ if index < 0 or index >= len(dataset_entries):
171
+ return "Invalid track index."
172
+ entry = dataset_entries[index]
173
+ entry.caption = caption
174
+ entry.lyrics = lyrics
175
+ if bpm.strip():
176
+ try:
177
+ entry.bpm = int(float(bpm))
178
+ except ValueError:
179
+ return "Invalid BPM value. Use an integer or leave empty."
180
+ else:
181
+ entry.bpm = None
182
+ entry.keyscale = keyscale
183
+ entry.vocal_language = lang
184
+
185
+ sidecar_path = Path(entry.audio_path).with_suffix(".json")
186
+ meta = {
187
+ "caption": entry.caption,
188
+ "lyrics": entry.lyrics,
189
+ "bpm": entry.bpm,
190
+ "keyscale": entry.keyscale,
191
+ "timesignature": entry.timesignature,
192
+ "vocal_language": entry.vocal_language,
193
+ "duration": entry.duration,
194
+ }
195
+ sidecar_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
196
+ return f"Saved sidecar for {Path(entry.audio_path).name}"
197
+
198
+
199
+ def init_auto_label_lm(lm_model_path: str, lm_backend: str, lm_device: str):
200
+ global _lm_init_ok, _last_lm_init_args
201
+ _last_lm_init_args = dict(
202
+ lm_model_path=lm_model_path,
203
+ lm_backend=lm_backend,
204
+ lm_device=lm_device,
205
+ )
206
+ status = _init_auto_label_lm_gpu(**_last_lm_init_args)
207
+ _lm_init_ok = not str(status).startswith("LM init failed:") and not str(status).startswith("LM init exception:")
208
+ return status
209
+
210
+
211
+ @_gpu_callback
212
+ def _init_auto_label_lm_gpu(lm_model_path: str, lm_backend: str, lm_device: str):
213
+ return _init_auto_label_lm_impl(lm_model_path, lm_backend, lm_device)
214
+
215
+
216
+ def _init_auto_label_lm_impl(lm_model_path: str, lm_backend: str, lm_device: str):
217
+ """Initialize LLM for dataset auto-labeling."""
218
+ checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")
219
+ full_lm_path = os.path.join(checkpoint_dir, lm_model_path)
220
+
221
+ try:
222
+ if not os.path.exists(full_lm_path):
223
+ from pathlib import Path as _Path
224
+ from acestep.model_downloader import ensure_main_model, ensure_lm_model
225
+
226
+ if lm_model_path == "acestep-5Hz-lm-1.7B":
227
+ ok, msg = ensure_main_model(
228
+ checkpoints_dir=_Path(checkpoint_dir),
229
+ prefer_source="huggingface",
230
+ )
231
+ else:
232
+ ok, msg = ensure_lm_model(
233
+ model_name=lm_model_path,
234
+ checkpoints_dir=_Path(checkpoint_dir),
235
+ prefer_source="huggingface",
236
+ )
237
+ if not ok:
238
+ return f"Failed to download LM model: {msg}"
239
+
240
+ status, ok = llm_handler.initialize(
241
+ checkpoint_dir=checkpoint_dir,
242
+ lm_model_path=lm_model_path,
243
+ backend=lm_backend,
244
+ device=lm_device,
245
+ offload_to_cpu=False,
246
+ )
247
+ return status if ok else f"LM init failed:\n{status}"
248
+ except Exception as exc:
249
+ logger.exception("LM init failed for auto-label")
250
+ return f"LM init exception: {exc}"
251
+
252
+
253
+ def _write_entry_sidecar(entry: TrackEntry):
254
+ sidecar_path = Path(entry.audio_path).with_suffix(".json")
255
+ meta = {
256
+ "caption": entry.caption,
257
+ "lyrics": entry.lyrics,
258
+ "bpm": entry.bpm,
259
+ "keyscale": entry.keyscale,
260
+ "timesignature": entry.timesignature,
261
+ "vocal_language": entry.vocal_language,
262
+ "duration": entry.duration,
263
+ }
264
+ sidecar_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
265
+
266
+
267
+ @_gpu_callback
268
+ def auto_label_all(overwrite_existing: bool, caption_only: bool, max_files_per_run: int = 6, reset_cursor: bool = False):
269
+ """Auto-label all loaded tracks using ACE audio understanding (audio->codes->metadata)."""
270
+ global dataset_entries, _auto_label_cursor
271
+
272
+ if handler.model is None:
273
+ if _model_init_ok and _last_model_init_args:
274
+ status, ok = _init_model_impl(**_last_model_init_args)
275
+ if not ok:
276
+ return f"Model reload failed before auto-label:\n{status}", [], "Auto-label skipped."
277
+ else:
278
+ return "Initialize model first in Step 1.", [], "Auto-label skipped."
279
+ if not dataset_entries:
280
+ return "Load dataset first in Step 2.", [], "Auto-label skipped."
281
+ if not llm_handler.llm_initialized:
282
+ if _lm_init_ok and _last_lm_init_args:
283
+ status = _init_auto_label_lm_impl(**_last_lm_init_args)
284
+ if not llm_handler.llm_initialized:
285
+ return (
286
+ f"Auto-label LM reload failed:\n{status}",
287
+ _rows_from_entries(dataset_entries),
288
+ "Auto-label skipped.",
289
+ )
290
+ else:
291
+ return "Initialize Auto-Label LM first.", _rows_from_entries(dataset_entries), "Auto-label skipped."
292
+
293
+ if max_files_per_run <= 0:
294
+ max_files_per_run = 6
295
+ if reset_cursor:
296
+ _auto_label_cursor = 0
297
+ if _auto_label_cursor < 0 or _auto_label_cursor >= len(dataset_entries):
298
+ _auto_label_cursor = 0
299
+
300
+ start_idx = _auto_label_cursor
301
+ end_idx = min(len(dataset_entries), start_idx + int(max_files_per_run))
302
+
303
+ updated = 0
304
+ skipped = 0
305
+ failed = 0
306
+ logs: List[str] = []
307
+
308
+ for idx in range(start_idx, end_idx):
309
+ entry = dataset_entries[idx]
310
+ try:
311
+ missing_fields = []
312
+ if not (entry.caption or "").strip():
313
+ missing_fields.append("caption")
314
+ if (not caption_only) and (not (entry.lyrics or "").strip()):
315
+ missing_fields.append("lyrics")
316
+ if entry.bpm is None:
317
+ missing_fields.append("bpm")
318
+ if not (entry.keyscale or "").strip():
319
+ missing_fields.append("keyscale")
320
+ if entry.duration is None:
321
+ missing_fields.append("duration")
322
+
323
+ # Skip only when every core field is already available.
324
+ if (not overwrite_existing) and (len(missing_fields) == 0):
325
+ skipped += 1
326
+ logs.append(f"[{idx}] Skipped (already fully labeled): {Path(entry.audio_path).name}")
327
+ continue
328
+
329
+ codes = handler.convert_src_audio_to_codes(entry.audio_path)
330
+ if not codes or codes.startswith("❌"):
331
+ failed += 1
332
+ logs.append(f"[{idx}] Failed to convert audio to codes: {Path(entry.audio_path).name}")
333
+ continue
334
+
335
+ result = understand_music(
336
+ llm_handler=llm_handler,
337
+ audio_codes=codes,
338
+ temperature=0.85,
339
+ use_constrained_decoding=True,
340
+ constrained_decoding_debug=False,
341
+ )
342
+ if not result.success:
343
+ failed += 1
344
+ logs.append(f"[{idx}] Failed to label: {Path(entry.audio_path).name} ({result.error or result.status_message})")
345
+ continue
346
+
347
+ # Update fields. If overwrite is false, fill only missing values.
348
+ if overwrite_existing or not (entry.caption or "").strip():
349
+ entry.caption = (result.caption or entry.caption or "").strip()
350
+ if not caption_only:
351
+ if overwrite_existing or not (entry.lyrics or "").strip():
352
+ entry.lyrics = (result.lyrics or entry.lyrics or "").strip()
353
+ if entry.bpm is None and result.bpm is not None:
354
+ entry.bpm = int(result.bpm)
355
+ if (not entry.keyscale) and result.keyscale:
356
+ entry.keyscale = result.keyscale
357
+ if (not entry.timesignature) and result.timesignature:
358
+ entry.timesignature = result.timesignature
359
+ if (not entry.vocal_language) and result.language:
360
+ entry.vocal_language = result.language
361
+ if entry.duration is None and result.duration is not None:
362
+ entry.duration = float(result.duration)
363
+
364
+ _write_entry_sidecar(entry)
365
+ updated += 1
366
+ logs.append(f"[{idx}] Labeled: {Path(entry.audio_path).name}")
367
+ except Exception as exc:
368
+ failed += 1
369
+ logs.append(f"[{idx}] Exception: {Path(entry.audio_path).name} ({exc})")
370
+
371
+ _auto_label_cursor = 0 if end_idx >= len(dataset_entries) else end_idx
372
+ mode = "caption-only" if caption_only else "caption+lyrics"
373
+ progress_msg = (
374
+ f"Processed batch {start_idx + 1}-{end_idx} of {len(dataset_entries)}. "
375
+ if len(dataset_entries) > 0 else ""
376
+ )
377
+ if _auto_label_cursor == 0 and len(dataset_entries) > 0:
378
+ progress_msg += "Reached end of dataset."
379
+ else:
380
+ progress_msg += f"Next start index: {_auto_label_cursor}."
381
+ summary = (
382
+ f"Auto-label ({mode}) complete. Updated={updated}, Skipped={skipped}, Failed={failed}. "
383
+ f"{progress_msg}"
384
+ )
385
+ detail = "\n".join(logs[-40:]) if logs else "No logs."
386
+ return summary, _rows_from_entries(dataset_entries), detail
387
+
388
+
389
+ # ===========================================================================
390
+ # Tab 3 - Training
391
+ # ===========================================================================
392
+
393
+ def _run_training(config_dict: dict):
394
+ """Target for the background training thread."""
395
+ global trainer, _training_status, _training_log, _training_started_at
396
+ _training_status = "running"
397
+ _training_log.clear()
398
+ _training_started_at = time.time()
399
+
400
+ try:
401
+ cfg = LoRATrainConfig(**config_dict)
402
+ trainer = LoRATrainer(handler, cfg)
403
+ trainer.prepare()
404
+ _training_log.append(f"Training device: {handler.device}")
405
+
406
+ def _cb(step, total, loss, epoch):
407
+ elapsed = 0.0 if _training_started_at is None else max(0.0, time.time() - _training_started_at)
408
+ rate = (step / elapsed) if elapsed > 0 else 0.0
409
+ remaining = max(0, total - step)
410
+ eta_sec = (remaining / rate) if rate > 0 else -1.0
411
+ eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown"
412
+ msg = (
413
+ f"Step {step}/{total} Epoch {epoch+1} Loss {loss:.6f} "
414
+ f"Elapsed {elapsed/60:.1f}m ETA {eta_msg}"
415
+ )
416
+ _training_log.append(msg)
417
+
418
+ result = trainer.train(dataset_entries, progress_callback=_cb)
419
+ _training_log.append(result)
420
+ _training_status = "done"
421
+ except Exception as exc:
422
+ _training_log.append(f"ERROR: {exc}")
423
+ _training_status = "stopped"
424
+ logger.exception("Training failed")
425
+
426
+
427
+ def start_training(
428
+ lora_rank, lora_alpha, lora_dropout,
429
+ lr, weight_decay, optimizer_name,
430
+ max_grad_norm, warmup_ratio, scheduler_name,
431
+ num_epochs, batch_size, grad_accum,
432
+ save_every, log_every, shift,
433
+ max_duration, output_dir, resume_from,
434
+ ):
435
+ global _training_thread, _training_status
436
+
437
+ if handler.model is None:
438
+ return "Model not initialised. Go to Model Setup first."
439
+ if not dataset_entries:
440
+ return "No dataset loaded. Go to Dataset tab first."
441
+ if _training_status == "running":
442
+ return "Training already in progress."
443
+
444
+ config_dict = dict(
445
+ lora_rank=int(lora_rank),
446
+ lora_alpha=int(lora_alpha),
447
+ lora_dropout=float(lora_dropout),
448
+ learning_rate=float(lr),
449
+ weight_decay=float(weight_decay),
450
+ optimizer=optimizer_name,
451
+ max_grad_norm=float(max_grad_norm),
452
+ warmup_ratio=float(warmup_ratio),
453
+ scheduler=scheduler_name,
454
+ num_epochs=int(num_epochs),
455
+ batch_size=int(batch_size),
456
+ gradient_accumulation_steps=int(grad_accum),
457
+ save_every_n_epochs=int(save_every),
458
+ log_every_n_steps=int(log_every),
459
+ shift=float(shift),
460
+ max_duration_sec=float(max_duration),
461
+ output_dir=output_dir,
462
+ resume_from=(resume_from.strip() if isinstance(resume_from, str) and resume_from.strip() else None),
463
+ device=str(handler.device),
464
+ )
465
+
466
+ steps_per_epoch = math.ceil(len(dataset_entries) / int(batch_size))
467
+ total_steps = steps_per_epoch * int(num_epochs)
468
+ total_optim_steps = math.ceil(total_steps / int(grad_accum))
469
+
470
+ _training_thread = threading.Thread(target=_run_training, args=(config_dict,), daemon=True)
471
+ _training_thread.start()
472
+ return (
473
+ f"Training started on {handler.device}. "
474
+ f"Estimated optimiser steps: {total_optim_steps}."
475
+ )
476
+
477
+
478
+ def stop_training():
479
+ global trainer, _training_status
480
+ if trainer:
481
+ trainer.request_stop()
482
+ _training_status = "stopped"
483
+ return "Stop requested - will finish current step."
484
+ return "No training in progress."
485
+
486
+
487
+ def poll_training():
488
+ """Return current log + loss chart data."""
489
+ log_text = "\n".join(_training_log[-50:]) if _training_log else "(no output yet)"
490
+
491
+ # Build loss curve data
492
+ chart_data = []
493
+ if trainer and trainer.loss_history:
494
+ chart_data = [[h["step"], h["loss"]] for h in trainer.loss_history]
495
+
496
+ status = _training_status
497
+ device_line = f"Device: {handler.device}"
498
+ if torch.cuda.is_available() and str(handler.device).startswith("cuda"):
499
+ try:
500
+ idx = torch.cuda.current_device()
501
+ name = torch.cuda.get_device_name(idx)
502
+ allocated = torch.cuda.memory_allocated(idx) / (1024 ** 3)
503
+ reserved = torch.cuda.memory_reserved(idx) / (1024 ** 3)
504
+ device_line = (
505
+ f"Device: {handler.device} ({name}) | "
506
+ f"VRAM allocated={allocated:.2f}GB reserved={reserved:.2f}GB"
507
+ )
508
+ except Exception:
509
+ pass
510
+
511
+ return f"Status: {status}\n{device_line}\n\n{log_text}", chart_data
512
+
513
+
514
+ # ===========================================================================
515
+ # Tab 4 - Evaluation / A-B Test
516
+ # ===========================================================================
517
+
518
+ def list_adapters(output_dir: str):
519
+ adapters = LoRATrainer.list_adapters(output_dir)
520
+ return adapters if adapters else ["(none found)"]
521
+
522
+
523
+ @_gpu_callback
524
+ def load_adapter(adapter_path: str):
525
+ if not adapter_path or adapter_path == "(none found)":
526
+ return "Select a valid adapter path."
527
+ return handler.load_lora(adapter_path)
528
+
529
+
530
+ @_gpu_callback
531
+ def unload_adapter():
532
+ return handler.unload_lora()
533
+
534
+
535
+ def set_lora_scale(scale: float):
536
+ return handler.set_lora_scale(scale)
537
+
538
+
539
+ @_gpu_callback
540
+ def generate_sample(
541
+ prompt: str,
542
+ lyrics: str,
543
+ duration: float,
544
+ bpm: int,
545
+ steps: int,
546
+ guidance: float,
547
+ seed: int,
548
+ use_lora: bool,
549
+ lora_scale: float,
550
+ ):
551
+ """Generate a single audio sample for evaluation."""
552
+ if handler.model is None:
553
+ return None, "Model not initialised."
554
+
555
+ # Toggle LoRA if loaded
556
+ if handler.lora_loaded:
557
+ handler.set_use_lora(use_lora)
558
+ if use_lora:
559
+ handler.set_lora_scale(lora_scale)
560
+
561
+ actual_seed = int(seed) if seed >= 0 else random.randint(0, 2**32 - 1)
562
+
563
+ result = handler.generate_music(
564
+ captions=prompt,
565
+ lyrics=lyrics,
566
+ bpm=bpm if bpm > 0 else None,
567
+ inference_steps=steps,
568
+ guidance_scale=guidance,
569
+ use_random_seed=False,
570
+ seed=actual_seed,
571
+ audio_duration=duration,
572
+ batch_size=1,
573
+ )
574
+
575
+ if not result.get("success", False):
576
+ return None, result.get("error", "Generation failed.")
577
+
578
+ audios = result.get("audios", [])
579
+ if not audios:
580
+ return None, "No audio produced."
581
+
582
+ # Save to temp file
583
+ audio_data = audios[0]
584
+ wav_tensor = audio_data.get("tensor")
585
+ sr = audio_data.get("sample_rate", 48000)
586
+
587
+ if wav_tensor is None:
588
+ path = audio_data.get("path")
589
+ if path and os.path.exists(path):
590
+ return path, f"Generated (from file), seed={actual_seed}."
591
+ return None, "No audio tensor."
592
+
593
+ tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
594
+ audio_saver.save_audio(wav_tensor, tmp.name, sample_rate=sr)
595
+ return tmp.name, f"Generated successfully, seed={actual_seed}."
596
+
597
+
598
+ @_gpu_callback
599
+ def ab_test(
600
+ prompt, lyrics, duration, bpm, steps, guidance, seed,
601
+ lora_scale_b,
602
+ ):
603
+ """Generate two samples: A = base, B = LoRA at given scale."""
604
+ resolved_seed = int(seed) if seed >= 0 else random.randint(0, 2**32 - 1)
605
+ results = {}
606
+ for label, use, scale in [("A (base)", False, 0.0), ("B (LoRA)", True, lora_scale_b)]:
607
+ path, msg = generate_sample(
608
+ prompt, lyrics, duration, bpm, steps, guidance, resolved_seed,
609
+ use_lora=use, lora_scale=scale,
610
+ )
611
+ results[label] = (path, msg)
612
+
613
+ return (
614
+ results["A (base)"][0],
615
+ results["A (base)"][1],
616
+ results["B (LoRA)"][0],
617
+ results["B (LoRA)"][1],
618
+ )
619
+
620
+
621
+ # ===========================================================================
622
+ # Build the Gradio App
623
+ # ===========================================================================
624
+
625
+ def get_workflow_status():
626
+ model_is_ready = (handler.model is not None) or _model_init_ok
627
+ model_ready = "Ready" if model_is_ready else "Not initialized"
628
+ tracks = len(dataset_entries)
629
+ training_state = _training_status
630
+ lora_status = handler.get_lora_status() if handler.model is not None else {"loaded": False, "active": False, "scale": 1.0}
631
+ init_note = ""
632
+ if IS_SPACE and _model_init_ok and handler.model is None:
633
+ init_note = " (Zero GPU callback context)"
634
+ return (
635
+ f"Model: {model_ready}{init_note}\n"
636
+ f"Tracks Loaded: {tracks}\n"
637
+ f"Training: {training_state}\n"
638
+ f"LoRA Loaded: {lora_status.get('loaded', False)}\n"
639
+ f"LoRA Active: {lora_status.get('active', False)}\n"
640
+ f"LoRA Scale: {lora_status.get('scale', 1.0)}"
641
+ )
642
+
643
+
644
+ def init_model_and_status(
645
+ model_name: str,
646
+ device: str,
647
+ offload_cpu: bool,
648
+ offload_dit_cpu: bool,
649
+ ):
650
+ status = init_model(model_name, device, offload_cpu, offload_dit_cpu)
651
+ return status, get_workflow_status()
652
+
653
+
654
+ def build_ui():
655
+ available_models = get_available_models()
656
+
657
+ with gr.Blocks(title="ACE-Step 1.5 LoRA Studio", theme=gr.themes.Soft()) as app:
658
+ gr.Markdown(
659
+ "# ACE-Step 1.5 LoRA Studio\n"
660
+ "Use this guided workflow from left to right.\n\n"
661
+ "**Step 1:** Initialize model \n"
662
+ "**Step 2:** Load dataset \n"
663
+ "**Step 3:** Start training \n"
664
+ "**Step 4:** Evaluate adapter"
665
+ )
666
+ with gr.Row():
667
+ workflow_status = gr.Textbox(label="Workflow Status", value=get_workflow_status(), lines=6, interactive=False)
668
+ refresh_status_btn = gr.Button("Refresh Status")
669
+ refresh_status_btn.click(get_workflow_status, outputs=workflow_status, api_name="workflow_status")
670
+
671
+ # ---- Step 1 ----
672
+ with gr.Tab("Step 1 - Initialize Model"):
673
+ gr.Markdown(
674
+ "### Instructions\n"
675
+ "1. Pick a model (`acestep-v15-base` recommended for LoRA).\n"
676
+ "2. Keep device on `auto` unless you need manual override.\n"
677
+ "3. Click **Initialize Model** and confirm status is success."
678
+ )
679
+ with gr.Row():
680
+ model_dd = gr.Dropdown(
681
+ choices=available_models,
682
+ value=available_models[0] if available_models else None,
683
+ label="DiT Model",
684
+ )
685
+ device_dd = gr.Dropdown(
686
+ choices=["auto", "cuda", "mps", "cpu"],
687
+ value="auto",
688
+ label="Device",
689
+ )
690
+ with gr.Row():
691
+ offload_cb = gr.Checkbox(label="Offload To CPU (optional)", value=False)
692
+ offload_dit_cb = gr.Checkbox(label="Offload DiT To CPU (optional)", value=False)
693
+ init_btn = gr.Button("Initialize Model", variant="primary")
694
+ init_out = gr.Textbox(label="Initialization Output", lines=8, interactive=False)
695
+ init_btn.click(
696
+ init_model_and_status,
697
+ [model_dd, device_dd, offload_cb, offload_dit_cb],
698
+ [init_out, workflow_status],
699
+ api_name="init_model",
700
+ )
701
+
702
+ # ---- Step 2 ----
703
+ with gr.Tab("Step 2 - Load Dataset"):
704
+ gr.Markdown(
705
+ "### Instructions\n"
706
+ "1. Either scan a folder or drag/drop audio files (+ optional .json sidecars).\n"
707
+ "2. Confirm tracks appear in the table.\n"
708
+ "3. Optional: run Auto-Label All to fill caption/lyrics/metas.\n"
709
+ "4. Optional: edit metadata manually and save sidecar JSON."
710
+ )
711
+ with gr.Row():
712
+ folder_input = gr.Textbox(label="Dataset Folder Path", placeholder="e.g. ./dataset_inbox")
713
+ scan_btn = gr.Button("Scan Folder")
714
+ with gr.Row():
715
+ upload_files = gr.Files(
716
+ label="Drag/Drop Audio Files (+ Optional JSON Sidecars)",
717
+ file_count="multiple",
718
+ file_types=["audio", ".json"],
719
+ type="filepath",
720
+ )
721
+ upload_btn = gr.Button("Load Dropped Files")
722
+ scan_msg = gr.Textbox(label="Dataset Result", interactive=False)
723
+ dataset_table = gr.Dataframe(
724
+ headers=["File", "Duration", "Caption", "Lyrics", "Language"],
725
+ datatype=["str", "str", "str", "str", "str"],
726
+ label="Tracks",
727
+ interactive=False,
728
+ )
729
+ scan_btn.click(
730
+ scan_folder,
731
+ folder_input,
732
+ [scan_msg, dataset_table],
733
+ api_name="scan_folder",
734
+ )
735
+ upload_btn.click(
736
+ load_uploaded,
737
+ upload_files,
738
+ [scan_msg, dataset_table],
739
+ api_name="load_uploaded",
740
+ )
741
+
742
+ with gr.Accordion("Auto-Label (ACE audio understanding)", open=False):
743
+ gr.Markdown(
744
+ "Auto-label uses ACE: audio -> semantic codes -> metadata/lyrics.\n"
745
+ "Initialize LM first, then run Auto-Label All.\n"
746
+ "Use Caption-Only if your dataset has no lyrics.\n"
747
+ "On Zero GPU, process in smaller batches and click Auto-Label All repeatedly."
748
+ )
749
+ with gr.Row():
750
+ lm_model_dd = gr.Dropdown(
751
+ choices=["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
752
+ value="acestep-5Hz-lm-0.6B",
753
+ label="Auto-Label LM Model",
754
+ )
755
+ lm_backend_dd = gr.Dropdown(
756
+ choices=["pt", "vllm", "mlx"],
757
+ value="pt",
758
+ label="LM Backend",
759
+ )
760
+ lm_device_dd = gr.Dropdown(
761
+ choices=["auto", "cuda", "mps", "xpu", "cpu"],
762
+ value="auto",
763
+ label="LM Device",
764
+ )
765
+ with gr.Row():
766
+ init_lm_btn = gr.Button("Initialize Auto-Label LM")
767
+ overwrite_cb = gr.Checkbox(label="Overwrite Existing Caption/Lyrics", value=False)
768
+ caption_only_cb = gr.Checkbox(label="Caption-Only (Skip Lyrics)", value=True)
769
+ auto_label_btn = gr.Button("Auto-Label All", variant="primary")
770
+ with gr.Row():
771
+ max_files_per_run = gr.Slider(1, 25, value=6, step=1, label="Files Per Run (Zero GPU Safe)")
772
+ reset_cursor_cb = gr.Checkbox(label="Restart From First Track", value=False)
773
+ lm_init_status = gr.Textbox(label="Auto-Label LM Status", lines=5, interactive=False)
774
+ auto_label_status = gr.Textbox(label="Auto-Label Summary", interactive=False)
775
+ auto_label_log = gr.Textbox(label="Auto-Label Log", lines=8, interactive=False)
776
+ init_lm_btn.click(
777
+ init_auto_label_lm,
778
+ [lm_model_dd, lm_backend_dd, lm_device_dd],
779
+ lm_init_status,
780
+ api_name="init_auto_label_lm",
781
+ )
782
+ auto_label_btn.click(
783
+ auto_label_all,
784
+ [overwrite_cb, caption_only_cb, max_files_per_run, reset_cursor_cb],
785
+ [auto_label_status, dataset_table, auto_label_log],
786
+ api_name="auto_label_all",
787
+ )
788
+
789
+ with gr.Accordion("Optional: Edit Metadata Sidecar", open=False):
790
+ with gr.Row():
791
+ edit_idx = gr.Number(label="Track Index (0-based)", value=0, precision=0)
792
+ edit_caption = gr.Textbox(label="Caption")
793
+ edit_lyrics = gr.Textbox(label="Lyrics", lines=3)
794
+ with gr.Row():
795
+ edit_bpm = gr.Textbox(label="BPM", placeholder="e.g. 120")
796
+ edit_key = gr.Textbox(label="Key/Scale", placeholder="e.g. Am")
797
+ edit_lang = gr.Textbox(label="Language", value="en")
798
+ save_btn = gr.Button("Save Sidecar")
799
+ save_msg = gr.Textbox(label="Save Result", interactive=False)
800
+ save_btn.click(
801
+ save_sidecar,
802
+ [edit_idx, edit_caption, edit_lyrics, edit_bpm, edit_key, edit_lang],
803
+ save_msg,
804
+ api_name="save_sidecar",
805
+ )
806
+
807
+ # ---- Step 3 ----
808
+ with gr.Tab("Step 3 - Train LoRA"):
809
+ gr.Markdown(
810
+ "### Instructions\n"
811
+ "1. Keep default settings for first run.\n"
812
+ "2. Set output directory (defaults are good).\n"
813
+ "3. Click **Start Training** and monitor logs/loss.\n"
814
+ "4. Use **Stop Training** for graceful stop."
815
+ )
816
+ with gr.Row():
817
+ t_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs")
818
+ t_bs = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
819
+ t_accum = gr.Slider(1, 16, value=1, step=1, label="Grad Accumulation")
820
+ with gr.Row():
821
+ t_outdir = gr.Textbox(label="Output Directory", value=DEFAULT_OUTPUT_DIR)
822
+ t_resume = gr.Textbox(label="Resume From Adapter Directory (optional)", value="")
823
+
824
+ with gr.Accordion("Advanced Training Settings (optional)", open=False):
825
+ with gr.Row():
826
+ t_rank = gr.Slider(4, 256, value=64, step=4, label="LoRA Rank")
827
+ t_alpha = gr.Slider(4, 256, value=64, step=4, label="LoRA Alpha")
828
+ t_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="LoRA Dropout")
829
+ with gr.Row():
830
+ t_lr = gr.Number(label="Learning Rate", value=1e-4)
831
+ t_wd = gr.Number(label="Weight Decay", value=0.01)
832
+ t_optim = gr.Dropdown(["adamw", "adamw_8bit"], value="adamw_8bit", label="Optimizer")
833
+ with gr.Row():
834
+ t_grad_norm = gr.Number(label="Max Grad Norm", value=1.0)
835
+ t_warmup = gr.Number(label="Warmup Ratio", value=0.03)
836
+ t_sched = gr.Dropdown(
837
+ ["constant_with_warmup", "linear", "cosine"],
838
+ value="constant_with_warmup",
839
+ label="Scheduler",
840
+ )
841
+ with gr.Row():
842
+ t_save = gr.Slider(1, 100, value=10, step=1, label="Save Every N Epochs")
843
+ t_log = gr.Slider(1, 100, value=5, step=1, label="Log Every N Steps")
844
+ t_shift = gr.Number(label="Timestep Shift", value=3.0)
845
+ t_maxdur = gr.Number(label="Max Audio Duration (s)", value=240)
846
+
847
+ with gr.Row():
848
+ train_btn = gr.Button("Start Training", variant="primary")
849
+ stop_btn = gr.Button("Stop Training", variant="stop")
850
+ poll_btn = gr.Button("Refresh Log")
851
+
852
+ train_status = gr.Textbox(label="Training Log", lines=12, interactive=False)
853
+ loss_chart = gr.LinePlot(
854
+ x="Step",
855
+ y="Loss",
856
+ title="Training Loss",
857
+ x_title="Step",
858
+ y_title="Loss",
859
+ )
860
+
861
+ train_btn.click(
862
+ start_training,
863
+ [
864
+ t_rank, t_alpha, t_dropout,
865
+ t_lr, t_wd, t_optim,
866
+ t_grad_norm, t_warmup, t_sched,
867
+ t_epochs, t_bs, t_accum,
868
+ t_save, t_log, t_shift,
869
+ t_maxdur, t_outdir, t_resume,
870
+ ],
871
+ train_status,
872
+ api_name="start_training",
873
+ )
874
+ stop_btn.click(stop_training, outputs=train_status, api_name="stop_training")
875
+
876
+ def _poll_and_format():
877
+ log_text, chart_data = poll_training()
878
+ if chart_data:
879
+ import pandas as pd
880
+ df = pd.DataFrame(chart_data, columns=["Step", "Loss"])
881
+ else:
882
+ import pandas as pd
883
+ df = pd.DataFrame({"Step": [], "Loss": []})
884
+ return log_text, df
885
+
886
+ poll_btn.click(_poll_and_format, outputs=[train_status, loss_chart], api_name="poll_training")
887
+
888
+ # ---- Step 4 ----
889
+ with gr.Tab("Step 4 - Evaluate"):
890
+ gr.Markdown(
891
+ "### Instructions\n"
892
+ "1. Refresh adapter list and load a trained adapter.\n"
893
+ "2. Run single generation or A/B test.\n"
894
+ "3. Use same seed for fair comparison."
895
+ )
896
+
897
+ with gr.Accordion("Adapter Management", open=True):
898
+ with gr.Row():
899
+ adapter_dir = gr.Textbox(label="Adapters Directory", value=DEFAULT_OUTPUT_DIR)
900
+ refresh_btn = gr.Button("Refresh List")
901
+ adapter_dd = gr.Dropdown(label="Select Adapter", choices=[])
902
+ with gr.Row():
903
+ load_btn = gr.Button("Load Adapter", variant="primary")
904
+ unload_btn = gr.Button("Unload Adapter")
905
+ adapter_status = gr.Textbox(label="Adapter Status", interactive=False)
906
+
907
+ def _refresh(d):
908
+ adapters = list_adapters(d)
909
+ return gr.update(choices=adapters, value=adapters[0] if adapters else None)
910
+
911
+ refresh_btn.click(_refresh, adapter_dir, adapter_dd, api_name="list_adapters")
912
+ load_btn.click(load_adapter, adapter_dd, adapter_status, api_name="load_adapter")
913
+ unload_btn.click(unload_adapter, outputs=adapter_status, api_name="unload_adapter")
914
+
915
+ with gr.Accordion("Generation Settings", open=True):
916
+ with gr.Row():
917
+ eval_prompt = gr.Textbox(label="Prompt / Caption", lines=2, placeholder="upbeat pop rock with electric guitar")
918
+ eval_lyrics = gr.Textbox(label="Lyrics", lines=3, placeholder="[Instrumental]")
919
+ with gr.Row():
920
+ eval_dur = gr.Slider(10, 300, value=30, step=5, label="Duration (s)")
921
+ eval_bpm = gr.Number(label="BPM (0 = auto)", value=0)
922
+ eval_steps = gr.Slider(1, 100, value=8, step=1, label="Inference Steps")
923
+ with gr.Row():
924
+ eval_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.5, label="Guidance Scale")
925
+ eval_seed = gr.Number(label="Seed (-1 = random)", value=-1)
926
+
927
+ with gr.Row():
928
+ sg_use_lora = gr.Checkbox(label="Use LoRA", value=True)
929
+ sg_scale = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="LoRA Scale")
930
+ sg_btn = gr.Button("Generate", variant="primary")
931
+ sg_audio = gr.Audio(label="Single Output", type="filepath")
932
+ sg_msg = gr.Textbox(label="Generation Status", interactive=False)
933
+ sg_btn.click(
934
+ generate_sample,
935
+ [eval_prompt, eval_lyrics, eval_dur, eval_bpm, eval_steps, eval_guidance, eval_seed, sg_use_lora, sg_scale],
936
+ [sg_audio, sg_msg],
937
+ api_name="generate_sample",
938
+ )
939
+
940
+ gr.Markdown("#### A/B Test (Base vs LoRA)")
941
+ with gr.Row():
942
+ ab_scale = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="LoRA Scale for B")
943
+ ab_btn = gr.Button("Run A/B Test")
944
+ with gr.Row():
945
+ ab_audio_a = gr.Audio(label="A - Base", type="filepath")
946
+ ab_audio_b = gr.Audio(label="B - Base + LoRA", type="filepath")
947
+ with gr.Row():
948
+ ab_msg_a = gr.Textbox(label="Status A", interactive=False)
949
+ ab_msg_b = gr.Textbox(label="Status B", interactive=False)
950
+
951
+ ab_btn.click(
952
+ ab_test,
953
+ [eval_prompt, eval_lyrics, eval_dur, eval_bpm, eval_steps, eval_guidance, eval_seed, ab_scale],
954
+ [ab_audio_a, ab_msg_a, ab_audio_b, ab_msg_b],
955
+ api_name="ab_test",
956
+ )
957
+
958
+ app.queue(default_concurrency_limit=1)
959
+ return app
960
+
961
+
962
+ # ===========================================================================
963
+ # Entry point
964
+ # ===========================================================================
965
+
966
+ if __name__ == "__main__":
967
+ app = build_ui()
968
+ app.launch(
969
+ server_name="0.0.0.0",
970
+ server_port=7860,
971
+ share=False,
972
+ )
973
+
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
requirements.txt CHANGED
@@ -14,3 +14,7 @@ vector-quantize-pytorch
14
  PyYAML
15
  modelscope
16
  filelock>=3.13.0
 
 
 
 
 
14
  PyYAML
15
  modelscope
16
  filelock>=3.13.0
17
+ peft>=0.11.0
18
+ gradio>=4.0.0
19
+ pandas
20
+ bitsandbytes
scripts/endpoint/generate_interactive.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import json
4
+ import os
5
+ import sys
6
+ import time
7
+ from pathlib import Path
8
+ from urllib.error import HTTPError, URLError
9
+ from urllib.request import Request, urlopen
10
+
11
+ DEFAULT_URL = "https://your-endpoint-url.endpoints.huggingface.cloud"
12
+ DEFAULT_SAMPLE_RATE = 44100
13
+
14
+
15
+ def read_dotenv_value(key: str, dotenv_path: str = ".env") -> str:
16
+ path = Path(dotenv_path)
17
+ if not path.exists():
18
+ return ""
19
+ for raw in path.read_text(encoding="utf-8").splitlines():
20
+ line = raw.strip()
21
+ if not line or line.startswith("#") or "=" not in line:
22
+ continue
23
+ k, v = line.split("=", 1)
24
+ if k.strip() == key:
25
+ return v.strip().strip('"').strip("'")
26
+ return ""
27
+
28
+
29
+ def prompt_text(label: str, default: str = "", required: bool = False) -> str:
30
+ while True:
31
+ suffix = f" [{default}]" if default else ""
32
+ value = input(f"{label}{suffix}: ").strip()
33
+ if not value:
34
+ value = default
35
+ if value or not required:
36
+ return value
37
+ print("Value required.")
38
+
39
+
40
+ def prompt_int(label: str, default: int | None = None, allow_blank: bool = False) -> int | None:
41
+ while True:
42
+ default_str = "" if default is None else str(default)
43
+ value = prompt_text(label, default_str, required=not allow_blank)
44
+ if not value and allow_blank:
45
+ return None
46
+ try:
47
+ return int(value)
48
+ except ValueError:
49
+ print("Enter a valid integer.")
50
+
51
+
52
+ def prompt_float(label: str, default: float) -> float:
53
+ while True:
54
+ value = prompt_text(label, str(default), required=True)
55
+ try:
56
+ return float(value)
57
+ except ValueError:
58
+ print("Enter a valid number.")
59
+
60
+
61
+ def prompt_yes_no(label: str, default: bool) -> bool:
62
+ default_text = "y" if default else "n"
63
+ while True:
64
+ value = prompt_text(f"{label} (y/n)", default_text, required=True).lower()
65
+ if value in {"y", "yes", "1", "true", "t"}:
66
+ return True
67
+ if value in {"n", "no", "0", "false", "f"}:
68
+ return False
69
+ print("Please answer y or n.")
70
+
71
+
72
+ def prompt_multiline(label: str, end_token: str = "END") -> str:
73
+ print(label)
74
+ print(f"Finish lyrics by typing {end_token} on its own line.")
75
+ lines: list[str] = []
76
+ while True:
77
+ line = input()
78
+ if line.strip() == end_token:
79
+ break
80
+ lines.append(line)
81
+ return "\n".join(lines).strip()
82
+
83
+
84
+ def prompt_lyrics_optional() -> str:
85
+ use_lyrics = prompt_yes_no("Add custom lyrics", True)
86
+ if not use_lyrics:
87
+ return ""
88
+ return prompt_multiline("Paste lyrics (or just type END for none)")
89
+
90
+
91
+ def send_request(url: str, token: str, payload: dict) -> dict:
92
+ data = json.dumps(payload).encode("utf-8")
93
+ req = Request(
94
+ url=url,
95
+ data=data,
96
+ method="POST",
97
+ headers={
98
+ "Authorization": f"Bearer {token}",
99
+ "Content-Type": "application/json",
100
+ },
101
+ )
102
+ try:
103
+ with urlopen(req, timeout=3600) as resp:
104
+ body = resp.read().decode("utf-8")
105
+ return json.loads(body)
106
+ except HTTPError as e:
107
+ text = e.read().decode("utf-8", errors="replace")
108
+ raise RuntimeError(f"HTTP {e.code}: {text}") from e
109
+ except URLError as e:
110
+ raise RuntimeError(f"Network error: {e}") from e
111
+
112
+
113
+ def resolve_token(cli_token: str) -> str:
114
+ if cli_token:
115
+ return cli_token
116
+ env_token = os.getenv("HF_TOKEN") or os.getenv("hf_token")
117
+ if env_token:
118
+ return env_token
119
+ dotenv_token = read_dotenv_value("hf_token") or read_dotenv_value("HF_TOKEN")
120
+ return dotenv_token
121
+
122
+
123
+ def main() -> int:
124
+ parser = argparse.ArgumentParser(description="Interactive ACE-Step endpoint generator")
125
+ parser.add_argument("--url", default=os.getenv("HF_ENDPOINT_URL", DEFAULT_URL), help="Inference endpoint URL")
126
+ parser.add_argument("--token", default="", help="HF token. If omitted, uses env/.env")
127
+ parser.add_argument("--prompt", default="", help="Initial prompt")
128
+ parser.add_argument("--out-file", default="", help="Output WAV file path")
129
+ parser.add_argument(
130
+ "--advanced",
131
+ action="store_true",
132
+ help="Ask advanced generation options (seed/guidance/steps/sample-rate/LM).",
133
+ )
134
+ args = parser.parse_args()
135
+
136
+ print("=== ACE-Step Interactive Generation ===")
137
+
138
+ token = resolve_token(args.token)
139
+ if not token:
140
+ print("No token found. Set HF_TOKEN or hf_token in .env, or pass --token.")
141
+ return 1
142
+
143
+ url = prompt_text("Endpoint URL", args.url, required=True)
144
+ music_prompt = prompt_text("Music prompt", args.prompt, required=True)
145
+ bpm = prompt_int("BPM (blank for auto)", None, allow_blank=True)
146
+ duration_sec = prompt_int("Duration seconds", 120)
147
+ instrumental = prompt_yes_no("Instrumental (no vocals)", False)
148
+ lyrics = "" if instrumental else prompt_lyrics_optional()
149
+
150
+ # Quality-first defaults: use SFT + LM path configured on the endpoint.
151
+ sample_rate = DEFAULT_SAMPLE_RATE
152
+ steps = 50
153
+ guidance_scale = 7.0
154
+ seed = 42
155
+ use_lm = True
156
+ allow_fallback = False
157
+ simple_prompt = False
158
+
159
+ if args.advanced:
160
+ print("\nAdvanced options:")
161
+ sample_rate = prompt_int("Sample rate", DEFAULT_SAMPLE_RATE)
162
+ steps = prompt_int("Steps", 50)
163
+ guidance_scale = prompt_float("Guidance scale", 7.0)
164
+ seed = prompt_int("Seed", 42)
165
+ use_lm = prompt_yes_no("Use LM planning (higher quality, slower)", True)
166
+ allow_fallback = prompt_yes_no("Allow fallback sine audio", False)
167
+
168
+ default_out = args.out_file or f"music_{int(time.time())}.wav"
169
+ out_file = prompt_text("Output file", default_out, required=True)
170
+
171
+ inputs = {
172
+ "prompt": music_prompt,
173
+ "duration_sec": duration_sec,
174
+ "sample_rate": sample_rate,
175
+ "seed": seed,
176
+ "guidance_scale": guidance_scale,
177
+ "steps": steps,
178
+ "use_lm": use_lm,
179
+ "simple_prompt": simple_prompt,
180
+ "instrumental": instrumental,
181
+ "allow_fallback": allow_fallback,
182
+ }
183
+ if bpm is not None:
184
+ inputs["bpm"] = bpm
185
+ if lyrics:
186
+ inputs["lyrics"] = lyrics
187
+
188
+ payload = {"inputs": inputs}
189
+
190
+ print("\nSending request...")
191
+ try:
192
+ response = send_request(url, token, payload)
193
+ except Exception as e:
194
+ print(f"Request failed: {e}")
195
+ return 1
196
+
197
+ print("Response summary:")
198
+ print(json.dumps({
199
+ "used_fallback": response.get("used_fallback"),
200
+ "model_loaded": response.get("model_loaded"),
201
+ "model_error": response.get("model_error"),
202
+ "sample_rate": response.get("sample_rate"),
203
+ "duration_sec": response.get("duration_sec"),
204
+ }, indent=2))
205
+
206
+ if response.get("error"):
207
+ print(f"Endpoint error: {response['error']}")
208
+ return 1
209
+
210
+ audio_b64 = response.get("audio_base64_wav")
211
+ if not audio_b64:
212
+ print("No audio_base64_wav in response.")
213
+ return 1
214
+
215
+ audio_bytes = base64.b64decode(audio_b64)
216
+ Path(out_file).write_bytes(audio_bytes)
217
+ print(f"Saved audio: {out_file}")
218
+
219
+ return 0
220
+
221
+
222
+ if __name__ == "__main__":
223
+ raise SystemExit(main())
scripts/endpoint/test.bat ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal
3
+
4
+ set "PROMPT=%*"
5
+ if not defined PROMPT set "PROMPT=upbeat pop rap with emotional guitar"
6
+ set "PROMPT=%PROMPT:"=%"
7
+
8
+ powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0test.ps1" -Prompt "%PROMPT%" -SimplePrompt -DurationSec 12 -SampleRate 44100 -Seed 42 -GuidanceScale 7.0 -Steps 50 -UseLM 1 -OutFile "test_music.wav"
9
+
10
+ if errorlevel 1 (
11
+ echo Request failed.
12
+ exit /b 1
13
+ )
14
+
15
+ echo Done.
16
+ endlocal
scripts/endpoint/test.ps1 ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ param(
2
+ [string]$Token = "",
3
+ [string]$Url = "",
4
+ [string]$Prompt = "upbeat pop rap with emotional guitar",
5
+ [string]$Lyrics = "",
6
+ [int]$DurationSec = 3,
7
+ [int]$SampleRate = 44100,
8
+ [int]$Seed = 42,
9
+ [double]$GuidanceScale = 7.0,
10
+ [int]$Steps = 50,
11
+ [string]$UseLM = "true",
12
+ [switch]$SimplePrompt,
13
+ [switch]$Instrumental,
14
+ [switch]$RnbLoveTemptation,
15
+ [switch]$RnbPopRap2Min,
16
+ [switch]$AllowFallback,
17
+ [string]$OutFile = "test_music.wav"
18
+ )
19
+
20
+ $ErrorActionPreference = "Stop"
21
+
22
+ if (-not $Token) {
23
+ $Token = $env:HF_TOKEN
24
+ }
25
+ if (-not $Url) {
26
+ $Url = $env:HF_ENDPOINT_URL
27
+ }
28
+
29
+ if (-not $Token) {
30
+ throw "HF token not provided. Use -Token or set HF_TOKEN."
31
+ }
32
+ if (-not $Url) {
33
+ throw "Endpoint URL not provided. Use -Url or set HF_ENDPOINT_URL."
34
+ }
35
+
36
+ if ($RnbLoveTemptation.IsPresent) {
37
+ $Prompt = "melodic RnB pop with ambient melodies, evolving chord progression, emotional modern production, intimate vocals, soulful hooks"
38
+ $Lyrics = @"
39
+ [Verse 1]
40
+ Late night shadows on my skin,
41
+ I hear your name where the silence has been.
42
+ I swore I'd run when the fire got close,
43
+ But I keep chasing what hurts me the most.
44
+
45
+ [Pre-Chorus]
46
+ Every promise pulls me back again,
47
+ Sweet poison dressed like a loyal friend.
48
+
49
+ [Chorus]
50
+ I'm fighting love and temptation,
51
+ Heart in a war with my own salvation.
52
+ One touch and my defenses break,
53
+ I know it's danger but I stay awake.
54
+ I'm fighting love and temptation,
55
+ Drowning slow in this sweet devastation.
56
+ I want to leave but I hesitate,
57
+ Cause I still crave what I know I should hate.
58
+
59
+ [Verse 2]
60
+ Your voice is velvet over broken glass,
61
+ I learn the pain but I still relapse.
62
+ Truth on my lips, lies in my veins,
63
+ I pray for peace while I dance in flames.
64
+
65
+ [Bridge]
66
+ If love is a test, I'm failing with grace,
67
+ Still falling for fire, still calling your name.
68
+
69
+ [Final Chorus]
70
+ I'm fighting love and temptation,
71
+ Heart in a war with my own salvation.
72
+ One touch and my defenses break,
73
+ I know it's danger but I stay awake.
74
+ "@
75
+
76
+ if ($DurationSec -eq 3) {
77
+ $DurationSec = 24
78
+ }
79
+
80
+ if (-not $PSBoundParameters.ContainsKey("SimplePrompt")) {
81
+ $SimplePrompt = $false
82
+ }
83
+ }
84
+
85
+ if ($RnbPopRap2Min.IsPresent) {
86
+ $Prompt = "2 minute RnB Pop Rap song, melodic ambient pads, emotional chord progression, intimate female and male vocal blend, catchy hooks, modern drums, deep 808, vulnerable but confident tone"
87
+ $Lyrics = @"
88
+ [Intro]
89
+ Mm, midnight in my head, no sleep.
90
+ Same old war in my chest, still deep.
91
+ I say I'm done, then I call your name,
92
+ I run from the fire, then I walk in the flame.
93
+
94
+ [Verse 1]
95
+ Streetlights drip on the window pane,
96
+ I wear my pride like a silver chain.
97
+ Say I don't need you, that's what I say,
98
+ But your voice in my mind don't fade away.
99
+ I got dreams, got scars, got bills to pay,
100
+ Still I fold when your eyes pull me in that way.
101
+ I know better, I swear I do,
102
+ But temptation sounds like truth when it sounds like you.
103
+
104
+ [Pre-Chorus]
105
+ Every promise tastes sweet then turns to smoke,
106
+ I keep rebuilding hearts that we already broke.
107
+ I want peace, but I want your touch,
108
+ I know it's too much, still it's never enough.
109
+
110
+ [Chorus]
111
+ I'm fighting love and temptation,
112
+ Heart on trial with no salvation.
113
+ One more kiss and the walls cave in,
114
+ I lose myself just to feel again.
115
+ I'm fighting love and temptation,
116
+ Drowning slow in sweet devastation.
117
+ I say goodbye, then I hesitate,
118
+ Cause I still crave what I know I should hate.
119
+
120
+ [Rap Verse 1]
121
+ Look, I been in and out the same lane,
122
+ Different night, same rain.
123
+ Tell myself "don't text back,"
124
+ Still type your name, press send, same pain.
125
+ You the high and the low in one dose,
126
+ Got me praying for distance, still close.
127
+ I play tough, but the truth is loud,
128
+ When you're gone, all this noise in the crowd.
129
+ Yeah, I hustle, I grind, I glow,
130
+ But alone in the dark, I'm a different soul.
131
+ If love was logic, I'd be free by now,
132
+ But my heart ain't science, I just bleed it out.
133
+
134
+ [Verse 2]
135
+ Your perfume still lives in my hoodie seams,
136
+ Like a ghost in the corners of all my dreams.
137
+ I learned your chaos, your every disguise,
138
+ The saint in your smile, the storm in your eyes.
139
+ I touch your hand and forget my name,
140
+ Call it desire, call it blame.
141
+ I need healing, I need release,
142
+ But your lips keep turning my war to peace.
143
+
144
+ [Pre-Chorus]
145
+ Every promise tastes sweet then turns to smoke,
146
+ I keep rebuilding hearts that we already broke.
147
+ I want peace, but I want your touch,
148
+ I know it's too much, still it's never enough.
149
+
150
+ [Chorus]
151
+ I'm fighting love and temptation,
152
+ Heart on trial with no salvation.
153
+ One more kiss and the walls cave in,
154
+ I lose myself just to feel again.
155
+ I'm fighting love and temptation,
156
+ Drowning slow in sweet devastation.
157
+ I say goodbye, then I hesitate,
158
+ Cause I still crave what I know I should hate.
159
+
160
+ [Rap Verse 2]
161
+ Uh, late calls, no sleep, red eyes,
162
+ Truth hurts more than sweet lies.
163
+ We toxic, but the chemistry loud,
164
+ Like thunder in a summer night over this town.
165
+ Tell me leave, then you pull me near,
166
+ Tell me "trust me," then feed my fear.
167
+ I keep faith in a broken map,
168
+ Tryna find us on roads that don't lead back.
169
+ I got plans, got goals, got pride,
170
+ But temptation got hands on the wheel tonight.
171
+ If I fall, let me fall with grace,
172
+ I still see home when I look in your face.
173
+
174
+ [Bridge]
175
+ If this love is a test, I'm failing in style,
176
+ Smiling through fire for one more while.
177
+ I know I should run, I know I should wait,
178
+ But your name on my tongue sounds too much like fate.
179
+
180
+ [Final Chorus]
181
+ I'm fighting love and temptation,
182
+ Heart on trial with no salvation.
183
+ One more kiss and the walls cave in,
184
+ I lose myself just to feel again.
185
+ I'm fighting love and temptation,
186
+ Drowning slow in sweet devastation.
187
+ I say goodbye, then I hesitate,
188
+ Cause I still crave what I know I should hate.
189
+
190
+ [Outro]
191
+ Mm, midnight in my head, no sleep.
192
+ Still your name in my chest, too deep.
193
+ "@
194
+
195
+ if ($DurationSec -eq 3) {
196
+ $DurationSec = 120
197
+ }
198
+
199
+ if (-not $PSBoundParameters.ContainsKey("SimplePrompt")) {
200
+ $SimplePrompt = $false
201
+ }
202
+ }
203
+
204
+ $useLmBool = $true
205
+ if ($null -ne $UseLM -and $UseLM -ne "") {
206
+ try {
207
+ $useLmBool = [System.Convert]::ToBoolean($UseLM)
208
+ }
209
+ catch {
210
+ $useLmBool = ($UseLM -match '^(1|true|t|yes|y|on)$')
211
+ }
212
+ }
213
+
214
+ $inputs = @{
215
+ prompt = $Prompt
216
+ duration_sec = $DurationSec
217
+ sample_rate = $SampleRate
218
+ seed = $Seed
219
+ guidance_scale = $GuidanceScale
220
+ steps = $Steps
221
+ use_lm = $useLmBool
222
+ allow_fallback = $AllowFallback.IsPresent
223
+ }
224
+
225
+ if ($Lyrics) {
226
+ $inputs["lyrics"] = $Lyrics
227
+ }
228
+ if ($SimplePrompt.IsPresent) {
229
+ $inputs["simple_prompt"] = $true
230
+ }
231
+ if ($Instrumental.IsPresent) {
232
+ $inputs["instrumental"] = $true
233
+ }
234
+
235
+ $body = @{ inputs = $inputs } | ConvertTo-Json -Depth 8
236
+
237
+ $response = Invoke-RestMethod -Method Post -Uri $Url -Headers @{
238
+ Authorization = "Bearer $Token"
239
+ "Content-Type" = "application/json"
240
+ } -Body $body
241
+
242
+ $response | ConvertTo-Json -Depth 6
243
+
244
+ if ($response.error) {
245
+ throw "Endpoint returned error: $($response.error)"
246
+ }
247
+
248
+ if ($response.used_fallback -and -not $AllowFallback.IsPresent) {
249
+ throw "Endpoint used fallback audio. Set -AllowFallback only if you want fallback behavior."
250
+ }
251
+
252
+ if (-not $response.audio_base64_wav) {
253
+ throw "No audio_base64_wav returned."
254
+ }
255
+
256
+ [IO.File]::WriteAllBytes($OutFile, [Convert]::FromBase64String($response.audio_base64_wav))
257
+ Write-Host "Saved audio file: $OutFile"
scripts/endpoint/test_rnb.bat ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal
3
+
4
+ powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0test.ps1" -RnbLoveTemptation -DurationSec 24 -SampleRate 44100 -Seed 42 -GuidanceScale 7.0 -Steps 8 -UseLM 1 -OutFile "test_rnb_music.wav"
5
+
6
+ if errorlevel 1 (
7
+ echo Request failed.
8
+ exit /b 1
9
+ )
10
+
11
+ echo Done.
12
+ endlocal
scripts/endpoint/test_rnb_2min.bat ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ setlocal
3
+
4
+ powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0test.ps1" -RnbPopRap2Min -DurationSec 120 -SampleRate 44100 -Seed 42 -GuidanceScale 7.0 -Steps 8 -UseLM 1 -OutFile "test_rnb_pop_rap_2min.wav"
5
+
6
+ if errorlevel 1 (
7
+ echo Request failed.
8
+ exit /b 1
9
+ )
10
+
11
+ echo Done.
12
+ endlocal
scripts/hf_clone.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Bootstrap this project into your own Hugging Face Space and/or Endpoint repo.
4
+
5
+ Examples:
6
+ python scripts/hf_clone.py space --repo-id your-name/ace-step-lora-studio
7
+ python scripts/hf_clone.py endpoint --repo-id your-name/ace-step-endpoint
8
+ python scripts/hf_clone.py all --space-repo-id your-name/ace-step-lora-studio --endpoint-repo-id your-name/ace-step-endpoint
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import os
15
+ import shutil
16
+ import tempfile
17
+ from pathlib import Path
18
+ from typing import Iterable
19
+
20
+ from huggingface_hub import HfApi
21
+
22
+
23
+ PROJECT_ROOT = Path(__file__).resolve().parents[1]
24
+
25
+ COMMON_SKIP_DIRS = {
26
+ ".git",
27
+ "__pycache__",
28
+ ".pytest_cache",
29
+ ".mypy_cache",
30
+ ".ruff_cache",
31
+ ".venv",
32
+ "venv",
33
+ "env",
34
+ ".idea",
35
+ ".vscode",
36
+ ".cache",
37
+ ".huggingface",
38
+ ".gradio",
39
+ "checkpoints",
40
+ "lora_output",
41
+ "outputs",
42
+ "artifacts",
43
+ "models",
44
+ "datasets",
45
+ "Lora-ace-step",
46
+ }
47
+
48
+ COMMON_SKIP_FILES = {
49
+ ".env",
50
+ }
51
+
52
+ COMMON_SKIP_PREFIXES = (
53
+ "song_summaries_llm",
54
+ )
55
+
56
+ COMMON_SKIP_SUFFIXES = {
57
+ ".wav",
58
+ ".flac",
59
+ ".mp3",
60
+ ".ogg",
61
+ ".opus",
62
+ ".m4a",
63
+ ".aac",
64
+ ".pt",
65
+ ".bin",
66
+ ".safetensors",
67
+ ".ckpt",
68
+ ".onnx",
69
+ ".log",
70
+ ".pyc",
71
+ ".pyo",
72
+ ".pyd",
73
+ }
74
+
75
+ MAX_FILE_BYTES = 30 * 1024 * 1024 # 30MB safety cap for upload snapshot
76
+
77
+
78
+ def _should_skip_common(rel_path: Path, is_dir: bool) -> bool:
79
+ if any(part in COMMON_SKIP_DIRS for part in rel_path.parts):
80
+ return True
81
+ if rel_path.name in COMMON_SKIP_FILES:
82
+ return True
83
+ if any(rel_path.name.startswith(prefix) for prefix in COMMON_SKIP_PREFIXES):
84
+ return True
85
+ if not is_dir and rel_path.suffix.lower() in COMMON_SKIP_SUFFIXES:
86
+ return True
87
+ return False
88
+
89
+
90
+ def _copy_file(src: Path, dst: Path) -> None:
91
+ dst.parent.mkdir(parents=True, exist_ok=True)
92
+ shutil.copy2(src, dst)
93
+
94
+
95
+ def _stage_space_snapshot(staging_dir: Path) -> tuple[int, int, list[str]]:
96
+ copied = 0
97
+ bytes_total = 0
98
+ skipped: list[str] = []
99
+
100
+ for src in PROJECT_ROOT.rglob("*"):
101
+ rel = src.relative_to(PROJECT_ROOT)
102
+
103
+ if src.is_dir():
104
+ if _should_skip_common(rel, is_dir=True):
105
+ skipped.append(f"{rel}/")
106
+ continue
107
+
108
+ if _should_skip_common(rel, is_dir=False):
109
+ skipped.append(str(rel))
110
+ continue
111
+
112
+ size = src.stat().st_size
113
+ if size > MAX_FILE_BYTES:
114
+ skipped.append(f"{rel} (>{MAX_FILE_BYTES // (1024 * 1024)}MB)")
115
+ continue
116
+
117
+ dst = staging_dir / rel
118
+ _copy_file(src, dst)
119
+ copied += 1
120
+ bytes_total += size
121
+
122
+ return copied, bytes_total, skipped
123
+
124
+
125
+ def _iter_endpoint_paths() -> Iterable[Path]:
126
+ # Minimal runtime set for custom endpoint repos.
127
+ required = [
128
+ PROJECT_ROOT / "handler.py",
129
+ PROJECT_ROOT / "requirements.txt",
130
+ PROJECT_ROOT / "packages.txt",
131
+ PROJECT_ROOT / "acestep",
132
+ ]
133
+ for p in required:
134
+ if p.exists():
135
+ yield p
136
+
137
+ template_readme = PROJECT_ROOT / "templates" / "hf-endpoint" / "README.md"
138
+ if template_readme.exists():
139
+ yield template_readme
140
+
141
+
142
+ def _stage_endpoint_snapshot(staging_dir: Path) -> tuple[int, int]:
143
+ copied = 0
144
+ bytes_total = 0
145
+
146
+ for src in _iter_endpoint_paths():
147
+ if src.is_file():
148
+ rel_dst = Path("README.md") if src.name == "README.md" and "templates" in src.parts else Path(src.name)
149
+ dst = staging_dir / rel_dst
150
+ _copy_file(src, dst)
151
+ copied += 1
152
+ bytes_total += src.stat().st_size
153
+ continue
154
+
155
+ if src.is_dir():
156
+ for nested in src.rglob("*"):
157
+ rel_nested = nested.relative_to(src)
158
+ if nested.is_dir():
159
+ if _should_skip_common(Path(src.name) / rel_nested, is_dir=True):
160
+ continue
161
+ continue
162
+ if _should_skip_common(Path(src.name) / rel_nested, is_dir=False):
163
+ continue
164
+ if nested.suffix.lower() in {".wav", ".flac", ".mp3", ".ogg"}:
165
+ continue
166
+
167
+ dst = staging_dir / src.name / rel_nested
168
+ _copy_file(nested, dst)
169
+ copied += 1
170
+ bytes_total += nested.stat().st_size
171
+
172
+ return copied, bytes_total
173
+
174
+
175
+ def _resolve_token(arg_token: str) -> str | None:
176
+ if arg_token:
177
+ return arg_token
178
+ return os.getenv("HF_TOKEN")
179
+
180
+
181
+ def _ensure_repo(
182
+ api: HfApi,
183
+ repo_id: str,
184
+ repo_type: str,
185
+ private: bool,
186
+ space_sdk: str | None = None,
187
+ ) -> None:
188
+ kwargs = {
189
+ "repo_id": repo_id,
190
+ "repo_type": repo_type,
191
+ "private": private,
192
+ "exist_ok": True,
193
+ }
194
+ if repo_type == "space" and space_sdk:
195
+ kwargs["space_sdk"] = space_sdk
196
+ api.create_repo(**kwargs)
197
+
198
+
199
+ def _upload_snapshot(
200
+ api: HfApi,
201
+ repo_id: str,
202
+ repo_type: str,
203
+ folder_path: Path,
204
+ commit_message: str,
205
+ ) -> None:
206
+ api.upload_folder(
207
+ repo_id=repo_id,
208
+ repo_type=repo_type,
209
+ folder_path=str(folder_path),
210
+ commit_message=commit_message,
211
+ delete_patterns=[],
212
+ )
213
+
214
+
215
+ def _fmt_mb(num_bytes: int) -> str:
216
+ return f"{num_bytes / (1024 * 1024):.2f} MB"
217
+
218
+
219
+ def clone_space(repo_id: str, private: bool, token: str | None, dry_run: bool) -> None:
220
+ with tempfile.TemporaryDirectory(prefix="hf_space_clone_") as tmp:
221
+ staging = Path(tmp)
222
+ copied, bytes_total, skipped = _stage_space_snapshot(staging)
223
+ print(f"[space] staged files: {copied}, size: {_fmt_mb(bytes_total)}")
224
+ if skipped:
225
+ print(f"[space] skipped entries: {len(skipped)}")
226
+ for item in skipped[:20]:
227
+ print(f" - {item}")
228
+ if len(skipped) > 20:
229
+ print(f" ... and {len(skipped) - 20} more")
230
+
231
+ if dry_run:
232
+ print("[space] dry-run complete (nothing uploaded).")
233
+ return
234
+
235
+ api = HfApi(token=token)
236
+ _ensure_repo(api, repo_id=repo_id, repo_type="space", private=private, space_sdk="gradio")
237
+ _upload_snapshot(
238
+ api,
239
+ repo_id=repo_id,
240
+ repo_type="space",
241
+ folder_path=staging,
242
+ commit_message="Bootstrap ACE-Step LoRA Studio Space",
243
+ )
244
+ print(f"[space] uploaded to https://huggingface.co/spaces/{repo_id}")
245
+
246
+
247
+ def clone_endpoint(repo_id: str, private: bool, token: str | None, dry_run: bool) -> None:
248
+ with tempfile.TemporaryDirectory(prefix="hf_endpoint_clone_") as tmp:
249
+ staging = Path(tmp)
250
+ copied, bytes_total = _stage_endpoint_snapshot(staging)
251
+ print(f"[endpoint] staged files: {copied}, size: {_fmt_mb(bytes_total)}")
252
+
253
+ if dry_run:
254
+ print("[endpoint] dry-run complete (nothing uploaded).")
255
+ return
256
+
257
+ api = HfApi(token=token)
258
+ _ensure_repo(api, repo_id=repo_id, repo_type="model", private=private)
259
+ _upload_snapshot(
260
+ api,
261
+ repo_id=repo_id,
262
+ repo_type="model",
263
+ folder_path=staging,
264
+ commit_message="Bootstrap ACE-Step custom endpoint repo",
265
+ )
266
+ print(f"[endpoint] uploaded to https://huggingface.co/{repo_id}")
267
+
268
+
269
+ def build_parser() -> argparse.ArgumentParser:
270
+ parser = argparse.ArgumentParser(description="Clone this project into your own HF Space/Endpoint repos.")
271
+ subparsers = parser.add_subparsers(dest="cmd", required=True)
272
+
273
+ p_space = subparsers.add_parser("space", help="Create/update your HF Space from this project.")
274
+ p_space.add_argument("--repo-id", required=True, help="Target space repo id, e.g. username/my-space.")
275
+ p_space.add_argument("--private", action="store_true", help="Create repo as private.")
276
+ p_space.add_argument("--token", type=str, default="", help="HF token (default: HF_TOKEN env var).")
277
+ p_space.add_argument("--dry-run", action="store_true", help="Stage files only; do not upload.")
278
+
279
+ p_endpoint = subparsers.add_parser("endpoint", help="Create/update your custom endpoint model repo.")
280
+ p_endpoint.add_argument("--repo-id", required=True, help="Target model repo id, e.g. username/my-endpoint.")
281
+ p_endpoint.add_argument("--private", action="store_true", help="Create repo as private.")
282
+ p_endpoint.add_argument("--token", type=str, default="", help="HF token (default: HF_TOKEN env var).")
283
+ p_endpoint.add_argument("--dry-run", action="store_true", help="Stage files only; do not upload.")
284
+
285
+ p_all = subparsers.add_parser("all", help="Run both Space and Endpoint bootstrap.")
286
+ p_all.add_argument("--space-repo-id", required=True, help="Target space repo id.")
287
+ p_all.add_argument("--endpoint-repo-id", required=True, help="Target endpoint model repo id.")
288
+ p_all.add_argument("--space-private", action="store_true", help="Create Space as private.")
289
+ p_all.add_argument("--endpoint-private", action="store_true", help="Create endpoint repo as private.")
290
+ p_all.add_argument("--token", type=str, default="", help="HF token (default: HF_TOKEN env var).")
291
+ p_all.add_argument("--dry-run", action="store_true", help="Stage files only; do not upload.")
292
+
293
+ return parser
294
+
295
+
296
+ def main() -> int:
297
+ args = build_parser().parse_args()
298
+ token = _resolve_token(args.token)
299
+
300
+ if not token and not args.dry_run:
301
+ print("HF token not found. Set HF_TOKEN or pass --token.")
302
+ return 1
303
+
304
+ if args.cmd == "space":
305
+ clone_space(args.repo_id, private=bool(args.private), token=token, dry_run=bool(args.dry_run))
306
+ elif args.cmd == "endpoint":
307
+ clone_endpoint(args.repo_id, private=bool(args.private), token=token, dry_run=bool(args.dry_run))
308
+ else:
309
+ clone_space(args.space_repo_id, private=bool(args.space_private), token=token, dry_run=bool(args.dry_run))
310
+ clone_endpoint(
311
+ args.endpoint_repo_id,
312
+ private=bool(args.endpoint_private),
313
+ token=token,
314
+ dry_run=bool(args.dry_run),
315
+ )
316
+
317
+ return 0
318
+
319
+
320
+ if __name__ == "__main__":
321
+ raise SystemExit(main())
scripts/jobs/submit_hf_lora_job.ps1 ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ param(
2
+ [string]$CodeRepo = "YOUR_USERNAME/ace-step-lora-studio",
3
+ [string]$DatasetRepo = "",
4
+ [string]$DatasetRevision = "main",
5
+ [string]$DatasetSubdir = "",
6
+ [string]$ModelConfig = "acestep-v15-base",
7
+ [string]$Flavor = "a10g-large",
8
+ [string]$Timeout = "8h",
9
+ [int]$Epochs = 20,
10
+ [int]$BatchSize = 1,
11
+ [int]$GradAccum = 1,
12
+ [string]$OutputDir = "/workspace/output",
13
+ [string]$UploadRepo = "",
14
+ [switch]$UploadPrivate,
15
+ [switch]$Detach
16
+ )
17
+
18
+ $ErrorActionPreference = "Stop"
19
+
20
+ if (-not $DatasetRepo) {
21
+ throw "Provide -DatasetRepo (HF dataset repo containing your audio + optional sidecars)."
22
+ }
23
+
24
+ $secretArgs = @("--secrets", "HF_TOKEN")
25
+
26
+ $uploadArgs = ""
27
+ if ($UploadRepo) {
28
+ $uploadArgs = "--upload-repo `"$UploadRepo`""
29
+ if ($UploadPrivate.IsPresent) {
30
+ $uploadArgs += " --upload-private"
31
+ }
32
+ }
33
+
34
+ $datasetSubdirArgs = ""
35
+ if ($DatasetSubdir) {
36
+ $datasetSubdirArgs = "--dataset-subdir `"$DatasetSubdir`""
37
+ }
38
+
39
+ $detachArg = ""
40
+ if ($Detach.IsPresent) {
41
+ $detachArg = "--detach"
42
+ }
43
+
44
+ $jobCommand = @"
45
+ set -e
46
+ python -m pip install --no-cache-dir --upgrade pip
47
+ git clone https://huggingface.co/$CodeRepo /workspace/code
48
+ cd /workspace/code
49
+ python -m pip install --no-cache-dir -r requirements.txt
50
+ python lora_train.py \
51
+ --dataset-repo "$DatasetRepo" \
52
+ --dataset-revision "$DatasetRevision" \
53
+ $datasetSubdirArgs \
54
+ --model-config "$ModelConfig" \
55
+ --device auto \
56
+ --num-epochs $Epochs \
57
+ --batch-size $BatchSize \
58
+ --grad-accum $GradAccum \
59
+ --output-dir "$OutputDir" \
60
+ $uploadArgs
61
+ "@
62
+
63
+ $argsList = @(
64
+ "jobs", "run",
65
+ "--flavor", $Flavor,
66
+ "--timeout", $Timeout
67
+ ) + $secretArgs
68
+
69
+ if ($detachArg) {
70
+ $argsList += $detachArg
71
+ }
72
+
73
+ $argsList += @(
74
+ "pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime",
75
+ "bash", "-lc", $jobCommand
76
+ )
77
+
78
+ Write-Host "Submitting HF Job with flavor=$Flavor timeout=$Timeout ..."
79
+ Write-Host "Dataset repo: $DatasetRepo"
80
+ Write-Host "Code repo: $CodeRepo"
81
+ if ($UploadRepo) {
82
+ Write-Host "Will upload final adapter to: $UploadRepo"
83
+ }
84
+
85
+ & hf @argsList
summaries/findings.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Improving ACE-Step LoRA with Time-Event-Based Annotation
2
+
3
+ [Back to project README](../README.md)
4
+
5
+ ## Baseline context in this repo
6
+
7
+ This project already provides a solid end-to-end workflow:
8
+
9
+ - Train LoRA adapters with `lora_train.py` and the Gradio UI (`app.py`, `lora_ui.py`).
10
+ - Deploy generation through a custom endpoint runtime (`handler.py`, `acestep/`).
11
+ - Test prompts and lyrics quickly with endpoint client scripts in `scripts/endpoint/`.
12
+
13
+ Today, most conditioning in this pipeline is still global (caption, lyrics, BPM, key, tags). That is a strong baseline, but it does not explicitly teach *when* events happen inside a track.
14
+
15
+ ## Core limitation
16
+
17
+ Current annotations usually describe *what* a song is, not *when* events occur. The model can learn style and texture, but temporal structure is weaker:
18
+
19
+ - Verse/chorus transitions are often less deliberate than human-produced songs.
20
+ - Build-ups, drops, or effect changes can feel averaged or blurred.
21
+ - Subgenre-specific arrangement timing is harder to reproduce consistently.
22
+
23
+ ## Why time-event labels are promising
24
+
25
+ 1. Better musical structure: teach the model where sections start/end and where key transitions occur.
26
+ 2. Better genre fidelity: encode timing differences between styles that share similar instruments.
27
+ 3. Better control at inference: allow prompting for both content and structure (what + when).
28
+
29
+ ## Practical direction for this codebase
30
+
31
+ A useful next step is to extend the current sidecar metadata approach with optional timed events.
32
+
33
+ Example direction:
34
+
35
+ - Keep existing fields (`caption`, `lyrics`, `bpm`, etc.).
36
+ - Add an `events` list with event type + start/end times.
37
+ - Start with a small, high-quality subset before scaling.
38
+
39
+ Illustrative shape:
40
+
41
+ ```json
42
+ {
43
+ "caption": "emotional rnb pop with warm pads",
44
+ "bpm": 92,
45
+ "events": [
46
+ {"type": "intro", "start": 0.0, "end": 8.0},
47
+ {"type": "verse", "start": 8.0, "end": 32.0},
48
+ {"type": "chorus", "start": 32.0, "end": 48.0}
49
+ ]
50
+ }
51
+ ```
52
+
53
+ ## Early experiments worth running
54
+
55
+ - Compare baseline LoRA vs time-event LoRA on the same curated mini-dataset.
56
+ - Score structural accuracy (section order, transition timing tolerance).
57
+ - Run blind listening tests for perceived musical arc and arrangement coherence.
58
+ - Track whether time labels improve consistency without reducing creativity.
59
+
60
+ ## Expected outcomes
61
+
62
+ If this works, this repo can evolve from "style-conditioned generation" toward "structure-aware generation":
63
+
64
+ - More intentional song progression.
65
+ - Stronger subgenre identity.
66
+ - Better controllability for creators.
67
+
68
+ This is still a baseline research note, but it gives a clear technical direction that fits the current project architecture.
templates/hf-endpoint/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ACE-Step Custom Endpoint Repo
2
+
3
+ This repo is intended for a Hugging Face **Dedicated Inference Endpoint** with a custom `handler.py`.
4
+
5
+ ## Contents
6
+
7
+ - `handler.py`: Endpoint request/response logic.
8
+ - `acestep/`: Core inference utilities.
9
+ - `requirements.txt`: Python dependencies.
10
+ - `packages.txt`: System dependencies.
11
+
12
+ ## Expected Request Payload
13
+
14
+ ```json
15
+ {
16
+ "inputs": {
17
+ "prompt": "upbeat pop rap with emotional guitar",
18
+ "lyrics": "[Verse] city lights and midnight rain",
19
+ "duration_sec": 12,
20
+ "sample_rate": 44100,
21
+ "seed": 42,
22
+ "guidance_scale": 7.0,
23
+ "steps": 50,
24
+ "use_lm": true
25
+ }
26
+ }
27
+ ```
28
+
29
+ ## Quick Setup
30
+
31
+ 1. Create a model repo on Hugging Face.
32
+ 2. Push this folder content to that repo.
33
+ 3. Create a new dedicated endpoint from this custom repo.
34
+ 4. Set environment variables on the endpoint as needed:
35
+ - `ACE_CONFIG_PATH` (default `acestep-v15-sft`)
36
+ - `ACE_LM_MODEL_PATH` (default `acestep-5Hz-lm-4B`)
37
+ - `ACE_DOWNLOAD_SOURCE` (`huggingface` or `modelscope`)
38
+ 5. Scale down or pause when idle to control cost.