mrfakename commited on
Commit
acbeaa4
·
verified ·
1 Parent(s): 9f5c8f7

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. spaces/Ace-Step-v1.5/.env.example +4 -0
  3. spaces/Ace-Step-v1.5/.gitattributes +38 -0
  4. spaces/Ace-Step-v1.5/.gitignore +228 -0
  5. spaces/Ace-Step-v1.5/Dockerfile +57 -0
  6. spaces/Ace-Step-v1.5/LICENSE +21 -0
  7. spaces/Ace-Step-v1.5/README.md +229 -0
  8. spaces/Ace-Step-v1.5/acestep/__init__.py +1 -0
  9. spaces/Ace-Step-v1.5/acestep/acestep_v15_pipeline.py +303 -0
  10. spaces/Ace-Step-v1.5/acestep/api_server.py +1700 -0
  11. spaces/Ace-Step-v1.5/acestep/audio_utils.py +378 -0
  12. spaces/Ace-Step-v1.5/acestep/constants.py +109 -0
  13. spaces/Ace-Step-v1.5/acestep/constrained_logits_processor.py +0 -0
  14. spaces/Ace-Step-v1.5/acestep/dataset_handler.py +37 -0
  15. spaces/Ace-Step-v1.5/acestep/dit_alignment_score.py +870 -0
  16. spaces/Ace-Step-v1.5/acestep/genres_vocab.txt +0 -0
  17. spaces/Ace-Step-v1.5/acestep/gradio_ui/__init__.py +1 -0
  18. spaces/Ace-Step-v1.5/acestep/gradio_ui/events/__init__.py +1310 -0
  19. spaces/Ace-Step-v1.5/acestep/gradio_ui/events/generation_handlers.py +1054 -0
  20. spaces/Ace-Step-v1.5/acestep/gradio_ui/events/results_handlers.py +0 -0
  21. spaces/Ace-Step-v1.5/acestep/gradio_ui/events/training_handlers.py +644 -0
  22. spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n.py +152 -0
  23. spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/en.json +245 -0
  24. spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/ja.json +245 -0
  25. spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/zh.json +245 -0
  26. spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/__init__.py +98 -0
  27. spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/dataset.py +101 -0
  28. spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/generation.py +693 -0
  29. spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/result.py +598 -0
  30. spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/training.py +562 -0
  31. spaces/Ace-Step-v1.5/acestep/handler.py +0 -0
  32. spaces/Ace-Step-v1.5/acestep/inference.py +1182 -0
  33. spaces/Ace-Step-v1.5/acestep/llm_inference.py +0 -0
  34. spaces/Ace-Step-v1.5/acestep/local_cache.py +129 -0
  35. spaces/Ace-Step-v1.5/acestep/test_time_scaling.py +410 -0
  36. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/LICENSE +21 -0
  37. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/README.md +66 -0
  38. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/assets/logo.png +3 -0
  39. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/bench.py +32 -0
  40. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/example.py +33 -0
  41. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
  42. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
  43. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +112 -0
  44. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +124 -0
  45. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +529 -0
  46. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +222 -0
  47. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +96 -0
  48. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
  49. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
  50. spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
.gitattributes CHANGED
@@ -39,3 +39,9 @@ code/assets/acestudio_logo.png filter=lfs diff=lfs merge=lfs -text
39
  code/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
40
  code/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
41
  code/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
39
  code/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
40
  code/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
41
  code/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
42
+ spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/assets/logo.png filter=lfs diff=lfs merge=lfs -text
43
+ spaces/Ace-Step-v1.5/assets/ACE-Step_framework.png filter=lfs diff=lfs merge=lfs -text
44
+ spaces/Ace-Step-v1.5/assets/acestudio_logo.png filter=lfs diff=lfs merge=lfs -text
45
+ spaces/Ace-Step-v1.5/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
46
+ spaces/Ace-Step-v1.5/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
47
+ spaces/Ace-Step-v1.5/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
spaces/Ace-Step-v1.5/.env.example ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ ACESTEP_CONFIG_PATH=acestep-v15-turbo
2
+ ACESTEP_LM_MODEL_PATH=acestep-5Hz-lm-1.7B
3
+ ACESTEP_DEVICE=auto
4
+ ACESTEP_LM_BACKEND=vllm
spaces/Ace-Step-v1.5/.gitattributes ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
spaces/Ace-Step-v1.5/.gitignore ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ *.mp3
3
+ *.wav
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[codz]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py.cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # UV
102
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ uv.lock
106
+
107
+ # poetry
108
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
109
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
110
+ # commonly ignored for libraries.
111
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
112
+ #poetry.lock
113
+ #poetry.toml
114
+
115
+ # pdm
116
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
117
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
118
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
119
+ #pdm.lock
120
+ #pdm.toml
121
+ .pdm-python
122
+ .pdm-build/
123
+
124
+ # pixi
125
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
126
+ #pixi.lock
127
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
128
+ # in the .venv directory. It is recommended not to include this directory in version control.
129
+ .pixi
130
+
131
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
132
+ __pypackages__/
133
+
134
+ # Celery stuff
135
+ celerybeat-schedule
136
+ celerybeat.pid
137
+
138
+ # SageMath parsed files
139
+ *.sage.py
140
+
141
+ # Environments
142
+ .env
143
+ .envrc
144
+ .venv
145
+ env/
146
+ venv/
147
+ ENV/
148
+ env.bak/
149
+ venv.bak/
150
+
151
+ # Spyder project settings
152
+ .spyderproject
153
+ .spyproject
154
+
155
+ # Rope project settings
156
+ .ropeproject
157
+
158
+ # mkdocs documentation
159
+ /site
160
+
161
+ # mypy
162
+ .mypy_cache/
163
+ .dmypy.json
164
+ dmypy.json
165
+
166
+ # Pyre type checker
167
+ .pyre/
168
+
169
+ # pytype static type analyzer
170
+ .pytype/
171
+
172
+ # Cython debug symbols
173
+ cython_debug/
174
+
175
+ # PyCharm
176
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
177
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
178
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
179
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
180
+ #.idea/
181
+
182
+ # Abstra
183
+ # Abstra is an AI-powered process automation framework.
184
+ # Ignore directories containing user credentials, local state, and settings.
185
+ # Learn more at https://abstra.io/docs
186
+ .abstra/
187
+
188
+ # Visual Studio Code
189
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
190
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
191
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
192
+ # you could uncomment the following to ignore the entire vscode folder
193
+ # .vscode/
194
+
195
+ # Ruff stuff:
196
+ .ruff_cache/
197
+
198
+ # PyPI configuration file
199
+ .pypirc
200
+
201
+ # Cursor
202
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
203
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
204
+ # refer to https://docs.cursor.com/context/ignore-files
205
+ .cursorignore
206
+ .cursorindexingignore
207
+
208
+ # Marimo
209
+ marimo/_static/
210
+ marimo/_lsp/
211
+ __marimo__/
212
+ tests/
213
+ checkpoints/
214
+ playground.ipynb
215
+ .history/
216
+ upload_checkpoints.sh
217
+ checkpoints.7z
218
+ README_old.md
219
+ discord_bot/
220
+ feishu_bot/
221
+ tmp*
222
+ torchinductor_root/
223
+ scripts/
224
+ checkpoints_legacy/
225
+ lora_output/
226
+ datasets/
227
+ python_embeded/
228
+ checkpoints_pack/
spaces/Ace-Step-v1.5/Dockerfile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace Space Docker SDK
2
+ # Use slim Python image - HuggingFace GPU Spaces provide CUDA runtime
3
+ FROM python:3.11-slim
4
+
5
+ # Set environment variables
6
+ ENV PYTHONDONTWRITEBYTECODE=1 \
7
+ PYTHONUNBUFFERED=1 \
8
+ DEBIAN_FRONTEND=noninteractive \
9
+ TORCHAUDIO_USE_TORCHCODEC=0
10
+
11
+ # Install system dependencies
12
+ # build-essential is required for triton to compile CUDA kernels
13
+ # ffmpeg and libav* dev packages are required for torchaudio's ffmpeg backend
14
+ # Note: torchaudio's ffmpeg backend needs shared libraries, not just the ffmpeg binary
15
+ RUN apt-get update && \
16
+ apt-get install -y --no-install-recommends git libsndfile1 build-essential && \
17
+ apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswresample-dev && \
18
+ rm -rf /var/lib/apt/lists/*
19
+
20
+ # Set up a new user named "user" with user ID 1000 (HuggingFace Space requirement)
21
+ RUN useradd -m -u 1000 user
22
+
23
+ # Create /data directory with proper permissions for persistent storage
24
+ RUN mkdir -p /data && chown user:user /data && chmod 755 /data
25
+
26
+ # Set environment variables for user
27
+ ENV HOME=/home/user \
28
+ PATH=/home/user/.local/bin:$PATH \
29
+ GRADIO_SERVER_NAME=0.0.0.0 \
30
+ GRADIO_SERVER_PORT=7860
31
+
32
+ # Set the working directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Copy requirements first for better Docker layer caching
36
+ COPY --chown=user:user requirements.txt .
37
+
38
+ # Copy the local nano-vllm package
39
+ COPY --chown=user:user acestep/third_parts/nano-vllm ./acestep/third_parts/nano-vllm
40
+
41
+ # Switch to user before installing packages
42
+ USER user
43
+
44
+ # Install dependencies from requirements.txt (includes PyTorch with CUDA from --extra-index-url)
45
+ RUN pip install --no-cache-dir --user -r requirements.txt
46
+
47
+ # Install nano-vllm with --no-deps since all dependencies are already installed
48
+ RUN pip install --no-deps ./acestep/third_parts/nano-vllm
49
+
50
+ # Copy the rest of the application
51
+ COPY --chown=user:user . .
52
+
53
+ # Expose port
54
+ EXPOSE 7860
55
+
56
+ # Run the application
57
+ CMD ["python", "app.py"]
spaces/Ace-Step-v1.5/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 ACEStep
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
spaces/Ace-Step-v1.5/README.md ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ACE-Step v1.5
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ app_port: 7860
8
+ pinned: false
9
+ license: mit
10
+ short_description: Music Generation Foundation Model v1.5
11
+ ---
12
+
13
+ <h1 align="center">ACE-Step 1.5</h1>
14
+ <h1 align="center">Pushing the Boundaries of Open-Source Music Generation</h1>
15
+ <p align="center">
16
+ <a href="https://ace-step.github.io/ace-step-v1.5.github.io/">Project</a> |
17
+ <a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
18
+ <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
19
+ <a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5">Space Demo</a> |
20
+ <a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
21
+ <a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
22
+ </p>
23
+
24
+ <p align="center">
25
+ <img src="./assets/orgnization_logos.png" width="100%" alt="StepFun Logo">
26
+ </p>
27
+
28
+ ## Table of Contents
29
+
30
+ - [✨ Features](#-features)
31
+ - [📦 Installation](#-installation)
32
+ - [🚀 Usage](#-usage)
33
+ - [🔨 Train](#-train)
34
+ - [🏗️ Architecture](#️-architecture)
35
+ - [🦁 Model Zoo](#-model-zoo)
36
+
37
+ ## 📝 Abstract
38
+ We present ACE-Step v1.5, a highly efficient foundation model that democratizes commercial-grade music production on consumer hardware. Optimized for local deployment (<4GB VRAM), the model accelerates generation by over 100× compared to traditional pure LM architectures, producing superior high-fidelity audio in seconds characterized by coherent semantics and exceptional melodies. At its core lies a novel hybrid architecture where the Language Model (LM) functions as an omni-capable planner: it transforms simple user queries into comprehensive song blueprints—scaling from short loops to 10-minute compositions—while synthesizing metadata, lyrics, and captions via Chain-of-Thought to guide the Diffusion Transformer (DiT). Uniquely, this alignment is achieved through intrinsic reinforcement learning relying solely on the model’s internal mechanisms, thereby eliminating the biases inherent in external reward models or human preferences. Beyond standard synthesis, ACE-Step v1.5 unifies precise stylistic control with versatile editing capabilities—such as cover generation, repainting, and vocal-to-BGM conversion—while maintaining strict adherence to prompts across 50+ languages.
39
+
40
+
41
+ ## ✨ Features
42
+
43
+ <p align="center">
44
+ <img src="./assets/application_map.png" width="100%" alt="ACE-Step Framework">
45
+ </p>
46
+
47
+ ### ⚡ Performance
48
+ - ✅ **Ultra-Fast Generation** — 0.5s to 10s generation time on A100 (depending on think mode & diffusion steps)
49
+ - ✅ **Flexible Duration** — Supports 10 seconds to 10 minutes (600s) audio generation
50
+ - ✅ **Batch Generation** — Generate up to 8 songs simultaneously
51
+
52
+ ### 🎵 Generation Quality
53
+ - ✅ **Commercial-Grade Output** — Quality between Suno v4.5 and Suno v5
54
+ - ✅ **Rich Style Support** — 1000+ instruments and styles with fine-grained timbre description
55
+ - ✅ **Multi-Language Lyrics** — Supports 50+ languages with lyrics prompt for structure & style control
56
+
57
+ ### 🎛️ Versatility & Control
58
+
59
+ | Feature | Description |
60
+ |---------|-------------|
61
+ | ✅ Reference Audio Input | Use reference audio to guide generation style |
62
+ | ✅ Cover Generation | Create covers from existing audio |
63
+ | ✅ Repaint & Edit | Selective local audio editing and regeneration |
64
+ | ✅ Track Separation | Separate audio into individual stems |
65
+ | ✅ Multi-Track Generation | Add layers like Suno Studio's "Add Layer" feature |
66
+ | ✅ Vocal2BGM | Auto-generate accompaniment for vocal tracks |
67
+ | ✅ Metadata Control | Control duration, BPM, key/scale, time signature |
68
+ | ✅ Simple Mode | Generate full songs from simple descriptions |
69
+ | ✅ Query Rewriting | Auto LM expansion of tags and lyrics |
70
+ | ✅ Audio Understanding | Extract BPM, key/scale, time signature & caption from audio |
71
+ | ✅ LRC Generation | Auto-generate lyric timestamps for generated music |
72
+ | ✅ LoRA Training | One-click annotation & training in Gradio. 8 songs, 1 hour on 3090 (12GB VRAM) |
73
+ | ✅ Quality Scoring | Automatic quality assessment for generated audio |
74
+
75
+
76
+
77
+ ## 📦 Installation
78
+
79
+ > **Requirements:** Python 3.11, CUDA GPU recommended (works on CPU/MPS but slower)
80
+
81
+ ### 1. Install uv (Package Manager)
82
+
83
+ ```bash
84
+ # macOS / Linux
85
+ curl -LsSf https://astral.sh/uv/install.sh | sh
86
+
87
+ # Windows (PowerShell)
88
+ powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
89
+ ```
90
+
91
+ ### 2. Clone & Install
92
+
93
+ ```bash
94
+ git clone https://github.com/ACE-Step/ACE-Step-1.5.git
95
+ cd ACE-Step-1.5
96
+ uv sync
97
+ ```
98
+
99
+ ### 3. Launch
100
+
101
+ #### 🖥️ Gradio Web UI (Recommended)
102
+
103
+ ```bash
104
+ uv run acestep
105
+ ```
106
+
107
+ Open http://localhost:7860 in your browser. Models will be downloaded automatically on first run.
108
+
109
+ #### 🌐 REST API Server
110
+
111
+ ```bash
112
+ uv run acestep-api
113
+ ```
114
+
115
+ API runs at http://localhost:8001. See [API Documentation](./docs/en/API.md) for endpoints.
116
+
117
+ ### Command Line Options
118
+
119
+ **Gradio UI (`acestep`):**
120
+
121
+ | Option | Default | Description |
122
+ |--------|---------|-------------|
123
+ | `--port` | 7860 | Server port |
124
+ | `--server-name` | 127.0.0.1 | Server address (use `0.0.0.0` for network access) |
125
+ | `--share` | false | Create public Gradio link |
126
+ | `--language` | en | UI language: `en`, `zh`, `ja` |
127
+ | `--init_service` | false | Auto-initialize models on startup |
128
+ | `--config_path` | auto | DiT model (e.g., `acestep-v15-turbo`, `acestep-v15-turbo-shift3`) |
129
+ | `--lm_model_path` | auto | LM model (e.g., `acestep-5Hz-lm-0.6B`, `acestep-5Hz-lm-1.7B`) |
130
+ | `--offload_to_cpu` | auto | CPU offload (auto-enabled if VRAM < 16GB) |
131
+
132
+ **Examples:**
133
+
134
+ ```bash
135
+ # Public access with Chinese UI
136
+ uv run acestep --server-name 0.0.0.0 --share --language zh
137
+
138
+ # Pre-initialize models on startup
139
+ uv run acestep --init_service true --config_path acestep-v15-turbo
140
+ ```
141
+
142
+ ### Development
143
+
144
+ ```bash
145
+ # Add dependencies
146
+ uv add package-name
147
+ uv add --dev package-name
148
+
149
+ # Update all dependencies
150
+ uv sync --upgrade
151
+ ```
152
+
153
+ ## 🚀 Usage
154
+
155
+ We provide multiple ways to use ACE-Step:
156
+
157
+ | Method | Description | Documentation |
158
+ |--------|-------------|---------------|
159
+ | 🖥️ **Gradio Web UI** | Interactive web interface for music generation | [Gradio Guide](./docs/en/GRADIO_GUIDE.md) |
160
+ | 🐍 **Python API** | Programmatic access for integration | [Inference API](./docs/en/INFERENCE.md) |
161
+ | 🌐 **REST API** | HTTP-based async API for services | [REST API](./docs/en/API.md) |
162
+
163
+ **📚 Documentation available in:** [English](./docs/en/) | [中文](./docs/zh/) | [日本語](./docs/ja/)
164
+
165
+
166
+ ## 🔨 Train
167
+
168
+ See the **LoRA Training** tab in Gradio UI for one-click training, or check [Gradio Guide - LoRA Training](./docs/en/GRADIO_GUIDE.md#lora-training) for details.
169
+
170
+ ## 🏗️ Architecture
171
+
172
+ <p align="center">
173
+ <img src="./assets/ACE-Step_framework.png" width="100%" alt="ACE-Step Framework">
174
+ </p>
175
+
176
+ ## 🦁 Model Zoo
177
+
178
+ <p align="center">
179
+ <img src="./assets/model_zoo.png" width="100%" alt="Model Zoo">
180
+ </p>
181
+
182
+ ### DiT Models
183
+
184
+ | DiT Model | Pre-Training | SFT | RL | CFG | Step | Refer audio | Text2Music | Cover | Repaint | Extract | Lego | Complete | Quality | Diversity | Fine-Tunability | Hugging Face |
185
+ |-----------|:------------:|:---:|:--:|:---:|:----:|:-----------:|:----------:|:-----:|:-------:|:-------:|:----:|:--------:|:-------:|:---------:|:---------------:|--------------|
186
+ | `acestep-v15-base` | ✅ | ❌ | ❌ | ✅ | 50 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | High | Easy | [Link](https://huggingface.co/ACE-Step/acestep-v15-base) |
187
+ | `acestep-v15-sft` | ✅ | ✅ | ❌ | ✅ | 50 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | High | Medium | Easy | [Link](https://huggingface.co/ACE-Step/acestep-v15-sft) |
188
+ | `acestep-v15-turbo` | ✅ | ✅ | ❌ | ❌ | 8 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | Very High | Medium | Medium | [Link](https://huggingface.co/ACE-Step/Ace-Step1.5) |
189
+ | `acestep-v15-turbo-rl` | ✅ | ✅ | ✅ | ❌ | 8 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | Very High | Medium | Medium | To be released |
190
+
191
+ ### LM Models
192
+
193
+ | LM Model | Pretrain from | Pre-Training | SFT | RL | CoT metas | Query rewrite | Audio Understanding | Composition Capability | Copy Melody | Hugging Face |
194
+ |----------|---------------|:------------:|:---:|:--:|:---------:|:-------------:|:-------------------:|:----------------------:|:-----------:|--------------|
195
+ | `acestep-5Hz-lm-0.6B` | Qwen3-0.6B | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | Medium | Weak | ✅ |
196
+ | `acestep-5Hz-lm-1.7B` | Qwen3-1.7B | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | Medium | Medium | ✅ |
197
+ | `acestep-5Hz-lm-4B` | Qwen3-4B | ✅ | ✅ | ✅ | ✅ | ✅ | Strong | Strong | Strong | To be released |
198
+
199
+ ## 📜 License & Disclaimer
200
+
201
+ This project is licensed under [MIT](./LICENSE)
202
+
203
+ ACE-Step enables original music generation across diverse genres, with applications in creative production, education, and entertainment. While designed to support positive and artistic use cases, we acknowledge potential risks such as unintentional copyright infringement due to stylistic similarity, inappropriate blending of cultural elements, and misuse for generating harmful content. To ensure responsible use, we encourage users to verify the originality of generated works, clearly disclose AI involvement, and obtain appropriate permissions when adapting protected styles or materials. By using ACE-Step, you agree to uphold these principles and respect artistic integrity, cultural diversity, and legal compliance. The authors are not responsible for any misuse of the model, including but not limited to copyright violations, cultural insensitivity, or the generation of harmful content.
204
+
205
+ 🔔 Important Notice
206
+ The only official website for the ACE-Step project is our GitHub Pages site.
207
+ We do not operate any other websites.
208
+ 🚫 Fake domains include but are not limited to:
209
+ ac\*\*p.com, a\*\*p.org, a\*\*\*c.org
210
+ ⚠️ Please be cautious. Do not visit, trust, or make payments on any of those sites.
211
+
212
+ ## 🙏 Acknowledgements
213
+
214
+ This project is co-led by ACE Studio and StepFun.
215
+
216
+
217
+ ## 📖 Citation
218
+
219
+ If you find this project useful for your research, please consider citing:
220
+
221
+ ```BibTeX
222
+ @misc{gong2026acestep,
223
+ title={ACE-Step 1.5: Pushing the Boundaries of Open-Source Music Generation},
224
+ author={Junmin Gong, Song Yulin, Wenxiao Zhao, Sen Wang, Shengyuan Xu, Jing Guo},
225
+ howpublished={\url{https://github.com/ace-step/ACE-Step-1.5}},
226
+ year={2026},
227
+ note={GitHub repository}
228
+ }
229
+ ```
spaces/Ace-Step-v1.5/acestep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ACE-Step package."""
spaces/Ace-Step-v1.5/acestep/acestep_v15_pipeline.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step V1.5 Pipeline
3
+ Handler wrapper connecting model and UI
4
+ """
5
+ import os
6
+ import sys
7
+
8
+ # Load environment variables from .env file in project root
9
+ # This allows configuration without hardcoding values
10
+ # Falls back to .env.example if .env is not found
11
+ try:
12
+ from dotenv import load_dotenv
13
+ # Get project root directory
14
+ _current_file = os.path.abspath(__file__)
15
+ _project_root = os.path.dirname(os.path.dirname(_current_file))
16
+ _env_path = os.path.join(_project_root, '.env')
17
+ _env_example_path = os.path.join(_project_root, '.env.example')
18
+
19
+ if os.path.exists(_env_path):
20
+ load_dotenv(_env_path)
21
+ print(f"Loaded configuration from {_env_path}")
22
+ elif os.path.exists(_env_example_path):
23
+ load_dotenv(_env_example_path)
24
+ print(f"Loaded configuration from {_env_example_path} (fallback)")
25
+ except ImportError:
26
+ # python-dotenv not installed, skip loading .env
27
+ pass
28
+
29
+ # Clear proxy settings that may affect Gradio
30
+ for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
31
+ os.environ.pop(proxy_var, None)
32
+
33
+ try:
34
+ # When executed as a module: `python -m acestep.acestep_v15_pipeline`
35
+ from .handler import AceStepHandler
36
+ from .llm_inference import LLMHandler
37
+ from .dataset_handler import DatasetHandler
38
+ from .gradio_ui import create_gradio_interface
39
+ except ImportError:
40
+ # When executed as a script: `python acestep/acestep_v15_pipeline.py`
41
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
42
+ if project_root not in sys.path:
43
+ sys.path.insert(0, project_root)
44
+ from acestep.handler import AceStepHandler
45
+ from acestep.llm_inference import LLMHandler
46
+ from acestep.dataset_handler import DatasetHandler
47
+ from acestep.gradio_ui import create_gradio_interface
48
+
49
+
50
+ def create_demo(init_params=None, language='en'):
51
+ """
52
+ Create Gradio demo interface
53
+
54
+ Args:
55
+ init_params: Dictionary containing initialization parameters and state.
56
+ If None, service will not be pre-initialized.
57
+ Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
58
+ 'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
59
+ 'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
60
+ 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
61
+ 'language' (UI language code)
62
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
63
+
64
+ Returns:
65
+ Gradio Blocks instance
66
+ """
67
+ # Get persistent storage path from init_params (for HuggingFace Space)
68
+ persistent_storage_path = None
69
+ if init_params:
70
+ persistent_storage_path = init_params.get('persistent_storage_path')
71
+
72
+ # Use pre-initialized handlers if available, otherwise create new ones
73
+ if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
74
+ dit_handler = init_params['dit_handler']
75
+ llm_handler = init_params['llm_handler']
76
+ else:
77
+ dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
78
+ llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
79
+
80
+ dataset_handler = DatasetHandler() # Dataset handler
81
+
82
+ # Create Gradio interface with all handlers and initialization parameters
83
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
84
+
85
+ return demo
86
+
87
+
88
+ def get_gpu_memory_gb():
89
+ """
90
+ Get GPU memory in GB. Returns 0 if no GPU is available.
91
+ """
92
+ try:
93
+ import torch
94
+ if torch.cuda.is_available():
95
+ # Get total memory of the first GPU in GB
96
+ total_memory = torch.cuda.get_device_properties(0).total_memory
97
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
98
+ return memory_gb
99
+ else:
100
+ return 0
101
+ except Exception as e:
102
+ print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
103
+ return 0
104
+
105
+
106
+ def main():
107
+ """Main entry function"""
108
+ import argparse
109
+
110
+ # Detect GPU memory to auto-configure offload settings
111
+ gpu_memory_gb = get_gpu_memory_gb()
112
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
113
+
114
+ if auto_offload:
115
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
116
+ print("Auto-enabling CPU offload to reduce GPU memory usage")
117
+ elif gpu_memory_gb > 0:
118
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
119
+ print("CPU offload disabled by default")
120
+ else:
121
+ print("No GPU detected, running on CPU")
122
+
123
+ parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
124
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
125
+ parser.add_argument("--share", action="store_true", help="Create a public link")
126
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
127
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
128
+ parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
129
+
130
+ # Service mode argument
131
+ parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
132
+ help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
133
+
134
+ # Service initialization arguments
135
+ parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
136
+ parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
137
+ parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
138
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
139
+ parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
140
+ parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
141
+ parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
142
+ parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
143
+ parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
144
+ parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
145
+
146
+ args = parser.parse_args()
147
+
148
+ # Service mode defaults (can be configured via .env file)
149
+ if args.service_mode:
150
+ print("Service mode enabled - applying preset configurations...")
151
+ # Force init_service in service mode
152
+ args.init_service = True
153
+ # Default DiT model for service mode (from env or fallback)
154
+ if args.config_path is None:
155
+ args.config_path = os.environ.get(
156
+ "SERVICE_MODE_DIT_MODEL",
157
+ "acestep-v15-turbo-fix-inst-shift-dynamic"
158
+ )
159
+ # Default LM model for service mode (from env or fallback)
160
+ if args.lm_model_path is None:
161
+ args.lm_model_path = os.environ.get(
162
+ "SERVICE_MODE_LM_MODEL",
163
+ "acestep-5Hz-lm-1.7B-v4-fix"
164
+ )
165
+ # Backend for service mode (from env or fallback to vllm)
166
+ args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
167
+ print(f" DiT model: {args.config_path}")
168
+ print(f" LM model: {args.lm_model_path}")
169
+ print(f" Backend: {args.backend}")
170
+
171
+ try:
172
+ init_params = None
173
+
174
+ # If init_service is True, perform initialization before creating UI
175
+ if args.init_service:
176
+ print("Initializing service from command line...")
177
+
178
+ # Create handler instances for initialization
179
+ dit_handler = AceStepHandler()
180
+ llm_handler = LLMHandler()
181
+
182
+ # Auto-select config_path if not provided
183
+ if args.config_path is None:
184
+ available_models = dit_handler.get_available_acestep_v15_models()
185
+ if available_models:
186
+ args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
187
+ print(f"Auto-selected config_path: {args.config_path}")
188
+ else:
189
+ print("Error: No available models found. Please specify --config_path", file=sys.stderr)
190
+ sys.exit(1)
191
+
192
+ # Get project root (same logic as in handler)
193
+ current_file = os.path.abspath(__file__)
194
+ project_root = os.path.dirname(os.path.dirname(current_file))
195
+
196
+ # Determine flash attention setting
197
+ use_flash_attention = args.use_flash_attention
198
+ if use_flash_attention is None:
199
+ use_flash_attention = dit_handler.is_flash_attention_available()
200
+
201
+ # Initialize DiT handler
202
+ print(f"Initializing DiT model: {args.config_path} on {args.device}...")
203
+ init_status, enable_generate = dit_handler.initialize_service(
204
+ project_root=project_root,
205
+ config_path=args.config_path,
206
+ device=args.device,
207
+ use_flash_attention=use_flash_attention,
208
+ compile_model=False,
209
+ offload_to_cpu=args.offload_to_cpu,
210
+ offload_dit_to_cpu=args.offload_dit_to_cpu
211
+ )
212
+
213
+ if not enable_generate:
214
+ print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
215
+ sys.exit(1)
216
+
217
+ print(f"DiT model initialized successfully")
218
+
219
+ # Initialize LM handler if requested
220
+ lm_status = ""
221
+ if args.init_llm:
222
+ if args.lm_model_path is None:
223
+ # Try to get default LM model
224
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
225
+ if available_lm_models:
226
+ args.lm_model_path = available_lm_models[0]
227
+ print(f"Using default LM model: {args.lm_model_path}")
228
+ else:
229
+ print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
230
+ args.init_llm = False
231
+
232
+ if args.init_llm and args.lm_model_path:
233
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
234
+ print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
235
+ lm_status, lm_success = llm_handler.initialize(
236
+ checkpoint_dir=checkpoint_dir,
237
+ lm_model_path=args.lm_model_path,
238
+ backend=args.backend,
239
+ device=args.device,
240
+ offload_to_cpu=args.offload_to_cpu,
241
+ dtype=dit_handler.dtype
242
+ )
243
+
244
+ if lm_success:
245
+ print(f"5Hz LM initialized successfully")
246
+ init_status += f"\n{lm_status}"
247
+ else:
248
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
249
+ init_status += f"\n{lm_status}"
250
+
251
+ # Prepare initialization parameters for UI
252
+ init_params = {
253
+ 'pre_initialized': True,
254
+ 'service_mode': args.service_mode,
255
+ 'checkpoint': args.checkpoint,
256
+ 'config_path': args.config_path,
257
+ 'device': args.device,
258
+ 'init_llm': args.init_llm,
259
+ 'lm_model_path': args.lm_model_path,
260
+ 'backend': args.backend,
261
+ 'use_flash_attention': use_flash_attention,
262
+ 'offload_to_cpu': args.offload_to_cpu,
263
+ 'offload_dit_to_cpu': args.offload_dit_to_cpu,
264
+ 'init_status': init_status,
265
+ 'enable_generate': enable_generate,
266
+ 'dit_handler': dit_handler,
267
+ 'llm_handler': llm_handler,
268
+ 'language': args.language
269
+ }
270
+
271
+ print("Service initialization completed successfully!")
272
+
273
+ # Create and launch demo
274
+ print(f"Creating Gradio interface with language: {args.language}...")
275
+ demo = create_demo(init_params=init_params, language=args.language)
276
+
277
+ # Enable queue for multi-user support
278
+ # This ensures proper request queuing and prevents concurrent generation conflicts
279
+ print("Enabling queue for multi-user support...")
280
+ demo.queue(
281
+ max_size=20, # Maximum queue size (adjust based on your needs)
282
+ status_update_rate="auto", # Update rate for queue status
283
+ )
284
+
285
+ print(f"Launching server on {args.server_name}:{args.port}...")
286
+ demo.launch(
287
+ server_name=args.server_name,
288
+ server_port=args.port,
289
+ share=args.share,
290
+ debug=args.debug,
291
+ show_error=True,
292
+ prevent_thread_lock=False, # Keep thread locked to maintain server running
293
+ inbrowser=False, # Don't auto-open browser
294
+ )
295
+ except Exception as e:
296
+ print(f"Error launching Gradio: {e}", file=sys.stderr)
297
+ import traceback
298
+ traceback.print_exc()
299
+ sys.exit(1)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
spaces/Ace-Step-v1.5/acestep/api_server.py ADDED
@@ -0,0 +1,1700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for ACE-Step V1.5.
2
+
3
+ Endpoints:
4
+ - POST /release_task Create music generation task
5
+ - POST /query_result Batch query task results
6
+ - POST /v1/music/random Create random sample task
7
+ - GET /v1/models List available models
8
+ - GET /v1/audio Download audio file
9
+ - GET /health Health check
10
+
11
+ NOTE:
12
+ - In-memory queue and job store -> run uvicorn with workers=1.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import json
19
+ import os
20
+ import sys
21
+ import time
22
+ import traceback
23
+ import tempfile
24
+ import urllib.parse
25
+ from collections import deque
26
+ from concurrent.futures import ThreadPoolExecutor
27
+ from contextlib import asynccontextmanager
28
+ from dataclasses import dataclass
29
+ from pathlib import Path
30
+ from threading import Lock
31
+ from typing import Any, Dict, List, Literal, Optional
32
+ from uuid import uuid4
33
+
34
+ try:
35
+ from dotenv import load_dotenv
36
+ except ImportError: # Optional dependency
37
+ load_dotenv = None # type: ignore
38
+
39
+ from fastapi import FastAPI, HTTPException, Request
40
+ from pydantic import BaseModel, Field
41
+ from starlette.datastructures import UploadFile as StarletteUploadFile
42
+
43
+ from acestep.handler import AceStepHandler
44
+ from acestep.llm_inference import LLMHandler
45
+ from acestep.constants import (
46
+ DEFAULT_DIT_INSTRUCTION,
47
+ DEFAULT_LM_INSTRUCTION,
48
+ TASK_INSTRUCTIONS,
49
+ )
50
+ from acestep.inference import (
51
+ GenerationParams,
52
+ GenerationConfig,
53
+ generate_music,
54
+ create_sample,
55
+ format_sample,
56
+ )
57
+ from acestep.gradio_ui.events.results_handlers import _build_generation_info
58
+
59
+
60
+ # =============================================================================
61
+ # Constants
62
+ # =============================================================================
63
+
64
+ RESULT_KEY_PREFIX = "ace_step_v1.5_"
65
+ RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days
66
+ TASK_TIMEOUT_SECONDS = 3600 # 1 hour
67
+ STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2}
68
+
69
+ LM_DEFAULT_TEMPERATURE = 0.85
70
+ LM_DEFAULT_CFG_SCALE = 2.5
71
+ LM_DEFAULT_TOP_P = 0.9
72
+
73
+ # Parameter aliases for request parsing
74
+ PARAM_ALIASES = {
75
+ "prompt": ["prompt"],
76
+ "sample_mode": ["sample_mode", "sampleMode"],
77
+ "sample_query": ["sample_query", "sampleQuery", "description", "desc"],
78
+ "use_format": ["use_format", "useFormat", "format"],
79
+ "model": ["model", "dit_model", "ditModel"],
80
+ "key_scale": ["key_scale", "keyscale", "keyScale"],
81
+ "time_signature": ["time_signature", "timesignature", "timeSignature"],
82
+ "audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"],
83
+ "vocal_language": ["vocal_language", "vocalLanguage"],
84
+ "inference_steps": ["inference_steps", "inferenceSteps"],
85
+ "guidance_scale": ["guidance_scale", "guidanceScale"],
86
+ "use_random_seed": ["use_random_seed", "useRandomSeed"],
87
+ "audio_code_string": ["audio_code_string", "audioCodeString"],
88
+ "audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"],
89
+ "task_type": ["task_type", "taskType"],
90
+ "infer_method": ["infer_method", "inferMethod"],
91
+ "use_tiled_decode": ["use_tiled_decode", "useTiledDecode"],
92
+ "constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"],
93
+ "constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"],
94
+ "use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"],
95
+ "use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
96
+ "is_format_caption": ["is_format_caption", "isFormatCaption"],
97
+ }
98
+
99
+
100
+ def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
101
+ """
102
+ Parse a description string to extract language code and instrumental flag.
103
+
104
+ This function analyzes user descriptions like "Pop rock. English" or "piano solo"
105
+ to detect:
106
+ - Language: Maps language names to ISO codes (e.g., "English" -> "en")
107
+ - Instrumental: Detects patterns indicating instrumental/no-vocal music
108
+
109
+ Args:
110
+ description: User's natural language music description
111
+
112
+ Returns:
113
+ (language_code, is_instrumental) tuple:
114
+ - language_code: ISO language code (e.g., "en", "zh") or None if not detected
115
+ - is_instrumental: True if description indicates instrumental music
116
+ """
117
+ import re
118
+
119
+ if not description:
120
+ return None, False
121
+
122
+ description_lower = description.lower().strip()
123
+
124
+ # Language mapping: input patterns -> ISO code
125
+ language_mapping = {
126
+ 'english': 'en', 'en': 'en',
127
+ 'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
128
+ 'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
129
+ 'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
130
+ 'spanish': 'es', 'español': 'es', 'es': 'es',
131
+ 'french': 'fr', 'français': 'fr', 'fr': 'fr',
132
+ 'german': 'de', 'deutsch': 'de', 'de': 'de',
133
+ 'italian': 'it', 'italiano': 'it', 'it': 'it',
134
+ 'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
135
+ 'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
136
+ 'bengali': 'bn', 'bn': 'bn',
137
+ 'hindi': 'hi', 'hi': 'hi',
138
+ 'arabic': 'ar', 'ar': 'ar',
139
+ 'thai': 'th', 'th': 'th',
140
+ 'vietnamese': 'vi', 'vi': 'vi',
141
+ 'indonesian': 'id', 'id': 'id',
142
+ 'turkish': 'tr', 'tr': 'tr',
143
+ 'dutch': 'nl', 'nl': 'nl',
144
+ 'polish': 'pl', 'pl': 'pl',
145
+ }
146
+
147
+ # Detect language
148
+ detected_language = None
149
+ for lang_name, lang_code in language_mapping.items():
150
+ if len(lang_name) <= 2:
151
+ pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
152
+ else:
153
+ pattern = r'\b' + re.escape(lang_name) + r'\b'
154
+
155
+ if re.search(pattern, description_lower):
156
+ detected_language = lang_code
157
+ break
158
+
159
+ # Detect instrumental
160
+ is_instrumental = False
161
+ if 'instrumental' in description_lower:
162
+ is_instrumental = True
163
+ elif 'pure music' in description_lower or 'pure instrument' in description_lower:
164
+ is_instrumental = True
165
+ elif description_lower.endswith(' solo') or description_lower == 'solo':
166
+ is_instrumental = True
167
+
168
+ return detected_language, is_instrumental
169
+
170
+
171
+ JobStatus = Literal["queued", "running", "succeeded", "failed"]
172
+
173
+
174
+ class GenerateMusicRequest(BaseModel):
175
+ prompt: str = Field(default="", description="Text prompt describing the music")
176
+ lyrics: str = Field(default="", description="Lyric text")
177
+
178
+ # New API semantics:
179
+ # - thinking=True: use 5Hz LM to generate audio codes (lm-dit behavior)
180
+ # - thinking=False: do not use LM to generate codes (dit behavior)
181
+ # Regardless of thinking, if some metas are missing, server may use LM to fill them.
182
+ thinking: bool = False
183
+ # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
184
+ sample_mode: bool = False
185
+ # Description for sample mode: auto-generate caption/lyrics from description query
186
+ sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
187
+ # Whether to use format_sample() to enhance input caption/lyrics
188
+ use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
189
+ # Model name for multi-model support (select which DiT model to use)
190
+ model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
191
+
192
+ bpm: Optional[int] = None
193
+ # Accept common client keys via manual parsing (see RequestParser).
194
+ key_scale: str = ""
195
+ time_signature: str = ""
196
+ vocal_language: str = "en"
197
+ inference_steps: int = 8
198
+ guidance_scale: float = 7.0
199
+ use_random_seed: bool = True
200
+ seed: int = -1
201
+
202
+ reference_audio_path: Optional[str] = None
203
+ src_audio_path: Optional[str] = None
204
+ audio_duration: Optional[float] = None
205
+ batch_size: Optional[int] = None
206
+
207
+ audio_code_string: str = ""
208
+
209
+ repainting_start: float = 0.0
210
+ repainting_end: Optional[float] = None
211
+
212
+ instruction: str = DEFAULT_DIT_INSTRUCTION
213
+ audio_cover_strength: float = 1.0
214
+ task_type: str = "text2music"
215
+
216
+ use_adg: bool = False
217
+ cfg_interval_start: float = 0.0
218
+ cfg_interval_end: float = 1.0
219
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
220
+ shift: float = Field(
221
+ default=3.0,
222
+ description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
223
+ )
224
+ timesteps: Optional[str] = Field(
225
+ default=None,
226
+ description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
227
+ )
228
+
229
+ audio_format: str = "mp3"
230
+ use_tiled_decode: bool = True
231
+
232
+ # 5Hz LM (server-side): used for metadata completion and (when thinking=True) codes generation.
233
+ lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
234
+ lm_backend: Literal["vllm", "pt"] = "vllm"
235
+
236
+ constrained_decoding: bool = True
237
+ constrained_decoding_debug: bool = False
238
+ use_cot_caption: bool = True
239
+ use_cot_language: bool = True
240
+ is_format_caption: bool = False
241
+
242
+ lm_temperature: float = 0.85
243
+ lm_cfg_scale: float = 2.5
244
+ lm_top_k: Optional[int] = None
245
+ lm_top_p: Optional[float] = 0.9
246
+ lm_repetition_penalty: float = 1.0
247
+ lm_negative_prompt: str = "NO USER INPUT"
248
+
249
+ class Config:
250
+ allow_population_by_field_name = True
251
+ allow_population_by_alias = True
252
+
253
+
254
+ class CreateJobResponse(BaseModel):
255
+ task_id: str
256
+ status: JobStatus
257
+ queue_position: int = 0 # 1-based best-effort position when queued
258
+
259
+
260
+ class JobResult(BaseModel):
261
+ first_audio_path: Optional[str] = None
262
+ second_audio_path: Optional[str] = None
263
+ audio_paths: list[str] = Field(default_factory=list)
264
+
265
+ generation_info: str = ""
266
+ status_message: str = ""
267
+ seed_value: str = ""
268
+
269
+ metas: Dict[str, Any] = Field(default_factory=dict)
270
+ bpm: Optional[int] = None
271
+ duration: Optional[float] = None
272
+ genres: Optional[str] = None
273
+ keyscale: Optional[str] = None
274
+ timesignature: Optional[str] = None
275
+
276
+ # Model information
277
+ lm_model: Optional[str] = None
278
+ dit_model: Optional[str] = None
279
+
280
+
281
+ class JobResponse(BaseModel):
282
+ job_id: str
283
+ status: JobStatus
284
+ created_at: float
285
+ started_at: Optional[float] = None
286
+ finished_at: Optional[float] = None
287
+
288
+ # queue observability
289
+ queue_position: int = 0
290
+ eta_seconds: Optional[float] = None
291
+ avg_job_seconds: Optional[float] = None
292
+
293
+ result: Optional[JobResult] = None
294
+ error: Optional[str] = None
295
+
296
+
297
+ @dataclass
298
+ class _JobRecord:
299
+ job_id: str
300
+ status: JobStatus
301
+ created_at: float
302
+ started_at: Optional[float] = None
303
+ finished_at: Optional[float] = None
304
+ result: Optional[Dict[str, Any]] = None
305
+ error: Optional[str] = None
306
+ env: str = "development"
307
+
308
+
309
+ class _JobStore:
310
+ def __init__(self) -> None:
311
+ self._lock = Lock()
312
+ self._jobs: Dict[str, _JobRecord] = {}
313
+
314
+ def create(self) -> _JobRecord:
315
+ job_id = str(uuid4())
316
+ rec = _JobRecord(job_id=job_id, status="queued", created_at=time.time())
317
+ with self._lock:
318
+ self._jobs[job_id] = rec
319
+ return rec
320
+
321
+ def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord:
322
+ """Create job record with specified ID"""
323
+ rec = _JobRecord(
324
+ job_id=job_id,
325
+ status="queued",
326
+ created_at=time.time(),
327
+ env=env
328
+ )
329
+ with self._lock:
330
+ self._jobs[job_id] = rec
331
+ return rec
332
+
333
+ def get(self, job_id: str) -> Optional[_JobRecord]:
334
+ with self._lock:
335
+ return self._jobs.get(job_id)
336
+
337
+ def mark_running(self, job_id: str) -> None:
338
+ with self._lock:
339
+ rec = self._jobs[job_id]
340
+ rec.status = "running"
341
+ rec.started_at = time.time()
342
+
343
+ def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None:
344
+ with self._lock:
345
+ rec = self._jobs[job_id]
346
+ rec.status = "succeeded"
347
+ rec.finished_at = time.time()
348
+ rec.result = result
349
+ rec.error = None
350
+
351
+ def mark_failed(self, job_id: str, error: str) -> None:
352
+ with self._lock:
353
+ rec = self._jobs[job_id]
354
+ rec.status = "failed"
355
+ rec.finished_at = time.time()
356
+ rec.result = None
357
+ rec.error = error
358
+
359
+
360
+ def _env_bool(name: str, default: bool) -> bool:
361
+ v = os.getenv(name)
362
+ if v is None:
363
+ return default
364
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
365
+
366
+
367
+ def _get_project_root() -> str:
368
+ current_file = os.path.abspath(__file__)
369
+ return os.path.dirname(os.path.dirname(current_file))
370
+
371
+
372
+ def _get_model_name(config_path: str) -> str:
373
+ """
374
+ Extract model name from config_path.
375
+
376
+ Args:
377
+ config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
378
+
379
+ Returns:
380
+ Model name (last directory name from config_path)
381
+ """
382
+ if not config_path:
383
+ return ""
384
+ normalized = config_path.rstrip("/\\")
385
+ return os.path.basename(normalized)
386
+
387
+
388
+ def _load_project_env() -> None:
389
+ if load_dotenv is None:
390
+ return
391
+ try:
392
+ project_root = _get_project_root()
393
+ env_path = os.path.join(project_root, ".env")
394
+ if os.path.exists(env_path):
395
+ load_dotenv(env_path, override=False)
396
+ except Exception:
397
+ # Optional best-effort: continue even if .env loading fails.
398
+ pass
399
+
400
+
401
+ _load_project_env()
402
+
403
+
404
+ def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
405
+ if v is None:
406
+ return default
407
+ if isinstance(v, int):
408
+ return v
409
+ s = str(v).strip()
410
+ if s == "":
411
+ return default
412
+ try:
413
+ return int(s)
414
+ except Exception:
415
+ return default
416
+
417
+
418
+ def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]:
419
+ if v is None:
420
+ return default
421
+ if isinstance(v, float):
422
+ return v
423
+ s = str(v).strip()
424
+ if s == "":
425
+ return default
426
+ try:
427
+ return float(s)
428
+ except Exception:
429
+ return default
430
+
431
+
432
+ def _to_bool(v: Any, default: bool = False) -> bool:
433
+ if v is None:
434
+ return default
435
+ if isinstance(v, bool):
436
+ return v
437
+ s = str(v).strip().lower()
438
+ if s == "":
439
+ return default
440
+ return s in {"1", "true", "yes", "y", "on"}
441
+
442
+
443
+ def _map_status(status: str) -> int:
444
+ """Map job status string to integer code."""
445
+ return STATUS_MAP.get(status, 2)
446
+
447
+
448
+ def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]:
449
+ """Parse comma-separated timesteps string to list of floats."""
450
+ if not s or not s.strip():
451
+ return None
452
+ try:
453
+ return [float(t.strip()) for t in s.split(",") if t.strip()]
454
+ except (ValueError, Exception):
455
+ return None
456
+
457
+
458
+ class RequestParser:
459
+ """Parse request parameters from multiple sources with alias support."""
460
+
461
+ def __init__(self, raw: dict):
462
+ self._raw = dict(raw) if raw else {}
463
+ self._param_obj = self._parse_json(self._raw.get("param_obj"))
464
+ self._metas = self._find_metas()
465
+
466
+ def _parse_json(self, v) -> dict:
467
+ if isinstance(v, dict):
468
+ return v
469
+ if isinstance(v, str) and v.strip():
470
+ try:
471
+ return json.loads(v)
472
+ except Exception:
473
+ pass
474
+ return {}
475
+
476
+ def _find_metas(self) -> dict:
477
+ for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"):
478
+ v = self._raw.get(key)
479
+ if v:
480
+ return self._parse_json(v)
481
+ return {}
482
+
483
+ def get(self, name: str, default=None):
484
+ """Get parameter by canonical name from all sources."""
485
+ aliases = PARAM_ALIASES.get(name, [name])
486
+ for source in (self._raw, self._param_obj, self._metas):
487
+ for alias in aliases:
488
+ v = source.get(alias)
489
+ if v is not None:
490
+ return v
491
+ return default
492
+
493
+ def str(self, name: str, default: str = "") -> str:
494
+ v = self.get(name)
495
+ return str(v) if v is not None else default
496
+
497
+ def int(self, name: str, default: Optional[int] = None) -> Optional[int]:
498
+ return _to_int(self.get(name), default)
499
+
500
+ def float(self, name: str, default: Optional[float] = None) -> Optional[float]:
501
+ return _to_float(self.get(name), default)
502
+
503
+ def bool(self, name: str, default: bool = False) -> bool:
504
+ return _to_bool(self.get(name), default)
505
+
506
+
507
+ async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
508
+ suffix = Path(upload.filename or "").suffix
509
+ fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
510
+ os.close(fd)
511
+ try:
512
+ with open(path, "wb") as f:
513
+ while True:
514
+ chunk = await upload.read(1024 * 1024)
515
+ if not chunk:
516
+ break
517
+ f.write(chunk)
518
+ except Exception:
519
+ try:
520
+ os.remove(path)
521
+ except Exception:
522
+ pass
523
+ raise
524
+ finally:
525
+ try:
526
+ await upload.close()
527
+ except Exception:
528
+ pass
529
+ return path
530
+
531
+
532
+ def create_app() -> FastAPI:
533
+ store = _JobStore()
534
+
535
+ QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
536
+ WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended
537
+
538
+ INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
539
+ AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
540
+
541
+ def _path_to_audio_url(path: str) -> str:
542
+ """Convert local file path to downloadable relative URL"""
543
+ if not path:
544
+ return path
545
+ if path.startswith("http://") or path.startswith("https://"):
546
+ return path
547
+ encoded_path = urllib.parse.quote(path, safe="")
548
+ return f"/v1/audio?path={encoded_path}"
549
+
550
+ @asynccontextmanager
551
+ async def lifespan(app: FastAPI):
552
+ # Clear proxy env that may affect downstream libs
553
+ for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
554
+ os.environ.pop(proxy_var, None)
555
+
556
+ # Ensure compilation/temp caches do not fill up small default /tmp.
557
+ # Triton/Inductor (and the system compiler) can create large temporary files.
558
+ project_root = _get_project_root()
559
+ cache_root = os.path.join(project_root, ".cache", "acestep")
560
+ tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
561
+ triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
562
+ inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
563
+
564
+ for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
565
+ try:
566
+ os.makedirs(p, exist_ok=True)
567
+ except Exception:
568
+ # Best-effort: do not block startup if directory creation fails.
569
+ pass
570
+
571
+ # Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
572
+ if os.getenv("ACESTEP_TMPDIR"):
573
+ os.environ["TMPDIR"] = tmp_root
574
+ os.environ["TEMP"] = tmp_root
575
+ os.environ["TMP"] = tmp_root
576
+ else:
577
+ os.environ.setdefault("TMPDIR", tmp_root)
578
+ os.environ.setdefault("TEMP", tmp_root)
579
+ os.environ.setdefault("TMP", tmp_root)
580
+
581
+ os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
582
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
583
+
584
+ handler = AceStepHandler()
585
+ llm_handler = LLMHandler()
586
+ init_lock = asyncio.Lock()
587
+ app.state._initialized = False
588
+ app.state._init_error = None
589
+ app.state._init_lock = init_lock
590
+
591
+ app.state.llm_handler = llm_handler
592
+ app.state._llm_initialized = False
593
+ app.state._llm_init_error = None
594
+ app.state._llm_init_lock = Lock()
595
+
596
+ # Multi-model support: secondary DiT handlers
597
+ handler2 = None
598
+ handler3 = None
599
+ config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
600
+ config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
601
+
602
+ if config_path2:
603
+ handler2 = AceStepHandler()
604
+ if config_path3:
605
+ handler3 = AceStepHandler()
606
+
607
+ app.state.handler2 = handler2
608
+ app.state.handler3 = handler3
609
+ app.state._initialized2 = False
610
+ app.state._initialized3 = False
611
+ app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
612
+ app.state._config_path2 = config_path2
613
+ app.state._config_path3 = config_path3
614
+
615
+ max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
616
+ executor = ThreadPoolExecutor(max_workers=max_workers)
617
+
618
+ # Queue & observability
619
+ app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req)
620
+ app.state.pending_ids = deque() # queued job_ids
621
+ app.state.pending_lock = asyncio.Lock()
622
+
623
+ # temp files per job (from multipart uploads)
624
+ app.state.job_temp_files = {} # job_id -> list[path]
625
+ app.state.job_temp_files_lock = asyncio.Lock()
626
+
627
+ # stats
628
+ app.state.stats_lock = asyncio.Lock()
629
+ app.state.recent_durations = deque(maxlen=AVG_WINDOW)
630
+ app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS
631
+
632
+ app.state.handler = handler
633
+ app.state.executor = executor
634
+ app.state.job_store = store
635
+ app.state._python_executable = sys.executable
636
+
637
+ # Temporary directory for saving generated audio files
638
+ app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
639
+ os.makedirs(app.state.temp_audio_dir, exist_ok=True)
640
+
641
+ # Initialize local cache
642
+ try:
643
+ from acestep.local_cache import get_local_cache
644
+ local_cache_dir = os.path.join(cache_root, "local_redis")
645
+ app.state.local_cache = get_local_cache(local_cache_dir)
646
+ except ImportError:
647
+ app.state.local_cache = None
648
+
649
+ async def _ensure_initialized() -> None:
650
+ h: AceStepHandler = app.state.handler
651
+
652
+ if getattr(app.state, "_initialized", False):
653
+ return
654
+ if getattr(app.state, "_init_error", None):
655
+ raise RuntimeError(app.state._init_error)
656
+
657
+ async with app.state._init_lock:
658
+ if getattr(app.state, "_initialized", False):
659
+ return
660
+ if getattr(app.state, "_init_error", None):
661
+ raise RuntimeError(app.state._init_error)
662
+
663
+ project_root = _get_project_root()
664
+ config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
665
+ device = os.getenv("ACESTEP_DEVICE", "auto")
666
+
667
+ use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
668
+ offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
669
+ offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
670
+
671
+ # Initialize primary model
672
+ status_msg, ok = h.initialize_service(
673
+ project_root=project_root,
674
+ config_path=config_path,
675
+ device=device,
676
+ use_flash_attention=use_flash_attention,
677
+ compile_model=False,
678
+ offload_to_cpu=offload_to_cpu,
679
+ offload_dit_to_cpu=offload_dit_to_cpu,
680
+ )
681
+ if not ok:
682
+ app.state._init_error = status_msg
683
+ raise RuntimeError(status_msg)
684
+ app.state._initialized = True
685
+
686
+ # Initialize secondary model if configured
687
+ if app.state.handler2 and app.state._config_path2:
688
+ try:
689
+ status_msg2, ok2 = app.state.handler2.initialize_service(
690
+ project_root=project_root,
691
+ config_path=app.state._config_path2,
692
+ device=device,
693
+ use_flash_attention=use_flash_attention,
694
+ compile_model=False,
695
+ offload_to_cpu=offload_to_cpu,
696
+ offload_dit_to_cpu=offload_dit_to_cpu,
697
+ )
698
+ app.state._initialized2 = ok2
699
+ if ok2:
700
+ print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
701
+ else:
702
+ print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
703
+ except Exception as e:
704
+ print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
705
+ app.state._initialized2 = False
706
+
707
+ # Initialize third model if configured
708
+ if app.state.handler3 and app.state._config_path3:
709
+ try:
710
+ status_msg3, ok3 = app.state.handler3.initialize_service(
711
+ project_root=project_root,
712
+ config_path=app.state._config_path3,
713
+ device=device,
714
+ use_flash_attention=use_flash_attention,
715
+ compile_model=False,
716
+ offload_to_cpu=offload_to_cpu,
717
+ offload_dit_to_cpu=offload_dit_to_cpu,
718
+ )
719
+ app.state._initialized3 = ok3
720
+ if ok3:
721
+ print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
722
+ else:
723
+ print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
724
+ except Exception as e:
725
+ print(f"[API Server] Warning: Failed to initialize third model: {e}")
726
+ app.state._initialized3 = False
727
+
728
+ async def _cleanup_job_temp_files(job_id: str) -> None:
729
+ async with app.state.job_temp_files_lock:
730
+ paths = app.state.job_temp_files.pop(job_id, [])
731
+ for p in paths:
732
+ try:
733
+ os.remove(p)
734
+ except Exception:
735
+ pass
736
+
737
+ def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None:
738
+ """Update local cache with job result"""
739
+ local_cache = getattr(app.state, 'local_cache', None)
740
+ if not local_cache:
741
+ return
742
+
743
+ rec = store.get(job_id)
744
+ env = getattr(rec, 'env', 'development') if rec else 'development'
745
+ create_time = rec.created_at if rec else time.time()
746
+
747
+ status_int = _map_status(status)
748
+
749
+ if status == "succeeded" and result:
750
+ audio_paths = result.get("audio_paths", [])
751
+ # Final prompt/lyrics (may be modified by thinking/format)
752
+ final_prompt = result.get("prompt", "")
753
+ final_lyrics = result.get("lyrics", "")
754
+ # Original user input from metas
755
+ metas_raw = result.get("metas", {}) or {}
756
+ original_prompt = metas_raw.get("prompt", "")
757
+ original_lyrics = metas_raw.get("lyrics", "")
758
+ # metas contains original input + other metadata
759
+ metas = {
760
+ "bpm": metas_raw.get("bpm"),
761
+ "duration": metas_raw.get("duration"),
762
+ "genres": metas_raw.get("genres", ""),
763
+ "keyscale": metas_raw.get("keyscale", ""),
764
+ "timesignature": metas_raw.get("timesignature", ""),
765
+ "prompt": original_prompt,
766
+ "lyrics": original_lyrics,
767
+ }
768
+ # Extra fields for Discord bot
769
+ generation_info = result.get("generation_info", "")
770
+ seed_value = result.get("seed_value", "")
771
+ lm_model = result.get("lm_model", "")
772
+ dit_model = result.get("dit_model", "")
773
+
774
+ if audio_paths:
775
+ result_data = [
776
+ {
777
+ "file": p,
778
+ "wave": "",
779
+ "status": status_int,
780
+ "create_time": int(create_time),
781
+ "env": env,
782
+ "prompt": final_prompt,
783
+ "lyrics": final_lyrics,
784
+ "metas": metas,
785
+ "generation_info": generation_info,
786
+ "seed_value": seed_value,
787
+ "lm_model": lm_model,
788
+ "dit_model": dit_model,
789
+ }
790
+ for p in audio_paths
791
+ ]
792
+ else:
793
+ result_data = [{
794
+ "file": "",
795
+ "wave": "",
796
+ "status": status_int,
797
+ "create_time": int(create_time),
798
+ "env": env,
799
+ "prompt": final_prompt,
800
+ "lyrics": final_lyrics,
801
+ "metas": metas,
802
+ "generation_info": generation_info,
803
+ "seed_value": seed_value,
804
+ "lm_model": lm_model,
805
+ "dit_model": dit_model,
806
+ }]
807
+ else:
808
+ result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
809
+
810
+ result_key = f"{RESULT_KEY_PREFIX}{job_id}"
811
+ local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS)
812
+
813
+ async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
814
+ job_store: _JobStore = app.state.job_store
815
+ llm: LLMHandler = app.state.llm_handler
816
+ executor: ThreadPoolExecutor = app.state.executor
817
+
818
+ await _ensure_initialized()
819
+ job_store.mark_running(job_id)
820
+
821
+ # Select DiT handler based on user's model choice
822
+ # Default: use primary handler
823
+ selected_handler: AceStepHandler = app.state.handler
824
+ selected_model_name = _get_model_name(app.state._config_path)
825
+
826
+ if req.model:
827
+ model_matched = False
828
+
829
+ # Check if it matches the second model
830
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
831
+ model2_name = _get_model_name(app.state._config_path2)
832
+ if req.model == model2_name:
833
+ selected_handler = app.state.handler2
834
+ selected_model_name = model2_name
835
+ model_matched = True
836
+ print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
837
+
838
+ # Check if it matches the third model
839
+ if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
840
+ model3_name = _get_model_name(app.state._config_path3)
841
+ if req.model == model3_name:
842
+ selected_handler = app.state.handler3
843
+ selected_model_name = model3_name
844
+ model_matched = True
845
+ print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
846
+
847
+ if not model_matched:
848
+ available_models = [_get_model_name(app.state._config_path)]
849
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
850
+ available_models.append(_get_model_name(app.state._config_path2))
851
+ if app.state.handler3 and getattr(app.state, "_initialized3", False):
852
+ available_models.append(_get_model_name(app.state._config_path3))
853
+ print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
854
+
855
+ # Use selected handler for generation
856
+ h: AceStepHandler = selected_handler
857
+
858
+ def _blocking_generate() -> Dict[str, Any]:
859
+ """Generate music using unified inference logic from acestep.inference"""
860
+
861
+ def _ensure_llm_ready() -> None:
862
+ """Ensure LLM handler is initialized when needed"""
863
+ with app.state._llm_init_lock:
864
+ initialized = getattr(app.state, "_llm_initialized", False)
865
+ had_error = getattr(app.state, "_llm_init_error", None)
866
+ if initialized or had_error is not None:
867
+ return
868
+
869
+ project_root = _get_project_root()
870
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
871
+ lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
872
+ backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
873
+ if backend not in {"vllm", "pt"}:
874
+ backend = "vllm"
875
+
876
+ lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
877
+ lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
878
+
879
+ status, ok = llm.initialize(
880
+ checkpoint_dir=checkpoint_dir,
881
+ lm_model_path=lm_model_path,
882
+ backend=backend,
883
+ device=lm_device,
884
+ offload_to_cpu=lm_offload,
885
+ dtype=h.dtype,
886
+ )
887
+ if not ok:
888
+ app.state._llm_init_error = status
889
+ else:
890
+ app.state._llm_initialized = True
891
+
892
+ def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
893
+ """Ensure a stable `metas` dict (keys always present)."""
894
+ meta = meta or {}
895
+ out: Dict[str, Any] = dict(meta)
896
+
897
+ # Normalize key aliases
898
+ if "keyscale" not in out and "key_scale" in out:
899
+ out["keyscale"] = out.get("key_scale")
900
+ if "timesignature" not in out and "time_signature" in out:
901
+ out["timesignature"] = out.get("time_signature")
902
+
903
+ # Ensure required keys exist
904
+ for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
905
+ if out.get(k) in (None, ""):
906
+ out[k] = "N/A"
907
+ return out
908
+
909
+ # Normalize LM sampling parameters
910
+ lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
911
+ lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
912
+
913
+ # Determine if LLM is needed
914
+ thinking = bool(req.thinking)
915
+ sample_mode = bool(req.sample_mode)
916
+ has_sample_query = bool(req.sample_query and req.sample_query.strip())
917
+ use_format = bool(req.use_format)
918
+ use_cot_caption = bool(req.use_cot_caption)
919
+ use_cot_language = bool(req.use_cot_language)
920
+
921
+ # LLM is needed for:
922
+ # - thinking mode (LM generates audio codes)
923
+ # - sample_mode (LM generates random caption/lyrics/metas)
924
+ # - sample_query/description (LM generates from description)
925
+ # - use_format (LM enhances caption/lyrics)
926
+ # - use_cot_caption or use_cot_language (LM enhances metadata)
927
+ need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
928
+
929
+ # Ensure LLM is ready if needed
930
+ if need_llm:
931
+ _ensure_llm_ready()
932
+ if getattr(app.state, "_llm_init_error", None):
933
+ raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
934
+
935
+ # Handle sample mode or description: generate caption/lyrics/metas via LM
936
+ caption = req.prompt
937
+ lyrics = req.lyrics
938
+ bpm = req.bpm
939
+ key_scale = req.key_scale
940
+ time_signature = req.time_signature
941
+ audio_duration = req.audio_duration
942
+
943
+ # Save original user input for metas
944
+ original_prompt = req.prompt or ""
945
+ original_lyrics = req.lyrics or ""
946
+
947
+ if sample_mode or has_sample_query:
948
+ if has_sample_query:
949
+ # Use create_sample() with description query
950
+ parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
951
+
952
+ # Determine vocal_language with priority:
953
+ # 1. User-specified vocal_language (if not default "en")
954
+ # 2. Language parsed from description
955
+ # 3. None (no constraint)
956
+ if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
957
+ sample_language = req.vocal_language
958
+ else:
959
+ sample_language = parsed_language
960
+
961
+ sample_result = create_sample(
962
+ llm_handler=llm,
963
+ query=req.sample_query,
964
+ instrumental=parsed_instrumental,
965
+ vocal_language=sample_language,
966
+ temperature=req.lm_temperature,
967
+ top_k=lm_top_k if lm_top_k > 0 else None,
968
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
969
+ use_constrained_decoding=req.constrained_decoding,
970
+ )
971
+
972
+ if not sample_result.success:
973
+ raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
974
+
975
+ # Use generated sample data
976
+ caption = sample_result.caption
977
+ lyrics = sample_result.lyrics
978
+ bpm = sample_result.bpm
979
+ key_scale = sample_result.keyscale
980
+ time_signature = sample_result.timesignature
981
+ audio_duration = sample_result.duration
982
+ else:
983
+ # Original sample_mode behavior: random generation
984
+ sample_metadata, sample_status = llm.understand_audio_from_codes(
985
+ audio_codes="NO USER INPUT",
986
+ temperature=req.lm_temperature,
987
+ top_k=lm_top_k if lm_top_k > 0 else None,
988
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
989
+ repetition_penalty=req.lm_repetition_penalty,
990
+ use_constrained_decoding=req.constrained_decoding,
991
+ constrained_decoding_debug=req.constrained_decoding_debug,
992
+ )
993
+
994
+ if not sample_metadata or str(sample_status).startswith("❌"):
995
+ raise RuntimeError(f"Sample generation failed: {sample_status}")
996
+
997
+ # Use generated values with fallback defaults
998
+ caption = sample_metadata.get("caption", "")
999
+ lyrics = sample_metadata.get("lyrics", "")
1000
+ bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
1001
+ key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
1002
+ time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
1003
+ audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
1004
+
1005
+ # Apply format_sample() if use_format is True and caption/lyrics are provided
1006
+ format_has_duration = False
1007
+
1008
+ if req.use_format and (caption or lyrics):
1009
+ _ensure_llm_ready()
1010
+ if getattr(app.state, "_llm_init_error", None):
1011
+ raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
1012
+
1013
+ # Build user_metadata from request params (matching bot.py behavior)
1014
+ user_metadata_for_format = {}
1015
+ if bpm is not None:
1016
+ user_metadata_for_format['bpm'] = bpm
1017
+ if audio_duration is not None and audio_duration > 0:
1018
+ user_metadata_for_format['duration'] = int(audio_duration)
1019
+ if key_scale:
1020
+ user_metadata_for_format['keyscale'] = key_scale
1021
+ if time_signature:
1022
+ user_metadata_for_format['timesignature'] = time_signature
1023
+ if req.vocal_language and req.vocal_language != "unknown":
1024
+ user_metadata_for_format['language'] = req.vocal_language
1025
+
1026
+ format_result = format_sample(
1027
+ llm_handler=llm,
1028
+ caption=caption,
1029
+ lyrics=lyrics,
1030
+ user_metadata=user_metadata_for_format if user_metadata_for_format else None,
1031
+ temperature=req.lm_temperature,
1032
+ top_k=lm_top_k if lm_top_k > 0 else None,
1033
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
1034
+ use_constrained_decoding=req.constrained_decoding,
1035
+ )
1036
+
1037
+ if format_result.success:
1038
+ # Extract all formatted data (matching bot.py behavior)
1039
+ caption = format_result.caption or caption
1040
+ lyrics = format_result.lyrics or lyrics
1041
+ if format_result.duration:
1042
+ audio_duration = format_result.duration
1043
+ format_has_duration = True
1044
+ if format_result.bpm:
1045
+ bpm = format_result.bpm
1046
+ if format_result.keyscale:
1047
+ key_scale = format_result.keyscale
1048
+ if format_result.timesignature:
1049
+ time_signature = format_result.timesignature
1050
+
1051
+ # Parse timesteps string to list of floats if provided
1052
+ parsed_timesteps = _parse_timesteps(req.timesteps)
1053
+
1054
+ # Determine actual inference steps (timesteps override inference_steps)
1055
+ actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
1056
+
1057
+ # Auto-select instruction based on task_type if user didn't provide custom instruction
1058
+ # This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
1059
+ instruction_to_use = req.instruction
1060
+ if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
1061
+ instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
1062
+
1063
+ # Build GenerationParams using unified interface
1064
+ # Note: thinking controls LM code generation, sample_mode only affects CoT metas
1065
+ params = GenerationParams(
1066
+ task_type=req.task_type,
1067
+ instruction=instruction_to_use,
1068
+ reference_audio=req.reference_audio_path,
1069
+ src_audio=req.src_audio_path,
1070
+ audio_codes=req.audio_code_string,
1071
+ caption=caption,
1072
+ lyrics=lyrics,
1073
+ instrumental=False,
1074
+ vocal_language=req.vocal_language,
1075
+ bpm=bpm,
1076
+ keyscale=key_scale,
1077
+ timesignature=time_signature,
1078
+ duration=audio_duration if audio_duration else -1.0,
1079
+ inference_steps=req.inference_steps,
1080
+ seed=req.seed,
1081
+ guidance_scale=req.guidance_scale,
1082
+ use_adg=req.use_adg,
1083
+ cfg_interval_start=req.cfg_interval_start,
1084
+ cfg_interval_end=req.cfg_interval_end,
1085
+ shift=req.shift,
1086
+ infer_method=req.infer_method,
1087
+ timesteps=parsed_timesteps,
1088
+ repainting_start=req.repainting_start,
1089
+ repainting_end=req.repainting_end if req.repainting_end else -1,
1090
+ audio_cover_strength=req.audio_cover_strength,
1091
+ # LM parameters
1092
+ thinking=thinking, # Use LM for code generation when thinking=True
1093
+ lm_temperature=req.lm_temperature,
1094
+ lm_cfg_scale=req.lm_cfg_scale,
1095
+ lm_top_k=lm_top_k,
1096
+ lm_top_p=lm_top_p,
1097
+ lm_negative_prompt=req.lm_negative_prompt,
1098
+ # use_cot_metas logic:
1099
+ # - sample_mode: metas already generated, skip Phase 1
1100
+ # - format with duration: metas already generated, skip Phase 1
1101
+ # - format without duration: need Phase 1 to generate duration
1102
+ # - no format: need Phase 1 to generate all metas
1103
+ use_cot_metas=not sample_mode and not format_has_duration,
1104
+ use_cot_caption=req.use_cot_caption,
1105
+ use_cot_language=req.use_cot_language,
1106
+ use_constrained_decoding=req.constrained_decoding,
1107
+ )
1108
+
1109
+ # Build GenerationConfig - default to 2 audios like gradio_ui
1110
+ batch_size = req.batch_size if req.batch_size is not None else 2
1111
+ config = GenerationConfig(
1112
+ batch_size=batch_size,
1113
+ use_random_seed=req.use_random_seed,
1114
+ seeds=None, # Let unified logic handle seed generation
1115
+ audio_format=req.audio_format,
1116
+ constrained_decoding_debug=req.constrained_decoding_debug,
1117
+ )
1118
+
1119
+ # Check LLM initialization status
1120
+ llm_is_initialized = getattr(app.state, "_llm_initialized", False)
1121
+ llm_to_pass = llm if llm_is_initialized else None
1122
+
1123
+ # Generate music using unified interface
1124
+ result = generate_music(
1125
+ dit_handler=h,
1126
+ llm_handler=llm_to_pass,
1127
+ params=params,
1128
+ config=config,
1129
+ save_dir=app.state.temp_audio_dir,
1130
+ progress=None,
1131
+ )
1132
+
1133
+ if not result.success:
1134
+ raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
1135
+
1136
+ # Extract results
1137
+ audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
1138
+ first_audio = audio_paths[0] if len(audio_paths) > 0 else None
1139
+ second_audio = audio_paths[1] if len(audio_paths) > 1 else None
1140
+
1141
+ # Get metadata from LM or CoT results
1142
+ lm_metadata = result.extra_outputs.get("lm_metadata", {})
1143
+ metas_out = _normalize_metas(lm_metadata)
1144
+
1145
+ # Update metas with actual values used
1146
+ if params.cot_bpm:
1147
+ metas_out["bpm"] = params.cot_bpm
1148
+ elif bpm:
1149
+ metas_out["bpm"] = bpm
1150
+
1151
+ if params.cot_duration:
1152
+ metas_out["duration"] = params.cot_duration
1153
+ elif audio_duration:
1154
+ metas_out["duration"] = audio_duration
1155
+
1156
+ if params.cot_keyscale:
1157
+ metas_out["keyscale"] = params.cot_keyscale
1158
+ elif key_scale:
1159
+ metas_out["keyscale"] = key_scale
1160
+
1161
+ if params.cot_timesignature:
1162
+ metas_out["timesignature"] = params.cot_timesignature
1163
+ elif time_signature:
1164
+ metas_out["timesignature"] = time_signature
1165
+
1166
+ # Store original user input in metas (not the final/modified values)
1167
+ metas_out["prompt"] = original_prompt
1168
+ metas_out["lyrics"] = original_lyrics
1169
+
1170
+ # Extract seed values for response (comma-separated for multiple audios)
1171
+ seed_values = []
1172
+ for audio in result.audios:
1173
+ audio_params = audio.get("params", {})
1174
+ seed = audio_params.get("seed")
1175
+ if seed is not None:
1176
+ seed_values.append(str(seed))
1177
+ seed_value = ",".join(seed_values) if seed_values else ""
1178
+
1179
+ # Build generation_info using the helper function (like gradio_ui)
1180
+ time_costs = result.extra_outputs.get("time_costs", {})
1181
+ generation_info = _build_generation_info(
1182
+ lm_metadata=lm_metadata,
1183
+ time_costs=time_costs,
1184
+ seed_value=seed_value,
1185
+ inference_steps=req.inference_steps,
1186
+ num_audios=len(result.audios),
1187
+ )
1188
+
1189
+ def _none_if_na_str(v: Any) -> Optional[str]:
1190
+ if v is None:
1191
+ return None
1192
+ s = str(v).strip()
1193
+ if s in {"", "N/A"}:
1194
+ return None
1195
+ return s
1196
+
1197
+ # Get model information
1198
+ lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B")
1199
+ # Use selected_model_name (set at the beginning of _run_one_job)
1200
+ dit_model_name = selected_model_name
1201
+
1202
+ return {
1203
+ "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
1204
+ "second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
1205
+ "audio_paths": [_path_to_audio_url(p) for p in audio_paths],
1206
+ "generation_info": generation_info,
1207
+ "status_message": result.status_message,
1208
+ "seed_value": seed_value,
1209
+ # Final prompt/lyrics (may be modified by thinking/format)
1210
+ "prompt": caption or "",
1211
+ "lyrics": lyrics or "",
1212
+ # metas contains original user input + other metadata
1213
+ "metas": metas_out,
1214
+ "bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
1215
+ "duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
1216
+ "genres": _none_if_na_str(metas_out.get("genres")),
1217
+ "keyscale": _none_if_na_str(metas_out.get("keyscale")),
1218
+ "timesignature": _none_if_na_str(metas_out.get("timesignature")),
1219
+ "lm_model": lm_model_name,
1220
+ "dit_model": dit_model_name,
1221
+ }
1222
+
1223
+ t0 = time.time()
1224
+ try:
1225
+ loop = asyncio.get_running_loop()
1226
+ result = await loop.run_in_executor(executor, _blocking_generate)
1227
+ job_store.mark_succeeded(job_id, result)
1228
+
1229
+ # Update local cache
1230
+ _update_local_cache(job_id, result, "succeeded")
1231
+ except Exception:
1232
+ job_store.mark_failed(job_id, traceback.format_exc())
1233
+
1234
+ # Update local cache
1235
+ _update_local_cache(job_id, None, "failed")
1236
+ finally:
1237
+ dt = max(0.0, time.time() - t0)
1238
+ async with app.state.stats_lock:
1239
+ app.state.recent_durations.append(dt)
1240
+ if app.state.recent_durations:
1241
+ app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations)
1242
+
1243
+ async def _queue_worker(worker_idx: int) -> None:
1244
+ while True:
1245
+ job_id, req = await app.state.job_queue.get()
1246
+ try:
1247
+ async with app.state.pending_lock:
1248
+ try:
1249
+ app.state.pending_ids.remove(job_id)
1250
+ except ValueError:
1251
+ pass
1252
+
1253
+ await _run_one_job(job_id, req)
1254
+ finally:
1255
+ await _cleanup_job_temp_files(job_id)
1256
+ app.state.job_queue.task_done()
1257
+
1258
+ worker_count = max(1, WORKER_COUNT)
1259
+ workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)]
1260
+ app.state.worker_tasks = workers
1261
+
1262
+ try:
1263
+ yield
1264
+ finally:
1265
+ for t in workers:
1266
+ t.cancel()
1267
+ executor.shutdown(wait=False, cancel_futures=True)
1268
+
1269
+ app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
1270
+
1271
+ async def _queue_position(job_id: str) -> int:
1272
+ async with app.state.pending_lock:
1273
+ try:
1274
+ return list(app.state.pending_ids).index(job_id) + 1
1275
+ except ValueError:
1276
+ return 0
1277
+
1278
+ async def _eta_seconds_for_position(pos: int) -> Optional[float]:
1279
+ if pos <= 0:
1280
+ return None
1281
+ async with app.state.stats_lock:
1282
+ avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
1283
+ return pos * avg
1284
+
1285
+ @app.post("/release_task", response_model=CreateJobResponse)
1286
+ async def create_music_generate_job(request: Request) -> CreateJobResponse:
1287
+ content_type = (request.headers.get("content-type") or "").lower()
1288
+ temp_files: list[str] = []
1289
+
1290
+ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest:
1291
+ """Build GenerateMusicRequest from parsed parameters."""
1292
+ return GenerateMusicRequest(
1293
+ prompt=p.str("prompt"),
1294
+ lyrics=p.str("lyrics"),
1295
+ thinking=p.bool("thinking"),
1296
+ sample_mode=p.bool("sample_mode"),
1297
+ sample_query=p.str("sample_query"),
1298
+ use_format=p.bool("use_format"),
1299
+ model=p.str("model") or None,
1300
+ bpm=p.int("bpm"),
1301
+ key_scale=p.str("key_scale"),
1302
+ time_signature=p.str("time_signature"),
1303
+ audio_duration=p.float("audio_duration"),
1304
+ vocal_language=p.str("vocal_language", "en"),
1305
+ inference_steps=p.int("inference_steps", 8),
1306
+ guidance_scale=p.float("guidance_scale", 7.0),
1307
+ use_random_seed=p.bool("use_random_seed", True),
1308
+ seed=p.int("seed", -1),
1309
+ batch_size=p.int("batch_size"),
1310
+ audio_code_string=p.str("audio_code_string"),
1311
+ repainting_start=p.float("repainting_start", 0.0),
1312
+ repainting_end=p.float("repainting_end"),
1313
+ instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
1314
+ audio_cover_strength=p.float("audio_cover_strength", 1.0),
1315
+ task_type=p.str("task_type", "text2music"),
1316
+ use_adg=p.bool("use_adg"),
1317
+ cfg_interval_start=p.float("cfg_interval_start", 0.0),
1318
+ cfg_interval_end=p.float("cfg_interval_end", 1.0),
1319
+ infer_method=p.str("infer_method", "ode"),
1320
+ shift=p.float("shift", 3.0),
1321
+ audio_format=p.str("audio_format", "mp3"),
1322
+ use_tiled_decode=p.bool("use_tiled_decode", True),
1323
+ lm_model_path=p.str("lm_model_path") or None,
1324
+ lm_backend=p.str("lm_backend", "vllm"),
1325
+ lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE),
1326
+ lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE),
1327
+ lm_top_k=p.int("lm_top_k"),
1328
+ lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P),
1329
+ lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0),
1330
+ lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"),
1331
+ constrained_decoding=p.bool("constrained_decoding", True),
1332
+ constrained_decoding_debug=p.bool("constrained_decoding_debug"),
1333
+ use_cot_caption=p.bool("use_cot_caption", True),
1334
+ use_cot_language=p.bool("use_cot_language", True),
1335
+ is_format_caption=p.bool("is_format_caption"),
1336
+ **kwargs,
1337
+ )
1338
+
1339
+ if content_type.startswith("application/json"):
1340
+ body = await request.json()
1341
+ if not isinstance(body, dict):
1342
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1343
+ req = _build_request(RequestParser(body))
1344
+
1345
+ elif content_type.endswith("+json"):
1346
+ body = await request.json()
1347
+ if not isinstance(body, dict):
1348
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1349
+ req = _build_request(RequestParser(body))
1350
+
1351
+ elif content_type.startswith("multipart/form-data"):
1352
+ form = await request.form()
1353
+
1354
+ ref_up = form.get("reference_audio")
1355
+ src_up = form.get("src_audio")
1356
+
1357
+ reference_audio_path = None
1358
+ src_audio_path = None
1359
+
1360
+ if isinstance(ref_up, StarletteUploadFile):
1361
+ reference_audio_path = await _save_upload_to_temp(ref_up, prefix="reference_audio")
1362
+ temp_files.append(reference_audio_path)
1363
+ else:
1364
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1365
+
1366
+ if isinstance(src_up, StarletteUploadFile):
1367
+ src_audio_path = await _save_upload_to_temp(src_up, prefix="src_audio")
1368
+ temp_files.append(src_audio_path)
1369
+ else:
1370
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1371
+
1372
+ req = _build_request(
1373
+ RequestParser(dict(form)),
1374
+ reference_audio_path=reference_audio_path,
1375
+ src_audio_path=src_audio_path,
1376
+ )
1377
+
1378
+ elif content_type.startswith("application/x-www-form-urlencoded"):
1379
+ form = await request.form()
1380
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1381
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1382
+ req = _build_request(
1383
+ RequestParser(dict(form)),
1384
+ reference_audio_path=reference_audio_path,
1385
+ src_audio_path=src_audio_path,
1386
+ )
1387
+
1388
+ else:
1389
+ raw = await request.body()
1390
+ raw_stripped = raw.lstrip()
1391
+ # Best-effort: accept missing/incorrect Content-Type if payload is valid JSON.
1392
+ if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["):
1393
+ try:
1394
+ body = json.loads(raw.decode("utf-8"))
1395
+ if isinstance(body, dict):
1396
+ req = _build_request(RequestParser(body))
1397
+ else:
1398
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1399
+ except HTTPException:
1400
+ raise
1401
+ except Exception:
1402
+ raise HTTPException(
1403
+ status_code=400,
1404
+ detail="Invalid JSON body (hint: set 'Content-Type: application/json')",
1405
+ )
1406
+ # Best-effort: parse key=value bodies even if Content-Type is missing.
1407
+ elif raw_stripped and b"=" in raw:
1408
+ parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
1409
+ flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()}
1410
+ reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
1411
+ src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
1412
+ req = _build_request(
1413
+ RequestParser(flat),
1414
+ reference_audio_path=reference_audio_path,
1415
+ src_audio_path=src_audio_path,
1416
+ )
1417
+ else:
1418
+ raise HTTPException(
1419
+ status_code=415,
1420
+ detail=(
1421
+ f"Unsupported Content-Type: {content_type or '(missing)'}; "
1422
+ "use application/json, application/x-www-form-urlencoded, or multipart/form-data"
1423
+ ),
1424
+ )
1425
+
1426
+ rec = store.create()
1427
+
1428
+ q: asyncio.Queue = app.state.job_queue
1429
+ if q.full():
1430
+ for p in temp_files:
1431
+ try:
1432
+ os.remove(p)
1433
+ except Exception:
1434
+ pass
1435
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
1436
+
1437
+ if temp_files:
1438
+ async with app.state.job_temp_files_lock:
1439
+ app.state.job_temp_files[rec.job_id] = temp_files
1440
+
1441
+ async with app.state.pending_lock:
1442
+ app.state.pending_ids.append(rec.job_id)
1443
+ position = len(app.state.pending_ids)
1444
+
1445
+ await q.put((rec.job_id, req))
1446
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1447
+
1448
+ @app.post("/v1/music/random", response_model=CreateJobResponse)
1449
+ async def create_random_sample_job(request: Request) -> CreateJobResponse:
1450
+ """Create a sample-mode job that auto-generates caption/lyrics via LM."""
1451
+
1452
+ thinking_value: Any = None
1453
+ content_type = (request.headers.get("content-type") or "").lower()
1454
+ body_dict: Dict[str, Any] = {}
1455
+
1456
+ if "json" in content_type:
1457
+ try:
1458
+ payload = await request.json()
1459
+ if isinstance(payload, dict):
1460
+ body_dict = payload
1461
+ except Exception:
1462
+ body_dict = {}
1463
+
1464
+ if not body_dict and request.query_params:
1465
+ body_dict = dict(request.query_params)
1466
+
1467
+ thinking_value = body_dict.get("thinking")
1468
+ if thinking_value is None:
1469
+ thinking_value = body_dict.get("Thinking")
1470
+
1471
+ thinking_flag = _to_bool(thinking_value, True)
1472
+
1473
+ req = GenerateMusicRequest(
1474
+ caption="",
1475
+ lyrics="",
1476
+ thinking=thinking_flag,
1477
+ sample_mode=True,
1478
+ )
1479
+
1480
+ rec = store.create()
1481
+ q: asyncio.Queue = app.state.job_queue
1482
+ if q.full():
1483
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
1484
+
1485
+ async with app.state.pending_lock:
1486
+ app.state.pending_ids.append(rec.job_id)
1487
+ position = len(app.state.pending_ids)
1488
+
1489
+ await q.put((rec.job_id, req))
1490
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1491
+
1492
+ @app.post("/query_result")
1493
+ async def query_result(request: Request) -> List[Dict[str, Any]]:
1494
+ """Batch query job results"""
1495
+ content_type = (request.headers.get("content-type") or "").lower()
1496
+
1497
+ if "json" in content_type:
1498
+ body = await request.json()
1499
+ else:
1500
+ form = await request.form()
1501
+ body = {k: v for k, v in form.items()}
1502
+
1503
+ task_id_list_str = body.get("task_id_list", "[]")
1504
+
1505
+ # Parse task ID list
1506
+ if isinstance(task_id_list_str, list):
1507
+ task_id_list = task_id_list_str
1508
+ else:
1509
+ try:
1510
+ task_id_list = json.loads(task_id_list_str)
1511
+ except Exception:
1512
+ task_id_list = []
1513
+
1514
+ local_cache = getattr(app.state, 'local_cache', None)
1515
+ data_list = []
1516
+ current_time = time.time()
1517
+
1518
+ for task_id in task_id_list:
1519
+ result_key = f"{RESULT_KEY_PREFIX}{task_id}"
1520
+
1521
+ # Read from local cache first
1522
+ if local_cache:
1523
+ data = local_cache.get(result_key)
1524
+ if data:
1525
+ try:
1526
+ data_json = json.loads(data)
1527
+ except Exception:
1528
+ data_json = []
1529
+
1530
+ if len(data_json) <= 0:
1531
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1532
+ else:
1533
+ status = data_json[0].get("status")
1534
+ create_time = data_json[0].get("create_time", 0)
1535
+ if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS:
1536
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1537
+ else:
1538
+ data_list.append({
1539
+ "task_id": task_id,
1540
+ "result": data,
1541
+ "status": int(status) if status is not None else 1,
1542
+ })
1543
+ continue
1544
+
1545
+ # Fallback to job_store query
1546
+ rec = store.get(task_id)
1547
+ if rec:
1548
+ env = getattr(rec, 'env', 'development')
1549
+ create_time = rec.created_at
1550
+ status_int = _map_status(rec.status)
1551
+
1552
+ if rec.result and rec.status == "succeeded":
1553
+ audio_paths = rec.result.get("audio_paths", [])
1554
+ metas = rec.result.get("metas", {}) or {}
1555
+ result_data = [
1556
+ {
1557
+ "file": p, "wave": "", "status": status_int,
1558
+ "create_time": int(create_time), "env": env,
1559
+ "prompt": metas.get("caption", ""),
1560
+ "lyrics": metas.get("lyrics", ""),
1561
+ "metas": {
1562
+ "bpm": metas.get("bpm"),
1563
+ "duration": metas.get("duration"),
1564
+ "genres": metas.get("genres", ""),
1565
+ "keyscale": metas.get("keyscale", ""),
1566
+ "timesignature": metas.get("timesignature", ""),
1567
+ }
1568
+ }
1569
+ for p in audio_paths
1570
+ ] if audio_paths else [{
1571
+ "file": "", "wave": "", "status": status_int,
1572
+ "create_time": int(create_time), "env": env,
1573
+ "prompt": metas.get("caption", ""),
1574
+ "lyrics": metas.get("lyrics", ""),
1575
+ "metas": {
1576
+ "bpm": metas.get("bpm"),
1577
+ "duration": metas.get("duration"),
1578
+ "genres": metas.get("genres", ""),
1579
+ "keyscale": metas.get("keyscale", ""),
1580
+ "timesignature": metas.get("timesignature", ""),
1581
+ }
1582
+ }]
1583
+ else:
1584
+ result_data = [{
1585
+ "file": "", "wave": "", "status": status_int,
1586
+ "create_time": int(create_time), "env": env,
1587
+ "prompt": "", "lyrics": "",
1588
+ "metas": {}
1589
+ }]
1590
+
1591
+ data_list.append({
1592
+ "task_id": task_id,
1593
+ "result": json.dumps(result_data, ensure_ascii=False),
1594
+ "status": status_int,
1595
+ })
1596
+ else:
1597
+ data_list.append({"task_id": task_id, "result": "[]", "status": 0})
1598
+
1599
+ return data_list
1600
+
1601
+ @app.get("/health")
1602
+ async def health_check():
1603
+ """Health check endpoint for service status."""
1604
+ return {
1605
+ "status": "ok",
1606
+ "service": "ACE-Step API",
1607
+ "version": "1.0",
1608
+ }
1609
+
1610
+ @app.get("/v1/models")
1611
+ async def list_models():
1612
+ """List available DiT models."""
1613
+ models = []
1614
+
1615
+ # Primary model (always available if initialized)
1616
+ if getattr(app.state, "_initialized", False):
1617
+ primary_model = _get_model_name(app.state._config_path)
1618
+ if primary_model:
1619
+ models.append({
1620
+ "name": primary_model,
1621
+ "is_default": True,
1622
+ })
1623
+
1624
+ # Secondary model
1625
+ if getattr(app.state, "_initialized2", False) and app.state._config_path2:
1626
+ secondary_model = _get_model_name(app.state._config_path2)
1627
+ if secondary_model:
1628
+ models.append({
1629
+ "name": secondary_model,
1630
+ "is_default": False,
1631
+ })
1632
+
1633
+ # Third model
1634
+ if getattr(app.state, "_initialized3", False) and app.state._config_path3:
1635
+ third_model = _get_model_name(app.state._config_path3)
1636
+ if third_model:
1637
+ models.append({
1638
+ "name": third_model,
1639
+ "is_default": False,
1640
+ })
1641
+
1642
+ return {
1643
+ "models": models,
1644
+ "default_model": models[0]["name"] if models else None,
1645
+ }
1646
+
1647
+ @app.get("/v1/audio")
1648
+ async def get_audio(path: str):
1649
+ """Serve audio file by path."""
1650
+ from fastapi.responses import FileResponse
1651
+
1652
+ if not os.path.exists(path):
1653
+ raise HTTPException(status_code=404, detail=f"Audio file not found: {path}")
1654
+
1655
+ ext = os.path.splitext(path)[1].lower()
1656
+ media_types = {
1657
+ ".mp3": "audio/mpeg",
1658
+ ".wav": "audio/wav",
1659
+ ".flac": "audio/flac",
1660
+ ".ogg": "audio/ogg",
1661
+ }
1662
+ media_type = media_types.get(ext, "audio/mpeg")
1663
+
1664
+ return FileResponse(path, media_type=media_type)
1665
+
1666
+ return app
1667
+
1668
+
1669
+ app = create_app()
1670
+
1671
+
1672
+ def main() -> None:
1673
+ import argparse
1674
+ import uvicorn
1675
+
1676
+ parser = argparse.ArgumentParser(description="ACE-Step API server")
1677
+ parser.add_argument(
1678
+ "--host",
1679
+ default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
1680
+ help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
1681
+ )
1682
+ parser.add_argument(
1683
+ "--port",
1684
+ type=int,
1685
+ default=int(os.getenv("ACESTEP_API_PORT", "8001")),
1686
+ help="Bind port (default from ACESTEP_API_PORT or 8001)",
1687
+ )
1688
+ args = parser.parse_args()
1689
+
1690
+ # IMPORTANT: in-memory queue/store -> workers MUST be 1
1691
+ uvicorn.run(
1692
+ "acestep.api_server:app",
1693
+ host=str(args.host),
1694
+ port=int(args.port),
1695
+ reload=False,
1696
+ workers=1,
1697
+ )
1698
+
1699
+ if __name__ == "__main__":
1700
+ main()
spaces/Ace-Step-v1.5/acestep/audio_utils.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+
12
+ # Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space
13
+ # This forces torchaudio to use ffmpeg/sox/soundfile backends instead
14
+ os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0"
15
+
16
+ import hashlib
17
+ import json
18
+ from pathlib import Path
19
+ from typing import Union, Optional, List, Tuple
20
+ import torch
21
+ import numpy as np
22
+ import torchaudio
23
+ from loguru import logger
24
+
25
+
26
+ class AudioSaver:
27
+ """Audio saving and transcoding utility class"""
28
+
29
+ def __init__(self, default_format: str = "flac"):
30
+ """
31
+ Initialize audio saver
32
+
33
+ Args:
34
+ default_format: Default save format ('flac', 'wav', 'mp3')
35
+ """
36
+ self.default_format = default_format.lower()
37
+ if self.default_format not in ["flac", "wav", "mp3"]:
38
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
39
+ self.default_format = "flac"
40
+
41
+ def save_audio(
42
+ self,
43
+ audio_data: Union[torch.Tensor, np.ndarray],
44
+ output_path: Union[str, Path],
45
+ sample_rate: int = 48000,
46
+ format: Optional[str] = None,
47
+ channels_first: bool = True,
48
+ ) -> str:
49
+ """
50
+ Save audio data to file
51
+
52
+ Args:
53
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
54
+ output_path: Output file path (extension can be omitted)
55
+ sample_rate: Sample rate
56
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
57
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
58
+
59
+ Returns:
60
+ Actual saved file path
61
+ """
62
+ format = (format or self.default_format).lower()
63
+ if format not in ["flac", "wav", "mp3"]:
64
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
65
+ format = self.default_format
66
+
67
+ # Ensure output path has correct extension
68
+ output_path = Path(output_path)
69
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
70
+ output_path = output_path.with_suffix(f'.{format}')
71
+
72
+ # Convert to torch tensor
73
+ if isinstance(audio_data, np.ndarray):
74
+ if channels_first:
75
+ # numpy [samples, channels] -> tensor [channels, samples]
76
+ audio_tensor = torch.from_numpy(audio_data.T).float()
77
+ else:
78
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
79
+ audio_tensor = torch.from_numpy(audio_data).float()
80
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
81
+ audio_tensor = audio_tensor.T
82
+ else:
83
+ # torch tensor
84
+ audio_tensor = audio_data.cpu().float()
85
+ if not channels_first and audio_tensor.dim() == 2:
86
+ # [samples, channels] -> [channels, samples]
87
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
88
+ audio_tensor = audio_tensor.T
89
+
90
+ # Ensure memory is contiguous
91
+ audio_tensor = audio_tensor.contiguous()
92
+
93
+ # Select backend and save
94
+ try:
95
+ if format == "mp3":
96
+ # MP3 uses ffmpeg backend
97
+ torchaudio.save(
98
+ str(output_path),
99
+ audio_tensor,
100
+ sample_rate,
101
+ channels_first=True,
102
+ backend='ffmpeg',
103
+ )
104
+ elif format in ["flac", "wav"]:
105
+ # FLAC and WAV use soundfile backend (fastest)
106
+ torchaudio.save(
107
+ str(output_path),
108
+ audio_tensor,
109
+ sample_rate,
110
+ channels_first=True,
111
+ backend='soundfile',
112
+ )
113
+ else:
114
+ # Other formats use default backend
115
+ torchaudio.save(
116
+ str(output_path),
117
+ audio_tensor,
118
+ sample_rate,
119
+ channels_first=True,
120
+ )
121
+
122
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
123
+ return str(output_path)
124
+
125
+ except Exception as e:
126
+ try:
127
+ import soundfile as sf
128
+ audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
129
+ sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
130
+ logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
131
+ return str(output_path)
132
+ except Exception as e:
133
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
134
+ raise
135
+
136
+ def _load_audio_file(self, audio_file: Union[str, Path]) -> Tuple[torch.Tensor, int]:
137
+ """
138
+ Load audio file with ffmpeg backend, fallback to soundfile if failed.
139
+
140
+ This handles CUDA dependency issues with torchcodec on HuggingFace Space.
141
+
142
+ Args:
143
+ audio_file: Path to the audio file
144
+
145
+ Returns:
146
+ Tuple of (audio_tensor, sample_rate)
147
+
148
+ Raises:
149
+ FileNotFoundError: If the audio file doesn't exist
150
+ Exception: If all methods fail to load the audio
151
+ """
152
+ audio_file = str(audio_file)
153
+
154
+ # Check if file exists first
155
+ if not Path(audio_file).exists():
156
+ raise FileNotFoundError(f"Audio file not found: {audio_file}")
157
+
158
+ # Try torchaudio with explicit ffmpeg backend first
159
+ try:
160
+ audio, sr = torchaudio.load(audio_file, backend="ffmpeg")
161
+ return audio, sr
162
+ except Exception as e:
163
+ logger.debug(f"[AudioSaver._load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback")
164
+
165
+ # Fallback: use soundfile directly (most compatible)
166
+ try:
167
+ import soundfile as sf
168
+ audio_np, sr = sf.read(audio_file)
169
+ # soundfile returns [samples, channels] or [samples], convert to [channels, samples]
170
+ audio = torch.from_numpy(audio_np).float()
171
+ if audio.dim() == 1:
172
+ # Mono: [samples] -> [1, samples]
173
+ audio = audio.unsqueeze(0)
174
+ else:
175
+ # Stereo: [samples, channels] -> [channels, samples]
176
+ audio = audio.T
177
+ return audio, sr
178
+ except Exception as e:
179
+ logger.error(f"[AudioSaver._load_audio_file] All methods failed to load audio: {audio_file}, error: {e}")
180
+ raise
181
+
182
+ def convert_audio(
183
+ self,
184
+ input_path: Union[str, Path],
185
+ output_path: Union[str, Path],
186
+ output_format: str,
187
+ remove_input: bool = False,
188
+ ) -> str:
189
+ """
190
+ Convert audio format
191
+
192
+ Args:
193
+ input_path: Input audio file path
194
+ output_path: Output audio file path
195
+ output_format: Target format ('flac', 'wav', 'mp3')
196
+ remove_input: Whether to delete input file
197
+
198
+ Returns:
199
+ Output file path
200
+ """
201
+ input_path = Path(input_path)
202
+ output_path = Path(output_path)
203
+
204
+ if not input_path.exists():
205
+ raise FileNotFoundError(f"Input file not found: {input_path}")
206
+
207
+ # Load audio with fallback backends
208
+ audio_tensor, sample_rate = self._load_audio_file(input_path)
209
+
210
+ # Save as new format
211
+ output_path = self.save_audio(
212
+ audio_tensor,
213
+ output_path,
214
+ sample_rate=sample_rate,
215
+ format=output_format,
216
+ channels_first=True
217
+ )
218
+
219
+ # Delete input file if needed
220
+ if remove_input:
221
+ input_path.unlink()
222
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
223
+
224
+ return output_path
225
+
226
+ def save_batch(
227
+ self,
228
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
229
+ output_dir: Union[str, Path],
230
+ file_prefix: str = "audio",
231
+ sample_rate: int = 48000,
232
+ format: Optional[str] = None,
233
+ channels_first: bool = True,
234
+ ) -> List[str]:
235
+ """
236
+ Save audio batch
237
+
238
+ Args:
239
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
240
+ output_dir: Output directory
241
+ file_prefix: File prefix
242
+ sample_rate: Sample rate
243
+ format: Audio format
244
+ channels_first: Tensor format flag
245
+
246
+ Returns:
247
+ List of saved file paths
248
+ """
249
+ output_dir = Path(output_dir)
250
+ output_dir.mkdir(parents=True, exist_ok=True)
251
+
252
+ # Process batch
253
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
254
+ # [batch, channels, samples]
255
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
256
+ elif isinstance(audio_batch, list):
257
+ audio_list = audio_batch
258
+ else:
259
+ audio_list = [audio_batch]
260
+
261
+ saved_paths = []
262
+ for i, audio in enumerate(audio_list):
263
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
264
+ saved_path = self.save_audio(
265
+ audio,
266
+ output_path,
267
+ sample_rate=sample_rate,
268
+ format=format,
269
+ channels_first=channels_first
270
+ )
271
+ saved_paths.append(saved_path)
272
+
273
+ return saved_paths
274
+
275
+
276
+ def get_audio_file_hash(audio_file) -> str:
277
+ """
278
+ Get hash identifier for an audio file.
279
+
280
+ Args:
281
+ audio_file: Path to audio file (str) or file-like object
282
+
283
+ Returns:
284
+ Hash string or empty string
285
+ """
286
+ if audio_file is None:
287
+ return ""
288
+
289
+ try:
290
+ if isinstance(audio_file, str):
291
+ if os.path.exists(audio_file):
292
+ with open(audio_file, 'rb') as f:
293
+ return hashlib.md5(f.read()).hexdigest()
294
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
295
+ elif hasattr(audio_file, 'name'):
296
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
297
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
298
+ except Exception:
299
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
300
+
301
+
302
+ def generate_uuid_from_params(params_dict) -> str:
303
+ """
304
+ Generate deterministic UUID from generation parameters.
305
+ Same parameters will always generate the same UUID.
306
+
307
+ Args:
308
+ params_dict: Dictionary of parameters
309
+
310
+ Returns:
311
+ UUID string
312
+ """
313
+
314
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
315
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
316
+ hash_hex = hash_obj.hexdigest()
317
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
318
+ return uuid_str
319
+
320
+
321
+ def generate_uuid_from_audio_data(
322
+ audio_data: Union[torch.Tensor, np.ndarray],
323
+ seed: Optional[int] = None
324
+ ) -> str:
325
+ """
326
+ Generate UUID from audio data (for caching/deduplication)
327
+
328
+ Args:
329
+ audio_data: Audio data
330
+ seed: Optional seed value
331
+
332
+ Returns:
333
+ UUID string
334
+ """
335
+ if isinstance(audio_data, torch.Tensor):
336
+ # Convert to numpy and calculate hash
337
+ audio_np = audio_data.cpu().numpy()
338
+ else:
339
+ audio_np = audio_data
340
+
341
+ # Calculate data hash
342
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
343
+
344
+ if seed is not None:
345
+ combined = f"{data_hash}_{seed}"
346
+ return hashlib.md5(combined.encode()).hexdigest()
347
+
348
+ return data_hash
349
+
350
+
351
+ # Global default instance
352
+ _default_saver = AudioSaver(default_format="flac")
353
+
354
+
355
+ def save_audio(
356
+ audio_data: Union[torch.Tensor, np.ndarray],
357
+ output_path: Union[str, Path],
358
+ sample_rate: int = 48000,
359
+ format: Optional[str] = None,
360
+ channels_first: bool = True,
361
+ ) -> str:
362
+ """
363
+ Convenience function: save audio (using default configuration)
364
+
365
+ Args:
366
+ audio_data: Audio data
367
+ output_path: Output path
368
+ sample_rate: Sample rate
369
+ format: Format (default flac)
370
+ channels_first: Tensor format flag
371
+
372
+ Returns:
373
+ Saved file path
374
+ """
375
+ return _default_saver.save_audio(
376
+ audio_data, output_path, sample_rate, format, channels_first
377
+ )
378
+
spaces/Ace-Step-v1.5/acestep/constants.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for ACE-Step
3
+ Centralized constants used across the codebase
4
+ """
5
+
6
+ # ==============================================================================
7
+ # Language Constants
8
+ # ==============================================================================
9
+
10
+ VALID_LANGUAGES = [
11
+ 'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
12
+ 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
13
+ 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
14
+ 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
15
+ 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
16
+ 'unknown'
17
+ ]
18
+
19
+
20
+ # ==============================================================================
21
+ # Keyscale Constants
22
+ # ==============================================================================
23
+
24
+ KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
25
+ KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
26
+ KEYSCALE_MODES = ['major', 'minor']
27
+
28
+ # Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
29
+ VALID_KEYSCALES = set()
30
+ for note in KEYSCALE_NOTES:
31
+ for acc in KEYSCALE_ACCIDENTALS:
32
+ for mode in KEYSCALE_MODES:
33
+ VALID_KEYSCALES.add(f"{note}{acc} {mode}")
34
+
35
+
36
+ # ==============================================================================
37
+ # Metadata Range Constants
38
+ # ==============================================================================
39
+
40
+ # BPM (Beats Per Minute) range
41
+ BPM_MIN = 30
42
+ BPM_MAX = 300
43
+
44
+ # Duration range (in seconds)
45
+ DURATION_MIN = 10
46
+ DURATION_MAX = 600
47
+
48
+ # Valid time signatures
49
+ VALID_TIME_SIGNATURES = [2, 3, 4, 6]
50
+
51
+
52
+ # ==============================================================================
53
+ # Task Type Constants
54
+ # ==============================================================================
55
+
56
+ TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
57
+
58
+ # Task types available for turbo models (subset)
59
+ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
60
+
61
+ # Task types available for base models (full set)
62
+ TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
63
+
64
+
65
+ # ==============================================================================
66
+ # Instruction Constants
67
+ # ==============================================================================
68
+
69
+ # Default instructions
70
+ DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
71
+ DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
72
+ DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
73
+ DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
74
+ DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
75
+
76
+ # Instruction templates for each task type
77
+ # Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
78
+ # These should be formatted using .format() or f-strings when used
79
+ TASK_INSTRUCTIONS = {
80
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
81
+ "repaint": "Repaint the mask area based on the given conditions:",
82
+ "cover": "Generate audio semantic tokens based on the given conditions:",
83
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
84
+ "extract_default": "Extract the track from the audio:",
85
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
86
+ "lego_default": "Generate the track based on the audio context:",
87
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
88
+ "complete_default": "Complete the input track:",
89
+ }
90
+
91
+
92
+ # ==============================================================================
93
+ # Track/Instrument Constants
94
+ # ==============================================================================
95
+
96
+ TRACK_NAMES = [
97
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
98
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
99
+ ]
100
+
101
+ SFT_GEN_PROMPT = """# Instruction
102
+ {}
103
+
104
+ # Caption
105
+ {}
106
+
107
+ # Metas
108
+ {}<|endoftext|>
109
+ """
spaces/Ace-Step-v1.5/acestep/constrained_logits_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
spaces/Ace-Step-v1.5/acestep/dataset_handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Handler
3
+ Handles dataset import and exploration functionality
4
+ """
5
+ from typing import Optional, Tuple, Any, Dict
6
+
7
+
8
+ class DatasetHandler:
9
+ """Dataset Handler for Dataset Explorer functionality"""
10
+
11
+ def __init__(self):
12
+ """Initialize dataset handler"""
13
+ self.dataset = None
14
+ self.dataset_imported = False
15
+
16
+ def import_dataset(self, dataset_type: str) -> str:
17
+ """
18
+ Import dataset (temporarily disabled)
19
+
20
+ Args:
21
+ dataset_type: Type of dataset to import (e.g., "train", "test")
22
+
23
+ Returns:
24
+ Status message string
25
+ """
26
+ self.dataset_imported = False
27
+ return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
28
+
29
+ def get_item_data(self, *args, **kwargs) -> Tuple:
30
+ """
31
+ Get dataset item (temporarily disabled)
32
+
33
+ Returns:
34
+ Tuple of placeholder values matching the expected return format
35
+ """
36
+ return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
37
+
spaces/Ace-Step-v1.5/acestep/dit_alignment_score.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT Alignment Score Module
3
+
4
+ This module provides lyrics-to-audio alignment using cross-attention matrices
5
+ from DiT model for generating LRC timestamps.
6
+
7
+ Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
8
+ """
9
+ import numba
10
+ import torch
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
+
16
+
17
+ # ================= Data Classes =================
18
+ @dataclass
19
+ class TokenTimestamp:
20
+ """Stores per-token timing information."""
21
+ token_id: int
22
+ text: str
23
+ start: float
24
+ end: float
25
+ probability: float
26
+
27
+
28
+ @dataclass
29
+ class SentenceTimestamp:
30
+ """Stores per-sentence timing information with token list."""
31
+ text: str
32
+ start: float
33
+ end: float
34
+ tokens: List[TokenTimestamp]
35
+ confidence: float
36
+
37
+
38
+ # ================= DTW Algorithm (Numba Optimized) =================
39
+ @numba.jit(nopython=True)
40
+ def dtw_cpu(x: np.ndarray):
41
+ """
42
+ Dynamic Time Warping algorithm optimized with Numba.
43
+
44
+ Args:
45
+ x: Cost matrix of shape [N, M]
46
+
47
+ Returns:
48
+ Tuple of (text_indices, time_indices) arrays
49
+ """
50
+ N, M = x.shape
51
+ # Use float32 for memory efficiency
52
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
53
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
54
+ cost[0, 0] = 0
55
+
56
+ for j in range(1, M + 1):
57
+ for i in range(1, N + 1):
58
+ c0 = cost[i - 1, j - 1]
59
+ c1 = cost[i - 1, j]
60
+ c2 = cost[i, j - 1]
61
+
62
+ if c0 < c1 and c0 < c2:
63
+ c, t = c0, 0
64
+ elif c1 < c0 and c1 < c2:
65
+ c, t = c1, 1
66
+ else:
67
+ c, t = c2, 2
68
+
69
+ cost[i, j] = x[i - 1, j - 1] + c
70
+ trace[i, j] = t
71
+
72
+ return _backtrace(trace, N, M)
73
+
74
+
75
+ @numba.jit(nopython=True)
76
+ def _backtrace(trace: np.ndarray, N: int, M: int):
77
+ """
78
+ Optimized backtrace function for DTW.
79
+
80
+ Args:
81
+ trace: Trace matrix of shape (N+1, M+1)
82
+ N, M: Original matrix dimensions
83
+
84
+ Returns:
85
+ Path array of shape (2, path_len) - first row is text indices, second is time indices
86
+ """
87
+ # Boundary handling
88
+ trace[0, :] = 2
89
+ trace[:, 0] = 1
90
+
91
+ # Pre-allocate array, max path length is N+M
92
+ max_path_len = N + M
93
+ path = np.zeros((2, max_path_len), dtype=np.int32)
94
+
95
+ i, j = N, M
96
+ path_idx = max_path_len - 1
97
+
98
+ while i > 0 or j > 0:
99
+ path[0, path_idx] = i - 1 # text index
100
+ path[1, path_idx] = j - 1 # time index
101
+ path_idx -= 1
102
+
103
+ t = trace[i, j]
104
+ if t == 0:
105
+ i -= 1
106
+ j -= 1
107
+ elif t == 1:
108
+ i -= 1
109
+ elif t == 2:
110
+ j -= 1
111
+ else:
112
+ break
113
+
114
+ actual_len = max_path_len - path_idx - 1
115
+ return path[:, path_idx + 1:max_path_len]
116
+
117
+
118
+ # ================= Utility Functions =================
119
+ def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
120
+ """
121
+ Apply median filter to tensor.
122
+
123
+ Args:
124
+ x: Input tensor
125
+ filter_width: Width of median filter
126
+
127
+ Returns:
128
+ Filtered tensor
129
+ """
130
+ pad_width = filter_width // 2
131
+ if x.shape[-1] <= pad_width:
132
+ return x
133
+ if x.ndim == 2:
134
+ x = x[None, :]
135
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
136
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
137
+ if result.ndim > 2:
138
+ result = result.squeeze(0)
139
+ return result
140
+
141
+
142
+ # ================= Main Aligner Class =================
143
+ class MusicStampsAligner:
144
+ """
145
+ Aligner class for generating lyrics timestamps from cross-attention matrices.
146
+
147
+ Uses bidirectional consensus denoising and DTW for alignment.
148
+ """
149
+
150
+ def __init__(self, tokenizer):
151
+ """
152
+ Initialize the aligner.
153
+
154
+ Args:
155
+ tokenizer: Text tokenizer for decoding tokens
156
+ """
157
+ self.tokenizer = tokenizer
158
+
159
+ def _apply_bidirectional_consensus(
160
+ self,
161
+ weights_stack: torch.Tensor,
162
+ violence_level: float,
163
+ medfilt_width: int
164
+ ) -> tuple:
165
+ """
166
+ Core denoising logic using bidirectional consensus.
167
+
168
+ Args:
169
+ weights_stack: Attention weights [Heads, Tokens, Frames]
170
+ violence_level: Denoising strength coefficient
171
+ medfilt_width: Median filter width
172
+
173
+ Returns:
174
+ Tuple of (calc_matrix, energy_matrix) as numpy arrays
175
+ """
176
+ # A. Bidirectional Consensus
177
+ row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
178
+ col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
179
+ processed = row_prob * col_prob
180
+
181
+ # 1. Row suppression (kill horizontal crossing lines)
182
+ row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
183
+ processed = processed - (violence_level * row_medians)
184
+ processed = torch.relu(processed)
185
+
186
+ # 2. Column suppression (kill vertical crossing lines)
187
+ col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
188
+ processed = processed - (violence_level * col_medians)
189
+ processed = torch.relu(processed)
190
+
191
+ # C. Power sharpening
192
+ processed = processed ** 2
193
+
194
+ # Energy matrix for confidence
195
+ energy_matrix = processed.mean(dim=0).cpu().numpy()
196
+
197
+ # D. Z-Score normalization
198
+ std, mean = torch.std_mean(processed, unbiased=False)
199
+ weights_processed = (processed - mean) / (std + 1e-9)
200
+
201
+ # E. Median filtering
202
+ weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
203
+ calc_matrix = weights_processed.mean(dim=0).numpy()
204
+
205
+ return calc_matrix, energy_matrix
206
+
207
+ def _preprocess_attention(
208
+ self,
209
+ attention_matrix: torch.Tensor,
210
+ custom_config: Dict[int, List[int]],
211
+ violence_level: float,
212
+ medfilt_width: int = 7
213
+ ) -> tuple:
214
+ """
215
+ Preprocess attention matrix for alignment.
216
+
217
+ Args:
218
+ attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
219
+ custom_config: Dict mapping layer indices to head indices
220
+ violence_level: Denoising strength
221
+ medfilt_width: Median filter width
222
+
223
+ Returns:
224
+ Tuple of (calc_matrix, energy_matrix, visual_matrix)
225
+ """
226
+ if not isinstance(attention_matrix, torch.Tensor):
227
+ weights = torch.tensor(attention_matrix)
228
+ else:
229
+ weights = attention_matrix.clone()
230
+
231
+ weights = weights.cpu().float()
232
+
233
+ selected_tensors = []
234
+ for layer_idx, head_indices in custom_config.items():
235
+ for head_idx in head_indices:
236
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
237
+ head_matrix = weights[layer_idx, head_idx]
238
+ selected_tensors.append(head_matrix)
239
+
240
+ if not selected_tensors:
241
+ return None, None, None
242
+
243
+ # Stack selected heads: [Heads, Tokens, Frames]
244
+ weights_stack = torch.stack(selected_tensors, dim=0)
245
+ visual_matrix = weights_stack.mean(dim=0).numpy()
246
+
247
+ calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
248
+ weights_stack, violence_level, medfilt_width
249
+ )
250
+
251
+ return calc_matrix, energy_matrix, visual_matrix
252
+
253
+ def stamps_align_info(
254
+ self,
255
+ attention_matrix: torch.Tensor,
256
+ lyrics_tokens: List[int],
257
+ total_duration_seconds: float,
258
+ custom_config: Dict[int, List[int]],
259
+ return_matrices: bool = False,
260
+ violence_level: float = 2.0,
261
+ medfilt_width: int = 1
262
+ ) -> Dict[str, Any]:
263
+ """
264
+ Get alignment information from attention matrix.
265
+
266
+ Args:
267
+ attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
268
+ lyrics_tokens: List of lyrics token IDs
269
+ total_duration_seconds: Total audio duration in seconds
270
+ custom_config: Dict mapping layer indices to head indices
271
+ return_matrices: Whether to return intermediate matrices
272
+ violence_level: Denoising strength
273
+ medfilt_width: Median filter width
274
+
275
+ Returns:
276
+ Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
277
+ and optionally energy_matrix and vis_matrix
278
+ """
279
+ calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
280
+ attention_matrix, custom_config, violence_level, medfilt_width
281
+ )
282
+
283
+ if calc_matrix is None:
284
+ return {
285
+ "calc_matrix": None,
286
+ "lyrics_tokens": lyrics_tokens,
287
+ "total_duration_seconds": total_duration_seconds,
288
+ "error": "No valid attention heads found"
289
+ }
290
+
291
+ return_dict = {
292
+ "calc_matrix": calc_matrix,
293
+ "lyrics_tokens": lyrics_tokens,
294
+ "total_duration_seconds": total_duration_seconds
295
+ }
296
+
297
+ if return_matrices:
298
+ return_dict['energy_matrix'] = energy_matrix
299
+ return_dict['vis_matrix'] = visual_matrix
300
+
301
+ return return_dict
302
+
303
+ def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
304
+ """
305
+ Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
306
+
307
+ For Chinese and other multi-byte characters, the tokenizer may split them
308
+ into multiple byte-level tokens. Decoding each token individually produces
309
+ invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
310
+ to correctly track which characters each token contributes.
311
+
312
+ Args:
313
+ token_ids: List of token IDs
314
+
315
+ Returns:
316
+ List of decoded text for each token position
317
+ """
318
+ decoded_tokens = []
319
+ prev_bytes = b""
320
+
321
+ for i in range(len(token_ids)):
322
+ # Decode tokens from start to current position
323
+ current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
324
+ current_bytes = current_text.encode('utf-8', errors='surrogatepass')
325
+
326
+ # The contribution of current token is the new bytes added
327
+ if len(current_bytes) >= len(prev_bytes):
328
+ new_bytes = current_bytes[len(prev_bytes):]
329
+ # Try to decode the new bytes; if incomplete, use empty string
330
+ try:
331
+ token_text = new_bytes.decode('utf-8')
332
+ except UnicodeDecodeError:
333
+ # Incomplete UTF-8 sequence, this token doesn't complete a character
334
+ token_text = ""
335
+ else:
336
+ # Edge case: current decode is shorter (shouldn't happen normally)
337
+ token_text = ""
338
+
339
+ decoded_tokens.append(token_text)
340
+ prev_bytes = current_bytes
341
+
342
+ return decoded_tokens
343
+
344
+ def token_timestamps(
345
+ self,
346
+ calc_matrix: np.ndarray,
347
+ lyrics_tokens: List[int],
348
+ total_duration_seconds: float
349
+ ) -> List[TokenTimestamp]:
350
+ """
351
+ Generate per-token timestamps using DTW.
352
+
353
+ Args:
354
+ calc_matrix: Processed attention matrix [Tokens, Frames]
355
+ lyrics_tokens: List of token IDs
356
+ total_duration_seconds: Total audio duration
357
+
358
+ Returns:
359
+ List of TokenTimestamp objects
360
+ """
361
+ n_frames = calc_matrix.shape[-1]
362
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
363
+
364
+ seconds_per_frame = total_duration_seconds / n_frames
365
+ alignment_results = []
366
+
367
+ # Use incremental decoding to properly handle multi-byte UTF-8 characters
368
+ decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
369
+
370
+ for i in range(len(lyrics_tokens)):
371
+ mask = (text_indices == i)
372
+
373
+ if not np.any(mask):
374
+ start = alignment_results[-1].end if alignment_results else 0.0
375
+ end = start
376
+ token_conf = 0.0
377
+ else:
378
+ times = time_indices[mask] * seconds_per_frame
379
+ start = times[0]
380
+ end = times[-1]
381
+ token_conf = 0.0
382
+
383
+ if end < start:
384
+ end = start
385
+
386
+ alignment_results.append(TokenTimestamp(
387
+ token_id=lyrics_tokens[i],
388
+ text=decoded_tokens[i],
389
+ start=float(start),
390
+ end=float(end),
391
+ probability=token_conf
392
+ ))
393
+
394
+ return alignment_results
395
+
396
+ def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
397
+ """
398
+ Decode a sentence by decoding all token IDs together.
399
+ This avoids UTF-8 encoding issues from joining individual token texts.
400
+
401
+ Args:
402
+ tokens: List of TokenTimestamp objects
403
+
404
+ Returns:
405
+ Properly decoded sentence text
406
+ """
407
+ token_ids = [t.token_id for t in tokens]
408
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
409
+
410
+ def sentence_timestamps(
411
+ self,
412
+ token_alignment: List[TokenTimestamp]
413
+ ) -> List[SentenceTimestamp]:
414
+ """
415
+ Group token timestamps into sentence timestamps.
416
+
417
+ Args:
418
+ token_alignment: List of TokenTimestamp objects
419
+
420
+ Returns:
421
+ List of SentenceTimestamp objects
422
+ """
423
+ results = []
424
+ current_tokens = []
425
+
426
+ for token in token_alignment:
427
+ current_tokens.append(token)
428
+
429
+ if '\n' in token.text:
430
+ # Decode all token IDs together to avoid UTF-8 issues
431
+ full_text = self._decode_sentence_from_tokens(current_tokens)
432
+
433
+ if full_text.strip():
434
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
435
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
436
+
437
+ results.append(SentenceTimestamp(
438
+ text=full_text.strip(),
439
+ start=round(current_tokens[0].start, 3),
440
+ end=round(current_tokens[-1].end, 3),
441
+ tokens=list(current_tokens),
442
+ confidence=sent_conf
443
+ ))
444
+
445
+ current_tokens = []
446
+
447
+ # Handle last sentence
448
+ if current_tokens:
449
+ # Decode all token IDs together to avoid UTF-8 issues
450
+ full_text = self._decode_sentence_from_tokens(current_tokens)
451
+ if full_text.strip():
452
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
453
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
454
+
455
+ results.append(SentenceTimestamp(
456
+ text=full_text.strip(),
457
+ start=round(current_tokens[0].start, 3),
458
+ end=round(current_tokens[-1].end, 3),
459
+ tokens=list(current_tokens),
460
+ confidence=sent_conf
461
+ ))
462
+
463
+ # Normalize confidence scores
464
+ if results:
465
+ all_scores = [s.confidence for s in results]
466
+ min_score = min(all_scores)
467
+ max_score = max(all_scores)
468
+ score_range = max_score - min_score
469
+
470
+ if score_range > 1e-9:
471
+ for s in results:
472
+ normalized_score = (s.confidence - min_score) / score_range
473
+ s.confidence = round(normalized_score, 2)
474
+ else:
475
+ for s in results:
476
+ s.confidence = round(s.confidence, 2)
477
+
478
+ return results
479
+
480
+ def format_lrc(
481
+ self,
482
+ sentence_timestamps: List[SentenceTimestamp],
483
+ include_end_time: bool = False
484
+ ) -> str:
485
+ """
486
+ Format sentence timestamps as LRC lyrics format.
487
+
488
+ Args:
489
+ sentence_timestamps: List of SentenceTimestamp objects
490
+ include_end_time: Whether to include end time (enhanced LRC format)
491
+
492
+ Returns:
493
+ LRC formatted string
494
+ """
495
+ lines = []
496
+
497
+ for sentence in sentence_timestamps:
498
+ # Convert seconds to mm:ss.xx format
499
+ start_minutes = int(sentence.start // 60)
500
+ start_seconds = sentence.start % 60
501
+
502
+ if include_end_time:
503
+ end_minutes = int(sentence.end // 60)
504
+ end_seconds = sentence.end % 60
505
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
506
+ else:
507
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
508
+
509
+ # Clean the text (remove structural tags like [verse], [chorus])
510
+ text = sentence.text
511
+
512
+ lines.append(f"{timestamp}{text}")
513
+
514
+ return "\n".join(lines)
515
+
516
+ def get_timestamps_and_lrc(
517
+ self,
518
+ calc_matrix: np.ndarray,
519
+ lyrics_tokens: List[int],
520
+ total_duration_seconds: float
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Convenience method to get both timestamps and LRC in one call.
524
+
525
+ Args:
526
+ calc_matrix: Processed attention matrix
527
+ lyrics_tokens: List of token IDs
528
+ total_duration_seconds: Total audio duration
529
+
530
+ Returns:
531
+ Dict containing token_timestamps, sentence_timestamps, and lrc_text
532
+ """
533
+ token_stamps = self.token_timestamps(
534
+ calc_matrix=calc_matrix,
535
+ lyrics_tokens=lyrics_tokens,
536
+ total_duration_seconds=total_duration_seconds
537
+ )
538
+
539
+ sentence_stamps = self.sentence_timestamps(token_stamps)
540
+ lrc_text = self.format_lrc(sentence_stamps)
541
+
542
+ return {
543
+ "token_timestamps": token_stamps,
544
+ "sentence_timestamps": sentence_stamps,
545
+ "lrc_text": lrc_text
546
+ }
547
+
548
+
549
+ class MusicLyricScorer:
550
+ """
551
+ Scorer class for evaluating lyrics-to-audio alignment quality.
552
+
553
+ Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
554
+ using tensor operations for potential differentiability or GPU acceleration.
555
+ """
556
+
557
+ def __init__(self, tokenizer: Any):
558
+ """
559
+ Initialize the aligner.
560
+
561
+ Args:
562
+ tokenizer: Tokenizer instance (must implement .decode()).
563
+ """
564
+ self.tokenizer = tokenizer
565
+
566
+ def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
567
+ """
568
+ Generate a mask distinguishing lyrics (1) from structural tags (0).
569
+ Uses self.tokenizer to decode tokens.
570
+
571
+ Args:
572
+ token_ids: List of token IDs.
573
+
574
+ Returns:
575
+ Numpy array of shape [len(token_ids)] with 1 or 0.
576
+ """
577
+ decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
578
+ mask = np.ones(len(token_ids), dtype=np.int32)
579
+ in_bracket = False
580
+
581
+ for i, token_str in enumerate(decoded_tokens):
582
+ if '[' in token_str:
583
+ in_bracket = True
584
+ if in_bracket:
585
+ mask[i] = 0
586
+ if ']' in token_str:
587
+ in_bracket = False
588
+ mask[i] = 0
589
+ return mask
590
+
591
+ def _preprocess_attention(
592
+ self,
593
+ attention_matrix: Union[torch.Tensor, np.ndarray],
594
+ custom_config: Dict[int, List[int]],
595
+ medfilt_width: int = 1
596
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
597
+ """
598
+ Extracts and normalizes the attention matrix.
599
+
600
+ Logic V4: Uses Min-Max normalization to highlight energy differences.
601
+
602
+ Args:
603
+ attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
604
+ custom_config: Config mapping layers to heads.
605
+ medfilt_width: Width for median filtering.
606
+
607
+ Returns:
608
+ Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
609
+ """
610
+ # 1. Prepare Tensor
611
+ if not isinstance(attention_matrix, torch.Tensor):
612
+ weights = torch.tensor(attention_matrix)
613
+ else:
614
+ weights = attention_matrix.clone()
615
+ weights = weights.cpu().float()
616
+
617
+ # 2. Select Heads based on config
618
+ selected_tensors = []
619
+ for layer_idx, head_indices in custom_config.items():
620
+ for head_idx in head_indices:
621
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
622
+ selected_tensors.append(weights[layer_idx, head_idx])
623
+
624
+ if not selected_tensors:
625
+ return None, None, None
626
+
627
+ weights_stack = torch.stack(selected_tensors, dim=0)
628
+
629
+ # 3. Average Heads
630
+ avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
631
+
632
+ # 4. Preprocessing Logic
633
+ # Min-Max normalization preserving energy distribution
634
+ # Median filter is applied to the energy matrix
635
+ energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
636
+ energy_matrix = energy_tensor.numpy()
637
+
638
+ e_min, e_max = energy_matrix.min(), energy_matrix.max()
639
+
640
+ if e_max - e_min > 1e-9:
641
+ energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
642
+ else:
643
+ energy_matrix = np.zeros_like(energy_matrix)
644
+
645
+ # Contrast enhancement for DTW pathfinding
646
+ # calc_matrix is used for pathfinding, energy_matrix for scoring
647
+ calc_matrix = energy_matrix ** 2
648
+
649
+ return calc_matrix, energy_matrix, avg_weights
650
+
651
+ def _compute_alignment_metrics(
652
+ self,
653
+ energy_matrix: torch.Tensor,
654
+ path_coords: torch.Tensor,
655
+ type_mask: torch.Tensor,
656
+ time_weight: float = 0.01,
657
+ overlap_frames: float = 9.0,
658
+ instrumental_weight: float = 1.0
659
+ ) -> Tuple[float, float, float]:
660
+ """
661
+ Core metric calculation logic using high-precision Tensor operations.
662
+
663
+ Args:
664
+ energy_matrix: Normalized energy [Rows, Cols].
665
+ path_coords: DTW path coordinates [Steps, 2].
666
+ type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
667
+ time_weight: Minimum energy threshold for monotonicity.
668
+ overlap_frames: Allowed overlap for monotonicity check.
669
+ instrumental_weight: Weight for non-lyric tokens in confidence calc.
670
+
671
+ Returns:
672
+ Tuple of (coverage, monotonicity, confidence).
673
+ """
674
+ # Ensure high precision for internal calculation
675
+ energy_matrix = energy_matrix.to(dtype=torch.float64)
676
+ path_coords = path_coords.long()
677
+ type_mask = type_mask.long()
678
+
679
+ device = energy_matrix.device
680
+ rows, cols = energy_matrix.shape
681
+
682
+ is_lyrics_row = (type_mask == 1)
683
+
684
+ # ================= A. Coverage Score =================
685
+ # Ratio of lyric lines that have significant energy peak
686
+ row_max_energies = energy_matrix.max(dim=1).values
687
+ total_sung_rows = is_lyrics_row.sum().double()
688
+
689
+ coverage_threshold = 0.1
690
+ valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
691
+ valid_sung_rows = valid_sung_mask.sum().double()
692
+
693
+ if total_sung_rows > 0:
694
+ coverage_score = valid_sung_rows / total_sung_rows
695
+ else:
696
+ coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
697
+
698
+ # ================= B. Monotonicity Score =================
699
+ # Check if the "center of mass" of lyric lines moves forward in time
700
+ col_indices = torch.arange(cols, device=device, dtype=torch.float64)
701
+
702
+ # Zero out low energy noise
703
+ weights = torch.where(
704
+ energy_matrix > time_weight,
705
+ energy_matrix,
706
+ torch.zeros_like(energy_matrix)
707
+ )
708
+
709
+ sum_w = weights.sum(dim=1)
710
+ sum_t = (weights * col_indices).sum(dim=1)
711
+
712
+ # Calculate centroids
713
+ centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
714
+ valid_w_mask = sum_w > 1e-9
715
+ centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
716
+
717
+ # Extract sequence of valid lyrics centroids
718
+ valid_sequence_mask = is_lyrics_row & (centroids >= 0)
719
+ sung_centroids = centroids[valid_sequence_mask]
720
+
721
+ cnt = sung_centroids.shape[0]
722
+ if cnt > 1:
723
+ curr_c = sung_centroids[:-1]
724
+ next_c = sung_centroids[1:]
725
+
726
+ # Check non-decreasing order with overlap tolerance
727
+ non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
728
+ pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
729
+ monotonicity_score = non_decreasing / pairs
730
+ else:
731
+ monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
732
+
733
+ # ================= C. Path Confidence =================
734
+ # Average energy along the optimal path
735
+ if path_coords.shape[0] > 0:
736
+ p_rows = path_coords[:, 0]
737
+ p_cols = path_coords[:, 1]
738
+
739
+ path_energies = energy_matrix[p_rows, p_cols]
740
+ step_weights = torch.ones_like(path_energies)
741
+
742
+ # Lower weight for instrumental/tag steps
743
+ is_inst_step = (type_mask[p_rows] == 0)
744
+ step_weights[is_inst_step] = instrumental_weight
745
+
746
+ total_energy = (path_energies * step_weights).sum()
747
+ total_steps = step_weights.sum()
748
+
749
+ if total_steps > 0:
750
+ path_confidence = total_energy / total_steps
751
+ else:
752
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
753
+ else:
754
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
755
+
756
+ return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
757
+
758
+ def lyrics_alignment_info(
759
+ self,
760
+ attention_matrix: Union[torch.Tensor, np.ndarray],
761
+ token_ids: List[int],
762
+ custom_config: Dict[int, List[int]],
763
+ return_matrices: bool = False,
764
+ medfilt_width: int = 1
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Generates alignment path and processed matrices.
768
+
769
+ Args:
770
+ attention_matrix: Input attention tensor.
771
+ token_ids: Corresponding token IDs.
772
+ custom_config: Layer/Head configuration.
773
+ return_matrices: If True, returns matrices in the output.
774
+ medfilt_width: Median filter width.
775
+
776
+ Returns:
777
+ Dict or AlignmentInfo object containing path and masks.
778
+ """
779
+ calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
780
+ attention_matrix, custom_config, medfilt_width
781
+ )
782
+
783
+ if calc_matrix is None:
784
+ return {
785
+ "calc_matrix": None,
786
+ "error": "No valid attention heads found"
787
+ }
788
+
789
+ # 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
790
+ # Uses self.tokenizer internally
791
+ type_mask = self._generate_token_type_mask(token_ids)
792
+
793
+ # Safety check for shape mismatch
794
+ if len(type_mask) != energy_matrix.shape[0]:
795
+ # Fallback to all lyrics if shapes don't align
796
+ type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
797
+
798
+ # 2. DTW Pathfinding
799
+ # Using negative calc_matrix because DTW minimizes cost
800
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
801
+ path_coords = np.stack([text_indices, time_indices], axis=1)
802
+
803
+ return_dict = {
804
+ "path_coords": path_coords,
805
+ "type_mask": type_mask,
806
+ "energy_matrix": energy_matrix
807
+ }
808
+ if return_matrices:
809
+ return_dict['calc_matrix'] = calc_matrix
810
+ return_dict['vis_matrix'] = vis_matrix
811
+
812
+ return return_dict
813
+
814
+ def calculate_score(
815
+ self,
816
+ energy_matrix: Union[torch.Tensor, np.ndarray],
817
+ type_mask: Union[torch.Tensor, np.ndarray],
818
+ path_coords: Union[torch.Tensor, np.ndarray],
819
+ time_weight: float = 0.01,
820
+ overlap_frames: float = 9.0,
821
+ instrumental_weight: float = 1.0
822
+ ) -> Dict[str, Any]:
823
+ """
824
+ Calculates the final alignment score based on pre-computed components.
825
+
826
+ Args:
827
+ energy_matrix: Processed energy matrix.
828
+ type_mask: Token type mask.
829
+ path_coords: DTW path coordinates.
830
+ time_weight: Minimum energy threshold for monotonicity.
831
+ overlap_frames: Allowed backward movement frames.
832
+ instrumental_weight: Weight for non-lyric path steps.
833
+
834
+ Returns:
835
+ AlignmentScore object containing individual metrics and final score.
836
+ """
837
+ # Ensure Inputs are Tensors on the correct device
838
+ if not isinstance(energy_matrix, torch.Tensor):
839
+ energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
840
+
841
+ device = energy_matrix.device
842
+
843
+ if not isinstance(type_mask, torch.Tensor):
844
+ type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
845
+ else:
846
+ type_mask = type_mask.to(device=device, dtype=torch.long)
847
+
848
+ if not isinstance(path_coords, torch.Tensor):
849
+ path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
850
+ else:
851
+ path_coords = path_coords.to(device=device, dtype=torch.long)
852
+
853
+ # Compute Metrics
854
+ coverage, monotonicity, confidence = self._compute_alignment_metrics(
855
+ energy_matrix=energy_matrix,
856
+ path_coords=path_coords,
857
+ type_mask=type_mask,
858
+ time_weight=time_weight,
859
+ overlap_frames=overlap_frames,
860
+ instrumental_weight=instrumental_weight
861
+ )
862
+
863
+ # Final Score Calculation
864
+ # (Cov^2 * Mono^2 * Conf)
865
+ final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
866
+ final_score = float(np.clip(final_score, 0.0, 1.0))
867
+
868
+ return {
869
+ "lyrics_score": round(final_score, 4)
870
+ }
spaces/Ace-Step-v1.5/acestep/genres_vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
spaces/Ace-Step-v1.5/acestep/gradio_ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from acestep.gradio_ui.interfaces import create_gradio_interface
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/__init__.py ADDED
@@ -0,0 +1,1310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Event Handlers Module
3
+ Main entry point for setting up all event handlers
4
+ """
5
+ import gradio as gr
6
+ from typing import Optional
7
+
8
+ # Import handler modules
9
+ from . import generation_handlers as gen_h
10
+ from . import results_handlers as res_h
11
+ from . import training_handlers as train_h
12
+ from acestep.gradio_ui.i18n import t
13
+
14
+
15
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
16
+ """Setup event handlers connecting UI components and business logic
17
+
18
+ Args:
19
+ init_params: Dictionary containing initialization parameters including:
20
+ - dit_handler_2: Optional second DiT handler for multi-model setup
21
+ - available_dit_models: List of available DiT model names
22
+ - config_path: Primary model config path
23
+ - config_path_2: Secondary model config path (if available)
24
+ """
25
+ # Get secondary DiT handler from init_params (for multi-model support)
26
+ dit_handler_2 = init_params.get('dit_handler_2') if init_params else None
27
+ config_path_1 = init_params.get('config_path', '') if init_params else ''
28
+ config_path_2 = init_params.get('config_path_2', '') if init_params else ''
29
+
30
+ # ========== Dataset Handlers ==========
31
+ dataset_section["import_dataset_btn"].click(
32
+ fn=dataset_handler.import_dataset,
33
+ inputs=[dataset_section["dataset_type"]],
34
+ outputs=[dataset_section["data_status"]]
35
+ )
36
+
37
+ # ========== Service Initialization ==========
38
+ generation_section["refresh_btn"].click(
39
+ fn=lambda: gen_h.refresh_checkpoints(dit_handler),
40
+ outputs=[generation_section["checkpoint_dropdown"]]
41
+ )
42
+
43
+ generation_section["config_path"].change(
44
+ fn=gen_h.update_model_type_settings,
45
+ inputs=[generation_section["config_path"]],
46
+ outputs=[
47
+ generation_section["inference_steps"],
48
+ generation_section["guidance_scale"],
49
+ generation_section["use_adg"],
50
+ generation_section["shift"],
51
+ generation_section["cfg_interval_start"],
52
+ generation_section["cfg_interval_end"],
53
+ generation_section["task_type"],
54
+ ]
55
+ )
56
+
57
+ generation_section["init_btn"].click(
58
+ fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
59
+ inputs=[
60
+ generation_section["checkpoint_dropdown"],
61
+ generation_section["config_path"],
62
+ generation_section["device"],
63
+ generation_section["init_llm_checkbox"],
64
+ generation_section["lm_model_path"],
65
+ generation_section["backend_dropdown"],
66
+ generation_section["use_flash_attention_checkbox"],
67
+ generation_section["offload_to_cpu_checkbox"],
68
+ generation_section["offload_dit_to_cpu_checkbox"],
69
+ ],
70
+ outputs=[
71
+ generation_section["init_status"],
72
+ generation_section["generate_btn"],
73
+ generation_section["service_config_accordion"],
74
+ # Model type settings (updated based on actual loaded model)
75
+ generation_section["inference_steps"],
76
+ generation_section["guidance_scale"],
77
+ generation_section["use_adg"],
78
+ generation_section["shift"],
79
+ generation_section["cfg_interval_start"],
80
+ generation_section["cfg_interval_end"],
81
+ generation_section["task_type"],
82
+ ]
83
+ )
84
+
85
+ # ========== LoRA Handlers ==========
86
+ generation_section["load_lora_btn"].click(
87
+ fn=dit_handler.load_lora,
88
+ inputs=[generation_section["lora_path"]],
89
+ outputs=[generation_section["lora_status"]]
90
+ ).then(
91
+ # Update checkbox to enabled state after loading
92
+ fn=lambda: gr.update(value=True),
93
+ outputs=[generation_section["use_lora_checkbox"]]
94
+ )
95
+
96
+ generation_section["unload_lora_btn"].click(
97
+ fn=dit_handler.unload_lora,
98
+ outputs=[generation_section["lora_status"]]
99
+ ).then(
100
+ # Update checkbox to disabled state after unloading
101
+ fn=lambda: gr.update(value=False),
102
+ outputs=[generation_section["use_lora_checkbox"]]
103
+ )
104
+
105
+ generation_section["use_lora_checkbox"].change(
106
+ fn=dit_handler.set_use_lora,
107
+ inputs=[generation_section["use_lora_checkbox"]],
108
+ outputs=[generation_section["lora_status"]]
109
+ )
110
+
111
+ # ========== UI Visibility Updates ==========
112
+ generation_section["init_llm_checkbox"].change(
113
+ fn=gen_h.update_negative_prompt_visibility,
114
+ inputs=[generation_section["init_llm_checkbox"]],
115
+ outputs=[generation_section["lm_negative_prompt"]]
116
+ )
117
+
118
+ generation_section["init_llm_checkbox"].change(
119
+ fn=gen_h.update_audio_cover_strength_visibility,
120
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
121
+ outputs=[generation_section["audio_cover_strength"]]
122
+ )
123
+
124
+ generation_section["task_type"].change(
125
+ fn=gen_h.update_audio_cover_strength_visibility,
126
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
127
+ outputs=[generation_section["audio_cover_strength"]]
128
+ )
129
+
130
+ generation_section["batch_size_input"].change(
131
+ fn=gen_h.update_audio_components_visibility,
132
+ inputs=[generation_section["batch_size_input"]],
133
+ outputs=[
134
+ results_section["audio_col_1"],
135
+ results_section["audio_col_2"],
136
+ results_section["audio_col_3"],
137
+ results_section["audio_col_4"],
138
+ results_section["audio_row_5_8"],
139
+ results_section["audio_col_5"],
140
+ results_section["audio_col_6"],
141
+ results_section["audio_col_7"],
142
+ results_section["audio_col_8"],
143
+ ]
144
+ )
145
+
146
+ # ========== Audio Conversion ==========
147
+ generation_section["convert_src_to_codes_btn"].click(
148
+ fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
149
+ inputs=[generation_section["src_audio"]],
150
+ outputs=[generation_section["text2music_audio_code_string"]]
151
+ )
152
+
153
+ # ========== Instruction UI Updates ==========
154
+ for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"]]:
155
+ trigger.change(
156
+ fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
157
+ inputs=[
158
+ generation_section["task_type"],
159
+ generation_section["track_name"],
160
+ generation_section["complete_track_classes"],
161
+ generation_section["text2music_audio_code_string"],
162
+ generation_section["init_llm_checkbox"]
163
+ ],
164
+ outputs=[
165
+ generation_section["instruction_display_gen"],
166
+ generation_section["track_name"],
167
+ generation_section["complete_track_classes"],
168
+ generation_section["audio_cover_strength"],
169
+ generation_section["repainting_group"],
170
+ generation_section["text2music_audio_codes_group"],
171
+ ]
172
+ )
173
+
174
+ # ========== Sample/Transcribe Handlers ==========
175
+ # Load random example from ./examples/text2music directory
176
+ generation_section["sample_btn"].click(
177
+ fn=lambda task: gen_h.load_random_example(task) + (True,),
178
+ inputs=[
179
+ generation_section["task_type"],
180
+ ],
181
+ outputs=[
182
+ generation_section["captions"],
183
+ generation_section["lyrics"],
184
+ generation_section["think_checkbox"],
185
+ generation_section["bpm"],
186
+ generation_section["audio_duration"],
187
+ generation_section["key_scale"],
188
+ generation_section["vocal_language"],
189
+ generation_section["time_signature"],
190
+ results_section["is_format_caption_state"]
191
+ ]
192
+ )
193
+
194
+ generation_section["text2music_audio_code_string"].change(
195
+ fn=gen_h.update_transcribe_button_text,
196
+ inputs=[generation_section["text2music_audio_code_string"]],
197
+ outputs=[generation_section["transcribe_btn"]]
198
+ )
199
+
200
+ generation_section["transcribe_btn"].click(
201
+ fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
202
+ inputs=[
203
+ generation_section["text2music_audio_code_string"],
204
+ generation_section["constrained_decoding_debug"]
205
+ ],
206
+ outputs=[
207
+ results_section["status_output"],
208
+ generation_section["captions"],
209
+ generation_section["lyrics"],
210
+ generation_section["bpm"],
211
+ generation_section["audio_duration"],
212
+ generation_section["key_scale"],
213
+ generation_section["vocal_language"],
214
+ generation_section["time_signature"],
215
+ results_section["is_format_caption_state"]
216
+ ]
217
+ )
218
+
219
+ # ========== Reset Format Caption Flag ==========
220
+ for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
221
+ generation_section["key_scale"], generation_section["time_signature"],
222
+ generation_section["vocal_language"], generation_section["audio_duration"]]:
223
+ trigger.change(
224
+ fn=gen_h.reset_format_caption_flag,
225
+ inputs=[],
226
+ outputs=[results_section["is_format_caption_state"]]
227
+ )
228
+
229
+ # ========== Audio Uploads Accordion ==========
230
+ for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
231
+ trigger.change(
232
+ fn=gen_h.update_audio_uploads_accordion,
233
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
234
+ outputs=[generation_section["audio_uploads_accordion"]]
235
+ )
236
+
237
+ # ========== Instrumental Checkbox ==========
238
+ generation_section["instrumental_checkbox"].change(
239
+ fn=gen_h.handle_instrumental_checkbox,
240
+ inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
241
+ outputs=[generation_section["lyrics"]]
242
+ )
243
+
244
+ # ========== Format Button ==========
245
+ # Note: cfg_scale and negative_prompt are not supported in format mode
246
+ generation_section["format_btn"].click(
247
+ fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
248
+ llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
249
+ ),
250
+ inputs=[
251
+ generation_section["captions"],
252
+ generation_section["lyrics"],
253
+ generation_section["bpm"],
254
+ generation_section["audio_duration"],
255
+ generation_section["key_scale"],
256
+ generation_section["time_signature"],
257
+ generation_section["lm_temperature"],
258
+ generation_section["lm_top_k"],
259
+ generation_section["lm_top_p"],
260
+ generation_section["constrained_decoding_debug"],
261
+ ],
262
+ outputs=[
263
+ generation_section["captions"],
264
+ generation_section["lyrics"],
265
+ generation_section["bpm"],
266
+ generation_section["audio_duration"],
267
+ generation_section["key_scale"],
268
+ generation_section["vocal_language"],
269
+ generation_section["time_signature"],
270
+ results_section["is_format_caption_state"],
271
+ results_section["status_output"],
272
+ ]
273
+ )
274
+
275
+ # ========== Generation Mode Toggle (Simple/Custom/Cover/Repaint) ==========
276
+ generation_section["generation_mode"].change(
277
+ fn=gen_h.handle_generation_mode_change,
278
+ inputs=[generation_section["generation_mode"]],
279
+ outputs=[
280
+ generation_section["simple_mode_group"],
281
+ generation_section["custom_mode_content"],
282
+ generation_section["cover_mode_group"],
283
+ generation_section["repainting_group"],
284
+ generation_section["task_type"],
285
+ generation_section["generate_btn"],
286
+ generation_section["simple_sample_created"],
287
+ generation_section["src_audio_group"],
288
+ generation_section["audio_cover_strength"],
289
+ generation_section["think_checkbox"], # Disable thinking for cover/repaint modes
290
+ ]
291
+ )
292
+
293
+ # ========== Process Source Audio Button ==========
294
+ # Combines Convert to Codes + Transcribe in one step
295
+ generation_section["process_src_btn"].click(
296
+ fn=lambda src, debug: gen_h.process_source_audio(dit_handler, llm_handler, src, debug),
297
+ inputs=[
298
+ generation_section["src_audio"],
299
+ generation_section["constrained_decoding_debug"]
300
+ ],
301
+ outputs=[
302
+ generation_section["text2music_audio_code_string"],
303
+ results_section["status_output"],
304
+ generation_section["captions"],
305
+ generation_section["lyrics"],
306
+ generation_section["bpm"],
307
+ generation_section["audio_duration"],
308
+ generation_section["key_scale"],
309
+ generation_section["vocal_language"],
310
+ generation_section["time_signature"],
311
+ results_section["is_format_caption_state"],
312
+ ]
313
+ )
314
+
315
+ # ========== Simple Mode Instrumental Checkbox ==========
316
+ # When instrumental is checked, disable vocal language and set to ["unknown"]
317
+ generation_section["simple_instrumental_checkbox"].change(
318
+ fn=gen_h.handle_simple_instrumental_change,
319
+ inputs=[generation_section["simple_instrumental_checkbox"]],
320
+ outputs=[generation_section["simple_vocal_language"]]
321
+ )
322
+
323
+ # ========== Random Description Button ==========
324
+ generation_section["random_desc_btn"].click(
325
+ fn=gen_h.load_random_simple_description,
326
+ inputs=[],
327
+ outputs=[
328
+ generation_section["simple_query_input"],
329
+ generation_section["simple_instrumental_checkbox"],
330
+ generation_section["simple_vocal_language"],
331
+ ]
332
+ )
333
+
334
+ # ========== Create Sample Button (Simple Mode) ==========
335
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
336
+ generation_section["create_sample_btn"].click(
337
+ fn=lambda query, instrumental, vocal_lang, temp, top_k, top_p, debug: gen_h.handle_create_sample(
338
+ llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
339
+ ),
340
+ inputs=[
341
+ generation_section["simple_query_input"],
342
+ generation_section["simple_instrumental_checkbox"],
343
+ generation_section["simple_vocal_language"],
344
+ generation_section["lm_temperature"],
345
+ generation_section["lm_top_k"],
346
+ generation_section["lm_top_p"],
347
+ generation_section["constrained_decoding_debug"],
348
+ ],
349
+ outputs=[
350
+ generation_section["captions"],
351
+ generation_section["lyrics"],
352
+ generation_section["bpm"],
353
+ generation_section["audio_duration"],
354
+ generation_section["key_scale"],
355
+ generation_section["vocal_language"],
356
+ generation_section["simple_vocal_language"],
357
+ generation_section["time_signature"],
358
+ generation_section["instrumental_checkbox"],
359
+ generation_section["caption_accordion"],
360
+ generation_section["lyrics_accordion"],
361
+ generation_section["generate_btn"],
362
+ generation_section["simple_sample_created"],
363
+ generation_section["think_checkbox"],
364
+ results_section["is_format_caption_state"],
365
+ results_section["status_output"],
366
+ ]
367
+ )
368
+
369
+ # ========== Load/Save Metadata ==========
370
+ generation_section["load_file"].upload(
371
+ fn=gen_h.load_metadata,
372
+ inputs=[generation_section["load_file"]],
373
+ outputs=[
374
+ generation_section["task_type"],
375
+ generation_section["captions"],
376
+ generation_section["lyrics"],
377
+ generation_section["vocal_language"],
378
+ generation_section["bpm"],
379
+ generation_section["key_scale"],
380
+ generation_section["time_signature"],
381
+ generation_section["audio_duration"],
382
+ generation_section["batch_size_input"],
383
+ generation_section["inference_steps"],
384
+ generation_section["guidance_scale"],
385
+ generation_section["seed"],
386
+ generation_section["random_seed_checkbox"],
387
+ generation_section["use_adg"],
388
+ generation_section["cfg_interval_start"],
389
+ generation_section["cfg_interval_end"],
390
+ generation_section["shift"],
391
+ generation_section["infer_method"],
392
+ generation_section["custom_timesteps"],
393
+ generation_section["audio_format"],
394
+ generation_section["lm_temperature"],
395
+ generation_section["lm_cfg_scale"],
396
+ generation_section["lm_top_k"],
397
+ generation_section["lm_top_p"],
398
+ generation_section["lm_negative_prompt"],
399
+ generation_section["use_cot_metas"], # Added: use_cot_metas
400
+ generation_section["use_cot_caption"],
401
+ generation_section["use_cot_language"],
402
+ generation_section["audio_cover_strength"],
403
+ generation_section["think_checkbox"],
404
+ generation_section["text2music_audio_code_string"],
405
+ generation_section["repainting_start"],
406
+ generation_section["repainting_end"],
407
+ generation_section["track_name"],
408
+ generation_section["complete_track_classes"],
409
+ generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
410
+ results_section["is_format_caption_state"]
411
+ ]
412
+ )
413
+
414
+ # Save buttons for all 8 audio outputs
415
+ download_existing_js = """(current_audio, batch_files) => {
416
+ // Debug: print what the input actually is
417
+ console.log("👉 [Debug] Current Audio Input:", current_audio);
418
+
419
+ // 1. Safety check
420
+ if (!current_audio) {
421
+ console.warn("⚠️ No audio selected or audio is empty.");
422
+ return;
423
+ }
424
+ if (!batch_files || !Array.isArray(batch_files)) {
425
+ console.warn("⚠️ Batch file list is empty/not ready.");
426
+ return;
427
+ }
428
+
429
+ // 2. Smartly extract path string
430
+ let pathString = "";
431
+
432
+ if (typeof current_audio === "string") {
433
+ // Case A: direct path string received
434
+ pathString = current_audio;
435
+ } else if (typeof current_audio === "object") {
436
+ // Case B: an object is received, try common properties
437
+ // Gradio file objects usually have path, url, or name
438
+ pathString = current_audio.path || current_audio.name || current_audio.url || "";
439
+ }
440
+
441
+ if (!pathString) {
442
+ console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
443
+ return;
444
+ }
445
+
446
+ // 3. Extract Key (UUID)
447
+ // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
448
+ let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
449
+ let key = filename.split('.')[0]; // get UUID without extension
450
+
451
+ console.log(`🔑 Key extracted: ${key}`);
452
+
453
+ // 4. Find matching file(s) in the list
454
+ let targets = batch_files.filter(f => {
455
+ // Also extract names from batch_files objects
456
+ // f usually contains name (backend path) and orig_name (download name)
457
+ const fPath = f.name || f.path || "";
458
+ return fPath.includes(key);
459
+ });
460
+
461
+ if (targets.length === 0) {
462
+ console.warn("❌ No matching files found in batch list for key:", key);
463
+ alert("Batch list does not contain this file yet. Please wait for generation to finish.");
464
+ return;
465
+ }
466
+
467
+ // 5. Trigger download(s)
468
+ console.log(`🎯 Found ${targets.length} files to download.`);
469
+ targets.forEach((f, index) => {
470
+ setTimeout(() => {
471
+ const a = document.createElement('a');
472
+ // Prefer url (frontend-accessible link), otherwise try data
473
+ a.href = f.url || f.data;
474
+ a.download = f.orig_name || "download";
475
+ a.style.display = 'none';
476
+ document.body.appendChild(a);
477
+ a.click();
478
+ document.body.removeChild(a);
479
+ }, index * 1000); // 300ms interval to avoid browser blocking
480
+ });
481
+ }
482
+ """
483
+ for btn_idx in range(1, 9):
484
+ results_section[f"save_btn_{btn_idx}"].click(
485
+ fn=None,
486
+ inputs=[
487
+ results_section[f"generated_audio_{btn_idx}"],
488
+ results_section["generated_audio_batch"],
489
+ ],
490
+ js=download_existing_js # Run the above JS
491
+ )
492
+ # ========== Send to Cover Handlers ==========
493
+ def send_to_cover_handler(audio_file, lm_metadata):
494
+ """Send audio to cover mode and switch to cover"""
495
+ if audio_file is None:
496
+ return (gr.skip(),) * 11
497
+ return (
498
+ audio_file, # src_audio
499
+ gr.skip(), # bpm
500
+ gr.skip(), # captions
501
+ gr.skip(), # lyrics
502
+ gr.skip(), # audio_duration
503
+ gr.skip(), # key_scale
504
+ gr.skip(), # vocal_language
505
+ gr.skip(), # time_signature
506
+ gr.skip(), # is_format_caption_state
507
+ "cover", # generation_mode - switch to cover
508
+ "cover", # task_type - set to cover
509
+ )
510
+
511
+ for btn_idx in range(1, 9):
512
+ results_section[f"send_to_cover_btn_{btn_idx}"].click(
513
+ fn=send_to_cover_handler,
514
+ inputs=[
515
+ results_section[f"generated_audio_{btn_idx}"],
516
+ results_section["lm_metadata_state"]
517
+ ],
518
+ outputs=[
519
+ generation_section["src_audio"],
520
+ generation_section["bpm"],
521
+ generation_section["captions"],
522
+ generation_section["lyrics"],
523
+ generation_section["audio_duration"],
524
+ generation_section["key_scale"],
525
+ generation_section["vocal_language"],
526
+ generation_section["time_signature"],
527
+ results_section["is_format_caption_state"],
528
+ generation_section["generation_mode"],
529
+ generation_section["task_type"],
530
+ ]
531
+ )
532
+
533
+ # ========== Send to Repaint Handlers ==========
534
+ def send_to_repaint_handler(audio_file, lm_metadata):
535
+ """Send audio to repaint mode and switch to repaint"""
536
+ if audio_file is None:
537
+ return (gr.skip(),) * 11
538
+ return (
539
+ audio_file, # src_audio
540
+ gr.skip(), # bpm
541
+ gr.skip(), # captions
542
+ gr.skip(), # lyrics
543
+ gr.skip(), # audio_duration
544
+ gr.skip(), # key_scale
545
+ gr.skip(), # vocal_language
546
+ gr.skip(), # time_signature
547
+ gr.skip(), # is_format_caption_state
548
+ "repaint", # generation_mode - switch to repaint
549
+ "repaint", # task_type - set to repaint
550
+ )
551
+
552
+ for btn_idx in range(1, 9):
553
+ results_section[f"send_to_repaint_btn_{btn_idx}"].click(
554
+ fn=send_to_repaint_handler,
555
+ inputs=[
556
+ results_section[f"generated_audio_{btn_idx}"],
557
+ results_section["lm_metadata_state"]
558
+ ],
559
+ outputs=[
560
+ generation_section["src_audio"],
561
+ generation_section["bpm"],
562
+ generation_section["captions"],
563
+ generation_section["lyrics"],
564
+ generation_section["audio_duration"],
565
+ generation_section["key_scale"],
566
+ generation_section["vocal_language"],
567
+ generation_section["time_signature"],
568
+ results_section["is_format_caption_state"],
569
+ generation_section["generation_mode"],
570
+ generation_section["task_type"],
571
+ ]
572
+ )
573
+
574
+ # ========== Score Calculation Handlers ==========
575
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
576
+ def make_score_handler(idx):
577
+ return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
578
+ dit_handler, llm_handler, idx, scale, batch_idx, queue
579
+ )
580
+
581
+ for btn_idx in range(1, 9):
582
+ results_section[f"score_btn_{btn_idx}"].click(
583
+ fn=make_score_handler(btn_idx),
584
+ inputs=[
585
+ generation_section["score_scale"],
586
+ results_section["current_batch_index"],
587
+ results_section["batch_queue"],
588
+ ],
589
+ outputs=[
590
+ results_section[f"score_display_{btn_idx}"],
591
+ results_section[f"details_accordion_{btn_idx}"],
592
+ results_section["batch_queue"]
593
+ ]
594
+ )
595
+
596
+ # ========== LRC Timestamp Handlers ==========
597
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
598
+ def make_lrc_handler(idx):
599
+ return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
600
+ dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
601
+ )
602
+
603
+ for btn_idx in range(1, 9):
604
+ results_section[f"lrc_btn_{btn_idx}"].click(
605
+ fn=make_lrc_handler(btn_idx),
606
+ inputs=[
607
+ results_section["current_batch_index"],
608
+ results_section["batch_queue"],
609
+ generation_section["vocal_language"],
610
+ generation_section["inference_steps"],
611
+ ],
612
+ outputs=[
613
+ results_section[f"lrc_display_{btn_idx}"],
614
+ results_section[f"details_accordion_{btn_idx}"],
615
+ # NOTE: Removed generated_audio output!
616
+ # Audio subtitles are now updated via lrc_display.change() event.
617
+ results_section["batch_queue"]
618
+ ]
619
+ )
620
+
621
+ def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
622
+ """Wrapper that selects the appropriate DiT handler based on model selection"""
623
+ # Convert args to list for modification
624
+ args_list = list(args)
625
+
626
+ # args order (after simple mode params):
627
+ # captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
628
+ # inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
629
+ # reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
630
+ # text2music_audio_code_string (14), repainting_start (15), repainting_end (16),
631
+ # instruction_display_gen (17), audio_cover_strength (18), task_type (19), ...
632
+ # ... lm_temperature (27), think_checkbox (28), ...
633
+ # ... instrumental_checkbox (at position after all regular params)
634
+
635
+ src_audio = args_list[13] if len(args_list) > 13 else None
636
+ task_type = args_list[19] if len(args_list) > 19 else "text2music"
637
+
638
+ # Validate: Cover and Repaint modes require source audio
639
+ if task_type in ["cover", "repaint"] and src_audio is None:
640
+ raise gr.Error(f"Source Audio is required for {task_type.capitalize()} mode. Please upload an audio file.")
641
+
642
+ # Handle Simple mode: first create sample, then generate
643
+ if generation_mode == "simple":
644
+ # Get instrumental from the main checkbox (args[-6] based on input order)
645
+ # The instrumental_checkbox is passed after all the regular generation params
646
+ instrumental = args_list[-6] if len(args_list) > 6 else False # instrumental_checkbox position
647
+ lm_temperature = args_list[27] if len(args_list) > 27 else 0.85
648
+ lm_top_k = args_list[30] if len(args_list) > 30 else 0
649
+ lm_top_p = args_list[31] if len(args_list) > 31 else 0.9
650
+ constrained_decoding_debug = args_list[38] if len(args_list) > 38 else False
651
+
652
+ # Call create_sample to generate caption/lyrics/metadata
653
+ from acestep.inference import create_sample
654
+
655
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
656
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
657
+
658
+ result = create_sample(
659
+ llm_handler=llm_handler,
660
+ query=simple_query_input,
661
+ instrumental=instrumental,
662
+ vocal_language=simple_vocal_language,
663
+ temperature=lm_temperature,
664
+ top_k=top_k_value,
665
+ top_p=top_p_value,
666
+ use_constrained_decoding=True,
667
+ constrained_decoding_debug=constrained_decoding_debug,
668
+ )
669
+
670
+ if not result.success:
671
+ raise gr.Error(f"Failed to create sample: {result.status_message}")
672
+
673
+ # Update args with generated data
674
+ args_list[0] = result.caption # captions
675
+ args_list[1] = result.lyrics # lyrics
676
+ args_list[2] = result.bpm # bpm
677
+ args_list[3] = result.keyscale # key_scale
678
+ args_list[4] = result.timesignature # time_signature
679
+ args_list[5] = result.language # vocal_language
680
+ if result.duration and result.duration > 0:
681
+ args_list[11] = result.duration # audio_duration
682
+ # Enable thinking for Simple mode
683
+ args_list[28] = True # think_checkbox
684
+ # Mark as formatted caption (LM-generated sample)
685
+ args_list[36] = True # is_format_caption_state
686
+
687
+ # Determine which handler to use
688
+ active_handler = dit_handler # Default to primary handler
689
+ if dit_handler_2 is not None and selected_model == config_path_2:
690
+ active_handler = dit_handler_2
691
+ yield from res_h.generate_with_batch_management(active_handler, llm_handler, *args_list)
692
+
693
+ # ========== Generation Handler ==========
694
+ generation_section["generate_btn"].click(
695
+ fn=generation_wrapper,
696
+ inputs=[
697
+ generation_section["dit_model_selector"], # Model selection input
698
+ generation_section["generation_mode"], # For Simple mode detection
699
+ generation_section["simple_query_input"], # Simple mode query
700
+ generation_section["simple_vocal_language"], # Simple mode vocal language
701
+ generation_section["captions"],
702
+ generation_section["lyrics"],
703
+ generation_section["bpm"],
704
+ generation_section["key_scale"],
705
+ generation_section["time_signature"],
706
+ generation_section["vocal_language"],
707
+ generation_section["inference_steps"],
708
+ generation_section["guidance_scale"],
709
+ generation_section["random_seed_checkbox"],
710
+ generation_section["seed"],
711
+ generation_section["reference_audio"],
712
+ generation_section["audio_duration"],
713
+ generation_section["batch_size_input"],
714
+ generation_section["src_audio"],
715
+ generation_section["text2music_audio_code_string"],
716
+ generation_section["repainting_start"],
717
+ generation_section["repainting_end"],
718
+ generation_section["instruction_display_gen"],
719
+ generation_section["audio_cover_strength"],
720
+ generation_section["task_type"],
721
+ generation_section["use_adg"],
722
+ generation_section["cfg_interval_start"],
723
+ generation_section["cfg_interval_end"],
724
+ generation_section["shift"],
725
+ generation_section["infer_method"],
726
+ generation_section["custom_timesteps"],
727
+ generation_section["audio_format"],
728
+ generation_section["lm_temperature"],
729
+ generation_section["think_checkbox"],
730
+ generation_section["lm_cfg_scale"],
731
+ generation_section["lm_top_k"],
732
+ generation_section["lm_top_p"],
733
+ generation_section["lm_negative_prompt"],
734
+ generation_section["use_cot_metas"],
735
+ generation_section["use_cot_caption"],
736
+ generation_section["use_cot_language"],
737
+ results_section["is_format_caption_state"],
738
+ generation_section["constrained_decoding_debug"],
739
+ generation_section["allow_lm_batch"],
740
+ generation_section["auto_score"],
741
+ generation_section["auto_lrc"],
742
+ generation_section["score_scale"],
743
+ generation_section["lm_batch_chunk_size"],
744
+ generation_section["track_name"],
745
+ generation_section["complete_track_classes"],
746
+ generation_section["autogen_checkbox"],
747
+ results_section["current_batch_index"],
748
+ results_section["total_batches"],
749
+ results_section["batch_queue"],
750
+ results_section["generation_params_state"],
751
+ ],
752
+ outputs=[
753
+ results_section["generated_audio_1"],
754
+ results_section["generated_audio_2"],
755
+ results_section["generated_audio_3"],
756
+ results_section["generated_audio_4"],
757
+ results_section["generated_audio_5"],
758
+ results_section["generated_audio_6"],
759
+ results_section["generated_audio_7"],
760
+ results_section["generated_audio_8"],
761
+ results_section["generated_audio_batch"],
762
+ results_section["generation_info"],
763
+ results_section["status_output"],
764
+ generation_section["seed"],
765
+ results_section["score_display_1"],
766
+ results_section["score_display_2"],
767
+ results_section["score_display_3"],
768
+ results_section["score_display_4"],
769
+ results_section["score_display_5"],
770
+ results_section["score_display_6"],
771
+ results_section["score_display_7"],
772
+ results_section["score_display_8"],
773
+ results_section["codes_display_1"],
774
+ results_section["codes_display_2"],
775
+ results_section["codes_display_3"],
776
+ results_section["codes_display_4"],
777
+ results_section["codes_display_5"],
778
+ results_section["codes_display_6"],
779
+ results_section["codes_display_7"],
780
+ results_section["codes_display_8"],
781
+ results_section["details_accordion_1"],
782
+ results_section["details_accordion_2"],
783
+ results_section["details_accordion_3"],
784
+ results_section["details_accordion_4"],
785
+ results_section["details_accordion_5"],
786
+ results_section["details_accordion_6"],
787
+ results_section["details_accordion_7"],
788
+ results_section["details_accordion_8"],
789
+ results_section["lrc_display_1"],
790
+ results_section["lrc_display_2"],
791
+ results_section["lrc_display_3"],
792
+ results_section["lrc_display_4"],
793
+ results_section["lrc_display_5"],
794
+ results_section["lrc_display_6"],
795
+ results_section["lrc_display_7"],
796
+ results_section["lrc_display_8"],
797
+ results_section["lm_metadata_state"],
798
+ results_section["is_format_caption_state"],
799
+ results_section["current_batch_index"],
800
+ results_section["total_batches"],
801
+ results_section["batch_queue"],
802
+ results_section["generation_params_state"],
803
+ results_section["batch_indicator"],
804
+ results_section["prev_batch_btn"],
805
+ results_section["next_batch_btn"],
806
+ results_section["next_batch_status"],
807
+ results_section["restore_params_btn"],
808
+ ]
809
+ ).then(
810
+ fn=lambda selected_model, *args: res_h.generate_next_batch_background(
811
+ dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
812
+ llm_handler, *args
813
+ ),
814
+ inputs=[
815
+ generation_section["dit_model_selector"], # Model selection input
816
+ generation_section["autogen_checkbox"],
817
+ results_section["generation_params_state"],
818
+ results_section["current_batch_index"],
819
+ results_section["total_batches"],
820
+ results_section["batch_queue"],
821
+ results_section["is_format_caption_state"],
822
+ ],
823
+ outputs=[
824
+ results_section["batch_queue"],
825
+ results_section["total_batches"],
826
+ results_section["next_batch_status"],
827
+ results_section["next_batch_btn"],
828
+ ]
829
+ )
830
+
831
+ # ========== Batch Navigation Handlers ==========
832
+ results_section["prev_batch_btn"].click(
833
+ fn=res_h.navigate_to_previous_batch,
834
+ inputs=[
835
+ results_section["current_batch_index"],
836
+ results_section["batch_queue"],
837
+ ],
838
+ outputs=[
839
+ results_section["generated_audio_1"],
840
+ results_section["generated_audio_2"],
841
+ results_section["generated_audio_3"],
842
+ results_section["generated_audio_4"],
843
+ results_section["generated_audio_5"],
844
+ results_section["generated_audio_6"],
845
+ results_section["generated_audio_7"],
846
+ results_section["generated_audio_8"],
847
+ results_section["generated_audio_batch"],
848
+ results_section["generation_info"],
849
+ results_section["current_batch_index"],
850
+ results_section["batch_indicator"],
851
+ results_section["prev_batch_btn"],
852
+ results_section["next_batch_btn"],
853
+ results_section["status_output"],
854
+ results_section["score_display_1"],
855
+ results_section["score_display_2"],
856
+ results_section["score_display_3"],
857
+ results_section["score_display_4"],
858
+ results_section["score_display_5"],
859
+ results_section["score_display_6"],
860
+ results_section["score_display_7"],
861
+ results_section["score_display_8"],
862
+ results_section["codes_display_1"],
863
+ results_section["codes_display_2"],
864
+ results_section["codes_display_3"],
865
+ results_section["codes_display_4"],
866
+ results_section["codes_display_5"],
867
+ results_section["codes_display_6"],
868
+ results_section["codes_display_7"],
869
+ results_section["codes_display_8"],
870
+ results_section["lrc_display_1"],
871
+ results_section["lrc_display_2"],
872
+ results_section["lrc_display_3"],
873
+ results_section["lrc_display_4"],
874
+ results_section["lrc_display_5"],
875
+ results_section["lrc_display_6"],
876
+ results_section["lrc_display_7"],
877
+ results_section["lrc_display_8"],
878
+ results_section["details_accordion_1"],
879
+ results_section["details_accordion_2"],
880
+ results_section["details_accordion_3"],
881
+ results_section["details_accordion_4"],
882
+ results_section["details_accordion_5"],
883
+ results_section["details_accordion_6"],
884
+ results_section["details_accordion_7"],
885
+ results_section["details_accordion_8"],
886
+ results_section["restore_params_btn"],
887
+ ]
888
+ )
889
+
890
+ results_section["next_batch_btn"].click(
891
+ fn=res_h.capture_current_params,
892
+ inputs=[
893
+ generation_section["captions"],
894
+ generation_section["lyrics"],
895
+ generation_section["bpm"],
896
+ generation_section["key_scale"],
897
+ generation_section["time_signature"],
898
+ generation_section["vocal_language"],
899
+ generation_section["inference_steps"],
900
+ generation_section["guidance_scale"],
901
+ generation_section["random_seed_checkbox"],
902
+ generation_section["seed"],
903
+ generation_section["reference_audio"],
904
+ generation_section["audio_duration"],
905
+ generation_section["batch_size_input"],
906
+ generation_section["src_audio"],
907
+ generation_section["text2music_audio_code_string"],
908
+ generation_section["repainting_start"],
909
+ generation_section["repainting_end"],
910
+ generation_section["instruction_display_gen"],
911
+ generation_section["audio_cover_strength"],
912
+ generation_section["task_type"],
913
+ generation_section["use_adg"],
914
+ generation_section["cfg_interval_start"],
915
+ generation_section["cfg_interval_end"],
916
+ generation_section["shift"],
917
+ generation_section["infer_method"],
918
+ generation_section["custom_timesteps"],
919
+ generation_section["audio_format"],
920
+ generation_section["lm_temperature"],
921
+ generation_section["think_checkbox"],
922
+ generation_section["lm_cfg_scale"],
923
+ generation_section["lm_top_k"],
924
+ generation_section["lm_top_p"],
925
+ generation_section["lm_negative_prompt"],
926
+ generation_section["use_cot_metas"],
927
+ generation_section["use_cot_caption"],
928
+ generation_section["use_cot_language"],
929
+ generation_section["constrained_decoding_debug"],
930
+ generation_section["allow_lm_batch"],
931
+ generation_section["auto_score"],
932
+ generation_section["auto_lrc"],
933
+ generation_section["score_scale"],
934
+ generation_section["lm_batch_chunk_size"],
935
+ generation_section["track_name"],
936
+ generation_section["complete_track_classes"],
937
+ ],
938
+ outputs=[results_section["generation_params_state"]]
939
+ ).then(
940
+ fn=res_h.navigate_to_next_batch,
941
+ inputs=[
942
+ generation_section["autogen_checkbox"],
943
+ results_section["current_batch_index"],
944
+ results_section["total_batches"],
945
+ results_section["batch_queue"],
946
+ ],
947
+ outputs=[
948
+ results_section["generated_audio_1"],
949
+ results_section["generated_audio_2"],
950
+ results_section["generated_audio_3"],
951
+ results_section["generated_audio_4"],
952
+ results_section["generated_audio_5"],
953
+ results_section["generated_audio_6"],
954
+ results_section["generated_audio_7"],
955
+ results_section["generated_audio_8"],
956
+ results_section["generated_audio_batch"],
957
+ results_section["generation_info"],
958
+ results_section["current_batch_index"],
959
+ results_section["batch_indicator"],
960
+ results_section["prev_batch_btn"],
961
+ results_section["next_batch_btn"],
962
+ results_section["status_output"],
963
+ results_section["next_batch_status"],
964
+ results_section["score_display_1"],
965
+ results_section["score_display_2"],
966
+ results_section["score_display_3"],
967
+ results_section["score_display_4"],
968
+ results_section["score_display_5"],
969
+ results_section["score_display_6"],
970
+ results_section["score_display_7"],
971
+ results_section["score_display_8"],
972
+ results_section["codes_display_1"],
973
+ results_section["codes_display_2"],
974
+ results_section["codes_display_3"],
975
+ results_section["codes_display_4"],
976
+ results_section["codes_display_5"],
977
+ results_section["codes_display_6"],
978
+ results_section["codes_display_7"],
979
+ results_section["codes_display_8"],
980
+ results_section["lrc_display_1"],
981
+ results_section["lrc_display_2"],
982
+ results_section["lrc_display_3"],
983
+ results_section["lrc_display_4"],
984
+ results_section["lrc_display_5"],
985
+ results_section["lrc_display_6"],
986
+ results_section["lrc_display_7"],
987
+ results_section["lrc_display_8"],
988
+ results_section["details_accordion_1"],
989
+ results_section["details_accordion_2"],
990
+ results_section["details_accordion_3"],
991
+ results_section["details_accordion_4"],
992
+ results_section["details_accordion_5"],
993
+ results_section["details_accordion_6"],
994
+ results_section["details_accordion_7"],
995
+ results_section["details_accordion_8"],
996
+ results_section["restore_params_btn"],
997
+ ]
998
+ ).then(
999
+ fn=lambda selected_model, *args: res_h.generate_next_batch_background(
1000
+ dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
1001
+ llm_handler, *args
1002
+ ),
1003
+ inputs=[
1004
+ generation_section["dit_model_selector"], # Model selection input
1005
+ generation_section["autogen_checkbox"],
1006
+ results_section["generation_params_state"],
1007
+ results_section["current_batch_index"],
1008
+ results_section["total_batches"],
1009
+ results_section["batch_queue"],
1010
+ results_section["is_format_caption_state"],
1011
+ ],
1012
+ outputs=[
1013
+ results_section["batch_queue"],
1014
+ results_section["total_batches"],
1015
+ results_section["next_batch_status"],
1016
+ results_section["next_batch_btn"],
1017
+ ]
1018
+ )
1019
+
1020
+ # ========== Restore Parameters Handler ==========
1021
+ results_section["restore_params_btn"].click(
1022
+ fn=res_h.restore_batch_parameters,
1023
+ inputs=[
1024
+ results_section["current_batch_index"],
1025
+ results_section["batch_queue"]
1026
+ ],
1027
+ outputs=[
1028
+ generation_section["text2music_audio_code_string"],
1029
+ generation_section["captions"],
1030
+ generation_section["lyrics"],
1031
+ generation_section["bpm"],
1032
+ generation_section["key_scale"],
1033
+ generation_section["time_signature"],
1034
+ generation_section["vocal_language"],
1035
+ generation_section["audio_duration"],
1036
+ generation_section["batch_size_input"],
1037
+ generation_section["inference_steps"],
1038
+ generation_section["lm_temperature"],
1039
+ generation_section["lm_cfg_scale"],
1040
+ generation_section["lm_top_k"],
1041
+ generation_section["lm_top_p"],
1042
+ generation_section["think_checkbox"],
1043
+ generation_section["use_cot_caption"],
1044
+ generation_section["use_cot_language"],
1045
+ generation_section["allow_lm_batch"],
1046
+ generation_section["track_name"],
1047
+ generation_section["complete_track_classes"],
1048
+ ]
1049
+ )
1050
+
1051
+ # ========== LRC Display Change Handlers ==========
1052
+ # NEW APPROACH: Use lrc_display.change() to update audio subtitles
1053
+ # This decouples audio value updates from subtitle updates, avoiding flickering.
1054
+ #
1055
+ # When lrc_display text changes (from generate, LRC button, or manual edit):
1056
+ # 1. lrc_display.change() is triggered
1057
+ # 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
1058
+ # 3. Audio value is NEVER updated here - only subtitles
1059
+ for lrc_idx in range(1, 9):
1060
+ results_section[f"lrc_display_{lrc_idx}"].change(
1061
+ fn=res_h.update_audio_subtitles_from_lrc,
1062
+ inputs=[
1063
+ results_section[f"lrc_display_{lrc_idx}"],
1064
+ # audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
1065
+ ],
1066
+ outputs=[
1067
+ results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
1068
+ ]
1069
+ )
1070
+
1071
+
1072
+ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
1073
+ """Setup event handlers for the training tab (dataset builder and LoRA training)"""
1074
+
1075
+ # ========== Load Existing Dataset (Top Section) ==========
1076
+
1077
+ # Load existing dataset JSON at the top of Dataset Builder
1078
+ training_section["load_json_btn"].click(
1079
+ fn=train_h.load_existing_dataset_for_preprocess,
1080
+ inputs=[
1081
+ training_section["load_json_path"],
1082
+ training_section["dataset_builder_state"],
1083
+ ],
1084
+ outputs=[
1085
+ training_section["load_json_status"],
1086
+ training_section["audio_files_table"],
1087
+ training_section["sample_selector"],
1088
+ training_section["dataset_builder_state"],
1089
+ # Also update preview fields with first sample
1090
+ training_section["preview_audio"],
1091
+ training_section["preview_filename"],
1092
+ training_section["edit_caption"],
1093
+ training_section["edit_lyrics"],
1094
+ training_section["edit_bpm"],
1095
+ training_section["edit_keyscale"],
1096
+ training_section["edit_timesig"],
1097
+ training_section["edit_duration"],
1098
+ training_section["edit_language"],
1099
+ training_section["edit_instrumental"],
1100
+ ]
1101
+ )
1102
+
1103
+ # ========== Dataset Builder Handlers ==========
1104
+
1105
+ # Scan directory for audio files
1106
+ training_section["scan_btn"].click(
1107
+ fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
1108
+ dir, name, tag, pos, instr, state
1109
+ ),
1110
+ inputs=[
1111
+ training_section["audio_directory"],
1112
+ training_section["dataset_name"],
1113
+ training_section["custom_tag"],
1114
+ training_section["tag_position"],
1115
+ training_section["all_instrumental"],
1116
+ training_section["dataset_builder_state"],
1117
+ ],
1118
+ outputs=[
1119
+ training_section["audio_files_table"],
1120
+ training_section["scan_status"],
1121
+ training_section["sample_selector"],
1122
+ training_section["dataset_builder_state"],
1123
+ ]
1124
+ )
1125
+
1126
+ # Auto-label all samples
1127
+ training_section["auto_label_btn"].click(
1128
+ fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
1129
+ inputs=[
1130
+ training_section["dataset_builder_state"],
1131
+ training_section["skip_metas"],
1132
+ ],
1133
+ outputs=[
1134
+ training_section["audio_files_table"],
1135
+ training_section["label_progress"],
1136
+ training_section["dataset_builder_state"],
1137
+ ]
1138
+ )
1139
+
1140
+ # Sample selector change - update preview
1141
+ training_section["sample_selector"].change(
1142
+ fn=train_h.get_sample_preview,
1143
+ inputs=[
1144
+ training_section["sample_selector"],
1145
+ training_section["dataset_builder_state"],
1146
+ ],
1147
+ outputs=[
1148
+ training_section["preview_audio"],
1149
+ training_section["preview_filename"],
1150
+ training_section["edit_caption"],
1151
+ training_section["edit_lyrics"],
1152
+ training_section["edit_bpm"],
1153
+ training_section["edit_keyscale"],
1154
+ training_section["edit_timesig"],
1155
+ training_section["edit_duration"],
1156
+ training_section["edit_language"],
1157
+ training_section["edit_instrumental"],
1158
+ ]
1159
+ )
1160
+
1161
+ # Save sample edit
1162
+ training_section["save_edit_btn"].click(
1163
+ fn=train_h.save_sample_edit,
1164
+ inputs=[
1165
+ training_section["sample_selector"],
1166
+ training_section["edit_caption"],
1167
+ training_section["edit_lyrics"],
1168
+ training_section["edit_bpm"],
1169
+ training_section["edit_keyscale"],
1170
+ training_section["edit_timesig"],
1171
+ training_section["edit_language"],
1172
+ training_section["edit_instrumental"],
1173
+ training_section["dataset_builder_state"],
1174
+ ],
1175
+ outputs=[
1176
+ training_section["audio_files_table"],
1177
+ training_section["edit_status"],
1178
+ training_section["dataset_builder_state"],
1179
+ ]
1180
+ )
1181
+
1182
+ # Update settings when changed
1183
+ for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
1184
+ trigger.change(
1185
+ fn=train_h.update_settings,
1186
+ inputs=[
1187
+ training_section["custom_tag"],
1188
+ training_section["tag_position"],
1189
+ training_section["all_instrumental"],
1190
+ training_section["dataset_builder_state"],
1191
+ ],
1192
+ outputs=[training_section["dataset_builder_state"]]
1193
+ )
1194
+
1195
+ # Save dataset
1196
+ training_section["save_dataset_btn"].click(
1197
+ fn=train_h.save_dataset,
1198
+ inputs=[
1199
+ training_section["save_path"],
1200
+ training_section["dataset_name"],
1201
+ training_section["dataset_builder_state"],
1202
+ ],
1203
+ outputs=[training_section["save_status"]]
1204
+ )
1205
+
1206
+ # ========== Preprocess Handlers ==========
1207
+
1208
+ # Load existing dataset JSON for preprocessing
1209
+ # This also updates the preview section so users can view/edit samples
1210
+ training_section["load_existing_dataset_btn"].click(
1211
+ fn=train_h.load_existing_dataset_for_preprocess,
1212
+ inputs=[
1213
+ training_section["load_existing_dataset_path"],
1214
+ training_section["dataset_builder_state"],
1215
+ ],
1216
+ outputs=[
1217
+ training_section["load_existing_status"],
1218
+ training_section["audio_files_table"],
1219
+ training_section["sample_selector"],
1220
+ training_section["dataset_builder_state"],
1221
+ # Also update preview fields with first sample
1222
+ training_section["preview_audio"],
1223
+ training_section["preview_filename"],
1224
+ training_section["edit_caption"],
1225
+ training_section["edit_lyrics"],
1226
+ training_section["edit_bpm"],
1227
+ training_section["edit_keyscale"],
1228
+ training_section["edit_timesig"],
1229
+ training_section["edit_duration"],
1230
+ training_section["edit_language"],
1231
+ training_section["edit_instrumental"],
1232
+ ]
1233
+ )
1234
+
1235
+ # Preprocess dataset to tensor files
1236
+ training_section["preprocess_btn"].click(
1237
+ fn=lambda output_dir, state: train_h.preprocess_dataset(
1238
+ output_dir, dit_handler, state
1239
+ ),
1240
+ inputs=[
1241
+ training_section["preprocess_output_dir"],
1242
+ training_section["dataset_builder_state"],
1243
+ ],
1244
+ outputs=[training_section["preprocess_progress"]]
1245
+ )
1246
+
1247
+ # ========== Training Tab Handlers ==========
1248
+
1249
+ # Load preprocessed tensor dataset
1250
+ training_section["load_dataset_btn"].click(
1251
+ fn=train_h.load_training_dataset,
1252
+ inputs=[training_section["training_tensor_dir"]],
1253
+ outputs=[training_section["training_dataset_info"]]
1254
+ )
1255
+
1256
+ # Start training from preprocessed tensors
1257
+ def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
1258
+ try:
1259
+ for progress, log, plot, state in train_h.start_training(
1260
+ tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
1261
+ ):
1262
+ yield progress, log, plot, state
1263
+ except Exception as e:
1264
+ logger.exception("Training wrapper error")
1265
+ yield f"❌ Error: {str(e)}", str(e), None, ts
1266
+
1267
+ training_section["start_training_btn"].click(
1268
+ fn=training_wrapper,
1269
+ inputs=[
1270
+ training_section["training_tensor_dir"],
1271
+ training_section["lora_rank"],
1272
+ training_section["lora_alpha"],
1273
+ training_section["lora_dropout"],
1274
+ training_section["learning_rate"],
1275
+ training_section["train_epochs"],
1276
+ training_section["train_batch_size"],
1277
+ training_section["gradient_accumulation"],
1278
+ training_section["save_every_n_epochs"],
1279
+ training_section["training_shift"],
1280
+ training_section["training_seed"],
1281
+ training_section["lora_output_dir"],
1282
+ training_section["training_state"],
1283
+ ],
1284
+ outputs=[
1285
+ training_section["training_progress"],
1286
+ training_section["training_log"],
1287
+ training_section["training_loss_plot"],
1288
+ training_section["training_state"],
1289
+ ]
1290
+ )
1291
+
1292
+ # Stop training
1293
+ training_section["stop_training_btn"].click(
1294
+ fn=train_h.stop_training,
1295
+ inputs=[training_section["training_state"]],
1296
+ outputs=[
1297
+ training_section["training_progress"],
1298
+ training_section["training_state"],
1299
+ ]
1300
+ )
1301
+
1302
+ # Export LoRA
1303
+ training_section["export_lora_btn"].click(
1304
+ fn=train_h.export_lora,
1305
+ inputs=[
1306
+ training_section["export_path"],
1307
+ training_section["lora_output_dir"],
1308
+ ],
1309
+ outputs=[training_section["export_status"]]
1310
+ )
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/generation_handlers.py ADDED
@@ -0,0 +1,1054 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation Input Handlers Module
3
+ Contains event handlers and helper functions related to generation inputs
4
+ """
5
+ import os
6
+ import json
7
+ import random
8
+ import glob
9
+ import gradio as gr
10
+ from typing import Optional, List, Tuple
11
+ from acestep.constants import (
12
+ TASK_TYPES_TURBO,
13
+ TASK_TYPES_BASE,
14
+ )
15
+ from acestep.gradio_ui.i18n import t
16
+ from acestep.inference import understand_music, create_sample, format_sample
17
+
18
+
19
+ def parse_and_validate_timesteps(
20
+ timesteps_str: str,
21
+ inference_steps: int
22
+ ) -> Tuple[Optional[List[float]], bool, str]:
23
+ """
24
+ Parse timesteps string and validate.
25
+
26
+ Args:
27
+ timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
28
+ inference_steps: Expected number of inference steps
29
+
30
+ Returns:
31
+ Tuple of (parsed_timesteps, has_warning, warning_message)
32
+ - parsed_timesteps: List of float timesteps, or None if invalid/empty
33
+ - has_warning: Whether a warning was shown
34
+ - warning_message: Description of the warning
35
+ """
36
+ if not timesteps_str or not timesteps_str.strip():
37
+ return None, False, ""
38
+
39
+ # Parse comma-separated values
40
+ values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
41
+
42
+ if not values:
43
+ return None, False, ""
44
+
45
+ # Handle optional trailing 0
46
+ if values[-1] != "0":
47
+ values.append("0")
48
+
49
+ try:
50
+ timesteps = [float(v) for v in values]
51
+ except ValueError:
52
+ gr.Warning(t("messages.invalid_timesteps_format"))
53
+ return None, True, "Invalid format"
54
+
55
+ # Validate range [0, 1]
56
+ if any(ts < 0 or ts > 1 for ts in timesteps):
57
+ gr.Warning(t("messages.timesteps_out_of_range"))
58
+ return None, True, "Out of range"
59
+
60
+ # Check if count matches inference_steps
61
+ actual_steps = len(timesteps) - 1
62
+ if actual_steps != inference_steps:
63
+ gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
64
+ return timesteps, True, f"Using {actual_steps} steps from timesteps"
65
+
66
+ return timesteps, False, ""
67
+
68
+
69
+ def load_metadata(file_obj):
70
+ """Load generation parameters from a JSON file"""
71
+ if file_obj is None:
72
+ gr.Warning(t("messages.no_file_selected"))
73
+ return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
74
+
75
+ try:
76
+ # Read the uploaded file
77
+ if hasattr(file_obj, 'name'):
78
+ filepath = file_obj.name
79
+ else:
80
+ filepath = file_obj
81
+
82
+ with open(filepath, 'r', encoding='utf-8') as f:
83
+ metadata = json.load(f)
84
+
85
+ # Extract all fields
86
+ task_type = metadata.get('task_type', 'text2music')
87
+ captions = metadata.get('caption', '')
88
+ lyrics = metadata.get('lyrics', '')
89
+ vocal_language = metadata.get('vocal_language', 'unknown')
90
+
91
+ # Convert bpm
92
+ bpm_value = metadata.get('bpm')
93
+ if bpm_value is not None and bpm_value != "N/A":
94
+ try:
95
+ bpm = int(bpm_value) if bpm_value else None
96
+ except:
97
+ bpm = None
98
+ else:
99
+ bpm = None
100
+
101
+ key_scale = metadata.get('keyscale', '')
102
+ time_signature = metadata.get('timesignature', '')
103
+
104
+ # Convert duration
105
+ duration_value = metadata.get('duration', -1)
106
+ if duration_value is not None and duration_value != "N/A":
107
+ try:
108
+ audio_duration = float(duration_value)
109
+ except:
110
+ audio_duration = -1
111
+ else:
112
+ audio_duration = -1
113
+
114
+ batch_size = metadata.get('batch_size', 2)
115
+ inference_steps = metadata.get('inference_steps', 8)
116
+ guidance_scale = metadata.get('guidance_scale', 7.0)
117
+ seed = metadata.get('seed', '-1')
118
+ random_seed = False # Always set to False when loading to enable reproducibility with saved seed
119
+ use_adg = metadata.get('use_adg', False)
120
+ cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
121
+ cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
122
+ audio_format = metadata.get('audio_format', 'mp3')
123
+ lm_temperature = metadata.get('lm_temperature', 0.85)
124
+ lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
125
+ lm_top_k = metadata.get('lm_top_k', 0)
126
+ lm_top_p = metadata.get('lm_top_p', 0.9)
127
+ lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
128
+ use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
129
+ use_cot_caption = metadata.get('use_cot_caption', True)
130
+ use_cot_language = metadata.get('use_cot_language', True)
131
+ audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
132
+ think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
133
+ audio_codes = metadata.get('audio_codes', '')
134
+ repainting_start = metadata.get('repainting_start', 0.0)
135
+ repainting_end = metadata.get('repainting_end', -1)
136
+ track_name = metadata.get('track_name')
137
+ complete_track_classes = metadata.get('complete_track_classes', [])
138
+ shift = metadata.get('shift', 3.0) # Default 3.0 for base models
139
+ infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
140
+ custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
141
+ if custom_timesteps is None:
142
+ custom_timesteps = ''
143
+ instrumental = metadata.get('instrumental', False) # Added: read instrumental
144
+
145
+ gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
146
+
147
+ return (
148
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
149
+ audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
150
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
151
+ custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
152
+ audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
153
+ use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
154
+ think, audio_codes, repainting_start, repainting_end,
155
+ track_name, complete_track_classes, instrumental,
156
+ True # Set is_format_caption to True when loading from file
157
+ )
158
+
159
+ except json.JSONDecodeError as e:
160
+ gr.Warning(t("messages.invalid_json", error=str(e)))
161
+ return [None] * 36 + [False]
162
+ except Exception as e:
163
+ gr.Warning(t("messages.load_error", error=str(e)))
164
+ return [None] * 36 + [False]
165
+
166
+
167
+ def load_random_example(task_type: str):
168
+ """Load a random example from the task-specific examples directory
169
+
170
+ Args:
171
+ task_type: The task type (e.g., "text2music")
172
+
173
+ Returns:
174
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
175
+ """
176
+ try:
177
+ # Get the project root directory
178
+ current_file = os.path.abspath(__file__)
179
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
180
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
181
+
182
+ # Construct the examples directory path
183
+ examples_dir = os.path.join(project_root, "examples", task_type)
184
+
185
+ # Check if directory exists
186
+ if not os.path.exists(examples_dir):
187
+ gr.Warning(f"Examples directory not found: examples/{task_type}/")
188
+ return "", "", True, None, None, "", "", ""
189
+
190
+ # Find all JSON files in the directory
191
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
192
+
193
+ if not json_files:
194
+ gr.Warning(f"No JSON files found in examples/{task_type}/")
195
+ return "", "", True, None, None, "", "", ""
196
+
197
+ # Randomly select one file
198
+ selected_file = random.choice(json_files)
199
+
200
+ # Read and parse JSON
201
+ try:
202
+ with open(selected_file, 'r', encoding='utf-8') as f:
203
+ data = json.load(f)
204
+
205
+ # Extract caption (prefer 'caption', fallback to 'prompt')
206
+ caption_value = data.get('caption', data.get('prompt', ''))
207
+ if not isinstance(caption_value, str):
208
+ caption_value = str(caption_value) if caption_value else ''
209
+
210
+ # Extract lyrics
211
+ lyrics_value = data.get('lyrics', '')
212
+ if not isinstance(lyrics_value, str):
213
+ lyrics_value = str(lyrics_value) if lyrics_value else ''
214
+
215
+ # Extract think (default to True if not present)
216
+ think_value = data.get('think', True)
217
+ if not isinstance(think_value, bool):
218
+ think_value = True
219
+
220
+ # Extract optional metadata fields
221
+ bpm_value = None
222
+ if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
223
+ try:
224
+ bpm_value = int(data['bpm'])
225
+ except (ValueError, TypeError):
226
+ pass
227
+
228
+ duration_value = None
229
+ if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
230
+ try:
231
+ duration_value = float(data['duration'])
232
+ except (ValueError, TypeError):
233
+ pass
234
+
235
+ keyscale_value = data.get('keyscale', '')
236
+ if keyscale_value in [None, "N/A"]:
237
+ keyscale_value = ''
238
+
239
+ language_value = data.get('language', '')
240
+ if language_value in [None, "N/A"]:
241
+ language_value = ''
242
+
243
+ timesignature_value = data.get('timesignature', '')
244
+ if timesignature_value in [None, "N/A"]:
245
+ timesignature_value = ''
246
+
247
+ gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
248
+ return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
249
+
250
+ except json.JSONDecodeError as e:
251
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
252
+ return "", "", True, None, None, "", "", ""
253
+ except Exception as e:
254
+ gr.Warning(t("messages.example_error", error=str(e)))
255
+ return "", "", True, None, None, "", "", ""
256
+
257
+ except Exception as e:
258
+ gr.Warning(t("messages.example_error", error=str(e)))
259
+ return "", "", True, None, None, "", "", ""
260
+
261
+
262
+ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
263
+ """Smart sample function that uses LM if initialized, otherwise falls back to examples
264
+
265
+ This is a Gradio wrapper that uses the understand_music API from acestep.inference
266
+ to generate examples when LM is available.
267
+
268
+ Args:
269
+ llm_handler: LLM handler instance
270
+ task_type: The task type (e.g., "text2music")
271
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
272
+
273
+ Returns:
274
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
275
+ """
276
+ # Check if LM is initialized
277
+ if llm_handler.llm_initialized:
278
+ # Use LM to generate example via understand_music API
279
+ try:
280
+ result = understand_music(
281
+ llm_handler=llm_handler,
282
+ audio_codes="NO USER INPUT", # Empty input triggers example generation
283
+ temperature=0.85,
284
+ use_constrained_decoding=True,
285
+ constrained_decoding_debug=constrained_decoding_debug,
286
+ )
287
+
288
+ if result.success:
289
+ gr.Info(t("messages.lm_generated"))
290
+ return (
291
+ result.caption,
292
+ result.lyrics,
293
+ True, # Always enable think when using LM-generated examples
294
+ result.bpm,
295
+ result.duration,
296
+ result.keyscale,
297
+ result.language,
298
+ result.timesignature,
299
+ )
300
+ else:
301
+ gr.Warning(t("messages.lm_fallback"))
302
+ return load_random_example(task_type)
303
+
304
+ except Exception as e:
305
+ gr.Warning(t("messages.lm_fallback"))
306
+ return load_random_example(task_type)
307
+ else:
308
+ # LM not initialized, use examples directory
309
+ return load_random_example(task_type)
310
+
311
+
312
+ def load_random_simple_description():
313
+ """Load a random description from the simple_mode examples directory.
314
+
315
+ Returns:
316
+ Tuple of (description, instrumental, vocal_language) for updating UI components
317
+ """
318
+ try:
319
+ # Get the project root directory
320
+ current_file = os.path.abspath(__file__)
321
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
322
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
323
+
324
+ # Construct the examples directory path
325
+ examples_dir = os.path.join(project_root, "examples", "simple_mode")
326
+
327
+ # Check if directory exists
328
+ if not os.path.exists(examples_dir):
329
+ gr.Warning(t("messages.simple_examples_not_found"))
330
+ return gr.update(), gr.update(), gr.update()
331
+
332
+ # Find all JSON files in the directory
333
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
334
+
335
+ if not json_files:
336
+ gr.Warning(t("messages.simple_examples_empty"))
337
+ return gr.update(), gr.update(), gr.update()
338
+
339
+ # Randomly select one file
340
+ selected_file = random.choice(json_files)
341
+
342
+ # Read and parse JSON
343
+ try:
344
+ with open(selected_file, 'r', encoding='utf-8') as f:
345
+ data = json.load(f)
346
+
347
+ # Extract fields
348
+ description = data.get('description', '')
349
+ instrumental = data.get('instrumental', False)
350
+ vocal_language = data.get('vocal_language', 'unknown')
351
+
352
+ # Ensure vocal_language is a string
353
+ if isinstance(vocal_language, list):
354
+ vocal_language = vocal_language[0] if vocal_language else 'unknown'
355
+
356
+ gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
357
+ return description, instrumental, vocal_language
358
+
359
+ except json.JSONDecodeError as e:
360
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
361
+ return gr.update(), gr.update(), gr.update()
362
+ except Exception as e:
363
+ gr.Warning(t("messages.example_error", error=str(e)))
364
+ return gr.update(), gr.update(), gr.update()
365
+
366
+ except Exception as e:
367
+ gr.Warning(t("messages.example_error", error=str(e)))
368
+ return gr.update(), gr.update(), gr.update()
369
+
370
+
371
+ def refresh_checkpoints(dit_handler):
372
+ """Refresh available checkpoints"""
373
+ choices = dit_handler.get_available_checkpoints()
374
+ return gr.update(choices=choices)
375
+
376
+
377
+ def update_model_type_settings(config_path):
378
+ """Update UI settings based on model type (fallback when handler not initialized yet)
379
+
380
+ Note: This is used as a fallback when the user changes config_path dropdown
381
+ before initializing the model. The actual settings are determined by the
382
+ handler's is_turbo_model() method after initialization.
383
+ """
384
+ if config_path is None:
385
+ config_path = ""
386
+ config_path_lower = config_path.lower()
387
+
388
+ # Determine is_turbo based on config_path string
389
+ # This is a heuristic fallback - actual model type is determined after loading
390
+ if "turbo" in config_path_lower:
391
+ is_turbo = True
392
+ elif "base" in config_path_lower:
393
+ is_turbo = False
394
+ else:
395
+ # Default to turbo settings for unknown model types
396
+ is_turbo = True
397
+
398
+ return get_model_type_ui_settings(is_turbo)
399
+
400
+
401
+ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
402
+ """Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
403
+ # Initialize DiT handler
404
+ status, enable = dit_handler.initialize_service(
405
+ checkpoint, config_path, device,
406
+ use_flash_attention=use_flash_attention, compile_model=False,
407
+ offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
408
+ )
409
+
410
+ # Initialize LM handler if requested
411
+ if init_llm:
412
+ # Get checkpoint directory
413
+ current_file = os.path.abspath(__file__)
414
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
415
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
416
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
417
+
418
+ lm_status, lm_success = llm_handler.initialize(
419
+ checkpoint_dir=checkpoint_dir,
420
+ lm_model_path=lm_model_path,
421
+ backend=backend,
422
+ device=device,
423
+ offload_to_cpu=offload_to_cpu,
424
+ dtype=dit_handler.dtype
425
+ )
426
+
427
+ if lm_success:
428
+ status += f"\n{lm_status}"
429
+ else:
430
+ status += f"\n{lm_status}"
431
+ # Don't fail the entire initialization if LM fails, but log it
432
+ # Keep enable as is (DiT initialization result) even if LM fails
433
+
434
+ # Check if model is initialized - if so, collapse the accordion
435
+ is_model_initialized = dit_handler.model is not None
436
+ accordion_state = gr.Accordion(open=not is_model_initialized)
437
+
438
+ # Get model type settings based on actual loaded model
439
+ is_turbo = dit_handler.is_turbo_model()
440
+ model_type_settings = get_model_type_ui_settings(is_turbo)
441
+
442
+ return (
443
+ status,
444
+ gr.update(interactive=enable),
445
+ accordion_state,
446
+ *model_type_settings
447
+ )
448
+
449
+
450
+ def get_model_type_ui_settings(is_turbo: bool):
451
+ """Get UI settings based on whether the model is turbo or base"""
452
+ if is_turbo:
453
+ # Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
454
+ return (
455
+ gr.update(value=8, maximum=20, minimum=1), # inference_steps
456
+ gr.update(visible=False), # guidance_scale
457
+ gr.update(visible=False), # use_adg
458
+ gr.update(value=3.0, visible=True), # shift (show with default 3.0)
459
+ gr.update(visible=False), # cfg_interval_start
460
+ gr.update(visible=False), # cfg_interval_end
461
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
462
+ )
463
+ else:
464
+ # Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
465
+ return (
466
+ gr.update(value=32, maximum=200, minimum=1), # inference_steps
467
+ gr.update(visible=True), # guidance_scale
468
+ gr.update(visible=True), # use_adg
469
+ gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
470
+ gr.update(visible=True), # cfg_interval_start
471
+ gr.update(visible=True), # cfg_interval_end
472
+ gr.update(choices=TASK_TYPES_BASE), # task_type
473
+ )
474
+
475
+
476
+ def update_negative_prompt_visibility(init_llm_checked):
477
+ """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
478
+ return gr.update(visible=init_llm_checked)
479
+
480
+
481
+ def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
482
+ """Update audio_cover_strength visibility and label"""
483
+ # Show if task is cover OR if LM is initialized (but NOT for repaint mode)
484
+ # Repaint mode never shows this control
485
+ is_repaint = task_type_value == "repaint"
486
+ is_cover = task_type_value == "cover"
487
+ is_visible = is_cover or (init_llm_checked and not is_repaint)
488
+
489
+ # Change label based on context
490
+ if init_llm_checked and not is_cover:
491
+ label = "LM codes strength"
492
+ info = "Control how many denoising steps use LM-generated codes"
493
+ else:
494
+ label = "Audio Cover Strength"
495
+ info = "Control how many denoising steps use cover mode"
496
+
497
+ return gr.update(visible=is_visible, label=label, info=info)
498
+
499
+
500
+ def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
501
+ """Wrapper for converting src audio to codes"""
502
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
503
+ return codes_string
504
+
505
+
506
+ def update_instruction_ui(
507
+ dit_handler,
508
+ task_type_value: str,
509
+ track_name_value: Optional[str],
510
+ complete_track_classes_value: list,
511
+ audio_codes_content: str = "",
512
+ init_llm_checked: bool = False
513
+ ) -> tuple:
514
+ """Update instruction and UI visibility based on task type."""
515
+ instruction = dit_handler.generate_instruction(
516
+ task_type=task_type_value,
517
+ track_name=track_name_value,
518
+ complete_track_classes=complete_track_classes_value
519
+ )
520
+
521
+ # Show track_name for lego and extract
522
+ track_name_visible = task_type_value in ["lego", "extract"]
523
+ # Show complete_track_classes for complete
524
+ complete_visible = task_type_value == "complete"
525
+ # Show audio_cover_strength for cover OR when LM is initialized (but NOT for repaint)
526
+ is_repaint = task_type_value == "repaint"
527
+ is_cover = task_type_value == "cover"
528
+ audio_cover_strength_visible = is_cover or (init_llm_checked and not is_repaint)
529
+ # Determine label and info based on context
530
+ if init_llm_checked and not is_cover:
531
+ audio_cover_strength_label = "LM codes strength"
532
+ audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
533
+ else:
534
+ audio_cover_strength_label = "Audio Cover Strength"
535
+ audio_cover_strength_info = "Control how many denoising steps use cover mode"
536
+ # Show repainting controls for repaint and lego
537
+ repainting_visible = task_type_value in ["repaint", "lego"]
538
+ # Show text2music_audio_codes if task is text2music OR if it has content
539
+ # This allows it to stay visible even if user switches task type but has codes
540
+ has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
541
+ text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
542
+
543
+ return (
544
+ instruction, # instruction_display_gen
545
+ gr.update(visible=track_name_visible), # track_name
546
+ gr.update(visible=complete_visible), # complete_track_classes
547
+ gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
548
+ gr.update(visible=repainting_visible), # repainting_group
549
+ gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
550
+ )
551
+
552
+
553
+ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
554
+ """
555
+ Transcribe audio codes to metadata using LLM understanding.
556
+ If audio_code_string is empty, generate a sample example instead.
557
+
558
+ This is a Gradio wrapper around the understand_music API in acestep.inference.
559
+
560
+ Args:
561
+ llm_handler: LLM handler instance
562
+ audio_code_string: String containing audio codes (or empty for example generation)
563
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
564
+
565
+ Returns:
566
+ Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
567
+ """
568
+ # Call the inference API
569
+ result = understand_music(
570
+ llm_handler=llm_handler,
571
+ audio_codes=audio_code_string,
572
+ use_constrained_decoding=True,
573
+ constrained_decoding_debug=constrained_decoding_debug,
574
+ )
575
+
576
+ # Handle error case with localized message
577
+ if not result.success:
578
+ # Use localized error message for LLM not initialized
579
+ if result.error == "LLM not initialized":
580
+ return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
581
+ return result.status_message, "", "", None, None, "", "", "", False
582
+
583
+ return (
584
+ result.status_message,
585
+ result.caption,
586
+ result.lyrics,
587
+ result.bpm,
588
+ result.duration,
589
+ result.keyscale,
590
+ result.language,
591
+ result.timesignature,
592
+ True # Set is_format_caption to True (from Transcribe/LM understanding)
593
+ )
594
+
595
+
596
+ def update_transcribe_button_text(audio_code_string):
597
+ """
598
+ Update the transcribe button text based on input content.
599
+ If empty: "Generate Example"
600
+ If has content: "Transcribe"
601
+ """
602
+ if not audio_code_string or not audio_code_string.strip():
603
+ return gr.update(value="Generate Example")
604
+ else:
605
+ return gr.update(value="Transcribe")
606
+
607
+
608
+ def reset_format_caption_flag():
609
+ """Reset is_format_caption to False when user manually edits caption/metadata"""
610
+ return False
611
+
612
+
613
+ def update_audio_uploads_accordion(reference_audio, src_audio):
614
+ """Update Audio Uploads visibility based on whether audio files are present"""
615
+ has_audio = (reference_audio is not None) or (src_audio is not None)
616
+ return gr.update(visible=has_audio)
617
+
618
+
619
+ def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
620
+ """
621
+ Handle instrumental checkbox changes.
622
+ When checked: if no lyrics, fill with [Instrumental]
623
+ When unchecked: if lyrics is [Instrumental], clear it
624
+ """
625
+ if instrumental_checked:
626
+ # If checked and no lyrics, fill with [Instrumental]
627
+ if not current_lyrics or not current_lyrics.strip():
628
+ return "[Instrumental]"
629
+ else:
630
+ # Has lyrics, don't change
631
+ return current_lyrics
632
+ else:
633
+ # If unchecked and lyrics is exactly [Instrumental], clear it
634
+ if current_lyrics and current_lyrics.strip() == "[Instrumental]":
635
+ return ""
636
+ else:
637
+ # Has other lyrics, don't change
638
+ return current_lyrics
639
+
640
+
641
+ def handle_simple_instrumental_change(is_instrumental: bool):
642
+ """
643
+ Handle simple mode instrumental checkbox changes.
644
+ When checked: set vocal_language to "unknown" and disable editing.
645
+ When unchecked: enable vocal_language editing.
646
+
647
+ Args:
648
+ is_instrumental: Whether instrumental checkbox is checked
649
+
650
+ Returns:
651
+ gr.update for simple_vocal_language dropdown
652
+ """
653
+ if is_instrumental:
654
+ return gr.update(value="unknown", interactive=False)
655
+ else:
656
+ return gr.update(interactive=True)
657
+
658
+
659
+ def update_audio_components_visibility(batch_size):
660
+ """Show/hide individual audio components based on batch size (1-8)
661
+
662
+ Row 1: Components 1-4 (batch_size 1-4)
663
+ Row 2: Components 5-8 (batch_size 5-8)
664
+ """
665
+ # Clamp batch size to 1-8 range for UI
666
+ batch_size = min(max(int(batch_size), 1), 8)
667
+
668
+ # Row 1 columns (1-4)
669
+ updates_row1 = (
670
+ gr.update(visible=True), # audio_col_1: always visible
671
+ gr.update(visible=batch_size >= 2), # audio_col_2
672
+ gr.update(visible=batch_size >= 3), # audio_col_3
673
+ gr.update(visible=batch_size >= 4), # audio_col_4
674
+ )
675
+
676
+ # Row 2 container and columns (5-8)
677
+ show_row_5_8 = batch_size >= 5
678
+ updates_row2 = (
679
+ gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
680
+ gr.update(visible=batch_size >= 5), # audio_col_5
681
+ gr.update(visible=batch_size >= 6), # audio_col_6
682
+ gr.update(visible=batch_size >= 7), # audio_col_7
683
+ gr.update(visible=batch_size >= 8), # audio_col_8
684
+ )
685
+
686
+ return updates_row1 + updates_row2
687
+
688
+
689
+ def handle_generation_mode_change(mode: str):
690
+ """
691
+ Handle generation mode change between Simple, Custom, Cover, and Repaint modes.
692
+
693
+ Modes:
694
+ - Simple: Show simple mode group, hide others
695
+ - Custom: Show custom content (prompt), hide others
696
+ - Cover: Show src_audio_group + custom content + LM codes strength
697
+ - Repaint: Show src_audio_group + custom content + repaint time controls (hide LM codes strength)
698
+
699
+ Args:
700
+ mode: "simple", "custom", "cover", or "repaint"
701
+
702
+ Returns:
703
+ Tuple of updates for:
704
+ - simple_mode_group (visibility)
705
+ - custom_mode_content (visibility)
706
+ - cover_mode_group (visibility) - legacy, always hidden
707
+ - repainting_group (visibility)
708
+ - task_type (value)
709
+ - generate_btn (interactive state)
710
+ - simple_sample_created (reset state)
711
+ - src_audio_group (visibility) - shown for cover and repaint
712
+ - audio_cover_strength (visibility) - shown only for cover mode
713
+ - think_checkbox (value and interactive) - disabled for cover/repaint modes
714
+ """
715
+ is_simple = mode == "simple"
716
+ is_custom = mode == "custom"
717
+ is_cover = mode == "cover"
718
+ is_repaint = mode == "repaint"
719
+
720
+ # Map mode to task_type
721
+ task_type_map = {
722
+ "simple": "text2music",
723
+ "custom": "text2music",
724
+ "cover": "cover",
725
+ "repaint": "repaint",
726
+ }
727
+ task_type_value = task_type_map.get(mode, "text2music")
728
+
729
+ # think_checkbox: disabled and set to False for cover/repaint modes
730
+ # (these modes don't use LM thinking, they use source audio codes)
731
+ if is_cover or is_repaint:
732
+ think_checkbox_update = gr.update(value=False, interactive=False)
733
+ else:
734
+ think_checkbox_update = gr.update(value=True, interactive=True)
735
+
736
+ return (
737
+ gr.update(visible=is_simple), # simple_mode_group
738
+ gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
739
+ gr.update(visible=False), # cover_mode_group - legacy, always hidden
740
+ gr.update(visible=is_repaint), # repainting_group - time range controls
741
+ gr.update(value=task_type_value), # task_type
742
+ gr.update(interactive=True), # generate_btn - always enabled (Simple mode does create+generate in one step)
743
+ False, # simple_sample_created - reset to False on mode change
744
+ gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
745
+ gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
746
+ think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
747
+ )
748
+
749
+
750
+ def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
751
+ """
752
+ Process source audio: convert to codes and then transcribe.
753
+ This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
754
+
755
+ Args:
756
+ dit_handler: DiT handler instance for audio code conversion
757
+ llm_handler: LLM handler instance for transcription
758
+ src_audio: Path to source audio file
759
+ constrained_decoding_debug: Whether to enable debug logging
760
+
761
+ Returns:
762
+ Tuple of (audio_codes, status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
763
+ """
764
+ if src_audio is None:
765
+ return ("", "No audio file provided", "", "", None, None, "", "", "", False)
766
+
767
+ # Step 1: Convert audio to codes
768
+ try:
769
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
770
+ if not codes_string:
771
+ return ("", "Failed to convert audio to codes", "", "", None, None, "", "", "", False)
772
+ except Exception as e:
773
+ return ("", f"Error converting audio: {str(e)}", "", "", None, None, "", "", "", False)
774
+
775
+ # Step 2: Transcribe the codes
776
+ result = understand_music(
777
+ llm_handler=llm_handler,
778
+ audio_codes=codes_string,
779
+ use_constrained_decoding=True,
780
+ constrained_decoding_debug=constrained_decoding_debug,
781
+ )
782
+
783
+ # Handle error case
784
+ if not result.success:
785
+ if result.error == "LLM not initialized":
786
+ return (codes_string, t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False)
787
+ return (codes_string, result.status_message, "", "", None, None, "", "", "", False)
788
+
789
+ return (
790
+ codes_string,
791
+ result.status_message,
792
+ result.caption,
793
+ result.lyrics,
794
+ result.bpm,
795
+ result.duration,
796
+ result.keyscale,
797
+ result.language,
798
+ result.timesignature,
799
+ True # Set is_format_caption to True
800
+ )
801
+
802
+
803
+ def handle_create_sample(
804
+ llm_handler,
805
+ query: str,
806
+ instrumental: bool,
807
+ vocal_language: str,
808
+ lm_temperature: float,
809
+ lm_top_k: int,
810
+ lm_top_p: float,
811
+ constrained_decoding_debug: bool = False,
812
+ ):
813
+ """
814
+ Handle the Create Sample button click in Simple mode.
815
+
816
+ Creates a sample from the user's query using the LLM, then populates
817
+ the caption, lyrics, and metadata fields.
818
+
819
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
820
+
821
+ Args:
822
+ llm_handler: LLM handler instance
823
+ query: User's natural language music description
824
+ instrumental: Whether to generate instrumental music
825
+ vocal_language: Preferred vocal language for constrained decoding
826
+ lm_temperature: LLM temperature for generation
827
+ lm_top_k: LLM top-k sampling
828
+ lm_top_p: LLM top-p sampling
829
+ constrained_decoding_debug: Whether to enable debug logging
830
+
831
+ Returns:
832
+ Tuple of updates for:
833
+ - captions
834
+ - lyrics
835
+ - bpm
836
+ - audio_duration
837
+ - key_scale
838
+ - vocal_language
839
+ - time_signature
840
+ - instrumental_checkbox
841
+ - caption_accordion (open)
842
+ - lyrics_accordion (open)
843
+ - generate_btn (interactive)
844
+ - simple_sample_created (True)
845
+ - think_checkbox (True)
846
+ - is_format_caption_state (True)
847
+ - status_output
848
+ """
849
+ # Check if LLM is initialized
850
+ if not llm_handler.llm_initialized:
851
+ gr.Warning(t("messages.lm_not_initialized"))
852
+ return (
853
+ gr.update(), # captions - no change
854
+ gr.update(), # lyrics - no change
855
+ gr.update(), # bpm - no change
856
+ gr.update(), # audio_duration - no change
857
+ gr.update(), # key_scale - no change
858
+ gr.update(), # vocal_language - no change
859
+ gr.update(), # time_signature - no change
860
+ gr.update(), # instrumental_checkbox - no change
861
+ gr.update(), # caption_accordion - no change
862
+ gr.update(), # lyrics_accordion - no change
863
+ gr.update(interactive=False), # generate_btn - keep disabled
864
+ False, # simple_sample_created - still False
865
+ gr.update(), # think_checkbox - no change
866
+ gr.update(), # is_format_caption_state - no change
867
+ t("messages.lm_not_initialized"), # status_output
868
+ )
869
+
870
+ # Convert LM parameters
871
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
872
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
873
+
874
+ # Call create_sample API
875
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
876
+ result = create_sample(
877
+ llm_handler=llm_handler,
878
+ query=query,
879
+ instrumental=instrumental,
880
+ vocal_language=vocal_language,
881
+ temperature=lm_temperature,
882
+ top_k=top_k_value,
883
+ top_p=top_p_value,
884
+ use_constrained_decoding=True,
885
+ constrained_decoding_debug=constrained_decoding_debug,
886
+ )
887
+
888
+ # Handle error
889
+ if not result.success:
890
+ gr.Warning(result.status_message or t("messages.sample_creation_failed"))
891
+ return (
892
+ gr.update(), # captions - no change
893
+ gr.update(), # lyrics - no change
894
+ gr.update(), # bpm - no change
895
+ gr.update(), # audio_duration - no change
896
+ gr.update(), # key_scale - no change
897
+ gr.update(), # vocal_language - no change
898
+ gr.update(), # simple vocal_language - no change
899
+ gr.update(), # time_signature - no change
900
+ gr.update(), # instrumental_checkbox - no change
901
+ gr.update(), # caption_accordion - no change
902
+ gr.update(), # lyrics_accordion - no change
903
+ gr.update(interactive=False), # generate_btn - keep disabled
904
+ False, # simple_sample_created - still False
905
+ gr.update(), # think_checkbox - no change
906
+ gr.update(), # is_format_caption_state - no change
907
+ result.status_message or t("messages.sample_creation_failed"), # status_output
908
+ )
909
+
910
+ # Success - populate fields
911
+ gr.Info(t("messages.sample_created"))
912
+
913
+ return (
914
+ result.caption, # captions
915
+ result.lyrics, # lyrics
916
+ result.bpm, # bpm
917
+ result.duration if result.duration and result.duration > 0 else -1, # audio_duration
918
+ result.keyscale, # key_scale
919
+ result.language, # vocal_language
920
+ result.language, # simple vocal_language
921
+ result.timesignature, # time_signature
922
+ result.instrumental, # instrumental_checkbox
923
+ gr.Accordion(open=True), # caption_accordion - expand
924
+ gr.Accordion(open=True), # lyrics_accordion - expand
925
+ gr.update(interactive=True), # generate_btn - enable
926
+ True, # simple_sample_created - True
927
+ True, # think_checkbox - enable thinking
928
+ True, # is_format_caption_state - True (LM-generated)
929
+ result.status_message, # status_output
930
+ )
931
+
932
+
933
+ def handle_format_sample(
934
+ llm_handler,
935
+ caption: str,
936
+ lyrics: str,
937
+ bpm,
938
+ audio_duration,
939
+ key_scale: str,
940
+ time_signature: str,
941
+ lm_temperature: float,
942
+ lm_top_k: int,
943
+ lm_top_p: float,
944
+ constrained_decoding_debug: bool = False,
945
+ ):
946
+ """
947
+ Handle the Format button click to format caption and lyrics.
948
+
949
+ Takes user-provided caption and lyrics, and uses the LLM to generate
950
+ structured music metadata and an enhanced description.
951
+
952
+ Note: cfg_scale and negative_prompt are not supported in format mode.
953
+
954
+ Args:
955
+ llm_handler: LLM handler instance
956
+ caption: User's caption/description
957
+ lyrics: User's lyrics
958
+ bpm: User-provided BPM (optional, for constrained decoding)
959
+ audio_duration: User-provided duration (optional, for constrained decoding)
960
+ key_scale: User-provided key scale (optional, for constrained decoding)
961
+ time_signature: User-provided time signature (optional, for constrained decoding)
962
+ lm_temperature: LLM temperature for generation
963
+ lm_top_k: LLM top-k sampling
964
+ lm_top_p: LLM top-p sampling
965
+ constrained_decoding_debug: Whether to enable debug logging
966
+
967
+ Returns:
968
+ Tuple of updates for:
969
+ - captions
970
+ - lyrics
971
+ - bpm
972
+ - audio_duration
973
+ - key_scale
974
+ - vocal_language
975
+ - time_signature
976
+ - is_format_caption_state
977
+ - status_output
978
+ """
979
+ # Check if LLM is initialized
980
+ if not llm_handler.llm_initialized:
981
+ gr.Warning(t("messages.lm_not_initialized"))
982
+ return (
983
+ gr.update(), # captions - no change
984
+ gr.update(), # lyrics - no change
985
+ gr.update(), # bpm - no change
986
+ gr.update(), # audio_duration - no change
987
+ gr.update(), # key_scale - no change
988
+ gr.update(), # vocal_language - no change
989
+ gr.update(), # time_signature - no change
990
+ gr.update(), # is_format_caption_state - no change
991
+ t("messages.lm_not_initialized"), # status_output
992
+ )
993
+
994
+ # Build user_metadata from provided values for constrained decoding
995
+ user_metadata = {}
996
+ if bpm is not None and bpm > 0:
997
+ user_metadata['bpm'] = int(bpm)
998
+ if audio_duration is not None and audio_duration > 0:
999
+ user_metadata['duration'] = int(audio_duration)
1000
+ if key_scale and key_scale.strip():
1001
+ user_metadata['keyscale'] = key_scale.strip()
1002
+ if time_signature and time_signature.strip():
1003
+ user_metadata['timesignature'] = time_signature.strip()
1004
+
1005
+ # Only pass user_metadata if we have at least one field
1006
+ user_metadata_to_pass = user_metadata if user_metadata else None
1007
+
1008
+ # Convert LM parameters
1009
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
1010
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
1011
+
1012
+ # Call format_sample API
1013
+ result = format_sample(
1014
+ llm_handler=llm_handler,
1015
+ caption=caption,
1016
+ lyrics=lyrics,
1017
+ user_metadata=user_metadata_to_pass,
1018
+ temperature=lm_temperature,
1019
+ top_k=top_k_value,
1020
+ top_p=top_p_value,
1021
+ use_constrained_decoding=True,
1022
+ constrained_decoding_debug=constrained_decoding_debug,
1023
+ )
1024
+
1025
+ # Handle error
1026
+ if not result.success:
1027
+ gr.Warning(result.status_message or t("messages.format_failed"))
1028
+ return (
1029
+ gr.update(), # captions - no change
1030
+ gr.update(), # lyrics - no change
1031
+ gr.update(), # bpm - no change
1032
+ gr.update(), # audio_duration - no change
1033
+ gr.update(), # key_scale - no change
1034
+ gr.update(), # vocal_language - no change
1035
+ gr.update(), # time_signature - no change
1036
+ gr.update(), # is_format_caption_state - no change
1037
+ result.status_message or t("messages.format_failed"), # status_output
1038
+ )
1039
+
1040
+ # Success - populate fields
1041
+ gr.Info(t("messages.format_success"))
1042
+
1043
+ return (
1044
+ result.caption, # captions
1045
+ result.lyrics, # lyrics
1046
+ result.bpm, # bpm
1047
+ result.duration if result.duration and result.duration > 0 else -1, # audio_duration
1048
+ result.keyscale, # key_scale
1049
+ result.language, # vocal_language
1050
+ result.timesignature, # time_signature
1051
+ True, # is_format_caption_state - True (LM-formatted)
1052
+ result.status_message, # status_output
1053
+ )
1054
+
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/results_handlers.py ADDED
The diff for this file is too large to render. See raw diff
 
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/training_handlers.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Handlers for Training Tab
3
+
4
+ Contains all event handler functions for the dataset builder and training UI.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Any, Dict, List, Tuple, Optional
10
+ from loguru import logger
11
+ import gradio as gr
12
+
13
+ from acestep.training.dataset_builder import DatasetBuilder, AudioSample
14
+
15
+
16
+ def create_dataset_builder() -> DatasetBuilder:
17
+ """Create a new DatasetBuilder instance."""
18
+ return DatasetBuilder()
19
+
20
+
21
+ def scan_directory(
22
+ audio_dir: str,
23
+ dataset_name: str,
24
+ custom_tag: str,
25
+ tag_position: str,
26
+ all_instrumental: bool,
27
+ builder_state: Optional[DatasetBuilder],
28
+ ) -> Tuple[Any, str, Any, DatasetBuilder]:
29
+ """Scan a directory for audio files.
30
+
31
+ Returns:
32
+ Tuple of (table_data, status, slider_update, builder_state)
33
+ """
34
+ if not audio_dir or not audio_dir.strip():
35
+ return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
36
+
37
+ # Create or use existing builder
38
+ builder = builder_state if builder_state else DatasetBuilder()
39
+
40
+ # Set metadata before scanning
41
+ builder.metadata.name = dataset_name
42
+ builder.metadata.custom_tag = custom_tag
43
+ builder.metadata.tag_position = tag_position
44
+ builder.metadata.all_instrumental = all_instrumental
45
+
46
+ # Scan directory
47
+ samples, status = builder.scan_directory(audio_dir.strip())
48
+
49
+ if not samples:
50
+ return [], status, gr.Slider(maximum=0, value=0), builder
51
+
52
+ # Set instrumental and tag for all samples
53
+ builder.set_all_instrumental(all_instrumental)
54
+ if custom_tag:
55
+ builder.set_custom_tag(custom_tag, tag_position)
56
+
57
+ # Get table data
58
+ table_data = builder.get_samples_dataframe_data()
59
+
60
+ # Calculate slider max and return as Slider update
61
+ slider_max = max(0, len(samples) - 1)
62
+
63
+ return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
64
+
65
+
66
+ def auto_label_all(
67
+ dit_handler,
68
+ llm_handler,
69
+ builder_state: Optional[DatasetBuilder],
70
+ skip_metas: bool = False,
71
+ progress=None,
72
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
73
+ """Auto-label all samples in the dataset.
74
+
75
+ Args:
76
+ dit_handler: DiT handler for audio processing
77
+ llm_handler: LLM handler for caption generation
78
+ builder_state: Dataset builder state
79
+ skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
80
+ progress: Progress callback
81
+
82
+ Returns:
83
+ Tuple of (table_data, status, builder_state)
84
+ """
85
+ if builder_state is None:
86
+ return [], "❌ Please scan a directory first", builder_state
87
+
88
+ if not builder_state.samples:
89
+ return [], "❌ No samples to label. Please scan a directory first.", builder_state
90
+
91
+ # If skip_metas is True, just set default values without LLM
92
+ if skip_metas:
93
+ for sample in builder_state.samples:
94
+ sample.bpm = None # Will display as N/A
95
+ sample.keyscale = "N/A"
96
+ sample.timesignature = "N/A"
97
+ # For instrumental, language should be "unknown"
98
+ if sample.is_instrumental:
99
+ sample.language = "unknown"
100
+ else:
101
+ sample.language = "unknown"
102
+ # Use custom tag as caption if set, otherwise use filename
103
+ if builder_state.metadata.custom_tag:
104
+ sample.caption = builder_state.metadata.custom_tag
105
+ else:
106
+ sample.caption = sample.filename
107
+
108
+ table_data = builder_state.get_samples_dataframe_data()
109
+ return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
110
+
111
+ # Check if handlers are initialized
112
+ if dit_handler is None or dit_handler.model is None:
113
+ return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
114
+
115
+ if llm_handler is None or not llm_handler.llm_initialized:
116
+ return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
117
+
118
+ def progress_callback(msg):
119
+ if progress:
120
+ try:
121
+ progress(msg)
122
+ except:
123
+ pass
124
+
125
+ # Label all samples
126
+ samples, status = builder_state.label_all_samples(
127
+ dit_handler=dit_handler,
128
+ llm_handler=llm_handler,
129
+ progress_callback=progress_callback,
130
+ )
131
+
132
+ # Get updated table data
133
+ table_data = builder_state.get_samples_dataframe_data()
134
+
135
+ return table_data, status, builder_state
136
+
137
+
138
+ def get_sample_preview(
139
+ sample_idx: int,
140
+ builder_state: Optional[DatasetBuilder],
141
+ ) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
142
+ """Get preview data for a specific sample.
143
+
144
+ Returns:
145
+ Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
146
+ """
147
+ if builder_state is None or not builder_state.samples:
148
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
149
+
150
+ idx = int(sample_idx)
151
+ if idx < 0 or idx >= len(builder_state.samples):
152
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
153
+
154
+ sample = builder_state.samples[idx]
155
+
156
+ return (
157
+ sample.audio_path,
158
+ sample.filename,
159
+ sample.caption,
160
+ sample.lyrics,
161
+ sample.bpm,
162
+ sample.keyscale,
163
+ sample.timesignature,
164
+ sample.duration,
165
+ sample.language,
166
+ sample.is_instrumental,
167
+ )
168
+
169
+
170
+ def save_sample_edit(
171
+ sample_idx: int,
172
+ caption: str,
173
+ lyrics: str,
174
+ bpm: Optional[int],
175
+ keyscale: str,
176
+ timesig: str,
177
+ language: str,
178
+ is_instrumental: bool,
179
+ builder_state: Optional[DatasetBuilder],
180
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
181
+ """Save edits to a sample.
182
+
183
+ Returns:
184
+ Tuple of (table_data, status, builder_state)
185
+ """
186
+ if builder_state is None:
187
+ return [], "❌ No dataset loaded", builder_state
188
+
189
+ idx = int(sample_idx)
190
+
191
+ # Update sample
192
+ sample, status = builder_state.update_sample(
193
+ idx,
194
+ caption=caption,
195
+ lyrics=lyrics if not is_instrumental else "[Instrumental]",
196
+ bpm=int(bpm) if bpm else None,
197
+ keyscale=keyscale,
198
+ timesignature=timesig,
199
+ language="instrumental" if is_instrumental else language,
200
+ is_instrumental=is_instrumental,
201
+ labeled=True,
202
+ )
203
+
204
+ # Get updated table data
205
+ table_data = builder_state.get_samples_dataframe_data()
206
+
207
+ return table_data, status, builder_state
208
+
209
+
210
+ def update_settings(
211
+ custom_tag: str,
212
+ tag_position: str,
213
+ all_instrumental: bool,
214
+ builder_state: Optional[DatasetBuilder],
215
+ ) -> DatasetBuilder:
216
+ """Update dataset settings.
217
+
218
+ Returns:
219
+ Updated builder_state
220
+ """
221
+ if builder_state is None:
222
+ return builder_state
223
+
224
+ if custom_tag:
225
+ builder_state.set_custom_tag(custom_tag, tag_position)
226
+
227
+ builder_state.set_all_instrumental(all_instrumental)
228
+
229
+ return builder_state
230
+
231
+
232
+ def save_dataset(
233
+ save_path: str,
234
+ dataset_name: str,
235
+ builder_state: Optional[DatasetBuilder],
236
+ ) -> str:
237
+ """Save the dataset to a JSON file.
238
+
239
+ Returns:
240
+ Status message
241
+ """
242
+ if builder_state is None:
243
+ return "❌ No dataset to save. Please scan a directory first."
244
+
245
+ if not builder_state.samples:
246
+ return "❌ No samples in dataset."
247
+
248
+ if not save_path or not save_path.strip():
249
+ return "❌ Please enter a save path."
250
+
251
+ # Check if any samples are labeled
252
+ labeled_count = builder_state.get_labeled_count()
253
+ if labeled_count == 0:
254
+ return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
255
+
256
+ return builder_state.save_dataset(save_path.strip(), dataset_name)
257
+
258
+
259
+ def load_existing_dataset_for_preprocess(
260
+ dataset_path: str,
261
+ builder_state: Optional[DatasetBuilder],
262
+ ) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
263
+ """Load an existing dataset JSON file for preprocessing.
264
+
265
+ This allows users to load a previously saved dataset and proceed to preprocessing
266
+ without having to re-scan and re-label.
267
+
268
+ Returns:
269
+ Tuple of (status, table_data, slider_update, builder_state,
270
+ audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
271
+ """
272
+ empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
273
+
274
+ if not dataset_path or not dataset_path.strip():
275
+ return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
276
+
277
+ dataset_path = dataset_path.strip()
278
+
279
+ if not os.path.exists(dataset_path):
280
+ return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
281
+
282
+ # Create new builder (don't reuse old state when loading a file)
283
+ builder = DatasetBuilder()
284
+
285
+ # Load the dataset
286
+ samples, status = builder.load_dataset(dataset_path)
287
+
288
+ if not samples:
289
+ return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
290
+
291
+ # Get table data
292
+ table_data = builder.get_samples_dataframe_data()
293
+
294
+ # Calculate slider max
295
+ slider_max = max(0, len(samples) - 1)
296
+
297
+ # Create info text
298
+ labeled_count = builder.get_labeled_count()
299
+ info = f"✅ Loaded dataset: {builder.metadata.name}\n"
300
+ info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
301
+ info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
302
+ info += "📝 Ready for preprocessing! You can also edit samples below."
303
+
304
+ # Get first sample preview
305
+ first_sample = builder.samples[0]
306
+ preview = (
307
+ first_sample.audio_path,
308
+ first_sample.filename,
309
+ first_sample.caption,
310
+ first_sample.lyrics,
311
+ first_sample.bpm,
312
+ first_sample.keyscale,
313
+ first_sample.timesignature,
314
+ first_sample.duration,
315
+ first_sample.language,
316
+ first_sample.is_instrumental,
317
+ )
318
+
319
+ return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
320
+
321
+
322
+ def preprocess_dataset(
323
+ output_dir: str,
324
+ dit_handler,
325
+ builder_state: Optional[DatasetBuilder],
326
+ progress=None,
327
+ ) -> str:
328
+ """Preprocess dataset to tensor files for fast training.
329
+
330
+ This converts audio files to VAE latents and text to embeddings.
331
+
332
+ Returns:
333
+ Status message
334
+ """
335
+ if builder_state is None:
336
+ return "❌ No dataset loaded. Please scan a directory first."
337
+
338
+ if not builder_state.samples:
339
+ return "❌ No samples in dataset."
340
+
341
+ labeled_count = builder_state.get_labeled_count()
342
+ if labeled_count == 0:
343
+ return "❌ No labeled samples. Please auto-label or manually label samples first."
344
+
345
+ if not output_dir or not output_dir.strip():
346
+ return "❌ Please enter an output directory."
347
+
348
+ if dit_handler is None or dit_handler.model is None:
349
+ return "❌ Model not initialized. Please initialize the service first."
350
+
351
+ def progress_callback(msg):
352
+ if progress:
353
+ try:
354
+ progress(msg)
355
+ except:
356
+ pass
357
+
358
+ # Run preprocessing
359
+ output_paths, status = builder_state.preprocess_to_tensors(
360
+ dit_handler=dit_handler,
361
+ output_dir=output_dir.strip(),
362
+ progress_callback=progress_callback,
363
+ )
364
+
365
+ return status
366
+
367
+
368
+ def load_training_dataset(
369
+ tensor_dir: str,
370
+ ) -> str:
371
+ """Load a preprocessed tensor dataset for training.
372
+
373
+ Returns:
374
+ Info text about the dataset
375
+ """
376
+ if not tensor_dir or not tensor_dir.strip():
377
+ return "❌ Please enter a tensor directory path"
378
+
379
+ tensor_dir = tensor_dir.strip()
380
+
381
+ if not os.path.exists(tensor_dir):
382
+ return f"❌ Directory not found: {tensor_dir}"
383
+
384
+ if not os.path.isdir(tensor_dir):
385
+ return f"❌ Not a directory: {tensor_dir}"
386
+
387
+ # Check for manifest
388
+ manifest_path = os.path.join(tensor_dir, "manifest.json")
389
+ if os.path.exists(manifest_path):
390
+ try:
391
+ with open(manifest_path, 'r') as f:
392
+ manifest = json.load(f)
393
+
394
+ num_samples = manifest.get("num_samples", 0)
395
+ metadata = manifest.get("metadata", {})
396
+ name = metadata.get("name", "Unknown")
397
+ custom_tag = metadata.get("custom_tag", "")
398
+
399
+ info = f"✅ Loaded preprocessed dataset: {name}\n"
400
+ info += f"📊 Samples: {num_samples} preprocessed tensors\n"
401
+ info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
402
+
403
+ return info
404
+ except Exception as e:
405
+ logger.warning(f"Failed to read manifest: {e}")
406
+
407
+ # Fallback: count .pt files
408
+ pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
409
+
410
+ if not pt_files:
411
+ return f"❌ No .pt tensor files found in {tensor_dir}"
412
+
413
+ info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
414
+ info += "⚠️ No manifest.json found - using all .pt files"
415
+
416
+ return info
417
+
418
+
419
+ # Training handlers
420
+
421
+ import time
422
+ import re
423
+
424
+
425
+ def _format_duration(seconds):
426
+ """Format seconds to human readable string."""
427
+ seconds = int(seconds)
428
+ if seconds < 60:
429
+ return f"{seconds}s"
430
+ elif seconds < 3600:
431
+ return f"{seconds // 60}m {seconds % 60}s"
432
+ else:
433
+ return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
434
+
435
+
436
+ def start_training(
437
+ tensor_dir: str,
438
+ dit_handler,
439
+ lora_rank: int,
440
+ lora_alpha: int,
441
+ lora_dropout: float,
442
+ learning_rate: float,
443
+ train_epochs: int,
444
+ train_batch_size: int,
445
+ gradient_accumulation: int,
446
+ save_every_n_epochs: int,
447
+ training_shift: float,
448
+ training_seed: int,
449
+ lora_output_dir: str,
450
+ training_state: Dict,
451
+ progress=None,
452
+ ):
453
+ """Start LoRA training from preprocessed tensors.
454
+
455
+ This is a generator function that yields progress updates.
456
+ """
457
+ if not tensor_dir or not tensor_dir.strip():
458
+ yield "❌ Please enter a tensor directory path", "", None, training_state
459
+ return
460
+
461
+ tensor_dir = tensor_dir.strip()
462
+
463
+ if not os.path.exists(tensor_dir):
464
+ yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
465
+ return
466
+
467
+ if dit_handler is None or dit_handler.model is None:
468
+ yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
469
+ return
470
+
471
+ # Check for required training dependencies
472
+ try:
473
+ from lightning.fabric import Fabric
474
+ from peft import get_peft_model, LoraConfig
475
+ except ImportError as e:
476
+ yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
477
+ return
478
+
479
+ training_state["is_training"] = True
480
+ training_state["should_stop"] = False
481
+
482
+ try:
483
+ from acestep.training.trainer import LoRATrainer
484
+ from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
485
+
486
+ # Create configs
487
+ lora_config = LoRAConfigClass(
488
+ r=lora_rank,
489
+ alpha=lora_alpha,
490
+ dropout=lora_dropout,
491
+ )
492
+
493
+ training_config = TrainingConfig(
494
+ shift=training_shift,
495
+ learning_rate=learning_rate,
496
+ batch_size=train_batch_size,
497
+ gradient_accumulation_steps=gradient_accumulation,
498
+ max_epochs=train_epochs,
499
+ save_every_n_epochs=save_every_n_epochs,
500
+ seed=training_seed,
501
+ output_dir=lora_output_dir,
502
+ )
503
+
504
+ import pandas as pd
505
+
506
+ # Initialize training log and loss history
507
+ log_lines = []
508
+ loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
509
+
510
+ # Start timer
511
+ start_time = time.time()
512
+
513
+ yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
514
+
515
+ # Create trainer
516
+ trainer = LoRATrainer(
517
+ dit_handler=dit_handler,
518
+ lora_config=lora_config,
519
+ training_config=training_config,
520
+ )
521
+
522
+ # Collect loss history
523
+ step_list = []
524
+ loss_list = []
525
+
526
+ # Train with progress updates using preprocessed tensors
527
+ for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
528
+ # Calculate elapsed time and ETA
529
+ elapsed_seconds = time.time() - start_time
530
+ time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
531
+
532
+ # Parse "Epoch x/y" from status to calculate ETA
533
+ match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
534
+ if match:
535
+ current_ep = int(match.group(1))
536
+ total_ep = int(match.group(2))
537
+ if current_ep > 0:
538
+ eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
539
+ time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
540
+
541
+ # Display status with time info
542
+ display_status = f"{status}\n{time_info}"
543
+
544
+ # Terminal log
545
+ log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
546
+ logger.info(log_msg)
547
+
548
+ # Add to UI log
549
+ log_lines.append(status)
550
+ if len(log_lines) > 15:
551
+ log_lines = log_lines[-15:]
552
+ log_text = "\n".join(log_lines)
553
+
554
+ # Track loss for plot (only valid values)
555
+ if step > 0 and loss is not None and loss == loss: # Check for NaN
556
+ step_list.append(step)
557
+ loss_list.append(float(loss))
558
+ loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
559
+
560
+ yield display_status, log_text, loss_data, training_state
561
+
562
+ if training_state.get("should_stop", False):
563
+ logger.info("⏹️ Training stopped by user")
564
+ log_lines.append("⏹️ Training stopped by user")
565
+ yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
566
+ break
567
+
568
+ total_time = time.time() - start_time
569
+ training_state["is_training"] = False
570
+ completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
571
+
572
+ logger.info(completion_msg)
573
+ log_lines.append(completion_msg)
574
+
575
+ yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
576
+
577
+ except Exception as e:
578
+ logger.exception("Training error")
579
+ training_state["is_training"] = False
580
+ import pandas as pd
581
+ empty_df = pd.DataFrame({"step": [], "loss": []})
582
+ yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
583
+
584
+
585
+ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
586
+ """Stop the current training process.
587
+
588
+ Returns:
589
+ Tuple of (status, training_state)
590
+ """
591
+ if not training_state.get("is_training", False):
592
+ return "⚠️ No training in progress", training_state
593
+
594
+ training_state["should_stop"] = True
595
+ return "⏹️ Stopping training...", training_state
596
+
597
+
598
+ def export_lora(
599
+ export_path: str,
600
+ lora_output_dir: str,
601
+ ) -> str:
602
+ """Export the trained LoRA weights.
603
+
604
+ Returns:
605
+ Status message
606
+ """
607
+ if not export_path or not export_path.strip():
608
+ return "❌ Please enter an export path"
609
+
610
+ # Check if there's a trained model to export
611
+ final_dir = os.path.join(lora_output_dir, "final")
612
+ checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
613
+
614
+ # Prefer final, fallback to checkpoints
615
+ if os.path.exists(final_dir):
616
+ source_path = final_dir
617
+ elif os.path.exists(checkpoint_dir):
618
+ # Find the latest checkpoint
619
+ checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
620
+ if not checkpoints:
621
+ return "❌ No checkpoints found"
622
+
623
+ checkpoints.sort(key=lambda x: int(x.split("_")[1]))
624
+ latest = checkpoints[-1]
625
+ source_path = os.path.join(checkpoint_dir, latest)
626
+ else:
627
+ return f"❌ No trained model found in {lora_output_dir}"
628
+
629
+ try:
630
+ import shutil
631
+
632
+ export_path = export_path.strip()
633
+ os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
634
+
635
+ if os.path.exists(export_path):
636
+ shutil.rmtree(export_path)
637
+
638
+ shutil.copytree(source_path, export_path)
639
+
640
+ return f"✅ LoRA exported to {export_path}"
641
+
642
+ except Exception as e:
643
+ logger.exception("Export error")
644
+ return f"❌ Export failed: {str(e)}"
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Internationalization (i18n) module for Gradio UI
3
+ Supports multiple languages with easy translation management
4
+ """
5
+ import os
6
+ import json
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class I18n:
11
+ """Internationalization handler"""
12
+
13
+ def __init__(self, default_language: str = "en"):
14
+ """
15
+ Initialize i18n handler
16
+
17
+ Args:
18
+ default_language: Default language code (en, zh, ja, etc.)
19
+ """
20
+ self.current_language = default_language
21
+ self.translations: Dict[str, Dict[str, str]] = {}
22
+ self._load_all_translations()
23
+
24
+ def _load_all_translations(self):
25
+ """Load all translation files from i18n directory"""
26
+ current_file = os.path.abspath(__file__)
27
+ module_dir = os.path.dirname(current_file)
28
+ i18n_dir = os.path.join(module_dir, "i18n")
29
+
30
+ if not os.path.exists(i18n_dir):
31
+ # Create i18n directory if it doesn't exist
32
+ os.makedirs(i18n_dir)
33
+ return
34
+
35
+ # Load all JSON files in i18n directory
36
+ for filename in os.listdir(i18n_dir):
37
+ if filename.endswith(".json"):
38
+ lang_code = filename[:-5] # Remove .json extension
39
+ filepath = os.path.join(i18n_dir, filename)
40
+ try:
41
+ with open(filepath, 'r', encoding='utf-8') as f:
42
+ self.translations[lang_code] = json.load(f)
43
+ except Exception as e:
44
+ print(f"Error loading translation file {filename}: {e}")
45
+
46
+ def set_language(self, language: str):
47
+ """Set current language"""
48
+ if language in self.translations:
49
+ self.current_language = language
50
+ else:
51
+ print(f"Warning: Language '{language}' not found, using default")
52
+
53
+ def t(self, key: str, **kwargs) -> str:
54
+ """
55
+ Translate a key to current language
56
+
57
+ Args:
58
+ key: Translation key (dot-separated for nested keys)
59
+ **kwargs: Optional format parameters
60
+
61
+ Returns:
62
+ Translated string
63
+ """
64
+ # Get translation from current language
65
+ translation = self._get_nested_value(
66
+ self.translations.get(self.current_language, {}),
67
+ key
68
+ )
69
+
70
+ # Fallback to English if not found
71
+ if translation is None:
72
+ translation = self._get_nested_value(
73
+ self.translations.get('en', {}),
74
+ key
75
+ )
76
+
77
+ # Final fallback to key itself
78
+ if translation is None:
79
+ translation = key
80
+
81
+ # Apply formatting if kwargs provided
82
+ if kwargs:
83
+ try:
84
+ translation = translation.format(**kwargs)
85
+ except KeyError:
86
+ pass
87
+
88
+ return translation
89
+
90
+ def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
91
+ """
92
+ Get nested dictionary value using dot notation
93
+
94
+ Args:
95
+ data: Dictionary to search
96
+ key: Dot-separated key (e.g., "section.subsection.key")
97
+
98
+ Returns:
99
+ Value if found, None otherwise
100
+ """
101
+ keys = key.split('.')
102
+ current = data
103
+
104
+ for k in keys:
105
+ if isinstance(current, dict) and k in current:
106
+ current = current[k]
107
+ else:
108
+ return None
109
+
110
+ return current if isinstance(current, str) else None
111
+
112
+ def get_available_languages(self) -> list:
113
+ """Get list of available language codes"""
114
+ return list(self.translations.keys())
115
+
116
+
117
+ # Global i18n instance
118
+ _i18n_instance: Optional[I18n] = None
119
+
120
+
121
+ def get_i18n(language: Optional[str] = None) -> I18n:
122
+ """
123
+ Get global i18n instance
124
+
125
+ Args:
126
+ language: Optional language to set
127
+
128
+ Returns:
129
+ I18n instance
130
+ """
131
+ global _i18n_instance
132
+
133
+ if _i18n_instance is None:
134
+ _i18n_instance = I18n(default_language=language or "en")
135
+ elif language is not None:
136
+ _i18n_instance.set_language(language)
137
+
138
+ return _i18n_instance
139
+
140
+
141
+ def t(key: str, **kwargs) -> str:
142
+ """
143
+ Convenience function for translation
144
+
145
+ Args:
146
+ key: Translation key
147
+ **kwargs: Optional format parameters
148
+
149
+ Returns:
150
+ Translated string
151
+ """
152
+ return get_i18n().t(key, **kwargs)
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/en.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 Playground💡",
4
+ "subtitle": "Pushing the Boundaries of Open-Source Music Generation"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 Dataset Explorer",
8
+ "dataset_label": "Dataset",
9
+ "dataset_info": "Choose dataset to explore",
10
+ "import_btn": "📥 Import Dataset",
11
+ "search_type_label": "Search Type",
12
+ "search_type_info": "How to find items",
13
+ "search_value_label": "Search Value",
14
+ "search_value_placeholder": "Enter keys or index (leave empty for random)",
15
+ "search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
16
+ "instruction_label": "📝 Instruction",
17
+ "instruction_placeholder": "No instruction available",
18
+ "metadata_title": "📋 Item Metadata (JSON)",
19
+ "metadata_label": "Complete Item Information",
20
+ "source_audio": "Source Audio",
21
+ "target_audio": "Target Audio",
22
+ "reference_audio": "Reference Audio",
23
+ "get_item_btn": "🔍 Get Item",
24
+ "use_src_checkbox": "Use Source Audio from Dataset",
25
+ "use_src_info": "Check to use the source audio from dataset",
26
+ "data_status_label": "📊 Data Status",
27
+ "data_status_default": "❌ No dataset imported",
28
+ "autofill_btn": "📋 Auto-fill Generation Form"
29
+ },
30
+ "service": {
31
+ "title": "🔧 Service Configuration",
32
+ "checkpoint_label": "Checkpoint File",
33
+ "checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
34
+ "refresh_btn": "🔄 Refresh",
35
+ "model_path_label": "Main Model Path",
36
+ "model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
37
+ "device_label": "Device",
38
+ "device_info": "Processing device (auto-detect recommended)",
39
+ "lm_model_path_label": "5Hz LM Model Path",
40
+ "lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
41
+ "backend_label": "5Hz LM Backend",
42
+ "backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
43
+ "init_llm_label": "Initialize 5Hz LM",
44
+ "init_llm_info": "Check to initialize 5Hz LM during service initialization",
45
+ "flash_attention_label": "Use Flash Attention",
46
+ "flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
47
+ "flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
48
+ "offload_cpu_label": "Offload to CPU",
49
+ "offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
50
+ "offload_dit_cpu_label": "Offload DiT to CPU",
51
+ "offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
52
+ "init_btn": "Initialize Service",
53
+ "status_label": "Status",
54
+ "language_label": "UI Language",
55
+ "language_info": "Select interface language"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 Required Inputs",
59
+ "task_type_label": "Task Type",
60
+ "task_type_info": "Select the task type for generation",
61
+ "instruction_label": "Instruction",
62
+ "instruction_info": "Instruction is automatically generated based on task type",
63
+ "load_btn": "Load",
64
+ "track_name_label": "Track Name",
65
+ "track_name_info": "Select track name for lego/extract tasks",
66
+ "track_classes_label": "Track Names",
67
+ "track_classes_info": "Select multiple track classes for complete task",
68
+ "audio_uploads": "🎵 Audio Uploads",
69
+ "reference_audio": "Reference Audio (optional)",
70
+ "source_audio": "Source Audio (optional)",
71
+ "convert_codes_btn": "Convert to Codes",
72
+ "lm_codes_hints": "🎼 LM Codes Hints",
73
+ "lm_codes_label": "LM Codes Hints",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "Paste LM codes hints for text2music generation",
76
+ "lm_codes_sample": "LM Codes Hints (Sample {n})",
77
+ "lm_codes_sample_info": "Codes for sample {n}",
78
+ "transcribe_btn": "Transcribe",
79
+ "repainting_controls": "🎨 Repainting Controls (seconds)",
80
+ "repainting_start": "Repainting Start",
81
+ "repainting_end": "Repainting End",
82
+ "mode_label": "Generation Mode",
83
+ "mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
84
+ "mode_simple": "Simple",
85
+ "mode_custom": "Custom",
86
+ "simple_query_label": "Song Description",
87
+ "simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
88
+ "simple_query_info": "Enter a natural language description of the music you want to generate",
89
+ "simple_vocal_language_label": "Vocal Language (optional)",
90
+ "simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
91
+ "create_sample_btn": "Create Sample",
92
+ "caption_title": "📝 Music Caption",
93
+ "caption_label": "Music Caption (optional)",
94
+ "caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
95
+ "caption_info": "Describe the style, genre, instruments, and mood",
96
+ "lyrics_title": "📝 Lyrics",
97
+ "lyrics_label": "Lyrics (optional)",
98
+ "lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
99
+ "lyrics_info": "Song lyrics with structure",
100
+ "instrumental_label": "Instrumental",
101
+ "format_btn": "Format",
102
+ "optional_params": "⚙️ Optional Parameters",
103
+ "vocal_language_label": "Vocal Language (optional)",
104
+ "vocal_language_info": "use `unknown` for inst",
105
+ "bpm_label": "BPM (optional)",
106
+ "bpm_info": "leave empty for N/A",
107
+ "keyscale_label": "KeyScale (optional)",
108
+ "keyscale_placeholder": "Leave empty for N/A",
109
+ "keyscale_info": "A-G, #/♭, major/minor",
110
+ "timesig_label": "Time Signature (optional)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "Audio Duration (seconds)",
113
+ "duration_info": "Use -1 for random",
114
+ "batch_size_label": "Batch Size",
115
+ "batch_size_info": "Number of audio to generate (max 8)",
116
+ "advanced_settings": "🔧 Advanced Settings",
117
+ "inference_steps_label": "DiT Inference Steps",
118
+ "inference_steps_info": "Turbo: max 8, Base: max 200",
119
+ "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
120
+ "guidance_scale_info": "Higher values follow text more closely",
121
+ "seed_label": "Seed",
122
+ "seed_info": "Use comma-separated values for batches",
123
+ "random_seed_label": "Random Seed",
124
+ "random_seed_info": "Enable to auto-generate seeds",
125
+ "audio_format_label": "Audio Format",
126
+ "audio_format_info": "Audio format for saved files",
127
+ "use_adg_label": "Use ADG",
128
+ "use_adg_info": "Enable Angle Domain Guidance",
129
+ "shift_label": "Shift",
130
+ "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
131
+ "infer_method_label": "Inference Method",
132
+ "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
133
+ "custom_timesteps_label": "Custom Timesteps",
134
+ "custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
135
+ "cfg_interval_start": "CFG Interval Start",
136
+ "cfg_interval_end": "CFG Interval End",
137
+ "lm_params_title": "🤖 LM Generation Parameters",
138
+ "lm_temperature_label": "LM Temperature",
139
+ "lm_temperature_info": "5Hz LM temperature (higher = more random)",
140
+ "lm_cfg_scale_label": "LM CFG Scale",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = disabled)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = disabled)",
146
+ "lm_negative_prompt_label": "LM Negative Prompt",
147
+ "lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
149
+ "cot_metas_label": "CoT Metas",
150
+ "cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
151
+ "cot_language_label": "CoT Language",
152
+ "cot_language_info": "Generate language in CoT (chain-of-thought)",
153
+ "constrained_debug_label": "Constrained Decoding Debug",
154
+ "constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
155
+ "auto_score_label": "Auto Score",
156
+ "auto_score_info": "Automatically calculate quality scores for all generated audios",
157
+ "auto_lrc_label": "Auto LRC",
158
+ "auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
159
+ "lm_batch_chunk_label": "LM Batch Chunk Size",
160
+ "lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
161
+ "codes_strength_label": "LM Codes Strength",
162
+ "codes_strength_info": "Control how many denoising steps use LM-generated codes",
163
+ "cover_strength_label": "Audio Cover Strength",
164
+ "cover_strength_info": "Control how many denoising steps use cover mode",
165
+ "score_sensitivity_label": "Quality Score Sensitivity",
166
+ "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
167
+ "think_label": "Think",
168
+ "parallel_thinking_label": "ParallelThinking",
169
+ "generate_btn": "🎵 Generate Music",
170
+ "autogen_label": "AutoGen",
171
+ "caption_rewrite_label": "CaptionRewrite"
172
+ },
173
+ "results": {
174
+ "title": "🎵 Results",
175
+ "generated_music": "🎵 Generated Music (Sample {n})",
176
+ "send_to_src_btn": "🔗 Send To Src Audio",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
+ "save_btn": "💾 Save",
180
+ "score_btn": "📊 Score",
181
+ "lrc_btn": "🎵 LRC",
182
+ "quality_score_label": "Quality Score (Sample {n})",
183
+ "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
184
+ "codes_label": "LM Codes (Sample {n})",
185
+ "lrc_label": "Lyrics Timestamps (Sample {n})",
186
+ "lrc_placeholder": "Click 'LRC' to generate timestamps",
187
+ "details_accordion": "📊 Score & LRC & LM Codes",
188
+ "generation_status": "Generation Status",
189
+ "current_batch": "Current Batch",
190
+ "batch_indicator": "Batch {current} / {total}",
191
+ "next_batch_status": "Next Batch Status",
192
+ "prev_btn": "◀ Previous",
193
+ "next_btn": "Next ▶",
194
+ "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
195
+ "batch_results_title": "📁 Batch Results & Generation Details",
196
+ "all_files_label": "📁 All Generated Files (Download)",
197
+ "generation_details": "Generation Details"
198
+ },
199
+ "messages": {
200
+ "no_audio_to_save": "❌ No audio to save",
201
+ "save_success": "✅ Saved audio and metadata to {filename}",
202
+ "save_failed": "❌ Failed to save: {error}",
203
+ "no_file_selected": "⚠️ No file selected",
204
+ "params_loaded": "✅ Parameters loaded from {filename}",
205
+ "invalid_json": "❌ Invalid JSON file: {error}",
206
+ "load_error": "❌ Error loading file: {error}",
207
+ "example_loaded": "📁 Loaded example from {filename}",
208
+ "example_failed": "Failed to parse JSON file {filename}: {error}",
209
+ "example_error": "Error loading example: {error}",
210
+ "lm_generated": "🤖 Generated example using LM",
211
+ "lm_fallback": "Failed to generate example using LM, falling back to examples directory",
212
+ "lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
213
+ "autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
214
+ "batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
215
+ "batch_generating": "🔄 Starting background generation for Batch {n}...",
216
+ "batch_failed": "❌ Background generation failed: {error}",
217
+ "viewing_batch": "✅ Viewing Batch {n}",
218
+ "at_first_batch": "Already at first batch",
219
+ "at_last_batch": "No next batch available",
220
+ "batch_not_found": "Batch {n} not found in queue",
221
+ "no_batch_data": "No batch data found to restore.",
222
+ "params_restored": "✅ UI Parameters restored from Batch {n}",
223
+ "scoring_failed": "❌ Error: Batch data not found",
224
+ "no_codes": "❌ No audio codes available. Please generate music first.",
225
+ "score_failed": "❌ Scoring failed: {error}",
226
+ "score_error": "❌ Error calculating score: {error}",
227
+ "lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
228
+ "lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
229
+ "lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
230
+ "lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
231
+ "lrc_empty_result": "⚠️ LRC generation produced empty result.",
232
+ "empty_query": "⚠️ Please enter a music description.",
233
+ "sample_creation_failed": "❌ Failed to create sample. Please try again.",
234
+ "sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
235
+ "simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
236
+ "simple_examples_empty": "⚠️ No example files found in simple mode examples.",
237
+ "simple_example_loaded": "🎲 Loaded random example from {filename}",
238
+ "format_success": "✅ Caption and lyrics formatted successfully",
239
+ "format_failed": "❌ Format failed: {error}",
240
+ "skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
241
+ "invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
242
+ "timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
243
+ "timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
244
+ }
245
+ }
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/ja.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
4
+ "subtitle": "オープンソース音楽生成の限界を押し広げる"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 データセットエクスプローラー",
8
+ "dataset_label": "データセット",
9
+ "dataset_info": "探索するデータセットを選択",
10
+ "import_btn": "📥 データセットをインポート",
11
+ "search_type_label": "検索タイプ",
12
+ "search_type_info": "アイテムの検索方法",
13
+ "search_value_label": "検索値",
14
+ "search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
15
+ "search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
16
+ "instruction_label": "📝 指示",
17
+ "instruction_placeholder": "利用可能な指示がありません",
18
+ "metadata_title": "📋 アイテムメタデータ (JSON)",
19
+ "metadata_label": "完全なアイテム情報",
20
+ "source_audio": "ソースオーディオ",
21
+ "target_audio": "ターゲットオーディオ",
22
+ "reference_audio": "リファレンスオーディオ",
23
+ "get_item_btn": "🔍 アイテムを取得",
24
+ "use_src_checkbox": "データセットのソースオーディオを使用",
25
+ "use_src_info": "データセットのソースオーディオを使用する場合はチェック",
26
+ "data_status_label": "📊 データステータス",
27
+ "data_status_default": "❌ データセットがインポートされていません",
28
+ "autofill_btn": "📋 生成フォームを自動入力"
29
+ },
30
+ "service": {
31
+ "title": "🔧 サービス設定",
32
+ "checkpoint_label": "チェックポイントファイル",
33
+ "checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
34
+ "refresh_btn": "🔄 更新",
35
+ "model_path_label": "メインモデルパス",
36
+ "model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
37
+ "device_label": "デバイス",
38
+ "device_info": "処理デバイス(自動検出を推奨)",
39
+ "lm_model_path_label": "5Hz LM モデルパス",
40
+ "lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
41
+ "backend_label": "5Hz LM バックエンド",
42
+ "backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
43
+ "init_llm_label": "5Hz LM を初期化",
44
+ "init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
45
+ "flash_attention_label": "Flash Attention を使用",
46
+ "flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
47
+ "flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
48
+ "offload_cpu_label": "CPUにオフロード",
49
+ "offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
50
+ "offload_dit_cpu_label": "DiTをCPUにオフロード",
51
+ "offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
52
+ "init_btn": "サービスを初期化",
53
+ "status_label": "ステータス",
54
+ "language_label": "UI言語",
55
+ "language_info": "インターフェース言語を選択"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必須入力",
59
+ "task_type_label": "タスクタイプ",
60
+ "task_type_info": "生成のタスクタイプを選択",
61
+ "instruction_label": "指示",
62
+ "instruction_info": "指示はタスクタイプに基づいて自動生成されます",
63
+ "load_btn": "読み込む",
64
+ "track_name_label": "トラック名",
65
+ "track_name_info": "lego/extractタスクのトラック名を選択",
66
+ "track_classes_label": "トラック名",
67
+ "track_classes_info": "completeタスクの複数のトラッククラスを選択",
68
+ "audio_uploads": "🎵 オーディオアップロード",
69
+ "reference_audio": "リファレンスオーディオ(オプション)",
70
+ "source_audio": "ソースオーディオ(オプション)",
71
+ "convert_codes_btn": "コードに変換",
72
+ "lm_codes_hints": "🎼 LM コードヒント",
73
+ "lm_codes_label": "LM コードヒント",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
76
+ "lm_codes_sample": "LM コードヒント(サンプル {n})",
77
+ "lm_codes_sample_info": "サンプル{n}のコード",
78
+ "transcribe_btn": "転写",
79
+ "repainting_controls": "🎨 再描画コントロール(秒)",
80
+ "repainting_start": "再描画開始",
81
+ "repainting_end": "再描画終了",
82
+ "mode_label": "生成モード",
83
+ "mode_info": "シンプル:自然言語で音楽を説明��カスタム:キャプションと歌詞を完全にコントロール。",
84
+ "mode_simple": "シンプル",
85
+ "mode_custom": "カスタム",
86
+ "simple_query_label": "曲の説明",
87
+ "simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
88
+ "simple_query_info": "生成したい音楽の自然言語の説明を入力",
89
+ "simple_vocal_language_label": "ボーカル言語(オプション)",
90
+ "simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
91
+ "create_sample_btn": "サンプル作成",
92
+ "caption_title": "📝 音楽キャプション",
93
+ "caption_label": "音楽キャプション(オプション)",
94
+ "caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
95
+ "caption_info": "スタイル、ジャンル、楽器、ムードを説明",
96
+ "lyrics_title": "📝 歌詞",
97
+ "lyrics_label": "歌詞(オプション)",
98
+ "lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
99
+ "lyrics_info": "構造を持つ曲の歌詞",
100
+ "instrumental_label": "インストゥルメンタル",
101
+ "format_btn": "フォーマット",
102
+ "optional_params": "⚙️ オプションパラメータ",
103
+ "vocal_language_label": "ボーカル言語(オプション)",
104
+ "vocal_language_info": "インストには`unknown`を使用",
105
+ "bpm_label": "BPM(オプション)",
106
+ "bpm_info": "空白の場合はN/A",
107
+ "keyscale_label": "キースケール(オプション)",
108
+ "keyscale_placeholder": "空白の場合はN/A",
109
+ "keyscale_info": "A-G, #/♭, メジャー/マイナー",
110
+ "timesig_label": "拍子記号(オプション)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "オーディオ長(秒)",
113
+ "duration_info": "ランダムの場合は-1を使用",
114
+ "batch_size_label": "バッチサイズ",
115
+ "batch_size_info": "生成するオーディオの数(最大8)",
116
+ "advanced_settings": "🔧 詳細設定",
117
+ "inference_steps_label": "DiT 推論ステップ",
118
+ "inference_steps_info": "Turbo: 最大8、Base: 最大200",
119
+ "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
120
+ "guidance_scale_info": "値が高いほどテキストに忠実に従う",
121
+ "seed_label": "シード",
122
+ "seed_info": "バッチにはカンマ区切りの値を使用",
123
+ "random_seed_label": "ランダムシード",
124
+ "random_seed_info": "有効にすると自動的にシードを生成",
125
+ "audio_format_label": "オーディオフォーマット",
126
+ "audio_format_info": "保存ファイルのオーディオフォーマット",
127
+ "use_adg_label": "ADG を使用",
128
+ "use_adg_info": "角度ドメインガイダンスを有効化",
129
+ "shift_label": "シフト",
130
+ "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
131
+ "infer_method_label": "推論方法",
132
+ "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
133
+ "custom_timesteps_label": "カスタムタイムステップ",
134
+ "custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
135
+ "cfg_interval_start": "CFG 間隔開始",
136
+ "cfg_interval_end": "CFG 間隔終了",
137
+ "lm_params_title": "🤖 LM 生成パラメータ",
138
+ "lm_temperature_label": "LM 温度",
139
+ "lm_temperature_info": "5Hz LM温度(高いほどランダム)",
140
+ "lm_cfg_scale_label": "LM CFG スケール",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = 無効)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = 無効)",
146
+ "lm_negative_prompt_label": "LM ネガティブプロンプト",
147
+ "lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
149
+ "cot_metas_label": "CoT メタデータ",
150
+ "cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
151
+ "cot_language_label": "CoT 言語",
152
+ "cot_language_info": "CoTで言語を生成(思考の連鎖)",
153
+ "constrained_debug_label": "制約付きデコーディングデバッグ",
154
+ "constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
155
+ "auto_score_label": "自動スコアリング",
156
+ "auto_score_info": "生成���れたすべてのオーディオの品質スコアを自動計算",
157
+ "auto_lrc_label": "自動 LRC",
158
+ "auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
159
+ "lm_batch_chunk_label": "LM バッチチャンクサイズ",
160
+ "lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
161
+ "codes_strength_label": "LM コード強度",
162
+ "codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
163
+ "cover_strength_label": "オーディオカバー強度",
164
+ "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
165
+ "score_sensitivity_label": "品質スコア感度",
166
+ "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
167
+ "think_label": "思考",
168
+ "parallel_thinking_label": "並列思考",
169
+ "generate_btn": "🎵 音楽を生成",
170
+ "autogen_label": "自動生成",
171
+ "caption_rewrite_label": "キャプション書き換え"
172
+ },
173
+ "results": {
174
+ "title": "🎵 結果",
175
+ "generated_music": "🎵 生成された音楽(サンプル {n})",
176
+ "send_to_src_btn": "🔗 ソースオーディオに送信",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
+ "save_btn": "💾 保存",
180
+ "score_btn": "📊 スコア",
181
+ "lrc_btn": "🎵 LRC",
182
+ "quality_score_label": "品質スコア(サンプル {n})",
183
+ "quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
184
+ "codes_label": "LM コード(サンプル {n})",
185
+ "lrc_label": "歌詞タイムスタンプ(サンプル {n})",
186
+ "lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
187
+ "details_accordion": "📊 スコア & LRC & LM コード",
188
+ "generation_status": "生成ステータス",
189
+ "current_batch": "現在のバッチ",
190
+ "batch_indicator": "バッチ {current} / {total}",
191
+ "next_batch_status": "次のバッチステータス",
192
+ "prev_btn": "◀ 前へ",
193
+ "next_btn": "次へ ▶",
194
+ "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
195
+ "batch_results_title": "📁 バッチ結果と生成詳細",
196
+ "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
197
+ "generation_details": "生成詳細"
198
+ },
199
+ "messages": {
200
+ "no_audio_to_save": "❌ 保存するオーディオがありません",
201
+ "save_success": "✅ オーディオとメタデータを {filename} に保存しました",
202
+ "save_failed": "❌ 保存に失敗しました: {error}",
203
+ "no_file_selected": "⚠️ ファイルが選択されていません",
204
+ "params_loaded": "✅ {filename} からパラメータを読み込みました",
205
+ "invalid_json": "❌ 無効なJSONファイル: {error}",
206
+ "load_error": "❌ ファイルの読み込みエラー: {error}",
207
+ "example_loaded": "📁 {filename} からサンプルを読み込みました",
208
+ "example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
209
+ "example_error": "サンプル読み込みエラー: {error}",
210
+ "lm_generated": "🤖 LMを使用してサンプルを生成しました",
211
+ "lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
212
+ "lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
213
+ "autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
214
+ "batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
215
+ "batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
216
+ "batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
217
+ "viewing_batch": "✅ バッチ {n} を表示中",
218
+ "at_first_batch": "すでに最初のバッチです",
219
+ "at_last_batch": "次のバッチはありません",
220
+ "batch_not_found": "キューにバッチ {n} が見つかりません",
221
+ "no_batch_data": "復元するバッチデータがありません。",
222
+ "params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
223
+ "scoring_failed": "❌ エラー: バッチデータが見つかりません",
224
+ "no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
225
+ "score_failed": "❌ スコアリングに失敗しました: {error}",
226
+ "score_error": "❌ スコア計算エラー: {error}",
227
+ "lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
228
+ "lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
229
+ "lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
230
+ "lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
231
+ "lrc_empty_result": "⚠️ LRC生成の結果が空です。",
232
+ "empty_query": "⚠️ 音楽の説明を入力してください。",
233
+ "sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
234
+ "sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
235
+ "simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
236
+ "simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
237
+ "simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
238
+ "format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
239
+ "format_failed": "❌ フォーマットに失敗しました: {error}",
240
+ "skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
241
+ "invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
242
+ "timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
243
+ "timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
244
+ }
245
+ }
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/zh.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 演练场💡",
4
+ "subtitle": "推动开源音乐生成的边界"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 数据集浏览器",
8
+ "dataset_label": "数据集",
9
+ "dataset_info": "选择要浏览的数据集",
10
+ "import_btn": "📥 导入数据集",
11
+ "search_type_label": "搜索类型",
12
+ "search_type_info": "如何查找项目",
13
+ "search_value_label": "搜索值",
14
+ "search_value_placeholder": "输入键或索引(留空表示随机)",
15
+ "search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
16
+ "instruction_label": "📝 指令",
17
+ "instruction_placeholder": "无可用指令",
18
+ "metadata_title": "📋 项目元数据 (JSON)",
19
+ "metadata_label": "完整项目信息",
20
+ "source_audio": "源音频",
21
+ "target_audio": "目标音频",
22
+ "reference_audio": "参考音频",
23
+ "get_item_btn": "🔍 获取项目",
24
+ "use_src_checkbox": "使用数据集中的源音频",
25
+ "use_src_info": "勾选以使用数据集中的源音频",
26
+ "data_status_label": "📊 数据状态",
27
+ "data_status_default": "❌ 未导入数据集",
28
+ "autofill_btn": "📋 自动填充生成表单"
29
+ },
30
+ "service": {
31
+ "title": "🔧 服务配置",
32
+ "checkpoint_label": "检查点文件",
33
+ "checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
34
+ "refresh_btn": "🔄 刷新",
35
+ "model_path_label": "主模型路径",
36
+ "model_path_info": "选择模型配置目录(从检查点自动扫描)",
37
+ "device_label": "设备",
38
+ "device_info": "处理设备(建议自动检测)",
39
+ "lm_model_path_label": "5Hz LM 模型路径",
40
+ "lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
41
+ "backend_label": "5Hz LM 后端",
42
+ "backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
43
+ "init_llm_label": "初始化 5Hz LM",
44
+ "init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
45
+ "flash_attention_label": "使用Flash Attention",
46
+ "flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
47
+ "flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
48
+ "offload_cpu_label": "卸载到CPU",
49
+ "offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
50
+ "offload_dit_cpu_label": "将DiT卸载到CPU",
51
+ "offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
52
+ "init_btn": "初始化服务",
53
+ "status_label": "状态",
54
+ "language_label": "界面语言",
55
+ "language_info": "选择界面语言"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必需输入",
59
+ "task_type_label": "任务类型",
60
+ "task_type_info": "选择生成的任务类型",
61
+ "instruction_label": "指令",
62
+ "instruction_info": "指令根据任务类型自动生成",
63
+ "load_btn": "加载",
64
+ "track_name_label": "音轨名称",
65
+ "track_name_info": "为lego/extract任务选择音轨名称",
66
+ "track_classes_label": "音轨名称",
67
+ "track_classes_info": "为complete任务选择多个音轨类别",
68
+ "audio_uploads": "🎵 音频上传",
69
+ "reference_audio": "参考音频(可选)",
70
+ "source_audio": "源音频(可选)",
71
+ "convert_codes_btn": "转换为代码",
72
+ "lm_codes_hints": "🎼 LM 代码提示",
73
+ "lm_codes_label": "LM 代码提示",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "粘贴用于text2music生成的LM代码提示",
76
+ "lm_codes_sample": "LM 代码提示(样本 {n})",
77
+ "lm_codes_sample_info": "样本{n}的代码",
78
+ "transcribe_btn": "转录",
79
+ "repainting_controls": "🎨 重绘控制(秒)",
80
+ "repainting_start": "重绘开始",
81
+ "repainting_end": "重绘结束",
82
+ "mode_label": "生成模式",
83
+ "mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
84
+ "mode_simple": "简单",
85
+ "mode_custom": "自定义",
86
+ "simple_query_label": "歌曲描述",
87
+ "simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
88
+ "simple_query_info": "输入你想生成的音乐的自然语言描述",
89
+ "simple_vocal_language_label": "人声语言(可选)",
90
+ "simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
91
+ "create_sample_btn": "创建样本",
92
+ "caption_title": "📝 音乐描述",
93
+ "caption_label": "音乐描述(可选)",
94
+ "caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
95
+ "caption_info": "描述风格、流派、乐器和情绪",
96
+ "lyrics_title": "📝 歌词",
97
+ "lyrics_label": "歌词(可选)",
98
+ "lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
99
+ "lyrics_info": "带有结构的歌曲歌词",
100
+ "instrumental_label": "纯音乐",
101
+ "format_btn": "格式化",
102
+ "optional_params": "⚙️ 可选参数",
103
+ "vocal_language_label": "人声语言(可选)",
104
+ "vocal_language_info": "纯音乐使用 `unknown`",
105
+ "bpm_label": "BPM(可选)",
106
+ "bpm_info": "留空表示N/A",
107
+ "keyscale_label": "调性(可选)",
108
+ "keyscale_placeholder": "留空表示N/A",
109
+ "keyscale_info": "A-G, #/♭, 大调/小调",
110
+ "timesig_label": "拍号(可选)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "音频时长(秒)",
113
+ "duration_info": "使用-1表示随机",
114
+ "batch_size_label": "批量大小",
115
+ "batch_size_info": "要生成的音频数量(最多8个)",
116
+ "advanced_settings": "🔧 高级设置",
117
+ "inference_steps_label": "DiT 推理步数",
118
+ "inference_steps_info": "Turbo: 最多8, Base: 最多200",
119
+ "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
120
+ "guidance_scale_info": "更高的值更紧密地遵循文本",
121
+ "seed_label": "种子",
122
+ "seed_info": "批量使用逗号分隔的值",
123
+ "random_seed_label": "随机种子",
124
+ "random_seed_info": "启用以自动生成种子",
125
+ "audio_format_label": "音频格式",
126
+ "audio_format_info": "保存文件的音频格式",
127
+ "use_adg_label": "使用 ADG",
128
+ "use_adg_info": "启用角域引导",
129
+ "shift_label": "Shift",
130
+ "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
131
+ "infer_method_label": "推理方法",
132
+ "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
133
+ "custom_timesteps_label": "自定义时间步",
134
+ "custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
135
+ "cfg_interval_start": "CFG 间隔开始",
136
+ "cfg_interval_end": "CFG 间隔结束",
137
+ "lm_params_title": "🤖 LM 生成参数",
138
+ "lm_temperature_label": "LM 温度",
139
+ "lm_temperature_info": "5Hz LM温度(越高越随机)",
140
+ "lm_cfg_scale_label": "LM CFG 比例",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = 禁用)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = 禁用)",
146
+ "lm_negative_prompt_label": "LM 负面提示",
147
+ "lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
149
+ "cot_metas_label": "CoT 元数据",
150
+ "cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
151
+ "cot_language_label": "CoT 语言",
152
+ "cot_language_info": "在CoT中生成语言(思维链)",
153
+ "constrained_debug_label": "约束解码调试",
154
+ "constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
155
+ "auto_score_label": "自动评分",
156
+ "auto_score_info": "自动计算所有生成音频的质量分数",
157
+ "auto_lrc_label": "自动 LRC",
158
+ "auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
159
+ "lm_batch_chunk_label": "LM 批量块大小",
160
+ "lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
161
+ "codes_strength_label": "LM 代码强度",
162
+ "codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
163
+ "cover_strength_label": "音频覆盖强度",
164
+ "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
165
+ "score_sensitivity_label": "质量评分敏感度",
166
+ "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
167
+ "think_label": "思考",
168
+ "parallel_thinking_label": "并行思考",
169
+ "generate_btn": "🎵 生成音乐",
170
+ "autogen_label": "自动生成",
171
+ "caption_rewrite_label": "描述重写"
172
+ },
173
+ "results": {
174
+ "title": "🎵 结果",
175
+ "generated_music": "🎵 生成的音乐(样本 {n})",
176
+ "send_to_src_btn": "🔗 发送到源音频",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
+ "save_btn": "💾 保存",
180
+ "score_btn": "📊 评分",
181
+ "lrc_btn": "🎵 LRC",
182
+ "quality_score_label": "质量分数(样本 {n})",
183
+ "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
184
+ "codes_label": "LM 代码(样本 {n})",
185
+ "lrc_label": "歌词时间戳(样本 {n})",
186
+ "lrc_placeholder": "点击'LRC'生成时间戳",
187
+ "details_accordion": "📊 评分与LRC与LM代码",
188
+ "generation_status": "生成状态",
189
+ "current_batch": "当前批次",
190
+ "batch_indicator": "批次 {current} / {total}",
191
+ "next_batch_status": "下一批次状态",
192
+ "prev_btn": "◀ 上一个",
193
+ "next_btn": "下一个 ▶",
194
+ "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
195
+ "batch_results_title": "📁 批量结果和生成详情",
196
+ "all_files_label": "📁 所有生成的文件(��载)",
197
+ "generation_details": "生成详情"
198
+ },
199
+ "messages": {
200
+ "no_audio_to_save": "❌ 没有要保存的音频",
201
+ "save_success": "✅ 已将音频和元数据保存到 {filename}",
202
+ "save_failed": "❌ 保存失败: {error}",
203
+ "no_file_selected": "⚠️ 未选择文件",
204
+ "params_loaded": "✅ 已从 {filename} 加载参数",
205
+ "invalid_json": "❌ 无效的JSON文件: {error}",
206
+ "load_error": "❌ 加载文件时出错: {error}",
207
+ "example_loaded": "📁 已从 {filename} 加载示例",
208
+ "example_failed": "解析JSON文件 {filename} 失败: {error}",
209
+ "example_error": "加载示例时出错: {error}",
210
+ "lm_generated": "🤖 使用LM生成的示例",
211
+ "lm_fallback": "使用LM生成示例失败,回退到示例目录",
212
+ "lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
213
+ "autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
214
+ "batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
215
+ "batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
216
+ "batch_failed": "❌ 后台生成失败: {error}",
217
+ "viewing_batch": "✅ 查看批次 {n}",
218
+ "at_first_batch": "已在第一批次",
219
+ "at_last_batch": "没有下一批次可用",
220
+ "batch_not_found": "在队列中未找到批次 {n}",
221
+ "no_batch_data": "没有要恢复的批次数据。",
222
+ "params_restored": "✅ 已从批次 {n} 恢复UI参数",
223
+ "scoring_failed": "❌ 错误: 未找到批次数据",
224
+ "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
225
+ "score_failed": "❌ 评分失败: {error}",
226
+ "score_error": "❌ 计算分数时出错: {error}",
227
+ "lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
228
+ "lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
229
+ "lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
230
+ "lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
231
+ "lrc_empty_result": "⚠️ LRC生成结果为空。",
232
+ "empty_query": "⚠️ 请输入音乐描述。",
233
+ "sample_creation_failed": "❌ 创建样本失败。请重试。",
234
+ "sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
235
+ "simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
236
+ "simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
237
+ "simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
238
+ "format_success": "✅ 描述和歌词格式化成功",
239
+ "format_failed": "❌ 格式化失败: {error}",
240
+ "skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
241
+ "invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
242
+ "timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
243
+ "timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
244
+ }
245
+ }
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/__init__.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Components Module
3
+ Contains all Gradio interface component definitions and layouts
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import get_i18n, t
7
+ from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
+ from acestep.gradio_ui.interfaces.generation import create_generation_section
9
+ from acestep.gradio_ui.interfaces.result import create_results_section
10
+ from acestep.gradio_ui.interfaces.training import create_training_section
11
+ from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
12
+
13
+
14
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
15
+ """
16
+ Create Gradio interface
17
+
18
+ Args:
19
+ dit_handler: DiT handler instance
20
+ llm_handler: LM handler instance
21
+ dataset_handler: Dataset handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
25
+
26
+ Returns:
27
+ Gradio Blocks instance
28
+ """
29
+ # Initialize i18n with selected language
30
+ i18n = get_i18n(language)
31
+
32
+ with gr.Blocks(
33
+ title=t("app.title"),
34
+ theme=gr.themes.Soft(),
35
+ css="""
36
+ .main-header {
37
+ text-align: center;
38
+ margin-bottom: 2rem;
39
+ }
40
+ .section-header {
41
+ background: linear-gradient(90deg, #4CAF50, #45a049);
42
+ color: white;
43
+ padding: 10px;
44
+ border-radius: 5px;
45
+ margin: 10px 0;
46
+ }
47
+ .lm-hints-row {
48
+ align-items: stretch;
49
+ }
50
+ .lm-hints-col {
51
+ display: flex;
52
+ }
53
+ .lm-hints-col > div {
54
+ flex: 1;
55
+ display: flex;
56
+ }
57
+ .lm-hints-btn button {
58
+ height: 100%;
59
+ width: 100%;
60
+ }
61
+ """
62
+ ) as demo:
63
+
64
+ gr.HTML(f"""
65
+ <div class="main-header">
66
+ <h1>{t("app.title")}</h1>
67
+ <p>{t("app.subtitle")}</p>
68
+ <p style="margin-top: 0.5rem;">
69
+ <a href="https://ace-step.github.io/ace-step-v1.5.github.io/" target="_blank">Project</a> |
70
+ <a href="https://huggingface.co/collections/ACE-Step/ace-step-15" target="_blank">Hugging Face</a> |
71
+ <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5" target="_blank">ModelScope</a> |
72
+ <a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5" target="_blank">Space Demo</a> |
73
+ <a href="https://discord.gg/PeWDxrkdj7" target="_blank">Discord</a> |
74
+ <a href="https://arxiv.org/abs/2506.00045" target="_blank">Technical Report</a>
75
+ </p>
76
+ </div>
77
+ """)
78
+
79
+ # Dataset Explorer Section
80
+ dataset_section = create_dataset_section(dataset_handler)
81
+
82
+ # Generation Section (pass init_params and language to support pre-initialization)
83
+ generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
84
+
85
+ # Results Section
86
+ results_section = create_results_section(dit_handler)
87
+
88
+ # Training Section (LoRA training and dataset builder)
89
+ # Pass init_params to support hiding in service mode
90
+ training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
91
+
92
+ # Connect event handlers (pass init_params for multi-model support)
93
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=init_params)
94
+
95
+ # Connect training event handlers
96
+ setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
97
+
98
+ return demo
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Dataset Section Module
3
+ Contains dataset explorer section component definitions
4
+ """
5
+ import gradio as gr
6
+
7
+
8
+ def create_dataset_section(dataset_handler) -> dict:
9
+ """Create dataset explorer section"""
10
+ with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
11
+ with gr.Row(equal_height=True):
12
+ dataset_type = gr.Dropdown(
13
+ choices=["train", "test"],
14
+ value="train",
15
+ label="Dataset",
16
+ info="Choose dataset to explore",
17
+ scale=2
18
+ )
19
+ import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
20
+
21
+ search_type = gr.Dropdown(
22
+ choices=["keys", "idx", "random"],
23
+ value="random",
24
+ label="Search Type",
25
+ info="How to find items",
26
+ scale=1
27
+ )
28
+ search_value = gr.Textbox(
29
+ label="Search Value",
30
+ placeholder="Enter keys or index (leave empty for random)",
31
+ info="Keys: exact match, Index: 0 to dataset size-1",
32
+ scale=2
33
+ )
34
+
35
+ instruction_display = gr.Textbox(
36
+ label="📝 Instruction",
37
+ interactive=False,
38
+ placeholder="No instruction available",
39
+ lines=1
40
+ )
41
+
42
+ repaint_viz_plot = gr.Plot()
43
+
44
+ with gr.Accordion("📋 Item Metadata (JSON)", open=False):
45
+ item_info_json = gr.Code(
46
+ label="Complete Item Information",
47
+ language="json",
48
+ interactive=False,
49
+ lines=15
50
+ )
51
+
52
+ with gr.Row(equal_height=True):
53
+ item_src_audio = gr.Audio(
54
+ label="Source Audio",
55
+ type="filepath",
56
+ interactive=False,
57
+ scale=8
58
+ )
59
+ get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
60
+
61
+ with gr.Row(equal_height=True):
62
+ item_target_audio = gr.Audio(
63
+ label="Target Audio",
64
+ type="filepath",
65
+ interactive=False,
66
+ scale=8
67
+ )
68
+ item_refer_audio = gr.Audio(
69
+ label="Reference Audio",
70
+ type="filepath",
71
+ interactive=False,
72
+ scale=2
73
+ )
74
+
75
+ with gr.Row():
76
+ use_src_checkbox = gr.Checkbox(
77
+ label="Use Source Audio from Dataset",
78
+ value=True,
79
+ info="Check to use the source audio from dataset"
80
+ )
81
+
82
+ data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
83
+ auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
84
+
85
+ return {
86
+ "dataset_type": dataset_type,
87
+ "import_dataset_btn": import_dataset_btn,
88
+ "search_type": search_type,
89
+ "search_value": search_value,
90
+ "instruction_display": instruction_display,
91
+ "repaint_viz_plot": repaint_viz_plot,
92
+ "item_info_json": item_info_json,
93
+ "item_src_audio": item_src_audio,
94
+ "get_item_btn": get_item_btn,
95
+ "item_target_audio": item_target_audio,
96
+ "item_refer_audio": item_refer_audio,
97
+ "use_src_checkbox": use_src_checkbox,
98
+ "data_status": data_status,
99
+ "auto_fill_btn": auto_fill_btn,
100
+ }
101
+
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/generation.py ADDED
@@ -0,0 +1,693 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Generation Section Module
3
+ Contains generation section component definitions - Simplified UI
4
+ """
5
+ import gradio as gr
6
+ from acestep.constants import (
7
+ VALID_LANGUAGES,
8
+ TRACK_NAMES,
9
+ TASK_TYPES_TURBO,
10
+ TASK_TYPES_BASE,
11
+ DEFAULT_DIT_INSTRUCTION,
12
+ )
13
+ from acestep.gradio_ui.i18n import t
14
+
15
+
16
+ def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
17
+ """Create generation section with simplified UI
18
+
19
+ Args:
20
+ dit_handler: DiT handler instance
21
+ llm_handler: LM handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja')
25
+ """
26
+ # Check if service is pre-initialized
27
+ service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
28
+
29
+ # Check if running in service mode (restricted UI)
30
+ service_mode = init_params is not None and init_params.get('service_mode', False)
31
+
32
+ # Get current language from init_params if available
33
+ current_language = init_params.get('language', language) if init_params else language
34
+
35
+ # Get available models
36
+ available_dit_models = init_params.get('available_dit_models', []) if init_params else []
37
+ current_model_value = init_params.get('config_path', '') if init_params else ''
38
+ show_model_selector = len(available_dit_models) > 1
39
+
40
+ with gr.Group():
41
+ # ==================== Service Configuration (Hidden in service mode) ====================
42
+ accordion_open = not service_pre_initialized
43
+ accordion_visible = not service_pre_initialized
44
+ with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
45
+ # Language selector at the top
46
+ with gr.Row():
47
+ language_dropdown = gr.Dropdown(
48
+ choices=[
49
+ ("English", "en"),
50
+ ("中文", "zh"),
51
+ ("日本語", "ja"),
52
+ ],
53
+ value=current_language,
54
+ label=t("service.language_label"),
55
+ info=t("service.language_info"),
56
+ scale=1,
57
+ )
58
+
59
+ with gr.Row(equal_height=True):
60
+ with gr.Column(scale=4):
61
+ checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
62
+ checkpoint_dropdown = gr.Dropdown(
63
+ label=t("service.checkpoint_label"),
64
+ choices=dit_handler.get_available_checkpoints(),
65
+ value=checkpoint_value,
66
+ info=t("service.checkpoint_info")
67
+ )
68
+ with gr.Column(scale=1, min_width=90):
69
+ refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
70
+
71
+ with gr.Row():
72
+ available_models = dit_handler.get_available_acestep_v15_models()
73
+ default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
74
+ config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
75
+ config_path = gr.Dropdown(
76
+ label=t("service.model_path_label"),
77
+ choices=available_models,
78
+ value=config_path_value,
79
+ info=t("service.model_path_info")
80
+ )
81
+ device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
82
+ device = gr.Dropdown(
83
+ choices=["auto", "cuda", "cpu"],
84
+ value=device_value,
85
+ label=t("service.device_label"),
86
+ info=t("service.device_info")
87
+ )
88
+
89
+ with gr.Row():
90
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
91
+ default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
92
+ lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
93
+ lm_model_path = gr.Dropdown(
94
+ label=t("service.lm_model_path_label"),
95
+ choices=available_lm_models,
96
+ value=lm_model_path_value,
97
+ info=t("service.lm_model_path_info")
98
+ )
99
+ backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
100
+ backend_dropdown = gr.Dropdown(
101
+ choices=["vllm", "pt"],
102
+ value=backend_value,
103
+ label=t("service.backend_label"),
104
+ info=t("service.backend_info")
105
+ )
106
+
107
+ with gr.Row():
108
+ init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
109
+ init_llm_checkbox = gr.Checkbox(
110
+ label=t("service.init_llm_label"),
111
+ value=init_llm_value,
112
+ info=t("service.init_llm_info"),
113
+ )
114
+ flash_attn_available = dit_handler.is_flash_attention_available()
115
+ use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
116
+ use_flash_attention_checkbox = gr.Checkbox(
117
+ label=t("service.flash_attention_label"),
118
+ value=use_flash_attention_value,
119
+ interactive=flash_attn_available,
120
+ info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
121
+ )
122
+ offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
123
+ offload_to_cpu_checkbox = gr.Checkbox(
124
+ label=t("service.offload_cpu_label"),
125
+ value=offload_to_cpu_value,
126
+ info=t("service.offload_cpu_info")
127
+ )
128
+ offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
129
+ offload_dit_to_cpu_checkbox = gr.Checkbox(
130
+ label=t("service.offload_dit_cpu_label"),
131
+ value=offload_dit_to_cpu_value,
132
+ info=t("service.offload_dit_cpu_info")
133
+ )
134
+
135
+ init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
136
+ init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
137
+ init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
138
+
139
+ # LoRA Configuration Section
140
+ gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
141
+ with gr.Row():
142
+ lora_path = gr.Textbox(
143
+ label="LoRA Path",
144
+ placeholder="./lora_output/final/adapter",
145
+ info="Path to trained LoRA adapter directory",
146
+ scale=3,
147
+ )
148
+ load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
149
+ unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
150
+ with gr.Row():
151
+ use_lora_checkbox = gr.Checkbox(
152
+ label="Use LoRA",
153
+ value=False,
154
+ info="Enable LoRA adapter for inference",
155
+ scale=1,
156
+ )
157
+ lora_status = gr.Textbox(
158
+ label="LoRA Status",
159
+ value="No LoRA loaded",
160
+ interactive=False,
161
+ scale=2,
162
+ )
163
+
164
+ # ==================== Model Selector (Top, only when multiple models) ====================
165
+ with gr.Row(visible=show_model_selector):
166
+ dit_model_selector = gr.Dropdown(
167
+ choices=available_dit_models,
168
+ value=current_model_value,
169
+ label="models",
170
+ scale=1,
171
+ )
172
+
173
+ # Hidden dropdown when only one model (for event handler compatibility)
174
+ if not show_model_selector:
175
+ dit_model_selector = gr.Dropdown(
176
+ choices=available_dit_models if available_dit_models else [current_model_value],
177
+ value=current_model_value,
178
+ visible=False,
179
+ )
180
+
181
+ # ==================== Generation Mode (4 modes) ====================
182
+ gr.HTML("<div style='background: #4a5568; color: white; padding: 8px 16px; border-radius: 4px; font-weight: bold;'>Generation Mode</div>")
183
+ with gr.Row():
184
+ generation_mode = gr.Radio(
185
+ choices=[
186
+ ("Simple", "simple"),
187
+ ("Custom", "custom"),
188
+ ("Cover", "cover"),
189
+ ("Repaint", "repaint"),
190
+ ],
191
+ value="custom",
192
+ label="",
193
+ show_label=False,
194
+ )
195
+
196
+ # ==================== Simple Mode Group ====================
197
+ with gr.Column(visible=False) as simple_mode_group:
198
+ # Row: Song Description + Vocal Language + Random button
199
+ with gr.Row(equal_height=True):
200
+ simple_query_input = gr.Textbox(
201
+ label=t("generation.simple_query_label"),
202
+ placeholder=t("generation.simple_query_placeholder"),
203
+ lines=2,
204
+ info=t("generation.simple_query_info"),
205
+ scale=10,
206
+ )
207
+ simple_vocal_language = gr.Dropdown(
208
+ choices=VALID_LANGUAGES,
209
+ value="unknown",
210
+ allow_custom_value=True,
211
+ label=t("generation.simple_vocal_language_label"),
212
+ interactive=True,
213
+ info="use unknown for instrumental",
214
+ scale=2,
215
+ )
216
+ with gr.Column(scale=1, min_width=60):
217
+ random_desc_btn = gr.Button(
218
+ "🎲",
219
+ variant="secondary",
220
+ size="lg",
221
+ )
222
+
223
+ # Hidden components (kept for compatibility but not shown)
224
+ simple_instrumental_checkbox = gr.Checkbox(
225
+ label=t("generation.instrumental_label"),
226
+ value=False,
227
+ visible=False,
228
+ )
229
+ create_sample_btn = gr.Button(
230
+ t("generation.create_sample_btn"),
231
+ variant="primary",
232
+ size="lg",
233
+ visible=False,
234
+ )
235
+
236
+ # State to track if sample has been created in Simple mode
237
+ simple_sample_created = gr.State(value=False)
238
+
239
+ # ==================== Source Audio (for Cover/Repaint) ====================
240
+ # This is shown above the main content for Cover and Repaint modes
241
+ with gr.Column(visible=False) as src_audio_group:
242
+ with gr.Row(equal_height=True):
243
+ # Source Audio - scale=10 to match (refer_audio=2 + prompt/lyrics=8)
244
+ src_audio = gr.Audio(
245
+ label="Source Audio",
246
+ type="filepath",
247
+ scale=10,
248
+ )
249
+ # Process button - scale=1 to align with random button
250
+ with gr.Column(scale=1, min_width=80):
251
+ process_src_btn = gr.Button(
252
+ "Analyze",
253
+ variant="secondary",
254
+ size="lg",
255
+ )
256
+
257
+ # Hidden Audio Codes storage (needed internally but not displayed)
258
+ text2music_audio_code_string = gr.Textbox(
259
+ label="Audio Codes",
260
+ visible=False,
261
+ )
262
+
263
+ # ==================== Custom/Cover/Repaint Mode Content ====================
264
+ with gr.Column() as custom_mode_content:
265
+ with gr.Row(equal_height=True):
266
+ # Left: Reference Audio
267
+ with gr.Column(scale=2, min_width=200):
268
+ reference_audio = gr.Audio(
269
+ label="Reference Audio (optional)",
270
+ type="filepath",
271
+ show_label=True,
272
+ )
273
+
274
+ # Middle: Prompt + Lyrics + Format button
275
+ with gr.Column(scale=8):
276
+ # Row 1: Prompt and Lyrics
277
+ with gr.Row(equal_height=True):
278
+ captions = gr.Textbox(
279
+ label="Prompt",
280
+ placeholder="Describe the music style, mood, instruments...",
281
+ lines=12,
282
+ max_lines=12,
283
+ scale=1,
284
+ )
285
+ lyrics = gr.Textbox(
286
+ label="Lyrics",
287
+ placeholder="Enter lyrics here... Use [Verse], [Chorus] etc. for structure",
288
+ lines=12,
289
+ max_lines=12,
290
+ scale=1,
291
+ )
292
+
293
+ # Row 2: Format button (only below Prompt and Lyrics)
294
+ format_btn = gr.Button(
295
+ "Format",
296
+ variant="secondary",
297
+ )
298
+
299
+ # Right: Random button
300
+ with gr.Column(scale=1, min_width=60):
301
+ sample_btn = gr.Button(
302
+ "🎲",
303
+ variant="secondary",
304
+ size="lg",
305
+ )
306
+
307
+ # Placeholder for removed audio_uploads_accordion (for compatibility)
308
+ audio_uploads_accordion = gr.Column(visible=False)
309
+
310
+ # Legacy cover_mode_group (hidden, for backward compatibility)
311
+ cover_mode_group = gr.Column(visible=False)
312
+ # Legacy convert button (hidden, for backward compatibility)
313
+ convert_src_to_codes_btn = gr.Button("Convert to Codes", visible=False)
314
+
315
+ # ==================== Repaint Mode: Source + Time Range ====================
316
+ with gr.Column(visible=False) as repainting_group:
317
+ with gr.Row():
318
+ repainting_start = gr.Number(
319
+ label="Start (seconds)",
320
+ value=0.0,
321
+ step=0.1,
322
+ scale=1,
323
+ )
324
+ repainting_end = gr.Number(
325
+ label="End (seconds, -1 for end)",
326
+ value=-1,
327
+ minimum=-1,
328
+ step=0.1,
329
+ scale=1,
330
+ )
331
+
332
+ # ==================== Optional Parameters ====================
333
+ with gr.Accordion("⚙️ Optional Parameters", open=False, visible=False) as optional_params_accordion:
334
+ pass
335
+
336
+ # ==================== Advanced Settings ====================
337
+ with gr.Accordion("🔧 Advanced Settings", open=False) as advanced_options_accordion:
338
+ with gr.Row():
339
+ bpm = gr.Number(
340
+ label="BPM (optional)",
341
+ value=0,
342
+ step=1,
343
+ info="leave empty for N/A",
344
+ scale=1,
345
+ )
346
+ key_scale = gr.Textbox(
347
+ label="Key Signature (optional)",
348
+ placeholder="Leave empty for N/A",
349
+ value="",
350
+ info="A-G, #/♭, major/minor",
351
+ scale=1,
352
+ )
353
+ time_signature = gr.Dropdown(
354
+ choices=["", "2", "3", "4"],
355
+ value="",
356
+ label="Time Signature (optional)",
357
+ allow_custom_value=True,
358
+ info="2/4, 3/4, 4/4...",
359
+ scale=1,
360
+ )
361
+ audio_duration = gr.Number(
362
+ label="Audio Duration (seconds)",
363
+ value=-1,
364
+ minimum=-1,
365
+ maximum=600.0,
366
+ step=1,
367
+ info="Use -1 for random",
368
+ scale=1,
369
+ )
370
+ vocal_language = gr.Dropdown(
371
+ choices=VALID_LANGUAGES,
372
+ value="unknown",
373
+ label="Vocal Language",
374
+ allow_custom_value=True,
375
+ info="use `unknown` for instrumental",
376
+ scale=1,
377
+ )
378
+ batch_size_input = gr.Number(
379
+ label="batch size",
380
+ info="max 8",
381
+ value=2,
382
+ minimum=1,
383
+ maximum=8,
384
+ step=1,
385
+ scale=1,
386
+ )
387
+
388
+ # Row 1: DiT Inference Steps, Seed, Audio Format
389
+ with gr.Row():
390
+ inference_steps = gr.Slider(
391
+ minimum=1,
392
+ maximum=20,
393
+ value=8,
394
+ step=1,
395
+ label="DiT Inference Steps",
396
+ info="Turbo: max 8, Base: max 200",
397
+ )
398
+ seed = gr.Textbox(
399
+ label="Seed",
400
+ value="-1",
401
+ info="Use comma-separated values for batches",
402
+ )
403
+ audio_format = gr.Dropdown(
404
+ choices=["mp3", "flac"],
405
+ value="mp3",
406
+ label="Audio Format",
407
+ info="Audio format for saved files",
408
+ )
409
+
410
+ # Row 2: Shift, Random Seed, Inference Method
411
+ with gr.Row():
412
+ shift = gr.Slider(
413
+ minimum=1.0,
414
+ maximum=5.0,
415
+ value=3.0,
416
+ step=0.1,
417
+ label="Shift",
418
+ info="Timestep shift factor for base models (range 1.0-5.0, default 3.0). Not effective for turbo models.",
419
+ )
420
+ random_seed_checkbox = gr.Checkbox(
421
+ label="Random Seed",
422
+ value=True,
423
+ info="Enable to auto-generate seeds",
424
+ )
425
+ infer_method = gr.Dropdown(
426
+ choices=["ode", "sde"],
427
+ value="ode",
428
+ label="Inference Method",
429
+ info="Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
430
+ )
431
+
432
+ # Row 3: Custom Timesteps (full width)
433
+ custom_timesteps = gr.Textbox(
434
+ label="Custom Timesteps",
435
+ placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
436
+ value="",
437
+ info="Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
438
+ )
439
+
440
+ # Section: LM Generation Parameters
441
+ gr.HTML("<h4>🎵 LM Generation Parameters</h4>")
442
+
443
+ # Row 4: LM Temperature, LM CFG Scale, LM Top-K, LM Top-P
444
+ with gr.Row():
445
+ lm_temperature = gr.Slider(
446
+ minimum=0.0,
447
+ maximum=2.0,
448
+ value=0.85,
449
+ step=0.05,
450
+ label="LM Temperature",
451
+ info="5Hz LM temperature (higher = more random)",
452
+ )
453
+ lm_cfg_scale = gr.Slider(
454
+ minimum=1.0,
455
+ maximum=3.0,
456
+ value=2.0,
457
+ step=0.1,
458
+ label="LM CFG Scale",
459
+ info="5Hz LM CFG (1.0 = no CFG)",
460
+ )
461
+ lm_top_k = gr.Slider(
462
+ minimum=0,
463
+ maximum=100,
464
+ value=0,
465
+ step=1,
466
+ label="LM Top-K",
467
+ info="Top-k (0 = disabled)",
468
+ )
469
+ lm_top_p = gr.Slider(
470
+ minimum=0.0,
471
+ maximum=1.0,
472
+ value=0.9,
473
+ step=0.01,
474
+ label="LM Top-P",
475
+ info="Top-p (1.0 = disabled)",
476
+ )
477
+
478
+ # Row 5: LM Negative Prompt (full width)
479
+ lm_negative_prompt = gr.Textbox(
480
+ label="LM Negative Prompt",
481
+ value="NO USER INPUT",
482
+ placeholder="Things to avoid in generation...",
483
+ lines=2,
484
+ info="Negative prompt (use when LM CFG Scale > 1.0)",
485
+ )
486
+ # audio_cover_strength remains hidden for now
487
+ audio_cover_strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, visible=False)
488
+
489
+ # Note: audio_duration, bpm, key_scale, time_signature are now visible in Optional Parameters
490
+ # ==================== Generate Button Row ====================
491
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
492
+ with gr.Row(equal_height=True):
493
+ # Left: Thinking and Instrumental checkboxes
494
+ with gr.Column(scale=1, min_width=120):
495
+ think_checkbox = gr.Checkbox(
496
+ label="Thinking",
497
+ value=True,
498
+ )
499
+ instrumental_checkbox = gr.Checkbox(
500
+ label="Instrumental",
501
+ value=False,
502
+ )
503
+
504
+ # Center: Generate button
505
+ with gr.Column(scale=4):
506
+ generate_btn = gr.Button(
507
+ "🎵 Generate Music",
508
+ variant="primary",
509
+ size="lg",
510
+ interactive=generate_btn_interactive,
511
+ )
512
+
513
+ # Right: auto_score, auto_lrc
514
+ with gr.Column(scale=1, min_width=120):
515
+ auto_score = gr.Checkbox(
516
+ label="Get Scores",
517
+ value=False,
518
+ )
519
+ auto_lrc = gr.Checkbox(
520
+ label="Get LRC",
521
+ value=False,
522
+ )
523
+
524
+ # ==================== Hidden Components (for internal use) ====================
525
+ # These are needed for event handlers but not shown in UI
526
+
527
+ # Task type (set automatically based on generation_mode)
528
+ actual_model = init_params.get('config_path', 'acestep-v15-turbo') if service_pre_initialized else 'acestep-v15-turbo'
529
+ actual_model_lower = (actual_model or "").lower()
530
+ if "turbo" in actual_model_lower:
531
+ initial_task_choices = TASK_TYPES_TURBO
532
+ else:
533
+ initial_task_choices = TASK_TYPES_BASE
534
+
535
+ task_type = gr.Dropdown(
536
+ choices=initial_task_choices,
537
+ value="text2music",
538
+ visible=False,
539
+ )
540
+
541
+ instruction_display_gen = gr.Textbox(
542
+ value=DEFAULT_DIT_INSTRUCTION,
543
+ visible=False,
544
+ )
545
+
546
+ track_name = gr.Dropdown(
547
+ choices=TRACK_NAMES,
548
+ value=None,
549
+ visible=False,
550
+ )
551
+
552
+ complete_track_classes = gr.CheckboxGroup(
553
+ choices=TRACK_NAMES,
554
+ visible=False,
555
+ )
556
+
557
+ # Note: lyrics, vocal_language, instrumental_checkbox, format_btn are now visible in custom_mode_content
558
+
559
+ # Hidden advanced settings (keep defaults)
560
+ # Note: Most parameters are now visible in Advanced Settings section above
561
+ guidance_scale = gr.Slider(value=7.0, visible=False)
562
+ use_adg = gr.Checkbox(value=False, visible=False)
563
+ cfg_interval_start = gr.Slider(value=0.0, visible=False)
564
+ cfg_interval_end = gr.Slider(value=1.0, visible=False)
565
+
566
+ # LM parameters (remaining hidden ones)
567
+ use_cot_metas = gr.Checkbox(value=True, visible=False)
568
+ use_cot_caption = gr.Checkbox(value=True, visible=False)
569
+ use_cot_language = gr.Checkbox(value=True, visible=False)
570
+ constrained_decoding_debug = gr.Checkbox(value=False, visible=False)
571
+ allow_lm_batch = gr.Checkbox(value=True, visible=False)
572
+ lm_batch_chunk_size = gr.Number(value=8, visible=False)
573
+ score_scale = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, visible=False)
574
+ autogen_checkbox = gr.Checkbox(value=False, visible=False)
575
+
576
+ # Transcribe button (hidden)
577
+ transcribe_btn = gr.Button(value="Transcribe", visible=False)
578
+ text2music_audio_codes_group = gr.Group(visible=False)
579
+
580
+ # Note: format_btn is now visible in custom_mode_content
581
+
582
+ # Load file button (hidden for now)
583
+ load_file = gr.UploadButton(
584
+ label="Load",
585
+ file_types=[".json"],
586
+ file_count="single",
587
+ visible=False,
588
+ )
589
+
590
+ # Caption/Lyrics accordions (not used in new UI but needed for compatibility)
591
+ caption_accordion = gr.Accordion("Caption", visible=False)
592
+ lyrics_accordion = gr.Accordion("Lyrics", visible=False)
593
+ # Note: optional_params_accordion is now visible above
594
+
595
+ return {
596
+ "service_config_accordion": service_config_accordion,
597
+ "language_dropdown": language_dropdown,
598
+ "checkpoint_dropdown": checkpoint_dropdown,
599
+ "refresh_btn": refresh_btn,
600
+ "config_path": config_path,
601
+ "device": device,
602
+ "init_btn": init_btn,
603
+ "init_status": init_status,
604
+ "lm_model_path": lm_model_path,
605
+ "init_llm_checkbox": init_llm_checkbox,
606
+ "backend_dropdown": backend_dropdown,
607
+ "use_flash_attention_checkbox": use_flash_attention_checkbox,
608
+ "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
609
+ "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
610
+ # LoRA components
611
+ "lora_path": lora_path,
612
+ "load_lora_btn": load_lora_btn,
613
+ "unload_lora_btn": unload_lora_btn,
614
+ "use_lora_checkbox": use_lora_checkbox,
615
+ "lora_status": lora_status,
616
+ # DiT model selector
617
+ "dit_model_selector": dit_model_selector,
618
+ "task_type": task_type,
619
+ "instruction_display_gen": instruction_display_gen,
620
+ "track_name": track_name,
621
+ "complete_track_classes": complete_track_classes,
622
+ "audio_uploads_accordion": audio_uploads_accordion,
623
+ "reference_audio": reference_audio,
624
+ "src_audio": src_audio,
625
+ "convert_src_to_codes_btn": convert_src_to_codes_btn,
626
+ "text2music_audio_code_string": text2music_audio_code_string,
627
+ "transcribe_btn": transcribe_btn,
628
+ "text2music_audio_codes_group": text2music_audio_codes_group,
629
+ "lm_temperature": lm_temperature,
630
+ "lm_cfg_scale": lm_cfg_scale,
631
+ "lm_top_k": lm_top_k,
632
+ "lm_top_p": lm_top_p,
633
+ "lm_negative_prompt": lm_negative_prompt,
634
+ "use_cot_metas": use_cot_metas,
635
+ "use_cot_caption": use_cot_caption,
636
+ "use_cot_language": use_cot_language,
637
+ "repainting_group": repainting_group,
638
+ "repainting_start": repainting_start,
639
+ "repainting_end": repainting_end,
640
+ "audio_cover_strength": audio_cover_strength,
641
+ # Generation mode components
642
+ "generation_mode": generation_mode,
643
+ "simple_mode_group": simple_mode_group,
644
+ "simple_query_input": simple_query_input,
645
+ "random_desc_btn": random_desc_btn,
646
+ "simple_instrumental_checkbox": simple_instrumental_checkbox,
647
+ "simple_vocal_language": simple_vocal_language,
648
+ "create_sample_btn": create_sample_btn,
649
+ "simple_sample_created": simple_sample_created,
650
+ "caption_accordion": caption_accordion,
651
+ "lyrics_accordion": lyrics_accordion,
652
+ "optional_params_accordion": optional_params_accordion,
653
+ # Custom mode components
654
+ "custom_mode_content": custom_mode_content,
655
+ "cover_mode_group": cover_mode_group,
656
+ # Source audio group for Cover/Repaint
657
+ "src_audio_group": src_audio_group,
658
+ "process_src_btn": process_src_btn,
659
+ "advanced_options_accordion": advanced_options_accordion,
660
+ # Existing components
661
+ "captions": captions,
662
+ "sample_btn": sample_btn,
663
+ "load_file": load_file,
664
+ "lyrics": lyrics,
665
+ "vocal_language": vocal_language,
666
+ "bpm": bpm,
667
+ "key_scale": key_scale,
668
+ "time_signature": time_signature,
669
+ "audio_duration": audio_duration,
670
+ "batch_size_input": batch_size_input,
671
+ "inference_steps": inference_steps,
672
+ "guidance_scale": guidance_scale,
673
+ "seed": seed,
674
+ "random_seed_checkbox": random_seed_checkbox,
675
+ "use_adg": use_adg,
676
+ "cfg_interval_start": cfg_interval_start,
677
+ "cfg_interval_end": cfg_interval_end,
678
+ "shift": shift,
679
+ "infer_method": infer_method,
680
+ "custom_timesteps": custom_timesteps,
681
+ "audio_format": audio_format,
682
+ "think_checkbox": think_checkbox,
683
+ "autogen_checkbox": autogen_checkbox,
684
+ "generate_btn": generate_btn,
685
+ "instrumental_checkbox": instrumental_checkbox,
686
+ "format_btn": format_btn,
687
+ "constrained_decoding_debug": constrained_decoding_debug,
688
+ "score_scale": score_scale,
689
+ "allow_lm_batch": allow_lm_batch,
690
+ "auto_score": auto_score,
691
+ "auto_lrc": auto_lrc,
692
+ "lm_batch_chunk_size": lm_batch_chunk_size,
693
+ }
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/result.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Results Section Module
3
+ Contains results display section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import t
7
+
8
+
9
+ def create_results_section(dit_handler) -> dict:
10
+ """Create results display section"""
11
+ with gr.Accordion(t("results.title"), open=True):
12
+ # Hidden state to store LM-generated metadata
13
+ lm_metadata_state = gr.State(value=None)
14
+
15
+ # Hidden state to track if caption/metadata is from formatted source (LM/transcription)
16
+ is_format_caption_state = gr.State(value=False)
17
+
18
+ # Batch management states
19
+ current_batch_index = gr.State(value=0) # Currently displayed batch index
20
+ total_batches = gr.State(value=1) # Total number of batches generated
21
+ batch_queue = gr.State(value={}) # Dictionary storing all batch data
22
+ generation_params_state = gr.State(value={}) # Store generation parameters for next batches
23
+ is_generating_background = gr.State(value=False) # Background generation flag
24
+
25
+ # All audio components in one row with dynamic visibility
26
+ with gr.Row():
27
+ with gr.Column(visible=True) as audio_col_1:
28
+ generated_audio_1 = gr.Audio(
29
+ label=t("results.generated_music", n=1),
30
+ type="filepath",
31
+ interactive=False,
32
+ buttons=[]
33
+ )
34
+ with gr.Row(equal_height=True):
35
+ send_to_cover_btn_1 = gr.Button(
36
+ t("results.send_to_cover_btn"),
37
+ variant="secondary",
38
+ size="sm",
39
+ scale=1
40
+ )
41
+ send_to_repaint_btn_1 = gr.Button(
42
+ t("results.send_to_repaint_btn"),
43
+ variant="secondary",
44
+ size="sm",
45
+ scale=1
46
+ )
47
+ save_btn_1 = gr.Button(
48
+ t("results.save_btn"),
49
+ variant="primary",
50
+ size="sm",
51
+ scale=1
52
+ )
53
+ score_btn_1 = gr.Button(
54
+ t("results.score_btn"),
55
+ variant="secondary",
56
+ size="sm",
57
+ scale=1,
58
+ visible=False
59
+ )
60
+ lrc_btn_1 = gr.Button(
61
+ t("results.lrc_btn"),
62
+ variant="secondary",
63
+ size="sm",
64
+ scale=1,
65
+ visible=False
66
+ )
67
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
68
+ score_display_1 = gr.Textbox(
69
+ label=t("results.quality_score_label", n=1),
70
+ interactive=False,
71
+ buttons=["copy"],
72
+ lines=6,
73
+ max_lines=6,
74
+ visible=True
75
+ )
76
+ lrc_display_1 = gr.Textbox(
77
+ label=t("results.lrc_label", n=1),
78
+ interactive=True,
79
+ buttons=["copy"],
80
+ lines=8,
81
+ max_lines=8,
82
+ visible=True
83
+ )
84
+ codes_display_1 = gr.Textbox(
85
+ label=t("results.codes_label", n=1),
86
+ interactive=False,
87
+ buttons=["copy"],
88
+ lines=4,
89
+ max_lines=4,
90
+ visible=True
91
+ )
92
+ with gr.Column(visible=True) as audio_col_2:
93
+ generated_audio_2 = gr.Audio(
94
+ label=t("results.generated_music", n=2),
95
+ type="filepath",
96
+ interactive=False,
97
+ buttons=[]
98
+ )
99
+ with gr.Row(equal_height=True):
100
+ send_to_cover_btn_2 = gr.Button(
101
+ t("results.send_to_cover_btn"),
102
+ variant="secondary",
103
+ size="sm",
104
+ scale=1
105
+ )
106
+ send_to_repaint_btn_2 = gr.Button(
107
+ t("results.send_to_repaint_btn"),
108
+ variant="secondary",
109
+ size="sm",
110
+ scale=1
111
+ )
112
+ save_btn_2 = gr.Button(
113
+ t("results.save_btn"),
114
+ variant="primary",
115
+ size="sm",
116
+ scale=1
117
+ )
118
+ score_btn_2 = gr.Button(
119
+ t("results.score_btn"),
120
+ variant="secondary",
121
+ size="sm",
122
+ scale=1,
123
+ visible=False
124
+ )
125
+ lrc_btn_2 = gr.Button(
126
+ t("results.lrc_btn"),
127
+ variant="secondary",
128
+ size="sm",
129
+ scale=1,
130
+ visible=False
131
+ )
132
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
133
+ score_display_2 = gr.Textbox(
134
+ label=t("results.quality_score_label", n=2),
135
+ interactive=False,
136
+ buttons=["copy"],
137
+ lines=6,
138
+ max_lines=6,
139
+ visible=True
140
+ )
141
+ lrc_display_2 = gr.Textbox(
142
+ label=t("results.lrc_label", n=2),
143
+ interactive=True,
144
+ buttons=["copy"],
145
+ lines=8,
146
+ max_lines=8,
147
+ visible=True
148
+ )
149
+ codes_display_2 = gr.Textbox(
150
+ label=t("results.codes_label", n=2),
151
+ interactive=False,
152
+ buttons=["copy"],
153
+ lines=4,
154
+ max_lines=4,
155
+ visible=True
156
+ )
157
+ with gr.Column(visible=False) as audio_col_3:
158
+ generated_audio_3 = gr.Audio(
159
+ label=t("results.generated_music", n=3),
160
+ type="filepath",
161
+ interactive=False,
162
+ buttons=[]
163
+ )
164
+ with gr.Row(equal_height=True):
165
+ send_to_cover_btn_3 = gr.Button(
166
+ t("results.send_to_cover_btn"),
167
+ variant="secondary",
168
+ size="sm",
169
+ scale=1
170
+ )
171
+ send_to_repaint_btn_3 = gr.Button(
172
+ t("results.send_to_repaint_btn"),
173
+ variant="secondary",
174
+ size="sm",
175
+ scale=1
176
+ )
177
+ save_btn_3 = gr.Button(
178
+ t("results.save_btn"),
179
+ variant="primary",
180
+ size="sm",
181
+ scale=1
182
+ )
183
+ score_btn_3 = gr.Button(
184
+ t("results.score_btn"),
185
+ variant="secondary",
186
+ size="sm",
187
+ scale=1,
188
+ visible=False
189
+ )
190
+ lrc_btn_3 = gr.Button(
191
+ t("results.lrc_btn"),
192
+ variant="secondary",
193
+ size="sm",
194
+ scale=1,
195
+ visible=False
196
+ )
197
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
198
+ score_display_3 = gr.Textbox(
199
+ label=t("results.quality_score_label", n=3),
200
+ interactive=False,
201
+ buttons=["copy"],
202
+ lines=6,
203
+ max_lines=6,
204
+ visible=True
205
+ )
206
+ lrc_display_3 = gr.Textbox(
207
+ label=t("results.lrc_label", n=3),
208
+ interactive=True,
209
+ buttons=["copy"],
210
+ lines=8,
211
+ max_lines=8,
212
+ visible=True
213
+ )
214
+ codes_display_3 = gr.Textbox(
215
+ label=t("results.codes_label", n=3),
216
+ interactive=False,
217
+ buttons=["copy"],
218
+ lines=4,
219
+ max_lines=4,
220
+ visible=True
221
+ )
222
+ with gr.Column(visible=False) as audio_col_4:
223
+ generated_audio_4 = gr.Audio(
224
+ label=t("results.generated_music", n=4),
225
+ type="filepath",
226
+ interactive=False,
227
+ buttons=[]
228
+ )
229
+ with gr.Row(equal_height=True):
230
+ send_to_cover_btn_4 = gr.Button(
231
+ t("results.send_to_cover_btn"),
232
+ variant="secondary",
233
+ size="sm",
234
+ scale=1
235
+ )
236
+ send_to_repaint_btn_4 = gr.Button(
237
+ t("results.send_to_repaint_btn"),
238
+ variant="secondary",
239
+ size="sm",
240
+ scale=1
241
+ )
242
+ save_btn_4 = gr.Button(
243
+ t("results.save_btn"),
244
+ variant="primary",
245
+ size="sm",
246
+ scale=1
247
+ )
248
+ score_btn_4 = gr.Button(
249
+ t("results.score_btn"),
250
+ variant="secondary",
251
+ size="sm",
252
+ scale=1,
253
+ visible=False
254
+ )
255
+ lrc_btn_4 = gr.Button(
256
+ t("results.lrc_btn"),
257
+ variant="secondary",
258
+ size="sm",
259
+ scale=1,
260
+ visible=False
261
+ )
262
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
263
+ score_display_4 = gr.Textbox(
264
+ label=t("results.quality_score_label", n=4),
265
+ interactive=False,
266
+ buttons=["copy"],
267
+ lines=6,
268
+ max_lines=6,
269
+ visible=True
270
+ )
271
+ lrc_display_4 = gr.Textbox(
272
+ label=t("results.lrc_label", n=4),
273
+ interactive=True,
274
+ buttons=["copy"],
275
+ lines=8,
276
+ max_lines=8,
277
+ visible=True
278
+ )
279
+ codes_display_4 = gr.Textbox(
280
+ label=t("results.codes_label", n=4),
281
+ interactive=False,
282
+ buttons=["copy"],
283
+ lines=4,
284
+ max_lines=4,
285
+ visible=True
286
+ )
287
+
288
+ # Second row for batch size 5-8 (initially hidden)
289
+ with gr.Row(visible=False) as audio_row_5_8:
290
+ with gr.Column() as audio_col_5:
291
+ generated_audio_5 = gr.Audio(
292
+ label=t("results.generated_music", n=5),
293
+ type="filepath",
294
+ interactive=False,
295
+ buttons=[]
296
+ )
297
+ with gr.Row(equal_height=True):
298
+ send_to_cover_btn_5 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
299
+ send_to_repaint_btn_5 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
300
+ save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
301
+ score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
302
+ lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
303
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
304
+ score_display_5 = gr.Textbox(
305
+ label=t("results.quality_score_label", n=5),
306
+ interactive=False,
307
+ buttons=["copy"],
308
+ lines=6,
309
+ max_lines=6,
310
+ visible=True
311
+ )
312
+ lrc_display_5 = gr.Textbox(
313
+ label=t("results.lrc_label", n=5),
314
+ interactive=True,
315
+ buttons=["copy"],
316
+ lines=8,
317
+ max_lines=8,
318
+ visible=True
319
+ )
320
+ codes_display_5 = gr.Textbox(
321
+ label=t("results.codes_label", n=5),
322
+ interactive=False,
323
+ buttons=["copy"],
324
+ lines=4,
325
+ max_lines=4,
326
+ visible=True
327
+ )
328
+ with gr.Column() as audio_col_6:
329
+ generated_audio_6 = gr.Audio(
330
+ label=t("results.generated_music", n=6),
331
+ type="filepath",
332
+ interactive=False,
333
+ buttons=[]
334
+ )
335
+ with gr.Row(equal_height=True):
336
+ send_to_cover_btn_6 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
337
+ send_to_repaint_btn_6 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
338
+ save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
339
+ score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
340
+ lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
341
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
342
+ score_display_6 = gr.Textbox(
343
+ label=t("results.quality_score_label", n=6),
344
+ interactive=False,
345
+ buttons=["copy"],
346
+ lines=6,
347
+ max_lines=6,
348
+ visible=True
349
+ )
350
+ lrc_display_6 = gr.Textbox(
351
+ label=t("results.lrc_label", n=6),
352
+ interactive=True,
353
+ buttons=["copy"],
354
+ lines=8,
355
+ max_lines=8,
356
+ visible=True
357
+ )
358
+ codes_display_6 = gr.Textbox(
359
+ label=t("results.codes_label", n=6),
360
+ interactive=False,
361
+ buttons=["copy"],
362
+ lines=4,
363
+ max_lines=4,
364
+ visible=True
365
+ )
366
+ with gr.Column() as audio_col_7:
367
+ generated_audio_7 = gr.Audio(
368
+ label=t("results.generated_music", n=7),
369
+ type="filepath",
370
+ interactive=False,
371
+ buttons=[]
372
+ )
373
+ with gr.Row(equal_height=True):
374
+ send_to_cover_btn_7 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
375
+ send_to_repaint_btn_7 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
376
+ save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
377
+ score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
378
+ lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
379
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
380
+ score_display_7 = gr.Textbox(
381
+ label=t("results.quality_score_label", n=7),
382
+ interactive=False,
383
+ buttons=["copy"],
384
+ lines=6,
385
+ max_lines=6,
386
+ visible=True
387
+ )
388
+ lrc_display_7 = gr.Textbox(
389
+ label=t("results.lrc_label", n=7),
390
+ interactive=True,
391
+ buttons=["copy"],
392
+ lines=8,
393
+ max_lines=8,
394
+ visible=True
395
+ )
396
+ codes_display_7 = gr.Textbox(
397
+ label=t("results.codes_label", n=7),
398
+ interactive=False,
399
+ buttons=["copy"],
400
+ lines=4,
401
+ max_lines=4,
402
+ visible=True
403
+ )
404
+ with gr.Column() as audio_col_8:
405
+ generated_audio_8 = gr.Audio(
406
+ label=t("results.generated_music", n=8),
407
+ type="filepath",
408
+ interactive=False,
409
+ buttons=[]
410
+ )
411
+ with gr.Row(equal_height=True):
412
+ send_to_cover_btn_8 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
413
+ send_to_repaint_btn_8 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
414
+ save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
415
+ score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
416
+ lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
417
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
418
+ score_display_8 = gr.Textbox(
419
+ label=t("results.quality_score_label", n=8),
420
+ interactive=False,
421
+ buttons=["copy"],
422
+ lines=6,
423
+ max_lines=6,
424
+ visible=True
425
+ )
426
+ lrc_display_8 = gr.Textbox(
427
+ label=t("results.lrc_label", n=8),
428
+ interactive=True,
429
+ buttons=["copy"],
430
+ lines=8,
431
+ max_lines=8,
432
+ visible=True
433
+ )
434
+ codes_display_8 = gr.Textbox(
435
+ label=t("results.codes_label", n=8),
436
+ interactive=False,
437
+ buttons=["copy"],
438
+ lines=4,
439
+ max_lines=4,
440
+ visible=True
441
+ )
442
+
443
+ status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
444
+
445
+ # Batch navigation controls (hidden for simplified UI)
446
+ with gr.Row(equal_height=True, visible=False):
447
+ prev_batch_btn = gr.Button(
448
+ t("results.prev_btn"),
449
+ variant="secondary",
450
+ interactive=False,
451
+ scale=1,
452
+ size="sm"
453
+ )
454
+ batch_indicator = gr.Textbox(
455
+ label=t("results.current_batch"),
456
+ value=t("results.batch_indicator", current=1, total=1),
457
+ interactive=False,
458
+ scale=3
459
+ )
460
+ next_batch_status = gr.Textbox(
461
+ label=t("results.next_batch_status"),
462
+ value="",
463
+ interactive=False,
464
+ scale=3
465
+ )
466
+ next_batch_btn = gr.Button(
467
+ t("results.next_btn"),
468
+ variant="primary",
469
+ interactive=False,
470
+ scale=1,
471
+ size="sm"
472
+ )
473
+
474
+ # One-click restore parameters button (hidden for simplified UI)
475
+ restore_params_btn = gr.Button(
476
+ t("results.restore_params_btn"),
477
+ variant="secondary",
478
+ interactive=False,
479
+ size="sm",
480
+ visible=False
481
+ )
482
+
483
+ with gr.Accordion(t("results.batch_results_title"), open=True):
484
+ generated_audio_batch = gr.File(
485
+ label=t("results.all_files_label"),
486
+ file_count="multiple",
487
+ interactive=False,
488
+ visible=False
489
+ )
490
+ generation_info = gr.Markdown(label=t("results.generation_details"))
491
+
492
+ return {
493
+ "lm_metadata_state": lm_metadata_state,
494
+ "is_format_caption_state": is_format_caption_state,
495
+ "current_batch_index": current_batch_index,
496
+ "total_batches": total_batches,
497
+ "batch_queue": batch_queue,
498
+ "generation_params_state": generation_params_state,
499
+ "is_generating_background": is_generating_background,
500
+ "status_output": status_output,
501
+ "prev_batch_btn": prev_batch_btn,
502
+ "batch_indicator": batch_indicator,
503
+ "next_batch_btn": next_batch_btn,
504
+ "next_batch_status": next_batch_status,
505
+ "restore_params_btn": restore_params_btn,
506
+ "generated_audio_1": generated_audio_1,
507
+ "generated_audio_2": generated_audio_2,
508
+ "generated_audio_3": generated_audio_3,
509
+ "generated_audio_4": generated_audio_4,
510
+ "generated_audio_5": generated_audio_5,
511
+ "generated_audio_6": generated_audio_6,
512
+ "generated_audio_7": generated_audio_7,
513
+ "generated_audio_8": generated_audio_8,
514
+ "audio_row_5_8": audio_row_5_8,
515
+ "audio_col_1": audio_col_1,
516
+ "audio_col_2": audio_col_2,
517
+ "audio_col_3": audio_col_3,
518
+ "audio_col_4": audio_col_4,
519
+ "audio_col_5": audio_col_5,
520
+ "audio_col_6": audio_col_6,
521
+ "audio_col_7": audio_col_7,
522
+ "audio_col_8": audio_col_8,
523
+ "send_to_cover_btn_1": send_to_cover_btn_1,
524
+ "send_to_cover_btn_2": send_to_cover_btn_2,
525
+ "send_to_cover_btn_3": send_to_cover_btn_3,
526
+ "send_to_cover_btn_4": send_to_cover_btn_4,
527
+ "send_to_cover_btn_5": send_to_cover_btn_5,
528
+ "send_to_cover_btn_6": send_to_cover_btn_6,
529
+ "send_to_cover_btn_7": send_to_cover_btn_7,
530
+ "send_to_cover_btn_8": send_to_cover_btn_8,
531
+ "send_to_repaint_btn_1": send_to_repaint_btn_1,
532
+ "send_to_repaint_btn_2": send_to_repaint_btn_2,
533
+ "send_to_repaint_btn_3": send_to_repaint_btn_3,
534
+ "send_to_repaint_btn_4": send_to_repaint_btn_4,
535
+ "send_to_repaint_btn_5": send_to_repaint_btn_5,
536
+ "send_to_repaint_btn_6": send_to_repaint_btn_6,
537
+ "send_to_repaint_btn_7": send_to_repaint_btn_7,
538
+ "send_to_repaint_btn_8": send_to_repaint_btn_8,
539
+ "save_btn_1": save_btn_1,
540
+ "save_btn_2": save_btn_2,
541
+ "save_btn_3": save_btn_3,
542
+ "save_btn_4": save_btn_4,
543
+ "save_btn_5": save_btn_5,
544
+ "save_btn_6": save_btn_6,
545
+ "save_btn_7": save_btn_7,
546
+ "save_btn_8": save_btn_8,
547
+ "score_btn_1": score_btn_1,
548
+ "score_btn_2": score_btn_2,
549
+ "score_btn_3": score_btn_3,
550
+ "score_btn_4": score_btn_4,
551
+ "score_btn_5": score_btn_5,
552
+ "score_btn_6": score_btn_6,
553
+ "score_btn_7": score_btn_7,
554
+ "score_btn_8": score_btn_8,
555
+ "score_display_1": score_display_1,
556
+ "score_display_2": score_display_2,
557
+ "score_display_3": score_display_3,
558
+ "score_display_4": score_display_4,
559
+ "score_display_5": score_display_5,
560
+ "score_display_6": score_display_6,
561
+ "score_display_7": score_display_7,
562
+ "score_display_8": score_display_8,
563
+ "codes_display_1": codes_display_1,
564
+ "codes_display_2": codes_display_2,
565
+ "codes_display_3": codes_display_3,
566
+ "codes_display_4": codes_display_4,
567
+ "codes_display_5": codes_display_5,
568
+ "codes_display_6": codes_display_6,
569
+ "codes_display_7": codes_display_7,
570
+ "codes_display_8": codes_display_8,
571
+ "lrc_btn_1": lrc_btn_1,
572
+ "lrc_btn_2": lrc_btn_2,
573
+ "lrc_btn_3": lrc_btn_3,
574
+ "lrc_btn_4": lrc_btn_4,
575
+ "lrc_btn_5": lrc_btn_5,
576
+ "lrc_btn_6": lrc_btn_6,
577
+ "lrc_btn_7": lrc_btn_7,
578
+ "lrc_btn_8": lrc_btn_8,
579
+ "lrc_display_1": lrc_display_1,
580
+ "lrc_display_2": lrc_display_2,
581
+ "lrc_display_3": lrc_display_3,
582
+ "lrc_display_4": lrc_display_4,
583
+ "lrc_display_5": lrc_display_5,
584
+ "lrc_display_6": lrc_display_6,
585
+ "lrc_display_7": lrc_display_7,
586
+ "lrc_display_8": lrc_display_8,
587
+ "details_accordion_1": details_accordion_1,
588
+ "details_accordion_2": details_accordion_2,
589
+ "details_accordion_3": details_accordion_3,
590
+ "details_accordion_4": details_accordion_4,
591
+ "details_accordion_5": details_accordion_5,
592
+ "details_accordion_6": details_accordion_6,
593
+ "details_accordion_7": details_accordion_7,
594
+ "details_accordion_8": details_accordion_8,
595
+ "generated_audio_batch": generated_audio_batch,
596
+ "generation_info": generation_info,
597
+ }
598
+
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/training.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Training Tab Module
3
+
4
+ Contains the dataset builder and LoRA training interface components.
5
+ """
6
+
7
+ import os
8
+ import gradio as gr
9
+ from acestep.gradio_ui.i18n import t
10
+
11
+
12
+ def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
13
+ """Create the training tab section with dataset builder and training controls.
14
+
15
+ Args:
16
+ dit_handler: DiT handler instance
17
+ llm_handler: LLM handler instance
18
+ init_params: Dictionary containing initialization parameters and state.
19
+ If None, service will not be pre-initialized.
20
+
21
+ Returns:
22
+ Dictionary of Gradio components for event handling
23
+ """
24
+ # Check if running in service mode (hide training tab)
25
+ service_mode = init_params is not None and init_params.get('service_mode', False)
26
+
27
+ with gr.Tab("🎓 LoRA Training", visible=not service_mode):
28
+ gr.HTML("""
29
+ <div style="text-align: center; padding: 10px; margin-bottom: 15px;">
30
+ <h2>🎵 LoRA Training for ACE-Step</h2>
31
+ <p>Build datasets from your audio files and train custom LoRA adapters</p>
32
+ </div>
33
+ """)
34
+
35
+ with gr.Tabs():
36
+ # ==================== Dataset Builder Tab ====================
37
+ with gr.Tab("📁 Dataset Builder"):
38
+ # ========== Load Existing OR Scan New ==========
39
+ gr.HTML("""
40
+ <div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
41
+ <h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
42
+ <p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
43
+ </div>
44
+ """)
45
+
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ gr.HTML("<h4>📂 Load Existing Dataset</h4>")
49
+ with gr.Row():
50
+ load_json_path = gr.Textbox(
51
+ label="Dataset JSON Path",
52
+ placeholder="./datasets/my_lora_dataset.json",
53
+ info="Load a previously saved dataset",
54
+ scale=3,
55
+ )
56
+ load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
57
+ load_json_status = gr.Textbox(
58
+ label="Load Status",
59
+ interactive=False,
60
+ )
61
+
62
+ with gr.Column(scale=1):
63
+ gr.HTML("<h4>🔍 Scan New Directory</h4>")
64
+ with gr.Row():
65
+ audio_directory = gr.Textbox(
66
+ label="Audio Directory Path",
67
+ placeholder="/path/to/your/audio/folder",
68
+ info="Scan for audio files (wav, mp3, flac, ogg, opus)",
69
+ scale=3,
70
+ )
71
+ scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
72
+ scan_status = gr.Textbox(
73
+ label="Scan Status",
74
+ interactive=False,
75
+ )
76
+
77
+ gr.HTML("<hr>")
78
+
79
+ with gr.Row():
80
+ with gr.Column(scale=2):
81
+
82
+ # Audio files table
83
+ audio_files_table = gr.Dataframe(
84
+ headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
85
+ datatype=["number", "str", "str", "str", "str", "str", "str"],
86
+ label="Found Audio Files",
87
+ interactive=False,
88
+ wrap=True,
89
+ )
90
+
91
+ with gr.Column(scale=1):
92
+ gr.HTML("<h3>⚙️ Dataset Settings</h3>")
93
+
94
+ dataset_name = gr.Textbox(
95
+ label="Dataset Name",
96
+ value="my_lora_dataset",
97
+ placeholder="Enter dataset name",
98
+ )
99
+
100
+ all_instrumental = gr.Checkbox(
101
+ label="All Instrumental",
102
+ value=True,
103
+ info="Check if all tracks are instrumental (no vocals)",
104
+ )
105
+
106
+ need_lyrics = gr.Checkbox(
107
+ label="Transcribe Lyrics",
108
+ value=False,
109
+ info="Attempt to transcribe lyrics (slower)",
110
+ interactive=False, # Disabled for now
111
+ )
112
+
113
+ custom_tag = gr.Textbox(
114
+ label="Custom Activation Tag",
115
+ placeholder="e.g., 8bit_retro, my_style",
116
+ info="Unique tag to activate this LoRA's style",
117
+ )
118
+
119
+ tag_position = gr.Radio(
120
+ choices=[
121
+ ("Prepend (tag, caption)", "prepend"),
122
+ ("Append (caption, tag)", "append"),
123
+ ("Replace caption", "replace"),
124
+ ],
125
+ value="replace",
126
+ label="Tag Position",
127
+ info="Where to place the custom tag in the caption",
128
+ )
129
+
130
+ gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
131
+
132
+ with gr.Row():
133
+ with gr.Column(scale=3):
134
+ gr.Markdown("""
135
+ Click the button below to automatically generate metadata for all audio files using AI:
136
+ - **Caption**: Music style, genre, mood description
137
+ - **BPM**: Beats per minute
138
+ - **Key**: Musical key (e.g., C Major, Am)
139
+ - **Time Signature**: 4/4, 3/4, etc.
140
+ """)
141
+ skip_metas = gr.Checkbox(
142
+ label="Skip Metas (No LLM)",
143
+ value=False,
144
+ info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
145
+ )
146
+ with gr.Column(scale=1):
147
+ auto_label_btn = gr.Button(
148
+ "🏷️ Auto-Label All",
149
+ variant="primary",
150
+ size="lg",
151
+ )
152
+
153
+ label_progress = gr.Textbox(
154
+ label="Labeling Progress",
155
+ interactive=False,
156
+ lines=2,
157
+ )
158
+
159
+ gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ sample_selector = gr.Slider(
164
+ minimum=0,
165
+ maximum=0,
166
+ step=1,
167
+ value=0,
168
+ label="Select Sample #",
169
+ info="Choose a sample to preview and edit",
170
+ )
171
+
172
+ preview_audio = gr.Audio(
173
+ label="Audio Preview",
174
+ type="filepath",
175
+ interactive=False,
176
+ )
177
+
178
+ preview_filename = gr.Textbox(
179
+ label="Filename",
180
+ interactive=False,
181
+ )
182
+
183
+ with gr.Column(scale=2):
184
+ with gr.Row():
185
+ edit_caption = gr.Textbox(
186
+ label="Caption",
187
+ lines=3,
188
+ placeholder="Music description...",
189
+ )
190
+
191
+ with gr.Row():
192
+ edit_lyrics = gr.Textbox(
193
+ label="Lyrics",
194
+ lines=4,
195
+ placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
196
+ )
197
+
198
+ with gr.Row():
199
+ edit_bpm = gr.Number(
200
+ label="BPM",
201
+ precision=0,
202
+ )
203
+ edit_keyscale = gr.Textbox(
204
+ label="Key",
205
+ placeholder="C Major",
206
+ )
207
+ edit_timesig = gr.Dropdown(
208
+ choices=["", "2", "3", "4", "6"],
209
+ label="Time Signature",
210
+ )
211
+ edit_duration = gr.Number(
212
+ label="Duration (s)",
213
+ precision=1,
214
+ interactive=False,
215
+ )
216
+
217
+ with gr.Row():
218
+ edit_language = gr.Dropdown(
219
+ choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
220
+ value="instrumental",
221
+ label="Language",
222
+ )
223
+ edit_instrumental = gr.Checkbox(
224
+ label="Instrumental",
225
+ value=True,
226
+ )
227
+ save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
228
+
229
+ edit_status = gr.Textbox(
230
+ label="Edit Status",
231
+ interactive=False,
232
+ )
233
+
234
+ gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=3):
238
+ save_path = gr.Textbox(
239
+ label="Save Path",
240
+ value="./datasets/my_lora_dataset.json",
241
+ placeholder="./datasets/dataset_name.json",
242
+ info="Path where the dataset JSON will be saved",
243
+ )
244
+ with gr.Column(scale=1):
245
+ save_dataset_btn = gr.Button(
246
+ "💾 Save Dataset",
247
+ variant="primary",
248
+ size="lg",
249
+ )
250
+
251
+ save_status = gr.Textbox(
252
+ label="Save Status",
253
+ interactive=False,
254
+ lines=2,
255
+ )
256
+
257
+ gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
258
+
259
+ gr.Markdown("""
260
+ **Preprocessing converts your dataset to pre-computed tensors for fast training.**
261
+
262
+ You can either:
263
+ - Use the dataset from Steps 1-4 above, **OR**
264
+ - Load an existing dataset JSON file (if you've already saved one)
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=3):
269
+ load_existing_dataset_path = gr.Textbox(
270
+ label="Load Existing Dataset (Optional)",
271
+ placeholder="./datasets/my_lora_dataset.json",
272
+ info="Path to a previously saved dataset JSON file",
273
+ )
274
+ with gr.Column(scale=1):
275
+ load_existing_dataset_btn = gr.Button(
276
+ "📂 Load Dataset",
277
+ variant="secondary",
278
+ size="lg",
279
+ )
280
+
281
+ load_existing_status = gr.Textbox(
282
+ label="Load Status",
283
+ interactive=False,
284
+ )
285
+
286
+ gr.Markdown("""
287
+ This step:
288
+ - Encodes audio to VAE latents
289
+ - Encodes captions and lyrics to text embeddings
290
+ - Runs the condition encoder
291
+ - Saves all tensors to `.pt` files
292
+
293
+ ⚠️ **This requires the model to be loaded and may take a few minutes.**
294
+ """)
295
+
296
+ with gr.Row():
297
+ with gr.Column(scale=3):
298
+ preprocess_output_dir = gr.Textbox(
299
+ label="Tensor Output Directory",
300
+ value="./datasets/preprocessed_tensors",
301
+ placeholder="./datasets/preprocessed_tensors",
302
+ info="Directory to save preprocessed tensor files",
303
+ )
304
+ with gr.Column(scale=1):
305
+ preprocess_btn = gr.Button(
306
+ "⚡ Preprocess",
307
+ variant="primary",
308
+ size="lg",
309
+ )
310
+
311
+ preprocess_progress = gr.Textbox(
312
+ label="Preprocessing Progress",
313
+ interactive=False,
314
+ lines=3,
315
+ )
316
+
317
+ # ==================== Training Tab ====================
318
+ with gr.Tab("🚀 Train LoRA"):
319
+ with gr.Row():
320
+ with gr.Column(scale=2):
321
+ gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
322
+
323
+ gr.Markdown("""
324
+ Select the directory containing preprocessed tensor files (`.pt` files).
325
+ These are created in the "Dataset Builder" tab using the "Preprocess" button.
326
+ """)
327
+
328
+ training_tensor_dir = gr.Textbox(
329
+ label="Preprocessed Tensors Directory",
330
+ placeholder="./datasets/preprocessed_tensors",
331
+ value="./datasets/preprocessed_tensors",
332
+ info="Directory containing preprocessed .pt tensor files",
333
+ )
334
+
335
+ load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
336
+
337
+ training_dataset_info = gr.Textbox(
338
+ label="Dataset Info",
339
+ interactive=False,
340
+ lines=3,
341
+ )
342
+
343
+ with gr.Column(scale=1):
344
+ gr.HTML("<h3>⚙️ LoRA Settings</h3>")
345
+
346
+ lora_rank = gr.Slider(
347
+ minimum=4,
348
+ maximum=256,
349
+ step=4,
350
+ value=64,
351
+ label="LoRA Rank (r)",
352
+ info="Higher = more capacity, more memory",
353
+ )
354
+
355
+ lora_alpha = gr.Slider(
356
+ minimum=4,
357
+ maximum=512,
358
+ step=4,
359
+ value=128,
360
+ label="LoRA Alpha",
361
+ info="Scaling factor (typically 2x rank)",
362
+ )
363
+
364
+ lora_dropout = gr.Slider(
365
+ minimum=0.0,
366
+ maximum=0.5,
367
+ step=0.05,
368
+ value=0.1,
369
+ label="LoRA Dropout",
370
+ )
371
+
372
+ gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
373
+
374
+ with gr.Row():
375
+ learning_rate = gr.Number(
376
+ label="Learning Rate",
377
+ value=1e-4,
378
+ info="Start with 1e-4, adjust if needed",
379
+ )
380
+
381
+ train_epochs = gr.Slider(
382
+ minimum=100,
383
+ maximum=4000,
384
+ step=100,
385
+ value=500,
386
+ label="Max Epochs",
387
+ )
388
+
389
+ train_batch_size = gr.Slider(
390
+ minimum=1,
391
+ maximum=8,
392
+ step=1,
393
+ value=1,
394
+ label="Batch Size",
395
+ info="Increase if you have enough VRAM",
396
+ )
397
+
398
+ gradient_accumulation = gr.Slider(
399
+ minimum=1,
400
+ maximum=16,
401
+ step=1,
402
+ value=1,
403
+ label="Gradient Accumulation",
404
+ info="Effective batch = batch_size × accumulation",
405
+ )
406
+
407
+ with gr.Row():
408
+ save_every_n_epochs = gr.Slider(
409
+ minimum=50,
410
+ maximum=1000,
411
+ step=50,
412
+ value=200,
413
+ label="Save Every N Epochs",
414
+ )
415
+
416
+ training_shift = gr.Slider(
417
+ minimum=1.0,
418
+ maximum=5.0,
419
+ step=0.5,
420
+ value=3.0,
421
+ label="Shift",
422
+ info="Timestep shift for turbo model",
423
+ )
424
+
425
+ training_seed = gr.Number(
426
+ label="Seed",
427
+ value=42,
428
+ precision=0,
429
+ )
430
+
431
+ with gr.Row():
432
+ lora_output_dir = gr.Textbox(
433
+ label="Output Directory",
434
+ value="./lora_output",
435
+ placeholder="./lora_output",
436
+ info="Directory to save trained LoRA weights",
437
+ )
438
+
439
+ gr.HTML("<hr>")
440
+
441
+ with gr.Row():
442
+ with gr.Column(scale=1):
443
+ start_training_btn = gr.Button(
444
+ "🚀 Start Training",
445
+ variant="primary",
446
+ size="lg",
447
+ )
448
+ with gr.Column(scale=1):
449
+ stop_training_btn = gr.Button(
450
+ "⏹️ Stop Training",
451
+ variant="stop",
452
+ size="lg",
453
+ )
454
+
455
+ training_progress = gr.Textbox(
456
+ label="Training Progress",
457
+ interactive=False,
458
+ lines=2,
459
+ )
460
+
461
+ with gr.Row():
462
+ training_log = gr.Textbox(
463
+ label="Training Log",
464
+ interactive=False,
465
+ lines=10,
466
+ max_lines=15,
467
+ scale=1,
468
+ )
469
+ training_loss_plot = gr.LinePlot(
470
+ x="step",
471
+ y="loss",
472
+ title="Training Loss",
473
+ x_title="Step",
474
+ y_title="Loss",
475
+ scale=1,
476
+ )
477
+
478
+ gr.HTML("<hr><h3>📦 Export LoRA</h3>")
479
+
480
+ with gr.Row():
481
+ export_path = gr.Textbox(
482
+ label="Export Path",
483
+ value="./lora_output/final_lora",
484
+ placeholder="./lora_output/my_lora",
485
+ )
486
+ export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
487
+
488
+ export_status = gr.Textbox(
489
+ label="Export Status",
490
+ interactive=False,
491
+ )
492
+
493
+ # Store dataset builder state
494
+ dataset_builder_state = gr.State(None)
495
+ training_state = gr.State({"is_training": False, "should_stop": False})
496
+
497
+ return {
498
+ # Dataset Builder - Load or Scan
499
+ "load_json_path": load_json_path,
500
+ "load_json_btn": load_json_btn,
501
+ "load_json_status": load_json_status,
502
+ "audio_directory": audio_directory,
503
+ "scan_btn": scan_btn,
504
+ "scan_status": scan_status,
505
+ "audio_files_table": audio_files_table,
506
+ "dataset_name": dataset_name,
507
+ "all_instrumental": all_instrumental,
508
+ "need_lyrics": need_lyrics,
509
+ "custom_tag": custom_tag,
510
+ "tag_position": tag_position,
511
+ "skip_metas": skip_metas,
512
+ "auto_label_btn": auto_label_btn,
513
+ "label_progress": label_progress,
514
+ "sample_selector": sample_selector,
515
+ "preview_audio": preview_audio,
516
+ "preview_filename": preview_filename,
517
+ "edit_caption": edit_caption,
518
+ "edit_lyrics": edit_lyrics,
519
+ "edit_bpm": edit_bpm,
520
+ "edit_keyscale": edit_keyscale,
521
+ "edit_timesig": edit_timesig,
522
+ "edit_duration": edit_duration,
523
+ "edit_language": edit_language,
524
+ "edit_instrumental": edit_instrumental,
525
+ "save_edit_btn": save_edit_btn,
526
+ "edit_status": edit_status,
527
+ "save_path": save_path,
528
+ "save_dataset_btn": save_dataset_btn,
529
+ "save_status": save_status,
530
+ # Preprocessing
531
+ "load_existing_dataset_path": load_existing_dataset_path,
532
+ "load_existing_dataset_btn": load_existing_dataset_btn,
533
+ "load_existing_status": load_existing_status,
534
+ "preprocess_output_dir": preprocess_output_dir,
535
+ "preprocess_btn": preprocess_btn,
536
+ "preprocess_progress": preprocess_progress,
537
+ "dataset_builder_state": dataset_builder_state,
538
+ # Training
539
+ "training_tensor_dir": training_tensor_dir,
540
+ "load_dataset_btn": load_dataset_btn,
541
+ "training_dataset_info": training_dataset_info,
542
+ "lora_rank": lora_rank,
543
+ "lora_alpha": lora_alpha,
544
+ "lora_dropout": lora_dropout,
545
+ "learning_rate": learning_rate,
546
+ "train_epochs": train_epochs,
547
+ "train_batch_size": train_batch_size,
548
+ "gradient_accumulation": gradient_accumulation,
549
+ "save_every_n_epochs": save_every_n_epochs,
550
+ "training_shift": training_shift,
551
+ "training_seed": training_seed,
552
+ "lora_output_dir": lora_output_dir,
553
+ "start_training_btn": start_training_btn,
554
+ "stop_training_btn": stop_training_btn,
555
+ "training_progress": training_progress,
556
+ "training_log": training_log,
557
+ "training_loss_plot": training_loss_plot,
558
+ "export_path": export_path,
559
+ "export_lora_btn": export_lora_btn,
560
+ "export_status": export_status,
561
+ "training_state": training_state,
562
+ }
spaces/Ace-Step-v1.5/acestep/handler.py ADDED
The diff for this file is too large to render. See raw diff
 
spaces/Ace-Step-v1.5/acestep/inference.py ADDED
@@ -0,0 +1,1182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Inference API Module
3
+
4
+ This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
+ backward-compatible Gradio UI support.
7
+ """
8
+
9
+ import math
10
+ import os
11
+ import tempfile
12
+ from typing import Optional, Union, List, Dict, Any, Tuple
13
+ from dataclasses import dataclass, field, asdict
14
+ from loguru import logger
15
+
16
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
+
18
+ # HuggingFace Space environment detection
19
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
+
21
+ def _get_spaces_gpu_decorator(duration=180):
22
+ """
23
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
24
+ Returns identity decorator if not in Space environment.
25
+ """
26
+ if IS_HUGGINGFACE_SPACE:
27
+ try:
28
+ import spaces
29
+ return spaces.GPU(duration=duration)
30
+ except ImportError:
31
+ logger.warning("spaces package not found, GPU decorator disabled")
32
+ return lambda func: func
33
+ return lambda func: func
34
+
35
+
36
+ @dataclass
37
+ class GenerationParams:
38
+ """Configuration for music generation parameters.
39
+
40
+ Attributes:
41
+ # Text Inputs
42
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
43
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
44
+ instrumental: If True, generate instrumental music regardless of lyrics.
45
+
46
+ # Music Metadata
47
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
48
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
49
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
50
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
51
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
52
+
53
+ # Generation Parameters
54
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
55
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
56
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
57
+
58
+ # Advanced DiT Parameters
59
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
60
+ cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
61
+ cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
62
+ shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
63
+
64
+ # Task-Specific Parameters
65
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
66
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
67
+ src_audio: Path to a source audio file for audio-to-audio tasks.
68
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
69
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
70
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
71
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
72
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
73
+
74
+ # 5Hz Language Model Parameters for CoT reasoning
75
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
76
+ lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
77
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
78
+ lm_top_k: LLM top-k sampling (0 = disabled).
79
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
80
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
81
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
82
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
83
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
84
+ """
85
+ # Required Inputs
86
+ task_type: str = "text2music"
87
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
88
+
89
+ # Audio Uploads
90
+ reference_audio: Optional[str] = None
91
+ src_audio: Optional[str] = None
92
+
93
+ # LM Codes Hints
94
+ audio_codes: str = ""
95
+
96
+ # Text Inputs
97
+ caption: str = ""
98
+ lyrics: str = ""
99
+ instrumental: bool = False
100
+
101
+ # Metadata
102
+ vocal_language: str = "unknown"
103
+ bpm: Optional[int] = None
104
+ keyscale: str = ""
105
+ timesignature: str = ""
106
+ duration: float = -1.0
107
+
108
+ # Advanced Settings
109
+ inference_steps: int = 8
110
+ seed: int = -1
111
+ guidance_scale: float = 7.0
112
+ use_adg: bool = False
113
+ cfg_interval_start: float = 0.0
114
+ cfg_interval_end: float = 1.0
115
+ shift: float = 1.0
116
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
117
+ # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
118
+ # If provided, overrides inference_steps and shift
119
+ timesteps: Optional[List[float]] = None
120
+
121
+ repainting_start: float = 0.0
122
+ repainting_end: float = -1
123
+ audio_cover_strength: float = 1.0
124
+
125
+ # 5Hz Language Model Parameters
126
+ thinking: bool = True
127
+ lm_temperature: float = 0.85
128
+ lm_cfg_scale: float = 2.0
129
+ lm_top_k: int = 0
130
+ lm_top_p: float = 0.9
131
+ lm_negative_prompt: str = "NO USER INPUT"
132
+ use_cot_metas: bool = True
133
+ use_cot_caption: bool = True
134
+ use_cot_lyrics: bool = False # TODO: not used yet
135
+ use_cot_language: bool = True
136
+ use_constrained_decoding: bool = True
137
+
138
+ cot_bpm: Optional[int] = None
139
+ cot_keyscale: str = ""
140
+ cot_timesignature: str = ""
141
+ cot_duration: Optional[float] = None
142
+ cot_vocal_language: str = "unknown"
143
+ cot_caption: str = ""
144
+ cot_lyrics: str = ""
145
+
146
+ def to_dict(self) -> Dict[str, Any]:
147
+ """Convert config to dictionary for JSON serialization."""
148
+ return asdict(self)
149
+
150
+
151
+ @dataclass
152
+ class GenerationConfig:
153
+ """Configuration for music generation.
154
+
155
+ Attributes:
156
+ batch_size: Number of audio samples to generate
157
+ allow_lm_batch: Whether to allow batch processing in LM
158
+ use_random_seed: Whether to use random seed
159
+ seeds: Seed(s) for batch generation. Can be:
160
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
161
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
162
+ - int: Single seed value (will be converted to list and padded)
163
+ lm_batch_chunk_size: Batch chunk size for LM processing
164
+ constrained_decoding_debug: Whether to enable constrained decoding debug
165
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
166
+ """
167
+ batch_size: int = 2
168
+ allow_lm_batch: bool = False
169
+ use_random_seed: bool = True
170
+ seeds: Optional[List[int]] = None
171
+ lm_batch_chunk_size: int = 8
172
+ constrained_decoding_debug: bool = False
173
+ audio_format: str = "flac" # Default to FLAC for fast saving
174
+
175
+ def to_dict(self) -> Dict[str, Any]:
176
+ """Convert config to dictionary for JSON serialization."""
177
+ return asdict(self)
178
+
179
+
180
+ @dataclass
181
+ class GenerationResult:
182
+ """Result of music generation.
183
+
184
+ Attributes:
185
+ # Audio Outputs
186
+ audios: List of audio dictionaries with paths, keys, params
187
+ status_message: Status message from generation
188
+ extra_outputs: Extra outputs from generation
189
+ success: Whether generation completed successfully
190
+ error: Error message if generation failed
191
+ """
192
+
193
+ # Audio Outputs
194
+ audios: List[Dict[str, Any]] = field(default_factory=list)
195
+ # Generation Information
196
+ status_message: str = ""
197
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
198
+ # Success Status
199
+ success: bool = True
200
+ error: Optional[str] = None
201
+
202
+ def to_dict(self) -> Dict[str, Any]:
203
+ """Convert result to dictionary for JSON serialization."""
204
+ return asdict(self)
205
+
206
+
207
+ @dataclass
208
+ class UnderstandResult:
209
+ """Result of music understanding from audio codes.
210
+
211
+ Attributes:
212
+ # Metadata Fields
213
+ caption: Generated caption describing the music
214
+ lyrics: Generated or extracted lyrics
215
+ bpm: Beats per minute (None if not detected)
216
+ duration: Duration in seconds (None if not detected)
217
+ keyscale: Musical key (e.g., "C Major")
218
+ language: Vocal language code (e.g., "en", "zh")
219
+ timesignature: Time signature (e.g., "4/4")
220
+
221
+ # Status
222
+ status_message: Status message from understanding
223
+ success: Whether understanding completed successfully
224
+ error: Error message if understanding failed
225
+ """
226
+ # Metadata Fields
227
+ caption: str = ""
228
+ lyrics: str = ""
229
+ bpm: Optional[int] = None
230
+ duration: Optional[float] = None
231
+ keyscale: str = ""
232
+ language: str = ""
233
+ timesignature: str = ""
234
+
235
+ # Status
236
+ status_message: str = ""
237
+ success: bool = True
238
+ error: Optional[str] = None
239
+
240
+ def to_dict(self) -> Dict[str, Any]:
241
+ """Convert result to dictionary for JSON serialization."""
242
+ return asdict(self)
243
+
244
+
245
+ def _update_metadata_from_lm(
246
+ metadata: Dict[str, Any],
247
+ bpm: Optional[int],
248
+ key_scale: str,
249
+ time_signature: str,
250
+ audio_duration: Optional[float],
251
+ vocal_language: str,
252
+ caption: str,
253
+ lyrics: str,
254
+ ) -> Tuple[Optional[int], str, str, Optional[float]]:
255
+ """Update metadata fields from LM output if not provided by user."""
256
+
257
+ if bpm is None and metadata.get('bpm'):
258
+ bpm_value = metadata.get('bpm')
259
+ if bpm_value not in ["N/A", ""]:
260
+ try:
261
+ bpm = int(bpm_value)
262
+ except (ValueError, TypeError):
263
+ pass
264
+
265
+ if not key_scale and metadata.get('keyscale'):
266
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
267
+ if key_scale_value != "N/A":
268
+ key_scale = key_scale_value
269
+
270
+ if not time_signature and metadata.get('timesignature'):
271
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
272
+ if time_signature_value != "N/A":
273
+ time_signature = time_signature_value
274
+
275
+ if audio_duration is None or audio_duration <= 0:
276
+ audio_duration_value = metadata.get('duration', -1)
277
+ if audio_duration_value not in ["N/A", ""]:
278
+ try:
279
+ audio_duration = float(audio_duration_value)
280
+ except (ValueError, TypeError):
281
+ pass
282
+
283
+ if not vocal_language and metadata.get('vocal_language'):
284
+ vocal_language = metadata.get('vocal_language')
285
+ if not caption and metadata.get('caption'):
286
+ caption = metadata.get('caption')
287
+ if not lyrics and metadata.get('lyrics'):
288
+ lyrics = metadata.get('lyrics')
289
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
290
+
291
+
292
+ @_get_spaces_gpu_decorator(duration=180)
293
+ def generate_music(
294
+ dit_handler,
295
+ llm_handler,
296
+ params: GenerationParams,
297
+ config: GenerationConfig,
298
+ save_dir: Optional[str] = None,
299
+ progress=None,
300
+ ) -> GenerationResult:
301
+ """Generate music using ACE-Step model with optional LM reasoning.
302
+
303
+ Args:
304
+ dit_handler: Initialized DiT model handler (AceStepHandler instance)
305
+ llm_handler: Initialized LLM handler (LLMHandler instance)
306
+ params: Generation parameters (GenerationParams instance)
307
+ config: Generation configuration (GenerationConfig instance)
308
+
309
+ Returns:
310
+ GenerationResult with generated audio files and metadata
311
+ """
312
+ try:
313
+ # Phase 1: LM-based metadata and code generation (if enabled)
314
+ audio_code_string_to_use = params.audio_codes
315
+ lm_generated_metadata = None
316
+ lm_generated_audio_codes_list = []
317
+ lm_total_time_costs = {
318
+ "phase1_time": 0.0,
319
+ "phase2_time": 0.0,
320
+ "total_time": 0.0,
321
+ }
322
+
323
+ # Extract mutable copies of metadata (will be updated by LM if needed)
324
+ bpm = params.bpm
325
+ key_scale = params.keyscale
326
+ time_signature = params.timesignature
327
+ audio_duration = params.duration
328
+ dit_input_caption = params.caption
329
+ dit_input_vocal_language = params.vocal_language
330
+ dit_input_lyrics = params.lyrics
331
+ # Determine if we need to generate audio codes
332
+ # If user has provided audio_codes, we don't need to generate them
333
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
334
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
335
+
336
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
337
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
338
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
339
+ # Note: This logic can be refined based on specific requirements
340
+ need_audio_codes = not user_provided_audio_codes
341
+
342
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
343
+ # Determine actual batch size for chunk processing
344
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
345
+
346
+ # Prepare seeds for batch generation
347
+ # Use config.seed if provided, otherwise fallback to params.seed
348
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
349
+ seed_for_generation = ""
350
+ if config.seeds is not None and len(config.seeds) > 0:
351
+ if isinstance(config.seeds, list):
352
+ # Convert List[int] to comma-separated string
353
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
354
+
355
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
356
+ # This will handle all the logic: padding with random seeds if needed, etc.
357
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
358
+
359
+ # LM-based Chain-of-Thought reasoning
360
+ # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
361
+ # and don't need LM to generate audio codes
362
+ skip_lm_tasks = {"cover", "repaint"}
363
+
364
+ # Determine if we should use LLM
365
+ # LLM is needed for:
366
+ # 1. thinking=True: generate audio codes via LM
367
+ # 2. use_cot_caption=True: enhance/generate caption via CoT
368
+ # 3. use_cot_language=True: detect vocal language via CoT
369
+ # 4. use_cot_metas=True: fill missing metadata via CoT
370
+ need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
371
+ use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
372
+ lm_status = []
373
+
374
+ if params.task_type in skip_lm_tasks:
375
+ logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
376
+
377
+ logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
378
+ f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
379
+ f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
380
+ f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
381
+
382
+ if use_lm:
383
+ # Convert sampling parameters - handle None values safely
384
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
385
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
386
+
387
+ # Build user_metadata from user-provided values
388
+ user_metadata = {}
389
+ if bpm is not None:
390
+ try:
391
+ bpm_value = float(bpm)
392
+ if bpm_value > 0:
393
+ user_metadata['bpm'] = int(bpm_value)
394
+ except (ValueError, TypeError):
395
+ pass
396
+
397
+ if key_scale and key_scale.strip():
398
+ key_scale_clean = key_scale.strip()
399
+ if key_scale_clean.lower() not in ["n/a", ""]:
400
+ user_metadata['keyscale'] = key_scale_clean
401
+
402
+ if time_signature and time_signature.strip():
403
+ time_sig_clean = time_signature.strip()
404
+ if time_sig_clean.lower() not in ["n/a", ""]:
405
+ user_metadata['timesignature'] = time_sig_clean
406
+
407
+ if audio_duration is not None:
408
+ try:
409
+ duration_value = float(audio_duration)
410
+ if duration_value > 0:
411
+ user_metadata['duration'] = int(duration_value)
412
+ except (ValueError, TypeError):
413
+ pass
414
+
415
+ user_metadata_to_pass = user_metadata if user_metadata else None
416
+
417
+ # Determine infer_type based on whether we need audio codes
418
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
419
+ # - "dit": generates only metas (single phase)
420
+ infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
421
+
422
+ # Use chunk size from config, or default to batch_size if not set
423
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
424
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
425
+
426
+ all_metadata_list = []
427
+ all_audio_codes_list = []
428
+
429
+ for chunk_idx in range(num_chunks):
430
+ chunk_start = chunk_idx * max_inference_batch_size
431
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
432
+ chunk_size = chunk_end - chunk_start
433
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
434
+
435
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
436
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
437
+
438
+ # Use the determined infer_type
439
+ # - "llm_dit" will internally run two phases (metas + codes)
440
+ # - "dit" will only run phase 1 (metas only)
441
+ result = llm_handler.generate_with_stop_condition(
442
+ caption=params.caption or "",
443
+ lyrics=params.lyrics or "",
444
+ infer_type=infer_type,
445
+ temperature=params.lm_temperature,
446
+ cfg_scale=params.lm_cfg_scale,
447
+ negative_prompt=params.lm_negative_prompt,
448
+ top_k=top_k_value,
449
+ top_p=top_p_value,
450
+ user_metadata=user_metadata_to_pass,
451
+ use_cot_caption=params.use_cot_caption,
452
+ use_cot_language=params.use_cot_language,
453
+ use_cot_metas=params.use_cot_metas,
454
+ use_constrained_decoding=params.use_constrained_decoding,
455
+ constrained_decoding_debug=config.constrained_decoding_debug,
456
+ batch_size=chunk_size,
457
+ seeds=chunk_seeds,
458
+ progress=progress,
459
+ )
460
+
461
+ # Check if LM generation failed
462
+ if not result.get("success", False):
463
+ error_msg = result.get("error", "Unknown LM error")
464
+ lm_status.append(f"❌ LM Error: {error_msg}")
465
+ # Return early with error
466
+ return GenerationResult(
467
+ audios=[],
468
+ status_message=f"❌ LM generation failed: {error_msg}",
469
+ extra_outputs={},
470
+ success=False,
471
+ error=error_msg,
472
+ )
473
+
474
+ # Extract metadata and audio_codes from result dict
475
+ if chunk_size > 1:
476
+ metadata_list = result.get("metadata", [])
477
+ audio_codes_list = result.get("audio_codes", [])
478
+ all_metadata_list.extend(metadata_list)
479
+ all_audio_codes_list.extend(audio_codes_list)
480
+ else:
481
+ metadata = result.get("metadata", {})
482
+ audio_codes = result.get("audio_codes", "")
483
+ all_metadata_list.append(metadata)
484
+ all_audio_codes_list.append(audio_codes)
485
+
486
+ # Collect time costs from LM extra_outputs
487
+ lm_extra = result.get("extra_outputs", {})
488
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
489
+ if lm_chunk_time_costs:
490
+ # Accumulate time costs from all chunks
491
+ for key in ["phase1_time", "phase2_time", "total_time"]:
492
+ if key in lm_chunk_time_costs:
493
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
494
+
495
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
496
+ lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
497
+
498
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
499
+ lm_generated_audio_codes_list = all_audio_codes_list
500
+
501
+ # Set audio_code_string_to_use based on infer_type
502
+ if infer_type == "llm_dit":
503
+ # If batch mode, use list; otherwise use single string
504
+ if actual_batch_size > 1:
505
+ audio_code_string_to_use = all_audio_codes_list
506
+ else:
507
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
508
+ else:
509
+ # For "dit" mode, keep user-provided codes or empty
510
+ audio_code_string_to_use = params.audio_codes
511
+
512
+ # Update metadata from LM if not provided by user
513
+ if lm_generated_metadata:
514
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
515
+ metadata=lm_generated_metadata,
516
+ bpm=bpm,
517
+ key_scale=key_scale,
518
+ time_signature=time_signature,
519
+ audio_duration=audio_duration,
520
+ vocal_language=dit_input_vocal_language,
521
+ caption=dit_input_caption,
522
+ lyrics=dit_input_lyrics)
523
+ if not params.bpm:
524
+ params.cot_bpm = bpm
525
+ if not params.keyscale:
526
+ params.cot_keyscale = key_scale
527
+ if not params.timesignature:
528
+ params.cot_timesignature = time_signature
529
+ if not params.duration:
530
+ params.cot_duration = audio_duration
531
+ if not params.vocal_language:
532
+ params.cot_vocal_language = vocal_language
533
+ if not params.caption:
534
+ params.cot_caption = caption
535
+ if not params.lyrics:
536
+ params.cot_lyrics = lyrics
537
+
538
+ # set cot caption and language if needed
539
+ if params.use_cot_caption:
540
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
541
+ if params.use_cot_language:
542
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
543
+
544
+ # Phase 2: DiT music generation
545
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
546
+ result = dit_handler.generate_music(
547
+ captions=dit_input_caption,
548
+ lyrics=dit_input_lyrics,
549
+ bpm=bpm,
550
+ key_scale=key_scale,
551
+ time_signature=time_signature,
552
+ vocal_language=dit_input_vocal_language,
553
+ inference_steps=params.inference_steps,
554
+ guidance_scale=params.guidance_scale,
555
+ use_random_seed=config.use_random_seed,
556
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
557
+ reference_audio=params.reference_audio,
558
+ audio_duration=audio_duration,
559
+ batch_size=config.batch_size if config.batch_size is not None else 1,
560
+ src_audio=params.src_audio,
561
+ audio_code_string=audio_code_string_to_use,
562
+ repainting_start=params.repainting_start,
563
+ repainting_end=params.repainting_end,
564
+ instruction=params.instruction,
565
+ audio_cover_strength=params.audio_cover_strength,
566
+ task_type=params.task_type,
567
+ use_adg=params.use_adg,
568
+ cfg_interval_start=params.cfg_interval_start,
569
+ cfg_interval_end=params.cfg_interval_end,
570
+ shift=params.shift,
571
+ infer_method=params.infer_method,
572
+ timesteps=params.timesteps,
573
+ progress=progress,
574
+ )
575
+
576
+ # Check if generation failed
577
+ if not result.get("success", False):
578
+ return GenerationResult(
579
+ audios=[],
580
+ status_message=result.get("status_message", ""),
581
+ extra_outputs={},
582
+ success=False,
583
+ error=result.get("error"),
584
+ )
585
+
586
+ # Extract results from dit_handler.generate_music dict
587
+ dit_audios = result.get("audios", [])
588
+ status_message = result.get("status_message", "")
589
+ dit_extra_outputs = result.get("extra_outputs", {})
590
+
591
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
592
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
593
+ seed_list = actual_seed_list
594
+
595
+ # Get base params dictionary
596
+ base_params_dict = params.to_dict()
597
+
598
+ # Save audio files using AudioSaver (format from config)
599
+ audio_format = config.audio_format if config.audio_format else "flac"
600
+ audio_saver = AudioSaver(default_format=audio_format)
601
+
602
+ # Use handler's temp_dir for saving files
603
+ if save_dir is not None:
604
+ os.makedirs(save_dir, exist_ok=True)
605
+
606
+ # Build audios list for GenerationResult with params and save files
607
+ # Audio saving and UUID generation handled here, outside of handler
608
+ audios = []
609
+ for idx, dit_audio in enumerate(dit_audios):
610
+ # Create a copy of params dict for this audio
611
+ audio_params = base_params_dict.copy()
612
+
613
+ # Update audio-specific values
614
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
615
+
616
+ # Add audio codes if batch mode
617
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
618
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
619
+
620
+ # Get audio tensor and metadata
621
+ audio_tensor = dit_audio.get("tensor")
622
+ sample_rate = dit_audio.get("sample_rate", 48000)
623
+
624
+ # Generate UUID for this audio (moved from handler)
625
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
626
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
627
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
628
+ if isinstance(audio_code_str, list):
629
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
630
+
631
+ audio_key = generate_uuid_from_params(audio_params)
632
+
633
+ # Save audio file (handled outside handler)
634
+ audio_path = None
635
+ if audio_tensor is not None and save_dir is not None:
636
+ try:
637
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
638
+ audio_path = audio_saver.save_audio(audio_tensor,
639
+ audio_file,
640
+ sample_rate=sample_rate,
641
+ format=audio_format,
642
+ channels_first=True)
643
+ except Exception as e:
644
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
645
+ audio_path = "" # Fallback to empty path
646
+
647
+ audio_dict = {
648
+ "path": audio_path or "", # File path (saved here, not in handler)
649
+ "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
650
+ "key": audio_key,
651
+ "sample_rate": sample_rate,
652
+ "params": audio_params,
653
+ }
654
+
655
+ audios.append(audio_dict)
656
+
657
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
658
+ extra_outputs = dit_extra_outputs.copy()
659
+ extra_outputs["lm_metadata"] = lm_generated_metadata
660
+
661
+ # Merge time_costs from both LM and DiT into a unified dictionary
662
+ unified_time_costs = {}
663
+
664
+ # Add LM time costs (if LM was used)
665
+ if use_lm and lm_total_time_costs:
666
+ for key, value in lm_total_time_costs.items():
667
+ unified_time_costs[f"lm_{key}"] = value
668
+
669
+ # Add DiT time costs (if available)
670
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
671
+ if dit_time_costs:
672
+ for key, value in dit_time_costs.items():
673
+ unified_time_costs[f"dit_{key}"] = value
674
+
675
+ # Calculate total pipeline time
676
+ if unified_time_costs:
677
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
678
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
679
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
680
+
681
+ # Update extra_outputs with unified time_costs
682
+ extra_outputs["time_costs"] = unified_time_costs
683
+
684
+ if lm_status:
685
+ status_message = "\n".join(lm_status) + "\n" + status_message
686
+ else:
687
+ status_message = status_message
688
+ # Create and return GenerationResult
689
+ return GenerationResult(
690
+ audios=audios,
691
+ status_message=status_message,
692
+ extra_outputs=extra_outputs,
693
+ success=True,
694
+ error=None,
695
+ )
696
+
697
+ except Exception as e:
698
+ logger.exception("Music generation failed")
699
+ return GenerationResult(
700
+ audios=[],
701
+ status_message=f"Error: {str(e)}",
702
+ extra_outputs={},
703
+ success=False,
704
+ error=str(e),
705
+ )
706
+
707
+
708
+ def understand_music(
709
+ llm_handler,
710
+ audio_codes: str,
711
+ temperature: float = 0.85,
712
+ top_k: Optional[int] = None,
713
+ top_p: Optional[float] = None,
714
+ repetition_penalty: float = 1.0,
715
+ use_constrained_decoding: bool = True,
716
+ constrained_decoding_debug: bool = False,
717
+ ) -> UnderstandResult:
718
+ """Understand music from audio codes using the 5Hz Language Model.
719
+
720
+ This function analyzes audio semantic codes and generates metadata about the music,
721
+ including caption, lyrics, BPM, duration, key scale, language, and time signature.
722
+
723
+ If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
724
+ instead of analyzing existing codes.
725
+
726
+ Note: cfg_scale and negative_prompt are not supported in understand mode.
727
+
728
+ Args:
729
+ llm_handler: Initialized LLM handler (LLMHandler instance)
730
+ audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
731
+ Use empty string or "NO USER INPUT" to generate a sample example.
732
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
733
+ top_k: Top-K sampling (None or 0 = disabled)
734
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
735
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
736
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
737
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
738
+
739
+ Returns:
740
+ UnderstandResult with parsed metadata fields and status
741
+
742
+ Example:
743
+ >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
744
+ >>> if result.success:
745
+ ... print(f"Caption: {result.caption}")
746
+ ... print(f"BPM: {result.bpm}")
747
+ ... print(f"Lyrics: {result.lyrics}")
748
+ """
749
+ # Check if LLM is initialized
750
+ if not llm_handler.llm_initialized:
751
+ return UnderstandResult(
752
+ status_message="5Hz LM not initialized. Please initialize it first.",
753
+ success=False,
754
+ error="LLM not initialized",
755
+ )
756
+
757
+ # If codes are empty, use "NO USER INPUT" to generate a sample example
758
+ if not audio_codes or not audio_codes.strip():
759
+ audio_codes = "NO USER INPUT"
760
+
761
+ try:
762
+ # Call LLM understanding
763
+ metadata, status = llm_handler.understand_audio_from_codes(
764
+ audio_codes=audio_codes,
765
+ temperature=temperature,
766
+ top_k=top_k,
767
+ top_p=top_p,
768
+ repetition_penalty=repetition_penalty,
769
+ use_constrained_decoding=use_constrained_decoding,
770
+ constrained_decoding_debug=constrained_decoding_debug,
771
+ )
772
+
773
+ # Check if LLM returned empty metadata (error case)
774
+ if not metadata:
775
+ return UnderstandResult(
776
+ status_message=status or "Failed to understand audio codes",
777
+ success=False,
778
+ error=status or "Empty metadata returned",
779
+ )
780
+
781
+ # Extract and convert fields
782
+ caption = metadata.get('caption', '')
783
+ lyrics = metadata.get('lyrics', '')
784
+ keyscale = metadata.get('keyscale', '')
785
+ language = metadata.get('language', metadata.get('vocal_language', ''))
786
+ timesignature = metadata.get('timesignature', '')
787
+
788
+ # Convert BPM to int
789
+ bpm = None
790
+ bpm_value = metadata.get('bpm')
791
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
792
+ try:
793
+ bpm = int(bpm_value)
794
+ except (ValueError, TypeError):
795
+ pass
796
+
797
+ # Convert duration to float
798
+ duration = None
799
+ duration_value = metadata.get('duration')
800
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
801
+ try:
802
+ duration = float(duration_value)
803
+ except (ValueError, TypeError):
804
+ pass
805
+
806
+ # Clean up N/A values
807
+ if keyscale == 'N/A':
808
+ keyscale = ''
809
+ if language == 'N/A':
810
+ language = ''
811
+ if timesignature == 'N/A':
812
+ timesignature = ''
813
+
814
+ return UnderstandResult(
815
+ caption=caption,
816
+ lyrics=lyrics,
817
+ bpm=bpm,
818
+ duration=duration,
819
+ keyscale=keyscale,
820
+ language=language,
821
+ timesignature=timesignature,
822
+ status_message=status,
823
+ success=True,
824
+ error=None,
825
+ )
826
+
827
+ except Exception as e:
828
+ logger.exception("Music understanding failed")
829
+ return UnderstandResult(
830
+ status_message=f"Error: {str(e)}",
831
+ success=False,
832
+ error=str(e),
833
+ )
834
+
835
+
836
+ @dataclass
837
+ class CreateSampleResult:
838
+ """Result of creating a music sample from a natural language query.
839
+
840
+ This is used by the "Simple Mode" / "Inspiration Mode" feature where users
841
+ provide a natural language description and the LLM generates a complete
842
+ sample with caption, lyrics, and metadata.
843
+
844
+ Attributes:
845
+ # Metadata Fields
846
+ caption: Generated detailed music description/caption
847
+ lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
848
+ bpm: Beats per minute (None if not generated)
849
+ duration: Duration in seconds (None if not generated)
850
+ keyscale: Musical key (e.g., "C Major")
851
+ language: Vocal language code (e.g., "en", "zh")
852
+ timesignature: Time signature (e.g., "4")
853
+ instrumental: Whether this is an instrumental piece
854
+
855
+ # Status
856
+ status_message: Status message from sample creation
857
+ success: Whether sample creation completed successfully
858
+ error: Error message if sample creation failed
859
+ """
860
+ # Metadata Fields
861
+ caption: str = ""
862
+ lyrics: str = ""
863
+ bpm: Optional[int] = None
864
+ duration: Optional[float] = None
865
+ keyscale: str = ""
866
+ language: str = ""
867
+ timesignature: str = ""
868
+ instrumental: bool = False
869
+
870
+ # Status
871
+ status_message: str = ""
872
+ success: bool = True
873
+ error: Optional[str] = None
874
+
875
+ def to_dict(self) -> Dict[str, Any]:
876
+ """Convert result to dictionary for JSON serialization."""
877
+ return asdict(self)
878
+
879
+
880
+ def create_sample(
881
+ llm_handler,
882
+ query: str,
883
+ instrumental: bool = False,
884
+ vocal_language: Optional[str] = None,
885
+ temperature: float = 0.85,
886
+ top_k: Optional[int] = None,
887
+ top_p: Optional[float] = None,
888
+ repetition_penalty: float = 1.0,
889
+ use_constrained_decoding: bool = True,
890
+ constrained_decoding_debug: bool = False,
891
+ ) -> CreateSampleResult:
892
+ """Create a music sample from a natural language query using the 5Hz Language Model.
893
+
894
+ This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
895
+ language description of music and generates a complete sample including:
896
+ - Detailed caption/description
897
+ - Lyrics (unless instrumental)
898
+ - Metadata (BPM, duration, key, language, time signature)
899
+
900
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
901
+
902
+ Args:
903
+ llm_handler: Initialized LLM handler (LLMHandler instance)
904
+ query: User's natural language music description (e.g., "a soft Bengali love song")
905
+ instrumental: Whether to generate instrumental music (no vocals)
906
+ vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
907
+ If provided, the model will be constrained to generate lyrics in this language.
908
+ If None or "unknown", no language constraint is applied.
909
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
910
+ top_k: Top-K sampling (None or 0 = disabled)
911
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
912
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
913
+ use_constrained_decoding: Whether to use FSM-based constrained decoding
914
+ constrained_decoding_debug: Whether to enable debug logging
915
+
916
+ Returns:
917
+ CreateSampleResult with generated sample fields and status
918
+
919
+ Example:
920
+ >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
921
+ >>> if result.success:
922
+ ... print(f"Caption: {result.caption}")
923
+ ... print(f"Lyrics: {result.lyrics}")
924
+ ... print(f"BPM: {result.bpm}")
925
+ """
926
+ # Check if LLM is initialized
927
+ if not llm_handler.llm_initialized:
928
+ return CreateSampleResult(
929
+ status_message="5Hz LM not initialized. Please initialize it first.",
930
+ success=False,
931
+ error="LLM not initialized",
932
+ )
933
+
934
+ try:
935
+ # Call LLM to create sample
936
+ metadata, status = llm_handler.create_sample_from_query(
937
+ query=query,
938
+ instrumental=instrumental,
939
+ vocal_language=vocal_language,
940
+ temperature=temperature,
941
+ top_k=top_k,
942
+ top_p=top_p,
943
+ repetition_penalty=repetition_penalty,
944
+ use_constrained_decoding=use_constrained_decoding,
945
+ constrained_decoding_debug=constrained_decoding_debug,
946
+ )
947
+
948
+ # Check if LLM returned empty metadata (error case)
949
+ if not metadata:
950
+ return CreateSampleResult(
951
+ status_message=status or "Failed to create sample",
952
+ success=False,
953
+ error=status or "Empty metadata returned",
954
+ )
955
+
956
+ # Extract and convert fields
957
+ caption = metadata.get('caption', '')
958
+ lyrics = metadata.get('lyrics', '')
959
+ keyscale = metadata.get('keyscale', '')
960
+ language = metadata.get('language', metadata.get('vocal_language', ''))
961
+ timesignature = metadata.get('timesignature', '')
962
+ is_instrumental = metadata.get('instrumental', instrumental)
963
+
964
+ # Convert BPM to int
965
+ bpm = None
966
+ bpm_value = metadata.get('bpm')
967
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
968
+ try:
969
+ bpm = int(bpm_value)
970
+ except (ValueError, TypeError):
971
+ pass
972
+
973
+ # Convert duration to float
974
+ duration = None
975
+ duration_value = metadata.get('duration')
976
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
977
+ try:
978
+ duration = float(duration_value)
979
+ except (ValueError, TypeError):
980
+ pass
981
+
982
+ # Clean up N/A values
983
+ if keyscale == 'N/A':
984
+ keyscale = ''
985
+ if language == 'N/A':
986
+ language = ''
987
+ if timesignature == 'N/A':
988
+ timesignature = ''
989
+
990
+ return CreateSampleResult(
991
+ caption=caption,
992
+ lyrics=lyrics,
993
+ bpm=bpm,
994
+ duration=duration,
995
+ keyscale=keyscale,
996
+ language=language,
997
+ timesignature=timesignature,
998
+ instrumental=is_instrumental,
999
+ status_message=status,
1000
+ success=True,
1001
+ error=None,
1002
+ )
1003
+
1004
+ except Exception as e:
1005
+ logger.exception("Sample creation failed")
1006
+ return CreateSampleResult(
1007
+ status_message=f"Error: {str(e)}",
1008
+ success=False,
1009
+ error=str(e),
1010
+ )
1011
+
1012
+
1013
+ @dataclass
1014
+ class FormatSampleResult:
1015
+ """Result of formatting user-provided caption and lyrics.
1016
+
1017
+ This is used by the "Format" feature where users provide caption and lyrics,
1018
+ and the LLM formats them into structured music metadata and an enhanced description.
1019
+
1020
+ Attributes:
1021
+ # Metadata Fields
1022
+ caption: Enhanced/formatted music description/caption
1023
+ lyrics: Formatted lyrics (may be same as input or reformatted)
1024
+ bpm: Beats per minute (None if not detected)
1025
+ duration: Duration in seconds (None if not detected)
1026
+ keyscale: Musical key (e.g., "C Major")
1027
+ language: Vocal language code (e.g., "en", "zh")
1028
+ timesignature: Time signature (e.g., "4")
1029
+
1030
+ # Status
1031
+ status_message: Status message from formatting
1032
+ success: Whether formatting completed successfully
1033
+ error: Error message if formatting failed
1034
+ """
1035
+ # Metadata Fields
1036
+ caption: str = ""
1037
+ lyrics: str = ""
1038
+ bpm: Optional[int] = None
1039
+ duration: Optional[float] = None
1040
+ keyscale: str = ""
1041
+ language: str = ""
1042
+ timesignature: str = ""
1043
+
1044
+ # Status
1045
+ status_message: str = ""
1046
+ success: bool = True
1047
+ error: Optional[str] = None
1048
+
1049
+ def to_dict(self) -> Dict[str, Any]:
1050
+ """Convert result to dictionary for JSON serialization."""
1051
+ return asdict(self)
1052
+
1053
+
1054
+ def format_sample(
1055
+ llm_handler,
1056
+ caption: str,
1057
+ lyrics: str,
1058
+ user_metadata: Optional[Dict[str, Any]] = None,
1059
+ temperature: float = 0.85,
1060
+ top_k: Optional[int] = None,
1061
+ top_p: Optional[float] = None,
1062
+ repetition_penalty: float = 1.0,
1063
+ use_constrained_decoding: bool = True,
1064
+ constrained_decoding_debug: bool = False,
1065
+ ) -> FormatSampleResult:
1066
+ """Format user-provided caption and lyrics using the 5Hz Language Model.
1067
+
1068
+ This function takes user input (caption and lyrics) and generates structured
1069
+ music metadata including an enhanced caption, BPM, duration, key, language,
1070
+ and time signature.
1071
+
1072
+ If user_metadata is provided, those values will be used to constrain the
1073
+ decoding, ensuring the output matches user-specified values.
1074
+
1075
+ Note: cfg_scale and negative_prompt are not supported in format mode.
1076
+
1077
+ Args:
1078
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1079
+ caption: User's caption/description (e.g., "Latin pop, reggaeton")
1080
+ lyrics: User's lyrics with structure tags
1081
+ user_metadata: Optional dict with user-provided metadata to constrain decoding.
1082
+ Supported keys: bpm, duration, keyscale, timesignature, language
1083
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1084
+ top_k: Top-K sampling (None or 0 = disabled)
1085
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1086
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1087
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1088
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1089
+
1090
+ Returns:
1091
+ FormatSampleResult with formatted metadata fields and status
1092
+
1093
+ Example:
1094
+ >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
1095
+ >>> if result.success:
1096
+ ... print(f"Caption: {result.caption}")
1097
+ ... print(f"BPM: {result.bpm}")
1098
+ ... print(f"Lyrics: {result.lyrics}")
1099
+ """
1100
+ # Check if LLM is initialized
1101
+ if not llm_handler.llm_initialized:
1102
+ return FormatSampleResult(
1103
+ status_message="5Hz LM not initialized. Please initialize it first.",
1104
+ success=False,
1105
+ error="LLM not initialized",
1106
+ )
1107
+
1108
+ try:
1109
+ # Call LLM formatting
1110
+ metadata, status = llm_handler.format_sample_from_input(
1111
+ caption=caption,
1112
+ lyrics=lyrics,
1113
+ user_metadata=user_metadata,
1114
+ temperature=temperature,
1115
+ top_k=top_k,
1116
+ top_p=top_p,
1117
+ repetition_penalty=repetition_penalty,
1118
+ use_constrained_decoding=use_constrained_decoding,
1119
+ constrained_decoding_debug=constrained_decoding_debug,
1120
+ )
1121
+
1122
+ # Check if LLM returned empty metadata (error case)
1123
+ if not metadata:
1124
+ return FormatSampleResult(
1125
+ status_message=status or "Failed to format input",
1126
+ success=False,
1127
+ error=status or "Empty metadata returned",
1128
+ )
1129
+
1130
+ # Extract and convert fields
1131
+ result_caption = metadata.get('caption', '')
1132
+ result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
1133
+ keyscale = metadata.get('keyscale', '')
1134
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1135
+ timesignature = metadata.get('timesignature', '')
1136
+
1137
+ # Convert BPM to int
1138
+ bpm = None
1139
+ bpm_value = metadata.get('bpm')
1140
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1141
+ try:
1142
+ bpm = int(bpm_value)
1143
+ except (ValueError, TypeError):
1144
+ pass
1145
+
1146
+ # Convert duration to float
1147
+ duration = None
1148
+ duration_value = metadata.get('duration')
1149
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1150
+ try:
1151
+ duration = float(duration_value)
1152
+ except (ValueError, TypeError):
1153
+ pass
1154
+
1155
+ # Clean up N/A values
1156
+ if keyscale == 'N/A':
1157
+ keyscale = ''
1158
+ if language == 'N/A':
1159
+ language = ''
1160
+ if timesignature == 'N/A':
1161
+ timesignature = ''
1162
+
1163
+ return FormatSampleResult(
1164
+ caption=result_caption,
1165
+ lyrics=result_lyrics,
1166
+ bpm=bpm,
1167
+ duration=duration,
1168
+ keyscale=keyscale,
1169
+ language=language,
1170
+ timesignature=timesignature,
1171
+ status_message=status,
1172
+ success=True,
1173
+ error=None,
1174
+ )
1175
+
1176
+ except Exception as e:
1177
+ logger.exception("Format sample failed")
1178
+ return FormatSampleResult(
1179
+ status_message=f"Error: {str(e)}",
1180
+ success=False,
1181
+ error=str(e),
1182
+ )
spaces/Ace-Step-v1.5/acestep/llm_inference.py ADDED
The diff for this file is too large to render. See raw diff
 
spaces/Ace-Step-v1.5/acestep/local_cache.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local cache module to replace Redis
2
+
3
+ Uses diskcache as backend, provides Redis-compatible API.
4
+ Supports persistent storage and TTL expiration.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Any, Optional
10
+ from threading import Lock
11
+
12
+ try:
13
+ from diskcache import Cache
14
+ HAS_DISKCACHE = True
15
+ except ImportError:
16
+ HAS_DISKCACHE = False
17
+
18
+
19
+ class LocalCache:
20
+ """
21
+ Local cache implementation with Redis-compatible API.
22
+ Uses diskcache as backend, supports persistence and TTL.
23
+ """
24
+
25
+ _instance = None
26
+ _lock = Lock()
27
+
28
+ def __new__(cls, cache_dir: Optional[str] = None):
29
+ """Singleton pattern"""
30
+ if cls._instance is None:
31
+ with cls._lock:
32
+ if cls._instance is None:
33
+ cls._instance = super().__new__(cls)
34
+ cls._instance._initialized = False
35
+ return cls._instance
36
+
37
+ def __init__(self, cache_dir: Optional[str] = None):
38
+ if getattr(self, '_initialized', False):
39
+ return
40
+
41
+ if not HAS_DISKCACHE:
42
+ raise ImportError(
43
+ "diskcache not installed. Run: pip install diskcache"
44
+ )
45
+
46
+ if cache_dir is None:
47
+ cache_dir = os.path.join(
48
+ os.path.dirname(os.path.dirname(__file__)),
49
+ ".cache",
50
+ "local_redis"
51
+ )
52
+
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ self._cache = Cache(cache_dir)
55
+ self._initialized = True
56
+
57
+ def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
58
+ """
59
+ Set key-value pair
60
+
61
+ Args:
62
+ name: Key name
63
+ value: Value (auto-serialize dict/list)
64
+ ex: Expiration time (seconds)
65
+
66
+ Returns:
67
+ bool: Success status
68
+ """
69
+ if isinstance(value, (dict, list)):
70
+ value = json.dumps(value, ensure_ascii=False)
71
+ self._cache.set(name, value, expire=ex)
72
+ return True
73
+
74
+ def get(self, name: str) -> Optional[str]:
75
+ """Get value"""
76
+ return self._cache.get(name)
77
+
78
+ def delete(self, name: str) -> int:
79
+ """Delete key, returns number of deleted items"""
80
+ return 1 if self._cache.delete(name) else 0
81
+
82
+ def exists(self, name: str) -> bool:
83
+ """Check if key exists"""
84
+ return name in self._cache
85
+
86
+ def keys(self, pattern: str = "*") -> list:
87
+ """
88
+ Get list of matching keys
89
+ Note: Simplified implementation, only supports prefix and full matching
90
+ """
91
+ if pattern == "*":
92
+ return list(self._cache.iterkeys())
93
+
94
+ prefix = pattern.rstrip("*")
95
+ return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
96
+
97
+ def expire(self, name: str, seconds: int) -> bool:
98
+ """Set key expiration time"""
99
+ value = self._cache.get(name)
100
+ if value is not None:
101
+ self._cache.set(name, value, expire=seconds)
102
+ return True
103
+ return False
104
+
105
+ def ttl(self, name: str) -> int:
106
+ """
107
+ Get remaining time to live (seconds)
108
+ Note: diskcache does not directly support TTL queries
109
+ """
110
+ if name in self._cache:
111
+ return -1 # Exists but TTL unknown
112
+ return -2 # Key does not exist
113
+
114
+ def close(self):
115
+ """Close cache connection"""
116
+ if hasattr(self, '_cache'):
117
+ self._cache.close()
118
+
119
+
120
+ # Lazily initialized global instance
121
+ _local_cache: Optional[LocalCache] = None
122
+
123
+
124
+ def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
125
+ """Get local cache instance"""
126
+ global _local_cache
127
+ if _local_cache is None:
128
+ _local_cache = LocalCache(cache_dir)
129
+ return _local_cache
spaces/Ace-Step-v1.5/acestep/test_time_scaling.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-Time Scaling Module
3
+ Implements perplexity-based scoring for generated audio codes
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Optional, Dict, Any, List
8
+ from loguru import logger
9
+ import yaml
10
+ import math
11
+ import re
12
+
13
+
14
+ def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
15
+ """
16
+ Calculate Pointwise Mutual Information (PMI) score.
17
+
18
+ PMI = log P(condition|codes) - log P(condition)
19
+ = log [P(codes|condition) / P(codes)]
20
+
21
+ This removes the bias from P(condition) and measures how much the codes
22
+ improve our ability to predict the condition.
23
+
24
+ Args:
25
+ log_prob_conditional: Average log probability of condition given codes
26
+ log_prob_unconditional: Average log probability of condition without codes
27
+
28
+ Returns:
29
+ PMI score (higher is better, can be positive or negative)
30
+ - Positive: codes improve prediction → good match
31
+ - Zero: codes don't help → no correlation
32
+ - Negative: codes hurt prediction → poor match
33
+ """
34
+ return log_prob_conditional - log_prob_unconditional
35
+
36
+
37
+ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
38
+ """
39
+ Convert PMI score to normalized [0, 1] range using sigmoid function.
40
+
41
+ score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
42
+
43
+ Args:
44
+ pmi: PMI score (can be positive or negative)
45
+ scale: Scale parameter to control sensitivity (default 0.1)
46
+ - Smaller scale: more sensitive to PMI changes
47
+ - Larger scale: less sensitive to PMI changes
48
+
49
+ Returns:
50
+ Normalized score in [0, 1] range, where:
51
+ - PMI > 0 → score > 0.5 (good match)
52
+ - PMI = 0 → score = 0.5 (neutral)
53
+ - PMI < 0 → score < 0.5 (poor match)
54
+
55
+ Examples (scale=1.0):
56
+ PMI=2.0 → score≈0.88 (excellent)
57
+ PMI=1.0 → score≈0.73 (good)
58
+ PMI=0.0 → score=0.50 (neutral)
59
+ PMI=-1.0 → score≈0.27 (poor)
60
+ PMI=-2.0 → score≈0.12 (bad)
61
+ """
62
+ return 1.0 / (1.0 + math.exp(-pmi / scale))
63
+
64
+
65
+ def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
66
+ target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ """
68
+ Args:
69
+ llm_handler: The handler containing the model and tokenizer.
70
+ formatted_prompt: The input context.
71
+ target_text: The text we want to calculate probability/recall for.
72
+
73
+ Returns:
74
+ Tuple of (target_logits, target_ids)
75
+ - target_logits: Logits used to predict the target tokens.
76
+ - target_ids: The ground truth token IDs of the target.
77
+ """
78
+ model = llm_handler.get_hf_model_for_scoring()
79
+ tokenizer = llm_handler.llm_tokenizer
80
+ device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
81
+
82
+ # 1. Tokenize prompt ONLY to get its length (used for slicing later).
83
+ # We must ensure special tokens are added to count the offset correctly.
84
+ prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
85
+ prompt_len = prompt_tokens_temp['input_ids'].shape[1]
86
+
87
+ # 2. Tokenize the FULL text (Prompt + Target).
88
+ # This ensures subword merging at boundaries is handled correctly by the tokenizer.
89
+ full_text = formatted_prompt + target_text
90
+ full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
91
+
92
+ input_ids = full_tokens['input_ids']
93
+
94
+ # Safety check: if target was empty or truncated entirely
95
+ if input_ids.shape[1] <= prompt_len:
96
+ return torch.empty(0, device=device), torch.empty(0, device=device)
97
+
98
+ # 3. Forward Pass (Teacher Forcing)
99
+ with torch.no_grad():
100
+ with llm_handler._load_model_context():
101
+ outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
102
+ all_logits = outputs.logits # [1, seq_len, vocab_size]
103
+
104
+ # 4. Extract Logits and Labels
105
+ # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
106
+ # Target starts at index `prompt_len`.
107
+ # So we need logits from `prompt_len - 1` up to the second to last position.
108
+
109
+ target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
110
+ target_ids = input_ids[0, prompt_len:] # [target_len]
111
+
112
+ return target_logits, target_ids
113
+
114
+
115
+ # ==============================================================================
116
+ # Scoring Logic
117
+ # ==============================================================================
118
+
119
+
120
+ def _calculate_topk_recall(llm_handler,
121
+ formatted_prompt: str,
122
+ target_text: str,
123
+ topk: int = 10) -> Tuple[float, Dict[int, float]]:
124
+ """
125
+ Calculate top-k recall for target text given prompt.
126
+ Checks if the ground truth token is within the top-k probabilities at each step.
127
+ """
128
+ # Use the fixed helper to get aligned logits/labels
129
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
130
+
131
+ if target_ids.shape[0] == 0:
132
+ return 0.0, {}
133
+
134
+ target_len = target_ids.shape[0]
135
+
136
+ # Get top-k indices for all positions at once
137
+ # topk_indices: [target_len, topk]
138
+ _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
139
+
140
+ recall_per_k = {}
141
+ position_scores = []
142
+
143
+ # Convert to list for faster CPU iteration
144
+ target_ids_list = target_ids.tolist()
145
+ topk_indices_list = topk_indices.tolist()
146
+
147
+ for k in range(1, topk + 1):
148
+ hits = 0
149
+ for pos in range(target_len):
150
+ gt_token = target_ids_list[pos]
151
+ # Check the top-k slice
152
+ topk_at_pos = topk_indices_list[pos][:k]
153
+
154
+ if gt_token in topk_at_pos:
155
+ hits += 1
156
+ # Calculate position-weighted score only once (when k=topk)
157
+ if k == topk:
158
+ rank = topk_at_pos.index(gt_token) + 1
159
+ # Rank 1 = 1.0, Rank k = small positive
160
+ position_weight = 1.0 - (rank - 1) / topk
161
+ position_scores.append(position_weight)
162
+
163
+ recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
164
+
165
+ # Fill scores for positions where GT was NOT in top-k
166
+ while len(position_scores) < target_len:
167
+ position_scores.append(0.0)
168
+
169
+ average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
170
+
171
+ return average_recall, recall_per_k
172
+
173
+
174
+ def _calculate_metadata_recall(llm_handler,
175
+ formatted_prompt: str,
176
+ fields_dict: Dict[str, Any],
177
+ topk: int = 10) -> Dict[str, float]:
178
+ """
179
+ Args:
180
+ fields_dict: Dictionary of {field_name: field_value}
181
+ """
182
+ if not fields_dict:
183
+ return {}
184
+
185
+ field_scores = {}
186
+
187
+ for field_name in sorted(fields_dict.keys()):
188
+ # Construct target text for this specific field
189
+ # e.g. <think>\nbpm: 120\n</think>\n
190
+ field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
191
+ field_target_text = f"<think>\n{field_yaml}\n</think>\n"
192
+
193
+ # Calculate recall using the robust logic
194
+ avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
195
+
196
+ field_scores[field_name] = avg_score
197
+ logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
198
+
199
+ return field_scores
200
+
201
+
202
+ def _calculate_log_prob(
203
+ llm_handler,
204
+ formatted_prompt: str,
205
+ target_text: str,
206
+ temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
207
+ ) -> float:
208
+ """
209
+ Calculate average log probability of target text given prompt.
210
+ """
211
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
212
+
213
+ if target_ids.shape[0] == 0:
214
+ return float('-inf')
215
+
216
+ # FIX: Do not divide by temperature.
217
+ # Log-probability for PMI/Perplexity should be exact.
218
+
219
+ # Calculate log probabilities (log_softmax)
220
+ log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
221
+
222
+ # Gather log probabilities of the ground truth tokens
223
+ target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
224
+
225
+ # Return average log probability
226
+ mean_log_prob = target_log_probs.mean().item()
227
+
228
+ return mean_log_prob
229
+
230
+
231
+ def calculate_reward_score(
232
+ scores: Dict[str, float],
233
+ weights_config: Optional[Dict[str, float]] = None
234
+ ) -> Tuple[float, str]:
235
+ """
236
+ Reward Model Calculator: Computes a final reward based on user priorities.
237
+
238
+ Priority Logic:
239
+ 1. Caption (Highest): The overall vibe/style must match.
240
+ 2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
241
+ 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
242
+
243
+ Strategy: Dynamic Weighted Sum
244
+ - Metadata fields are aggregated into a single 'metadata' score first.
245
+ - Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
246
+
247
+ Args:
248
+ scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
249
+ weights_config: Optional custom weights. Defaults to:
250
+ Caption (50%), Lyrics (30%), Metadata (20%).
251
+
252
+ Returns:
253
+ final_reward: The calculated reward score (0.0 - 1.0).
254
+ explanation: A formatted string explaining how the score was derived.
255
+ """
256
+
257
+ # 1. Default Preference Configuration
258
+ # These weights determine the relative importance of each component.
259
+ if weights_config is None:
260
+ weights_config = {
261
+ 'caption': 0.50, # High priority: Style/Vibe
262
+ 'lyrics': 0.30, # Medium priority: Content
263
+ 'metadata': 0.20 # Low priority: Technical details
264
+ }
265
+
266
+ # 2. Extract and Group Scores
267
+ # Caption and Lyrics are standalone high-level features.
268
+ caption_score = scores.get('caption')
269
+ lyrics_score = scores.get('lyrics')
270
+
271
+ # Metadata fields (bpm, key, duration, etc.) are aggregated.
272
+ # We treat them as a single "Technical Score" to prevent them from
273
+ # diluting the weight of Caption/Lyrics simply by having many fields.
274
+ meta_scores_list = [
275
+ val for key, val in scores.items()
276
+ if key not in ['caption', 'lyrics']
277
+ ]
278
+
279
+ # Calculate average of all metadata fields (if any exist)
280
+ meta_aggregate_score = None
281
+ if meta_scores_list:
282
+ meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
283
+
284
+ # 3. specific Active Components & Dynamic Weighting
285
+ # We only include components that actually exist in this generation.
286
+ active_components = {}
287
+
288
+ if caption_score is not None:
289
+ active_components['caption'] = (caption_score, weights_config['caption'])
290
+
291
+ if lyrics_score is not None:
292
+ active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
293
+
294
+ if meta_aggregate_score is not None:
295
+ active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
296
+
297
+ # 4. Calculate Final Weighted Score
298
+ total_base_weight = sum(w for _, w in active_components.values())
299
+ total_score = 0.0
300
+
301
+ breakdown_lines = []
302
+
303
+ if total_base_weight == 0:
304
+ return 0.0, "❌ No valid scores available to calculate reward."
305
+
306
+ # Sort by weight (importance) for display
307
+ sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
308
+
309
+ for name, (score, base_weight) in sorted_components:
310
+ # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
311
+ normalized_weight = base_weight / total_base_weight
312
+ weighted_contribution = score * normalized_weight
313
+ total_score += weighted_contribution
314
+
315
+ breakdown_lines.append(
316
+ f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
317
+ f"-> Contrib: +{weighted_contribution:.4f}"
318
+ )
319
+
320
+ return total_score, "\n".join(breakdown_lines)
321
+
322
+ # ==============================================================================
323
+ # Main Public API
324
+ # ==============================================================================
325
+
326
+
327
+ def calculate_pmi_score_per_condition(
328
+ llm_handler,
329
+ audio_codes: str,
330
+ caption: str = "",
331
+ lyrics: str = "",
332
+ metadata: Optional[Dict[str, Any]] = None,
333
+ temperature: float = 1.0,
334
+ topk: int = 10,
335
+ score_scale: float = 0.1,
336
+ ) -> Tuple[Dict[str, float], float, str]:
337
+ """
338
+ Calculate quality score separately for each condition.
339
+ - Metadata: Uses Top-k Recall.
340
+ - Caption/Lyrics: Uses PMI (Normalized).
341
+ """
342
+ if not llm_handler.llm_initialized:
343
+ return {}, 0.0, "❌ LLM not initialized"
344
+
345
+ if not audio_codes or not audio_codes.strip():
346
+ return {}, 0.0, "❌ No audio codes provided"
347
+
348
+ if "caption" not in metadata:
349
+ metadata['caption'] = caption
350
+
351
+ formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
352
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
353
+ try:
354
+ # 1. Calculate Recall for Metadata Fields
355
+ if metadata and isinstance(metadata, dict):
356
+ scores = {}
357
+ # Define which fields use which metric
358
+ metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
359
+ metadata_pmi_keys = ['caption']
360
+ for key in metadata_recall_keys:
361
+ if key in metadata and metadata[key] is not None:
362
+ recall_metadata = {key: metadata[key]}
363
+ field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
364
+ scores.update(field_scores)
365
+
366
+ # 2. Calculate PMI for Caption
367
+ for key in metadata_pmi_keys:
368
+ if key in metadata and metadata[key] is not None:
369
+ cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
370
+ target_text = f"<think>\n{cot_yaml}\n</think>\n"
371
+
372
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
373
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
374
+
375
+ pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
376
+ scores[key] = pmi_normalized
377
+
378
+ # 3. Calculate PMI for Lyrics
379
+ if lyrics:
380
+ target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
381
+
382
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
383
+
384
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
385
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
386
+
387
+ scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
388
+
389
+ if not scores:
390
+ return {}, 0.0, "❌ No conditions to evaluate"
391
+
392
+ # 4. Global Score
393
+ global_score = sum(scores.values()) / len(scores)
394
+ global_score, breakdown_lines = calculate_reward_score(scores)
395
+
396
+ # Status Message
397
+ status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
398
+ for key, score in sorted(scores.items()):
399
+ metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
400
+ status_lines.append(f" {key}: {score:.4f} ({metric})")
401
+ status = "\n".join(status_lines)
402
+ logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
403
+ return scores, global_score, status
404
+
405
+ except Exception as e:
406
+ import traceback
407
+ error_msg = f"❌ Error: {str(e)}"
408
+ logger.error(error_msg)
409
+ logger.error(traceback.format_exc())
410
+ return {}, float('-inf'), error_msg
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Xingkai Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img width="300" src="assets/logo.png">
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
7
+ </p>
8
+
9
+ # Nano-vLLM
10
+
11
+ A lightweight vLLM implementation built from scratch.
12
+
13
+ ## Key Features
14
+
15
+ * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
16
+ * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
17
+ * ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
23
+ ```
24
+
25
+ ## Model Download
26
+
27
+ To download the model weights manually, use the following command:
28
+ ```bash
29
+ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
30
+ --local-dir ~/huggingface/Qwen3-0.6B/ \
31
+ --local-dir-use-symlinks False
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
37
+ ```python
38
+ from nanovllm import LLM, SamplingParams
39
+ llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
40
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
41
+ prompts = ["Hello, Nano-vLLM."]
42
+ outputs = llm.generate(prompts, sampling_params)
43
+ outputs[0]["text"]
44
+ ```
45
+
46
+ ## Benchmark
47
+
48
+ See `bench.py` for benchmark.
49
+
50
+ **Test Configuration:**
51
+ - Hardware: RTX 4070 Laptop (8GB)
52
+ - Model: Qwen3-0.6B
53
+ - Total Requests: 256 sequences
54
+ - Input Length: Randomly sampled between 100–1024 tokens
55
+ - Output Length: Randomly sampled between 100–1024 tokens
56
+
57
+ **Performance Results:**
58
+ | Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
59
+ |----------------|-------------|----------|-----------------------|
60
+ | vLLM | 133,966 | 98.37 | 1361.84 |
61
+ | Nano-vLLM | 133,966 | 93.41 | 1434.13 |
62
+
63
+
64
+ ## Star History
65
+
66
+ [![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/assets/logo.png ADDED

Git LFS Details

  • SHA256: 03ec4039dc248e97e9943694d3ccfb52c1a73a6dab94c4cd6fd4288e08de98c8
  • Pointer size: 131 Bytes
  • Size of remote file: 397 kB
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/bench.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from random import randint, seed
4
+ from nanovllm import LLM, SamplingParams
5
+ # from vllm import LLM, SamplingParams
6
+
7
+
8
+ def main():
9
+ seed(0)
10
+ num_seqs = 256
11
+ max_input_len = 1024
12
+ max_ouput_len = 1024
13
+
14
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
15
+ llm = LLM(path, enforce_eager=False, max_model_len=4096)
16
+
17
+ prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
18
+ sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
19
+ # uncomment the following line for vllm
20
+ # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
21
+
22
+ llm.generate(["Benchmark: "], SamplingParams())
23
+ t = time.time()
24
+ llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
25
+ t = (time.time() - t)
26
+ total_tokens = sum(sp.max_tokens for sp in sampling_params)
27
+ throughput = total_tokens / t
28
+ print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/example.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from nanovllm import LLM, SamplingParams
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ def main():
7
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
+ llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
10
+
11
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
12
+ prompts = [
13
+ "introduce yourself",
14
+ "list all prime numbers within 100",
15
+ ]
16
+ prompts = [
17
+ tokenizer.apply_chat_template(
18
+ [{"role": "user", "content": prompt}],
19
+ tokenize=False,
20
+ add_generation_prompt=True,
21
+ )
22
+ for prompt in prompts
23
+ ]
24
+ outputs = llm.generate(prompts, sampling_params)
25
+
26
+ for prompt, output in zip(prompts, outputs):
27
+ print("\n")
28
+ print(f"Prompt: {prompt!r}")
29
+ print(f"Completion: {output['text']!r}")
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nanovllm.llm import LLM
2
+ from nanovllm.sampling_params import SamplingParams
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from transformers import AutoConfig
4
+
5
+
6
+ @dataclass
7
+ class Config:
8
+ model: str
9
+ max_num_batched_tokens: int = 16384
10
+ max_num_seqs: int = 512
11
+ max_model_len: int = 4096
12
+ gpu_memory_utilization: float = 0.9
13
+ tensor_parallel_size: int = 1
14
+ enforce_eager: bool = False
15
+ hf_config: AutoConfig | None = None
16
+ eos: int = -1
17
+ kvcache_block_size: int = 256
18
+ num_kvcache_blocks: int = -1
19
+
20
+ def __post_init__(self):
21
+ assert os.path.isdir(self.model)
22
+ assert self.kvcache_block_size % 256 == 0
23
+ assert 1 <= self.tensor_parallel_size <= 8
24
+ self.hf_config = AutoConfig.from_pretrained(self.model)
25
+ self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
26
+ assert self.max_num_batched_tokens >= self.max_model_len
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import xxhash
3
+ import numpy as np
4
+
5
+ from nanovllm.engine.sequence import Sequence
6
+
7
+
8
+ class Block:
9
+
10
+ def __init__(self, block_id):
11
+ self.block_id = block_id
12
+ self.ref_count = 0
13
+ self.hash = -1
14
+ self.token_ids = []
15
+
16
+ def update(self, hash: int, token_ids: list[int]):
17
+ self.hash = hash
18
+ self.token_ids = token_ids
19
+
20
+ def reset(self):
21
+ self.ref_count = 1
22
+ self.hash = -1
23
+ self.token_ids = []
24
+
25
+
26
+ class BlockManager:
27
+
28
+ def __init__(self, num_blocks: int, block_size: int):
29
+ self.block_size = block_size
30
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
31
+ self.hash_to_block_id: dict[int, int] = dict()
32
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
33
+ self.used_block_ids: set[int] = set()
34
+
35
+ @classmethod
36
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
37
+ h = xxhash.xxh64()
38
+ if prefix != -1:
39
+ h.update(prefix.to_bytes(8, "little"))
40
+ h.update(np.array(token_ids).tobytes())
41
+ return h.intdigest()
42
+
43
+ def _allocate_block(self, block_id: int) -> Block:
44
+ block = self.blocks[block_id]
45
+ assert block.ref_count == 0
46
+ block.reset()
47
+ self.free_block_ids.remove(block_id)
48
+ self.used_block_ids.add(block_id)
49
+ return self.blocks[block_id]
50
+
51
+ def _deallocate_block(self, block_id: int) -> Block:
52
+ assert self.blocks[block_id].ref_count == 0
53
+ self.used_block_ids.remove(block_id)
54
+ self.free_block_ids.append(block_id)
55
+
56
+ def can_allocate(self, seq: Sequence) -> bool:
57
+ return len(self.free_block_ids) >= seq.num_blocks
58
+
59
+ def allocate(self, seq: Sequence):
60
+ assert not seq.block_table
61
+ h = -1
62
+ cache_miss = False
63
+ for i in range(seq.num_blocks):
64
+ token_ids = seq.block(i)
65
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
66
+ block_id = self.hash_to_block_id.get(h, -1)
67
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
68
+ cache_miss = True
69
+ if cache_miss:
70
+ block_id = self.free_block_ids[0]
71
+ block = self._allocate_block(block_id)
72
+ else:
73
+ seq.num_cached_tokens += self.block_size
74
+ if block_id in self.used_block_ids:
75
+ block = self.blocks[block_id]
76
+ block.ref_count += 1
77
+ else:
78
+ block = self._allocate_block(block_id)
79
+ if h != -1:
80
+ block.update(h, token_ids)
81
+ self.hash_to_block_id[h] = block_id
82
+ seq.block_table.append(block_id)
83
+
84
+ def deallocate(self, seq: Sequence):
85
+ for block_id in reversed(seq.block_table):
86
+ block = self.blocks[block_id]
87
+ block.ref_count -= 1
88
+ if block.ref_count == 0:
89
+ self._deallocate_block(block_id)
90
+ seq.num_cached_tokens = 0
91
+ seq.block_table.clear()
92
+
93
+ def can_append(self, seq: Sequence) -> bool:
94
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
95
+
96
+ def may_append(self, seq: Sequence):
97
+ block_table = seq.block_table
98
+ last_block = self.blocks[block_table[-1]]
99
+ if len(seq) % self.block_size == 1:
100
+ assert last_block.hash != -1
101
+ block_id = self.free_block_ids[0]
102
+ self._allocate_block(block_id)
103
+ block_table.append(block_id)
104
+ elif len(seq) % self.block_size == 0:
105
+ assert last_block.hash == -1
106
+ token_ids = seq.block(seq.num_blocks-1)
107
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
108
+ h = self.compute_hash(token_ids, prefix)
109
+ last_block.update(h, token_ids)
110
+ self.hash_to_block_id[h] = last_block.block_id
111
+ else:
112
+ assert last_block.hash == -1
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from dataclasses import fields
3
+ from time import perf_counter
4
+ from tqdm.auto import tqdm
5
+ from transformers import AutoTokenizer
6
+ import torch.multiprocessing as mp
7
+
8
+ from nanovllm.config import Config
9
+ from nanovllm.sampling_params import SamplingParams
10
+ from nanovllm.engine.sequence import Sequence
11
+ from nanovllm.engine.scheduler import Scheduler
12
+ from nanovllm.engine.model_runner import ModelRunner
13
+
14
+
15
+ class LLMEngine:
16
+
17
+ def __init__(self, model, **kwargs):
18
+ config_fields = {field.name for field in fields(Config)}
19
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
20
+ config = Config(model, **config_kwargs)
21
+ self.ps = []
22
+ self.events = []
23
+ ctx = mp.get_context("spawn")
24
+ for i in range(1, config.tensor_parallel_size):
25
+ event = ctx.Event()
26
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
27
+ process.start()
28
+ self.ps.append(process)
29
+ self.events.append(event)
30
+ self.model_runner = ModelRunner(config, 0, self.events)
31
+ tokenizer = kwargs.get("tokenizer", None)
32
+ if tokenizer is not None:
33
+ self.tokenizer = tokenizer
34
+ else:
35
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
36
+ config.eos = self.tokenizer.eos_token_id
37
+ self.scheduler = Scheduler(config)
38
+ atexit.register(self.exit)
39
+
40
+ def exit(self):
41
+ self.model_runner.call("exit")
42
+ del self.model_runner
43
+ for p in self.ps:
44
+ p.join()
45
+
46
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
47
+ if isinstance(prompt, str):
48
+ prompt = self.tokenizer.encode(prompt)
49
+ # For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
50
+ if sampling_params.cfg_scale > 1.0:
51
+ if unconditional_prompt is None:
52
+ # Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
53
+ # This is a fallback - ideally users should provide unconditional_prompt
54
+ if isinstance(prompt, list):
55
+ # For now, just use the same prompt (user should provide unconditional_prompt)
56
+ # TODO: Implement automatic "NO USER INPUT" replacement if possible
57
+ unconditional_prompt = prompt
58
+ else:
59
+ unconditional_prompt = prompt
60
+ if isinstance(unconditional_prompt, str):
61
+ unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
62
+ # Create unconditional sequence first (so we can reference it from conditional)
63
+ uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
64
+ # Create conditional sequence with reference to unconditional
65
+ cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
66
+ uncond_seq.paired_seq = cond_seq # Link them bidirectionally
67
+ # Add both sequences to scheduler
68
+ self.scheduler.add(cond_seq)
69
+ self.scheduler.add(uncond_seq)
70
+ else:
71
+ seq = Sequence(prompt, sampling_params)
72
+ self.scheduler.add(seq)
73
+
74
+ def step(self):
75
+ seqs, is_prefill = self.scheduler.schedule()
76
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
77
+ self.scheduler.postprocess(seqs, token_ids)
78
+ # Only output conditional sequences (unconditional sequences are just for CFG computation)
79
+ output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
80
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
81
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
82
+ return outputs, num_tokens
83
+
84
+ def is_finished(self):
85
+ return self.scheduler.is_finished()
86
+
87
+ def generate(
88
+ self,
89
+ prompts: list[str] | list[list[int]],
90
+ sampling_params: SamplingParams | list[SamplingParams],
91
+ use_tqdm: bool = True,
92
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
93
+ ) -> list[str]:
94
+ if use_tqdm:
95
+ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
96
+ if not isinstance(sampling_params, list):
97
+ sampling_params = [sampling_params] * len(prompts)
98
+ if unconditional_prompts is None:
99
+ unconditional_prompts = [None] * len(prompts)
100
+ for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
101
+ self.add_request(prompt, sp, uncond_prompt)
102
+ outputs = {}
103
+ prefill_throughput = decode_throughput = 0.
104
+ while not self.is_finished():
105
+ t = perf_counter()
106
+ output, num_tokens = self.step()
107
+ if use_tqdm:
108
+ if num_tokens > 0:
109
+ prefill_throughput = num_tokens / (perf_counter() - t)
110
+ else:
111
+ decode_throughput = -num_tokens / (perf_counter() - t)
112
+ pbar.set_postfix({
113
+ "Prefill": f"{int(prefill_throughput)}tok/s",
114
+ "Decode": f"{int(decode_throughput)}tok/s",
115
+ })
116
+ for seq_id, token_ids in output:
117
+ outputs[seq_id] = token_ids
118
+ if use_tqdm:
119
+ pbar.update(1)
120
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
121
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
122
+ if use_tqdm:
123
+ pbar.close()
124
+ return outputs
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import torch.distributed as dist
4
+ from multiprocessing.synchronize import Event
5
+ from multiprocessing.shared_memory import SharedMemory
6
+ import sys
7
+
8
+ from nanovllm.config import Config
9
+ from nanovllm.engine.sequence import Sequence
10
+ from nanovllm.models.qwen3 import Qwen3ForCausalLM
11
+ from nanovllm.layers.sampler import Sampler
12
+ from nanovllm.utils.context import set_context, get_context, reset_context
13
+ from nanovllm.utils.loader import load_model
14
+
15
+ import socket
16
+
17
+
18
+ def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
19
+ """Find an available port starting from start_port.
20
+
21
+ Args:
22
+ start_port: The starting port number to check
23
+ max_attempts: Maximum number of ports to try
24
+
25
+ Returns:
26
+ An available port number
27
+
28
+ Raises:
29
+ RuntimeError: If no available port is found within max_attempts
30
+ """
31
+ for i in range(max_attempts):
32
+ port = start_port + i
33
+ try:
34
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
35
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
36
+ s.bind(('localhost', port))
37
+ return port
38
+ except OSError:
39
+ # Port is in use, try next one
40
+ continue
41
+ raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
42
+
43
+
44
+ class ModelRunner:
45
+
46
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
47
+ # Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
48
+ torch._dynamo.config.capture_scalar_outputs = True
49
+
50
+ self.config = config
51
+ hf_config = config.hf_config
52
+ self.block_size = config.kvcache_block_size
53
+ self.enforce_eager = config.enforce_eager
54
+ self.world_size = config.tensor_parallel_size
55
+ self.rank = rank
56
+ self.event = event
57
+ dist_port = find_available_port()
58
+ print(f"[debug]dist_port: {dist_port}")
59
+ # Use gloo backend on Windows, nccl on Linux/other platforms
60
+ backend = "gloo" if sys.platform == "win32" else "nccl"
61
+ dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
62
+ torch.cuda.set_device(rank)
63
+ default_dtype = torch.get_default_dtype()
64
+ # Use dtype instead of deprecated torch_dtype
65
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
66
+ torch.set_default_dtype(config_dtype)
67
+ torch.set_default_device("cuda")
68
+ self.model = Qwen3ForCausalLM(hf_config)
69
+ load_model(self.model, config.model)
70
+ self.sampler = Sampler()
71
+
72
+ # Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
73
+ # Must be called before warmup_model() since it uses these buffers
74
+ self._allocate_sample_buffers()
75
+
76
+ self.warmup_model()
77
+ self.allocate_kv_cache()
78
+ if not self.enforce_eager:
79
+ self.capture_cudagraph()
80
+
81
+ torch.set_default_device("cpu")
82
+ torch.set_default_dtype(default_dtype)
83
+
84
+ if self.world_size > 1:
85
+ if rank == 0:
86
+ self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
87
+ dist.barrier()
88
+ else:
89
+ dist.barrier()
90
+ self.shm = SharedMemory(name="nanovllm")
91
+ self.loop()
92
+
93
+ def _allocate_sample_buffers(self):
94
+ """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
+ max_bs = self.config.max_num_seqs
96
+ max_tokens = self.config.max_num_batched_tokens
97
+ max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
98
+
99
+ # Pre-allocate pinned memory buffers on CPU for fast transfer
100
+ # Must explicitly specify device="cpu" since default device may be "cuda"
101
+ self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
102
+ self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
103
+ self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
104
+ self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
105
+ self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
106
+
107
+ # Pre-allocate decode buffers on CPU with pinned memory
108
+ self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
109
+ self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
110
+ self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
111
+ self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
112
+
113
+ # Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
114
+ self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
115
+ self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
116
+ self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
117
+ self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
118
+
119
+ # Pre-allocate block tables buffer (shared by both decode and prefill)
120
+ self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
121
+
122
+ # Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
123
+ # Max length is max_model_len since sequences can be that long
124
+ self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
125
+
126
+ def exit(self):
127
+ if self.world_size > 1:
128
+ self.shm.close()
129
+ dist.barrier()
130
+ if self.rank == 0:
131
+ self.shm.unlink()
132
+ if not self.enforce_eager:
133
+ del self.graphs, self.graph_pool
134
+ torch.cuda.synchronize()
135
+ dist.destroy_process_group()
136
+
137
+ def loop(self):
138
+ while True:
139
+ method_name, args = self.read_shm()
140
+ self.call(method_name, *args)
141
+ if method_name == "exit":
142
+ break
143
+
144
+ def read_shm(self):
145
+ assert self.world_size > 1 and self.rank > 0
146
+ self.event.wait()
147
+ n = int.from_bytes(self.shm.buf[0:4], "little")
148
+ method_name, *args = pickle.loads(self.shm.buf[4:n+4])
149
+ self.event.clear()
150
+ return method_name, args
151
+
152
+ def write_shm(self, method_name, *args):
153
+ assert self.world_size > 1 and self.rank == 0
154
+ data = pickle.dumps([method_name, *args])
155
+ n = len(data)
156
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
157
+ self.shm.buf[4:n+4] = data
158
+ for event in self.event:
159
+ event.set()
160
+
161
+ def call(self, method_name, *args):
162
+ if self.world_size > 1 and self.rank == 0:
163
+ self.write_shm(method_name, *args)
164
+ method = getattr(self, method_name, None)
165
+ return method(*args)
166
+
167
+ def warmup_model(self):
168
+ torch.cuda.empty_cache()
169
+ torch.cuda.reset_peak_memory_stats()
170
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
171
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
172
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
173
+ self.run(seqs, True)
174
+ torch.cuda.empty_cache()
175
+
176
+ def allocate_kv_cache(self):
177
+ config = self.config
178
+ hf_config = config.hf_config
179
+ free, total = torch.cuda.mem_get_info()
180
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
181
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
182
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
183
+ # Use dtype instead of deprecated torch_dtype
184
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
185
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * config_dtype.itemsize
186
+
187
+ # Calculate available memory for KV cache
188
+ # After warmup_model, empty_cache has been called, so current represents model memory only
189
+ # Use free memory but respect the gpu_memory_utilization limit
190
+ target_total_usage = total * config.gpu_memory_utilization
191
+ available_for_kv_cache = min(free * 0.9, target_total_usage - current)
192
+
193
+ # Ensure we have positive memory available
194
+ if available_for_kv_cache <= 0:
195
+ available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
196
+
197
+ config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
198
+ if config.num_kvcache_blocks <= 0:
199
+ raise RuntimeError(
200
+ f"Insufficient GPU memory for KV cache. "
201
+ f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
202
+ f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
203
+ f"Block size: {block_bytes / 1024**2:.2f} MB"
204
+ )
205
+ self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
206
+ layer_id = 0
207
+ for module in self.model.modules():
208
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
209
+ module.k_cache = self.kv_cache[0, layer_id]
210
+ module.v_cache = self.kv_cache[1, layer_id]
211
+ layer_id += 1
212
+
213
+ def prepare_block_tables(self, seqs: list[Sequence]):
214
+ max_len = max(len(seq.block_table) for seq in seqs)
215
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
216
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
217
+ return block_tables
218
+
219
+ def prepare_prefill(self, seqs: list[Sequence]):
220
+ input_ids = []
221
+ positions = []
222
+ cu_seqlens_q = [0]
223
+ cu_seqlens_k = [0]
224
+ max_seqlen_q = 0
225
+ max_seqlen_k = 0
226
+ slot_mapping = []
227
+ block_tables = None
228
+ for seq in seqs:
229
+ seqlen = len(seq)
230
+ input_ids.extend(seq[seq.num_cached_tokens:])
231
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
232
+ seqlen_q = seqlen - seq.num_cached_tokens
233
+ seqlen_k = seqlen
234
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
235
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
236
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
237
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
238
+ if not seq.block_table: # warmup
239
+ continue
240
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
241
+ start = seq.block_table[i] * self.block_size
242
+ if i != seq.num_blocks - 1:
243
+ end = start + self.block_size
244
+ else:
245
+ end = start + seq.last_block_num_tokens
246
+ slot_mapping.extend(list(range(start, end)))
247
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
248
+ block_tables = self.prepare_block_tables(seqs)
249
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
250
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
251
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
252
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
253
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
254
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
255
+ return input_ids, positions
256
+
257
+ def prepare_decode(self, seqs: list[Sequence]):
258
+ """Optimized decode preparation using pre-allocated buffers."""
259
+ bs = len(seqs)
260
+
261
+ # Use pre-allocated CPU buffers
262
+ for i, seq in enumerate(seqs):
263
+ self._cpu_input_ids[i] = seq.last_token
264
+ self._cpu_positions[i] = len(seq) - 1
265
+ self._cpu_context_lens[i] = len(seq)
266
+ self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
267
+
268
+ # Transfer to GPU using sliced views
269
+ input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
270
+ positions = self._cpu_positions[:bs].cuda(non_blocking=True)
271
+ slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
272
+ context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
273
+ block_tables = self.prepare_block_tables(seqs)
274
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
275
+ return input_ids, positions
276
+
277
+ def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
278
+ """Optimized sample preparation using pre-allocated buffers."""
279
+ if is_cfg_batch:
280
+ num_seqs = len(seqs) // 2
281
+ target_seqs = seqs[:num_seqs]
282
+ else:
283
+ num_seqs = len(seqs)
284
+ target_seqs = seqs
285
+
286
+ # Fill pre-allocated CPU buffers
287
+ top_ks_is_zero = True
288
+ top_ps_is_one = True
289
+ repetition_penalties_is_one = True
290
+ for i, seq in enumerate(target_seqs):
291
+ self._cpu_temperatures[i] = seq.temperature
292
+ self._cpu_cfg_scales[i] = seq.cfg_scale
293
+ self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
294
+ if seq.top_k is not None and seq.top_k > 0:
295
+ top_ks_is_zero = False
296
+ self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
297
+ if seq.top_p is not None and seq.top_p == 1.0:
298
+ top_ps_is_one = False
299
+ self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
300
+ if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
301
+ repetition_penalties_is_one = False
302
+
303
+ # Transfer to GPU using sliced views (single batched transfer)
304
+ temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
305
+ cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
306
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
307
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
308
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
309
+
310
+ return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
311
+
312
+ @torch.inference_mode()
313
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
314
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
315
+ return self.model.compute_logits(self.model(input_ids, positions))
316
+ else:
317
+ bs = input_ids.size(0)
318
+ context = get_context()
319
+
320
+ # Check if block_tables size exceeds pre-allocated buffer size
321
+ # This can happen when conditional and unconditional sequences have different lengths
322
+ # in CFG mode, causing block_tables to have more columns than expected
323
+ max_num_blocks = self.graph_vars["block_tables"].size(1)
324
+ if context.block_tables.size(1) > max_num_blocks:
325
+ # Fall back to eager mode when block_tables is too large for CUDA graph
326
+ return self.model.compute_logits(self.model(input_ids, positions))
327
+
328
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
329
+ graph_vars = self.graph_vars
330
+ graph_vars["input_ids"][:bs] = input_ids
331
+ graph_vars["positions"][:bs] = positions
332
+ graph_vars["slot_mapping"].fill_(-1)
333
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
334
+ graph_vars["context_lens"].zero_()
335
+ graph_vars["context_lens"][:bs] = context.context_lens
336
+ # Clear block_tables first to ensure no stale data from previous runs
337
+ graph_vars["block_tables"][:bs].fill_(-1)
338
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
339
+ graph.replay()
340
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
341
+
342
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
343
+ """Run model forward and sampling. For CFG sequences, batch is structured as:
344
+ [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
345
+ where uncond_seqi is the paired unconditional sequence of cond_seqi."""
346
+ # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
347
+ is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
348
+ if is_cfg_batch:
349
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
350
+ num_cond = len(seqs) // 2
351
+ cond_seqs = seqs[:num_cond]
352
+ # uncond_seqs = seqs[num_cond:]
353
+
354
+ # Prepare inputs for both conditional and unconditional (they're already in the batch)
355
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
356
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
357
+ if sample_params is not None:
358
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
359
+ else:
360
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
361
+
362
+ # Run model forward (processes entire batch: cond + uncond)
363
+ logits_all = self.run_model(input_ids, positions, is_prefill)
364
+ reset_context()
365
+
366
+ if self.rank == 0:
367
+ # Split logits: first half is conditional, second half is unconditional
368
+ logits_cond = logits_all[:num_cond]
369
+ logits_uncond = logits_all[num_cond:]
370
+
371
+ # Apply repetition penalty to conditional logits (before CFG)
372
+ if repetition_penalties is not None:
373
+ for i, seq in enumerate(cond_seqs):
374
+ penalty = repetition_penalties[i].item()
375
+ if penalty != 1.0:
376
+ # Only penalize completion tokens (not prompt tokens)
377
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
378
+ if len(completion_tokens) > 0:
379
+ # Create token mask: mark tokens that appeared in completion
380
+ token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
381
+ token_mask[completion_tokens] = True
382
+
383
+ # Apply standard repetition penalty formula (matching transformers implementation):
384
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
385
+ penalty_scores = torch.where(
386
+ logits_cond[i] < 0,
387
+ logits_cond[i] * penalty,
388
+ logits_cond[i] / penalty
389
+ )
390
+ # Only apply penalty to tokens that appeared in completion
391
+ logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
392
+
393
+ # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
394
+ cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
395
+ logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
396
+
397
+ # Apply logits processor for constrained decoding (if any sequence has one)
398
+ for i, seq in enumerate(cond_seqs):
399
+ if seq.logits_processor is not None:
400
+ # Create input_ids tensor for this sequence
401
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
402
+ # Apply processor to this sequence's logits
403
+ logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
404
+
405
+ # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
406
+ # cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
407
+
408
+ # Sample from CFG logits
409
+ token_ids_cfg = self.sampler(
410
+ logits_cfg,
411
+ temperatures,
412
+ top_ks=top_ks if top_ks is not None else None,
413
+ top_ps=top_ps if top_ps is not None else None,
414
+ repetition_penalties=None, # Already applied above
415
+ # input_ids=cond_input_ids,
416
+ ).tolist()
417
+
418
+ # Update logits processor state after sampling
419
+ for i, seq in enumerate(cond_seqs):
420
+ if seq.logits_processor_update_state is not None:
421
+ seq.logits_processor_update_state(token_ids_cfg[i])
422
+
423
+ # Return token_ids (will be applied to both conditional and unconditional sequences)
424
+ return token_ids_cfg
425
+ else:
426
+ return None
427
+ else:
428
+ # Normal batch (non-CFG)
429
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
430
+ else self.prepare_decode(seqs))
431
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
432
+ if sample_params is not None:
433
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
434
+ else:
435
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
436
+ logits = self.run_model(input_ids, positions, is_prefill)
437
+ reset_context()
438
+
439
+ if self.rank == 0:
440
+ # Apply repetition penalty to logits
441
+ if repetition_penalties is not None:
442
+ for i, seq in enumerate(seqs):
443
+ penalty = repetition_penalties[i].item()
444
+ if penalty != 1.0:
445
+ # Only penalize completion tokens (not prompt tokens)
446
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
447
+ if len(completion_tokens) > 0:
448
+ # Create token mask: mark tokens that appeared in completion
449
+ token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
450
+ token_mask[completion_tokens] = True
451
+
452
+ # Apply standard repetition penalty formula (matching transformers implementation):
453
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
454
+ penalty_scores = torch.where(
455
+ logits[i] < 0,
456
+ logits[i] * penalty,
457
+ logits[i] / penalty
458
+ )
459
+ # Only apply penalty to tokens that appeared in completion
460
+ logits[i] = torch.where(token_mask, penalty_scores, logits[i])
461
+
462
+ # Apply logits processor for constrained decoding (if any sequence has one)
463
+ # Clone logits to avoid in-place update issues in inference mode
464
+ logits = logits.clone()
465
+ for i, seq in enumerate(seqs):
466
+ if seq.logits_processor is not None:
467
+ # Create input_ids tensor for this sequence
468
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
469
+ # Apply processor to this sequence's logits (clone to avoid inference mode issues)
470
+ processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
471
+ logits[i] = processed[0]
472
+
473
+ # Prepare input_ids for sampler
474
+ # seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
475
+
476
+ token_ids = self.sampler(
477
+ logits,
478
+ temperatures,
479
+ top_ks=top_ks if top_ks is not None else None,
480
+ top_ps=top_ps if top_ps is not None else None,
481
+ repetition_penalties=None, # Already applied above
482
+ # input_ids=seq_input_ids,
483
+ ).tolist()
484
+
485
+ # Update logits processor state after sampling
486
+ for i, seq in enumerate(seqs):
487
+ if seq.logits_processor_update_state is not None:
488
+ seq.logits_processor_update_state(token_ids[i])
489
+
490
+ return token_ids
491
+ else:
492
+ return None
493
+
494
+ @torch.inference_mode()
495
+ def capture_cudagraph(self):
496
+ config = self.config
497
+ hf_config = config.hf_config
498
+ max_bs = min(self.config.max_num_seqs, 512)
499
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
500
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
501
+ positions = torch.zeros(max_bs, dtype=torch.int64)
502
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
503
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
504
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
505
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
506
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
507
+ self.graphs = {}
508
+ self.graph_pool = None
509
+
510
+ for bs in reversed(self.graph_bs):
511
+ graph = torch.cuda.CUDAGraph()
512
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
513
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
514
+ with torch.cuda.graph(graph, self.graph_pool):
515
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
516
+ if self.graph_pool is None:
517
+ self.graph_pool = graph.pool()
518
+ self.graphs[bs] = graph
519
+ torch.cuda.synchronize()
520
+ reset_context()
521
+
522
+ self.graph_vars = dict(
523
+ input_ids=input_ids,
524
+ positions=positions,
525
+ slot_mapping=slot_mapping,
526
+ context_lens=context_lens,
527
+ block_tables=block_tables,
528
+ outputs=outputs,
529
+ )
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ from nanovllm.config import Config
4
+ from nanovllm.engine.sequence import Sequence, SequenceStatus
5
+ from nanovllm.engine.block_manager import BlockManager
6
+
7
+
8
+ class Scheduler:
9
+
10
+ def __init__(self, config: Config):
11
+ self.max_num_seqs = config.max_num_seqs
12
+ self.max_num_batched_tokens = config.max_num_batched_tokens
13
+ self.eos = config.eos
14
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
15
+ self.waiting: deque[Sequence] = deque()
16
+ self.running: deque[Sequence] = deque()
17
+
18
+ def is_finished(self):
19
+ return not self.waiting and not self.running
20
+
21
+ def add(self, seq: Sequence):
22
+ self.waiting.append(seq)
23
+
24
+ def schedule(self) -> tuple[list[Sequence], bool]:
25
+ # prefill
26
+ scheduled_seqs = []
27
+ num_seqs = 0
28
+ num_batched_tokens = 0
29
+ processed_seqs = set() # Track processed sequences to handle CFG pairs
30
+
31
+ while self.waiting and num_seqs < self.max_num_seqs:
32
+ seq = self.waiting[0]
33
+
34
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
35
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
36
+ # This is a conditional sequence, need to schedule its paired unconditional sequence too
37
+ paired_seq = seq.paired_seq
38
+ if paired_seq.status != SequenceStatus.WAITING:
39
+ # Paired sequence not in waiting, skip this conditional sequence for now
40
+ break
41
+
42
+ # Calculate tokens for both sequences
43
+ total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
44
+ can_allocate_both = (self.block_manager.can_allocate(seq) and
45
+ self.block_manager.can_allocate(paired_seq))
46
+
47
+ if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
48
+ break
49
+
50
+ # Schedule both sequences: conditional first, then unconditional
51
+ for s in [seq, paired_seq]:
52
+ num_seqs += 1
53
+ self.block_manager.allocate(s)
54
+ num_batched_tokens += len(s) - s.num_cached_tokens
55
+ s.status = SequenceStatus.RUNNING
56
+ self.waiting.remove(s)
57
+ self.running.append(s)
58
+ scheduled_seqs.append(s)
59
+ processed_seqs.add(s.seq_id)
60
+ else:
61
+ # Normal sequence or unconditional sequence (already processed with its conditional)
62
+ if seq.seq_id in processed_seqs:
63
+ # Skip if already processed as part of a CFG pair
64
+ self.waiting.popleft()
65
+ continue
66
+
67
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
68
+ break
69
+ num_seqs += 1
70
+ self.block_manager.allocate(seq)
71
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
72
+ seq.status = SequenceStatus.RUNNING
73
+ self.waiting.popleft()
74
+ self.running.append(seq)
75
+ scheduled_seqs.append(seq)
76
+
77
+ if scheduled_seqs:
78
+ # For CFG batches, ensure conditional sequences come before their unconditional pairs
79
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
80
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
81
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
82
+
83
+ # Reorder: non-CFG, then CFG conditional, then CFG unconditional
84
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
85
+ return scheduled_seqs, True
86
+
87
+ # decode
88
+ processed_seqs = set()
89
+ temp_running = list(self.running) # Work with a copy
90
+
91
+ while temp_running and num_seqs < self.max_num_seqs:
92
+ seq = temp_running.pop(0)
93
+
94
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
95
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
96
+ paired_seq = seq.paired_seq
97
+ if paired_seq not in temp_running:
98
+ # Paired sequence not available, skip for now
99
+ continue
100
+
101
+ # Remove paired_seq from temp_running
102
+ temp_running.remove(paired_seq)
103
+
104
+ # Check if both can append
105
+ can_append_both = (self.block_manager.can_append(seq) and
106
+ self.block_manager.can_append(paired_seq))
107
+
108
+ if not can_append_both:
109
+ # Try preempting other sequences
110
+ preempted = False
111
+ while not can_append_both and temp_running:
112
+ other_seq = temp_running.pop(0)
113
+ if other_seq != seq and other_seq != paired_seq:
114
+ self.preempt(other_seq)
115
+ can_append_both = (self.block_manager.can_append(seq) and
116
+ self.block_manager.can_append(paired_seq))
117
+ preempted = True
118
+ else:
119
+ temp_running.append(other_seq)
120
+ break
121
+
122
+ if not can_append_both:
123
+ # Can't schedule this pair right now
124
+ temp_running.append(seq)
125
+ temp_running.append(paired_seq)
126
+ continue
127
+
128
+ # Schedule both sequences
129
+ for s in [seq, paired_seq]:
130
+ num_seqs += 1
131
+ self.block_manager.may_append(s)
132
+ scheduled_seqs.append(s)
133
+ processed_seqs.add(s.seq_id)
134
+ # Remove from actual running list if scheduled
135
+ if s in self.running:
136
+ self.running.remove(s)
137
+ else:
138
+ # Normal sequence or unconditional (already processed)
139
+ if seq.seq_id in processed_seqs:
140
+ continue
141
+
142
+ while not self.block_manager.can_append(seq):
143
+ if temp_running:
144
+ other_seq = temp_running.pop(0)
145
+ if other_seq != seq:
146
+ self.preempt(other_seq)
147
+ else:
148
+ temp_running.append(other_seq)
149
+ break
150
+ else:
151
+ self.preempt(seq)
152
+ if seq in self.running:
153
+ self.running.remove(seq)
154
+ break
155
+ else:
156
+ num_seqs += 1
157
+ self.block_manager.may_append(seq)
158
+ scheduled_seqs.append(seq)
159
+ if seq in self.running:
160
+ self.running.remove(seq)
161
+
162
+ assert scheduled_seqs
163
+
164
+ # For CFG batches in decode, ensure conditional sequences come before unconditional
165
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
166
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
167
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
168
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
169
+
170
+ self.running.extendleft(reversed(scheduled_seqs))
171
+ return scheduled_seqs, False
172
+
173
+ def preempt(self, seq: Sequence):
174
+ seq.status = SequenceStatus.WAITING
175
+ self.block_manager.deallocate(seq)
176
+ self.waiting.appendleft(seq)
177
+
178
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
179
+ # Check if this is a CFG batch
180
+ is_cfg_batch = False
181
+ if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
182
+ num_cond = len(seqs) // 2
183
+ is_cfg_batch = (num_cond > 0 and
184
+ not seqs[0].is_unconditional and
185
+ seqs[num_cond].is_unconditional)
186
+
187
+ if is_cfg_batch:
188
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
189
+ # token_ids correspond to conditional sequences only (sampled from CFG logits)
190
+ num_cond = len(seqs) // 2
191
+ cond_seqs = seqs[:num_cond]
192
+ uncond_seqs = seqs[num_cond:]
193
+
194
+ # Apply the same sampled token to both conditional and unconditional sequences
195
+ for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
196
+ cond_seq.append_token(token_id)
197
+ uncond_seq.append_token(token_id) # Same token for unconditional
198
+
199
+ # Check if either sequence is finished
200
+ cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
201
+ cond_seq.num_completion_tokens == cond_seq.max_tokens)
202
+ uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
203
+ uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
204
+
205
+ if cond_finished or uncond_finished:
206
+ # Mark both as finished
207
+ cond_seq.status = SequenceStatus.FINISHED
208
+ uncond_seq.status = SequenceStatus.FINISHED
209
+ self.block_manager.deallocate(cond_seq)
210
+ self.block_manager.deallocate(uncond_seq)
211
+ if cond_seq in self.running:
212
+ self.running.remove(cond_seq)
213
+ if uncond_seq in self.running:
214
+ self.running.remove(uncond_seq)
215
+ else:
216
+ # Normal batch
217
+ for seq, token_id in zip(seqs, token_ids):
218
+ seq.append_token(token_id)
219
+ if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
220
+ seq.status = SequenceStatus.FINISHED
221
+ self.block_manager.deallocate(seq)
222
+ self.running.remove(seq)
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from enum import Enum, auto
3
+ from itertools import count
4
+ from typing import Optional, Callable, Any
5
+
6
+ from nanovllm.sampling_params import SamplingParams
7
+
8
+
9
+ class SequenceStatus(Enum):
10
+ WAITING = auto()
11
+ RUNNING = auto()
12
+ FINISHED = auto()
13
+
14
+
15
+ class Sequence:
16
+ block_size = 256
17
+ counter = count()
18
+
19
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
20
+ self.seq_id = next(Sequence.counter)
21
+ self.status = SequenceStatus.WAITING
22
+ self.token_ids = copy(token_ids)
23
+ self.last_token = token_ids[-1]
24
+ self.num_tokens = len(self.token_ids)
25
+ self.num_prompt_tokens = len(token_ids)
26
+ self.num_cached_tokens = 0
27
+ self.block_table = []
28
+ self.temperature = sampling_params.temperature
29
+ self.max_tokens = sampling_params.max_tokens
30
+ self.ignore_eos = sampling_params.ignore_eos
31
+ self.cfg_scale = sampling_params.cfg_scale
32
+ self.top_k = sampling_params.top_k
33
+ self.top_p = sampling_params.top_p
34
+ self.repetition_penalty = sampling_params.repetition_penalty
35
+ # For CFG: mark if this is an unconditional sequence
36
+ self.is_unconditional = is_unconditional
37
+ # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
38
+ # For conditional sequences, this points to the unconditional sequence
39
+ self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
40
+ # For constrained decoding: logits processor and state update callback
41
+ self.logits_processor: Optional[Any] = sampling_params.logits_processor
42
+ self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
43
+
44
+ def __len__(self):
45
+ return self.num_tokens
46
+
47
+ def __getitem__(self, key):
48
+ return self.token_ids[key]
49
+
50
+ @property
51
+ def is_finished(self):
52
+ return self.status == SequenceStatus.FINISHED
53
+
54
+ @property
55
+ def num_completion_tokens(self):
56
+ return self.num_tokens - self.num_prompt_tokens
57
+
58
+ @property
59
+ def prompt_token_ids(self):
60
+ return self.token_ids[:self.num_prompt_tokens]
61
+
62
+ @property
63
+ def completion_token_ids(self):
64
+ return self.token_ids[self.num_prompt_tokens:]
65
+
66
+ @property
67
+ def num_cached_blocks(self):
68
+ return self.num_cached_tokens // self.block_size
69
+
70
+ @property
71
+ def num_blocks(self):
72
+ return (self.num_tokens + self.block_size - 1) // self.block_size
73
+
74
+ @property
75
+ def last_block_num_tokens(self):
76
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
77
+
78
+ def block(self, i):
79
+ assert 0 <= i < self.num_blocks
80
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
81
+
82
+ def append_token(self, token_id: int):
83
+ self.token_ids.append(token_id)
84
+ self.last_token = token_id
85
+ self.num_tokens += 1
86
+
87
+ def __getstate__(self):
88
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
89
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
90
+
91
+ def __setstate__(self, state):
92
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
93
+ if self.num_completion_tokens == 0:
94
+ self.token_ids = state[-1]
95
+ else:
96
+ self.last_token = state[-1]
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/activation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SiluAndMul(nn.Module):
7
+
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ @torch.compile
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ x, y = x.chunk(2, -1)
14
+ return F.silu(x) * y
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/attention.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
7
+ from nanovllm.utils.context import get_context
8
+
9
+
10
+ @triton.jit
11
+ def store_kvcache_kernel(
12
+ key_ptr,
13
+ key_stride,
14
+ value_ptr,
15
+ value_stride,
16
+ k_cache_ptr,
17
+ v_cache_ptr,
18
+ slot_mapping_ptr,
19
+ D: tl.constexpr,
20
+ ):
21
+ idx = tl.program_id(0)
22
+ slot = tl.load(slot_mapping_ptr + idx)
23
+ if slot == -1: return
24
+ key_offsets = idx * key_stride + tl.arange(0, D)
25
+ value_offsets = idx * value_stride + tl.arange(0, D)
26
+ key = tl.load(key_ptr + key_offsets)
27
+ value = tl.load(value_ptr + value_offsets)
28
+ cache_offsets = slot * D + tl.arange(0, D)
29
+ tl.store(k_cache_ptr + cache_offsets, key)
30
+ tl.store(v_cache_ptr + cache_offsets, value)
31
+
32
+
33
+ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
34
+ N, num_heads, head_dim = key.shape
35
+ D = num_heads * head_dim
36
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
37
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
38
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
39
+ assert slot_mapping.numel() == N
40
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
41
+
42
+
43
+ class Attention(nn.Module):
44
+
45
+ def __init__(
46
+ self,
47
+ num_heads,
48
+ head_dim,
49
+ scale,
50
+ num_kv_heads,
51
+ ):
52
+ super().__init__()
53
+ self.num_heads = num_heads
54
+ self.head_dim = head_dim
55
+ self.scale = scale
56
+ self.num_kv_heads = num_kv_heads
57
+ self.k_cache = self.v_cache = torch.tensor([])
58
+
59
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
60
+ context = get_context()
61
+ k_cache, v_cache = self.k_cache, self.v_cache
62
+ if k_cache.numel() and v_cache.numel():
63
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
64
+ if context.is_prefill:
65
+ if context.block_tables is not None: # prefix cache
66
+ k, v = k_cache, v_cache
67
+ o = flash_attn_varlen_func(q, k, v,
68
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
69
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
70
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
71
+ else: # decode
72
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
73
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
74
+ softmax_scale=self.scale, causal=True)
75
+ return o
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from nanovllm.utils.context import get_context
7
+
8
+
9
+ class VocabParallelEmbedding(nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ num_embeddings: int,
14
+ embedding_dim: int,
15
+ ):
16
+ super().__init__()
17
+ self.tp_rank = dist.get_rank()
18
+ self.tp_size = dist.get_world_size()
19
+ assert num_embeddings % self.tp_size == 0
20
+ self.num_embeddings = num_embeddings
21
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
22
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
23
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
24
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
25
+ self.weight.weight_loader = self.weight_loader
26
+
27
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
28
+ param_data = param.data
29
+ shard_size = param_data.size(0)
30
+ start_idx = self.tp_rank * shard_size
31
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
32
+ param_data.copy_(loaded_weight)
33
+
34
+ def forward(self, x: torch.Tensor):
35
+ if self.tp_size > 1:
36
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
37
+ x = mask * (x - self.vocab_start_idx)
38
+ y = F.embedding(x, self.weight)
39
+ if self.tp_size > 1:
40
+ y = mask.unsqueeze(1) * y
41
+ dist.all_reduce(y)
42
+ return y
43
+
44
+
45
+ class ParallelLMHead(VocabParallelEmbedding):
46
+
47
+ def __init__(
48
+ self,
49
+ num_embeddings: int,
50
+ embedding_dim: int,
51
+ bias: bool = False,
52
+ ):
53
+ assert not bias
54
+ super().__init__(num_embeddings, embedding_dim)
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ context = get_context()
58
+ if context.is_prefill:
59
+ last_indices = context.cu_seqlens_q[1:] - 1
60
+ x = x[last_indices].contiguous()
61
+ logits = F.linear(x, self.weight)
62
+ if self.tp_size > 1:
63
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
64
+ dist.gather(logits, all_logits, 0)
65
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
66
+ return logits