mrfakename commited on
Commit
9f5c8f7
·
verified ·
1 Parent(s): f0bb57d

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. code/.env.example +4 -0
  3. code/.gitignore +228 -0
  4. code/LICENSE +21 -0
  5. code/README.md +221 -0
  6. code/acestep/__init__.py +1 -0
  7. code/acestep/acestep_v15_pipeline.py +298 -0
  8. code/acestep/api_server.py +1725 -0
  9. code/acestep/audio_utils.py +327 -0
  10. code/acestep/constants.py +109 -0
  11. code/acestep/constrained_logits_processor.py +0 -0
  12. code/acestep/dataset_handler.py +37 -0
  13. code/acestep/dit_alignment_score.py +870 -0
  14. code/acestep/genres_vocab.txt +0 -0
  15. code/acestep/gradio_ui/__init__.py +1 -0
  16. code/acestep/gradio_ui/events/__init__.py +1129 -0
  17. code/acestep/gradio_ui/events/generation_handlers.py +974 -0
  18. code/acestep/gradio_ui/events/results_handlers.py +0 -0
  19. code/acestep/gradio_ui/events/training_handlers.py +644 -0
  20. code/acestep/gradio_ui/i18n.py +152 -0
  21. code/acestep/gradio_ui/i18n/en.json +243 -0
  22. code/acestep/gradio_ui/i18n/ja.json +243 -0
  23. code/acestep/gradio_ui/i18n/zh.json +243 -0
  24. code/acestep/gradio_ui/interfaces/__init__.py +90 -0
  25. code/acestep/gradio_ui/interfaces/dataset.py +101 -0
  26. code/acestep/gradio_ui/interfaces/generation.py +766 -0
  27. code/acestep/gradio_ui/interfaces/result.py +552 -0
  28. code/acestep/gradio_ui/interfaces/training.py +562 -0
  29. code/acestep/handler.py +0 -0
  30. code/acestep/inference.py +1164 -0
  31. code/acestep/llm_inference.py +0 -0
  32. code/acestep/local_cache.py +129 -0
  33. code/acestep/test_time_scaling.py +410 -0
  34. code/acestep/third_parts/nano-vllm/LICENSE +21 -0
  35. code/acestep/third_parts/nano-vllm/README.md +66 -0
  36. code/acestep/third_parts/nano-vllm/assets/logo.png +3 -0
  37. code/acestep/third_parts/nano-vllm/bench.py +32 -0
  38. code/acestep/third_parts/nano-vllm/example.py +33 -0
  39. code/acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
  40. code/acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
  41. code/acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +112 -0
  42. code/acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +124 -0
  43. code/acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +529 -0
  44. code/acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +222 -0
  45. code/acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +96 -0
  46. code/acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
  47. code/acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
  48. code/acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
  49. code/acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py +50 -0
  50. code/acestep/third_parts/nano-vllm/nanovllm/layers/linear.py +153 -0
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ code/acestep/third_parts/nano-vllm/assets/logo.png filter=lfs diff=lfs merge=lfs -text
37
+ code/assets/ACE-Step_framework.png filter=lfs diff=lfs merge=lfs -text
38
+ code/assets/acestudio_logo.png filter=lfs diff=lfs merge=lfs -text
39
+ code/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
40
+ code/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
41
+ code/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
code/.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ACESTEP_CONFIG_PATH=acestep-v15-turbo
2
+ ACESTEP_LM_MODEL_PATH=acestep-5Hz-lm-1.7B
3
+ ACESTEP_DEVICE=auto
4
+ ACESTEP_LM_BACKEND=vllm
code/.gitignore ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ *.mp3
3
+ *.wav
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[codz]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py.cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # UV
102
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ uv.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+ #poetry.toml
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
118
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
119
+ #pdm.lock
120
+ #pdm.toml
121
+ .pdm-python
122
+ .pdm-build/
123
+
124
+ # pixi
125
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
126
+ #pixi.lock
127
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
128
+ # in the .venv directory. It is recommended not to include this directory in version control.
129
+ .pixi
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # SageMath parsed files
139
+ *.sage.py
140
+
141
+ # Environments
142
+ .env
143
+ .envrc
144
+ .venv
145
+ env/
146
+ venv/
147
+ ENV/
148
+ env.bak/
149
+ venv.bak/
150
+
151
+ # Spyder project settings
152
+ .spyderproject
153
+ .spyproject
154
+
155
+ # Rope project settings
156
+ .ropeproject
157
+
158
+ # mkdocs documentation
159
+ /site
160
+
161
+ # mypy
162
+ .mypy_cache/
163
+ .dmypy.json
164
+ dmypy.json
165
+
166
+ # Pyre type checker
167
+ .pyre/
168
+
169
+ # pytype static type analyzer
170
+ .pytype/
171
+
172
+ # Cython debug symbols
173
+ cython_debug/
174
+
175
+ # PyCharm
176
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
177
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
178
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
179
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
180
+ #.idea/
181
+
182
+ # Abstra
183
+ # Abstra is an AI-powered process automation framework.
184
+ # Ignore directories containing user credentials, local state, and settings.
185
+ # Learn more at https://abstra.io/docs
186
+ .abstra/
187
+
188
+ # Visual Studio Code
189
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
190
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
191
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
192
+ # you could uncomment the following to ignore the entire vscode folder
193
+ # .vscode/
194
+
195
+ # Ruff stuff:
196
+ .ruff_cache/
197
+
198
+ # PyPI configuration file
199
+ .pypirc
200
+
201
+ # Cursor
202
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
203
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
204
+ # refer to https://docs.cursor.com/context/ignore-files
205
+ .cursorignore
206
+ .cursorindexingignore
207
+
208
+ # Marimo
209
+ marimo/_static/
210
+ marimo/_lsp/
211
+ __marimo__/
212
+ tests/
213
+ checkpoints/
214
+ playground.ipynb
215
+ .history/
216
+ upload_checkpoints.sh
217
+ checkpoints.7z
218
+ README_old.md
219
+ discord_bot/
220
+ feishu_bot/
221
+ tmp*
222
+ torchinductor_root/
223
+ scripts/
224
+ checkpoints_legacy/
225
+ lora_output/
226
+ datasets/
227
+ python_embeded/
228
+ checkpoints_pack/
code/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 ACEStep
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
code/README.md ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">ACE-Step 1.5</h1>
2
+ <h1 align="center">Pushing the Boundaries of Open-Source Music Generation</h1>
3
+ <p align="center">
4
+ <a href="https://ace-step.github.io/ace-step-v1.5.github.io/">Project</a> |
5
+ <a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
6
+ <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
7
+ <a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5">Space Demo</a> |
8
+ <a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
9
+ <a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
10
+ </p>
11
+
12
+ <p align="center">
13
+ <img src="./assets/orgnization_logos.png" width="100%" alt="StepFun Logo">
14
+ </p>
15
+
16
+ ## Table of Contents
17
+
18
+ - [✨ Features](#-features)
19
+ - [📦 Installation](#-installation)
20
+ - [🚀 Usage](#-usage)
21
+ - [🔨 Train](#-train)
22
+ - [🏗️ Architecture](#️-architecture)
23
+ - [🦁 Model Zoo](#-model-zoo)
24
+
25
+ ## 📝 Abstract
26
+ 🚀 We present ACE-Step v1.5, a highly efficient open-source music foundation model that brings commercial-grade generation to consumer hardware. On commonly used evaluation metrics, ACE-Step v1.5 achieves quality beyond most commercial music models while remaining extremely fast—under 2 seconds per full song on an A100 and under 10 seconds on an RTX 3090. The model runs locally with less than 4GB of VRAM, and supports lightweight personalization: users can train a LoRA from just a few songs to capture their own style.
27
+
28
+ 🌉 At its core lies a novel hybrid architecture where the Language Model (LM) functions as an omni-capable planner: it transforms simple user queries into comprehensive song blueprints—scaling from short loops to 10-minute compositions—while synthesizing metadata, lyrics, and captions via Chain-of-Thought to guide the Diffusion Transformer (DiT). ⚡ Uniquely, this alignment is achieved through intrinsic reinforcement learning relying solely on the model's internal mechanisms, thereby eliminating the biases inherent in external reward models or human preferences. 🎚️
29
+
30
+ 🔮 Beyond standard synthesis, ACE-Step v1.5 unifies precise stylistic control with versatile editing capabilities—such as cover generation, repainting, and vocal-to-BGM conversion—while maintaining strict adherence to prompts across 50+ languages. This paves the way for powerful tools that seamlessly integrate into the creative workflows of music artists, producers, and content creators. 🎸
31
+
32
+
33
+ ## ✨ Features
34
+
35
+ <p align="center">
36
+ <img src="./assets/application_map.png" width="100%" alt="ACE-Step Framework">
37
+ </p>
38
+
39
+ ### ⚡ Performance
40
+ - ✅ **Ultra-Fast Generation** — Under 2s per full song on A100, under 10s on RTX 3090 (0.5s to 10s on A100 depending on think mode & diffusion steps)
41
+ - ✅ **Flexible Duration** — Supports 10 seconds to 10 minutes (600s) audio generation
42
+ - ✅ **Batch Generation** — Generate up to 8 songs simultaneously
43
+
44
+ ### 🎵 Generation Quality
45
+ - ✅ **Commercial-Grade Output** — Quality beyond most commercial music models (between Suno v4.5 and Suno v5)
46
+ - ✅ **Rich Style Support** — 1000+ instruments and styles with fine-grained timbre description
47
+ - ✅ **Multi-Language Lyrics** — Supports 50+ languages with lyrics prompt for structure & style control
48
+
49
+ ### 🎛️ Versatility & Control
50
+
51
+ | Feature | Description |
52
+ |---------|-------------|
53
+ | ✅ Reference Audio Input | Use reference audio to guide generation style |
54
+ | ✅ Cover Generation | Create covers from existing audio |
55
+ | ✅ Repaint & Edit | Selective local audio editing and regeneration |
56
+ | ✅ Track Separation | Separate audio into individual stems |
57
+ | ✅ Multi-Track Generation | Add layers like Suno Studio's "Add Layer" feature |
58
+ | ✅ Vocal2BGM | Auto-generate accompaniment for vocal tracks |
59
+ | ✅ Metadata Control | Control duration, BPM, key/scale, time signature |
60
+ | ✅ Simple Mode | Generate full songs from simple descriptions |
61
+ | ✅ Query Rewriting | Auto LM expansion of tags and lyrics |
62
+ | ✅ Audio Understanding | Extract BPM, key/scale, time signature & caption from audio |
63
+ | ✅ LRC Generation | Auto-generate lyric timestamps for generated music |
64
+ | ✅ LoRA Training | One-click annotation & training in Gradio. 8 songs, 1 hour on 3090 (12GB VRAM) |
65
+ | ✅ Quality Scoring | Automatic quality assessment for generated audio |
66
+
67
+
68
+
69
+ ## 📦 Installation
70
+
71
+ > **Requirements:** Python 3.11, CUDA GPU recommended (works on CPU/MPS but slower)
72
+
73
+ ### 1. Install uv (Package Manager)
74
+
75
+ ```bash
76
+ # macOS / Linux
77
+ curl -LsSf https://astral.sh/uv/install.sh | sh
78
+
79
+ # Windows (PowerShell)
80
+ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
81
+ ```
82
+
83
+ ### 2. Clone & Install
84
+
85
+ ```bash
86
+ git clone https://github.com/ACE-Step/ACE-Step-1.5.git
87
+ cd ACE-Step-1.5
88
+ uv sync
89
+ ```
90
+
91
+ ### 3. Launch
92
+
93
+ #### 🖥️ Gradio Web UI (Recommended)
94
+
95
+ ```bash
96
+ uv run acestep
97
+ ```
98
+
99
+ Open http://localhost:7860 in your browser. Models will be downloaded automatically on first run.
100
+
101
+ #### 🌐 REST API Server
102
+
103
+ ```bash
104
+ uv run acestep-api
105
+ ```
106
+
107
+ API runs at http://localhost:8001. See [API Documentation](./docs/en/API.md) for endpoints.
108
+
109
+ ### Command Line Options
110
+
111
+ **Gradio UI (`acestep`):**
112
+
113
+ | Option | Default | Description |
114
+ |--------|---------|-------------|
115
+ | `--port` | 7860 | Server port |
116
+ | `--server-name` | 127.0.0.1 | Server address (use `0.0.0.0` for network access) |
117
+ | `--share` | false | Create public Gradio link |
118
+ | `--language` | en | UI language: `en`, `zh`, `ja` |
119
+ | `--init_service` | false | Auto-initialize models on startup |
120
+ | `--config_path` | auto | DiT model (e.g., `acestep-v15-turbo`, `acestep-v15-turbo-shift3`) |
121
+ | `--lm_model_path` | auto | LM model (e.g., `acestep-5Hz-lm-0.6B`, `acestep-5Hz-lm-1.7B`) |
122
+ | `--offload_to_cpu` | auto | CPU offload (auto-enabled if VRAM < 16GB) |
123
+
124
+ **Examples:**
125
+
126
+ ```bash
127
+ # Public access with Chinese UI
128
+ uv run acestep --server-name 0.0.0.0 --share --language zh
129
+
130
+ # Pre-initialize models on startup
131
+ uv run acestep --init_service true --config_path acestep-v15-turbo
132
+ ```
133
+
134
+ ### Development
135
+
136
+ ```bash
137
+ # Add dependencies
138
+ uv add package-name
139
+ uv add --dev package-name
140
+
141
+ # Update all dependencies
142
+ uv sync --upgrade
143
+ ```
144
+
145
+ ## 🚀 Usage
146
+
147
+ We provide multiple ways to use ACE-Step:
148
+
149
+ | Method | Description | Documentation |
150
+ |--------|-------------|---------------|
151
+ | 🖥️ **Gradio Web UI** | Interactive web interface for music generation | [Gradio Guide](./docs/en/GRADIO_GUIDE.md) |
152
+ | 🐍 **Python API** | Programmatic access for integration | [Inference API](./docs/en/INFERENCE.md) |
153
+ | 🌐 **REST API** | HTTP-based async API for services | [REST API](./docs/en/API.md) |
154
+
155
+ **📚 Documentation available in:** [English](./docs/en/) | [中文](./docs/zh/) | [日本語](./docs/ja/)
156
+
157
+
158
+ ## 🔨 Train
159
+
160
+ See the **LoRA Training** tab in Gradio UI for one-click training, or check [Gradio Guide - LoRA Training](./docs/en/GRADIO_GUIDE.md#lora-training) for details.
161
+
162
+ ## 🏗️ Architecture
163
+
164
+ <p align="center">
165
+ <img src="./assets/ACE-Step_framework.png" width="100%" alt="ACE-Step Framework">
166
+ </p>
167
+
168
+ ## 🦁 Model Zoo
169
+
170
+ <p align="center">
171
+ <img src="./assets/model_zoo.png" width="100%" alt="Model Zoo">
172
+ </p>
173
+
174
+ ### DiT Models
175
+
176
+ | DiT Model | Pre-Training | SFT | RL | CFG | Step | Refer audio | Text2Music | Cover | Repaint | Extract | Lego | Complete | Quality | Diversity | Fine-Tunability | Hugging Face |
177
+ |-----------|:------------:|:---:|:--:|:---:|:----:|:-----------:|:----------:|:-----:|:-------:|:-------:|:----:|:--------:|:-------:|:---------:|:---------------:|--------------|
178
+ | `acestep-v15-base` | ✅ | ❌ | ❌ | ✅ | 50 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | High | Easy | [Link](https://huggingface.co/ACE-Step/acestep-v15-base) |
179
+ | `acestep-v15-sft` | ✅ | ✅ | ❌ | ✅ | 50 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | High | Medium | Easy | [Link](https://huggingface.co/ACE-Step/acestep-v15-sft) |
180
+ | `acestep-v15-turbo` | ✅ | ✅ | ❌ | ❌ | 8 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | Very High | Medium | Medium | [Link](https://huggingface.co/ACE-Step/Ace-Step1.5) |
181
+ | `acestep-v15-turbo-rl` | ✅ | ✅ | ✅ | ❌ | 8 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | Very High | Medium | Medium | To be released |
182
+
183
+ ### LM Models
184
+
185
+ | LM Model | Pretrain from | Pre-Training | SFT | RL | CoT metas | Query rewrite | Audio Understanding | Composition Capability | Copy Melody | Hugging Face |
186
+ |----------|---------------|:------------:|:---:|:--:|:---------:|:-------------:|:-------------------:|:----------------------:|:-----------:|--------------|
187
+ | `acestep-5Hz-lm-0.6B` | Qwen3-0.6B | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | Medium | Weak | ✅ |
188
+ | `acestep-5Hz-lm-1.7B` | Qwen3-1.7B | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | Medium | Medium | ✅ |
189
+ | `acestep-5Hz-lm-4B` | Qwen3-4B | ✅ | ✅ | ✅ | ✅ | ✅ | Strong | Strong | Strong | To be released |
190
+
191
+ ## 📜 License & Disclaimer
192
+
193
+ This project is licensed under [MIT](./LICENSE)
194
+
195
+ ACE-Step enables original music generation across diverse genres, with applications in creative production, education, and entertainment. While designed to support positive and artistic use cases, we acknowledge potential risks such as unintentional copyright infringement due to stylistic similarity, inappropriate blending of cultural elements, and misuse for generating harmful content. To ensure responsible use, we encourage users to verify the originality of generated works, clearly disclose AI involvement, and obtain appropriate permissions when adapting protected styles or materials. By using ACE-Step, you agree to uphold these principles and respect artistic integrity, cultural diversity, and legal compliance. The authors are not responsible for any misuse of the model, including but not limited to copyright violations, cultural insensitivity, or the generation of harmful content.
196
+
197
+ 🔔 Important Notice
198
+ The only official website for the ACE-Step project is our GitHub Pages site.
199
+ We do not operate any other websites.
200
+ 🚫 Fake domains include but are not limited to:
201
+ ac\*\*p.com, a\*\*p.org, a\*\*\*c.org
202
+ ⚠️ Please be cautious. Do not visit, trust, or make payments on any of those sites.
203
+
204
+ ## 🙏 Acknowledgements
205
+
206
+ This project is co-led by ACE Studio and StepFun.
207
+
208
+
209
+ ## 📖 Citation
210
+
211
+ If you find this project useful for your research, please consider citing:
212
+
213
+ ```BibTeX
214
+ @misc{gong2026acestep,
215
+ title={ACE-Step 1.5: Pushing the Boundaries of Open-Source Music Generation},
216
+ author={Junmin Gong, Yulin Song, Wenxiao Zhao, Sen Wang, Shengyuan Xu, Jing Guo},
217
+ howpublished={\url{https://github.com/ace-step/ACE-Step-1.5}},
218
+ year={2026},
219
+ note={GitHub repository}
220
+ }
221
+ ```
code/acestep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ACE-Step package."""
code/acestep/acestep_v15_pipeline.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step V1.5 Pipeline
3
+ Handler wrapper connecting model and UI
4
+ """
5
+ import os
6
+ import sys
7
+
8
+ # Load environment variables from .env file in project root
9
+ # This allows configuration without hardcoding values
10
+ # Falls back to .env.example if .env is not found
11
+ try:
12
+ from dotenv import load_dotenv
13
+ # Get project root directory
14
+ _current_file = os.path.abspath(__file__)
15
+ _project_root = os.path.dirname(os.path.dirname(_current_file))
16
+ _env_path = os.path.join(_project_root, '.env')
17
+ _env_example_path = os.path.join(_project_root, '.env.example')
18
+
19
+ if os.path.exists(_env_path):
20
+ load_dotenv(_env_path)
21
+ print(f"Loaded configuration from {_env_path}")
22
+ elif os.path.exists(_env_example_path):
23
+ load_dotenv(_env_example_path)
24
+ print(f"Loaded configuration from {_env_example_path} (fallback)")
25
+ except ImportError:
26
+ # python-dotenv not installed, skip loading .env
27
+ pass
28
+
29
+ # Clear proxy settings that may affect Gradio
30
+ for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
31
+ os.environ.pop(proxy_var, None)
32
+
33
+ try:
34
+ # When executed as a module: `python -m acestep.acestep_v15_pipeline`
35
+ from .handler import AceStepHandler
36
+ from .llm_inference import LLMHandler
37
+ from .dataset_handler import DatasetHandler
38
+ from .gradio_ui import create_gradio_interface
39
+ except ImportError:
40
+ # When executed as a script: `python acestep/acestep_v15_pipeline.py`
41
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
42
+ if project_root not in sys.path:
43
+ sys.path.insert(0, project_root)
44
+ from acestep.handler import AceStepHandler
45
+ from acestep.llm_inference import LLMHandler
46
+ from acestep.dataset_handler import DatasetHandler
47
+ from acestep.gradio_ui import create_gradio_interface
48
+
49
+
50
+ def create_demo(init_params=None, language='en'):
51
+ """
52
+ Create Gradio demo interface
53
+
54
+ Args:
55
+ init_params: Dictionary containing initialization parameters and state.
56
+ If None, service will not be pre-initialized.
57
+ Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
58
+ 'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
59
+ 'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
60
+ 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
61
+ 'language' (UI language code)
62
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
63
+
64
+ Returns:
65
+ Gradio Blocks instance
66
+ """
67
+ # Use pre-initialized handlers if available, otherwise create new ones
68
+ if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
69
+ dit_handler = init_params['dit_handler']
70
+ llm_handler = init_params['llm_handler']
71
+ else:
72
+ dit_handler = AceStepHandler() # DiT handler
73
+ llm_handler = LLMHandler() # LM handler
74
+
75
+ dataset_handler = DatasetHandler() # Dataset handler
76
+
77
+ # Create Gradio interface with all handlers and initialization parameters
78
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
79
+
80
+ return demo
81
+
82
+
83
+ def get_gpu_memory_gb():
84
+ """
85
+ Get GPU memory in GB. Returns 0 if no GPU is available.
86
+ """
87
+ try:
88
+ import torch
89
+ if torch.cuda.is_available():
90
+ # Get total memory of the first GPU in GB
91
+ total_memory = torch.cuda.get_device_properties(0).total_memory
92
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
93
+ return memory_gb
94
+ else:
95
+ return 0
96
+ except Exception as e:
97
+ print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
98
+ return 0
99
+
100
+
101
+ def main():
102
+ """Main entry function"""
103
+ import argparse
104
+
105
+ # Detect GPU memory to auto-configure offload settings
106
+ gpu_memory_gb = get_gpu_memory_gb()
107
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
108
+
109
+ if auto_offload:
110
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
111
+ print("Auto-enabling CPU offload to reduce GPU memory usage")
112
+ elif gpu_memory_gb > 0:
113
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
114
+ print("CPU offload disabled by default")
115
+ else:
116
+ print("No GPU detected, running on CPU")
117
+
118
+ parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
119
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
120
+ parser.add_argument("--share", action="store_true", help="Create a public link")
121
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
122
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
123
+ parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
124
+
125
+ # Service mode argument
126
+ parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
127
+ help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
128
+
129
+ # Service initialization arguments
130
+ parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
131
+ parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
132
+ parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
133
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
134
+ parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
135
+ parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
136
+ parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
137
+ parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
138
+ parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
139
+ parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
140
+
141
+ args = parser.parse_args()
142
+
143
+ # Service mode defaults (can be configured via .env file)
144
+ if args.service_mode:
145
+ print("Service mode enabled - applying preset configurations...")
146
+ # Force init_service in service mode
147
+ args.init_service = True
148
+ # Default DiT model for service mode (from env or fallback)
149
+ if args.config_path is None:
150
+ args.config_path = os.environ.get(
151
+ "SERVICE_MODE_DIT_MODEL",
152
+ "acestep-v15-turbo-fix-inst-shift-dynamic"
153
+ )
154
+ # Default LM model for service mode (from env or fallback)
155
+ if args.lm_model_path is None:
156
+ args.lm_model_path = os.environ.get(
157
+ "SERVICE_MODE_LM_MODEL",
158
+ "acestep-5Hz-lm-1.7B-v4-fix"
159
+ )
160
+ # Backend for service mode (from env or fallback to vllm)
161
+ args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
162
+ print(f" DiT model: {args.config_path}")
163
+ print(f" LM model: {args.lm_model_path}")
164
+ print(f" Backend: {args.backend}")
165
+
166
+ try:
167
+ init_params = None
168
+
169
+ # If init_service is True, perform initialization before creating UI
170
+ if args.init_service:
171
+ print("Initializing service from command line...")
172
+
173
+ # Create handler instances for initialization
174
+ dit_handler = AceStepHandler()
175
+ llm_handler = LLMHandler()
176
+
177
+ # Auto-select config_path if not provided
178
+ if args.config_path is None:
179
+ available_models = dit_handler.get_available_acestep_v15_models()
180
+ if available_models:
181
+ args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
182
+ print(f"Auto-selected config_path: {args.config_path}")
183
+ else:
184
+ print("Error: No available models found. Please specify --config_path", file=sys.stderr)
185
+ sys.exit(1)
186
+
187
+ # Get project root (same logic as in handler)
188
+ current_file = os.path.abspath(__file__)
189
+ project_root = os.path.dirname(os.path.dirname(current_file))
190
+
191
+ # Determine flash attention setting
192
+ use_flash_attention = args.use_flash_attention
193
+ if use_flash_attention is None:
194
+ use_flash_attention = dit_handler.is_flash_attention_available()
195
+
196
+ # Initialize DiT handler
197
+ print(f"Initializing DiT model: {args.config_path} on {args.device}...")
198
+ init_status, enable_generate = dit_handler.initialize_service(
199
+ project_root=project_root,
200
+ config_path=args.config_path,
201
+ device=args.device,
202
+ use_flash_attention=use_flash_attention,
203
+ compile_model=False,
204
+ offload_to_cpu=args.offload_to_cpu,
205
+ offload_dit_to_cpu=args.offload_dit_to_cpu
206
+ )
207
+
208
+ if not enable_generate:
209
+ print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
210
+ sys.exit(1)
211
+
212
+ print(f"DiT model initialized successfully")
213
+
214
+ # Initialize LM handler if requested
215
+ lm_status = ""
216
+ if args.init_llm:
217
+ if args.lm_model_path is None:
218
+ # Try to get default LM model
219
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
220
+ if available_lm_models:
221
+ args.lm_model_path = available_lm_models[0]
222
+ print(f"Using default LM model: {args.lm_model_path}")
223
+ else:
224
+ print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
225
+ args.init_llm = False
226
+
227
+ if args.init_llm and args.lm_model_path:
228
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
229
+ print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
230
+ lm_status, lm_success = llm_handler.initialize(
231
+ checkpoint_dir=checkpoint_dir,
232
+ lm_model_path=args.lm_model_path,
233
+ backend=args.backend,
234
+ device=args.device,
235
+ offload_to_cpu=args.offload_to_cpu,
236
+ dtype=dit_handler.dtype
237
+ )
238
+
239
+ if lm_success:
240
+ print(f"5Hz LM initialized successfully")
241
+ init_status += f"\n{lm_status}"
242
+ else:
243
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
244
+ init_status += f"\n{lm_status}"
245
+
246
+ # Prepare initialization parameters for UI
247
+ init_params = {
248
+ 'pre_initialized': True,
249
+ 'service_mode': args.service_mode,
250
+ 'checkpoint': args.checkpoint,
251
+ 'config_path': args.config_path,
252
+ 'device': args.device,
253
+ 'init_llm': args.init_llm,
254
+ 'lm_model_path': args.lm_model_path,
255
+ 'backend': args.backend,
256
+ 'use_flash_attention': use_flash_attention,
257
+ 'offload_to_cpu': args.offload_to_cpu,
258
+ 'offload_dit_to_cpu': args.offload_dit_to_cpu,
259
+ 'init_status': init_status,
260
+ 'enable_generate': enable_generate,
261
+ 'dit_handler': dit_handler,
262
+ 'llm_handler': llm_handler,
263
+ 'language': args.language
264
+ }
265
+
266
+ print("Service initialization completed successfully!")
267
+
268
+ # Create and launch demo
269
+ print(f"Creating Gradio interface with language: {args.language}...")
270
+ demo = create_demo(init_params=init_params, language=args.language)
271
+
272
+ # Enable queue for multi-user support
273
+ # This ensures proper request queuing and prevents concurrent generation conflicts
274
+ print("Enabling queue for multi-user support...")
275
+ demo.queue(
276
+ max_size=20, # Maximum queue size (adjust based on your needs)
277
+ status_update_rate="auto", # Update rate for queue status
278
+ )
279
+
280
+ print(f"Launching server on {args.server_name}:{args.port}...")
281
+ demo.launch(
282
+ server_name=args.server_name,
283
+ server_port=args.port,
284
+ share=args.share,
285
+ debug=args.debug,
286
+ show_error=True,
287
+ prevent_thread_lock=False, # Keep thread locked to maintain server running
288
+ inbrowser=False, # Don't auto-open browser
289
+ )
290
+ except Exception as e:
291
+ print(f"Error launching Gradio: {e}", file=sys.stderr)
292
+ import traceback
293
+ traceback.print_exc()
294
+ sys.exit(1)
295
+
296
+
297
+ if __name__ == "__main__":
298
+ main()
code/acestep/api_server.py ADDED
@@ -0,0 +1,1725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for ACE-Step V1.5.
2
+
3
+ Endpoints:
4
+ - POST /release_task Create music generation task
5
+ - POST /query_result Batch query task results
6
+ - GET /v1/models List available models
7
+ - GET /v1/audio Download audio file
8
+ - GET /health Health check
9
+
10
+ NOTE:
11
+ - In-memory queue and job store -> run uvicorn with workers=1.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import json
18
+ import os
19
+ import sys
20
+ import time
21
+ import traceback
22
+ import tempfile
23
+ import urllib.parse
24
+ from collections import deque
25
+ from concurrent.futures import ThreadPoolExecutor
26
+ from contextlib import asynccontextmanager
27
+ from dataclasses import dataclass
28
+ from pathlib import Path
29
+ from threading import Lock
30
+ from typing import Any, Dict, List, Literal, Optional
31
+ from uuid import uuid4
32
+
33
+ try:
34
+ from dotenv import load_dotenv
35
+ except ImportError: # Optional dependency
36
+ load_dotenv = None # type: ignore
37
+
38
+ from fastapi import FastAPI, HTTPException, Request
39
+ from pydantic import BaseModel, Field
40
+ from starlette.datastructures import UploadFile as StarletteUploadFile
41
+
42
+ from acestep.handler import AceStepHandler
43
+ from acestep.llm_inference import LLMHandler
44
+ from acestep.constants import (
45
+ DEFAULT_DIT_INSTRUCTION,
46
+ DEFAULT_LM_INSTRUCTION,
47
+ TASK_INSTRUCTIONS,
48
+ )
49
+ from acestep.inference import (
50
+ GenerationParams,
51
+ GenerationConfig,
52
+ generate_music,
53
+ create_sample,
54
+ format_sample,
55
+ )
56
+ from acestep.gradio_ui.events.results_handlers import _build_generation_info
57
+
58
+
59
+ # =============================================================================
60
+ # Constants
61
+ # =============================================================================
62
+
63
+ RESULT_KEY_PREFIX = "ace_step_v1.5_"
64
+ RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days
65
+ TASK_TIMEOUT_SECONDS = 3600 # 1 hour
66
+ JOB_STORE_CLEANUP_INTERVAL = 300 # 5 minutes - interval for cleaning up old jobs
67
+ JOB_STORE_MAX_AGE_SECONDS = 86400 # 24 hours - completed jobs older than this will be cleaned
68
+ STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2}
69
+
70
+ LM_DEFAULT_TEMPERATURE = 0.85
71
+ LM_DEFAULT_CFG_SCALE = 2.5
72
+ LM_DEFAULT_TOP_P = 0.9
73
+
74
+ # Parameter aliases for request parsing
75
+ PARAM_ALIASES = {
76
+ "prompt": ["prompt"],
77
+ "sample_mode": ["sample_mode", "sampleMode"],
78
+ "sample_query": ["sample_query", "sampleQuery", "description", "desc"],
79
+ "use_format": ["use_format", "useFormat", "format"],
80
+ "model": ["model", "dit_model", "ditModel"],
81
+ "key_scale": ["key_scale", "keyscale", "keyScale"],
82
+ "time_signature": ["time_signature", "timesignature", "timeSignature"],
83
+ "audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"],
84
+ "vocal_language": ["vocal_language", "vocalLanguage"],
85
+ "inference_steps": ["inference_steps", "inferenceSteps"],
86
+ "guidance_scale": ["guidance_scale", "guidanceScale"],
87
+ "use_random_seed": ["use_random_seed", "useRandomSeed"],
88
+ "audio_code_string": ["audio_code_string", "audioCodeString"],
89
+ "audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"],
90
+ "task_type": ["task_type", "taskType"],
91
+ "infer_method": ["infer_method", "inferMethod"],
92
+ "use_tiled_decode": ["use_tiled_decode", "useTiledDecode"],
93
+ "constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"],
94
+ "constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"],
95
+ "use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"],
96
+ "use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
97
+ "is_format_caption": ["is_format_caption", "isFormatCaption"],
98
+ }
99
+
100
+
101
+ def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
102
+ """
103
+ Parse a description string to extract language code and instrumental flag.
104
+
105
+ This function analyzes user descriptions like "Pop rock. English" or "piano solo"
106
+ to detect:
107
+ - Language: Maps language names to ISO codes (e.g., "English" -> "en")
108
+ - Instrumental: Detects patterns indicating instrumental/no-vocal music
109
+
110
+ Args:
111
+ description: User's natural language music description
112
+
113
+ Returns:
114
+ (language_code, is_instrumental) tuple:
115
+ - language_code: ISO language code (e.g., "en", "zh") or None if not detected
116
+ - is_instrumental: True if description indicates instrumental music
117
+ """
118
+ import re
119
+
120
+ if not description:
121
+ return None, False
122
+
123
+ description_lower = description.lower().strip()
124
+
125
+ # Language mapping: input patterns -> ISO code
126
+ language_mapping = {
127
+ 'english': 'en', 'en': 'en',
128
+ 'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
129
+ 'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
130
+ 'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
131
+ 'spanish': 'es', 'español': 'es', 'es': 'es',
132
+ 'french': 'fr', 'français': 'fr', 'fr': 'fr',
133
+ 'german': 'de', 'deutsch': 'de', 'de': 'de',
134
+ 'italian': 'it', 'italiano': 'it', 'it': 'it',
135
+ 'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
136
+ 'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
137
+ 'bengali': 'bn', 'bn': 'bn',
138
+ 'hindi': 'hi', 'hi': 'hi',
139
+ 'arabic': 'ar', 'ar': 'ar',
140
+ 'thai': 'th', 'th': 'th',
141
+ 'vietnamese': 'vi', 'vi': 'vi',
142
+ 'indonesian': 'id', 'id': 'id',
143
+ 'turkish': 'tr', 'tr': 'tr',
144
+ 'dutch': 'nl', 'nl': 'nl',
145
+ 'polish': 'pl', 'pl': 'pl',
146
+ }
147
+
148
+ # Detect language
149
+ detected_language = None
150
+ for lang_name, lang_code in language_mapping.items():
151
+ if len(lang_name) <= 2:
152
+ pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
153
+ else:
154
+ pattern = r'\b' + re.escape(lang_name) + r'\b'
155
+
156
+ if re.search(pattern, description_lower):
157
+ detected_language = lang_code
158
+ break
159
+
160
+ # Detect instrumental
161
+ is_instrumental = False
162
+ if 'instrumental' in description_lower:
163
+ is_instrumental = True
164
+ elif 'pure music' in description_lower or 'pure instrument' in description_lower:
165
+ is_instrumental = True
166
+ elif description_lower.endswith(' solo') or description_lower == 'solo':
167
+ is_instrumental = True
168
+
169
+ return detected_language, is_instrumental
170
+
171
+
172
+ JobStatus = Literal["queued", "running", "succeeded", "failed"]
173
+
174
+
175
+ class GenerateMusicRequest(BaseModel):
176
+ prompt: str = Field(default="", description="Text prompt describing the music")
177
+ lyrics: str = Field(default="", description="Lyric text")
178
+
179
+ # New API semantics:
180
+ # - thinking=True: use 5Hz LM to generate audio codes (lm-dit behavior)
181
+ # - thinking=False: do not use LM to generate codes (dit behavior)
182
+ # Regardless of thinking, if some metas are missing, server may use LM to fill them.
183
+ thinking: bool = False
184
+ # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
185
+ sample_mode: bool = False
186
+ # Description for sample mode: auto-generate caption/lyrics from description query
187
+ sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
188
+ # Whether to use format_sample() to enhance input caption/lyrics
189
+ use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
190
+ # Model name for multi-model support (select which DiT model to use)
191
+ model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
192
+
193
+ bpm: Optional[int] = None
194
+ # Accept common client keys via manual parsing (see RequestParser).
195
+ key_scale: str = ""
196
+ time_signature: str = ""
197
+ vocal_language: str = "en"
198
+ inference_steps: int = 8
199
+ guidance_scale: float = 7.0
200
+ use_random_seed: bool = True
201
+ seed: int = -1
202
+
203
+ reference_audio_path: Optional[str] = None
204
+ src_audio_path: Optional[str] = None
205
+ audio_duration: Optional[float] = None
206
+ batch_size: Optional[int] = None
207
+
208
+ audio_code_string: str = ""
209
+
210
+ repainting_start: float = 0.0
211
+ repainting_end: Optional[float] = None
212
+
213
+ instruction: str = DEFAULT_DIT_INSTRUCTION
214
+ audio_cover_strength: float = 1.0
215
+ task_type: str = "text2music"
216
+
217
+ use_adg: bool = False
218
+ cfg_interval_start: float = 0.0
219
+ cfg_interval_end: float = 1.0
220
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
221
+ shift: float = Field(
222
+ default=3.0,
223
+ description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
224
+ )
225
+ timesteps: Optional[str] = Field(
226
+ default=None,
227
+ description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
228
+ )
229
+
230
+ audio_format: str = "mp3"
231
+ use_tiled_decode: bool = True
232
+
233
+ # 5Hz LM (server-side): used for metadata completion and (when thinking=True) codes generation.
234
+ lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
235
+ lm_backend: Literal["vllm", "pt"] = "vllm"
236
+
237
+ constrained_decoding: bool = True
238
+ constrained_decoding_debug: bool = False
239
+ use_cot_caption: bool = True
240
+ use_cot_language: bool = True
241
+ is_format_caption: bool = False
242
+
243
+ lm_temperature: float = 0.85
244
+ lm_cfg_scale: float = 2.5
245
+ lm_top_k: Optional[int] = None
246
+ lm_top_p: Optional[float] = 0.9
247
+ lm_repetition_penalty: float = 1.0
248
+ lm_negative_prompt: str = "NO USER INPUT"
249
+
250
+ class Config:
251
+ allow_population_by_field_name = True
252
+ allow_population_by_alias = True
253
+
254
+
255
+ class CreateJobResponse(BaseModel):
256
+ task_id: str
257
+ status: JobStatus
258
+ queue_position: int = 0 # 1-based best-effort position when queued
259
+
260
+
261
+ class JobResult(BaseModel):
262
+ first_audio_path: Optional[str] = None
263
+ second_audio_path: Optional[str] = None
264
+ audio_paths: list[str] = Field(default_factory=list)
265
+
266
+ generation_info: str = ""
267
+ status_message: str = ""
268
+ seed_value: str = ""
269
+
270
+ metas: Dict[str, Any] = Field(default_factory=dict)
271
+ bpm: Optional[int] = None
272
+ duration: Optional[float] = None
273
+ genres: Optional[str] = None
274
+ keyscale: Optional[str] = None
275
+ timesignature: Optional[str] = None
276
+
277
+ # Model information
278
+ lm_model: Optional[str] = None
279
+ dit_model: Optional[str] = None
280
+
281
+
282
+ class JobResponse(BaseModel):
283
+ job_id: str
284
+ status: JobStatus
285
+ created_at: float
286
+ started_at: Optional[float] = None
287
+ finished_at: Optional[float] = None
288
+
289
+ # queue observability
290
+ queue_position: int = 0
291
+ eta_seconds: Optional[float] = None
292
+ avg_job_seconds: Optional[float] = None
293
+
294
+ result: Optional[JobResult] = None
295
+ error: Optional[str] = None
296
+
297
+
298
+ @dataclass
299
+ class _JobRecord:
300
+ job_id: str
301
+ status: JobStatus
302
+ created_at: float
303
+ started_at: Optional[float] = None
304
+ finished_at: Optional[float] = None
305
+ result: Optional[Dict[str, Any]] = None
306
+ error: Optional[str] = None
307
+ env: str = "development"
308
+
309
+
310
+ class _JobStore:
311
+ def __init__(self, max_age_seconds: int = JOB_STORE_MAX_AGE_SECONDS) -> None:
312
+ self._lock = Lock()
313
+ self._jobs: Dict[str, _JobRecord] = {}
314
+ self._max_age = max_age_seconds
315
+
316
+ def create(self) -> _JobRecord:
317
+ job_id = str(uuid4())
318
+ rec = _JobRecord(job_id=job_id, status="queued", created_at=time.time())
319
+ with self._lock:
320
+ self._jobs[job_id] = rec
321
+ return rec
322
+
323
+ def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord:
324
+ """Create job record with specified ID"""
325
+ rec = _JobRecord(
326
+ job_id=job_id,
327
+ status="queued",
328
+ created_at=time.time(),
329
+ env=env
330
+ )
331
+ with self._lock:
332
+ self._jobs[job_id] = rec
333
+ return rec
334
+
335
+ def get(self, job_id: str) -> Optional[_JobRecord]:
336
+ with self._lock:
337
+ return self._jobs.get(job_id)
338
+
339
+ def mark_running(self, job_id: str) -> None:
340
+ with self._lock:
341
+ rec = self._jobs[job_id]
342
+ rec.status = "running"
343
+ rec.started_at = time.time()
344
+
345
+ def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None:
346
+ with self._lock:
347
+ rec = self._jobs[job_id]
348
+ rec.status = "succeeded"
349
+ rec.finished_at = time.time()
350
+ rec.result = result
351
+ rec.error = None
352
+
353
+ def mark_failed(self, job_id: str, error: str) -> None:
354
+ with self._lock:
355
+ rec = self._jobs[job_id]
356
+ rec.status = "failed"
357
+ rec.finished_at = time.time()
358
+ rec.result = None
359
+ rec.error = error
360
+
361
+ def cleanup_old_jobs(self, max_age_seconds: Optional[int] = None) -> int:
362
+ """
363
+ Clean up completed jobs older than max_age_seconds.
364
+
365
+ Only removes jobs with status 'succeeded' or 'failed'.
366
+ Jobs that are 'queued' or 'running' are never removed.
367
+
368
+ Returns the number of jobs removed.
369
+ """
370
+ max_age = max_age_seconds if max_age_seconds is not None else self._max_age
371
+ now = time.time()
372
+ removed = 0
373
+
374
+ with self._lock:
375
+ to_remove = []
376
+ for job_id, rec in self._jobs.items():
377
+ if rec.status in ("succeeded", "failed"):
378
+ finish_time = rec.finished_at or rec.created_at
379
+ age = now - finish_time
380
+ if age > max_age:
381
+ to_remove.append(job_id)
382
+
383
+ for job_id in to_remove:
384
+ del self._jobs[job_id]
385
+ removed += 1
386
+
387
+ return removed
388
+
389
+ def get_stats(self) -> Dict[str, int]:
390
+ """Get statistics about jobs in the store."""
391
+ with self._lock:
392
+ stats = {
393
+ "total": len(self._jobs),
394
+ "queued": 0,
395
+ "running": 0,
396
+ "succeeded": 0,
397
+ "failed": 0,
398
+ }
399
+ for rec in self._jobs.values():
400
+ if rec.status in stats:
401
+ stats[rec.status] += 1
402
+ return stats
403
+
404
+
405
+ def _env_bool(name: str, default: bool) -> bool:
406
+ v = os.getenv(name)
407
+ if v is None:
408
+ return default
409
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
410
+
411
+
412
+ def _get_project_root() -> str:
413
+ current_file = os.path.abspath(__file__)
414
+ return os.path.dirname(os.path.dirname(current_file))
415
+
416
+
417
+ def _get_model_name(config_path: str) -> str:
418
+ """
419
+ Extract model name from config_path.
420
+
421
+ Args:
422
+ config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
423
+
424
+ Returns:
425
+ Model name (last directory name from config_path)
426
+ """
427
+ if not config_path:
428
+ return ""
429
+ normalized = config_path.rstrip("/\\")
430
+ return os.path.basename(normalized)
431
+
432
+
433
+ def _load_project_env() -> None:
434
+ if load_dotenv is None:
435
+ return
436
+ try:
437
+ project_root = _get_project_root()
438
+ env_path = os.path.join(project_root, ".env")
439
+ if os.path.exists(env_path):
440
+ load_dotenv(env_path, override=False)
441
+ except Exception:
442
+ # Optional best-effort: continue even if .env loading fails.
443
+ pass
444
+
445
+
446
+ _load_project_env()
447
+
448
+
449
+ def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
450
+ if v is None:
451
+ return default
452
+ if isinstance(v, int):
453
+ return v
454
+ s = str(v).strip()
455
+ if s == "":
456
+ return default
457
+ try:
458
+ return int(s)
459
+ except Exception:
460
+ return default
461
+
462
+
463
+ def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]:
464
+ if v is None:
465
+ return default
466
+ if isinstance(v, float):
467
+ return v
468
+ s = str(v).strip()
469
+ if s == "":
470
+ return default
471
+ try:
472
+ return float(s)
473
+ except Exception:
474
+ return default
475
+
476
+
477
+ def _to_bool(v: Any, default: bool = False) -> bool:
478
+ if v is None:
479
+ return default
480
+ if isinstance(v, bool):
481
+ return v
482
+ s = str(v).strip().lower()
483
+ if s == "":
484
+ return default
485
+ return s in {"1", "true", "yes", "y", "on"}
486
+
487
+
488
+ def _map_status(status: str) -> int:
489
+ """Map job status string to integer code."""
490
+ return STATUS_MAP.get(status, 2)
491
+
492
+
493
+ def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]:
494
+ """Parse comma-separated timesteps string to list of floats."""
495
+ if not s or not s.strip():
496
+ return None
497
+ try:
498
+ return [float(t.strip()) for t in s.split(",") if t.strip()]
499
+ except (ValueError, Exception):
500
+ return None
501
+
502
+
503
+ def _is_instrumental(lyrics: str) -> bool:
504
+ """
505
+ Determine if the music should be instrumental based on lyrics.
506
+
507
+ Returns True if:
508
+ - lyrics is empty or whitespace only
509
+ - lyrics (lowercased and trimmed) is "[inst]" or "[instrumental]"
510
+ """
511
+ if not lyrics:
512
+ return True
513
+ lyrics_clean = lyrics.strip().lower()
514
+ if not lyrics_clean:
515
+ return True
516
+ return lyrics_clean in ("[inst]", "[instrumental]")
517
+
518
+
519
+ class RequestParser:
520
+ """Parse request parameters from multiple sources with alias support."""
521
+
522
+ def __init__(self, raw: dict):
523
+ self._raw = dict(raw) if raw else {}
524
+ self._param_obj = self._parse_json(self._raw.get("param_obj"))
525
+ self._metas = self._find_metas()
526
+
527
+ def _parse_json(self, v) -> dict:
528
+ if isinstance(v, dict):
529
+ return v
530
+ if isinstance(v, str) and v.strip():
531
+ try:
532
+ return json.loads(v)
533
+ except Exception:
534
+ pass
535
+ return {}
536
+
537
+ def _find_metas(self) -> dict:
538
+ for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"):
539
+ v = self._raw.get(key)
540
+ if v:
541
+ return self._parse_json(v)
542
+ return {}
543
+
544
+ def get(self, name: str, default=None):
545
+ """Get parameter by canonical name from all sources."""
546
+ aliases = PARAM_ALIASES.get(name, [name])
547
+ for source in (self._raw, self._param_obj, self._metas):
548
+ for alias in aliases:
549
+ v = source.get(alias)
550
+ if v is not None:
551
+ return v
552
+ return default
553
+
554
+ def str(self, name: str, default: str = "") -> str:
555
+ v = self.get(name)
556
+ return str(v) if v is not None else default
557
+
558
+ def int(self, name: str, default: Optional[int] = None) -> Optional[int]:
559
+ return _to_int(self.get(name), default)
560
+
561
+ def float(self, name: str, default: Optional[float] = None) -> Optional[float]:
562
+ return _to_float(self.get(name), default)
563
+
564
+ def bool(self, name: str, default: bool = False) -> bool:
565
+ return _to_bool(self.get(name), default)
566
+
567
+
568
+ async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
569
+ suffix = Path(upload.filename or "").suffix
570
+ fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
571
+ os.close(fd)
572
+ try:
573
+ with open(path, "wb") as f:
574
+ while True:
575
+ chunk = await upload.read(1024 * 1024)
576
+ if not chunk:
577
+ break
578
+ f.write(chunk)
579
+ except Exception:
580
+ try:
581
+ os.remove(path)
582
+ except Exception:
583
+ pass
584
+ raise
585
+ finally:
586
+ try:
587
+ await upload.close()
588
+ except Exception:
589
+ pass
590
+ return path
591
+
592
+
593
+ def create_app() -> FastAPI:
594
+ store = _JobStore()
595
+
596
+ QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
597
+ WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended
598
+
599
+ INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
600
+ AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
601
+
602
+ def _path_to_audio_url(path: str) -> str:
603
+ """Convert local file path to downloadable relative URL"""
604
+ if not path:
605
+ return path
606
+ if path.startswith("http://") or path.startswith("https://"):
607
+ return path
608
+ encoded_path = urllib.parse.quote(path, safe="")
609
+ return f"/v1/audio?path={encoded_path}"
610
+
611
+ @asynccontextmanager
612
+ async def lifespan(app: FastAPI):
613
+ # Clear proxy env that may affect downstream libs
614
+ for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
615
+ os.environ.pop(proxy_var, None)
616
+
617
+ # Ensure compilation/temp caches do not fill up small default /tmp.
618
+ # Triton/Inductor (and the system compiler) can create large temporary files.
619
+ project_root = _get_project_root()
620
+ cache_root = os.path.join(project_root, ".cache", "acestep")
621
+ tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
622
+ triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
623
+ inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
624
+
625
+ for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
626
+ try:
627
+ os.makedirs(p, exist_ok=True)
628
+ except Exception:
629
+ # Best-effort: do not block startup if directory creation fails.
630
+ pass
631
+
632
+ # Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
633
+ if os.getenv("ACESTEP_TMPDIR"):
634
+ os.environ["TMPDIR"] = tmp_root
635
+ os.environ["TEMP"] = tmp_root
636
+ os.environ["TMP"] = tmp_root
637
+ else:
638
+ os.environ.setdefault("TMPDIR", tmp_root)
639
+ os.environ.setdefault("TEMP", tmp_root)
640
+ os.environ.setdefault("TMP", tmp_root)
641
+
642
+ os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
643
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
644
+
645
+ handler = AceStepHandler()
646
+ llm_handler = LLMHandler()
647
+ init_lock = asyncio.Lock()
648
+ app.state._initialized = False
649
+ app.state._init_error = None
650
+ app.state._init_lock = init_lock
651
+
652
+ app.state.llm_handler = llm_handler
653
+ app.state._llm_initialized = False
654
+ app.state._llm_init_error = None
655
+ app.state._llm_init_lock = Lock()
656
+
657
+ # Multi-model support: secondary DiT handlers
658
+ handler2 = None
659
+ handler3 = None
660
+ config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
661
+ config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
662
+
663
+ if config_path2:
664
+ handler2 = AceStepHandler()
665
+ if config_path3:
666
+ handler3 = AceStepHandler()
667
+
668
+ app.state.handler2 = handler2
669
+ app.state.handler3 = handler3
670
+ app.state._initialized2 = False
671
+ app.state._initialized3 = False
672
+ app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
673
+ app.state._config_path2 = config_path2
674
+ app.state._config_path3 = config_path3
675
+
676
+ max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
677
+ executor = ThreadPoolExecutor(max_workers=max_workers)
678
+
679
+ # Queue & observability
680
+ app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req)
681
+ app.state.pending_ids = deque() # queued job_ids
682
+ app.state.pending_lock = asyncio.Lock()
683
+
684
+ # temp files per job (from multipart uploads)
685
+ app.state.job_temp_files = {} # job_id -> list[path]
686
+ app.state.job_temp_files_lock = asyncio.Lock()
687
+
688
+ # stats
689
+ app.state.stats_lock = asyncio.Lock()
690
+ app.state.recent_durations = deque(maxlen=AVG_WINDOW)
691
+ app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS
692
+
693
+ app.state.handler = handler
694
+ app.state.executor = executor
695
+ app.state.job_store = store
696
+ app.state._python_executable = sys.executable
697
+
698
+ # Temporary directory for saving generated audio files
699
+ app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
700
+ os.makedirs(app.state.temp_audio_dir, exist_ok=True)
701
+
702
+ # Initialize local cache
703
+ try:
704
+ from acestep.local_cache import get_local_cache
705
+ local_cache_dir = os.path.join(cache_root, "local_redis")
706
+ app.state.local_cache = get_local_cache(local_cache_dir)
707
+ except ImportError:
708
+ app.state.local_cache = None
709
+
710
+ async def _ensure_initialized() -> None:
711
+ h: AceStepHandler = app.state.handler
712
+
713
+ if getattr(app.state, "_initialized", False):
714
+ return
715
+ if getattr(app.state, "_init_error", None):
716
+ raise RuntimeError(app.state._init_error)
717
+
718
+ async with app.state._init_lock:
719
+ if getattr(app.state, "_initialized", False):
720
+ return
721
+ if getattr(app.state, "_init_error", None):
722
+ raise RuntimeError(app.state._init_error)
723
+
724
+ project_root = _get_project_root()
725
+ config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
726
+ device = os.getenv("ACESTEP_DEVICE", "auto")
727
+
728
+ use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
729
+ offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
730
+ offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
731
+
732
+ # Initialize primary model
733
+ status_msg, ok = h.initialize_service(
734
+ project_root=project_root,
735
+ config_path=config_path,
736
+ device=device,
737
+ use_flash_attention=use_flash_attention,
738
+ compile_model=False,
739
+ offload_to_cpu=offload_to_cpu,
740
+ offload_dit_to_cpu=offload_dit_to_cpu,
741
+ )
742
+ if not ok:
743
+ app.state._init_error = status_msg
744
+ raise RuntimeError(status_msg)
745
+ app.state._initialized = True
746
+
747
+ # Initialize secondary model if configured
748
+ if app.state.handler2 and app.state._config_path2:
749
+ try:
750
+ status_msg2, ok2 = app.state.handler2.initialize_service(
751
+ project_root=project_root,
752
+ config_path=app.state._config_path2,
753
+ device=device,
754
+ use_flash_attention=use_flash_attention,
755
+ compile_model=False,
756
+ offload_to_cpu=offload_to_cpu,
757
+ offload_dit_to_cpu=offload_dit_to_cpu,
758
+ )
759
+ app.state._initialized2 = ok2
760
+ if ok2:
761
+ print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
762
+ else:
763
+ print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
764
+ except Exception as e:
765
+ print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
766
+ app.state._initialized2 = False
767
+
768
+ # Initialize third model if configured
769
+ if app.state.handler3 and app.state._config_path3:
770
+ try:
771
+ status_msg3, ok3 = app.state.handler3.initialize_service(
772
+ project_root=project_root,
773
+ config_path=app.state._config_path3,
774
+ device=device,
775
+ use_flash_attention=use_flash_attention,
776
+ compile_model=False,
777
+ offload_to_cpu=offload_to_cpu,
778
+ offload_dit_to_cpu=offload_dit_to_cpu,
779
+ )
780
+ app.state._initialized3 = ok3
781
+ if ok3:
782
+ print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
783
+ else:
784
+ print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
785
+ except Exception as e:
786
+ print(f"[API Server] Warning: Failed to initialize third model: {e}")
787
+ app.state._initialized3 = False
788
+
789
+ async def _cleanup_job_temp_files(job_id: str) -> None:
790
+ async with app.state.job_temp_files_lock:
791
+ paths = app.state.job_temp_files.pop(job_id, [])
792
+ for p in paths:
793
+ try:
794
+ os.remove(p)
795
+ except Exception:
796
+ pass
797
+
798
+ def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None:
799
+ """Update local cache with job result"""
800
+ local_cache = getattr(app.state, 'local_cache', None)
801
+ if not local_cache:
802
+ return
803
+
804
+ rec = store.get(job_id)
805
+ env = getattr(rec, 'env', 'development') if rec else 'development'
806
+ create_time = rec.created_at if rec else time.time()
807
+
808
+ status_int = _map_status(status)
809
+
810
+ if status == "succeeded" and result:
811
+ audio_paths = result.get("audio_paths", [])
812
+ # Final prompt/lyrics (may be modified by thinking/format)
813
+ final_prompt = result.get("prompt", "")
814
+ final_lyrics = result.get("lyrics", "")
815
+ # Original user input from metas
816
+ metas_raw = result.get("metas", {}) or {}
817
+ original_prompt = metas_raw.get("prompt", "")
818
+ original_lyrics = metas_raw.get("lyrics", "")
819
+ # metas contains original input + other metadata
820
+ metas = {
821
+ "bpm": metas_raw.get("bpm"),
822
+ "duration": metas_raw.get("duration"),
823
+ "genres": metas_raw.get("genres", ""),
824
+ "keyscale": metas_raw.get("keyscale", ""),
825
+ "timesignature": metas_raw.get("timesignature", ""),
826
+ "prompt": original_prompt,
827
+ "lyrics": original_lyrics,
828
+ }
829
+ # Extra fields for Discord bot
830
+ generation_info = result.get("generation_info", "")
831
+ seed_value = result.get("seed_value", "")
832
+ lm_model = result.get("lm_model", "")
833
+ dit_model = result.get("dit_model", "")
834
+
835
+ if audio_paths:
836
+ result_data = [
837
+ {
838
+ "file": p,
839
+ "wave": "",
840
+ "status": status_int,
841
+ "create_time": int(create_time),
842
+ "env": env,
843
+ "prompt": final_prompt,
844
+ "lyrics": final_lyrics,
845
+ "metas": metas,
846
+ "generation_info": generation_info,
847
+ "seed_value": seed_value,
848
+ "lm_model": lm_model,
849
+ "dit_model": dit_model,
850
+ }
851
+ for p in audio_paths
852
+ ]
853
+ else:
854
+ result_data = [{
855
+ "file": "",
856
+ "wave": "",
857
+ "status": status_int,
858
+ "create_time": int(create_time),
859
+ "env": env,
860
+ "prompt": final_prompt,
861
+ "lyrics": final_lyrics,
862
+ "metas": metas,
863
+ "generation_info": generation_info,
864
+ "seed_value": seed_value,
865
+ "lm_model": lm_model,
866
+ "dit_model": dit_model,
867
+ }]
868
+ else:
869
+ result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
870
+
871
+ result_key = f"{RESULT_KEY_PREFIX}{job_id}"
872
+ local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS)
873
+
874
+ async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
875
+ job_store: _JobStore = app.state.job_store
876
+ llm: LLMHandler = app.state.llm_handler
877
+ executor: ThreadPoolExecutor = app.state.executor
878
+
879
+ await _ensure_initialized()
880
+ job_store.mark_running(job_id)
881
+
882
+ # Select DiT handler based on user's model choice
883
+ # Default: use primary handler
884
+ selected_handler: AceStepHandler = app.state.handler
885
+ selected_model_name = _get_model_name(app.state._config_path)
886
+
887
+ if req.model:
888
+ model_matched = False
889
+
890
+ # Check if it matches the second model
891
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
892
+ model2_name = _get_model_name(app.state._config_path2)
893
+ if req.model == model2_name:
894
+ selected_handler = app.state.handler2
895
+ selected_model_name = model2_name
896
+ model_matched = True
897
+ print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
898
+
899
+ # Check if it matches the third model
900
+ if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
901
+ model3_name = _get_model_name(app.state._config_path3)
902
+ if req.model == model3_name:
903
+ selected_handler = app.state.handler3
904
+ selected_model_name = model3_name
905
+ model_matched = True
906
+ print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
907
+
908
+ if not model_matched:
909
+ available_models = [_get_model_name(app.state._config_path)]
910
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
911
+ available_models.append(_get_model_name(app.state._config_path2))
912
+ if app.state.handler3 and getattr(app.state, "_initialized3", False):
913
+ available_models.append(_get_model_name(app.state._config_path3))
914
+ print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
915
+
916
+ # Use selected handler for generation
917
+ h: AceStepHandler = selected_handler
918
+
919
+ def _blocking_generate() -> Dict[str, Any]:
920
+ """Generate music using unified inference logic from acestep.inference"""
921
+
922
+ def _ensure_llm_ready() -> None:
923
+ """Ensure LLM handler is initialized when needed"""
924
+ with app.state._llm_init_lock:
925
+ initialized = getattr(app.state, "_llm_initialized", False)
926
+ had_error = getattr(app.state, "_llm_init_error", None)
927
+ if initialized or had_error is not None:
928
+ return
929
+
930
+ project_root = _get_project_root()
931
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
932
+ lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
933
+ backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
934
+ if backend not in {"vllm", "pt"}:
935
+ backend = "vllm"
936
+
937
+ lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
938
+ lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
939
+
940
+ status, ok = llm.initialize(
941
+ checkpoint_dir=checkpoint_dir,
942
+ lm_model_path=lm_model_path,
943
+ backend=backend,
944
+ device=lm_device,
945
+ offload_to_cpu=lm_offload,
946
+ dtype=h.dtype,
947
+ )
948
+ if not ok:
949
+ app.state._llm_init_error = status
950
+ else:
951
+ app.state._llm_initialized = True
952
+
953
+ def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
954
+ """Ensure a stable `metas` dict (keys always present)."""
955
+ meta = meta or {}
956
+ out: Dict[str, Any] = dict(meta)
957
+
958
+ # Normalize key aliases
959
+ if "keyscale" not in out and "key_scale" in out:
960
+ out["keyscale"] = out.get("key_scale")
961
+ if "timesignature" not in out and "time_signature" in out:
962
+ out["timesignature"] = out.get("time_signature")
963
+
964
+ # Ensure required keys exist
965
+ for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
966
+ if out.get(k) in (None, ""):
967
+ out[k] = "N/A"
968
+ return out
969
+
970
+ # Normalize LM sampling parameters
971
+ lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
972
+ lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
973
+
974
+ # Determine if LLM is needed
975
+ thinking = bool(req.thinking)
976
+ sample_mode = bool(req.sample_mode)
977
+ has_sample_query = bool(req.sample_query and req.sample_query.strip())
978
+ use_format = bool(req.use_format)
979
+ use_cot_caption = bool(req.use_cot_caption)
980
+ use_cot_language = bool(req.use_cot_language)
981
+
982
+ # LLM is needed for:
983
+ # - thinking mode (LM generates audio codes)
984
+ # - sample_mode (LM generates random caption/lyrics/metas)
985
+ # - sample_query/description (LM generates from description)
986
+ # - use_format (LM enhances caption/lyrics)
987
+ # - use_cot_caption or use_cot_language (LM enhances metadata)
988
+ need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
989
+
990
+ # Ensure LLM is ready if needed
991
+ if need_llm:
992
+ _ensure_llm_ready()
993
+ if getattr(app.state, "_llm_init_error", None):
994
+ raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
995
+
996
+ # Handle sample mode or description: generate caption/lyrics/metas via LM
997
+ caption = req.prompt
998
+ lyrics = req.lyrics
999
+ bpm = req.bpm
1000
+ key_scale = req.key_scale
1001
+ time_signature = req.time_signature
1002
+ audio_duration = req.audio_duration
1003
+
1004
+ # Save original user input for metas
1005
+ original_prompt = req.prompt or ""
1006
+ original_lyrics = req.lyrics or ""
1007
+
1008
+ if sample_mode or has_sample_query:
1009
+ # Parse description hints from sample_query (if provided)
1010
+ sample_query = req.sample_query if has_sample_query else "NO USER INPUT"
1011
+ parsed_language, parsed_instrumental = _parse_description_hints(sample_query)
1012
+
1013
+ # Determine vocal_language with priority:
1014
+ # 1. User-specified vocal_language (if not default "en")
1015
+ # 2. Language parsed from description
1016
+ # 3. None (no constraint)
1017
+ if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
1018
+ sample_language = req.vocal_language
1019
+ else:
1020
+ sample_language = parsed_language
1021
+
1022
+ sample_result = create_sample(
1023
+ llm_handler=llm,
1024
+ query=sample_query,
1025
+ instrumental=parsed_instrumental,
1026
+ vocal_language=sample_language,
1027
+ temperature=req.lm_temperature,
1028
+ top_k=lm_top_k if lm_top_k > 0 else None,
1029
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
1030
+ use_constrained_decoding=True,
1031
+ )
1032
+
1033
+ if not sample_result.success:
1034
+ raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
1035
+
1036
+ # Use generated sample data
1037
+ caption = sample_result.caption
1038
+ lyrics = sample_result.lyrics
1039
+ bpm = sample_result.bpm
1040
+ key_scale = sample_result.keyscale
1041
+ time_signature = sample_result.timesignature
1042
+ audio_duration = sample_result.duration
1043
+
1044
+ # Apply format_sample() if use_format is True and caption/lyrics are provided
1045
+ format_has_duration = False
1046
+
1047
+ if req.use_format and (caption or lyrics):
1048
+ _ensure_llm_ready()
1049
+ if getattr(app.state, "_llm_init_error", None):
1050
+ raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
1051
+
1052
+ # Build user_metadata from request params (matching bot.py behavior)
1053
+ user_metadata_for_format = {}
1054
+ if bpm is not None:
1055
+ user_metadata_for_format['bpm'] = bpm
1056
+ if audio_duration is not None and float(audio_duration) > 0:
1057
+ user_metadata_for_format['duration'] = float(audio_duration)
1058
+ if key_scale:
1059
+ user_metadata_for_format['keyscale'] = key_scale
1060
+ if time_signature:
1061
+ user_metadata_for_format['timesignature'] = time_signature
1062
+ if req.vocal_language and req.vocal_language != "unknown":
1063
+ user_metadata_for_format['language'] = req.vocal_language
1064
+
1065
+ format_result = format_sample(
1066
+ llm_handler=llm,
1067
+ caption=caption,
1068
+ lyrics=lyrics,
1069
+ user_metadata=user_metadata_for_format if user_metadata_for_format else None,
1070
+ temperature=req.lm_temperature,
1071
+ top_k=lm_top_k if lm_top_k > 0 else None,
1072
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
1073
+ use_constrained_decoding=True,
1074
+ )
1075
+
1076
+ if format_result.success:
1077
+ # Extract all formatted data (matching bot.py behavior)
1078
+ caption = format_result.caption or caption
1079
+ lyrics = format_result.lyrics or lyrics
1080
+ if format_result.duration:
1081
+ audio_duration = format_result.duration
1082
+ format_has_duration = True
1083
+ if format_result.bpm:
1084
+ bpm = format_result.bpm
1085
+ if format_result.keyscale:
1086
+ key_scale = format_result.keyscale
1087
+ if format_result.timesignature:
1088
+ time_signature = format_result.timesignature
1089
+
1090
+ # Parse timesteps string to list of floats if provided
1091
+ parsed_timesteps = _parse_timesteps(req.timesteps)
1092
+
1093
+ # Determine actual inference steps (timesteps override inference_steps)
1094
+ actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
1095
+
1096
+ # Auto-select instruction based on task_type if user didn't provide custom instruction
1097
+ # This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
1098
+ instruction_to_use = req.instruction
1099
+ if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
1100
+ instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
1101
+
1102
+ # Build GenerationParams using unified interface
1103
+ # Note: thinking controls LM code generation, sample_mode only affects CoT metas
1104
+ params = GenerationParams(
1105
+ task_type=req.task_type,
1106
+ instruction=instruction_to_use,
1107
+ reference_audio=req.reference_audio_path,
1108
+ src_audio=req.src_audio_path,
1109
+ audio_codes=req.audio_code_string,
1110
+ caption=caption,
1111
+ lyrics=lyrics,
1112
+ instrumental=_is_instrumental(lyrics),
1113
+ vocal_language=req.vocal_language,
1114
+ bpm=bpm,
1115
+ keyscale=key_scale,
1116
+ timesignature=time_signature,
1117
+ duration=audio_duration if audio_duration else -1.0,
1118
+ inference_steps=req.inference_steps,
1119
+ seed=req.seed,
1120
+ guidance_scale=req.guidance_scale,
1121
+ use_adg=req.use_adg,
1122
+ cfg_interval_start=req.cfg_interval_start,
1123
+ cfg_interval_end=req.cfg_interval_end,
1124
+ shift=req.shift,
1125
+ infer_method=req.infer_method,
1126
+ timesteps=parsed_timesteps,
1127
+ repainting_start=req.repainting_start,
1128
+ repainting_end=req.repainting_end if req.repainting_end else -1,
1129
+ audio_cover_strength=req.audio_cover_strength,
1130
+ # LM parameters
1131
+ thinking=thinking, # Use LM for code generation when thinking=True
1132
+ lm_temperature=req.lm_temperature,
1133
+ lm_cfg_scale=req.lm_cfg_scale,
1134
+ lm_top_k=lm_top_k,
1135
+ lm_top_p=lm_top_p,
1136
+ lm_negative_prompt=req.lm_negative_prompt,
1137
+ # use_cot_metas logic:
1138
+ # - sample_mode: metas already generated, skip Phase 1
1139
+ # - format with duration: metas already generated, skip Phase 1
1140
+ # - format without duration: need Phase 1 to generate duration
1141
+ # - no format: need Phase 1 to generate all metas
1142
+ use_cot_metas=not sample_mode and not format_has_duration,
1143
+ use_cot_caption=req.use_cot_caption,
1144
+ use_cot_language=req.use_cot_language,
1145
+ use_constrained_decoding=True,
1146
+ )
1147
+
1148
+ # Build GenerationConfig - default to 2 audios like gradio_ui
1149
+ batch_size = req.batch_size if req.batch_size is not None else 2
1150
+ config = GenerationConfig(
1151
+ batch_size=batch_size,
1152
+ use_random_seed=req.use_random_seed,
1153
+ seeds=None, # Let unified logic handle seed generation
1154
+ audio_format=req.audio_format,
1155
+ constrained_decoding_debug=req.constrained_decoding_debug,
1156
+ )
1157
+
1158
+ # Check LLM initialization status
1159
+ llm_is_initialized = getattr(app.state, "_llm_initialized", False)
1160
+ llm_to_pass = llm if llm_is_initialized else None
1161
+
1162
+ # Generate music using unified interface
1163
+ result = generate_music(
1164
+ dit_handler=h,
1165
+ llm_handler=llm_to_pass,
1166
+ params=params,
1167
+ config=config,
1168
+ save_dir=app.state.temp_audio_dir,
1169
+ progress=None,
1170
+ )
1171
+
1172
+ if not result.success:
1173
+ raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
1174
+
1175
+ # Extract results
1176
+ audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
1177
+ first_audio = audio_paths[0] if len(audio_paths) > 0 else None
1178
+ second_audio = audio_paths[1] if len(audio_paths) > 1 else None
1179
+
1180
+ # Get metadata from LM or CoT results
1181
+ lm_metadata = result.extra_outputs.get("lm_metadata", {})
1182
+ metas_out = _normalize_metas(lm_metadata)
1183
+
1184
+ # Update metas with actual values used
1185
+ if params.cot_bpm:
1186
+ metas_out["bpm"] = params.cot_bpm
1187
+ elif bpm:
1188
+ metas_out["bpm"] = bpm
1189
+
1190
+ if params.cot_duration:
1191
+ metas_out["duration"] = params.cot_duration
1192
+ elif audio_duration:
1193
+ metas_out["duration"] = audio_duration
1194
+
1195
+ if params.cot_keyscale:
1196
+ metas_out["keyscale"] = params.cot_keyscale
1197
+ elif key_scale:
1198
+ metas_out["keyscale"] = key_scale
1199
+
1200
+ if params.cot_timesignature:
1201
+ metas_out["timesignature"] = params.cot_timesignature
1202
+ elif time_signature:
1203
+ metas_out["timesignature"] = time_signature
1204
+
1205
+ # Store original user input in metas (not the final/modified values)
1206
+ metas_out["prompt"] = original_prompt
1207
+ metas_out["lyrics"] = original_lyrics
1208
+
1209
+ # Extract seed values for response (comma-separated for multiple audios)
1210
+ seed_values = []
1211
+ for audio in result.audios:
1212
+ audio_params = audio.get("params", {})
1213
+ seed = audio_params.get("seed")
1214
+ if seed is not None:
1215
+ seed_values.append(str(seed))
1216
+ seed_value = ",".join(seed_values) if seed_values else ""
1217
+
1218
+ # Build generation_info using the helper function (like gradio_ui)
1219
+ time_costs = result.extra_outputs.get("time_costs", {})
1220
+ generation_info = _build_generation_info(
1221
+ lm_metadata=lm_metadata,
1222
+ time_costs=time_costs,
1223
+ seed_value=seed_value,
1224
+ inference_steps=req.inference_steps,
1225
+ num_audios=len(result.audios),
1226
+ )
1227
+
1228
+ def _none_if_na_str(v: Any) -> Optional[str]:
1229
+ if v is None:
1230
+ return None
1231
+ s = str(v).strip()
1232
+ if s in {"", "N/A"}:
1233
+ return None
1234
+ return s
1235
+
1236
+ # Get model information
1237
+ lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B")
1238
+ # Use selected_model_name (set at the beginning of _run_one_job)
1239
+ dit_model_name = selected_model_name
1240
+
1241
+ return {
1242
+ "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
1243
+ "second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
1244
+ "audio_paths": [_path_to_audio_url(p) for p in audio_paths],
1245
+ "generation_info": generation_info,
1246
+ "status_message": result.status_message,
1247
+ "seed_value": seed_value,
1248
+ # Final prompt/lyrics (may be modified by thinking/format)
1249
+ "prompt": caption or "",
1250
+ "lyrics": lyrics or "",
1251
+ # metas contains original user input + other metadata
1252
+ "metas": metas_out,
1253
+ "bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
1254
+ "duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
1255
+ "genres": _none_if_na_str(metas_out.get("genres")),
1256
+ "keyscale": _none_if_na_str(metas_out.get("keyscale")),
1257
+ "timesignature": _none_if_na_str(metas_out.get("timesignature")),
1258
+ "lm_model": lm_model_name,
1259
+ "dit_model": dit_model_name,
1260
+ }
1261
+
1262
+ t0 = time.time()
1263
+ try:
1264
+ loop = asyncio.get_running_loop()
1265
+ result = await loop.run_in_executor(executor, _blocking_generate)
1266
+ job_store.mark_succeeded(job_id, result)
1267
+
1268
+ # Update local cache
1269
+ _update_local_cache(job_id, result, "succeeded")
1270
+ except Exception:
1271
+ job_store.mark_failed(job_id, traceback.format_exc())
1272
+
1273
+ # Update local cache
1274
+ _update_local_cache(job_id, None, "failed")
1275
+ finally:
1276
+ dt = max(0.0, time.time() - t0)
1277
+ async with app.state.stats_lock:
1278
+ app.state.recent_durations.append(dt)
1279
+ if app.state.recent_durations:
1280
+ app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations)
1281
+
1282
+ async def _queue_worker(worker_idx: int) -> None:
1283
+ while True:
1284
+ job_id, req = await app.state.job_queue.get()
1285
+ try:
1286
+ async with app.state.pending_lock:
1287
+ try:
1288
+ app.state.pending_ids.remove(job_id)
1289
+ except ValueError:
1290
+ pass
1291
+
1292
+ await _run_one_job(job_id, req)
1293
+ finally:
1294
+ await _cleanup_job_temp_files(job_id)
1295
+ app.state.job_queue.task_done()
1296
+
1297
+ async def _job_store_cleanup_worker() -> None:
1298
+ """Background task to periodically clean up old completed jobs."""
1299
+ while True:
1300
+ try:
1301
+ await asyncio.sleep(JOB_STORE_CLEANUP_INTERVAL)
1302
+ removed = store.cleanup_old_jobs()
1303
+ if removed > 0:
1304
+ stats = store.get_stats()
1305
+ print(f"[API Server] Cleaned up {removed} old jobs. Current stats: {stats}")
1306
+ except asyncio.CancelledError:
1307
+ break
1308
+ except Exception as e:
1309
+ print(f"[API Server] Job cleanup error: {e}")
1310
+
1311
+ worker_count = max(1, WORKER_COUNT)
1312
+ workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)]
1313
+ cleanup_task = asyncio.create_task(_job_store_cleanup_worker())
1314
+ app.state.worker_tasks = workers
1315
+ app.state.cleanup_task = cleanup_task
1316
+
1317
+ try:
1318
+ yield
1319
+ finally:
1320
+ cleanup_task.cancel()
1321
+ for t in workers:
1322
+ t.cancel()
1323
+ executor.shutdown(wait=False, cancel_futures=True)
1324
+
1325
+ app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
1326
+
1327
+ async def _queue_position(job_id: str) -> int:
1328
+ async with app.state.pending_lock:
1329
+ try:
1330
+ return list(app.state.pending_ids).index(job_id) + 1
1331
+ except ValueError:
1332
+ return 0
1333
+
1334
+ async def _eta_seconds_for_position(pos: int) -> Optional[float]:
1335
+ if pos <= 0:
1336
+ return None
1337
+ async with app.state.stats_lock:
1338
+ avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
1339
+ return pos * avg
1340
+
1341
+ @app.post("/release_task", response_model=CreateJobResponse)
1342
+ async def create_music_generate_job(request: Request) -> CreateJobResponse:
1343
+ content_type = (request.headers.get("content-type") or "").lower()
1344
+ temp_files: list[str] = []
1345
+
1346
+ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest:
1347
+ """Build GenerateMusicRequest from parsed parameters."""
1348
+ return GenerateMusicRequest(
1349
+ prompt=p.str("prompt"),
1350
+ lyrics=p.str("lyrics"),
1351
+ thinking=p.bool("thinking"),
1352
+ sample_mode=p.bool("sample_mode"),
1353
+ sample_query=p.str("sample_query"),
1354
+ use_format=p.bool("use_format"),
1355
+ model=p.str("model") or None,
1356
+ bpm=p.int("bpm"),
1357
+ key_scale=p.str("key_scale"),
1358
+ time_signature=p.str("time_signature"),
1359
+ audio_duration=p.float("audio_duration"),
1360
+ vocal_language=p.str("vocal_language", "en"),
1361
+ inference_steps=p.int("inference_steps", 8),
1362
+ guidance_scale=p.float("guidance_scale", 7.0),
1363
+ use_random_seed=p.bool("use_random_seed", True),
1364
+ seed=p.int("seed", -1),
1365
+ batch_size=p.int("batch_size"),
1366
+ audio_code_string=p.str("audio_code_string"),
1367
+ repainting_start=p.float("repainting_start", 0.0),
1368
+ repainting_end=p.float("repainting_end"),
1369
+ instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
1370
+ audio_cover_strength=p.float("audio_cover_strength", 1.0),
1371
+ task_type=p.str("task_type", "text2music"),
1372
+ use_adg=p.bool("use_adg"),
1373
+ cfg_interval_start=p.float("cfg_interval_start", 0.0),
1374
+ cfg_interval_end=p.float("cfg_interval_end", 1.0),
1375
+ infer_method=p.str("infer_method", "ode"),
1376
+ shift=p.float("shift", 3.0),
1377
+ audio_format=p.str("audio_format", "mp3"),
1378
+ use_tiled_decode=p.bool("use_tiled_decode", True),
1379
+ lm_model_path=p.str("lm_model_path") or None,
1380
+ lm_backend=p.str("lm_backend", "vllm"),
1381
+ lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE),
1382
+ lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE),
1383
+ lm_top_k=p.int("lm_top_k"),
1384
+ lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P),
1385
+ lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0),
1386
+ lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"),
1387
+ constrained_decoding=p.bool("constrained_decoding", True),
1388
+ constrained_decoding_debug=p.bool("constrained_decoding_debug"),
1389
+ use_cot_caption=p.bool("use_cot_caption", True),
1390
+ use_cot_language=p.bool("use_cot_language", True),
1391
+ is_format_caption=p.bool("is_format_caption"),
1392
+ **kwargs,
1393
+ )
1394
+
1395
+ if content_type.startswith("application/json"):
1396
+ body = await request.json()
1397
+ if not isinstance(body, dict):
1398
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1399
+ req = _build_request(RequestParser(body))
1400
+
1401
+ elif content_type.endswith("+json"):
1402
+ body = await request.json()
1403
+ if not isinstance(body, dict):
1404
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1405
+ req = _build_request(RequestParser(body))
1406
+
1407
+ elif content_type.startswith("multipart/form-data"):
1408
+ form = await request.form()
1409
+
1410
+ ref_up = form.get("reference_audio")
1411
+ src_up = form.get("src_audio")
1412
+
1413
+ reference_audio_path = None
1414
+ src_audio_path = None
1415
+
1416
+ if isinstance(ref_up, StarletteUploadFile):
1417
+ reference_audio_path = await _save_upload_to_temp(ref_up, prefix="reference_audio")
1418
+ temp_files.append(reference_audio_path)
1419
+ else:
1420
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1421
+
1422
+ if isinstance(src_up, StarletteUploadFile):
1423
+ src_audio_path = await _save_upload_to_temp(src_up, prefix="src_audio")
1424
+ temp_files.append(src_audio_path)
1425
+ else:
1426
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1427
+
1428
+ req = _build_request(
1429
+ RequestParser(dict(form)),
1430
+ reference_audio_path=reference_audio_path,
1431
+ src_audio_path=src_audio_path,
1432
+ )
1433
+
1434
+ elif content_type.startswith("application/x-www-form-urlencoded"):
1435
+ form = await request.form()
1436
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1437
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1438
+ req = _build_request(
1439
+ RequestParser(dict(form)),
1440
+ reference_audio_path=reference_audio_path,
1441
+ src_audio_path=src_audio_path,
1442
+ )
1443
+
1444
+ else:
1445
+ raw = await request.body()
1446
+ raw_stripped = raw.lstrip()
1447
+ # Best-effort: accept missing/incorrect Content-Type if payload is valid JSON.
1448
+ if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["):
1449
+ try:
1450
+ body = json.loads(raw.decode("utf-8"))
1451
+ if isinstance(body, dict):
1452
+ req = _build_request(RequestParser(body))
1453
+ else:
1454
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1455
+ except HTTPException:
1456
+ raise
1457
+ except Exception:
1458
+ raise HTTPException(
1459
+ status_code=400,
1460
+ detail="Invalid JSON body (hint: set 'Content-Type: application/json')",
1461
+ )
1462
+ # Best-effort: parse key=value bodies even if Content-Type is missing.
1463
+ elif raw_stripped and b"=" in raw:
1464
+ parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
1465
+ flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()}
1466
+ reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
1467
+ src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
1468
+ req = _build_request(
1469
+ RequestParser(flat),
1470
+ reference_audio_path=reference_audio_path,
1471
+ src_audio_path=src_audio_path,
1472
+ )
1473
+ else:
1474
+ raise HTTPException(
1475
+ status_code=415,
1476
+ detail=(
1477
+ f"Unsupported Content-Type: {content_type or '(missing)'}; "
1478
+ "use application/json, application/x-www-form-urlencoded, or multipart/form-data"
1479
+ ),
1480
+ )
1481
+
1482
+ rec = store.create()
1483
+
1484
+ q: asyncio.Queue = app.state.job_queue
1485
+ if q.full():
1486
+ for p in temp_files:
1487
+ try:
1488
+ os.remove(p)
1489
+ except Exception:
1490
+ pass
1491
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
1492
+
1493
+ if temp_files:
1494
+ async with app.state.job_temp_files_lock:
1495
+ app.state.job_temp_files[rec.job_id] = temp_files
1496
+
1497
+ async with app.state.pending_lock:
1498
+ app.state.pending_ids.append(rec.job_id)
1499
+ position = len(app.state.pending_ids)
1500
+
1501
+ await q.put((rec.job_id, req))
1502
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1503
+
1504
+ @app.post("/query_result")
1505
+ async def query_result(request: Request) -> List[Dict[str, Any]]:
1506
+ """Batch query job results"""
1507
+ content_type = (request.headers.get("content-type") or "").lower()
1508
+
1509
+ if "json" in content_type:
1510
+ body = await request.json()
1511
+ else:
1512
+ form = await request.form()
1513
+ body = {k: v for k, v in form.items()}
1514
+
1515
+ task_id_list_str = body.get("task_id_list", "[]")
1516
+
1517
+ # Parse task ID list
1518
+ if isinstance(task_id_list_str, list):
1519
+ task_id_list = task_id_list_str
1520
+ else:
1521
+ try:
1522
+ task_id_list = json.loads(task_id_list_str)
1523
+ except Exception:
1524
+ task_id_list = []
1525
+
1526
+ local_cache = getattr(app.state, 'local_cache', None)
1527
+ data_list = []
1528
+ current_time = time.time()
1529
+
1530
+ for task_id in task_id_list:
1531
+ result_key = f"{RESULT_KEY_PREFIX}{task_id}"
1532
+
1533
+ # Read from local cache first
1534
+ if local_cache:
1535
+ data = local_cache.get(result_key)
1536
+ if data:
1537
+ try:
1538
+ data_json = json.loads(data)
1539
+ except Exception:
1540
+ data_json = []
1541
+
1542
+ if len(data_json) <= 0:
1543
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1544
+ else:
1545
+ status = data_json[0].get("status")
1546
+ create_time = data_json[0].get("create_time", 0)
1547
+ if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS:
1548
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1549
+ else:
1550
+ data_list.append({
1551
+ "task_id": task_id,
1552
+ "result": data,
1553
+ "status": int(status) if status is not None else 1,
1554
+ })
1555
+ continue
1556
+
1557
+ # Fallback to job_store query
1558
+ rec = store.get(task_id)
1559
+ if rec:
1560
+ env = getattr(rec, 'env', 'development')
1561
+ create_time = rec.created_at
1562
+ status_int = _map_status(rec.status)
1563
+
1564
+ if rec.result and rec.status == "succeeded":
1565
+ audio_paths = rec.result.get("audio_paths", [])
1566
+ metas = rec.result.get("metas", {}) or {}
1567
+ result_data = [
1568
+ {
1569
+ "file": p, "wave": "", "status": status_int,
1570
+ "create_time": int(create_time), "env": env,
1571
+ "prompt": metas.get("caption", ""),
1572
+ "lyrics": metas.get("lyrics", ""),
1573
+ "metas": {
1574
+ "bpm": metas.get("bpm"),
1575
+ "duration": metas.get("duration"),
1576
+ "genres": metas.get("genres", ""),
1577
+ "keyscale": metas.get("keyscale", ""),
1578
+ "timesignature": metas.get("timesignature", ""),
1579
+ }
1580
+ }
1581
+ for p in audio_paths
1582
+ ] if audio_paths else [{
1583
+ "file": "", "wave": "", "status": status_int,
1584
+ "create_time": int(create_time), "env": env,
1585
+ "prompt": metas.get("caption", ""),
1586
+ "lyrics": metas.get("lyrics", ""),
1587
+ "metas": {
1588
+ "bpm": metas.get("bpm"),
1589
+ "duration": metas.get("duration"),
1590
+ "genres": metas.get("genres", ""),
1591
+ "keyscale": metas.get("keyscale", ""),
1592
+ "timesignature": metas.get("timesignature", ""),
1593
+ }
1594
+ }]
1595
+ else:
1596
+ result_data = [{
1597
+ "file": "", "wave": "", "status": status_int,
1598
+ "create_time": int(create_time), "env": env,
1599
+ "prompt": "", "lyrics": "",
1600
+ "metas": {}
1601
+ }]
1602
+
1603
+ data_list.append({
1604
+ "task_id": task_id,
1605
+ "result": json.dumps(result_data, ensure_ascii=False),
1606
+ "status": status_int,
1607
+ })
1608
+ else:
1609
+ data_list.append({"task_id": task_id, "result": "[]", "status": 0})
1610
+
1611
+ return data_list
1612
+
1613
+ @app.get("/health")
1614
+ async def health_check():
1615
+ """Health check endpoint for service status."""
1616
+ return {
1617
+ "status": "ok",
1618
+ "service": "ACE-Step API",
1619
+ "version": "1.0",
1620
+ }
1621
+
1622
+ @app.get("/v1/stats")
1623
+ async def get_stats():
1624
+ """Get server statistics including job store stats."""
1625
+ job_stats = store.get_stats()
1626
+ async with app.state.stats_lock:
1627
+ avg_job_seconds = getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS)
1628
+ return {
1629
+ "jobs": job_stats,
1630
+ "queue_size": app.state.job_queue.qsize(),
1631
+ "queue_maxsize": QUEUE_MAXSIZE,
1632
+ "avg_job_seconds": avg_job_seconds,
1633
+ }
1634
+
1635
+ @app.get("/v1/models")
1636
+ async def list_models():
1637
+ """List available DiT models."""
1638
+ models = []
1639
+
1640
+ # Primary model (always available if initialized)
1641
+ if getattr(app.state, "_initialized", False):
1642
+ primary_model = _get_model_name(app.state._config_path)
1643
+ if primary_model:
1644
+ models.append({
1645
+ "name": primary_model,
1646
+ "is_default": True,
1647
+ })
1648
+
1649
+ # Secondary model
1650
+ if getattr(app.state, "_initialized2", False) and app.state._config_path2:
1651
+ secondary_model = _get_model_name(app.state._config_path2)
1652
+ if secondary_model:
1653
+ models.append({
1654
+ "name": secondary_model,
1655
+ "is_default": False,
1656
+ })
1657
+
1658
+ # Third model
1659
+ if getattr(app.state, "_initialized3", False) and app.state._config_path3:
1660
+ third_model = _get_model_name(app.state._config_path3)
1661
+ if third_model:
1662
+ models.append({
1663
+ "name": third_model,
1664
+ "is_default": False,
1665
+ })
1666
+
1667
+ return {
1668
+ "models": models,
1669
+ "default_model": models[0]["name"] if models else None,
1670
+ }
1671
+
1672
+ @app.get("/v1/audio")
1673
+ async def get_audio(path: str):
1674
+ """Serve audio file by path."""
1675
+ from fastapi.responses import FileResponse
1676
+
1677
+ if not os.path.exists(path):
1678
+ raise HTTPException(status_code=404, detail=f"Audio file not found: {path}")
1679
+
1680
+ ext = os.path.splitext(path)[1].lower()
1681
+ media_types = {
1682
+ ".mp3": "audio/mpeg",
1683
+ ".wav": "audio/wav",
1684
+ ".flac": "audio/flac",
1685
+ ".ogg": "audio/ogg",
1686
+ }
1687
+ media_type = media_types.get(ext, "audio/mpeg")
1688
+
1689
+ return FileResponse(path, media_type=media_type)
1690
+
1691
+ return app
1692
+
1693
+
1694
+ app = create_app()
1695
+
1696
+
1697
+ def main() -> None:
1698
+ import argparse
1699
+ import uvicorn
1700
+
1701
+ parser = argparse.ArgumentParser(description="ACE-Step API server")
1702
+ parser.add_argument(
1703
+ "--host",
1704
+ default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
1705
+ help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
1706
+ )
1707
+ parser.add_argument(
1708
+ "--port",
1709
+ type=int,
1710
+ default=int(os.getenv("ACESTEP_API_PORT", "8001")),
1711
+ help="Bind port (default from ACESTEP_API_PORT or 8001)",
1712
+ )
1713
+ args = parser.parse_args()
1714
+
1715
+ # IMPORTANT: in-memory queue/store -> workers MUST be 1
1716
+ uvicorn.run(
1717
+ "acestep.api_server:app",
1718
+ host=str(args.host),
1719
+ port=int(args.port),
1720
+ reload=False,
1721
+ workers=1,
1722
+ )
1723
+
1724
+ if __name__ == "__main__":
1725
+ main()
code/acestep/audio_utils.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+ import hashlib
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Tuple
15
+ import torch
16
+ import numpy as np
17
+ import torchaudio
18
+ from loguru import logger
19
+
20
+
21
+ class AudioSaver:
22
+ """Audio saving and transcoding utility class"""
23
+
24
+ def __init__(self, default_format: str = "flac"):
25
+ """
26
+ Initialize audio saver
27
+
28
+ Args:
29
+ default_format: Default save format ('flac', 'wav', 'mp3')
30
+ """
31
+ self.default_format = default_format.lower()
32
+ if self.default_format not in ["flac", "wav", "mp3"]:
33
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
34
+ self.default_format = "flac"
35
+
36
+ def save_audio(
37
+ self,
38
+ audio_data: Union[torch.Tensor, np.ndarray],
39
+ output_path: Union[str, Path],
40
+ sample_rate: int = 48000,
41
+ format: Optional[str] = None,
42
+ channels_first: bool = True,
43
+ ) -> str:
44
+ """
45
+ Save audio data to file
46
+
47
+ Args:
48
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
49
+ output_path: Output file path (extension can be omitted)
50
+ sample_rate: Sample rate
51
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
52
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
53
+
54
+ Returns:
55
+ Actual saved file path
56
+ """
57
+ format = (format or self.default_format).lower()
58
+ if format not in ["flac", "wav", "mp3"]:
59
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
60
+ format = self.default_format
61
+
62
+ # Ensure output path has correct extension
63
+ output_path = Path(output_path)
64
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
65
+ output_path = output_path.with_suffix(f'.{format}')
66
+
67
+ # Convert to torch tensor
68
+ if isinstance(audio_data, np.ndarray):
69
+ if channels_first:
70
+ # numpy [samples, channels] -> tensor [channels, samples]
71
+ audio_tensor = torch.from_numpy(audio_data.T).float()
72
+ else:
73
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
74
+ audio_tensor = torch.from_numpy(audio_data).float()
75
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
76
+ audio_tensor = audio_tensor.T
77
+ else:
78
+ # torch tensor
79
+ audio_tensor = audio_data.cpu().float()
80
+ if not channels_first and audio_tensor.dim() == 2:
81
+ # [samples, channels] -> [channels, samples]
82
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
83
+ audio_tensor = audio_tensor.T
84
+
85
+ # Ensure memory is contiguous
86
+ audio_tensor = audio_tensor.contiguous()
87
+
88
+ # Select backend and save
89
+ try:
90
+ if format == "mp3":
91
+ # MP3 uses ffmpeg backend
92
+ torchaudio.save(
93
+ str(output_path),
94
+ audio_tensor,
95
+ sample_rate,
96
+ channels_first=True,
97
+ backend='ffmpeg',
98
+ )
99
+ elif format in ["flac", "wav"]:
100
+ # FLAC and WAV use soundfile backend (fastest)
101
+ torchaudio.save(
102
+ str(output_path),
103
+ audio_tensor,
104
+ sample_rate,
105
+ channels_first=True,
106
+ backend='soundfile',
107
+ )
108
+ else:
109
+ # Other formats use default backend
110
+ torchaudio.save(
111
+ str(output_path),
112
+ audio_tensor,
113
+ sample_rate,
114
+ channels_first=True,
115
+ )
116
+
117
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
118
+ return str(output_path)
119
+
120
+ except Exception as e:
121
+ try:
122
+ import soundfile as sf
123
+ audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
124
+ sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
125
+ logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
126
+ return str(output_path)
127
+ except Exception as e:
128
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
129
+ raise
130
+
131
+ def convert_audio(
132
+ self,
133
+ input_path: Union[str, Path],
134
+ output_path: Union[str, Path],
135
+ output_format: str,
136
+ remove_input: bool = False,
137
+ ) -> str:
138
+ """
139
+ Convert audio format
140
+
141
+ Args:
142
+ input_path: Input audio file path
143
+ output_path: Output audio file path
144
+ output_format: Target format ('flac', 'wav', 'mp3')
145
+ remove_input: Whether to delete input file
146
+
147
+ Returns:
148
+ Output file path
149
+ """
150
+ input_path = Path(input_path)
151
+ output_path = Path(output_path)
152
+
153
+ if not input_path.exists():
154
+ raise FileNotFoundError(f"Input file not found: {input_path}")
155
+
156
+ # Load audio
157
+ audio_tensor, sample_rate = torchaudio.load(str(input_path))
158
+
159
+ # Save as new format
160
+ output_path = self.save_audio(
161
+ audio_tensor,
162
+ output_path,
163
+ sample_rate=sample_rate,
164
+ format=output_format,
165
+ channels_first=True
166
+ )
167
+
168
+ # Delete input file if needed
169
+ if remove_input:
170
+ input_path.unlink()
171
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
172
+
173
+ return output_path
174
+
175
+ def save_batch(
176
+ self,
177
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
178
+ output_dir: Union[str, Path],
179
+ file_prefix: str = "audio",
180
+ sample_rate: int = 48000,
181
+ format: Optional[str] = None,
182
+ channels_first: bool = True,
183
+ ) -> List[str]:
184
+ """
185
+ Save audio batch
186
+
187
+ Args:
188
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
189
+ output_dir: Output directory
190
+ file_prefix: File prefix
191
+ sample_rate: Sample rate
192
+ format: Audio format
193
+ channels_first: Tensor format flag
194
+
195
+ Returns:
196
+ List of saved file paths
197
+ """
198
+ output_dir = Path(output_dir)
199
+ output_dir.mkdir(parents=True, exist_ok=True)
200
+
201
+ # Process batch
202
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
203
+ # [batch, channels, samples]
204
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
205
+ elif isinstance(audio_batch, list):
206
+ audio_list = audio_batch
207
+ else:
208
+ audio_list = [audio_batch]
209
+
210
+ saved_paths = []
211
+ for i, audio in enumerate(audio_list):
212
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
213
+ saved_path = self.save_audio(
214
+ audio,
215
+ output_path,
216
+ sample_rate=sample_rate,
217
+ format=format,
218
+ channels_first=channels_first
219
+ )
220
+ saved_paths.append(saved_path)
221
+
222
+ return saved_paths
223
+
224
+
225
+ def get_audio_file_hash(audio_file) -> str:
226
+ """
227
+ Get hash identifier for an audio file.
228
+
229
+ Args:
230
+ audio_file: Path to audio file (str) or file-like object
231
+
232
+ Returns:
233
+ Hash string or empty string
234
+ """
235
+ if audio_file is None:
236
+ return ""
237
+
238
+ try:
239
+ if isinstance(audio_file, str):
240
+ if os.path.exists(audio_file):
241
+ with open(audio_file, 'rb') as f:
242
+ return hashlib.md5(f.read()).hexdigest()
243
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
244
+ elif hasattr(audio_file, 'name'):
245
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
246
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
247
+ except Exception:
248
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
249
+
250
+
251
+ def generate_uuid_from_params(params_dict) -> str:
252
+ """
253
+ Generate deterministic UUID from generation parameters.
254
+ Same parameters will always generate the same UUID.
255
+
256
+ Args:
257
+ params_dict: Dictionary of parameters
258
+
259
+ Returns:
260
+ UUID string
261
+ """
262
+
263
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
264
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
265
+ hash_hex = hash_obj.hexdigest()
266
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
267
+ return uuid_str
268
+
269
+
270
+ def generate_uuid_from_audio_data(
271
+ audio_data: Union[torch.Tensor, np.ndarray],
272
+ seed: Optional[int] = None
273
+ ) -> str:
274
+ """
275
+ Generate UUID from audio data (for caching/deduplication)
276
+
277
+ Args:
278
+ audio_data: Audio data
279
+ seed: Optional seed value
280
+
281
+ Returns:
282
+ UUID string
283
+ """
284
+ if isinstance(audio_data, torch.Tensor):
285
+ # Convert to numpy and calculate hash
286
+ audio_np = audio_data.cpu().numpy()
287
+ else:
288
+ audio_np = audio_data
289
+
290
+ # Calculate data hash
291
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
292
+
293
+ if seed is not None:
294
+ combined = f"{data_hash}_{seed}"
295
+ return hashlib.md5(combined.encode()).hexdigest()
296
+
297
+ return data_hash
298
+
299
+
300
+ # Global default instance
301
+ _default_saver = AudioSaver(default_format="flac")
302
+
303
+
304
+ def save_audio(
305
+ audio_data: Union[torch.Tensor, np.ndarray],
306
+ output_path: Union[str, Path],
307
+ sample_rate: int = 48000,
308
+ format: Optional[str] = None,
309
+ channels_first: bool = True,
310
+ ) -> str:
311
+ """
312
+ Convenience function: save audio (using default configuration)
313
+
314
+ Args:
315
+ audio_data: Audio data
316
+ output_path: Output path
317
+ sample_rate: Sample rate
318
+ format: Format (default flac)
319
+ channels_first: Tensor format flag
320
+
321
+ Returns:
322
+ Saved file path
323
+ """
324
+ return _default_saver.save_audio(
325
+ audio_data, output_path, sample_rate, format, channels_first
326
+ )
327
+
code/acestep/constants.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for ACE-Step
3
+ Centralized constants used across the codebase
4
+ """
5
+
6
+ # ==============================================================================
7
+ # Language Constants
8
+ # ==============================================================================
9
+
10
+ VALID_LANGUAGES = [
11
+ 'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
12
+ 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
13
+ 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
14
+ 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
15
+ 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
16
+ 'unknown'
17
+ ]
18
+
19
+
20
+ # ==============================================================================
21
+ # Keyscale Constants
22
+ # ==============================================================================
23
+
24
+ KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
25
+ KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
26
+ KEYSCALE_MODES = ['major', 'minor']
27
+
28
+ # Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
29
+ VALID_KEYSCALES = set()
30
+ for note in KEYSCALE_NOTES:
31
+ for acc in KEYSCALE_ACCIDENTALS:
32
+ for mode in KEYSCALE_MODES:
33
+ VALID_KEYSCALES.add(f"{note}{acc} {mode}")
34
+
35
+
36
+ # ==============================================================================
37
+ # Metadata Range Constants
38
+ # ==============================================================================
39
+
40
+ # BPM (Beats Per Minute) range
41
+ BPM_MIN = 30
42
+ BPM_MAX = 300
43
+
44
+ # Duration range (in seconds)
45
+ DURATION_MIN = 10
46
+ DURATION_MAX = 600
47
+
48
+ # Valid time signatures
49
+ VALID_TIME_SIGNATURES = [2, 3, 4, 6]
50
+
51
+
52
+ # ==============================================================================
53
+ # Task Type Constants
54
+ # ==============================================================================
55
+
56
+ TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
57
+
58
+ # Task types available for turbo models (subset)
59
+ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
60
+
61
+ # Task types available for base models (full set)
62
+ TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
63
+
64
+
65
+ # ==============================================================================
66
+ # Instruction Constants
67
+ # ==============================================================================
68
+
69
+ # Default instructions
70
+ DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
71
+ DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
72
+ DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
73
+ DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
74
+ DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
75
+
76
+ # Instruction templates for each task type
77
+ # Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
78
+ # These should be formatted using .format() or f-strings when used
79
+ TASK_INSTRUCTIONS = {
80
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
81
+ "repaint": "Repaint the mask area based on the given conditions:",
82
+ "cover": "Generate audio semantic tokens based on the given conditions:",
83
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
84
+ "extract_default": "Extract the track from the audio:",
85
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
86
+ "lego_default": "Generate the track based on the audio context:",
87
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
88
+ "complete_default": "Complete the input track:",
89
+ }
90
+
91
+
92
+ # ==============================================================================
93
+ # Track/Instrument Constants
94
+ # ==============================================================================
95
+
96
+ TRACK_NAMES = [
97
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
98
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
99
+ ]
100
+
101
+ SFT_GEN_PROMPT = """# Instruction
102
+ {}
103
+
104
+ # Caption
105
+ {}
106
+
107
+ # Metas
108
+ {}<|endoftext|>
109
+ """
code/acestep/constrained_logits_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
code/acestep/dataset_handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Handler
3
+ Handles dataset import and exploration functionality
4
+ """
5
+ from typing import Optional, Tuple, Any, Dict
6
+
7
+
8
+ class DatasetHandler:
9
+ """Dataset Handler for Dataset Explorer functionality"""
10
+
11
+ def __init__(self):
12
+ """Initialize dataset handler"""
13
+ self.dataset = None
14
+ self.dataset_imported = False
15
+
16
+ def import_dataset(self, dataset_type: str) -> str:
17
+ """
18
+ Import dataset (temporarily disabled)
19
+
20
+ Args:
21
+ dataset_type: Type of dataset to import (e.g., "train", "test")
22
+
23
+ Returns:
24
+ Status message string
25
+ """
26
+ self.dataset_imported = False
27
+ return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
28
+
29
+ def get_item_data(self, *args, **kwargs) -> Tuple:
30
+ """
31
+ Get dataset item (temporarily disabled)
32
+
33
+ Returns:
34
+ Tuple of placeholder values matching the expected return format
35
+ """
36
+ return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
37
+
code/acestep/dit_alignment_score.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT Alignment Score Module
3
+
4
+ This module provides lyrics-to-audio alignment using cross-attention matrices
5
+ from DiT model for generating LRC timestamps.
6
+
7
+ Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
8
+ """
9
+ import numba
10
+ import torch
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
+
16
+
17
+ # ================= Data Classes =================
18
+ @dataclass
19
+ class TokenTimestamp:
20
+ """Stores per-token timing information."""
21
+ token_id: int
22
+ text: str
23
+ start: float
24
+ end: float
25
+ probability: float
26
+
27
+
28
+ @dataclass
29
+ class SentenceTimestamp:
30
+ """Stores per-sentence timing information with token list."""
31
+ text: str
32
+ start: float
33
+ end: float
34
+ tokens: List[TokenTimestamp]
35
+ confidence: float
36
+
37
+
38
+ # ================= DTW Algorithm (Numba Optimized) =================
39
+ @numba.jit(nopython=True)
40
+ def dtw_cpu(x: np.ndarray):
41
+ """
42
+ Dynamic Time Warping algorithm optimized with Numba.
43
+
44
+ Args:
45
+ x: Cost matrix of shape [N, M]
46
+
47
+ Returns:
48
+ Tuple of (text_indices, time_indices) arrays
49
+ """
50
+ N, M = x.shape
51
+ # Use float32 for memory efficiency
52
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
53
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
54
+ cost[0, 0] = 0
55
+
56
+ for j in range(1, M + 1):
57
+ for i in range(1, N + 1):
58
+ c0 = cost[i - 1, j - 1]
59
+ c1 = cost[i - 1, j]
60
+ c2 = cost[i, j - 1]
61
+
62
+ if c0 < c1 and c0 < c2:
63
+ c, t = c0, 0
64
+ elif c1 < c0 and c1 < c2:
65
+ c, t = c1, 1
66
+ else:
67
+ c, t = c2, 2
68
+
69
+ cost[i, j] = x[i - 1, j - 1] + c
70
+ trace[i, j] = t
71
+
72
+ return _backtrace(trace, N, M)
73
+
74
+
75
+ @numba.jit(nopython=True)
76
+ def _backtrace(trace: np.ndarray, N: int, M: int):
77
+ """
78
+ Optimized backtrace function for DTW.
79
+
80
+ Args:
81
+ trace: Trace matrix of shape (N+1, M+1)
82
+ N, M: Original matrix dimensions
83
+
84
+ Returns:
85
+ Path array of shape (2, path_len) - first row is text indices, second is time indices
86
+ """
87
+ # Boundary handling
88
+ trace[0, :] = 2
89
+ trace[:, 0] = 1
90
+
91
+ # Pre-allocate array, max path length is N+M
92
+ max_path_len = N + M
93
+ path = np.zeros((2, max_path_len), dtype=np.int32)
94
+
95
+ i, j = N, M
96
+ path_idx = max_path_len - 1
97
+
98
+ while i > 0 or j > 0:
99
+ path[0, path_idx] = i - 1 # text index
100
+ path[1, path_idx] = j - 1 # time index
101
+ path_idx -= 1
102
+
103
+ t = trace[i, j]
104
+ if t == 0:
105
+ i -= 1
106
+ j -= 1
107
+ elif t == 1:
108
+ i -= 1
109
+ elif t == 2:
110
+ j -= 1
111
+ else:
112
+ break
113
+
114
+ actual_len = max_path_len - path_idx - 1
115
+ return path[:, path_idx + 1:max_path_len]
116
+
117
+
118
+ # ================= Utility Functions =================
119
+ def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
120
+ """
121
+ Apply median filter to tensor.
122
+
123
+ Args:
124
+ x: Input tensor
125
+ filter_width: Width of median filter
126
+
127
+ Returns:
128
+ Filtered tensor
129
+ """
130
+ pad_width = filter_width // 2
131
+ if x.shape[-1] <= pad_width:
132
+ return x
133
+ if x.ndim == 2:
134
+ x = x[None, :]
135
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
136
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
137
+ if result.ndim > 2:
138
+ result = result.squeeze(0)
139
+ return result
140
+
141
+
142
+ # ================= Main Aligner Class =================
143
+ class MusicStampsAligner:
144
+ """
145
+ Aligner class for generating lyrics timestamps from cross-attention matrices.
146
+
147
+ Uses bidirectional consensus denoising and DTW for alignment.
148
+ """
149
+
150
+ def __init__(self, tokenizer):
151
+ """
152
+ Initialize the aligner.
153
+
154
+ Args:
155
+ tokenizer: Text tokenizer for decoding tokens
156
+ """
157
+ self.tokenizer = tokenizer
158
+
159
+ def _apply_bidirectional_consensus(
160
+ self,
161
+ weights_stack: torch.Tensor,
162
+ violence_level: float,
163
+ medfilt_width: int
164
+ ) -> tuple:
165
+ """
166
+ Core denoising logic using bidirectional consensus.
167
+
168
+ Args:
169
+ weights_stack: Attention weights [Heads, Tokens, Frames]
170
+ violence_level: Denoising strength coefficient
171
+ medfilt_width: Median filter width
172
+
173
+ Returns:
174
+ Tuple of (calc_matrix, energy_matrix) as numpy arrays
175
+ """
176
+ # A. Bidirectional Consensus
177
+ row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
178
+ col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
179
+ processed = row_prob * col_prob
180
+
181
+ # 1. Row suppression (kill horizontal crossing lines)
182
+ row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
183
+ processed = processed - (violence_level * row_medians)
184
+ processed = torch.relu(processed)
185
+
186
+ # 2. Column suppression (kill vertical crossing lines)
187
+ col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
188
+ processed = processed - (violence_level * col_medians)
189
+ processed = torch.relu(processed)
190
+
191
+ # C. Power sharpening
192
+ processed = processed ** 2
193
+
194
+ # Energy matrix for confidence
195
+ energy_matrix = processed.mean(dim=0).cpu().numpy()
196
+
197
+ # D. Z-Score normalization
198
+ std, mean = torch.std_mean(processed, unbiased=False)
199
+ weights_processed = (processed - mean) / (std + 1e-9)
200
+
201
+ # E. Median filtering
202
+ weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
203
+ calc_matrix = weights_processed.mean(dim=0).numpy()
204
+
205
+ return calc_matrix, energy_matrix
206
+
207
+ def _preprocess_attention(
208
+ self,
209
+ attention_matrix: torch.Tensor,
210
+ custom_config: Dict[int, List[int]],
211
+ violence_level: float,
212
+ medfilt_width: int = 7
213
+ ) -> tuple:
214
+ """
215
+ Preprocess attention matrix for alignment.
216
+
217
+ Args:
218
+ attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
219
+ custom_config: Dict mapping layer indices to head indices
220
+ violence_level: Denoising strength
221
+ medfilt_width: Median filter width
222
+
223
+ Returns:
224
+ Tuple of (calc_matrix, energy_matrix, visual_matrix)
225
+ """
226
+ if not isinstance(attention_matrix, torch.Tensor):
227
+ weights = torch.tensor(attention_matrix)
228
+ else:
229
+ weights = attention_matrix.clone()
230
+
231
+ weights = weights.cpu().float()
232
+
233
+ selected_tensors = []
234
+ for layer_idx, head_indices in custom_config.items():
235
+ for head_idx in head_indices:
236
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
237
+ head_matrix = weights[layer_idx, head_idx]
238
+ selected_tensors.append(head_matrix)
239
+
240
+ if not selected_tensors:
241
+ return None, None, None
242
+
243
+ # Stack selected heads: [Heads, Tokens, Frames]
244
+ weights_stack = torch.stack(selected_tensors, dim=0)
245
+ visual_matrix = weights_stack.mean(dim=0).numpy()
246
+
247
+ calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
248
+ weights_stack, violence_level, medfilt_width
249
+ )
250
+
251
+ return calc_matrix, energy_matrix, visual_matrix
252
+
253
+ def stamps_align_info(
254
+ self,
255
+ attention_matrix: torch.Tensor,
256
+ lyrics_tokens: List[int],
257
+ total_duration_seconds: float,
258
+ custom_config: Dict[int, List[int]],
259
+ return_matrices: bool = False,
260
+ violence_level: float = 2.0,
261
+ medfilt_width: int = 1
262
+ ) -> Dict[str, Any]:
263
+ """
264
+ Get alignment information from attention matrix.
265
+
266
+ Args:
267
+ attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
268
+ lyrics_tokens: List of lyrics token IDs
269
+ total_duration_seconds: Total audio duration in seconds
270
+ custom_config: Dict mapping layer indices to head indices
271
+ return_matrices: Whether to return intermediate matrices
272
+ violence_level: Denoising strength
273
+ medfilt_width: Median filter width
274
+
275
+ Returns:
276
+ Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
277
+ and optionally energy_matrix and vis_matrix
278
+ """
279
+ calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
280
+ attention_matrix, custom_config, violence_level, medfilt_width
281
+ )
282
+
283
+ if calc_matrix is None:
284
+ return {
285
+ "calc_matrix": None,
286
+ "lyrics_tokens": lyrics_tokens,
287
+ "total_duration_seconds": total_duration_seconds,
288
+ "error": "No valid attention heads found"
289
+ }
290
+
291
+ return_dict = {
292
+ "calc_matrix": calc_matrix,
293
+ "lyrics_tokens": lyrics_tokens,
294
+ "total_duration_seconds": total_duration_seconds
295
+ }
296
+
297
+ if return_matrices:
298
+ return_dict['energy_matrix'] = energy_matrix
299
+ return_dict['vis_matrix'] = visual_matrix
300
+
301
+ return return_dict
302
+
303
+ def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
304
+ """
305
+ Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
306
+
307
+ For Chinese and other multi-byte characters, the tokenizer may split them
308
+ into multiple byte-level tokens. Decoding each token individually produces
309
+ invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
310
+ to correctly track which characters each token contributes.
311
+
312
+ Args:
313
+ token_ids: List of token IDs
314
+
315
+ Returns:
316
+ List of decoded text for each token position
317
+ """
318
+ decoded_tokens = []
319
+ prev_bytes = b""
320
+
321
+ for i in range(len(token_ids)):
322
+ # Decode tokens from start to current position
323
+ current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
324
+ current_bytes = current_text.encode('utf-8', errors='surrogatepass')
325
+
326
+ # The contribution of current token is the new bytes added
327
+ if len(current_bytes) >= len(prev_bytes):
328
+ new_bytes = current_bytes[len(prev_bytes):]
329
+ # Try to decode the new bytes; if incomplete, use empty string
330
+ try:
331
+ token_text = new_bytes.decode('utf-8')
332
+ except UnicodeDecodeError:
333
+ # Incomplete UTF-8 sequence, this token doesn't complete a character
334
+ token_text = ""
335
+ else:
336
+ # Edge case: current decode is shorter (shouldn't happen normally)
337
+ token_text = ""
338
+
339
+ decoded_tokens.append(token_text)
340
+ prev_bytes = current_bytes
341
+
342
+ return decoded_tokens
343
+
344
+ def token_timestamps(
345
+ self,
346
+ calc_matrix: np.ndarray,
347
+ lyrics_tokens: List[int],
348
+ total_duration_seconds: float
349
+ ) -> List[TokenTimestamp]:
350
+ """
351
+ Generate per-token timestamps using DTW.
352
+
353
+ Args:
354
+ calc_matrix: Processed attention matrix [Tokens, Frames]
355
+ lyrics_tokens: List of token IDs
356
+ total_duration_seconds: Total audio duration
357
+
358
+ Returns:
359
+ List of TokenTimestamp objects
360
+ """
361
+ n_frames = calc_matrix.shape[-1]
362
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
363
+
364
+ seconds_per_frame = total_duration_seconds / n_frames
365
+ alignment_results = []
366
+
367
+ # Use incremental decoding to properly handle multi-byte UTF-8 characters
368
+ decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
369
+
370
+ for i in range(len(lyrics_tokens)):
371
+ mask = (text_indices == i)
372
+
373
+ if not np.any(mask):
374
+ start = alignment_results[-1].end if alignment_results else 0.0
375
+ end = start
376
+ token_conf = 0.0
377
+ else:
378
+ times = time_indices[mask] * seconds_per_frame
379
+ start = times[0]
380
+ end = times[-1]
381
+ token_conf = 0.0
382
+
383
+ if end < start:
384
+ end = start
385
+
386
+ alignment_results.append(TokenTimestamp(
387
+ token_id=lyrics_tokens[i],
388
+ text=decoded_tokens[i],
389
+ start=float(start),
390
+ end=float(end),
391
+ probability=token_conf
392
+ ))
393
+
394
+ return alignment_results
395
+
396
+ def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
397
+ """
398
+ Decode a sentence by decoding all token IDs together.
399
+ This avoids UTF-8 encoding issues from joining individual token texts.
400
+
401
+ Args:
402
+ tokens: List of TokenTimestamp objects
403
+
404
+ Returns:
405
+ Properly decoded sentence text
406
+ """
407
+ token_ids = [t.token_id for t in tokens]
408
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
409
+
410
+ def sentence_timestamps(
411
+ self,
412
+ token_alignment: List[TokenTimestamp]
413
+ ) -> List[SentenceTimestamp]:
414
+ """
415
+ Group token timestamps into sentence timestamps.
416
+
417
+ Args:
418
+ token_alignment: List of TokenTimestamp objects
419
+
420
+ Returns:
421
+ List of SentenceTimestamp objects
422
+ """
423
+ results = []
424
+ current_tokens = []
425
+
426
+ for token in token_alignment:
427
+ current_tokens.append(token)
428
+
429
+ if '\n' in token.text:
430
+ # Decode all token IDs together to avoid UTF-8 issues
431
+ full_text = self._decode_sentence_from_tokens(current_tokens)
432
+
433
+ if full_text.strip():
434
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
435
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
436
+
437
+ results.append(SentenceTimestamp(
438
+ text=full_text.strip(),
439
+ start=round(current_tokens[0].start, 3),
440
+ end=round(current_tokens[-1].end, 3),
441
+ tokens=list(current_tokens),
442
+ confidence=sent_conf
443
+ ))
444
+
445
+ current_tokens = []
446
+
447
+ # Handle last sentence
448
+ if current_tokens:
449
+ # Decode all token IDs together to avoid UTF-8 issues
450
+ full_text = self._decode_sentence_from_tokens(current_tokens)
451
+ if full_text.strip():
452
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
453
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
454
+
455
+ results.append(SentenceTimestamp(
456
+ text=full_text.strip(),
457
+ start=round(current_tokens[0].start, 3),
458
+ end=round(current_tokens[-1].end, 3),
459
+ tokens=list(current_tokens),
460
+ confidence=sent_conf
461
+ ))
462
+
463
+ # Normalize confidence scores
464
+ if results:
465
+ all_scores = [s.confidence for s in results]
466
+ min_score = min(all_scores)
467
+ max_score = max(all_scores)
468
+ score_range = max_score - min_score
469
+
470
+ if score_range > 1e-9:
471
+ for s in results:
472
+ normalized_score = (s.confidence - min_score) / score_range
473
+ s.confidence = round(normalized_score, 2)
474
+ else:
475
+ for s in results:
476
+ s.confidence = round(s.confidence, 2)
477
+
478
+ return results
479
+
480
+ def format_lrc(
481
+ self,
482
+ sentence_timestamps: List[SentenceTimestamp],
483
+ include_end_time: bool = False
484
+ ) -> str:
485
+ """
486
+ Format sentence timestamps as LRC lyrics format.
487
+
488
+ Args:
489
+ sentence_timestamps: List of SentenceTimestamp objects
490
+ include_end_time: Whether to include end time (enhanced LRC format)
491
+
492
+ Returns:
493
+ LRC formatted string
494
+ """
495
+ lines = []
496
+
497
+ for sentence in sentence_timestamps:
498
+ # Convert seconds to mm:ss.xx format
499
+ start_minutes = int(sentence.start // 60)
500
+ start_seconds = sentence.start % 60
501
+
502
+ if include_end_time:
503
+ end_minutes = int(sentence.end // 60)
504
+ end_seconds = sentence.end % 60
505
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
506
+ else:
507
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
508
+
509
+ # Clean the text (remove structural tags like [verse], [chorus])
510
+ text = sentence.text
511
+
512
+ lines.append(f"{timestamp}{text}")
513
+
514
+ return "\n".join(lines)
515
+
516
+ def get_timestamps_and_lrc(
517
+ self,
518
+ calc_matrix: np.ndarray,
519
+ lyrics_tokens: List[int],
520
+ total_duration_seconds: float
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Convenience method to get both timestamps and LRC in one call.
524
+
525
+ Args:
526
+ calc_matrix: Processed attention matrix
527
+ lyrics_tokens: List of token IDs
528
+ total_duration_seconds: Total audio duration
529
+
530
+ Returns:
531
+ Dict containing token_timestamps, sentence_timestamps, and lrc_text
532
+ """
533
+ token_stamps = self.token_timestamps(
534
+ calc_matrix=calc_matrix,
535
+ lyrics_tokens=lyrics_tokens,
536
+ total_duration_seconds=total_duration_seconds
537
+ )
538
+
539
+ sentence_stamps = self.sentence_timestamps(token_stamps)
540
+ lrc_text = self.format_lrc(sentence_stamps)
541
+
542
+ return {
543
+ "token_timestamps": token_stamps,
544
+ "sentence_timestamps": sentence_stamps,
545
+ "lrc_text": lrc_text
546
+ }
547
+
548
+
549
+ class MusicLyricScorer:
550
+ """
551
+ Scorer class for evaluating lyrics-to-audio alignment quality.
552
+
553
+ Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
554
+ using tensor operations for potential differentiability or GPU acceleration.
555
+ """
556
+
557
+ def __init__(self, tokenizer: Any):
558
+ """
559
+ Initialize the aligner.
560
+
561
+ Args:
562
+ tokenizer: Tokenizer instance (must implement .decode()).
563
+ """
564
+ self.tokenizer = tokenizer
565
+
566
+ def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
567
+ """
568
+ Generate a mask distinguishing lyrics (1) from structural tags (0).
569
+ Uses self.tokenizer to decode tokens.
570
+
571
+ Args:
572
+ token_ids: List of token IDs.
573
+
574
+ Returns:
575
+ Numpy array of shape [len(token_ids)] with 1 or 0.
576
+ """
577
+ decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
578
+ mask = np.ones(len(token_ids), dtype=np.int32)
579
+ in_bracket = False
580
+
581
+ for i, token_str in enumerate(decoded_tokens):
582
+ if '[' in token_str:
583
+ in_bracket = True
584
+ if in_bracket:
585
+ mask[i] = 0
586
+ if ']' in token_str:
587
+ in_bracket = False
588
+ mask[i] = 0
589
+ return mask
590
+
591
+ def _preprocess_attention(
592
+ self,
593
+ attention_matrix: Union[torch.Tensor, np.ndarray],
594
+ custom_config: Dict[int, List[int]],
595
+ medfilt_width: int = 1
596
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
597
+ """
598
+ Extracts and normalizes the attention matrix.
599
+
600
+ Logic V4: Uses Min-Max normalization to highlight energy differences.
601
+
602
+ Args:
603
+ attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
604
+ custom_config: Config mapping layers to heads.
605
+ medfilt_width: Width for median filtering.
606
+
607
+ Returns:
608
+ Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
609
+ """
610
+ # 1. Prepare Tensor
611
+ if not isinstance(attention_matrix, torch.Tensor):
612
+ weights = torch.tensor(attention_matrix)
613
+ else:
614
+ weights = attention_matrix.clone()
615
+ weights = weights.cpu().float()
616
+
617
+ # 2. Select Heads based on config
618
+ selected_tensors = []
619
+ for layer_idx, head_indices in custom_config.items():
620
+ for head_idx in head_indices:
621
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
622
+ selected_tensors.append(weights[layer_idx, head_idx])
623
+
624
+ if not selected_tensors:
625
+ return None, None, None
626
+
627
+ weights_stack = torch.stack(selected_tensors, dim=0)
628
+
629
+ # 3. Average Heads
630
+ avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
631
+
632
+ # 4. Preprocessing Logic
633
+ # Min-Max normalization preserving energy distribution
634
+ # Median filter is applied to the energy matrix
635
+ energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
636
+ energy_matrix = energy_tensor.numpy()
637
+
638
+ e_min, e_max = energy_matrix.min(), energy_matrix.max()
639
+
640
+ if e_max - e_min > 1e-9:
641
+ energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
642
+ else:
643
+ energy_matrix = np.zeros_like(energy_matrix)
644
+
645
+ # Contrast enhancement for DTW pathfinding
646
+ # calc_matrix is used for pathfinding, energy_matrix for scoring
647
+ calc_matrix = energy_matrix ** 2
648
+
649
+ return calc_matrix, energy_matrix, avg_weights
650
+
651
+ def _compute_alignment_metrics(
652
+ self,
653
+ energy_matrix: torch.Tensor,
654
+ path_coords: torch.Tensor,
655
+ type_mask: torch.Tensor,
656
+ time_weight: float = 0.01,
657
+ overlap_frames: float = 9.0,
658
+ instrumental_weight: float = 1.0
659
+ ) -> Tuple[float, float, float]:
660
+ """
661
+ Core metric calculation logic using high-precision Tensor operations.
662
+
663
+ Args:
664
+ energy_matrix: Normalized energy [Rows, Cols].
665
+ path_coords: DTW path coordinates [Steps, 2].
666
+ type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
667
+ time_weight: Minimum energy threshold for monotonicity.
668
+ overlap_frames: Allowed overlap for monotonicity check.
669
+ instrumental_weight: Weight for non-lyric tokens in confidence calc.
670
+
671
+ Returns:
672
+ Tuple of (coverage, monotonicity, confidence).
673
+ """
674
+ # Ensure high precision for internal calculation
675
+ energy_matrix = energy_matrix.to(dtype=torch.float64)
676
+ path_coords = path_coords.long()
677
+ type_mask = type_mask.long()
678
+
679
+ device = energy_matrix.device
680
+ rows, cols = energy_matrix.shape
681
+
682
+ is_lyrics_row = (type_mask == 1)
683
+
684
+ # ================= A. Coverage Score =================
685
+ # Ratio of lyric lines that have significant energy peak
686
+ row_max_energies = energy_matrix.max(dim=1).values
687
+ total_sung_rows = is_lyrics_row.sum().double()
688
+
689
+ coverage_threshold = 0.1
690
+ valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
691
+ valid_sung_rows = valid_sung_mask.sum().double()
692
+
693
+ if total_sung_rows > 0:
694
+ coverage_score = valid_sung_rows / total_sung_rows
695
+ else:
696
+ coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
697
+
698
+ # ================= B. Monotonicity Score =================
699
+ # Check if the "center of mass" of lyric lines moves forward in time
700
+ col_indices = torch.arange(cols, device=device, dtype=torch.float64)
701
+
702
+ # Zero out low energy noise
703
+ weights = torch.where(
704
+ energy_matrix > time_weight,
705
+ energy_matrix,
706
+ torch.zeros_like(energy_matrix)
707
+ )
708
+
709
+ sum_w = weights.sum(dim=1)
710
+ sum_t = (weights * col_indices).sum(dim=1)
711
+
712
+ # Calculate centroids
713
+ centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
714
+ valid_w_mask = sum_w > 1e-9
715
+ centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
716
+
717
+ # Extract sequence of valid lyrics centroids
718
+ valid_sequence_mask = is_lyrics_row & (centroids >= 0)
719
+ sung_centroids = centroids[valid_sequence_mask]
720
+
721
+ cnt = sung_centroids.shape[0]
722
+ if cnt > 1:
723
+ curr_c = sung_centroids[:-1]
724
+ next_c = sung_centroids[1:]
725
+
726
+ # Check non-decreasing order with overlap tolerance
727
+ non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
728
+ pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
729
+ monotonicity_score = non_decreasing / pairs
730
+ else:
731
+ monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
732
+
733
+ # ================= C. Path Confidence =================
734
+ # Average energy along the optimal path
735
+ if path_coords.shape[0] > 0:
736
+ p_rows = path_coords[:, 0]
737
+ p_cols = path_coords[:, 1]
738
+
739
+ path_energies = energy_matrix[p_rows, p_cols]
740
+ step_weights = torch.ones_like(path_energies)
741
+
742
+ # Lower weight for instrumental/tag steps
743
+ is_inst_step = (type_mask[p_rows] == 0)
744
+ step_weights[is_inst_step] = instrumental_weight
745
+
746
+ total_energy = (path_energies * step_weights).sum()
747
+ total_steps = step_weights.sum()
748
+
749
+ if total_steps > 0:
750
+ path_confidence = total_energy / total_steps
751
+ else:
752
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
753
+ else:
754
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
755
+
756
+ return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
757
+
758
+ def lyrics_alignment_info(
759
+ self,
760
+ attention_matrix: Union[torch.Tensor, np.ndarray],
761
+ token_ids: List[int],
762
+ custom_config: Dict[int, List[int]],
763
+ return_matrices: bool = False,
764
+ medfilt_width: int = 1
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Generates alignment path and processed matrices.
768
+
769
+ Args:
770
+ attention_matrix: Input attention tensor.
771
+ token_ids: Corresponding token IDs.
772
+ custom_config: Layer/Head configuration.
773
+ return_matrices: If True, returns matrices in the output.
774
+ medfilt_width: Median filter width.
775
+
776
+ Returns:
777
+ Dict or AlignmentInfo object containing path and masks.
778
+ """
779
+ calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
780
+ attention_matrix, custom_config, medfilt_width
781
+ )
782
+
783
+ if calc_matrix is None:
784
+ return {
785
+ "calc_matrix": None,
786
+ "error": "No valid attention heads found"
787
+ }
788
+
789
+ # 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
790
+ # Uses self.tokenizer internally
791
+ type_mask = self._generate_token_type_mask(token_ids)
792
+
793
+ # Safety check for shape mismatch
794
+ if len(type_mask) != energy_matrix.shape[0]:
795
+ # Fallback to all lyrics if shapes don't align
796
+ type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
797
+
798
+ # 2. DTW Pathfinding
799
+ # Using negative calc_matrix because DTW minimizes cost
800
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
801
+ path_coords = np.stack([text_indices, time_indices], axis=1)
802
+
803
+ return_dict = {
804
+ "path_coords": path_coords,
805
+ "type_mask": type_mask,
806
+ "energy_matrix": energy_matrix
807
+ }
808
+ if return_matrices:
809
+ return_dict['calc_matrix'] = calc_matrix
810
+ return_dict['vis_matrix'] = vis_matrix
811
+
812
+ return return_dict
813
+
814
+ def calculate_score(
815
+ self,
816
+ energy_matrix: Union[torch.Tensor, np.ndarray],
817
+ type_mask: Union[torch.Tensor, np.ndarray],
818
+ path_coords: Union[torch.Tensor, np.ndarray],
819
+ time_weight: float = 0.01,
820
+ overlap_frames: float = 9.0,
821
+ instrumental_weight: float = 1.0
822
+ ) -> Dict[str, Any]:
823
+ """
824
+ Calculates the final alignment score based on pre-computed components.
825
+
826
+ Args:
827
+ energy_matrix: Processed energy matrix.
828
+ type_mask: Token type mask.
829
+ path_coords: DTW path coordinates.
830
+ time_weight: Minimum energy threshold for monotonicity.
831
+ overlap_frames: Allowed backward movement frames.
832
+ instrumental_weight: Weight for non-lyric path steps.
833
+
834
+ Returns:
835
+ AlignmentScore object containing individual metrics and final score.
836
+ """
837
+ # Ensure Inputs are Tensors on the correct device
838
+ if not isinstance(energy_matrix, torch.Tensor):
839
+ energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
840
+
841
+ device = energy_matrix.device
842
+
843
+ if not isinstance(type_mask, torch.Tensor):
844
+ type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
845
+ else:
846
+ type_mask = type_mask.to(device=device, dtype=torch.long)
847
+
848
+ if not isinstance(path_coords, torch.Tensor):
849
+ path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
850
+ else:
851
+ path_coords = path_coords.to(device=device, dtype=torch.long)
852
+
853
+ # Compute Metrics
854
+ coverage, monotonicity, confidence = self._compute_alignment_metrics(
855
+ energy_matrix=energy_matrix,
856
+ path_coords=path_coords,
857
+ type_mask=type_mask,
858
+ time_weight=time_weight,
859
+ overlap_frames=overlap_frames,
860
+ instrumental_weight=instrumental_weight
861
+ )
862
+
863
+ # Final Score Calculation
864
+ # (Cov^2 * Mono^2 * Conf)
865
+ final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
866
+ final_score = float(np.clip(final_score, 0.0, 1.0))
867
+
868
+ return {
869
+ "lyrics_score": round(final_score, 4)
870
+ }
code/acestep/genres_vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
code/acestep/gradio_ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from acestep.gradio_ui.interfaces import create_gradio_interface
code/acestep/gradio_ui/events/__init__.py ADDED
@@ -0,0 +1,1129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Event Handlers Module
3
+ Main entry point for setting up all event handlers
4
+ """
5
+ import gradio as gr
6
+ from typing import Optional
7
+
8
+ # Import handler modules
9
+ from . import generation_handlers as gen_h
10
+ from . import results_handlers as res_h
11
+ from . import training_handlers as train_h
12
+ from acestep.gradio_ui.i18n import t
13
+
14
+
15
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
16
+ """Setup event handlers connecting UI components and business logic"""
17
+
18
+ # ========== Dataset Handlers ==========
19
+ dataset_section["import_dataset_btn"].click(
20
+ fn=dataset_handler.import_dataset,
21
+ inputs=[dataset_section["dataset_type"]],
22
+ outputs=[dataset_section["data_status"]]
23
+ )
24
+
25
+ # ========== Service Initialization ==========
26
+ generation_section["refresh_btn"].click(
27
+ fn=lambda: gen_h.refresh_checkpoints(dit_handler),
28
+ outputs=[generation_section["checkpoint_dropdown"]]
29
+ )
30
+
31
+ generation_section["config_path"].change(
32
+ fn=gen_h.update_model_type_settings,
33
+ inputs=[generation_section["config_path"]],
34
+ outputs=[
35
+ generation_section["inference_steps"],
36
+ generation_section["guidance_scale"],
37
+ generation_section["use_adg"],
38
+ generation_section["shift"],
39
+ generation_section["cfg_interval_start"],
40
+ generation_section["cfg_interval_end"],
41
+ generation_section["task_type"],
42
+ ]
43
+ )
44
+
45
+ generation_section["init_btn"].click(
46
+ fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
47
+ inputs=[
48
+ generation_section["checkpoint_dropdown"],
49
+ generation_section["config_path"],
50
+ generation_section["device"],
51
+ generation_section["init_llm_checkbox"],
52
+ generation_section["lm_model_path"],
53
+ generation_section["backend_dropdown"],
54
+ generation_section["use_flash_attention_checkbox"],
55
+ generation_section["offload_to_cpu_checkbox"],
56
+ generation_section["offload_dit_to_cpu_checkbox"],
57
+ ],
58
+ outputs=[
59
+ generation_section["init_status"],
60
+ generation_section["generate_btn"],
61
+ generation_section["service_config_accordion"],
62
+ # Model type settings (updated based on actual loaded model)
63
+ generation_section["inference_steps"],
64
+ generation_section["guidance_scale"],
65
+ generation_section["use_adg"],
66
+ generation_section["shift"],
67
+ generation_section["cfg_interval_start"],
68
+ generation_section["cfg_interval_end"],
69
+ generation_section["task_type"],
70
+ ]
71
+ )
72
+
73
+ # ========== LoRA Handlers ==========
74
+ generation_section["load_lora_btn"].click(
75
+ fn=dit_handler.load_lora,
76
+ inputs=[generation_section["lora_path"]],
77
+ outputs=[generation_section["lora_status"]]
78
+ ).then(
79
+ # Update checkbox to enabled state after loading
80
+ fn=lambda: gr.update(value=True),
81
+ outputs=[generation_section["use_lora_checkbox"]]
82
+ )
83
+
84
+ generation_section["unload_lora_btn"].click(
85
+ fn=dit_handler.unload_lora,
86
+ outputs=[generation_section["lora_status"]]
87
+ ).then(
88
+ # Update checkbox to disabled state after unloading
89
+ fn=lambda: gr.update(value=False),
90
+ outputs=[generation_section["use_lora_checkbox"]]
91
+ )
92
+
93
+ generation_section["use_lora_checkbox"].change(
94
+ fn=dit_handler.set_use_lora,
95
+ inputs=[generation_section["use_lora_checkbox"]],
96
+ outputs=[generation_section["lora_status"]]
97
+ )
98
+
99
+ # ========== UI Visibility Updates ==========
100
+ generation_section["init_llm_checkbox"].change(
101
+ fn=gen_h.update_negative_prompt_visibility,
102
+ inputs=[generation_section["init_llm_checkbox"]],
103
+ outputs=[generation_section["lm_negative_prompt"]]
104
+ )
105
+
106
+ generation_section["init_llm_checkbox"].change(
107
+ fn=gen_h.update_audio_cover_strength_visibility,
108
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
109
+ outputs=[generation_section["audio_cover_strength"]]
110
+ )
111
+
112
+ generation_section["task_type"].change(
113
+ fn=gen_h.update_audio_cover_strength_visibility,
114
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
115
+ outputs=[generation_section["audio_cover_strength"]]
116
+ )
117
+
118
+ generation_section["batch_size_input"].change(
119
+ fn=gen_h.update_audio_components_visibility,
120
+ inputs=[generation_section["batch_size_input"]],
121
+ outputs=[
122
+ results_section["audio_col_1"],
123
+ results_section["audio_col_2"],
124
+ results_section["audio_col_3"],
125
+ results_section["audio_col_4"],
126
+ results_section["audio_row_5_8"],
127
+ results_section["audio_col_5"],
128
+ results_section["audio_col_6"],
129
+ results_section["audio_col_7"],
130
+ results_section["audio_col_8"],
131
+ ]
132
+ )
133
+
134
+ # ========== Audio Conversion ==========
135
+ generation_section["convert_src_to_codes_btn"].click(
136
+ fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
137
+ inputs=[generation_section["src_audio"]],
138
+ outputs=[generation_section["text2music_audio_code_string"]]
139
+ )
140
+
141
+ # ========== Instruction UI Updates ==========
142
+ for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"]]:
143
+ trigger.change(
144
+ fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
145
+ inputs=[
146
+ generation_section["task_type"],
147
+ generation_section["track_name"],
148
+ generation_section["complete_track_classes"],
149
+ generation_section["text2music_audio_code_string"],
150
+ generation_section["init_llm_checkbox"]
151
+ ],
152
+ outputs=[
153
+ generation_section["instruction_display_gen"],
154
+ generation_section["track_name"],
155
+ generation_section["complete_track_classes"],
156
+ generation_section["audio_cover_strength"],
157
+ generation_section["repainting_group"],
158
+ generation_section["text2music_audio_codes_group"],
159
+ ]
160
+ )
161
+
162
+ # ========== Sample/Transcribe Handlers ==========
163
+ # Load random example from ./examples/text2music directory
164
+ generation_section["sample_btn"].click(
165
+ fn=lambda task: gen_h.load_random_example(task) + (True,),
166
+ inputs=[
167
+ generation_section["task_type"],
168
+ ],
169
+ outputs=[
170
+ generation_section["captions"],
171
+ generation_section["lyrics"],
172
+ generation_section["think_checkbox"],
173
+ generation_section["bpm"],
174
+ generation_section["audio_duration"],
175
+ generation_section["key_scale"],
176
+ generation_section["vocal_language"],
177
+ generation_section["time_signature"],
178
+ results_section["is_format_caption_state"]
179
+ ]
180
+ )
181
+
182
+ generation_section["text2music_audio_code_string"].change(
183
+ fn=gen_h.update_transcribe_button_text,
184
+ inputs=[generation_section["text2music_audio_code_string"]],
185
+ outputs=[generation_section["transcribe_btn"]]
186
+ )
187
+
188
+ generation_section["transcribe_btn"].click(
189
+ fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
190
+ inputs=[
191
+ generation_section["text2music_audio_code_string"],
192
+ generation_section["constrained_decoding_debug"]
193
+ ],
194
+ outputs=[
195
+ results_section["status_output"],
196
+ generation_section["captions"],
197
+ generation_section["lyrics"],
198
+ generation_section["bpm"],
199
+ generation_section["audio_duration"],
200
+ generation_section["key_scale"],
201
+ generation_section["vocal_language"],
202
+ generation_section["time_signature"],
203
+ results_section["is_format_caption_state"]
204
+ ]
205
+ )
206
+
207
+ # ========== Reset Format Caption Flag ==========
208
+ for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
209
+ generation_section["key_scale"], generation_section["time_signature"],
210
+ generation_section["vocal_language"], generation_section["audio_duration"]]:
211
+ trigger.change(
212
+ fn=gen_h.reset_format_caption_flag,
213
+ inputs=[],
214
+ outputs=[results_section["is_format_caption_state"]]
215
+ )
216
+
217
+ # ========== Audio Uploads Accordion ==========
218
+ for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
219
+ trigger.change(
220
+ fn=gen_h.update_audio_uploads_accordion,
221
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
222
+ outputs=[generation_section["audio_uploads_accordion"]]
223
+ )
224
+
225
+ # ========== Instrumental Checkbox ==========
226
+ generation_section["instrumental_checkbox"].change(
227
+ fn=gen_h.handle_instrumental_checkbox,
228
+ inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
229
+ outputs=[generation_section["lyrics"]]
230
+ )
231
+
232
+ # ========== Format Button ==========
233
+ # Note: cfg_scale and negative_prompt are not supported in format mode
234
+ generation_section["format_btn"].click(
235
+ fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
236
+ llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
237
+ ),
238
+ inputs=[
239
+ generation_section["captions"],
240
+ generation_section["lyrics"],
241
+ generation_section["bpm"],
242
+ generation_section["audio_duration"],
243
+ generation_section["key_scale"],
244
+ generation_section["time_signature"],
245
+ generation_section["lm_temperature"],
246
+ generation_section["lm_top_k"],
247
+ generation_section["lm_top_p"],
248
+ generation_section["constrained_decoding_debug"],
249
+ ],
250
+ outputs=[
251
+ generation_section["captions"],
252
+ generation_section["lyrics"],
253
+ generation_section["bpm"],
254
+ generation_section["audio_duration"],
255
+ generation_section["key_scale"],
256
+ generation_section["vocal_language"],
257
+ generation_section["time_signature"],
258
+ results_section["is_format_caption_state"],
259
+ results_section["status_output"],
260
+ ]
261
+ )
262
+
263
+ # ========== Simple/Custom Mode Toggle ==========
264
+ generation_section["generation_mode"].change(
265
+ fn=gen_h.handle_generation_mode_change,
266
+ inputs=[generation_section["generation_mode"]],
267
+ outputs=[
268
+ generation_section["simple_mode_group"],
269
+ generation_section["caption_accordion"],
270
+ generation_section["lyrics_accordion"],
271
+ generation_section["generate_btn"],
272
+ generation_section["simple_sample_created"],
273
+ generation_section["optional_params_accordion"],
274
+ ]
275
+ )
276
+
277
+ # ========== Simple Mode Instrumental Checkbox ==========
278
+ # When instrumental is checked, disable vocal language and set to ["unknown"]
279
+ generation_section["simple_instrumental_checkbox"].change(
280
+ fn=gen_h.handle_simple_instrumental_change,
281
+ inputs=[generation_section["simple_instrumental_checkbox"]],
282
+ outputs=[generation_section["simple_vocal_language"]]
283
+ )
284
+
285
+ # ========== Random Description Button ==========
286
+ generation_section["random_desc_btn"].click(
287
+ fn=gen_h.load_random_simple_description,
288
+ inputs=[],
289
+ outputs=[
290
+ generation_section["simple_query_input"],
291
+ generation_section["simple_instrumental_checkbox"],
292
+ generation_section["simple_vocal_language"],
293
+ ]
294
+ )
295
+
296
+ # ========== Create Sample Button (Simple Mode) ==========
297
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
298
+ generation_section["create_sample_btn"].click(
299
+ fn=lambda query, instrumental, vocal_lang, temp, top_k, top_p, debug: gen_h.handle_create_sample(
300
+ llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
301
+ ),
302
+ inputs=[
303
+ generation_section["simple_query_input"],
304
+ generation_section["simple_instrumental_checkbox"],
305
+ generation_section["simple_vocal_language"],
306
+ generation_section["lm_temperature"],
307
+ generation_section["lm_top_k"],
308
+ generation_section["lm_top_p"],
309
+ generation_section["constrained_decoding_debug"],
310
+ ],
311
+ outputs=[
312
+ generation_section["captions"],
313
+ generation_section["lyrics"],
314
+ generation_section["bpm"],
315
+ generation_section["audio_duration"],
316
+ generation_section["key_scale"],
317
+ generation_section["vocal_language"],
318
+ generation_section["simple_vocal_language"],
319
+ generation_section["time_signature"],
320
+ generation_section["instrumental_checkbox"],
321
+ generation_section["caption_accordion"],
322
+ generation_section["lyrics_accordion"],
323
+ generation_section["generate_btn"],
324
+ generation_section["simple_sample_created"],
325
+ generation_section["think_checkbox"],
326
+ results_section["is_format_caption_state"],
327
+ results_section["status_output"],
328
+ ]
329
+ )
330
+
331
+ # ========== Load/Save Metadata ==========
332
+ generation_section["load_file"].upload(
333
+ fn=gen_h.load_metadata,
334
+ inputs=[generation_section["load_file"]],
335
+ outputs=[
336
+ generation_section["task_type"],
337
+ generation_section["captions"],
338
+ generation_section["lyrics"],
339
+ generation_section["vocal_language"],
340
+ generation_section["bpm"],
341
+ generation_section["key_scale"],
342
+ generation_section["time_signature"],
343
+ generation_section["audio_duration"],
344
+ generation_section["batch_size_input"],
345
+ generation_section["inference_steps"],
346
+ generation_section["guidance_scale"],
347
+ generation_section["seed"],
348
+ generation_section["random_seed_checkbox"],
349
+ generation_section["use_adg"],
350
+ generation_section["cfg_interval_start"],
351
+ generation_section["cfg_interval_end"],
352
+ generation_section["shift"],
353
+ generation_section["infer_method"],
354
+ generation_section["custom_timesteps"],
355
+ generation_section["audio_format"],
356
+ generation_section["lm_temperature"],
357
+ generation_section["lm_cfg_scale"],
358
+ generation_section["lm_top_k"],
359
+ generation_section["lm_top_p"],
360
+ generation_section["lm_negative_prompt"],
361
+ generation_section["use_cot_metas"], # Added: use_cot_metas
362
+ generation_section["use_cot_caption"],
363
+ generation_section["use_cot_language"],
364
+ generation_section["audio_cover_strength"],
365
+ generation_section["think_checkbox"],
366
+ generation_section["text2music_audio_code_string"],
367
+ generation_section["repainting_start"],
368
+ generation_section["repainting_end"],
369
+ generation_section["track_name"],
370
+ generation_section["complete_track_classes"],
371
+ generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
372
+ results_section["is_format_caption_state"]
373
+ ]
374
+ )
375
+
376
+ # Save buttons for all 8 audio outputs
377
+ download_existing_js = """(current_audio, batch_files) => {
378
+ // Debug: print what the input actually is
379
+ console.log("👉 [Debug] Current Audio Input:", current_audio);
380
+
381
+ // 1. Safety check
382
+ if (!current_audio) {
383
+ console.warn("⚠️ No audio selected or audio is empty.");
384
+ return;
385
+ }
386
+ if (!batch_files || !Array.isArray(batch_files)) {
387
+ console.warn("⚠️ Batch file list is empty/not ready.");
388
+ return;
389
+ }
390
+
391
+ // 2. Smartly extract path string
392
+ let pathString = "";
393
+
394
+ if (typeof current_audio === "string") {
395
+ // Case A: direct path string received
396
+ pathString = current_audio;
397
+ } else if (typeof current_audio === "object") {
398
+ // Case B: an object is received, try common properties
399
+ // Gradio file objects usually have path, url, or name
400
+ pathString = current_audio.path || current_audio.name || current_audio.url || "";
401
+ }
402
+
403
+ if (!pathString) {
404
+ console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
405
+ return;
406
+ }
407
+
408
+ // 3. Extract Key (UUID)
409
+ // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
410
+ let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
411
+ let key = filename.split('.')[0]; // get UUID without extension
412
+
413
+ console.log(`🔑 Key extracted: ${key}`);
414
+
415
+ // 4. Find matching file(s) in the list
416
+ let targets = batch_files.filter(f => {
417
+ // Also extract names from batch_files objects
418
+ // f usually contains name (backend path) and orig_name (download name)
419
+ const fPath = f.name || f.path || "";
420
+ return fPath.includes(key);
421
+ });
422
+
423
+ if (targets.length === 0) {
424
+ console.warn("❌ No matching files found in batch list for key:", key);
425
+ alert("Batch list does not contain this file yet. Please wait for generation to finish.");
426
+ return;
427
+ }
428
+
429
+ // 5. Trigger download(s)
430
+ console.log(`🎯 Found ${targets.length} files to download.`);
431
+ targets.forEach((f, index) => {
432
+ setTimeout(() => {
433
+ const a = document.createElement('a');
434
+ // Prefer url (frontend-accessible link), otherwise try data
435
+ a.href = f.url || f.data;
436
+ a.download = f.orig_name || "download";
437
+ a.style.display = 'none';
438
+ document.body.appendChild(a);
439
+ a.click();
440
+ document.body.removeChild(a);
441
+ }, index * 1000); // 300ms interval to avoid browser blocking
442
+ });
443
+ }
444
+ """
445
+ for btn_idx in range(1, 9):
446
+ results_section[f"save_btn_{btn_idx}"].click(
447
+ fn=None,
448
+ inputs=[
449
+ results_section[f"generated_audio_{btn_idx}"],
450
+ results_section["generated_audio_batch"],
451
+ ],
452
+ js=download_existing_js # Run the above JS
453
+ )
454
+ # ========== Send to SRC Handlers ==========
455
+ for btn_idx in range(1, 9):
456
+ results_section[f"send_to_src_btn_{btn_idx}"].click(
457
+ fn=res_h.send_audio_to_src_with_metadata,
458
+ inputs=[
459
+ results_section[f"generated_audio_{btn_idx}"],
460
+ results_section["lm_metadata_state"]
461
+ ],
462
+ outputs=[
463
+ generation_section["src_audio"],
464
+ generation_section["bpm"],
465
+ generation_section["captions"],
466
+ generation_section["lyrics"],
467
+ generation_section["audio_duration"],
468
+ generation_section["key_scale"],
469
+ generation_section["vocal_language"],
470
+ generation_section["time_signature"],
471
+ results_section["is_format_caption_state"]
472
+ ]
473
+ )
474
+
475
+ # ========== Score Calculation Handlers ==========
476
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
477
+ def make_score_handler(idx):
478
+ return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
479
+ dit_handler, llm_handler, idx, scale, batch_idx, queue
480
+ )
481
+
482
+ for btn_idx in range(1, 9):
483
+ results_section[f"score_btn_{btn_idx}"].click(
484
+ fn=make_score_handler(btn_idx),
485
+ inputs=[
486
+ generation_section["score_scale"],
487
+ results_section["current_batch_index"],
488
+ results_section["batch_queue"],
489
+ ],
490
+ outputs=[
491
+ results_section[f"score_display_{btn_idx}"],
492
+ results_section[f"details_accordion_{btn_idx}"],
493
+ results_section["batch_queue"]
494
+ ]
495
+ )
496
+
497
+ # ========== LRC Timestamp Handlers ==========
498
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
499
+ def make_lrc_handler(idx):
500
+ return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
501
+ dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
502
+ )
503
+
504
+ for btn_idx in range(1, 9):
505
+ results_section[f"lrc_btn_{btn_idx}"].click(
506
+ fn=make_lrc_handler(btn_idx),
507
+ inputs=[
508
+ results_section["current_batch_index"],
509
+ results_section["batch_queue"],
510
+ generation_section["vocal_language"],
511
+ generation_section["inference_steps"],
512
+ ],
513
+ outputs=[
514
+ results_section[f"lrc_display_{btn_idx}"],
515
+ results_section[f"details_accordion_{btn_idx}"],
516
+ # NOTE: Removed generated_audio output!
517
+ # Audio subtitles are now updated via lrc_display.change() event.
518
+ results_section["batch_queue"]
519
+ ]
520
+ )
521
+
522
+ def generation_wrapper(*args):
523
+ yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
524
+ # ========== Generation Handler ==========
525
+ generation_section["generate_btn"].click(
526
+ fn=generation_wrapper,
527
+ inputs=[
528
+ generation_section["captions"],
529
+ generation_section["lyrics"],
530
+ generation_section["bpm"],
531
+ generation_section["key_scale"],
532
+ generation_section["time_signature"],
533
+ generation_section["vocal_language"],
534
+ generation_section["inference_steps"],
535
+ generation_section["guidance_scale"],
536
+ generation_section["random_seed_checkbox"],
537
+ generation_section["seed"],
538
+ generation_section["reference_audio"],
539
+ generation_section["audio_duration"],
540
+ generation_section["batch_size_input"],
541
+ generation_section["src_audio"],
542
+ generation_section["text2music_audio_code_string"],
543
+ generation_section["repainting_start"],
544
+ generation_section["repainting_end"],
545
+ generation_section["instruction_display_gen"],
546
+ generation_section["audio_cover_strength"],
547
+ generation_section["task_type"],
548
+ generation_section["use_adg"],
549
+ generation_section["cfg_interval_start"],
550
+ generation_section["cfg_interval_end"],
551
+ generation_section["shift"],
552
+ generation_section["infer_method"],
553
+ generation_section["custom_timesteps"],
554
+ generation_section["audio_format"],
555
+ generation_section["lm_temperature"],
556
+ generation_section["think_checkbox"],
557
+ generation_section["lm_cfg_scale"],
558
+ generation_section["lm_top_k"],
559
+ generation_section["lm_top_p"],
560
+ generation_section["lm_negative_prompt"],
561
+ generation_section["use_cot_metas"],
562
+ generation_section["use_cot_caption"],
563
+ generation_section["use_cot_language"],
564
+ results_section["is_format_caption_state"],
565
+ generation_section["constrained_decoding_debug"],
566
+ generation_section["allow_lm_batch"],
567
+ generation_section["auto_score"],
568
+ generation_section["auto_lrc"],
569
+ generation_section["score_scale"],
570
+ generation_section["lm_batch_chunk_size"],
571
+ generation_section["track_name"],
572
+ generation_section["complete_track_classes"],
573
+ generation_section["autogen_checkbox"],
574
+ results_section["current_batch_index"],
575
+ results_section["total_batches"],
576
+ results_section["batch_queue"],
577
+ results_section["generation_params_state"],
578
+ ],
579
+ outputs=[
580
+ results_section["generated_audio_1"],
581
+ results_section["generated_audio_2"],
582
+ results_section["generated_audio_3"],
583
+ results_section["generated_audio_4"],
584
+ results_section["generated_audio_5"],
585
+ results_section["generated_audio_6"],
586
+ results_section["generated_audio_7"],
587
+ results_section["generated_audio_8"],
588
+ results_section["generated_audio_batch"],
589
+ results_section["generation_info"],
590
+ results_section["status_output"],
591
+ generation_section["seed"],
592
+ results_section["score_display_1"],
593
+ results_section["score_display_2"],
594
+ results_section["score_display_3"],
595
+ results_section["score_display_4"],
596
+ results_section["score_display_5"],
597
+ results_section["score_display_6"],
598
+ results_section["score_display_7"],
599
+ results_section["score_display_8"],
600
+ results_section["codes_display_1"],
601
+ results_section["codes_display_2"],
602
+ results_section["codes_display_3"],
603
+ results_section["codes_display_4"],
604
+ results_section["codes_display_5"],
605
+ results_section["codes_display_6"],
606
+ results_section["codes_display_7"],
607
+ results_section["codes_display_8"],
608
+ results_section["details_accordion_1"],
609
+ results_section["details_accordion_2"],
610
+ results_section["details_accordion_3"],
611
+ results_section["details_accordion_4"],
612
+ results_section["details_accordion_5"],
613
+ results_section["details_accordion_6"],
614
+ results_section["details_accordion_7"],
615
+ results_section["details_accordion_8"],
616
+ results_section["lrc_display_1"],
617
+ results_section["lrc_display_2"],
618
+ results_section["lrc_display_3"],
619
+ results_section["lrc_display_4"],
620
+ results_section["lrc_display_5"],
621
+ results_section["lrc_display_6"],
622
+ results_section["lrc_display_7"],
623
+ results_section["lrc_display_8"],
624
+ results_section["lm_metadata_state"],
625
+ results_section["is_format_caption_state"],
626
+ results_section["current_batch_index"],
627
+ results_section["total_batches"],
628
+ results_section["batch_queue"],
629
+ results_section["generation_params_state"],
630
+ results_section["batch_indicator"],
631
+ results_section["prev_batch_btn"],
632
+ results_section["next_batch_btn"],
633
+ results_section["next_batch_status"],
634
+ results_section["restore_params_btn"],
635
+ ]
636
+ ).then(
637
+ fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
638
+ inputs=[
639
+ generation_section["autogen_checkbox"],
640
+ results_section["generation_params_state"],
641
+ results_section["current_batch_index"],
642
+ results_section["total_batches"],
643
+ results_section["batch_queue"],
644
+ results_section["is_format_caption_state"],
645
+ ],
646
+ outputs=[
647
+ results_section["batch_queue"],
648
+ results_section["total_batches"],
649
+ results_section["next_batch_status"],
650
+ results_section["next_batch_btn"],
651
+ ]
652
+ )
653
+
654
+ # ========== Batch Navigation Handlers ==========
655
+ results_section["prev_batch_btn"].click(
656
+ fn=res_h.navigate_to_previous_batch,
657
+ inputs=[
658
+ results_section["current_batch_index"],
659
+ results_section["batch_queue"],
660
+ ],
661
+ outputs=[
662
+ results_section["generated_audio_1"],
663
+ results_section["generated_audio_2"],
664
+ results_section["generated_audio_3"],
665
+ results_section["generated_audio_4"],
666
+ results_section["generated_audio_5"],
667
+ results_section["generated_audio_6"],
668
+ results_section["generated_audio_7"],
669
+ results_section["generated_audio_8"],
670
+ results_section["generated_audio_batch"],
671
+ results_section["generation_info"],
672
+ results_section["current_batch_index"],
673
+ results_section["batch_indicator"],
674
+ results_section["prev_batch_btn"],
675
+ results_section["next_batch_btn"],
676
+ results_section["status_output"],
677
+ results_section["score_display_1"],
678
+ results_section["score_display_2"],
679
+ results_section["score_display_3"],
680
+ results_section["score_display_4"],
681
+ results_section["score_display_5"],
682
+ results_section["score_display_6"],
683
+ results_section["score_display_7"],
684
+ results_section["score_display_8"],
685
+ results_section["codes_display_1"],
686
+ results_section["codes_display_2"],
687
+ results_section["codes_display_3"],
688
+ results_section["codes_display_4"],
689
+ results_section["codes_display_5"],
690
+ results_section["codes_display_6"],
691
+ results_section["codes_display_7"],
692
+ results_section["codes_display_8"],
693
+ results_section["lrc_display_1"],
694
+ results_section["lrc_display_2"],
695
+ results_section["lrc_display_3"],
696
+ results_section["lrc_display_4"],
697
+ results_section["lrc_display_5"],
698
+ results_section["lrc_display_6"],
699
+ results_section["lrc_display_7"],
700
+ results_section["lrc_display_8"],
701
+ results_section["details_accordion_1"],
702
+ results_section["details_accordion_2"],
703
+ results_section["details_accordion_3"],
704
+ results_section["details_accordion_4"],
705
+ results_section["details_accordion_5"],
706
+ results_section["details_accordion_6"],
707
+ results_section["details_accordion_7"],
708
+ results_section["details_accordion_8"],
709
+ results_section["restore_params_btn"],
710
+ ]
711
+ )
712
+
713
+ results_section["next_batch_btn"].click(
714
+ fn=res_h.capture_current_params,
715
+ inputs=[
716
+ generation_section["captions"],
717
+ generation_section["lyrics"],
718
+ generation_section["bpm"],
719
+ generation_section["key_scale"],
720
+ generation_section["time_signature"],
721
+ generation_section["vocal_language"],
722
+ generation_section["inference_steps"],
723
+ generation_section["guidance_scale"],
724
+ generation_section["random_seed_checkbox"],
725
+ generation_section["seed"],
726
+ generation_section["reference_audio"],
727
+ generation_section["audio_duration"],
728
+ generation_section["batch_size_input"],
729
+ generation_section["src_audio"],
730
+ generation_section["text2music_audio_code_string"],
731
+ generation_section["repainting_start"],
732
+ generation_section["repainting_end"],
733
+ generation_section["instruction_display_gen"],
734
+ generation_section["audio_cover_strength"],
735
+ generation_section["task_type"],
736
+ generation_section["use_adg"],
737
+ generation_section["cfg_interval_start"],
738
+ generation_section["cfg_interval_end"],
739
+ generation_section["shift"],
740
+ generation_section["infer_method"],
741
+ generation_section["custom_timesteps"],
742
+ generation_section["audio_format"],
743
+ generation_section["lm_temperature"],
744
+ generation_section["think_checkbox"],
745
+ generation_section["lm_cfg_scale"],
746
+ generation_section["lm_top_k"],
747
+ generation_section["lm_top_p"],
748
+ generation_section["lm_negative_prompt"],
749
+ generation_section["use_cot_metas"],
750
+ generation_section["use_cot_caption"],
751
+ generation_section["use_cot_language"],
752
+ generation_section["constrained_decoding_debug"],
753
+ generation_section["allow_lm_batch"],
754
+ generation_section["auto_score"],
755
+ generation_section["auto_lrc"],
756
+ generation_section["score_scale"],
757
+ generation_section["lm_batch_chunk_size"],
758
+ generation_section["track_name"],
759
+ generation_section["complete_track_classes"],
760
+ ],
761
+ outputs=[results_section["generation_params_state"]]
762
+ ).then(
763
+ fn=res_h.navigate_to_next_batch,
764
+ inputs=[
765
+ generation_section["autogen_checkbox"],
766
+ results_section["current_batch_index"],
767
+ results_section["total_batches"],
768
+ results_section["batch_queue"],
769
+ ],
770
+ outputs=[
771
+ results_section["generated_audio_1"],
772
+ results_section["generated_audio_2"],
773
+ results_section["generated_audio_3"],
774
+ results_section["generated_audio_4"],
775
+ results_section["generated_audio_5"],
776
+ results_section["generated_audio_6"],
777
+ results_section["generated_audio_7"],
778
+ results_section["generated_audio_8"],
779
+ results_section["generated_audio_batch"],
780
+ results_section["generation_info"],
781
+ results_section["current_batch_index"],
782
+ results_section["batch_indicator"],
783
+ results_section["prev_batch_btn"],
784
+ results_section["next_batch_btn"],
785
+ results_section["status_output"],
786
+ results_section["next_batch_status"],
787
+ results_section["score_display_1"],
788
+ results_section["score_display_2"],
789
+ results_section["score_display_3"],
790
+ results_section["score_display_4"],
791
+ results_section["score_display_5"],
792
+ results_section["score_display_6"],
793
+ results_section["score_display_7"],
794
+ results_section["score_display_8"],
795
+ results_section["codes_display_1"],
796
+ results_section["codes_display_2"],
797
+ results_section["codes_display_3"],
798
+ results_section["codes_display_4"],
799
+ results_section["codes_display_5"],
800
+ results_section["codes_display_6"],
801
+ results_section["codes_display_7"],
802
+ results_section["codes_display_8"],
803
+ results_section["lrc_display_1"],
804
+ results_section["lrc_display_2"],
805
+ results_section["lrc_display_3"],
806
+ results_section["lrc_display_4"],
807
+ results_section["lrc_display_5"],
808
+ results_section["lrc_display_6"],
809
+ results_section["lrc_display_7"],
810
+ results_section["lrc_display_8"],
811
+ results_section["details_accordion_1"],
812
+ results_section["details_accordion_2"],
813
+ results_section["details_accordion_3"],
814
+ results_section["details_accordion_4"],
815
+ results_section["details_accordion_5"],
816
+ results_section["details_accordion_6"],
817
+ results_section["details_accordion_7"],
818
+ results_section["details_accordion_8"],
819
+ results_section["restore_params_btn"],
820
+ ]
821
+ ).then(
822
+ fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
823
+ inputs=[
824
+ generation_section["autogen_checkbox"],
825
+ results_section["generation_params_state"],
826
+ results_section["current_batch_index"],
827
+ results_section["total_batches"],
828
+ results_section["batch_queue"],
829
+ results_section["is_format_caption_state"],
830
+ ],
831
+ outputs=[
832
+ results_section["batch_queue"],
833
+ results_section["total_batches"],
834
+ results_section["next_batch_status"],
835
+ results_section["next_batch_btn"],
836
+ ]
837
+ )
838
+
839
+ # ========== Restore Parameters Handler ==========
840
+ results_section["restore_params_btn"].click(
841
+ fn=res_h.restore_batch_parameters,
842
+ inputs=[
843
+ results_section["current_batch_index"],
844
+ results_section["batch_queue"]
845
+ ],
846
+ outputs=[
847
+ generation_section["text2music_audio_code_string"],
848
+ generation_section["captions"],
849
+ generation_section["lyrics"],
850
+ generation_section["bpm"],
851
+ generation_section["key_scale"],
852
+ generation_section["time_signature"],
853
+ generation_section["vocal_language"],
854
+ generation_section["audio_duration"],
855
+ generation_section["batch_size_input"],
856
+ generation_section["inference_steps"],
857
+ generation_section["lm_temperature"],
858
+ generation_section["lm_cfg_scale"],
859
+ generation_section["lm_top_k"],
860
+ generation_section["lm_top_p"],
861
+ generation_section["think_checkbox"],
862
+ generation_section["use_cot_caption"],
863
+ generation_section["use_cot_language"],
864
+ generation_section["allow_lm_batch"],
865
+ generation_section["track_name"],
866
+ generation_section["complete_track_classes"],
867
+ ]
868
+ )
869
+
870
+ # ========== LRC Display Change Handlers ==========
871
+ # NEW APPROACH: Use lrc_display.change() to update audio subtitles
872
+ # This decouples audio value updates from subtitle updates, avoiding flickering.
873
+ #
874
+ # When lrc_display text changes (from generate, LRC button, or manual edit):
875
+ # 1. lrc_display.change() is triggered
876
+ # 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
877
+ # 3. Audio value is NEVER updated here - only subtitles
878
+ for lrc_idx in range(1, 9):
879
+ results_section[f"lrc_display_{lrc_idx}"].change(
880
+ fn=res_h.update_audio_subtitles_from_lrc,
881
+ inputs=[
882
+ results_section[f"lrc_display_{lrc_idx}"],
883
+ # audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
884
+ ],
885
+ outputs=[
886
+ results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
887
+ ]
888
+ )
889
+
890
+
891
+ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
892
+ """Setup event handlers for the training tab (dataset builder and LoRA training)"""
893
+
894
+ # ========== Load Existing Dataset (Top Section) ==========
895
+
896
+ # Load existing dataset JSON at the top of Dataset Builder
897
+ training_section["load_json_btn"].click(
898
+ fn=train_h.load_existing_dataset_for_preprocess,
899
+ inputs=[
900
+ training_section["load_json_path"],
901
+ training_section["dataset_builder_state"],
902
+ ],
903
+ outputs=[
904
+ training_section["load_json_status"],
905
+ training_section["audio_files_table"],
906
+ training_section["sample_selector"],
907
+ training_section["dataset_builder_state"],
908
+ # Also update preview fields with first sample
909
+ training_section["preview_audio"],
910
+ training_section["preview_filename"],
911
+ training_section["edit_caption"],
912
+ training_section["edit_lyrics"],
913
+ training_section["edit_bpm"],
914
+ training_section["edit_keyscale"],
915
+ training_section["edit_timesig"],
916
+ training_section["edit_duration"],
917
+ training_section["edit_language"],
918
+ training_section["edit_instrumental"],
919
+ ]
920
+ )
921
+
922
+ # ========== Dataset Builder Handlers ==========
923
+
924
+ # Scan directory for audio files
925
+ training_section["scan_btn"].click(
926
+ fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
927
+ dir, name, tag, pos, instr, state
928
+ ),
929
+ inputs=[
930
+ training_section["audio_directory"],
931
+ training_section["dataset_name"],
932
+ training_section["custom_tag"],
933
+ training_section["tag_position"],
934
+ training_section["all_instrumental"],
935
+ training_section["dataset_builder_state"],
936
+ ],
937
+ outputs=[
938
+ training_section["audio_files_table"],
939
+ training_section["scan_status"],
940
+ training_section["sample_selector"],
941
+ training_section["dataset_builder_state"],
942
+ ]
943
+ )
944
+
945
+ # Auto-label all samples
946
+ training_section["auto_label_btn"].click(
947
+ fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
948
+ inputs=[
949
+ training_section["dataset_builder_state"],
950
+ training_section["skip_metas"],
951
+ ],
952
+ outputs=[
953
+ training_section["audio_files_table"],
954
+ training_section["label_progress"],
955
+ training_section["dataset_builder_state"],
956
+ ]
957
+ )
958
+
959
+ # Sample selector change - update preview
960
+ training_section["sample_selector"].change(
961
+ fn=train_h.get_sample_preview,
962
+ inputs=[
963
+ training_section["sample_selector"],
964
+ training_section["dataset_builder_state"],
965
+ ],
966
+ outputs=[
967
+ training_section["preview_audio"],
968
+ training_section["preview_filename"],
969
+ training_section["edit_caption"],
970
+ training_section["edit_lyrics"],
971
+ training_section["edit_bpm"],
972
+ training_section["edit_keyscale"],
973
+ training_section["edit_timesig"],
974
+ training_section["edit_duration"],
975
+ training_section["edit_language"],
976
+ training_section["edit_instrumental"],
977
+ ]
978
+ )
979
+
980
+ # Save sample edit
981
+ training_section["save_edit_btn"].click(
982
+ fn=train_h.save_sample_edit,
983
+ inputs=[
984
+ training_section["sample_selector"],
985
+ training_section["edit_caption"],
986
+ training_section["edit_lyrics"],
987
+ training_section["edit_bpm"],
988
+ training_section["edit_keyscale"],
989
+ training_section["edit_timesig"],
990
+ training_section["edit_language"],
991
+ training_section["edit_instrumental"],
992
+ training_section["dataset_builder_state"],
993
+ ],
994
+ outputs=[
995
+ training_section["audio_files_table"],
996
+ training_section["edit_status"],
997
+ training_section["dataset_builder_state"],
998
+ ]
999
+ )
1000
+
1001
+ # Update settings when changed
1002
+ for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
1003
+ trigger.change(
1004
+ fn=train_h.update_settings,
1005
+ inputs=[
1006
+ training_section["custom_tag"],
1007
+ training_section["tag_position"],
1008
+ training_section["all_instrumental"],
1009
+ training_section["dataset_builder_state"],
1010
+ ],
1011
+ outputs=[training_section["dataset_builder_state"]]
1012
+ )
1013
+
1014
+ # Save dataset
1015
+ training_section["save_dataset_btn"].click(
1016
+ fn=train_h.save_dataset,
1017
+ inputs=[
1018
+ training_section["save_path"],
1019
+ training_section["dataset_name"],
1020
+ training_section["dataset_builder_state"],
1021
+ ],
1022
+ outputs=[training_section["save_status"]]
1023
+ )
1024
+
1025
+ # ========== Preprocess Handlers ==========
1026
+
1027
+ # Load existing dataset JSON for preprocessing
1028
+ # This also updates the preview section so users can view/edit samples
1029
+ training_section["load_existing_dataset_btn"].click(
1030
+ fn=train_h.load_existing_dataset_for_preprocess,
1031
+ inputs=[
1032
+ training_section["load_existing_dataset_path"],
1033
+ training_section["dataset_builder_state"],
1034
+ ],
1035
+ outputs=[
1036
+ training_section["load_existing_status"],
1037
+ training_section["audio_files_table"],
1038
+ training_section["sample_selector"],
1039
+ training_section["dataset_builder_state"],
1040
+ # Also update preview fields with first sample
1041
+ training_section["preview_audio"],
1042
+ training_section["preview_filename"],
1043
+ training_section["edit_caption"],
1044
+ training_section["edit_lyrics"],
1045
+ training_section["edit_bpm"],
1046
+ training_section["edit_keyscale"],
1047
+ training_section["edit_timesig"],
1048
+ training_section["edit_duration"],
1049
+ training_section["edit_language"],
1050
+ training_section["edit_instrumental"],
1051
+ ]
1052
+ )
1053
+
1054
+ # Preprocess dataset to tensor files
1055
+ training_section["preprocess_btn"].click(
1056
+ fn=lambda output_dir, state: train_h.preprocess_dataset(
1057
+ output_dir, dit_handler, state
1058
+ ),
1059
+ inputs=[
1060
+ training_section["preprocess_output_dir"],
1061
+ training_section["dataset_builder_state"],
1062
+ ],
1063
+ outputs=[training_section["preprocess_progress"]]
1064
+ )
1065
+
1066
+ # ========== Training Tab Handlers ==========
1067
+
1068
+ # Load preprocessed tensor dataset
1069
+ training_section["load_dataset_btn"].click(
1070
+ fn=train_h.load_training_dataset,
1071
+ inputs=[training_section["training_tensor_dir"]],
1072
+ outputs=[training_section["training_dataset_info"]]
1073
+ )
1074
+
1075
+ # Start training from preprocessed tensors
1076
+ def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
1077
+ try:
1078
+ for progress, log, plot, state in train_h.start_training(
1079
+ tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
1080
+ ):
1081
+ yield progress, log, plot, state
1082
+ except Exception as e:
1083
+ logger.exception("Training wrapper error")
1084
+ yield f"❌ Error: {str(e)}", str(e), None, ts
1085
+
1086
+ training_section["start_training_btn"].click(
1087
+ fn=training_wrapper,
1088
+ inputs=[
1089
+ training_section["training_tensor_dir"],
1090
+ training_section["lora_rank"],
1091
+ training_section["lora_alpha"],
1092
+ training_section["lora_dropout"],
1093
+ training_section["learning_rate"],
1094
+ training_section["train_epochs"],
1095
+ training_section["train_batch_size"],
1096
+ training_section["gradient_accumulation"],
1097
+ training_section["save_every_n_epochs"],
1098
+ training_section["training_shift"],
1099
+ training_section["training_seed"],
1100
+ training_section["lora_output_dir"],
1101
+ training_section["training_state"],
1102
+ ],
1103
+ outputs=[
1104
+ training_section["training_progress"],
1105
+ training_section["training_log"],
1106
+ training_section["training_loss_plot"],
1107
+ training_section["training_state"],
1108
+ ]
1109
+ )
1110
+
1111
+ # Stop training
1112
+ training_section["stop_training_btn"].click(
1113
+ fn=train_h.stop_training,
1114
+ inputs=[training_section["training_state"]],
1115
+ outputs=[
1116
+ training_section["training_progress"],
1117
+ training_section["training_state"],
1118
+ ]
1119
+ )
1120
+
1121
+ # Export LoRA
1122
+ training_section["export_lora_btn"].click(
1123
+ fn=train_h.export_lora,
1124
+ inputs=[
1125
+ training_section["export_path"],
1126
+ training_section["lora_output_dir"],
1127
+ ],
1128
+ outputs=[training_section["export_status"]]
1129
+ )
code/acestep/gradio_ui/events/generation_handlers.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation Input Handlers Module
3
+ Contains event handlers and helper functions related to generation inputs
4
+ """
5
+ import os
6
+ import json
7
+ import random
8
+ import glob
9
+ import gradio as gr
10
+ from typing import Optional, List, Tuple
11
+ from acestep.constants import (
12
+ TASK_TYPES_TURBO,
13
+ TASK_TYPES_BASE,
14
+ )
15
+ from acestep.gradio_ui.i18n import t
16
+ from acestep.inference import understand_music, create_sample, format_sample
17
+
18
+
19
+ def parse_and_validate_timesteps(
20
+ timesteps_str: str,
21
+ inference_steps: int
22
+ ) -> Tuple[Optional[List[float]], bool, str]:
23
+ """
24
+ Parse timesteps string and validate.
25
+
26
+ Args:
27
+ timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
28
+ inference_steps: Expected number of inference steps
29
+
30
+ Returns:
31
+ Tuple of (parsed_timesteps, has_warning, warning_message)
32
+ - parsed_timesteps: List of float timesteps, or None if invalid/empty
33
+ - has_warning: Whether a warning was shown
34
+ - warning_message: Description of the warning
35
+ """
36
+ if not timesteps_str or not timesteps_str.strip():
37
+ return None, False, ""
38
+
39
+ # Parse comma-separated values
40
+ values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
41
+
42
+ if not values:
43
+ return None, False, ""
44
+
45
+ # Handle optional trailing 0
46
+ if values[-1] != "0":
47
+ values.append("0")
48
+
49
+ try:
50
+ timesteps = [float(v) for v in values]
51
+ except ValueError:
52
+ gr.Warning(t("messages.invalid_timesteps_format"))
53
+ return None, True, "Invalid format"
54
+
55
+ # Validate range [0, 1]
56
+ if any(ts < 0 or ts > 1 for ts in timesteps):
57
+ gr.Warning(t("messages.timesteps_out_of_range"))
58
+ return None, True, "Out of range"
59
+
60
+ # Check if count matches inference_steps
61
+ actual_steps = len(timesteps) - 1
62
+ if actual_steps != inference_steps:
63
+ gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
64
+ return timesteps, True, f"Using {actual_steps} steps from timesteps"
65
+
66
+ return timesteps, False, ""
67
+
68
+
69
+ def load_metadata(file_obj):
70
+ """Load generation parameters from a JSON file"""
71
+ if file_obj is None:
72
+ gr.Warning(t("messages.no_file_selected"))
73
+ return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
74
+
75
+ try:
76
+ # Read the uploaded file
77
+ if hasattr(file_obj, 'name'):
78
+ filepath = file_obj.name
79
+ else:
80
+ filepath = file_obj
81
+
82
+ with open(filepath, 'r', encoding='utf-8') as f:
83
+ metadata = json.load(f)
84
+
85
+ # Extract all fields
86
+ task_type = metadata.get('task_type', 'text2music')
87
+ captions = metadata.get('caption', '')
88
+ lyrics = metadata.get('lyrics', '')
89
+ vocal_language = metadata.get('vocal_language', 'unknown')
90
+
91
+ # Convert bpm
92
+ bpm_value = metadata.get('bpm')
93
+ if bpm_value is not None and bpm_value != "N/A":
94
+ try:
95
+ bpm = int(bpm_value) if bpm_value else None
96
+ except:
97
+ bpm = None
98
+ else:
99
+ bpm = None
100
+
101
+ key_scale = metadata.get('keyscale', '')
102
+ time_signature = metadata.get('timesignature', '')
103
+
104
+ # Convert duration
105
+ duration_value = metadata.get('duration', -1)
106
+ if duration_value is not None and duration_value != "N/A":
107
+ try:
108
+ audio_duration = float(duration_value)
109
+ except:
110
+ audio_duration = -1
111
+ else:
112
+ audio_duration = -1
113
+
114
+ batch_size = metadata.get('batch_size', 2)
115
+ inference_steps = metadata.get('inference_steps', 8)
116
+ guidance_scale = metadata.get('guidance_scale', 7.0)
117
+ seed = metadata.get('seed', '-1')
118
+ random_seed = False # Always set to False when loading to enable reproducibility with saved seed
119
+ use_adg = metadata.get('use_adg', False)
120
+ cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
121
+ cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
122
+ audio_format = metadata.get('audio_format', 'mp3')
123
+ lm_temperature = metadata.get('lm_temperature', 0.85)
124
+ lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
125
+ lm_top_k = metadata.get('lm_top_k', 0)
126
+ lm_top_p = metadata.get('lm_top_p', 0.9)
127
+ lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
128
+ use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
129
+ use_cot_caption = metadata.get('use_cot_caption', True)
130
+ use_cot_language = metadata.get('use_cot_language', True)
131
+ audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
132
+ think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
133
+ audio_codes = metadata.get('audio_codes', '')
134
+ repainting_start = metadata.get('repainting_start', 0.0)
135
+ repainting_end = metadata.get('repainting_end', -1)
136
+ track_name = metadata.get('track_name')
137
+ complete_track_classes = metadata.get('complete_track_classes', [])
138
+ shift = metadata.get('shift', 3.0) # Default 3.0 for base models
139
+ infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
140
+ custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
141
+ if custom_timesteps is None:
142
+ custom_timesteps = ''
143
+ instrumental = metadata.get('instrumental', False) # Added: read instrumental
144
+
145
+ gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
146
+
147
+ return (
148
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
149
+ audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
150
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
151
+ custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
152
+ audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
153
+ use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
154
+ think, audio_codes, repainting_start, repainting_end,
155
+ track_name, complete_track_classes, instrumental,
156
+ True # Set is_format_caption to True when loading from file
157
+ )
158
+
159
+ except json.JSONDecodeError as e:
160
+ gr.Warning(t("messages.invalid_json", error=str(e)))
161
+ return [None] * 36 + [False]
162
+ except Exception as e:
163
+ gr.Warning(t("messages.load_error", error=str(e)))
164
+ return [None] * 36 + [False]
165
+
166
+
167
+ def load_random_example(task_type: str):
168
+ """Load a random example from the task-specific examples directory
169
+
170
+ Args:
171
+ task_type: The task type (e.g., "text2music")
172
+
173
+ Returns:
174
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
175
+ """
176
+ try:
177
+ # Get the project root directory
178
+ current_file = os.path.abspath(__file__)
179
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
180
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
181
+
182
+ # Construct the examples directory path
183
+ examples_dir = os.path.join(project_root, "examples", task_type)
184
+
185
+ # Check if directory exists
186
+ if not os.path.exists(examples_dir):
187
+ gr.Warning(f"Examples directory not found: examples/{task_type}/")
188
+ return "", "", True, None, None, "", "", ""
189
+
190
+ # Find all JSON files in the directory
191
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
192
+
193
+ if not json_files:
194
+ gr.Warning(f"No JSON files found in examples/{task_type}/")
195
+ return "", "", True, None, None, "", "", ""
196
+
197
+ # Randomly select one file
198
+ selected_file = random.choice(json_files)
199
+
200
+ # Read and parse JSON
201
+ try:
202
+ with open(selected_file, 'r', encoding='utf-8') as f:
203
+ data = json.load(f)
204
+
205
+ # Extract caption (prefer 'caption', fallback to 'prompt')
206
+ caption_value = data.get('caption', data.get('prompt', ''))
207
+ if not isinstance(caption_value, str):
208
+ caption_value = str(caption_value) if caption_value else ''
209
+
210
+ # Extract lyrics
211
+ lyrics_value = data.get('lyrics', '')
212
+ if not isinstance(lyrics_value, str):
213
+ lyrics_value = str(lyrics_value) if lyrics_value else ''
214
+
215
+ # Extract think (default to True if not present)
216
+ think_value = data.get('think', True)
217
+ if not isinstance(think_value, bool):
218
+ think_value = True
219
+
220
+ # Extract optional metadata fields
221
+ bpm_value = None
222
+ if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
223
+ try:
224
+ bpm_value = int(data['bpm'])
225
+ except (ValueError, TypeError):
226
+ pass
227
+
228
+ duration_value = None
229
+ if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
230
+ try:
231
+ duration_value = float(data['duration'])
232
+ except (ValueError, TypeError):
233
+ pass
234
+
235
+ keyscale_value = data.get('keyscale', '')
236
+ if keyscale_value in [None, "N/A"]:
237
+ keyscale_value = ''
238
+
239
+ language_value = data.get('language', '')
240
+ if language_value in [None, "N/A"]:
241
+ language_value = ''
242
+
243
+ timesignature_value = data.get('timesignature', '')
244
+ if timesignature_value in [None, "N/A"]:
245
+ timesignature_value = ''
246
+
247
+ gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
248
+ return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
249
+
250
+ except json.JSONDecodeError as e:
251
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
252
+ return "", "", True, None, None, "", "", ""
253
+ except Exception as e:
254
+ gr.Warning(t("messages.example_error", error=str(e)))
255
+ return "", "", True, None, None, "", "", ""
256
+
257
+ except Exception as e:
258
+ gr.Warning(t("messages.example_error", error=str(e)))
259
+ return "", "", True, None, None, "", "", ""
260
+
261
+
262
+ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
263
+ """Smart sample function that uses LM if initialized, otherwise falls back to examples
264
+
265
+ This is a Gradio wrapper that uses the understand_music API from acestep.inference
266
+ to generate examples when LM is available.
267
+
268
+ Args:
269
+ llm_handler: LLM handler instance
270
+ task_type: The task type (e.g., "text2music")
271
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
272
+
273
+ Returns:
274
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
275
+ """
276
+ # Check if LM is initialized
277
+ if llm_handler.llm_initialized:
278
+ # Use LM to generate example via understand_music API
279
+ try:
280
+ result = understand_music(
281
+ llm_handler=llm_handler,
282
+ audio_codes="NO USER INPUT", # Empty input triggers example generation
283
+ temperature=0.85,
284
+ use_constrained_decoding=True,
285
+ constrained_decoding_debug=constrained_decoding_debug,
286
+ )
287
+
288
+ if result.success:
289
+ gr.Info(t("messages.lm_generated"))
290
+ return (
291
+ result.caption,
292
+ result.lyrics,
293
+ True, # Always enable think when using LM-generated examples
294
+ result.bpm,
295
+ result.duration,
296
+ result.keyscale,
297
+ result.language,
298
+ result.timesignature,
299
+ )
300
+ else:
301
+ gr.Warning(t("messages.lm_fallback"))
302
+ return load_random_example(task_type)
303
+
304
+ except Exception as e:
305
+ gr.Warning(t("messages.lm_fallback"))
306
+ return load_random_example(task_type)
307
+ else:
308
+ # LM not initialized, use examples directory
309
+ return load_random_example(task_type)
310
+
311
+
312
+ def load_random_simple_description():
313
+ """Load a random description from the simple_mode examples directory.
314
+
315
+ Returns:
316
+ Tuple of (description, instrumental, vocal_language) for updating UI components
317
+ """
318
+ try:
319
+ # Get the project root directory
320
+ current_file = os.path.abspath(__file__)
321
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
322
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
323
+
324
+ # Construct the examples directory path
325
+ examples_dir = os.path.join(project_root, "examples", "simple_mode")
326
+
327
+ # Check if directory exists
328
+ if not os.path.exists(examples_dir):
329
+ gr.Warning(t("messages.simple_examples_not_found"))
330
+ return gr.update(), gr.update(), gr.update()
331
+
332
+ # Find all JSON files in the directory
333
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
334
+
335
+ if not json_files:
336
+ gr.Warning(t("messages.simple_examples_empty"))
337
+ return gr.update(), gr.update(), gr.update()
338
+
339
+ # Randomly select one file
340
+ selected_file = random.choice(json_files)
341
+
342
+ # Read and parse JSON
343
+ try:
344
+ with open(selected_file, 'r', encoding='utf-8') as f:
345
+ data = json.load(f)
346
+
347
+ # Extract fields
348
+ description = data.get('description', '')
349
+ instrumental = data.get('instrumental', False)
350
+ vocal_language = data.get('vocal_language', 'unknown')
351
+
352
+ # Ensure vocal_language is a string
353
+ if isinstance(vocal_language, list):
354
+ vocal_language = vocal_language[0] if vocal_language else 'unknown'
355
+
356
+ gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
357
+ return description, instrumental, vocal_language
358
+
359
+ except json.JSONDecodeError as e:
360
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
361
+ return gr.update(), gr.update(), gr.update()
362
+ except Exception as e:
363
+ gr.Warning(t("messages.example_error", error=str(e)))
364
+ return gr.update(), gr.update(), gr.update()
365
+
366
+ except Exception as e:
367
+ gr.Warning(t("messages.example_error", error=str(e)))
368
+ return gr.update(), gr.update(), gr.update()
369
+
370
+
371
+ def refresh_checkpoints(dit_handler):
372
+ """Refresh available checkpoints"""
373
+ choices = dit_handler.get_available_checkpoints()
374
+ return gr.update(choices=choices)
375
+
376
+
377
+ def update_model_type_settings(config_path):
378
+ """Update UI settings based on model type (fallback when handler not initialized yet)
379
+
380
+ Note: This is used as a fallback when the user changes config_path dropdown
381
+ before initializing the model. The actual settings are determined by the
382
+ handler's is_turbo_model() method after initialization.
383
+ """
384
+ if config_path is None:
385
+ config_path = ""
386
+ config_path_lower = config_path.lower()
387
+
388
+ # Determine is_turbo based on config_path string
389
+ # This is a heuristic fallback - actual model type is determined after loading
390
+ if "turbo" in config_path_lower:
391
+ is_turbo = True
392
+ elif "base" in config_path_lower:
393
+ is_turbo = False
394
+ else:
395
+ # Default to turbo settings for unknown model types
396
+ is_turbo = True
397
+
398
+ return get_model_type_ui_settings(is_turbo)
399
+
400
+
401
+ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
402
+ """Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
403
+ # Initialize DiT handler
404
+ status, enable = dit_handler.initialize_service(
405
+ checkpoint, config_path, device,
406
+ use_flash_attention=use_flash_attention, compile_model=False,
407
+ offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
408
+ )
409
+
410
+ # Initialize LM handler if requested
411
+ if init_llm:
412
+ # Get checkpoint directory
413
+ current_file = os.path.abspath(__file__)
414
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
415
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
416
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
417
+
418
+ lm_status, lm_success = llm_handler.initialize(
419
+ checkpoint_dir=checkpoint_dir,
420
+ lm_model_path=lm_model_path,
421
+ backend=backend,
422
+ device=device,
423
+ offload_to_cpu=offload_to_cpu,
424
+ dtype=dit_handler.dtype
425
+ )
426
+
427
+ if lm_success:
428
+ status += f"\n{lm_status}"
429
+ else:
430
+ status += f"\n{lm_status}"
431
+ # Don't fail the entire initialization if LM fails, but log it
432
+ # Keep enable as is (DiT initialization result) even if LM fails
433
+
434
+ # Check if model is initialized - if so, collapse the accordion
435
+ is_model_initialized = dit_handler.model is not None
436
+ accordion_state = gr.Accordion(open=not is_model_initialized)
437
+
438
+ # Get model type settings based on actual loaded model
439
+ is_turbo = dit_handler.is_turbo_model()
440
+ model_type_settings = get_model_type_ui_settings(is_turbo)
441
+
442
+ return (
443
+ status,
444
+ gr.update(interactive=enable),
445
+ accordion_state,
446
+ *model_type_settings
447
+ )
448
+
449
+
450
+ def get_model_type_ui_settings(is_turbo: bool):
451
+ """Get UI settings based on whether the model is turbo or base"""
452
+ if is_turbo:
453
+ # Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
454
+ return (
455
+ gr.update(value=8, maximum=20, minimum=1), # inference_steps
456
+ gr.update(visible=False), # guidance_scale
457
+ gr.update(visible=False), # use_adg
458
+ gr.update(value=3.0, visible=True), # shift (show with default 3.0)
459
+ gr.update(visible=False), # cfg_interval_start
460
+ gr.update(visible=False), # cfg_interval_end
461
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
462
+ )
463
+ else:
464
+ # Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
465
+ return (
466
+ gr.update(value=32, maximum=200, minimum=1), # inference_steps
467
+ gr.update(visible=True), # guidance_scale
468
+ gr.update(visible=True), # use_adg
469
+ gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
470
+ gr.update(visible=True), # cfg_interval_start
471
+ gr.update(visible=True), # cfg_interval_end
472
+ gr.update(choices=TASK_TYPES_BASE), # task_type
473
+ )
474
+
475
+
476
+ def update_negative_prompt_visibility(init_llm_checked):
477
+ """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
478
+ return gr.update(visible=init_llm_checked)
479
+
480
+
481
+ def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
482
+ """Update audio_cover_strength visibility and label"""
483
+ # Show if task is cover OR if LM is initialized
484
+ is_visible = (task_type_value == "cover") or init_llm_checked
485
+ # Change label based on context
486
+ if init_llm_checked and task_type_value != "cover":
487
+ label = "LM codes strength"
488
+ info = "Control how many denoising steps use LM-generated codes"
489
+ else:
490
+ label = "Audio Cover Strength"
491
+ info = "Control how many denoising steps use cover mode"
492
+
493
+ return gr.update(visible=is_visible, label=label, info=info)
494
+
495
+
496
+ def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
497
+ """Wrapper for converting src audio to codes"""
498
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
499
+ return codes_string
500
+
501
+
502
+ def update_instruction_ui(
503
+ dit_handler,
504
+ task_type_value: str,
505
+ track_name_value: Optional[str],
506
+ complete_track_classes_value: list,
507
+ audio_codes_content: str = "",
508
+ init_llm_checked: bool = False
509
+ ) -> tuple:
510
+ """Update instruction and UI visibility based on task type."""
511
+ instruction = dit_handler.generate_instruction(
512
+ task_type=task_type_value,
513
+ track_name=track_name_value,
514
+ complete_track_classes=complete_track_classes_value
515
+ )
516
+
517
+ # Show track_name for lego and extract
518
+ track_name_visible = task_type_value in ["lego", "extract"]
519
+ # Show complete_track_classes for complete
520
+ complete_visible = task_type_value == "complete"
521
+ # Show audio_cover_strength for cover OR when LM is initialized
522
+ audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked
523
+ # Determine label and info based on context
524
+ if init_llm_checked and task_type_value != "cover":
525
+ audio_cover_strength_label = "LM codes strength"
526
+ audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
527
+ else:
528
+ audio_cover_strength_label = "Audio Cover Strength"
529
+ audio_cover_strength_info = "Control how many denoising steps use cover mode"
530
+ # Show repainting controls for repaint and lego
531
+ repainting_visible = task_type_value in ["repaint", "lego"]
532
+ # Show text2music_audio_codes if task is text2music OR if it has content
533
+ # This allows it to stay visible even if user switches task type but has codes
534
+ has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
535
+ text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
536
+
537
+ return (
538
+ instruction, # instruction_display_gen
539
+ gr.update(visible=track_name_visible), # track_name
540
+ gr.update(visible=complete_visible), # complete_track_classes
541
+ gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
542
+ gr.update(visible=repainting_visible), # repainting_group
543
+ gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
544
+ )
545
+
546
+
547
+ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
548
+ """
549
+ Transcribe audio codes to metadata using LLM understanding.
550
+ If audio_code_string is empty, generate a sample example instead.
551
+
552
+ This is a Gradio wrapper around the understand_music API in acestep.inference.
553
+
554
+ Args:
555
+ llm_handler: LLM handler instance
556
+ audio_code_string: String containing audio codes (or empty for example generation)
557
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
558
+
559
+ Returns:
560
+ Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
561
+ """
562
+ # Call the inference API
563
+ result = understand_music(
564
+ llm_handler=llm_handler,
565
+ audio_codes=audio_code_string,
566
+ use_constrained_decoding=True,
567
+ constrained_decoding_debug=constrained_decoding_debug,
568
+ )
569
+
570
+ # Handle error case with localized message
571
+ if not result.success:
572
+ # Use localized error message for LLM not initialized
573
+ if result.error == "LLM not initialized":
574
+ return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
575
+ return result.status_message, "", "", None, None, "", "", "", False
576
+
577
+ return (
578
+ result.status_message,
579
+ result.caption,
580
+ result.lyrics,
581
+ result.bpm,
582
+ result.duration,
583
+ result.keyscale,
584
+ result.language,
585
+ result.timesignature,
586
+ True # Set is_format_caption to True (from Transcribe/LM understanding)
587
+ )
588
+
589
+
590
+ def update_transcribe_button_text(audio_code_string):
591
+ """
592
+ Update the transcribe button text based on input content.
593
+ If empty: "Generate Example"
594
+ If has content: "Transcribe"
595
+ """
596
+ if not audio_code_string or not audio_code_string.strip():
597
+ return gr.update(value="Generate Example")
598
+ else:
599
+ return gr.update(value="Transcribe")
600
+
601
+
602
+ def reset_format_caption_flag():
603
+ """Reset is_format_caption to False when user manually edits caption/metadata"""
604
+ return False
605
+
606
+
607
+ def update_audio_uploads_accordion(reference_audio, src_audio):
608
+ """Update Audio Uploads accordion open state based on whether audio files are present"""
609
+ has_audio = (reference_audio is not None) or (src_audio is not None)
610
+ return gr.Accordion(open=has_audio)
611
+
612
+
613
+ def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
614
+ """
615
+ Handle instrumental checkbox changes.
616
+ When checked: if no lyrics, fill with [Instrumental]
617
+ When unchecked: if lyrics is [Instrumental], clear it
618
+ """
619
+ if instrumental_checked:
620
+ # If checked and no lyrics, fill with [Instrumental]
621
+ if not current_lyrics or not current_lyrics.strip():
622
+ return "[Instrumental]"
623
+ else:
624
+ # Has lyrics, don't change
625
+ return current_lyrics
626
+ else:
627
+ # If unchecked and lyrics is exactly [Instrumental], clear it
628
+ if current_lyrics and current_lyrics.strip() == "[Instrumental]":
629
+ return ""
630
+ else:
631
+ # Has other lyrics, don't change
632
+ return current_lyrics
633
+
634
+
635
+ def handle_simple_instrumental_change(is_instrumental: bool):
636
+ """
637
+ Handle simple mode instrumental checkbox changes.
638
+ When checked: set vocal_language to "unknown" and disable editing.
639
+ When unchecked: enable vocal_language editing.
640
+
641
+ Args:
642
+ is_instrumental: Whether instrumental checkbox is checked
643
+
644
+ Returns:
645
+ gr.update for simple_vocal_language dropdown
646
+ """
647
+ if is_instrumental:
648
+ return gr.update(value="unknown", interactive=False)
649
+ else:
650
+ return gr.update(interactive=True)
651
+
652
+
653
+ def update_audio_components_visibility(batch_size):
654
+ """Show/hide individual audio components based on batch size (1-8)
655
+
656
+ Row 1: Components 1-4 (batch_size 1-4)
657
+ Row 2: Components 5-8 (batch_size 5-8)
658
+ """
659
+ # Clamp batch size to 1-8 range for UI
660
+ batch_size = min(max(int(batch_size), 1), 8)
661
+
662
+ # Row 1 columns (1-4)
663
+ updates_row1 = (
664
+ gr.update(visible=True), # audio_col_1: always visible
665
+ gr.update(visible=batch_size >= 2), # audio_col_2
666
+ gr.update(visible=batch_size >= 3), # audio_col_3
667
+ gr.update(visible=batch_size >= 4), # audio_col_4
668
+ )
669
+
670
+ # Row 2 container and columns (5-8)
671
+ show_row_5_8 = batch_size >= 5
672
+ updates_row2 = (
673
+ gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
674
+ gr.update(visible=batch_size >= 5), # audio_col_5
675
+ gr.update(visible=batch_size >= 6), # audio_col_6
676
+ gr.update(visible=batch_size >= 7), # audio_col_7
677
+ gr.update(visible=batch_size >= 8), # audio_col_8
678
+ )
679
+
680
+ return updates_row1 + updates_row2
681
+
682
+
683
+ def handle_generation_mode_change(mode: str):
684
+ """
685
+ Handle generation mode change between Simple and Custom modes.
686
+
687
+ In Simple mode:
688
+ - Show simple mode group (query input, instrumental checkbox, create button)
689
+ - Collapse caption and lyrics accordions
690
+ - Hide optional parameters accordion
691
+ - Disable generate button until sample is created
692
+
693
+ In Custom mode:
694
+ - Hide simple mode group
695
+ - Expand caption and lyrics accordions
696
+ - Show optional parameters accordion
697
+ - Enable generate button
698
+
699
+ Args:
700
+ mode: "simple" or "custom"
701
+
702
+ Returns:
703
+ Tuple of updates for:
704
+ - simple_mode_group (visibility)
705
+ - caption_accordion (open state)
706
+ - lyrics_accordion (open state)
707
+ - generate_btn (interactive state)
708
+ - simple_sample_created (reset state)
709
+ - optional_params_accordion (visibility)
710
+ """
711
+ is_simple = mode == "simple"
712
+
713
+ return (
714
+ gr.update(visible=is_simple), # simple_mode_group
715
+ gr.Accordion(open=not is_simple), # caption_accordion - collapsed in simple, open in custom
716
+ gr.Accordion(open=not is_simple), # lyrics_accordion - collapsed in simple, open in custom
717
+ gr.update(interactive=not is_simple), # generate_btn - disabled in simple until sample created
718
+ False, # simple_sample_created - reset to False on mode change
719
+ gr.Accordion(open=not is_simple), # optional_params_accordion - hidden in simple mode
720
+ )
721
+
722
+
723
+ def handle_create_sample(
724
+ llm_handler,
725
+ query: str,
726
+ instrumental: bool,
727
+ vocal_language: str,
728
+ lm_temperature: float,
729
+ lm_top_k: int,
730
+ lm_top_p: float,
731
+ constrained_decoding_debug: bool = False,
732
+ ):
733
+ """
734
+ Handle the Create Sample button click in Simple mode.
735
+
736
+ Creates a sample from the user's query using the LLM, then populates
737
+ the caption, lyrics, and metadata fields.
738
+
739
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
740
+
741
+ Args:
742
+ llm_handler: LLM handler instance
743
+ query: User's natural language music description
744
+ instrumental: Whether to generate instrumental music
745
+ vocal_language: Preferred vocal language for constrained decoding
746
+ lm_temperature: LLM temperature for generation
747
+ lm_top_k: LLM top-k sampling
748
+ lm_top_p: LLM top-p sampling
749
+ constrained_decoding_debug: Whether to enable debug logging
750
+
751
+ Returns:
752
+ Tuple of updates for:
753
+ - captions
754
+ - lyrics
755
+ - bpm
756
+ - audio_duration
757
+ - key_scale
758
+ - vocal_language
759
+ - time_signature
760
+ - instrumental_checkbox
761
+ - caption_accordion (open)
762
+ - lyrics_accordion (open)
763
+ - generate_btn (interactive)
764
+ - simple_sample_created (True)
765
+ - think_checkbox (True)
766
+ - is_format_caption_state (True)
767
+ - status_output
768
+ """
769
+ # Check if LLM is initialized
770
+ if not llm_handler.llm_initialized:
771
+ gr.Warning(t("messages.lm_not_initialized"))
772
+ return (
773
+ gr.update(), # captions - no change
774
+ gr.update(), # lyrics - no change
775
+ gr.update(), # bpm - no change
776
+ gr.update(), # audio_duration - no change
777
+ gr.update(), # key_scale - no change
778
+ gr.update(), # vocal_language - no change
779
+ gr.update(), # time_signature - no change
780
+ gr.update(), # instrumental_checkbox - no change
781
+ gr.update(), # caption_accordion - no change
782
+ gr.update(), # lyrics_accordion - no change
783
+ gr.update(interactive=False), # generate_btn - keep disabled
784
+ False, # simple_sample_created - still False
785
+ gr.update(), # think_checkbox - no change
786
+ gr.update(), # is_format_caption_state - no change
787
+ t("messages.lm_not_initialized"), # status_output
788
+ )
789
+
790
+ # Convert LM parameters
791
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
792
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
793
+
794
+ # Call create_sample API
795
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
796
+ result = create_sample(
797
+ llm_handler=llm_handler,
798
+ query=query,
799
+ instrumental=instrumental,
800
+ vocal_language=vocal_language,
801
+ temperature=lm_temperature,
802
+ top_k=top_k_value,
803
+ top_p=top_p_value,
804
+ use_constrained_decoding=True,
805
+ constrained_decoding_debug=constrained_decoding_debug,
806
+ )
807
+
808
+ # Handle error
809
+ if not result.success:
810
+ gr.Warning(result.status_message or t("messages.sample_creation_failed"))
811
+ return (
812
+ gr.update(), # captions - no change
813
+ gr.update(), # lyrics - no change
814
+ gr.update(), # bpm - no change
815
+ gr.update(), # audio_duration - no change
816
+ gr.update(), # key_scale - no change
817
+ gr.update(), # vocal_language - no change
818
+ gr.update(), # simple vocal_language - no change
819
+ gr.update(), # time_signature - no change
820
+ gr.update(), # instrumental_checkbox - no change
821
+ gr.update(), # caption_accordion - no change
822
+ gr.update(), # lyrics_accordion - no change
823
+ gr.update(interactive=False), # generate_btn - keep disabled
824
+ False, # simple_sample_created - still False
825
+ gr.update(), # think_checkbox - no change
826
+ gr.update(), # is_format_caption_state - no change
827
+ result.status_message or t("messages.sample_creation_failed"), # status_output
828
+ )
829
+
830
+ # Success - populate fields
831
+ gr.Info(t("messages.sample_created"))
832
+
833
+ return (
834
+ result.caption, # captions
835
+ result.lyrics, # lyrics
836
+ result.bpm, # bpm
837
+ result.duration if result.duration and result.duration > 0 else -1, # audio_duration
838
+ result.keyscale, # key_scale
839
+ result.language, # vocal_language
840
+ result.language, # simple vocal_language
841
+ result.timesignature, # time_signature
842
+ result.instrumental, # instrumental_checkbox
843
+ gr.Accordion(open=True), # caption_accordion - expand
844
+ gr.Accordion(open=True), # lyrics_accordion - expand
845
+ gr.update(interactive=True), # generate_btn - enable
846
+ True, # simple_sample_created - True
847
+ True, # think_checkbox - enable thinking
848
+ True, # is_format_caption_state - True (LM-generated)
849
+ result.status_message, # status_output
850
+ )
851
+
852
+
853
+ def handle_format_sample(
854
+ llm_handler,
855
+ caption: str,
856
+ lyrics: str,
857
+ bpm,
858
+ audio_duration,
859
+ key_scale: str,
860
+ time_signature: str,
861
+ lm_temperature: float,
862
+ lm_top_k: int,
863
+ lm_top_p: float,
864
+ constrained_decoding_debug: bool = False,
865
+ ):
866
+ """
867
+ Handle the Format button click to format caption and lyrics.
868
+
869
+ Takes user-provided caption and lyrics, and uses the LLM to generate
870
+ structured music metadata and an enhanced description.
871
+
872
+ Note: cfg_scale and negative_prompt are not supported in format mode.
873
+
874
+ Args:
875
+ llm_handler: LLM handler instance
876
+ caption: User's caption/description
877
+ lyrics: User's lyrics
878
+ bpm: User-provided BPM (optional, for constrained decoding)
879
+ audio_duration: User-provided duration (optional, for constrained decoding)
880
+ key_scale: User-provided key scale (optional, for constrained decoding)
881
+ time_signature: User-provided time signature (optional, for constrained decoding)
882
+ lm_temperature: LLM temperature for generation
883
+ lm_top_k: LLM top-k sampling
884
+ lm_top_p: LLM top-p sampling
885
+ constrained_decoding_debug: Whether to enable debug logging
886
+
887
+ Returns:
888
+ Tuple of updates for:
889
+ - captions
890
+ - lyrics
891
+ - bpm
892
+ - audio_duration
893
+ - key_scale
894
+ - vocal_language
895
+ - time_signature
896
+ - is_format_caption_state
897
+ - status_output
898
+ """
899
+ # Check if LLM is initialized
900
+ if not llm_handler.llm_initialized:
901
+ gr.Warning(t("messages.lm_not_initialized"))
902
+ return (
903
+ gr.update(), # captions - no change
904
+ gr.update(), # lyrics - no change
905
+ gr.update(), # bpm - no change
906
+ gr.update(), # audio_duration - no change
907
+ gr.update(), # key_scale - no change
908
+ gr.update(), # vocal_language - no change
909
+ gr.update(), # time_signature - no change
910
+ gr.update(), # is_format_caption_state - no change
911
+ t("messages.lm_not_initialized"), # status_output
912
+ )
913
+
914
+ # Build user_metadata from provided values for constrained decoding
915
+ user_metadata = {}
916
+ if bpm is not None and bpm > 0:
917
+ user_metadata['bpm'] = int(bpm)
918
+ if audio_duration is not None and float(audio_duration) > 0:
919
+ user_metadata['duration'] = int(audio_duration)
920
+ if key_scale and key_scale.strip():
921
+ user_metadata['keyscale'] = key_scale.strip()
922
+ if time_signature and time_signature.strip():
923
+ user_metadata['timesignature'] = time_signature.strip()
924
+
925
+ # Only pass user_metadata if we have at least one field
926
+ user_metadata_to_pass = user_metadata if user_metadata else None
927
+
928
+ # Convert LM parameters
929
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
930
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
931
+
932
+ # Call format_sample API
933
+ result = format_sample(
934
+ llm_handler=llm_handler,
935
+ caption=caption,
936
+ lyrics=lyrics,
937
+ user_metadata=user_metadata_to_pass,
938
+ temperature=lm_temperature,
939
+ top_k=top_k_value,
940
+ top_p=top_p_value,
941
+ use_constrained_decoding=True,
942
+ constrained_decoding_debug=constrained_decoding_debug,
943
+ )
944
+
945
+ # Handle error
946
+ if not result.success:
947
+ gr.Warning(result.status_message or t("messages.format_failed"))
948
+ return (
949
+ gr.update(), # captions - no change
950
+ gr.update(), # lyrics - no change
951
+ gr.update(), # bpm - no change
952
+ gr.update(), # audio_duration - no change
953
+ gr.update(), # key_scale - no change
954
+ gr.update(), # vocal_language - no change
955
+ gr.update(), # time_signature - no change
956
+ gr.update(), # is_format_caption_state - no change
957
+ result.status_message or t("messages.format_failed"), # status_output
958
+ )
959
+
960
+ # Success - populate fields
961
+ gr.Info(t("messages.format_success"))
962
+
963
+ return (
964
+ result.caption, # captions
965
+ result.lyrics, # lyrics
966
+ result.bpm, # bpm
967
+ result.duration if result.duration and result.duration > 0 else -1, # audio_duration
968
+ result.keyscale, # key_scale
969
+ result.language, # vocal_language
970
+ result.timesignature, # time_signature
971
+ True, # is_format_caption_state - True (LM-formatted)
972
+ result.status_message, # status_output
973
+ )
974
+
code/acestep/gradio_ui/events/results_handlers.py ADDED
The diff for this file is too large to render. See raw diff
 
code/acestep/gradio_ui/events/training_handlers.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Handlers for Training Tab
3
+
4
+ Contains all event handler functions for the dataset builder and training UI.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Any, Dict, List, Tuple, Optional
10
+ from loguru import logger
11
+ import gradio as gr
12
+
13
+ from acestep.training.dataset_builder import DatasetBuilder, AudioSample
14
+
15
+
16
+ def create_dataset_builder() -> DatasetBuilder:
17
+ """Create a new DatasetBuilder instance."""
18
+ return DatasetBuilder()
19
+
20
+
21
+ def scan_directory(
22
+ audio_dir: str,
23
+ dataset_name: str,
24
+ custom_tag: str,
25
+ tag_position: str,
26
+ all_instrumental: bool,
27
+ builder_state: Optional[DatasetBuilder],
28
+ ) -> Tuple[Any, str, Any, DatasetBuilder]:
29
+ """Scan a directory for audio files.
30
+
31
+ Returns:
32
+ Tuple of (table_data, status, slider_update, builder_state)
33
+ """
34
+ if not audio_dir or not audio_dir.strip():
35
+ return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
36
+
37
+ # Create or use existing builder
38
+ builder = builder_state if builder_state else DatasetBuilder()
39
+
40
+ # Set metadata before scanning
41
+ builder.metadata.name = dataset_name
42
+ builder.metadata.custom_tag = custom_tag
43
+ builder.metadata.tag_position = tag_position
44
+ builder.metadata.all_instrumental = all_instrumental
45
+
46
+ # Scan directory
47
+ samples, status = builder.scan_directory(audio_dir.strip())
48
+
49
+ if not samples:
50
+ return [], status, gr.Slider(maximum=0, value=0), builder
51
+
52
+ # Set instrumental and tag for all samples
53
+ builder.set_all_instrumental(all_instrumental)
54
+ if custom_tag:
55
+ builder.set_custom_tag(custom_tag, tag_position)
56
+
57
+ # Get table data
58
+ table_data = builder.get_samples_dataframe_data()
59
+
60
+ # Calculate slider max and return as Slider update
61
+ slider_max = max(0, len(samples) - 1)
62
+
63
+ return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
64
+
65
+
66
+ def auto_label_all(
67
+ dit_handler,
68
+ llm_handler,
69
+ builder_state: Optional[DatasetBuilder],
70
+ skip_metas: bool = False,
71
+ progress=None,
72
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
73
+ """Auto-label all samples in the dataset.
74
+
75
+ Args:
76
+ dit_handler: DiT handler for audio processing
77
+ llm_handler: LLM handler for caption generation
78
+ builder_state: Dataset builder state
79
+ skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
80
+ progress: Progress callback
81
+
82
+ Returns:
83
+ Tuple of (table_data, status, builder_state)
84
+ """
85
+ if builder_state is None:
86
+ return [], "❌ Please scan a directory first", builder_state
87
+
88
+ if not builder_state.samples:
89
+ return [], "❌ No samples to label. Please scan a directory first.", builder_state
90
+
91
+ # If skip_metas is True, just set default values without LLM
92
+ if skip_metas:
93
+ for sample in builder_state.samples:
94
+ sample.bpm = None # Will display as N/A
95
+ sample.keyscale = "N/A"
96
+ sample.timesignature = "N/A"
97
+ # For instrumental, language should be "unknown"
98
+ if sample.is_instrumental:
99
+ sample.language = "unknown"
100
+ else:
101
+ sample.language = "unknown"
102
+ # Use custom tag as caption if set, otherwise use filename
103
+ if builder_state.metadata.custom_tag:
104
+ sample.caption = builder_state.metadata.custom_tag
105
+ else:
106
+ sample.caption = sample.filename
107
+
108
+ table_data = builder_state.get_samples_dataframe_data()
109
+ return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
110
+
111
+ # Check if handlers are initialized
112
+ if dit_handler is None or dit_handler.model is None:
113
+ return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
114
+
115
+ if llm_handler is None or not llm_handler.llm_initialized:
116
+ return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
117
+
118
+ def progress_callback(msg):
119
+ if progress:
120
+ try:
121
+ progress(msg)
122
+ except:
123
+ pass
124
+
125
+ # Label all samples
126
+ samples, status = builder_state.label_all_samples(
127
+ dit_handler=dit_handler,
128
+ llm_handler=llm_handler,
129
+ progress_callback=progress_callback,
130
+ )
131
+
132
+ # Get updated table data
133
+ table_data = builder_state.get_samples_dataframe_data()
134
+
135
+ return table_data, status, builder_state
136
+
137
+
138
+ def get_sample_preview(
139
+ sample_idx: int,
140
+ builder_state: Optional[DatasetBuilder],
141
+ ) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
142
+ """Get preview data for a specific sample.
143
+
144
+ Returns:
145
+ Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
146
+ """
147
+ if builder_state is None or not builder_state.samples:
148
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
149
+
150
+ idx = int(sample_idx)
151
+ if idx < 0 or idx >= len(builder_state.samples):
152
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
153
+
154
+ sample = builder_state.samples[idx]
155
+
156
+ return (
157
+ sample.audio_path,
158
+ sample.filename,
159
+ sample.caption,
160
+ sample.lyrics,
161
+ sample.bpm,
162
+ sample.keyscale,
163
+ sample.timesignature,
164
+ sample.duration,
165
+ sample.language,
166
+ sample.is_instrumental,
167
+ )
168
+
169
+
170
+ def save_sample_edit(
171
+ sample_idx: int,
172
+ caption: str,
173
+ lyrics: str,
174
+ bpm: Optional[int],
175
+ keyscale: str,
176
+ timesig: str,
177
+ language: str,
178
+ is_instrumental: bool,
179
+ builder_state: Optional[DatasetBuilder],
180
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
181
+ """Save edits to a sample.
182
+
183
+ Returns:
184
+ Tuple of (table_data, status, builder_state)
185
+ """
186
+ if builder_state is None:
187
+ return [], "❌ No dataset loaded", builder_state
188
+
189
+ idx = int(sample_idx)
190
+
191
+ # Update sample
192
+ sample, status = builder_state.update_sample(
193
+ idx,
194
+ caption=caption,
195
+ lyrics=lyrics if not is_instrumental else "[Instrumental]",
196
+ bpm=int(bpm) if bpm else None,
197
+ keyscale=keyscale,
198
+ timesignature=timesig,
199
+ language="instrumental" if is_instrumental else language,
200
+ is_instrumental=is_instrumental,
201
+ labeled=True,
202
+ )
203
+
204
+ # Get updated table data
205
+ table_data = builder_state.get_samples_dataframe_data()
206
+
207
+ return table_data, status, builder_state
208
+
209
+
210
+ def update_settings(
211
+ custom_tag: str,
212
+ tag_position: str,
213
+ all_instrumental: bool,
214
+ builder_state: Optional[DatasetBuilder],
215
+ ) -> DatasetBuilder:
216
+ """Update dataset settings.
217
+
218
+ Returns:
219
+ Updated builder_state
220
+ """
221
+ if builder_state is None:
222
+ return builder_state
223
+
224
+ if custom_tag:
225
+ builder_state.set_custom_tag(custom_tag, tag_position)
226
+
227
+ builder_state.set_all_instrumental(all_instrumental)
228
+
229
+ return builder_state
230
+
231
+
232
+ def save_dataset(
233
+ save_path: str,
234
+ dataset_name: str,
235
+ builder_state: Optional[DatasetBuilder],
236
+ ) -> str:
237
+ """Save the dataset to a JSON file.
238
+
239
+ Returns:
240
+ Status message
241
+ """
242
+ if builder_state is None:
243
+ return "❌ No dataset to save. Please scan a directory first."
244
+
245
+ if not builder_state.samples:
246
+ return "❌ No samples in dataset."
247
+
248
+ if not save_path or not save_path.strip():
249
+ return "❌ Please enter a save path."
250
+
251
+ # Check if any samples are labeled
252
+ labeled_count = builder_state.get_labeled_count()
253
+ if labeled_count == 0:
254
+ return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
255
+
256
+ return builder_state.save_dataset(save_path.strip(), dataset_name)
257
+
258
+
259
+ def load_existing_dataset_for_preprocess(
260
+ dataset_path: str,
261
+ builder_state: Optional[DatasetBuilder],
262
+ ) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
263
+ """Load an existing dataset JSON file for preprocessing.
264
+
265
+ This allows users to load a previously saved dataset and proceed to preprocessing
266
+ without having to re-scan and re-label.
267
+
268
+ Returns:
269
+ Tuple of (status, table_data, slider_update, builder_state,
270
+ audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
271
+ """
272
+ empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
273
+
274
+ if not dataset_path or not dataset_path.strip():
275
+ return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
276
+
277
+ dataset_path = dataset_path.strip()
278
+
279
+ if not os.path.exists(dataset_path):
280
+ return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
281
+
282
+ # Create new builder (don't reuse old state when loading a file)
283
+ builder = DatasetBuilder()
284
+
285
+ # Load the dataset
286
+ samples, status = builder.load_dataset(dataset_path)
287
+
288
+ if not samples:
289
+ return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
290
+
291
+ # Get table data
292
+ table_data = builder.get_samples_dataframe_data()
293
+
294
+ # Calculate slider max
295
+ slider_max = max(0, len(samples) - 1)
296
+
297
+ # Create info text
298
+ labeled_count = builder.get_labeled_count()
299
+ info = f"✅ Loaded dataset: {builder.metadata.name}\n"
300
+ info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
301
+ info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
302
+ info += "📝 Ready for preprocessing! You can also edit samples below."
303
+
304
+ # Get first sample preview
305
+ first_sample = builder.samples[0]
306
+ preview = (
307
+ first_sample.audio_path,
308
+ first_sample.filename,
309
+ first_sample.caption,
310
+ first_sample.lyrics,
311
+ first_sample.bpm,
312
+ first_sample.keyscale,
313
+ first_sample.timesignature,
314
+ first_sample.duration,
315
+ first_sample.language,
316
+ first_sample.is_instrumental,
317
+ )
318
+
319
+ return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
320
+
321
+
322
+ def preprocess_dataset(
323
+ output_dir: str,
324
+ dit_handler,
325
+ builder_state: Optional[DatasetBuilder],
326
+ progress=None,
327
+ ) -> str:
328
+ """Preprocess dataset to tensor files for fast training.
329
+
330
+ This converts audio files to VAE latents and text to embeddings.
331
+
332
+ Returns:
333
+ Status message
334
+ """
335
+ if builder_state is None:
336
+ return "❌ No dataset loaded. Please scan a directory first."
337
+
338
+ if not builder_state.samples:
339
+ return "❌ No samples in dataset."
340
+
341
+ labeled_count = builder_state.get_labeled_count()
342
+ if labeled_count == 0:
343
+ return "❌ No labeled samples. Please auto-label or manually label samples first."
344
+
345
+ if not output_dir or not output_dir.strip():
346
+ return "❌ Please enter an output directory."
347
+
348
+ if dit_handler is None or dit_handler.model is None:
349
+ return "❌ Model not initialized. Please initialize the service first."
350
+
351
+ def progress_callback(msg):
352
+ if progress:
353
+ try:
354
+ progress(msg)
355
+ except:
356
+ pass
357
+
358
+ # Run preprocessing
359
+ output_paths, status = builder_state.preprocess_to_tensors(
360
+ dit_handler=dit_handler,
361
+ output_dir=output_dir.strip(),
362
+ progress_callback=progress_callback,
363
+ )
364
+
365
+ return status
366
+
367
+
368
+ def load_training_dataset(
369
+ tensor_dir: str,
370
+ ) -> str:
371
+ """Load a preprocessed tensor dataset for training.
372
+
373
+ Returns:
374
+ Info text about the dataset
375
+ """
376
+ if not tensor_dir or not tensor_dir.strip():
377
+ return "❌ Please enter a tensor directory path"
378
+
379
+ tensor_dir = tensor_dir.strip()
380
+
381
+ if not os.path.exists(tensor_dir):
382
+ return f"❌ Directory not found: {tensor_dir}"
383
+
384
+ if not os.path.isdir(tensor_dir):
385
+ return f"❌ Not a directory: {tensor_dir}"
386
+
387
+ # Check for manifest
388
+ manifest_path = os.path.join(tensor_dir, "manifest.json")
389
+ if os.path.exists(manifest_path):
390
+ try:
391
+ with open(manifest_path, 'r') as f:
392
+ manifest = json.load(f)
393
+
394
+ num_samples = manifest.get("num_samples", 0)
395
+ metadata = manifest.get("metadata", {})
396
+ name = metadata.get("name", "Unknown")
397
+ custom_tag = metadata.get("custom_tag", "")
398
+
399
+ info = f"✅ Loaded preprocessed dataset: {name}\n"
400
+ info += f"📊 Samples: {num_samples} preprocessed tensors\n"
401
+ info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
402
+
403
+ return info
404
+ except Exception as e:
405
+ logger.warning(f"Failed to read manifest: {e}")
406
+
407
+ # Fallback: count .pt files
408
+ pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
409
+
410
+ if not pt_files:
411
+ return f"❌ No .pt tensor files found in {tensor_dir}"
412
+
413
+ info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
414
+ info += "⚠️ No manifest.json found - using all .pt files"
415
+
416
+ return info
417
+
418
+
419
+ # Training handlers
420
+
421
+ import time
422
+ import re
423
+
424
+
425
+ def _format_duration(seconds):
426
+ """Format seconds to human readable string."""
427
+ seconds = int(seconds)
428
+ if seconds < 60:
429
+ return f"{seconds}s"
430
+ elif seconds < 3600:
431
+ return f"{seconds // 60}m {seconds % 60}s"
432
+ else:
433
+ return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
434
+
435
+
436
+ def start_training(
437
+ tensor_dir: str,
438
+ dit_handler,
439
+ lora_rank: int,
440
+ lora_alpha: int,
441
+ lora_dropout: float,
442
+ learning_rate: float,
443
+ train_epochs: int,
444
+ train_batch_size: int,
445
+ gradient_accumulation: int,
446
+ save_every_n_epochs: int,
447
+ training_shift: float,
448
+ training_seed: int,
449
+ lora_output_dir: str,
450
+ training_state: Dict,
451
+ progress=None,
452
+ ):
453
+ """Start LoRA training from preprocessed tensors.
454
+
455
+ This is a generator function that yields progress updates.
456
+ """
457
+ if not tensor_dir or not tensor_dir.strip():
458
+ yield "❌ Please enter a tensor directory path", "", None, training_state
459
+ return
460
+
461
+ tensor_dir = tensor_dir.strip()
462
+
463
+ if not os.path.exists(tensor_dir):
464
+ yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
465
+ return
466
+
467
+ if dit_handler is None or dit_handler.model is None:
468
+ yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
469
+ return
470
+
471
+ # Check for required training dependencies
472
+ try:
473
+ from lightning.fabric import Fabric
474
+ from peft import get_peft_model, LoraConfig
475
+ except ImportError as e:
476
+ yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
477
+ return
478
+
479
+ training_state["is_training"] = True
480
+ training_state["should_stop"] = False
481
+
482
+ try:
483
+ from acestep.training.trainer import LoRATrainer
484
+ from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
485
+
486
+ # Create configs
487
+ lora_config = LoRAConfigClass(
488
+ r=lora_rank,
489
+ alpha=lora_alpha,
490
+ dropout=lora_dropout,
491
+ )
492
+
493
+ training_config = TrainingConfig(
494
+ shift=training_shift,
495
+ learning_rate=learning_rate,
496
+ batch_size=train_batch_size,
497
+ gradient_accumulation_steps=gradient_accumulation,
498
+ max_epochs=train_epochs,
499
+ save_every_n_epochs=save_every_n_epochs,
500
+ seed=training_seed,
501
+ output_dir=lora_output_dir,
502
+ )
503
+
504
+ import pandas as pd
505
+
506
+ # Initialize training log and loss history
507
+ log_lines = []
508
+ loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
509
+
510
+ # Start timer
511
+ start_time = time.time()
512
+
513
+ yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
514
+
515
+ # Create trainer
516
+ trainer = LoRATrainer(
517
+ dit_handler=dit_handler,
518
+ lora_config=lora_config,
519
+ training_config=training_config,
520
+ )
521
+
522
+ # Collect loss history
523
+ step_list = []
524
+ loss_list = []
525
+
526
+ # Train with progress updates using preprocessed tensors
527
+ for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
528
+ # Calculate elapsed time and ETA
529
+ elapsed_seconds = time.time() - start_time
530
+ time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
531
+
532
+ # Parse "Epoch x/y" from status to calculate ETA
533
+ match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
534
+ if match:
535
+ current_ep = int(match.group(1))
536
+ total_ep = int(match.group(2))
537
+ if current_ep > 0:
538
+ eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
539
+ time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
540
+
541
+ # Display status with time info
542
+ display_status = f"{status}\n{time_info}"
543
+
544
+ # Terminal log
545
+ log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
546
+ logger.info(log_msg)
547
+
548
+ # Add to UI log
549
+ log_lines.append(status)
550
+ if len(log_lines) > 15:
551
+ log_lines = log_lines[-15:]
552
+ log_text = "\n".join(log_lines)
553
+
554
+ # Track loss for plot (only valid values)
555
+ if step > 0 and loss is not None and loss == loss: # Check for NaN
556
+ step_list.append(step)
557
+ loss_list.append(float(loss))
558
+ loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
559
+
560
+ yield display_status, log_text, loss_data, training_state
561
+
562
+ if training_state.get("should_stop", False):
563
+ logger.info("⏹️ Training stopped by user")
564
+ log_lines.append("⏹️ Training stopped by user")
565
+ yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
566
+ break
567
+
568
+ total_time = time.time() - start_time
569
+ training_state["is_training"] = False
570
+ completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
571
+
572
+ logger.info(completion_msg)
573
+ log_lines.append(completion_msg)
574
+
575
+ yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
576
+
577
+ except Exception as e:
578
+ logger.exception("Training error")
579
+ training_state["is_training"] = False
580
+ import pandas as pd
581
+ empty_df = pd.DataFrame({"step": [], "loss": []})
582
+ yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
583
+
584
+
585
+ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
586
+ """Stop the current training process.
587
+
588
+ Returns:
589
+ Tuple of (status, training_state)
590
+ """
591
+ if not training_state.get("is_training", False):
592
+ return "⚠️ No training in progress", training_state
593
+
594
+ training_state["should_stop"] = True
595
+ return "⏹️ Stopping training...", training_state
596
+
597
+
598
+ def export_lora(
599
+ export_path: str,
600
+ lora_output_dir: str,
601
+ ) -> str:
602
+ """Export the trained LoRA weights.
603
+
604
+ Returns:
605
+ Status message
606
+ """
607
+ if not export_path or not export_path.strip():
608
+ return "❌ Please enter an export path"
609
+
610
+ # Check if there's a trained model to export
611
+ final_dir = os.path.join(lora_output_dir, "final")
612
+ checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
613
+
614
+ # Prefer final, fallback to checkpoints
615
+ if os.path.exists(final_dir):
616
+ source_path = final_dir
617
+ elif os.path.exists(checkpoint_dir):
618
+ # Find the latest checkpoint
619
+ checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
620
+ if not checkpoints:
621
+ return "❌ No checkpoints found"
622
+
623
+ checkpoints.sort(key=lambda x: int(x.split("_")[1]))
624
+ latest = checkpoints[-1]
625
+ source_path = os.path.join(checkpoint_dir, latest)
626
+ else:
627
+ return f"❌ No trained model found in {lora_output_dir}"
628
+
629
+ try:
630
+ import shutil
631
+
632
+ export_path = export_path.strip()
633
+ os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
634
+
635
+ if os.path.exists(export_path):
636
+ shutil.rmtree(export_path)
637
+
638
+ shutil.copytree(source_path, export_path)
639
+
640
+ return f"✅ LoRA exported to {export_path}"
641
+
642
+ except Exception as e:
643
+ logger.exception("Export error")
644
+ return f"❌ Export failed: {str(e)}"
code/acestep/gradio_ui/i18n.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Internationalization (i18n) module for Gradio UI
3
+ Supports multiple languages with easy translation management
4
+ """
5
+ import os
6
+ import json
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class I18n:
11
+ """Internationalization handler"""
12
+
13
+ def __init__(self, default_language: str = "en"):
14
+ """
15
+ Initialize i18n handler
16
+
17
+ Args:
18
+ default_language: Default language code (en, zh, ja, etc.)
19
+ """
20
+ self.current_language = default_language
21
+ self.translations: Dict[str, Dict[str, str]] = {}
22
+ self._load_all_translations()
23
+
24
+ def _load_all_translations(self):
25
+ """Load all translation files from i18n directory"""
26
+ current_file = os.path.abspath(__file__)
27
+ module_dir = os.path.dirname(current_file)
28
+ i18n_dir = os.path.join(module_dir, "i18n")
29
+
30
+ if not os.path.exists(i18n_dir):
31
+ # Create i18n directory if it doesn't exist
32
+ os.makedirs(i18n_dir)
33
+ return
34
+
35
+ # Load all JSON files in i18n directory
36
+ for filename in os.listdir(i18n_dir):
37
+ if filename.endswith(".json"):
38
+ lang_code = filename[:-5] # Remove .json extension
39
+ filepath = os.path.join(i18n_dir, filename)
40
+ try:
41
+ with open(filepath, 'r', encoding='utf-8') as f:
42
+ self.translations[lang_code] = json.load(f)
43
+ except Exception as e:
44
+ print(f"Error loading translation file {filename}: {e}")
45
+
46
+ def set_language(self, language: str):
47
+ """Set current language"""
48
+ if language in self.translations:
49
+ self.current_language = language
50
+ else:
51
+ print(f"Warning: Language '{language}' not found, using default")
52
+
53
+ def t(self, key: str, **kwargs) -> str:
54
+ """
55
+ Translate a key to current language
56
+
57
+ Args:
58
+ key: Translation key (dot-separated for nested keys)
59
+ **kwargs: Optional format parameters
60
+
61
+ Returns:
62
+ Translated string
63
+ """
64
+ # Get translation from current language
65
+ translation = self._get_nested_value(
66
+ self.translations.get(self.current_language, {}),
67
+ key
68
+ )
69
+
70
+ # Fallback to English if not found
71
+ if translation is None:
72
+ translation = self._get_nested_value(
73
+ self.translations.get('en', {}),
74
+ key
75
+ )
76
+
77
+ # Final fallback to key itself
78
+ if translation is None:
79
+ translation = key
80
+
81
+ # Apply formatting if kwargs provided
82
+ if kwargs:
83
+ try:
84
+ translation = translation.format(**kwargs)
85
+ except KeyError:
86
+ pass
87
+
88
+ return translation
89
+
90
+ def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
91
+ """
92
+ Get nested dictionary value using dot notation
93
+
94
+ Args:
95
+ data: Dictionary to search
96
+ key: Dot-separated key (e.g., "section.subsection.key")
97
+
98
+ Returns:
99
+ Value if found, None otherwise
100
+ """
101
+ keys = key.split('.')
102
+ current = data
103
+
104
+ for k in keys:
105
+ if isinstance(current, dict) and k in current:
106
+ current = current[k]
107
+ else:
108
+ return None
109
+
110
+ return current if isinstance(current, str) else None
111
+
112
+ def get_available_languages(self) -> list:
113
+ """Get list of available language codes"""
114
+ return list(self.translations.keys())
115
+
116
+
117
+ # Global i18n instance
118
+ _i18n_instance: Optional[I18n] = None
119
+
120
+
121
+ def get_i18n(language: Optional[str] = None) -> I18n:
122
+ """
123
+ Get global i18n instance
124
+
125
+ Args:
126
+ language: Optional language to set
127
+
128
+ Returns:
129
+ I18n instance
130
+ """
131
+ global _i18n_instance
132
+
133
+ if _i18n_instance is None:
134
+ _i18n_instance = I18n(default_language=language or "en")
135
+ elif language is not None:
136
+ _i18n_instance.set_language(language)
137
+
138
+ return _i18n_instance
139
+
140
+
141
+ def t(key: str, **kwargs) -> str:
142
+ """
143
+ Convenience function for translation
144
+
145
+ Args:
146
+ key: Translation key
147
+ **kwargs: Optional format parameters
148
+
149
+ Returns:
150
+ Translated string
151
+ """
152
+ return get_i18n().t(key, **kwargs)
code/acestep/gradio_ui/i18n/en.json ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 Playground💡",
4
+ "subtitle": "Pushing the Boundaries of Open-Source Music Generation"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 Dataset Explorer",
8
+ "dataset_label": "Dataset",
9
+ "dataset_info": "Choose dataset to explore",
10
+ "import_btn": "📥 Import Dataset",
11
+ "search_type_label": "Search Type",
12
+ "search_type_info": "How to find items",
13
+ "search_value_label": "Search Value",
14
+ "search_value_placeholder": "Enter keys or index (leave empty for random)",
15
+ "search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
16
+ "instruction_label": "📝 Instruction",
17
+ "instruction_placeholder": "No instruction available",
18
+ "metadata_title": "📋 Item Metadata (JSON)",
19
+ "metadata_label": "Complete Item Information",
20
+ "source_audio": "Source Audio",
21
+ "target_audio": "Target Audio",
22
+ "reference_audio": "Reference Audio",
23
+ "get_item_btn": "🔍 Get Item",
24
+ "use_src_checkbox": "Use Source Audio from Dataset",
25
+ "use_src_info": "Check to use the source audio from dataset",
26
+ "data_status_label": "📊 Data Status",
27
+ "data_status_default": "❌ No dataset imported",
28
+ "autofill_btn": "📋 Auto-fill Generation Form"
29
+ },
30
+ "service": {
31
+ "title": "🔧 Service Configuration",
32
+ "checkpoint_label": "Checkpoint File",
33
+ "checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
34
+ "refresh_btn": "🔄 Refresh",
35
+ "model_path_label": "Main Model Path",
36
+ "model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
37
+ "device_label": "Device",
38
+ "device_info": "Processing device (auto-detect recommended)",
39
+ "lm_model_path_label": "5Hz LM Model Path",
40
+ "lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
41
+ "backend_label": "5Hz LM Backend",
42
+ "backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
43
+ "init_llm_label": "Initialize 5Hz LM",
44
+ "init_llm_info": "Check to initialize 5Hz LM during service initialization",
45
+ "flash_attention_label": "Use Flash Attention",
46
+ "flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
47
+ "flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
48
+ "offload_cpu_label": "Offload to CPU",
49
+ "offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
50
+ "offload_dit_cpu_label": "Offload DiT to CPU",
51
+ "offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
52
+ "init_btn": "Initialize Service",
53
+ "status_label": "Status",
54
+ "language_label": "UI Language",
55
+ "language_info": "Select interface language"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 Required Inputs",
59
+ "task_type_label": "Task Type",
60
+ "task_type_info": "Select the task type for generation",
61
+ "instruction_label": "Instruction",
62
+ "instruction_info": "Instruction is automatically generated based on task type",
63
+ "load_btn": "Load",
64
+ "track_name_label": "Track Name",
65
+ "track_name_info": "Select track name for lego/extract tasks",
66
+ "track_classes_label": "Track Names",
67
+ "track_classes_info": "Select multiple track classes for complete task",
68
+ "audio_uploads": "🎵 Audio Uploads",
69
+ "reference_audio": "Reference Audio (optional)",
70
+ "source_audio": "Source Audio (optional)",
71
+ "convert_codes_btn": "Convert to Codes",
72
+ "lm_codes_hints": "🎼 LM Codes Hints",
73
+ "lm_codes_label": "LM Codes Hints",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "Paste LM codes hints for text2music generation",
76
+ "lm_codes_sample": "LM Codes Hints (Sample {n})",
77
+ "lm_codes_sample_info": "Codes for sample {n}",
78
+ "transcribe_btn": "Transcribe",
79
+ "repainting_controls": "🎨 Repainting Controls (seconds)",
80
+ "repainting_start": "Repainting Start",
81
+ "repainting_end": "Repainting End",
82
+ "mode_label": "Generation Mode",
83
+ "mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
84
+ "mode_simple": "Simple",
85
+ "mode_custom": "Custom",
86
+ "simple_query_label": "Song Description",
87
+ "simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
88
+ "simple_query_info": "Enter a natural language description of the music you want to generate",
89
+ "simple_vocal_language_label": "Vocal Language (optional)",
90
+ "simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
91
+ "create_sample_btn": "Create Sample",
92
+ "caption_title": "📝 Music Caption",
93
+ "caption_label": "Music Caption (optional)",
94
+ "caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
95
+ "caption_info": "Describe the style, genre, instruments, and mood",
96
+ "lyrics_title": "📝 Lyrics",
97
+ "lyrics_label": "Lyrics (optional)",
98
+ "lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
99
+ "lyrics_info": "Song lyrics with structure",
100
+ "instrumental_label": "Instrumental",
101
+ "format_btn": "Format",
102
+ "optional_params": "⚙️ Optional Parameters",
103
+ "vocal_language_label": "Vocal Language (optional)",
104
+ "vocal_language_info": "use `unknown` for inst",
105
+ "bpm_label": "BPM (optional)",
106
+ "bpm_info": "leave empty for N/A",
107
+ "keyscale_label": "KeyScale (optional)",
108
+ "keyscale_placeholder": "Leave empty for N/A",
109
+ "keyscale_info": "A-G, #/♭, major/minor",
110
+ "timesig_label": "Time Signature (optional)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "Audio Duration (seconds)",
113
+ "duration_info": "Use -1 for random",
114
+ "batch_size_label": "Batch Size",
115
+ "batch_size_info": "Number of audio to generate (max 8)",
116
+ "advanced_settings": "🔧 Advanced Settings",
117
+ "inference_steps_label": "DiT Inference Steps",
118
+ "inference_steps_info": "Turbo: max 8, Base: max 200",
119
+ "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
120
+ "guidance_scale_info": "Higher values follow text more closely",
121
+ "seed_label": "Seed",
122
+ "seed_info": "Use comma-separated values for batches",
123
+ "random_seed_label": "Random Seed",
124
+ "random_seed_info": "Enable to auto-generate seeds",
125
+ "audio_format_label": "Audio Format",
126
+ "audio_format_info": "Audio format for saved files",
127
+ "use_adg_label": "Use ADG",
128
+ "use_adg_info": "Enable Angle Domain Guidance",
129
+ "shift_label": "Shift",
130
+ "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
131
+ "infer_method_label": "Inference Method",
132
+ "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
133
+ "custom_timesteps_label": "Custom Timesteps",
134
+ "custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
135
+ "cfg_interval_start": "CFG Interval Start",
136
+ "cfg_interval_end": "CFG Interval End",
137
+ "lm_params_title": "🤖 LM Generation Parameters",
138
+ "lm_temperature_label": "LM Temperature",
139
+ "lm_temperature_info": "5Hz LM temperature (higher = more random)",
140
+ "lm_cfg_scale_label": "LM CFG Scale",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = disabled)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = disabled)",
146
+ "lm_negative_prompt_label": "LM Negative Prompt",
147
+ "lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
149
+ "cot_metas_label": "CoT Metas",
150
+ "cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
151
+ "cot_language_label": "CoT Language",
152
+ "cot_language_info": "Generate language in CoT (chain-of-thought)",
153
+ "constrained_debug_label": "Constrained Decoding Debug",
154
+ "constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
155
+ "auto_score_label": "Auto Score",
156
+ "auto_score_info": "Automatically calculate quality scores for all generated audios",
157
+ "auto_lrc_label": "Auto LRC",
158
+ "auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
159
+ "lm_batch_chunk_label": "LM Batch Chunk Size",
160
+ "lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
161
+ "codes_strength_label": "LM Codes Strength",
162
+ "codes_strength_info": "Control how many denoising steps use LM-generated codes",
163
+ "cover_strength_label": "Audio Cover Strength",
164
+ "cover_strength_info": "Control how many denoising steps use cover mode",
165
+ "score_sensitivity_label": "Quality Score Sensitivity",
166
+ "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
167
+ "think_label": "Think",
168
+ "parallel_thinking_label": "ParallelThinking",
169
+ "generate_btn": "🎵 Generate Music",
170
+ "autogen_label": "AutoGen",
171
+ "caption_rewrite_label": "CaptionRewrite"
172
+ },
173
+ "results": {
174
+ "title": "🎵 Results",
175
+ "generated_music": "🎵 Generated Music (Sample {n})",
176
+ "send_to_src_btn": "🔗 Send To Src Audio",
177
+ "save_btn": "💾 Save",
178
+ "score_btn": "📊 Score",
179
+ "lrc_btn": "🎵 LRC",
180
+ "quality_score_label": "Quality Score (Sample {n})",
181
+ "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
182
+ "codes_label": "LM Codes (Sample {n})",
183
+ "lrc_label": "Lyrics Timestamps (Sample {n})",
184
+ "lrc_placeholder": "Click 'LRC' to generate timestamps",
185
+ "details_accordion": "📊 Score & LRC & LM Codes",
186
+ "generation_status": "Generation Status",
187
+ "current_batch": "Current Batch",
188
+ "batch_indicator": "Batch {current} / {total}",
189
+ "next_batch_status": "Next Batch Status",
190
+ "prev_btn": "◀ Previous",
191
+ "next_btn": "Next ▶",
192
+ "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
193
+ "batch_results_title": "📁 Batch Results & Generation Details",
194
+ "all_files_label": "📁 All Generated Files (Download)",
195
+ "generation_details": "Generation Details"
196
+ },
197
+ "messages": {
198
+ "no_audio_to_save": "❌ No audio to save",
199
+ "save_success": "✅ Saved audio and metadata to {filename}",
200
+ "save_failed": "❌ Failed to save: {error}",
201
+ "no_file_selected": "⚠️ No file selected",
202
+ "params_loaded": "✅ Parameters loaded from {filename}",
203
+ "invalid_json": "❌ Invalid JSON file: {error}",
204
+ "load_error": "❌ Error loading file: {error}",
205
+ "example_loaded": "📁 Loaded example from {filename}",
206
+ "example_failed": "Failed to parse JSON file {filename}: {error}",
207
+ "example_error": "Error loading example: {error}",
208
+ "lm_generated": "🤖 Generated example using LM",
209
+ "lm_fallback": "Failed to generate example using LM, falling back to examples directory",
210
+ "lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
211
+ "autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
212
+ "batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
213
+ "batch_generating": "🔄 Starting background generation for Batch {n}...",
214
+ "batch_failed": "❌ Background generation failed: {error}",
215
+ "viewing_batch": "✅ Viewing Batch {n}",
216
+ "at_first_batch": "Already at first batch",
217
+ "at_last_batch": "No next batch available",
218
+ "batch_not_found": "Batch {n} not found in queue",
219
+ "no_batch_data": "No batch data found to restore.",
220
+ "params_restored": "✅ UI Parameters restored from Batch {n}",
221
+ "scoring_failed": "❌ Error: Batch data not found",
222
+ "no_codes": "❌ No audio codes available. Please generate music first.",
223
+ "score_failed": "❌ Scoring failed: {error}",
224
+ "score_error": "❌ Error calculating score: {error}",
225
+ "lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
226
+ "lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
227
+ "lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
228
+ "lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
229
+ "lrc_empty_result": "⚠️ LRC generation produced empty result.",
230
+ "empty_query": "⚠️ Please enter a music description.",
231
+ "sample_creation_failed": "❌ Failed to create sample. Please try again.",
232
+ "sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
233
+ "simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
234
+ "simple_examples_empty": "⚠️ No example files found in simple mode examples.",
235
+ "simple_example_loaded": "🎲 Loaded random example from {filename}",
236
+ "format_success": "✅ Caption and lyrics formatted successfully",
237
+ "format_failed": "❌ Format failed: {error}",
238
+ "skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
239
+ "invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
240
+ "timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
241
+ "timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
242
+ }
243
+ }
code/acestep/gradio_ui/i18n/ja.json ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
4
+ "subtitle": "オープンソース音楽生成の限界を押し広げる"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 データセットエクスプローラー",
8
+ "dataset_label": "データセット",
9
+ "dataset_info": "探索するデータセットを選択",
10
+ "import_btn": "📥 データセットをインポート",
11
+ "search_type_label": "検索タイプ",
12
+ "search_type_info": "アイテムの検索方法",
13
+ "search_value_label": "検索値",
14
+ "search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
15
+ "search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
16
+ "instruction_label": "📝 指示",
17
+ "instruction_placeholder": "利用可能な指示がありません",
18
+ "metadata_title": "📋 アイテムメタデータ (JSON)",
19
+ "metadata_label": "完全なアイテム情報",
20
+ "source_audio": "ソースオーディオ",
21
+ "target_audio": "ターゲットオーディオ",
22
+ "reference_audio": "リファレンスオーディオ",
23
+ "get_item_btn": "🔍 アイテムを取得",
24
+ "use_src_checkbox": "データセットのソースオーディオを使用",
25
+ "use_src_info": "データセットのソースオーディオを使用する場合はチェック",
26
+ "data_status_label": "📊 データステータス",
27
+ "data_status_default": "❌ データセットがインポートされていません",
28
+ "autofill_btn": "📋 生成フォームを自動入力"
29
+ },
30
+ "service": {
31
+ "title": "🔧 サービス設定",
32
+ "checkpoint_label": "チェックポイントファイル",
33
+ "checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
34
+ "refresh_btn": "🔄 更新",
35
+ "model_path_label": "メインモデルパス",
36
+ "model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
37
+ "device_label": "デバイス",
38
+ "device_info": "処理デバイス(自動検出を推奨)",
39
+ "lm_model_path_label": "5Hz LM モデルパス",
40
+ "lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
41
+ "backend_label": "5Hz LM バックエンド",
42
+ "backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
43
+ "init_llm_label": "5Hz LM を初期化",
44
+ "init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
45
+ "flash_attention_label": "Flash Attention を使用",
46
+ "flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
47
+ "flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
48
+ "offload_cpu_label": "CPUにオフロード",
49
+ "offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
50
+ "offload_dit_cpu_label": "DiTをCPUにオフロード",
51
+ "offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
52
+ "init_btn": "サービスを初期化",
53
+ "status_label": "ステータス",
54
+ "language_label": "UI言語",
55
+ "language_info": "インターフェース言語を選択"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必須入力",
59
+ "task_type_label": "タスクタイプ",
60
+ "task_type_info": "生成のタスクタイプを選択",
61
+ "instruction_label": "指示",
62
+ "instruction_info": "指示はタスクタイプに基づいて自動生成されます",
63
+ "load_btn": "読み込む",
64
+ "track_name_label": "トラック名",
65
+ "track_name_info": "lego/extractタスクのトラック名を選択",
66
+ "track_classes_label": "トラック名",
67
+ "track_classes_info": "completeタスクの複数のトラッククラスを選択",
68
+ "audio_uploads": "🎵 オーディオアップロード",
69
+ "reference_audio": "リファレンスオーディオ(オプション)",
70
+ "source_audio": "ソースオーディオ(オプション)",
71
+ "convert_codes_btn": "コードに変換",
72
+ "lm_codes_hints": "🎼 LM コードヒント",
73
+ "lm_codes_label": "LM コードヒント",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
76
+ "lm_codes_sample": "LM コードヒント(サンプル {n})",
77
+ "lm_codes_sample_info": "サンプル{n}のコード",
78
+ "transcribe_btn": "転写",
79
+ "repainting_controls": "🎨 再描画コントロール(秒)",
80
+ "repainting_start": "再描画開始",
81
+ "repainting_end": "再描画終了",
82
+ "mode_label": "生成モード",
83
+ "mode_info": "シンプル:自然言語で音楽を説明��カスタム:キャプションと歌詞を完全にコントロール。",
84
+ "mode_simple": "シンプル",
85
+ "mode_custom": "カスタム",
86
+ "simple_query_label": "曲の説明",
87
+ "simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
88
+ "simple_query_info": "生成したい音楽の自然言語の説明を入力",
89
+ "simple_vocal_language_label": "ボーカル言語(オプション)",
90
+ "simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
91
+ "create_sample_btn": "サンプル作成",
92
+ "caption_title": "📝 音楽キャプション",
93
+ "caption_label": "音楽キャプション(オプション)",
94
+ "caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
95
+ "caption_info": "スタイル、ジャンル、楽器、ムードを説明",
96
+ "lyrics_title": "📝 歌詞",
97
+ "lyrics_label": "歌詞(オプション)",
98
+ "lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
99
+ "lyrics_info": "構造を持つ曲の歌詞",
100
+ "instrumental_label": "インストゥルメンタル",
101
+ "format_btn": "フォーマット",
102
+ "optional_params": "⚙️ オプションパラメータ",
103
+ "vocal_language_label": "ボーカル言語(オプション)",
104
+ "vocal_language_info": "インストには`unknown`を使用",
105
+ "bpm_label": "BPM(オプション)",
106
+ "bpm_info": "空白の場合はN/A",
107
+ "keyscale_label": "キースケール(オプション)",
108
+ "keyscale_placeholder": "空白の場合はN/A",
109
+ "keyscale_info": "A-G, #/♭, メジャー/マイナー",
110
+ "timesig_label": "拍子記号(オプション)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "オーディオ長(秒)",
113
+ "duration_info": "ランダムの場合は-1を使用",
114
+ "batch_size_label": "バッチサイズ",
115
+ "batch_size_info": "生成するオーディオの数(最大8)",
116
+ "advanced_settings": "🔧 詳細設定",
117
+ "inference_steps_label": "DiT 推論ステップ",
118
+ "inference_steps_info": "Turbo: 最大8、Base: 最大200",
119
+ "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
120
+ "guidance_scale_info": "値が高いほどテキストに忠実に従う",
121
+ "seed_label": "シード",
122
+ "seed_info": "バッチにはカンマ区切りの値を使用",
123
+ "random_seed_label": "ランダムシード",
124
+ "random_seed_info": "有効にすると自動的にシードを生成",
125
+ "audio_format_label": "オーディオフォーマット",
126
+ "audio_format_info": "保存ファイルのオーディオフォーマット",
127
+ "use_adg_label": "ADG を使用",
128
+ "use_adg_info": "角度ドメインガイダンスを有効化",
129
+ "shift_label": "シフト",
130
+ "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
131
+ "infer_method_label": "推論方法",
132
+ "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
133
+ "custom_timesteps_label": "カスタムタイムステップ",
134
+ "custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
135
+ "cfg_interval_start": "CFG 間隔開始",
136
+ "cfg_interval_end": "CFG 間隔終了",
137
+ "lm_params_title": "🤖 LM 生成パラメータ",
138
+ "lm_temperature_label": "LM 温度",
139
+ "lm_temperature_info": "5Hz LM温度(高いほどランダム)",
140
+ "lm_cfg_scale_label": "LM CFG スケール",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = 無効)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = 無効)",
146
+ "lm_negative_prompt_label": "LM ネガティブプロンプト",
147
+ "lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
149
+ "cot_metas_label": "CoT メタデータ",
150
+ "cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
151
+ "cot_language_label": "CoT 言語",
152
+ "cot_language_info": "CoTで言語を生成(思考の連鎖)",
153
+ "constrained_debug_label": "制約付きデコーディングデバッグ",
154
+ "constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
155
+ "auto_score_label": "自動スコアリング",
156
+ "auto_score_info": "生成���れたすべてのオーディオの品質スコアを自動計算",
157
+ "auto_lrc_label": "自動 LRC",
158
+ "auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
159
+ "lm_batch_chunk_label": "LM バッチチャンクサイズ",
160
+ "lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
161
+ "codes_strength_label": "LM コード強度",
162
+ "codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
163
+ "cover_strength_label": "オーディオカバー強度",
164
+ "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
165
+ "score_sensitivity_label": "品質スコア感度",
166
+ "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
167
+ "think_label": "思考",
168
+ "parallel_thinking_label": "並列思考",
169
+ "generate_btn": "🎵 音楽を生成",
170
+ "autogen_label": "自動生成",
171
+ "caption_rewrite_label": "キャプション書き換え"
172
+ },
173
+ "results": {
174
+ "title": "🎵 結果",
175
+ "generated_music": "🎵 生成された音楽(サンプル {n})",
176
+ "send_to_src_btn": "🔗 ソースオーディオに送信",
177
+ "save_btn": "💾 保存",
178
+ "score_btn": "📊 スコア",
179
+ "lrc_btn": "🎵 LRC",
180
+ "quality_score_label": "品質スコア(サンプル {n})",
181
+ "quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
182
+ "codes_label": "LM コード(サンプル {n})",
183
+ "lrc_label": "歌詞タイムスタンプ(サンプル {n})",
184
+ "lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
185
+ "details_accordion": "📊 スコア & LRC & LM コード",
186
+ "generation_status": "生成ステータス",
187
+ "current_batch": "現在のバッチ",
188
+ "batch_indicator": "バッチ {current} / {total}",
189
+ "next_batch_status": "次のバッチステータス",
190
+ "prev_btn": "◀ 前へ",
191
+ "next_btn": "次へ ▶",
192
+ "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
193
+ "batch_results_title": "📁 バッチ結果と生成詳細",
194
+ "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
195
+ "generation_details": "生成詳細"
196
+ },
197
+ "messages": {
198
+ "no_audio_to_save": "❌ 保存するオーディオがありません",
199
+ "save_success": "✅ オーディオとメタデータを {filename} に保存しました",
200
+ "save_failed": "❌ 保存に失敗しました: {error}",
201
+ "no_file_selected": "⚠️ ファイルが選択されていません",
202
+ "params_loaded": "✅ {filename} からパラメータを読み込みました",
203
+ "invalid_json": "❌ 無効なJSONファイル: {error}",
204
+ "load_error": "❌ ファイルの読み込みエラー: {error}",
205
+ "example_loaded": "📁 {filename} からサンプルを読み込みました",
206
+ "example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
207
+ "example_error": "サンプル読み込みエラー: {error}",
208
+ "lm_generated": "🤖 LMを使用してサンプルを生成しました",
209
+ "lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
210
+ "lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
211
+ "autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
212
+ "batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
213
+ "batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
214
+ "batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
215
+ "viewing_batch": "✅ バッチ {n} を表示中",
216
+ "at_first_batch": "すでに最初のバッチです",
217
+ "at_last_batch": "次のバッチはありません",
218
+ "batch_not_found": "キューにバッチ {n} が見つかりません",
219
+ "no_batch_data": "復元するバッチデータがありません。",
220
+ "params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
221
+ "scoring_failed": "❌ エラー: バッチデータが見つかりません",
222
+ "no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
223
+ "score_failed": "❌ スコアリングに失敗しました: {error}",
224
+ "score_error": "❌ スコア計算エラー: {error}",
225
+ "lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
226
+ "lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
227
+ "lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
228
+ "lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
229
+ "lrc_empty_result": "⚠️ LRC生成の結果が空です。",
230
+ "empty_query": "⚠️ 音楽の説明を入力してください。",
231
+ "sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
232
+ "sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
233
+ "simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
234
+ "simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
235
+ "simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
236
+ "format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
237
+ "format_failed": "❌ フォーマットに失敗しました: {error}",
238
+ "skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
239
+ "invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
240
+ "timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
241
+ "timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
242
+ }
243
+ }
code/acestep/gradio_ui/i18n/zh.json ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 演练场💡",
4
+ "subtitle": "推动开源音乐生成的边界"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 数据集浏览器",
8
+ "dataset_label": "数据集",
9
+ "dataset_info": "选择要浏览的数据集",
10
+ "import_btn": "📥 导入数据集",
11
+ "search_type_label": "搜索类型",
12
+ "search_type_info": "如何查找项目",
13
+ "search_value_label": "搜索值",
14
+ "search_value_placeholder": "输入键或索引(留空表示随机)",
15
+ "search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
16
+ "instruction_label": "📝 指令",
17
+ "instruction_placeholder": "无可用指令",
18
+ "metadata_title": "📋 项目元数据 (JSON)",
19
+ "metadata_label": "完整项目信息",
20
+ "source_audio": "源音频",
21
+ "target_audio": "目标音频",
22
+ "reference_audio": "参考音频",
23
+ "get_item_btn": "🔍 获取项目",
24
+ "use_src_checkbox": "使用数据集中的源音频",
25
+ "use_src_info": "勾选以使用数据集中的源音频",
26
+ "data_status_label": "📊 数据状态",
27
+ "data_status_default": "❌ 未导入数据集",
28
+ "autofill_btn": "📋 自动填充生成表单"
29
+ },
30
+ "service": {
31
+ "title": "🔧 服务配置",
32
+ "checkpoint_label": "检查点文件",
33
+ "checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
34
+ "refresh_btn": "🔄 刷新",
35
+ "model_path_label": "主模型路径",
36
+ "model_path_info": "选择模型配置目录(从检查点自动扫描)",
37
+ "device_label": "设备",
38
+ "device_info": "处理设备(建议自动检测)",
39
+ "lm_model_path_label": "5Hz LM 模型路径",
40
+ "lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
41
+ "backend_label": "5Hz LM 后端",
42
+ "backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
43
+ "init_llm_label": "初始化 5Hz LM",
44
+ "init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
45
+ "flash_attention_label": "使用Flash Attention",
46
+ "flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
47
+ "flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
48
+ "offload_cpu_label": "卸载到CPU",
49
+ "offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
50
+ "offload_dit_cpu_label": "将DiT卸载到CPU",
51
+ "offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
52
+ "init_btn": "初始化服务",
53
+ "status_label": "状态",
54
+ "language_label": "界面语言",
55
+ "language_info": "选择界面语言"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必需输入",
59
+ "task_type_label": "任务类型",
60
+ "task_type_info": "选择生成的任务类型",
61
+ "instruction_label": "指令",
62
+ "instruction_info": "指令根据任务类型自动生成",
63
+ "load_btn": "加载",
64
+ "track_name_label": "音轨名称",
65
+ "track_name_info": "为lego/extract任务选择音轨名称",
66
+ "track_classes_label": "音轨名称",
67
+ "track_classes_info": "为complete任务选择多个音轨类别",
68
+ "audio_uploads": "🎵 音频上传",
69
+ "reference_audio": "参考音频(可选)",
70
+ "source_audio": "源音频(可选)",
71
+ "convert_codes_btn": "转换为代码",
72
+ "lm_codes_hints": "🎼 LM 代码提示",
73
+ "lm_codes_label": "LM 代码提示",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "粘贴用于text2music生成的LM代码提示",
76
+ "lm_codes_sample": "LM 代码提示(样本 {n})",
77
+ "lm_codes_sample_info": "样本{n}的代码",
78
+ "transcribe_btn": "转录",
79
+ "repainting_controls": "🎨 重绘控制(秒)",
80
+ "repainting_start": "重绘开始",
81
+ "repainting_end": "重绘结束",
82
+ "mode_label": "生成模式",
83
+ "mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
84
+ "mode_simple": "简单",
85
+ "mode_custom": "自定义",
86
+ "simple_query_label": "歌曲描述",
87
+ "simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
88
+ "simple_query_info": "输入你想生成的音乐的自然语言描述",
89
+ "simple_vocal_language_label": "人声语言(可选)",
90
+ "simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
91
+ "create_sample_btn": "创建样本",
92
+ "caption_title": "📝 音乐描述",
93
+ "caption_label": "音乐描述(可选)",
94
+ "caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
95
+ "caption_info": "描述风格、流派、乐器和情绪",
96
+ "lyrics_title": "📝 歌词",
97
+ "lyrics_label": "歌词(可选)",
98
+ "lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
99
+ "lyrics_info": "带有结构的歌曲歌词",
100
+ "instrumental_label": "纯音乐",
101
+ "format_btn": "格式化",
102
+ "optional_params": "⚙️ 可选参数",
103
+ "vocal_language_label": "人声语言(可选)",
104
+ "vocal_language_info": "纯音乐使用 `unknown`",
105
+ "bpm_label": "BPM(可选)",
106
+ "bpm_info": "留空表示N/A",
107
+ "keyscale_label": "调性(可选)",
108
+ "keyscale_placeholder": "留空表示N/A",
109
+ "keyscale_info": "A-G, #/♭, 大调/小调",
110
+ "timesig_label": "拍号(可选)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "音频时长(秒)",
113
+ "duration_info": "使用-1表示随机",
114
+ "batch_size_label": "批量大小",
115
+ "batch_size_info": "要生成的音频数量(最多8个)",
116
+ "advanced_settings": "🔧 高级设置",
117
+ "inference_steps_label": "DiT 推理步数",
118
+ "inference_steps_info": "Turbo: 最多8, Base: 最多200",
119
+ "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
120
+ "guidance_scale_info": "更高的值更紧密地遵循文本",
121
+ "seed_label": "种子",
122
+ "seed_info": "批量使用逗号分隔的值",
123
+ "random_seed_label": "随机种子",
124
+ "random_seed_info": "启用以自动生成种子",
125
+ "audio_format_label": "音频格式",
126
+ "audio_format_info": "保存文件的音频格式",
127
+ "use_adg_label": "使用 ADG",
128
+ "use_adg_info": "启用角域引导",
129
+ "shift_label": "Shift",
130
+ "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
131
+ "infer_method_label": "推理方法",
132
+ "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
133
+ "custom_timesteps_label": "自定义时间步",
134
+ "custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
135
+ "cfg_interval_start": "CFG 间隔开始",
136
+ "cfg_interval_end": "CFG 间隔结束",
137
+ "lm_params_title": "🤖 LM 生成参数",
138
+ "lm_temperature_label": "LM 温度",
139
+ "lm_temperature_info": "5Hz LM温度(越高越随机)",
140
+ "lm_cfg_scale_label": "LM CFG 比例",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = 禁用)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = 禁用)",
146
+ "lm_negative_prompt_label": "LM 负面提示",
147
+ "lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
149
+ "cot_metas_label": "CoT 元数据",
150
+ "cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
151
+ "cot_language_label": "CoT 语言",
152
+ "cot_language_info": "在CoT中生成语言(思维链)",
153
+ "constrained_debug_label": "约束解码调试",
154
+ "constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
155
+ "auto_score_label": "自动评分",
156
+ "auto_score_info": "自动计算所有生成音频的质量分数",
157
+ "auto_lrc_label": "自动 LRC",
158
+ "auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
159
+ "lm_batch_chunk_label": "LM 批量块大小",
160
+ "lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
161
+ "codes_strength_label": "LM 代码强度",
162
+ "codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
163
+ "cover_strength_label": "音频覆盖强度",
164
+ "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
165
+ "score_sensitivity_label": "质量评分敏感度",
166
+ "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
167
+ "think_label": "思考",
168
+ "parallel_thinking_label": "并行思考",
169
+ "generate_btn": "🎵 生成音乐",
170
+ "autogen_label": "自动生成",
171
+ "caption_rewrite_label": "描述重写"
172
+ },
173
+ "results": {
174
+ "title": "🎵 结果",
175
+ "generated_music": "🎵 生成的音乐(样本 {n})",
176
+ "send_to_src_btn": "🔗 发送到源音频",
177
+ "save_btn": "💾 保存",
178
+ "score_btn": "📊 评分",
179
+ "lrc_btn": "🎵 LRC",
180
+ "quality_score_label": "质量分数(样本 {n})",
181
+ "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
182
+ "codes_label": "LM 代码(样本 {n})",
183
+ "lrc_label": "歌词时间戳(样本 {n})",
184
+ "lrc_placeholder": "点击'LRC'生成时间戳",
185
+ "details_accordion": "📊 评分与LRC与LM代码",
186
+ "generation_status": "生成状态",
187
+ "current_batch": "当前批次",
188
+ "batch_indicator": "批次 {current} / {total}",
189
+ "next_batch_status": "下一批次状态",
190
+ "prev_btn": "◀ 上一个",
191
+ "next_btn": "下一个 ▶",
192
+ "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
193
+ "batch_results_title": "📁 批量结果和生成详情",
194
+ "all_files_label": "📁 所有生成的文件(下载)",
195
+ "generation_details": "生成详情"
196
+ },
197
+ "messages": {
198
+ "no_audio_to_save": "��� 没有要保存的音频",
199
+ "save_success": "✅ 已将音频和元数据保存到 {filename}",
200
+ "save_failed": "❌ 保存失败: {error}",
201
+ "no_file_selected": "⚠️ 未选择文件",
202
+ "params_loaded": "✅ 已从 {filename} 加载参数",
203
+ "invalid_json": "❌ 无效的JSON文件: {error}",
204
+ "load_error": "❌ 加载文件时出错: {error}",
205
+ "example_loaded": "📁 已从 {filename} 加载示例",
206
+ "example_failed": "解析JSON文件 {filename} 失败: {error}",
207
+ "example_error": "加载示例时出错: {error}",
208
+ "lm_generated": "🤖 使用LM生成的示例",
209
+ "lm_fallback": "使用LM生成示例失败,回退到示例目录",
210
+ "lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
211
+ "autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
212
+ "batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
213
+ "batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
214
+ "batch_failed": "❌ 后台生成失败: {error}",
215
+ "viewing_batch": "✅ 查看批次 {n}",
216
+ "at_first_batch": "已在第一批次",
217
+ "at_last_batch": "没有下一批次可用",
218
+ "batch_not_found": "在队列中未找到批次 {n}",
219
+ "no_batch_data": "没有要恢复的批次数据。",
220
+ "params_restored": "✅ 已从批次 {n} 恢复UI参数",
221
+ "scoring_failed": "❌ 错误: 未找到批次数据",
222
+ "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
223
+ "score_failed": "❌ 评分失败: {error}",
224
+ "score_error": "❌ 计算分数时出错: {error}",
225
+ "lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
226
+ "lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
227
+ "lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
228
+ "lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
229
+ "lrc_empty_result": "⚠️ LRC生成结果为空。",
230
+ "empty_query": "⚠️ 请输入音乐描述。",
231
+ "sample_creation_failed": "❌ 创建样本失败。请重试。",
232
+ "sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
233
+ "simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
234
+ "simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
235
+ "simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
236
+ "format_success": "✅ 描述和歌词格式化成功",
237
+ "format_failed": "❌ 格式化失败: {error}",
238
+ "skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
239
+ "invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
240
+ "timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
241
+ "timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
242
+ }
243
+ }
code/acestep/gradio_ui/interfaces/__init__.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Components Module
3
+ Contains all Gradio interface component definitions and layouts
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import get_i18n, t
7
+ from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
+ from acestep.gradio_ui.interfaces.generation import create_generation_section
9
+ from acestep.gradio_ui.interfaces.result import create_results_section
10
+ from acestep.gradio_ui.interfaces.training import create_training_section
11
+ from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
12
+
13
+
14
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
15
+ """
16
+ Create Gradio interface
17
+
18
+ Args:
19
+ dit_handler: DiT handler instance
20
+ llm_handler: LM handler instance
21
+ dataset_handler: Dataset handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
25
+
26
+ Returns:
27
+ Gradio Blocks instance
28
+ """
29
+ # Initialize i18n with selected language
30
+ i18n = get_i18n(language)
31
+
32
+ with gr.Blocks(
33
+ title=t("app.title"),
34
+ theme=gr.themes.Soft(),
35
+ css="""
36
+ .main-header {
37
+ text-align: center;
38
+ margin-bottom: 2rem;
39
+ }
40
+ .section-header {
41
+ background: linear-gradient(90deg, #4CAF50, #45a049);
42
+ color: white;
43
+ padding: 10px;
44
+ border-radius: 5px;
45
+ margin: 10px 0;
46
+ }
47
+ .lm-hints-row {
48
+ align-items: stretch;
49
+ }
50
+ .lm-hints-col {
51
+ display: flex;
52
+ }
53
+ .lm-hints-col > div {
54
+ flex: 1;
55
+ display: flex;
56
+ }
57
+ .lm-hints-btn button {
58
+ height: 100%;
59
+ width: 100%;
60
+ }
61
+ """
62
+ ) as demo:
63
+
64
+ gr.HTML(f"""
65
+ <div class="main-header">
66
+ <h1>{t("app.title")}</h1>
67
+ <p>{t("app.subtitle")}</p>
68
+ </div>
69
+ """)
70
+
71
+ # Dataset Explorer Section
72
+ dataset_section = create_dataset_section(dataset_handler)
73
+
74
+ # Generation Section (pass init_params and language to support pre-initialization)
75
+ generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
76
+
77
+ # Results Section
78
+ results_section = create_results_section(dit_handler)
79
+
80
+ # Training Section (LoRA training and dataset builder)
81
+ # Pass init_params to support hiding in service mode
82
+ training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
83
+
84
+ # Connect event handlers
85
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
86
+
87
+ # Connect training event handlers
88
+ setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
89
+
90
+ return demo
code/acestep/gradio_ui/interfaces/dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Dataset Section Module
3
+ Contains dataset explorer section component definitions
4
+ """
5
+ import gradio as gr
6
+
7
+
8
+ def create_dataset_section(dataset_handler) -> dict:
9
+ """Create dataset explorer section"""
10
+ with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
11
+ with gr.Row(equal_height=True):
12
+ dataset_type = gr.Dropdown(
13
+ choices=["train", "test"],
14
+ value="train",
15
+ label="Dataset",
16
+ info="Choose dataset to explore",
17
+ scale=2
18
+ )
19
+ import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
20
+
21
+ search_type = gr.Dropdown(
22
+ choices=["keys", "idx", "random"],
23
+ value="random",
24
+ label="Search Type",
25
+ info="How to find items",
26
+ scale=1
27
+ )
28
+ search_value = gr.Textbox(
29
+ label="Search Value",
30
+ placeholder="Enter keys or index (leave empty for random)",
31
+ info="Keys: exact match, Index: 0 to dataset size-1",
32
+ scale=2
33
+ )
34
+
35
+ instruction_display = gr.Textbox(
36
+ label="📝 Instruction",
37
+ interactive=False,
38
+ placeholder="No instruction available",
39
+ lines=1
40
+ )
41
+
42
+ repaint_viz_plot = gr.Plot()
43
+
44
+ with gr.Accordion("📋 Item Metadata (JSON)", open=False):
45
+ item_info_json = gr.Code(
46
+ label="Complete Item Information",
47
+ language="json",
48
+ interactive=False,
49
+ lines=15
50
+ )
51
+
52
+ with gr.Row(equal_height=True):
53
+ item_src_audio = gr.Audio(
54
+ label="Source Audio",
55
+ type="filepath",
56
+ interactive=False,
57
+ scale=8
58
+ )
59
+ get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
60
+
61
+ with gr.Row(equal_height=True):
62
+ item_target_audio = gr.Audio(
63
+ label="Target Audio",
64
+ type="filepath",
65
+ interactive=False,
66
+ scale=8
67
+ )
68
+ item_refer_audio = gr.Audio(
69
+ label="Reference Audio",
70
+ type="filepath",
71
+ interactive=False,
72
+ scale=2
73
+ )
74
+
75
+ with gr.Row():
76
+ use_src_checkbox = gr.Checkbox(
77
+ label="Use Source Audio from Dataset",
78
+ value=True,
79
+ info="Check to use the source audio from dataset"
80
+ )
81
+
82
+ data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
83
+ auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
84
+
85
+ return {
86
+ "dataset_type": dataset_type,
87
+ "import_dataset_btn": import_dataset_btn,
88
+ "search_type": search_type,
89
+ "search_value": search_value,
90
+ "instruction_display": instruction_display,
91
+ "repaint_viz_plot": repaint_viz_plot,
92
+ "item_info_json": item_info_json,
93
+ "item_src_audio": item_src_audio,
94
+ "get_item_btn": get_item_btn,
95
+ "item_target_audio": item_target_audio,
96
+ "item_refer_audio": item_refer_audio,
97
+ "use_src_checkbox": use_src_checkbox,
98
+ "data_status": data_status,
99
+ "auto_fill_btn": auto_fill_btn,
100
+ }
101
+
code/acestep/gradio_ui/interfaces/generation.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Generation Section Module
3
+ Contains generation section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.constants import (
7
+ VALID_LANGUAGES,
8
+ TRACK_NAMES,
9
+ TASK_TYPES_TURBO,
10
+ TASK_TYPES_BASE,
11
+ DEFAULT_DIT_INSTRUCTION,
12
+ )
13
+ from acestep.gradio_ui.i18n import t
14
+
15
+
16
+ def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
17
+ """Create generation section
18
+
19
+ Args:
20
+ dit_handler: DiT handler instance
21
+ llm_handler: LM handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja')
25
+ """
26
+ # Check if service is pre-initialized
27
+ service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
28
+
29
+ # Check if running in service mode (restricted UI)
30
+ service_mode = init_params is not None and init_params.get('service_mode', False)
31
+
32
+ # Get current language from init_params if available
33
+ current_language = init_params.get('language', language) if init_params else language
34
+
35
+ with gr.Group():
36
+ # Service Configuration - collapse if pre-initialized, hide if in service mode
37
+ accordion_open = not service_pre_initialized
38
+ accordion_visible = not service_pre_initialized # Hide when running in service mode
39
+ with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
40
+ # Language selector at the top
41
+ with gr.Row():
42
+ language_dropdown = gr.Dropdown(
43
+ choices=[
44
+ ("English", "en"),
45
+ ("中文", "zh"),
46
+ ("日本語", "ja"),
47
+ ],
48
+ value=current_language,
49
+ label=t("service.language_label"),
50
+ info=t("service.language_info"),
51
+ scale=1,
52
+ )
53
+
54
+ # Dropdown options section - all dropdowns grouped together
55
+ with gr.Row(equal_height=True):
56
+ with gr.Column(scale=4):
57
+ # Set checkpoint value from init_params if pre-initialized
58
+ checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
59
+ checkpoint_dropdown = gr.Dropdown(
60
+ label=t("service.checkpoint_label"),
61
+ choices=dit_handler.get_available_checkpoints(),
62
+ value=checkpoint_value,
63
+ info=t("service.checkpoint_info")
64
+ )
65
+ with gr.Column(scale=1, min_width=90):
66
+ refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
67
+
68
+ with gr.Row():
69
+ # Get available acestep-v15- model list
70
+ available_models = dit_handler.get_available_acestep_v15_models()
71
+ default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
72
+
73
+ # Set config_path value from init_params if pre-initialized
74
+ config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
75
+ config_path = gr.Dropdown(
76
+ label=t("service.model_path_label"),
77
+ choices=available_models,
78
+ value=config_path_value,
79
+ info=t("service.model_path_info")
80
+ )
81
+ # Set device value from init_params if pre-initialized
82
+ device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
83
+ device = gr.Dropdown(
84
+ choices=["auto", "cuda", "cpu"],
85
+ value=device_value,
86
+ label=t("service.device_label"),
87
+ info=t("service.device_info")
88
+ )
89
+
90
+ with gr.Row():
91
+ # Get available 5Hz LM model list
92
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
93
+ default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
94
+
95
+ # Set lm_model_path value from init_params if pre-initialized
96
+ lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
97
+ lm_model_path = gr.Dropdown(
98
+ label=t("service.lm_model_path_label"),
99
+ choices=available_lm_models,
100
+ value=lm_model_path_value,
101
+ info=t("service.lm_model_path_info")
102
+ )
103
+ # Set backend value from init_params if pre-initialized
104
+ backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
105
+ backend_dropdown = gr.Dropdown(
106
+ choices=["vllm", "pt"],
107
+ value=backend_value,
108
+ label=t("service.backend_label"),
109
+ info=t("service.backend_info")
110
+ )
111
+
112
+ # Checkbox options section - all checkboxes grouped together
113
+ with gr.Row():
114
+ # Set init_llm value from init_params if pre-initialized
115
+ init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
116
+ init_llm_checkbox = gr.Checkbox(
117
+ label=t("service.init_llm_label"),
118
+ value=init_llm_value,
119
+ info=t("service.init_llm_info"),
120
+ )
121
+ # Auto-detect flash attention availability
122
+ flash_attn_available = dit_handler.is_flash_attention_available()
123
+ # Set use_flash_attention value from init_params if pre-initialized
124
+ use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
125
+ use_flash_attention_checkbox = gr.Checkbox(
126
+ label=t("service.flash_attention_label"),
127
+ value=use_flash_attention_value,
128
+ interactive=flash_attn_available,
129
+ info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
130
+ )
131
+ # Set offload_to_cpu value from init_params if pre-initialized
132
+ offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
133
+ offload_to_cpu_checkbox = gr.Checkbox(
134
+ label=t("service.offload_cpu_label"),
135
+ value=offload_to_cpu_value,
136
+ info=t("service.offload_cpu_info")
137
+ )
138
+ # Set offload_dit_to_cpu value from init_params if pre-initialized
139
+ offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
140
+ offload_dit_to_cpu_checkbox = gr.Checkbox(
141
+ label=t("service.offload_dit_cpu_label"),
142
+ value=offload_dit_to_cpu_value,
143
+ info=t("service.offload_dit_cpu_info")
144
+ )
145
+
146
+ init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
147
+ # Set init_status value from init_params if pre-initialized
148
+ init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
149
+ init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
150
+
151
+ # LoRA Configuration Section
152
+ gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
153
+ with gr.Row():
154
+ lora_path = gr.Textbox(
155
+ label="LoRA Path",
156
+ placeholder="./lora_output/final/adapter",
157
+ info="Path to trained LoRA adapter directory",
158
+ scale=3,
159
+ )
160
+ load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
161
+ unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
162
+ with gr.Row():
163
+ use_lora_checkbox = gr.Checkbox(
164
+ label="Use LoRA",
165
+ value=False,
166
+ info="Enable LoRA adapter for inference",
167
+ scale=1,
168
+ )
169
+ lora_status = gr.Textbox(
170
+ label="LoRA Status",
171
+ value="No LoRA loaded",
172
+ interactive=False,
173
+ scale=2,
174
+ )
175
+
176
+ # Inputs
177
+ with gr.Row():
178
+ with gr.Column(scale=2):
179
+ with gr.Accordion(t("generation.required_inputs"), open=True):
180
+ # Task type
181
+ # Determine initial task_type choices based on actual model in use
182
+ # When service is pre-initialized, use config_path from init_params
183
+ actual_model = init_params.get('config_path', default_model) if service_pre_initialized else default_model
184
+ actual_model_lower = (actual_model or "").lower()
185
+ if "turbo" in actual_model_lower:
186
+ initial_task_choices = TASK_TYPES_TURBO
187
+ else:
188
+ initial_task_choices = TASK_TYPES_BASE
189
+
190
+ with gr.Row(equal_height=True):
191
+ with gr.Column(scale=2):
192
+ task_type = gr.Dropdown(
193
+ choices=initial_task_choices,
194
+ value="text2music",
195
+ label=t("generation.task_type_label"),
196
+ info=t("generation.task_type_info"),
197
+ )
198
+ with gr.Column(scale=7):
199
+ instruction_display_gen = gr.Textbox(
200
+ label=t("generation.instruction_label"),
201
+ value=DEFAULT_DIT_INSTRUCTION,
202
+ interactive=False,
203
+ lines=1,
204
+ info=t("generation.instruction_info"),
205
+ )
206
+ with gr.Column(scale=1, min_width=100):
207
+ load_file = gr.UploadButton(
208
+ t("generation.load_btn"),
209
+ file_types=[".json"],
210
+ file_count="single",
211
+ variant="secondary",
212
+ size="sm",
213
+ )
214
+
215
+ track_name = gr.Dropdown(
216
+ choices=TRACK_NAMES,
217
+ value=None,
218
+ label=t("generation.track_name_label"),
219
+ info=t("generation.track_name_info"),
220
+ visible=False
221
+ )
222
+
223
+ complete_track_classes = gr.CheckboxGroup(
224
+ choices=TRACK_NAMES,
225
+ label=t("generation.track_classes_label"),
226
+ info=t("generation.track_classes_info"),
227
+ visible=False
228
+ )
229
+
230
+ # Audio uploads
231
+ audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False)
232
+ with audio_uploads_accordion:
233
+ with gr.Row(equal_height=True):
234
+ with gr.Column(scale=2):
235
+ reference_audio = gr.Audio(
236
+ label=t("generation.reference_audio"),
237
+ type="filepath",
238
+ )
239
+ with gr.Column(scale=7):
240
+ src_audio = gr.Audio(
241
+ label=t("generation.source_audio"),
242
+ type="filepath",
243
+ )
244
+ with gr.Column(scale=1, min_width=80):
245
+ convert_src_to_codes_btn = gr.Button(
246
+ t("generation.convert_codes_btn"),
247
+ variant="secondary",
248
+ size="sm"
249
+ )
250
+
251
+ # Audio Codes for text2music - single input for transcription or cover task
252
+ with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group:
253
+ with gr.Row(equal_height=True):
254
+ text2music_audio_code_string = gr.Textbox(
255
+ label=t("generation.lm_codes_label"),
256
+ placeholder=t("generation.lm_codes_placeholder"),
257
+ lines=6,
258
+ info=t("generation.lm_codes_info"),
259
+ scale=9,
260
+ )
261
+ transcribe_btn = gr.Button(
262
+ t("generation.transcribe_btn"),
263
+ variant="secondary",
264
+ size="sm",
265
+ scale=1,
266
+ )
267
+
268
+ # Repainting controls
269
+ with gr.Group(visible=False) as repainting_group:
270
+ gr.HTML(f"<h5>{t('generation.repainting_controls')}</h5>")
271
+ with gr.Row():
272
+ repainting_start = gr.Number(
273
+ label=t("generation.repainting_start"),
274
+ value=0.0,
275
+ step=0.1,
276
+ )
277
+ repainting_end = gr.Number(
278
+ label=t("generation.repainting_end"),
279
+ value=-1,
280
+ minimum=-1,
281
+ step=0.1,
282
+ )
283
+
284
+ # Simple/Custom Mode Toggle
285
+ # In service mode: only Custom mode, hide the toggle
286
+ with gr.Row(visible=not service_mode):
287
+ generation_mode = gr.Radio(
288
+ choices=[
289
+ (t("generation.mode_simple"), "simple"),
290
+ (t("generation.mode_custom"), "custom"),
291
+ ],
292
+ value="custom" if service_mode else "simple",
293
+ label=t("generation.mode_label"),
294
+ info=t("generation.mode_info"),
295
+ )
296
+
297
+ # Simple Mode Components - hidden in service mode
298
+ with gr.Group(visible=not service_mode) as simple_mode_group:
299
+ with gr.Row(equal_height=True):
300
+ simple_query_input = gr.Textbox(
301
+ label=t("generation.simple_query_label"),
302
+ placeholder=t("generation.simple_query_placeholder"),
303
+ lines=2,
304
+ info=t("generation.simple_query_info"),
305
+ scale=12,
306
+ )
307
+
308
+ with gr.Column(scale=1, min_width=100):
309
+ random_desc_btn = gr.Button(
310
+ "🎲",
311
+ variant="secondary",
312
+ size="sm",
313
+ scale=2
314
+ )
315
+
316
+ with gr.Row(equal_height=True):
317
+ with gr.Column(scale=1, variant="compact"):
318
+ simple_instrumental_checkbox = gr.Checkbox(
319
+ label=t("generation.instrumental_label"),
320
+ value=False,
321
+ )
322
+ with gr.Column(scale=18):
323
+ create_sample_btn = gr.Button(
324
+ t("generation.create_sample_btn"),
325
+ variant="primary",
326
+ size="lg",
327
+ )
328
+ with gr.Column(scale=1, variant="compact"):
329
+ simple_vocal_language = gr.Dropdown(
330
+ choices=VALID_LANGUAGES,
331
+ value="unknown",
332
+ allow_custom_value=True,
333
+ label=t("generation.simple_vocal_language_label"),
334
+ interactive=True,
335
+ )
336
+
337
+ # State to track if sample has been created in Simple mode
338
+ simple_sample_created = gr.State(value=False)
339
+
340
+ # Music Caption - wrapped in accordion that can be collapsed in Simple mode
341
+ # In service mode: auto-expand
342
+ with gr.Accordion(t("generation.caption_title"), open=service_mode) as caption_accordion:
343
+ with gr.Row(equal_height=True):
344
+ captions = gr.Textbox(
345
+ label=t("generation.caption_label"),
346
+ placeholder=t("generation.caption_placeholder"),
347
+ lines=3,
348
+ info=t("generation.caption_info"),
349
+ scale=12,
350
+ )
351
+ with gr.Column(scale=1, min_width=100):
352
+ sample_btn = gr.Button(
353
+ "🎲",
354
+ variant="secondary",
355
+ size="sm",
356
+ scale=2,
357
+ )
358
+ # Lyrics - wrapped in accordion that can be collapsed in Simple mode
359
+ # In service mode: auto-expand
360
+ with gr.Accordion(t("generation.lyrics_title"), open=service_mode) as lyrics_accordion:
361
+ lyrics = gr.Textbox(
362
+ label=t("generation.lyrics_label"),
363
+ placeholder=t("generation.lyrics_placeholder"),
364
+ lines=8,
365
+ info=t("generation.lyrics_info")
366
+ )
367
+
368
+ with gr.Row(variant="compact", equal_height=True):
369
+ instrumental_checkbox = gr.Checkbox(
370
+ label=t("generation.instrumental_label"),
371
+ value=False,
372
+ scale=1,
373
+ min_width=120,
374
+ container=True,
375
+ )
376
+
377
+ # 中间:语言选择 (Dropdown)
378
+ # 移除 gr.HTML hack,直接使用 label 参数,Gradio 会自动处理对齐
379
+ vocal_language = gr.Dropdown(
380
+ choices=VALID_LANGUAGES,
381
+ value="unknown",
382
+ label=t("generation.vocal_language_label"),
383
+ show_label=False,
384
+ container=True,
385
+ allow_custom_value=True,
386
+ scale=3,
387
+ )
388
+
389
+ # 右侧:格式化按钮 (Button)
390
+ # 放在同一行最右侧,操作更顺手
391
+ format_btn = gr.Button(
392
+ t("generation.format_btn"),
393
+ variant="secondary",
394
+ scale=1,
395
+ min_width=80,
396
+ )
397
+
398
+ # Optional Parameters
399
+ # In service mode: auto-expand
400
+ with gr.Accordion(t("generation.optional_params"), open=service_mode) as optional_params_accordion:
401
+ with gr.Row():
402
+ bpm = gr.Number(
403
+ label=t("generation.bpm_label"),
404
+ value=None,
405
+ step=1,
406
+ info=t("generation.bpm_info")
407
+ )
408
+ key_scale = gr.Textbox(
409
+ label=t("generation.keyscale_label"),
410
+ placeholder=t("generation.keyscale_placeholder"),
411
+ value="",
412
+ info=t("generation.keyscale_info")
413
+ )
414
+ time_signature = gr.Dropdown(
415
+ choices=["2", "3", "4", "N/A", ""],
416
+ value="",
417
+ label=t("generation.timesig_label"),
418
+ allow_custom_value=True,
419
+ info=t("generation.timesig_info")
420
+ )
421
+ audio_duration = gr.Number(
422
+ label=t("generation.duration_label"),
423
+ value=-1,
424
+ minimum=-1,
425
+ maximum=600.0,
426
+ step=0.1,
427
+ info=t("generation.duration_info")
428
+ )
429
+ batch_size_input = gr.Number(
430
+ label=t("generation.batch_size_label"),
431
+ value=2,
432
+ minimum=1,
433
+ maximum=8,
434
+ step=1,
435
+ info=t("generation.batch_size_info"),
436
+ interactive=not service_mode # Fixed in service mode
437
+ )
438
+
439
+ # Advanced Settings
440
+ # Default UI settings use turbo mode (max 20 steps, default 8, show shift with default 3)
441
+ # These will be updated after model initialization based on handler.is_turbo_model()
442
+ with gr.Accordion(t("generation.advanced_settings"), open=False):
443
+ with gr.Row():
444
+ inference_steps = gr.Slider(
445
+ minimum=1,
446
+ maximum=20,
447
+ value=8,
448
+ step=1,
449
+ label=t("generation.inference_steps_label"),
450
+ info=t("generation.inference_steps_info")
451
+ )
452
+ guidance_scale = gr.Slider(
453
+ minimum=1.0,
454
+ maximum=15.0,
455
+ value=7.0,
456
+ step=0.1,
457
+ label=t("generation.guidance_scale_label"),
458
+ info=t("generation.guidance_scale_info"),
459
+ visible=False
460
+ )
461
+ with gr.Column():
462
+ seed = gr.Textbox(
463
+ label=t("generation.seed_label"),
464
+ value="-1",
465
+ info=t("generation.seed_info")
466
+ )
467
+ random_seed_checkbox = gr.Checkbox(
468
+ label=t("generation.random_seed_label"),
469
+ value=True,
470
+ info=t("generation.random_seed_info")
471
+ )
472
+ audio_format = gr.Dropdown(
473
+ choices=["mp3", "flac"],
474
+ value="mp3",
475
+ label=t("generation.audio_format_label"),
476
+ info=t("generation.audio_format_info"),
477
+ interactive=not service_mode # Fixed in service mode
478
+ )
479
+
480
+ with gr.Row():
481
+ use_adg = gr.Checkbox(
482
+ label=t("generation.use_adg_label"),
483
+ value=False,
484
+ info=t("generation.use_adg_info"),
485
+ visible=False
486
+ )
487
+ shift = gr.Slider(
488
+ minimum=1.0,
489
+ maximum=5.0,
490
+ value=3.0,
491
+ step=0.1,
492
+ label=t("generation.shift_label"),
493
+ info=t("generation.shift_info"),
494
+ visible=True
495
+ )
496
+ infer_method = gr.Dropdown(
497
+ choices=["ode", "sde"],
498
+ value="ode",
499
+ label=t("generation.infer_method_label"),
500
+ info=t("generation.infer_method_info"),
501
+ )
502
+
503
+ with gr.Row():
504
+ custom_timesteps = gr.Textbox(
505
+ label=t("generation.custom_timesteps_label"),
506
+ placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
507
+ value="",
508
+ info=t("generation.custom_timesteps_info"),
509
+ )
510
+
511
+ with gr.Row():
512
+ cfg_interval_start = gr.Slider(
513
+ minimum=0.0,
514
+ maximum=1.0,
515
+ value=0.0,
516
+ step=0.01,
517
+ label=t("generation.cfg_interval_start"),
518
+ visible=False
519
+ )
520
+ cfg_interval_end = gr.Slider(
521
+ minimum=0.0,
522
+ maximum=1.0,
523
+ value=1.0,
524
+ step=0.01,
525
+ label=t("generation.cfg_interval_end"),
526
+ visible=False
527
+ )
528
+
529
+ # LM (Language Model) Parameters
530
+ gr.HTML(f"<h4>{t('generation.lm_params_title')}</h4>")
531
+ with gr.Row():
532
+ lm_temperature = gr.Slider(
533
+ label=t("generation.lm_temperature_label"),
534
+ minimum=0.0,
535
+ maximum=2.0,
536
+ value=0.85,
537
+ step=0.1,
538
+ scale=1,
539
+ info=t("generation.lm_temperature_info")
540
+ )
541
+ lm_cfg_scale = gr.Slider(
542
+ label=t("generation.lm_cfg_scale_label"),
543
+ minimum=1.0,
544
+ maximum=3.0,
545
+ value=2.0,
546
+ step=0.1,
547
+ scale=1,
548
+ info=t("generation.lm_cfg_scale_info")
549
+ )
550
+ lm_top_k = gr.Slider(
551
+ label=t("generation.lm_top_k_label"),
552
+ minimum=0,
553
+ maximum=100,
554
+ value=0,
555
+ step=1,
556
+ scale=1,
557
+ info=t("generation.lm_top_k_info")
558
+ )
559
+ lm_top_p = gr.Slider(
560
+ label=t("generation.lm_top_p_label"),
561
+ minimum=0.0,
562
+ maximum=1.0,
563
+ value=0.9,
564
+ step=0.01,
565
+ scale=1,
566
+ info=t("generation.lm_top_p_info")
567
+ )
568
+
569
+ with gr.Row():
570
+ lm_negative_prompt = gr.Textbox(
571
+ label=t("generation.lm_negative_prompt_label"),
572
+ value="NO USER INPUT",
573
+ placeholder=t("generation.lm_negative_prompt_placeholder"),
574
+ info=t("generation.lm_negative_prompt_info"),
575
+ lines=2,
576
+ scale=2,
577
+ )
578
+
579
+ with gr.Row():
580
+ use_cot_metas = gr.Checkbox(
581
+ label=t("generation.cot_metas_label"),
582
+ value=True,
583
+ info=t("generation.cot_metas_info"),
584
+ scale=1,
585
+ )
586
+ use_cot_language = gr.Checkbox(
587
+ label=t("generation.cot_language_label"),
588
+ value=True,
589
+ info=t("generation.cot_language_info"),
590
+ scale=1,
591
+ )
592
+ constrained_decoding_debug = gr.Checkbox(
593
+ label=t("generation.constrained_debug_label"),
594
+ value=False,
595
+ info=t("generation.constrained_debug_info"),
596
+ scale=1,
597
+ interactive=not service_mode # Fixed in service mode
598
+ )
599
+
600
+ with gr.Row():
601
+ auto_score = gr.Checkbox(
602
+ label=t("generation.auto_score_label"),
603
+ value=False,
604
+ info=t("generation.auto_score_info"),
605
+ scale=1,
606
+ interactive=not service_mode # Fixed in service mode
607
+ )
608
+ auto_lrc = gr.Checkbox(
609
+ label=t("generation.auto_lrc_label"),
610
+ value=False,
611
+ info=t("generation.auto_lrc_info"),
612
+ scale=1,
613
+ interactive=not service_mode # Fixed in service mode
614
+ )
615
+ lm_batch_chunk_size = gr.Number(
616
+ label=t("generation.lm_batch_chunk_label"),
617
+ value=8,
618
+ minimum=1,
619
+ maximum=32,
620
+ step=1,
621
+ info=t("generation.lm_batch_chunk_info"),
622
+ scale=1,
623
+ interactive=not service_mode # Fixed in service mode
624
+ )
625
+
626
+ with gr.Row():
627
+ audio_cover_strength = gr.Slider(
628
+ minimum=0.0,
629
+ maximum=1.0,
630
+ value=1.0,
631
+ step=0.01,
632
+ label=t("generation.codes_strength_label"),
633
+ info=t("generation.codes_strength_info"),
634
+ scale=1,
635
+ )
636
+ score_scale = gr.Slider(
637
+ minimum=0.01,
638
+ maximum=1.0,
639
+ value=0.5,
640
+ step=0.01,
641
+ label=t("generation.score_sensitivity_label"),
642
+ info=t("generation.score_sensitivity_info"),
643
+ scale=1,
644
+ visible=not service_mode # Hidden in service mode
645
+ )
646
+
647
+ # Set generate_btn to interactive if service is pre-initialized
648
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
649
+ with gr.Row(equal_height=True):
650
+ with gr.Column(scale=1, variant="compact"):
651
+ think_checkbox = gr.Checkbox(
652
+ label=t("generation.think_label"),
653
+ value=True,
654
+ scale=1,
655
+ )
656
+ allow_lm_batch = gr.Checkbox(
657
+ label=t("generation.parallel_thinking_label"),
658
+ value=True,
659
+ scale=1,
660
+ )
661
+ with gr.Column(scale=18):
662
+ generate_btn = gr.Button(t("generation.generate_btn"), variant="primary", size="lg", interactive=generate_btn_interactive)
663
+ with gr.Column(scale=1, variant="compact"):
664
+ autogen_checkbox = gr.Checkbox(
665
+ label=t("generation.autogen_label"),
666
+ value=False, # Default to False for both service and local modes
667
+ scale=1,
668
+ interactive=not service_mode # Not selectable in service mode
669
+ )
670
+ use_cot_caption = gr.Checkbox(
671
+ label=t("generation.caption_rewrite_label"),
672
+ value=True,
673
+ scale=1,
674
+ )
675
+
676
+ return {
677
+ "service_config_accordion": service_config_accordion,
678
+ "language_dropdown": language_dropdown,
679
+ "checkpoint_dropdown": checkpoint_dropdown,
680
+ "refresh_btn": refresh_btn,
681
+ "config_path": config_path,
682
+ "device": device,
683
+ "init_btn": init_btn,
684
+ "init_status": init_status,
685
+ "lm_model_path": lm_model_path,
686
+ "init_llm_checkbox": init_llm_checkbox,
687
+ "backend_dropdown": backend_dropdown,
688
+ "use_flash_attention_checkbox": use_flash_attention_checkbox,
689
+ "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
690
+ "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
691
+ # LoRA components
692
+ "lora_path": lora_path,
693
+ "load_lora_btn": load_lora_btn,
694
+ "unload_lora_btn": unload_lora_btn,
695
+ "use_lora_checkbox": use_lora_checkbox,
696
+ "lora_status": lora_status,
697
+ "task_type": task_type,
698
+ "instruction_display_gen": instruction_display_gen,
699
+ "track_name": track_name,
700
+ "complete_track_classes": complete_track_classes,
701
+ "audio_uploads_accordion": audio_uploads_accordion,
702
+ "reference_audio": reference_audio,
703
+ "src_audio": src_audio,
704
+ "convert_src_to_codes_btn": convert_src_to_codes_btn,
705
+ "text2music_audio_code_string": text2music_audio_code_string,
706
+ "transcribe_btn": transcribe_btn,
707
+ "text2music_audio_codes_group": text2music_audio_codes_group,
708
+ "lm_temperature": lm_temperature,
709
+ "lm_cfg_scale": lm_cfg_scale,
710
+ "lm_top_k": lm_top_k,
711
+ "lm_top_p": lm_top_p,
712
+ "lm_negative_prompt": lm_negative_prompt,
713
+ "use_cot_metas": use_cot_metas,
714
+ "use_cot_caption": use_cot_caption,
715
+ "use_cot_language": use_cot_language,
716
+ "repainting_group": repainting_group,
717
+ "repainting_start": repainting_start,
718
+ "repainting_end": repainting_end,
719
+ "audio_cover_strength": audio_cover_strength,
720
+ # Simple/Custom Mode Components
721
+ "generation_mode": generation_mode,
722
+ "simple_mode_group": simple_mode_group,
723
+ "simple_query_input": simple_query_input,
724
+ "random_desc_btn": random_desc_btn,
725
+ "simple_instrumental_checkbox": simple_instrumental_checkbox,
726
+ "simple_vocal_language": simple_vocal_language,
727
+ "create_sample_btn": create_sample_btn,
728
+ "simple_sample_created": simple_sample_created,
729
+ "caption_accordion": caption_accordion,
730
+ "lyrics_accordion": lyrics_accordion,
731
+ "optional_params_accordion": optional_params_accordion,
732
+ # Existing components
733
+ "captions": captions,
734
+ "sample_btn": sample_btn,
735
+ "load_file": load_file,
736
+ "lyrics": lyrics,
737
+ "vocal_language": vocal_language,
738
+ "bpm": bpm,
739
+ "key_scale": key_scale,
740
+ "time_signature": time_signature,
741
+ "audio_duration": audio_duration,
742
+ "batch_size_input": batch_size_input,
743
+ "inference_steps": inference_steps,
744
+ "guidance_scale": guidance_scale,
745
+ "seed": seed,
746
+ "random_seed_checkbox": random_seed_checkbox,
747
+ "use_adg": use_adg,
748
+ "cfg_interval_start": cfg_interval_start,
749
+ "cfg_interval_end": cfg_interval_end,
750
+ "shift": shift,
751
+ "infer_method": infer_method,
752
+ "custom_timesteps": custom_timesteps,
753
+ "audio_format": audio_format,
754
+ "think_checkbox": think_checkbox,
755
+ "autogen_checkbox": autogen_checkbox,
756
+ "generate_btn": generate_btn,
757
+ "instrumental_checkbox": instrumental_checkbox,
758
+ "format_btn": format_btn,
759
+ "constrained_decoding_debug": constrained_decoding_debug,
760
+ "score_scale": score_scale,
761
+ "allow_lm_batch": allow_lm_batch,
762
+ "auto_score": auto_score,
763
+ "auto_lrc": auto_lrc,
764
+ "lm_batch_chunk_size": lm_batch_chunk_size,
765
+ }
766
+
code/acestep/gradio_ui/interfaces/result.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Results Section Module
3
+ Contains results display section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import t
7
+
8
+
9
+ def create_results_section(dit_handler) -> dict:
10
+ """Create results display section"""
11
+ with gr.Accordion(t("results.title"), open=True):
12
+ # Hidden state to store LM-generated metadata
13
+ lm_metadata_state = gr.State(value=None)
14
+
15
+ # Hidden state to track if caption/metadata is from formatted source (LM/transcription)
16
+ is_format_caption_state = gr.State(value=False)
17
+
18
+ # Batch management states
19
+ current_batch_index = gr.State(value=0) # Currently displayed batch index
20
+ total_batches = gr.State(value=1) # Total number of batches generated
21
+ batch_queue = gr.State(value={}) # Dictionary storing all batch data
22
+ generation_params_state = gr.State(value={}) # Store generation parameters for next batches
23
+ is_generating_background = gr.State(value=False) # Background generation flag
24
+
25
+ # All audio components in one row with dynamic visibility
26
+ with gr.Row():
27
+ with gr.Column(visible=True) as audio_col_1:
28
+ generated_audio_1 = gr.Audio(
29
+ label=t("results.generated_music", n=1),
30
+ type="filepath",
31
+ interactive=False,
32
+ buttons=[]
33
+ )
34
+ with gr.Row(equal_height=True):
35
+ send_to_src_btn_1 = gr.Button(
36
+ t("results.send_to_src_btn"),
37
+ variant="secondary",
38
+ size="sm",
39
+ scale=1
40
+ )
41
+ save_btn_1 = gr.Button(
42
+ t("results.save_btn"),
43
+ variant="primary",
44
+ size="sm",
45
+ scale=1
46
+ )
47
+ score_btn_1 = gr.Button(
48
+ t("results.score_btn"),
49
+ variant="secondary",
50
+ size="sm",
51
+ scale=1
52
+ )
53
+ lrc_btn_1 = gr.Button(
54
+ t("results.lrc_btn"),
55
+ variant="secondary",
56
+ size="sm",
57
+ scale=1
58
+ )
59
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
60
+ codes_display_1 = gr.Textbox(
61
+ label=t("results.codes_label", n=1),
62
+ interactive=False,
63
+ buttons=["copy"],
64
+ lines=4,
65
+ max_lines=4,
66
+ visible=True
67
+ )
68
+ score_display_1 = gr.Textbox(
69
+ label=t("results.quality_score_label", n=1),
70
+ interactive=False,
71
+ buttons=["copy"],
72
+ lines=6,
73
+ max_lines=6,
74
+ visible=True
75
+ )
76
+ lrc_display_1 = gr.Textbox(
77
+ label=t("results.lrc_label", n=1),
78
+ interactive=True,
79
+ buttons=["copy"],
80
+ lines=8,
81
+ max_lines=8,
82
+ visible=True
83
+ )
84
+ with gr.Column(visible=True) as audio_col_2:
85
+ generated_audio_2 = gr.Audio(
86
+ label=t("results.generated_music", n=2),
87
+ type="filepath",
88
+ interactive=False,
89
+ buttons=[]
90
+ )
91
+ with gr.Row(equal_height=True):
92
+ send_to_src_btn_2 = gr.Button(
93
+ t("results.send_to_src_btn"),
94
+ variant="secondary",
95
+ size="sm",
96
+ scale=1
97
+ )
98
+ save_btn_2 = gr.Button(
99
+ t("results.save_btn"),
100
+ variant="primary",
101
+ size="sm",
102
+ scale=1
103
+ )
104
+ score_btn_2 = gr.Button(
105
+ t("results.score_btn"),
106
+ variant="secondary",
107
+ size="sm",
108
+ scale=1
109
+ )
110
+ lrc_btn_2 = gr.Button(
111
+ t("results.lrc_btn"),
112
+ variant="secondary",
113
+ size="sm",
114
+ scale=1
115
+ )
116
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
117
+ codes_display_2 = gr.Textbox(
118
+ label=t("results.codes_label", n=2),
119
+ interactive=False,
120
+ buttons=["copy"],
121
+ lines=4,
122
+ max_lines=4,
123
+ visible=True
124
+ )
125
+ score_display_2 = gr.Textbox(
126
+ label=t("results.quality_score_label", n=2),
127
+ interactive=False,
128
+ buttons=["copy"],
129
+ lines=6,
130
+ max_lines=6,
131
+ visible=True
132
+ )
133
+ lrc_display_2 = gr.Textbox(
134
+ label=t("results.lrc_label", n=2),
135
+ interactive=True,
136
+ buttons=["copy"],
137
+ lines=8,
138
+ max_lines=8,
139
+ visible=True
140
+ )
141
+ with gr.Column(visible=False) as audio_col_3:
142
+ generated_audio_3 = gr.Audio(
143
+ label=t("results.generated_music", n=3),
144
+ type="filepath",
145
+ interactive=False,
146
+ buttons=[]
147
+ )
148
+ with gr.Row(equal_height=True):
149
+ send_to_src_btn_3 = gr.Button(
150
+ t("results.send_to_src_btn"),
151
+ variant="secondary",
152
+ size="sm",
153
+ scale=1
154
+ )
155
+ save_btn_3 = gr.Button(
156
+ t("results.save_btn"),
157
+ variant="primary",
158
+ size="sm",
159
+ scale=1
160
+ )
161
+ score_btn_3 = gr.Button(
162
+ t("results.score_btn"),
163
+ variant="secondary",
164
+ size="sm",
165
+ scale=1
166
+ )
167
+ lrc_btn_3 = gr.Button(
168
+ t("results.lrc_btn"),
169
+ variant="secondary",
170
+ size="sm",
171
+ scale=1
172
+ )
173
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
174
+ codes_display_3 = gr.Textbox(
175
+ label=t("results.codes_label", n=3),
176
+ interactive=False,
177
+ buttons=["copy"],
178
+ lines=4,
179
+ max_lines=4,
180
+ visible=True
181
+ )
182
+ score_display_3 = gr.Textbox(
183
+ label=t("results.quality_score_label", n=3),
184
+ interactive=False,
185
+ buttons=["copy"],
186
+ lines=6,
187
+ max_lines=6,
188
+ visible=True
189
+ )
190
+ lrc_display_3 = gr.Textbox(
191
+ label=t("results.lrc_label", n=3),
192
+ interactive=True,
193
+ buttons=["copy"],
194
+ lines=8,
195
+ max_lines=8,
196
+ visible=True
197
+ )
198
+ with gr.Column(visible=False) as audio_col_4:
199
+ generated_audio_4 = gr.Audio(
200
+ label=t("results.generated_music", n=4),
201
+ type="filepath",
202
+ interactive=False,
203
+ buttons=[]
204
+ )
205
+ with gr.Row(equal_height=True):
206
+ send_to_src_btn_4 = gr.Button(
207
+ t("results.send_to_src_btn"),
208
+ variant="secondary",
209
+ size="sm",
210
+ scale=1
211
+ )
212
+ save_btn_4 = gr.Button(
213
+ t("results.save_btn"),
214
+ variant="primary",
215
+ size="sm",
216
+ scale=1
217
+ )
218
+ score_btn_4 = gr.Button(
219
+ t("results.score_btn"),
220
+ variant="secondary",
221
+ size="sm",
222
+ scale=1
223
+ )
224
+ lrc_btn_4 = gr.Button(
225
+ t("results.lrc_btn"),
226
+ variant="secondary",
227
+ size="sm",
228
+ scale=1
229
+ )
230
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
231
+ codes_display_4 = gr.Textbox(
232
+ label=t("results.codes_label", n=4),
233
+ interactive=False,
234
+ buttons=["copy"],
235
+ lines=4,
236
+ max_lines=4,
237
+ visible=True
238
+ )
239
+ score_display_4 = gr.Textbox(
240
+ label=t("results.quality_score_label", n=4),
241
+ interactive=False,
242
+ buttons=["copy"],
243
+ lines=6,
244
+ max_lines=6,
245
+ visible=True
246
+ )
247
+ lrc_display_4 = gr.Textbox(
248
+ label=t("results.lrc_label", n=4),
249
+ interactive=True,
250
+ buttons=["copy"],
251
+ lines=8,
252
+ max_lines=8,
253
+ visible=True
254
+ )
255
+
256
+ # Second row for batch size 5-8 (initially hidden)
257
+ with gr.Row(visible=False) as audio_row_5_8:
258
+ with gr.Column() as audio_col_5:
259
+ generated_audio_5 = gr.Audio(
260
+ label=t("results.generated_music", n=5),
261
+ type="filepath",
262
+ interactive=False,
263
+ buttons=[]
264
+ )
265
+ with gr.Row(equal_height=True):
266
+ send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
267
+ save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
268
+ score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
269
+ lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
270
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
271
+ codes_display_5 = gr.Textbox(
272
+ label=t("results.codes_label", n=5),
273
+ interactive=False,
274
+ buttons=["copy"],
275
+ lines=4,
276
+ max_lines=4,
277
+ visible=True
278
+ )
279
+ score_display_5 = gr.Textbox(
280
+ label=t("results.quality_score_label", n=5),
281
+ interactive=False,
282
+ buttons=["copy"],
283
+ lines=6,
284
+ max_lines=6,
285
+ visible=True
286
+ )
287
+ lrc_display_5 = gr.Textbox(
288
+ label=t("results.lrc_label", n=5),
289
+ interactive=True,
290
+ buttons=["copy"],
291
+ lines=8,
292
+ max_lines=8,
293
+ visible=True
294
+ )
295
+ with gr.Column() as audio_col_6:
296
+ generated_audio_6 = gr.Audio(
297
+ label=t("results.generated_music", n=6),
298
+ type="filepath",
299
+ interactive=False,
300
+ buttons=[]
301
+ )
302
+ with gr.Row(equal_height=True):
303
+ send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
304
+ save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
305
+ score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
306
+ lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
307
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
308
+ codes_display_6 = gr.Textbox(
309
+ label=t("results.codes_label", n=6),
310
+ interactive=False,
311
+ buttons=["copy"],
312
+ lines=4,
313
+ max_lines=4,
314
+ visible=True
315
+ )
316
+ score_display_6 = gr.Textbox(
317
+ label=t("results.quality_score_label", n=6),
318
+ interactive=False,
319
+ buttons=["copy"],
320
+ lines=6,
321
+ max_lines=6,
322
+ visible=True
323
+ )
324
+ lrc_display_6 = gr.Textbox(
325
+ label=t("results.lrc_label", n=6),
326
+ interactive=True,
327
+ buttons=["copy"],
328
+ lines=8,
329
+ max_lines=8,
330
+ visible=True
331
+ )
332
+ with gr.Column() as audio_col_7:
333
+ generated_audio_7 = gr.Audio(
334
+ label=t("results.generated_music", n=7),
335
+ type="filepath",
336
+ interactive=False,
337
+ buttons=[]
338
+ )
339
+ with gr.Row(equal_height=True):
340
+ send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
341
+ save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
342
+ score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
343
+ lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
344
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
345
+ codes_display_7 = gr.Textbox(
346
+ label=t("results.codes_label", n=7),
347
+ interactive=False,
348
+ buttons=["copy"],
349
+ lines=4,
350
+ max_lines=4,
351
+ visible=True
352
+ )
353
+ score_display_7 = gr.Textbox(
354
+ label=t("results.quality_score_label", n=7),
355
+ interactive=False,
356
+ buttons=["copy"],
357
+ lines=6,
358
+ max_lines=6,
359
+ visible=True
360
+ )
361
+ lrc_display_7 = gr.Textbox(
362
+ label=t("results.lrc_label", n=7),
363
+ interactive=True,
364
+ buttons=["copy"],
365
+ lines=8,
366
+ max_lines=8,
367
+ visible=True
368
+ )
369
+ with gr.Column() as audio_col_8:
370
+ generated_audio_8 = gr.Audio(
371
+ label=t("results.generated_music", n=8),
372
+ type="filepath",
373
+ interactive=False,
374
+ buttons=[]
375
+ )
376
+ with gr.Row(equal_height=True):
377
+ send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
378
+ save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
379
+ score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
380
+ lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
381
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
382
+ codes_display_8 = gr.Textbox(
383
+ label=t("results.codes_label", n=8),
384
+ interactive=False,
385
+ buttons=["copy"],
386
+ lines=4,
387
+ max_lines=4,
388
+ visible=True
389
+ )
390
+ score_display_8 = gr.Textbox(
391
+ label=t("results.quality_score_label", n=8),
392
+ interactive=False,
393
+ buttons=["copy"],
394
+ lines=6,
395
+ max_lines=6,
396
+ visible=True
397
+ )
398
+ lrc_display_8 = gr.Textbox(
399
+ label=t("results.lrc_label", n=8),
400
+ interactive=True,
401
+ buttons=["copy"],
402
+ lines=8,
403
+ max_lines=8,
404
+ visible=True
405
+ )
406
+
407
+ status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
408
+
409
+ # Batch navigation controls
410
+ with gr.Row(equal_height=True):
411
+ prev_batch_btn = gr.Button(
412
+ t("results.prev_btn"),
413
+ variant="secondary",
414
+ interactive=False,
415
+ scale=1,
416
+ size="sm"
417
+ )
418
+ batch_indicator = gr.Textbox(
419
+ label=t("results.current_batch"),
420
+ value=t("results.batch_indicator", current=1, total=1),
421
+ interactive=False,
422
+ scale=3
423
+ )
424
+ next_batch_status = gr.Textbox(
425
+ label=t("results.next_batch_status"),
426
+ value="",
427
+ interactive=False,
428
+ scale=3
429
+ )
430
+ next_batch_btn = gr.Button(
431
+ t("results.next_btn"),
432
+ variant="primary",
433
+ interactive=False,
434
+ scale=1,
435
+ size="sm"
436
+ )
437
+
438
+ # One-click restore parameters button
439
+ restore_params_btn = gr.Button(
440
+ t("results.restore_params_btn"),
441
+ variant="secondary",
442
+ interactive=False, # Initially disabled, enabled after generation
443
+ size="sm"
444
+ )
445
+
446
+ with gr.Accordion(t("results.batch_results_title"), open=False):
447
+ generated_audio_batch = gr.File(
448
+ label=t("results.all_files_label"),
449
+ file_count="multiple",
450
+ interactive=False
451
+ )
452
+ generation_info = gr.Markdown(label=t("results.generation_details"))
453
+
454
+ return {
455
+ "lm_metadata_state": lm_metadata_state,
456
+ "is_format_caption_state": is_format_caption_state,
457
+ "current_batch_index": current_batch_index,
458
+ "total_batches": total_batches,
459
+ "batch_queue": batch_queue,
460
+ "generation_params_state": generation_params_state,
461
+ "is_generating_background": is_generating_background,
462
+ "status_output": status_output,
463
+ "prev_batch_btn": prev_batch_btn,
464
+ "batch_indicator": batch_indicator,
465
+ "next_batch_btn": next_batch_btn,
466
+ "next_batch_status": next_batch_status,
467
+ "restore_params_btn": restore_params_btn,
468
+ "generated_audio_1": generated_audio_1,
469
+ "generated_audio_2": generated_audio_2,
470
+ "generated_audio_3": generated_audio_3,
471
+ "generated_audio_4": generated_audio_4,
472
+ "generated_audio_5": generated_audio_5,
473
+ "generated_audio_6": generated_audio_6,
474
+ "generated_audio_7": generated_audio_7,
475
+ "generated_audio_8": generated_audio_8,
476
+ "audio_row_5_8": audio_row_5_8,
477
+ "audio_col_1": audio_col_1,
478
+ "audio_col_2": audio_col_2,
479
+ "audio_col_3": audio_col_3,
480
+ "audio_col_4": audio_col_4,
481
+ "audio_col_5": audio_col_5,
482
+ "audio_col_6": audio_col_6,
483
+ "audio_col_7": audio_col_7,
484
+ "audio_col_8": audio_col_8,
485
+ "send_to_src_btn_1": send_to_src_btn_1,
486
+ "send_to_src_btn_2": send_to_src_btn_2,
487
+ "send_to_src_btn_3": send_to_src_btn_3,
488
+ "send_to_src_btn_4": send_to_src_btn_4,
489
+ "send_to_src_btn_5": send_to_src_btn_5,
490
+ "send_to_src_btn_6": send_to_src_btn_6,
491
+ "send_to_src_btn_7": send_to_src_btn_7,
492
+ "send_to_src_btn_8": send_to_src_btn_8,
493
+ "save_btn_1": save_btn_1,
494
+ "save_btn_2": save_btn_2,
495
+ "save_btn_3": save_btn_3,
496
+ "save_btn_4": save_btn_4,
497
+ "save_btn_5": save_btn_5,
498
+ "save_btn_6": save_btn_6,
499
+ "save_btn_7": save_btn_7,
500
+ "save_btn_8": save_btn_8,
501
+ "score_btn_1": score_btn_1,
502
+ "score_btn_2": score_btn_2,
503
+ "score_btn_3": score_btn_3,
504
+ "score_btn_4": score_btn_4,
505
+ "score_btn_5": score_btn_5,
506
+ "score_btn_6": score_btn_6,
507
+ "score_btn_7": score_btn_7,
508
+ "score_btn_8": score_btn_8,
509
+ "score_display_1": score_display_1,
510
+ "score_display_2": score_display_2,
511
+ "score_display_3": score_display_3,
512
+ "score_display_4": score_display_4,
513
+ "score_display_5": score_display_5,
514
+ "score_display_6": score_display_6,
515
+ "score_display_7": score_display_7,
516
+ "score_display_8": score_display_8,
517
+ "codes_display_1": codes_display_1,
518
+ "codes_display_2": codes_display_2,
519
+ "codes_display_3": codes_display_3,
520
+ "codes_display_4": codes_display_4,
521
+ "codes_display_5": codes_display_5,
522
+ "codes_display_6": codes_display_6,
523
+ "codes_display_7": codes_display_7,
524
+ "codes_display_8": codes_display_8,
525
+ "lrc_btn_1": lrc_btn_1,
526
+ "lrc_btn_2": lrc_btn_2,
527
+ "lrc_btn_3": lrc_btn_3,
528
+ "lrc_btn_4": lrc_btn_4,
529
+ "lrc_btn_5": lrc_btn_5,
530
+ "lrc_btn_6": lrc_btn_6,
531
+ "lrc_btn_7": lrc_btn_7,
532
+ "lrc_btn_8": lrc_btn_8,
533
+ "lrc_display_1": lrc_display_1,
534
+ "lrc_display_2": lrc_display_2,
535
+ "lrc_display_3": lrc_display_3,
536
+ "lrc_display_4": lrc_display_4,
537
+ "lrc_display_5": lrc_display_5,
538
+ "lrc_display_6": lrc_display_6,
539
+ "lrc_display_7": lrc_display_7,
540
+ "lrc_display_8": lrc_display_8,
541
+ "details_accordion_1": details_accordion_1,
542
+ "details_accordion_2": details_accordion_2,
543
+ "details_accordion_3": details_accordion_3,
544
+ "details_accordion_4": details_accordion_4,
545
+ "details_accordion_5": details_accordion_5,
546
+ "details_accordion_6": details_accordion_6,
547
+ "details_accordion_7": details_accordion_7,
548
+ "details_accordion_8": details_accordion_8,
549
+ "generated_audio_batch": generated_audio_batch,
550
+ "generation_info": generation_info,
551
+ }
552
+
code/acestep/gradio_ui/interfaces/training.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Training Tab Module
3
+
4
+ Contains the dataset builder and LoRA training interface components.
5
+ """
6
+
7
+ import os
8
+ import gradio as gr
9
+ from acestep.gradio_ui.i18n import t
10
+
11
+
12
+ def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
13
+ """Create the training tab section with dataset builder and training controls.
14
+
15
+ Args:
16
+ dit_handler: DiT handler instance
17
+ llm_handler: LLM handler instance
18
+ init_params: Dictionary containing initialization parameters and state.
19
+ If None, service will not be pre-initialized.
20
+
21
+ Returns:
22
+ Dictionary of Gradio components for event handling
23
+ """
24
+ # Check if running in service mode (hide training tab)
25
+ service_mode = init_params is not None and init_params.get('service_mode', False)
26
+
27
+ with gr.Tab("🎓 LoRA Training", visible=not service_mode):
28
+ gr.HTML("""
29
+ <div style="text-align: center; padding: 10px; margin-bottom: 15px;">
30
+ <h2>🎵 LoRA Training for ACE-Step</h2>
31
+ <p>Build datasets from your audio files and train custom LoRA adapters</p>
32
+ </div>
33
+ """)
34
+
35
+ with gr.Tabs():
36
+ # ==================== Dataset Builder Tab ====================
37
+ with gr.Tab("📁 Dataset Builder"):
38
+ # ========== Load Existing OR Scan New ==========
39
+ gr.HTML("""
40
+ <div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
41
+ <h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
42
+ <p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
43
+ </div>
44
+ """)
45
+
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ gr.HTML("<h4>📂 Load Existing Dataset</h4>")
49
+ with gr.Row():
50
+ load_json_path = gr.Textbox(
51
+ label="Dataset JSON Path",
52
+ placeholder="./datasets/my_lora_dataset.json",
53
+ info="Load a previously saved dataset",
54
+ scale=3,
55
+ )
56
+ load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
57
+ load_json_status = gr.Textbox(
58
+ label="Load Status",
59
+ interactive=False,
60
+ )
61
+
62
+ with gr.Column(scale=1):
63
+ gr.HTML("<h4>🔍 Scan New Directory</h4>")
64
+ with gr.Row():
65
+ audio_directory = gr.Textbox(
66
+ label="Audio Directory Path",
67
+ placeholder="/path/to/your/audio/folder",
68
+ info="Scan for audio files (wav, mp3, flac, ogg, opus)",
69
+ scale=3,
70
+ )
71
+ scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
72
+ scan_status = gr.Textbox(
73
+ label="Scan Status",
74
+ interactive=False,
75
+ )
76
+
77
+ gr.HTML("<hr>")
78
+
79
+ with gr.Row():
80
+ with gr.Column(scale=2):
81
+
82
+ # Audio files table
83
+ audio_files_table = gr.Dataframe(
84
+ headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
85
+ datatype=["number", "str", "str", "str", "str", "str", "str"],
86
+ label="Found Audio Files",
87
+ interactive=False,
88
+ wrap=True,
89
+ )
90
+
91
+ with gr.Column(scale=1):
92
+ gr.HTML("<h3>⚙️ Dataset Settings</h3>")
93
+
94
+ dataset_name = gr.Textbox(
95
+ label="Dataset Name",
96
+ value="my_lora_dataset",
97
+ placeholder="Enter dataset name",
98
+ )
99
+
100
+ all_instrumental = gr.Checkbox(
101
+ label="All Instrumental",
102
+ value=True,
103
+ info="Check if all tracks are instrumental (no vocals)",
104
+ )
105
+
106
+ need_lyrics = gr.Checkbox(
107
+ label="Transcribe Lyrics",
108
+ value=False,
109
+ info="Attempt to transcribe lyrics (slower)",
110
+ interactive=False, # Disabled for now
111
+ )
112
+
113
+ custom_tag = gr.Textbox(
114
+ label="Custom Activation Tag",
115
+ placeholder="e.g., 8bit_retro, my_style",
116
+ info="Unique tag to activate this LoRA's style",
117
+ )
118
+
119
+ tag_position = gr.Radio(
120
+ choices=[
121
+ ("Prepend (tag, caption)", "prepend"),
122
+ ("Append (caption, tag)", "append"),
123
+ ("Replace caption", "replace"),
124
+ ],
125
+ value="replace",
126
+ label="Tag Position",
127
+ info="Where to place the custom tag in the caption",
128
+ )
129
+
130
+ gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
131
+
132
+ with gr.Row():
133
+ with gr.Column(scale=3):
134
+ gr.Markdown("""
135
+ Click the button below to automatically generate metadata for all audio files using AI:
136
+ - **Caption**: Music style, genre, mood description
137
+ - **BPM**: Beats per minute
138
+ - **Key**: Musical key (e.g., C Major, Am)
139
+ - **Time Signature**: 4/4, 3/4, etc.
140
+ """)
141
+ skip_metas = gr.Checkbox(
142
+ label="Skip Metas (No LLM)",
143
+ value=False,
144
+ info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
145
+ )
146
+ with gr.Column(scale=1):
147
+ auto_label_btn = gr.Button(
148
+ "🏷️ Auto-Label All",
149
+ variant="primary",
150
+ size="lg",
151
+ )
152
+
153
+ label_progress = gr.Textbox(
154
+ label="Labeling Progress",
155
+ interactive=False,
156
+ lines=2,
157
+ )
158
+
159
+ gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ sample_selector = gr.Slider(
164
+ minimum=0,
165
+ maximum=0,
166
+ step=1,
167
+ value=0,
168
+ label="Select Sample #",
169
+ info="Choose a sample to preview and edit",
170
+ )
171
+
172
+ preview_audio = gr.Audio(
173
+ label="Audio Preview",
174
+ type="filepath",
175
+ interactive=False,
176
+ )
177
+
178
+ preview_filename = gr.Textbox(
179
+ label="Filename",
180
+ interactive=False,
181
+ )
182
+
183
+ with gr.Column(scale=2):
184
+ with gr.Row():
185
+ edit_caption = gr.Textbox(
186
+ label="Caption",
187
+ lines=3,
188
+ placeholder="Music description...",
189
+ )
190
+
191
+ with gr.Row():
192
+ edit_lyrics = gr.Textbox(
193
+ label="Lyrics",
194
+ lines=4,
195
+ placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
196
+ )
197
+
198
+ with gr.Row():
199
+ edit_bpm = gr.Number(
200
+ label="BPM",
201
+ precision=0,
202
+ )
203
+ edit_keyscale = gr.Textbox(
204
+ label="Key",
205
+ placeholder="C Major",
206
+ )
207
+ edit_timesig = gr.Dropdown(
208
+ choices=["", "2", "3", "4", "6"],
209
+ label="Time Signature",
210
+ )
211
+ edit_duration = gr.Number(
212
+ label="Duration (s)",
213
+ precision=1,
214
+ interactive=False,
215
+ )
216
+
217
+ with gr.Row():
218
+ edit_language = gr.Dropdown(
219
+ choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
220
+ value="instrumental",
221
+ label="Language",
222
+ )
223
+ edit_instrumental = gr.Checkbox(
224
+ label="Instrumental",
225
+ value=True,
226
+ )
227
+ save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
228
+
229
+ edit_status = gr.Textbox(
230
+ label="Edit Status",
231
+ interactive=False,
232
+ )
233
+
234
+ gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=3):
238
+ save_path = gr.Textbox(
239
+ label="Save Path",
240
+ value="./datasets/my_lora_dataset.json",
241
+ placeholder="./datasets/dataset_name.json",
242
+ info="Path where the dataset JSON will be saved",
243
+ )
244
+ with gr.Column(scale=1):
245
+ save_dataset_btn = gr.Button(
246
+ "💾 Save Dataset",
247
+ variant="primary",
248
+ size="lg",
249
+ )
250
+
251
+ save_status = gr.Textbox(
252
+ label="Save Status",
253
+ interactive=False,
254
+ lines=2,
255
+ )
256
+
257
+ gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
258
+
259
+ gr.Markdown("""
260
+ **Preprocessing converts your dataset to pre-computed tensors for fast training.**
261
+
262
+ You can either:
263
+ - Use the dataset from Steps 1-4 above, **OR**
264
+ - Load an existing dataset JSON file (if you've already saved one)
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=3):
269
+ load_existing_dataset_path = gr.Textbox(
270
+ label="Load Existing Dataset (Optional)",
271
+ placeholder="./datasets/my_lora_dataset.json",
272
+ info="Path to a previously saved dataset JSON file",
273
+ )
274
+ with gr.Column(scale=1):
275
+ load_existing_dataset_btn = gr.Button(
276
+ "📂 Load Dataset",
277
+ variant="secondary",
278
+ size="lg",
279
+ )
280
+
281
+ load_existing_status = gr.Textbox(
282
+ label="Load Status",
283
+ interactive=False,
284
+ )
285
+
286
+ gr.Markdown("""
287
+ This step:
288
+ - Encodes audio to VAE latents
289
+ - Encodes captions and lyrics to text embeddings
290
+ - Runs the condition encoder
291
+ - Saves all tensors to `.pt` files
292
+
293
+ ⚠️ **This requires the model to be loaded and may take a few minutes.**
294
+ """)
295
+
296
+ with gr.Row():
297
+ with gr.Column(scale=3):
298
+ preprocess_output_dir = gr.Textbox(
299
+ label="Tensor Output Directory",
300
+ value="./datasets/preprocessed_tensors",
301
+ placeholder="./datasets/preprocessed_tensors",
302
+ info="Directory to save preprocessed tensor files",
303
+ )
304
+ with gr.Column(scale=1):
305
+ preprocess_btn = gr.Button(
306
+ "⚡ Preprocess",
307
+ variant="primary",
308
+ size="lg",
309
+ )
310
+
311
+ preprocess_progress = gr.Textbox(
312
+ label="Preprocessing Progress",
313
+ interactive=False,
314
+ lines=3,
315
+ )
316
+
317
+ # ==================== Training Tab ====================
318
+ with gr.Tab("🚀 Train LoRA"):
319
+ with gr.Row():
320
+ with gr.Column(scale=2):
321
+ gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
322
+
323
+ gr.Markdown("""
324
+ Select the directory containing preprocessed tensor files (`.pt` files).
325
+ These are created in the "Dataset Builder" tab using the "Preprocess" button.
326
+ """)
327
+
328
+ training_tensor_dir = gr.Textbox(
329
+ label="Preprocessed Tensors Directory",
330
+ placeholder="./datasets/preprocessed_tensors",
331
+ value="./datasets/preprocessed_tensors",
332
+ info="Directory containing preprocessed .pt tensor files",
333
+ )
334
+
335
+ load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
336
+
337
+ training_dataset_info = gr.Textbox(
338
+ label="Dataset Info",
339
+ interactive=False,
340
+ lines=3,
341
+ )
342
+
343
+ with gr.Column(scale=1):
344
+ gr.HTML("<h3>⚙️ LoRA Settings</h3>")
345
+
346
+ lora_rank = gr.Slider(
347
+ minimum=4,
348
+ maximum=256,
349
+ step=4,
350
+ value=64,
351
+ label="LoRA Rank (r)",
352
+ info="Higher = more capacity, more memory",
353
+ )
354
+
355
+ lora_alpha = gr.Slider(
356
+ minimum=4,
357
+ maximum=512,
358
+ step=4,
359
+ value=128,
360
+ label="LoRA Alpha",
361
+ info="Scaling factor (typically 2x rank)",
362
+ )
363
+
364
+ lora_dropout = gr.Slider(
365
+ minimum=0.0,
366
+ maximum=0.5,
367
+ step=0.05,
368
+ value=0.1,
369
+ label="LoRA Dropout",
370
+ )
371
+
372
+ gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
373
+
374
+ with gr.Row():
375
+ learning_rate = gr.Number(
376
+ label="Learning Rate",
377
+ value=1e-4,
378
+ info="Start with 1e-4, adjust if needed",
379
+ )
380
+
381
+ train_epochs = gr.Slider(
382
+ minimum=100,
383
+ maximum=4000,
384
+ step=100,
385
+ value=500,
386
+ label="Max Epochs",
387
+ )
388
+
389
+ train_batch_size = gr.Slider(
390
+ minimum=1,
391
+ maximum=8,
392
+ step=1,
393
+ value=1,
394
+ label="Batch Size",
395
+ info="Increase if you have enough VRAM",
396
+ )
397
+
398
+ gradient_accumulation = gr.Slider(
399
+ minimum=1,
400
+ maximum=16,
401
+ step=1,
402
+ value=1,
403
+ label="Gradient Accumulation",
404
+ info="Effective batch = batch_size × accumulation",
405
+ )
406
+
407
+ with gr.Row():
408
+ save_every_n_epochs = gr.Slider(
409
+ minimum=50,
410
+ maximum=1000,
411
+ step=50,
412
+ value=200,
413
+ label="Save Every N Epochs",
414
+ )
415
+
416
+ training_shift = gr.Slider(
417
+ minimum=1.0,
418
+ maximum=5.0,
419
+ step=0.5,
420
+ value=3.0,
421
+ label="Shift",
422
+ info="Timestep shift for turbo model",
423
+ )
424
+
425
+ training_seed = gr.Number(
426
+ label="Seed",
427
+ value=42,
428
+ precision=0,
429
+ )
430
+
431
+ with gr.Row():
432
+ lora_output_dir = gr.Textbox(
433
+ label="Output Directory",
434
+ value="./lora_output",
435
+ placeholder="./lora_output",
436
+ info="Directory to save trained LoRA weights",
437
+ )
438
+
439
+ gr.HTML("<hr>")
440
+
441
+ with gr.Row():
442
+ with gr.Column(scale=1):
443
+ start_training_btn = gr.Button(
444
+ "🚀 Start Training",
445
+ variant="primary",
446
+ size="lg",
447
+ )
448
+ with gr.Column(scale=1):
449
+ stop_training_btn = gr.Button(
450
+ "⏹️ Stop Training",
451
+ variant="stop",
452
+ size="lg",
453
+ )
454
+
455
+ training_progress = gr.Textbox(
456
+ label="Training Progress",
457
+ interactive=False,
458
+ lines=2,
459
+ )
460
+
461
+ with gr.Row():
462
+ training_log = gr.Textbox(
463
+ label="Training Log",
464
+ interactive=False,
465
+ lines=10,
466
+ max_lines=15,
467
+ scale=1,
468
+ )
469
+ training_loss_plot = gr.LinePlot(
470
+ x="step",
471
+ y="loss",
472
+ title="Training Loss",
473
+ x_title="Step",
474
+ y_title="Loss",
475
+ scale=1,
476
+ )
477
+
478
+ gr.HTML("<hr><h3>📦 Export LoRA</h3>")
479
+
480
+ with gr.Row():
481
+ export_path = gr.Textbox(
482
+ label="Export Path",
483
+ value="./lora_output/final_lora",
484
+ placeholder="./lora_output/my_lora",
485
+ )
486
+ export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
487
+
488
+ export_status = gr.Textbox(
489
+ label="Export Status",
490
+ interactive=False,
491
+ )
492
+
493
+ # Store dataset builder state
494
+ dataset_builder_state = gr.State(None)
495
+ training_state = gr.State({"is_training": False, "should_stop": False})
496
+
497
+ return {
498
+ # Dataset Builder - Load or Scan
499
+ "load_json_path": load_json_path,
500
+ "load_json_btn": load_json_btn,
501
+ "load_json_status": load_json_status,
502
+ "audio_directory": audio_directory,
503
+ "scan_btn": scan_btn,
504
+ "scan_status": scan_status,
505
+ "audio_files_table": audio_files_table,
506
+ "dataset_name": dataset_name,
507
+ "all_instrumental": all_instrumental,
508
+ "need_lyrics": need_lyrics,
509
+ "custom_tag": custom_tag,
510
+ "tag_position": tag_position,
511
+ "skip_metas": skip_metas,
512
+ "auto_label_btn": auto_label_btn,
513
+ "label_progress": label_progress,
514
+ "sample_selector": sample_selector,
515
+ "preview_audio": preview_audio,
516
+ "preview_filename": preview_filename,
517
+ "edit_caption": edit_caption,
518
+ "edit_lyrics": edit_lyrics,
519
+ "edit_bpm": edit_bpm,
520
+ "edit_keyscale": edit_keyscale,
521
+ "edit_timesig": edit_timesig,
522
+ "edit_duration": edit_duration,
523
+ "edit_language": edit_language,
524
+ "edit_instrumental": edit_instrumental,
525
+ "save_edit_btn": save_edit_btn,
526
+ "edit_status": edit_status,
527
+ "save_path": save_path,
528
+ "save_dataset_btn": save_dataset_btn,
529
+ "save_status": save_status,
530
+ # Preprocessing
531
+ "load_existing_dataset_path": load_existing_dataset_path,
532
+ "load_existing_dataset_btn": load_existing_dataset_btn,
533
+ "load_existing_status": load_existing_status,
534
+ "preprocess_output_dir": preprocess_output_dir,
535
+ "preprocess_btn": preprocess_btn,
536
+ "preprocess_progress": preprocess_progress,
537
+ "dataset_builder_state": dataset_builder_state,
538
+ # Training
539
+ "training_tensor_dir": training_tensor_dir,
540
+ "load_dataset_btn": load_dataset_btn,
541
+ "training_dataset_info": training_dataset_info,
542
+ "lora_rank": lora_rank,
543
+ "lora_alpha": lora_alpha,
544
+ "lora_dropout": lora_dropout,
545
+ "learning_rate": learning_rate,
546
+ "train_epochs": train_epochs,
547
+ "train_batch_size": train_batch_size,
548
+ "gradient_accumulation": gradient_accumulation,
549
+ "save_every_n_epochs": save_every_n_epochs,
550
+ "training_shift": training_shift,
551
+ "training_seed": training_seed,
552
+ "lora_output_dir": lora_output_dir,
553
+ "start_training_btn": start_training_btn,
554
+ "stop_training_btn": stop_training_btn,
555
+ "training_progress": training_progress,
556
+ "training_log": training_log,
557
+ "training_loss_plot": training_loss_plot,
558
+ "export_path": export_path,
559
+ "export_lora_btn": export_lora_btn,
560
+ "export_status": export_status,
561
+ "training_state": training_state,
562
+ }
code/acestep/handler.py ADDED
The diff for this file is too large to render. See raw diff
 
code/acestep/inference.py ADDED
@@ -0,0 +1,1164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Inference API Module
3
+
4
+ This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
+ backward-compatible Gradio UI support.
7
+ """
8
+
9
+ import math
10
+ import os
11
+ import tempfile
12
+ from typing import Optional, Union, List, Dict, Any, Tuple
13
+ from dataclasses import dataclass, field, asdict
14
+ from loguru import logger
15
+
16
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
+
18
+
19
+ @dataclass
20
+ class GenerationParams:
21
+ """Configuration for music generation parameters.
22
+
23
+ Attributes:
24
+ # Text Inputs
25
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
26
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
27
+ instrumental: If True, generate instrumental music regardless of lyrics.
28
+
29
+ # Music Metadata
30
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
31
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
32
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
33
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
34
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
35
+
36
+ # Generation Parameters
37
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
38
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
39
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
40
+
41
+ # Advanced DiT Parameters
42
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
43
+ cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
44
+ cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
45
+ shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
46
+
47
+ # Task-Specific Parameters
48
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
49
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
50
+ src_audio: Path to a source audio file for audio-to-audio tasks.
51
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
52
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
53
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
54
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
55
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
56
+
57
+ # 5Hz Language Model Parameters for CoT reasoning
58
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
59
+ lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
60
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
61
+ lm_top_k: LLM top-k sampling (0 = disabled).
62
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
63
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
64
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
65
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
66
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
67
+ """
68
+ # Required Inputs
69
+ task_type: str = "text2music"
70
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
71
+
72
+ # Audio Uploads
73
+ reference_audio: Optional[str] = None
74
+ src_audio: Optional[str] = None
75
+
76
+ # LM Codes Hints
77
+ audio_codes: str = ""
78
+
79
+ # Text Inputs
80
+ caption: str = ""
81
+ lyrics: str = ""
82
+ instrumental: bool = False
83
+
84
+ # Metadata
85
+ vocal_language: str = "unknown"
86
+ bpm: Optional[int] = None
87
+ keyscale: str = ""
88
+ timesignature: str = ""
89
+ duration: float = -1.0
90
+
91
+ # Advanced Settings
92
+ inference_steps: int = 8
93
+ seed: int = -1
94
+ guidance_scale: float = 7.0
95
+ use_adg: bool = False
96
+ cfg_interval_start: float = 0.0
97
+ cfg_interval_end: float = 1.0
98
+ shift: float = 1.0
99
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
100
+ # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
101
+ # If provided, overrides inference_steps and shift
102
+ timesteps: Optional[List[float]] = None
103
+
104
+ repainting_start: float = 0.0
105
+ repainting_end: float = -1
106
+ audio_cover_strength: float = 1.0
107
+
108
+ # 5Hz Language Model Parameters
109
+ thinking: bool = True
110
+ lm_temperature: float = 0.85
111
+ lm_cfg_scale: float = 2.0
112
+ lm_top_k: int = 0
113
+ lm_top_p: float = 0.9
114
+ lm_negative_prompt: str = "NO USER INPUT"
115
+ use_cot_metas: bool = True
116
+ use_cot_caption: bool = True
117
+ use_cot_lyrics: bool = False # TODO: not used yet
118
+ use_cot_language: bool = True
119
+ use_constrained_decoding: bool = True
120
+
121
+ cot_bpm: Optional[int] = None
122
+ cot_keyscale: str = ""
123
+ cot_timesignature: str = ""
124
+ cot_duration: Optional[float] = None
125
+ cot_vocal_language: str = "unknown"
126
+ cot_caption: str = ""
127
+ cot_lyrics: str = ""
128
+
129
+ def to_dict(self) -> Dict[str, Any]:
130
+ """Convert config to dictionary for JSON serialization."""
131
+ return asdict(self)
132
+
133
+
134
+ @dataclass
135
+ class GenerationConfig:
136
+ """Configuration for music generation.
137
+
138
+ Attributes:
139
+ batch_size: Number of audio samples to generate
140
+ allow_lm_batch: Whether to allow batch processing in LM
141
+ use_random_seed: Whether to use random seed
142
+ seeds: Seed(s) for batch generation. Can be:
143
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
144
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
145
+ - int: Single seed value (will be converted to list and padded)
146
+ lm_batch_chunk_size: Batch chunk size for LM processing
147
+ constrained_decoding_debug: Whether to enable constrained decoding debug
148
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
149
+ """
150
+ batch_size: int = 2
151
+ allow_lm_batch: bool = False
152
+ use_random_seed: bool = True
153
+ seeds: Optional[List[int]] = None
154
+ lm_batch_chunk_size: int = 8
155
+ constrained_decoding_debug: bool = False
156
+ audio_format: str = "flac" # Default to FLAC for fast saving
157
+
158
+ def to_dict(self) -> Dict[str, Any]:
159
+ """Convert config to dictionary for JSON serialization."""
160
+ return asdict(self)
161
+
162
+
163
+ @dataclass
164
+ class GenerationResult:
165
+ """Result of music generation.
166
+
167
+ Attributes:
168
+ # Audio Outputs
169
+ audios: List of audio dictionaries with paths, keys, params
170
+ status_message: Status message from generation
171
+ extra_outputs: Extra outputs from generation
172
+ success: Whether generation completed successfully
173
+ error: Error message if generation failed
174
+ """
175
+
176
+ # Audio Outputs
177
+ audios: List[Dict[str, Any]] = field(default_factory=list)
178
+ # Generation Information
179
+ status_message: str = ""
180
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
181
+ # Success Status
182
+ success: bool = True
183
+ error: Optional[str] = None
184
+
185
+ def to_dict(self) -> Dict[str, Any]:
186
+ """Convert result to dictionary for JSON serialization."""
187
+ return asdict(self)
188
+
189
+
190
+ @dataclass
191
+ class UnderstandResult:
192
+ """Result of music understanding from audio codes.
193
+
194
+ Attributes:
195
+ # Metadata Fields
196
+ caption: Generated caption describing the music
197
+ lyrics: Generated or extracted lyrics
198
+ bpm: Beats per minute (None if not detected)
199
+ duration: Duration in seconds (None if not detected)
200
+ keyscale: Musical key (e.g., "C Major")
201
+ language: Vocal language code (e.g., "en", "zh")
202
+ timesignature: Time signature (e.g., "4/4")
203
+
204
+ # Status
205
+ status_message: Status message from understanding
206
+ success: Whether understanding completed successfully
207
+ error: Error message if understanding failed
208
+ """
209
+ # Metadata Fields
210
+ caption: str = ""
211
+ lyrics: str = ""
212
+ bpm: Optional[int] = None
213
+ duration: Optional[float] = None
214
+ keyscale: str = ""
215
+ language: str = ""
216
+ timesignature: str = ""
217
+
218
+ # Status
219
+ status_message: str = ""
220
+ success: bool = True
221
+ error: Optional[str] = None
222
+
223
+ def to_dict(self) -> Dict[str, Any]:
224
+ """Convert result to dictionary for JSON serialization."""
225
+ return asdict(self)
226
+
227
+
228
+ def _update_metadata_from_lm(
229
+ metadata: Dict[str, Any],
230
+ bpm: Optional[int],
231
+ key_scale: str,
232
+ time_signature: str,
233
+ audio_duration: Optional[float],
234
+ vocal_language: str,
235
+ caption: str,
236
+ lyrics: str,
237
+ ) -> Tuple[Optional[int], str, str, Optional[float]]:
238
+ """Update metadata fields from LM output if not provided by user."""
239
+
240
+ if bpm is None and metadata.get('bpm'):
241
+ bpm_value = metadata.get('bpm')
242
+ if bpm_value not in ["N/A", ""]:
243
+ try:
244
+ bpm = int(bpm_value)
245
+ except (ValueError, TypeError):
246
+ pass
247
+
248
+ if not key_scale and metadata.get('keyscale'):
249
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
250
+ if key_scale_value != "N/A":
251
+ key_scale = key_scale_value
252
+
253
+ if not time_signature and metadata.get('timesignature'):
254
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
255
+ if time_signature_value != "N/A":
256
+ time_signature = time_signature_value
257
+
258
+ if audio_duration is None:
259
+ audio_duration_value = metadata.get('duration', -1)
260
+ if audio_duration_value not in ["N/A", ""]:
261
+ try:
262
+ audio_duration = float(audio_duration_value)
263
+ except (ValueError, TypeError):
264
+ pass
265
+
266
+ if not vocal_language and metadata.get('vocal_language'):
267
+ vocal_language = metadata.get('vocal_language')
268
+ if not caption and metadata.get('caption'):
269
+ caption = metadata.get('caption')
270
+ if not lyrics and metadata.get('lyrics'):
271
+ lyrics = metadata.get('lyrics')
272
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
273
+
274
+
275
+ def generate_music(
276
+ dit_handler,
277
+ llm_handler,
278
+ params: GenerationParams,
279
+ config: GenerationConfig,
280
+ save_dir: Optional[str] = None,
281
+ progress=None,
282
+ ) -> GenerationResult:
283
+ """Generate music using ACE-Step model with optional LM reasoning.
284
+
285
+ Args:
286
+ dit_handler: Initialized DiT model handler (AceStepHandler instance)
287
+ llm_handler: Initialized LLM handler (LLMHandler instance)
288
+ params: Generation parameters (GenerationParams instance)
289
+ config: Generation configuration (GenerationConfig instance)
290
+
291
+ Returns:
292
+ GenerationResult with generated audio files and metadata
293
+ """
294
+ try:
295
+ # Phase 1: LM-based metadata and code generation (if enabled)
296
+ audio_code_string_to_use = params.audio_codes
297
+ lm_generated_metadata = None
298
+ lm_generated_audio_codes_list = []
299
+ lm_total_time_costs = {
300
+ "phase1_time": 0.0,
301
+ "phase2_time": 0.0,
302
+ "total_time": 0.0,
303
+ }
304
+
305
+ # Extract mutable copies of metadata (will be updated by LM if needed)
306
+ bpm = params.bpm
307
+ key_scale = params.keyscale
308
+ time_signature = params.timesignature
309
+ audio_duration = params.duration
310
+ dit_input_caption = params.caption
311
+ dit_input_vocal_language = params.vocal_language
312
+ dit_input_lyrics = params.lyrics
313
+ # Determine if we need to generate audio codes
314
+ # If user has provided audio_codes, we don't need to generate them
315
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
316
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
317
+
318
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
319
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
320
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
321
+ # Note: This logic can be refined based on specific requirements
322
+ need_audio_codes = not user_provided_audio_codes
323
+
324
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
325
+ # Determine actual batch size for chunk processing
326
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
327
+
328
+ # Prepare seeds for batch generation
329
+ # Use config.seed if provided, otherwise fallback to params.seed
330
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
331
+ seed_for_generation = ""
332
+ if config.seeds is not None and len(config.seeds) > 0:
333
+ if isinstance(config.seeds, list):
334
+ # Convert List[int] to comma-separated string
335
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
336
+
337
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
338
+ # This will handle all the logic: padding with random seeds if needed, etc.
339
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
340
+
341
+ # LM-based Chain-of-Thought reasoning
342
+ # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
343
+ # and don't need LM to generate audio codes
344
+ skip_lm_tasks = {"cover", "repaint"}
345
+
346
+ # Determine if we should use LLM
347
+ # LLM is needed for:
348
+ # 1. thinking=True: generate audio codes via LM
349
+ # 2. use_cot_caption=True: enhance/generate caption via CoT
350
+ # 3. use_cot_language=True: detect vocal language via CoT
351
+ # 4. use_cot_metas=True: fill missing metadata via CoT
352
+ need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
353
+ use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
354
+ lm_status = []
355
+
356
+ if params.task_type in skip_lm_tasks:
357
+ logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
358
+
359
+ logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
360
+ f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
361
+ f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
362
+ f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
363
+
364
+ if use_lm:
365
+ # Convert sampling parameters - handle None values safely
366
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
367
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
368
+
369
+ # Build user_metadata from user-provided values
370
+ user_metadata = {}
371
+ if bpm is not None:
372
+ try:
373
+ bpm_value = float(bpm)
374
+ if bpm_value > 0:
375
+ user_metadata['bpm'] = int(bpm_value)
376
+ except (ValueError, TypeError):
377
+ pass
378
+
379
+ if key_scale and key_scale.strip():
380
+ key_scale_clean = key_scale.strip()
381
+ if key_scale_clean.lower() not in ["n/a", ""]:
382
+ user_metadata['keyscale'] = key_scale_clean
383
+
384
+ if time_signature and time_signature.strip():
385
+ time_sig_clean = time_signature.strip()
386
+ if time_sig_clean.lower() not in ["n/a", ""]:
387
+ user_metadata['timesignature'] = time_sig_clean
388
+
389
+ if audio_duration is not None:
390
+ try:
391
+ duration_value = float(audio_duration)
392
+ if duration_value > 0:
393
+ user_metadata['duration'] = int(duration_value)
394
+ except (ValueError, TypeError):
395
+ pass
396
+
397
+ user_metadata_to_pass = user_metadata if user_metadata else None
398
+
399
+ # Determine infer_type based on whether we need audio codes
400
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
401
+ # - "dit": generates only metas (single phase)
402
+ infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
403
+
404
+ # Use chunk size from config, or default to batch_size if not set
405
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
406
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
407
+
408
+ all_metadata_list = []
409
+ all_audio_codes_list = []
410
+
411
+ for chunk_idx in range(num_chunks):
412
+ chunk_start = chunk_idx * max_inference_batch_size
413
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
414
+ chunk_size = chunk_end - chunk_start
415
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
416
+
417
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
418
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
419
+
420
+ # Use the determined infer_type
421
+ # - "llm_dit" will internally run two phases (metas + codes)
422
+ # - "dit" will only run phase 1 (metas only)
423
+ result = llm_handler.generate_with_stop_condition(
424
+ caption=params.caption or "",
425
+ lyrics=params.lyrics or "",
426
+ infer_type=infer_type,
427
+ temperature=params.lm_temperature,
428
+ cfg_scale=params.lm_cfg_scale,
429
+ negative_prompt=params.lm_negative_prompt,
430
+ top_k=top_k_value,
431
+ top_p=top_p_value,
432
+ user_metadata=user_metadata_to_pass,
433
+ use_cot_caption=params.use_cot_caption,
434
+ use_cot_language=params.use_cot_language,
435
+ use_cot_metas=params.use_cot_metas,
436
+ use_constrained_decoding=params.use_constrained_decoding,
437
+ constrained_decoding_debug=config.constrained_decoding_debug,
438
+ batch_size=chunk_size,
439
+ seeds=chunk_seeds,
440
+ progress=progress,
441
+ )
442
+
443
+ # Check if LM generation failed
444
+ if not result.get("success", False):
445
+ error_msg = result.get("error", "Unknown LM error")
446
+ lm_status.append(f"❌ LM Error: {error_msg}")
447
+ # Return early with error
448
+ return GenerationResult(
449
+ audios=[],
450
+ status_message=f"❌ LM generation failed: {error_msg}",
451
+ extra_outputs={},
452
+ success=False,
453
+ error=error_msg,
454
+ )
455
+
456
+ # Extract metadata and audio_codes from result dict
457
+ if chunk_size > 1:
458
+ metadata_list = result.get("metadata", [])
459
+ audio_codes_list = result.get("audio_codes", [])
460
+ all_metadata_list.extend(metadata_list)
461
+ all_audio_codes_list.extend(audio_codes_list)
462
+ else:
463
+ metadata = result.get("metadata", {})
464
+ audio_codes = result.get("audio_codes", "")
465
+ all_metadata_list.append(metadata)
466
+ all_audio_codes_list.append(audio_codes)
467
+
468
+ # Collect time costs from LM extra_outputs
469
+ lm_extra = result.get("extra_outputs", {})
470
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
471
+ if lm_chunk_time_costs:
472
+ # Accumulate time costs from all chunks
473
+ for key in ["phase1_time", "phase2_time", "total_time"]:
474
+ if key in lm_chunk_time_costs:
475
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
476
+
477
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
478
+ lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
479
+
480
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
481
+ lm_generated_audio_codes_list = all_audio_codes_list
482
+
483
+ # Set audio_code_string_to_use based on infer_type
484
+ if infer_type == "llm_dit":
485
+ # If batch mode, use list; otherwise use single string
486
+ if actual_batch_size > 1:
487
+ audio_code_string_to_use = all_audio_codes_list
488
+ else:
489
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
490
+ else:
491
+ # For "dit" mode, keep user-provided codes or empty
492
+ audio_code_string_to_use = params.audio_codes
493
+
494
+ # Update metadata from LM if not provided by user
495
+ if lm_generated_metadata:
496
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
497
+ metadata=lm_generated_metadata,
498
+ bpm=bpm,
499
+ key_scale=key_scale,
500
+ time_signature=time_signature,
501
+ audio_duration=audio_duration,
502
+ vocal_language=dit_input_vocal_language,
503
+ caption=dit_input_caption,
504
+ lyrics=dit_input_lyrics)
505
+ if not params.bpm:
506
+ params.cot_bpm = bpm
507
+ if not params.keyscale:
508
+ params.cot_keyscale = key_scale
509
+ if not params.timesignature:
510
+ params.cot_timesignature = time_signature
511
+ if not params.duration:
512
+ params.cot_duration = audio_duration
513
+ if not params.vocal_language:
514
+ params.cot_vocal_language = vocal_language
515
+ if not params.caption:
516
+ params.cot_caption = caption
517
+ if not params.lyrics:
518
+ params.cot_lyrics = lyrics
519
+
520
+ # set cot caption and language if needed
521
+ if params.use_cot_caption:
522
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
523
+ if params.use_cot_language:
524
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
525
+
526
+ # Phase 2: DiT music generation
527
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
528
+ result = dit_handler.generate_music(
529
+ captions=dit_input_caption,
530
+ lyrics=dit_input_lyrics,
531
+ bpm=bpm,
532
+ key_scale=key_scale,
533
+ time_signature=time_signature,
534
+ vocal_language=dit_input_vocal_language,
535
+ inference_steps=params.inference_steps,
536
+ guidance_scale=params.guidance_scale,
537
+ use_random_seed=config.use_random_seed,
538
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
539
+ reference_audio=params.reference_audio,
540
+ audio_duration=audio_duration,
541
+ batch_size=config.batch_size if config.batch_size is not None else 1,
542
+ src_audio=params.src_audio,
543
+ audio_code_string=audio_code_string_to_use,
544
+ repainting_start=params.repainting_start,
545
+ repainting_end=params.repainting_end,
546
+ instruction=params.instruction,
547
+ audio_cover_strength=params.audio_cover_strength,
548
+ task_type=params.task_type,
549
+ use_adg=params.use_adg,
550
+ cfg_interval_start=params.cfg_interval_start,
551
+ cfg_interval_end=params.cfg_interval_end,
552
+ shift=params.shift,
553
+ infer_method=params.infer_method,
554
+ timesteps=params.timesteps,
555
+ progress=progress,
556
+ )
557
+
558
+ # Check if generation failed
559
+ if not result.get("success", False):
560
+ return GenerationResult(
561
+ audios=[],
562
+ status_message=result.get("status_message", ""),
563
+ extra_outputs={},
564
+ success=False,
565
+ error=result.get("error"),
566
+ )
567
+
568
+ # Extract results from dit_handler.generate_music dict
569
+ dit_audios = result.get("audios", [])
570
+ status_message = result.get("status_message", "")
571
+ dit_extra_outputs = result.get("extra_outputs", {})
572
+
573
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
574
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
575
+ seed_list = actual_seed_list
576
+
577
+ # Get base params dictionary
578
+ base_params_dict = params.to_dict()
579
+
580
+ # Save audio files using AudioSaver (format from config)
581
+ audio_format = config.audio_format if config.audio_format else "flac"
582
+ audio_saver = AudioSaver(default_format=audio_format)
583
+
584
+ # Use handler's temp_dir for saving files
585
+ if save_dir is not None:
586
+ os.makedirs(save_dir, exist_ok=True)
587
+
588
+ # Build audios list for GenerationResult with params and save files
589
+ # Audio saving and UUID generation handled here, outside of handler
590
+ audios = []
591
+ for idx, dit_audio in enumerate(dit_audios):
592
+ # Create a copy of params dict for this audio
593
+ audio_params = base_params_dict.copy()
594
+
595
+ # Update audio-specific values
596
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
597
+
598
+ # Add audio codes if batch mode
599
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
600
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
601
+
602
+ # Get audio tensor and metadata
603
+ audio_tensor = dit_audio.get("tensor")
604
+ sample_rate = dit_audio.get("sample_rate", 48000)
605
+
606
+ # Generate UUID for this audio (moved from handler)
607
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
608
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
609
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
610
+ if isinstance(audio_code_str, list):
611
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
612
+
613
+ audio_key = generate_uuid_from_params(audio_params)
614
+
615
+ # Save audio file (handled outside handler)
616
+ audio_path = None
617
+ if audio_tensor is not None and save_dir is not None:
618
+ try:
619
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
620
+ audio_path = audio_saver.save_audio(audio_tensor,
621
+ audio_file,
622
+ sample_rate=sample_rate,
623
+ format=audio_format,
624
+ channels_first=True)
625
+ except Exception as e:
626
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
627
+ audio_path = "" # Fallback to empty path
628
+
629
+ audio_dict = {
630
+ "path": audio_path or "", # File path (saved here, not in handler)
631
+ "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
632
+ "key": audio_key,
633
+ "sample_rate": sample_rate,
634
+ "params": audio_params,
635
+ }
636
+
637
+ audios.append(audio_dict)
638
+
639
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
640
+ extra_outputs = dit_extra_outputs.copy()
641
+ extra_outputs["lm_metadata"] = lm_generated_metadata
642
+
643
+ # Merge time_costs from both LM and DiT into a unified dictionary
644
+ unified_time_costs = {}
645
+
646
+ # Add LM time costs (if LM was used)
647
+ if use_lm and lm_total_time_costs:
648
+ for key, value in lm_total_time_costs.items():
649
+ unified_time_costs[f"lm_{key}"] = value
650
+
651
+ # Add DiT time costs (if available)
652
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
653
+ if dit_time_costs:
654
+ for key, value in dit_time_costs.items():
655
+ unified_time_costs[f"dit_{key}"] = value
656
+
657
+ # Calculate total pipeline time
658
+ if unified_time_costs:
659
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
660
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
661
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
662
+
663
+ # Update extra_outputs with unified time_costs
664
+ extra_outputs["time_costs"] = unified_time_costs
665
+
666
+ if lm_status:
667
+ status_message = "\n".join(lm_status) + "\n" + status_message
668
+ else:
669
+ status_message = status_message
670
+ # Create and return GenerationResult
671
+ return GenerationResult(
672
+ audios=audios,
673
+ status_message=status_message,
674
+ extra_outputs=extra_outputs,
675
+ success=True,
676
+ error=None,
677
+ )
678
+
679
+ except Exception as e:
680
+ logger.exception("Music generation failed")
681
+ return GenerationResult(
682
+ audios=[],
683
+ status_message=f"Error: {str(e)}",
684
+ extra_outputs={},
685
+ success=False,
686
+ error=str(e),
687
+ )
688
+
689
+
690
+ def understand_music(
691
+ llm_handler,
692
+ audio_codes: str,
693
+ temperature: float = 0.85,
694
+ top_k: Optional[int] = None,
695
+ top_p: Optional[float] = None,
696
+ repetition_penalty: float = 1.0,
697
+ use_constrained_decoding: bool = True,
698
+ constrained_decoding_debug: bool = False,
699
+ ) -> UnderstandResult:
700
+ """Understand music from audio codes using the 5Hz Language Model.
701
+
702
+ This function analyzes audio semantic codes and generates metadata about the music,
703
+ including caption, lyrics, BPM, duration, key scale, language, and time signature.
704
+
705
+ If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
706
+ instead of analyzing existing codes.
707
+
708
+ Note: cfg_scale and negative_prompt are not supported in understand mode.
709
+
710
+ Args:
711
+ llm_handler: Initialized LLM handler (LLMHandler instance)
712
+ audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
713
+ Use empty string or "NO USER INPUT" to generate a sample example.
714
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
715
+ top_k: Top-K sampling (None or 0 = disabled)
716
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
717
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
718
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
719
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
720
+
721
+ Returns:
722
+ UnderstandResult with parsed metadata fields and status
723
+
724
+ Example:
725
+ >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
726
+ >>> if result.success:
727
+ ... print(f"Caption: {result.caption}")
728
+ ... print(f"BPM: {result.bpm}")
729
+ ... print(f"Lyrics: {result.lyrics}")
730
+ """
731
+ # Check if LLM is initialized
732
+ if not llm_handler.llm_initialized:
733
+ return UnderstandResult(
734
+ status_message="5Hz LM not initialized. Please initialize it first.",
735
+ success=False,
736
+ error="LLM not initialized",
737
+ )
738
+
739
+ # If codes are empty, use "NO USER INPUT" to generate a sample example
740
+ if not audio_codes or not audio_codes.strip():
741
+ audio_codes = "NO USER INPUT"
742
+
743
+ try:
744
+ # Call LLM understanding
745
+ metadata, status = llm_handler.understand_audio_from_codes(
746
+ audio_codes=audio_codes,
747
+ temperature=temperature,
748
+ top_k=top_k,
749
+ top_p=top_p,
750
+ repetition_penalty=repetition_penalty,
751
+ use_constrained_decoding=use_constrained_decoding,
752
+ constrained_decoding_debug=constrained_decoding_debug,
753
+ )
754
+
755
+ # Check if LLM returned empty metadata (error case)
756
+ if not metadata:
757
+ return UnderstandResult(
758
+ status_message=status or "Failed to understand audio codes",
759
+ success=False,
760
+ error=status or "Empty metadata returned",
761
+ )
762
+
763
+ # Extract and convert fields
764
+ caption = metadata.get('caption', '')
765
+ lyrics = metadata.get('lyrics', '')
766
+ keyscale = metadata.get('keyscale', '')
767
+ language = metadata.get('language', metadata.get('vocal_language', ''))
768
+ timesignature = metadata.get('timesignature', '')
769
+
770
+ # Convert BPM to int
771
+ bpm = None
772
+ bpm_value = metadata.get('bpm')
773
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
774
+ try:
775
+ bpm = int(bpm_value)
776
+ except (ValueError, TypeError):
777
+ pass
778
+
779
+ # Convert duration to float
780
+ duration = None
781
+ duration_value = metadata.get('duration')
782
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
783
+ try:
784
+ duration = float(duration_value)
785
+ except (ValueError, TypeError):
786
+ pass
787
+
788
+ # Clean up N/A values
789
+ if keyscale == 'N/A':
790
+ keyscale = ''
791
+ if language == 'N/A':
792
+ language = ''
793
+ if timesignature == 'N/A':
794
+ timesignature = ''
795
+
796
+ return UnderstandResult(
797
+ caption=caption,
798
+ lyrics=lyrics,
799
+ bpm=bpm,
800
+ duration=duration,
801
+ keyscale=keyscale,
802
+ language=language,
803
+ timesignature=timesignature,
804
+ status_message=status,
805
+ success=True,
806
+ error=None,
807
+ )
808
+
809
+ except Exception as e:
810
+ logger.exception("Music understanding failed")
811
+ return UnderstandResult(
812
+ status_message=f"Error: {str(e)}",
813
+ success=False,
814
+ error=str(e),
815
+ )
816
+
817
+
818
+ @dataclass
819
+ class CreateSampleResult:
820
+ """Result of creating a music sample from a natural language query.
821
+
822
+ This is used by the "Simple Mode" / "Inspiration Mode" feature where users
823
+ provide a natural language description and the LLM generates a complete
824
+ sample with caption, lyrics, and metadata.
825
+
826
+ Attributes:
827
+ # Metadata Fields
828
+ caption: Generated detailed music description/caption
829
+ lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
830
+ bpm: Beats per minute (None if not generated)
831
+ duration: Duration in seconds (None if not generated)
832
+ keyscale: Musical key (e.g., "C Major")
833
+ language: Vocal language code (e.g., "en", "zh")
834
+ timesignature: Time signature (e.g., "4")
835
+ instrumental: Whether this is an instrumental piece
836
+
837
+ # Status
838
+ status_message: Status message from sample creation
839
+ success: Whether sample creation completed successfully
840
+ error: Error message if sample creation failed
841
+ """
842
+ # Metadata Fields
843
+ caption: str = ""
844
+ lyrics: str = ""
845
+ bpm: Optional[int] = None
846
+ duration: Optional[float] = None
847
+ keyscale: str = ""
848
+ language: str = ""
849
+ timesignature: str = ""
850
+ instrumental: bool = False
851
+
852
+ # Status
853
+ status_message: str = ""
854
+ success: bool = True
855
+ error: Optional[str] = None
856
+
857
+ def to_dict(self) -> Dict[str, Any]:
858
+ """Convert result to dictionary for JSON serialization."""
859
+ return asdict(self)
860
+
861
+
862
+ def create_sample(
863
+ llm_handler,
864
+ query: str,
865
+ instrumental: bool = False,
866
+ vocal_language: Optional[str] = None,
867
+ temperature: float = 0.85,
868
+ top_k: Optional[int] = None,
869
+ top_p: Optional[float] = None,
870
+ repetition_penalty: float = 1.0,
871
+ use_constrained_decoding: bool = True,
872
+ constrained_decoding_debug: bool = False,
873
+ ) -> CreateSampleResult:
874
+ """Create a music sample from a natural language query using the 5Hz Language Model.
875
+
876
+ This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
877
+ language description of music and generates a complete sample including:
878
+ - Detailed caption/description
879
+ - Lyrics (unless instrumental)
880
+ - Metadata (BPM, duration, key, language, time signature)
881
+
882
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
883
+
884
+ Args:
885
+ llm_handler: Initialized LLM handler (LLMHandler instance)
886
+ query: User's natural language music description (e.g., "a soft Bengali love song")
887
+ instrumental: Whether to generate instrumental music (no vocals)
888
+ vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
889
+ If provided, the model will be constrained to generate lyrics in this language.
890
+ If None or "unknown", no language constraint is applied.
891
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
892
+ top_k: Top-K sampling (None or 0 = disabled)
893
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
894
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
895
+ use_constrained_decoding: Whether to use FSM-based constrained decoding
896
+ constrained_decoding_debug: Whether to enable debug logging
897
+
898
+ Returns:
899
+ CreateSampleResult with generated sample fields and status
900
+
901
+ Example:
902
+ >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
903
+ >>> if result.success:
904
+ ... print(f"Caption: {result.caption}")
905
+ ... print(f"Lyrics: {result.lyrics}")
906
+ ... print(f"BPM: {result.bpm}")
907
+ """
908
+ # Check if LLM is initialized
909
+ if not llm_handler.llm_initialized:
910
+ return CreateSampleResult(
911
+ status_message="5Hz LM not initialized. Please initialize it first.",
912
+ success=False,
913
+ error="LLM not initialized",
914
+ )
915
+
916
+ try:
917
+ # Call LLM to create sample
918
+ metadata, status = llm_handler.create_sample_from_query(
919
+ query=query,
920
+ instrumental=instrumental,
921
+ vocal_language=vocal_language,
922
+ temperature=temperature,
923
+ top_k=top_k,
924
+ top_p=top_p,
925
+ repetition_penalty=repetition_penalty,
926
+ use_constrained_decoding=use_constrained_decoding,
927
+ constrained_decoding_debug=constrained_decoding_debug,
928
+ )
929
+
930
+ # Check if LLM returned empty metadata (error case)
931
+ if not metadata:
932
+ return CreateSampleResult(
933
+ status_message=status or "Failed to create sample",
934
+ success=False,
935
+ error=status or "Empty metadata returned",
936
+ )
937
+
938
+ # Extract and convert fields
939
+ caption = metadata.get('caption', '')
940
+ lyrics = metadata.get('lyrics', '')
941
+ keyscale = metadata.get('keyscale', '')
942
+ language = metadata.get('language', metadata.get('vocal_language', ''))
943
+ timesignature = metadata.get('timesignature', '')
944
+ is_instrumental = metadata.get('instrumental', instrumental)
945
+
946
+ # Convert BPM to int
947
+ bpm = None
948
+ bpm_value = metadata.get('bpm')
949
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
950
+ try:
951
+ bpm = int(bpm_value)
952
+ except (ValueError, TypeError):
953
+ pass
954
+
955
+ # Convert duration to float
956
+ duration = None
957
+ duration_value = metadata.get('duration')
958
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
959
+ try:
960
+ duration = float(duration_value)
961
+ except (ValueError, TypeError):
962
+ pass
963
+
964
+ # Clean up N/A values
965
+ if keyscale == 'N/A':
966
+ keyscale = ''
967
+ if language == 'N/A':
968
+ language = ''
969
+ if timesignature == 'N/A':
970
+ timesignature = ''
971
+
972
+ return CreateSampleResult(
973
+ caption=caption,
974
+ lyrics=lyrics,
975
+ bpm=bpm,
976
+ duration=duration,
977
+ keyscale=keyscale,
978
+ language=language,
979
+ timesignature=timesignature,
980
+ instrumental=is_instrumental,
981
+ status_message=status,
982
+ success=True,
983
+ error=None,
984
+ )
985
+
986
+ except Exception as e:
987
+ logger.exception("Sample creation failed")
988
+ return CreateSampleResult(
989
+ status_message=f"Error: {str(e)}",
990
+ success=False,
991
+ error=str(e),
992
+ )
993
+
994
+
995
+ @dataclass
996
+ class FormatSampleResult:
997
+ """Result of formatting user-provided caption and lyrics.
998
+
999
+ This is used by the "Format" feature where users provide caption and lyrics,
1000
+ and the LLM formats them into structured music metadata and an enhanced description.
1001
+
1002
+ Attributes:
1003
+ # Metadata Fields
1004
+ caption: Enhanced/formatted music description/caption
1005
+ lyrics: Formatted lyrics (may be same as input or reformatted)
1006
+ bpm: Beats per minute (None if not detected)
1007
+ duration: Duration in seconds (None if not detected)
1008
+ keyscale: Musical key (e.g., "C Major")
1009
+ language: Vocal language code (e.g., "en", "zh")
1010
+ timesignature: Time signature (e.g., "4")
1011
+
1012
+ # Status
1013
+ status_message: Status message from formatting
1014
+ success: Whether formatting completed successfully
1015
+ error: Error message if formatting failed
1016
+ """
1017
+ # Metadata Fields
1018
+ caption: str = ""
1019
+ lyrics: str = ""
1020
+ bpm: Optional[int] = None
1021
+ duration: Optional[float] = None
1022
+ keyscale: str = ""
1023
+ language: str = ""
1024
+ timesignature: str = ""
1025
+
1026
+ # Status
1027
+ status_message: str = ""
1028
+ success: bool = True
1029
+ error: Optional[str] = None
1030
+
1031
+ def to_dict(self) -> Dict[str, Any]:
1032
+ """Convert result to dictionary for JSON serialization."""
1033
+ return asdict(self)
1034
+
1035
+
1036
+ def format_sample(
1037
+ llm_handler,
1038
+ caption: str,
1039
+ lyrics: str,
1040
+ user_metadata: Optional[Dict[str, Any]] = None,
1041
+ temperature: float = 0.85,
1042
+ top_k: Optional[int] = None,
1043
+ top_p: Optional[float] = None,
1044
+ repetition_penalty: float = 1.0,
1045
+ use_constrained_decoding: bool = True,
1046
+ constrained_decoding_debug: bool = False,
1047
+ ) -> FormatSampleResult:
1048
+ """Format user-provided caption and lyrics using the 5Hz Language Model.
1049
+
1050
+ This function takes user input (caption and lyrics) and generates structured
1051
+ music metadata including an enhanced caption, BPM, duration, key, language,
1052
+ and time signature.
1053
+
1054
+ If user_metadata is provided, those values will be used to constrain the
1055
+ decoding, ensuring the output matches user-specified values.
1056
+
1057
+ Note: cfg_scale and negative_prompt are not supported in format mode.
1058
+
1059
+ Args:
1060
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1061
+ caption: User's caption/description (e.g., "Latin pop, reggaeton")
1062
+ lyrics: User's lyrics with structure tags
1063
+ user_metadata: Optional dict with user-provided metadata to constrain decoding.
1064
+ Supported keys: bpm, duration, keyscale, timesignature, language
1065
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1066
+ top_k: Top-K sampling (None or 0 = disabled)
1067
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1068
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1069
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1070
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1071
+
1072
+ Returns:
1073
+ FormatSampleResult with formatted metadata fields and status
1074
+
1075
+ Example:
1076
+ >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
1077
+ >>> if result.success:
1078
+ ... print(f"Caption: {result.caption}")
1079
+ ... print(f"BPM: {result.bpm}")
1080
+ ... print(f"Lyrics: {result.lyrics}")
1081
+ """
1082
+ # Check if LLM is initialized
1083
+ if not llm_handler.llm_initialized:
1084
+ return FormatSampleResult(
1085
+ status_message="5Hz LM not initialized. Please initialize it first.",
1086
+ success=False,
1087
+ error="LLM not initialized",
1088
+ )
1089
+
1090
+ try:
1091
+ # Call LLM formatting
1092
+ metadata, status = llm_handler.format_sample_from_input(
1093
+ caption=caption,
1094
+ lyrics=lyrics,
1095
+ user_metadata=user_metadata,
1096
+ temperature=temperature,
1097
+ top_k=top_k,
1098
+ top_p=top_p,
1099
+ repetition_penalty=repetition_penalty,
1100
+ use_constrained_decoding=use_constrained_decoding,
1101
+ constrained_decoding_debug=constrained_decoding_debug,
1102
+ )
1103
+
1104
+ # Check if LLM returned empty metadata (error case)
1105
+ if not metadata:
1106
+ return FormatSampleResult(
1107
+ status_message=status or "Failed to format input",
1108
+ success=False,
1109
+ error=status or "Empty metadata returned",
1110
+ )
1111
+
1112
+ # Extract and convert fields
1113
+ result_caption = metadata.get('caption', '')
1114
+ result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
1115
+ keyscale = metadata.get('keyscale', '')
1116
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1117
+ timesignature = metadata.get('timesignature', '')
1118
+
1119
+ # Convert BPM to int
1120
+ bpm = None
1121
+ bpm_value = metadata.get('bpm')
1122
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1123
+ try:
1124
+ bpm = int(bpm_value)
1125
+ except (ValueError, TypeError):
1126
+ pass
1127
+
1128
+ # Convert duration to float
1129
+ duration = None
1130
+ duration_value = metadata.get('duration')
1131
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1132
+ try:
1133
+ duration = float(duration_value)
1134
+ except (ValueError, TypeError):
1135
+ pass
1136
+
1137
+ # Clean up N/A values
1138
+ if keyscale == 'N/A':
1139
+ keyscale = ''
1140
+ if language == 'N/A':
1141
+ language = ''
1142
+ if timesignature == 'N/A':
1143
+ timesignature = ''
1144
+
1145
+ return FormatSampleResult(
1146
+ caption=result_caption,
1147
+ lyrics=result_lyrics,
1148
+ bpm=bpm,
1149
+ duration=duration,
1150
+ keyscale=keyscale,
1151
+ language=language,
1152
+ timesignature=timesignature,
1153
+ status_message=status,
1154
+ success=True,
1155
+ error=None,
1156
+ )
1157
+
1158
+ except Exception as e:
1159
+ logger.exception("Format sample failed")
1160
+ return FormatSampleResult(
1161
+ status_message=f"Error: {str(e)}",
1162
+ success=False,
1163
+ error=str(e),
1164
+ )
code/acestep/llm_inference.py ADDED
The diff for this file is too large to render. See raw diff
 
code/acestep/local_cache.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local cache module to replace Redis
2
+
3
+ Uses diskcache as backend, provides Redis-compatible API.
4
+ Supports persistent storage and TTL expiration.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Any, Optional
10
+ from threading import Lock
11
+
12
+ try:
13
+ from diskcache import Cache
14
+ HAS_DISKCACHE = True
15
+ except ImportError:
16
+ HAS_DISKCACHE = False
17
+
18
+
19
+ class LocalCache:
20
+ """
21
+ Local cache implementation with Redis-compatible API.
22
+ Uses diskcache as backend, supports persistence and TTL.
23
+ """
24
+
25
+ _instance = None
26
+ _lock = Lock()
27
+
28
+ def __new__(cls, cache_dir: Optional[str] = None):
29
+ """Singleton pattern"""
30
+ if cls._instance is None:
31
+ with cls._lock:
32
+ if cls._instance is None:
33
+ cls._instance = super().__new__(cls)
34
+ cls._instance._initialized = False
35
+ return cls._instance
36
+
37
+ def __init__(self, cache_dir: Optional[str] = None):
38
+ if getattr(self, '_initialized', False):
39
+ return
40
+
41
+ if not HAS_DISKCACHE:
42
+ raise ImportError(
43
+ "diskcache not installed. Run: pip install diskcache"
44
+ )
45
+
46
+ if cache_dir is None:
47
+ cache_dir = os.path.join(
48
+ os.path.dirname(os.path.dirname(__file__)),
49
+ ".cache",
50
+ "local_redis"
51
+ )
52
+
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ self._cache = Cache(cache_dir)
55
+ self._initialized = True
56
+
57
+ def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
58
+ """
59
+ Set key-value pair
60
+
61
+ Args:
62
+ name: Key name
63
+ value: Value (auto-serialize dict/list)
64
+ ex: Expiration time (seconds)
65
+
66
+ Returns:
67
+ bool: Success status
68
+ """
69
+ if isinstance(value, (dict, list)):
70
+ value = json.dumps(value, ensure_ascii=False)
71
+ self._cache.set(name, value, expire=ex)
72
+ return True
73
+
74
+ def get(self, name: str) -> Optional[str]:
75
+ """Get value"""
76
+ return self._cache.get(name)
77
+
78
+ def delete(self, name: str) -> int:
79
+ """Delete key, returns number of deleted items"""
80
+ return 1 if self._cache.delete(name) else 0
81
+
82
+ def exists(self, name: str) -> bool:
83
+ """Check if key exists"""
84
+ return name in self._cache
85
+
86
+ def keys(self, pattern: str = "*") -> list:
87
+ """
88
+ Get list of matching keys
89
+ Note: Simplified implementation, only supports prefix and full matching
90
+ """
91
+ if pattern == "*":
92
+ return list(self._cache.iterkeys())
93
+
94
+ prefix = pattern.rstrip("*")
95
+ return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
96
+
97
+ def expire(self, name: str, seconds: int) -> bool:
98
+ """Set key expiration time"""
99
+ value = self._cache.get(name)
100
+ if value is not None:
101
+ self._cache.set(name, value, expire=seconds)
102
+ return True
103
+ return False
104
+
105
+ def ttl(self, name: str) -> int:
106
+ """
107
+ Get remaining time to live (seconds)
108
+ Note: diskcache does not directly support TTL queries
109
+ """
110
+ if name in self._cache:
111
+ return -1 # Exists but TTL unknown
112
+ return -2 # Key does not exist
113
+
114
+ def close(self):
115
+ """Close cache connection"""
116
+ if hasattr(self, '_cache'):
117
+ self._cache.close()
118
+
119
+
120
+ # Lazily initialized global instance
121
+ _local_cache: Optional[LocalCache] = None
122
+
123
+
124
+ def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
125
+ """Get local cache instance"""
126
+ global _local_cache
127
+ if _local_cache is None:
128
+ _local_cache = LocalCache(cache_dir)
129
+ return _local_cache
code/acestep/test_time_scaling.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-Time Scaling Module
3
+ Implements perplexity-based scoring for generated audio codes
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Optional, Dict, Any, List
8
+ from loguru import logger
9
+ import yaml
10
+ import math
11
+ import re
12
+
13
+
14
+ def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
15
+ """
16
+ Calculate Pointwise Mutual Information (PMI) score.
17
+
18
+ PMI = log P(condition|codes) - log P(condition)
19
+ = log [P(codes|condition) / P(codes)]
20
+
21
+ This removes the bias from P(condition) and measures how much the codes
22
+ improve our ability to predict the condition.
23
+
24
+ Args:
25
+ log_prob_conditional: Average log probability of condition given codes
26
+ log_prob_unconditional: Average log probability of condition without codes
27
+
28
+ Returns:
29
+ PMI score (higher is better, can be positive or negative)
30
+ - Positive: codes improve prediction → good match
31
+ - Zero: codes don't help → no correlation
32
+ - Negative: codes hurt prediction → poor match
33
+ """
34
+ return log_prob_conditional - log_prob_unconditional
35
+
36
+
37
+ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
38
+ """
39
+ Convert PMI score to normalized [0, 1] range using sigmoid function.
40
+
41
+ score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
42
+
43
+ Args:
44
+ pmi: PMI score (can be positive or negative)
45
+ scale: Scale parameter to control sensitivity (default 0.1)
46
+ - Smaller scale: more sensitive to PMI changes
47
+ - Larger scale: less sensitive to PMI changes
48
+
49
+ Returns:
50
+ Normalized score in [0, 1] range, where:
51
+ - PMI > 0 → score > 0.5 (good match)
52
+ - PMI = 0 → score = 0.5 (neutral)
53
+ - PMI < 0 → score < 0.5 (poor match)
54
+
55
+ Examples (scale=1.0):
56
+ PMI=2.0 → score≈0.88 (excellent)
57
+ PMI=1.0 → score≈0.73 (good)
58
+ PMI=0.0 → score=0.50 (neutral)
59
+ PMI=-1.0 → score≈0.27 (poor)
60
+ PMI=-2.0 → score≈0.12 (bad)
61
+ """
62
+ return 1.0 / (1.0 + math.exp(-pmi / scale))
63
+
64
+
65
+ def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
66
+ target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ """
68
+ Args:
69
+ llm_handler: The handler containing the model and tokenizer.
70
+ formatted_prompt: The input context.
71
+ target_text: The text we want to calculate probability/recall for.
72
+
73
+ Returns:
74
+ Tuple of (target_logits, target_ids)
75
+ - target_logits: Logits used to predict the target tokens.
76
+ - target_ids: The ground truth token IDs of the target.
77
+ """
78
+ model = llm_handler.get_hf_model_for_scoring()
79
+ tokenizer = llm_handler.llm_tokenizer
80
+ device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
81
+
82
+ # 1. Tokenize prompt ONLY to get its length (used for slicing later).
83
+ # We must ensure special tokens are added to count the offset correctly.
84
+ prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
85
+ prompt_len = prompt_tokens_temp['input_ids'].shape[1]
86
+
87
+ # 2. Tokenize the FULL text (Prompt + Target).
88
+ # This ensures subword merging at boundaries is handled correctly by the tokenizer.
89
+ full_text = formatted_prompt + target_text
90
+ full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
91
+
92
+ input_ids = full_tokens['input_ids']
93
+
94
+ # Safety check: if target was empty or truncated entirely
95
+ if input_ids.shape[1] <= prompt_len:
96
+ return torch.empty(0, device=device), torch.empty(0, device=device)
97
+
98
+ # 3. Forward Pass (Teacher Forcing)
99
+ with torch.no_grad():
100
+ with llm_handler._load_model_context():
101
+ outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
102
+ all_logits = outputs.logits # [1, seq_len, vocab_size]
103
+
104
+ # 4. Extract Logits and Labels
105
+ # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
106
+ # Target starts at index `prompt_len`.
107
+ # So we need logits from `prompt_len - 1` up to the second to last position.
108
+
109
+ target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
110
+ target_ids = input_ids[0, prompt_len:] # [target_len]
111
+
112
+ return target_logits, target_ids
113
+
114
+
115
+ # ==============================================================================
116
+ # Scoring Logic
117
+ # ==============================================================================
118
+
119
+
120
+ def _calculate_topk_recall(llm_handler,
121
+ formatted_prompt: str,
122
+ target_text: str,
123
+ topk: int = 10) -> Tuple[float, Dict[int, float]]:
124
+ """
125
+ Calculate top-k recall for target text given prompt.
126
+ Checks if the ground truth token is within the top-k probabilities at each step.
127
+ """
128
+ # Use the fixed helper to get aligned logits/labels
129
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
130
+
131
+ if target_ids.shape[0] == 0:
132
+ return 0.0, {}
133
+
134
+ target_len = target_ids.shape[0]
135
+
136
+ # Get top-k indices for all positions at once
137
+ # topk_indices: [target_len, topk]
138
+ _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
139
+
140
+ recall_per_k = {}
141
+ position_scores = []
142
+
143
+ # Convert to list for faster CPU iteration
144
+ target_ids_list = target_ids.tolist()
145
+ topk_indices_list = topk_indices.tolist()
146
+
147
+ for k in range(1, topk + 1):
148
+ hits = 0
149
+ for pos in range(target_len):
150
+ gt_token = target_ids_list[pos]
151
+ # Check the top-k slice
152
+ topk_at_pos = topk_indices_list[pos][:k]
153
+
154
+ if gt_token in topk_at_pos:
155
+ hits += 1
156
+ # Calculate position-weighted score only once (when k=topk)
157
+ if k == topk:
158
+ rank = topk_at_pos.index(gt_token) + 1
159
+ # Rank 1 = 1.0, Rank k = small positive
160
+ position_weight = 1.0 - (rank - 1) / topk
161
+ position_scores.append(position_weight)
162
+
163
+ recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
164
+
165
+ # Fill scores for positions where GT was NOT in top-k
166
+ while len(position_scores) < target_len:
167
+ position_scores.append(0.0)
168
+
169
+ average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
170
+
171
+ return average_recall, recall_per_k
172
+
173
+
174
+ def _calculate_metadata_recall(llm_handler,
175
+ formatted_prompt: str,
176
+ fields_dict: Dict[str, Any],
177
+ topk: int = 10) -> Dict[str, float]:
178
+ """
179
+ Args:
180
+ fields_dict: Dictionary of {field_name: field_value}
181
+ """
182
+ if not fields_dict:
183
+ return {}
184
+
185
+ field_scores = {}
186
+
187
+ for field_name in sorted(fields_dict.keys()):
188
+ # Construct target text for this specific field
189
+ # e.g. <think>\nbpm: 120\n</think>\n
190
+ field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
191
+ field_target_text = f"<think>\n{field_yaml}\n</think>\n"
192
+
193
+ # Calculate recall using the robust logic
194
+ avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
195
+
196
+ field_scores[field_name] = avg_score
197
+ logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
198
+
199
+ return field_scores
200
+
201
+
202
+ def _calculate_log_prob(
203
+ llm_handler,
204
+ formatted_prompt: str,
205
+ target_text: str,
206
+ temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
207
+ ) -> float:
208
+ """
209
+ Calculate average log probability of target text given prompt.
210
+ """
211
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
212
+
213
+ if target_ids.shape[0] == 0:
214
+ return float('-inf')
215
+
216
+ # FIX: Do not divide by temperature.
217
+ # Log-probability for PMI/Perplexity should be exact.
218
+
219
+ # Calculate log probabilities (log_softmax)
220
+ log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
221
+
222
+ # Gather log probabilities of the ground truth tokens
223
+ target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
224
+
225
+ # Return average log probability
226
+ mean_log_prob = target_log_probs.mean().item()
227
+
228
+ return mean_log_prob
229
+
230
+
231
+ def calculate_reward_score(
232
+ scores: Dict[str, float],
233
+ weights_config: Optional[Dict[str, float]] = None
234
+ ) -> Tuple[float, str]:
235
+ """
236
+ Reward Model Calculator: Computes a final reward based on user priorities.
237
+
238
+ Priority Logic:
239
+ 1. Caption (Highest): The overall vibe/style must match.
240
+ 2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
241
+ 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
242
+
243
+ Strategy: Dynamic Weighted Sum
244
+ - Metadata fields are aggregated into a single 'metadata' score first.
245
+ - Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
246
+
247
+ Args:
248
+ scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
249
+ weights_config: Optional custom weights. Defaults to:
250
+ Caption (50%), Lyrics (30%), Metadata (20%).
251
+
252
+ Returns:
253
+ final_reward: The calculated reward score (0.0 - 1.0).
254
+ explanation: A formatted string explaining how the score was derived.
255
+ """
256
+
257
+ # 1. Default Preference Configuration
258
+ # These weights determine the relative importance of each component.
259
+ if weights_config is None:
260
+ weights_config = {
261
+ 'caption': 0.50, # High priority: Style/Vibe
262
+ 'lyrics': 0.30, # Medium priority: Content
263
+ 'metadata': 0.20 # Low priority: Technical details
264
+ }
265
+
266
+ # 2. Extract and Group Scores
267
+ # Caption and Lyrics are standalone high-level features.
268
+ caption_score = scores.get('caption')
269
+ lyrics_score = scores.get('lyrics')
270
+
271
+ # Metadata fields (bpm, key, duration, etc.) are aggregated.
272
+ # We treat them as a single "Technical Score" to prevent them from
273
+ # diluting the weight of Caption/Lyrics simply by having many fields.
274
+ meta_scores_list = [
275
+ val for key, val in scores.items()
276
+ if key not in ['caption', 'lyrics']
277
+ ]
278
+
279
+ # Calculate average of all metadata fields (if any exist)
280
+ meta_aggregate_score = None
281
+ if meta_scores_list:
282
+ meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
283
+
284
+ # 3. specific Active Components & Dynamic Weighting
285
+ # We only include components that actually exist in this generation.
286
+ active_components = {}
287
+
288
+ if caption_score is not None:
289
+ active_components['caption'] = (caption_score, weights_config['caption'])
290
+
291
+ if lyrics_score is not None:
292
+ active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
293
+
294
+ if meta_aggregate_score is not None:
295
+ active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
296
+
297
+ # 4. Calculate Final Weighted Score
298
+ total_base_weight = sum(w for _, w in active_components.values())
299
+ total_score = 0.0
300
+
301
+ breakdown_lines = []
302
+
303
+ if total_base_weight == 0:
304
+ return 0.0, "❌ No valid scores available to calculate reward."
305
+
306
+ # Sort by weight (importance) for display
307
+ sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
308
+
309
+ for name, (score, base_weight) in sorted_components:
310
+ # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
311
+ normalized_weight = base_weight / total_base_weight
312
+ weighted_contribution = score * normalized_weight
313
+ total_score += weighted_contribution
314
+
315
+ breakdown_lines.append(
316
+ f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
317
+ f"-> Contrib: +{weighted_contribution:.4f}"
318
+ )
319
+
320
+ return total_score, "\n".join(breakdown_lines)
321
+
322
+ # ==============================================================================
323
+ # Main Public API
324
+ # ==============================================================================
325
+
326
+
327
+ def calculate_pmi_score_per_condition(
328
+ llm_handler,
329
+ audio_codes: str,
330
+ caption: str = "",
331
+ lyrics: str = "",
332
+ metadata: Optional[Dict[str, Any]] = None,
333
+ temperature: float = 1.0,
334
+ topk: int = 10,
335
+ score_scale: float = 0.1,
336
+ ) -> Tuple[Dict[str, float], float, str]:
337
+ """
338
+ Calculate quality score separately for each condition.
339
+ - Metadata: Uses Top-k Recall.
340
+ - Caption/Lyrics: Uses PMI (Normalized).
341
+ """
342
+ if not llm_handler.llm_initialized:
343
+ return {}, 0.0, "❌ LLM not initialized"
344
+
345
+ if not audio_codes or not audio_codes.strip():
346
+ return {}, 0.0, "❌ No audio codes provided"
347
+
348
+ if "caption" not in metadata:
349
+ metadata['caption'] = caption
350
+
351
+ formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
352
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
353
+ try:
354
+ # 1. Calculate Recall for Metadata Fields
355
+ if metadata and isinstance(metadata, dict):
356
+ scores = {}
357
+ # Define which fields use which metric
358
+ metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
359
+ metadata_pmi_keys = ['caption']
360
+ for key in metadata_recall_keys:
361
+ if key in metadata and metadata[key] is not None:
362
+ recall_metadata = {key: metadata[key]}
363
+ field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
364
+ scores.update(field_scores)
365
+
366
+ # 2. Calculate PMI for Caption
367
+ for key in metadata_pmi_keys:
368
+ if key in metadata and metadata[key] is not None:
369
+ cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
370
+ target_text = f"<think>\n{cot_yaml}\n</think>\n"
371
+
372
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
373
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
374
+
375
+ pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
376
+ scores[key] = pmi_normalized
377
+
378
+ # 3. Calculate PMI for Lyrics
379
+ if lyrics:
380
+ target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
381
+
382
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
383
+
384
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
385
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
386
+
387
+ scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
388
+
389
+ if not scores:
390
+ return {}, 0.0, "❌ No conditions to evaluate"
391
+
392
+ # 4. Global Score
393
+ global_score = sum(scores.values()) / len(scores)
394
+ global_score, breakdown_lines = calculate_reward_score(scores)
395
+
396
+ # Status Message
397
+ status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
398
+ for key, score in sorted(scores.items()):
399
+ metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
400
+ status_lines.append(f" {key}: {score:.4f} ({metric})")
401
+ status = "\n".join(status_lines)
402
+ logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
403
+ return scores, global_score, status
404
+
405
+ except Exception as e:
406
+ import traceback
407
+ error_msg = f"❌ Error: {str(e)}"
408
+ logger.error(error_msg)
409
+ logger.error(traceback.format_exc())
410
+ return {}, float('-inf'), error_msg
code/acestep/third_parts/nano-vllm/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Xingkai Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
code/acestep/third_parts/nano-vllm/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img width="300" src="assets/logo.png">
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
7
+ </p>
8
+
9
+ # Nano-vLLM
10
+
11
+ A lightweight vLLM implementation built from scratch.
12
+
13
+ ## Key Features
14
+
15
+ * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
16
+ * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
17
+ * ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
23
+ ```
24
+
25
+ ## Model Download
26
+
27
+ To download the model weights manually, use the following command:
28
+ ```bash
29
+ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
30
+ --local-dir ~/huggingface/Qwen3-0.6B/ \
31
+ --local-dir-use-symlinks False
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
37
+ ```python
38
+ from nanovllm import LLM, SamplingParams
39
+ llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
40
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
41
+ prompts = ["Hello, Nano-vLLM."]
42
+ outputs = llm.generate(prompts, sampling_params)
43
+ outputs[0]["text"]
44
+ ```
45
+
46
+ ## Benchmark
47
+
48
+ See `bench.py` for benchmark.
49
+
50
+ **Test Configuration:**
51
+ - Hardware: RTX 4070 Laptop (8GB)
52
+ - Model: Qwen3-0.6B
53
+ - Total Requests: 256 sequences
54
+ - Input Length: Randomly sampled between 100–1024 tokens
55
+ - Output Length: Randomly sampled between 100–1024 tokens
56
+
57
+ **Performance Results:**
58
+ | Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
59
+ |----------------|-------------|----------|-----------------------|
60
+ | vLLM | 133,966 | 98.37 | 1361.84 |
61
+ | Nano-vLLM | 133,966 | 93.41 | 1434.13 |
62
+
63
+
64
+ ## Star History
65
+
66
+ [![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
code/acestep/third_parts/nano-vllm/assets/logo.png ADDED

Git LFS Details

  • SHA256: 03ec4039dc248e97e9943694d3ccfb52c1a73a6dab94c4cd6fd4288e08de98c8
  • Pointer size: 131 Bytes
  • Size of remote file: 397 kB
code/acestep/third_parts/nano-vllm/bench.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from random import randint, seed
4
+ from nanovllm import LLM, SamplingParams
5
+ # from vllm import LLM, SamplingParams
6
+
7
+
8
+ def main():
9
+ seed(0)
10
+ num_seqs = 256
11
+ max_input_len = 1024
12
+ max_ouput_len = 1024
13
+
14
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
15
+ llm = LLM(path, enforce_eager=False, max_model_len=4096)
16
+
17
+ prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
18
+ sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
19
+ # uncomment the following line for vllm
20
+ # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
21
+
22
+ llm.generate(["Benchmark: "], SamplingParams())
23
+ t = time.time()
24
+ llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
25
+ t = (time.time() - t)
26
+ total_tokens = sum(sp.max_tokens for sp in sampling_params)
27
+ throughput = total_tokens / t
28
+ print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
code/acestep/third_parts/nano-vllm/example.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from nanovllm import LLM, SamplingParams
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ def main():
7
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
+ llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
10
+
11
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
12
+ prompts = [
13
+ "introduce yourself",
14
+ "list all prime numbers within 100",
15
+ ]
16
+ prompts = [
17
+ tokenizer.apply_chat_template(
18
+ [{"role": "user", "content": prompt}],
19
+ tokenize=False,
20
+ add_generation_prompt=True,
21
+ )
22
+ for prompt in prompts
23
+ ]
24
+ outputs = llm.generate(prompts, sampling_params)
25
+
26
+ for prompt, output in zip(prompts, outputs):
27
+ print("\n")
28
+ print(f"Prompt: {prompt!r}")
29
+ print(f"Completion: {output['text']!r}")
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
code/acestep/third_parts/nano-vllm/nanovllm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nanovllm.llm import LLM
2
+ from nanovllm.sampling_params import SamplingParams
code/acestep/third_parts/nano-vllm/nanovllm/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from transformers import AutoConfig
4
+
5
+
6
+ @dataclass
7
+ class Config:
8
+ model: str
9
+ max_num_batched_tokens: int = 16384
10
+ max_num_seqs: int = 512
11
+ max_model_len: int = 4096
12
+ gpu_memory_utilization: float = 0.9
13
+ tensor_parallel_size: int = 1
14
+ enforce_eager: bool = False
15
+ hf_config: AutoConfig | None = None
16
+ eos: int = -1
17
+ kvcache_block_size: int = 256
18
+ num_kvcache_blocks: int = -1
19
+
20
+ def __post_init__(self):
21
+ assert os.path.isdir(self.model)
22
+ assert self.kvcache_block_size % 256 == 0
23
+ assert 1 <= self.tensor_parallel_size <= 8
24
+ self.hf_config = AutoConfig.from_pretrained(self.model)
25
+ self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
26
+ assert self.max_num_batched_tokens >= self.max_model_len
code/acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import xxhash
3
+ import numpy as np
4
+
5
+ from nanovllm.engine.sequence import Sequence
6
+
7
+
8
+ class Block:
9
+
10
+ def __init__(self, block_id):
11
+ self.block_id = block_id
12
+ self.ref_count = 0
13
+ self.hash = -1
14
+ self.token_ids = []
15
+
16
+ def update(self, hash: int, token_ids: list[int]):
17
+ self.hash = hash
18
+ self.token_ids = token_ids
19
+
20
+ def reset(self):
21
+ self.ref_count = 1
22
+ self.hash = -1
23
+ self.token_ids = []
24
+
25
+
26
+ class BlockManager:
27
+
28
+ def __init__(self, num_blocks: int, block_size: int):
29
+ self.block_size = block_size
30
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
31
+ self.hash_to_block_id: dict[int, int] = dict()
32
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
33
+ self.used_block_ids: set[int] = set()
34
+
35
+ @classmethod
36
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
37
+ h = xxhash.xxh64()
38
+ if prefix != -1:
39
+ h.update(prefix.to_bytes(8, "little"))
40
+ h.update(np.array(token_ids).tobytes())
41
+ return h.intdigest()
42
+
43
+ def _allocate_block(self, block_id: int) -> Block:
44
+ block = self.blocks[block_id]
45
+ assert block.ref_count == 0
46
+ block.reset()
47
+ self.free_block_ids.remove(block_id)
48
+ self.used_block_ids.add(block_id)
49
+ return self.blocks[block_id]
50
+
51
+ def _deallocate_block(self, block_id: int) -> Block:
52
+ assert self.blocks[block_id].ref_count == 0
53
+ self.used_block_ids.remove(block_id)
54
+ self.free_block_ids.append(block_id)
55
+
56
+ def can_allocate(self, seq: Sequence) -> bool:
57
+ return len(self.free_block_ids) >= seq.num_blocks
58
+
59
+ def allocate(self, seq: Sequence):
60
+ assert not seq.block_table
61
+ h = -1
62
+ cache_miss = False
63
+ for i in range(seq.num_blocks):
64
+ token_ids = seq.block(i)
65
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
66
+ block_id = self.hash_to_block_id.get(h, -1)
67
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
68
+ cache_miss = True
69
+ if cache_miss:
70
+ block_id = self.free_block_ids[0]
71
+ block = self._allocate_block(block_id)
72
+ else:
73
+ seq.num_cached_tokens += self.block_size
74
+ if block_id in self.used_block_ids:
75
+ block = self.blocks[block_id]
76
+ block.ref_count += 1
77
+ else:
78
+ block = self._allocate_block(block_id)
79
+ if h != -1:
80
+ block.update(h, token_ids)
81
+ self.hash_to_block_id[h] = block_id
82
+ seq.block_table.append(block_id)
83
+
84
+ def deallocate(self, seq: Sequence):
85
+ for block_id in reversed(seq.block_table):
86
+ block = self.blocks[block_id]
87
+ block.ref_count -= 1
88
+ if block.ref_count == 0:
89
+ self._deallocate_block(block_id)
90
+ seq.num_cached_tokens = 0
91
+ seq.block_table.clear()
92
+
93
+ def can_append(self, seq: Sequence) -> bool:
94
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
95
+
96
+ def may_append(self, seq: Sequence):
97
+ block_table = seq.block_table
98
+ last_block = self.blocks[block_table[-1]]
99
+ if len(seq) % self.block_size == 1:
100
+ assert last_block.hash != -1
101
+ block_id = self.free_block_ids[0]
102
+ self._allocate_block(block_id)
103
+ block_table.append(block_id)
104
+ elif len(seq) % self.block_size == 0:
105
+ assert last_block.hash == -1
106
+ token_ids = seq.block(seq.num_blocks-1)
107
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
108
+ h = self.compute_hash(token_ids, prefix)
109
+ last_block.update(h, token_ids)
110
+ self.hash_to_block_id[h] = last_block.block_id
111
+ else:
112
+ assert last_block.hash == -1
code/acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from dataclasses import fields
3
+ from time import perf_counter
4
+ from tqdm.auto import tqdm
5
+ from transformers import AutoTokenizer
6
+ import torch.multiprocessing as mp
7
+
8
+ from nanovllm.config import Config
9
+ from nanovllm.sampling_params import SamplingParams
10
+ from nanovllm.engine.sequence import Sequence
11
+ from nanovllm.engine.scheduler import Scheduler
12
+ from nanovllm.engine.model_runner import ModelRunner
13
+
14
+
15
+ class LLMEngine:
16
+
17
+ def __init__(self, model, **kwargs):
18
+ config_fields = {field.name for field in fields(Config)}
19
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
20
+ config = Config(model, **config_kwargs)
21
+ self.ps = []
22
+ self.events = []
23
+ ctx = mp.get_context("spawn")
24
+ for i in range(1, config.tensor_parallel_size):
25
+ event = ctx.Event()
26
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
27
+ process.start()
28
+ self.ps.append(process)
29
+ self.events.append(event)
30
+ self.model_runner = ModelRunner(config, 0, self.events)
31
+ tokenizer = kwargs.get("tokenizer", None)
32
+ if tokenizer is not None:
33
+ self.tokenizer = tokenizer
34
+ else:
35
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
36
+ config.eos = self.tokenizer.eos_token_id
37
+ self.scheduler = Scheduler(config)
38
+ atexit.register(self.exit)
39
+
40
+ def exit(self):
41
+ self.model_runner.call("exit")
42
+ del self.model_runner
43
+ for p in self.ps:
44
+ p.join()
45
+
46
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
47
+ if isinstance(prompt, str):
48
+ prompt = self.tokenizer.encode(prompt)
49
+ # For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
50
+ if sampling_params.cfg_scale > 1.0:
51
+ if unconditional_prompt is None:
52
+ # Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
53
+ # This is a fallback - ideally users should provide unconditional_prompt
54
+ if isinstance(prompt, list):
55
+ # For now, just use the same prompt (user should provide unconditional_prompt)
56
+ # TODO: Implement automatic "NO USER INPUT" replacement if possible
57
+ unconditional_prompt = prompt
58
+ else:
59
+ unconditional_prompt = prompt
60
+ if isinstance(unconditional_prompt, str):
61
+ unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
62
+ # Create unconditional sequence first (so we can reference it from conditional)
63
+ uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
64
+ # Create conditional sequence with reference to unconditional
65
+ cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
66
+ uncond_seq.paired_seq = cond_seq # Link them bidirectionally
67
+ # Add both sequences to scheduler
68
+ self.scheduler.add(cond_seq)
69
+ self.scheduler.add(uncond_seq)
70
+ else:
71
+ seq = Sequence(prompt, sampling_params)
72
+ self.scheduler.add(seq)
73
+
74
+ def step(self):
75
+ seqs, is_prefill = self.scheduler.schedule()
76
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
77
+ self.scheduler.postprocess(seqs, token_ids)
78
+ # Only output conditional sequences (unconditional sequences are just for CFG computation)
79
+ output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
80
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
81
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
82
+ return outputs, num_tokens
83
+
84
+ def is_finished(self):
85
+ return self.scheduler.is_finished()
86
+
87
+ def generate(
88
+ self,
89
+ prompts: list[str] | list[list[int]],
90
+ sampling_params: SamplingParams | list[SamplingParams],
91
+ use_tqdm: bool = True,
92
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
93
+ ) -> list[str]:
94
+ if use_tqdm:
95
+ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
96
+ if not isinstance(sampling_params, list):
97
+ sampling_params = [sampling_params] * len(prompts)
98
+ if unconditional_prompts is None:
99
+ unconditional_prompts = [None] * len(prompts)
100
+ for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
101
+ self.add_request(prompt, sp, uncond_prompt)
102
+ outputs = {}
103
+ prefill_throughput = decode_throughput = 0.
104
+ while not self.is_finished():
105
+ t = perf_counter()
106
+ output, num_tokens = self.step()
107
+ if use_tqdm:
108
+ if num_tokens > 0:
109
+ prefill_throughput = num_tokens / (perf_counter() - t)
110
+ else:
111
+ decode_throughput = -num_tokens / (perf_counter() - t)
112
+ pbar.set_postfix({
113
+ "Prefill": f"{int(prefill_throughput)}tok/s",
114
+ "Decode": f"{int(decode_throughput)}tok/s",
115
+ })
116
+ for seq_id, token_ids in output:
117
+ outputs[seq_id] = token_ids
118
+ if use_tqdm:
119
+ pbar.update(1)
120
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
121
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
122
+ if use_tqdm:
123
+ pbar.close()
124
+ return outputs
code/acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import torch.distributed as dist
4
+ from multiprocessing.synchronize import Event
5
+ from multiprocessing.shared_memory import SharedMemory
6
+ import sys
7
+
8
+ from nanovllm.config import Config
9
+ from nanovllm.engine.sequence import Sequence
10
+ from nanovllm.models.qwen3 import Qwen3ForCausalLM
11
+ from nanovllm.layers.sampler import Sampler
12
+ from nanovllm.utils.context import set_context, get_context, reset_context
13
+ from nanovllm.utils.loader import load_model
14
+
15
+ import socket
16
+
17
+
18
+ def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
19
+ """Find an available port starting from start_port.
20
+
21
+ Args:
22
+ start_port: The starting port number to check
23
+ max_attempts: Maximum number of ports to try
24
+
25
+ Returns:
26
+ An available port number
27
+
28
+ Raises:
29
+ RuntimeError: If no available port is found within max_attempts
30
+ """
31
+ for i in range(max_attempts):
32
+ port = start_port + i
33
+ try:
34
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
35
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
36
+ s.bind(('localhost', port))
37
+ return port
38
+ except OSError:
39
+ # Port is in use, try next one
40
+ continue
41
+ raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
42
+
43
+
44
+ class ModelRunner:
45
+
46
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
47
+ # Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
48
+ torch._dynamo.config.capture_scalar_outputs = True
49
+
50
+ self.config = config
51
+ hf_config = config.hf_config
52
+ self.block_size = config.kvcache_block_size
53
+ self.enforce_eager = config.enforce_eager
54
+ self.world_size = config.tensor_parallel_size
55
+ self.rank = rank
56
+ self.event = event
57
+ dist_port = find_available_port()
58
+ print(f"[debug]dist_port: {dist_port}")
59
+ # Use gloo backend on Windows, nccl on Linux/other platforms
60
+ backend = "gloo" if sys.platform == "win32" else "nccl"
61
+ dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
62
+ torch.cuda.set_device(rank)
63
+ default_dtype = torch.get_default_dtype()
64
+ # Use dtype instead of deprecated torch_dtype
65
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
66
+ torch.set_default_dtype(config_dtype)
67
+ torch.set_default_device("cuda")
68
+ self.model = Qwen3ForCausalLM(hf_config)
69
+ load_model(self.model, config.model)
70
+ self.sampler = Sampler()
71
+
72
+ # Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
73
+ # Must be called before warmup_model() since it uses these buffers
74
+ self._allocate_sample_buffers()
75
+
76
+ self.warmup_model()
77
+ self.allocate_kv_cache()
78
+ if not self.enforce_eager:
79
+ self.capture_cudagraph()
80
+
81
+ torch.set_default_device("cpu")
82
+ torch.set_default_dtype(default_dtype)
83
+
84
+ if self.world_size > 1:
85
+ if rank == 0:
86
+ self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
87
+ dist.barrier()
88
+ else:
89
+ dist.barrier()
90
+ self.shm = SharedMemory(name="nanovllm")
91
+ self.loop()
92
+
93
+ def _allocate_sample_buffers(self):
94
+ """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
+ max_bs = self.config.max_num_seqs
96
+ max_tokens = self.config.max_num_batched_tokens
97
+ max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
98
+
99
+ # Pre-allocate pinned memory buffers on CPU for fast transfer
100
+ # Must explicitly specify device="cpu" since default device may be "cuda"
101
+ self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
102
+ self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
103
+ self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
104
+ self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
105
+ self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
106
+
107
+ # Pre-allocate decode buffers on CPU with pinned memory
108
+ self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
109
+ self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
110
+ self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
111
+ self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
112
+
113
+ # Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
114
+ self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
115
+ self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
116
+ self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
117
+ self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
118
+
119
+ # Pre-allocate block tables buffer (shared by both decode and prefill)
120
+ self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
121
+
122
+ # Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
123
+ # Max length is max_model_len since sequences can be that long
124
+ self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
125
+
126
+ def exit(self):
127
+ if self.world_size > 1:
128
+ self.shm.close()
129
+ dist.barrier()
130
+ if self.rank == 0:
131
+ self.shm.unlink()
132
+ if not self.enforce_eager:
133
+ del self.graphs, self.graph_pool
134
+ torch.cuda.synchronize()
135
+ dist.destroy_process_group()
136
+
137
+ def loop(self):
138
+ while True:
139
+ method_name, args = self.read_shm()
140
+ self.call(method_name, *args)
141
+ if method_name == "exit":
142
+ break
143
+
144
+ def read_shm(self):
145
+ assert self.world_size > 1 and self.rank > 0
146
+ self.event.wait()
147
+ n = int.from_bytes(self.shm.buf[0:4], "little")
148
+ method_name, *args = pickle.loads(self.shm.buf[4:n+4])
149
+ self.event.clear()
150
+ return method_name, args
151
+
152
+ def write_shm(self, method_name, *args):
153
+ assert self.world_size > 1 and self.rank == 0
154
+ data = pickle.dumps([method_name, *args])
155
+ n = len(data)
156
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
157
+ self.shm.buf[4:n+4] = data
158
+ for event in self.event:
159
+ event.set()
160
+
161
+ def call(self, method_name, *args):
162
+ if self.world_size > 1 and self.rank == 0:
163
+ self.write_shm(method_name, *args)
164
+ method = getattr(self, method_name, None)
165
+ return method(*args)
166
+
167
+ def warmup_model(self):
168
+ torch.cuda.empty_cache()
169
+ torch.cuda.reset_peak_memory_stats()
170
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
171
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
172
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
173
+ self.run(seqs, True)
174
+ torch.cuda.empty_cache()
175
+
176
+ def allocate_kv_cache(self):
177
+ config = self.config
178
+ hf_config = config.hf_config
179
+ free, total = torch.cuda.mem_get_info()
180
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
181
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
182
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
183
+ # Use dtype instead of deprecated torch_dtype
184
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
185
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * config_dtype.itemsize
186
+
187
+ # Calculate available memory for KV cache
188
+ # After warmup_model, empty_cache has been called, so current represents model memory only
189
+ # Use free memory but respect the gpu_memory_utilization limit
190
+ target_total_usage = total * config.gpu_memory_utilization
191
+ available_for_kv_cache = min(free * 0.9, target_total_usage - current)
192
+
193
+ # Ensure we have positive memory available
194
+ if available_for_kv_cache <= 0:
195
+ available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
196
+
197
+ config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
198
+ if config.num_kvcache_blocks <= 0:
199
+ raise RuntimeError(
200
+ f"Insufficient GPU memory for KV cache. "
201
+ f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
202
+ f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
203
+ f"Block size: {block_bytes / 1024**2:.2f} MB"
204
+ )
205
+ self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
206
+ layer_id = 0
207
+ for module in self.model.modules():
208
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
209
+ module.k_cache = self.kv_cache[0, layer_id]
210
+ module.v_cache = self.kv_cache[1, layer_id]
211
+ layer_id += 1
212
+
213
+ def prepare_block_tables(self, seqs: list[Sequence]):
214
+ max_len = max(len(seq.block_table) for seq in seqs)
215
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
216
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
217
+ return block_tables
218
+
219
+ def prepare_prefill(self, seqs: list[Sequence]):
220
+ input_ids = []
221
+ positions = []
222
+ cu_seqlens_q = [0]
223
+ cu_seqlens_k = [0]
224
+ max_seqlen_q = 0
225
+ max_seqlen_k = 0
226
+ slot_mapping = []
227
+ block_tables = None
228
+ for seq in seqs:
229
+ seqlen = len(seq)
230
+ input_ids.extend(seq[seq.num_cached_tokens:])
231
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
232
+ seqlen_q = seqlen - seq.num_cached_tokens
233
+ seqlen_k = seqlen
234
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
235
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
236
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
237
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
238
+ if not seq.block_table: # warmup
239
+ continue
240
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
241
+ start = seq.block_table[i] * self.block_size
242
+ if i != seq.num_blocks - 1:
243
+ end = start + self.block_size
244
+ else:
245
+ end = start + seq.last_block_num_tokens
246
+ slot_mapping.extend(list(range(start, end)))
247
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
248
+ block_tables = self.prepare_block_tables(seqs)
249
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
250
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
251
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
252
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
253
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
254
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
255
+ return input_ids, positions
256
+
257
+ def prepare_decode(self, seqs: list[Sequence]):
258
+ """Optimized decode preparation using pre-allocated buffers."""
259
+ bs = len(seqs)
260
+
261
+ # Use pre-allocated CPU buffers
262
+ for i, seq in enumerate(seqs):
263
+ self._cpu_input_ids[i] = seq.last_token
264
+ self._cpu_positions[i] = len(seq) - 1
265
+ self._cpu_context_lens[i] = len(seq)
266
+ self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
267
+
268
+ # Transfer to GPU using sliced views
269
+ input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
270
+ positions = self._cpu_positions[:bs].cuda(non_blocking=True)
271
+ slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
272
+ context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
273
+ block_tables = self.prepare_block_tables(seqs)
274
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
275
+ return input_ids, positions
276
+
277
+ def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
278
+ """Optimized sample preparation using pre-allocated buffers."""
279
+ if is_cfg_batch:
280
+ num_seqs = len(seqs) // 2
281
+ target_seqs = seqs[:num_seqs]
282
+ else:
283
+ num_seqs = len(seqs)
284
+ target_seqs = seqs
285
+
286
+ # Fill pre-allocated CPU buffers
287
+ top_ks_is_zero = True
288
+ top_ps_is_one = True
289
+ repetition_penalties_is_one = True
290
+ for i, seq in enumerate(target_seqs):
291
+ self._cpu_temperatures[i] = seq.temperature
292
+ self._cpu_cfg_scales[i] = seq.cfg_scale
293
+ self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
294
+ if seq.top_k is not None and seq.top_k > 0:
295
+ top_ks_is_zero = False
296
+ self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
297
+ if seq.top_p is not None and seq.top_p == 1.0:
298
+ top_ps_is_one = False
299
+ self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
300
+ if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
301
+ repetition_penalties_is_one = False
302
+
303
+ # Transfer to GPU using sliced views (single batched transfer)
304
+ temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
305
+ cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
306
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
307
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
308
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
309
+
310
+ return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
311
+
312
+ @torch.inference_mode()
313
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
314
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
315
+ return self.model.compute_logits(self.model(input_ids, positions))
316
+ else:
317
+ bs = input_ids.size(0)
318
+ context = get_context()
319
+
320
+ # Check if block_tables size exceeds pre-allocated buffer size
321
+ # This can happen when conditional and unconditional sequences have different lengths
322
+ # in CFG mode, causing block_tables to have more columns than expected
323
+ max_num_blocks = self.graph_vars["block_tables"].size(1)
324
+ if context.block_tables.size(1) > max_num_blocks:
325
+ # Fall back to eager mode when block_tables is too large for CUDA graph
326
+ return self.model.compute_logits(self.model(input_ids, positions))
327
+
328
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
329
+ graph_vars = self.graph_vars
330
+ graph_vars["input_ids"][:bs] = input_ids
331
+ graph_vars["positions"][:bs] = positions
332
+ graph_vars["slot_mapping"].fill_(-1)
333
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
334
+ graph_vars["context_lens"].zero_()
335
+ graph_vars["context_lens"][:bs] = context.context_lens
336
+ # Clear block_tables first to ensure no stale data from previous runs
337
+ graph_vars["block_tables"][:bs].fill_(-1)
338
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
339
+ graph.replay()
340
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
341
+
342
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
343
+ """Run model forward and sampling. For CFG sequences, batch is structured as:
344
+ [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
345
+ where uncond_seqi is the paired unconditional sequence of cond_seqi."""
346
+ # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
347
+ is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
348
+ if is_cfg_batch:
349
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
350
+ num_cond = len(seqs) // 2
351
+ cond_seqs = seqs[:num_cond]
352
+ # uncond_seqs = seqs[num_cond:]
353
+
354
+ # Prepare inputs for both conditional and unconditional (they're already in the batch)
355
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
356
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
357
+ if sample_params is not None:
358
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
359
+ else:
360
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
361
+
362
+ # Run model forward (processes entire batch: cond + uncond)
363
+ logits_all = self.run_model(input_ids, positions, is_prefill)
364
+ reset_context()
365
+
366
+ if self.rank == 0:
367
+ # Split logits: first half is conditional, second half is unconditional
368
+ logits_cond = logits_all[:num_cond]
369
+ logits_uncond = logits_all[num_cond:]
370
+
371
+ # Apply repetition penalty to conditional logits (before CFG)
372
+ if repetition_penalties is not None:
373
+ for i, seq in enumerate(cond_seqs):
374
+ penalty = repetition_penalties[i].item()
375
+ if penalty != 1.0:
376
+ # Only penalize completion tokens (not prompt tokens)
377
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
378
+ if len(completion_tokens) > 0:
379
+ # Create token mask: mark tokens that appeared in completion
380
+ token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
381
+ token_mask[completion_tokens] = True
382
+
383
+ # Apply standard repetition penalty formula (matching transformers implementation):
384
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
385
+ penalty_scores = torch.where(
386
+ logits_cond[i] < 0,
387
+ logits_cond[i] * penalty,
388
+ logits_cond[i] / penalty
389
+ )
390
+ # Only apply penalty to tokens that appeared in completion
391
+ logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
392
+
393
+ # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
394
+ cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
395
+ logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
396
+
397
+ # Apply logits processor for constrained decoding (if any sequence has one)
398
+ for i, seq in enumerate(cond_seqs):
399
+ if seq.logits_processor is not None:
400
+ # Create input_ids tensor for this sequence
401
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
402
+ # Apply processor to this sequence's logits
403
+ logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
404
+
405
+ # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
406
+ # cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
407
+
408
+ # Sample from CFG logits
409
+ token_ids_cfg = self.sampler(
410
+ logits_cfg,
411
+ temperatures,
412
+ top_ks=top_ks if top_ks is not None else None,
413
+ top_ps=top_ps if top_ps is not None else None,
414
+ repetition_penalties=None, # Already applied above
415
+ # input_ids=cond_input_ids,
416
+ ).tolist()
417
+
418
+ # Update logits processor state after sampling
419
+ for i, seq in enumerate(cond_seqs):
420
+ if seq.logits_processor_update_state is not None:
421
+ seq.logits_processor_update_state(token_ids_cfg[i])
422
+
423
+ # Return token_ids (will be applied to both conditional and unconditional sequences)
424
+ return token_ids_cfg
425
+ else:
426
+ return None
427
+ else:
428
+ # Normal batch (non-CFG)
429
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
430
+ else self.prepare_decode(seqs))
431
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
432
+ if sample_params is not None:
433
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
434
+ else:
435
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
436
+ logits = self.run_model(input_ids, positions, is_prefill)
437
+ reset_context()
438
+
439
+ if self.rank == 0:
440
+ # Apply repetition penalty to logits
441
+ if repetition_penalties is not None:
442
+ for i, seq in enumerate(seqs):
443
+ penalty = repetition_penalties[i].item()
444
+ if penalty != 1.0:
445
+ # Only penalize completion tokens (not prompt tokens)
446
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
447
+ if len(completion_tokens) > 0:
448
+ # Create token mask: mark tokens that appeared in completion
449
+ token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
450
+ token_mask[completion_tokens] = True
451
+
452
+ # Apply standard repetition penalty formula (matching transformers implementation):
453
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
454
+ penalty_scores = torch.where(
455
+ logits[i] < 0,
456
+ logits[i] * penalty,
457
+ logits[i] / penalty
458
+ )
459
+ # Only apply penalty to tokens that appeared in completion
460
+ logits[i] = torch.where(token_mask, penalty_scores, logits[i])
461
+
462
+ # Apply logits processor for constrained decoding (if any sequence has one)
463
+ # Clone logits to avoid in-place update issues in inference mode
464
+ logits = logits.clone()
465
+ for i, seq in enumerate(seqs):
466
+ if seq.logits_processor is not None:
467
+ # Create input_ids tensor for this sequence
468
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
469
+ # Apply processor to this sequence's logits (clone to avoid inference mode issues)
470
+ processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
471
+ logits[i] = processed[0]
472
+
473
+ # Prepare input_ids for sampler
474
+ # seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
475
+
476
+ token_ids = self.sampler(
477
+ logits,
478
+ temperatures,
479
+ top_ks=top_ks if top_ks is not None else None,
480
+ top_ps=top_ps if top_ps is not None else None,
481
+ repetition_penalties=None, # Already applied above
482
+ # input_ids=seq_input_ids,
483
+ ).tolist()
484
+
485
+ # Update logits processor state after sampling
486
+ for i, seq in enumerate(seqs):
487
+ if seq.logits_processor_update_state is not None:
488
+ seq.logits_processor_update_state(token_ids[i])
489
+
490
+ return token_ids
491
+ else:
492
+ return None
493
+
494
+ @torch.inference_mode()
495
+ def capture_cudagraph(self):
496
+ config = self.config
497
+ hf_config = config.hf_config
498
+ max_bs = min(self.config.max_num_seqs, 512)
499
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
500
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
501
+ positions = torch.zeros(max_bs, dtype=torch.int64)
502
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
503
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
504
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
505
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
506
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
507
+ self.graphs = {}
508
+ self.graph_pool = None
509
+
510
+ for bs in reversed(self.graph_bs):
511
+ graph = torch.cuda.CUDAGraph()
512
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
513
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
514
+ with torch.cuda.graph(graph, self.graph_pool):
515
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
516
+ if self.graph_pool is None:
517
+ self.graph_pool = graph.pool()
518
+ self.graphs[bs] = graph
519
+ torch.cuda.synchronize()
520
+ reset_context()
521
+
522
+ self.graph_vars = dict(
523
+ input_ids=input_ids,
524
+ positions=positions,
525
+ slot_mapping=slot_mapping,
526
+ context_lens=context_lens,
527
+ block_tables=block_tables,
528
+ outputs=outputs,
529
+ )
code/acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ from nanovllm.config import Config
4
+ from nanovllm.engine.sequence import Sequence, SequenceStatus
5
+ from nanovllm.engine.block_manager import BlockManager
6
+
7
+
8
+ class Scheduler:
9
+
10
+ def __init__(self, config: Config):
11
+ self.max_num_seqs = config.max_num_seqs
12
+ self.max_num_batched_tokens = config.max_num_batched_tokens
13
+ self.eos = config.eos
14
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
15
+ self.waiting: deque[Sequence] = deque()
16
+ self.running: deque[Sequence] = deque()
17
+
18
+ def is_finished(self):
19
+ return not self.waiting and not self.running
20
+
21
+ def add(self, seq: Sequence):
22
+ self.waiting.append(seq)
23
+
24
+ def schedule(self) -> tuple[list[Sequence], bool]:
25
+ # prefill
26
+ scheduled_seqs = []
27
+ num_seqs = 0
28
+ num_batched_tokens = 0
29
+ processed_seqs = set() # Track processed sequences to handle CFG pairs
30
+
31
+ while self.waiting and num_seqs < self.max_num_seqs:
32
+ seq = self.waiting[0]
33
+
34
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
35
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
36
+ # This is a conditional sequence, need to schedule its paired unconditional sequence too
37
+ paired_seq = seq.paired_seq
38
+ if paired_seq.status != SequenceStatus.WAITING:
39
+ # Paired sequence not in waiting, skip this conditional sequence for now
40
+ break
41
+
42
+ # Calculate tokens for both sequences
43
+ total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
44
+ can_allocate_both = (self.block_manager.can_allocate(seq) and
45
+ self.block_manager.can_allocate(paired_seq))
46
+
47
+ if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
48
+ break
49
+
50
+ # Schedule both sequences: conditional first, then unconditional
51
+ for s in [seq, paired_seq]:
52
+ num_seqs += 1
53
+ self.block_manager.allocate(s)
54
+ num_batched_tokens += len(s) - s.num_cached_tokens
55
+ s.status = SequenceStatus.RUNNING
56
+ self.waiting.remove(s)
57
+ self.running.append(s)
58
+ scheduled_seqs.append(s)
59
+ processed_seqs.add(s.seq_id)
60
+ else:
61
+ # Normal sequence or unconditional sequence (already processed with its conditional)
62
+ if seq.seq_id in processed_seqs:
63
+ # Skip if already processed as part of a CFG pair
64
+ self.waiting.popleft()
65
+ continue
66
+
67
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
68
+ break
69
+ num_seqs += 1
70
+ self.block_manager.allocate(seq)
71
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
72
+ seq.status = SequenceStatus.RUNNING
73
+ self.waiting.popleft()
74
+ self.running.append(seq)
75
+ scheduled_seqs.append(seq)
76
+
77
+ if scheduled_seqs:
78
+ # For CFG batches, ensure conditional sequences come before their unconditional pairs
79
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
80
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
81
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
82
+
83
+ # Reorder: non-CFG, then CFG conditional, then CFG unconditional
84
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
85
+ return scheduled_seqs, True
86
+
87
+ # decode
88
+ processed_seqs = set()
89
+ temp_running = list(self.running) # Work with a copy
90
+
91
+ while temp_running and num_seqs < self.max_num_seqs:
92
+ seq = temp_running.pop(0)
93
+
94
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
95
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
96
+ paired_seq = seq.paired_seq
97
+ if paired_seq not in temp_running:
98
+ # Paired sequence not available, skip for now
99
+ continue
100
+
101
+ # Remove paired_seq from temp_running
102
+ temp_running.remove(paired_seq)
103
+
104
+ # Check if both can append
105
+ can_append_both = (self.block_manager.can_append(seq) and
106
+ self.block_manager.can_append(paired_seq))
107
+
108
+ if not can_append_both:
109
+ # Try preempting other sequences
110
+ preempted = False
111
+ while not can_append_both and temp_running:
112
+ other_seq = temp_running.pop(0)
113
+ if other_seq != seq and other_seq != paired_seq:
114
+ self.preempt(other_seq)
115
+ can_append_both = (self.block_manager.can_append(seq) and
116
+ self.block_manager.can_append(paired_seq))
117
+ preempted = True
118
+ else:
119
+ temp_running.append(other_seq)
120
+ break
121
+
122
+ if not can_append_both:
123
+ # Can't schedule this pair right now
124
+ temp_running.append(seq)
125
+ temp_running.append(paired_seq)
126
+ continue
127
+
128
+ # Schedule both sequences
129
+ for s in [seq, paired_seq]:
130
+ num_seqs += 1
131
+ self.block_manager.may_append(s)
132
+ scheduled_seqs.append(s)
133
+ processed_seqs.add(s.seq_id)
134
+ # Remove from actual running list if scheduled
135
+ if s in self.running:
136
+ self.running.remove(s)
137
+ else:
138
+ # Normal sequence or unconditional (already processed)
139
+ if seq.seq_id in processed_seqs:
140
+ continue
141
+
142
+ while not self.block_manager.can_append(seq):
143
+ if temp_running:
144
+ other_seq = temp_running.pop(0)
145
+ if other_seq != seq:
146
+ self.preempt(other_seq)
147
+ else:
148
+ temp_running.append(other_seq)
149
+ break
150
+ else:
151
+ self.preempt(seq)
152
+ if seq in self.running:
153
+ self.running.remove(seq)
154
+ break
155
+ else:
156
+ num_seqs += 1
157
+ self.block_manager.may_append(seq)
158
+ scheduled_seqs.append(seq)
159
+ if seq in self.running:
160
+ self.running.remove(seq)
161
+
162
+ assert scheduled_seqs
163
+
164
+ # For CFG batches in decode, ensure conditional sequences come before unconditional
165
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
166
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
167
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
168
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
169
+
170
+ self.running.extendleft(reversed(scheduled_seqs))
171
+ return scheduled_seqs, False
172
+
173
+ def preempt(self, seq: Sequence):
174
+ seq.status = SequenceStatus.WAITING
175
+ self.block_manager.deallocate(seq)
176
+ self.waiting.appendleft(seq)
177
+
178
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
179
+ # Check if this is a CFG batch
180
+ is_cfg_batch = False
181
+ if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
182
+ num_cond = len(seqs) // 2
183
+ is_cfg_batch = (num_cond > 0 and
184
+ not seqs[0].is_unconditional and
185
+ seqs[num_cond].is_unconditional)
186
+
187
+ if is_cfg_batch:
188
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
189
+ # token_ids correspond to conditional sequences only (sampled from CFG logits)
190
+ num_cond = len(seqs) // 2
191
+ cond_seqs = seqs[:num_cond]
192
+ uncond_seqs = seqs[num_cond:]
193
+
194
+ # Apply the same sampled token to both conditional and unconditional sequences
195
+ for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
196
+ cond_seq.append_token(token_id)
197
+ uncond_seq.append_token(token_id) # Same token for unconditional
198
+
199
+ # Check if either sequence is finished
200
+ cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
201
+ cond_seq.num_completion_tokens == cond_seq.max_tokens)
202
+ uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
203
+ uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
204
+
205
+ if cond_finished or uncond_finished:
206
+ # Mark both as finished
207
+ cond_seq.status = SequenceStatus.FINISHED
208
+ uncond_seq.status = SequenceStatus.FINISHED
209
+ self.block_manager.deallocate(cond_seq)
210
+ self.block_manager.deallocate(uncond_seq)
211
+ if cond_seq in self.running:
212
+ self.running.remove(cond_seq)
213
+ if uncond_seq in self.running:
214
+ self.running.remove(uncond_seq)
215
+ else:
216
+ # Normal batch
217
+ for seq, token_id in zip(seqs, token_ids):
218
+ seq.append_token(token_id)
219
+ if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
220
+ seq.status = SequenceStatus.FINISHED
221
+ self.block_manager.deallocate(seq)
222
+ self.running.remove(seq)
code/acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from enum import Enum, auto
3
+ from itertools import count
4
+ from typing import Optional, Callable, Any
5
+
6
+ from nanovllm.sampling_params import SamplingParams
7
+
8
+
9
+ class SequenceStatus(Enum):
10
+ WAITING = auto()
11
+ RUNNING = auto()
12
+ FINISHED = auto()
13
+
14
+
15
+ class Sequence:
16
+ block_size = 256
17
+ counter = count()
18
+
19
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
20
+ self.seq_id = next(Sequence.counter)
21
+ self.status = SequenceStatus.WAITING
22
+ self.token_ids = copy(token_ids)
23
+ self.last_token = token_ids[-1]
24
+ self.num_tokens = len(self.token_ids)
25
+ self.num_prompt_tokens = len(token_ids)
26
+ self.num_cached_tokens = 0
27
+ self.block_table = []
28
+ self.temperature = sampling_params.temperature
29
+ self.max_tokens = sampling_params.max_tokens
30
+ self.ignore_eos = sampling_params.ignore_eos
31
+ self.cfg_scale = sampling_params.cfg_scale
32
+ self.top_k = sampling_params.top_k
33
+ self.top_p = sampling_params.top_p
34
+ self.repetition_penalty = sampling_params.repetition_penalty
35
+ # For CFG: mark if this is an unconditional sequence
36
+ self.is_unconditional = is_unconditional
37
+ # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
38
+ # For conditional sequences, this points to the unconditional sequence
39
+ self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
40
+ # For constrained decoding: logits processor and state update callback
41
+ self.logits_processor: Optional[Any] = sampling_params.logits_processor
42
+ self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
43
+
44
+ def __len__(self):
45
+ return self.num_tokens
46
+
47
+ def __getitem__(self, key):
48
+ return self.token_ids[key]
49
+
50
+ @property
51
+ def is_finished(self):
52
+ return self.status == SequenceStatus.FINISHED
53
+
54
+ @property
55
+ def num_completion_tokens(self):
56
+ return self.num_tokens - self.num_prompt_tokens
57
+
58
+ @property
59
+ def prompt_token_ids(self):
60
+ return self.token_ids[:self.num_prompt_tokens]
61
+
62
+ @property
63
+ def completion_token_ids(self):
64
+ return self.token_ids[self.num_prompt_tokens:]
65
+
66
+ @property
67
+ def num_cached_blocks(self):
68
+ return self.num_cached_tokens // self.block_size
69
+
70
+ @property
71
+ def num_blocks(self):
72
+ return (self.num_tokens + self.block_size - 1) // self.block_size
73
+
74
+ @property
75
+ def last_block_num_tokens(self):
76
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
77
+
78
+ def block(self, i):
79
+ assert 0 <= i < self.num_blocks
80
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
81
+
82
+ def append_token(self, token_id: int):
83
+ self.token_ids.append(token_id)
84
+ self.last_token = token_id
85
+ self.num_tokens += 1
86
+
87
+ def __getstate__(self):
88
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
89
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
90
+
91
+ def __setstate__(self, state):
92
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
93
+ if self.num_completion_tokens == 0:
94
+ self.token_ids = state[-1]
95
+ else:
96
+ self.last_token = state[-1]
code/acestep/third_parts/nano-vllm/nanovllm/layers/activation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SiluAndMul(nn.Module):
7
+
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ @torch.compile
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ x, y = x.chunk(2, -1)
14
+ return F.silu(x) * y
code/acestep/third_parts/nano-vllm/nanovllm/layers/attention.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
7
+ from nanovllm.utils.context import get_context
8
+
9
+
10
+ @triton.jit
11
+ def store_kvcache_kernel(
12
+ key_ptr,
13
+ key_stride,
14
+ value_ptr,
15
+ value_stride,
16
+ k_cache_ptr,
17
+ v_cache_ptr,
18
+ slot_mapping_ptr,
19
+ D: tl.constexpr,
20
+ ):
21
+ idx = tl.program_id(0)
22
+ slot = tl.load(slot_mapping_ptr + idx)
23
+ if slot == -1: return
24
+ key_offsets = idx * key_stride + tl.arange(0, D)
25
+ value_offsets = idx * value_stride + tl.arange(0, D)
26
+ key = tl.load(key_ptr + key_offsets)
27
+ value = tl.load(value_ptr + value_offsets)
28
+ cache_offsets = slot * D + tl.arange(0, D)
29
+ tl.store(k_cache_ptr + cache_offsets, key)
30
+ tl.store(v_cache_ptr + cache_offsets, value)
31
+
32
+
33
+ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
34
+ N, num_heads, head_dim = key.shape
35
+ D = num_heads * head_dim
36
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
37
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
38
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
39
+ assert slot_mapping.numel() == N
40
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
41
+
42
+
43
+ class Attention(nn.Module):
44
+
45
+ def __init__(
46
+ self,
47
+ num_heads,
48
+ head_dim,
49
+ scale,
50
+ num_kv_heads,
51
+ ):
52
+ super().__init__()
53
+ self.num_heads = num_heads
54
+ self.head_dim = head_dim
55
+ self.scale = scale
56
+ self.num_kv_heads = num_kv_heads
57
+ self.k_cache = self.v_cache = torch.tensor([])
58
+
59
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
60
+ context = get_context()
61
+ k_cache, v_cache = self.k_cache, self.v_cache
62
+ if k_cache.numel() and v_cache.numel():
63
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
64
+ if context.is_prefill:
65
+ if context.block_tables is not None: # prefix cache
66
+ k, v = k_cache, v_cache
67
+ o = flash_attn_varlen_func(q, k, v,
68
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
69
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
70
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
71
+ else: # decode
72
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
73
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
74
+ softmax_scale=self.scale, causal=True)
75
+ return o
code/acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from nanovllm.utils.context import get_context
7
+
8
+
9
+ class VocabParallelEmbedding(nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ num_embeddings: int,
14
+ embedding_dim: int,
15
+ ):
16
+ super().__init__()
17
+ self.tp_rank = dist.get_rank()
18
+ self.tp_size = dist.get_world_size()
19
+ assert num_embeddings % self.tp_size == 0
20
+ self.num_embeddings = num_embeddings
21
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
22
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
23
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
24
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
25
+ self.weight.weight_loader = self.weight_loader
26
+
27
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
28
+ param_data = param.data
29
+ shard_size = param_data.size(0)
30
+ start_idx = self.tp_rank * shard_size
31
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
32
+ param_data.copy_(loaded_weight)
33
+
34
+ def forward(self, x: torch.Tensor):
35
+ if self.tp_size > 1:
36
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
37
+ x = mask * (x - self.vocab_start_idx)
38
+ y = F.embedding(x, self.weight)
39
+ if self.tp_size > 1:
40
+ y = mask.unsqueeze(1) * y
41
+ dist.all_reduce(y)
42
+ return y
43
+
44
+
45
+ class ParallelLMHead(VocabParallelEmbedding):
46
+
47
+ def __init__(
48
+ self,
49
+ num_embeddings: int,
50
+ embedding_dim: int,
51
+ bias: bool = False,
52
+ ):
53
+ assert not bias
54
+ super().__init__(num_embeddings, embedding_dim)
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ context = get_context()
58
+ if context.is_prefill:
59
+ last_indices = context.cu_seqlens_q[1:] - 1
60
+ x = x[last_indices].contiguous()
61
+ logits = F.linear(x, self.weight)
62
+ if self.tp_size > 1:
63
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
64
+ dist.gather(logits, all_logits, 0)
65
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
66
+ return logits
code/acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+
7
+ def __init__(
8
+ self,
9
+ hidden_size: int,
10
+ eps: float = 1e-6,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(hidden_size))
15
+
16
+ @torch.compile
17
+ def rms_forward(
18
+ self,
19
+ x: torch.Tensor,
20
+ ) -> torch.Tensor:
21
+ orig_dtype = x.dtype
22
+ x = x.float()
23
+ var = x.pow(2).mean(dim=-1, keepdim=True)
24
+ x.mul_(torch.rsqrt(var + self.eps))
25
+ x = x.to(orig_dtype).mul_(self.weight)
26
+ return x
27
+
28
+ @torch.compile
29
+ def add_rms_forward(
30
+ self,
31
+ x: torch.Tensor,
32
+ residual: torch.Tensor,
33
+ ) -> tuple[torch.Tensor, torch.Tensor]:
34
+ orig_dtype = x.dtype
35
+ x = x.float().add_(residual.float())
36
+ residual = x.to(orig_dtype)
37
+ var = x.pow(2).mean(dim=-1, keepdim=True)
38
+ x.mul_(torch.rsqrt(var + self.eps))
39
+ x = x.to(orig_dtype).mul_(self.weight)
40
+ return x, residual
41
+
42
+ def forward(
43
+ self,
44
+ x: torch.Tensor,
45
+ residual: torch.Tensor | None = None,
46
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
47
+ if residual is None:
48
+ return self.rms_forward(x)
49
+ else:
50
+ return self.add_rms_forward(x, residual)
code/acestep/third_parts/nano-vllm/nanovllm/layers/linear.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+
7
+ def divide(numerator, denominator):
8
+ assert numerator % denominator == 0
9
+ return numerator // denominator
10
+
11
+
12
+ class LinearBase(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ input_size: int,
17
+ output_size: int,
18
+ bias: bool = False,
19
+ tp_dim: int | None = None,
20
+ ):
21
+ super().__init__()
22
+ self.tp_dim = tp_dim
23
+ self.tp_rank = dist.get_rank()
24
+ self.tp_size = dist.get_world_size()
25
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
26
+ self.weight.weight_loader = self.weight_loader
27
+ if bias:
28
+ self.bias = nn.Parameter(torch.empty(output_size))
29
+ self.bias.weight_loader = self.weight_loader
30
+ else:
31
+ self.register_parameter("bias", None)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ raise NotImplementedError
35
+
36
+
37
+ class ReplicatedLinear(LinearBase):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int,
43
+ bias: bool = False,
44
+ ):
45
+ super().__init__(input_size, output_size, bias)
46
+
47
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
48
+ param.data.copy_(loaded_weight)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return F.linear(x, self.weight, self.bias)
52
+
53
+
54
+ class ColumnParallelLinear(LinearBase):
55
+
56
+ def __init__(
57
+ self,
58
+ input_size: int,
59
+ output_size: int,
60
+ bias: bool = False,
61
+ ):
62
+ tp_size = dist.get_world_size()
63
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
64
+
65
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
66
+ param_data = param.data
67
+ shard_size = param_data.size(self.tp_dim)
68
+ start_idx = self.tp_rank * shard_size
69
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
70
+ param_data.copy_(loaded_weight)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return F.linear(x, self.weight, self.bias)
74
+
75
+
76
+ class MergedColumnParallelLinear(ColumnParallelLinear):
77
+
78
+ def __init__(
79
+ self,
80
+ input_size: int,
81
+ output_sizes: list[int],
82
+ bias: bool = False,
83
+ ):
84
+ self.output_sizes = output_sizes
85
+ super().__init__(input_size, sum(output_sizes), bias)
86
+
87
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
88
+ param_data = param.data
89
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
90
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
91
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
92
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
93
+ param_data.copy_(loaded_weight)
94
+
95
+
96
+ class QKVParallelLinear(ColumnParallelLinear):
97
+
98
+ def __init__(
99
+ self,
100
+ hidden_size: int,
101
+ head_size: int,
102
+ total_num_heads: int,
103
+ total_num_kv_heads: int | None = None,
104
+ bias: bool = False,
105
+ ):
106
+ tp_size = dist.get_world_size()
107
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
108
+ self.head_size = head_size
109
+ self.num_heads = divide(total_num_heads, tp_size)
110
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
111
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
112
+ super().__init__(hidden_size, output_size, bias)
113
+
114
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
115
+ param_data = param.data
116
+ assert loaded_shard_id in ["q", "k", "v"]
117
+ if loaded_shard_id == "q":
118
+ shard_size = self.num_heads * self.head_size
119
+ shard_offset = 0
120
+ elif loaded_shard_id == "k":
121
+ shard_size = self.num_kv_heads * self.head_size
122
+ shard_offset = self.num_heads * self.head_size
123
+ else:
124
+ shard_size = self.num_kv_heads * self.head_size
125
+ shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
126
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
127
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
128
+ param_data.copy_(loaded_weight)
129
+
130
+
131
+ class RowParallelLinear(LinearBase):
132
+
133
+ def __init__(
134
+ self,
135
+ input_size: int,
136
+ output_size: int,
137
+ bias: bool = False,
138
+ ):
139
+ tp_size = dist.get_world_size()
140
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
141
+
142
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
143
+ param_data = param.data
144
+ shard_size = param_data.size(self.tp_dim)
145
+ start_idx = self.tp_rank * shard_size
146
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
147
+ param_data.copy_(loaded_weight)
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
151
+ if self.tp_size > 1:
152
+ dist.all_reduce(y)
153
+ return y