Andrew commited on
Commit ·
bd37cca
1
Parent(s): a3ab20b
github push
Browse files- .env.example +5 -0
- .gitignore +64 -2
- CONTRIBUTING.md +25 -0
- LICENSE +21 -0
- README.md +147 -0
- acestep/handler.py +41 -2
- acestep/llm_inference.py +1 -1
- app.py +7 -0
- docs/ACE-Step-1.5-LoRA-HF-Consolidated.md +131 -0
- docs/deploy/ENDPOINT.md +80 -0
- docs/deploy/SPACE.md +40 -0
- docs/guides/README.md +5 -0
- docs/guides/qwen2-audio-train.md +0 -0
- handler.py +44 -6
- lora_train.py +1056 -0
- lora_ui.py +973 -0
- packages.txt +2 -0
- requirements.txt +4 -0
- scripts/endpoint/generate_interactive.py +223 -0
- scripts/endpoint/test.bat +16 -0
- scripts/endpoint/test.ps1 +257 -0
- scripts/endpoint/test_rnb.bat +12 -0
- scripts/endpoint/test_rnb_2min.bat +12 -0
- scripts/hf_clone.py +321 -0
- scripts/jobs/submit_hf_lora_job.ps1 +85 -0
- summaries/findings.md +68 -0
- templates/hf-endpoint/README.md +38 -0
.env.example
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
HF_TOKEN=hf_xxx_your_token_here
|
| 2 |
+
HF_ENDPOINT_URL=https://your-endpoint-url.endpoints.huggingface.cloud
|
| 3 |
+
|
| 4 |
+
# Optional defaults used by scripts/hf_clone.py
|
| 5 |
+
HF_USERNAME=your-hf-username
|
.gitignore
CHANGED
|
@@ -1,4 +1,66 @@
|
|
|
|
|
| 1 |
.env
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
*.wav
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment
|
| 2 |
.env
|
| 3 |
+
.env.*
|
| 4 |
+
!.env.example
|
| 5 |
+
|
| 6 |
+
# Python cache/build artifacts
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
.pytest_cache/
|
| 12 |
+
.mypy_cache/
|
| 13 |
+
.ruff_cache/
|
| 14 |
+
.ipynb_checkpoints/
|
| 15 |
+
.coverage
|
| 16 |
+
htmlcov/
|
| 17 |
+
build/
|
| 18 |
+
dist/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
|
| 21 |
+
# Virtual environments
|
| 22 |
+
.venv/
|
| 23 |
+
venv/
|
| 24 |
+
env/
|
| 25 |
+
|
| 26 |
+
# Tool and local runtime caches
|
| 27 |
+
.cache/
|
| 28 |
+
.huggingface/
|
| 29 |
+
.gradio/
|
| 30 |
+
|
| 31 |
+
# Logs/temp
|
| 32 |
+
*.log
|
| 33 |
+
*.tmp
|
| 34 |
+
*.temp
|
| 35 |
+
*.bak
|
| 36 |
+
|
| 37 |
+
# Model/data/runtime artifacts
|
| 38 |
+
checkpoints/
|
| 39 |
+
lora_output/
|
| 40 |
+
outputs/
|
| 41 |
+
artifacts/
|
| 42 |
+
models/
|
| 43 |
+
datasets/
|
| 44 |
+
/data/
|
| 45 |
*.wav
|
| 46 |
+
*.flac
|
| 47 |
+
*.mp3
|
| 48 |
+
*.ogg
|
| 49 |
+
*.opus
|
| 50 |
+
*.m4a
|
| 51 |
+
*.aac
|
| 52 |
+
*.pt
|
| 53 |
+
*.bin
|
| 54 |
+
*.safetensors
|
| 55 |
+
*.ckpt
|
| 56 |
+
*.onnx
|
| 57 |
+
|
| 58 |
+
# OS/editor
|
| 59 |
+
.DS_Store
|
| 60 |
+
Thumbs.db
|
| 61 |
+
.idea/
|
| 62 |
+
.vscode/
|
| 63 |
+
|
| 64 |
+
# Optional local working copies
|
| 65 |
+
Lora-ace-step/
|
| 66 |
+
song_summaries_llm*.md
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing
|
| 2 |
+
|
| 3 |
+
## Development Setup
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
python -m pip install --upgrade pip
|
| 7 |
+
python -m pip install -r requirements.txt
|
| 8 |
+
python app.py
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
## Before Opening A PR
|
| 12 |
+
|
| 13 |
+
1. Keep secrets out of git (`HF_TOKEN`, endpoint URLs, `.env`).
|
| 14 |
+
2. Do not commit local artifacts (`checkpoints/`, `lora_output/`, generated audio).
|
| 15 |
+
3. Run quick CLI sanity checks:
|
| 16 |
+
- `python lora_train.py --help`
|
| 17 |
+
- `python scripts/hf_clone.py --help`
|
| 18 |
+
- `python scripts/endpoint/generate_interactive.py --help`
|
| 19 |
+
4. Update docs (`README.md`, `docs/deploy/*`) if behavior or workflows changed.
|
| 20 |
+
|
| 21 |
+
## Scope Guidelines
|
| 22 |
+
|
| 23 |
+
- UI + training workflow changes belong in `lora_ui.py` / `lora_train.py`.
|
| 24 |
+
- Inference endpoint changes belong in `handler.py`.
|
| 25 |
+
- Shared ACE-Step runtime logic belongs in `acestep/`.
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 ACE-Step LoRA Studio contributors
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ACE-Step 1.5 LoRA Studio
|
| 3 |
+
emoji: music
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: teal
|
| 6 |
+
sdk: gradio
|
| 7 |
+
app_file: app.py
|
| 8 |
+
pinned: false
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# ACE-Step 1.5 LoRA Studio
|
| 12 |
+
|
| 13 |
+
Train ACE-Step 1.5 LoRA adapters, deploy your own Hugging Face Space, and run production-style inference through a Dedicated Endpoint.
|
| 14 |
+
|
| 15 |
+
[](https://huggingface.co/new-space)
|
| 16 |
+
[](https://huggingface.co/new-model)
|
| 17 |
+
[](LICENSE)
|
| 18 |
+
|
| 19 |
+
## What you get
|
| 20 |
+
|
| 21 |
+
- LoRA training UI and workflow: `app.py`, `lora_ui.py`
|
| 22 |
+
- CLI LoRA trainer for local/HF datasets: `lora_train.py`
|
| 23 |
+
- Custom endpoint runtime: `handler.py`, `acestep/`
|
| 24 |
+
- Bootstrap automation for cloning into your HF account: `scripts/hf_clone.py`
|
| 25 |
+
- Endpoint test clients and HF job launcher: `scripts/endpoint/`, `scripts/jobs/`
|
| 26 |
+
|
| 27 |
+
## Quick start (local)
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
python -m pip install --upgrade pip
|
| 31 |
+
python -m pip install -r requirements.txt
|
| 32 |
+
python app.py
|
| 33 |
+
```
|
| 34 |
+
|
| 35 |
+
Open `http://localhost:7860`.
|
| 36 |
+
|
| 37 |
+
## Clone to your HF account
|
| 38 |
+
|
| 39 |
+
Use the two buttons near the top of this README to create target repos in your HF account, then run:
|
| 40 |
+
|
| 41 |
+
Set token once:
|
| 42 |
+
|
| 43 |
+
```bash
|
| 44 |
+
# Linux/macOS
|
| 45 |
+
export HF_TOKEN=hf_xxx
|
| 46 |
+
|
| 47 |
+
# Windows PowerShell
|
| 48 |
+
$env:HF_TOKEN="hf_xxx"
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Clone your own Space:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
python scripts/hf_clone.py space --repo-id YOUR_USERNAME/YOUR_SPACE_NAME
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
Clone your own Endpoint repo:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
python scripts/hf_clone.py endpoint --repo-id YOUR_USERNAME/YOUR_ENDPOINT_REPO
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
Clone both in one run:
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
python scripts/hf_clone.py all \
|
| 67 |
+
--space-repo-id YOUR_USERNAME/YOUR_SPACE_NAME \
|
| 68 |
+
--endpoint-repo-id YOUR_USERNAME/YOUR_ENDPOINT_REPO
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Project layout
|
| 72 |
+
|
| 73 |
+
```text
|
| 74 |
+
.
|
| 75 |
+
|- app.py
|
| 76 |
+
|- lora_ui.py
|
| 77 |
+
|- lora_train.py
|
| 78 |
+
|- handler.py
|
| 79 |
+
|- acestep/
|
| 80 |
+
|- scripts/
|
| 81 |
+
| |- hf_clone.py
|
| 82 |
+
| |- endpoint/
|
| 83 |
+
| | |- generate_interactive.py
|
| 84 |
+
| | |- test.ps1
|
| 85 |
+
| | |- test.bat
|
| 86 |
+
| | |- test_rnb.bat
|
| 87 |
+
| | `- test_rnb_2min.bat
|
| 88 |
+
| `- jobs/
|
| 89 |
+
| `- submit_hf_lora_job.ps1
|
| 90 |
+
|- docs/
|
| 91 |
+
| |- deploy/
|
| 92 |
+
| `- guides/
|
| 93 |
+
|- summaries/
|
| 94 |
+
| `- findings.md
|
| 95 |
+
`- templates/hf-endpoint/
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
## Dataset format
|
| 99 |
+
|
| 100 |
+
Supported audio:
|
| 101 |
+
|
| 102 |
+
- `.wav`, `.flac`, `.mp3`, `.ogg`, `.opus`, `.m4a`, `.aac`
|
| 103 |
+
|
| 104 |
+
Optional sidecar metadata per track:
|
| 105 |
+
|
| 106 |
+
- `song_001.wav`
|
| 107 |
+
- `song_001.json`
|
| 108 |
+
|
| 109 |
+
```json
|
| 110 |
+
{
|
| 111 |
+
"caption": "melodic emotional rnb pop with warm pads",
|
| 112 |
+
"lyrics": "[Verse]\\n...",
|
| 113 |
+
"bpm": 92,
|
| 114 |
+
"keyscale": "Am",
|
| 115 |
+
"timesignature": "4/4",
|
| 116 |
+
"vocal_language": "en",
|
| 117 |
+
"duration": 120
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Endpoint testing
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
python scripts/endpoint/generate_interactive.py
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
Or run scripted tests:
|
| 128 |
+
|
| 129 |
+
- `scripts/endpoint/test.ps1`
|
| 130 |
+
- `scripts/endpoint/test.bat`
|
| 131 |
+
|
| 132 |
+
## Findings and notes
|
| 133 |
+
|
| 134 |
+
Current baseline analysis and improvement ideas are tracked in `summaries/findings.md`.
|
| 135 |
+
|
| 136 |
+
## Docs
|
| 137 |
+
|
| 138 |
+
- Space deployment: `docs/deploy/SPACE.md`
|
| 139 |
+
- Endpoint deployment: `docs/deploy/ENDPOINT.md`
|
| 140 |
+
- Additional guides: `docs/guides/qwen2-audio-train.md`
|
| 141 |
+
|
| 142 |
+
## Open-source readiness checklist
|
| 143 |
+
|
| 144 |
+
- Secrets are env-driven (`HF_TOKEN`, `HF_ENDPOINT_URL`, `.env`).
|
| 145 |
+
- Local artifacts are ignored via `.gitignore`.
|
| 146 |
+
- MIT license included.
|
| 147 |
+
- Reproducible clone/deploy paths documented.
|
acestep/handler.py
CHANGED
|
@@ -24,6 +24,7 @@ from typing import Optional, Dict, Any, Tuple, List, Union
|
|
| 24 |
import torch
|
| 25 |
import torchaudio
|
| 26 |
import soundfile as sf
|
|
|
|
| 27 |
import time
|
| 28 |
from tqdm import tqdm
|
| 29 |
from loguru import logger
|
|
@@ -1655,7 +1656,7 @@ class AceStepHandler:
|
|
| 1655 |
|
| 1656 |
try:
|
| 1657 |
# Load audio file
|
| 1658 |
-
audio, sr =
|
| 1659 |
|
| 1660 |
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
|
| 1661 |
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
|
@@ -1710,7 +1711,7 @@ class AceStepHandler:
|
|
| 1710 |
|
| 1711 |
try:
|
| 1712 |
# Load audio file
|
| 1713 |
-
audio, sr =
|
| 1714 |
|
| 1715 |
# Normalize to stereo 48kHz
|
| 1716 |
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
@@ -1720,6 +1721,44 @@ class AceStepHandler:
|
|
| 1720 |
except Exception as e:
|
| 1721 |
logger.exception("[process_src_audio] Error processing source audio")
|
| 1722 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1723 |
|
| 1724 |
def convert_src_audio_to_codes(self, audio_file) -> str:
|
| 1725 |
"""
|
|
|
|
| 24 |
import torch
|
| 25 |
import torchaudio
|
| 26 |
import soundfile as sf
|
| 27 |
+
import numpy as np
|
| 28 |
import time
|
| 29 |
from tqdm import tqdm
|
| 30 |
from loguru import logger
|
|
|
|
| 1656 |
|
| 1657 |
try:
|
| 1658 |
# Load audio file
|
| 1659 |
+
audio, sr = self._load_audio_any_backend(audio_file)
|
| 1660 |
|
| 1661 |
logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}")
|
| 1662 |
logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}")
|
|
|
|
| 1711 |
|
| 1712 |
try:
|
| 1713 |
# Load audio file
|
| 1714 |
+
audio, sr = self._load_audio_any_backend(audio_file)
|
| 1715 |
|
| 1716 |
# Normalize to stereo 48kHz
|
| 1717 |
audio = self._normalize_audio_to_stereo_48k(audio, sr)
|
|
|
|
| 1721 |
except Exception as e:
|
| 1722 |
logger.exception("[process_src_audio] Error processing source audio")
|
| 1723 |
return None
|
| 1724 |
+
|
| 1725 |
+
def _load_audio_any_backend(self, audio_file):
|
| 1726 |
+
"""Load audio with torchaudio first, then soundfile fallback."""
|
| 1727 |
+
def _coerce_audio_tensor(audio_obj):
|
| 1728 |
+
if isinstance(audio_obj, list):
|
| 1729 |
+
audio_obj = np.asarray(audio_obj, dtype=np.float32)
|
| 1730 |
+
if isinstance(audio_obj, np.ndarray):
|
| 1731 |
+
audio_obj = torch.from_numpy(audio_obj)
|
| 1732 |
+
if not torch.is_tensor(audio_obj):
|
| 1733 |
+
raise TypeError(f"Unsupported audio type: {type(audio_obj)}")
|
| 1734 |
+
|
| 1735 |
+
if not torch.is_floating_point(audio_obj):
|
| 1736 |
+
audio_obj = audio_obj.float()
|
| 1737 |
+
|
| 1738 |
+
# Normalize to [C, T]
|
| 1739 |
+
if audio_obj.dim() == 1:
|
| 1740 |
+
audio_obj = audio_obj.unsqueeze(0)
|
| 1741 |
+
elif audio_obj.dim() == 2:
|
| 1742 |
+
if audio_obj.shape[0] > audio_obj.shape[1] and audio_obj.shape[1] <= 8:
|
| 1743 |
+
audio_obj = audio_obj.transpose(0, 1)
|
| 1744 |
+
elif audio_obj.dim() == 3:
|
| 1745 |
+
audio_obj = audio_obj[0]
|
| 1746 |
+
else:
|
| 1747 |
+
raise ValueError(f"Unexpected audio dims: {tuple(audio_obj.shape)}")
|
| 1748 |
+
return audio_obj.contiguous()
|
| 1749 |
+
|
| 1750 |
+
try:
|
| 1751 |
+
audio, sr = torchaudio.load(audio_file)
|
| 1752 |
+
return _coerce_audio_tensor(audio), sr
|
| 1753 |
+
except Exception as torchaudio_exc:
|
| 1754 |
+
try:
|
| 1755 |
+
audio_np, sr = sf.read(audio_file, dtype="float32", always_2d=True)
|
| 1756 |
+
return _coerce_audio_tensor(audio_np.T), sr
|
| 1757 |
+
except Exception as sf_exc:
|
| 1758 |
+
raise RuntimeError(
|
| 1759 |
+
f"Audio decode failed for '{audio_file}' with torchaudio ({torchaudio_exc}) "
|
| 1760 |
+
f"and soundfile ({sf_exc})."
|
| 1761 |
+
) from sf_exc
|
| 1762 |
|
| 1763 |
def convert_src_audio_to_codes(self, audio_file) -> str:
|
| 1764 |
"""
|
acestep/llm_inference.py
CHANGED
|
@@ -457,7 +457,7 @@ class LLMHandler:
|
|
| 457 |
|
| 458 |
# If lm_model_path is None, use default
|
| 459 |
if lm_model_path is None:
|
| 460 |
-
lm_model_path = "acestep-5Hz-lm-
|
| 461 |
logger.info(f"[initialize] lm_model_path is None, using default: {lm_model_path}")
|
| 462 |
|
| 463 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
|
|
|
| 457 |
|
| 458 |
# If lm_model_path is None, use default
|
| 459 |
if lm_model_path is None:
|
| 460 |
+
lm_model_path = "acestep-5Hz-lm-4B"
|
| 461 |
logger.info(f"[initialize] lm_model_path is None, using default: {lm_model_path}")
|
| 462 |
|
| 463 |
full_lm_model_path = os.path.join(checkpoint_dir, lm_model_path)
|
app.py
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
import os
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from lora_ui import build_ui
|
| 4 |
|
| 5 |
app = build_ui()
|
|
|
|
| 1 |
import os
|
| 2 |
|
| 3 |
+
# On Hugging Face Spaces Zero, `spaces` must be imported before CUDA-related modules.
|
| 4 |
+
if os.getenv("SPACE_ID"):
|
| 5 |
+
try:
|
| 6 |
+
import spaces # noqa: F401
|
| 7 |
+
except Exception:
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
from lora_ui import build_ui
|
| 11 |
|
| 12 |
app = build_ui()
|
docs/ACE-Step-1.5-LoRA-HF-Consolidated.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ACE-Step 1.5 LoRA Pipeline (Simple + HF Spaces)
|
| 2 |
+
|
| 3 |
+
Last updated: 2026-02-12
|
| 4 |
+
|
| 5 |
+
## 1. What is already implemented in this repo
|
| 6 |
+
- Drag/drop dataset loading and folder scan.
|
| 7 |
+
- Optional per-track sidecar JSON (`song.wav` + `song.json`).
|
| 8 |
+
- New **Auto-Label All** option in `lora_ui.py`:
|
| 9 |
+
- Uses ACE audio understanding (`audio -> semantic codes -> caption/lyrics/metadata`).
|
| 10 |
+
- Writes/updates sidecar JSON for each track.
|
| 11 |
+
- LoRA training with ACE flow-matching defaults and adapter checkpoints.
|
| 12 |
+
- Training log now shows device plus elapsed time and ETA.
|
| 13 |
+
|
| 14 |
+
## 2. Direct answers to your core questions
|
| 15 |
+
|
| 16 |
+
### Is LoRA using HF GPU?
|
| 17 |
+
Yes, if the Space hardware is GPU and model device is `auto`/`cuda`, training runs on that Space GPU.
|
| 18 |
+
|
| 19 |
+
### Do we get time estimates?
|
| 20 |
+
Yes. The training status now shows elapsed time and ETA in the log.
|
| 21 |
+
|
| 22 |
+
### How are metadata and lyrics paired per song?
|
| 23 |
+
By basename in the same folder:
|
| 24 |
+
- `track01.wav`
|
| 25 |
+
- `track01.json`
|
| 26 |
+
|
| 27 |
+
### Do you need all metadata?
|
| 28 |
+
No. In this pipeline, metadata is optional.
|
| 29 |
+
- Required minimum: audio files.
|
| 30 |
+
- Strongly recommended: `caption` and/or `lyrics` for better conditioning quality.
|
| 31 |
+
- Optional but helpful: `bpm`, `keyscale`, `timesignature`, `vocal_language`, `duration`.
|
| 32 |
+
|
| 33 |
+
### Where are trained adapters saved?
|
| 34 |
+
- Local run: `lora_output/...` by default.
|
| 35 |
+
- HF Space run: `/data/lora_output/...` by default (as configured in UI code).
|
| 36 |
+
- Final adapter checkpoint is saved under a `final` subfolder.
|
| 37 |
+
|
| 38 |
+
### Cloud GPU + local files?
|
| 39 |
+
- Training on Spaces uses cloud GPU and writes artifacts to the Space filesystem.
|
| 40 |
+
- To keep results outside the Space, download them or upload to a Hub model repo.
|
| 41 |
+
|
| 42 |
+
### Can HF Endpoint GPU train this?
|
| 43 |
+
Not the right product. Inference Endpoints are for model serving/inference; use Spaces (interactive) or Jobs (batch) for training.
|
| 44 |
+
|
| 45 |
+
## 3. Minimal dataset format
|
| 46 |
+
|
| 47 |
+
Drop files into one folder:
|
| 48 |
+
|
| 49 |
+
```text
|
| 50 |
+
dataset_inbox/
|
| 51 |
+
song_a.wav
|
| 52 |
+
song_b.flac
|
| 53 |
+
song_c.mp3
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Optional sidecar for tighter control:
|
| 57 |
+
|
| 58 |
+
```text
|
| 59 |
+
dataset_inbox/
|
| 60 |
+
song_a.wav
|
| 61 |
+
song_a.json
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Example `song_a.json`:
|
| 65 |
+
|
| 66 |
+
```json
|
| 67 |
+
{
|
| 68 |
+
"caption": "emotional indie pop with airy female vocal and warm pads",
|
| 69 |
+
"lyrics": "[Verse]\n...",
|
| 70 |
+
"bpm": 96,
|
| 71 |
+
"keyscale": "Am",
|
| 72 |
+
"timesignature": "4/4",
|
| 73 |
+
"vocal_language": "en",
|
| 74 |
+
"duration": 120
|
| 75 |
+
}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
## 4. Super simple training flow (UI)
|
| 79 |
+
1. Start UI:
|
| 80 |
+
- Local: `python app.py`
|
| 81 |
+
- Space: app starts automatically from `app.py`.
|
| 82 |
+
2. Step 1 tab: initialize `acestep-v15-base` (best LoRA plasticity).
|
| 83 |
+
3. Step 2 tab: scan folder or drag/drop files.
|
| 84 |
+
4. Optional: initialize auto-label LM and click **Auto-Label All**.
|
| 85 |
+
5. Step 3 tab: keep defaults for first run, click **Start Training**.
|
| 86 |
+
6. Click **Refresh Log** to monitor status/loss/ETA.
|
| 87 |
+
7. Step 4 tab: load adapter from output folder and A/B test against base.
|
| 88 |
+
|
| 89 |
+
## 5. HF Spaces setup (step by step)
|
| 90 |
+
1. Create a new Hugging Face **Space** with SDK = `Gradio`.
|
| 91 |
+
2. Push this repo to that Space repo.
|
| 92 |
+
3. Ensure Space metadata/front matter includes:
|
| 93 |
+
- `sdk: gradio`
|
| 94 |
+
- `app_file: app.py`
|
| 95 |
+
4. In Space `Settings -> Hardware`, select a GPU tier.
|
| 96 |
+
5. In Space `Settings -> Variables and secrets`, add any needed tokens as secrets (never hardcode).
|
| 97 |
+
6. Open the Space and run the 4-step UI flow.
|
| 98 |
+
|
| 99 |
+
## 6. GPU association and cost control
|
| 100 |
+
|
| 101 |
+
### Pick hardware for your stage
|
| 102 |
+
- Fast/cheap iteration: start with T4 or A10G.
|
| 103 |
+
- Heavier runs or bigger LM usage: A100/L40S/H100 class.
|
| 104 |
+
|
| 105 |
+
### Keep spend under control
|
| 106 |
+
1. Use smaller auto-label LM (`0.6B`) unless you need higher quality labels.
|
| 107 |
+
2. Train with `acestep-v15-base` only for final-quality runs; iterate on turbo variants if needed.
|
| 108 |
+
3. Pause or downgrade hardware immediately when idle.
|
| 109 |
+
4. Export/upload adapters right after training so you can shut hardware down.
|
| 110 |
+
|
| 111 |
+
### Current billing behavior to remember
|
| 112 |
+
HF Spaces docs indicate upgraded hardware is billed by minute while the Space is running, and you should pause/stop upgraded hardware when not in use.
|
| 113 |
+
|
| 114 |
+
## 7. Suggested first-run defaults
|
| 115 |
+
- Model: `acestep-v15-base`
|
| 116 |
+
- LoRA rank/alpha/dropout: `64 / 64 / 0.1`
|
| 117 |
+
- Optimizer: `adamw_8bit`
|
| 118 |
+
- LR: `1e-4`
|
| 119 |
+
- Warmup: `0.03`
|
| 120 |
+
- Scheduler: `constant_with_warmup`
|
| 121 |
+
- Shift: `3.0`
|
| 122 |
+
- Max grad norm: `1.0`
|
| 123 |
+
|
| 124 |
+
## 8. Source links (official)
|
| 125 |
+
- ACE-Step Gradio guide: https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5/blob/main/docs/en/GRADIO_GUIDE.md
|
| 126 |
+
- ACE-Step README: https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5/blob/main/README.md
|
| 127 |
+
- ACE-Step LoRA model card note (DiT-only LoRA): https://huggingface.co/ACE-Step/Ace-Step-v1.5-lo-ra-new-year
|
| 128 |
+
- HF Spaces overview: https://huggingface.co/docs/hub/en/spaces-overview
|
| 129 |
+
- HF Spaces GPU/hardware docs: https://huggingface.co/docs/hub/en/spaces-gpus
|
| 130 |
+
- HF Spaces config reference: https://huggingface.co/docs/hub/en/spaces-config-reference
|
| 131 |
+
- HF Inference Endpoints overview: https://huggingface.co/docs/inference-endpoints/en/index
|
docs/deploy/ENDPOINT.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deploy Inference To Your Own HF Dedicated Endpoint
|
| 2 |
+
|
| 3 |
+
This guide deploys the custom `handler.py` inference runtime to a Hugging Face Dedicated Inference Endpoint.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Hugging Face account
|
| 8 |
+
- `HF_TOKEN` with repo write access
|
| 9 |
+
- Dedicated Endpoint access on your HF plan
|
| 10 |
+
|
| 11 |
+
## 1) Create/Update Your Endpoint Repo
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
python scripts/hf_clone.py endpoint --repo-id YOUR_USERNAME/YOUR_ENDPOINT_REPO
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
This uploads:
|
| 18 |
+
|
| 19 |
+
- `handler.py`
|
| 20 |
+
- `acestep/`
|
| 21 |
+
- `requirements.txt`
|
| 22 |
+
- `packages.txt`
|
| 23 |
+
- endpoint-specific README template
|
| 24 |
+
|
| 25 |
+
## 2) Create Endpoint In HF UI
|
| 26 |
+
|
| 27 |
+
1. Go to **Inference Endpoints** -> **New endpoint**.
|
| 28 |
+
2. Select your custom model repo: `YOUR_USERNAME/YOUR_ENDPOINT_REPO`.
|
| 29 |
+
3. Choose GPU hardware.
|
| 30 |
+
4. Deploy.
|
| 31 |
+
|
| 32 |
+
## 3) Recommended Endpoint Environment Variables
|
| 33 |
+
|
| 34 |
+
- `ACE_CONFIG_PATH` (default: `acestep-v15-sft`)
|
| 35 |
+
- `ACE_LM_MODEL_PATH` (default: `acestep-5Hz-lm-4B`)
|
| 36 |
+
- `ACE_LM_BACKEND` (default: `pt`)
|
| 37 |
+
- `ACE_DOWNLOAD_SOURCE` (`huggingface` or `modelscope`)
|
| 38 |
+
- `ACE_ENABLE_FALLBACK` (`false` recommended for strict failure visibility)
|
| 39 |
+
|
| 40 |
+
## 4) Test The Endpoint
|
| 41 |
+
|
| 42 |
+
Set credentials:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
# Linux/macOS
|
| 46 |
+
export HF_TOKEN=hf_xxx
|
| 47 |
+
export HF_ENDPOINT_URL=https://your-endpoint-url.endpoints.huggingface.cloud
|
| 48 |
+
|
| 49 |
+
# Windows PowerShell
|
| 50 |
+
$env:HF_TOKEN="hf_xxx"
|
| 51 |
+
$env:HF_ENDPOINT_URL="https://your-endpoint-url.endpoints.huggingface.cloud"
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Test with:
|
| 55 |
+
|
| 56 |
+
- `python scripts/endpoint/generate_interactive.py`
|
| 57 |
+
- `scripts/endpoint/test.ps1`
|
| 58 |
+
|
| 59 |
+
## Request Contract
|
| 60 |
+
|
| 61 |
+
```json
|
| 62 |
+
{
|
| 63 |
+
"inputs": {
|
| 64 |
+
"prompt": "upbeat pop rap with emotional guitar",
|
| 65 |
+
"lyrics": "[Verse] city lights and midnight rain",
|
| 66 |
+
"duration_sec": 12,
|
| 67 |
+
"sample_rate": 44100,
|
| 68 |
+
"seed": 42,
|
| 69 |
+
"guidance_scale": 7.0,
|
| 70 |
+
"steps": 50,
|
| 71 |
+
"use_lm": true
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
## Cost Control
|
| 77 |
+
|
| 78 |
+
- Use scale-to-zero for idle periods.
|
| 79 |
+
- Pause endpoint for immediate spend stop.
|
| 80 |
+
- Expect cold starts when scaled to zero.
|
docs/deploy/SPACE.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Deploy LoRA Studio To Your Own HF Space
|
| 2 |
+
|
| 3 |
+
This guide deploys the full LoRA Studio UI to your own Hugging Face Space.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
- Hugging Face account
|
| 8 |
+
- `HF_TOKEN` with repo write access
|
| 9 |
+
- Python environment with `requirements.txt` installed
|
| 10 |
+
|
| 11 |
+
## Fast Path (Recommended)
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
python scripts/hf_clone.py space --repo-id YOUR_USERNAME/YOUR_SPACE_NAME
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Optional private Space:
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
python scripts/hf_clone.py space --repo-id YOUR_USERNAME/YOUR_SPACE_NAME --private
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
## Manual Path
|
| 24 |
+
|
| 25 |
+
1. Create a new Space on Hugging Face:
|
| 26 |
+
- SDK: `Gradio`
|
| 27 |
+
2. Push this repo content (excluding local artifacts) to that Space repo.
|
| 28 |
+
3. Ensure README front matter has:
|
| 29 |
+
- `sdk: gradio`
|
| 30 |
+
- `app_file: app.py`
|
| 31 |
+
4. In Space settings:
|
| 32 |
+
- select GPU hardware (A10G/A100/etc.) if needed
|
| 33 |
+
- add secrets (`HF_TOKEN`) if your flow requires private Hub access
|
| 34 |
+
|
| 35 |
+
## Runtime Notes
|
| 36 |
+
|
| 37 |
+
- Space output defaults to `/data/lora_output` on Hugging Face Spaces.
|
| 38 |
+
- Enable persistent storage if you need checkpoint retention across restarts.
|
| 39 |
+
- For long-running non-interactive training, HF Jobs may be more cost-efficient than keeping a Space running.
|
| 40 |
+
|
docs/guides/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Guides
|
| 2 |
+
|
| 3 |
+
Additional step-by-step guides that are useful but not required for the core LoRA Studio flow.
|
| 4 |
+
|
| 5 |
+
- `qwen2-audio-train.md`
|
docs/guides/qwen2-audio-train.md
ADDED
|
File without changes
|
handler.py
CHANGED
|
@@ -3,6 +3,7 @@ import base64
|
|
| 3 |
import io
|
| 4 |
import os
|
| 5 |
import traceback
|
|
|
|
| 6 |
from typing import Any, Dict, Optional, Tuple
|
| 7 |
|
| 8 |
import numpy as np
|
|
@@ -27,7 +28,7 @@ class EndpointHandler:
|
|
| 27 |
"sample_rate": 44100,
|
| 28 |
"seed": 42,
|
| 29 |
"guidance_scale": 7.0,
|
| 30 |
-
"steps":
|
| 31 |
"use_lm": true,
|
| 32 |
"simple_prompt": false,
|
| 33 |
"instrumental": false,
|
|
@@ -50,8 +51,10 @@ class EndpointHandler:
|
|
| 50 |
self.project_root = os.path.dirname(os.path.abspath(__file__))
|
| 51 |
|
| 52 |
self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
|
| 56 |
self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
|
| 57 |
|
|
@@ -233,6 +236,31 @@ class EndpointHandler:
|
|
| 233 |
|
| 234 |
try:
|
| 235 |
checkpoint_dir = os.path.join(self.project_root, "checkpoints")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
status, ok = self.llm_handler.initialize(
|
| 237 |
checkpoint_dir=checkpoint_dir,
|
| 238 |
lm_model_path=self.lm_model_path,
|
|
@@ -352,8 +380,14 @@ class EndpointHandler:
|
|
| 352 |
|
| 353 |
seed = self._to_int(raw_inputs.get("seed", 42), 42)
|
| 354 |
guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
|
| 355 |
-
steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps",
|
| 356 |
steps = max(1, min(steps, 200))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
|
| 358 |
allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
|
| 359 |
|
|
@@ -365,6 +399,7 @@ class EndpointHandler:
|
|
| 365 |
"seed": seed,
|
| 366 |
"guidance_scale": guidance_scale,
|
| 367 |
"steps": steps,
|
|
|
|
| 368 |
"use_lm": use_lm,
|
| 369 |
"instrumental": instrumental,
|
| 370 |
"simple_prompt": simple_prompt,
|
|
@@ -383,7 +418,7 @@ class EndpointHandler:
|
|
| 383 |
"simple_expansion_error": None,
|
| 384 |
}
|
| 385 |
|
| 386 |
-
bpm =
|
| 387 |
keyscale = ""
|
| 388 |
timesignature = ""
|
| 389 |
vocal_language = "unknown"
|
|
@@ -399,7 +434,9 @@ class EndpointHandler:
|
|
| 399 |
if getattr(sample, "success", False):
|
| 400 |
caption = getattr(sample, "caption", "") or caption
|
| 401 |
lyrics = getattr(sample, "lyrics", "") or lyrics
|
| 402 |
-
|
|
|
|
|
|
|
| 403 |
keyscale = getattr(sample, "keyscale", "") or ""
|
| 404 |
timesignature = getattr(sample, "timesignature", "") or ""
|
| 405 |
vocal_language = getattr(sample, "language", "") or "unknown"
|
|
@@ -526,6 +563,7 @@ class EndpointHandler:
|
|
| 526 |
"seed": req["seed"],
|
| 527 |
"guidance_scale": req["guidance_scale"],
|
| 528 |
"steps": req["steps"],
|
|
|
|
| 529 |
"use_lm": req["use_lm"],
|
| 530 |
"simple_prompt": req["simple_prompt"],
|
| 531 |
"instrumental": req["instrumental"],
|
|
|
|
| 3 |
import io
|
| 4 |
import os
|
| 5 |
import traceback
|
| 6 |
+
from pathlib import Path
|
| 7 |
from typing import Any, Dict, Optional, Tuple
|
| 8 |
|
| 9 |
import numpy as np
|
|
|
|
| 28 |
"sample_rate": 44100,
|
| 29 |
"seed": 42,
|
| 30 |
"guidance_scale": 7.0,
|
| 31 |
+
"steps": 50,
|
| 32 |
"use_lm": true,
|
| 33 |
"simple_prompt": false,
|
| 34 |
"instrumental": false,
|
|
|
|
| 51 |
self.project_root = os.path.dirname(os.path.abspath(__file__))
|
| 52 |
|
| 53 |
self.model_repo = os.getenv("ACE_MODEL_REPO", "ACE-Step/Ace-Step1.5")
|
| 54 |
+
# Default to the larger quality-oriented setup.
|
| 55 |
+
# Override via ACE_CONFIG_PATH / ACE_LM_MODEL_PATH when needed.
|
| 56 |
+
self.config_path = os.getenv("ACE_CONFIG_PATH", "acestep-v15-sft")
|
| 57 |
+
self.lm_model_path = os.getenv("ACE_LM_MODEL_PATH", "acestep-5Hz-lm-4B")
|
| 58 |
self.lm_backend = os.getenv("ACE_LM_BACKEND", "pt")
|
| 59 |
self.download_source = os.getenv("ACE_DOWNLOAD_SOURCE", "huggingface")
|
| 60 |
|
|
|
|
| 236 |
|
| 237 |
try:
|
| 238 |
checkpoint_dir = os.path.join(self.project_root, "checkpoints")
|
| 239 |
+
full_lm_model_path = os.path.join(checkpoint_dir, self.lm_model_path)
|
| 240 |
+
if not os.path.exists(full_lm_model_path):
|
| 241 |
+
try:
|
| 242 |
+
from acestep.model_downloader import ensure_lm_model, ensure_main_model
|
| 243 |
+
except Exception as e:
|
| 244 |
+
self.llm_error = f"LM download helper import failed: {type(e).__name__}: {e}"
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
# 1.7B ships with main; 0.6B/4B are standalone submodels.
|
| 248 |
+
if self.lm_model_path == "acestep-5Hz-lm-1.7B":
|
| 249 |
+
dl_ok, dl_msg = ensure_main_model(
|
| 250 |
+
checkpoints_dir=Path(checkpoint_dir),
|
| 251 |
+
prefer_source=self.download_source,
|
| 252 |
+
)
|
| 253 |
+
else:
|
| 254 |
+
dl_ok, dl_msg = ensure_lm_model(
|
| 255 |
+
model_name=self.lm_model_path,
|
| 256 |
+
checkpoints_dir=Path(checkpoint_dir),
|
| 257 |
+
prefer_source=self.download_source,
|
| 258 |
+
)
|
| 259 |
+
self.init_details["llm_download"] = dl_msg
|
| 260 |
+
if not dl_ok:
|
| 261 |
+
self.llm_error = f"LM download failed: {dl_msg}"
|
| 262 |
+
return False
|
| 263 |
+
|
| 264 |
status, ok = self.llm_handler.initialize(
|
| 265 |
checkpoint_dir=checkpoint_dir,
|
| 266 |
lm_model_path=self.lm_model_path,
|
|
|
|
| 380 |
|
| 381 |
seed = self._to_int(raw_inputs.get("seed", 42), 42)
|
| 382 |
guidance_scale = self._to_float(raw_inputs.get("guidance_scale", 7.0), 7.0)
|
| 383 |
+
steps = self._to_int(raw_inputs.get("steps", raw_inputs.get("inference_steps", 50)), 50)
|
| 384 |
steps = max(1, min(steps, 200))
|
| 385 |
+
bpm_raw = raw_inputs.get("bpm")
|
| 386 |
+
bpm = None
|
| 387 |
+
if bpm_raw is not None and str(bpm_raw).strip() != "":
|
| 388 |
+
bpm = self._to_int(bpm_raw, 0)
|
| 389 |
+
if bpm <= 0:
|
| 390 |
+
bpm = None
|
| 391 |
use_lm = self._to_bool(raw_inputs.get("use_lm", raw_inputs.get("thinking", True)), True)
|
| 392 |
allow_fallback = self._to_bool(raw_inputs.get("allow_fallback"), self.enable_fallback)
|
| 393 |
|
|
|
|
| 399 |
"seed": seed,
|
| 400 |
"guidance_scale": guidance_scale,
|
| 401 |
"steps": steps,
|
| 402 |
+
"bpm": bpm,
|
| 403 |
"use_lm": use_lm,
|
| 404 |
"instrumental": instrumental,
|
| 405 |
"simple_prompt": simple_prompt,
|
|
|
|
| 418 |
"simple_expansion_error": None,
|
| 419 |
}
|
| 420 |
|
| 421 |
+
bpm = req.get("bpm")
|
| 422 |
keyscale = ""
|
| 423 |
timesignature = ""
|
| 424 |
vocal_language = "unknown"
|
|
|
|
| 434 |
if getattr(sample, "success", False):
|
| 435 |
caption = getattr(sample, "caption", "") or caption
|
| 436 |
lyrics = getattr(sample, "lyrics", "") or lyrics
|
| 437 |
+
sample_bpm = getattr(sample, "bpm", None)
|
| 438 |
+
if bpm is None:
|
| 439 |
+
bpm = sample_bpm
|
| 440 |
keyscale = getattr(sample, "keyscale", "") or ""
|
| 441 |
timesignature = getattr(sample, "timesignature", "") or ""
|
| 442 |
vocal_language = getattr(sample, "language", "") or "unknown"
|
|
|
|
| 563 |
"seed": req["seed"],
|
| 564 |
"guidance_scale": req["guidance_scale"],
|
| 565 |
"steps": req["steps"],
|
| 566 |
+
"bpm": req.get("bpm"),
|
| 567 |
"use_lm": req["use_lm"],
|
| 568 |
"simple_prompt": req["simple_prompt"],
|
| 569 |
"instrumental": req["instrumental"],
|
lora_train.py
ADDED
|
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step 1.5 LoRA Training Engine
|
| 3 |
+
|
| 4 |
+
Handles dataset building, VAE encoding, and flow-matching LoRA training
|
| 5 |
+
of the DiT decoder. Designed to work with the existing AceStepHandler.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import sys
|
| 10 |
+
import json
|
| 11 |
+
import math
|
| 12 |
+
import time
|
| 13 |
+
import random
|
| 14 |
+
import hashlib
|
| 15 |
+
import argparse
|
| 16 |
+
import tempfile
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from dataclasses import dataclass, field, asdict
|
| 19 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
import torchaudio
|
| 24 |
+
import soundfile as sf
|
| 25 |
+
import numpy as np
|
| 26 |
+
from loguru import logger
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
# Dataset helpers
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
AUDIO_EXTENSIONS = {".wav", ".flac", ".mp3", ".ogg", ".opus", ".m4a", ".aac"}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class TrackEntry:
|
| 38 |
+
"""One audio file + its metadata."""
|
| 39 |
+
|
| 40 |
+
audio_path: str
|
| 41 |
+
caption: str = ""
|
| 42 |
+
lyrics: str = ""
|
| 43 |
+
bpm: Optional[int] = None
|
| 44 |
+
keyscale: str = ""
|
| 45 |
+
timesignature: str = "4/4"
|
| 46 |
+
vocal_language: str = "en"
|
| 47 |
+
duration: Optional[float] = None # seconds (measured at scan time)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _load_track_entry(audio_path: Path) -> TrackEntry:
|
| 51 |
+
"""Load one track + optional sidecar metadata."""
|
| 52 |
+
sidecar = audio_path.with_suffix(".json")
|
| 53 |
+
meta: Dict[str, Any] = {}
|
| 54 |
+
if sidecar.exists():
|
| 55 |
+
try:
|
| 56 |
+
meta = json.loads(sidecar.read_text(encoding="utf-8"))
|
| 57 |
+
except Exception as exc:
|
| 58 |
+
logger.warning(f"Bad sidecar {sidecar}: {exc}")
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
info = torchaudio.info(str(audio_path))
|
| 62 |
+
duration = info.num_frames / info.sample_rate
|
| 63 |
+
except Exception:
|
| 64 |
+
duration = meta.get("duration")
|
| 65 |
+
|
| 66 |
+
return TrackEntry(
|
| 67 |
+
audio_path=str(audio_path),
|
| 68 |
+
caption=meta.get("caption", ""),
|
| 69 |
+
lyrics=meta.get("lyrics", ""),
|
| 70 |
+
bpm=meta.get("bpm"),
|
| 71 |
+
keyscale=meta.get("keyscale", ""),
|
| 72 |
+
timesignature=meta.get("timesignature", "4/4"),
|
| 73 |
+
vocal_language=meta.get("vocal_language", "en"),
|
| 74 |
+
duration=duration,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def scan_dataset_folder(folder: str) -> List[TrackEntry]:
|
| 79 |
+
"""Scan *folder* for audio files and optional JSON sidecars.
|
| 80 |
+
|
| 81 |
+
For every ``track.wav`` found, if ``track.json`` exists next to it the
|
| 82 |
+
metadata fields are loaded from the sidecar. Missing sidecars are fine –
|
| 83 |
+
the entry will have empty metadata that can be filled later.
|
| 84 |
+
"""
|
| 85 |
+
folder = Path(folder)
|
| 86 |
+
if not folder.is_dir():
|
| 87 |
+
raise FileNotFoundError(f"Dataset folder not found: {folder}")
|
| 88 |
+
|
| 89 |
+
entries: List[TrackEntry] = []
|
| 90 |
+
for audio_path in sorted(folder.rglob("*")):
|
| 91 |
+
if audio_path.suffix.lower() not in AUDIO_EXTENSIONS:
|
| 92 |
+
continue
|
| 93 |
+
entries.append(_load_track_entry(audio_path))
|
| 94 |
+
|
| 95 |
+
logger.info(f"Scanned {len(entries)} audio files in {folder}")
|
| 96 |
+
return entries
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def scan_uploaded_files(file_paths: List[str]) -> List[TrackEntry]:
|
| 100 |
+
"""Build entries from dropped/uploaded files.
|
| 101 |
+
|
| 102 |
+
Supports uploading audio files together with optional ``.json`` sidecars.
|
| 103 |
+
Sidecars are matched by basename stem (``song.mp3`` <-> ``song.json``).
|
| 104 |
+
"""
|
| 105 |
+
meta_by_stem: Dict[str, Dict[str, Any]] = {}
|
| 106 |
+
for path in file_paths:
|
| 107 |
+
p = Path(path)
|
| 108 |
+
if not p.exists() or p.suffix.lower() != ".json":
|
| 109 |
+
continue
|
| 110 |
+
try:
|
| 111 |
+
meta_by_stem[p.stem] = json.loads(p.read_text(encoding="utf-8"))
|
| 112 |
+
except Exception as exc:
|
| 113 |
+
logger.warning(f"Bad uploaded sidecar {p}: {exc}")
|
| 114 |
+
|
| 115 |
+
entries: List[TrackEntry] = []
|
| 116 |
+
for path in file_paths:
|
| 117 |
+
p = Path(path)
|
| 118 |
+
if not p.exists() or p.suffix.lower() not in AUDIO_EXTENSIONS:
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
uploaded_meta = meta_by_stem.get(p.stem)
|
| 122 |
+
if uploaded_meta is None:
|
| 123 |
+
entries.append(_load_track_entry(p))
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
info = torchaudio.info(str(p))
|
| 128 |
+
duration = info.num_frames / info.sample_rate
|
| 129 |
+
except Exception:
|
| 130 |
+
duration = uploaded_meta.get("duration")
|
| 131 |
+
|
| 132 |
+
bpm_val = uploaded_meta.get("bpm")
|
| 133 |
+
if isinstance(bpm_val, str) and bpm_val.strip():
|
| 134 |
+
try:
|
| 135 |
+
bpm_val = int(float(bpm_val))
|
| 136 |
+
except Exception:
|
| 137 |
+
bpm_val = None
|
| 138 |
+
|
| 139 |
+
entries.append(
|
| 140 |
+
TrackEntry(
|
| 141 |
+
audio_path=str(p),
|
| 142 |
+
caption=uploaded_meta.get("caption", "") or "",
|
| 143 |
+
lyrics=uploaded_meta.get("lyrics", "") or "",
|
| 144 |
+
bpm=bpm_val if isinstance(bpm_val, int) else None,
|
| 145 |
+
keyscale=uploaded_meta.get("keyscale", "") or "",
|
| 146 |
+
timesignature=uploaded_meta.get("timesignature", "4/4") or "4/4",
|
| 147 |
+
vocal_language=uploaded_meta.get("vocal_language", uploaded_meta.get("language", "en")) or "en",
|
| 148 |
+
duration=duration,
|
| 149 |
+
)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
logger.info(
|
| 153 |
+
"Loaded {} uploaded audio files ({} uploaded sidecars detected)".format(
|
| 154 |
+
len(entries), len(meta_by_stem)
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
return entries
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ---------------------------------------------------------------------------
|
| 161 |
+
# Training hyper-parameters
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@dataclass
|
| 166 |
+
class LoRATrainConfig:
|
| 167 |
+
"""All tuneable knobs for a LoRA run."""
|
| 168 |
+
|
| 169 |
+
# LoRA architecture
|
| 170 |
+
lora_rank: int = 64
|
| 171 |
+
lora_alpha: int = 64
|
| 172 |
+
lora_dropout: float = 0.1
|
| 173 |
+
lora_target_modules: List[str] = field(
|
| 174 |
+
default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Optimiser
|
| 178 |
+
learning_rate: float = 1e-4
|
| 179 |
+
weight_decay: float = 0.01
|
| 180 |
+
optimizer: str = "adamw_8bit" # "adamw" | "adamw_8bit"
|
| 181 |
+
max_grad_norm: float = 1.0
|
| 182 |
+
|
| 183 |
+
# Schedule
|
| 184 |
+
warmup_ratio: float = 0.03
|
| 185 |
+
scheduler: str = "constant_with_warmup"
|
| 186 |
+
|
| 187 |
+
# Training loop
|
| 188 |
+
num_epochs: int = 50
|
| 189 |
+
batch_size: int = 1
|
| 190 |
+
gradient_accumulation_steps: int = 1
|
| 191 |
+
save_every_n_epochs: int = 10
|
| 192 |
+
log_every_n_steps: int = 5
|
| 193 |
+
|
| 194 |
+
# Flow matching
|
| 195 |
+
shift: float = 3.0 # timestep shift factor
|
| 196 |
+
|
| 197 |
+
# Audio pre-processing
|
| 198 |
+
max_duration_sec: float = 240.0 # clamp audio to this length
|
| 199 |
+
sample_rate: int = 48000
|
| 200 |
+
|
| 201 |
+
# Paths
|
| 202 |
+
output_dir: str = "lora_output"
|
| 203 |
+
resume_from: Optional[str] = None
|
| 204 |
+
|
| 205 |
+
# Device
|
| 206 |
+
device: str = "auto"
|
| 207 |
+
dtype: str = "bf16" # "bf16" | "fp16" | "fp32"
|
| 208 |
+
mixed_precision: bool = True
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
# Core trainer
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class LoRATrainer:
|
| 217 |
+
"""Thin training loop that wraps the existing AceStepHandler."""
|
| 218 |
+
|
| 219 |
+
def __init__(self, handler, config: LoRATrainConfig):
|
| 220 |
+
"""
|
| 221 |
+
Args:
|
| 222 |
+
handler: Initialised ``AceStepHandler`` (model, vae, text_encoder loaded).
|
| 223 |
+
config: Training hyper-parameters.
|
| 224 |
+
"""
|
| 225 |
+
self.handler = handler
|
| 226 |
+
self.cfg = config
|
| 227 |
+
|
| 228 |
+
self.device = handler.device
|
| 229 |
+
self.dtype = handler.dtype
|
| 230 |
+
|
| 231 |
+
# Will be set during prepare()
|
| 232 |
+
self.peft_model = None
|
| 233 |
+
self.optimizer = None
|
| 234 |
+
self.scheduler = None
|
| 235 |
+
self.global_step = 0
|
| 236 |
+
self.current_epoch = 0
|
| 237 |
+
|
| 238 |
+
# Loss history for UI
|
| 239 |
+
self.loss_history: List[Dict[str, Any]] = []
|
| 240 |
+
self._stop_requested = False
|
| 241 |
+
|
| 242 |
+
# ------------------------------------------------------------------
|
| 243 |
+
# Setup
|
| 244 |
+
# ------------------------------------------------------------------
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def _resolve_lora_target_modules(model, requested_targets: Optional[List[str]]) -> List[str]:
|
| 248 |
+
"""Resolve LoRA target module suffixes against the actual decoder module names."""
|
| 249 |
+
linear_module_names = [
|
| 250 |
+
name for name, module in model.named_modules() if isinstance(module, torch.nn.Linear)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
def _exists_as_suffix(target: str) -> bool:
|
| 254 |
+
return any(name.endswith(target) for name in linear_module_names)
|
| 255 |
+
|
| 256 |
+
requested_targets = requested_targets or []
|
| 257 |
+
resolved = [target for target in requested_targets if _exists_as_suffix(target)]
|
| 258 |
+
if resolved:
|
| 259 |
+
return resolved
|
| 260 |
+
|
| 261 |
+
fallback_groups = [
|
| 262 |
+
["q_proj", "k_proj", "v_proj", "o_proj"],
|
| 263 |
+
["to_q", "to_k", "to_v", "to_out.0"],
|
| 264 |
+
["query", "key", "value", "out_proj"],
|
| 265 |
+
["wq", "wk", "wv", "wo"],
|
| 266 |
+
["qkv", "proj_out"],
|
| 267 |
+
]
|
| 268 |
+
for group in fallback_groups:
|
| 269 |
+
group_resolved = [target for target in group if _exists_as_suffix(target)]
|
| 270 |
+
if len(group_resolved) >= 2:
|
| 271 |
+
return group_resolved
|
| 272 |
+
|
| 273 |
+
sample = ", ".join(linear_module_names[:30])
|
| 274 |
+
raise ValueError(
|
| 275 |
+
"Could not find LoRA target modules in decoder. "
|
| 276 |
+
f"Requested={requested_targets}. "
|
| 277 |
+
f"Sample linear modules: {sample}"
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
def prepare(self):
|
| 281 |
+
"""Attach LoRA adapters to the decoder and build the optimiser."""
|
| 282 |
+
import copy
|
| 283 |
+
from peft import LoraConfig, PeftModel, TaskType, get_peft_model
|
| 284 |
+
|
| 285 |
+
# Keep a backup of the plain base decoder so load/unload logic remains valid.
|
| 286 |
+
if self.handler._base_decoder is None:
|
| 287 |
+
self.handler._base_decoder = copy.deepcopy(self.handler.model.decoder)
|
| 288 |
+
else:
|
| 289 |
+
self.handler.model.decoder = copy.deepcopy(self.handler._base_decoder)
|
| 290 |
+
self.handler.model.decoder = self.handler.model.decoder.to(self.device).to(self.dtype)
|
| 291 |
+
self.handler.model.decoder.eval()
|
| 292 |
+
|
| 293 |
+
resume_adapter = None
|
| 294 |
+
if self.cfg.resume_from:
|
| 295 |
+
adapter_cfg = os.path.join(self.cfg.resume_from, "adapter_config.json")
|
| 296 |
+
if os.path.isfile(adapter_cfg):
|
| 297 |
+
resume_adapter = self.cfg.resume_from
|
| 298 |
+
|
| 299 |
+
if resume_adapter:
|
| 300 |
+
logger.info(f"Loading existing LoRA adapter for resume: {resume_adapter}")
|
| 301 |
+
self.peft_model = PeftModel.from_pretrained(
|
| 302 |
+
self.handler.model.decoder,
|
| 303 |
+
resume_adapter,
|
| 304 |
+
is_trainable=True,
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
resolved_targets = self._resolve_lora_target_modules(
|
| 308 |
+
self.handler.model.decoder,
|
| 309 |
+
self.cfg.lora_target_modules,
|
| 310 |
+
)
|
| 311 |
+
logger.info(f"Using LoRA target modules: {resolved_targets}")
|
| 312 |
+
peft_cfg = LoraConfig(
|
| 313 |
+
r=self.cfg.lora_rank,
|
| 314 |
+
lora_alpha=self.cfg.lora_alpha,
|
| 315 |
+
lora_dropout=self.cfg.lora_dropout,
|
| 316 |
+
target_modules=resolved_targets,
|
| 317 |
+
bias="none",
|
| 318 |
+
task_type=TaskType.FEATURE_EXTRACTION,
|
| 319 |
+
)
|
| 320 |
+
self.peft_model = get_peft_model(self.handler.model.decoder, peft_cfg)
|
| 321 |
+
|
| 322 |
+
self.peft_model.print_trainable_parameters()
|
| 323 |
+
self.handler.model.decoder = self.peft_model
|
| 324 |
+
self.handler.model.decoder.to(self.device).to(self.dtype)
|
| 325 |
+
self.handler.model.decoder.train()
|
| 326 |
+
self.handler.lora_loaded = True
|
| 327 |
+
self.handler.use_lora = True
|
| 328 |
+
|
| 329 |
+
# Build optimiser (only LoRA params are trainable)
|
| 330 |
+
trainable_params = [p for p in self.peft_model.parameters() if p.requires_grad]
|
| 331 |
+
if self.cfg.optimizer == "adamw_8bit":
|
| 332 |
+
try:
|
| 333 |
+
import bitsandbytes as bnb
|
| 334 |
+
self.optimizer = bnb.optim.AdamW8bit(
|
| 335 |
+
trainable_params,
|
| 336 |
+
lr=self.cfg.learning_rate,
|
| 337 |
+
weight_decay=self.cfg.weight_decay,
|
| 338 |
+
)
|
| 339 |
+
except ImportError:
|
| 340 |
+
logger.warning("bitsandbytes not found – falling back to standard AdamW")
|
| 341 |
+
self.optimizer = torch.optim.AdamW(
|
| 342 |
+
trainable_params,
|
| 343 |
+
lr=self.cfg.learning_rate,
|
| 344 |
+
weight_decay=self.cfg.weight_decay,
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
self.optimizer = torch.optim.AdamW(
|
| 348 |
+
trainable_params,
|
| 349 |
+
lr=self.cfg.learning_rate,
|
| 350 |
+
weight_decay=self.cfg.weight_decay,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Resume checkpoint state (after model/adapter restore).
|
| 354 |
+
if self.cfg.resume_from and os.path.isfile(
|
| 355 |
+
os.path.join(self.cfg.resume_from, "training_state.pt")
|
| 356 |
+
):
|
| 357 |
+
state = torch.load(
|
| 358 |
+
os.path.join(self.cfg.resume_from, "training_state.pt"),
|
| 359 |
+
weights_only=False,
|
| 360 |
+
)
|
| 361 |
+
try:
|
| 362 |
+
self.optimizer.load_state_dict(state["optimizer"])
|
| 363 |
+
except Exception as exc:
|
| 364 |
+
logger.warning(f"Could not restore optimizer state, continuing fresh optimizer: {exc}")
|
| 365 |
+
self.global_step = int(state.get("global_step", 0))
|
| 366 |
+
# Saved epoch is completed epoch index; continue from next epoch.
|
| 367 |
+
self.current_epoch = int(state.get("epoch", -1)) + 1
|
| 368 |
+
loss_path = os.path.join(self.cfg.resume_from, "loss_history.json")
|
| 369 |
+
if os.path.isfile(loss_path):
|
| 370 |
+
try:
|
| 371 |
+
with open(loss_path, "r", encoding="utf-8") as f:
|
| 372 |
+
self.loss_history = json.load(f)
|
| 373 |
+
except Exception:
|
| 374 |
+
pass
|
| 375 |
+
logger.info(
|
| 376 |
+
f"Resumed from {self.cfg.resume_from} "
|
| 377 |
+
f"(epoch {self.current_epoch}, step {self.global_step})"
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# ------------------------------------------------------------------
|
| 381 |
+
# Data loading
|
| 382 |
+
# ------------------------------------------------------------------
|
| 383 |
+
|
| 384 |
+
@staticmethod
|
| 385 |
+
def _coerce_audio_tensor(audio: Any) -> torch.Tensor:
|
| 386 |
+
"""Coerce decoded audio into torch.Tensor with shape [C, T]."""
|
| 387 |
+
if isinstance(audio, list):
|
| 388 |
+
audio = np.asarray(audio, dtype=np.float32)
|
| 389 |
+
if isinstance(audio, np.ndarray):
|
| 390 |
+
audio = torch.from_numpy(audio)
|
| 391 |
+
if not torch.is_tensor(audio):
|
| 392 |
+
raise TypeError(f"Unsupported audio type: {type(audio)}")
|
| 393 |
+
|
| 394 |
+
# Ensure floating point for downstream resample/vae encode.
|
| 395 |
+
if not torch.is_floating_point(audio):
|
| 396 |
+
audio = audio.float()
|
| 397 |
+
|
| 398 |
+
# Normalize dimensions to [C, T].
|
| 399 |
+
if audio.dim() == 1:
|
| 400 |
+
audio = audio.unsqueeze(0)
|
| 401 |
+
elif audio.dim() == 2:
|
| 402 |
+
# Accept either [T, C] or [C, T]; transpose only when clearly [T, C].
|
| 403 |
+
if audio.shape[0] > audio.shape[1] and audio.shape[1] <= 8:
|
| 404 |
+
audio = audio.transpose(0, 1)
|
| 405 |
+
elif audio.dim() == 3:
|
| 406 |
+
# If batched, take first item.
|
| 407 |
+
audio = audio[0]
|
| 408 |
+
else:
|
| 409 |
+
raise ValueError(f"Unexpected audio dims: {tuple(audio.shape)}")
|
| 410 |
+
|
| 411 |
+
return audio.contiguous()
|
| 412 |
+
|
| 413 |
+
def _load_audio(self, path: str) -> torch.Tensor:
|
| 414 |
+
"""Load audio, resample to 48 kHz stereo, clamp to max_duration."""
|
| 415 |
+
try:
|
| 416 |
+
wav, sr = torchaudio.load(path)
|
| 417 |
+
except Exception as torchaudio_exc:
|
| 418 |
+
# torchaudio on some Space images requires torchcodec for decode.
|
| 419 |
+
# Fallback to soundfile so training can proceed without torchcodec.
|
| 420 |
+
try:
|
| 421 |
+
audio_np, sr = sf.read(path, dtype="float32", always_2d=True)
|
| 422 |
+
wav = torch.from_numpy(audio_np.T)
|
| 423 |
+
except Exception as sf_exc:
|
| 424 |
+
raise RuntimeError(
|
| 425 |
+
f"Failed to decode audio '{path}' with torchaudio ({torchaudio_exc}) "
|
| 426 |
+
f"and soundfile ({sf_exc})."
|
| 427 |
+
) from sf_exc
|
| 428 |
+
|
| 429 |
+
wav = self._coerce_audio_tensor(wav)
|
| 430 |
+
|
| 431 |
+
# Resample if needed
|
| 432 |
+
if sr != self.cfg.sample_rate:
|
| 433 |
+
wav = torchaudio.functional.resample(wav, sr, self.cfg.sample_rate)
|
| 434 |
+
|
| 435 |
+
# Convert mono → stereo
|
| 436 |
+
if wav.shape[0] == 1:
|
| 437 |
+
wav = wav.repeat(2, 1)
|
| 438 |
+
elif wav.shape[0] > 2:
|
| 439 |
+
wav = wav[:2]
|
| 440 |
+
|
| 441 |
+
# Clamp length
|
| 442 |
+
max_samples = int(self.cfg.max_duration_sec * self.cfg.sample_rate)
|
| 443 |
+
if wav.shape[1] > max_samples:
|
| 444 |
+
wav = wav[:, :max_samples]
|
| 445 |
+
|
| 446 |
+
return wav # [2, T]
|
| 447 |
+
|
| 448 |
+
def _encode_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
| 449 |
+
"""Encode raw waveform → VAE latent on device."""
|
| 450 |
+
with torch.no_grad():
|
| 451 |
+
latent = self.handler._encode_audio_to_latents(wav)
|
| 452 |
+
if latent.dim() == 2:
|
| 453 |
+
latent = latent.unsqueeze(0)
|
| 454 |
+
latent = latent.to(self.dtype)
|
| 455 |
+
return latent
|
| 456 |
+
|
| 457 |
+
def _build_text_embeddings(self, caption: str, lyrics: str):
|
| 458 |
+
"""Compute text & lyric embeddings using the text encoder."""
|
| 459 |
+
tokenizer = self.handler.text_tokenizer
|
| 460 |
+
text_encoder = self.handler.text_encoder
|
| 461 |
+
|
| 462 |
+
# Caption embedding
|
| 463 |
+
text_tokens = tokenizer(
|
| 464 |
+
caption or "",
|
| 465 |
+
return_tensors="pt",
|
| 466 |
+
padding="max_length",
|
| 467 |
+
truncation=True,
|
| 468 |
+
max_length=512,
|
| 469 |
+
).to(self.device)
|
| 470 |
+
|
| 471 |
+
with torch.no_grad():
|
| 472 |
+
text_hidden = text_encoder(
|
| 473 |
+
input_ids=text_tokens["input_ids"]
|
| 474 |
+
).last_hidden_state.to(self.dtype)
|
| 475 |
+
text_mask = text_tokens["attention_mask"].to(self.dtype)
|
| 476 |
+
|
| 477 |
+
# Lyric embedding (token-level via embed_tokens)
|
| 478 |
+
lyric_tokens = tokenizer(
|
| 479 |
+
lyrics or "",
|
| 480 |
+
return_tensors="pt",
|
| 481 |
+
padding="max_length",
|
| 482 |
+
truncation=True,
|
| 483 |
+
max_length=512,
|
| 484 |
+
).to(self.device)
|
| 485 |
+
|
| 486 |
+
with torch.no_grad():
|
| 487 |
+
lyric_hidden = text_encoder.embed_tokens(
|
| 488 |
+
lyric_tokens["input_ids"]
|
| 489 |
+
).to(self.dtype)
|
| 490 |
+
lyric_mask = lyric_tokens["attention_mask"].to(self.dtype)
|
| 491 |
+
|
| 492 |
+
return text_hidden, text_mask, lyric_hidden, lyric_mask
|
| 493 |
+
|
| 494 |
+
# ------------------------------------------------------------------
|
| 495 |
+
# Flow matching loss
|
| 496 |
+
# ------------------------------------------------------------------
|
| 497 |
+
|
| 498 |
+
def _flow_matching_loss(
|
| 499 |
+
self,
|
| 500 |
+
x1: torch.Tensor,
|
| 501 |
+
encoder_hidden_states: torch.Tensor,
|
| 502 |
+
encoder_attention_mask: torch.Tensor,
|
| 503 |
+
context_latents: torch.Tensor,
|
| 504 |
+
) -> torch.Tensor:
|
| 505 |
+
"""Compute rectified-flow MSE loss for one sample.
|
| 506 |
+
|
| 507 |
+
Notation follows ACE-Step convention:
|
| 508 |
+
x0 = noise, x1 = clean latent
|
| 509 |
+
xt = t * x0 + (1 - t) * x1
|
| 510 |
+
target velocity = x0 - x1
|
| 511 |
+
"""
|
| 512 |
+
bsz = x1.shape[0]
|
| 513 |
+
|
| 514 |
+
# Sample random timestep per element
|
| 515 |
+
t = torch.rand(bsz, device=self.device, dtype=self.dtype)
|
| 516 |
+
|
| 517 |
+
# Apply timestep shift: t_shifted = shift * t / (1 + (shift - 1) * t)
|
| 518 |
+
if self.cfg.shift != 1.0:
|
| 519 |
+
t = self.cfg.shift * t / (1.0 + (self.cfg.shift - 1.0) * t)
|
| 520 |
+
|
| 521 |
+
t = t.clamp(1e-5, 1.0 - 1e-5)
|
| 522 |
+
|
| 523 |
+
# Noise
|
| 524 |
+
x0 = torch.randn_like(x1)
|
| 525 |
+
|
| 526 |
+
# Interpolate
|
| 527 |
+
t_expand = t.view(bsz, 1, 1)
|
| 528 |
+
xt = t_expand * x0 + (1.0 - t_expand) * x1
|
| 529 |
+
|
| 530 |
+
# Target velocity
|
| 531 |
+
velocity_target = x0 - x1
|
| 532 |
+
|
| 533 |
+
# Attention mask
|
| 534 |
+
attention_mask = torch.ones(
|
| 535 |
+
bsz, x1.shape[1], device=self.device, dtype=self.dtype
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
# Forward through decoder
|
| 539 |
+
decoder_out = self.handler.model.decoder(
|
| 540 |
+
hidden_states=xt,
|
| 541 |
+
timestep=t,
|
| 542 |
+
timestep_r=t,
|
| 543 |
+
attention_mask=attention_mask,
|
| 544 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 545 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 546 |
+
context_latents=context_latents,
|
| 547 |
+
use_cache=False,
|
| 548 |
+
output_attentions=False,
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
velocity_pred = decoder_out[0] # first element is the predicted output
|
| 552 |
+
loss = F.mse_loss(velocity_pred, velocity_target)
|
| 553 |
+
return loss
|
| 554 |
+
|
| 555 |
+
@staticmethod
|
| 556 |
+
def _pad_and_stack(tensors: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor:
|
| 557 |
+
"""Pad variable-length tensors on dimension 0 and stack as batch."""
|
| 558 |
+
normalized = []
|
| 559 |
+
for t in tensors:
|
| 560 |
+
if t.dim() >= 2 and t.shape[0] == 1:
|
| 561 |
+
normalized.append(t.squeeze(0))
|
| 562 |
+
else:
|
| 563 |
+
normalized.append(t)
|
| 564 |
+
|
| 565 |
+
max_len = max(t.shape[0] for t in normalized)
|
| 566 |
+
template = normalized[0]
|
| 567 |
+
out_shape = (len(normalized), max_len, *template.shape[1:])
|
| 568 |
+
out = template.new_full(out_shape, pad_value)
|
| 569 |
+
for i, t in enumerate(normalized):
|
| 570 |
+
out[i, : t.shape[0]] = t
|
| 571 |
+
return out
|
| 572 |
+
|
| 573 |
+
# ------------------------------------------------------------------
|
| 574 |
+
# Main training loop
|
| 575 |
+
# ------------------------------------------------------------------
|
| 576 |
+
|
| 577 |
+
def request_stop(self):
|
| 578 |
+
"""Ask the training loop to stop after the current step."""
|
| 579 |
+
self._stop_requested = True
|
| 580 |
+
|
| 581 |
+
def train(
|
| 582 |
+
self,
|
| 583 |
+
entries: List[TrackEntry],
|
| 584 |
+
progress_callback=None,
|
| 585 |
+
) -> str:
|
| 586 |
+
"""Run the full LoRA training.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
entries: List of scanned TrackEntry objects.
|
| 590 |
+
progress_callback: ``fn(step, total_steps, loss, epoch)`` for UI updates.
|
| 591 |
+
|
| 592 |
+
Returns:
|
| 593 |
+
Status message.
|
| 594 |
+
"""
|
| 595 |
+
self._stop_requested = False
|
| 596 |
+
self.loss_history.clear()
|
| 597 |
+
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
| 598 |
+
|
| 599 |
+
if not entries:
|
| 600 |
+
return "No training data provided."
|
| 601 |
+
|
| 602 |
+
num_entries = len(entries)
|
| 603 |
+
total_steps = (
|
| 604 |
+
math.ceil(num_entries / self.cfg.batch_size)
|
| 605 |
+
* self.cfg.num_epochs
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
# ---- Pre-encode all audio & text (fits in CPU RAM) ----
|
| 609 |
+
logger.info("Pre-encoding dataset through VAE & text encoder ...")
|
| 610 |
+
dataset: List[Dict[str, Any]] = []
|
| 611 |
+
failed_encode: List[str] = []
|
| 612 |
+
|
| 613 |
+
# Freeze VAE and text encoder (they are not trained)
|
| 614 |
+
self.handler.vae.eval()
|
| 615 |
+
self.handler.text_encoder.eval()
|
| 616 |
+
|
| 617 |
+
# Reuse silence reference latent (same as handler's internal fallback path).
|
| 618 |
+
ref_latent = self.handler.silence_latent[:, :750, :].to(self.device).to(self.dtype)
|
| 619 |
+
ref_order_mask = torch.zeros(1, device=self.device, dtype=torch.long)
|
| 620 |
+
|
| 621 |
+
for idx, entry in enumerate(tqdm(entries, desc="Encoding dataset")):
|
| 622 |
+
try:
|
| 623 |
+
wav = self._load_audio(entry.audio_path)
|
| 624 |
+
latent = self._encode_audio(wav)
|
| 625 |
+
text_h, text_m, lyric_h, lyric_m = self._build_text_embeddings(
|
| 626 |
+
entry.caption, entry.lyrics
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# Prepare condition using the model's own prepare_condition
|
| 630 |
+
with torch.no_grad():
|
| 631 |
+
enc_hs, enc_mask, ctx_lat = self.handler.model.prepare_condition(
|
| 632 |
+
text_hidden_states=text_h,
|
| 633 |
+
text_attention_mask=text_m,
|
| 634 |
+
lyric_hidden_states=lyric_h,
|
| 635 |
+
lyric_attention_mask=lyric_m,
|
| 636 |
+
refer_audio_acoustic_hidden_states_packed=ref_latent,
|
| 637 |
+
refer_audio_order_mask=ref_order_mask,
|
| 638 |
+
hidden_states=latent,
|
| 639 |
+
attention_mask=torch.ones(
|
| 640 |
+
1, latent.shape[1],
|
| 641 |
+
device=self.device, dtype=self.dtype,
|
| 642 |
+
),
|
| 643 |
+
silence_latent=self.handler.silence_latent,
|
| 644 |
+
src_latents=latent,
|
| 645 |
+
chunk_masks=torch.ones_like(latent),
|
| 646 |
+
is_covers=[False],
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
dataset.append(
|
| 650 |
+
{
|
| 651 |
+
"latent": latent.cpu(),
|
| 652 |
+
"enc_hs": enc_hs.cpu(),
|
| 653 |
+
"enc_mask": enc_mask.cpu(),
|
| 654 |
+
"ctx_lat": ctx_lat.cpu(),
|
| 655 |
+
"name": Path(entry.audio_path).stem,
|
| 656 |
+
}
|
| 657 |
+
)
|
| 658 |
+
except Exception as exc:
|
| 659 |
+
reason = f"{Path(entry.audio_path).name}: {exc}"
|
| 660 |
+
failed_encode.append(reason)
|
| 661 |
+
logger.warning(f"Skipping {entry.audio_path}: {exc}")
|
| 662 |
+
|
| 663 |
+
if not dataset:
|
| 664 |
+
preview = "\n".join(f"- {msg}" for msg in failed_encode[:8]) or "- (no detailed errors captured)"
|
| 665 |
+
return (
|
| 666 |
+
"All tracks failed to encode. Check audio files.\n"
|
| 667 |
+
"First errors:\n"
|
| 668 |
+
f"{preview}\n"
|
| 669 |
+
"Tip: try WAV/FLAC files and dataset folder scan instead of temporary uploads."
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
logger.info(f"Encoded {len(dataset)}/{num_entries} tracks.")
|
| 673 |
+
|
| 674 |
+
# ---- Warmup scheduler ----
|
| 675 |
+
total_optim_steps = math.ceil(
|
| 676 |
+
total_steps / self.cfg.gradient_accumulation_steps
|
| 677 |
+
)
|
| 678 |
+
warmup_steps = int(total_optim_steps * self.cfg.warmup_ratio)
|
| 679 |
+
|
| 680 |
+
if self.cfg.scheduler in {"constant_with_warmup", "linear", "cosine"}:
|
| 681 |
+
try:
|
| 682 |
+
from transformers import get_scheduler
|
| 683 |
+
self.scheduler = get_scheduler(
|
| 684 |
+
name=self.cfg.scheduler,
|
| 685 |
+
optimizer=self.optimizer,
|
| 686 |
+
num_warmup_steps=warmup_steps,
|
| 687 |
+
num_training_steps=total_optim_steps,
|
| 688 |
+
)
|
| 689 |
+
except Exception as exc:
|
| 690 |
+
logger.warning(f"Could not create scheduler '{self.cfg.scheduler}', disabling scheduler: {exc}")
|
| 691 |
+
self.scheduler = None
|
| 692 |
+
else:
|
| 693 |
+
self.scheduler = None
|
| 694 |
+
|
| 695 |
+
# ---- Training loop ----
|
| 696 |
+
logger.info(
|
| 697 |
+
f"Starting LoRA training: {self.cfg.num_epochs} epochs, "
|
| 698 |
+
f"{len(dataset)} samples, {total_optim_steps} optimiser steps"
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
self.peft_model.train()
|
| 702 |
+
accum_loss = 0.0
|
| 703 |
+
step_in_accum = 0
|
| 704 |
+
|
| 705 |
+
for epoch in range(self.current_epoch, self.cfg.num_epochs):
|
| 706 |
+
if self._stop_requested:
|
| 707 |
+
break
|
| 708 |
+
|
| 709 |
+
self.current_epoch = epoch
|
| 710 |
+
indices = list(range(len(dataset)))
|
| 711 |
+
random.shuffle(indices)
|
| 712 |
+
|
| 713 |
+
epoch_loss = 0.0
|
| 714 |
+
epoch_steps = 0
|
| 715 |
+
|
| 716 |
+
for i in range(0, len(indices), self.cfg.batch_size):
|
| 717 |
+
if self._stop_requested:
|
| 718 |
+
break
|
| 719 |
+
|
| 720 |
+
batch_indices = indices[i : i + self.cfg.batch_size]
|
| 721 |
+
batch_items = [dataset[j] for j in batch_indices]
|
| 722 |
+
|
| 723 |
+
# Move batch to device
|
| 724 |
+
latents = self._pad_and_stack([it["latent"] for it in batch_items]).to(self.device, self.dtype)
|
| 725 |
+
enc_hs = self._pad_and_stack([it["enc_hs"] for it in batch_items]).to(self.device, self.dtype)
|
| 726 |
+
enc_mask = self._pad_and_stack([it["enc_mask"] for it in batch_items], pad_value=0.0).to(self.device)
|
| 727 |
+
if enc_mask.dtype != self.dtype:
|
| 728 |
+
enc_mask = enc_mask.to(self.dtype)
|
| 729 |
+
ctx_lat = self._pad_and_stack([it["ctx_lat"] for it in batch_items]).to(self.device, self.dtype)
|
| 730 |
+
|
| 731 |
+
# Forward + loss
|
| 732 |
+
loss = self._flow_matching_loss(latents, enc_hs, enc_mask, ctx_lat)
|
| 733 |
+
loss = loss / self.cfg.gradient_accumulation_steps
|
| 734 |
+
loss.backward()
|
| 735 |
+
|
| 736 |
+
accum_loss += loss.item()
|
| 737 |
+
step_in_accum += 1
|
| 738 |
+
|
| 739 |
+
if step_in_accum >= self.cfg.gradient_accumulation_steps:
|
| 740 |
+
torch.nn.utils.clip_grad_norm_(
|
| 741 |
+
self.peft_model.parameters(), self.cfg.max_grad_norm
|
| 742 |
+
)
|
| 743 |
+
self.optimizer.step()
|
| 744 |
+
if self.scheduler is not None:
|
| 745 |
+
self.scheduler.step()
|
| 746 |
+
self.optimizer.zero_grad()
|
| 747 |
+
|
| 748 |
+
self.global_step += 1
|
| 749 |
+
avg_loss = accum_loss
|
| 750 |
+
accum_loss = 0.0
|
| 751 |
+
step_in_accum = 0
|
| 752 |
+
|
| 753 |
+
self.loss_history.append(
|
| 754 |
+
{
|
| 755 |
+
"step": self.global_step,
|
| 756 |
+
"epoch": epoch,
|
| 757 |
+
"loss": avg_loss,
|
| 758 |
+
"lr": self.optimizer.param_groups[0]["lr"],
|
| 759 |
+
}
|
| 760 |
+
)
|
| 761 |
+
|
| 762 |
+
if self.global_step % self.cfg.log_every_n_steps == 0:
|
| 763 |
+
logger.info(
|
| 764 |
+
f"Epoch {epoch+1}/{self.cfg.num_epochs} "
|
| 765 |
+
f"Step {self.global_step}/{total_optim_steps} "
|
| 766 |
+
f"Loss {avg_loss:.6f} "
|
| 767 |
+
f"LR {self.optimizer.param_groups[0]['lr']:.2e}"
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
if progress_callback:
|
| 771 |
+
progress_callback(
|
| 772 |
+
self.global_step, total_optim_steps, avg_loss, epoch
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
epoch_loss += loss.item() * self.cfg.gradient_accumulation_steps
|
| 776 |
+
epoch_steps += 1
|
| 777 |
+
|
| 778 |
+
# Flush remaining micro-batches when len(dataset) is not divisible by grad accumulation.
|
| 779 |
+
if step_in_accum > 0:
|
| 780 |
+
torch.nn.utils.clip_grad_norm_(self.peft_model.parameters(), self.cfg.max_grad_norm)
|
| 781 |
+
self.optimizer.step()
|
| 782 |
+
if self.scheduler is not None:
|
| 783 |
+
self.scheduler.step()
|
| 784 |
+
self.optimizer.zero_grad()
|
| 785 |
+
self.global_step += 1
|
| 786 |
+
avg_loss = accum_loss
|
| 787 |
+
accum_loss = 0.0
|
| 788 |
+
step_in_accum = 0
|
| 789 |
+
self.loss_history.append(
|
| 790 |
+
{
|
| 791 |
+
"step": self.global_step,
|
| 792 |
+
"epoch": epoch,
|
| 793 |
+
"loss": avg_loss,
|
| 794 |
+
"lr": self.optimizer.param_groups[0]["lr"],
|
| 795 |
+
}
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
# End of epoch – checkpoint?
|
| 799 |
+
if (
|
| 800 |
+
(epoch + 1) % self.cfg.save_every_n_epochs == 0
|
| 801 |
+
or epoch == self.cfg.num_epochs - 1
|
| 802 |
+
or self._stop_requested
|
| 803 |
+
):
|
| 804 |
+
self._save_checkpoint(epoch)
|
| 805 |
+
|
| 806 |
+
if epoch_steps > 0:
|
| 807 |
+
avg_epoch_loss = epoch_loss / epoch_steps
|
| 808 |
+
logger.info(
|
| 809 |
+
f"Epoch {epoch+1} complete – avg loss {avg_epoch_loss:.6f}"
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
# Final save
|
| 813 |
+
final_dir = self._save_checkpoint(self.current_epoch, final=True)
|
| 814 |
+
status = (
|
| 815 |
+
"Training stopped early." if self._stop_requested else "Training complete!"
|
| 816 |
+
)
|
| 817 |
+
return f"{status} Adapter saved to {final_dir}"
|
| 818 |
+
|
| 819 |
+
# ------------------------------------------------------------------
|
| 820 |
+
# Checkpointing
|
| 821 |
+
# ------------------------------------------------------------------
|
| 822 |
+
|
| 823 |
+
def _save_checkpoint(self, epoch: int, final: bool = False) -> str:
|
| 824 |
+
tag = "final" if final else f"epoch-{epoch+1}"
|
| 825 |
+
save_dir = os.path.join(self.cfg.output_dir, tag)
|
| 826 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 827 |
+
|
| 828 |
+
# Save PEFT adapter
|
| 829 |
+
self.peft_model.save_pretrained(save_dir)
|
| 830 |
+
|
| 831 |
+
# Save training state
|
| 832 |
+
torch.save(
|
| 833 |
+
{
|
| 834 |
+
"optimizer": self.optimizer.state_dict(),
|
| 835 |
+
"global_step": self.global_step,
|
| 836 |
+
"epoch": epoch,
|
| 837 |
+
},
|
| 838 |
+
os.path.join(save_dir, "training_state.pt"),
|
| 839 |
+
)
|
| 840 |
+
|
| 841 |
+
# Save loss curve
|
| 842 |
+
loss_path = os.path.join(save_dir, "loss_history.json")
|
| 843 |
+
with open(loss_path, "w") as f:
|
| 844 |
+
json.dump(self.loss_history, f)
|
| 845 |
+
|
| 846 |
+
# Save config
|
| 847 |
+
cfg_path = os.path.join(save_dir, "train_config.json")
|
| 848 |
+
with open(cfg_path, "w") as f:
|
| 849 |
+
json.dump(asdict(self.cfg), f, indent=2)
|
| 850 |
+
|
| 851 |
+
logger.info(f"Checkpoint saved → {save_dir}")
|
| 852 |
+
return save_dir
|
| 853 |
+
|
| 854 |
+
# ------------------------------------------------------------------
|
| 855 |
+
# Adapter listing
|
| 856 |
+
# ------------------------------------------------------------------
|
| 857 |
+
|
| 858 |
+
@staticmethod
|
| 859 |
+
def list_adapters(output_dir: str = "lora_output") -> List[str]:
|
| 860 |
+
"""Return adapter directories inside *output_dir* (recursive)."""
|
| 861 |
+
results = []
|
| 862 |
+
root = Path(output_dir)
|
| 863 |
+
if not root.is_dir():
|
| 864 |
+
return results
|
| 865 |
+
for cfg in sorted(root.rglob("adapter_config.json")):
|
| 866 |
+
d = cfg.parent
|
| 867 |
+
if d.is_dir():
|
| 868 |
+
results.append(str(d))
|
| 869 |
+
return results
|
| 870 |
+
|
| 871 |
+
|
| 872 |
+
def _build_arg_parser() -> argparse.ArgumentParser:
|
| 873 |
+
parser = argparse.ArgumentParser(description="ACE-Step 1.5 LoRA trainer (CLI)")
|
| 874 |
+
|
| 875 |
+
# Dataset
|
| 876 |
+
parser.add_argument("--dataset-dir", type=str, default="", help="Local dataset folder path")
|
| 877 |
+
parser.add_argument("--dataset-repo", type=str, default="", help="HF dataset repo id (optional)")
|
| 878 |
+
parser.add_argument("--dataset-revision", type=str, default="main", help="HF dataset revision")
|
| 879 |
+
parser.add_argument("--dataset-subdir", type=str, default="", help="Subdirectory inside downloaded dataset")
|
| 880 |
+
|
| 881 |
+
# Model init
|
| 882 |
+
parser.add_argument("--model-config", type=str, default="acestep-v15-base", help="DiT config name")
|
| 883 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"])
|
| 884 |
+
parser.add_argument("--offload-to-cpu", action="store_true")
|
| 885 |
+
parser.add_argument("--offload-dit-to-cpu", action="store_true")
|
| 886 |
+
parser.add_argument("--prefer-source", type=str, default="huggingface", choices=["huggingface", "modelscope"])
|
| 887 |
+
|
| 888 |
+
# Train config
|
| 889 |
+
parser.add_argument("--output-dir", type=str, default="lora_output")
|
| 890 |
+
parser.add_argument("--resume-from", type=str, default="")
|
| 891 |
+
parser.add_argument("--num-epochs", type=int, default=50)
|
| 892 |
+
parser.add_argument("--batch-size", type=int, default=1)
|
| 893 |
+
parser.add_argument("--grad-accum", type=int, default=1)
|
| 894 |
+
parser.add_argument("--save-every", type=int, default=10)
|
| 895 |
+
parser.add_argument("--log-every", type=int, default=5)
|
| 896 |
+
parser.add_argument("--max-duration-sec", type=float, default=240.0)
|
| 897 |
+
|
| 898 |
+
parser.add_argument("--lora-rank", type=int, default=64)
|
| 899 |
+
parser.add_argument("--lora-alpha", type=int, default=64)
|
| 900 |
+
parser.add_argument("--lora-dropout", type=float, default=0.1)
|
| 901 |
+
|
| 902 |
+
parser.add_argument("--learning-rate", type=float, default=1e-4)
|
| 903 |
+
parser.add_argument("--weight-decay", type=float, default=0.01)
|
| 904 |
+
parser.add_argument("--optimizer", type=str, default="adamw_8bit", choices=["adamw", "adamw_8bit"])
|
| 905 |
+
parser.add_argument("--max-grad-norm", type=float, default=1.0)
|
| 906 |
+
parser.add_argument("--warmup-ratio", type=float, default=0.03)
|
| 907 |
+
parser.add_argument("--scheduler", type=str, default="constant_with_warmup", choices=["constant_with_warmup", "linear", "cosine"])
|
| 908 |
+
parser.add_argument("--shift", type=float, default=3.0)
|
| 909 |
+
|
| 910 |
+
# Optional upload
|
| 911 |
+
parser.add_argument("--upload-repo", type=str, default="", help="HF model repo to upload final adapter")
|
| 912 |
+
parser.add_argument("--upload-path", type=str, default="", help="Path inside upload repo (optional)")
|
| 913 |
+
parser.add_argument("--upload-private", action="store_true")
|
| 914 |
+
parser.add_argument("--hf-token-env", type=str, default="HF_TOKEN", help="Environment variable name for HF token")
|
| 915 |
+
|
| 916 |
+
return parser
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def _resolve_dataset_dir(args) -> str:
|
| 920 |
+
if args.dataset_dir:
|
| 921 |
+
return args.dataset_dir
|
| 922 |
+
|
| 923 |
+
if not args.dataset_repo:
|
| 924 |
+
raise ValueError("Provide --dataset-dir or --dataset-repo.")
|
| 925 |
+
|
| 926 |
+
from huggingface_hub import snapshot_download
|
| 927 |
+
|
| 928 |
+
token = os.getenv(args.hf_token_env)
|
| 929 |
+
temp_root = tempfile.mkdtemp(prefix="acestep_lora_dataset_")
|
| 930 |
+
local_dir = os.path.join(temp_root, "dataset")
|
| 931 |
+
logger.info(f"Downloading dataset repo {args.dataset_repo}@{args.dataset_revision} to {local_dir}")
|
| 932 |
+
snapshot_download(
|
| 933 |
+
repo_id=args.dataset_repo,
|
| 934 |
+
repo_type="dataset",
|
| 935 |
+
revision=args.dataset_revision,
|
| 936 |
+
local_dir=local_dir,
|
| 937 |
+
local_dir_use_symlinks=False,
|
| 938 |
+
token=token,
|
| 939 |
+
)
|
| 940 |
+
if args.dataset_subdir:
|
| 941 |
+
sub = os.path.join(local_dir, args.dataset_subdir)
|
| 942 |
+
if not os.path.isdir(sub):
|
| 943 |
+
raise FileNotFoundError(f"Dataset subdir not found: {sub}")
|
| 944 |
+
return sub
|
| 945 |
+
return local_dir
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def _upload_adapter_if_requested(args, final_dir: str):
|
| 949 |
+
if not args.upload_repo:
|
| 950 |
+
return
|
| 951 |
+
|
| 952 |
+
from huggingface_hub import HfApi
|
| 953 |
+
|
| 954 |
+
token = os.getenv(args.hf_token_env)
|
| 955 |
+
if not token:
|
| 956 |
+
raise RuntimeError(
|
| 957 |
+
f"{args.hf_token_env} is not set. Needed for upload to {args.upload_repo}."
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
api = HfApi(token=token)
|
| 961 |
+
api.create_repo(
|
| 962 |
+
repo_id=args.upload_repo,
|
| 963 |
+
repo_type="model",
|
| 964 |
+
exist_ok=True,
|
| 965 |
+
private=bool(args.upload_private),
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
path_in_repo = args.upload_path.strip().strip("/") if args.upload_path else ""
|
| 969 |
+
commit_message = f"Upload ACE-Step LoRA adapter from {Path(final_dir).name}"
|
| 970 |
+
logger.info(f"Uploading adapter from {final_dir} to {args.upload_repo}/{path_in_repo}")
|
| 971 |
+
api.upload_folder(
|
| 972 |
+
repo_id=args.upload_repo,
|
| 973 |
+
repo_type="model",
|
| 974 |
+
folder_path=final_dir,
|
| 975 |
+
path_in_repo=path_in_repo,
|
| 976 |
+
commit_message=commit_message,
|
| 977 |
+
)
|
| 978 |
+
logger.info("Upload complete")
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
def main():
|
| 982 |
+
args = _build_arg_parser().parse_args()
|
| 983 |
+
|
| 984 |
+
dataset_dir = _resolve_dataset_dir(args)
|
| 985 |
+
entries = scan_dataset_folder(dataset_dir)
|
| 986 |
+
if not entries:
|
| 987 |
+
raise RuntimeError(f"No audio files found in dataset: {dataset_dir}")
|
| 988 |
+
|
| 989 |
+
from acestep.handler import AceStepHandler
|
| 990 |
+
|
| 991 |
+
project_root = str(Path(__file__).resolve().parent)
|
| 992 |
+
handler = AceStepHandler()
|
| 993 |
+
status, ok = handler.initialize_service(
|
| 994 |
+
project_root=project_root,
|
| 995 |
+
config_path=args.model_config,
|
| 996 |
+
device=args.device,
|
| 997 |
+
use_flash_attention=False,
|
| 998 |
+
compile_model=False,
|
| 999 |
+
offload_to_cpu=bool(args.offload_to_cpu),
|
| 1000 |
+
offload_dit_to_cpu=bool(args.offload_dit_to_cpu),
|
| 1001 |
+
prefer_source=args.prefer_source,
|
| 1002 |
+
)
|
| 1003 |
+
print(status)
|
| 1004 |
+
if not ok:
|
| 1005 |
+
raise RuntimeError("Model initialization failed")
|
| 1006 |
+
|
| 1007 |
+
cfg = LoRATrainConfig(
|
| 1008 |
+
lora_rank=args.lora_rank,
|
| 1009 |
+
lora_alpha=args.lora_alpha,
|
| 1010 |
+
lora_dropout=args.lora_dropout,
|
| 1011 |
+
learning_rate=args.learning_rate,
|
| 1012 |
+
weight_decay=args.weight_decay,
|
| 1013 |
+
optimizer=args.optimizer,
|
| 1014 |
+
max_grad_norm=args.max_grad_norm,
|
| 1015 |
+
warmup_ratio=args.warmup_ratio,
|
| 1016 |
+
scheduler=args.scheduler,
|
| 1017 |
+
num_epochs=args.num_epochs,
|
| 1018 |
+
batch_size=args.batch_size,
|
| 1019 |
+
gradient_accumulation_steps=args.grad_accum,
|
| 1020 |
+
save_every_n_epochs=args.save_every,
|
| 1021 |
+
log_every_n_steps=args.log_every,
|
| 1022 |
+
shift=args.shift,
|
| 1023 |
+
max_duration_sec=args.max_duration_sec,
|
| 1024 |
+
output_dir=args.output_dir,
|
| 1025 |
+
resume_from=(args.resume_from.strip() if args.resume_from else None),
|
| 1026 |
+
device=args.device,
|
| 1027 |
+
)
|
| 1028 |
+
|
| 1029 |
+
trainer = LoRATrainer(handler, cfg)
|
| 1030 |
+
trainer.prepare()
|
| 1031 |
+
|
| 1032 |
+
start = time.time()
|
| 1033 |
+
|
| 1034 |
+
def _progress(step, total, loss, epoch):
|
| 1035 |
+
elapsed = time.time() - start
|
| 1036 |
+
rate = step / elapsed if elapsed > 0 else 0.0
|
| 1037 |
+
remaining = max(0.0, total - step)
|
| 1038 |
+
eta_sec = remaining / rate if rate > 0 else -1.0
|
| 1039 |
+
eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown"
|
| 1040 |
+
logger.info(
|
| 1041 |
+
f"[progress] step={step}/{total} epoch={epoch+1} loss={loss:.6f} elapsed={elapsed/60:.1f}m eta={eta_msg}"
|
| 1042 |
+
)
|
| 1043 |
+
|
| 1044 |
+
msg = trainer.train(entries, progress_callback=_progress)
|
| 1045 |
+
print(msg)
|
| 1046 |
+
|
| 1047 |
+
final_dir = os.path.join(cfg.output_dir, "final")
|
| 1048 |
+
if os.path.isdir(final_dir):
|
| 1049 |
+
_upload_adapter_if_requested(args, final_dir)
|
| 1050 |
+
print(f"Final adapter directory: {final_dir}")
|
| 1051 |
+
else:
|
| 1052 |
+
print(f"Warning: final adapter directory not found at {final_dir}")
|
| 1053 |
+
|
| 1054 |
+
|
| 1055 |
+
if __name__ == "__main__":
|
| 1056 |
+
main()
|
lora_ui.py
ADDED
|
@@ -0,0 +1,973 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step 1.5 LoRA Training and Evaluation UI.
|
| 3 |
+
|
| 4 |
+
Gradio interface with four tabs:
|
| 5 |
+
1. Model Setup: initialize base DiT, VAE, and text encoder
|
| 6 |
+
2. Dataset: scan folder or drop files, then edit/save sidecars
|
| 7 |
+
3. Training: configure hyperparameters and run LoRA training
|
| 8 |
+
4. Evaluation: load adapters and run deterministic A/B generation
|
| 9 |
+
"""
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import json
|
| 13 |
+
import math
|
| 14 |
+
import random
|
| 15 |
+
import threading
|
| 16 |
+
import tempfile
|
| 17 |
+
import time
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import List, Optional
|
| 20 |
+
|
| 21 |
+
import gradio as gr
|
| 22 |
+
# On Hugging Face Spaces Zero, `spaces` must be imported before CUDA-related modules.
|
| 23 |
+
if os.getenv("SPACE_ID"):
|
| 24 |
+
try:
|
| 25 |
+
import spaces # noqa: F401
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
import torch
|
| 29 |
+
from loguru import logger
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Ensure project root is on sys.path so `acestep` imports work
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
+
PROJECT_ROOT = str(Path(__file__).resolve().parent)
|
| 35 |
+
if PROJECT_ROOT not in sys.path:
|
| 36 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 37 |
+
|
| 38 |
+
from acestep.handler import AceStepHandler
|
| 39 |
+
from acestep.audio_utils import AudioSaver
|
| 40 |
+
from acestep.llm_inference import LLMHandler
|
| 41 |
+
from acestep.inference import understand_music
|
| 42 |
+
from lora_train import (
|
| 43 |
+
LoRATrainConfig,
|
| 44 |
+
LoRATrainer,
|
| 45 |
+
TrackEntry,
|
| 46 |
+
scan_dataset_folder,
|
| 47 |
+
scan_uploaded_files,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
# Globals (shared across Gradio callbacks)
|
| 52 |
+
# ---------------------------------------------------------------------------
|
| 53 |
+
handler = AceStepHandler()
|
| 54 |
+
llm_handler = LLMHandler()
|
| 55 |
+
trainer: Optional[LoRATrainer] = None
|
| 56 |
+
dataset_entries: List[TrackEntry] = []
|
| 57 |
+
_training_thread: Optional[threading.Thread] = None
|
| 58 |
+
_training_log: List[str] = []
|
| 59 |
+
_training_status: str = "idle" # idle | running | stopped | done
|
| 60 |
+
_training_started_at: Optional[float] = None
|
| 61 |
+
_model_init_ok: bool = False
|
| 62 |
+
_model_init_status: str = ""
|
| 63 |
+
_last_model_init_args: Optional[dict] = None
|
| 64 |
+
_lm_init_ok: bool = False
|
| 65 |
+
_last_lm_init_args: Optional[dict] = None
|
| 66 |
+
_auto_label_cursor: int = 0
|
| 67 |
+
|
| 68 |
+
audio_saver = AudioSaver(default_format="wav")
|
| 69 |
+
IS_SPACE = bool(os.getenv("SPACE_ID"))
|
| 70 |
+
DEFAULT_OUTPUT_DIR = "/data/lora_output" if IS_SPACE else "lora_output"
|
| 71 |
+
|
| 72 |
+
if IS_SPACE:
|
| 73 |
+
try:
|
| 74 |
+
import spaces as _hf_spaces
|
| 75 |
+
_gpu_callback = _hf_spaces.GPU(duration=300)
|
| 76 |
+
except Exception:
|
| 77 |
+
_gpu_callback = lambda fn: fn
|
| 78 |
+
else:
|
| 79 |
+
_gpu_callback = lambda fn: fn
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _rows_from_entries(entries: List[TrackEntry]):
|
| 83 |
+
rows = []
|
| 84 |
+
for e in entries:
|
| 85 |
+
rows.append([
|
| 86 |
+
Path(e.audio_path).name,
|
| 87 |
+
f"{e.duration:.1f}s" if e.duration else "?",
|
| 88 |
+
e.caption or "(none)",
|
| 89 |
+
e.lyrics[:60] + "..." if len(e.lyrics) > 60 else (e.lyrics or "(none)"),
|
| 90 |
+
e.vocal_language,
|
| 91 |
+
])
|
| 92 |
+
return rows
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
# ===========================================================================
|
| 96 |
+
# Tab 1 - Model Setup
|
| 97 |
+
# ===========================================================================
|
| 98 |
+
|
| 99 |
+
def get_available_models():
|
| 100 |
+
models = handler.get_available_acestep_v15_models()
|
| 101 |
+
return models if models else ["acestep-v15-base"]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def init_model(
|
| 105 |
+
model_name: str,
|
| 106 |
+
device: str,
|
| 107 |
+
offload_cpu: bool,
|
| 108 |
+
offload_dit_cpu: bool,
|
| 109 |
+
):
|
| 110 |
+
global _model_init_ok, _model_init_status, _last_model_init_args
|
| 111 |
+
_last_model_init_args = dict(
|
| 112 |
+
project_root=PROJECT_ROOT,
|
| 113 |
+
config_path=model_name,
|
| 114 |
+
device=device,
|
| 115 |
+
use_flash_attention=False,
|
| 116 |
+
compile_model=False,
|
| 117 |
+
offload_to_cpu=offload_cpu,
|
| 118 |
+
offload_dit_to_cpu=offload_dit_cpu,
|
| 119 |
+
)
|
| 120 |
+
status, ok = _init_model_gpu(**_last_model_init_args)
|
| 121 |
+
_model_init_ok = bool(ok)
|
| 122 |
+
_model_init_status = status or ""
|
| 123 |
+
return status
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@_gpu_callback
|
| 127 |
+
def _init_model_gpu(**kwargs):
|
| 128 |
+
return _init_model_impl(**kwargs)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _init_model_impl(**kwargs):
|
| 132 |
+
return handler.initialize_service(**kwargs)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ===========================================================================
|
| 136 |
+
# Tab 2 - Dataset
|
| 137 |
+
# ===========================================================================
|
| 138 |
+
|
| 139 |
+
def scan_folder(folder_path: str):
|
| 140 |
+
global dataset_entries, _auto_label_cursor
|
| 141 |
+
if not folder_path or not os.path.isdir(folder_path):
|
| 142 |
+
return "Provide a valid folder path.", []
|
| 143 |
+
dataset_entries = scan_dataset_folder(folder_path)
|
| 144 |
+
_auto_label_cursor = 0
|
| 145 |
+
rows = _rows_from_entries(dataset_entries)
|
| 146 |
+
msg = f"Found {len(dataset_entries)} audio files."
|
| 147 |
+
return msg, rows
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def load_uploaded(file_paths: List[str]):
|
| 151 |
+
global dataset_entries, _auto_label_cursor
|
| 152 |
+
if not file_paths:
|
| 153 |
+
return "Drop audio files (and optional .json sidecars) first.", []
|
| 154 |
+
sidecar_count = sum(
|
| 155 |
+
1 for p in file_paths if isinstance(p, str) and Path(p).suffix.lower() == ".json"
|
| 156 |
+
)
|
| 157 |
+
dataset_entries = scan_uploaded_files(file_paths)
|
| 158 |
+
_auto_label_cursor = 0
|
| 159 |
+
rows = _rows_from_entries(dataset_entries)
|
| 160 |
+
msg = (
|
| 161 |
+
f"Loaded {len(dataset_entries)} dropped audio files."
|
| 162 |
+
+ (f" Matched {sidecar_count} uploaded sidecar JSON file(s)." if sidecar_count else "")
|
| 163 |
+
)
|
| 164 |
+
return msg, rows
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def save_sidecar(index: int, caption: str, lyrics: str, bpm: str, keyscale: str, lang: str):
|
| 168 |
+
"""Save metadata edits back to a JSON sidecar and update in-memory entry."""
|
| 169 |
+
global dataset_entries
|
| 170 |
+
if index < 0 or index >= len(dataset_entries):
|
| 171 |
+
return "Invalid track index."
|
| 172 |
+
entry = dataset_entries[index]
|
| 173 |
+
entry.caption = caption
|
| 174 |
+
entry.lyrics = lyrics
|
| 175 |
+
if bpm.strip():
|
| 176 |
+
try:
|
| 177 |
+
entry.bpm = int(float(bpm))
|
| 178 |
+
except ValueError:
|
| 179 |
+
return "Invalid BPM value. Use an integer or leave empty."
|
| 180 |
+
else:
|
| 181 |
+
entry.bpm = None
|
| 182 |
+
entry.keyscale = keyscale
|
| 183 |
+
entry.vocal_language = lang
|
| 184 |
+
|
| 185 |
+
sidecar_path = Path(entry.audio_path).with_suffix(".json")
|
| 186 |
+
meta = {
|
| 187 |
+
"caption": entry.caption,
|
| 188 |
+
"lyrics": entry.lyrics,
|
| 189 |
+
"bpm": entry.bpm,
|
| 190 |
+
"keyscale": entry.keyscale,
|
| 191 |
+
"timesignature": entry.timesignature,
|
| 192 |
+
"vocal_language": entry.vocal_language,
|
| 193 |
+
"duration": entry.duration,
|
| 194 |
+
}
|
| 195 |
+
sidecar_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 196 |
+
return f"Saved sidecar for {Path(entry.audio_path).name}"
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def init_auto_label_lm(lm_model_path: str, lm_backend: str, lm_device: str):
|
| 200 |
+
global _lm_init_ok, _last_lm_init_args
|
| 201 |
+
_last_lm_init_args = dict(
|
| 202 |
+
lm_model_path=lm_model_path,
|
| 203 |
+
lm_backend=lm_backend,
|
| 204 |
+
lm_device=lm_device,
|
| 205 |
+
)
|
| 206 |
+
status = _init_auto_label_lm_gpu(**_last_lm_init_args)
|
| 207 |
+
_lm_init_ok = not str(status).startswith("LM init failed:") and not str(status).startswith("LM init exception:")
|
| 208 |
+
return status
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@_gpu_callback
|
| 212 |
+
def _init_auto_label_lm_gpu(lm_model_path: str, lm_backend: str, lm_device: str):
|
| 213 |
+
return _init_auto_label_lm_impl(lm_model_path, lm_backend, lm_device)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _init_auto_label_lm_impl(lm_model_path: str, lm_backend: str, lm_device: str):
|
| 217 |
+
"""Initialize LLM for dataset auto-labeling."""
|
| 218 |
+
checkpoint_dir = os.path.join(PROJECT_ROOT, "checkpoints")
|
| 219 |
+
full_lm_path = os.path.join(checkpoint_dir, lm_model_path)
|
| 220 |
+
|
| 221 |
+
try:
|
| 222 |
+
if not os.path.exists(full_lm_path):
|
| 223 |
+
from pathlib import Path as _Path
|
| 224 |
+
from acestep.model_downloader import ensure_main_model, ensure_lm_model
|
| 225 |
+
|
| 226 |
+
if lm_model_path == "acestep-5Hz-lm-1.7B":
|
| 227 |
+
ok, msg = ensure_main_model(
|
| 228 |
+
checkpoints_dir=_Path(checkpoint_dir),
|
| 229 |
+
prefer_source="huggingface",
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
ok, msg = ensure_lm_model(
|
| 233 |
+
model_name=lm_model_path,
|
| 234 |
+
checkpoints_dir=_Path(checkpoint_dir),
|
| 235 |
+
prefer_source="huggingface",
|
| 236 |
+
)
|
| 237 |
+
if not ok:
|
| 238 |
+
return f"Failed to download LM model: {msg}"
|
| 239 |
+
|
| 240 |
+
status, ok = llm_handler.initialize(
|
| 241 |
+
checkpoint_dir=checkpoint_dir,
|
| 242 |
+
lm_model_path=lm_model_path,
|
| 243 |
+
backend=lm_backend,
|
| 244 |
+
device=lm_device,
|
| 245 |
+
offload_to_cpu=False,
|
| 246 |
+
)
|
| 247 |
+
return status if ok else f"LM init failed:\n{status}"
|
| 248 |
+
except Exception as exc:
|
| 249 |
+
logger.exception("LM init failed for auto-label")
|
| 250 |
+
return f"LM init exception: {exc}"
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _write_entry_sidecar(entry: TrackEntry):
|
| 254 |
+
sidecar_path = Path(entry.audio_path).with_suffix(".json")
|
| 255 |
+
meta = {
|
| 256 |
+
"caption": entry.caption,
|
| 257 |
+
"lyrics": entry.lyrics,
|
| 258 |
+
"bpm": entry.bpm,
|
| 259 |
+
"keyscale": entry.keyscale,
|
| 260 |
+
"timesignature": entry.timesignature,
|
| 261 |
+
"vocal_language": entry.vocal_language,
|
| 262 |
+
"duration": entry.duration,
|
| 263 |
+
}
|
| 264 |
+
sidecar_path.write_text(json.dumps(meta, indent=2, ensure_ascii=False), encoding="utf-8")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@_gpu_callback
|
| 268 |
+
def auto_label_all(overwrite_existing: bool, caption_only: bool, max_files_per_run: int = 6, reset_cursor: bool = False):
|
| 269 |
+
"""Auto-label all loaded tracks using ACE audio understanding (audio->codes->metadata)."""
|
| 270 |
+
global dataset_entries, _auto_label_cursor
|
| 271 |
+
|
| 272 |
+
if handler.model is None:
|
| 273 |
+
if _model_init_ok and _last_model_init_args:
|
| 274 |
+
status, ok = _init_model_impl(**_last_model_init_args)
|
| 275 |
+
if not ok:
|
| 276 |
+
return f"Model reload failed before auto-label:\n{status}", [], "Auto-label skipped."
|
| 277 |
+
else:
|
| 278 |
+
return "Initialize model first in Step 1.", [], "Auto-label skipped."
|
| 279 |
+
if not dataset_entries:
|
| 280 |
+
return "Load dataset first in Step 2.", [], "Auto-label skipped."
|
| 281 |
+
if not llm_handler.llm_initialized:
|
| 282 |
+
if _lm_init_ok and _last_lm_init_args:
|
| 283 |
+
status = _init_auto_label_lm_impl(**_last_lm_init_args)
|
| 284 |
+
if not llm_handler.llm_initialized:
|
| 285 |
+
return (
|
| 286 |
+
f"Auto-label LM reload failed:\n{status}",
|
| 287 |
+
_rows_from_entries(dataset_entries),
|
| 288 |
+
"Auto-label skipped.",
|
| 289 |
+
)
|
| 290 |
+
else:
|
| 291 |
+
return "Initialize Auto-Label LM first.", _rows_from_entries(dataset_entries), "Auto-label skipped."
|
| 292 |
+
|
| 293 |
+
if max_files_per_run <= 0:
|
| 294 |
+
max_files_per_run = 6
|
| 295 |
+
if reset_cursor:
|
| 296 |
+
_auto_label_cursor = 0
|
| 297 |
+
if _auto_label_cursor < 0 or _auto_label_cursor >= len(dataset_entries):
|
| 298 |
+
_auto_label_cursor = 0
|
| 299 |
+
|
| 300 |
+
start_idx = _auto_label_cursor
|
| 301 |
+
end_idx = min(len(dataset_entries), start_idx + int(max_files_per_run))
|
| 302 |
+
|
| 303 |
+
updated = 0
|
| 304 |
+
skipped = 0
|
| 305 |
+
failed = 0
|
| 306 |
+
logs: List[str] = []
|
| 307 |
+
|
| 308 |
+
for idx in range(start_idx, end_idx):
|
| 309 |
+
entry = dataset_entries[idx]
|
| 310 |
+
try:
|
| 311 |
+
missing_fields = []
|
| 312 |
+
if not (entry.caption or "").strip():
|
| 313 |
+
missing_fields.append("caption")
|
| 314 |
+
if (not caption_only) and (not (entry.lyrics or "").strip()):
|
| 315 |
+
missing_fields.append("lyrics")
|
| 316 |
+
if entry.bpm is None:
|
| 317 |
+
missing_fields.append("bpm")
|
| 318 |
+
if not (entry.keyscale or "").strip():
|
| 319 |
+
missing_fields.append("keyscale")
|
| 320 |
+
if entry.duration is None:
|
| 321 |
+
missing_fields.append("duration")
|
| 322 |
+
|
| 323 |
+
# Skip only when every core field is already available.
|
| 324 |
+
if (not overwrite_existing) and (len(missing_fields) == 0):
|
| 325 |
+
skipped += 1
|
| 326 |
+
logs.append(f"[{idx}] Skipped (already fully labeled): {Path(entry.audio_path).name}")
|
| 327 |
+
continue
|
| 328 |
+
|
| 329 |
+
codes = handler.convert_src_audio_to_codes(entry.audio_path)
|
| 330 |
+
if not codes or codes.startswith("❌"):
|
| 331 |
+
failed += 1
|
| 332 |
+
logs.append(f"[{idx}] Failed to convert audio to codes: {Path(entry.audio_path).name}")
|
| 333 |
+
continue
|
| 334 |
+
|
| 335 |
+
result = understand_music(
|
| 336 |
+
llm_handler=llm_handler,
|
| 337 |
+
audio_codes=codes,
|
| 338 |
+
temperature=0.85,
|
| 339 |
+
use_constrained_decoding=True,
|
| 340 |
+
constrained_decoding_debug=False,
|
| 341 |
+
)
|
| 342 |
+
if not result.success:
|
| 343 |
+
failed += 1
|
| 344 |
+
logs.append(f"[{idx}] Failed to label: {Path(entry.audio_path).name} ({result.error or result.status_message})")
|
| 345 |
+
continue
|
| 346 |
+
|
| 347 |
+
# Update fields. If overwrite is false, fill only missing values.
|
| 348 |
+
if overwrite_existing or not (entry.caption or "").strip():
|
| 349 |
+
entry.caption = (result.caption or entry.caption or "").strip()
|
| 350 |
+
if not caption_only:
|
| 351 |
+
if overwrite_existing or not (entry.lyrics or "").strip():
|
| 352 |
+
entry.lyrics = (result.lyrics or entry.lyrics or "").strip()
|
| 353 |
+
if entry.bpm is None and result.bpm is not None:
|
| 354 |
+
entry.bpm = int(result.bpm)
|
| 355 |
+
if (not entry.keyscale) and result.keyscale:
|
| 356 |
+
entry.keyscale = result.keyscale
|
| 357 |
+
if (not entry.timesignature) and result.timesignature:
|
| 358 |
+
entry.timesignature = result.timesignature
|
| 359 |
+
if (not entry.vocal_language) and result.language:
|
| 360 |
+
entry.vocal_language = result.language
|
| 361 |
+
if entry.duration is None and result.duration is not None:
|
| 362 |
+
entry.duration = float(result.duration)
|
| 363 |
+
|
| 364 |
+
_write_entry_sidecar(entry)
|
| 365 |
+
updated += 1
|
| 366 |
+
logs.append(f"[{idx}] Labeled: {Path(entry.audio_path).name}")
|
| 367 |
+
except Exception as exc:
|
| 368 |
+
failed += 1
|
| 369 |
+
logs.append(f"[{idx}] Exception: {Path(entry.audio_path).name} ({exc})")
|
| 370 |
+
|
| 371 |
+
_auto_label_cursor = 0 if end_idx >= len(dataset_entries) else end_idx
|
| 372 |
+
mode = "caption-only" if caption_only else "caption+lyrics"
|
| 373 |
+
progress_msg = (
|
| 374 |
+
f"Processed batch {start_idx + 1}-{end_idx} of {len(dataset_entries)}. "
|
| 375 |
+
if len(dataset_entries) > 0 else ""
|
| 376 |
+
)
|
| 377 |
+
if _auto_label_cursor == 0 and len(dataset_entries) > 0:
|
| 378 |
+
progress_msg += "Reached end of dataset."
|
| 379 |
+
else:
|
| 380 |
+
progress_msg += f"Next start index: {_auto_label_cursor}."
|
| 381 |
+
summary = (
|
| 382 |
+
f"Auto-label ({mode}) complete. Updated={updated}, Skipped={skipped}, Failed={failed}. "
|
| 383 |
+
f"{progress_msg}"
|
| 384 |
+
)
|
| 385 |
+
detail = "\n".join(logs[-40:]) if logs else "No logs."
|
| 386 |
+
return summary, _rows_from_entries(dataset_entries), detail
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# ===========================================================================
|
| 390 |
+
# Tab 3 - Training
|
| 391 |
+
# ===========================================================================
|
| 392 |
+
|
| 393 |
+
def _run_training(config_dict: dict):
|
| 394 |
+
"""Target for the background training thread."""
|
| 395 |
+
global trainer, _training_status, _training_log, _training_started_at
|
| 396 |
+
_training_status = "running"
|
| 397 |
+
_training_log.clear()
|
| 398 |
+
_training_started_at = time.time()
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
cfg = LoRATrainConfig(**config_dict)
|
| 402 |
+
trainer = LoRATrainer(handler, cfg)
|
| 403 |
+
trainer.prepare()
|
| 404 |
+
_training_log.append(f"Training device: {handler.device}")
|
| 405 |
+
|
| 406 |
+
def _cb(step, total, loss, epoch):
|
| 407 |
+
elapsed = 0.0 if _training_started_at is None else max(0.0, time.time() - _training_started_at)
|
| 408 |
+
rate = (step / elapsed) if elapsed > 0 else 0.0
|
| 409 |
+
remaining = max(0, total - step)
|
| 410 |
+
eta_sec = (remaining / rate) if rate > 0 else -1.0
|
| 411 |
+
eta_msg = f"{eta_sec/60:.1f}m" if eta_sec >= 0 else "unknown"
|
| 412 |
+
msg = (
|
| 413 |
+
f"Step {step}/{total} Epoch {epoch+1} Loss {loss:.6f} "
|
| 414 |
+
f"Elapsed {elapsed/60:.1f}m ETA {eta_msg}"
|
| 415 |
+
)
|
| 416 |
+
_training_log.append(msg)
|
| 417 |
+
|
| 418 |
+
result = trainer.train(dataset_entries, progress_callback=_cb)
|
| 419 |
+
_training_log.append(result)
|
| 420 |
+
_training_status = "done"
|
| 421 |
+
except Exception as exc:
|
| 422 |
+
_training_log.append(f"ERROR: {exc}")
|
| 423 |
+
_training_status = "stopped"
|
| 424 |
+
logger.exception("Training failed")
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def start_training(
|
| 428 |
+
lora_rank, lora_alpha, lora_dropout,
|
| 429 |
+
lr, weight_decay, optimizer_name,
|
| 430 |
+
max_grad_norm, warmup_ratio, scheduler_name,
|
| 431 |
+
num_epochs, batch_size, grad_accum,
|
| 432 |
+
save_every, log_every, shift,
|
| 433 |
+
max_duration, output_dir, resume_from,
|
| 434 |
+
):
|
| 435 |
+
global _training_thread, _training_status
|
| 436 |
+
|
| 437 |
+
if handler.model is None:
|
| 438 |
+
return "Model not initialised. Go to Model Setup first."
|
| 439 |
+
if not dataset_entries:
|
| 440 |
+
return "No dataset loaded. Go to Dataset tab first."
|
| 441 |
+
if _training_status == "running":
|
| 442 |
+
return "Training already in progress."
|
| 443 |
+
|
| 444 |
+
config_dict = dict(
|
| 445 |
+
lora_rank=int(lora_rank),
|
| 446 |
+
lora_alpha=int(lora_alpha),
|
| 447 |
+
lora_dropout=float(lora_dropout),
|
| 448 |
+
learning_rate=float(lr),
|
| 449 |
+
weight_decay=float(weight_decay),
|
| 450 |
+
optimizer=optimizer_name,
|
| 451 |
+
max_grad_norm=float(max_grad_norm),
|
| 452 |
+
warmup_ratio=float(warmup_ratio),
|
| 453 |
+
scheduler=scheduler_name,
|
| 454 |
+
num_epochs=int(num_epochs),
|
| 455 |
+
batch_size=int(batch_size),
|
| 456 |
+
gradient_accumulation_steps=int(grad_accum),
|
| 457 |
+
save_every_n_epochs=int(save_every),
|
| 458 |
+
log_every_n_steps=int(log_every),
|
| 459 |
+
shift=float(shift),
|
| 460 |
+
max_duration_sec=float(max_duration),
|
| 461 |
+
output_dir=output_dir,
|
| 462 |
+
resume_from=(resume_from.strip() if isinstance(resume_from, str) and resume_from.strip() else None),
|
| 463 |
+
device=str(handler.device),
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
steps_per_epoch = math.ceil(len(dataset_entries) / int(batch_size))
|
| 467 |
+
total_steps = steps_per_epoch * int(num_epochs)
|
| 468 |
+
total_optim_steps = math.ceil(total_steps / int(grad_accum))
|
| 469 |
+
|
| 470 |
+
_training_thread = threading.Thread(target=_run_training, args=(config_dict,), daemon=True)
|
| 471 |
+
_training_thread.start()
|
| 472 |
+
return (
|
| 473 |
+
f"Training started on {handler.device}. "
|
| 474 |
+
f"Estimated optimiser steps: {total_optim_steps}."
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def stop_training():
|
| 479 |
+
global trainer, _training_status
|
| 480 |
+
if trainer:
|
| 481 |
+
trainer.request_stop()
|
| 482 |
+
_training_status = "stopped"
|
| 483 |
+
return "Stop requested - will finish current step."
|
| 484 |
+
return "No training in progress."
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def poll_training():
|
| 488 |
+
"""Return current log + loss chart data."""
|
| 489 |
+
log_text = "\n".join(_training_log[-50:]) if _training_log else "(no output yet)"
|
| 490 |
+
|
| 491 |
+
# Build loss curve data
|
| 492 |
+
chart_data = []
|
| 493 |
+
if trainer and trainer.loss_history:
|
| 494 |
+
chart_data = [[h["step"], h["loss"]] for h in trainer.loss_history]
|
| 495 |
+
|
| 496 |
+
status = _training_status
|
| 497 |
+
device_line = f"Device: {handler.device}"
|
| 498 |
+
if torch.cuda.is_available() and str(handler.device).startswith("cuda"):
|
| 499 |
+
try:
|
| 500 |
+
idx = torch.cuda.current_device()
|
| 501 |
+
name = torch.cuda.get_device_name(idx)
|
| 502 |
+
allocated = torch.cuda.memory_allocated(idx) / (1024 ** 3)
|
| 503 |
+
reserved = torch.cuda.memory_reserved(idx) / (1024 ** 3)
|
| 504 |
+
device_line = (
|
| 505 |
+
f"Device: {handler.device} ({name}) | "
|
| 506 |
+
f"VRAM allocated={allocated:.2f}GB reserved={reserved:.2f}GB"
|
| 507 |
+
)
|
| 508 |
+
except Exception:
|
| 509 |
+
pass
|
| 510 |
+
|
| 511 |
+
return f"Status: {status}\n{device_line}\n\n{log_text}", chart_data
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# ===========================================================================
|
| 515 |
+
# Tab 4 - Evaluation / A-B Test
|
| 516 |
+
# ===========================================================================
|
| 517 |
+
|
| 518 |
+
def list_adapters(output_dir: str):
|
| 519 |
+
adapters = LoRATrainer.list_adapters(output_dir)
|
| 520 |
+
return adapters if adapters else ["(none found)"]
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
@_gpu_callback
|
| 524 |
+
def load_adapter(adapter_path: str):
|
| 525 |
+
if not adapter_path or adapter_path == "(none found)":
|
| 526 |
+
return "Select a valid adapter path."
|
| 527 |
+
return handler.load_lora(adapter_path)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
@_gpu_callback
|
| 531 |
+
def unload_adapter():
|
| 532 |
+
return handler.unload_lora()
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def set_lora_scale(scale: float):
|
| 536 |
+
return handler.set_lora_scale(scale)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
@_gpu_callback
|
| 540 |
+
def generate_sample(
|
| 541 |
+
prompt: str,
|
| 542 |
+
lyrics: str,
|
| 543 |
+
duration: float,
|
| 544 |
+
bpm: int,
|
| 545 |
+
steps: int,
|
| 546 |
+
guidance: float,
|
| 547 |
+
seed: int,
|
| 548 |
+
use_lora: bool,
|
| 549 |
+
lora_scale: float,
|
| 550 |
+
):
|
| 551 |
+
"""Generate a single audio sample for evaluation."""
|
| 552 |
+
if handler.model is None:
|
| 553 |
+
return None, "Model not initialised."
|
| 554 |
+
|
| 555 |
+
# Toggle LoRA if loaded
|
| 556 |
+
if handler.lora_loaded:
|
| 557 |
+
handler.set_use_lora(use_lora)
|
| 558 |
+
if use_lora:
|
| 559 |
+
handler.set_lora_scale(lora_scale)
|
| 560 |
+
|
| 561 |
+
actual_seed = int(seed) if seed >= 0 else random.randint(0, 2**32 - 1)
|
| 562 |
+
|
| 563 |
+
result = handler.generate_music(
|
| 564 |
+
captions=prompt,
|
| 565 |
+
lyrics=lyrics,
|
| 566 |
+
bpm=bpm if bpm > 0 else None,
|
| 567 |
+
inference_steps=steps,
|
| 568 |
+
guidance_scale=guidance,
|
| 569 |
+
use_random_seed=False,
|
| 570 |
+
seed=actual_seed,
|
| 571 |
+
audio_duration=duration,
|
| 572 |
+
batch_size=1,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
if not result.get("success", False):
|
| 576 |
+
return None, result.get("error", "Generation failed.")
|
| 577 |
+
|
| 578 |
+
audios = result.get("audios", [])
|
| 579 |
+
if not audios:
|
| 580 |
+
return None, "No audio produced."
|
| 581 |
+
|
| 582 |
+
# Save to temp file
|
| 583 |
+
audio_data = audios[0]
|
| 584 |
+
wav_tensor = audio_data.get("tensor")
|
| 585 |
+
sr = audio_data.get("sample_rate", 48000)
|
| 586 |
+
|
| 587 |
+
if wav_tensor is None:
|
| 588 |
+
path = audio_data.get("path")
|
| 589 |
+
if path and os.path.exists(path):
|
| 590 |
+
return path, f"Generated (from file), seed={actual_seed}."
|
| 591 |
+
return None, "No audio tensor."
|
| 592 |
+
|
| 593 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
| 594 |
+
audio_saver.save_audio(wav_tensor, tmp.name, sample_rate=sr)
|
| 595 |
+
return tmp.name, f"Generated successfully, seed={actual_seed}."
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
@_gpu_callback
|
| 599 |
+
def ab_test(
|
| 600 |
+
prompt, lyrics, duration, bpm, steps, guidance, seed,
|
| 601 |
+
lora_scale_b,
|
| 602 |
+
):
|
| 603 |
+
"""Generate two samples: A = base, B = LoRA at given scale."""
|
| 604 |
+
resolved_seed = int(seed) if seed >= 0 else random.randint(0, 2**32 - 1)
|
| 605 |
+
results = {}
|
| 606 |
+
for label, use, scale in [("A (base)", False, 0.0), ("B (LoRA)", True, lora_scale_b)]:
|
| 607 |
+
path, msg = generate_sample(
|
| 608 |
+
prompt, lyrics, duration, bpm, steps, guidance, resolved_seed,
|
| 609 |
+
use_lora=use, lora_scale=scale,
|
| 610 |
+
)
|
| 611 |
+
results[label] = (path, msg)
|
| 612 |
+
|
| 613 |
+
return (
|
| 614 |
+
results["A (base)"][0],
|
| 615 |
+
results["A (base)"][1],
|
| 616 |
+
results["B (LoRA)"][0],
|
| 617 |
+
results["B (LoRA)"][1],
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
# ===========================================================================
|
| 622 |
+
# Build the Gradio App
|
| 623 |
+
# ===========================================================================
|
| 624 |
+
|
| 625 |
+
def get_workflow_status():
|
| 626 |
+
model_is_ready = (handler.model is not None) or _model_init_ok
|
| 627 |
+
model_ready = "Ready" if model_is_ready else "Not initialized"
|
| 628 |
+
tracks = len(dataset_entries)
|
| 629 |
+
training_state = _training_status
|
| 630 |
+
lora_status = handler.get_lora_status() if handler.model is not None else {"loaded": False, "active": False, "scale": 1.0}
|
| 631 |
+
init_note = ""
|
| 632 |
+
if IS_SPACE and _model_init_ok and handler.model is None:
|
| 633 |
+
init_note = " (Zero GPU callback context)"
|
| 634 |
+
return (
|
| 635 |
+
f"Model: {model_ready}{init_note}\n"
|
| 636 |
+
f"Tracks Loaded: {tracks}\n"
|
| 637 |
+
f"Training: {training_state}\n"
|
| 638 |
+
f"LoRA Loaded: {lora_status.get('loaded', False)}\n"
|
| 639 |
+
f"LoRA Active: {lora_status.get('active', False)}\n"
|
| 640 |
+
f"LoRA Scale: {lora_status.get('scale', 1.0)}"
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def init_model_and_status(
|
| 645 |
+
model_name: str,
|
| 646 |
+
device: str,
|
| 647 |
+
offload_cpu: bool,
|
| 648 |
+
offload_dit_cpu: bool,
|
| 649 |
+
):
|
| 650 |
+
status = init_model(model_name, device, offload_cpu, offload_dit_cpu)
|
| 651 |
+
return status, get_workflow_status()
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def build_ui():
|
| 655 |
+
available_models = get_available_models()
|
| 656 |
+
|
| 657 |
+
with gr.Blocks(title="ACE-Step 1.5 LoRA Studio", theme=gr.themes.Soft()) as app:
|
| 658 |
+
gr.Markdown(
|
| 659 |
+
"# ACE-Step 1.5 LoRA Studio\n"
|
| 660 |
+
"Use this guided workflow from left to right.\n\n"
|
| 661 |
+
"**Step 1:** Initialize model \n"
|
| 662 |
+
"**Step 2:** Load dataset \n"
|
| 663 |
+
"**Step 3:** Start training \n"
|
| 664 |
+
"**Step 4:** Evaluate adapter"
|
| 665 |
+
)
|
| 666 |
+
with gr.Row():
|
| 667 |
+
workflow_status = gr.Textbox(label="Workflow Status", value=get_workflow_status(), lines=6, interactive=False)
|
| 668 |
+
refresh_status_btn = gr.Button("Refresh Status")
|
| 669 |
+
refresh_status_btn.click(get_workflow_status, outputs=workflow_status, api_name="workflow_status")
|
| 670 |
+
|
| 671 |
+
# ---- Step 1 ----
|
| 672 |
+
with gr.Tab("Step 1 - Initialize Model"):
|
| 673 |
+
gr.Markdown(
|
| 674 |
+
"### Instructions\n"
|
| 675 |
+
"1. Pick a model (`acestep-v15-base` recommended for LoRA).\n"
|
| 676 |
+
"2. Keep device on `auto` unless you need manual override.\n"
|
| 677 |
+
"3. Click **Initialize Model** and confirm status is success."
|
| 678 |
+
)
|
| 679 |
+
with gr.Row():
|
| 680 |
+
model_dd = gr.Dropdown(
|
| 681 |
+
choices=available_models,
|
| 682 |
+
value=available_models[0] if available_models else None,
|
| 683 |
+
label="DiT Model",
|
| 684 |
+
)
|
| 685 |
+
device_dd = gr.Dropdown(
|
| 686 |
+
choices=["auto", "cuda", "mps", "cpu"],
|
| 687 |
+
value="auto",
|
| 688 |
+
label="Device",
|
| 689 |
+
)
|
| 690 |
+
with gr.Row():
|
| 691 |
+
offload_cb = gr.Checkbox(label="Offload To CPU (optional)", value=False)
|
| 692 |
+
offload_dit_cb = gr.Checkbox(label="Offload DiT To CPU (optional)", value=False)
|
| 693 |
+
init_btn = gr.Button("Initialize Model", variant="primary")
|
| 694 |
+
init_out = gr.Textbox(label="Initialization Output", lines=8, interactive=False)
|
| 695 |
+
init_btn.click(
|
| 696 |
+
init_model_and_status,
|
| 697 |
+
[model_dd, device_dd, offload_cb, offload_dit_cb],
|
| 698 |
+
[init_out, workflow_status],
|
| 699 |
+
api_name="init_model",
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# ---- Step 2 ----
|
| 703 |
+
with gr.Tab("Step 2 - Load Dataset"):
|
| 704 |
+
gr.Markdown(
|
| 705 |
+
"### Instructions\n"
|
| 706 |
+
"1. Either scan a folder or drag/drop audio files (+ optional .json sidecars).\n"
|
| 707 |
+
"2. Confirm tracks appear in the table.\n"
|
| 708 |
+
"3. Optional: run Auto-Label All to fill caption/lyrics/metas.\n"
|
| 709 |
+
"4. Optional: edit metadata manually and save sidecar JSON."
|
| 710 |
+
)
|
| 711 |
+
with gr.Row():
|
| 712 |
+
folder_input = gr.Textbox(label="Dataset Folder Path", placeholder="e.g. ./dataset_inbox")
|
| 713 |
+
scan_btn = gr.Button("Scan Folder")
|
| 714 |
+
with gr.Row():
|
| 715 |
+
upload_files = gr.Files(
|
| 716 |
+
label="Drag/Drop Audio Files (+ Optional JSON Sidecars)",
|
| 717 |
+
file_count="multiple",
|
| 718 |
+
file_types=["audio", ".json"],
|
| 719 |
+
type="filepath",
|
| 720 |
+
)
|
| 721 |
+
upload_btn = gr.Button("Load Dropped Files")
|
| 722 |
+
scan_msg = gr.Textbox(label="Dataset Result", interactive=False)
|
| 723 |
+
dataset_table = gr.Dataframe(
|
| 724 |
+
headers=["File", "Duration", "Caption", "Lyrics", "Language"],
|
| 725 |
+
datatype=["str", "str", "str", "str", "str"],
|
| 726 |
+
label="Tracks",
|
| 727 |
+
interactive=False,
|
| 728 |
+
)
|
| 729 |
+
scan_btn.click(
|
| 730 |
+
scan_folder,
|
| 731 |
+
folder_input,
|
| 732 |
+
[scan_msg, dataset_table],
|
| 733 |
+
api_name="scan_folder",
|
| 734 |
+
)
|
| 735 |
+
upload_btn.click(
|
| 736 |
+
load_uploaded,
|
| 737 |
+
upload_files,
|
| 738 |
+
[scan_msg, dataset_table],
|
| 739 |
+
api_name="load_uploaded",
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
with gr.Accordion("Auto-Label (ACE audio understanding)", open=False):
|
| 743 |
+
gr.Markdown(
|
| 744 |
+
"Auto-label uses ACE: audio -> semantic codes -> metadata/lyrics.\n"
|
| 745 |
+
"Initialize LM first, then run Auto-Label All.\n"
|
| 746 |
+
"Use Caption-Only if your dataset has no lyrics.\n"
|
| 747 |
+
"On Zero GPU, process in smaller batches and click Auto-Label All repeatedly."
|
| 748 |
+
)
|
| 749 |
+
with gr.Row():
|
| 750 |
+
lm_model_dd = gr.Dropdown(
|
| 751 |
+
choices=["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
|
| 752 |
+
value="acestep-5Hz-lm-0.6B",
|
| 753 |
+
label="Auto-Label LM Model",
|
| 754 |
+
)
|
| 755 |
+
lm_backend_dd = gr.Dropdown(
|
| 756 |
+
choices=["pt", "vllm", "mlx"],
|
| 757 |
+
value="pt",
|
| 758 |
+
label="LM Backend",
|
| 759 |
+
)
|
| 760 |
+
lm_device_dd = gr.Dropdown(
|
| 761 |
+
choices=["auto", "cuda", "mps", "xpu", "cpu"],
|
| 762 |
+
value="auto",
|
| 763 |
+
label="LM Device",
|
| 764 |
+
)
|
| 765 |
+
with gr.Row():
|
| 766 |
+
init_lm_btn = gr.Button("Initialize Auto-Label LM")
|
| 767 |
+
overwrite_cb = gr.Checkbox(label="Overwrite Existing Caption/Lyrics", value=False)
|
| 768 |
+
caption_only_cb = gr.Checkbox(label="Caption-Only (Skip Lyrics)", value=True)
|
| 769 |
+
auto_label_btn = gr.Button("Auto-Label All", variant="primary")
|
| 770 |
+
with gr.Row():
|
| 771 |
+
max_files_per_run = gr.Slider(1, 25, value=6, step=1, label="Files Per Run (Zero GPU Safe)")
|
| 772 |
+
reset_cursor_cb = gr.Checkbox(label="Restart From First Track", value=False)
|
| 773 |
+
lm_init_status = gr.Textbox(label="Auto-Label LM Status", lines=5, interactive=False)
|
| 774 |
+
auto_label_status = gr.Textbox(label="Auto-Label Summary", interactive=False)
|
| 775 |
+
auto_label_log = gr.Textbox(label="Auto-Label Log", lines=8, interactive=False)
|
| 776 |
+
init_lm_btn.click(
|
| 777 |
+
init_auto_label_lm,
|
| 778 |
+
[lm_model_dd, lm_backend_dd, lm_device_dd],
|
| 779 |
+
lm_init_status,
|
| 780 |
+
api_name="init_auto_label_lm",
|
| 781 |
+
)
|
| 782 |
+
auto_label_btn.click(
|
| 783 |
+
auto_label_all,
|
| 784 |
+
[overwrite_cb, caption_only_cb, max_files_per_run, reset_cursor_cb],
|
| 785 |
+
[auto_label_status, dataset_table, auto_label_log],
|
| 786 |
+
api_name="auto_label_all",
|
| 787 |
+
)
|
| 788 |
+
|
| 789 |
+
with gr.Accordion("Optional: Edit Metadata Sidecar", open=False):
|
| 790 |
+
with gr.Row():
|
| 791 |
+
edit_idx = gr.Number(label="Track Index (0-based)", value=0, precision=0)
|
| 792 |
+
edit_caption = gr.Textbox(label="Caption")
|
| 793 |
+
edit_lyrics = gr.Textbox(label="Lyrics", lines=3)
|
| 794 |
+
with gr.Row():
|
| 795 |
+
edit_bpm = gr.Textbox(label="BPM", placeholder="e.g. 120")
|
| 796 |
+
edit_key = gr.Textbox(label="Key/Scale", placeholder="e.g. Am")
|
| 797 |
+
edit_lang = gr.Textbox(label="Language", value="en")
|
| 798 |
+
save_btn = gr.Button("Save Sidecar")
|
| 799 |
+
save_msg = gr.Textbox(label="Save Result", interactive=False)
|
| 800 |
+
save_btn.click(
|
| 801 |
+
save_sidecar,
|
| 802 |
+
[edit_idx, edit_caption, edit_lyrics, edit_bpm, edit_key, edit_lang],
|
| 803 |
+
save_msg,
|
| 804 |
+
api_name="save_sidecar",
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# ---- Step 3 ----
|
| 808 |
+
with gr.Tab("Step 3 - Train LoRA"):
|
| 809 |
+
gr.Markdown(
|
| 810 |
+
"### Instructions\n"
|
| 811 |
+
"1. Keep default settings for first run.\n"
|
| 812 |
+
"2. Set output directory (defaults are good).\n"
|
| 813 |
+
"3. Click **Start Training** and monitor logs/loss.\n"
|
| 814 |
+
"4. Use **Stop Training** for graceful stop."
|
| 815 |
+
)
|
| 816 |
+
with gr.Row():
|
| 817 |
+
t_epochs = gr.Slider(1, 500, value=50, step=1, label="Epochs")
|
| 818 |
+
t_bs = gr.Slider(1, 8, value=1, step=1, label="Batch Size")
|
| 819 |
+
t_accum = gr.Slider(1, 16, value=1, step=1, label="Grad Accumulation")
|
| 820 |
+
with gr.Row():
|
| 821 |
+
t_outdir = gr.Textbox(label="Output Directory", value=DEFAULT_OUTPUT_DIR)
|
| 822 |
+
t_resume = gr.Textbox(label="Resume From Adapter Directory (optional)", value="")
|
| 823 |
+
|
| 824 |
+
with gr.Accordion("Advanced Training Settings (optional)", open=False):
|
| 825 |
+
with gr.Row():
|
| 826 |
+
t_rank = gr.Slider(4, 256, value=64, step=4, label="LoRA Rank")
|
| 827 |
+
t_alpha = gr.Slider(4, 256, value=64, step=4, label="LoRA Alpha")
|
| 828 |
+
t_dropout = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="LoRA Dropout")
|
| 829 |
+
with gr.Row():
|
| 830 |
+
t_lr = gr.Number(label="Learning Rate", value=1e-4)
|
| 831 |
+
t_wd = gr.Number(label="Weight Decay", value=0.01)
|
| 832 |
+
t_optim = gr.Dropdown(["adamw", "adamw_8bit"], value="adamw_8bit", label="Optimizer")
|
| 833 |
+
with gr.Row():
|
| 834 |
+
t_grad_norm = gr.Number(label="Max Grad Norm", value=1.0)
|
| 835 |
+
t_warmup = gr.Number(label="Warmup Ratio", value=0.03)
|
| 836 |
+
t_sched = gr.Dropdown(
|
| 837 |
+
["constant_with_warmup", "linear", "cosine"],
|
| 838 |
+
value="constant_with_warmup",
|
| 839 |
+
label="Scheduler",
|
| 840 |
+
)
|
| 841 |
+
with gr.Row():
|
| 842 |
+
t_save = gr.Slider(1, 100, value=10, step=1, label="Save Every N Epochs")
|
| 843 |
+
t_log = gr.Slider(1, 100, value=5, step=1, label="Log Every N Steps")
|
| 844 |
+
t_shift = gr.Number(label="Timestep Shift", value=3.0)
|
| 845 |
+
t_maxdur = gr.Number(label="Max Audio Duration (s)", value=240)
|
| 846 |
+
|
| 847 |
+
with gr.Row():
|
| 848 |
+
train_btn = gr.Button("Start Training", variant="primary")
|
| 849 |
+
stop_btn = gr.Button("Stop Training", variant="stop")
|
| 850 |
+
poll_btn = gr.Button("Refresh Log")
|
| 851 |
+
|
| 852 |
+
train_status = gr.Textbox(label="Training Log", lines=12, interactive=False)
|
| 853 |
+
loss_chart = gr.LinePlot(
|
| 854 |
+
x="Step",
|
| 855 |
+
y="Loss",
|
| 856 |
+
title="Training Loss",
|
| 857 |
+
x_title="Step",
|
| 858 |
+
y_title="Loss",
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
train_btn.click(
|
| 862 |
+
start_training,
|
| 863 |
+
[
|
| 864 |
+
t_rank, t_alpha, t_dropout,
|
| 865 |
+
t_lr, t_wd, t_optim,
|
| 866 |
+
t_grad_norm, t_warmup, t_sched,
|
| 867 |
+
t_epochs, t_bs, t_accum,
|
| 868 |
+
t_save, t_log, t_shift,
|
| 869 |
+
t_maxdur, t_outdir, t_resume,
|
| 870 |
+
],
|
| 871 |
+
train_status,
|
| 872 |
+
api_name="start_training",
|
| 873 |
+
)
|
| 874 |
+
stop_btn.click(stop_training, outputs=train_status, api_name="stop_training")
|
| 875 |
+
|
| 876 |
+
def _poll_and_format():
|
| 877 |
+
log_text, chart_data = poll_training()
|
| 878 |
+
if chart_data:
|
| 879 |
+
import pandas as pd
|
| 880 |
+
df = pd.DataFrame(chart_data, columns=["Step", "Loss"])
|
| 881 |
+
else:
|
| 882 |
+
import pandas as pd
|
| 883 |
+
df = pd.DataFrame({"Step": [], "Loss": []})
|
| 884 |
+
return log_text, df
|
| 885 |
+
|
| 886 |
+
poll_btn.click(_poll_and_format, outputs=[train_status, loss_chart], api_name="poll_training")
|
| 887 |
+
|
| 888 |
+
# ---- Step 4 ----
|
| 889 |
+
with gr.Tab("Step 4 - Evaluate"):
|
| 890 |
+
gr.Markdown(
|
| 891 |
+
"### Instructions\n"
|
| 892 |
+
"1. Refresh adapter list and load a trained adapter.\n"
|
| 893 |
+
"2. Run single generation or A/B test.\n"
|
| 894 |
+
"3. Use same seed for fair comparison."
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
with gr.Accordion("Adapter Management", open=True):
|
| 898 |
+
with gr.Row():
|
| 899 |
+
adapter_dir = gr.Textbox(label="Adapters Directory", value=DEFAULT_OUTPUT_DIR)
|
| 900 |
+
refresh_btn = gr.Button("Refresh List")
|
| 901 |
+
adapter_dd = gr.Dropdown(label="Select Adapter", choices=[])
|
| 902 |
+
with gr.Row():
|
| 903 |
+
load_btn = gr.Button("Load Adapter", variant="primary")
|
| 904 |
+
unload_btn = gr.Button("Unload Adapter")
|
| 905 |
+
adapter_status = gr.Textbox(label="Adapter Status", interactive=False)
|
| 906 |
+
|
| 907 |
+
def _refresh(d):
|
| 908 |
+
adapters = list_adapters(d)
|
| 909 |
+
return gr.update(choices=adapters, value=adapters[0] if adapters else None)
|
| 910 |
+
|
| 911 |
+
refresh_btn.click(_refresh, adapter_dir, adapter_dd, api_name="list_adapters")
|
| 912 |
+
load_btn.click(load_adapter, adapter_dd, adapter_status, api_name="load_adapter")
|
| 913 |
+
unload_btn.click(unload_adapter, outputs=adapter_status, api_name="unload_adapter")
|
| 914 |
+
|
| 915 |
+
with gr.Accordion("Generation Settings", open=True):
|
| 916 |
+
with gr.Row():
|
| 917 |
+
eval_prompt = gr.Textbox(label="Prompt / Caption", lines=2, placeholder="upbeat pop rock with electric guitar")
|
| 918 |
+
eval_lyrics = gr.Textbox(label="Lyrics", lines=3, placeholder="[Instrumental]")
|
| 919 |
+
with gr.Row():
|
| 920 |
+
eval_dur = gr.Slider(10, 300, value=30, step=5, label="Duration (s)")
|
| 921 |
+
eval_bpm = gr.Number(label="BPM (0 = auto)", value=0)
|
| 922 |
+
eval_steps = gr.Slider(1, 100, value=8, step=1, label="Inference Steps")
|
| 923 |
+
with gr.Row():
|
| 924 |
+
eval_guidance = gr.Slider(1.0, 15.0, value=7.0, step=0.5, label="Guidance Scale")
|
| 925 |
+
eval_seed = gr.Number(label="Seed (-1 = random)", value=-1)
|
| 926 |
+
|
| 927 |
+
with gr.Row():
|
| 928 |
+
sg_use_lora = gr.Checkbox(label="Use LoRA", value=True)
|
| 929 |
+
sg_scale = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="LoRA Scale")
|
| 930 |
+
sg_btn = gr.Button("Generate", variant="primary")
|
| 931 |
+
sg_audio = gr.Audio(label="Single Output", type="filepath")
|
| 932 |
+
sg_msg = gr.Textbox(label="Generation Status", interactive=False)
|
| 933 |
+
sg_btn.click(
|
| 934 |
+
generate_sample,
|
| 935 |
+
[eval_prompt, eval_lyrics, eval_dur, eval_bpm, eval_steps, eval_guidance, eval_seed, sg_use_lora, sg_scale],
|
| 936 |
+
[sg_audio, sg_msg],
|
| 937 |
+
api_name="generate_sample",
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
gr.Markdown("#### A/B Test (Base vs LoRA)")
|
| 941 |
+
with gr.Row():
|
| 942 |
+
ab_scale = gr.Slider(0.0, 1.0, value=1.0, step=0.05, label="LoRA Scale for B")
|
| 943 |
+
ab_btn = gr.Button("Run A/B Test")
|
| 944 |
+
with gr.Row():
|
| 945 |
+
ab_audio_a = gr.Audio(label="A - Base", type="filepath")
|
| 946 |
+
ab_audio_b = gr.Audio(label="B - Base + LoRA", type="filepath")
|
| 947 |
+
with gr.Row():
|
| 948 |
+
ab_msg_a = gr.Textbox(label="Status A", interactive=False)
|
| 949 |
+
ab_msg_b = gr.Textbox(label="Status B", interactive=False)
|
| 950 |
+
|
| 951 |
+
ab_btn.click(
|
| 952 |
+
ab_test,
|
| 953 |
+
[eval_prompt, eval_lyrics, eval_dur, eval_bpm, eval_steps, eval_guidance, eval_seed, ab_scale],
|
| 954 |
+
[ab_audio_a, ab_msg_a, ab_audio_b, ab_msg_b],
|
| 955 |
+
api_name="ab_test",
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
app.queue(default_concurrency_limit=1)
|
| 959 |
+
return app
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
# ===========================================================================
|
| 963 |
+
# Entry point
|
| 964 |
+
# ===========================================================================
|
| 965 |
+
|
| 966 |
+
if __name__ == "__main__":
|
| 967 |
+
app = build_ui()
|
| 968 |
+
app.launch(
|
| 969 |
+
server_name="0.0.0.0",
|
| 970 |
+
server_port=7860,
|
| 971 |
+
share=False,
|
| 972 |
+
)
|
| 973 |
+
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libsndfile1
|
requirements.txt
CHANGED
|
@@ -14,3 +14,7 @@ vector-quantize-pytorch
|
|
| 14 |
PyYAML
|
| 15 |
modelscope
|
| 16 |
filelock>=3.13.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
PyYAML
|
| 15 |
modelscope
|
| 16 |
filelock>=3.13.0
|
| 17 |
+
peft>=0.11.0
|
| 18 |
+
gradio>=4.0.0
|
| 19 |
+
pandas
|
| 20 |
+
bitsandbytes
|
scripts/endpoint/generate_interactive.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import base64
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from urllib.error import HTTPError, URLError
|
| 9 |
+
from urllib.request import Request, urlopen
|
| 10 |
+
|
| 11 |
+
DEFAULT_URL = "https://your-endpoint-url.endpoints.huggingface.cloud"
|
| 12 |
+
DEFAULT_SAMPLE_RATE = 44100
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def read_dotenv_value(key: str, dotenv_path: str = ".env") -> str:
|
| 16 |
+
path = Path(dotenv_path)
|
| 17 |
+
if not path.exists():
|
| 18 |
+
return ""
|
| 19 |
+
for raw in path.read_text(encoding="utf-8").splitlines():
|
| 20 |
+
line = raw.strip()
|
| 21 |
+
if not line or line.startswith("#") or "=" not in line:
|
| 22 |
+
continue
|
| 23 |
+
k, v = line.split("=", 1)
|
| 24 |
+
if k.strip() == key:
|
| 25 |
+
return v.strip().strip('"').strip("'")
|
| 26 |
+
return ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def prompt_text(label: str, default: str = "", required: bool = False) -> str:
|
| 30 |
+
while True:
|
| 31 |
+
suffix = f" [{default}]" if default else ""
|
| 32 |
+
value = input(f"{label}{suffix}: ").strip()
|
| 33 |
+
if not value:
|
| 34 |
+
value = default
|
| 35 |
+
if value or not required:
|
| 36 |
+
return value
|
| 37 |
+
print("Value required.")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def prompt_int(label: str, default: int | None = None, allow_blank: bool = False) -> int | None:
|
| 41 |
+
while True:
|
| 42 |
+
default_str = "" if default is None else str(default)
|
| 43 |
+
value = prompt_text(label, default_str, required=not allow_blank)
|
| 44 |
+
if not value and allow_blank:
|
| 45 |
+
return None
|
| 46 |
+
try:
|
| 47 |
+
return int(value)
|
| 48 |
+
except ValueError:
|
| 49 |
+
print("Enter a valid integer.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def prompt_float(label: str, default: float) -> float:
|
| 53 |
+
while True:
|
| 54 |
+
value = prompt_text(label, str(default), required=True)
|
| 55 |
+
try:
|
| 56 |
+
return float(value)
|
| 57 |
+
except ValueError:
|
| 58 |
+
print("Enter a valid number.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def prompt_yes_no(label: str, default: bool) -> bool:
|
| 62 |
+
default_text = "y" if default else "n"
|
| 63 |
+
while True:
|
| 64 |
+
value = prompt_text(f"{label} (y/n)", default_text, required=True).lower()
|
| 65 |
+
if value in {"y", "yes", "1", "true", "t"}:
|
| 66 |
+
return True
|
| 67 |
+
if value in {"n", "no", "0", "false", "f"}:
|
| 68 |
+
return False
|
| 69 |
+
print("Please answer y or n.")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def prompt_multiline(label: str, end_token: str = "END") -> str:
|
| 73 |
+
print(label)
|
| 74 |
+
print(f"Finish lyrics by typing {end_token} on its own line.")
|
| 75 |
+
lines: list[str] = []
|
| 76 |
+
while True:
|
| 77 |
+
line = input()
|
| 78 |
+
if line.strip() == end_token:
|
| 79 |
+
break
|
| 80 |
+
lines.append(line)
|
| 81 |
+
return "\n".join(lines).strip()
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def prompt_lyrics_optional() -> str:
|
| 85 |
+
use_lyrics = prompt_yes_no("Add custom lyrics", True)
|
| 86 |
+
if not use_lyrics:
|
| 87 |
+
return ""
|
| 88 |
+
return prompt_multiline("Paste lyrics (or just type END for none)")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def send_request(url: str, token: str, payload: dict) -> dict:
|
| 92 |
+
data = json.dumps(payload).encode("utf-8")
|
| 93 |
+
req = Request(
|
| 94 |
+
url=url,
|
| 95 |
+
data=data,
|
| 96 |
+
method="POST",
|
| 97 |
+
headers={
|
| 98 |
+
"Authorization": f"Bearer {token}",
|
| 99 |
+
"Content-Type": "application/json",
|
| 100 |
+
},
|
| 101 |
+
)
|
| 102 |
+
try:
|
| 103 |
+
with urlopen(req, timeout=3600) as resp:
|
| 104 |
+
body = resp.read().decode("utf-8")
|
| 105 |
+
return json.loads(body)
|
| 106 |
+
except HTTPError as e:
|
| 107 |
+
text = e.read().decode("utf-8", errors="replace")
|
| 108 |
+
raise RuntimeError(f"HTTP {e.code}: {text}") from e
|
| 109 |
+
except URLError as e:
|
| 110 |
+
raise RuntimeError(f"Network error: {e}") from e
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def resolve_token(cli_token: str) -> str:
|
| 114 |
+
if cli_token:
|
| 115 |
+
return cli_token
|
| 116 |
+
env_token = os.getenv("HF_TOKEN") or os.getenv("hf_token")
|
| 117 |
+
if env_token:
|
| 118 |
+
return env_token
|
| 119 |
+
dotenv_token = read_dotenv_value("hf_token") or read_dotenv_value("HF_TOKEN")
|
| 120 |
+
return dotenv_token
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def main() -> int:
|
| 124 |
+
parser = argparse.ArgumentParser(description="Interactive ACE-Step endpoint generator")
|
| 125 |
+
parser.add_argument("--url", default=os.getenv("HF_ENDPOINT_URL", DEFAULT_URL), help="Inference endpoint URL")
|
| 126 |
+
parser.add_argument("--token", default="", help="HF token. If omitted, uses env/.env")
|
| 127 |
+
parser.add_argument("--prompt", default="", help="Initial prompt")
|
| 128 |
+
parser.add_argument("--out-file", default="", help="Output WAV file path")
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--advanced",
|
| 131 |
+
action="store_true",
|
| 132 |
+
help="Ask advanced generation options (seed/guidance/steps/sample-rate/LM).",
|
| 133 |
+
)
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
print("=== ACE-Step Interactive Generation ===")
|
| 137 |
+
|
| 138 |
+
token = resolve_token(args.token)
|
| 139 |
+
if not token:
|
| 140 |
+
print("No token found. Set HF_TOKEN or hf_token in .env, or pass --token.")
|
| 141 |
+
return 1
|
| 142 |
+
|
| 143 |
+
url = prompt_text("Endpoint URL", args.url, required=True)
|
| 144 |
+
music_prompt = prompt_text("Music prompt", args.prompt, required=True)
|
| 145 |
+
bpm = prompt_int("BPM (blank for auto)", None, allow_blank=True)
|
| 146 |
+
duration_sec = prompt_int("Duration seconds", 120)
|
| 147 |
+
instrumental = prompt_yes_no("Instrumental (no vocals)", False)
|
| 148 |
+
lyrics = "" if instrumental else prompt_lyrics_optional()
|
| 149 |
+
|
| 150 |
+
# Quality-first defaults: use SFT + LM path configured on the endpoint.
|
| 151 |
+
sample_rate = DEFAULT_SAMPLE_RATE
|
| 152 |
+
steps = 50
|
| 153 |
+
guidance_scale = 7.0
|
| 154 |
+
seed = 42
|
| 155 |
+
use_lm = True
|
| 156 |
+
allow_fallback = False
|
| 157 |
+
simple_prompt = False
|
| 158 |
+
|
| 159 |
+
if args.advanced:
|
| 160 |
+
print("\nAdvanced options:")
|
| 161 |
+
sample_rate = prompt_int("Sample rate", DEFAULT_SAMPLE_RATE)
|
| 162 |
+
steps = prompt_int("Steps", 50)
|
| 163 |
+
guidance_scale = prompt_float("Guidance scale", 7.0)
|
| 164 |
+
seed = prompt_int("Seed", 42)
|
| 165 |
+
use_lm = prompt_yes_no("Use LM planning (higher quality, slower)", True)
|
| 166 |
+
allow_fallback = prompt_yes_no("Allow fallback sine audio", False)
|
| 167 |
+
|
| 168 |
+
default_out = args.out_file or f"music_{int(time.time())}.wav"
|
| 169 |
+
out_file = prompt_text("Output file", default_out, required=True)
|
| 170 |
+
|
| 171 |
+
inputs = {
|
| 172 |
+
"prompt": music_prompt,
|
| 173 |
+
"duration_sec": duration_sec,
|
| 174 |
+
"sample_rate": sample_rate,
|
| 175 |
+
"seed": seed,
|
| 176 |
+
"guidance_scale": guidance_scale,
|
| 177 |
+
"steps": steps,
|
| 178 |
+
"use_lm": use_lm,
|
| 179 |
+
"simple_prompt": simple_prompt,
|
| 180 |
+
"instrumental": instrumental,
|
| 181 |
+
"allow_fallback": allow_fallback,
|
| 182 |
+
}
|
| 183 |
+
if bpm is not None:
|
| 184 |
+
inputs["bpm"] = bpm
|
| 185 |
+
if lyrics:
|
| 186 |
+
inputs["lyrics"] = lyrics
|
| 187 |
+
|
| 188 |
+
payload = {"inputs": inputs}
|
| 189 |
+
|
| 190 |
+
print("\nSending request...")
|
| 191 |
+
try:
|
| 192 |
+
response = send_request(url, token, payload)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
print(f"Request failed: {e}")
|
| 195 |
+
return 1
|
| 196 |
+
|
| 197 |
+
print("Response summary:")
|
| 198 |
+
print(json.dumps({
|
| 199 |
+
"used_fallback": response.get("used_fallback"),
|
| 200 |
+
"model_loaded": response.get("model_loaded"),
|
| 201 |
+
"model_error": response.get("model_error"),
|
| 202 |
+
"sample_rate": response.get("sample_rate"),
|
| 203 |
+
"duration_sec": response.get("duration_sec"),
|
| 204 |
+
}, indent=2))
|
| 205 |
+
|
| 206 |
+
if response.get("error"):
|
| 207 |
+
print(f"Endpoint error: {response['error']}")
|
| 208 |
+
return 1
|
| 209 |
+
|
| 210 |
+
audio_b64 = response.get("audio_base64_wav")
|
| 211 |
+
if not audio_b64:
|
| 212 |
+
print("No audio_base64_wav in response.")
|
| 213 |
+
return 1
|
| 214 |
+
|
| 215 |
+
audio_bytes = base64.b64decode(audio_b64)
|
| 216 |
+
Path(out_file).write_bytes(audio_bytes)
|
| 217 |
+
print(f"Saved audio: {out_file}")
|
| 218 |
+
|
| 219 |
+
return 0
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
raise SystemExit(main())
|
scripts/endpoint/test.bat
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
setlocal
|
| 3 |
+
|
| 4 |
+
set "PROMPT=%*"
|
| 5 |
+
if not defined PROMPT set "PROMPT=upbeat pop rap with emotional guitar"
|
| 6 |
+
set "PROMPT=%PROMPT:"=%"
|
| 7 |
+
|
| 8 |
+
powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0test.ps1" -Prompt "%PROMPT%" -SimplePrompt -DurationSec 12 -SampleRate 44100 -Seed 42 -GuidanceScale 7.0 -Steps 50 -UseLM 1 -OutFile "test_music.wav"
|
| 9 |
+
|
| 10 |
+
if errorlevel 1 (
|
| 11 |
+
echo Request failed.
|
| 12 |
+
exit /b 1
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
echo Done.
|
| 16 |
+
endlocal
|
scripts/endpoint/test.ps1
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
param(
|
| 2 |
+
[string]$Token = "",
|
| 3 |
+
[string]$Url = "",
|
| 4 |
+
[string]$Prompt = "upbeat pop rap with emotional guitar",
|
| 5 |
+
[string]$Lyrics = "",
|
| 6 |
+
[int]$DurationSec = 3,
|
| 7 |
+
[int]$SampleRate = 44100,
|
| 8 |
+
[int]$Seed = 42,
|
| 9 |
+
[double]$GuidanceScale = 7.0,
|
| 10 |
+
[int]$Steps = 50,
|
| 11 |
+
[string]$UseLM = "true",
|
| 12 |
+
[switch]$SimplePrompt,
|
| 13 |
+
[switch]$Instrumental,
|
| 14 |
+
[switch]$RnbLoveTemptation,
|
| 15 |
+
[switch]$RnbPopRap2Min,
|
| 16 |
+
[switch]$AllowFallback,
|
| 17 |
+
[string]$OutFile = "test_music.wav"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
$ErrorActionPreference = "Stop"
|
| 21 |
+
|
| 22 |
+
if (-not $Token) {
|
| 23 |
+
$Token = $env:HF_TOKEN
|
| 24 |
+
}
|
| 25 |
+
if (-not $Url) {
|
| 26 |
+
$Url = $env:HF_ENDPOINT_URL
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
if (-not $Token) {
|
| 30 |
+
throw "HF token not provided. Use -Token or set HF_TOKEN."
|
| 31 |
+
}
|
| 32 |
+
if (-not $Url) {
|
| 33 |
+
throw "Endpoint URL not provided. Use -Url or set HF_ENDPOINT_URL."
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
if ($RnbLoveTemptation.IsPresent) {
|
| 37 |
+
$Prompt = "melodic RnB pop with ambient melodies, evolving chord progression, emotional modern production, intimate vocals, soulful hooks"
|
| 38 |
+
$Lyrics = @"
|
| 39 |
+
[Verse 1]
|
| 40 |
+
Late night shadows on my skin,
|
| 41 |
+
I hear your name where the silence has been.
|
| 42 |
+
I swore I'd run when the fire got close,
|
| 43 |
+
But I keep chasing what hurts me the most.
|
| 44 |
+
|
| 45 |
+
[Pre-Chorus]
|
| 46 |
+
Every promise pulls me back again,
|
| 47 |
+
Sweet poison dressed like a loyal friend.
|
| 48 |
+
|
| 49 |
+
[Chorus]
|
| 50 |
+
I'm fighting love and temptation,
|
| 51 |
+
Heart in a war with my own salvation.
|
| 52 |
+
One touch and my defenses break,
|
| 53 |
+
I know it's danger but I stay awake.
|
| 54 |
+
I'm fighting love and temptation,
|
| 55 |
+
Drowning slow in this sweet devastation.
|
| 56 |
+
I want to leave but I hesitate,
|
| 57 |
+
Cause I still crave what I know I should hate.
|
| 58 |
+
|
| 59 |
+
[Verse 2]
|
| 60 |
+
Your voice is velvet over broken glass,
|
| 61 |
+
I learn the pain but I still relapse.
|
| 62 |
+
Truth on my lips, lies in my veins,
|
| 63 |
+
I pray for peace while I dance in flames.
|
| 64 |
+
|
| 65 |
+
[Bridge]
|
| 66 |
+
If love is a test, I'm failing with grace,
|
| 67 |
+
Still falling for fire, still calling your name.
|
| 68 |
+
|
| 69 |
+
[Final Chorus]
|
| 70 |
+
I'm fighting love and temptation,
|
| 71 |
+
Heart in a war with my own salvation.
|
| 72 |
+
One touch and my defenses break,
|
| 73 |
+
I know it's danger but I stay awake.
|
| 74 |
+
"@
|
| 75 |
+
|
| 76 |
+
if ($DurationSec -eq 3) {
|
| 77 |
+
$DurationSec = 24
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
if (-not $PSBoundParameters.ContainsKey("SimplePrompt")) {
|
| 81 |
+
$SimplePrompt = $false
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
if ($RnbPopRap2Min.IsPresent) {
|
| 86 |
+
$Prompt = "2 minute RnB Pop Rap song, melodic ambient pads, emotional chord progression, intimate female and male vocal blend, catchy hooks, modern drums, deep 808, vulnerable but confident tone"
|
| 87 |
+
$Lyrics = @"
|
| 88 |
+
[Intro]
|
| 89 |
+
Mm, midnight in my head, no sleep.
|
| 90 |
+
Same old war in my chest, still deep.
|
| 91 |
+
I say I'm done, then I call your name,
|
| 92 |
+
I run from the fire, then I walk in the flame.
|
| 93 |
+
|
| 94 |
+
[Verse 1]
|
| 95 |
+
Streetlights drip on the window pane,
|
| 96 |
+
I wear my pride like a silver chain.
|
| 97 |
+
Say I don't need you, that's what I say,
|
| 98 |
+
But your voice in my mind don't fade away.
|
| 99 |
+
I got dreams, got scars, got bills to pay,
|
| 100 |
+
Still I fold when your eyes pull me in that way.
|
| 101 |
+
I know better, I swear I do,
|
| 102 |
+
But temptation sounds like truth when it sounds like you.
|
| 103 |
+
|
| 104 |
+
[Pre-Chorus]
|
| 105 |
+
Every promise tastes sweet then turns to smoke,
|
| 106 |
+
I keep rebuilding hearts that we already broke.
|
| 107 |
+
I want peace, but I want your touch,
|
| 108 |
+
I know it's too much, still it's never enough.
|
| 109 |
+
|
| 110 |
+
[Chorus]
|
| 111 |
+
I'm fighting love and temptation,
|
| 112 |
+
Heart on trial with no salvation.
|
| 113 |
+
One more kiss and the walls cave in,
|
| 114 |
+
I lose myself just to feel again.
|
| 115 |
+
I'm fighting love and temptation,
|
| 116 |
+
Drowning slow in sweet devastation.
|
| 117 |
+
I say goodbye, then I hesitate,
|
| 118 |
+
Cause I still crave what I know I should hate.
|
| 119 |
+
|
| 120 |
+
[Rap Verse 1]
|
| 121 |
+
Look, I been in and out the same lane,
|
| 122 |
+
Different night, same rain.
|
| 123 |
+
Tell myself "don't text back,"
|
| 124 |
+
Still type your name, press send, same pain.
|
| 125 |
+
You the high and the low in one dose,
|
| 126 |
+
Got me praying for distance, still close.
|
| 127 |
+
I play tough, but the truth is loud,
|
| 128 |
+
When you're gone, all this noise in the crowd.
|
| 129 |
+
Yeah, I hustle, I grind, I glow,
|
| 130 |
+
But alone in the dark, I'm a different soul.
|
| 131 |
+
If love was logic, I'd be free by now,
|
| 132 |
+
But my heart ain't science, I just bleed it out.
|
| 133 |
+
|
| 134 |
+
[Verse 2]
|
| 135 |
+
Your perfume still lives in my hoodie seams,
|
| 136 |
+
Like a ghost in the corners of all my dreams.
|
| 137 |
+
I learned your chaos, your every disguise,
|
| 138 |
+
The saint in your smile, the storm in your eyes.
|
| 139 |
+
I touch your hand and forget my name,
|
| 140 |
+
Call it desire, call it blame.
|
| 141 |
+
I need healing, I need release,
|
| 142 |
+
But your lips keep turning my war to peace.
|
| 143 |
+
|
| 144 |
+
[Pre-Chorus]
|
| 145 |
+
Every promise tastes sweet then turns to smoke,
|
| 146 |
+
I keep rebuilding hearts that we already broke.
|
| 147 |
+
I want peace, but I want your touch,
|
| 148 |
+
I know it's too much, still it's never enough.
|
| 149 |
+
|
| 150 |
+
[Chorus]
|
| 151 |
+
I'm fighting love and temptation,
|
| 152 |
+
Heart on trial with no salvation.
|
| 153 |
+
One more kiss and the walls cave in,
|
| 154 |
+
I lose myself just to feel again.
|
| 155 |
+
I'm fighting love and temptation,
|
| 156 |
+
Drowning slow in sweet devastation.
|
| 157 |
+
I say goodbye, then I hesitate,
|
| 158 |
+
Cause I still crave what I know I should hate.
|
| 159 |
+
|
| 160 |
+
[Rap Verse 2]
|
| 161 |
+
Uh, late calls, no sleep, red eyes,
|
| 162 |
+
Truth hurts more than sweet lies.
|
| 163 |
+
We toxic, but the chemistry loud,
|
| 164 |
+
Like thunder in a summer night over this town.
|
| 165 |
+
Tell me leave, then you pull me near,
|
| 166 |
+
Tell me "trust me," then feed my fear.
|
| 167 |
+
I keep faith in a broken map,
|
| 168 |
+
Tryna find us on roads that don't lead back.
|
| 169 |
+
I got plans, got goals, got pride,
|
| 170 |
+
But temptation got hands on the wheel tonight.
|
| 171 |
+
If I fall, let me fall with grace,
|
| 172 |
+
I still see home when I look in your face.
|
| 173 |
+
|
| 174 |
+
[Bridge]
|
| 175 |
+
If this love is a test, I'm failing in style,
|
| 176 |
+
Smiling through fire for one more while.
|
| 177 |
+
I know I should run, I know I should wait,
|
| 178 |
+
But your name on my tongue sounds too much like fate.
|
| 179 |
+
|
| 180 |
+
[Final Chorus]
|
| 181 |
+
I'm fighting love and temptation,
|
| 182 |
+
Heart on trial with no salvation.
|
| 183 |
+
One more kiss and the walls cave in,
|
| 184 |
+
I lose myself just to feel again.
|
| 185 |
+
I'm fighting love and temptation,
|
| 186 |
+
Drowning slow in sweet devastation.
|
| 187 |
+
I say goodbye, then I hesitate,
|
| 188 |
+
Cause I still crave what I know I should hate.
|
| 189 |
+
|
| 190 |
+
[Outro]
|
| 191 |
+
Mm, midnight in my head, no sleep.
|
| 192 |
+
Still your name in my chest, too deep.
|
| 193 |
+
"@
|
| 194 |
+
|
| 195 |
+
if ($DurationSec -eq 3) {
|
| 196 |
+
$DurationSec = 120
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
if (-not $PSBoundParameters.ContainsKey("SimplePrompt")) {
|
| 200 |
+
$SimplePrompt = $false
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
$useLmBool = $true
|
| 205 |
+
if ($null -ne $UseLM -and $UseLM -ne "") {
|
| 206 |
+
try {
|
| 207 |
+
$useLmBool = [System.Convert]::ToBoolean($UseLM)
|
| 208 |
+
}
|
| 209 |
+
catch {
|
| 210 |
+
$useLmBool = ($UseLM -match '^(1|true|t|yes|y|on)$')
|
| 211 |
+
}
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
$inputs = @{
|
| 215 |
+
prompt = $Prompt
|
| 216 |
+
duration_sec = $DurationSec
|
| 217 |
+
sample_rate = $SampleRate
|
| 218 |
+
seed = $Seed
|
| 219 |
+
guidance_scale = $GuidanceScale
|
| 220 |
+
steps = $Steps
|
| 221 |
+
use_lm = $useLmBool
|
| 222 |
+
allow_fallback = $AllowFallback.IsPresent
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
if ($Lyrics) {
|
| 226 |
+
$inputs["lyrics"] = $Lyrics
|
| 227 |
+
}
|
| 228 |
+
if ($SimplePrompt.IsPresent) {
|
| 229 |
+
$inputs["simple_prompt"] = $true
|
| 230 |
+
}
|
| 231 |
+
if ($Instrumental.IsPresent) {
|
| 232 |
+
$inputs["instrumental"] = $true
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
$body = @{ inputs = $inputs } | ConvertTo-Json -Depth 8
|
| 236 |
+
|
| 237 |
+
$response = Invoke-RestMethod -Method Post -Uri $Url -Headers @{
|
| 238 |
+
Authorization = "Bearer $Token"
|
| 239 |
+
"Content-Type" = "application/json"
|
| 240 |
+
} -Body $body
|
| 241 |
+
|
| 242 |
+
$response | ConvertTo-Json -Depth 6
|
| 243 |
+
|
| 244 |
+
if ($response.error) {
|
| 245 |
+
throw "Endpoint returned error: $($response.error)"
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
if ($response.used_fallback -and -not $AllowFallback.IsPresent) {
|
| 249 |
+
throw "Endpoint used fallback audio. Set -AllowFallback only if you want fallback behavior."
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
if (-not $response.audio_base64_wav) {
|
| 253 |
+
throw "No audio_base64_wav returned."
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
[IO.File]::WriteAllBytes($OutFile, [Convert]::FromBase64String($response.audio_base64_wav))
|
| 257 |
+
Write-Host "Saved audio file: $OutFile"
|
scripts/endpoint/test_rnb.bat
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
setlocal
|
| 3 |
+
|
| 4 |
+
powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0test.ps1" -RnbLoveTemptation -DurationSec 24 -SampleRate 44100 -Seed 42 -GuidanceScale 7.0 -Steps 8 -UseLM 1 -OutFile "test_rnb_music.wav"
|
| 5 |
+
|
| 6 |
+
if errorlevel 1 (
|
| 7 |
+
echo Request failed.
|
| 8 |
+
exit /b 1
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
echo Done.
|
| 12 |
+
endlocal
|
scripts/endpoint/test_rnb_2min.bat
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
setlocal
|
| 3 |
+
|
| 4 |
+
powershell -NoProfile -ExecutionPolicy Bypass -File "%~dp0test.ps1" -RnbPopRap2Min -DurationSec 120 -SampleRate 44100 -Seed 42 -GuidanceScale 7.0 -Steps 8 -UseLM 1 -OutFile "test_rnb_pop_rap_2min.wav"
|
| 5 |
+
|
| 6 |
+
if errorlevel 1 (
|
| 7 |
+
echo Request failed.
|
| 8 |
+
exit /b 1
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
echo Done.
|
| 12 |
+
endlocal
|
scripts/hf_clone.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
"""
|
| 3 |
+
Bootstrap this project into your own Hugging Face Space and/or Endpoint repo.
|
| 4 |
+
|
| 5 |
+
Examples:
|
| 6 |
+
python scripts/hf_clone.py space --repo-id your-name/ace-step-lora-studio
|
| 7 |
+
python scripts/hf_clone.py endpoint --repo-id your-name/ace-step-endpoint
|
| 8 |
+
python scripts/hf_clone.py all --space-repo-id your-name/ace-step-lora-studio --endpoint-repo-id your-name/ace-step-endpoint
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import os
|
| 15 |
+
import shutil
|
| 16 |
+
import tempfile
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Iterable
|
| 19 |
+
|
| 20 |
+
from huggingface_hub import HfApi
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 24 |
+
|
| 25 |
+
COMMON_SKIP_DIRS = {
|
| 26 |
+
".git",
|
| 27 |
+
"__pycache__",
|
| 28 |
+
".pytest_cache",
|
| 29 |
+
".mypy_cache",
|
| 30 |
+
".ruff_cache",
|
| 31 |
+
".venv",
|
| 32 |
+
"venv",
|
| 33 |
+
"env",
|
| 34 |
+
".idea",
|
| 35 |
+
".vscode",
|
| 36 |
+
".cache",
|
| 37 |
+
".huggingface",
|
| 38 |
+
".gradio",
|
| 39 |
+
"checkpoints",
|
| 40 |
+
"lora_output",
|
| 41 |
+
"outputs",
|
| 42 |
+
"artifacts",
|
| 43 |
+
"models",
|
| 44 |
+
"datasets",
|
| 45 |
+
"Lora-ace-step",
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
COMMON_SKIP_FILES = {
|
| 49 |
+
".env",
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
COMMON_SKIP_PREFIXES = (
|
| 53 |
+
"song_summaries_llm",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
COMMON_SKIP_SUFFIXES = {
|
| 57 |
+
".wav",
|
| 58 |
+
".flac",
|
| 59 |
+
".mp3",
|
| 60 |
+
".ogg",
|
| 61 |
+
".opus",
|
| 62 |
+
".m4a",
|
| 63 |
+
".aac",
|
| 64 |
+
".pt",
|
| 65 |
+
".bin",
|
| 66 |
+
".safetensors",
|
| 67 |
+
".ckpt",
|
| 68 |
+
".onnx",
|
| 69 |
+
".log",
|
| 70 |
+
".pyc",
|
| 71 |
+
".pyo",
|
| 72 |
+
".pyd",
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
MAX_FILE_BYTES = 30 * 1024 * 1024 # 30MB safety cap for upload snapshot
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _should_skip_common(rel_path: Path, is_dir: bool) -> bool:
|
| 79 |
+
if any(part in COMMON_SKIP_DIRS for part in rel_path.parts):
|
| 80 |
+
return True
|
| 81 |
+
if rel_path.name in COMMON_SKIP_FILES:
|
| 82 |
+
return True
|
| 83 |
+
if any(rel_path.name.startswith(prefix) for prefix in COMMON_SKIP_PREFIXES):
|
| 84 |
+
return True
|
| 85 |
+
if not is_dir and rel_path.suffix.lower() in COMMON_SKIP_SUFFIXES:
|
| 86 |
+
return True
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def _copy_file(src: Path, dst: Path) -> None:
|
| 91 |
+
dst.parent.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
shutil.copy2(src, dst)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _stage_space_snapshot(staging_dir: Path) -> tuple[int, int, list[str]]:
|
| 96 |
+
copied = 0
|
| 97 |
+
bytes_total = 0
|
| 98 |
+
skipped: list[str] = []
|
| 99 |
+
|
| 100 |
+
for src in PROJECT_ROOT.rglob("*"):
|
| 101 |
+
rel = src.relative_to(PROJECT_ROOT)
|
| 102 |
+
|
| 103 |
+
if src.is_dir():
|
| 104 |
+
if _should_skip_common(rel, is_dir=True):
|
| 105 |
+
skipped.append(f"{rel}/")
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
if _should_skip_common(rel, is_dir=False):
|
| 109 |
+
skipped.append(str(rel))
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
size = src.stat().st_size
|
| 113 |
+
if size > MAX_FILE_BYTES:
|
| 114 |
+
skipped.append(f"{rel} (>{MAX_FILE_BYTES // (1024 * 1024)}MB)")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
dst = staging_dir / rel
|
| 118 |
+
_copy_file(src, dst)
|
| 119 |
+
copied += 1
|
| 120 |
+
bytes_total += size
|
| 121 |
+
|
| 122 |
+
return copied, bytes_total, skipped
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _iter_endpoint_paths() -> Iterable[Path]:
|
| 126 |
+
# Minimal runtime set for custom endpoint repos.
|
| 127 |
+
required = [
|
| 128 |
+
PROJECT_ROOT / "handler.py",
|
| 129 |
+
PROJECT_ROOT / "requirements.txt",
|
| 130 |
+
PROJECT_ROOT / "packages.txt",
|
| 131 |
+
PROJECT_ROOT / "acestep",
|
| 132 |
+
]
|
| 133 |
+
for p in required:
|
| 134 |
+
if p.exists():
|
| 135 |
+
yield p
|
| 136 |
+
|
| 137 |
+
template_readme = PROJECT_ROOT / "templates" / "hf-endpoint" / "README.md"
|
| 138 |
+
if template_readme.exists():
|
| 139 |
+
yield template_readme
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _stage_endpoint_snapshot(staging_dir: Path) -> tuple[int, int]:
|
| 143 |
+
copied = 0
|
| 144 |
+
bytes_total = 0
|
| 145 |
+
|
| 146 |
+
for src in _iter_endpoint_paths():
|
| 147 |
+
if src.is_file():
|
| 148 |
+
rel_dst = Path("README.md") if src.name == "README.md" and "templates" in src.parts else Path(src.name)
|
| 149 |
+
dst = staging_dir / rel_dst
|
| 150 |
+
_copy_file(src, dst)
|
| 151 |
+
copied += 1
|
| 152 |
+
bytes_total += src.stat().st_size
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
if src.is_dir():
|
| 156 |
+
for nested in src.rglob("*"):
|
| 157 |
+
rel_nested = nested.relative_to(src)
|
| 158 |
+
if nested.is_dir():
|
| 159 |
+
if _should_skip_common(Path(src.name) / rel_nested, is_dir=True):
|
| 160 |
+
continue
|
| 161 |
+
continue
|
| 162 |
+
if _should_skip_common(Path(src.name) / rel_nested, is_dir=False):
|
| 163 |
+
continue
|
| 164 |
+
if nested.suffix.lower() in {".wav", ".flac", ".mp3", ".ogg"}:
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
dst = staging_dir / src.name / rel_nested
|
| 168 |
+
_copy_file(nested, dst)
|
| 169 |
+
copied += 1
|
| 170 |
+
bytes_total += nested.stat().st_size
|
| 171 |
+
|
| 172 |
+
return copied, bytes_total
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _resolve_token(arg_token: str) -> str | None:
|
| 176 |
+
if arg_token:
|
| 177 |
+
return arg_token
|
| 178 |
+
return os.getenv("HF_TOKEN")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _ensure_repo(
|
| 182 |
+
api: HfApi,
|
| 183 |
+
repo_id: str,
|
| 184 |
+
repo_type: str,
|
| 185 |
+
private: bool,
|
| 186 |
+
space_sdk: str | None = None,
|
| 187 |
+
) -> None:
|
| 188 |
+
kwargs = {
|
| 189 |
+
"repo_id": repo_id,
|
| 190 |
+
"repo_type": repo_type,
|
| 191 |
+
"private": private,
|
| 192 |
+
"exist_ok": True,
|
| 193 |
+
}
|
| 194 |
+
if repo_type == "space" and space_sdk:
|
| 195 |
+
kwargs["space_sdk"] = space_sdk
|
| 196 |
+
api.create_repo(**kwargs)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _upload_snapshot(
|
| 200 |
+
api: HfApi,
|
| 201 |
+
repo_id: str,
|
| 202 |
+
repo_type: str,
|
| 203 |
+
folder_path: Path,
|
| 204 |
+
commit_message: str,
|
| 205 |
+
) -> None:
|
| 206 |
+
api.upload_folder(
|
| 207 |
+
repo_id=repo_id,
|
| 208 |
+
repo_type=repo_type,
|
| 209 |
+
folder_path=str(folder_path),
|
| 210 |
+
commit_message=commit_message,
|
| 211 |
+
delete_patterns=[],
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _fmt_mb(num_bytes: int) -> str:
|
| 216 |
+
return f"{num_bytes / (1024 * 1024):.2f} MB"
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def clone_space(repo_id: str, private: bool, token: str | None, dry_run: bool) -> None:
|
| 220 |
+
with tempfile.TemporaryDirectory(prefix="hf_space_clone_") as tmp:
|
| 221 |
+
staging = Path(tmp)
|
| 222 |
+
copied, bytes_total, skipped = _stage_space_snapshot(staging)
|
| 223 |
+
print(f"[space] staged files: {copied}, size: {_fmt_mb(bytes_total)}")
|
| 224 |
+
if skipped:
|
| 225 |
+
print(f"[space] skipped entries: {len(skipped)}")
|
| 226 |
+
for item in skipped[:20]:
|
| 227 |
+
print(f" - {item}")
|
| 228 |
+
if len(skipped) > 20:
|
| 229 |
+
print(f" ... and {len(skipped) - 20} more")
|
| 230 |
+
|
| 231 |
+
if dry_run:
|
| 232 |
+
print("[space] dry-run complete (nothing uploaded).")
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
api = HfApi(token=token)
|
| 236 |
+
_ensure_repo(api, repo_id=repo_id, repo_type="space", private=private, space_sdk="gradio")
|
| 237 |
+
_upload_snapshot(
|
| 238 |
+
api,
|
| 239 |
+
repo_id=repo_id,
|
| 240 |
+
repo_type="space",
|
| 241 |
+
folder_path=staging,
|
| 242 |
+
commit_message="Bootstrap ACE-Step LoRA Studio Space",
|
| 243 |
+
)
|
| 244 |
+
print(f"[space] uploaded to https://huggingface.co/spaces/{repo_id}")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def clone_endpoint(repo_id: str, private: bool, token: str | None, dry_run: bool) -> None:
|
| 248 |
+
with tempfile.TemporaryDirectory(prefix="hf_endpoint_clone_") as tmp:
|
| 249 |
+
staging = Path(tmp)
|
| 250 |
+
copied, bytes_total = _stage_endpoint_snapshot(staging)
|
| 251 |
+
print(f"[endpoint] staged files: {copied}, size: {_fmt_mb(bytes_total)}")
|
| 252 |
+
|
| 253 |
+
if dry_run:
|
| 254 |
+
print("[endpoint] dry-run complete (nothing uploaded).")
|
| 255 |
+
return
|
| 256 |
+
|
| 257 |
+
api = HfApi(token=token)
|
| 258 |
+
_ensure_repo(api, repo_id=repo_id, repo_type="model", private=private)
|
| 259 |
+
_upload_snapshot(
|
| 260 |
+
api,
|
| 261 |
+
repo_id=repo_id,
|
| 262 |
+
repo_type="model",
|
| 263 |
+
folder_path=staging,
|
| 264 |
+
commit_message="Bootstrap ACE-Step custom endpoint repo",
|
| 265 |
+
)
|
| 266 |
+
print(f"[endpoint] uploaded to https://huggingface.co/{repo_id}")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 270 |
+
parser = argparse.ArgumentParser(description="Clone this project into your own HF Space/Endpoint repos.")
|
| 271 |
+
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
| 272 |
+
|
| 273 |
+
p_space = subparsers.add_parser("space", help="Create/update your HF Space from this project.")
|
| 274 |
+
p_space.add_argument("--repo-id", required=True, help="Target space repo id, e.g. username/my-space.")
|
| 275 |
+
p_space.add_argument("--private", action="store_true", help="Create repo as private.")
|
| 276 |
+
p_space.add_argument("--token", type=str, default="", help="HF token (default: HF_TOKEN env var).")
|
| 277 |
+
p_space.add_argument("--dry-run", action="store_true", help="Stage files only; do not upload.")
|
| 278 |
+
|
| 279 |
+
p_endpoint = subparsers.add_parser("endpoint", help="Create/update your custom endpoint model repo.")
|
| 280 |
+
p_endpoint.add_argument("--repo-id", required=True, help="Target model repo id, e.g. username/my-endpoint.")
|
| 281 |
+
p_endpoint.add_argument("--private", action="store_true", help="Create repo as private.")
|
| 282 |
+
p_endpoint.add_argument("--token", type=str, default="", help="HF token (default: HF_TOKEN env var).")
|
| 283 |
+
p_endpoint.add_argument("--dry-run", action="store_true", help="Stage files only; do not upload.")
|
| 284 |
+
|
| 285 |
+
p_all = subparsers.add_parser("all", help="Run both Space and Endpoint bootstrap.")
|
| 286 |
+
p_all.add_argument("--space-repo-id", required=True, help="Target space repo id.")
|
| 287 |
+
p_all.add_argument("--endpoint-repo-id", required=True, help="Target endpoint model repo id.")
|
| 288 |
+
p_all.add_argument("--space-private", action="store_true", help="Create Space as private.")
|
| 289 |
+
p_all.add_argument("--endpoint-private", action="store_true", help="Create endpoint repo as private.")
|
| 290 |
+
p_all.add_argument("--token", type=str, default="", help="HF token (default: HF_TOKEN env var).")
|
| 291 |
+
p_all.add_argument("--dry-run", action="store_true", help="Stage files only; do not upload.")
|
| 292 |
+
|
| 293 |
+
return parser
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def main() -> int:
|
| 297 |
+
args = build_parser().parse_args()
|
| 298 |
+
token = _resolve_token(args.token)
|
| 299 |
+
|
| 300 |
+
if not token and not args.dry_run:
|
| 301 |
+
print("HF token not found. Set HF_TOKEN or pass --token.")
|
| 302 |
+
return 1
|
| 303 |
+
|
| 304 |
+
if args.cmd == "space":
|
| 305 |
+
clone_space(args.repo_id, private=bool(args.private), token=token, dry_run=bool(args.dry_run))
|
| 306 |
+
elif args.cmd == "endpoint":
|
| 307 |
+
clone_endpoint(args.repo_id, private=bool(args.private), token=token, dry_run=bool(args.dry_run))
|
| 308 |
+
else:
|
| 309 |
+
clone_space(args.space_repo_id, private=bool(args.space_private), token=token, dry_run=bool(args.dry_run))
|
| 310 |
+
clone_endpoint(
|
| 311 |
+
args.endpoint_repo_id,
|
| 312 |
+
private=bool(args.endpoint_private),
|
| 313 |
+
token=token,
|
| 314 |
+
dry_run=bool(args.dry_run),
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
return 0
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
raise SystemExit(main())
|
scripts/jobs/submit_hf_lora_job.ps1
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
param(
|
| 2 |
+
[string]$CodeRepo = "YOUR_USERNAME/ace-step-lora-studio",
|
| 3 |
+
[string]$DatasetRepo = "",
|
| 4 |
+
[string]$DatasetRevision = "main",
|
| 5 |
+
[string]$DatasetSubdir = "",
|
| 6 |
+
[string]$ModelConfig = "acestep-v15-base",
|
| 7 |
+
[string]$Flavor = "a10g-large",
|
| 8 |
+
[string]$Timeout = "8h",
|
| 9 |
+
[int]$Epochs = 20,
|
| 10 |
+
[int]$BatchSize = 1,
|
| 11 |
+
[int]$GradAccum = 1,
|
| 12 |
+
[string]$OutputDir = "/workspace/output",
|
| 13 |
+
[string]$UploadRepo = "",
|
| 14 |
+
[switch]$UploadPrivate,
|
| 15 |
+
[switch]$Detach
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
$ErrorActionPreference = "Stop"
|
| 19 |
+
|
| 20 |
+
if (-not $DatasetRepo) {
|
| 21 |
+
throw "Provide -DatasetRepo (HF dataset repo containing your audio + optional sidecars)."
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
$secretArgs = @("--secrets", "HF_TOKEN")
|
| 25 |
+
|
| 26 |
+
$uploadArgs = ""
|
| 27 |
+
if ($UploadRepo) {
|
| 28 |
+
$uploadArgs = "--upload-repo `"$UploadRepo`""
|
| 29 |
+
if ($UploadPrivate.IsPresent) {
|
| 30 |
+
$uploadArgs += " --upload-private"
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
$datasetSubdirArgs = ""
|
| 35 |
+
if ($DatasetSubdir) {
|
| 36 |
+
$datasetSubdirArgs = "--dataset-subdir `"$DatasetSubdir`""
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
$detachArg = ""
|
| 40 |
+
if ($Detach.IsPresent) {
|
| 41 |
+
$detachArg = "--detach"
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
$jobCommand = @"
|
| 45 |
+
set -e
|
| 46 |
+
python -m pip install --no-cache-dir --upgrade pip
|
| 47 |
+
git clone https://huggingface.co/$CodeRepo /workspace/code
|
| 48 |
+
cd /workspace/code
|
| 49 |
+
python -m pip install --no-cache-dir -r requirements.txt
|
| 50 |
+
python lora_train.py \
|
| 51 |
+
--dataset-repo "$DatasetRepo" \
|
| 52 |
+
--dataset-revision "$DatasetRevision" \
|
| 53 |
+
$datasetSubdirArgs \
|
| 54 |
+
--model-config "$ModelConfig" \
|
| 55 |
+
--device auto \
|
| 56 |
+
--num-epochs $Epochs \
|
| 57 |
+
--batch-size $BatchSize \
|
| 58 |
+
--grad-accum $GradAccum \
|
| 59 |
+
--output-dir "$OutputDir" \
|
| 60 |
+
$uploadArgs
|
| 61 |
+
"@
|
| 62 |
+
|
| 63 |
+
$argsList = @(
|
| 64 |
+
"jobs", "run",
|
| 65 |
+
"--flavor", $Flavor,
|
| 66 |
+
"--timeout", $Timeout
|
| 67 |
+
) + $secretArgs
|
| 68 |
+
|
| 69 |
+
if ($detachArg) {
|
| 70 |
+
$argsList += $detachArg
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
$argsList += @(
|
| 74 |
+
"pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime",
|
| 75 |
+
"bash", "-lc", $jobCommand
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
Write-Host "Submitting HF Job with flavor=$Flavor timeout=$Timeout ..."
|
| 79 |
+
Write-Host "Dataset repo: $DatasetRepo"
|
| 80 |
+
Write-Host "Code repo: $CodeRepo"
|
| 81 |
+
if ($UploadRepo) {
|
| 82 |
+
Write-Host "Will upload final adapter to: $UploadRepo"
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
& hf @argsList
|
summaries/findings.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Improving ACE-Step LoRA with Time-Event-Based Annotation
|
| 2 |
+
|
| 3 |
+
[Back to project README](../README.md)
|
| 4 |
+
|
| 5 |
+
## Baseline context in this repo
|
| 6 |
+
|
| 7 |
+
This project already provides a solid end-to-end workflow:
|
| 8 |
+
|
| 9 |
+
- Train LoRA adapters with `lora_train.py` and the Gradio UI (`app.py`, `lora_ui.py`).
|
| 10 |
+
- Deploy generation through a custom endpoint runtime (`handler.py`, `acestep/`).
|
| 11 |
+
- Test prompts and lyrics quickly with endpoint client scripts in `scripts/endpoint/`.
|
| 12 |
+
|
| 13 |
+
Today, most conditioning in this pipeline is still global (caption, lyrics, BPM, key, tags). That is a strong baseline, but it does not explicitly teach *when* events happen inside a track.
|
| 14 |
+
|
| 15 |
+
## Core limitation
|
| 16 |
+
|
| 17 |
+
Current annotations usually describe *what* a song is, not *when* events occur. The model can learn style and texture, but temporal structure is weaker:
|
| 18 |
+
|
| 19 |
+
- Verse/chorus transitions are often less deliberate than human-produced songs.
|
| 20 |
+
- Build-ups, drops, or effect changes can feel averaged or blurred.
|
| 21 |
+
- Subgenre-specific arrangement timing is harder to reproduce consistently.
|
| 22 |
+
|
| 23 |
+
## Why time-event labels are promising
|
| 24 |
+
|
| 25 |
+
1. Better musical structure: teach the model where sections start/end and where key transitions occur.
|
| 26 |
+
2. Better genre fidelity: encode timing differences between styles that share similar instruments.
|
| 27 |
+
3. Better control at inference: allow prompting for both content and structure (what + when).
|
| 28 |
+
|
| 29 |
+
## Practical direction for this codebase
|
| 30 |
+
|
| 31 |
+
A useful next step is to extend the current sidecar metadata approach with optional timed events.
|
| 32 |
+
|
| 33 |
+
Example direction:
|
| 34 |
+
|
| 35 |
+
- Keep existing fields (`caption`, `lyrics`, `bpm`, etc.).
|
| 36 |
+
- Add an `events` list with event type + start/end times.
|
| 37 |
+
- Start with a small, high-quality subset before scaling.
|
| 38 |
+
|
| 39 |
+
Illustrative shape:
|
| 40 |
+
|
| 41 |
+
```json
|
| 42 |
+
{
|
| 43 |
+
"caption": "emotional rnb pop with warm pads",
|
| 44 |
+
"bpm": 92,
|
| 45 |
+
"events": [
|
| 46 |
+
{"type": "intro", "start": 0.0, "end": 8.0},
|
| 47 |
+
{"type": "verse", "start": 8.0, "end": 32.0},
|
| 48 |
+
{"type": "chorus", "start": 32.0, "end": 48.0}
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Early experiments worth running
|
| 54 |
+
|
| 55 |
+
- Compare baseline LoRA vs time-event LoRA on the same curated mini-dataset.
|
| 56 |
+
- Score structural accuracy (section order, transition timing tolerance).
|
| 57 |
+
- Run blind listening tests for perceived musical arc and arrangement coherence.
|
| 58 |
+
- Track whether time labels improve consistency without reducing creativity.
|
| 59 |
+
|
| 60 |
+
## Expected outcomes
|
| 61 |
+
|
| 62 |
+
If this works, this repo can evolve from "style-conditioned generation" toward "structure-aware generation":
|
| 63 |
+
|
| 64 |
+
- More intentional song progression.
|
| 65 |
+
- Stronger subgenre identity.
|
| 66 |
+
- Better controllability for creators.
|
| 67 |
+
|
| 68 |
+
This is still a baseline research note, but it gives a clear technical direction that fits the current project architecture.
|
templates/hf-endpoint/README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ACE-Step Custom Endpoint Repo
|
| 2 |
+
|
| 3 |
+
This repo is intended for a Hugging Face **Dedicated Inference Endpoint** with a custom `handler.py`.
|
| 4 |
+
|
| 5 |
+
## Contents
|
| 6 |
+
|
| 7 |
+
- `handler.py`: Endpoint request/response logic.
|
| 8 |
+
- `acestep/`: Core inference utilities.
|
| 9 |
+
- `requirements.txt`: Python dependencies.
|
| 10 |
+
- `packages.txt`: System dependencies.
|
| 11 |
+
|
| 12 |
+
## Expected Request Payload
|
| 13 |
+
|
| 14 |
+
```json
|
| 15 |
+
{
|
| 16 |
+
"inputs": {
|
| 17 |
+
"prompt": "upbeat pop rap with emotional guitar",
|
| 18 |
+
"lyrics": "[Verse] city lights and midnight rain",
|
| 19 |
+
"duration_sec": 12,
|
| 20 |
+
"sample_rate": 44100,
|
| 21 |
+
"seed": 42,
|
| 22 |
+
"guidance_scale": 7.0,
|
| 23 |
+
"steps": 50,
|
| 24 |
+
"use_lm": true
|
| 25 |
+
}
|
| 26 |
+
}
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Quick Setup
|
| 30 |
+
|
| 31 |
+
1. Create a model repo on Hugging Face.
|
| 32 |
+
2. Push this folder content to that repo.
|
| 33 |
+
3. Create a new dedicated endpoint from this custom repo.
|
| 34 |
+
4. Set environment variables on the endpoint as needed:
|
| 35 |
+
- `ACE_CONFIG_PATH` (default `acestep-v15-sft`)
|
| 36 |
+
- `ACE_LM_MODEL_PATH` (default `acestep-5Hz-lm-4B`)
|
| 37 |
+
- `ACE_DOWNLOAD_SOURCE` (`huggingface` or `modelscope`)
|
| 38 |
+
5. Scale down or pause when idle to control cost.
|