Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- spaces/Ace-Step-v1.5/.env.example +4 -0
- spaces/Ace-Step-v1.5/.gitattributes +38 -0
- spaces/Ace-Step-v1.5/.gitignore +228 -0
- spaces/Ace-Step-v1.5/Dockerfile +57 -0
- spaces/Ace-Step-v1.5/LICENSE +21 -0
- spaces/Ace-Step-v1.5/README.md +229 -0
- spaces/Ace-Step-v1.5/acestep/__init__.py +1 -0
- spaces/Ace-Step-v1.5/acestep/acestep_v15_pipeline.py +303 -0
- spaces/Ace-Step-v1.5/acestep/api_server.py +1700 -0
- spaces/Ace-Step-v1.5/acestep/audio_utils.py +378 -0
- spaces/Ace-Step-v1.5/acestep/constants.py +109 -0
- spaces/Ace-Step-v1.5/acestep/constrained_logits_processor.py +0 -0
- spaces/Ace-Step-v1.5/acestep/dataset_handler.py +37 -0
- spaces/Ace-Step-v1.5/acestep/dit_alignment_score.py +870 -0
- spaces/Ace-Step-v1.5/acestep/genres_vocab.txt +0 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/__init__.py +1 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/events/__init__.py +1310 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/events/generation_handlers.py +1054 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/events/results_handlers.py +0 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/events/training_handlers.py +644 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n.py +152 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/en.json +245 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/ja.json +245 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/zh.json +245 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/__init__.py +98 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/dataset.py +101 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/generation.py +693 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/result.py +598 -0
- spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/training.py +562 -0
- spaces/Ace-Step-v1.5/acestep/handler.py +0 -0
- spaces/Ace-Step-v1.5/acestep/inference.py +1182 -0
- spaces/Ace-Step-v1.5/acestep/llm_inference.py +0 -0
- spaces/Ace-Step-v1.5/acestep/local_cache.py +129 -0
- spaces/Ace-Step-v1.5/acestep/test_time_scaling.py +410 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/LICENSE +21 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/README.md +66 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/assets/logo.png +3 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/bench.py +32 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/example.py +33 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +112 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +124 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +529 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +222 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +96 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
- spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
.gitattributes
CHANGED
|
@@ -39,3 +39,9 @@ code/assets/acestudio_logo.png filter=lfs diff=lfs merge=lfs -text
|
|
| 39 |
code/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
code/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
code/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
code/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
code/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
code/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/assets/logo.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
spaces/Ace-Step-v1.5/assets/ACE-Step_framework.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
spaces/Ace-Step-v1.5/assets/acestudio_logo.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
spaces/Ace-Step-v1.5/assets/application_map.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
spaces/Ace-Step-v1.5/assets/model_zoo.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
spaces/Ace-Step-v1.5/assets/orgnization_logos.png filter=lfs diff=lfs merge=lfs -text
|
spaces/Ace-Step-v1.5/.env.example
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ACESTEP_CONFIG_PATH=acestep-v15-turbo
|
| 2 |
+
ACESTEP_LM_MODEL_PATH=acestep-5Hz-lm-1.7B
|
| 3 |
+
ACESTEP_DEVICE=auto
|
| 4 |
+
ACESTEP_LM_BACKEND=vllm
|
spaces/Ace-Step-v1.5/.gitattributes
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
spaces/Ace-Step-v1.5/.gitignore
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data/
|
| 2 |
+
*.mp3
|
| 3 |
+
*.wav
|
| 4 |
+
|
| 5 |
+
# Byte-compiled / optimized / DLL files
|
| 6 |
+
__pycache__/
|
| 7 |
+
*.py[codz]
|
| 8 |
+
*$py.class
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Distribution / packaging
|
| 14 |
+
.Python
|
| 15 |
+
build/
|
| 16 |
+
develop-eggs/
|
| 17 |
+
dist/
|
| 18 |
+
downloads/
|
| 19 |
+
eggs/
|
| 20 |
+
.eggs/
|
| 21 |
+
lib/
|
| 22 |
+
lib64/
|
| 23 |
+
parts/
|
| 24 |
+
sdist/
|
| 25 |
+
var/
|
| 26 |
+
wheels/
|
| 27 |
+
share/python-wheels/
|
| 28 |
+
*.egg-info/
|
| 29 |
+
.installed.cfg
|
| 30 |
+
*.egg
|
| 31 |
+
MANIFEST
|
| 32 |
+
|
| 33 |
+
# PyInstaller
|
| 34 |
+
# Usually these files are written by a python script from a template
|
| 35 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 36 |
+
*.manifest
|
| 37 |
+
*.spec
|
| 38 |
+
|
| 39 |
+
# Installer logs
|
| 40 |
+
pip-log.txt
|
| 41 |
+
pip-delete-this-directory.txt
|
| 42 |
+
|
| 43 |
+
# Unit test / coverage reports
|
| 44 |
+
htmlcov/
|
| 45 |
+
.tox/
|
| 46 |
+
.nox/
|
| 47 |
+
.coverage
|
| 48 |
+
.coverage.*
|
| 49 |
+
.cache
|
| 50 |
+
nosetests.xml
|
| 51 |
+
coverage.xml
|
| 52 |
+
*.cover
|
| 53 |
+
*.py.cover
|
| 54 |
+
.hypothesis/
|
| 55 |
+
.pytest_cache/
|
| 56 |
+
cover/
|
| 57 |
+
|
| 58 |
+
# Translations
|
| 59 |
+
*.mo
|
| 60 |
+
*.pot
|
| 61 |
+
|
| 62 |
+
# Django stuff:
|
| 63 |
+
*.log
|
| 64 |
+
local_settings.py
|
| 65 |
+
db.sqlite3
|
| 66 |
+
db.sqlite3-journal
|
| 67 |
+
|
| 68 |
+
# Flask stuff:
|
| 69 |
+
instance/
|
| 70 |
+
.webassets-cache
|
| 71 |
+
|
| 72 |
+
# Scrapy stuff:
|
| 73 |
+
.scrapy
|
| 74 |
+
|
| 75 |
+
# Sphinx documentation
|
| 76 |
+
docs/_build/
|
| 77 |
+
|
| 78 |
+
# PyBuilder
|
| 79 |
+
.pybuilder/
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# IPython
|
| 86 |
+
profile_default/
|
| 87 |
+
ipython_config.py
|
| 88 |
+
|
| 89 |
+
# pyenv
|
| 90 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 91 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 92 |
+
# .python-version
|
| 93 |
+
|
| 94 |
+
# pipenv
|
| 95 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 96 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 97 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 98 |
+
# install all needed dependencies.
|
| 99 |
+
#Pipfile.lock
|
| 100 |
+
|
| 101 |
+
# UV
|
| 102 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 103 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 104 |
+
# commonly ignored for libraries.
|
| 105 |
+
uv.lock
|
| 106 |
+
|
| 107 |
+
# poetry
|
| 108 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 109 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 110 |
+
# commonly ignored for libraries.
|
| 111 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 112 |
+
#poetry.lock
|
| 113 |
+
#poetry.toml
|
| 114 |
+
|
| 115 |
+
# pdm
|
| 116 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 117 |
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
| 118 |
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
| 119 |
+
#pdm.lock
|
| 120 |
+
#pdm.toml
|
| 121 |
+
.pdm-python
|
| 122 |
+
.pdm-build/
|
| 123 |
+
|
| 124 |
+
# pixi
|
| 125 |
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
| 126 |
+
#pixi.lock
|
| 127 |
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
| 128 |
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
| 129 |
+
.pixi
|
| 130 |
+
|
| 131 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 132 |
+
__pypackages__/
|
| 133 |
+
|
| 134 |
+
# Celery stuff
|
| 135 |
+
celerybeat-schedule
|
| 136 |
+
celerybeat.pid
|
| 137 |
+
|
| 138 |
+
# SageMath parsed files
|
| 139 |
+
*.sage.py
|
| 140 |
+
|
| 141 |
+
# Environments
|
| 142 |
+
.env
|
| 143 |
+
.envrc
|
| 144 |
+
.venv
|
| 145 |
+
env/
|
| 146 |
+
venv/
|
| 147 |
+
ENV/
|
| 148 |
+
env.bak/
|
| 149 |
+
venv.bak/
|
| 150 |
+
|
| 151 |
+
# Spyder project settings
|
| 152 |
+
.spyderproject
|
| 153 |
+
.spyproject
|
| 154 |
+
|
| 155 |
+
# Rope project settings
|
| 156 |
+
.ropeproject
|
| 157 |
+
|
| 158 |
+
# mkdocs documentation
|
| 159 |
+
/site
|
| 160 |
+
|
| 161 |
+
# mypy
|
| 162 |
+
.mypy_cache/
|
| 163 |
+
.dmypy.json
|
| 164 |
+
dmypy.json
|
| 165 |
+
|
| 166 |
+
# Pyre type checker
|
| 167 |
+
.pyre/
|
| 168 |
+
|
| 169 |
+
# pytype static type analyzer
|
| 170 |
+
.pytype/
|
| 171 |
+
|
| 172 |
+
# Cython debug symbols
|
| 173 |
+
cython_debug/
|
| 174 |
+
|
| 175 |
+
# PyCharm
|
| 176 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 177 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 178 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 179 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 180 |
+
#.idea/
|
| 181 |
+
|
| 182 |
+
# Abstra
|
| 183 |
+
# Abstra is an AI-powered process automation framework.
|
| 184 |
+
# Ignore directories containing user credentials, local state, and settings.
|
| 185 |
+
# Learn more at https://abstra.io/docs
|
| 186 |
+
.abstra/
|
| 187 |
+
|
| 188 |
+
# Visual Studio Code
|
| 189 |
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
| 190 |
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
| 191 |
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
| 192 |
+
# you could uncomment the following to ignore the entire vscode folder
|
| 193 |
+
# .vscode/
|
| 194 |
+
|
| 195 |
+
# Ruff stuff:
|
| 196 |
+
.ruff_cache/
|
| 197 |
+
|
| 198 |
+
# PyPI configuration file
|
| 199 |
+
.pypirc
|
| 200 |
+
|
| 201 |
+
# Cursor
|
| 202 |
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
| 203 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 204 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 205 |
+
.cursorignore
|
| 206 |
+
.cursorindexingignore
|
| 207 |
+
|
| 208 |
+
# Marimo
|
| 209 |
+
marimo/_static/
|
| 210 |
+
marimo/_lsp/
|
| 211 |
+
__marimo__/
|
| 212 |
+
tests/
|
| 213 |
+
checkpoints/
|
| 214 |
+
playground.ipynb
|
| 215 |
+
.history/
|
| 216 |
+
upload_checkpoints.sh
|
| 217 |
+
checkpoints.7z
|
| 218 |
+
README_old.md
|
| 219 |
+
discord_bot/
|
| 220 |
+
feishu_bot/
|
| 221 |
+
tmp*
|
| 222 |
+
torchinductor_root/
|
| 223 |
+
scripts/
|
| 224 |
+
checkpoints_legacy/
|
| 225 |
+
lora_output/
|
| 226 |
+
datasets/
|
| 227 |
+
python_embeded/
|
| 228 |
+
checkpoints_pack/
|
spaces/Ace-Step-v1.5/Dockerfile
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace Space Docker SDK
|
| 2 |
+
# Use slim Python image - HuggingFace GPU Spaces provide CUDA runtime
|
| 3 |
+
FROM python:3.11-slim
|
| 4 |
+
|
| 5 |
+
# Set environment variables
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
PYTHONUNBUFFERED=1 \
|
| 8 |
+
DEBIAN_FRONTEND=noninteractive \
|
| 9 |
+
TORCHAUDIO_USE_TORCHCODEC=0
|
| 10 |
+
|
| 11 |
+
# Install system dependencies
|
| 12 |
+
# build-essential is required for triton to compile CUDA kernels
|
| 13 |
+
# ffmpeg and libav* dev packages are required for torchaudio's ffmpeg backend
|
| 14 |
+
# Note: torchaudio's ffmpeg backend needs shared libraries, not just the ffmpeg binary
|
| 15 |
+
RUN apt-get update && \
|
| 16 |
+
apt-get install -y --no-install-recommends git libsndfile1 build-essential && \
|
| 17 |
+
apt-get install -y ffmpeg libavcodec-dev libavformat-dev libavutil-dev libswresample-dev && \
|
| 18 |
+
rm -rf /var/lib/apt/lists/*
|
| 19 |
+
|
| 20 |
+
# Set up a new user named "user" with user ID 1000 (HuggingFace Space requirement)
|
| 21 |
+
RUN useradd -m -u 1000 user
|
| 22 |
+
|
| 23 |
+
# Create /data directory with proper permissions for persistent storage
|
| 24 |
+
RUN mkdir -p /data && chown user:user /data && chmod 755 /data
|
| 25 |
+
|
| 26 |
+
# Set environment variables for user
|
| 27 |
+
ENV HOME=/home/user \
|
| 28 |
+
PATH=/home/user/.local/bin:$PATH \
|
| 29 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 30 |
+
GRADIO_SERVER_PORT=7860
|
| 31 |
+
|
| 32 |
+
# Set the working directory
|
| 33 |
+
WORKDIR $HOME/app
|
| 34 |
+
|
| 35 |
+
# Copy requirements first for better Docker layer caching
|
| 36 |
+
COPY --chown=user:user requirements.txt .
|
| 37 |
+
|
| 38 |
+
# Copy the local nano-vllm package
|
| 39 |
+
COPY --chown=user:user acestep/third_parts/nano-vllm ./acestep/third_parts/nano-vllm
|
| 40 |
+
|
| 41 |
+
# Switch to user before installing packages
|
| 42 |
+
USER user
|
| 43 |
+
|
| 44 |
+
# Install dependencies from requirements.txt (includes PyTorch with CUDA from --extra-index-url)
|
| 45 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 46 |
+
|
| 47 |
+
# Install nano-vllm with --no-deps since all dependencies are already installed
|
| 48 |
+
RUN pip install --no-deps ./acestep/third_parts/nano-vllm
|
| 49 |
+
|
| 50 |
+
# Copy the rest of the application
|
| 51 |
+
COPY --chown=user:user . .
|
| 52 |
+
|
| 53 |
+
# Expose port
|
| 54 |
+
EXPOSE 7860
|
| 55 |
+
|
| 56 |
+
# Run the application
|
| 57 |
+
CMD ["python", "app.py"]
|
spaces/Ace-Step-v1.5/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 ACEStep
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
spaces/Ace-Step-v1.5/README.md
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ACE-Step v1.5
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
pinned: false
|
| 9 |
+
license: mit
|
| 10 |
+
short_description: Music Generation Foundation Model v1.5
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
<h1 align="center">ACE-Step 1.5</h1>
|
| 14 |
+
<h1 align="center">Pushing the Boundaries of Open-Source Music Generation</h1>
|
| 15 |
+
<p align="center">
|
| 16 |
+
<a href="https://ace-step.github.io/ace-step-v1.5.github.io/">Project</a> |
|
| 17 |
+
<a href="https://huggingface.co/collections/ACE-Step/ace-step-15">Hugging Face</a> |
|
| 18 |
+
<a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5">ModelScope</a> |
|
| 19 |
+
<a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5">Space Demo</a> |
|
| 20 |
+
<a href="https://discord.gg/PeWDxrkdj7">Discord</a> |
|
| 21 |
+
<a href="https://arxiv.org/abs/2506.00045">Technical Report</a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
<p align="center">
|
| 25 |
+
<img src="./assets/orgnization_logos.png" width="100%" alt="StepFun Logo">
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
## Table of Contents
|
| 29 |
+
|
| 30 |
+
- [✨ Features](#-features)
|
| 31 |
+
- [📦 Installation](#-installation)
|
| 32 |
+
- [🚀 Usage](#-usage)
|
| 33 |
+
- [🔨 Train](#-train)
|
| 34 |
+
- [🏗️ Architecture](#️-architecture)
|
| 35 |
+
- [🦁 Model Zoo](#-model-zoo)
|
| 36 |
+
|
| 37 |
+
## 📝 Abstract
|
| 38 |
+
We present ACE-Step v1.5, a highly efficient foundation model that democratizes commercial-grade music production on consumer hardware. Optimized for local deployment (<4GB VRAM), the model accelerates generation by over 100× compared to traditional pure LM architectures, producing superior high-fidelity audio in seconds characterized by coherent semantics and exceptional melodies. At its core lies a novel hybrid architecture where the Language Model (LM) functions as an omni-capable planner: it transforms simple user queries into comprehensive song blueprints—scaling from short loops to 10-minute compositions—while synthesizing metadata, lyrics, and captions via Chain-of-Thought to guide the Diffusion Transformer (DiT). Uniquely, this alignment is achieved through intrinsic reinforcement learning relying solely on the model’s internal mechanisms, thereby eliminating the biases inherent in external reward models or human preferences. Beyond standard synthesis, ACE-Step v1.5 unifies precise stylistic control with versatile editing capabilities—such as cover generation, repainting, and vocal-to-BGM conversion—while maintaining strict adherence to prompts across 50+ languages.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## ✨ Features
|
| 42 |
+
|
| 43 |
+
<p align="center">
|
| 44 |
+
<img src="./assets/application_map.png" width="100%" alt="ACE-Step Framework">
|
| 45 |
+
</p>
|
| 46 |
+
|
| 47 |
+
### ⚡ Performance
|
| 48 |
+
- ✅ **Ultra-Fast Generation** — 0.5s to 10s generation time on A100 (depending on think mode & diffusion steps)
|
| 49 |
+
- ✅ **Flexible Duration** — Supports 10 seconds to 10 minutes (600s) audio generation
|
| 50 |
+
- ✅ **Batch Generation** — Generate up to 8 songs simultaneously
|
| 51 |
+
|
| 52 |
+
### 🎵 Generation Quality
|
| 53 |
+
- ✅ **Commercial-Grade Output** — Quality between Suno v4.5 and Suno v5
|
| 54 |
+
- ✅ **Rich Style Support** — 1000+ instruments and styles with fine-grained timbre description
|
| 55 |
+
- ✅ **Multi-Language Lyrics** — Supports 50+ languages with lyrics prompt for structure & style control
|
| 56 |
+
|
| 57 |
+
### 🎛️ Versatility & Control
|
| 58 |
+
|
| 59 |
+
| Feature | Description |
|
| 60 |
+
|---------|-------------|
|
| 61 |
+
| ✅ Reference Audio Input | Use reference audio to guide generation style |
|
| 62 |
+
| ✅ Cover Generation | Create covers from existing audio |
|
| 63 |
+
| ✅ Repaint & Edit | Selective local audio editing and regeneration |
|
| 64 |
+
| ✅ Track Separation | Separate audio into individual stems |
|
| 65 |
+
| ✅ Multi-Track Generation | Add layers like Suno Studio's "Add Layer" feature |
|
| 66 |
+
| ✅ Vocal2BGM | Auto-generate accompaniment for vocal tracks |
|
| 67 |
+
| ✅ Metadata Control | Control duration, BPM, key/scale, time signature |
|
| 68 |
+
| ✅ Simple Mode | Generate full songs from simple descriptions |
|
| 69 |
+
| ✅ Query Rewriting | Auto LM expansion of tags and lyrics |
|
| 70 |
+
| ✅ Audio Understanding | Extract BPM, key/scale, time signature & caption from audio |
|
| 71 |
+
| ✅ LRC Generation | Auto-generate lyric timestamps for generated music |
|
| 72 |
+
| ✅ LoRA Training | One-click annotation & training in Gradio. 8 songs, 1 hour on 3090 (12GB VRAM) |
|
| 73 |
+
| ✅ Quality Scoring | Automatic quality assessment for generated audio |
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
## 📦 Installation
|
| 78 |
+
|
| 79 |
+
> **Requirements:** Python 3.11, CUDA GPU recommended (works on CPU/MPS but slower)
|
| 80 |
+
|
| 81 |
+
### 1. Install uv (Package Manager)
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
# macOS / Linux
|
| 85 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 86 |
+
|
| 87 |
+
# Windows (PowerShell)
|
| 88 |
+
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### 2. Clone & Install
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
git clone https://github.com/ACE-Step/ACE-Step-1.5.git
|
| 95 |
+
cd ACE-Step-1.5
|
| 96 |
+
uv sync
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### 3. Launch
|
| 100 |
+
|
| 101 |
+
#### 🖥️ Gradio Web UI (Recommended)
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
uv run acestep
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
Open http://localhost:7860 in your browser. Models will be downloaded automatically on first run.
|
| 108 |
+
|
| 109 |
+
#### 🌐 REST API Server
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
uv run acestep-api
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
API runs at http://localhost:8001. See [API Documentation](./docs/en/API.md) for endpoints.
|
| 116 |
+
|
| 117 |
+
### Command Line Options
|
| 118 |
+
|
| 119 |
+
**Gradio UI (`acestep`):**
|
| 120 |
+
|
| 121 |
+
| Option | Default | Description |
|
| 122 |
+
|--------|---------|-------------|
|
| 123 |
+
| `--port` | 7860 | Server port |
|
| 124 |
+
| `--server-name` | 127.0.0.1 | Server address (use `0.0.0.0` for network access) |
|
| 125 |
+
| `--share` | false | Create public Gradio link |
|
| 126 |
+
| `--language` | en | UI language: `en`, `zh`, `ja` |
|
| 127 |
+
| `--init_service` | false | Auto-initialize models on startup |
|
| 128 |
+
| `--config_path` | auto | DiT model (e.g., `acestep-v15-turbo`, `acestep-v15-turbo-shift3`) |
|
| 129 |
+
| `--lm_model_path` | auto | LM model (e.g., `acestep-5Hz-lm-0.6B`, `acestep-5Hz-lm-1.7B`) |
|
| 130 |
+
| `--offload_to_cpu` | auto | CPU offload (auto-enabled if VRAM < 16GB) |
|
| 131 |
+
|
| 132 |
+
**Examples:**
|
| 133 |
+
|
| 134 |
+
```bash
|
| 135 |
+
# Public access with Chinese UI
|
| 136 |
+
uv run acestep --server-name 0.0.0.0 --share --language zh
|
| 137 |
+
|
| 138 |
+
# Pre-initialize models on startup
|
| 139 |
+
uv run acestep --init_service true --config_path acestep-v15-turbo
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### Development
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
# Add dependencies
|
| 146 |
+
uv add package-name
|
| 147 |
+
uv add --dev package-name
|
| 148 |
+
|
| 149 |
+
# Update all dependencies
|
| 150 |
+
uv sync --upgrade
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
## 🚀 Usage
|
| 154 |
+
|
| 155 |
+
We provide multiple ways to use ACE-Step:
|
| 156 |
+
|
| 157 |
+
| Method | Description | Documentation |
|
| 158 |
+
|--------|-------------|---------------|
|
| 159 |
+
| 🖥️ **Gradio Web UI** | Interactive web interface for music generation | [Gradio Guide](./docs/en/GRADIO_GUIDE.md) |
|
| 160 |
+
| 🐍 **Python API** | Programmatic access for integration | [Inference API](./docs/en/INFERENCE.md) |
|
| 161 |
+
| 🌐 **REST API** | HTTP-based async API for services | [REST API](./docs/en/API.md) |
|
| 162 |
+
|
| 163 |
+
**📚 Documentation available in:** [English](./docs/en/) | [中文](./docs/zh/) | [日本語](./docs/ja/)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
## 🔨 Train
|
| 167 |
+
|
| 168 |
+
See the **LoRA Training** tab in Gradio UI for one-click training, or check [Gradio Guide - LoRA Training](./docs/en/GRADIO_GUIDE.md#lora-training) for details.
|
| 169 |
+
|
| 170 |
+
## 🏗️ Architecture
|
| 171 |
+
|
| 172 |
+
<p align="center">
|
| 173 |
+
<img src="./assets/ACE-Step_framework.png" width="100%" alt="ACE-Step Framework">
|
| 174 |
+
</p>
|
| 175 |
+
|
| 176 |
+
## 🦁 Model Zoo
|
| 177 |
+
|
| 178 |
+
<p align="center">
|
| 179 |
+
<img src="./assets/model_zoo.png" width="100%" alt="Model Zoo">
|
| 180 |
+
</p>
|
| 181 |
+
|
| 182 |
+
### DiT Models
|
| 183 |
+
|
| 184 |
+
| DiT Model | Pre-Training | SFT | RL | CFG | Step | Refer audio | Text2Music | Cover | Repaint | Extract | Lego | Complete | Quality | Diversity | Fine-Tunability | Hugging Face |
|
| 185 |
+
|-----------|:------------:|:---:|:--:|:---:|:----:|:-----------:|:----------:|:-----:|:-------:|:-------:|:----:|:--------:|:-------:|:---------:|:---------------:|--------------|
|
| 186 |
+
| `acestep-v15-base` | ✅ | ❌ | ❌ | ✅ | 50 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | High | Easy | [Link](https://huggingface.co/ACE-Step/acestep-v15-base) |
|
| 187 |
+
| `acestep-v15-sft` | ✅ | ✅ | ❌ | ✅ | 50 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | High | Medium | Easy | [Link](https://huggingface.co/ACE-Step/acestep-v15-sft) |
|
| 188 |
+
| `acestep-v15-turbo` | ✅ | ✅ | ❌ | ❌ | 8 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | Very High | Medium | Medium | [Link](https://huggingface.co/ACE-Step/Ace-Step1.5) |
|
| 189 |
+
| `acestep-v15-turbo-rl` | ✅ | ✅ | ✅ | ❌ | 8 | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | Very High | Medium | Medium | To be released |
|
| 190 |
+
|
| 191 |
+
### LM Models
|
| 192 |
+
|
| 193 |
+
| LM Model | Pretrain from | Pre-Training | SFT | RL | CoT metas | Query rewrite | Audio Understanding | Composition Capability | Copy Melody | Hugging Face |
|
| 194 |
+
|----------|---------------|:------------:|:---:|:--:|:---------:|:-------------:|:-------------------:|:----------------------:|:-----------:|--------------|
|
| 195 |
+
| `acestep-5Hz-lm-0.6B` | Qwen3-0.6B | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | Medium | Weak | ✅ |
|
| 196 |
+
| `acestep-5Hz-lm-1.7B` | Qwen3-1.7B | ✅ | ✅ | ✅ | ✅ | ✅ | Medium | Medium | Medium | ✅ |
|
| 197 |
+
| `acestep-5Hz-lm-4B` | Qwen3-4B | ✅ | ✅ | ✅ | ✅ | ✅ | Strong | Strong | Strong | To be released |
|
| 198 |
+
|
| 199 |
+
## 📜 License & Disclaimer
|
| 200 |
+
|
| 201 |
+
This project is licensed under [MIT](./LICENSE)
|
| 202 |
+
|
| 203 |
+
ACE-Step enables original music generation across diverse genres, with applications in creative production, education, and entertainment. While designed to support positive and artistic use cases, we acknowledge potential risks such as unintentional copyright infringement due to stylistic similarity, inappropriate blending of cultural elements, and misuse for generating harmful content. To ensure responsible use, we encourage users to verify the originality of generated works, clearly disclose AI involvement, and obtain appropriate permissions when adapting protected styles or materials. By using ACE-Step, you agree to uphold these principles and respect artistic integrity, cultural diversity, and legal compliance. The authors are not responsible for any misuse of the model, including but not limited to copyright violations, cultural insensitivity, or the generation of harmful content.
|
| 204 |
+
|
| 205 |
+
🔔 Important Notice
|
| 206 |
+
The only official website for the ACE-Step project is our GitHub Pages site.
|
| 207 |
+
We do not operate any other websites.
|
| 208 |
+
🚫 Fake domains include but are not limited to:
|
| 209 |
+
ac\*\*p.com, a\*\*p.org, a\*\*\*c.org
|
| 210 |
+
⚠️ Please be cautious. Do not visit, trust, or make payments on any of those sites.
|
| 211 |
+
|
| 212 |
+
## 🙏 Acknowledgements
|
| 213 |
+
|
| 214 |
+
This project is co-led by ACE Studio and StepFun.
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
## 📖 Citation
|
| 218 |
+
|
| 219 |
+
If you find this project useful for your research, please consider citing:
|
| 220 |
+
|
| 221 |
+
```BibTeX
|
| 222 |
+
@misc{gong2026acestep,
|
| 223 |
+
title={ACE-Step 1.5: Pushing the Boundaries of Open-Source Music Generation},
|
| 224 |
+
author={Junmin Gong, Song Yulin, Wenxiao Zhao, Sen Wang, Shengyuan Xu, Jing Guo},
|
| 225 |
+
howpublished={\url{https://github.com/ace-step/ACE-Step-1.5}},
|
| 226 |
+
year={2026},
|
| 227 |
+
note={GitHub repository}
|
| 228 |
+
}
|
| 229 |
+
```
|
spaces/Ace-Step-v1.5/acestep/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ACE-Step package."""
|
spaces/Ace-Step-v1.5/acestep/acestep_v15_pipeline.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step V1.5 Pipeline
|
| 3 |
+
Handler wrapper connecting model and UI
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
# Load environment variables from .env file in project root
|
| 9 |
+
# This allows configuration without hardcoding values
|
| 10 |
+
# Falls back to .env.example if .env is not found
|
| 11 |
+
try:
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
# Get project root directory
|
| 14 |
+
_current_file = os.path.abspath(__file__)
|
| 15 |
+
_project_root = os.path.dirname(os.path.dirname(_current_file))
|
| 16 |
+
_env_path = os.path.join(_project_root, '.env')
|
| 17 |
+
_env_example_path = os.path.join(_project_root, '.env.example')
|
| 18 |
+
|
| 19 |
+
if os.path.exists(_env_path):
|
| 20 |
+
load_dotenv(_env_path)
|
| 21 |
+
print(f"Loaded configuration from {_env_path}")
|
| 22 |
+
elif os.path.exists(_env_example_path):
|
| 23 |
+
load_dotenv(_env_example_path)
|
| 24 |
+
print(f"Loaded configuration from {_env_example_path} (fallback)")
|
| 25 |
+
except ImportError:
|
| 26 |
+
# python-dotenv not installed, skip loading .env
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
# Clear proxy settings that may affect Gradio
|
| 30 |
+
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 31 |
+
os.environ.pop(proxy_var, None)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# When executed as a module: `python -m acestep.acestep_v15_pipeline`
|
| 35 |
+
from .handler import AceStepHandler
|
| 36 |
+
from .llm_inference import LLMHandler
|
| 37 |
+
from .dataset_handler import DatasetHandler
|
| 38 |
+
from .gradio_ui import create_gradio_interface
|
| 39 |
+
except ImportError:
|
| 40 |
+
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
|
| 41 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 42 |
+
if project_root not in sys.path:
|
| 43 |
+
sys.path.insert(0, project_root)
|
| 44 |
+
from acestep.handler import AceStepHandler
|
| 45 |
+
from acestep.llm_inference import LLMHandler
|
| 46 |
+
from acestep.dataset_handler import DatasetHandler
|
| 47 |
+
from acestep.gradio_ui import create_gradio_interface
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def create_demo(init_params=None, language='en'):
|
| 51 |
+
"""
|
| 52 |
+
Create Gradio demo interface
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 56 |
+
If None, service will not be pre-initialized.
|
| 57 |
+
Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
|
| 58 |
+
'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
|
| 59 |
+
'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
|
| 60 |
+
'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
|
| 61 |
+
'language' (UI language code)
|
| 62 |
+
language: UI language code ('en', 'zh', 'ja', default: 'en')
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Gradio Blocks instance
|
| 66 |
+
"""
|
| 67 |
+
# Get persistent storage path from init_params (for HuggingFace Space)
|
| 68 |
+
persistent_storage_path = None
|
| 69 |
+
if init_params:
|
| 70 |
+
persistent_storage_path = init_params.get('persistent_storage_path')
|
| 71 |
+
|
| 72 |
+
# Use pre-initialized handlers if available, otherwise create new ones
|
| 73 |
+
if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
|
| 74 |
+
dit_handler = init_params['dit_handler']
|
| 75 |
+
llm_handler = init_params['llm_handler']
|
| 76 |
+
else:
|
| 77 |
+
dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
|
| 78 |
+
llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
|
| 79 |
+
|
| 80 |
+
dataset_handler = DatasetHandler() # Dataset handler
|
| 81 |
+
|
| 82 |
+
# Create Gradio interface with all handlers and initialization parameters
|
| 83 |
+
demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
|
| 84 |
+
|
| 85 |
+
return demo
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_gpu_memory_gb():
|
| 89 |
+
"""
|
| 90 |
+
Get GPU memory in GB. Returns 0 if no GPU is available.
|
| 91 |
+
"""
|
| 92 |
+
try:
|
| 93 |
+
import torch
|
| 94 |
+
if torch.cuda.is_available():
|
| 95 |
+
# Get total memory of the first GPU in GB
|
| 96 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory
|
| 97 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 98 |
+
return memory_gb
|
| 99 |
+
else:
|
| 100 |
+
return 0
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
|
| 103 |
+
return 0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def main():
|
| 107 |
+
"""Main entry function"""
|
| 108 |
+
import argparse
|
| 109 |
+
|
| 110 |
+
# Detect GPU memory to auto-configure offload settings
|
| 111 |
+
gpu_memory_gb = get_gpu_memory_gb()
|
| 112 |
+
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
|
| 113 |
+
|
| 114 |
+
if auto_offload:
|
| 115 |
+
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
|
| 116 |
+
print("Auto-enabling CPU offload to reduce GPU memory usage")
|
| 117 |
+
elif gpu_memory_gb > 0:
|
| 118 |
+
print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
|
| 119 |
+
print("CPU offload disabled by default")
|
| 120 |
+
else:
|
| 121 |
+
print("No GPU detected, running on CPU")
|
| 122 |
+
|
| 123 |
+
parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
|
| 124 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
|
| 125 |
+
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 126 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 127 |
+
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
|
| 128 |
+
parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
|
| 129 |
+
|
| 130 |
+
# Service mode argument
|
| 131 |
+
parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
|
| 132 |
+
help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
|
| 133 |
+
|
| 134 |
+
# Service initialization arguments
|
| 135 |
+
parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
|
| 136 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
|
| 137 |
+
parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
|
| 138 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
|
| 139 |
+
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
|
| 140 |
+
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
|
| 141 |
+
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
|
| 142 |
+
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
|
| 143 |
+
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
|
| 144 |
+
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
|
| 145 |
+
|
| 146 |
+
args = parser.parse_args()
|
| 147 |
+
|
| 148 |
+
# Service mode defaults (can be configured via .env file)
|
| 149 |
+
if args.service_mode:
|
| 150 |
+
print("Service mode enabled - applying preset configurations...")
|
| 151 |
+
# Force init_service in service mode
|
| 152 |
+
args.init_service = True
|
| 153 |
+
# Default DiT model for service mode (from env or fallback)
|
| 154 |
+
if args.config_path is None:
|
| 155 |
+
args.config_path = os.environ.get(
|
| 156 |
+
"SERVICE_MODE_DIT_MODEL",
|
| 157 |
+
"acestep-v15-turbo-fix-inst-shift-dynamic"
|
| 158 |
+
)
|
| 159 |
+
# Default LM model for service mode (from env or fallback)
|
| 160 |
+
if args.lm_model_path is None:
|
| 161 |
+
args.lm_model_path = os.environ.get(
|
| 162 |
+
"SERVICE_MODE_LM_MODEL",
|
| 163 |
+
"acestep-5Hz-lm-1.7B-v4-fix"
|
| 164 |
+
)
|
| 165 |
+
# Backend for service mode (from env or fallback to vllm)
|
| 166 |
+
args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
|
| 167 |
+
print(f" DiT model: {args.config_path}")
|
| 168 |
+
print(f" LM model: {args.lm_model_path}")
|
| 169 |
+
print(f" Backend: {args.backend}")
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
init_params = None
|
| 173 |
+
|
| 174 |
+
# If init_service is True, perform initialization before creating UI
|
| 175 |
+
if args.init_service:
|
| 176 |
+
print("Initializing service from command line...")
|
| 177 |
+
|
| 178 |
+
# Create handler instances for initialization
|
| 179 |
+
dit_handler = AceStepHandler()
|
| 180 |
+
llm_handler = LLMHandler()
|
| 181 |
+
|
| 182 |
+
# Auto-select config_path if not provided
|
| 183 |
+
if args.config_path is None:
|
| 184 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 185 |
+
if available_models:
|
| 186 |
+
args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
|
| 187 |
+
print(f"Auto-selected config_path: {args.config_path}")
|
| 188 |
+
else:
|
| 189 |
+
print("Error: No available models found. Please specify --config_path", file=sys.stderr)
|
| 190 |
+
sys.exit(1)
|
| 191 |
+
|
| 192 |
+
# Get project root (same logic as in handler)
|
| 193 |
+
current_file = os.path.abspath(__file__)
|
| 194 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 195 |
+
|
| 196 |
+
# Determine flash attention setting
|
| 197 |
+
use_flash_attention = args.use_flash_attention
|
| 198 |
+
if use_flash_attention is None:
|
| 199 |
+
use_flash_attention = dit_handler.is_flash_attention_available()
|
| 200 |
+
|
| 201 |
+
# Initialize DiT handler
|
| 202 |
+
print(f"Initializing DiT model: {args.config_path} on {args.device}...")
|
| 203 |
+
init_status, enable_generate = dit_handler.initialize_service(
|
| 204 |
+
project_root=project_root,
|
| 205 |
+
config_path=args.config_path,
|
| 206 |
+
device=args.device,
|
| 207 |
+
use_flash_attention=use_flash_attention,
|
| 208 |
+
compile_model=False,
|
| 209 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 210 |
+
offload_dit_to_cpu=args.offload_dit_to_cpu
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if not enable_generate:
|
| 214 |
+
print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
|
| 215 |
+
sys.exit(1)
|
| 216 |
+
|
| 217 |
+
print(f"DiT model initialized successfully")
|
| 218 |
+
|
| 219 |
+
# Initialize LM handler if requested
|
| 220 |
+
lm_status = ""
|
| 221 |
+
if args.init_llm:
|
| 222 |
+
if args.lm_model_path is None:
|
| 223 |
+
# Try to get default LM model
|
| 224 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 225 |
+
if available_lm_models:
|
| 226 |
+
args.lm_model_path = available_lm_models[0]
|
| 227 |
+
print(f"Using default LM model: {args.lm_model_path}")
|
| 228 |
+
else:
|
| 229 |
+
print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
|
| 230 |
+
args.init_llm = False
|
| 231 |
+
|
| 232 |
+
if args.init_llm and args.lm_model_path:
|
| 233 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 234 |
+
print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
|
| 235 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 236 |
+
checkpoint_dir=checkpoint_dir,
|
| 237 |
+
lm_model_path=args.lm_model_path,
|
| 238 |
+
backend=args.backend,
|
| 239 |
+
device=args.device,
|
| 240 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 241 |
+
dtype=dit_handler.dtype
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if lm_success:
|
| 245 |
+
print(f"5Hz LM initialized successfully")
|
| 246 |
+
init_status += f"\n{lm_status}"
|
| 247 |
+
else:
|
| 248 |
+
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 249 |
+
init_status += f"\n{lm_status}"
|
| 250 |
+
|
| 251 |
+
# Prepare initialization parameters for UI
|
| 252 |
+
init_params = {
|
| 253 |
+
'pre_initialized': True,
|
| 254 |
+
'service_mode': args.service_mode,
|
| 255 |
+
'checkpoint': args.checkpoint,
|
| 256 |
+
'config_path': args.config_path,
|
| 257 |
+
'device': args.device,
|
| 258 |
+
'init_llm': args.init_llm,
|
| 259 |
+
'lm_model_path': args.lm_model_path,
|
| 260 |
+
'backend': args.backend,
|
| 261 |
+
'use_flash_attention': use_flash_attention,
|
| 262 |
+
'offload_to_cpu': args.offload_to_cpu,
|
| 263 |
+
'offload_dit_to_cpu': args.offload_dit_to_cpu,
|
| 264 |
+
'init_status': init_status,
|
| 265 |
+
'enable_generate': enable_generate,
|
| 266 |
+
'dit_handler': dit_handler,
|
| 267 |
+
'llm_handler': llm_handler,
|
| 268 |
+
'language': args.language
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
print("Service initialization completed successfully!")
|
| 272 |
+
|
| 273 |
+
# Create and launch demo
|
| 274 |
+
print(f"Creating Gradio interface with language: {args.language}...")
|
| 275 |
+
demo = create_demo(init_params=init_params, language=args.language)
|
| 276 |
+
|
| 277 |
+
# Enable queue for multi-user support
|
| 278 |
+
# This ensures proper request queuing and prevents concurrent generation conflicts
|
| 279 |
+
print("Enabling queue for multi-user support...")
|
| 280 |
+
demo.queue(
|
| 281 |
+
max_size=20, # Maximum queue size (adjust based on your needs)
|
| 282 |
+
status_update_rate="auto", # Update rate for queue status
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print(f"Launching server on {args.server_name}:{args.port}...")
|
| 286 |
+
demo.launch(
|
| 287 |
+
server_name=args.server_name,
|
| 288 |
+
server_port=args.port,
|
| 289 |
+
share=args.share,
|
| 290 |
+
debug=args.debug,
|
| 291 |
+
show_error=True,
|
| 292 |
+
prevent_thread_lock=False, # Keep thread locked to maintain server running
|
| 293 |
+
inbrowser=False, # Don't auto-open browser
|
| 294 |
+
)
|
| 295 |
+
except Exception as e:
|
| 296 |
+
print(f"Error launching Gradio: {e}", file=sys.stderr)
|
| 297 |
+
import traceback
|
| 298 |
+
traceback.print_exc()
|
| 299 |
+
sys.exit(1)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
if __name__ == "__main__":
|
| 303 |
+
main()
|
spaces/Ace-Step-v1.5/acestep/api_server.py
ADDED
|
@@ -0,0 +1,1700 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI server for ACE-Step V1.5.
|
| 2 |
+
|
| 3 |
+
Endpoints:
|
| 4 |
+
- POST /release_task Create music generation task
|
| 5 |
+
- POST /query_result Batch query task results
|
| 6 |
+
- POST /v1/music/random Create random sample task
|
| 7 |
+
- GET /v1/models List available models
|
| 8 |
+
- GET /v1/audio Download audio file
|
| 9 |
+
- GET /health Health check
|
| 10 |
+
|
| 11 |
+
NOTE:
|
| 12 |
+
- In-memory queue and job store -> run uvicorn with workers=1.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import asyncio
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import time
|
| 22 |
+
import traceback
|
| 23 |
+
import tempfile
|
| 24 |
+
import urllib.parse
|
| 25 |
+
from collections import deque
|
| 26 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 27 |
+
from contextlib import asynccontextmanager
|
| 28 |
+
from dataclasses import dataclass
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from threading import Lock
|
| 31 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 32 |
+
from uuid import uuid4
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from dotenv import load_dotenv
|
| 36 |
+
except ImportError: # Optional dependency
|
| 37 |
+
load_dotenv = None # type: ignore
|
| 38 |
+
|
| 39 |
+
from fastapi import FastAPI, HTTPException, Request
|
| 40 |
+
from pydantic import BaseModel, Field
|
| 41 |
+
from starlette.datastructures import UploadFile as StarletteUploadFile
|
| 42 |
+
|
| 43 |
+
from acestep.handler import AceStepHandler
|
| 44 |
+
from acestep.llm_inference import LLMHandler
|
| 45 |
+
from acestep.constants import (
|
| 46 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 47 |
+
DEFAULT_LM_INSTRUCTION,
|
| 48 |
+
TASK_INSTRUCTIONS,
|
| 49 |
+
)
|
| 50 |
+
from acestep.inference import (
|
| 51 |
+
GenerationParams,
|
| 52 |
+
GenerationConfig,
|
| 53 |
+
generate_music,
|
| 54 |
+
create_sample,
|
| 55 |
+
format_sample,
|
| 56 |
+
)
|
| 57 |
+
from acestep.gradio_ui.events.results_handlers import _build_generation_info
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# =============================================================================
|
| 61 |
+
# Constants
|
| 62 |
+
# =============================================================================
|
| 63 |
+
|
| 64 |
+
RESULT_KEY_PREFIX = "ace_step_v1.5_"
|
| 65 |
+
RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days
|
| 66 |
+
TASK_TIMEOUT_SECONDS = 3600 # 1 hour
|
| 67 |
+
STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2}
|
| 68 |
+
|
| 69 |
+
LM_DEFAULT_TEMPERATURE = 0.85
|
| 70 |
+
LM_DEFAULT_CFG_SCALE = 2.5
|
| 71 |
+
LM_DEFAULT_TOP_P = 0.9
|
| 72 |
+
|
| 73 |
+
# Parameter aliases for request parsing
|
| 74 |
+
PARAM_ALIASES = {
|
| 75 |
+
"prompt": ["prompt"],
|
| 76 |
+
"sample_mode": ["sample_mode", "sampleMode"],
|
| 77 |
+
"sample_query": ["sample_query", "sampleQuery", "description", "desc"],
|
| 78 |
+
"use_format": ["use_format", "useFormat", "format"],
|
| 79 |
+
"model": ["model", "dit_model", "ditModel"],
|
| 80 |
+
"key_scale": ["key_scale", "keyscale", "keyScale"],
|
| 81 |
+
"time_signature": ["time_signature", "timesignature", "timeSignature"],
|
| 82 |
+
"audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"],
|
| 83 |
+
"vocal_language": ["vocal_language", "vocalLanguage"],
|
| 84 |
+
"inference_steps": ["inference_steps", "inferenceSteps"],
|
| 85 |
+
"guidance_scale": ["guidance_scale", "guidanceScale"],
|
| 86 |
+
"use_random_seed": ["use_random_seed", "useRandomSeed"],
|
| 87 |
+
"audio_code_string": ["audio_code_string", "audioCodeString"],
|
| 88 |
+
"audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"],
|
| 89 |
+
"task_type": ["task_type", "taskType"],
|
| 90 |
+
"infer_method": ["infer_method", "inferMethod"],
|
| 91 |
+
"use_tiled_decode": ["use_tiled_decode", "useTiledDecode"],
|
| 92 |
+
"constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"],
|
| 93 |
+
"constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"],
|
| 94 |
+
"use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"],
|
| 95 |
+
"use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
|
| 96 |
+
"is_format_caption": ["is_format_caption", "isFormatCaption"],
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
|
| 101 |
+
"""
|
| 102 |
+
Parse a description string to extract language code and instrumental flag.
|
| 103 |
+
|
| 104 |
+
This function analyzes user descriptions like "Pop rock. English" or "piano solo"
|
| 105 |
+
to detect:
|
| 106 |
+
- Language: Maps language names to ISO codes (e.g., "English" -> "en")
|
| 107 |
+
- Instrumental: Detects patterns indicating instrumental/no-vocal music
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
description: User's natural language music description
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
(language_code, is_instrumental) tuple:
|
| 114 |
+
- language_code: ISO language code (e.g., "en", "zh") or None if not detected
|
| 115 |
+
- is_instrumental: True if description indicates instrumental music
|
| 116 |
+
"""
|
| 117 |
+
import re
|
| 118 |
+
|
| 119 |
+
if not description:
|
| 120 |
+
return None, False
|
| 121 |
+
|
| 122 |
+
description_lower = description.lower().strip()
|
| 123 |
+
|
| 124 |
+
# Language mapping: input patterns -> ISO code
|
| 125 |
+
language_mapping = {
|
| 126 |
+
'english': 'en', 'en': 'en',
|
| 127 |
+
'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
|
| 128 |
+
'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
|
| 129 |
+
'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
|
| 130 |
+
'spanish': 'es', 'español': 'es', 'es': 'es',
|
| 131 |
+
'french': 'fr', 'français': 'fr', 'fr': 'fr',
|
| 132 |
+
'german': 'de', 'deutsch': 'de', 'de': 'de',
|
| 133 |
+
'italian': 'it', 'italiano': 'it', 'it': 'it',
|
| 134 |
+
'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
|
| 135 |
+
'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
|
| 136 |
+
'bengali': 'bn', 'bn': 'bn',
|
| 137 |
+
'hindi': 'hi', 'hi': 'hi',
|
| 138 |
+
'arabic': 'ar', 'ar': 'ar',
|
| 139 |
+
'thai': 'th', 'th': 'th',
|
| 140 |
+
'vietnamese': 'vi', 'vi': 'vi',
|
| 141 |
+
'indonesian': 'id', 'id': 'id',
|
| 142 |
+
'turkish': 'tr', 'tr': 'tr',
|
| 143 |
+
'dutch': 'nl', 'nl': 'nl',
|
| 144 |
+
'polish': 'pl', 'pl': 'pl',
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
# Detect language
|
| 148 |
+
detected_language = None
|
| 149 |
+
for lang_name, lang_code in language_mapping.items():
|
| 150 |
+
if len(lang_name) <= 2:
|
| 151 |
+
pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
|
| 152 |
+
else:
|
| 153 |
+
pattern = r'\b' + re.escape(lang_name) + r'\b'
|
| 154 |
+
|
| 155 |
+
if re.search(pattern, description_lower):
|
| 156 |
+
detected_language = lang_code
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
# Detect instrumental
|
| 160 |
+
is_instrumental = False
|
| 161 |
+
if 'instrumental' in description_lower:
|
| 162 |
+
is_instrumental = True
|
| 163 |
+
elif 'pure music' in description_lower or 'pure instrument' in description_lower:
|
| 164 |
+
is_instrumental = True
|
| 165 |
+
elif description_lower.endswith(' solo') or description_lower == 'solo':
|
| 166 |
+
is_instrumental = True
|
| 167 |
+
|
| 168 |
+
return detected_language, is_instrumental
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
JobStatus = Literal["queued", "running", "succeeded", "failed"]
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class GenerateMusicRequest(BaseModel):
|
| 175 |
+
prompt: str = Field(default="", description="Text prompt describing the music")
|
| 176 |
+
lyrics: str = Field(default="", description="Lyric text")
|
| 177 |
+
|
| 178 |
+
# New API semantics:
|
| 179 |
+
# - thinking=True: use 5Hz LM to generate audio codes (lm-dit behavior)
|
| 180 |
+
# - thinking=False: do not use LM to generate codes (dit behavior)
|
| 181 |
+
# Regardless of thinking, if some metas are missing, server may use LM to fill them.
|
| 182 |
+
thinking: bool = False
|
| 183 |
+
# Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
|
| 184 |
+
sample_mode: bool = False
|
| 185 |
+
# Description for sample mode: auto-generate caption/lyrics from description query
|
| 186 |
+
sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
|
| 187 |
+
# Whether to use format_sample() to enhance input caption/lyrics
|
| 188 |
+
use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
|
| 189 |
+
# Model name for multi-model support (select which DiT model to use)
|
| 190 |
+
model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
|
| 191 |
+
|
| 192 |
+
bpm: Optional[int] = None
|
| 193 |
+
# Accept common client keys via manual parsing (see RequestParser).
|
| 194 |
+
key_scale: str = ""
|
| 195 |
+
time_signature: str = ""
|
| 196 |
+
vocal_language: str = "en"
|
| 197 |
+
inference_steps: int = 8
|
| 198 |
+
guidance_scale: float = 7.0
|
| 199 |
+
use_random_seed: bool = True
|
| 200 |
+
seed: int = -1
|
| 201 |
+
|
| 202 |
+
reference_audio_path: Optional[str] = None
|
| 203 |
+
src_audio_path: Optional[str] = None
|
| 204 |
+
audio_duration: Optional[float] = None
|
| 205 |
+
batch_size: Optional[int] = None
|
| 206 |
+
|
| 207 |
+
audio_code_string: str = ""
|
| 208 |
+
|
| 209 |
+
repainting_start: float = 0.0
|
| 210 |
+
repainting_end: Optional[float] = None
|
| 211 |
+
|
| 212 |
+
instruction: str = DEFAULT_DIT_INSTRUCTION
|
| 213 |
+
audio_cover_strength: float = 1.0
|
| 214 |
+
task_type: str = "text2music"
|
| 215 |
+
|
| 216 |
+
use_adg: bool = False
|
| 217 |
+
cfg_interval_start: float = 0.0
|
| 218 |
+
cfg_interval_end: float = 1.0
|
| 219 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 220 |
+
shift: float = Field(
|
| 221 |
+
default=3.0,
|
| 222 |
+
description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
|
| 223 |
+
)
|
| 224 |
+
timesteps: Optional[str] = Field(
|
| 225 |
+
default=None,
|
| 226 |
+
description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
audio_format: str = "mp3"
|
| 230 |
+
use_tiled_decode: bool = True
|
| 231 |
+
|
| 232 |
+
# 5Hz LM (server-side): used for metadata completion and (when thinking=True) codes generation.
|
| 233 |
+
lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
|
| 234 |
+
lm_backend: Literal["vllm", "pt"] = "vllm"
|
| 235 |
+
|
| 236 |
+
constrained_decoding: bool = True
|
| 237 |
+
constrained_decoding_debug: bool = False
|
| 238 |
+
use_cot_caption: bool = True
|
| 239 |
+
use_cot_language: bool = True
|
| 240 |
+
is_format_caption: bool = False
|
| 241 |
+
|
| 242 |
+
lm_temperature: float = 0.85
|
| 243 |
+
lm_cfg_scale: float = 2.5
|
| 244 |
+
lm_top_k: Optional[int] = None
|
| 245 |
+
lm_top_p: Optional[float] = 0.9
|
| 246 |
+
lm_repetition_penalty: float = 1.0
|
| 247 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 248 |
+
|
| 249 |
+
class Config:
|
| 250 |
+
allow_population_by_field_name = True
|
| 251 |
+
allow_population_by_alias = True
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class CreateJobResponse(BaseModel):
|
| 255 |
+
task_id: str
|
| 256 |
+
status: JobStatus
|
| 257 |
+
queue_position: int = 0 # 1-based best-effort position when queued
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class JobResult(BaseModel):
|
| 261 |
+
first_audio_path: Optional[str] = None
|
| 262 |
+
second_audio_path: Optional[str] = None
|
| 263 |
+
audio_paths: list[str] = Field(default_factory=list)
|
| 264 |
+
|
| 265 |
+
generation_info: str = ""
|
| 266 |
+
status_message: str = ""
|
| 267 |
+
seed_value: str = ""
|
| 268 |
+
|
| 269 |
+
metas: Dict[str, Any] = Field(default_factory=dict)
|
| 270 |
+
bpm: Optional[int] = None
|
| 271 |
+
duration: Optional[float] = None
|
| 272 |
+
genres: Optional[str] = None
|
| 273 |
+
keyscale: Optional[str] = None
|
| 274 |
+
timesignature: Optional[str] = None
|
| 275 |
+
|
| 276 |
+
# Model information
|
| 277 |
+
lm_model: Optional[str] = None
|
| 278 |
+
dit_model: Optional[str] = None
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class JobResponse(BaseModel):
|
| 282 |
+
job_id: str
|
| 283 |
+
status: JobStatus
|
| 284 |
+
created_at: float
|
| 285 |
+
started_at: Optional[float] = None
|
| 286 |
+
finished_at: Optional[float] = None
|
| 287 |
+
|
| 288 |
+
# queue observability
|
| 289 |
+
queue_position: int = 0
|
| 290 |
+
eta_seconds: Optional[float] = None
|
| 291 |
+
avg_job_seconds: Optional[float] = None
|
| 292 |
+
|
| 293 |
+
result: Optional[JobResult] = None
|
| 294 |
+
error: Optional[str] = None
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@dataclass
|
| 298 |
+
class _JobRecord:
|
| 299 |
+
job_id: str
|
| 300 |
+
status: JobStatus
|
| 301 |
+
created_at: float
|
| 302 |
+
started_at: Optional[float] = None
|
| 303 |
+
finished_at: Optional[float] = None
|
| 304 |
+
result: Optional[Dict[str, Any]] = None
|
| 305 |
+
error: Optional[str] = None
|
| 306 |
+
env: str = "development"
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class _JobStore:
|
| 310 |
+
def __init__(self) -> None:
|
| 311 |
+
self._lock = Lock()
|
| 312 |
+
self._jobs: Dict[str, _JobRecord] = {}
|
| 313 |
+
|
| 314 |
+
def create(self) -> _JobRecord:
|
| 315 |
+
job_id = str(uuid4())
|
| 316 |
+
rec = _JobRecord(job_id=job_id, status="queued", created_at=time.time())
|
| 317 |
+
with self._lock:
|
| 318 |
+
self._jobs[job_id] = rec
|
| 319 |
+
return rec
|
| 320 |
+
|
| 321 |
+
def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord:
|
| 322 |
+
"""Create job record with specified ID"""
|
| 323 |
+
rec = _JobRecord(
|
| 324 |
+
job_id=job_id,
|
| 325 |
+
status="queued",
|
| 326 |
+
created_at=time.time(),
|
| 327 |
+
env=env
|
| 328 |
+
)
|
| 329 |
+
with self._lock:
|
| 330 |
+
self._jobs[job_id] = rec
|
| 331 |
+
return rec
|
| 332 |
+
|
| 333 |
+
def get(self, job_id: str) -> Optional[_JobRecord]:
|
| 334 |
+
with self._lock:
|
| 335 |
+
return self._jobs.get(job_id)
|
| 336 |
+
|
| 337 |
+
def mark_running(self, job_id: str) -> None:
|
| 338 |
+
with self._lock:
|
| 339 |
+
rec = self._jobs[job_id]
|
| 340 |
+
rec.status = "running"
|
| 341 |
+
rec.started_at = time.time()
|
| 342 |
+
|
| 343 |
+
def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None:
|
| 344 |
+
with self._lock:
|
| 345 |
+
rec = self._jobs[job_id]
|
| 346 |
+
rec.status = "succeeded"
|
| 347 |
+
rec.finished_at = time.time()
|
| 348 |
+
rec.result = result
|
| 349 |
+
rec.error = None
|
| 350 |
+
|
| 351 |
+
def mark_failed(self, job_id: str, error: str) -> None:
|
| 352 |
+
with self._lock:
|
| 353 |
+
rec = self._jobs[job_id]
|
| 354 |
+
rec.status = "failed"
|
| 355 |
+
rec.finished_at = time.time()
|
| 356 |
+
rec.result = None
|
| 357 |
+
rec.error = error
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 361 |
+
v = os.getenv(name)
|
| 362 |
+
if v is None:
|
| 363 |
+
return default
|
| 364 |
+
return v.strip().lower() in {"1", "true", "yes", "y", "on"}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _get_project_root() -> str:
|
| 368 |
+
current_file = os.path.abspath(__file__)
|
| 369 |
+
return os.path.dirname(os.path.dirname(current_file))
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _get_model_name(config_path: str) -> str:
|
| 373 |
+
"""
|
| 374 |
+
Extract model name from config_path.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
|
| 378 |
+
|
| 379 |
+
Returns:
|
| 380 |
+
Model name (last directory name from config_path)
|
| 381 |
+
"""
|
| 382 |
+
if not config_path:
|
| 383 |
+
return ""
|
| 384 |
+
normalized = config_path.rstrip("/\\")
|
| 385 |
+
return os.path.basename(normalized)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _load_project_env() -> None:
|
| 389 |
+
if load_dotenv is None:
|
| 390 |
+
return
|
| 391 |
+
try:
|
| 392 |
+
project_root = _get_project_root()
|
| 393 |
+
env_path = os.path.join(project_root, ".env")
|
| 394 |
+
if os.path.exists(env_path):
|
| 395 |
+
load_dotenv(env_path, override=False)
|
| 396 |
+
except Exception:
|
| 397 |
+
# Optional best-effort: continue even if .env loading fails.
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
_load_project_env()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
|
| 405 |
+
if v is None:
|
| 406 |
+
return default
|
| 407 |
+
if isinstance(v, int):
|
| 408 |
+
return v
|
| 409 |
+
s = str(v).strip()
|
| 410 |
+
if s == "":
|
| 411 |
+
return default
|
| 412 |
+
try:
|
| 413 |
+
return int(s)
|
| 414 |
+
except Exception:
|
| 415 |
+
return default
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]:
|
| 419 |
+
if v is None:
|
| 420 |
+
return default
|
| 421 |
+
if isinstance(v, float):
|
| 422 |
+
return v
|
| 423 |
+
s = str(v).strip()
|
| 424 |
+
if s == "":
|
| 425 |
+
return default
|
| 426 |
+
try:
|
| 427 |
+
return float(s)
|
| 428 |
+
except Exception:
|
| 429 |
+
return default
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _to_bool(v: Any, default: bool = False) -> bool:
|
| 433 |
+
if v is None:
|
| 434 |
+
return default
|
| 435 |
+
if isinstance(v, bool):
|
| 436 |
+
return v
|
| 437 |
+
s = str(v).strip().lower()
|
| 438 |
+
if s == "":
|
| 439 |
+
return default
|
| 440 |
+
return s in {"1", "true", "yes", "y", "on"}
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _map_status(status: str) -> int:
|
| 444 |
+
"""Map job status string to integer code."""
|
| 445 |
+
return STATUS_MAP.get(status, 2)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]:
|
| 449 |
+
"""Parse comma-separated timesteps string to list of floats."""
|
| 450 |
+
if not s or not s.strip():
|
| 451 |
+
return None
|
| 452 |
+
try:
|
| 453 |
+
return [float(t.strip()) for t in s.split(",") if t.strip()]
|
| 454 |
+
except (ValueError, Exception):
|
| 455 |
+
return None
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
class RequestParser:
|
| 459 |
+
"""Parse request parameters from multiple sources with alias support."""
|
| 460 |
+
|
| 461 |
+
def __init__(self, raw: dict):
|
| 462 |
+
self._raw = dict(raw) if raw else {}
|
| 463 |
+
self._param_obj = self._parse_json(self._raw.get("param_obj"))
|
| 464 |
+
self._metas = self._find_metas()
|
| 465 |
+
|
| 466 |
+
def _parse_json(self, v) -> dict:
|
| 467 |
+
if isinstance(v, dict):
|
| 468 |
+
return v
|
| 469 |
+
if isinstance(v, str) and v.strip():
|
| 470 |
+
try:
|
| 471 |
+
return json.loads(v)
|
| 472 |
+
except Exception:
|
| 473 |
+
pass
|
| 474 |
+
return {}
|
| 475 |
+
|
| 476 |
+
def _find_metas(self) -> dict:
|
| 477 |
+
for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"):
|
| 478 |
+
v = self._raw.get(key)
|
| 479 |
+
if v:
|
| 480 |
+
return self._parse_json(v)
|
| 481 |
+
return {}
|
| 482 |
+
|
| 483 |
+
def get(self, name: str, default=None):
|
| 484 |
+
"""Get parameter by canonical name from all sources."""
|
| 485 |
+
aliases = PARAM_ALIASES.get(name, [name])
|
| 486 |
+
for source in (self._raw, self._param_obj, self._metas):
|
| 487 |
+
for alias in aliases:
|
| 488 |
+
v = source.get(alias)
|
| 489 |
+
if v is not None:
|
| 490 |
+
return v
|
| 491 |
+
return default
|
| 492 |
+
|
| 493 |
+
def str(self, name: str, default: str = "") -> str:
|
| 494 |
+
v = self.get(name)
|
| 495 |
+
return str(v) if v is not None else default
|
| 496 |
+
|
| 497 |
+
def int(self, name: str, default: Optional[int] = None) -> Optional[int]:
|
| 498 |
+
return _to_int(self.get(name), default)
|
| 499 |
+
|
| 500 |
+
def float(self, name: str, default: Optional[float] = None) -> Optional[float]:
|
| 501 |
+
return _to_float(self.get(name), default)
|
| 502 |
+
|
| 503 |
+
def bool(self, name: str, default: bool = False) -> bool:
|
| 504 |
+
return _to_bool(self.get(name), default)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
|
| 508 |
+
suffix = Path(upload.filename or "").suffix
|
| 509 |
+
fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
|
| 510 |
+
os.close(fd)
|
| 511 |
+
try:
|
| 512 |
+
with open(path, "wb") as f:
|
| 513 |
+
while True:
|
| 514 |
+
chunk = await upload.read(1024 * 1024)
|
| 515 |
+
if not chunk:
|
| 516 |
+
break
|
| 517 |
+
f.write(chunk)
|
| 518 |
+
except Exception:
|
| 519 |
+
try:
|
| 520 |
+
os.remove(path)
|
| 521 |
+
except Exception:
|
| 522 |
+
pass
|
| 523 |
+
raise
|
| 524 |
+
finally:
|
| 525 |
+
try:
|
| 526 |
+
await upload.close()
|
| 527 |
+
except Exception:
|
| 528 |
+
pass
|
| 529 |
+
return path
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def create_app() -> FastAPI:
|
| 533 |
+
store = _JobStore()
|
| 534 |
+
|
| 535 |
+
QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
|
| 536 |
+
WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended
|
| 537 |
+
|
| 538 |
+
INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
|
| 539 |
+
AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
|
| 540 |
+
|
| 541 |
+
def _path_to_audio_url(path: str) -> str:
|
| 542 |
+
"""Convert local file path to downloadable relative URL"""
|
| 543 |
+
if not path:
|
| 544 |
+
return path
|
| 545 |
+
if path.startswith("http://") or path.startswith("https://"):
|
| 546 |
+
return path
|
| 547 |
+
encoded_path = urllib.parse.quote(path, safe="")
|
| 548 |
+
return f"/v1/audio?path={encoded_path}"
|
| 549 |
+
|
| 550 |
+
@asynccontextmanager
|
| 551 |
+
async def lifespan(app: FastAPI):
|
| 552 |
+
# Clear proxy env that may affect downstream libs
|
| 553 |
+
for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
|
| 554 |
+
os.environ.pop(proxy_var, None)
|
| 555 |
+
|
| 556 |
+
# Ensure compilation/temp caches do not fill up small default /tmp.
|
| 557 |
+
# Triton/Inductor (and the system compiler) can create large temporary files.
|
| 558 |
+
project_root = _get_project_root()
|
| 559 |
+
cache_root = os.path.join(project_root, ".cache", "acestep")
|
| 560 |
+
tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
|
| 561 |
+
triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
|
| 562 |
+
inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
|
| 563 |
+
|
| 564 |
+
for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
|
| 565 |
+
try:
|
| 566 |
+
os.makedirs(p, exist_ok=True)
|
| 567 |
+
except Exception:
|
| 568 |
+
# Best-effort: do not block startup if directory creation fails.
|
| 569 |
+
pass
|
| 570 |
+
|
| 571 |
+
# Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
|
| 572 |
+
if os.getenv("ACESTEP_TMPDIR"):
|
| 573 |
+
os.environ["TMPDIR"] = tmp_root
|
| 574 |
+
os.environ["TEMP"] = tmp_root
|
| 575 |
+
os.environ["TMP"] = tmp_root
|
| 576 |
+
else:
|
| 577 |
+
os.environ.setdefault("TMPDIR", tmp_root)
|
| 578 |
+
os.environ.setdefault("TEMP", tmp_root)
|
| 579 |
+
os.environ.setdefault("TMP", tmp_root)
|
| 580 |
+
|
| 581 |
+
os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
|
| 582 |
+
os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
|
| 583 |
+
|
| 584 |
+
handler = AceStepHandler()
|
| 585 |
+
llm_handler = LLMHandler()
|
| 586 |
+
init_lock = asyncio.Lock()
|
| 587 |
+
app.state._initialized = False
|
| 588 |
+
app.state._init_error = None
|
| 589 |
+
app.state._init_lock = init_lock
|
| 590 |
+
|
| 591 |
+
app.state.llm_handler = llm_handler
|
| 592 |
+
app.state._llm_initialized = False
|
| 593 |
+
app.state._llm_init_error = None
|
| 594 |
+
app.state._llm_init_lock = Lock()
|
| 595 |
+
|
| 596 |
+
# Multi-model support: secondary DiT handlers
|
| 597 |
+
handler2 = None
|
| 598 |
+
handler3 = None
|
| 599 |
+
config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
|
| 600 |
+
config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
|
| 601 |
+
|
| 602 |
+
if config_path2:
|
| 603 |
+
handler2 = AceStepHandler()
|
| 604 |
+
if config_path3:
|
| 605 |
+
handler3 = AceStepHandler()
|
| 606 |
+
|
| 607 |
+
app.state.handler2 = handler2
|
| 608 |
+
app.state.handler3 = handler3
|
| 609 |
+
app.state._initialized2 = False
|
| 610 |
+
app.state._initialized3 = False
|
| 611 |
+
app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
|
| 612 |
+
app.state._config_path2 = config_path2
|
| 613 |
+
app.state._config_path3 = config_path3
|
| 614 |
+
|
| 615 |
+
max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
|
| 616 |
+
executor = ThreadPoolExecutor(max_workers=max_workers)
|
| 617 |
+
|
| 618 |
+
# Queue & observability
|
| 619 |
+
app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req)
|
| 620 |
+
app.state.pending_ids = deque() # queued job_ids
|
| 621 |
+
app.state.pending_lock = asyncio.Lock()
|
| 622 |
+
|
| 623 |
+
# temp files per job (from multipart uploads)
|
| 624 |
+
app.state.job_temp_files = {} # job_id -> list[path]
|
| 625 |
+
app.state.job_temp_files_lock = asyncio.Lock()
|
| 626 |
+
|
| 627 |
+
# stats
|
| 628 |
+
app.state.stats_lock = asyncio.Lock()
|
| 629 |
+
app.state.recent_durations = deque(maxlen=AVG_WINDOW)
|
| 630 |
+
app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS
|
| 631 |
+
|
| 632 |
+
app.state.handler = handler
|
| 633 |
+
app.state.executor = executor
|
| 634 |
+
app.state.job_store = store
|
| 635 |
+
app.state._python_executable = sys.executable
|
| 636 |
+
|
| 637 |
+
# Temporary directory for saving generated audio files
|
| 638 |
+
app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
|
| 639 |
+
os.makedirs(app.state.temp_audio_dir, exist_ok=True)
|
| 640 |
+
|
| 641 |
+
# Initialize local cache
|
| 642 |
+
try:
|
| 643 |
+
from acestep.local_cache import get_local_cache
|
| 644 |
+
local_cache_dir = os.path.join(cache_root, "local_redis")
|
| 645 |
+
app.state.local_cache = get_local_cache(local_cache_dir)
|
| 646 |
+
except ImportError:
|
| 647 |
+
app.state.local_cache = None
|
| 648 |
+
|
| 649 |
+
async def _ensure_initialized() -> None:
|
| 650 |
+
h: AceStepHandler = app.state.handler
|
| 651 |
+
|
| 652 |
+
if getattr(app.state, "_initialized", False):
|
| 653 |
+
return
|
| 654 |
+
if getattr(app.state, "_init_error", None):
|
| 655 |
+
raise RuntimeError(app.state._init_error)
|
| 656 |
+
|
| 657 |
+
async with app.state._init_lock:
|
| 658 |
+
if getattr(app.state, "_initialized", False):
|
| 659 |
+
return
|
| 660 |
+
if getattr(app.state, "_init_error", None):
|
| 661 |
+
raise RuntimeError(app.state._init_error)
|
| 662 |
+
|
| 663 |
+
project_root = _get_project_root()
|
| 664 |
+
config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
|
| 665 |
+
device = os.getenv("ACESTEP_DEVICE", "auto")
|
| 666 |
+
|
| 667 |
+
use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
|
| 668 |
+
offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
|
| 669 |
+
offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
|
| 670 |
+
|
| 671 |
+
# Initialize primary model
|
| 672 |
+
status_msg, ok = h.initialize_service(
|
| 673 |
+
project_root=project_root,
|
| 674 |
+
config_path=config_path,
|
| 675 |
+
device=device,
|
| 676 |
+
use_flash_attention=use_flash_attention,
|
| 677 |
+
compile_model=False,
|
| 678 |
+
offload_to_cpu=offload_to_cpu,
|
| 679 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 680 |
+
)
|
| 681 |
+
if not ok:
|
| 682 |
+
app.state._init_error = status_msg
|
| 683 |
+
raise RuntimeError(status_msg)
|
| 684 |
+
app.state._initialized = True
|
| 685 |
+
|
| 686 |
+
# Initialize secondary model if configured
|
| 687 |
+
if app.state.handler2 and app.state._config_path2:
|
| 688 |
+
try:
|
| 689 |
+
status_msg2, ok2 = app.state.handler2.initialize_service(
|
| 690 |
+
project_root=project_root,
|
| 691 |
+
config_path=app.state._config_path2,
|
| 692 |
+
device=device,
|
| 693 |
+
use_flash_attention=use_flash_attention,
|
| 694 |
+
compile_model=False,
|
| 695 |
+
offload_to_cpu=offload_to_cpu,
|
| 696 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 697 |
+
)
|
| 698 |
+
app.state._initialized2 = ok2
|
| 699 |
+
if ok2:
|
| 700 |
+
print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
|
| 701 |
+
else:
|
| 702 |
+
print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
|
| 703 |
+
except Exception as e:
|
| 704 |
+
print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
|
| 705 |
+
app.state._initialized2 = False
|
| 706 |
+
|
| 707 |
+
# Initialize third model if configured
|
| 708 |
+
if app.state.handler3 and app.state._config_path3:
|
| 709 |
+
try:
|
| 710 |
+
status_msg3, ok3 = app.state.handler3.initialize_service(
|
| 711 |
+
project_root=project_root,
|
| 712 |
+
config_path=app.state._config_path3,
|
| 713 |
+
device=device,
|
| 714 |
+
use_flash_attention=use_flash_attention,
|
| 715 |
+
compile_model=False,
|
| 716 |
+
offload_to_cpu=offload_to_cpu,
|
| 717 |
+
offload_dit_to_cpu=offload_dit_to_cpu,
|
| 718 |
+
)
|
| 719 |
+
app.state._initialized3 = ok3
|
| 720 |
+
if ok3:
|
| 721 |
+
print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
|
| 722 |
+
else:
|
| 723 |
+
print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
|
| 724 |
+
except Exception as e:
|
| 725 |
+
print(f"[API Server] Warning: Failed to initialize third model: {e}")
|
| 726 |
+
app.state._initialized3 = False
|
| 727 |
+
|
| 728 |
+
async def _cleanup_job_temp_files(job_id: str) -> None:
|
| 729 |
+
async with app.state.job_temp_files_lock:
|
| 730 |
+
paths = app.state.job_temp_files.pop(job_id, [])
|
| 731 |
+
for p in paths:
|
| 732 |
+
try:
|
| 733 |
+
os.remove(p)
|
| 734 |
+
except Exception:
|
| 735 |
+
pass
|
| 736 |
+
|
| 737 |
+
def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None:
|
| 738 |
+
"""Update local cache with job result"""
|
| 739 |
+
local_cache = getattr(app.state, 'local_cache', None)
|
| 740 |
+
if not local_cache:
|
| 741 |
+
return
|
| 742 |
+
|
| 743 |
+
rec = store.get(job_id)
|
| 744 |
+
env = getattr(rec, 'env', 'development') if rec else 'development'
|
| 745 |
+
create_time = rec.created_at if rec else time.time()
|
| 746 |
+
|
| 747 |
+
status_int = _map_status(status)
|
| 748 |
+
|
| 749 |
+
if status == "succeeded" and result:
|
| 750 |
+
audio_paths = result.get("audio_paths", [])
|
| 751 |
+
# Final prompt/lyrics (may be modified by thinking/format)
|
| 752 |
+
final_prompt = result.get("prompt", "")
|
| 753 |
+
final_lyrics = result.get("lyrics", "")
|
| 754 |
+
# Original user input from metas
|
| 755 |
+
metas_raw = result.get("metas", {}) or {}
|
| 756 |
+
original_prompt = metas_raw.get("prompt", "")
|
| 757 |
+
original_lyrics = metas_raw.get("lyrics", "")
|
| 758 |
+
# metas contains original input + other metadata
|
| 759 |
+
metas = {
|
| 760 |
+
"bpm": metas_raw.get("bpm"),
|
| 761 |
+
"duration": metas_raw.get("duration"),
|
| 762 |
+
"genres": metas_raw.get("genres", ""),
|
| 763 |
+
"keyscale": metas_raw.get("keyscale", ""),
|
| 764 |
+
"timesignature": metas_raw.get("timesignature", ""),
|
| 765 |
+
"prompt": original_prompt,
|
| 766 |
+
"lyrics": original_lyrics,
|
| 767 |
+
}
|
| 768 |
+
# Extra fields for Discord bot
|
| 769 |
+
generation_info = result.get("generation_info", "")
|
| 770 |
+
seed_value = result.get("seed_value", "")
|
| 771 |
+
lm_model = result.get("lm_model", "")
|
| 772 |
+
dit_model = result.get("dit_model", "")
|
| 773 |
+
|
| 774 |
+
if audio_paths:
|
| 775 |
+
result_data = [
|
| 776 |
+
{
|
| 777 |
+
"file": p,
|
| 778 |
+
"wave": "",
|
| 779 |
+
"status": status_int,
|
| 780 |
+
"create_time": int(create_time),
|
| 781 |
+
"env": env,
|
| 782 |
+
"prompt": final_prompt,
|
| 783 |
+
"lyrics": final_lyrics,
|
| 784 |
+
"metas": metas,
|
| 785 |
+
"generation_info": generation_info,
|
| 786 |
+
"seed_value": seed_value,
|
| 787 |
+
"lm_model": lm_model,
|
| 788 |
+
"dit_model": dit_model,
|
| 789 |
+
}
|
| 790 |
+
for p in audio_paths
|
| 791 |
+
]
|
| 792 |
+
else:
|
| 793 |
+
result_data = [{
|
| 794 |
+
"file": "",
|
| 795 |
+
"wave": "",
|
| 796 |
+
"status": status_int,
|
| 797 |
+
"create_time": int(create_time),
|
| 798 |
+
"env": env,
|
| 799 |
+
"prompt": final_prompt,
|
| 800 |
+
"lyrics": final_lyrics,
|
| 801 |
+
"metas": metas,
|
| 802 |
+
"generation_info": generation_info,
|
| 803 |
+
"seed_value": seed_value,
|
| 804 |
+
"lm_model": lm_model,
|
| 805 |
+
"dit_model": dit_model,
|
| 806 |
+
}]
|
| 807 |
+
else:
|
| 808 |
+
result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
|
| 809 |
+
|
| 810 |
+
result_key = f"{RESULT_KEY_PREFIX}{job_id}"
|
| 811 |
+
local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS)
|
| 812 |
+
|
| 813 |
+
async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
|
| 814 |
+
job_store: _JobStore = app.state.job_store
|
| 815 |
+
llm: LLMHandler = app.state.llm_handler
|
| 816 |
+
executor: ThreadPoolExecutor = app.state.executor
|
| 817 |
+
|
| 818 |
+
await _ensure_initialized()
|
| 819 |
+
job_store.mark_running(job_id)
|
| 820 |
+
|
| 821 |
+
# Select DiT handler based on user's model choice
|
| 822 |
+
# Default: use primary handler
|
| 823 |
+
selected_handler: AceStepHandler = app.state.handler
|
| 824 |
+
selected_model_name = _get_model_name(app.state._config_path)
|
| 825 |
+
|
| 826 |
+
if req.model:
|
| 827 |
+
model_matched = False
|
| 828 |
+
|
| 829 |
+
# Check if it matches the second model
|
| 830 |
+
if app.state.handler2 and getattr(app.state, "_initialized2", False):
|
| 831 |
+
model2_name = _get_model_name(app.state._config_path2)
|
| 832 |
+
if req.model == model2_name:
|
| 833 |
+
selected_handler = app.state.handler2
|
| 834 |
+
selected_model_name = model2_name
|
| 835 |
+
model_matched = True
|
| 836 |
+
print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
|
| 837 |
+
|
| 838 |
+
# Check if it matches the third model
|
| 839 |
+
if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
|
| 840 |
+
model3_name = _get_model_name(app.state._config_path3)
|
| 841 |
+
if req.model == model3_name:
|
| 842 |
+
selected_handler = app.state.handler3
|
| 843 |
+
selected_model_name = model3_name
|
| 844 |
+
model_matched = True
|
| 845 |
+
print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
|
| 846 |
+
|
| 847 |
+
if not model_matched:
|
| 848 |
+
available_models = [_get_model_name(app.state._config_path)]
|
| 849 |
+
if app.state.handler2 and getattr(app.state, "_initialized2", False):
|
| 850 |
+
available_models.append(_get_model_name(app.state._config_path2))
|
| 851 |
+
if app.state.handler3 and getattr(app.state, "_initialized3", False):
|
| 852 |
+
available_models.append(_get_model_name(app.state._config_path3))
|
| 853 |
+
print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
|
| 854 |
+
|
| 855 |
+
# Use selected handler for generation
|
| 856 |
+
h: AceStepHandler = selected_handler
|
| 857 |
+
|
| 858 |
+
def _blocking_generate() -> Dict[str, Any]:
|
| 859 |
+
"""Generate music using unified inference logic from acestep.inference"""
|
| 860 |
+
|
| 861 |
+
def _ensure_llm_ready() -> None:
|
| 862 |
+
"""Ensure LLM handler is initialized when needed"""
|
| 863 |
+
with app.state._llm_init_lock:
|
| 864 |
+
initialized = getattr(app.state, "_llm_initialized", False)
|
| 865 |
+
had_error = getattr(app.state, "_llm_init_error", None)
|
| 866 |
+
if initialized or had_error is not None:
|
| 867 |
+
return
|
| 868 |
+
|
| 869 |
+
project_root = _get_project_root()
|
| 870 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 871 |
+
lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
|
| 872 |
+
backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
|
| 873 |
+
if backend not in {"vllm", "pt"}:
|
| 874 |
+
backend = "vllm"
|
| 875 |
+
|
| 876 |
+
lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
|
| 877 |
+
lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
|
| 878 |
+
|
| 879 |
+
status, ok = llm.initialize(
|
| 880 |
+
checkpoint_dir=checkpoint_dir,
|
| 881 |
+
lm_model_path=lm_model_path,
|
| 882 |
+
backend=backend,
|
| 883 |
+
device=lm_device,
|
| 884 |
+
offload_to_cpu=lm_offload,
|
| 885 |
+
dtype=h.dtype,
|
| 886 |
+
)
|
| 887 |
+
if not ok:
|
| 888 |
+
app.state._llm_init_error = status
|
| 889 |
+
else:
|
| 890 |
+
app.state._llm_initialized = True
|
| 891 |
+
|
| 892 |
+
def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
|
| 893 |
+
"""Ensure a stable `metas` dict (keys always present)."""
|
| 894 |
+
meta = meta or {}
|
| 895 |
+
out: Dict[str, Any] = dict(meta)
|
| 896 |
+
|
| 897 |
+
# Normalize key aliases
|
| 898 |
+
if "keyscale" not in out and "key_scale" in out:
|
| 899 |
+
out["keyscale"] = out.get("key_scale")
|
| 900 |
+
if "timesignature" not in out and "time_signature" in out:
|
| 901 |
+
out["timesignature"] = out.get("time_signature")
|
| 902 |
+
|
| 903 |
+
# Ensure required keys exist
|
| 904 |
+
for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
|
| 905 |
+
if out.get(k) in (None, ""):
|
| 906 |
+
out[k] = "N/A"
|
| 907 |
+
return out
|
| 908 |
+
|
| 909 |
+
# Normalize LM sampling parameters
|
| 910 |
+
lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
|
| 911 |
+
lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
|
| 912 |
+
|
| 913 |
+
# Determine if LLM is needed
|
| 914 |
+
thinking = bool(req.thinking)
|
| 915 |
+
sample_mode = bool(req.sample_mode)
|
| 916 |
+
has_sample_query = bool(req.sample_query and req.sample_query.strip())
|
| 917 |
+
use_format = bool(req.use_format)
|
| 918 |
+
use_cot_caption = bool(req.use_cot_caption)
|
| 919 |
+
use_cot_language = bool(req.use_cot_language)
|
| 920 |
+
|
| 921 |
+
# LLM is needed for:
|
| 922 |
+
# - thinking mode (LM generates audio codes)
|
| 923 |
+
# - sample_mode (LM generates random caption/lyrics/metas)
|
| 924 |
+
# - sample_query/description (LM generates from description)
|
| 925 |
+
# - use_format (LM enhances caption/lyrics)
|
| 926 |
+
# - use_cot_caption or use_cot_language (LM enhances metadata)
|
| 927 |
+
need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
|
| 928 |
+
|
| 929 |
+
# Ensure LLM is ready if needed
|
| 930 |
+
if need_llm:
|
| 931 |
+
_ensure_llm_ready()
|
| 932 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 933 |
+
raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
|
| 934 |
+
|
| 935 |
+
# Handle sample mode or description: generate caption/lyrics/metas via LM
|
| 936 |
+
caption = req.prompt
|
| 937 |
+
lyrics = req.lyrics
|
| 938 |
+
bpm = req.bpm
|
| 939 |
+
key_scale = req.key_scale
|
| 940 |
+
time_signature = req.time_signature
|
| 941 |
+
audio_duration = req.audio_duration
|
| 942 |
+
|
| 943 |
+
# Save original user input for metas
|
| 944 |
+
original_prompt = req.prompt or ""
|
| 945 |
+
original_lyrics = req.lyrics or ""
|
| 946 |
+
|
| 947 |
+
if sample_mode or has_sample_query:
|
| 948 |
+
if has_sample_query:
|
| 949 |
+
# Use create_sample() with description query
|
| 950 |
+
parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
|
| 951 |
+
|
| 952 |
+
# Determine vocal_language with priority:
|
| 953 |
+
# 1. User-specified vocal_language (if not default "en")
|
| 954 |
+
# 2. Language parsed from description
|
| 955 |
+
# 3. None (no constraint)
|
| 956 |
+
if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
|
| 957 |
+
sample_language = req.vocal_language
|
| 958 |
+
else:
|
| 959 |
+
sample_language = parsed_language
|
| 960 |
+
|
| 961 |
+
sample_result = create_sample(
|
| 962 |
+
llm_handler=llm,
|
| 963 |
+
query=req.sample_query,
|
| 964 |
+
instrumental=parsed_instrumental,
|
| 965 |
+
vocal_language=sample_language,
|
| 966 |
+
temperature=req.lm_temperature,
|
| 967 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 968 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 969 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
if not sample_result.success:
|
| 973 |
+
raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
|
| 974 |
+
|
| 975 |
+
# Use generated sample data
|
| 976 |
+
caption = sample_result.caption
|
| 977 |
+
lyrics = sample_result.lyrics
|
| 978 |
+
bpm = sample_result.bpm
|
| 979 |
+
key_scale = sample_result.keyscale
|
| 980 |
+
time_signature = sample_result.timesignature
|
| 981 |
+
audio_duration = sample_result.duration
|
| 982 |
+
else:
|
| 983 |
+
# Original sample_mode behavior: random generation
|
| 984 |
+
sample_metadata, sample_status = llm.understand_audio_from_codes(
|
| 985 |
+
audio_codes="NO USER INPUT",
|
| 986 |
+
temperature=req.lm_temperature,
|
| 987 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 988 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 989 |
+
repetition_penalty=req.lm_repetition_penalty,
|
| 990 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 991 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
if not sample_metadata or str(sample_status).startswith("❌"):
|
| 995 |
+
raise RuntimeError(f"Sample generation failed: {sample_status}")
|
| 996 |
+
|
| 997 |
+
# Use generated values with fallback defaults
|
| 998 |
+
caption = sample_metadata.get("caption", "")
|
| 999 |
+
lyrics = sample_metadata.get("lyrics", "")
|
| 1000 |
+
bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
|
| 1001 |
+
key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
|
| 1002 |
+
time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
|
| 1003 |
+
audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
|
| 1004 |
+
|
| 1005 |
+
# Apply format_sample() if use_format is True and caption/lyrics are provided
|
| 1006 |
+
format_has_duration = False
|
| 1007 |
+
|
| 1008 |
+
if req.use_format and (caption or lyrics):
|
| 1009 |
+
_ensure_llm_ready()
|
| 1010 |
+
if getattr(app.state, "_llm_init_error", None):
|
| 1011 |
+
raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
|
| 1012 |
+
|
| 1013 |
+
# Build user_metadata from request params (matching bot.py behavior)
|
| 1014 |
+
user_metadata_for_format = {}
|
| 1015 |
+
if bpm is not None:
|
| 1016 |
+
user_metadata_for_format['bpm'] = bpm
|
| 1017 |
+
if audio_duration is not None and audio_duration > 0:
|
| 1018 |
+
user_metadata_for_format['duration'] = int(audio_duration)
|
| 1019 |
+
if key_scale:
|
| 1020 |
+
user_metadata_for_format['keyscale'] = key_scale
|
| 1021 |
+
if time_signature:
|
| 1022 |
+
user_metadata_for_format['timesignature'] = time_signature
|
| 1023 |
+
if req.vocal_language and req.vocal_language != "unknown":
|
| 1024 |
+
user_metadata_for_format['language'] = req.vocal_language
|
| 1025 |
+
|
| 1026 |
+
format_result = format_sample(
|
| 1027 |
+
llm_handler=llm,
|
| 1028 |
+
caption=caption,
|
| 1029 |
+
lyrics=lyrics,
|
| 1030 |
+
user_metadata=user_metadata_for_format if user_metadata_for_format else None,
|
| 1031 |
+
temperature=req.lm_temperature,
|
| 1032 |
+
top_k=lm_top_k if lm_top_k > 0 else None,
|
| 1033 |
+
top_p=lm_top_p if lm_top_p < 1.0 else None,
|
| 1034 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
if format_result.success:
|
| 1038 |
+
# Extract all formatted data (matching bot.py behavior)
|
| 1039 |
+
caption = format_result.caption or caption
|
| 1040 |
+
lyrics = format_result.lyrics or lyrics
|
| 1041 |
+
if format_result.duration:
|
| 1042 |
+
audio_duration = format_result.duration
|
| 1043 |
+
format_has_duration = True
|
| 1044 |
+
if format_result.bpm:
|
| 1045 |
+
bpm = format_result.bpm
|
| 1046 |
+
if format_result.keyscale:
|
| 1047 |
+
key_scale = format_result.keyscale
|
| 1048 |
+
if format_result.timesignature:
|
| 1049 |
+
time_signature = format_result.timesignature
|
| 1050 |
+
|
| 1051 |
+
# Parse timesteps string to list of floats if provided
|
| 1052 |
+
parsed_timesteps = _parse_timesteps(req.timesteps)
|
| 1053 |
+
|
| 1054 |
+
# Determine actual inference steps (timesteps override inference_steps)
|
| 1055 |
+
actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
|
| 1056 |
+
|
| 1057 |
+
# Auto-select instruction based on task_type if user didn't provide custom instruction
|
| 1058 |
+
# This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
|
| 1059 |
+
instruction_to_use = req.instruction
|
| 1060 |
+
if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
|
| 1061 |
+
instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
|
| 1062 |
+
|
| 1063 |
+
# Build GenerationParams using unified interface
|
| 1064 |
+
# Note: thinking controls LM code generation, sample_mode only affects CoT metas
|
| 1065 |
+
params = GenerationParams(
|
| 1066 |
+
task_type=req.task_type,
|
| 1067 |
+
instruction=instruction_to_use,
|
| 1068 |
+
reference_audio=req.reference_audio_path,
|
| 1069 |
+
src_audio=req.src_audio_path,
|
| 1070 |
+
audio_codes=req.audio_code_string,
|
| 1071 |
+
caption=caption,
|
| 1072 |
+
lyrics=lyrics,
|
| 1073 |
+
instrumental=False,
|
| 1074 |
+
vocal_language=req.vocal_language,
|
| 1075 |
+
bpm=bpm,
|
| 1076 |
+
keyscale=key_scale,
|
| 1077 |
+
timesignature=time_signature,
|
| 1078 |
+
duration=audio_duration if audio_duration else -1.0,
|
| 1079 |
+
inference_steps=req.inference_steps,
|
| 1080 |
+
seed=req.seed,
|
| 1081 |
+
guidance_scale=req.guidance_scale,
|
| 1082 |
+
use_adg=req.use_adg,
|
| 1083 |
+
cfg_interval_start=req.cfg_interval_start,
|
| 1084 |
+
cfg_interval_end=req.cfg_interval_end,
|
| 1085 |
+
shift=req.shift,
|
| 1086 |
+
infer_method=req.infer_method,
|
| 1087 |
+
timesteps=parsed_timesteps,
|
| 1088 |
+
repainting_start=req.repainting_start,
|
| 1089 |
+
repainting_end=req.repainting_end if req.repainting_end else -1,
|
| 1090 |
+
audio_cover_strength=req.audio_cover_strength,
|
| 1091 |
+
# LM parameters
|
| 1092 |
+
thinking=thinking, # Use LM for code generation when thinking=True
|
| 1093 |
+
lm_temperature=req.lm_temperature,
|
| 1094 |
+
lm_cfg_scale=req.lm_cfg_scale,
|
| 1095 |
+
lm_top_k=lm_top_k,
|
| 1096 |
+
lm_top_p=lm_top_p,
|
| 1097 |
+
lm_negative_prompt=req.lm_negative_prompt,
|
| 1098 |
+
# use_cot_metas logic:
|
| 1099 |
+
# - sample_mode: metas already generated, skip Phase 1
|
| 1100 |
+
# - format with duration: metas already generated, skip Phase 1
|
| 1101 |
+
# - format without duration: need Phase 1 to generate duration
|
| 1102 |
+
# - no format: need Phase 1 to generate all metas
|
| 1103 |
+
use_cot_metas=not sample_mode and not format_has_duration,
|
| 1104 |
+
use_cot_caption=req.use_cot_caption,
|
| 1105 |
+
use_cot_language=req.use_cot_language,
|
| 1106 |
+
use_constrained_decoding=req.constrained_decoding,
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
# Build GenerationConfig - default to 2 audios like gradio_ui
|
| 1110 |
+
batch_size = req.batch_size if req.batch_size is not None else 2
|
| 1111 |
+
config = GenerationConfig(
|
| 1112 |
+
batch_size=batch_size,
|
| 1113 |
+
use_random_seed=req.use_random_seed,
|
| 1114 |
+
seeds=None, # Let unified logic handle seed generation
|
| 1115 |
+
audio_format=req.audio_format,
|
| 1116 |
+
constrained_decoding_debug=req.constrained_decoding_debug,
|
| 1117 |
+
)
|
| 1118 |
+
|
| 1119 |
+
# Check LLM initialization status
|
| 1120 |
+
llm_is_initialized = getattr(app.state, "_llm_initialized", False)
|
| 1121 |
+
llm_to_pass = llm if llm_is_initialized else None
|
| 1122 |
+
|
| 1123 |
+
# Generate music using unified interface
|
| 1124 |
+
result = generate_music(
|
| 1125 |
+
dit_handler=h,
|
| 1126 |
+
llm_handler=llm_to_pass,
|
| 1127 |
+
params=params,
|
| 1128 |
+
config=config,
|
| 1129 |
+
save_dir=app.state.temp_audio_dir,
|
| 1130 |
+
progress=None,
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
if not result.success:
|
| 1134 |
+
raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
|
| 1135 |
+
|
| 1136 |
+
# Extract results
|
| 1137 |
+
audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
|
| 1138 |
+
first_audio = audio_paths[0] if len(audio_paths) > 0 else None
|
| 1139 |
+
second_audio = audio_paths[1] if len(audio_paths) > 1 else None
|
| 1140 |
+
|
| 1141 |
+
# Get metadata from LM or CoT results
|
| 1142 |
+
lm_metadata = result.extra_outputs.get("lm_metadata", {})
|
| 1143 |
+
metas_out = _normalize_metas(lm_metadata)
|
| 1144 |
+
|
| 1145 |
+
# Update metas with actual values used
|
| 1146 |
+
if params.cot_bpm:
|
| 1147 |
+
metas_out["bpm"] = params.cot_bpm
|
| 1148 |
+
elif bpm:
|
| 1149 |
+
metas_out["bpm"] = bpm
|
| 1150 |
+
|
| 1151 |
+
if params.cot_duration:
|
| 1152 |
+
metas_out["duration"] = params.cot_duration
|
| 1153 |
+
elif audio_duration:
|
| 1154 |
+
metas_out["duration"] = audio_duration
|
| 1155 |
+
|
| 1156 |
+
if params.cot_keyscale:
|
| 1157 |
+
metas_out["keyscale"] = params.cot_keyscale
|
| 1158 |
+
elif key_scale:
|
| 1159 |
+
metas_out["keyscale"] = key_scale
|
| 1160 |
+
|
| 1161 |
+
if params.cot_timesignature:
|
| 1162 |
+
metas_out["timesignature"] = params.cot_timesignature
|
| 1163 |
+
elif time_signature:
|
| 1164 |
+
metas_out["timesignature"] = time_signature
|
| 1165 |
+
|
| 1166 |
+
# Store original user input in metas (not the final/modified values)
|
| 1167 |
+
metas_out["prompt"] = original_prompt
|
| 1168 |
+
metas_out["lyrics"] = original_lyrics
|
| 1169 |
+
|
| 1170 |
+
# Extract seed values for response (comma-separated for multiple audios)
|
| 1171 |
+
seed_values = []
|
| 1172 |
+
for audio in result.audios:
|
| 1173 |
+
audio_params = audio.get("params", {})
|
| 1174 |
+
seed = audio_params.get("seed")
|
| 1175 |
+
if seed is not None:
|
| 1176 |
+
seed_values.append(str(seed))
|
| 1177 |
+
seed_value = ",".join(seed_values) if seed_values else ""
|
| 1178 |
+
|
| 1179 |
+
# Build generation_info using the helper function (like gradio_ui)
|
| 1180 |
+
time_costs = result.extra_outputs.get("time_costs", {})
|
| 1181 |
+
generation_info = _build_generation_info(
|
| 1182 |
+
lm_metadata=lm_metadata,
|
| 1183 |
+
time_costs=time_costs,
|
| 1184 |
+
seed_value=seed_value,
|
| 1185 |
+
inference_steps=req.inference_steps,
|
| 1186 |
+
num_audios=len(result.audios),
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
def _none_if_na_str(v: Any) -> Optional[str]:
|
| 1190 |
+
if v is None:
|
| 1191 |
+
return None
|
| 1192 |
+
s = str(v).strip()
|
| 1193 |
+
if s in {"", "N/A"}:
|
| 1194 |
+
return None
|
| 1195 |
+
return s
|
| 1196 |
+
|
| 1197 |
+
# Get model information
|
| 1198 |
+
lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B")
|
| 1199 |
+
# Use selected_model_name (set at the beginning of _run_one_job)
|
| 1200 |
+
dit_model_name = selected_model_name
|
| 1201 |
+
|
| 1202 |
+
return {
|
| 1203 |
+
"first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
|
| 1204 |
+
"second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
|
| 1205 |
+
"audio_paths": [_path_to_audio_url(p) for p in audio_paths],
|
| 1206 |
+
"generation_info": generation_info,
|
| 1207 |
+
"status_message": result.status_message,
|
| 1208 |
+
"seed_value": seed_value,
|
| 1209 |
+
# Final prompt/lyrics (may be modified by thinking/format)
|
| 1210 |
+
"prompt": caption or "",
|
| 1211 |
+
"lyrics": lyrics or "",
|
| 1212 |
+
# metas contains original user input + other metadata
|
| 1213 |
+
"metas": metas_out,
|
| 1214 |
+
"bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
|
| 1215 |
+
"duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
|
| 1216 |
+
"genres": _none_if_na_str(metas_out.get("genres")),
|
| 1217 |
+
"keyscale": _none_if_na_str(metas_out.get("keyscale")),
|
| 1218 |
+
"timesignature": _none_if_na_str(metas_out.get("timesignature")),
|
| 1219 |
+
"lm_model": lm_model_name,
|
| 1220 |
+
"dit_model": dit_model_name,
|
| 1221 |
+
}
|
| 1222 |
+
|
| 1223 |
+
t0 = time.time()
|
| 1224 |
+
try:
|
| 1225 |
+
loop = asyncio.get_running_loop()
|
| 1226 |
+
result = await loop.run_in_executor(executor, _blocking_generate)
|
| 1227 |
+
job_store.mark_succeeded(job_id, result)
|
| 1228 |
+
|
| 1229 |
+
# Update local cache
|
| 1230 |
+
_update_local_cache(job_id, result, "succeeded")
|
| 1231 |
+
except Exception:
|
| 1232 |
+
job_store.mark_failed(job_id, traceback.format_exc())
|
| 1233 |
+
|
| 1234 |
+
# Update local cache
|
| 1235 |
+
_update_local_cache(job_id, None, "failed")
|
| 1236 |
+
finally:
|
| 1237 |
+
dt = max(0.0, time.time() - t0)
|
| 1238 |
+
async with app.state.stats_lock:
|
| 1239 |
+
app.state.recent_durations.append(dt)
|
| 1240 |
+
if app.state.recent_durations:
|
| 1241 |
+
app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations)
|
| 1242 |
+
|
| 1243 |
+
async def _queue_worker(worker_idx: int) -> None:
|
| 1244 |
+
while True:
|
| 1245 |
+
job_id, req = await app.state.job_queue.get()
|
| 1246 |
+
try:
|
| 1247 |
+
async with app.state.pending_lock:
|
| 1248 |
+
try:
|
| 1249 |
+
app.state.pending_ids.remove(job_id)
|
| 1250 |
+
except ValueError:
|
| 1251 |
+
pass
|
| 1252 |
+
|
| 1253 |
+
await _run_one_job(job_id, req)
|
| 1254 |
+
finally:
|
| 1255 |
+
await _cleanup_job_temp_files(job_id)
|
| 1256 |
+
app.state.job_queue.task_done()
|
| 1257 |
+
|
| 1258 |
+
worker_count = max(1, WORKER_COUNT)
|
| 1259 |
+
workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)]
|
| 1260 |
+
app.state.worker_tasks = workers
|
| 1261 |
+
|
| 1262 |
+
try:
|
| 1263 |
+
yield
|
| 1264 |
+
finally:
|
| 1265 |
+
for t in workers:
|
| 1266 |
+
t.cancel()
|
| 1267 |
+
executor.shutdown(wait=False, cancel_futures=True)
|
| 1268 |
+
|
| 1269 |
+
app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
|
| 1270 |
+
|
| 1271 |
+
async def _queue_position(job_id: str) -> int:
|
| 1272 |
+
async with app.state.pending_lock:
|
| 1273 |
+
try:
|
| 1274 |
+
return list(app.state.pending_ids).index(job_id) + 1
|
| 1275 |
+
except ValueError:
|
| 1276 |
+
return 0
|
| 1277 |
+
|
| 1278 |
+
async def _eta_seconds_for_position(pos: int) -> Optional[float]:
|
| 1279 |
+
if pos <= 0:
|
| 1280 |
+
return None
|
| 1281 |
+
async with app.state.stats_lock:
|
| 1282 |
+
avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
|
| 1283 |
+
return pos * avg
|
| 1284 |
+
|
| 1285 |
+
@app.post("/release_task", response_model=CreateJobResponse)
|
| 1286 |
+
async def create_music_generate_job(request: Request) -> CreateJobResponse:
|
| 1287 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1288 |
+
temp_files: list[str] = []
|
| 1289 |
+
|
| 1290 |
+
def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest:
|
| 1291 |
+
"""Build GenerateMusicRequest from parsed parameters."""
|
| 1292 |
+
return GenerateMusicRequest(
|
| 1293 |
+
prompt=p.str("prompt"),
|
| 1294 |
+
lyrics=p.str("lyrics"),
|
| 1295 |
+
thinking=p.bool("thinking"),
|
| 1296 |
+
sample_mode=p.bool("sample_mode"),
|
| 1297 |
+
sample_query=p.str("sample_query"),
|
| 1298 |
+
use_format=p.bool("use_format"),
|
| 1299 |
+
model=p.str("model") or None,
|
| 1300 |
+
bpm=p.int("bpm"),
|
| 1301 |
+
key_scale=p.str("key_scale"),
|
| 1302 |
+
time_signature=p.str("time_signature"),
|
| 1303 |
+
audio_duration=p.float("audio_duration"),
|
| 1304 |
+
vocal_language=p.str("vocal_language", "en"),
|
| 1305 |
+
inference_steps=p.int("inference_steps", 8),
|
| 1306 |
+
guidance_scale=p.float("guidance_scale", 7.0),
|
| 1307 |
+
use_random_seed=p.bool("use_random_seed", True),
|
| 1308 |
+
seed=p.int("seed", -1),
|
| 1309 |
+
batch_size=p.int("batch_size"),
|
| 1310 |
+
audio_code_string=p.str("audio_code_string"),
|
| 1311 |
+
repainting_start=p.float("repainting_start", 0.0),
|
| 1312 |
+
repainting_end=p.float("repainting_end"),
|
| 1313 |
+
instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
|
| 1314 |
+
audio_cover_strength=p.float("audio_cover_strength", 1.0),
|
| 1315 |
+
task_type=p.str("task_type", "text2music"),
|
| 1316 |
+
use_adg=p.bool("use_adg"),
|
| 1317 |
+
cfg_interval_start=p.float("cfg_interval_start", 0.0),
|
| 1318 |
+
cfg_interval_end=p.float("cfg_interval_end", 1.0),
|
| 1319 |
+
infer_method=p.str("infer_method", "ode"),
|
| 1320 |
+
shift=p.float("shift", 3.0),
|
| 1321 |
+
audio_format=p.str("audio_format", "mp3"),
|
| 1322 |
+
use_tiled_decode=p.bool("use_tiled_decode", True),
|
| 1323 |
+
lm_model_path=p.str("lm_model_path") or None,
|
| 1324 |
+
lm_backend=p.str("lm_backend", "vllm"),
|
| 1325 |
+
lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE),
|
| 1326 |
+
lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE),
|
| 1327 |
+
lm_top_k=p.int("lm_top_k"),
|
| 1328 |
+
lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P),
|
| 1329 |
+
lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0),
|
| 1330 |
+
lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"),
|
| 1331 |
+
constrained_decoding=p.bool("constrained_decoding", True),
|
| 1332 |
+
constrained_decoding_debug=p.bool("constrained_decoding_debug"),
|
| 1333 |
+
use_cot_caption=p.bool("use_cot_caption", True),
|
| 1334 |
+
use_cot_language=p.bool("use_cot_language", True),
|
| 1335 |
+
is_format_caption=p.bool("is_format_caption"),
|
| 1336 |
+
**kwargs,
|
| 1337 |
+
)
|
| 1338 |
+
|
| 1339 |
+
if content_type.startswith("application/json"):
|
| 1340 |
+
body = await request.json()
|
| 1341 |
+
if not isinstance(body, dict):
|
| 1342 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 1343 |
+
req = _build_request(RequestParser(body))
|
| 1344 |
+
|
| 1345 |
+
elif content_type.endswith("+json"):
|
| 1346 |
+
body = await request.json()
|
| 1347 |
+
if not isinstance(body, dict):
|
| 1348 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 1349 |
+
req = _build_request(RequestParser(body))
|
| 1350 |
+
|
| 1351 |
+
elif content_type.startswith("multipart/form-data"):
|
| 1352 |
+
form = await request.form()
|
| 1353 |
+
|
| 1354 |
+
ref_up = form.get("reference_audio")
|
| 1355 |
+
src_up = form.get("src_audio")
|
| 1356 |
+
|
| 1357 |
+
reference_audio_path = None
|
| 1358 |
+
src_audio_path = None
|
| 1359 |
+
|
| 1360 |
+
if isinstance(ref_up, StarletteUploadFile):
|
| 1361 |
+
reference_audio_path = await _save_upload_to_temp(ref_up, prefix="reference_audio")
|
| 1362 |
+
temp_files.append(reference_audio_path)
|
| 1363 |
+
else:
|
| 1364 |
+
reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
|
| 1365 |
+
|
| 1366 |
+
if isinstance(src_up, StarletteUploadFile):
|
| 1367 |
+
src_audio_path = await _save_upload_to_temp(src_up, prefix="src_audio")
|
| 1368 |
+
temp_files.append(src_audio_path)
|
| 1369 |
+
else:
|
| 1370 |
+
src_audio_path = str(form.get("src_audio_path") or "").strip() or None
|
| 1371 |
+
|
| 1372 |
+
req = _build_request(
|
| 1373 |
+
RequestParser(dict(form)),
|
| 1374 |
+
reference_audio_path=reference_audio_path,
|
| 1375 |
+
src_audio_path=src_audio_path,
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
elif content_type.startswith("application/x-www-form-urlencoded"):
|
| 1379 |
+
form = await request.form()
|
| 1380 |
+
reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
|
| 1381 |
+
src_audio_path = str(form.get("src_audio_path") or "").strip() or None
|
| 1382 |
+
req = _build_request(
|
| 1383 |
+
RequestParser(dict(form)),
|
| 1384 |
+
reference_audio_path=reference_audio_path,
|
| 1385 |
+
src_audio_path=src_audio_path,
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
else:
|
| 1389 |
+
raw = await request.body()
|
| 1390 |
+
raw_stripped = raw.lstrip()
|
| 1391 |
+
# Best-effort: accept missing/incorrect Content-Type if payload is valid JSON.
|
| 1392 |
+
if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["):
|
| 1393 |
+
try:
|
| 1394 |
+
body = json.loads(raw.decode("utf-8"))
|
| 1395 |
+
if isinstance(body, dict):
|
| 1396 |
+
req = _build_request(RequestParser(body))
|
| 1397 |
+
else:
|
| 1398 |
+
raise HTTPException(status_code=400, detail="JSON payload must be an object")
|
| 1399 |
+
except HTTPException:
|
| 1400 |
+
raise
|
| 1401 |
+
except Exception:
|
| 1402 |
+
raise HTTPException(
|
| 1403 |
+
status_code=400,
|
| 1404 |
+
detail="Invalid JSON body (hint: set 'Content-Type: application/json')",
|
| 1405 |
+
)
|
| 1406 |
+
# Best-effort: parse key=value bodies even if Content-Type is missing.
|
| 1407 |
+
elif raw_stripped and b"=" in raw:
|
| 1408 |
+
parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
|
| 1409 |
+
flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()}
|
| 1410 |
+
reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
|
| 1411 |
+
src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
|
| 1412 |
+
req = _build_request(
|
| 1413 |
+
RequestParser(flat),
|
| 1414 |
+
reference_audio_path=reference_audio_path,
|
| 1415 |
+
src_audio_path=src_audio_path,
|
| 1416 |
+
)
|
| 1417 |
+
else:
|
| 1418 |
+
raise HTTPException(
|
| 1419 |
+
status_code=415,
|
| 1420 |
+
detail=(
|
| 1421 |
+
f"Unsupported Content-Type: {content_type or '(missing)'}; "
|
| 1422 |
+
"use application/json, application/x-www-form-urlencoded, or multipart/form-data"
|
| 1423 |
+
),
|
| 1424 |
+
)
|
| 1425 |
+
|
| 1426 |
+
rec = store.create()
|
| 1427 |
+
|
| 1428 |
+
q: asyncio.Queue = app.state.job_queue
|
| 1429 |
+
if q.full():
|
| 1430 |
+
for p in temp_files:
|
| 1431 |
+
try:
|
| 1432 |
+
os.remove(p)
|
| 1433 |
+
except Exception:
|
| 1434 |
+
pass
|
| 1435 |
+
raise HTTPException(status_code=429, detail="Server busy: queue is full")
|
| 1436 |
+
|
| 1437 |
+
if temp_files:
|
| 1438 |
+
async with app.state.job_temp_files_lock:
|
| 1439 |
+
app.state.job_temp_files[rec.job_id] = temp_files
|
| 1440 |
+
|
| 1441 |
+
async with app.state.pending_lock:
|
| 1442 |
+
app.state.pending_ids.append(rec.job_id)
|
| 1443 |
+
position = len(app.state.pending_ids)
|
| 1444 |
+
|
| 1445 |
+
await q.put((rec.job_id, req))
|
| 1446 |
+
return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
|
| 1447 |
+
|
| 1448 |
+
@app.post("/v1/music/random", response_model=CreateJobResponse)
|
| 1449 |
+
async def create_random_sample_job(request: Request) -> CreateJobResponse:
|
| 1450 |
+
"""Create a sample-mode job that auto-generates caption/lyrics via LM."""
|
| 1451 |
+
|
| 1452 |
+
thinking_value: Any = None
|
| 1453 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1454 |
+
body_dict: Dict[str, Any] = {}
|
| 1455 |
+
|
| 1456 |
+
if "json" in content_type:
|
| 1457 |
+
try:
|
| 1458 |
+
payload = await request.json()
|
| 1459 |
+
if isinstance(payload, dict):
|
| 1460 |
+
body_dict = payload
|
| 1461 |
+
except Exception:
|
| 1462 |
+
body_dict = {}
|
| 1463 |
+
|
| 1464 |
+
if not body_dict and request.query_params:
|
| 1465 |
+
body_dict = dict(request.query_params)
|
| 1466 |
+
|
| 1467 |
+
thinking_value = body_dict.get("thinking")
|
| 1468 |
+
if thinking_value is None:
|
| 1469 |
+
thinking_value = body_dict.get("Thinking")
|
| 1470 |
+
|
| 1471 |
+
thinking_flag = _to_bool(thinking_value, True)
|
| 1472 |
+
|
| 1473 |
+
req = GenerateMusicRequest(
|
| 1474 |
+
caption="",
|
| 1475 |
+
lyrics="",
|
| 1476 |
+
thinking=thinking_flag,
|
| 1477 |
+
sample_mode=True,
|
| 1478 |
+
)
|
| 1479 |
+
|
| 1480 |
+
rec = store.create()
|
| 1481 |
+
q: asyncio.Queue = app.state.job_queue
|
| 1482 |
+
if q.full():
|
| 1483 |
+
raise HTTPException(status_code=429, detail="Server busy: queue is full")
|
| 1484 |
+
|
| 1485 |
+
async with app.state.pending_lock:
|
| 1486 |
+
app.state.pending_ids.append(rec.job_id)
|
| 1487 |
+
position = len(app.state.pending_ids)
|
| 1488 |
+
|
| 1489 |
+
await q.put((rec.job_id, req))
|
| 1490 |
+
return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
|
| 1491 |
+
|
| 1492 |
+
@app.post("/query_result")
|
| 1493 |
+
async def query_result(request: Request) -> List[Dict[str, Any]]:
|
| 1494 |
+
"""Batch query job results"""
|
| 1495 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 1496 |
+
|
| 1497 |
+
if "json" in content_type:
|
| 1498 |
+
body = await request.json()
|
| 1499 |
+
else:
|
| 1500 |
+
form = await request.form()
|
| 1501 |
+
body = {k: v for k, v in form.items()}
|
| 1502 |
+
|
| 1503 |
+
task_id_list_str = body.get("task_id_list", "[]")
|
| 1504 |
+
|
| 1505 |
+
# Parse task ID list
|
| 1506 |
+
if isinstance(task_id_list_str, list):
|
| 1507 |
+
task_id_list = task_id_list_str
|
| 1508 |
+
else:
|
| 1509 |
+
try:
|
| 1510 |
+
task_id_list = json.loads(task_id_list_str)
|
| 1511 |
+
except Exception:
|
| 1512 |
+
task_id_list = []
|
| 1513 |
+
|
| 1514 |
+
local_cache = getattr(app.state, 'local_cache', None)
|
| 1515 |
+
data_list = []
|
| 1516 |
+
current_time = time.time()
|
| 1517 |
+
|
| 1518 |
+
for task_id in task_id_list:
|
| 1519 |
+
result_key = f"{RESULT_KEY_PREFIX}{task_id}"
|
| 1520 |
+
|
| 1521 |
+
# Read from local cache first
|
| 1522 |
+
if local_cache:
|
| 1523 |
+
data = local_cache.get(result_key)
|
| 1524 |
+
if data:
|
| 1525 |
+
try:
|
| 1526 |
+
data_json = json.loads(data)
|
| 1527 |
+
except Exception:
|
| 1528 |
+
data_json = []
|
| 1529 |
+
|
| 1530 |
+
if len(data_json) <= 0:
|
| 1531 |
+
data_list.append({"task_id": task_id, "result": data, "status": 2})
|
| 1532 |
+
else:
|
| 1533 |
+
status = data_json[0].get("status")
|
| 1534 |
+
create_time = data_json[0].get("create_time", 0)
|
| 1535 |
+
if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS:
|
| 1536 |
+
data_list.append({"task_id": task_id, "result": data, "status": 2})
|
| 1537 |
+
else:
|
| 1538 |
+
data_list.append({
|
| 1539 |
+
"task_id": task_id,
|
| 1540 |
+
"result": data,
|
| 1541 |
+
"status": int(status) if status is not None else 1,
|
| 1542 |
+
})
|
| 1543 |
+
continue
|
| 1544 |
+
|
| 1545 |
+
# Fallback to job_store query
|
| 1546 |
+
rec = store.get(task_id)
|
| 1547 |
+
if rec:
|
| 1548 |
+
env = getattr(rec, 'env', 'development')
|
| 1549 |
+
create_time = rec.created_at
|
| 1550 |
+
status_int = _map_status(rec.status)
|
| 1551 |
+
|
| 1552 |
+
if rec.result and rec.status == "succeeded":
|
| 1553 |
+
audio_paths = rec.result.get("audio_paths", [])
|
| 1554 |
+
metas = rec.result.get("metas", {}) or {}
|
| 1555 |
+
result_data = [
|
| 1556 |
+
{
|
| 1557 |
+
"file": p, "wave": "", "status": status_int,
|
| 1558 |
+
"create_time": int(create_time), "env": env,
|
| 1559 |
+
"prompt": metas.get("caption", ""),
|
| 1560 |
+
"lyrics": metas.get("lyrics", ""),
|
| 1561 |
+
"metas": {
|
| 1562 |
+
"bpm": metas.get("bpm"),
|
| 1563 |
+
"duration": metas.get("duration"),
|
| 1564 |
+
"genres": metas.get("genres", ""),
|
| 1565 |
+
"keyscale": metas.get("keyscale", ""),
|
| 1566 |
+
"timesignature": metas.get("timesignature", ""),
|
| 1567 |
+
}
|
| 1568 |
+
}
|
| 1569 |
+
for p in audio_paths
|
| 1570 |
+
] if audio_paths else [{
|
| 1571 |
+
"file": "", "wave": "", "status": status_int,
|
| 1572 |
+
"create_time": int(create_time), "env": env,
|
| 1573 |
+
"prompt": metas.get("caption", ""),
|
| 1574 |
+
"lyrics": metas.get("lyrics", ""),
|
| 1575 |
+
"metas": {
|
| 1576 |
+
"bpm": metas.get("bpm"),
|
| 1577 |
+
"duration": metas.get("duration"),
|
| 1578 |
+
"genres": metas.get("genres", ""),
|
| 1579 |
+
"keyscale": metas.get("keyscale", ""),
|
| 1580 |
+
"timesignature": metas.get("timesignature", ""),
|
| 1581 |
+
}
|
| 1582 |
+
}]
|
| 1583 |
+
else:
|
| 1584 |
+
result_data = [{
|
| 1585 |
+
"file": "", "wave": "", "status": status_int,
|
| 1586 |
+
"create_time": int(create_time), "env": env,
|
| 1587 |
+
"prompt": "", "lyrics": "",
|
| 1588 |
+
"metas": {}
|
| 1589 |
+
}]
|
| 1590 |
+
|
| 1591 |
+
data_list.append({
|
| 1592 |
+
"task_id": task_id,
|
| 1593 |
+
"result": json.dumps(result_data, ensure_ascii=False),
|
| 1594 |
+
"status": status_int,
|
| 1595 |
+
})
|
| 1596 |
+
else:
|
| 1597 |
+
data_list.append({"task_id": task_id, "result": "[]", "status": 0})
|
| 1598 |
+
|
| 1599 |
+
return data_list
|
| 1600 |
+
|
| 1601 |
+
@app.get("/health")
|
| 1602 |
+
async def health_check():
|
| 1603 |
+
"""Health check endpoint for service status."""
|
| 1604 |
+
return {
|
| 1605 |
+
"status": "ok",
|
| 1606 |
+
"service": "ACE-Step API",
|
| 1607 |
+
"version": "1.0",
|
| 1608 |
+
}
|
| 1609 |
+
|
| 1610 |
+
@app.get("/v1/models")
|
| 1611 |
+
async def list_models():
|
| 1612 |
+
"""List available DiT models."""
|
| 1613 |
+
models = []
|
| 1614 |
+
|
| 1615 |
+
# Primary model (always available if initialized)
|
| 1616 |
+
if getattr(app.state, "_initialized", False):
|
| 1617 |
+
primary_model = _get_model_name(app.state._config_path)
|
| 1618 |
+
if primary_model:
|
| 1619 |
+
models.append({
|
| 1620 |
+
"name": primary_model,
|
| 1621 |
+
"is_default": True,
|
| 1622 |
+
})
|
| 1623 |
+
|
| 1624 |
+
# Secondary model
|
| 1625 |
+
if getattr(app.state, "_initialized2", False) and app.state._config_path2:
|
| 1626 |
+
secondary_model = _get_model_name(app.state._config_path2)
|
| 1627 |
+
if secondary_model:
|
| 1628 |
+
models.append({
|
| 1629 |
+
"name": secondary_model,
|
| 1630 |
+
"is_default": False,
|
| 1631 |
+
})
|
| 1632 |
+
|
| 1633 |
+
# Third model
|
| 1634 |
+
if getattr(app.state, "_initialized3", False) and app.state._config_path3:
|
| 1635 |
+
third_model = _get_model_name(app.state._config_path3)
|
| 1636 |
+
if third_model:
|
| 1637 |
+
models.append({
|
| 1638 |
+
"name": third_model,
|
| 1639 |
+
"is_default": False,
|
| 1640 |
+
})
|
| 1641 |
+
|
| 1642 |
+
return {
|
| 1643 |
+
"models": models,
|
| 1644 |
+
"default_model": models[0]["name"] if models else None,
|
| 1645 |
+
}
|
| 1646 |
+
|
| 1647 |
+
@app.get("/v1/audio")
|
| 1648 |
+
async def get_audio(path: str):
|
| 1649 |
+
"""Serve audio file by path."""
|
| 1650 |
+
from fastapi.responses import FileResponse
|
| 1651 |
+
|
| 1652 |
+
if not os.path.exists(path):
|
| 1653 |
+
raise HTTPException(status_code=404, detail=f"Audio file not found: {path}")
|
| 1654 |
+
|
| 1655 |
+
ext = os.path.splitext(path)[1].lower()
|
| 1656 |
+
media_types = {
|
| 1657 |
+
".mp3": "audio/mpeg",
|
| 1658 |
+
".wav": "audio/wav",
|
| 1659 |
+
".flac": "audio/flac",
|
| 1660 |
+
".ogg": "audio/ogg",
|
| 1661 |
+
}
|
| 1662 |
+
media_type = media_types.get(ext, "audio/mpeg")
|
| 1663 |
+
|
| 1664 |
+
return FileResponse(path, media_type=media_type)
|
| 1665 |
+
|
| 1666 |
+
return app
|
| 1667 |
+
|
| 1668 |
+
|
| 1669 |
+
app = create_app()
|
| 1670 |
+
|
| 1671 |
+
|
| 1672 |
+
def main() -> None:
|
| 1673 |
+
import argparse
|
| 1674 |
+
import uvicorn
|
| 1675 |
+
|
| 1676 |
+
parser = argparse.ArgumentParser(description="ACE-Step API server")
|
| 1677 |
+
parser.add_argument(
|
| 1678 |
+
"--host",
|
| 1679 |
+
default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
|
| 1680 |
+
help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
|
| 1681 |
+
)
|
| 1682 |
+
parser.add_argument(
|
| 1683 |
+
"--port",
|
| 1684 |
+
type=int,
|
| 1685 |
+
default=int(os.getenv("ACESTEP_API_PORT", "8001")),
|
| 1686 |
+
help="Bind port (default from ACESTEP_API_PORT or 8001)",
|
| 1687 |
+
)
|
| 1688 |
+
args = parser.parse_args()
|
| 1689 |
+
|
| 1690 |
+
# IMPORTANT: in-memory queue/store -> workers MUST be 1
|
| 1691 |
+
uvicorn.run(
|
| 1692 |
+
"acestep.api_server:app",
|
| 1693 |
+
host=str(args.host),
|
| 1694 |
+
port=int(args.port),
|
| 1695 |
+
reload=False,
|
| 1696 |
+
workers=1,
|
| 1697 |
+
)
|
| 1698 |
+
|
| 1699 |
+
if __name__ == "__main__":
|
| 1700 |
+
main()
|
spaces/Ace-Step-v1.5/acestep/audio_utils.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio saving and transcoding utility module
|
| 3 |
+
|
| 4 |
+
Independent audio file operations outside of handler, supporting:
|
| 5 |
+
- Save audio tensor/numpy to files (default FLAC format, fast)
|
| 6 |
+
- Format conversion (FLAC/WAV/MP3)
|
| 7 |
+
- Batch processing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space
|
| 13 |
+
# This forces torchaudio to use ffmpeg/sox/soundfile backends instead
|
| 14 |
+
os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0"
|
| 15 |
+
|
| 16 |
+
import hashlib
|
| 17 |
+
import json
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Union, Optional, List, Tuple
|
| 20 |
+
import torch
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torchaudio
|
| 23 |
+
from loguru import logger
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class AudioSaver:
|
| 27 |
+
"""Audio saving and transcoding utility class"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, default_format: str = "flac"):
|
| 30 |
+
"""
|
| 31 |
+
Initialize audio saver
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
default_format: Default save format ('flac', 'wav', 'mp3')
|
| 35 |
+
"""
|
| 36 |
+
self.default_format = default_format.lower()
|
| 37 |
+
if self.default_format not in ["flac", "wav", "mp3"]:
|
| 38 |
+
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
| 39 |
+
self.default_format = "flac"
|
| 40 |
+
|
| 41 |
+
def save_audio(
|
| 42 |
+
self,
|
| 43 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 44 |
+
output_path: Union[str, Path],
|
| 45 |
+
sample_rate: int = 48000,
|
| 46 |
+
format: Optional[str] = None,
|
| 47 |
+
channels_first: bool = True,
|
| 48 |
+
) -> str:
|
| 49 |
+
"""
|
| 50 |
+
Save audio data to file
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
| 54 |
+
output_path: Output file path (extension can be omitted)
|
| 55 |
+
sample_rate: Sample rate
|
| 56 |
+
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
| 57 |
+
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Actual saved file path
|
| 61 |
+
"""
|
| 62 |
+
format = (format or self.default_format).lower()
|
| 63 |
+
if format not in ["flac", "wav", "mp3"]:
|
| 64 |
+
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
| 65 |
+
format = self.default_format
|
| 66 |
+
|
| 67 |
+
# Ensure output path has correct extension
|
| 68 |
+
output_path = Path(output_path)
|
| 69 |
+
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
| 70 |
+
output_path = output_path.with_suffix(f'.{format}')
|
| 71 |
+
|
| 72 |
+
# Convert to torch tensor
|
| 73 |
+
if isinstance(audio_data, np.ndarray):
|
| 74 |
+
if channels_first:
|
| 75 |
+
# numpy [samples, channels] -> tensor [channels, samples]
|
| 76 |
+
audio_tensor = torch.from_numpy(audio_data.T).float()
|
| 77 |
+
else:
|
| 78 |
+
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
| 79 |
+
audio_tensor = torch.from_numpy(audio_data).float()
|
| 80 |
+
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
| 81 |
+
audio_tensor = audio_tensor.T
|
| 82 |
+
else:
|
| 83 |
+
# torch tensor
|
| 84 |
+
audio_tensor = audio_data.cpu().float()
|
| 85 |
+
if not channels_first and audio_tensor.dim() == 2:
|
| 86 |
+
# [samples, channels] -> [channels, samples]
|
| 87 |
+
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
| 88 |
+
audio_tensor = audio_tensor.T
|
| 89 |
+
|
| 90 |
+
# Ensure memory is contiguous
|
| 91 |
+
audio_tensor = audio_tensor.contiguous()
|
| 92 |
+
|
| 93 |
+
# Select backend and save
|
| 94 |
+
try:
|
| 95 |
+
if format == "mp3":
|
| 96 |
+
# MP3 uses ffmpeg backend
|
| 97 |
+
torchaudio.save(
|
| 98 |
+
str(output_path),
|
| 99 |
+
audio_tensor,
|
| 100 |
+
sample_rate,
|
| 101 |
+
channels_first=True,
|
| 102 |
+
backend='ffmpeg',
|
| 103 |
+
)
|
| 104 |
+
elif format in ["flac", "wav"]:
|
| 105 |
+
# FLAC and WAV use soundfile backend (fastest)
|
| 106 |
+
torchaudio.save(
|
| 107 |
+
str(output_path),
|
| 108 |
+
audio_tensor,
|
| 109 |
+
sample_rate,
|
| 110 |
+
channels_first=True,
|
| 111 |
+
backend='soundfile',
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
# Other formats use default backend
|
| 115 |
+
torchaudio.save(
|
| 116 |
+
str(output_path),
|
| 117 |
+
audio_tensor,
|
| 118 |
+
sample_rate,
|
| 119 |
+
channels_first=True,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 123 |
+
return str(output_path)
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
try:
|
| 127 |
+
import soundfile as sf
|
| 128 |
+
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
|
| 129 |
+
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
|
| 130 |
+
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 131 |
+
return str(output_path)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
| 134 |
+
raise
|
| 135 |
+
|
| 136 |
+
def _load_audio_file(self, audio_file: Union[str, Path]) -> Tuple[torch.Tensor, int]:
|
| 137 |
+
"""
|
| 138 |
+
Load audio file with ffmpeg backend, fallback to soundfile if failed.
|
| 139 |
+
|
| 140 |
+
This handles CUDA dependency issues with torchcodec on HuggingFace Space.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
audio_file: Path to the audio file
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tuple of (audio_tensor, sample_rate)
|
| 147 |
+
|
| 148 |
+
Raises:
|
| 149 |
+
FileNotFoundError: If the audio file doesn't exist
|
| 150 |
+
Exception: If all methods fail to load the audio
|
| 151 |
+
"""
|
| 152 |
+
audio_file = str(audio_file)
|
| 153 |
+
|
| 154 |
+
# Check if file exists first
|
| 155 |
+
if not Path(audio_file).exists():
|
| 156 |
+
raise FileNotFoundError(f"Audio file not found: {audio_file}")
|
| 157 |
+
|
| 158 |
+
# Try torchaudio with explicit ffmpeg backend first
|
| 159 |
+
try:
|
| 160 |
+
audio, sr = torchaudio.load(audio_file, backend="ffmpeg")
|
| 161 |
+
return audio, sr
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.debug(f"[AudioSaver._load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback")
|
| 164 |
+
|
| 165 |
+
# Fallback: use soundfile directly (most compatible)
|
| 166 |
+
try:
|
| 167 |
+
import soundfile as sf
|
| 168 |
+
audio_np, sr = sf.read(audio_file)
|
| 169 |
+
# soundfile returns [samples, channels] or [samples], convert to [channels, samples]
|
| 170 |
+
audio = torch.from_numpy(audio_np).float()
|
| 171 |
+
if audio.dim() == 1:
|
| 172 |
+
# Mono: [samples] -> [1, samples]
|
| 173 |
+
audio = audio.unsqueeze(0)
|
| 174 |
+
else:
|
| 175 |
+
# Stereo: [samples, channels] -> [channels, samples]
|
| 176 |
+
audio = audio.T
|
| 177 |
+
return audio, sr
|
| 178 |
+
except Exception as e:
|
| 179 |
+
logger.error(f"[AudioSaver._load_audio_file] All methods failed to load audio: {audio_file}, error: {e}")
|
| 180 |
+
raise
|
| 181 |
+
|
| 182 |
+
def convert_audio(
|
| 183 |
+
self,
|
| 184 |
+
input_path: Union[str, Path],
|
| 185 |
+
output_path: Union[str, Path],
|
| 186 |
+
output_format: str,
|
| 187 |
+
remove_input: bool = False,
|
| 188 |
+
) -> str:
|
| 189 |
+
"""
|
| 190 |
+
Convert audio format
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
input_path: Input audio file path
|
| 194 |
+
output_path: Output audio file path
|
| 195 |
+
output_format: Target format ('flac', 'wav', 'mp3')
|
| 196 |
+
remove_input: Whether to delete input file
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
Output file path
|
| 200 |
+
"""
|
| 201 |
+
input_path = Path(input_path)
|
| 202 |
+
output_path = Path(output_path)
|
| 203 |
+
|
| 204 |
+
if not input_path.exists():
|
| 205 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 206 |
+
|
| 207 |
+
# Load audio with fallback backends
|
| 208 |
+
audio_tensor, sample_rate = self._load_audio_file(input_path)
|
| 209 |
+
|
| 210 |
+
# Save as new format
|
| 211 |
+
output_path = self.save_audio(
|
| 212 |
+
audio_tensor,
|
| 213 |
+
output_path,
|
| 214 |
+
sample_rate=sample_rate,
|
| 215 |
+
format=output_format,
|
| 216 |
+
channels_first=True
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Delete input file if needed
|
| 220 |
+
if remove_input:
|
| 221 |
+
input_path.unlink()
|
| 222 |
+
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
| 223 |
+
|
| 224 |
+
return output_path
|
| 225 |
+
|
| 226 |
+
def save_batch(
|
| 227 |
+
self,
|
| 228 |
+
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
| 229 |
+
output_dir: Union[str, Path],
|
| 230 |
+
file_prefix: str = "audio",
|
| 231 |
+
sample_rate: int = 48000,
|
| 232 |
+
format: Optional[str] = None,
|
| 233 |
+
channels_first: bool = True,
|
| 234 |
+
) -> List[str]:
|
| 235 |
+
"""
|
| 236 |
+
Save audio batch
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
| 240 |
+
output_dir: Output directory
|
| 241 |
+
file_prefix: File prefix
|
| 242 |
+
sample_rate: Sample rate
|
| 243 |
+
format: Audio format
|
| 244 |
+
channels_first: Tensor format flag
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
List of saved file paths
|
| 248 |
+
"""
|
| 249 |
+
output_dir = Path(output_dir)
|
| 250 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 251 |
+
|
| 252 |
+
# Process batch
|
| 253 |
+
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
| 254 |
+
# [batch, channels, samples]
|
| 255 |
+
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
| 256 |
+
elif isinstance(audio_batch, list):
|
| 257 |
+
audio_list = audio_batch
|
| 258 |
+
else:
|
| 259 |
+
audio_list = [audio_batch]
|
| 260 |
+
|
| 261 |
+
saved_paths = []
|
| 262 |
+
for i, audio in enumerate(audio_list):
|
| 263 |
+
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
| 264 |
+
saved_path = self.save_audio(
|
| 265 |
+
audio,
|
| 266 |
+
output_path,
|
| 267 |
+
sample_rate=sample_rate,
|
| 268 |
+
format=format,
|
| 269 |
+
channels_first=channels_first
|
| 270 |
+
)
|
| 271 |
+
saved_paths.append(saved_path)
|
| 272 |
+
|
| 273 |
+
return saved_paths
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_audio_file_hash(audio_file) -> str:
|
| 277 |
+
"""
|
| 278 |
+
Get hash identifier for an audio file.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
audio_file: Path to audio file (str) or file-like object
|
| 282 |
+
|
| 283 |
+
Returns:
|
| 284 |
+
Hash string or empty string
|
| 285 |
+
"""
|
| 286 |
+
if audio_file is None:
|
| 287 |
+
return ""
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
if isinstance(audio_file, str):
|
| 291 |
+
if os.path.exists(audio_file):
|
| 292 |
+
with open(audio_file, 'rb') as f:
|
| 293 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 294 |
+
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
| 295 |
+
elif hasattr(audio_file, 'name'):
|
| 296 |
+
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
| 297 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 298 |
+
except Exception:
|
| 299 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def generate_uuid_from_params(params_dict) -> str:
|
| 303 |
+
"""
|
| 304 |
+
Generate deterministic UUID from generation parameters.
|
| 305 |
+
Same parameters will always generate the same UUID.
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
params_dict: Dictionary of parameters
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
UUID string
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 315 |
+
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
| 316 |
+
hash_hex = hash_obj.hexdigest()
|
| 317 |
+
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
| 318 |
+
return uuid_str
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def generate_uuid_from_audio_data(
|
| 322 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 323 |
+
seed: Optional[int] = None
|
| 324 |
+
) -> str:
|
| 325 |
+
"""
|
| 326 |
+
Generate UUID from audio data (for caching/deduplication)
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
audio_data: Audio data
|
| 330 |
+
seed: Optional seed value
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
UUID string
|
| 334 |
+
"""
|
| 335 |
+
if isinstance(audio_data, torch.Tensor):
|
| 336 |
+
# Convert to numpy and calculate hash
|
| 337 |
+
audio_np = audio_data.cpu().numpy()
|
| 338 |
+
else:
|
| 339 |
+
audio_np = audio_data
|
| 340 |
+
|
| 341 |
+
# Calculate data hash
|
| 342 |
+
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
| 343 |
+
|
| 344 |
+
if seed is not None:
|
| 345 |
+
combined = f"{data_hash}_{seed}"
|
| 346 |
+
return hashlib.md5(combined.encode()).hexdigest()
|
| 347 |
+
|
| 348 |
+
return data_hash
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Global default instance
|
| 352 |
+
_default_saver = AudioSaver(default_format="flac")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def save_audio(
|
| 356 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 357 |
+
output_path: Union[str, Path],
|
| 358 |
+
sample_rate: int = 48000,
|
| 359 |
+
format: Optional[str] = None,
|
| 360 |
+
channels_first: bool = True,
|
| 361 |
+
) -> str:
|
| 362 |
+
"""
|
| 363 |
+
Convenience function: save audio (using default configuration)
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
audio_data: Audio data
|
| 367 |
+
output_path: Output path
|
| 368 |
+
sample_rate: Sample rate
|
| 369 |
+
format: Format (default flac)
|
| 370 |
+
channels_first: Tensor format flag
|
| 371 |
+
|
| 372 |
+
Returns:
|
| 373 |
+
Saved file path
|
| 374 |
+
"""
|
| 375 |
+
return _default_saver.save_audio(
|
| 376 |
+
audio_data, output_path, sample_rate, format, channels_first
|
| 377 |
+
)
|
| 378 |
+
|
spaces/Ace-Step-v1.5/acestep/constants.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for ACE-Step
|
| 3 |
+
Centralized constants used across the codebase
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ==============================================================================
|
| 7 |
+
# Language Constants
|
| 8 |
+
# ==============================================================================
|
| 9 |
+
|
| 10 |
+
VALID_LANGUAGES = [
|
| 11 |
+
'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
|
| 12 |
+
'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
|
| 13 |
+
'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
|
| 14 |
+
'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
|
| 15 |
+
'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
|
| 16 |
+
'unknown'
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ==============================================================================
|
| 21 |
+
# Keyscale Constants
|
| 22 |
+
# ==============================================================================
|
| 23 |
+
|
| 24 |
+
KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
|
| 25 |
+
KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 26 |
+
KEYSCALE_MODES = ['major', 'minor']
|
| 27 |
+
|
| 28 |
+
# Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
|
| 29 |
+
VALID_KEYSCALES = set()
|
| 30 |
+
for note in KEYSCALE_NOTES:
|
| 31 |
+
for acc in KEYSCALE_ACCIDENTALS:
|
| 32 |
+
for mode in KEYSCALE_MODES:
|
| 33 |
+
VALID_KEYSCALES.add(f"{note}{acc} {mode}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ==============================================================================
|
| 37 |
+
# Metadata Range Constants
|
| 38 |
+
# ==============================================================================
|
| 39 |
+
|
| 40 |
+
# BPM (Beats Per Minute) range
|
| 41 |
+
BPM_MIN = 30
|
| 42 |
+
BPM_MAX = 300
|
| 43 |
+
|
| 44 |
+
# Duration range (in seconds)
|
| 45 |
+
DURATION_MIN = 10
|
| 46 |
+
DURATION_MAX = 600
|
| 47 |
+
|
| 48 |
+
# Valid time signatures
|
| 49 |
+
VALID_TIME_SIGNATURES = [2, 3, 4, 6]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ==============================================================================
|
| 53 |
+
# Task Type Constants
|
| 54 |
+
# ==============================================================================
|
| 55 |
+
|
| 56 |
+
TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 57 |
+
|
| 58 |
+
# Task types available for turbo models (subset)
|
| 59 |
+
TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
|
| 60 |
+
|
| 61 |
+
# Task types available for base models (full set)
|
| 62 |
+
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# ==============================================================================
|
| 66 |
+
# Instruction Constants
|
| 67 |
+
# ==============================================================================
|
| 68 |
+
|
| 69 |
+
# Default instructions
|
| 70 |
+
DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
| 71 |
+
DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
|
| 72 |
+
DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 73 |
+
DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
|
| 74 |
+
DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
|
| 75 |
+
|
| 76 |
+
# Instruction templates for each task type
|
| 77 |
+
# Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
|
| 78 |
+
# These should be formatted using .format() or f-strings when used
|
| 79 |
+
TASK_INSTRUCTIONS = {
|
| 80 |
+
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
| 81 |
+
"repaint": "Repaint the mask area based on the given conditions:",
|
| 82 |
+
"cover": "Generate audio semantic tokens based on the given conditions:",
|
| 83 |
+
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
| 84 |
+
"extract_default": "Extract the track from the audio:",
|
| 85 |
+
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
| 86 |
+
"lego_default": "Generate the track based on the audio context:",
|
| 87 |
+
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
| 88 |
+
"complete_default": "Complete the input track:",
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ==============================================================================
|
| 93 |
+
# Track/Instrument Constants
|
| 94 |
+
# ==============================================================================
|
| 95 |
+
|
| 96 |
+
TRACK_NAMES = [
|
| 97 |
+
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 98 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
SFT_GEN_PROMPT = """# Instruction
|
| 102 |
+
{}
|
| 103 |
+
|
| 104 |
+
# Caption
|
| 105 |
+
{}
|
| 106 |
+
|
| 107 |
+
# Metas
|
| 108 |
+
{}<|endoftext|>
|
| 109 |
+
"""
|
spaces/Ace-Step-v1.5/acestep/constrained_logits_processor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
spaces/Ace-Step-v1.5/acestep/dataset_handler.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Handler
|
| 3 |
+
Handles dataset import and exploration functionality
|
| 4 |
+
"""
|
| 5 |
+
from typing import Optional, Tuple, Any, Dict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DatasetHandler:
|
| 9 |
+
"""Dataset Handler for Dataset Explorer functionality"""
|
| 10 |
+
|
| 11 |
+
def __init__(self):
|
| 12 |
+
"""Initialize dataset handler"""
|
| 13 |
+
self.dataset = None
|
| 14 |
+
self.dataset_imported = False
|
| 15 |
+
|
| 16 |
+
def import_dataset(self, dataset_type: str) -> str:
|
| 17 |
+
"""
|
| 18 |
+
Import dataset (temporarily disabled)
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
dataset_type: Type of dataset to import (e.g., "train", "test")
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
Status message string
|
| 25 |
+
"""
|
| 26 |
+
self.dataset_imported = False
|
| 27 |
+
return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
|
| 28 |
+
|
| 29 |
+
def get_item_data(self, *args, **kwargs) -> Tuple:
|
| 30 |
+
"""
|
| 31 |
+
Get dataset item (temporarily disabled)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Tuple of placeholder values matching the expected return format
|
| 35 |
+
"""
|
| 36 |
+
return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
|
| 37 |
+
|
spaces/Ace-Step-v1.5/acestep/dit_alignment_score.py
ADDED
|
@@ -0,0 +1,870 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DiT Alignment Score Module
|
| 3 |
+
|
| 4 |
+
This module provides lyrics-to-audio alignment using cross-attention matrices
|
| 5 |
+
from DiT model for generating LRC timestamps.
|
| 6 |
+
|
| 7 |
+
Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
|
| 8 |
+
"""
|
| 9 |
+
import numba
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ================= Data Classes =================
|
| 18 |
+
@dataclass
|
| 19 |
+
class TokenTimestamp:
|
| 20 |
+
"""Stores per-token timing information."""
|
| 21 |
+
token_id: int
|
| 22 |
+
text: str
|
| 23 |
+
start: float
|
| 24 |
+
end: float
|
| 25 |
+
probability: float
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SentenceTimestamp:
|
| 30 |
+
"""Stores per-sentence timing information with token list."""
|
| 31 |
+
text: str
|
| 32 |
+
start: float
|
| 33 |
+
end: float
|
| 34 |
+
tokens: List[TokenTimestamp]
|
| 35 |
+
confidence: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ================= DTW Algorithm (Numba Optimized) =================
|
| 39 |
+
@numba.jit(nopython=True)
|
| 40 |
+
def dtw_cpu(x: np.ndarray):
|
| 41 |
+
"""
|
| 42 |
+
Dynamic Time Warping algorithm optimized with Numba.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Cost matrix of shape [N, M]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (text_indices, time_indices) arrays
|
| 49 |
+
"""
|
| 50 |
+
N, M = x.shape
|
| 51 |
+
# Use float32 for memory efficiency
|
| 52 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
| 53 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
| 54 |
+
cost[0, 0] = 0
|
| 55 |
+
|
| 56 |
+
for j in range(1, M + 1):
|
| 57 |
+
for i in range(1, N + 1):
|
| 58 |
+
c0 = cost[i - 1, j - 1]
|
| 59 |
+
c1 = cost[i - 1, j]
|
| 60 |
+
c2 = cost[i, j - 1]
|
| 61 |
+
|
| 62 |
+
if c0 < c1 and c0 < c2:
|
| 63 |
+
c, t = c0, 0
|
| 64 |
+
elif c1 < c0 and c1 < c2:
|
| 65 |
+
c, t = c1, 1
|
| 66 |
+
else:
|
| 67 |
+
c, t = c2, 2
|
| 68 |
+
|
| 69 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
| 70 |
+
trace[i, j] = t
|
| 71 |
+
|
| 72 |
+
return _backtrace(trace, N, M)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@numba.jit(nopython=True)
|
| 76 |
+
def _backtrace(trace: np.ndarray, N: int, M: int):
|
| 77 |
+
"""
|
| 78 |
+
Optimized backtrace function for DTW.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
trace: Trace matrix of shape (N+1, M+1)
|
| 82 |
+
N, M: Original matrix dimensions
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Path array of shape (2, path_len) - first row is text indices, second is time indices
|
| 86 |
+
"""
|
| 87 |
+
# Boundary handling
|
| 88 |
+
trace[0, :] = 2
|
| 89 |
+
trace[:, 0] = 1
|
| 90 |
+
|
| 91 |
+
# Pre-allocate array, max path length is N+M
|
| 92 |
+
max_path_len = N + M
|
| 93 |
+
path = np.zeros((2, max_path_len), dtype=np.int32)
|
| 94 |
+
|
| 95 |
+
i, j = N, M
|
| 96 |
+
path_idx = max_path_len - 1
|
| 97 |
+
|
| 98 |
+
while i > 0 or j > 0:
|
| 99 |
+
path[0, path_idx] = i - 1 # text index
|
| 100 |
+
path[1, path_idx] = j - 1 # time index
|
| 101 |
+
path_idx -= 1
|
| 102 |
+
|
| 103 |
+
t = trace[i, j]
|
| 104 |
+
if t == 0:
|
| 105 |
+
i -= 1
|
| 106 |
+
j -= 1
|
| 107 |
+
elif t == 1:
|
| 108 |
+
i -= 1
|
| 109 |
+
elif t == 2:
|
| 110 |
+
j -= 1
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
actual_len = max_path_len - path_idx - 1
|
| 115 |
+
return path[:, path_idx + 1:max_path_len]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ================= Utility Functions =================
|
| 119 |
+
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
Apply median filter to tensor.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor
|
| 125 |
+
filter_width: Width of median filter
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Filtered tensor
|
| 129 |
+
"""
|
| 130 |
+
pad_width = filter_width // 2
|
| 131 |
+
if x.shape[-1] <= pad_width:
|
| 132 |
+
return x
|
| 133 |
+
if x.ndim == 2:
|
| 134 |
+
x = x[None, :]
|
| 135 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 136 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 137 |
+
if result.ndim > 2:
|
| 138 |
+
result = result.squeeze(0)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ================= Main Aligner Class =================
|
| 143 |
+
class MusicStampsAligner:
|
| 144 |
+
"""
|
| 145 |
+
Aligner class for generating lyrics timestamps from cross-attention matrices.
|
| 146 |
+
|
| 147 |
+
Uses bidirectional consensus denoising and DTW for alignment.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, tokenizer):
|
| 151 |
+
"""
|
| 152 |
+
Initialize the aligner.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
tokenizer: Text tokenizer for decoding tokens
|
| 156 |
+
"""
|
| 157 |
+
self.tokenizer = tokenizer
|
| 158 |
+
|
| 159 |
+
def _apply_bidirectional_consensus(
|
| 160 |
+
self,
|
| 161 |
+
weights_stack: torch.Tensor,
|
| 162 |
+
violence_level: float,
|
| 163 |
+
medfilt_width: int
|
| 164 |
+
) -> tuple:
|
| 165 |
+
"""
|
| 166 |
+
Core denoising logic using bidirectional consensus.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
weights_stack: Attention weights [Heads, Tokens, Frames]
|
| 170 |
+
violence_level: Denoising strength coefficient
|
| 171 |
+
medfilt_width: Median filter width
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (calc_matrix, energy_matrix) as numpy arrays
|
| 175 |
+
"""
|
| 176 |
+
# A. Bidirectional Consensus
|
| 177 |
+
row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
|
| 178 |
+
col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
|
| 179 |
+
processed = row_prob * col_prob
|
| 180 |
+
|
| 181 |
+
# 1. Row suppression (kill horizontal crossing lines)
|
| 182 |
+
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
|
| 183 |
+
processed = processed - (violence_level * row_medians)
|
| 184 |
+
processed = torch.relu(processed)
|
| 185 |
+
|
| 186 |
+
# 2. Column suppression (kill vertical crossing lines)
|
| 187 |
+
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
|
| 188 |
+
processed = processed - (violence_level * col_medians)
|
| 189 |
+
processed = torch.relu(processed)
|
| 190 |
+
|
| 191 |
+
# C. Power sharpening
|
| 192 |
+
processed = processed ** 2
|
| 193 |
+
|
| 194 |
+
# Energy matrix for confidence
|
| 195 |
+
energy_matrix = processed.mean(dim=0).cpu().numpy()
|
| 196 |
+
|
| 197 |
+
# D. Z-Score normalization
|
| 198 |
+
std, mean = torch.std_mean(processed, unbiased=False)
|
| 199 |
+
weights_processed = (processed - mean) / (std + 1e-9)
|
| 200 |
+
|
| 201 |
+
# E. Median filtering
|
| 202 |
+
weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
|
| 203 |
+
calc_matrix = weights_processed.mean(dim=0).numpy()
|
| 204 |
+
|
| 205 |
+
return calc_matrix, energy_matrix
|
| 206 |
+
|
| 207 |
+
def _preprocess_attention(
|
| 208 |
+
self,
|
| 209 |
+
attention_matrix: torch.Tensor,
|
| 210 |
+
custom_config: Dict[int, List[int]],
|
| 211 |
+
violence_level: float,
|
| 212 |
+
medfilt_width: int = 7
|
| 213 |
+
) -> tuple:
|
| 214 |
+
"""
|
| 215 |
+
Preprocess attention matrix for alignment.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
|
| 219 |
+
custom_config: Dict mapping layer indices to head indices
|
| 220 |
+
violence_level: Denoising strength
|
| 221 |
+
medfilt_width: Median filter width
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (calc_matrix, energy_matrix, visual_matrix)
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 227 |
+
weights = torch.tensor(attention_matrix)
|
| 228 |
+
else:
|
| 229 |
+
weights = attention_matrix.clone()
|
| 230 |
+
|
| 231 |
+
weights = weights.cpu().float()
|
| 232 |
+
|
| 233 |
+
selected_tensors = []
|
| 234 |
+
for layer_idx, head_indices in custom_config.items():
|
| 235 |
+
for head_idx in head_indices:
|
| 236 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 237 |
+
head_matrix = weights[layer_idx, head_idx]
|
| 238 |
+
selected_tensors.append(head_matrix)
|
| 239 |
+
|
| 240 |
+
if not selected_tensors:
|
| 241 |
+
return None, None, None
|
| 242 |
+
|
| 243 |
+
# Stack selected heads: [Heads, Tokens, Frames]
|
| 244 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 245 |
+
visual_matrix = weights_stack.mean(dim=0).numpy()
|
| 246 |
+
|
| 247 |
+
calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
|
| 248 |
+
weights_stack, violence_level, medfilt_width
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return calc_matrix, energy_matrix, visual_matrix
|
| 252 |
+
|
| 253 |
+
def stamps_align_info(
|
| 254 |
+
self,
|
| 255 |
+
attention_matrix: torch.Tensor,
|
| 256 |
+
lyrics_tokens: List[int],
|
| 257 |
+
total_duration_seconds: float,
|
| 258 |
+
custom_config: Dict[int, List[int]],
|
| 259 |
+
return_matrices: bool = False,
|
| 260 |
+
violence_level: float = 2.0,
|
| 261 |
+
medfilt_width: int = 1
|
| 262 |
+
) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Get alignment information from attention matrix.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
|
| 268 |
+
lyrics_tokens: List of lyrics token IDs
|
| 269 |
+
total_duration_seconds: Total audio duration in seconds
|
| 270 |
+
custom_config: Dict mapping layer indices to head indices
|
| 271 |
+
return_matrices: Whether to return intermediate matrices
|
| 272 |
+
violence_level: Denoising strength
|
| 273 |
+
medfilt_width: Median filter width
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
|
| 277 |
+
and optionally energy_matrix and vis_matrix
|
| 278 |
+
"""
|
| 279 |
+
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
|
| 280 |
+
attention_matrix, custom_config, violence_level, medfilt_width
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if calc_matrix is None:
|
| 284 |
+
return {
|
| 285 |
+
"calc_matrix": None,
|
| 286 |
+
"lyrics_tokens": lyrics_tokens,
|
| 287 |
+
"total_duration_seconds": total_duration_seconds,
|
| 288 |
+
"error": "No valid attention heads found"
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
return_dict = {
|
| 292 |
+
"calc_matrix": calc_matrix,
|
| 293 |
+
"lyrics_tokens": lyrics_tokens,
|
| 294 |
+
"total_duration_seconds": total_duration_seconds
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if return_matrices:
|
| 298 |
+
return_dict['energy_matrix'] = energy_matrix
|
| 299 |
+
return_dict['vis_matrix'] = visual_matrix
|
| 300 |
+
|
| 301 |
+
return return_dict
|
| 302 |
+
|
| 303 |
+
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
|
| 304 |
+
"""
|
| 305 |
+
Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
|
| 306 |
+
|
| 307 |
+
For Chinese and other multi-byte characters, the tokenizer may split them
|
| 308 |
+
into multiple byte-level tokens. Decoding each token individually produces
|
| 309 |
+
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
|
| 310 |
+
to correctly track which characters each token contributes.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
token_ids: List of token IDs
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List of decoded text for each token position
|
| 317 |
+
"""
|
| 318 |
+
decoded_tokens = []
|
| 319 |
+
prev_bytes = b""
|
| 320 |
+
|
| 321 |
+
for i in range(len(token_ids)):
|
| 322 |
+
# Decode tokens from start to current position
|
| 323 |
+
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
|
| 324 |
+
current_bytes = current_text.encode('utf-8', errors='surrogatepass')
|
| 325 |
+
|
| 326 |
+
# The contribution of current token is the new bytes added
|
| 327 |
+
if len(current_bytes) >= len(prev_bytes):
|
| 328 |
+
new_bytes = current_bytes[len(prev_bytes):]
|
| 329 |
+
# Try to decode the new bytes; if incomplete, use empty string
|
| 330 |
+
try:
|
| 331 |
+
token_text = new_bytes.decode('utf-8')
|
| 332 |
+
except UnicodeDecodeError:
|
| 333 |
+
# Incomplete UTF-8 sequence, this token doesn't complete a character
|
| 334 |
+
token_text = ""
|
| 335 |
+
else:
|
| 336 |
+
# Edge case: current decode is shorter (shouldn't happen normally)
|
| 337 |
+
token_text = ""
|
| 338 |
+
|
| 339 |
+
decoded_tokens.append(token_text)
|
| 340 |
+
prev_bytes = current_bytes
|
| 341 |
+
|
| 342 |
+
return decoded_tokens
|
| 343 |
+
|
| 344 |
+
def token_timestamps(
|
| 345 |
+
self,
|
| 346 |
+
calc_matrix: np.ndarray,
|
| 347 |
+
lyrics_tokens: List[int],
|
| 348 |
+
total_duration_seconds: float
|
| 349 |
+
) -> List[TokenTimestamp]:
|
| 350 |
+
"""
|
| 351 |
+
Generate per-token timestamps using DTW.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
calc_matrix: Processed attention matrix [Tokens, Frames]
|
| 355 |
+
lyrics_tokens: List of token IDs
|
| 356 |
+
total_duration_seconds: Total audio duration
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
List of TokenTimestamp objects
|
| 360 |
+
"""
|
| 361 |
+
n_frames = calc_matrix.shape[-1]
|
| 362 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
|
| 363 |
+
|
| 364 |
+
seconds_per_frame = total_duration_seconds / n_frames
|
| 365 |
+
alignment_results = []
|
| 366 |
+
|
| 367 |
+
# Use incremental decoding to properly handle multi-byte UTF-8 characters
|
| 368 |
+
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
|
| 369 |
+
|
| 370 |
+
for i in range(len(lyrics_tokens)):
|
| 371 |
+
mask = (text_indices == i)
|
| 372 |
+
|
| 373 |
+
if not np.any(mask):
|
| 374 |
+
start = alignment_results[-1].end if alignment_results else 0.0
|
| 375 |
+
end = start
|
| 376 |
+
token_conf = 0.0
|
| 377 |
+
else:
|
| 378 |
+
times = time_indices[mask] * seconds_per_frame
|
| 379 |
+
start = times[0]
|
| 380 |
+
end = times[-1]
|
| 381 |
+
token_conf = 0.0
|
| 382 |
+
|
| 383 |
+
if end < start:
|
| 384 |
+
end = start
|
| 385 |
+
|
| 386 |
+
alignment_results.append(TokenTimestamp(
|
| 387 |
+
token_id=lyrics_tokens[i],
|
| 388 |
+
text=decoded_tokens[i],
|
| 389 |
+
start=float(start),
|
| 390 |
+
end=float(end),
|
| 391 |
+
probability=token_conf
|
| 392 |
+
))
|
| 393 |
+
|
| 394 |
+
return alignment_results
|
| 395 |
+
|
| 396 |
+
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
|
| 397 |
+
"""
|
| 398 |
+
Decode a sentence by decoding all token IDs together.
|
| 399 |
+
This avoids UTF-8 encoding issues from joining individual token texts.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
tokens: List of TokenTimestamp objects
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Properly decoded sentence text
|
| 406 |
+
"""
|
| 407 |
+
token_ids = [t.token_id for t in tokens]
|
| 408 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
| 409 |
+
|
| 410 |
+
def sentence_timestamps(
|
| 411 |
+
self,
|
| 412 |
+
token_alignment: List[TokenTimestamp]
|
| 413 |
+
) -> List[SentenceTimestamp]:
|
| 414 |
+
"""
|
| 415 |
+
Group token timestamps into sentence timestamps.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
token_alignment: List of TokenTimestamp objects
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List of SentenceTimestamp objects
|
| 422 |
+
"""
|
| 423 |
+
results = []
|
| 424 |
+
current_tokens = []
|
| 425 |
+
|
| 426 |
+
for token in token_alignment:
|
| 427 |
+
current_tokens.append(token)
|
| 428 |
+
|
| 429 |
+
if '\n' in token.text:
|
| 430 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 431 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 432 |
+
|
| 433 |
+
if full_text.strip():
|
| 434 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 435 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 436 |
+
|
| 437 |
+
results.append(SentenceTimestamp(
|
| 438 |
+
text=full_text.strip(),
|
| 439 |
+
start=round(current_tokens[0].start, 3),
|
| 440 |
+
end=round(current_tokens[-1].end, 3),
|
| 441 |
+
tokens=list(current_tokens),
|
| 442 |
+
confidence=sent_conf
|
| 443 |
+
))
|
| 444 |
+
|
| 445 |
+
current_tokens = []
|
| 446 |
+
|
| 447 |
+
# Handle last sentence
|
| 448 |
+
if current_tokens:
|
| 449 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 450 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 451 |
+
if full_text.strip():
|
| 452 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 453 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 454 |
+
|
| 455 |
+
results.append(SentenceTimestamp(
|
| 456 |
+
text=full_text.strip(),
|
| 457 |
+
start=round(current_tokens[0].start, 3),
|
| 458 |
+
end=round(current_tokens[-1].end, 3),
|
| 459 |
+
tokens=list(current_tokens),
|
| 460 |
+
confidence=sent_conf
|
| 461 |
+
))
|
| 462 |
+
|
| 463 |
+
# Normalize confidence scores
|
| 464 |
+
if results:
|
| 465 |
+
all_scores = [s.confidence for s in results]
|
| 466 |
+
min_score = min(all_scores)
|
| 467 |
+
max_score = max(all_scores)
|
| 468 |
+
score_range = max_score - min_score
|
| 469 |
+
|
| 470 |
+
if score_range > 1e-9:
|
| 471 |
+
for s in results:
|
| 472 |
+
normalized_score = (s.confidence - min_score) / score_range
|
| 473 |
+
s.confidence = round(normalized_score, 2)
|
| 474 |
+
else:
|
| 475 |
+
for s in results:
|
| 476 |
+
s.confidence = round(s.confidence, 2)
|
| 477 |
+
|
| 478 |
+
return results
|
| 479 |
+
|
| 480 |
+
def format_lrc(
|
| 481 |
+
self,
|
| 482 |
+
sentence_timestamps: List[SentenceTimestamp],
|
| 483 |
+
include_end_time: bool = False
|
| 484 |
+
) -> str:
|
| 485 |
+
"""
|
| 486 |
+
Format sentence timestamps as LRC lyrics format.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
sentence_timestamps: List of SentenceTimestamp objects
|
| 490 |
+
include_end_time: Whether to include end time (enhanced LRC format)
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
LRC formatted string
|
| 494 |
+
"""
|
| 495 |
+
lines = []
|
| 496 |
+
|
| 497 |
+
for sentence in sentence_timestamps:
|
| 498 |
+
# Convert seconds to mm:ss.xx format
|
| 499 |
+
start_minutes = int(sentence.start // 60)
|
| 500 |
+
start_seconds = sentence.start % 60
|
| 501 |
+
|
| 502 |
+
if include_end_time:
|
| 503 |
+
end_minutes = int(sentence.end // 60)
|
| 504 |
+
end_seconds = sentence.end % 60
|
| 505 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
|
| 506 |
+
else:
|
| 507 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
|
| 508 |
+
|
| 509 |
+
# Clean the text (remove structural tags like [verse], [chorus])
|
| 510 |
+
text = sentence.text
|
| 511 |
+
|
| 512 |
+
lines.append(f"{timestamp}{text}")
|
| 513 |
+
|
| 514 |
+
return "\n".join(lines)
|
| 515 |
+
|
| 516 |
+
def get_timestamps_and_lrc(
|
| 517 |
+
self,
|
| 518 |
+
calc_matrix: np.ndarray,
|
| 519 |
+
lyrics_tokens: List[int],
|
| 520 |
+
total_duration_seconds: float
|
| 521 |
+
) -> Dict[str, Any]:
|
| 522 |
+
"""
|
| 523 |
+
Convenience method to get both timestamps and LRC in one call.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
calc_matrix: Processed attention matrix
|
| 527 |
+
lyrics_tokens: List of token IDs
|
| 528 |
+
total_duration_seconds: Total audio duration
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Dict containing token_timestamps, sentence_timestamps, and lrc_text
|
| 532 |
+
"""
|
| 533 |
+
token_stamps = self.token_timestamps(
|
| 534 |
+
calc_matrix=calc_matrix,
|
| 535 |
+
lyrics_tokens=lyrics_tokens,
|
| 536 |
+
total_duration_seconds=total_duration_seconds
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
sentence_stamps = self.sentence_timestamps(token_stamps)
|
| 540 |
+
lrc_text = self.format_lrc(sentence_stamps)
|
| 541 |
+
|
| 542 |
+
return {
|
| 543 |
+
"token_timestamps": token_stamps,
|
| 544 |
+
"sentence_timestamps": sentence_stamps,
|
| 545 |
+
"lrc_text": lrc_text
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class MusicLyricScorer:
|
| 550 |
+
"""
|
| 551 |
+
Scorer class for evaluating lyrics-to-audio alignment quality.
|
| 552 |
+
|
| 553 |
+
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
|
| 554 |
+
using tensor operations for potential differentiability or GPU acceleration.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(self, tokenizer: Any):
|
| 558 |
+
"""
|
| 559 |
+
Initialize the aligner.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
tokenizer: Tokenizer instance (must implement .decode()).
|
| 563 |
+
"""
|
| 564 |
+
self.tokenizer = tokenizer
|
| 565 |
+
|
| 566 |
+
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
|
| 567 |
+
"""
|
| 568 |
+
Generate a mask distinguishing lyrics (1) from structural tags (0).
|
| 569 |
+
Uses self.tokenizer to decode tokens.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
token_ids: List of token IDs.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Numpy array of shape [len(token_ids)] with 1 or 0.
|
| 576 |
+
"""
|
| 577 |
+
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
|
| 578 |
+
mask = np.ones(len(token_ids), dtype=np.int32)
|
| 579 |
+
in_bracket = False
|
| 580 |
+
|
| 581 |
+
for i, token_str in enumerate(decoded_tokens):
|
| 582 |
+
if '[' in token_str:
|
| 583 |
+
in_bracket = True
|
| 584 |
+
if in_bracket:
|
| 585 |
+
mask[i] = 0
|
| 586 |
+
if ']' in token_str:
|
| 587 |
+
in_bracket = False
|
| 588 |
+
mask[i] = 0
|
| 589 |
+
return mask
|
| 590 |
+
|
| 591 |
+
def _preprocess_attention(
|
| 592 |
+
self,
|
| 593 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 594 |
+
custom_config: Dict[int, List[int]],
|
| 595 |
+
medfilt_width: int = 1
|
| 596 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
|
| 597 |
+
"""
|
| 598 |
+
Extracts and normalizes the attention matrix.
|
| 599 |
+
|
| 600 |
+
Logic V4: Uses Min-Max normalization to highlight energy differences.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
|
| 604 |
+
custom_config: Config mapping layers to heads.
|
| 605 |
+
medfilt_width: Width for median filtering.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
|
| 609 |
+
"""
|
| 610 |
+
# 1. Prepare Tensor
|
| 611 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 612 |
+
weights = torch.tensor(attention_matrix)
|
| 613 |
+
else:
|
| 614 |
+
weights = attention_matrix.clone()
|
| 615 |
+
weights = weights.cpu().float()
|
| 616 |
+
|
| 617 |
+
# 2. Select Heads based on config
|
| 618 |
+
selected_tensors = []
|
| 619 |
+
for layer_idx, head_indices in custom_config.items():
|
| 620 |
+
for head_idx in head_indices:
|
| 621 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 622 |
+
selected_tensors.append(weights[layer_idx, head_idx])
|
| 623 |
+
|
| 624 |
+
if not selected_tensors:
|
| 625 |
+
return None, None, None
|
| 626 |
+
|
| 627 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 628 |
+
|
| 629 |
+
# 3. Average Heads
|
| 630 |
+
avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
|
| 631 |
+
|
| 632 |
+
# 4. Preprocessing Logic
|
| 633 |
+
# Min-Max normalization preserving energy distribution
|
| 634 |
+
# Median filter is applied to the energy matrix
|
| 635 |
+
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
|
| 636 |
+
energy_matrix = energy_tensor.numpy()
|
| 637 |
+
|
| 638 |
+
e_min, e_max = energy_matrix.min(), energy_matrix.max()
|
| 639 |
+
|
| 640 |
+
if e_max - e_min > 1e-9:
|
| 641 |
+
energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
|
| 642 |
+
else:
|
| 643 |
+
energy_matrix = np.zeros_like(energy_matrix)
|
| 644 |
+
|
| 645 |
+
# Contrast enhancement for DTW pathfinding
|
| 646 |
+
# calc_matrix is used for pathfinding, energy_matrix for scoring
|
| 647 |
+
calc_matrix = energy_matrix ** 2
|
| 648 |
+
|
| 649 |
+
return calc_matrix, energy_matrix, avg_weights
|
| 650 |
+
|
| 651 |
+
def _compute_alignment_metrics(
|
| 652 |
+
self,
|
| 653 |
+
energy_matrix: torch.Tensor,
|
| 654 |
+
path_coords: torch.Tensor,
|
| 655 |
+
type_mask: torch.Tensor,
|
| 656 |
+
time_weight: float = 0.01,
|
| 657 |
+
overlap_frames: float = 9.0,
|
| 658 |
+
instrumental_weight: float = 1.0
|
| 659 |
+
) -> Tuple[float, float, float]:
|
| 660 |
+
"""
|
| 661 |
+
Core metric calculation logic using high-precision Tensor operations.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
energy_matrix: Normalized energy [Rows, Cols].
|
| 665 |
+
path_coords: DTW path coordinates [Steps, 2].
|
| 666 |
+
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
|
| 667 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 668 |
+
overlap_frames: Allowed overlap for monotonicity check.
|
| 669 |
+
instrumental_weight: Weight for non-lyric tokens in confidence calc.
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Tuple of (coverage, monotonicity, confidence).
|
| 673 |
+
"""
|
| 674 |
+
# Ensure high precision for internal calculation
|
| 675 |
+
energy_matrix = energy_matrix.to(dtype=torch.float64)
|
| 676 |
+
path_coords = path_coords.long()
|
| 677 |
+
type_mask = type_mask.long()
|
| 678 |
+
|
| 679 |
+
device = energy_matrix.device
|
| 680 |
+
rows, cols = energy_matrix.shape
|
| 681 |
+
|
| 682 |
+
is_lyrics_row = (type_mask == 1)
|
| 683 |
+
|
| 684 |
+
# ================= A. Coverage Score =================
|
| 685 |
+
# Ratio of lyric lines that have significant energy peak
|
| 686 |
+
row_max_energies = energy_matrix.max(dim=1).values
|
| 687 |
+
total_sung_rows = is_lyrics_row.sum().double()
|
| 688 |
+
|
| 689 |
+
coverage_threshold = 0.1
|
| 690 |
+
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
|
| 691 |
+
valid_sung_rows = valid_sung_mask.sum().double()
|
| 692 |
+
|
| 693 |
+
if total_sung_rows > 0:
|
| 694 |
+
coverage_score = valid_sung_rows / total_sung_rows
|
| 695 |
+
else:
|
| 696 |
+
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 697 |
+
|
| 698 |
+
# ================= B. Monotonicity Score =================
|
| 699 |
+
# Check if the "center of mass" of lyric lines moves forward in time
|
| 700 |
+
col_indices = torch.arange(cols, device=device, dtype=torch.float64)
|
| 701 |
+
|
| 702 |
+
# Zero out low energy noise
|
| 703 |
+
weights = torch.where(
|
| 704 |
+
energy_matrix > time_weight,
|
| 705 |
+
energy_matrix,
|
| 706 |
+
torch.zeros_like(energy_matrix)
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
sum_w = weights.sum(dim=1)
|
| 710 |
+
sum_t = (weights * col_indices).sum(dim=1)
|
| 711 |
+
|
| 712 |
+
# Calculate centroids
|
| 713 |
+
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
|
| 714 |
+
valid_w_mask = sum_w > 1e-9
|
| 715 |
+
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
|
| 716 |
+
|
| 717 |
+
# Extract sequence of valid lyrics centroids
|
| 718 |
+
valid_sequence_mask = is_lyrics_row & (centroids >= 0)
|
| 719 |
+
sung_centroids = centroids[valid_sequence_mask]
|
| 720 |
+
|
| 721 |
+
cnt = sung_centroids.shape[0]
|
| 722 |
+
if cnt > 1:
|
| 723 |
+
curr_c = sung_centroids[:-1]
|
| 724 |
+
next_c = sung_centroids[1:]
|
| 725 |
+
|
| 726 |
+
# Check non-decreasing order with overlap tolerance
|
| 727 |
+
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
|
| 728 |
+
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
|
| 729 |
+
monotonicity_score = non_decreasing / pairs
|
| 730 |
+
else:
|
| 731 |
+
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 732 |
+
|
| 733 |
+
# ================= C. Path Confidence =================
|
| 734 |
+
# Average energy along the optimal path
|
| 735 |
+
if path_coords.shape[0] > 0:
|
| 736 |
+
p_rows = path_coords[:, 0]
|
| 737 |
+
p_cols = path_coords[:, 1]
|
| 738 |
+
|
| 739 |
+
path_energies = energy_matrix[p_rows, p_cols]
|
| 740 |
+
step_weights = torch.ones_like(path_energies)
|
| 741 |
+
|
| 742 |
+
# Lower weight for instrumental/tag steps
|
| 743 |
+
is_inst_step = (type_mask[p_rows] == 0)
|
| 744 |
+
step_weights[is_inst_step] = instrumental_weight
|
| 745 |
+
|
| 746 |
+
total_energy = (path_energies * step_weights).sum()
|
| 747 |
+
total_steps = step_weights.sum()
|
| 748 |
+
|
| 749 |
+
if total_steps > 0:
|
| 750 |
+
path_confidence = total_energy / total_steps
|
| 751 |
+
else:
|
| 752 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 753 |
+
else:
|
| 754 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 755 |
+
|
| 756 |
+
return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
|
| 757 |
+
|
| 758 |
+
def lyrics_alignment_info(
|
| 759 |
+
self,
|
| 760 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 761 |
+
token_ids: List[int],
|
| 762 |
+
custom_config: Dict[int, List[int]],
|
| 763 |
+
return_matrices: bool = False,
|
| 764 |
+
medfilt_width: int = 1
|
| 765 |
+
) -> Dict[str, Any]:
|
| 766 |
+
"""
|
| 767 |
+
Generates alignment path and processed matrices.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
attention_matrix: Input attention tensor.
|
| 771 |
+
token_ids: Corresponding token IDs.
|
| 772 |
+
custom_config: Layer/Head configuration.
|
| 773 |
+
return_matrices: If True, returns matrices in the output.
|
| 774 |
+
medfilt_width: Median filter width.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
Dict or AlignmentInfo object containing path and masks.
|
| 778 |
+
"""
|
| 779 |
+
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
|
| 780 |
+
attention_matrix, custom_config, medfilt_width
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if calc_matrix is None:
|
| 784 |
+
return {
|
| 785 |
+
"calc_matrix": None,
|
| 786 |
+
"error": "No valid attention heads found"
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
# 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
|
| 790 |
+
# Uses self.tokenizer internally
|
| 791 |
+
type_mask = self._generate_token_type_mask(token_ids)
|
| 792 |
+
|
| 793 |
+
# Safety check for shape mismatch
|
| 794 |
+
if len(type_mask) != energy_matrix.shape[0]:
|
| 795 |
+
# Fallback to all lyrics if shapes don't align
|
| 796 |
+
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
|
| 797 |
+
|
| 798 |
+
# 2. DTW Pathfinding
|
| 799 |
+
# Using negative calc_matrix because DTW minimizes cost
|
| 800 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
|
| 801 |
+
path_coords = np.stack([text_indices, time_indices], axis=1)
|
| 802 |
+
|
| 803 |
+
return_dict = {
|
| 804 |
+
"path_coords": path_coords,
|
| 805 |
+
"type_mask": type_mask,
|
| 806 |
+
"energy_matrix": energy_matrix
|
| 807 |
+
}
|
| 808 |
+
if return_matrices:
|
| 809 |
+
return_dict['calc_matrix'] = calc_matrix
|
| 810 |
+
return_dict['vis_matrix'] = vis_matrix
|
| 811 |
+
|
| 812 |
+
return return_dict
|
| 813 |
+
|
| 814 |
+
def calculate_score(
|
| 815 |
+
self,
|
| 816 |
+
energy_matrix: Union[torch.Tensor, np.ndarray],
|
| 817 |
+
type_mask: Union[torch.Tensor, np.ndarray],
|
| 818 |
+
path_coords: Union[torch.Tensor, np.ndarray],
|
| 819 |
+
time_weight: float = 0.01,
|
| 820 |
+
overlap_frames: float = 9.0,
|
| 821 |
+
instrumental_weight: float = 1.0
|
| 822 |
+
) -> Dict[str, Any]:
|
| 823 |
+
"""
|
| 824 |
+
Calculates the final alignment score based on pre-computed components.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
energy_matrix: Processed energy matrix.
|
| 828 |
+
type_mask: Token type mask.
|
| 829 |
+
path_coords: DTW path coordinates.
|
| 830 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 831 |
+
overlap_frames: Allowed backward movement frames.
|
| 832 |
+
instrumental_weight: Weight for non-lyric path steps.
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
AlignmentScore object containing individual metrics and final score.
|
| 836 |
+
"""
|
| 837 |
+
# Ensure Inputs are Tensors on the correct device
|
| 838 |
+
if not isinstance(energy_matrix, torch.Tensor):
|
| 839 |
+
energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
|
| 840 |
+
|
| 841 |
+
device = energy_matrix.device
|
| 842 |
+
|
| 843 |
+
if not isinstance(type_mask, torch.Tensor):
|
| 844 |
+
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
|
| 845 |
+
else:
|
| 846 |
+
type_mask = type_mask.to(device=device, dtype=torch.long)
|
| 847 |
+
|
| 848 |
+
if not isinstance(path_coords, torch.Tensor):
|
| 849 |
+
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
|
| 850 |
+
else:
|
| 851 |
+
path_coords = path_coords.to(device=device, dtype=torch.long)
|
| 852 |
+
|
| 853 |
+
# Compute Metrics
|
| 854 |
+
coverage, monotonicity, confidence = self._compute_alignment_metrics(
|
| 855 |
+
energy_matrix=energy_matrix,
|
| 856 |
+
path_coords=path_coords,
|
| 857 |
+
type_mask=type_mask,
|
| 858 |
+
time_weight=time_weight,
|
| 859 |
+
overlap_frames=overlap_frames,
|
| 860 |
+
instrumental_weight=instrumental_weight
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
# Final Score Calculation
|
| 864 |
+
# (Cov^2 * Mono^2 * Conf)
|
| 865 |
+
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
|
| 866 |
+
final_score = float(np.clip(final_score, 0.0, 1.0))
|
| 867 |
+
|
| 868 |
+
return {
|
| 869 |
+
"lyrics_score": round(final_score, 4)
|
| 870 |
+
}
|
spaces/Ace-Step-v1.5/acestep/genres_vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from acestep.gradio_ui.interfaces import create_gradio_interface
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/__init__.py
ADDED
|
@@ -0,0 +1,1310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Event Handlers Module
|
| 3 |
+
Main entry point for setting up all event handlers
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
# Import handler modules
|
| 9 |
+
from . import generation_handlers as gen_h
|
| 10 |
+
from . import results_handlers as res_h
|
| 11 |
+
from . import training_handlers as train_h
|
| 12 |
+
from acestep.gradio_ui.i18n import t
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
|
| 16 |
+
"""Setup event handlers connecting UI components and business logic
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
init_params: Dictionary containing initialization parameters including:
|
| 20 |
+
- dit_handler_2: Optional second DiT handler for multi-model setup
|
| 21 |
+
- available_dit_models: List of available DiT model names
|
| 22 |
+
- config_path: Primary model config path
|
| 23 |
+
- config_path_2: Secondary model config path (if available)
|
| 24 |
+
"""
|
| 25 |
+
# Get secondary DiT handler from init_params (for multi-model support)
|
| 26 |
+
dit_handler_2 = init_params.get('dit_handler_2') if init_params else None
|
| 27 |
+
config_path_1 = init_params.get('config_path', '') if init_params else ''
|
| 28 |
+
config_path_2 = init_params.get('config_path_2', '') if init_params else ''
|
| 29 |
+
|
| 30 |
+
# ========== Dataset Handlers ==========
|
| 31 |
+
dataset_section["import_dataset_btn"].click(
|
| 32 |
+
fn=dataset_handler.import_dataset,
|
| 33 |
+
inputs=[dataset_section["dataset_type"]],
|
| 34 |
+
outputs=[dataset_section["data_status"]]
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# ========== Service Initialization ==========
|
| 38 |
+
generation_section["refresh_btn"].click(
|
| 39 |
+
fn=lambda: gen_h.refresh_checkpoints(dit_handler),
|
| 40 |
+
outputs=[generation_section["checkpoint_dropdown"]]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
generation_section["config_path"].change(
|
| 44 |
+
fn=gen_h.update_model_type_settings,
|
| 45 |
+
inputs=[generation_section["config_path"]],
|
| 46 |
+
outputs=[
|
| 47 |
+
generation_section["inference_steps"],
|
| 48 |
+
generation_section["guidance_scale"],
|
| 49 |
+
generation_section["use_adg"],
|
| 50 |
+
generation_section["shift"],
|
| 51 |
+
generation_section["cfg_interval_start"],
|
| 52 |
+
generation_section["cfg_interval_end"],
|
| 53 |
+
generation_section["task_type"],
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
generation_section["init_btn"].click(
|
| 58 |
+
fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
|
| 59 |
+
inputs=[
|
| 60 |
+
generation_section["checkpoint_dropdown"],
|
| 61 |
+
generation_section["config_path"],
|
| 62 |
+
generation_section["device"],
|
| 63 |
+
generation_section["init_llm_checkbox"],
|
| 64 |
+
generation_section["lm_model_path"],
|
| 65 |
+
generation_section["backend_dropdown"],
|
| 66 |
+
generation_section["use_flash_attention_checkbox"],
|
| 67 |
+
generation_section["offload_to_cpu_checkbox"],
|
| 68 |
+
generation_section["offload_dit_to_cpu_checkbox"],
|
| 69 |
+
],
|
| 70 |
+
outputs=[
|
| 71 |
+
generation_section["init_status"],
|
| 72 |
+
generation_section["generate_btn"],
|
| 73 |
+
generation_section["service_config_accordion"],
|
| 74 |
+
# Model type settings (updated based on actual loaded model)
|
| 75 |
+
generation_section["inference_steps"],
|
| 76 |
+
generation_section["guidance_scale"],
|
| 77 |
+
generation_section["use_adg"],
|
| 78 |
+
generation_section["shift"],
|
| 79 |
+
generation_section["cfg_interval_start"],
|
| 80 |
+
generation_section["cfg_interval_end"],
|
| 81 |
+
generation_section["task_type"],
|
| 82 |
+
]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# ========== LoRA Handlers ==========
|
| 86 |
+
generation_section["load_lora_btn"].click(
|
| 87 |
+
fn=dit_handler.load_lora,
|
| 88 |
+
inputs=[generation_section["lora_path"]],
|
| 89 |
+
outputs=[generation_section["lora_status"]]
|
| 90 |
+
).then(
|
| 91 |
+
# Update checkbox to enabled state after loading
|
| 92 |
+
fn=lambda: gr.update(value=True),
|
| 93 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
generation_section["unload_lora_btn"].click(
|
| 97 |
+
fn=dit_handler.unload_lora,
|
| 98 |
+
outputs=[generation_section["lora_status"]]
|
| 99 |
+
).then(
|
| 100 |
+
# Update checkbox to disabled state after unloading
|
| 101 |
+
fn=lambda: gr.update(value=False),
|
| 102 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
generation_section["use_lora_checkbox"].change(
|
| 106 |
+
fn=dit_handler.set_use_lora,
|
| 107 |
+
inputs=[generation_section["use_lora_checkbox"]],
|
| 108 |
+
outputs=[generation_section["lora_status"]]
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# ========== UI Visibility Updates ==========
|
| 112 |
+
generation_section["init_llm_checkbox"].change(
|
| 113 |
+
fn=gen_h.update_negative_prompt_visibility,
|
| 114 |
+
inputs=[generation_section["init_llm_checkbox"]],
|
| 115 |
+
outputs=[generation_section["lm_negative_prompt"]]
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
generation_section["init_llm_checkbox"].change(
|
| 119 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 120 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
|
| 121 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
generation_section["task_type"].change(
|
| 125 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 126 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
|
| 127 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
generation_section["batch_size_input"].change(
|
| 131 |
+
fn=gen_h.update_audio_components_visibility,
|
| 132 |
+
inputs=[generation_section["batch_size_input"]],
|
| 133 |
+
outputs=[
|
| 134 |
+
results_section["audio_col_1"],
|
| 135 |
+
results_section["audio_col_2"],
|
| 136 |
+
results_section["audio_col_3"],
|
| 137 |
+
results_section["audio_col_4"],
|
| 138 |
+
results_section["audio_row_5_8"],
|
| 139 |
+
results_section["audio_col_5"],
|
| 140 |
+
results_section["audio_col_6"],
|
| 141 |
+
results_section["audio_col_7"],
|
| 142 |
+
results_section["audio_col_8"],
|
| 143 |
+
]
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# ========== Audio Conversion ==========
|
| 147 |
+
generation_section["convert_src_to_codes_btn"].click(
|
| 148 |
+
fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
|
| 149 |
+
inputs=[generation_section["src_audio"]],
|
| 150 |
+
outputs=[generation_section["text2music_audio_code_string"]]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# ========== Instruction UI Updates ==========
|
| 154 |
+
for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"]]:
|
| 155 |
+
trigger.change(
|
| 156 |
+
fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
|
| 157 |
+
inputs=[
|
| 158 |
+
generation_section["task_type"],
|
| 159 |
+
generation_section["track_name"],
|
| 160 |
+
generation_section["complete_track_classes"],
|
| 161 |
+
generation_section["text2music_audio_code_string"],
|
| 162 |
+
generation_section["init_llm_checkbox"]
|
| 163 |
+
],
|
| 164 |
+
outputs=[
|
| 165 |
+
generation_section["instruction_display_gen"],
|
| 166 |
+
generation_section["track_name"],
|
| 167 |
+
generation_section["complete_track_classes"],
|
| 168 |
+
generation_section["audio_cover_strength"],
|
| 169 |
+
generation_section["repainting_group"],
|
| 170 |
+
generation_section["text2music_audio_codes_group"],
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# ========== Sample/Transcribe Handlers ==========
|
| 175 |
+
# Load random example from ./examples/text2music directory
|
| 176 |
+
generation_section["sample_btn"].click(
|
| 177 |
+
fn=lambda task: gen_h.load_random_example(task) + (True,),
|
| 178 |
+
inputs=[
|
| 179 |
+
generation_section["task_type"],
|
| 180 |
+
],
|
| 181 |
+
outputs=[
|
| 182 |
+
generation_section["captions"],
|
| 183 |
+
generation_section["lyrics"],
|
| 184 |
+
generation_section["think_checkbox"],
|
| 185 |
+
generation_section["bpm"],
|
| 186 |
+
generation_section["audio_duration"],
|
| 187 |
+
generation_section["key_scale"],
|
| 188 |
+
generation_section["vocal_language"],
|
| 189 |
+
generation_section["time_signature"],
|
| 190 |
+
results_section["is_format_caption_state"]
|
| 191 |
+
]
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
generation_section["text2music_audio_code_string"].change(
|
| 195 |
+
fn=gen_h.update_transcribe_button_text,
|
| 196 |
+
inputs=[generation_section["text2music_audio_code_string"]],
|
| 197 |
+
outputs=[generation_section["transcribe_btn"]]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
generation_section["transcribe_btn"].click(
|
| 201 |
+
fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
|
| 202 |
+
inputs=[
|
| 203 |
+
generation_section["text2music_audio_code_string"],
|
| 204 |
+
generation_section["constrained_decoding_debug"]
|
| 205 |
+
],
|
| 206 |
+
outputs=[
|
| 207 |
+
results_section["status_output"],
|
| 208 |
+
generation_section["captions"],
|
| 209 |
+
generation_section["lyrics"],
|
| 210 |
+
generation_section["bpm"],
|
| 211 |
+
generation_section["audio_duration"],
|
| 212 |
+
generation_section["key_scale"],
|
| 213 |
+
generation_section["vocal_language"],
|
| 214 |
+
generation_section["time_signature"],
|
| 215 |
+
results_section["is_format_caption_state"]
|
| 216 |
+
]
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# ========== Reset Format Caption Flag ==========
|
| 220 |
+
for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
|
| 221 |
+
generation_section["key_scale"], generation_section["time_signature"],
|
| 222 |
+
generation_section["vocal_language"], generation_section["audio_duration"]]:
|
| 223 |
+
trigger.change(
|
| 224 |
+
fn=gen_h.reset_format_caption_flag,
|
| 225 |
+
inputs=[],
|
| 226 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
# ========== Audio Uploads Accordion ==========
|
| 230 |
+
for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
|
| 231 |
+
trigger.change(
|
| 232 |
+
fn=gen_h.update_audio_uploads_accordion,
|
| 233 |
+
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 234 |
+
outputs=[generation_section["audio_uploads_accordion"]]
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# ========== Instrumental Checkbox ==========
|
| 238 |
+
generation_section["instrumental_checkbox"].change(
|
| 239 |
+
fn=gen_h.handle_instrumental_checkbox,
|
| 240 |
+
inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
|
| 241 |
+
outputs=[generation_section["lyrics"]]
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# ========== Format Button ==========
|
| 245 |
+
# Note: cfg_scale and negative_prompt are not supported in format mode
|
| 246 |
+
generation_section["format_btn"].click(
|
| 247 |
+
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
|
| 248 |
+
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
|
| 249 |
+
),
|
| 250 |
+
inputs=[
|
| 251 |
+
generation_section["captions"],
|
| 252 |
+
generation_section["lyrics"],
|
| 253 |
+
generation_section["bpm"],
|
| 254 |
+
generation_section["audio_duration"],
|
| 255 |
+
generation_section["key_scale"],
|
| 256 |
+
generation_section["time_signature"],
|
| 257 |
+
generation_section["lm_temperature"],
|
| 258 |
+
generation_section["lm_top_k"],
|
| 259 |
+
generation_section["lm_top_p"],
|
| 260 |
+
generation_section["constrained_decoding_debug"],
|
| 261 |
+
],
|
| 262 |
+
outputs=[
|
| 263 |
+
generation_section["captions"],
|
| 264 |
+
generation_section["lyrics"],
|
| 265 |
+
generation_section["bpm"],
|
| 266 |
+
generation_section["audio_duration"],
|
| 267 |
+
generation_section["key_scale"],
|
| 268 |
+
generation_section["vocal_language"],
|
| 269 |
+
generation_section["time_signature"],
|
| 270 |
+
results_section["is_format_caption_state"],
|
| 271 |
+
results_section["status_output"],
|
| 272 |
+
]
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# ========== Generation Mode Toggle (Simple/Custom/Cover/Repaint) ==========
|
| 276 |
+
generation_section["generation_mode"].change(
|
| 277 |
+
fn=gen_h.handle_generation_mode_change,
|
| 278 |
+
inputs=[generation_section["generation_mode"]],
|
| 279 |
+
outputs=[
|
| 280 |
+
generation_section["simple_mode_group"],
|
| 281 |
+
generation_section["custom_mode_content"],
|
| 282 |
+
generation_section["cover_mode_group"],
|
| 283 |
+
generation_section["repainting_group"],
|
| 284 |
+
generation_section["task_type"],
|
| 285 |
+
generation_section["generate_btn"],
|
| 286 |
+
generation_section["simple_sample_created"],
|
| 287 |
+
generation_section["src_audio_group"],
|
| 288 |
+
generation_section["audio_cover_strength"],
|
| 289 |
+
generation_section["think_checkbox"], # Disable thinking for cover/repaint modes
|
| 290 |
+
]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# ========== Process Source Audio Button ==========
|
| 294 |
+
# Combines Convert to Codes + Transcribe in one step
|
| 295 |
+
generation_section["process_src_btn"].click(
|
| 296 |
+
fn=lambda src, debug: gen_h.process_source_audio(dit_handler, llm_handler, src, debug),
|
| 297 |
+
inputs=[
|
| 298 |
+
generation_section["src_audio"],
|
| 299 |
+
generation_section["constrained_decoding_debug"]
|
| 300 |
+
],
|
| 301 |
+
outputs=[
|
| 302 |
+
generation_section["text2music_audio_code_string"],
|
| 303 |
+
results_section["status_output"],
|
| 304 |
+
generation_section["captions"],
|
| 305 |
+
generation_section["lyrics"],
|
| 306 |
+
generation_section["bpm"],
|
| 307 |
+
generation_section["audio_duration"],
|
| 308 |
+
generation_section["key_scale"],
|
| 309 |
+
generation_section["vocal_language"],
|
| 310 |
+
generation_section["time_signature"],
|
| 311 |
+
results_section["is_format_caption_state"],
|
| 312 |
+
]
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# ========== Simple Mode Instrumental Checkbox ==========
|
| 316 |
+
# When instrumental is checked, disable vocal language and set to ["unknown"]
|
| 317 |
+
generation_section["simple_instrumental_checkbox"].change(
|
| 318 |
+
fn=gen_h.handle_simple_instrumental_change,
|
| 319 |
+
inputs=[generation_section["simple_instrumental_checkbox"]],
|
| 320 |
+
outputs=[generation_section["simple_vocal_language"]]
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
# ========== Random Description Button ==========
|
| 324 |
+
generation_section["random_desc_btn"].click(
|
| 325 |
+
fn=gen_h.load_random_simple_description,
|
| 326 |
+
inputs=[],
|
| 327 |
+
outputs=[
|
| 328 |
+
generation_section["simple_query_input"],
|
| 329 |
+
generation_section["simple_instrumental_checkbox"],
|
| 330 |
+
generation_section["simple_vocal_language"],
|
| 331 |
+
]
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
# ========== Create Sample Button (Simple Mode) ==========
|
| 335 |
+
# Note: cfg_scale and negative_prompt are not supported in create_sample mode
|
| 336 |
+
generation_section["create_sample_btn"].click(
|
| 337 |
+
fn=lambda query, instrumental, vocal_lang, temp, top_k, top_p, debug: gen_h.handle_create_sample(
|
| 338 |
+
llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
|
| 339 |
+
),
|
| 340 |
+
inputs=[
|
| 341 |
+
generation_section["simple_query_input"],
|
| 342 |
+
generation_section["simple_instrumental_checkbox"],
|
| 343 |
+
generation_section["simple_vocal_language"],
|
| 344 |
+
generation_section["lm_temperature"],
|
| 345 |
+
generation_section["lm_top_k"],
|
| 346 |
+
generation_section["lm_top_p"],
|
| 347 |
+
generation_section["constrained_decoding_debug"],
|
| 348 |
+
],
|
| 349 |
+
outputs=[
|
| 350 |
+
generation_section["captions"],
|
| 351 |
+
generation_section["lyrics"],
|
| 352 |
+
generation_section["bpm"],
|
| 353 |
+
generation_section["audio_duration"],
|
| 354 |
+
generation_section["key_scale"],
|
| 355 |
+
generation_section["vocal_language"],
|
| 356 |
+
generation_section["simple_vocal_language"],
|
| 357 |
+
generation_section["time_signature"],
|
| 358 |
+
generation_section["instrumental_checkbox"],
|
| 359 |
+
generation_section["caption_accordion"],
|
| 360 |
+
generation_section["lyrics_accordion"],
|
| 361 |
+
generation_section["generate_btn"],
|
| 362 |
+
generation_section["simple_sample_created"],
|
| 363 |
+
generation_section["think_checkbox"],
|
| 364 |
+
results_section["is_format_caption_state"],
|
| 365 |
+
results_section["status_output"],
|
| 366 |
+
]
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# ========== Load/Save Metadata ==========
|
| 370 |
+
generation_section["load_file"].upload(
|
| 371 |
+
fn=gen_h.load_metadata,
|
| 372 |
+
inputs=[generation_section["load_file"]],
|
| 373 |
+
outputs=[
|
| 374 |
+
generation_section["task_type"],
|
| 375 |
+
generation_section["captions"],
|
| 376 |
+
generation_section["lyrics"],
|
| 377 |
+
generation_section["vocal_language"],
|
| 378 |
+
generation_section["bpm"],
|
| 379 |
+
generation_section["key_scale"],
|
| 380 |
+
generation_section["time_signature"],
|
| 381 |
+
generation_section["audio_duration"],
|
| 382 |
+
generation_section["batch_size_input"],
|
| 383 |
+
generation_section["inference_steps"],
|
| 384 |
+
generation_section["guidance_scale"],
|
| 385 |
+
generation_section["seed"],
|
| 386 |
+
generation_section["random_seed_checkbox"],
|
| 387 |
+
generation_section["use_adg"],
|
| 388 |
+
generation_section["cfg_interval_start"],
|
| 389 |
+
generation_section["cfg_interval_end"],
|
| 390 |
+
generation_section["shift"],
|
| 391 |
+
generation_section["infer_method"],
|
| 392 |
+
generation_section["custom_timesteps"],
|
| 393 |
+
generation_section["audio_format"],
|
| 394 |
+
generation_section["lm_temperature"],
|
| 395 |
+
generation_section["lm_cfg_scale"],
|
| 396 |
+
generation_section["lm_top_k"],
|
| 397 |
+
generation_section["lm_top_p"],
|
| 398 |
+
generation_section["lm_negative_prompt"],
|
| 399 |
+
generation_section["use_cot_metas"], # Added: use_cot_metas
|
| 400 |
+
generation_section["use_cot_caption"],
|
| 401 |
+
generation_section["use_cot_language"],
|
| 402 |
+
generation_section["audio_cover_strength"],
|
| 403 |
+
generation_section["think_checkbox"],
|
| 404 |
+
generation_section["text2music_audio_code_string"],
|
| 405 |
+
generation_section["repainting_start"],
|
| 406 |
+
generation_section["repainting_end"],
|
| 407 |
+
generation_section["track_name"],
|
| 408 |
+
generation_section["complete_track_classes"],
|
| 409 |
+
generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
|
| 410 |
+
results_section["is_format_caption_state"]
|
| 411 |
+
]
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# Save buttons for all 8 audio outputs
|
| 415 |
+
download_existing_js = """(current_audio, batch_files) => {
|
| 416 |
+
// Debug: print what the input actually is
|
| 417 |
+
console.log("👉 [Debug] Current Audio Input:", current_audio);
|
| 418 |
+
|
| 419 |
+
// 1. Safety check
|
| 420 |
+
if (!current_audio) {
|
| 421 |
+
console.warn("⚠️ No audio selected or audio is empty.");
|
| 422 |
+
return;
|
| 423 |
+
}
|
| 424 |
+
if (!batch_files || !Array.isArray(batch_files)) {
|
| 425 |
+
console.warn("⚠️ Batch file list is empty/not ready.");
|
| 426 |
+
return;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
// 2. Smartly extract path string
|
| 430 |
+
let pathString = "";
|
| 431 |
+
|
| 432 |
+
if (typeof current_audio === "string") {
|
| 433 |
+
// Case A: direct path string received
|
| 434 |
+
pathString = current_audio;
|
| 435 |
+
} else if (typeof current_audio === "object") {
|
| 436 |
+
// Case B: an object is received, try common properties
|
| 437 |
+
// Gradio file objects usually have path, url, or name
|
| 438 |
+
pathString = current_audio.path || current_audio.name || current_audio.url || "";
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
if (!pathString) {
|
| 442 |
+
console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
|
| 443 |
+
return;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
// 3. Extract Key (UUID)
|
| 447 |
+
// Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
|
| 448 |
+
let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
|
| 449 |
+
let key = filename.split('.')[0]; // get UUID without extension
|
| 450 |
+
|
| 451 |
+
console.log(`🔑 Key extracted: ${key}`);
|
| 452 |
+
|
| 453 |
+
// 4. Find matching file(s) in the list
|
| 454 |
+
let targets = batch_files.filter(f => {
|
| 455 |
+
// Also extract names from batch_files objects
|
| 456 |
+
// f usually contains name (backend path) and orig_name (download name)
|
| 457 |
+
const fPath = f.name || f.path || "";
|
| 458 |
+
return fPath.includes(key);
|
| 459 |
+
});
|
| 460 |
+
|
| 461 |
+
if (targets.length === 0) {
|
| 462 |
+
console.warn("❌ No matching files found in batch list for key:", key);
|
| 463 |
+
alert("Batch list does not contain this file yet. Please wait for generation to finish.");
|
| 464 |
+
return;
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
// 5. Trigger download(s)
|
| 468 |
+
console.log(`🎯 Found ${targets.length} files to download.`);
|
| 469 |
+
targets.forEach((f, index) => {
|
| 470 |
+
setTimeout(() => {
|
| 471 |
+
const a = document.createElement('a');
|
| 472 |
+
// Prefer url (frontend-accessible link), otherwise try data
|
| 473 |
+
a.href = f.url || f.data;
|
| 474 |
+
a.download = f.orig_name || "download";
|
| 475 |
+
a.style.display = 'none';
|
| 476 |
+
document.body.appendChild(a);
|
| 477 |
+
a.click();
|
| 478 |
+
document.body.removeChild(a);
|
| 479 |
+
}, index * 1000); // 300ms interval to avoid browser blocking
|
| 480 |
+
});
|
| 481 |
+
}
|
| 482 |
+
"""
|
| 483 |
+
for btn_idx in range(1, 9):
|
| 484 |
+
results_section[f"save_btn_{btn_idx}"].click(
|
| 485 |
+
fn=None,
|
| 486 |
+
inputs=[
|
| 487 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 488 |
+
results_section["generated_audio_batch"],
|
| 489 |
+
],
|
| 490 |
+
js=download_existing_js # Run the above JS
|
| 491 |
+
)
|
| 492 |
+
# ========== Send to Cover Handlers ==========
|
| 493 |
+
def send_to_cover_handler(audio_file, lm_metadata):
|
| 494 |
+
"""Send audio to cover mode and switch to cover"""
|
| 495 |
+
if audio_file is None:
|
| 496 |
+
return (gr.skip(),) * 11
|
| 497 |
+
return (
|
| 498 |
+
audio_file, # src_audio
|
| 499 |
+
gr.skip(), # bpm
|
| 500 |
+
gr.skip(), # captions
|
| 501 |
+
gr.skip(), # lyrics
|
| 502 |
+
gr.skip(), # audio_duration
|
| 503 |
+
gr.skip(), # key_scale
|
| 504 |
+
gr.skip(), # vocal_language
|
| 505 |
+
gr.skip(), # time_signature
|
| 506 |
+
gr.skip(), # is_format_caption_state
|
| 507 |
+
"cover", # generation_mode - switch to cover
|
| 508 |
+
"cover", # task_type - set to cover
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
for btn_idx in range(1, 9):
|
| 512 |
+
results_section[f"send_to_cover_btn_{btn_idx}"].click(
|
| 513 |
+
fn=send_to_cover_handler,
|
| 514 |
+
inputs=[
|
| 515 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 516 |
+
results_section["lm_metadata_state"]
|
| 517 |
+
],
|
| 518 |
+
outputs=[
|
| 519 |
+
generation_section["src_audio"],
|
| 520 |
+
generation_section["bpm"],
|
| 521 |
+
generation_section["captions"],
|
| 522 |
+
generation_section["lyrics"],
|
| 523 |
+
generation_section["audio_duration"],
|
| 524 |
+
generation_section["key_scale"],
|
| 525 |
+
generation_section["vocal_language"],
|
| 526 |
+
generation_section["time_signature"],
|
| 527 |
+
results_section["is_format_caption_state"],
|
| 528 |
+
generation_section["generation_mode"],
|
| 529 |
+
generation_section["task_type"],
|
| 530 |
+
]
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
# ========== Send to Repaint Handlers ==========
|
| 534 |
+
def send_to_repaint_handler(audio_file, lm_metadata):
|
| 535 |
+
"""Send audio to repaint mode and switch to repaint"""
|
| 536 |
+
if audio_file is None:
|
| 537 |
+
return (gr.skip(),) * 11
|
| 538 |
+
return (
|
| 539 |
+
audio_file, # src_audio
|
| 540 |
+
gr.skip(), # bpm
|
| 541 |
+
gr.skip(), # captions
|
| 542 |
+
gr.skip(), # lyrics
|
| 543 |
+
gr.skip(), # audio_duration
|
| 544 |
+
gr.skip(), # key_scale
|
| 545 |
+
gr.skip(), # vocal_language
|
| 546 |
+
gr.skip(), # time_signature
|
| 547 |
+
gr.skip(), # is_format_caption_state
|
| 548 |
+
"repaint", # generation_mode - switch to repaint
|
| 549 |
+
"repaint", # task_type - set to repaint
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
for btn_idx in range(1, 9):
|
| 553 |
+
results_section[f"send_to_repaint_btn_{btn_idx}"].click(
|
| 554 |
+
fn=send_to_repaint_handler,
|
| 555 |
+
inputs=[
|
| 556 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 557 |
+
results_section["lm_metadata_state"]
|
| 558 |
+
],
|
| 559 |
+
outputs=[
|
| 560 |
+
generation_section["src_audio"],
|
| 561 |
+
generation_section["bpm"],
|
| 562 |
+
generation_section["captions"],
|
| 563 |
+
generation_section["lyrics"],
|
| 564 |
+
generation_section["audio_duration"],
|
| 565 |
+
generation_section["key_scale"],
|
| 566 |
+
generation_section["vocal_language"],
|
| 567 |
+
generation_section["time_signature"],
|
| 568 |
+
results_section["is_format_caption_state"],
|
| 569 |
+
generation_section["generation_mode"],
|
| 570 |
+
generation_section["task_type"],
|
| 571 |
+
]
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
# ========== Score Calculation Handlers ==========
|
| 575 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 576 |
+
def make_score_handler(idx):
|
| 577 |
+
return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
|
| 578 |
+
dit_handler, llm_handler, idx, scale, batch_idx, queue
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
for btn_idx in range(1, 9):
|
| 582 |
+
results_section[f"score_btn_{btn_idx}"].click(
|
| 583 |
+
fn=make_score_handler(btn_idx),
|
| 584 |
+
inputs=[
|
| 585 |
+
generation_section["score_scale"],
|
| 586 |
+
results_section["current_batch_index"],
|
| 587 |
+
results_section["batch_queue"],
|
| 588 |
+
],
|
| 589 |
+
outputs=[
|
| 590 |
+
results_section[f"score_display_{btn_idx}"],
|
| 591 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 592 |
+
results_section["batch_queue"]
|
| 593 |
+
]
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# ========== LRC Timestamp Handlers ==========
|
| 597 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 598 |
+
def make_lrc_handler(idx):
|
| 599 |
+
return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
|
| 600 |
+
dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
for btn_idx in range(1, 9):
|
| 604 |
+
results_section[f"lrc_btn_{btn_idx}"].click(
|
| 605 |
+
fn=make_lrc_handler(btn_idx),
|
| 606 |
+
inputs=[
|
| 607 |
+
results_section["current_batch_index"],
|
| 608 |
+
results_section["batch_queue"],
|
| 609 |
+
generation_section["vocal_language"],
|
| 610 |
+
generation_section["inference_steps"],
|
| 611 |
+
],
|
| 612 |
+
outputs=[
|
| 613 |
+
results_section[f"lrc_display_{btn_idx}"],
|
| 614 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 615 |
+
# NOTE: Removed generated_audio output!
|
| 616 |
+
# Audio subtitles are now updated via lrc_display.change() event.
|
| 617 |
+
results_section["batch_queue"]
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
|
| 622 |
+
"""Wrapper that selects the appropriate DiT handler based on model selection"""
|
| 623 |
+
# Convert args to list for modification
|
| 624 |
+
args_list = list(args)
|
| 625 |
+
|
| 626 |
+
# args order (after simple mode params):
|
| 627 |
+
# captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
|
| 628 |
+
# inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
|
| 629 |
+
# reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
|
| 630 |
+
# text2music_audio_code_string (14), repainting_start (15), repainting_end (16),
|
| 631 |
+
# instruction_display_gen (17), audio_cover_strength (18), task_type (19), ...
|
| 632 |
+
# ... lm_temperature (27), think_checkbox (28), ...
|
| 633 |
+
# ... instrumental_checkbox (at position after all regular params)
|
| 634 |
+
|
| 635 |
+
src_audio = args_list[13] if len(args_list) > 13 else None
|
| 636 |
+
task_type = args_list[19] if len(args_list) > 19 else "text2music"
|
| 637 |
+
|
| 638 |
+
# Validate: Cover and Repaint modes require source audio
|
| 639 |
+
if task_type in ["cover", "repaint"] and src_audio is None:
|
| 640 |
+
raise gr.Error(f"Source Audio is required for {task_type.capitalize()} mode. Please upload an audio file.")
|
| 641 |
+
|
| 642 |
+
# Handle Simple mode: first create sample, then generate
|
| 643 |
+
if generation_mode == "simple":
|
| 644 |
+
# Get instrumental from the main checkbox (args[-6] based on input order)
|
| 645 |
+
# The instrumental_checkbox is passed after all the regular generation params
|
| 646 |
+
instrumental = args_list[-6] if len(args_list) > 6 else False # instrumental_checkbox position
|
| 647 |
+
lm_temperature = args_list[27] if len(args_list) > 27 else 0.85
|
| 648 |
+
lm_top_k = args_list[30] if len(args_list) > 30 else 0
|
| 649 |
+
lm_top_p = args_list[31] if len(args_list) > 31 else 0.9
|
| 650 |
+
constrained_decoding_debug = args_list[38] if len(args_list) > 38 else False
|
| 651 |
+
|
| 652 |
+
# Call create_sample to generate caption/lyrics/metadata
|
| 653 |
+
from acestep.inference import create_sample
|
| 654 |
+
|
| 655 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 656 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 657 |
+
|
| 658 |
+
result = create_sample(
|
| 659 |
+
llm_handler=llm_handler,
|
| 660 |
+
query=simple_query_input,
|
| 661 |
+
instrumental=instrumental,
|
| 662 |
+
vocal_language=simple_vocal_language,
|
| 663 |
+
temperature=lm_temperature,
|
| 664 |
+
top_k=top_k_value,
|
| 665 |
+
top_p=top_p_value,
|
| 666 |
+
use_constrained_decoding=True,
|
| 667 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
if not result.success:
|
| 671 |
+
raise gr.Error(f"Failed to create sample: {result.status_message}")
|
| 672 |
+
|
| 673 |
+
# Update args with generated data
|
| 674 |
+
args_list[0] = result.caption # captions
|
| 675 |
+
args_list[1] = result.lyrics # lyrics
|
| 676 |
+
args_list[2] = result.bpm # bpm
|
| 677 |
+
args_list[3] = result.keyscale # key_scale
|
| 678 |
+
args_list[4] = result.timesignature # time_signature
|
| 679 |
+
args_list[5] = result.language # vocal_language
|
| 680 |
+
if result.duration and result.duration > 0:
|
| 681 |
+
args_list[11] = result.duration # audio_duration
|
| 682 |
+
# Enable thinking for Simple mode
|
| 683 |
+
args_list[28] = True # think_checkbox
|
| 684 |
+
# Mark as formatted caption (LM-generated sample)
|
| 685 |
+
args_list[36] = True # is_format_caption_state
|
| 686 |
+
|
| 687 |
+
# Determine which handler to use
|
| 688 |
+
active_handler = dit_handler # Default to primary handler
|
| 689 |
+
if dit_handler_2 is not None and selected_model == config_path_2:
|
| 690 |
+
active_handler = dit_handler_2
|
| 691 |
+
yield from res_h.generate_with_batch_management(active_handler, llm_handler, *args_list)
|
| 692 |
+
|
| 693 |
+
# ========== Generation Handler ==========
|
| 694 |
+
generation_section["generate_btn"].click(
|
| 695 |
+
fn=generation_wrapper,
|
| 696 |
+
inputs=[
|
| 697 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 698 |
+
generation_section["generation_mode"], # For Simple mode detection
|
| 699 |
+
generation_section["simple_query_input"], # Simple mode query
|
| 700 |
+
generation_section["simple_vocal_language"], # Simple mode vocal language
|
| 701 |
+
generation_section["captions"],
|
| 702 |
+
generation_section["lyrics"],
|
| 703 |
+
generation_section["bpm"],
|
| 704 |
+
generation_section["key_scale"],
|
| 705 |
+
generation_section["time_signature"],
|
| 706 |
+
generation_section["vocal_language"],
|
| 707 |
+
generation_section["inference_steps"],
|
| 708 |
+
generation_section["guidance_scale"],
|
| 709 |
+
generation_section["random_seed_checkbox"],
|
| 710 |
+
generation_section["seed"],
|
| 711 |
+
generation_section["reference_audio"],
|
| 712 |
+
generation_section["audio_duration"],
|
| 713 |
+
generation_section["batch_size_input"],
|
| 714 |
+
generation_section["src_audio"],
|
| 715 |
+
generation_section["text2music_audio_code_string"],
|
| 716 |
+
generation_section["repainting_start"],
|
| 717 |
+
generation_section["repainting_end"],
|
| 718 |
+
generation_section["instruction_display_gen"],
|
| 719 |
+
generation_section["audio_cover_strength"],
|
| 720 |
+
generation_section["task_type"],
|
| 721 |
+
generation_section["use_adg"],
|
| 722 |
+
generation_section["cfg_interval_start"],
|
| 723 |
+
generation_section["cfg_interval_end"],
|
| 724 |
+
generation_section["shift"],
|
| 725 |
+
generation_section["infer_method"],
|
| 726 |
+
generation_section["custom_timesteps"],
|
| 727 |
+
generation_section["audio_format"],
|
| 728 |
+
generation_section["lm_temperature"],
|
| 729 |
+
generation_section["think_checkbox"],
|
| 730 |
+
generation_section["lm_cfg_scale"],
|
| 731 |
+
generation_section["lm_top_k"],
|
| 732 |
+
generation_section["lm_top_p"],
|
| 733 |
+
generation_section["lm_negative_prompt"],
|
| 734 |
+
generation_section["use_cot_metas"],
|
| 735 |
+
generation_section["use_cot_caption"],
|
| 736 |
+
generation_section["use_cot_language"],
|
| 737 |
+
results_section["is_format_caption_state"],
|
| 738 |
+
generation_section["constrained_decoding_debug"],
|
| 739 |
+
generation_section["allow_lm_batch"],
|
| 740 |
+
generation_section["auto_score"],
|
| 741 |
+
generation_section["auto_lrc"],
|
| 742 |
+
generation_section["score_scale"],
|
| 743 |
+
generation_section["lm_batch_chunk_size"],
|
| 744 |
+
generation_section["track_name"],
|
| 745 |
+
generation_section["complete_track_classes"],
|
| 746 |
+
generation_section["autogen_checkbox"],
|
| 747 |
+
results_section["current_batch_index"],
|
| 748 |
+
results_section["total_batches"],
|
| 749 |
+
results_section["batch_queue"],
|
| 750 |
+
results_section["generation_params_state"],
|
| 751 |
+
],
|
| 752 |
+
outputs=[
|
| 753 |
+
results_section["generated_audio_1"],
|
| 754 |
+
results_section["generated_audio_2"],
|
| 755 |
+
results_section["generated_audio_3"],
|
| 756 |
+
results_section["generated_audio_4"],
|
| 757 |
+
results_section["generated_audio_5"],
|
| 758 |
+
results_section["generated_audio_6"],
|
| 759 |
+
results_section["generated_audio_7"],
|
| 760 |
+
results_section["generated_audio_8"],
|
| 761 |
+
results_section["generated_audio_batch"],
|
| 762 |
+
results_section["generation_info"],
|
| 763 |
+
results_section["status_output"],
|
| 764 |
+
generation_section["seed"],
|
| 765 |
+
results_section["score_display_1"],
|
| 766 |
+
results_section["score_display_2"],
|
| 767 |
+
results_section["score_display_3"],
|
| 768 |
+
results_section["score_display_4"],
|
| 769 |
+
results_section["score_display_5"],
|
| 770 |
+
results_section["score_display_6"],
|
| 771 |
+
results_section["score_display_7"],
|
| 772 |
+
results_section["score_display_8"],
|
| 773 |
+
results_section["codes_display_1"],
|
| 774 |
+
results_section["codes_display_2"],
|
| 775 |
+
results_section["codes_display_3"],
|
| 776 |
+
results_section["codes_display_4"],
|
| 777 |
+
results_section["codes_display_5"],
|
| 778 |
+
results_section["codes_display_6"],
|
| 779 |
+
results_section["codes_display_7"],
|
| 780 |
+
results_section["codes_display_8"],
|
| 781 |
+
results_section["details_accordion_1"],
|
| 782 |
+
results_section["details_accordion_2"],
|
| 783 |
+
results_section["details_accordion_3"],
|
| 784 |
+
results_section["details_accordion_4"],
|
| 785 |
+
results_section["details_accordion_5"],
|
| 786 |
+
results_section["details_accordion_6"],
|
| 787 |
+
results_section["details_accordion_7"],
|
| 788 |
+
results_section["details_accordion_8"],
|
| 789 |
+
results_section["lrc_display_1"],
|
| 790 |
+
results_section["lrc_display_2"],
|
| 791 |
+
results_section["lrc_display_3"],
|
| 792 |
+
results_section["lrc_display_4"],
|
| 793 |
+
results_section["lrc_display_5"],
|
| 794 |
+
results_section["lrc_display_6"],
|
| 795 |
+
results_section["lrc_display_7"],
|
| 796 |
+
results_section["lrc_display_8"],
|
| 797 |
+
results_section["lm_metadata_state"],
|
| 798 |
+
results_section["is_format_caption_state"],
|
| 799 |
+
results_section["current_batch_index"],
|
| 800 |
+
results_section["total_batches"],
|
| 801 |
+
results_section["batch_queue"],
|
| 802 |
+
results_section["generation_params_state"],
|
| 803 |
+
results_section["batch_indicator"],
|
| 804 |
+
results_section["prev_batch_btn"],
|
| 805 |
+
results_section["next_batch_btn"],
|
| 806 |
+
results_section["next_batch_status"],
|
| 807 |
+
results_section["restore_params_btn"],
|
| 808 |
+
]
|
| 809 |
+
).then(
|
| 810 |
+
fn=lambda selected_model, *args: res_h.generate_next_batch_background(
|
| 811 |
+
dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
|
| 812 |
+
llm_handler, *args
|
| 813 |
+
),
|
| 814 |
+
inputs=[
|
| 815 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 816 |
+
generation_section["autogen_checkbox"],
|
| 817 |
+
results_section["generation_params_state"],
|
| 818 |
+
results_section["current_batch_index"],
|
| 819 |
+
results_section["total_batches"],
|
| 820 |
+
results_section["batch_queue"],
|
| 821 |
+
results_section["is_format_caption_state"],
|
| 822 |
+
],
|
| 823 |
+
outputs=[
|
| 824 |
+
results_section["batch_queue"],
|
| 825 |
+
results_section["total_batches"],
|
| 826 |
+
results_section["next_batch_status"],
|
| 827 |
+
results_section["next_batch_btn"],
|
| 828 |
+
]
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
# ========== Batch Navigation Handlers ==========
|
| 832 |
+
results_section["prev_batch_btn"].click(
|
| 833 |
+
fn=res_h.navigate_to_previous_batch,
|
| 834 |
+
inputs=[
|
| 835 |
+
results_section["current_batch_index"],
|
| 836 |
+
results_section["batch_queue"],
|
| 837 |
+
],
|
| 838 |
+
outputs=[
|
| 839 |
+
results_section["generated_audio_1"],
|
| 840 |
+
results_section["generated_audio_2"],
|
| 841 |
+
results_section["generated_audio_3"],
|
| 842 |
+
results_section["generated_audio_4"],
|
| 843 |
+
results_section["generated_audio_5"],
|
| 844 |
+
results_section["generated_audio_6"],
|
| 845 |
+
results_section["generated_audio_7"],
|
| 846 |
+
results_section["generated_audio_8"],
|
| 847 |
+
results_section["generated_audio_batch"],
|
| 848 |
+
results_section["generation_info"],
|
| 849 |
+
results_section["current_batch_index"],
|
| 850 |
+
results_section["batch_indicator"],
|
| 851 |
+
results_section["prev_batch_btn"],
|
| 852 |
+
results_section["next_batch_btn"],
|
| 853 |
+
results_section["status_output"],
|
| 854 |
+
results_section["score_display_1"],
|
| 855 |
+
results_section["score_display_2"],
|
| 856 |
+
results_section["score_display_3"],
|
| 857 |
+
results_section["score_display_4"],
|
| 858 |
+
results_section["score_display_5"],
|
| 859 |
+
results_section["score_display_6"],
|
| 860 |
+
results_section["score_display_7"],
|
| 861 |
+
results_section["score_display_8"],
|
| 862 |
+
results_section["codes_display_1"],
|
| 863 |
+
results_section["codes_display_2"],
|
| 864 |
+
results_section["codes_display_3"],
|
| 865 |
+
results_section["codes_display_4"],
|
| 866 |
+
results_section["codes_display_5"],
|
| 867 |
+
results_section["codes_display_6"],
|
| 868 |
+
results_section["codes_display_7"],
|
| 869 |
+
results_section["codes_display_8"],
|
| 870 |
+
results_section["lrc_display_1"],
|
| 871 |
+
results_section["lrc_display_2"],
|
| 872 |
+
results_section["lrc_display_3"],
|
| 873 |
+
results_section["lrc_display_4"],
|
| 874 |
+
results_section["lrc_display_5"],
|
| 875 |
+
results_section["lrc_display_6"],
|
| 876 |
+
results_section["lrc_display_7"],
|
| 877 |
+
results_section["lrc_display_8"],
|
| 878 |
+
results_section["details_accordion_1"],
|
| 879 |
+
results_section["details_accordion_2"],
|
| 880 |
+
results_section["details_accordion_3"],
|
| 881 |
+
results_section["details_accordion_4"],
|
| 882 |
+
results_section["details_accordion_5"],
|
| 883 |
+
results_section["details_accordion_6"],
|
| 884 |
+
results_section["details_accordion_7"],
|
| 885 |
+
results_section["details_accordion_8"],
|
| 886 |
+
results_section["restore_params_btn"],
|
| 887 |
+
]
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
results_section["next_batch_btn"].click(
|
| 891 |
+
fn=res_h.capture_current_params,
|
| 892 |
+
inputs=[
|
| 893 |
+
generation_section["captions"],
|
| 894 |
+
generation_section["lyrics"],
|
| 895 |
+
generation_section["bpm"],
|
| 896 |
+
generation_section["key_scale"],
|
| 897 |
+
generation_section["time_signature"],
|
| 898 |
+
generation_section["vocal_language"],
|
| 899 |
+
generation_section["inference_steps"],
|
| 900 |
+
generation_section["guidance_scale"],
|
| 901 |
+
generation_section["random_seed_checkbox"],
|
| 902 |
+
generation_section["seed"],
|
| 903 |
+
generation_section["reference_audio"],
|
| 904 |
+
generation_section["audio_duration"],
|
| 905 |
+
generation_section["batch_size_input"],
|
| 906 |
+
generation_section["src_audio"],
|
| 907 |
+
generation_section["text2music_audio_code_string"],
|
| 908 |
+
generation_section["repainting_start"],
|
| 909 |
+
generation_section["repainting_end"],
|
| 910 |
+
generation_section["instruction_display_gen"],
|
| 911 |
+
generation_section["audio_cover_strength"],
|
| 912 |
+
generation_section["task_type"],
|
| 913 |
+
generation_section["use_adg"],
|
| 914 |
+
generation_section["cfg_interval_start"],
|
| 915 |
+
generation_section["cfg_interval_end"],
|
| 916 |
+
generation_section["shift"],
|
| 917 |
+
generation_section["infer_method"],
|
| 918 |
+
generation_section["custom_timesteps"],
|
| 919 |
+
generation_section["audio_format"],
|
| 920 |
+
generation_section["lm_temperature"],
|
| 921 |
+
generation_section["think_checkbox"],
|
| 922 |
+
generation_section["lm_cfg_scale"],
|
| 923 |
+
generation_section["lm_top_k"],
|
| 924 |
+
generation_section["lm_top_p"],
|
| 925 |
+
generation_section["lm_negative_prompt"],
|
| 926 |
+
generation_section["use_cot_metas"],
|
| 927 |
+
generation_section["use_cot_caption"],
|
| 928 |
+
generation_section["use_cot_language"],
|
| 929 |
+
generation_section["constrained_decoding_debug"],
|
| 930 |
+
generation_section["allow_lm_batch"],
|
| 931 |
+
generation_section["auto_score"],
|
| 932 |
+
generation_section["auto_lrc"],
|
| 933 |
+
generation_section["score_scale"],
|
| 934 |
+
generation_section["lm_batch_chunk_size"],
|
| 935 |
+
generation_section["track_name"],
|
| 936 |
+
generation_section["complete_track_classes"],
|
| 937 |
+
],
|
| 938 |
+
outputs=[results_section["generation_params_state"]]
|
| 939 |
+
).then(
|
| 940 |
+
fn=res_h.navigate_to_next_batch,
|
| 941 |
+
inputs=[
|
| 942 |
+
generation_section["autogen_checkbox"],
|
| 943 |
+
results_section["current_batch_index"],
|
| 944 |
+
results_section["total_batches"],
|
| 945 |
+
results_section["batch_queue"],
|
| 946 |
+
],
|
| 947 |
+
outputs=[
|
| 948 |
+
results_section["generated_audio_1"],
|
| 949 |
+
results_section["generated_audio_2"],
|
| 950 |
+
results_section["generated_audio_3"],
|
| 951 |
+
results_section["generated_audio_4"],
|
| 952 |
+
results_section["generated_audio_5"],
|
| 953 |
+
results_section["generated_audio_6"],
|
| 954 |
+
results_section["generated_audio_7"],
|
| 955 |
+
results_section["generated_audio_8"],
|
| 956 |
+
results_section["generated_audio_batch"],
|
| 957 |
+
results_section["generation_info"],
|
| 958 |
+
results_section["current_batch_index"],
|
| 959 |
+
results_section["batch_indicator"],
|
| 960 |
+
results_section["prev_batch_btn"],
|
| 961 |
+
results_section["next_batch_btn"],
|
| 962 |
+
results_section["status_output"],
|
| 963 |
+
results_section["next_batch_status"],
|
| 964 |
+
results_section["score_display_1"],
|
| 965 |
+
results_section["score_display_2"],
|
| 966 |
+
results_section["score_display_3"],
|
| 967 |
+
results_section["score_display_4"],
|
| 968 |
+
results_section["score_display_5"],
|
| 969 |
+
results_section["score_display_6"],
|
| 970 |
+
results_section["score_display_7"],
|
| 971 |
+
results_section["score_display_8"],
|
| 972 |
+
results_section["codes_display_1"],
|
| 973 |
+
results_section["codes_display_2"],
|
| 974 |
+
results_section["codes_display_3"],
|
| 975 |
+
results_section["codes_display_4"],
|
| 976 |
+
results_section["codes_display_5"],
|
| 977 |
+
results_section["codes_display_6"],
|
| 978 |
+
results_section["codes_display_7"],
|
| 979 |
+
results_section["codes_display_8"],
|
| 980 |
+
results_section["lrc_display_1"],
|
| 981 |
+
results_section["lrc_display_2"],
|
| 982 |
+
results_section["lrc_display_3"],
|
| 983 |
+
results_section["lrc_display_4"],
|
| 984 |
+
results_section["lrc_display_5"],
|
| 985 |
+
results_section["lrc_display_6"],
|
| 986 |
+
results_section["lrc_display_7"],
|
| 987 |
+
results_section["lrc_display_8"],
|
| 988 |
+
results_section["details_accordion_1"],
|
| 989 |
+
results_section["details_accordion_2"],
|
| 990 |
+
results_section["details_accordion_3"],
|
| 991 |
+
results_section["details_accordion_4"],
|
| 992 |
+
results_section["details_accordion_5"],
|
| 993 |
+
results_section["details_accordion_6"],
|
| 994 |
+
results_section["details_accordion_7"],
|
| 995 |
+
results_section["details_accordion_8"],
|
| 996 |
+
results_section["restore_params_btn"],
|
| 997 |
+
]
|
| 998 |
+
).then(
|
| 999 |
+
fn=lambda selected_model, *args: res_h.generate_next_batch_background(
|
| 1000 |
+
dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
|
| 1001 |
+
llm_handler, *args
|
| 1002 |
+
),
|
| 1003 |
+
inputs=[
|
| 1004 |
+
generation_section["dit_model_selector"], # Model selection input
|
| 1005 |
+
generation_section["autogen_checkbox"],
|
| 1006 |
+
results_section["generation_params_state"],
|
| 1007 |
+
results_section["current_batch_index"],
|
| 1008 |
+
results_section["total_batches"],
|
| 1009 |
+
results_section["batch_queue"],
|
| 1010 |
+
results_section["is_format_caption_state"],
|
| 1011 |
+
],
|
| 1012 |
+
outputs=[
|
| 1013 |
+
results_section["batch_queue"],
|
| 1014 |
+
results_section["total_batches"],
|
| 1015 |
+
results_section["next_batch_status"],
|
| 1016 |
+
results_section["next_batch_btn"],
|
| 1017 |
+
]
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
# ========== Restore Parameters Handler ==========
|
| 1021 |
+
results_section["restore_params_btn"].click(
|
| 1022 |
+
fn=res_h.restore_batch_parameters,
|
| 1023 |
+
inputs=[
|
| 1024 |
+
results_section["current_batch_index"],
|
| 1025 |
+
results_section["batch_queue"]
|
| 1026 |
+
],
|
| 1027 |
+
outputs=[
|
| 1028 |
+
generation_section["text2music_audio_code_string"],
|
| 1029 |
+
generation_section["captions"],
|
| 1030 |
+
generation_section["lyrics"],
|
| 1031 |
+
generation_section["bpm"],
|
| 1032 |
+
generation_section["key_scale"],
|
| 1033 |
+
generation_section["time_signature"],
|
| 1034 |
+
generation_section["vocal_language"],
|
| 1035 |
+
generation_section["audio_duration"],
|
| 1036 |
+
generation_section["batch_size_input"],
|
| 1037 |
+
generation_section["inference_steps"],
|
| 1038 |
+
generation_section["lm_temperature"],
|
| 1039 |
+
generation_section["lm_cfg_scale"],
|
| 1040 |
+
generation_section["lm_top_k"],
|
| 1041 |
+
generation_section["lm_top_p"],
|
| 1042 |
+
generation_section["think_checkbox"],
|
| 1043 |
+
generation_section["use_cot_caption"],
|
| 1044 |
+
generation_section["use_cot_language"],
|
| 1045 |
+
generation_section["allow_lm_batch"],
|
| 1046 |
+
generation_section["track_name"],
|
| 1047 |
+
generation_section["complete_track_classes"],
|
| 1048 |
+
]
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
# ========== LRC Display Change Handlers ==========
|
| 1052 |
+
# NEW APPROACH: Use lrc_display.change() to update audio subtitles
|
| 1053 |
+
# This decouples audio value updates from subtitle updates, avoiding flickering.
|
| 1054 |
+
#
|
| 1055 |
+
# When lrc_display text changes (from generate, LRC button, or manual edit):
|
| 1056 |
+
# 1. lrc_display.change() is triggered
|
| 1057 |
+
# 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
|
| 1058 |
+
# 3. Audio value is NEVER updated here - only subtitles
|
| 1059 |
+
for lrc_idx in range(1, 9):
|
| 1060 |
+
results_section[f"lrc_display_{lrc_idx}"].change(
|
| 1061 |
+
fn=res_h.update_audio_subtitles_from_lrc,
|
| 1062 |
+
inputs=[
|
| 1063 |
+
results_section[f"lrc_display_{lrc_idx}"],
|
| 1064 |
+
# audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
|
| 1065 |
+
],
|
| 1066 |
+
outputs=[
|
| 1067 |
+
results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
|
| 1068 |
+
]
|
| 1069 |
+
)
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
|
| 1073 |
+
"""Setup event handlers for the training tab (dataset builder and LoRA training)"""
|
| 1074 |
+
|
| 1075 |
+
# ========== Load Existing Dataset (Top Section) ==========
|
| 1076 |
+
|
| 1077 |
+
# Load existing dataset JSON at the top of Dataset Builder
|
| 1078 |
+
training_section["load_json_btn"].click(
|
| 1079 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 1080 |
+
inputs=[
|
| 1081 |
+
training_section["load_json_path"],
|
| 1082 |
+
training_section["dataset_builder_state"],
|
| 1083 |
+
],
|
| 1084 |
+
outputs=[
|
| 1085 |
+
training_section["load_json_status"],
|
| 1086 |
+
training_section["audio_files_table"],
|
| 1087 |
+
training_section["sample_selector"],
|
| 1088 |
+
training_section["dataset_builder_state"],
|
| 1089 |
+
# Also update preview fields with first sample
|
| 1090 |
+
training_section["preview_audio"],
|
| 1091 |
+
training_section["preview_filename"],
|
| 1092 |
+
training_section["edit_caption"],
|
| 1093 |
+
training_section["edit_lyrics"],
|
| 1094 |
+
training_section["edit_bpm"],
|
| 1095 |
+
training_section["edit_keyscale"],
|
| 1096 |
+
training_section["edit_timesig"],
|
| 1097 |
+
training_section["edit_duration"],
|
| 1098 |
+
training_section["edit_language"],
|
| 1099 |
+
training_section["edit_instrumental"],
|
| 1100 |
+
]
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
# ========== Dataset Builder Handlers ==========
|
| 1104 |
+
|
| 1105 |
+
# Scan directory for audio files
|
| 1106 |
+
training_section["scan_btn"].click(
|
| 1107 |
+
fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
|
| 1108 |
+
dir, name, tag, pos, instr, state
|
| 1109 |
+
),
|
| 1110 |
+
inputs=[
|
| 1111 |
+
training_section["audio_directory"],
|
| 1112 |
+
training_section["dataset_name"],
|
| 1113 |
+
training_section["custom_tag"],
|
| 1114 |
+
training_section["tag_position"],
|
| 1115 |
+
training_section["all_instrumental"],
|
| 1116 |
+
training_section["dataset_builder_state"],
|
| 1117 |
+
],
|
| 1118 |
+
outputs=[
|
| 1119 |
+
training_section["audio_files_table"],
|
| 1120 |
+
training_section["scan_status"],
|
| 1121 |
+
training_section["sample_selector"],
|
| 1122 |
+
training_section["dataset_builder_state"],
|
| 1123 |
+
]
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
# Auto-label all samples
|
| 1127 |
+
training_section["auto_label_btn"].click(
|
| 1128 |
+
fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
|
| 1129 |
+
inputs=[
|
| 1130 |
+
training_section["dataset_builder_state"],
|
| 1131 |
+
training_section["skip_metas"],
|
| 1132 |
+
],
|
| 1133 |
+
outputs=[
|
| 1134 |
+
training_section["audio_files_table"],
|
| 1135 |
+
training_section["label_progress"],
|
| 1136 |
+
training_section["dataset_builder_state"],
|
| 1137 |
+
]
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
# Sample selector change - update preview
|
| 1141 |
+
training_section["sample_selector"].change(
|
| 1142 |
+
fn=train_h.get_sample_preview,
|
| 1143 |
+
inputs=[
|
| 1144 |
+
training_section["sample_selector"],
|
| 1145 |
+
training_section["dataset_builder_state"],
|
| 1146 |
+
],
|
| 1147 |
+
outputs=[
|
| 1148 |
+
training_section["preview_audio"],
|
| 1149 |
+
training_section["preview_filename"],
|
| 1150 |
+
training_section["edit_caption"],
|
| 1151 |
+
training_section["edit_lyrics"],
|
| 1152 |
+
training_section["edit_bpm"],
|
| 1153 |
+
training_section["edit_keyscale"],
|
| 1154 |
+
training_section["edit_timesig"],
|
| 1155 |
+
training_section["edit_duration"],
|
| 1156 |
+
training_section["edit_language"],
|
| 1157 |
+
training_section["edit_instrumental"],
|
| 1158 |
+
]
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
# Save sample edit
|
| 1162 |
+
training_section["save_edit_btn"].click(
|
| 1163 |
+
fn=train_h.save_sample_edit,
|
| 1164 |
+
inputs=[
|
| 1165 |
+
training_section["sample_selector"],
|
| 1166 |
+
training_section["edit_caption"],
|
| 1167 |
+
training_section["edit_lyrics"],
|
| 1168 |
+
training_section["edit_bpm"],
|
| 1169 |
+
training_section["edit_keyscale"],
|
| 1170 |
+
training_section["edit_timesig"],
|
| 1171 |
+
training_section["edit_language"],
|
| 1172 |
+
training_section["edit_instrumental"],
|
| 1173 |
+
training_section["dataset_builder_state"],
|
| 1174 |
+
],
|
| 1175 |
+
outputs=[
|
| 1176 |
+
training_section["audio_files_table"],
|
| 1177 |
+
training_section["edit_status"],
|
| 1178 |
+
training_section["dataset_builder_state"],
|
| 1179 |
+
]
|
| 1180 |
+
)
|
| 1181 |
+
|
| 1182 |
+
# Update settings when changed
|
| 1183 |
+
for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
|
| 1184 |
+
trigger.change(
|
| 1185 |
+
fn=train_h.update_settings,
|
| 1186 |
+
inputs=[
|
| 1187 |
+
training_section["custom_tag"],
|
| 1188 |
+
training_section["tag_position"],
|
| 1189 |
+
training_section["all_instrumental"],
|
| 1190 |
+
training_section["dataset_builder_state"],
|
| 1191 |
+
],
|
| 1192 |
+
outputs=[training_section["dataset_builder_state"]]
|
| 1193 |
+
)
|
| 1194 |
+
|
| 1195 |
+
# Save dataset
|
| 1196 |
+
training_section["save_dataset_btn"].click(
|
| 1197 |
+
fn=train_h.save_dataset,
|
| 1198 |
+
inputs=[
|
| 1199 |
+
training_section["save_path"],
|
| 1200 |
+
training_section["dataset_name"],
|
| 1201 |
+
training_section["dataset_builder_state"],
|
| 1202 |
+
],
|
| 1203 |
+
outputs=[training_section["save_status"]]
|
| 1204 |
+
)
|
| 1205 |
+
|
| 1206 |
+
# ========== Preprocess Handlers ==========
|
| 1207 |
+
|
| 1208 |
+
# Load existing dataset JSON for preprocessing
|
| 1209 |
+
# This also updates the preview section so users can view/edit samples
|
| 1210 |
+
training_section["load_existing_dataset_btn"].click(
|
| 1211 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 1212 |
+
inputs=[
|
| 1213 |
+
training_section["load_existing_dataset_path"],
|
| 1214 |
+
training_section["dataset_builder_state"],
|
| 1215 |
+
],
|
| 1216 |
+
outputs=[
|
| 1217 |
+
training_section["load_existing_status"],
|
| 1218 |
+
training_section["audio_files_table"],
|
| 1219 |
+
training_section["sample_selector"],
|
| 1220 |
+
training_section["dataset_builder_state"],
|
| 1221 |
+
# Also update preview fields with first sample
|
| 1222 |
+
training_section["preview_audio"],
|
| 1223 |
+
training_section["preview_filename"],
|
| 1224 |
+
training_section["edit_caption"],
|
| 1225 |
+
training_section["edit_lyrics"],
|
| 1226 |
+
training_section["edit_bpm"],
|
| 1227 |
+
training_section["edit_keyscale"],
|
| 1228 |
+
training_section["edit_timesig"],
|
| 1229 |
+
training_section["edit_duration"],
|
| 1230 |
+
training_section["edit_language"],
|
| 1231 |
+
training_section["edit_instrumental"],
|
| 1232 |
+
]
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
# Preprocess dataset to tensor files
|
| 1236 |
+
training_section["preprocess_btn"].click(
|
| 1237 |
+
fn=lambda output_dir, state: train_h.preprocess_dataset(
|
| 1238 |
+
output_dir, dit_handler, state
|
| 1239 |
+
),
|
| 1240 |
+
inputs=[
|
| 1241 |
+
training_section["preprocess_output_dir"],
|
| 1242 |
+
training_section["dataset_builder_state"],
|
| 1243 |
+
],
|
| 1244 |
+
outputs=[training_section["preprocess_progress"]]
|
| 1245 |
+
)
|
| 1246 |
+
|
| 1247 |
+
# ========== Training Tab Handlers ==========
|
| 1248 |
+
|
| 1249 |
+
# Load preprocessed tensor dataset
|
| 1250 |
+
training_section["load_dataset_btn"].click(
|
| 1251 |
+
fn=train_h.load_training_dataset,
|
| 1252 |
+
inputs=[training_section["training_tensor_dir"]],
|
| 1253 |
+
outputs=[training_section["training_dataset_info"]]
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
# Start training from preprocessed tensors
|
| 1257 |
+
def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
|
| 1258 |
+
try:
|
| 1259 |
+
for progress, log, plot, state in train_h.start_training(
|
| 1260 |
+
tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
|
| 1261 |
+
):
|
| 1262 |
+
yield progress, log, plot, state
|
| 1263 |
+
except Exception as e:
|
| 1264 |
+
logger.exception("Training wrapper error")
|
| 1265 |
+
yield f"❌ Error: {str(e)}", str(e), None, ts
|
| 1266 |
+
|
| 1267 |
+
training_section["start_training_btn"].click(
|
| 1268 |
+
fn=training_wrapper,
|
| 1269 |
+
inputs=[
|
| 1270 |
+
training_section["training_tensor_dir"],
|
| 1271 |
+
training_section["lora_rank"],
|
| 1272 |
+
training_section["lora_alpha"],
|
| 1273 |
+
training_section["lora_dropout"],
|
| 1274 |
+
training_section["learning_rate"],
|
| 1275 |
+
training_section["train_epochs"],
|
| 1276 |
+
training_section["train_batch_size"],
|
| 1277 |
+
training_section["gradient_accumulation"],
|
| 1278 |
+
training_section["save_every_n_epochs"],
|
| 1279 |
+
training_section["training_shift"],
|
| 1280 |
+
training_section["training_seed"],
|
| 1281 |
+
training_section["lora_output_dir"],
|
| 1282 |
+
training_section["training_state"],
|
| 1283 |
+
],
|
| 1284 |
+
outputs=[
|
| 1285 |
+
training_section["training_progress"],
|
| 1286 |
+
training_section["training_log"],
|
| 1287 |
+
training_section["training_loss_plot"],
|
| 1288 |
+
training_section["training_state"],
|
| 1289 |
+
]
|
| 1290 |
+
)
|
| 1291 |
+
|
| 1292 |
+
# Stop training
|
| 1293 |
+
training_section["stop_training_btn"].click(
|
| 1294 |
+
fn=train_h.stop_training,
|
| 1295 |
+
inputs=[training_section["training_state"]],
|
| 1296 |
+
outputs=[
|
| 1297 |
+
training_section["training_progress"],
|
| 1298 |
+
training_section["training_state"],
|
| 1299 |
+
]
|
| 1300 |
+
)
|
| 1301 |
+
|
| 1302 |
+
# Export LoRA
|
| 1303 |
+
training_section["export_lora_btn"].click(
|
| 1304 |
+
fn=train_h.export_lora,
|
| 1305 |
+
inputs=[
|
| 1306 |
+
training_section["export_path"],
|
| 1307 |
+
training_section["lora_output_dir"],
|
| 1308 |
+
],
|
| 1309 |
+
outputs=[training_section["export_status"]]
|
| 1310 |
+
)
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/generation_handlers.py
ADDED
|
@@ -0,0 +1,1054 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation Input Handlers Module
|
| 3 |
+
Contains event handlers and helper functions related to generation inputs
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import glob
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from typing import Optional, List, Tuple
|
| 11 |
+
from acestep.constants import (
|
| 12 |
+
TASK_TYPES_TURBO,
|
| 13 |
+
TASK_TYPES_BASE,
|
| 14 |
+
)
|
| 15 |
+
from acestep.gradio_ui.i18n import t
|
| 16 |
+
from acestep.inference import understand_music, create_sample, format_sample
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_and_validate_timesteps(
|
| 20 |
+
timesteps_str: str,
|
| 21 |
+
inference_steps: int
|
| 22 |
+
) -> Tuple[Optional[List[float]], bool, str]:
|
| 23 |
+
"""
|
| 24 |
+
Parse timesteps string and validate.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 28 |
+
inference_steps: Expected number of inference steps
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
Tuple of (parsed_timesteps, has_warning, warning_message)
|
| 32 |
+
- parsed_timesteps: List of float timesteps, or None if invalid/empty
|
| 33 |
+
- has_warning: Whether a warning was shown
|
| 34 |
+
- warning_message: Description of the warning
|
| 35 |
+
"""
|
| 36 |
+
if not timesteps_str or not timesteps_str.strip():
|
| 37 |
+
return None, False, ""
|
| 38 |
+
|
| 39 |
+
# Parse comma-separated values
|
| 40 |
+
values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
|
| 41 |
+
|
| 42 |
+
if not values:
|
| 43 |
+
return None, False, ""
|
| 44 |
+
|
| 45 |
+
# Handle optional trailing 0
|
| 46 |
+
if values[-1] != "0":
|
| 47 |
+
values.append("0")
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
timesteps = [float(v) for v in values]
|
| 51 |
+
except ValueError:
|
| 52 |
+
gr.Warning(t("messages.invalid_timesteps_format"))
|
| 53 |
+
return None, True, "Invalid format"
|
| 54 |
+
|
| 55 |
+
# Validate range [0, 1]
|
| 56 |
+
if any(ts < 0 or ts > 1 for ts in timesteps):
|
| 57 |
+
gr.Warning(t("messages.timesteps_out_of_range"))
|
| 58 |
+
return None, True, "Out of range"
|
| 59 |
+
|
| 60 |
+
# Check if count matches inference_steps
|
| 61 |
+
actual_steps = len(timesteps) - 1
|
| 62 |
+
if actual_steps != inference_steps:
|
| 63 |
+
gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
|
| 64 |
+
return timesteps, True, f"Using {actual_steps} steps from timesteps"
|
| 65 |
+
|
| 66 |
+
return timesteps, False, ""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def load_metadata(file_obj):
|
| 70 |
+
"""Load generation parameters from a JSON file"""
|
| 71 |
+
if file_obj is None:
|
| 72 |
+
gr.Warning(t("messages.no_file_selected"))
|
| 73 |
+
return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# Read the uploaded file
|
| 77 |
+
if hasattr(file_obj, 'name'):
|
| 78 |
+
filepath = file_obj.name
|
| 79 |
+
else:
|
| 80 |
+
filepath = file_obj
|
| 81 |
+
|
| 82 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 83 |
+
metadata = json.load(f)
|
| 84 |
+
|
| 85 |
+
# Extract all fields
|
| 86 |
+
task_type = metadata.get('task_type', 'text2music')
|
| 87 |
+
captions = metadata.get('caption', '')
|
| 88 |
+
lyrics = metadata.get('lyrics', '')
|
| 89 |
+
vocal_language = metadata.get('vocal_language', 'unknown')
|
| 90 |
+
|
| 91 |
+
# Convert bpm
|
| 92 |
+
bpm_value = metadata.get('bpm')
|
| 93 |
+
if bpm_value is not None and bpm_value != "N/A":
|
| 94 |
+
try:
|
| 95 |
+
bpm = int(bpm_value) if bpm_value else None
|
| 96 |
+
except:
|
| 97 |
+
bpm = None
|
| 98 |
+
else:
|
| 99 |
+
bpm = None
|
| 100 |
+
|
| 101 |
+
key_scale = metadata.get('keyscale', '')
|
| 102 |
+
time_signature = metadata.get('timesignature', '')
|
| 103 |
+
|
| 104 |
+
# Convert duration
|
| 105 |
+
duration_value = metadata.get('duration', -1)
|
| 106 |
+
if duration_value is not None and duration_value != "N/A":
|
| 107 |
+
try:
|
| 108 |
+
audio_duration = float(duration_value)
|
| 109 |
+
except:
|
| 110 |
+
audio_duration = -1
|
| 111 |
+
else:
|
| 112 |
+
audio_duration = -1
|
| 113 |
+
|
| 114 |
+
batch_size = metadata.get('batch_size', 2)
|
| 115 |
+
inference_steps = metadata.get('inference_steps', 8)
|
| 116 |
+
guidance_scale = metadata.get('guidance_scale', 7.0)
|
| 117 |
+
seed = metadata.get('seed', '-1')
|
| 118 |
+
random_seed = False # Always set to False when loading to enable reproducibility with saved seed
|
| 119 |
+
use_adg = metadata.get('use_adg', False)
|
| 120 |
+
cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
|
| 121 |
+
cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
|
| 122 |
+
audio_format = metadata.get('audio_format', 'mp3')
|
| 123 |
+
lm_temperature = metadata.get('lm_temperature', 0.85)
|
| 124 |
+
lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
|
| 125 |
+
lm_top_k = metadata.get('lm_top_k', 0)
|
| 126 |
+
lm_top_p = metadata.get('lm_top_p', 0.9)
|
| 127 |
+
lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
|
| 128 |
+
use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
|
| 129 |
+
use_cot_caption = metadata.get('use_cot_caption', True)
|
| 130 |
+
use_cot_language = metadata.get('use_cot_language', True)
|
| 131 |
+
audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
|
| 132 |
+
think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
|
| 133 |
+
audio_codes = metadata.get('audio_codes', '')
|
| 134 |
+
repainting_start = metadata.get('repainting_start', 0.0)
|
| 135 |
+
repainting_end = metadata.get('repainting_end', -1)
|
| 136 |
+
track_name = metadata.get('track_name')
|
| 137 |
+
complete_track_classes = metadata.get('complete_track_classes', [])
|
| 138 |
+
shift = metadata.get('shift', 3.0) # Default 3.0 for base models
|
| 139 |
+
infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
|
| 140 |
+
custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
|
| 141 |
+
if custom_timesteps is None:
|
| 142 |
+
custom_timesteps = ''
|
| 143 |
+
instrumental = metadata.get('instrumental', False) # Added: read instrumental
|
| 144 |
+
|
| 145 |
+
gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
|
| 146 |
+
|
| 147 |
+
return (
|
| 148 |
+
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
|
| 149 |
+
audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
|
| 150 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
|
| 151 |
+
custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
|
| 152 |
+
audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 153 |
+
use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
|
| 154 |
+
think, audio_codes, repainting_start, repainting_end,
|
| 155 |
+
track_name, complete_track_classes, instrumental,
|
| 156 |
+
True # Set is_format_caption to True when loading from file
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
except json.JSONDecodeError as e:
|
| 160 |
+
gr.Warning(t("messages.invalid_json", error=str(e)))
|
| 161 |
+
return [None] * 36 + [False]
|
| 162 |
+
except Exception as e:
|
| 163 |
+
gr.Warning(t("messages.load_error", error=str(e)))
|
| 164 |
+
return [None] * 36 + [False]
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def load_random_example(task_type: str):
|
| 168 |
+
"""Load a random example from the task-specific examples directory
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
task_type: The task type (e.g., "text2music")
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
| 175 |
+
"""
|
| 176 |
+
try:
|
| 177 |
+
# Get the project root directory
|
| 178 |
+
current_file = os.path.abspath(__file__)
|
| 179 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 180 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 181 |
+
|
| 182 |
+
# Construct the examples directory path
|
| 183 |
+
examples_dir = os.path.join(project_root, "examples", task_type)
|
| 184 |
+
|
| 185 |
+
# Check if directory exists
|
| 186 |
+
if not os.path.exists(examples_dir):
|
| 187 |
+
gr.Warning(f"Examples directory not found: examples/{task_type}/")
|
| 188 |
+
return "", "", True, None, None, "", "", ""
|
| 189 |
+
|
| 190 |
+
# Find all JSON files in the directory
|
| 191 |
+
json_files = glob.glob(os.path.join(examples_dir, "*.json"))
|
| 192 |
+
|
| 193 |
+
if not json_files:
|
| 194 |
+
gr.Warning(f"No JSON files found in examples/{task_type}/")
|
| 195 |
+
return "", "", True, None, None, "", "", ""
|
| 196 |
+
|
| 197 |
+
# Randomly select one file
|
| 198 |
+
selected_file = random.choice(json_files)
|
| 199 |
+
|
| 200 |
+
# Read and parse JSON
|
| 201 |
+
try:
|
| 202 |
+
with open(selected_file, 'r', encoding='utf-8') as f:
|
| 203 |
+
data = json.load(f)
|
| 204 |
+
|
| 205 |
+
# Extract caption (prefer 'caption', fallback to 'prompt')
|
| 206 |
+
caption_value = data.get('caption', data.get('prompt', ''))
|
| 207 |
+
if not isinstance(caption_value, str):
|
| 208 |
+
caption_value = str(caption_value) if caption_value else ''
|
| 209 |
+
|
| 210 |
+
# Extract lyrics
|
| 211 |
+
lyrics_value = data.get('lyrics', '')
|
| 212 |
+
if not isinstance(lyrics_value, str):
|
| 213 |
+
lyrics_value = str(lyrics_value) if lyrics_value else ''
|
| 214 |
+
|
| 215 |
+
# Extract think (default to True if not present)
|
| 216 |
+
think_value = data.get('think', True)
|
| 217 |
+
if not isinstance(think_value, bool):
|
| 218 |
+
think_value = True
|
| 219 |
+
|
| 220 |
+
# Extract optional metadata fields
|
| 221 |
+
bpm_value = None
|
| 222 |
+
if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
|
| 223 |
+
try:
|
| 224 |
+
bpm_value = int(data['bpm'])
|
| 225 |
+
except (ValueError, TypeError):
|
| 226 |
+
pass
|
| 227 |
+
|
| 228 |
+
duration_value = None
|
| 229 |
+
if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
|
| 230 |
+
try:
|
| 231 |
+
duration_value = float(data['duration'])
|
| 232 |
+
except (ValueError, TypeError):
|
| 233 |
+
pass
|
| 234 |
+
|
| 235 |
+
keyscale_value = data.get('keyscale', '')
|
| 236 |
+
if keyscale_value in [None, "N/A"]:
|
| 237 |
+
keyscale_value = ''
|
| 238 |
+
|
| 239 |
+
language_value = data.get('language', '')
|
| 240 |
+
if language_value in [None, "N/A"]:
|
| 241 |
+
language_value = ''
|
| 242 |
+
|
| 243 |
+
timesignature_value = data.get('timesignature', '')
|
| 244 |
+
if timesignature_value in [None, "N/A"]:
|
| 245 |
+
timesignature_value = ''
|
| 246 |
+
|
| 247 |
+
gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
|
| 248 |
+
return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
|
| 249 |
+
|
| 250 |
+
except json.JSONDecodeError as e:
|
| 251 |
+
gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
|
| 252 |
+
return "", "", True, None, None, "", "", ""
|
| 253 |
+
except Exception as e:
|
| 254 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 255 |
+
return "", "", True, None, None, "", "", ""
|
| 256 |
+
|
| 257 |
+
except Exception as e:
|
| 258 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 259 |
+
return "", "", True, None, None, "", "", ""
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
|
| 263 |
+
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 264 |
+
|
| 265 |
+
This is a Gradio wrapper that uses the understand_music API from acestep.inference
|
| 266 |
+
to generate examples when LM is available.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
llm_handler: LLM handler instance
|
| 270 |
+
task_type: The task type (e.g., "text2music")
|
| 271 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
| 275 |
+
"""
|
| 276 |
+
# Check if LM is initialized
|
| 277 |
+
if llm_handler.llm_initialized:
|
| 278 |
+
# Use LM to generate example via understand_music API
|
| 279 |
+
try:
|
| 280 |
+
result = understand_music(
|
| 281 |
+
llm_handler=llm_handler,
|
| 282 |
+
audio_codes="NO USER INPUT", # Empty input triggers example generation
|
| 283 |
+
temperature=0.85,
|
| 284 |
+
use_constrained_decoding=True,
|
| 285 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
if result.success:
|
| 289 |
+
gr.Info(t("messages.lm_generated"))
|
| 290 |
+
return (
|
| 291 |
+
result.caption,
|
| 292 |
+
result.lyrics,
|
| 293 |
+
True, # Always enable think when using LM-generated examples
|
| 294 |
+
result.bpm,
|
| 295 |
+
result.duration,
|
| 296 |
+
result.keyscale,
|
| 297 |
+
result.language,
|
| 298 |
+
result.timesignature,
|
| 299 |
+
)
|
| 300 |
+
else:
|
| 301 |
+
gr.Warning(t("messages.lm_fallback"))
|
| 302 |
+
return load_random_example(task_type)
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
gr.Warning(t("messages.lm_fallback"))
|
| 306 |
+
return load_random_example(task_type)
|
| 307 |
+
else:
|
| 308 |
+
# LM not initialized, use examples directory
|
| 309 |
+
return load_random_example(task_type)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def load_random_simple_description():
|
| 313 |
+
"""Load a random description from the simple_mode examples directory.
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
Tuple of (description, instrumental, vocal_language) for updating UI components
|
| 317 |
+
"""
|
| 318 |
+
try:
|
| 319 |
+
# Get the project root directory
|
| 320 |
+
current_file = os.path.abspath(__file__)
|
| 321 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 322 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 323 |
+
|
| 324 |
+
# Construct the examples directory path
|
| 325 |
+
examples_dir = os.path.join(project_root, "examples", "simple_mode")
|
| 326 |
+
|
| 327 |
+
# Check if directory exists
|
| 328 |
+
if not os.path.exists(examples_dir):
|
| 329 |
+
gr.Warning(t("messages.simple_examples_not_found"))
|
| 330 |
+
return gr.update(), gr.update(), gr.update()
|
| 331 |
+
|
| 332 |
+
# Find all JSON files in the directory
|
| 333 |
+
json_files = glob.glob(os.path.join(examples_dir, "*.json"))
|
| 334 |
+
|
| 335 |
+
if not json_files:
|
| 336 |
+
gr.Warning(t("messages.simple_examples_empty"))
|
| 337 |
+
return gr.update(), gr.update(), gr.update()
|
| 338 |
+
|
| 339 |
+
# Randomly select one file
|
| 340 |
+
selected_file = random.choice(json_files)
|
| 341 |
+
|
| 342 |
+
# Read and parse JSON
|
| 343 |
+
try:
|
| 344 |
+
with open(selected_file, 'r', encoding='utf-8') as f:
|
| 345 |
+
data = json.load(f)
|
| 346 |
+
|
| 347 |
+
# Extract fields
|
| 348 |
+
description = data.get('description', '')
|
| 349 |
+
instrumental = data.get('instrumental', False)
|
| 350 |
+
vocal_language = data.get('vocal_language', 'unknown')
|
| 351 |
+
|
| 352 |
+
# Ensure vocal_language is a string
|
| 353 |
+
if isinstance(vocal_language, list):
|
| 354 |
+
vocal_language = vocal_language[0] if vocal_language else 'unknown'
|
| 355 |
+
|
| 356 |
+
gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
|
| 357 |
+
return description, instrumental, vocal_language
|
| 358 |
+
|
| 359 |
+
except json.JSONDecodeError as e:
|
| 360 |
+
gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
|
| 361 |
+
return gr.update(), gr.update(), gr.update()
|
| 362 |
+
except Exception as e:
|
| 363 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 364 |
+
return gr.update(), gr.update(), gr.update()
|
| 365 |
+
|
| 366 |
+
except Exception as e:
|
| 367 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 368 |
+
return gr.update(), gr.update(), gr.update()
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def refresh_checkpoints(dit_handler):
|
| 372 |
+
"""Refresh available checkpoints"""
|
| 373 |
+
choices = dit_handler.get_available_checkpoints()
|
| 374 |
+
return gr.update(choices=choices)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def update_model_type_settings(config_path):
|
| 378 |
+
"""Update UI settings based on model type (fallback when handler not initialized yet)
|
| 379 |
+
|
| 380 |
+
Note: This is used as a fallback when the user changes config_path dropdown
|
| 381 |
+
before initializing the model. The actual settings are determined by the
|
| 382 |
+
handler's is_turbo_model() method after initialization.
|
| 383 |
+
"""
|
| 384 |
+
if config_path is None:
|
| 385 |
+
config_path = ""
|
| 386 |
+
config_path_lower = config_path.lower()
|
| 387 |
+
|
| 388 |
+
# Determine is_turbo based on config_path string
|
| 389 |
+
# This is a heuristic fallback - actual model type is determined after loading
|
| 390 |
+
if "turbo" in config_path_lower:
|
| 391 |
+
is_turbo = True
|
| 392 |
+
elif "base" in config_path_lower:
|
| 393 |
+
is_turbo = False
|
| 394 |
+
else:
|
| 395 |
+
# Default to turbo settings for unknown model types
|
| 396 |
+
is_turbo = True
|
| 397 |
+
|
| 398 |
+
return get_model_type_ui_settings(is_turbo)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 402 |
+
"""Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
|
| 403 |
+
# Initialize DiT handler
|
| 404 |
+
status, enable = dit_handler.initialize_service(
|
| 405 |
+
checkpoint, config_path, device,
|
| 406 |
+
use_flash_attention=use_flash_attention, compile_model=False,
|
| 407 |
+
offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Initialize LM handler if requested
|
| 411 |
+
if init_llm:
|
| 412 |
+
# Get checkpoint directory
|
| 413 |
+
current_file = os.path.abspath(__file__)
|
| 414 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 415 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 416 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 417 |
+
|
| 418 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 419 |
+
checkpoint_dir=checkpoint_dir,
|
| 420 |
+
lm_model_path=lm_model_path,
|
| 421 |
+
backend=backend,
|
| 422 |
+
device=device,
|
| 423 |
+
offload_to_cpu=offload_to_cpu,
|
| 424 |
+
dtype=dit_handler.dtype
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if lm_success:
|
| 428 |
+
status += f"\n{lm_status}"
|
| 429 |
+
else:
|
| 430 |
+
status += f"\n{lm_status}"
|
| 431 |
+
# Don't fail the entire initialization if LM fails, but log it
|
| 432 |
+
# Keep enable as is (DiT initialization result) even if LM fails
|
| 433 |
+
|
| 434 |
+
# Check if model is initialized - if so, collapse the accordion
|
| 435 |
+
is_model_initialized = dit_handler.model is not None
|
| 436 |
+
accordion_state = gr.Accordion(open=not is_model_initialized)
|
| 437 |
+
|
| 438 |
+
# Get model type settings based on actual loaded model
|
| 439 |
+
is_turbo = dit_handler.is_turbo_model()
|
| 440 |
+
model_type_settings = get_model_type_ui_settings(is_turbo)
|
| 441 |
+
|
| 442 |
+
return (
|
| 443 |
+
status,
|
| 444 |
+
gr.update(interactive=enable),
|
| 445 |
+
accordion_state,
|
| 446 |
+
*model_type_settings
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def get_model_type_ui_settings(is_turbo: bool):
|
| 451 |
+
"""Get UI settings based on whether the model is turbo or base"""
|
| 452 |
+
if is_turbo:
|
| 453 |
+
# Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
|
| 454 |
+
return (
|
| 455 |
+
gr.update(value=8, maximum=20, minimum=1), # inference_steps
|
| 456 |
+
gr.update(visible=False), # guidance_scale
|
| 457 |
+
gr.update(visible=False), # use_adg
|
| 458 |
+
gr.update(value=3.0, visible=True), # shift (show with default 3.0)
|
| 459 |
+
gr.update(visible=False), # cfg_interval_start
|
| 460 |
+
gr.update(visible=False), # cfg_interval_end
|
| 461 |
+
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
# Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
|
| 465 |
+
return (
|
| 466 |
+
gr.update(value=32, maximum=200, minimum=1), # inference_steps
|
| 467 |
+
gr.update(visible=True), # guidance_scale
|
| 468 |
+
gr.update(visible=True), # use_adg
|
| 469 |
+
gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
|
| 470 |
+
gr.update(visible=True), # cfg_interval_start
|
| 471 |
+
gr.update(visible=True), # cfg_interval_end
|
| 472 |
+
gr.update(choices=TASK_TYPES_BASE), # task_type
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def update_negative_prompt_visibility(init_llm_checked):
|
| 477 |
+
"""Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
|
| 478 |
+
return gr.update(visible=init_llm_checked)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
|
| 482 |
+
"""Update audio_cover_strength visibility and label"""
|
| 483 |
+
# Show if task is cover OR if LM is initialized (but NOT for repaint mode)
|
| 484 |
+
# Repaint mode never shows this control
|
| 485 |
+
is_repaint = task_type_value == "repaint"
|
| 486 |
+
is_cover = task_type_value == "cover"
|
| 487 |
+
is_visible = is_cover or (init_llm_checked and not is_repaint)
|
| 488 |
+
|
| 489 |
+
# Change label based on context
|
| 490 |
+
if init_llm_checked and not is_cover:
|
| 491 |
+
label = "LM codes strength"
|
| 492 |
+
info = "Control how many denoising steps use LM-generated codes"
|
| 493 |
+
else:
|
| 494 |
+
label = "Audio Cover Strength"
|
| 495 |
+
info = "Control how many denoising steps use cover mode"
|
| 496 |
+
|
| 497 |
+
return gr.update(visible=is_visible, label=label, info=info)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
|
| 501 |
+
"""Wrapper for converting src audio to codes"""
|
| 502 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 503 |
+
return codes_string
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def update_instruction_ui(
|
| 507 |
+
dit_handler,
|
| 508 |
+
task_type_value: str,
|
| 509 |
+
track_name_value: Optional[str],
|
| 510 |
+
complete_track_classes_value: list,
|
| 511 |
+
audio_codes_content: str = "",
|
| 512 |
+
init_llm_checked: bool = False
|
| 513 |
+
) -> tuple:
|
| 514 |
+
"""Update instruction and UI visibility based on task type."""
|
| 515 |
+
instruction = dit_handler.generate_instruction(
|
| 516 |
+
task_type=task_type_value,
|
| 517 |
+
track_name=track_name_value,
|
| 518 |
+
complete_track_classes=complete_track_classes_value
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Show track_name for lego and extract
|
| 522 |
+
track_name_visible = task_type_value in ["lego", "extract"]
|
| 523 |
+
# Show complete_track_classes for complete
|
| 524 |
+
complete_visible = task_type_value == "complete"
|
| 525 |
+
# Show audio_cover_strength for cover OR when LM is initialized (but NOT for repaint)
|
| 526 |
+
is_repaint = task_type_value == "repaint"
|
| 527 |
+
is_cover = task_type_value == "cover"
|
| 528 |
+
audio_cover_strength_visible = is_cover or (init_llm_checked and not is_repaint)
|
| 529 |
+
# Determine label and info based on context
|
| 530 |
+
if init_llm_checked and not is_cover:
|
| 531 |
+
audio_cover_strength_label = "LM codes strength"
|
| 532 |
+
audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
|
| 533 |
+
else:
|
| 534 |
+
audio_cover_strength_label = "Audio Cover Strength"
|
| 535 |
+
audio_cover_strength_info = "Control how many denoising steps use cover mode"
|
| 536 |
+
# Show repainting controls for repaint and lego
|
| 537 |
+
repainting_visible = task_type_value in ["repaint", "lego"]
|
| 538 |
+
# Show text2music_audio_codes if task is text2music OR if it has content
|
| 539 |
+
# This allows it to stay visible even if user switches task type but has codes
|
| 540 |
+
has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
|
| 541 |
+
text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
|
| 542 |
+
|
| 543 |
+
return (
|
| 544 |
+
instruction, # instruction_display_gen
|
| 545 |
+
gr.update(visible=track_name_visible), # track_name
|
| 546 |
+
gr.update(visible=complete_visible), # complete_track_classes
|
| 547 |
+
gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
|
| 548 |
+
gr.update(visible=repainting_visible), # repainting_group
|
| 549 |
+
gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
|
| 554 |
+
"""
|
| 555 |
+
Transcribe audio codes to metadata using LLM understanding.
|
| 556 |
+
If audio_code_string is empty, generate a sample example instead.
|
| 557 |
+
|
| 558 |
+
This is a Gradio wrapper around the understand_music API in acestep.inference.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
llm_handler: LLM handler instance
|
| 562 |
+
audio_code_string: String containing audio codes (or empty for example generation)
|
| 563 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 567 |
+
"""
|
| 568 |
+
# Call the inference API
|
| 569 |
+
result = understand_music(
|
| 570 |
+
llm_handler=llm_handler,
|
| 571 |
+
audio_codes=audio_code_string,
|
| 572 |
+
use_constrained_decoding=True,
|
| 573 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Handle error case with localized message
|
| 577 |
+
if not result.success:
|
| 578 |
+
# Use localized error message for LLM not initialized
|
| 579 |
+
if result.error == "LLM not initialized":
|
| 580 |
+
return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
|
| 581 |
+
return result.status_message, "", "", None, None, "", "", "", False
|
| 582 |
+
|
| 583 |
+
return (
|
| 584 |
+
result.status_message,
|
| 585 |
+
result.caption,
|
| 586 |
+
result.lyrics,
|
| 587 |
+
result.bpm,
|
| 588 |
+
result.duration,
|
| 589 |
+
result.keyscale,
|
| 590 |
+
result.language,
|
| 591 |
+
result.timesignature,
|
| 592 |
+
True # Set is_format_caption to True (from Transcribe/LM understanding)
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def update_transcribe_button_text(audio_code_string):
|
| 597 |
+
"""
|
| 598 |
+
Update the transcribe button text based on input content.
|
| 599 |
+
If empty: "Generate Example"
|
| 600 |
+
If has content: "Transcribe"
|
| 601 |
+
"""
|
| 602 |
+
if not audio_code_string or not audio_code_string.strip():
|
| 603 |
+
return gr.update(value="Generate Example")
|
| 604 |
+
else:
|
| 605 |
+
return gr.update(value="Transcribe")
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def reset_format_caption_flag():
|
| 609 |
+
"""Reset is_format_caption to False when user manually edits caption/metadata"""
|
| 610 |
+
return False
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 614 |
+
"""Update Audio Uploads visibility based on whether audio files are present"""
|
| 615 |
+
has_audio = (reference_audio is not None) or (src_audio is not None)
|
| 616 |
+
return gr.update(visible=has_audio)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
|
| 620 |
+
"""
|
| 621 |
+
Handle instrumental checkbox changes.
|
| 622 |
+
When checked: if no lyrics, fill with [Instrumental]
|
| 623 |
+
When unchecked: if lyrics is [Instrumental], clear it
|
| 624 |
+
"""
|
| 625 |
+
if instrumental_checked:
|
| 626 |
+
# If checked and no lyrics, fill with [Instrumental]
|
| 627 |
+
if not current_lyrics or not current_lyrics.strip():
|
| 628 |
+
return "[Instrumental]"
|
| 629 |
+
else:
|
| 630 |
+
# Has lyrics, don't change
|
| 631 |
+
return current_lyrics
|
| 632 |
+
else:
|
| 633 |
+
# If unchecked and lyrics is exactly [Instrumental], clear it
|
| 634 |
+
if current_lyrics and current_lyrics.strip() == "[Instrumental]":
|
| 635 |
+
return ""
|
| 636 |
+
else:
|
| 637 |
+
# Has other lyrics, don't change
|
| 638 |
+
return current_lyrics
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def handle_simple_instrumental_change(is_instrumental: bool):
|
| 642 |
+
"""
|
| 643 |
+
Handle simple mode instrumental checkbox changes.
|
| 644 |
+
When checked: set vocal_language to "unknown" and disable editing.
|
| 645 |
+
When unchecked: enable vocal_language editing.
|
| 646 |
+
|
| 647 |
+
Args:
|
| 648 |
+
is_instrumental: Whether instrumental checkbox is checked
|
| 649 |
+
|
| 650 |
+
Returns:
|
| 651 |
+
gr.update for simple_vocal_language dropdown
|
| 652 |
+
"""
|
| 653 |
+
if is_instrumental:
|
| 654 |
+
return gr.update(value="unknown", interactive=False)
|
| 655 |
+
else:
|
| 656 |
+
return gr.update(interactive=True)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def update_audio_components_visibility(batch_size):
|
| 660 |
+
"""Show/hide individual audio components based on batch size (1-8)
|
| 661 |
+
|
| 662 |
+
Row 1: Components 1-4 (batch_size 1-4)
|
| 663 |
+
Row 2: Components 5-8 (batch_size 5-8)
|
| 664 |
+
"""
|
| 665 |
+
# Clamp batch size to 1-8 range for UI
|
| 666 |
+
batch_size = min(max(int(batch_size), 1), 8)
|
| 667 |
+
|
| 668 |
+
# Row 1 columns (1-4)
|
| 669 |
+
updates_row1 = (
|
| 670 |
+
gr.update(visible=True), # audio_col_1: always visible
|
| 671 |
+
gr.update(visible=batch_size >= 2), # audio_col_2
|
| 672 |
+
gr.update(visible=batch_size >= 3), # audio_col_3
|
| 673 |
+
gr.update(visible=batch_size >= 4), # audio_col_4
|
| 674 |
+
)
|
| 675 |
+
|
| 676 |
+
# Row 2 container and columns (5-8)
|
| 677 |
+
show_row_5_8 = batch_size >= 5
|
| 678 |
+
updates_row2 = (
|
| 679 |
+
gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
|
| 680 |
+
gr.update(visible=batch_size >= 5), # audio_col_5
|
| 681 |
+
gr.update(visible=batch_size >= 6), # audio_col_6
|
| 682 |
+
gr.update(visible=batch_size >= 7), # audio_col_7
|
| 683 |
+
gr.update(visible=batch_size >= 8), # audio_col_8
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
return updates_row1 + updates_row2
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def handle_generation_mode_change(mode: str):
|
| 690 |
+
"""
|
| 691 |
+
Handle generation mode change between Simple, Custom, Cover, and Repaint modes.
|
| 692 |
+
|
| 693 |
+
Modes:
|
| 694 |
+
- Simple: Show simple mode group, hide others
|
| 695 |
+
- Custom: Show custom content (prompt), hide others
|
| 696 |
+
- Cover: Show src_audio_group + custom content + LM codes strength
|
| 697 |
+
- Repaint: Show src_audio_group + custom content + repaint time controls (hide LM codes strength)
|
| 698 |
+
|
| 699 |
+
Args:
|
| 700 |
+
mode: "simple", "custom", "cover", or "repaint"
|
| 701 |
+
|
| 702 |
+
Returns:
|
| 703 |
+
Tuple of updates for:
|
| 704 |
+
- simple_mode_group (visibility)
|
| 705 |
+
- custom_mode_content (visibility)
|
| 706 |
+
- cover_mode_group (visibility) - legacy, always hidden
|
| 707 |
+
- repainting_group (visibility)
|
| 708 |
+
- task_type (value)
|
| 709 |
+
- generate_btn (interactive state)
|
| 710 |
+
- simple_sample_created (reset state)
|
| 711 |
+
- src_audio_group (visibility) - shown for cover and repaint
|
| 712 |
+
- audio_cover_strength (visibility) - shown only for cover mode
|
| 713 |
+
- think_checkbox (value and interactive) - disabled for cover/repaint modes
|
| 714 |
+
"""
|
| 715 |
+
is_simple = mode == "simple"
|
| 716 |
+
is_custom = mode == "custom"
|
| 717 |
+
is_cover = mode == "cover"
|
| 718 |
+
is_repaint = mode == "repaint"
|
| 719 |
+
|
| 720 |
+
# Map mode to task_type
|
| 721 |
+
task_type_map = {
|
| 722 |
+
"simple": "text2music",
|
| 723 |
+
"custom": "text2music",
|
| 724 |
+
"cover": "cover",
|
| 725 |
+
"repaint": "repaint",
|
| 726 |
+
}
|
| 727 |
+
task_type_value = task_type_map.get(mode, "text2music")
|
| 728 |
+
|
| 729 |
+
# think_checkbox: disabled and set to False for cover/repaint modes
|
| 730 |
+
# (these modes don't use LM thinking, they use source audio codes)
|
| 731 |
+
if is_cover or is_repaint:
|
| 732 |
+
think_checkbox_update = gr.update(value=False, interactive=False)
|
| 733 |
+
else:
|
| 734 |
+
think_checkbox_update = gr.update(value=True, interactive=True)
|
| 735 |
+
|
| 736 |
+
return (
|
| 737 |
+
gr.update(visible=is_simple), # simple_mode_group
|
| 738 |
+
gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
|
| 739 |
+
gr.update(visible=False), # cover_mode_group - legacy, always hidden
|
| 740 |
+
gr.update(visible=is_repaint), # repainting_group - time range controls
|
| 741 |
+
gr.update(value=task_type_value), # task_type
|
| 742 |
+
gr.update(interactive=True), # generate_btn - always enabled (Simple mode does create+generate in one step)
|
| 743 |
+
False, # simple_sample_created - reset to False on mode change
|
| 744 |
+
gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
|
| 745 |
+
gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
|
| 746 |
+
think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
|
| 751 |
+
"""
|
| 752 |
+
Process source audio: convert to codes and then transcribe.
|
| 753 |
+
This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
dit_handler: DiT handler instance for audio code conversion
|
| 757 |
+
llm_handler: LLM handler instance for transcription
|
| 758 |
+
src_audio: Path to source audio file
|
| 759 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 760 |
+
|
| 761 |
+
Returns:
|
| 762 |
+
Tuple of (audio_codes, status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 763 |
+
"""
|
| 764 |
+
if src_audio is None:
|
| 765 |
+
return ("", "No audio file provided", "", "", None, None, "", "", "", False)
|
| 766 |
+
|
| 767 |
+
# Step 1: Convert audio to codes
|
| 768 |
+
try:
|
| 769 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 770 |
+
if not codes_string:
|
| 771 |
+
return ("", "Failed to convert audio to codes", "", "", None, None, "", "", "", False)
|
| 772 |
+
except Exception as e:
|
| 773 |
+
return ("", f"Error converting audio: {str(e)}", "", "", None, None, "", "", "", False)
|
| 774 |
+
|
| 775 |
+
# Step 2: Transcribe the codes
|
| 776 |
+
result = understand_music(
|
| 777 |
+
llm_handler=llm_handler,
|
| 778 |
+
audio_codes=codes_string,
|
| 779 |
+
use_constrained_decoding=True,
|
| 780 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
# Handle error case
|
| 784 |
+
if not result.success:
|
| 785 |
+
if result.error == "LLM not initialized":
|
| 786 |
+
return (codes_string, t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False)
|
| 787 |
+
return (codes_string, result.status_message, "", "", None, None, "", "", "", False)
|
| 788 |
+
|
| 789 |
+
return (
|
| 790 |
+
codes_string,
|
| 791 |
+
result.status_message,
|
| 792 |
+
result.caption,
|
| 793 |
+
result.lyrics,
|
| 794 |
+
result.bpm,
|
| 795 |
+
result.duration,
|
| 796 |
+
result.keyscale,
|
| 797 |
+
result.language,
|
| 798 |
+
result.timesignature,
|
| 799 |
+
True # Set is_format_caption to True
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def handle_create_sample(
|
| 804 |
+
llm_handler,
|
| 805 |
+
query: str,
|
| 806 |
+
instrumental: bool,
|
| 807 |
+
vocal_language: str,
|
| 808 |
+
lm_temperature: float,
|
| 809 |
+
lm_top_k: int,
|
| 810 |
+
lm_top_p: float,
|
| 811 |
+
constrained_decoding_debug: bool = False,
|
| 812 |
+
):
|
| 813 |
+
"""
|
| 814 |
+
Handle the Create Sample button click in Simple mode.
|
| 815 |
+
|
| 816 |
+
Creates a sample from the user's query using the LLM, then populates
|
| 817 |
+
the caption, lyrics, and metadata fields.
|
| 818 |
+
|
| 819 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 820 |
+
|
| 821 |
+
Args:
|
| 822 |
+
llm_handler: LLM handler instance
|
| 823 |
+
query: User's natural language music description
|
| 824 |
+
instrumental: Whether to generate instrumental music
|
| 825 |
+
vocal_language: Preferred vocal language for constrained decoding
|
| 826 |
+
lm_temperature: LLM temperature for generation
|
| 827 |
+
lm_top_k: LLM top-k sampling
|
| 828 |
+
lm_top_p: LLM top-p sampling
|
| 829 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 830 |
+
|
| 831 |
+
Returns:
|
| 832 |
+
Tuple of updates for:
|
| 833 |
+
- captions
|
| 834 |
+
- lyrics
|
| 835 |
+
- bpm
|
| 836 |
+
- audio_duration
|
| 837 |
+
- key_scale
|
| 838 |
+
- vocal_language
|
| 839 |
+
- time_signature
|
| 840 |
+
- instrumental_checkbox
|
| 841 |
+
- caption_accordion (open)
|
| 842 |
+
- lyrics_accordion (open)
|
| 843 |
+
- generate_btn (interactive)
|
| 844 |
+
- simple_sample_created (True)
|
| 845 |
+
- think_checkbox (True)
|
| 846 |
+
- is_format_caption_state (True)
|
| 847 |
+
- status_output
|
| 848 |
+
"""
|
| 849 |
+
# Check if LLM is initialized
|
| 850 |
+
if not llm_handler.llm_initialized:
|
| 851 |
+
gr.Warning(t("messages.lm_not_initialized"))
|
| 852 |
+
return (
|
| 853 |
+
gr.update(), # captions - no change
|
| 854 |
+
gr.update(), # lyrics - no change
|
| 855 |
+
gr.update(), # bpm - no change
|
| 856 |
+
gr.update(), # audio_duration - no change
|
| 857 |
+
gr.update(), # key_scale - no change
|
| 858 |
+
gr.update(), # vocal_language - no change
|
| 859 |
+
gr.update(), # time_signature - no change
|
| 860 |
+
gr.update(), # instrumental_checkbox - no change
|
| 861 |
+
gr.update(), # caption_accordion - no change
|
| 862 |
+
gr.update(), # lyrics_accordion - no change
|
| 863 |
+
gr.update(interactive=False), # generate_btn - keep disabled
|
| 864 |
+
False, # simple_sample_created - still False
|
| 865 |
+
gr.update(), # think_checkbox - no change
|
| 866 |
+
gr.update(), # is_format_caption_state - no change
|
| 867 |
+
t("messages.lm_not_initialized"), # status_output
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Convert LM parameters
|
| 871 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 872 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 873 |
+
|
| 874 |
+
# Call create_sample API
|
| 875 |
+
# Note: cfg_scale and negative_prompt are not supported in create_sample mode
|
| 876 |
+
result = create_sample(
|
| 877 |
+
llm_handler=llm_handler,
|
| 878 |
+
query=query,
|
| 879 |
+
instrumental=instrumental,
|
| 880 |
+
vocal_language=vocal_language,
|
| 881 |
+
temperature=lm_temperature,
|
| 882 |
+
top_k=top_k_value,
|
| 883 |
+
top_p=top_p_value,
|
| 884 |
+
use_constrained_decoding=True,
|
| 885 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
# Handle error
|
| 889 |
+
if not result.success:
|
| 890 |
+
gr.Warning(result.status_message or t("messages.sample_creation_failed"))
|
| 891 |
+
return (
|
| 892 |
+
gr.update(), # captions - no change
|
| 893 |
+
gr.update(), # lyrics - no change
|
| 894 |
+
gr.update(), # bpm - no change
|
| 895 |
+
gr.update(), # audio_duration - no change
|
| 896 |
+
gr.update(), # key_scale - no change
|
| 897 |
+
gr.update(), # vocal_language - no change
|
| 898 |
+
gr.update(), # simple vocal_language - no change
|
| 899 |
+
gr.update(), # time_signature - no change
|
| 900 |
+
gr.update(), # instrumental_checkbox - no change
|
| 901 |
+
gr.update(), # caption_accordion - no change
|
| 902 |
+
gr.update(), # lyrics_accordion - no change
|
| 903 |
+
gr.update(interactive=False), # generate_btn - keep disabled
|
| 904 |
+
False, # simple_sample_created - still False
|
| 905 |
+
gr.update(), # think_checkbox - no change
|
| 906 |
+
gr.update(), # is_format_caption_state - no change
|
| 907 |
+
result.status_message or t("messages.sample_creation_failed"), # status_output
|
| 908 |
+
)
|
| 909 |
+
|
| 910 |
+
# Success - populate fields
|
| 911 |
+
gr.Info(t("messages.sample_created"))
|
| 912 |
+
|
| 913 |
+
return (
|
| 914 |
+
result.caption, # captions
|
| 915 |
+
result.lyrics, # lyrics
|
| 916 |
+
result.bpm, # bpm
|
| 917 |
+
result.duration if result.duration and result.duration > 0 else -1, # audio_duration
|
| 918 |
+
result.keyscale, # key_scale
|
| 919 |
+
result.language, # vocal_language
|
| 920 |
+
result.language, # simple vocal_language
|
| 921 |
+
result.timesignature, # time_signature
|
| 922 |
+
result.instrumental, # instrumental_checkbox
|
| 923 |
+
gr.Accordion(open=True), # caption_accordion - expand
|
| 924 |
+
gr.Accordion(open=True), # lyrics_accordion - expand
|
| 925 |
+
gr.update(interactive=True), # generate_btn - enable
|
| 926 |
+
True, # simple_sample_created - True
|
| 927 |
+
True, # think_checkbox - enable thinking
|
| 928 |
+
True, # is_format_caption_state - True (LM-generated)
|
| 929 |
+
result.status_message, # status_output
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
def handle_format_sample(
|
| 934 |
+
llm_handler,
|
| 935 |
+
caption: str,
|
| 936 |
+
lyrics: str,
|
| 937 |
+
bpm,
|
| 938 |
+
audio_duration,
|
| 939 |
+
key_scale: str,
|
| 940 |
+
time_signature: str,
|
| 941 |
+
lm_temperature: float,
|
| 942 |
+
lm_top_k: int,
|
| 943 |
+
lm_top_p: float,
|
| 944 |
+
constrained_decoding_debug: bool = False,
|
| 945 |
+
):
|
| 946 |
+
"""
|
| 947 |
+
Handle the Format button click to format caption and lyrics.
|
| 948 |
+
|
| 949 |
+
Takes user-provided caption and lyrics, and uses the LLM to generate
|
| 950 |
+
structured music metadata and an enhanced description.
|
| 951 |
+
|
| 952 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 953 |
+
|
| 954 |
+
Args:
|
| 955 |
+
llm_handler: LLM handler instance
|
| 956 |
+
caption: User's caption/description
|
| 957 |
+
lyrics: User's lyrics
|
| 958 |
+
bpm: User-provided BPM (optional, for constrained decoding)
|
| 959 |
+
audio_duration: User-provided duration (optional, for constrained decoding)
|
| 960 |
+
key_scale: User-provided key scale (optional, for constrained decoding)
|
| 961 |
+
time_signature: User-provided time signature (optional, for constrained decoding)
|
| 962 |
+
lm_temperature: LLM temperature for generation
|
| 963 |
+
lm_top_k: LLM top-k sampling
|
| 964 |
+
lm_top_p: LLM top-p sampling
|
| 965 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 966 |
+
|
| 967 |
+
Returns:
|
| 968 |
+
Tuple of updates for:
|
| 969 |
+
- captions
|
| 970 |
+
- lyrics
|
| 971 |
+
- bpm
|
| 972 |
+
- audio_duration
|
| 973 |
+
- key_scale
|
| 974 |
+
- vocal_language
|
| 975 |
+
- time_signature
|
| 976 |
+
- is_format_caption_state
|
| 977 |
+
- status_output
|
| 978 |
+
"""
|
| 979 |
+
# Check if LLM is initialized
|
| 980 |
+
if not llm_handler.llm_initialized:
|
| 981 |
+
gr.Warning(t("messages.lm_not_initialized"))
|
| 982 |
+
return (
|
| 983 |
+
gr.update(), # captions - no change
|
| 984 |
+
gr.update(), # lyrics - no change
|
| 985 |
+
gr.update(), # bpm - no change
|
| 986 |
+
gr.update(), # audio_duration - no change
|
| 987 |
+
gr.update(), # key_scale - no change
|
| 988 |
+
gr.update(), # vocal_language - no change
|
| 989 |
+
gr.update(), # time_signature - no change
|
| 990 |
+
gr.update(), # is_format_caption_state - no change
|
| 991 |
+
t("messages.lm_not_initialized"), # status_output
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
# Build user_metadata from provided values for constrained decoding
|
| 995 |
+
user_metadata = {}
|
| 996 |
+
if bpm is not None and bpm > 0:
|
| 997 |
+
user_metadata['bpm'] = int(bpm)
|
| 998 |
+
if audio_duration is not None and audio_duration > 0:
|
| 999 |
+
user_metadata['duration'] = int(audio_duration)
|
| 1000 |
+
if key_scale and key_scale.strip():
|
| 1001 |
+
user_metadata['keyscale'] = key_scale.strip()
|
| 1002 |
+
if time_signature and time_signature.strip():
|
| 1003 |
+
user_metadata['timesignature'] = time_signature.strip()
|
| 1004 |
+
|
| 1005 |
+
# Only pass user_metadata if we have at least one field
|
| 1006 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 1007 |
+
|
| 1008 |
+
# Convert LM parameters
|
| 1009 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 1010 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 1011 |
+
|
| 1012 |
+
# Call format_sample API
|
| 1013 |
+
result = format_sample(
|
| 1014 |
+
llm_handler=llm_handler,
|
| 1015 |
+
caption=caption,
|
| 1016 |
+
lyrics=lyrics,
|
| 1017 |
+
user_metadata=user_metadata_to_pass,
|
| 1018 |
+
temperature=lm_temperature,
|
| 1019 |
+
top_k=top_k_value,
|
| 1020 |
+
top_p=top_p_value,
|
| 1021 |
+
use_constrained_decoding=True,
|
| 1022 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# Handle error
|
| 1026 |
+
if not result.success:
|
| 1027 |
+
gr.Warning(result.status_message or t("messages.format_failed"))
|
| 1028 |
+
return (
|
| 1029 |
+
gr.update(), # captions - no change
|
| 1030 |
+
gr.update(), # lyrics - no change
|
| 1031 |
+
gr.update(), # bpm - no change
|
| 1032 |
+
gr.update(), # audio_duration - no change
|
| 1033 |
+
gr.update(), # key_scale - no change
|
| 1034 |
+
gr.update(), # vocal_language - no change
|
| 1035 |
+
gr.update(), # time_signature - no change
|
| 1036 |
+
gr.update(), # is_format_caption_state - no change
|
| 1037 |
+
result.status_message or t("messages.format_failed"), # status_output
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
# Success - populate fields
|
| 1041 |
+
gr.Info(t("messages.format_success"))
|
| 1042 |
+
|
| 1043 |
+
return (
|
| 1044 |
+
result.caption, # captions
|
| 1045 |
+
result.lyrics, # lyrics
|
| 1046 |
+
result.bpm, # bpm
|
| 1047 |
+
result.duration if result.duration and result.duration > 0 else -1, # audio_duration
|
| 1048 |
+
result.keyscale, # key_scale
|
| 1049 |
+
result.language, # vocal_language
|
| 1050 |
+
result.timesignature, # time_signature
|
| 1051 |
+
True, # is_format_caption_state - True (LM-formatted)
|
| 1052 |
+
result.status_message, # status_output
|
| 1053 |
+
)
|
| 1054 |
+
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/results_handlers.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/events/training_handlers.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event Handlers for Training Tab
|
| 3 |
+
|
| 4 |
+
Contains all event handler functions for the dataset builder and training UI.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Optional
|
| 10 |
+
from loguru import logger
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_dataset_builder() -> DatasetBuilder:
|
| 17 |
+
"""Create a new DatasetBuilder instance."""
|
| 18 |
+
return DatasetBuilder()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def scan_directory(
|
| 22 |
+
audio_dir: str,
|
| 23 |
+
dataset_name: str,
|
| 24 |
+
custom_tag: str,
|
| 25 |
+
tag_position: str,
|
| 26 |
+
all_instrumental: bool,
|
| 27 |
+
builder_state: Optional[DatasetBuilder],
|
| 28 |
+
) -> Tuple[Any, str, Any, DatasetBuilder]:
|
| 29 |
+
"""Scan a directory for audio files.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Tuple of (table_data, status, slider_update, builder_state)
|
| 33 |
+
"""
|
| 34 |
+
if not audio_dir or not audio_dir.strip():
|
| 35 |
+
return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
|
| 36 |
+
|
| 37 |
+
# Create or use existing builder
|
| 38 |
+
builder = builder_state if builder_state else DatasetBuilder()
|
| 39 |
+
|
| 40 |
+
# Set metadata before scanning
|
| 41 |
+
builder.metadata.name = dataset_name
|
| 42 |
+
builder.metadata.custom_tag = custom_tag
|
| 43 |
+
builder.metadata.tag_position = tag_position
|
| 44 |
+
builder.metadata.all_instrumental = all_instrumental
|
| 45 |
+
|
| 46 |
+
# Scan directory
|
| 47 |
+
samples, status = builder.scan_directory(audio_dir.strip())
|
| 48 |
+
|
| 49 |
+
if not samples:
|
| 50 |
+
return [], status, gr.Slider(maximum=0, value=0), builder
|
| 51 |
+
|
| 52 |
+
# Set instrumental and tag for all samples
|
| 53 |
+
builder.set_all_instrumental(all_instrumental)
|
| 54 |
+
if custom_tag:
|
| 55 |
+
builder.set_custom_tag(custom_tag, tag_position)
|
| 56 |
+
|
| 57 |
+
# Get table data
|
| 58 |
+
table_data = builder.get_samples_dataframe_data()
|
| 59 |
+
|
| 60 |
+
# Calculate slider max and return as Slider update
|
| 61 |
+
slider_max = max(0, len(samples) - 1)
|
| 62 |
+
|
| 63 |
+
return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def auto_label_all(
|
| 67 |
+
dit_handler,
|
| 68 |
+
llm_handler,
|
| 69 |
+
builder_state: Optional[DatasetBuilder],
|
| 70 |
+
skip_metas: bool = False,
|
| 71 |
+
progress=None,
|
| 72 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 73 |
+
"""Auto-label all samples in the dataset.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
dit_handler: DiT handler for audio processing
|
| 77 |
+
llm_handler: LLM handler for caption generation
|
| 78 |
+
builder_state: Dataset builder state
|
| 79 |
+
skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
|
| 80 |
+
progress: Progress callback
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
Tuple of (table_data, status, builder_state)
|
| 84 |
+
"""
|
| 85 |
+
if builder_state is None:
|
| 86 |
+
return [], "❌ Please scan a directory first", builder_state
|
| 87 |
+
|
| 88 |
+
if not builder_state.samples:
|
| 89 |
+
return [], "❌ No samples to label. Please scan a directory first.", builder_state
|
| 90 |
+
|
| 91 |
+
# If skip_metas is True, just set default values without LLM
|
| 92 |
+
if skip_metas:
|
| 93 |
+
for sample in builder_state.samples:
|
| 94 |
+
sample.bpm = None # Will display as N/A
|
| 95 |
+
sample.keyscale = "N/A"
|
| 96 |
+
sample.timesignature = "N/A"
|
| 97 |
+
# For instrumental, language should be "unknown"
|
| 98 |
+
if sample.is_instrumental:
|
| 99 |
+
sample.language = "unknown"
|
| 100 |
+
else:
|
| 101 |
+
sample.language = "unknown"
|
| 102 |
+
# Use custom tag as caption if set, otherwise use filename
|
| 103 |
+
if builder_state.metadata.custom_tag:
|
| 104 |
+
sample.caption = builder_state.metadata.custom_tag
|
| 105 |
+
else:
|
| 106 |
+
sample.caption = sample.filename
|
| 107 |
+
|
| 108 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 109 |
+
return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
|
| 110 |
+
|
| 111 |
+
# Check if handlers are initialized
|
| 112 |
+
if dit_handler is None or dit_handler.model is None:
|
| 113 |
+
return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
|
| 114 |
+
|
| 115 |
+
if llm_handler is None or not llm_handler.llm_initialized:
|
| 116 |
+
return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
| 117 |
+
|
| 118 |
+
def progress_callback(msg):
|
| 119 |
+
if progress:
|
| 120 |
+
try:
|
| 121 |
+
progress(msg)
|
| 122 |
+
except:
|
| 123 |
+
pass
|
| 124 |
+
|
| 125 |
+
# Label all samples
|
| 126 |
+
samples, status = builder_state.label_all_samples(
|
| 127 |
+
dit_handler=dit_handler,
|
| 128 |
+
llm_handler=llm_handler,
|
| 129 |
+
progress_callback=progress_callback,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Get updated table data
|
| 133 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 134 |
+
|
| 135 |
+
return table_data, status, builder_state
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_sample_preview(
|
| 139 |
+
sample_idx: int,
|
| 140 |
+
builder_state: Optional[DatasetBuilder],
|
| 141 |
+
) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
|
| 142 |
+
"""Get preview data for a specific sample.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
|
| 146 |
+
"""
|
| 147 |
+
if builder_state is None or not builder_state.samples:
|
| 148 |
+
return None, "", "", "", None, "", "", 0.0, "instrumental", True
|
| 149 |
+
|
| 150 |
+
idx = int(sample_idx)
|
| 151 |
+
if idx < 0 or idx >= len(builder_state.samples):
|
| 152 |
+
return None, "", "", "", None, "", "", 0.0, "instrumental", True
|
| 153 |
+
|
| 154 |
+
sample = builder_state.samples[idx]
|
| 155 |
+
|
| 156 |
+
return (
|
| 157 |
+
sample.audio_path,
|
| 158 |
+
sample.filename,
|
| 159 |
+
sample.caption,
|
| 160 |
+
sample.lyrics,
|
| 161 |
+
sample.bpm,
|
| 162 |
+
sample.keyscale,
|
| 163 |
+
sample.timesignature,
|
| 164 |
+
sample.duration,
|
| 165 |
+
sample.language,
|
| 166 |
+
sample.is_instrumental,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def save_sample_edit(
|
| 171 |
+
sample_idx: int,
|
| 172 |
+
caption: str,
|
| 173 |
+
lyrics: str,
|
| 174 |
+
bpm: Optional[int],
|
| 175 |
+
keyscale: str,
|
| 176 |
+
timesig: str,
|
| 177 |
+
language: str,
|
| 178 |
+
is_instrumental: bool,
|
| 179 |
+
builder_state: Optional[DatasetBuilder],
|
| 180 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 181 |
+
"""Save edits to a sample.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple of (table_data, status, builder_state)
|
| 185 |
+
"""
|
| 186 |
+
if builder_state is None:
|
| 187 |
+
return [], "❌ No dataset loaded", builder_state
|
| 188 |
+
|
| 189 |
+
idx = int(sample_idx)
|
| 190 |
+
|
| 191 |
+
# Update sample
|
| 192 |
+
sample, status = builder_state.update_sample(
|
| 193 |
+
idx,
|
| 194 |
+
caption=caption,
|
| 195 |
+
lyrics=lyrics if not is_instrumental else "[Instrumental]",
|
| 196 |
+
bpm=int(bpm) if bpm else None,
|
| 197 |
+
keyscale=keyscale,
|
| 198 |
+
timesignature=timesig,
|
| 199 |
+
language="instrumental" if is_instrumental else language,
|
| 200 |
+
is_instrumental=is_instrumental,
|
| 201 |
+
labeled=True,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Get updated table data
|
| 205 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 206 |
+
|
| 207 |
+
return table_data, status, builder_state
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def update_settings(
|
| 211 |
+
custom_tag: str,
|
| 212 |
+
tag_position: str,
|
| 213 |
+
all_instrumental: bool,
|
| 214 |
+
builder_state: Optional[DatasetBuilder],
|
| 215 |
+
) -> DatasetBuilder:
|
| 216 |
+
"""Update dataset settings.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
Updated builder_state
|
| 220 |
+
"""
|
| 221 |
+
if builder_state is None:
|
| 222 |
+
return builder_state
|
| 223 |
+
|
| 224 |
+
if custom_tag:
|
| 225 |
+
builder_state.set_custom_tag(custom_tag, tag_position)
|
| 226 |
+
|
| 227 |
+
builder_state.set_all_instrumental(all_instrumental)
|
| 228 |
+
|
| 229 |
+
return builder_state
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def save_dataset(
|
| 233 |
+
save_path: str,
|
| 234 |
+
dataset_name: str,
|
| 235 |
+
builder_state: Optional[DatasetBuilder],
|
| 236 |
+
) -> str:
|
| 237 |
+
"""Save the dataset to a JSON file.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
Status message
|
| 241 |
+
"""
|
| 242 |
+
if builder_state is None:
|
| 243 |
+
return "❌ No dataset to save. Please scan a directory first."
|
| 244 |
+
|
| 245 |
+
if not builder_state.samples:
|
| 246 |
+
return "❌ No samples in dataset."
|
| 247 |
+
|
| 248 |
+
if not save_path or not save_path.strip():
|
| 249 |
+
return "❌ Please enter a save path."
|
| 250 |
+
|
| 251 |
+
# Check if any samples are labeled
|
| 252 |
+
labeled_count = builder_state.get_labeled_count()
|
| 253 |
+
if labeled_count == 0:
|
| 254 |
+
return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
|
| 255 |
+
|
| 256 |
+
return builder_state.save_dataset(save_path.strip(), dataset_name)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def load_existing_dataset_for_preprocess(
|
| 260 |
+
dataset_path: str,
|
| 261 |
+
builder_state: Optional[DatasetBuilder],
|
| 262 |
+
) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
|
| 263 |
+
"""Load an existing dataset JSON file for preprocessing.
|
| 264 |
+
|
| 265 |
+
This allows users to load a previously saved dataset and proceed to preprocessing
|
| 266 |
+
without having to re-scan and re-label.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Tuple of (status, table_data, slider_update, builder_state,
|
| 270 |
+
audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
|
| 271 |
+
"""
|
| 272 |
+
empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
|
| 273 |
+
|
| 274 |
+
if not dataset_path or not dataset_path.strip():
|
| 275 |
+
return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
|
| 276 |
+
|
| 277 |
+
dataset_path = dataset_path.strip()
|
| 278 |
+
|
| 279 |
+
if not os.path.exists(dataset_path):
|
| 280 |
+
return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
|
| 281 |
+
|
| 282 |
+
# Create new builder (don't reuse old state when loading a file)
|
| 283 |
+
builder = DatasetBuilder()
|
| 284 |
+
|
| 285 |
+
# Load the dataset
|
| 286 |
+
samples, status = builder.load_dataset(dataset_path)
|
| 287 |
+
|
| 288 |
+
if not samples:
|
| 289 |
+
return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
|
| 290 |
+
|
| 291 |
+
# Get table data
|
| 292 |
+
table_data = builder.get_samples_dataframe_data()
|
| 293 |
+
|
| 294 |
+
# Calculate slider max
|
| 295 |
+
slider_max = max(0, len(samples) - 1)
|
| 296 |
+
|
| 297 |
+
# Create info text
|
| 298 |
+
labeled_count = builder.get_labeled_count()
|
| 299 |
+
info = f"✅ Loaded dataset: {builder.metadata.name}\n"
|
| 300 |
+
info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
|
| 301 |
+
info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
| 302 |
+
info += "📝 Ready for preprocessing! You can also edit samples below."
|
| 303 |
+
|
| 304 |
+
# Get first sample preview
|
| 305 |
+
first_sample = builder.samples[0]
|
| 306 |
+
preview = (
|
| 307 |
+
first_sample.audio_path,
|
| 308 |
+
first_sample.filename,
|
| 309 |
+
first_sample.caption,
|
| 310 |
+
first_sample.lyrics,
|
| 311 |
+
first_sample.bpm,
|
| 312 |
+
first_sample.keyscale,
|
| 313 |
+
first_sample.timesignature,
|
| 314 |
+
first_sample.duration,
|
| 315 |
+
first_sample.language,
|
| 316 |
+
first_sample.is_instrumental,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def preprocess_dataset(
|
| 323 |
+
output_dir: str,
|
| 324 |
+
dit_handler,
|
| 325 |
+
builder_state: Optional[DatasetBuilder],
|
| 326 |
+
progress=None,
|
| 327 |
+
) -> str:
|
| 328 |
+
"""Preprocess dataset to tensor files for fast training.
|
| 329 |
+
|
| 330 |
+
This converts audio files to VAE latents and text to embeddings.
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Status message
|
| 334 |
+
"""
|
| 335 |
+
if builder_state is None:
|
| 336 |
+
return "❌ No dataset loaded. Please scan a directory first."
|
| 337 |
+
|
| 338 |
+
if not builder_state.samples:
|
| 339 |
+
return "❌ No samples in dataset."
|
| 340 |
+
|
| 341 |
+
labeled_count = builder_state.get_labeled_count()
|
| 342 |
+
if labeled_count == 0:
|
| 343 |
+
return "❌ No labeled samples. Please auto-label or manually label samples first."
|
| 344 |
+
|
| 345 |
+
if not output_dir or not output_dir.strip():
|
| 346 |
+
return "❌ Please enter an output directory."
|
| 347 |
+
|
| 348 |
+
if dit_handler is None or dit_handler.model is None:
|
| 349 |
+
return "❌ Model not initialized. Please initialize the service first."
|
| 350 |
+
|
| 351 |
+
def progress_callback(msg):
|
| 352 |
+
if progress:
|
| 353 |
+
try:
|
| 354 |
+
progress(msg)
|
| 355 |
+
except:
|
| 356 |
+
pass
|
| 357 |
+
|
| 358 |
+
# Run preprocessing
|
| 359 |
+
output_paths, status = builder_state.preprocess_to_tensors(
|
| 360 |
+
dit_handler=dit_handler,
|
| 361 |
+
output_dir=output_dir.strip(),
|
| 362 |
+
progress_callback=progress_callback,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
return status
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def load_training_dataset(
|
| 369 |
+
tensor_dir: str,
|
| 370 |
+
) -> str:
|
| 371 |
+
"""Load a preprocessed tensor dataset for training.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
Info text about the dataset
|
| 375 |
+
"""
|
| 376 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 377 |
+
return "❌ Please enter a tensor directory path"
|
| 378 |
+
|
| 379 |
+
tensor_dir = tensor_dir.strip()
|
| 380 |
+
|
| 381 |
+
if not os.path.exists(tensor_dir):
|
| 382 |
+
return f"❌ Directory not found: {tensor_dir}"
|
| 383 |
+
|
| 384 |
+
if not os.path.isdir(tensor_dir):
|
| 385 |
+
return f"❌ Not a directory: {tensor_dir}"
|
| 386 |
+
|
| 387 |
+
# Check for manifest
|
| 388 |
+
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
| 389 |
+
if os.path.exists(manifest_path):
|
| 390 |
+
try:
|
| 391 |
+
with open(manifest_path, 'r') as f:
|
| 392 |
+
manifest = json.load(f)
|
| 393 |
+
|
| 394 |
+
num_samples = manifest.get("num_samples", 0)
|
| 395 |
+
metadata = manifest.get("metadata", {})
|
| 396 |
+
name = metadata.get("name", "Unknown")
|
| 397 |
+
custom_tag = metadata.get("custom_tag", "")
|
| 398 |
+
|
| 399 |
+
info = f"✅ Loaded preprocessed dataset: {name}\n"
|
| 400 |
+
info += f"📊 Samples: {num_samples} preprocessed tensors\n"
|
| 401 |
+
info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
|
| 402 |
+
|
| 403 |
+
return info
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.warning(f"Failed to read manifest: {e}")
|
| 406 |
+
|
| 407 |
+
# Fallback: count .pt files
|
| 408 |
+
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
|
| 409 |
+
|
| 410 |
+
if not pt_files:
|
| 411 |
+
return f"❌ No .pt tensor files found in {tensor_dir}"
|
| 412 |
+
|
| 413 |
+
info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
| 414 |
+
info += "⚠️ No manifest.json found - using all .pt files"
|
| 415 |
+
|
| 416 |
+
return info
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
# Training handlers
|
| 420 |
+
|
| 421 |
+
import time
|
| 422 |
+
import re
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _format_duration(seconds):
|
| 426 |
+
"""Format seconds to human readable string."""
|
| 427 |
+
seconds = int(seconds)
|
| 428 |
+
if seconds < 60:
|
| 429 |
+
return f"{seconds}s"
|
| 430 |
+
elif seconds < 3600:
|
| 431 |
+
return f"{seconds // 60}m {seconds % 60}s"
|
| 432 |
+
else:
|
| 433 |
+
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def start_training(
|
| 437 |
+
tensor_dir: str,
|
| 438 |
+
dit_handler,
|
| 439 |
+
lora_rank: int,
|
| 440 |
+
lora_alpha: int,
|
| 441 |
+
lora_dropout: float,
|
| 442 |
+
learning_rate: float,
|
| 443 |
+
train_epochs: int,
|
| 444 |
+
train_batch_size: int,
|
| 445 |
+
gradient_accumulation: int,
|
| 446 |
+
save_every_n_epochs: int,
|
| 447 |
+
training_shift: float,
|
| 448 |
+
training_seed: int,
|
| 449 |
+
lora_output_dir: str,
|
| 450 |
+
training_state: Dict,
|
| 451 |
+
progress=None,
|
| 452 |
+
):
|
| 453 |
+
"""Start LoRA training from preprocessed tensors.
|
| 454 |
+
|
| 455 |
+
This is a generator function that yields progress updates.
|
| 456 |
+
"""
|
| 457 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 458 |
+
yield "❌ Please enter a tensor directory path", "", None, training_state
|
| 459 |
+
return
|
| 460 |
+
|
| 461 |
+
tensor_dir = tensor_dir.strip()
|
| 462 |
+
|
| 463 |
+
if not os.path.exists(tensor_dir):
|
| 464 |
+
yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
|
| 465 |
+
return
|
| 466 |
+
|
| 467 |
+
if dit_handler is None or dit_handler.model is None:
|
| 468 |
+
yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
|
| 469 |
+
return
|
| 470 |
+
|
| 471 |
+
# Check for required training dependencies
|
| 472 |
+
try:
|
| 473 |
+
from lightning.fabric import Fabric
|
| 474 |
+
from peft import get_peft_model, LoraConfig
|
| 475 |
+
except ImportError as e:
|
| 476 |
+
yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
| 477 |
+
return
|
| 478 |
+
|
| 479 |
+
training_state["is_training"] = True
|
| 480 |
+
training_state["should_stop"] = False
|
| 481 |
+
|
| 482 |
+
try:
|
| 483 |
+
from acestep.training.trainer import LoRATrainer
|
| 484 |
+
from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
|
| 485 |
+
|
| 486 |
+
# Create configs
|
| 487 |
+
lora_config = LoRAConfigClass(
|
| 488 |
+
r=lora_rank,
|
| 489 |
+
alpha=lora_alpha,
|
| 490 |
+
dropout=lora_dropout,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
training_config = TrainingConfig(
|
| 494 |
+
shift=training_shift,
|
| 495 |
+
learning_rate=learning_rate,
|
| 496 |
+
batch_size=train_batch_size,
|
| 497 |
+
gradient_accumulation_steps=gradient_accumulation,
|
| 498 |
+
max_epochs=train_epochs,
|
| 499 |
+
save_every_n_epochs=save_every_n_epochs,
|
| 500 |
+
seed=training_seed,
|
| 501 |
+
output_dir=lora_output_dir,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
import pandas as pd
|
| 505 |
+
|
| 506 |
+
# Initialize training log and loss history
|
| 507 |
+
log_lines = []
|
| 508 |
+
loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
|
| 509 |
+
|
| 510 |
+
# Start timer
|
| 511 |
+
start_time = time.time()
|
| 512 |
+
|
| 513 |
+
yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
|
| 514 |
+
|
| 515 |
+
# Create trainer
|
| 516 |
+
trainer = LoRATrainer(
|
| 517 |
+
dit_handler=dit_handler,
|
| 518 |
+
lora_config=lora_config,
|
| 519 |
+
training_config=training_config,
|
| 520 |
+
)
|
| 521 |
+
|
| 522 |
+
# Collect loss history
|
| 523 |
+
step_list = []
|
| 524 |
+
loss_list = []
|
| 525 |
+
|
| 526 |
+
# Train with progress updates using preprocessed tensors
|
| 527 |
+
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
|
| 528 |
+
# Calculate elapsed time and ETA
|
| 529 |
+
elapsed_seconds = time.time() - start_time
|
| 530 |
+
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
|
| 531 |
+
|
| 532 |
+
# Parse "Epoch x/y" from status to calculate ETA
|
| 533 |
+
match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
|
| 534 |
+
if match:
|
| 535 |
+
current_ep = int(match.group(1))
|
| 536 |
+
total_ep = int(match.group(2))
|
| 537 |
+
if current_ep > 0:
|
| 538 |
+
eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
|
| 539 |
+
time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
|
| 540 |
+
|
| 541 |
+
# Display status with time info
|
| 542 |
+
display_status = f"{status}\n{time_info}"
|
| 543 |
+
|
| 544 |
+
# Terminal log
|
| 545 |
+
log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
|
| 546 |
+
logger.info(log_msg)
|
| 547 |
+
|
| 548 |
+
# Add to UI log
|
| 549 |
+
log_lines.append(status)
|
| 550 |
+
if len(log_lines) > 15:
|
| 551 |
+
log_lines = log_lines[-15:]
|
| 552 |
+
log_text = "\n".join(log_lines)
|
| 553 |
+
|
| 554 |
+
# Track loss for plot (only valid values)
|
| 555 |
+
if step > 0 and loss is not None and loss == loss: # Check for NaN
|
| 556 |
+
step_list.append(step)
|
| 557 |
+
loss_list.append(float(loss))
|
| 558 |
+
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
|
| 559 |
+
|
| 560 |
+
yield display_status, log_text, loss_data, training_state
|
| 561 |
+
|
| 562 |
+
if training_state.get("should_stop", False):
|
| 563 |
+
logger.info("⏹️ Training stopped by user")
|
| 564 |
+
log_lines.append("⏹️ Training stopped by user")
|
| 565 |
+
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
| 566 |
+
break
|
| 567 |
+
|
| 568 |
+
total_time = time.time() - start_time
|
| 569 |
+
training_state["is_training"] = False
|
| 570 |
+
completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
|
| 571 |
+
|
| 572 |
+
logger.info(completion_msg)
|
| 573 |
+
log_lines.append(completion_msg)
|
| 574 |
+
|
| 575 |
+
yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
| 576 |
+
|
| 577 |
+
except Exception as e:
|
| 578 |
+
logger.exception("Training error")
|
| 579 |
+
training_state["is_training"] = False
|
| 580 |
+
import pandas as pd
|
| 581 |
+
empty_df = pd.DataFrame({"step": [], "loss": []})
|
| 582 |
+
yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
| 586 |
+
"""Stop the current training process.
|
| 587 |
+
|
| 588 |
+
Returns:
|
| 589 |
+
Tuple of (status, training_state)
|
| 590 |
+
"""
|
| 591 |
+
if not training_state.get("is_training", False):
|
| 592 |
+
return "⚠️ No training in progress", training_state
|
| 593 |
+
|
| 594 |
+
training_state["should_stop"] = True
|
| 595 |
+
return "⏹️ Stopping training...", training_state
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def export_lora(
|
| 599 |
+
export_path: str,
|
| 600 |
+
lora_output_dir: str,
|
| 601 |
+
) -> str:
|
| 602 |
+
"""Export the trained LoRA weights.
|
| 603 |
+
|
| 604 |
+
Returns:
|
| 605 |
+
Status message
|
| 606 |
+
"""
|
| 607 |
+
if not export_path or not export_path.strip():
|
| 608 |
+
return "❌ Please enter an export path"
|
| 609 |
+
|
| 610 |
+
# Check if there's a trained model to export
|
| 611 |
+
final_dir = os.path.join(lora_output_dir, "final")
|
| 612 |
+
checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
|
| 613 |
+
|
| 614 |
+
# Prefer final, fallback to checkpoints
|
| 615 |
+
if os.path.exists(final_dir):
|
| 616 |
+
source_path = final_dir
|
| 617 |
+
elif os.path.exists(checkpoint_dir):
|
| 618 |
+
# Find the latest checkpoint
|
| 619 |
+
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
|
| 620 |
+
if not checkpoints:
|
| 621 |
+
return "❌ No checkpoints found"
|
| 622 |
+
|
| 623 |
+
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
|
| 624 |
+
latest = checkpoints[-1]
|
| 625 |
+
source_path = os.path.join(checkpoint_dir, latest)
|
| 626 |
+
else:
|
| 627 |
+
return f"❌ No trained model found in {lora_output_dir}"
|
| 628 |
+
|
| 629 |
+
try:
|
| 630 |
+
import shutil
|
| 631 |
+
|
| 632 |
+
export_path = export_path.strip()
|
| 633 |
+
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
|
| 634 |
+
|
| 635 |
+
if os.path.exists(export_path):
|
| 636 |
+
shutil.rmtree(export_path)
|
| 637 |
+
|
| 638 |
+
shutil.copytree(source_path, export_path)
|
| 639 |
+
|
| 640 |
+
return f"✅ LoRA exported to {export_path}"
|
| 641 |
+
|
| 642 |
+
except Exception as e:
|
| 643 |
+
logger.exception("Export error")
|
| 644 |
+
return f"❌ Export failed: {str(e)}"
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Internationalization (i18n) module for Gradio UI
|
| 3 |
+
Supports multiple languages with easy translation management
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class I18n:
|
| 11 |
+
"""Internationalization handler"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, default_language: str = "en"):
|
| 14 |
+
"""
|
| 15 |
+
Initialize i18n handler
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
default_language: Default language code (en, zh, ja, etc.)
|
| 19 |
+
"""
|
| 20 |
+
self.current_language = default_language
|
| 21 |
+
self.translations: Dict[str, Dict[str, str]] = {}
|
| 22 |
+
self._load_all_translations()
|
| 23 |
+
|
| 24 |
+
def _load_all_translations(self):
|
| 25 |
+
"""Load all translation files from i18n directory"""
|
| 26 |
+
current_file = os.path.abspath(__file__)
|
| 27 |
+
module_dir = os.path.dirname(current_file)
|
| 28 |
+
i18n_dir = os.path.join(module_dir, "i18n")
|
| 29 |
+
|
| 30 |
+
if not os.path.exists(i18n_dir):
|
| 31 |
+
# Create i18n directory if it doesn't exist
|
| 32 |
+
os.makedirs(i18n_dir)
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# Load all JSON files in i18n directory
|
| 36 |
+
for filename in os.listdir(i18n_dir):
|
| 37 |
+
if filename.endswith(".json"):
|
| 38 |
+
lang_code = filename[:-5] # Remove .json extension
|
| 39 |
+
filepath = os.path.join(i18n_dir, filename)
|
| 40 |
+
try:
|
| 41 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 42 |
+
self.translations[lang_code] = json.load(f)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Error loading translation file {filename}: {e}")
|
| 45 |
+
|
| 46 |
+
def set_language(self, language: str):
|
| 47 |
+
"""Set current language"""
|
| 48 |
+
if language in self.translations:
|
| 49 |
+
self.current_language = language
|
| 50 |
+
else:
|
| 51 |
+
print(f"Warning: Language '{language}' not found, using default")
|
| 52 |
+
|
| 53 |
+
def t(self, key: str, **kwargs) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Translate a key to current language
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
key: Translation key (dot-separated for nested keys)
|
| 59 |
+
**kwargs: Optional format parameters
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Translated string
|
| 63 |
+
"""
|
| 64 |
+
# Get translation from current language
|
| 65 |
+
translation = self._get_nested_value(
|
| 66 |
+
self.translations.get(self.current_language, {}),
|
| 67 |
+
key
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Fallback to English if not found
|
| 71 |
+
if translation is None:
|
| 72 |
+
translation = self._get_nested_value(
|
| 73 |
+
self.translations.get('en', {}),
|
| 74 |
+
key
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Final fallback to key itself
|
| 78 |
+
if translation is None:
|
| 79 |
+
translation = key
|
| 80 |
+
|
| 81 |
+
# Apply formatting if kwargs provided
|
| 82 |
+
if kwargs:
|
| 83 |
+
try:
|
| 84 |
+
translation = translation.format(**kwargs)
|
| 85 |
+
except KeyError:
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
return translation
|
| 89 |
+
|
| 90 |
+
def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
|
| 91 |
+
"""
|
| 92 |
+
Get nested dictionary value using dot notation
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
data: Dictionary to search
|
| 96 |
+
key: Dot-separated key (e.g., "section.subsection.key")
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Value if found, None otherwise
|
| 100 |
+
"""
|
| 101 |
+
keys = key.split('.')
|
| 102 |
+
current = data
|
| 103 |
+
|
| 104 |
+
for k in keys:
|
| 105 |
+
if isinstance(current, dict) and k in current:
|
| 106 |
+
current = current[k]
|
| 107 |
+
else:
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
return current if isinstance(current, str) else None
|
| 111 |
+
|
| 112 |
+
def get_available_languages(self) -> list:
|
| 113 |
+
"""Get list of available language codes"""
|
| 114 |
+
return list(self.translations.keys())
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Global i18n instance
|
| 118 |
+
_i18n_instance: Optional[I18n] = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_i18n(language: Optional[str] = None) -> I18n:
|
| 122 |
+
"""
|
| 123 |
+
Get global i18n instance
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
language: Optional language to set
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
I18n instance
|
| 130 |
+
"""
|
| 131 |
+
global _i18n_instance
|
| 132 |
+
|
| 133 |
+
if _i18n_instance is None:
|
| 134 |
+
_i18n_instance = I18n(default_language=language or "en")
|
| 135 |
+
elif language is not None:
|
| 136 |
+
_i18n_instance.set_language(language)
|
| 137 |
+
|
| 138 |
+
return _i18n_instance
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def t(key: str, **kwargs) -> str:
|
| 142 |
+
"""
|
| 143 |
+
Convenience function for translation
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
key: Translation key
|
| 147 |
+
**kwargs: Optional format parameters
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Translated string
|
| 151 |
+
"""
|
| 152 |
+
return get_i18n().t(key, **kwargs)
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/en.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 Playground💡",
|
| 4 |
+
"subtitle": "Pushing the Boundaries of Open-Source Music Generation"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 Dataset Explorer",
|
| 8 |
+
"dataset_label": "Dataset",
|
| 9 |
+
"dataset_info": "Choose dataset to explore",
|
| 10 |
+
"import_btn": "📥 Import Dataset",
|
| 11 |
+
"search_type_label": "Search Type",
|
| 12 |
+
"search_type_info": "How to find items",
|
| 13 |
+
"search_value_label": "Search Value",
|
| 14 |
+
"search_value_placeholder": "Enter keys or index (leave empty for random)",
|
| 15 |
+
"search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
|
| 16 |
+
"instruction_label": "📝 Instruction",
|
| 17 |
+
"instruction_placeholder": "No instruction available",
|
| 18 |
+
"metadata_title": "📋 Item Metadata (JSON)",
|
| 19 |
+
"metadata_label": "Complete Item Information",
|
| 20 |
+
"source_audio": "Source Audio",
|
| 21 |
+
"target_audio": "Target Audio",
|
| 22 |
+
"reference_audio": "Reference Audio",
|
| 23 |
+
"get_item_btn": "🔍 Get Item",
|
| 24 |
+
"use_src_checkbox": "Use Source Audio from Dataset",
|
| 25 |
+
"use_src_info": "Check to use the source audio from dataset",
|
| 26 |
+
"data_status_label": "📊 Data Status",
|
| 27 |
+
"data_status_default": "❌ No dataset imported",
|
| 28 |
+
"autofill_btn": "📋 Auto-fill Generation Form"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 Service Configuration",
|
| 32 |
+
"checkpoint_label": "Checkpoint File",
|
| 33 |
+
"checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
|
| 34 |
+
"refresh_btn": "🔄 Refresh",
|
| 35 |
+
"model_path_label": "Main Model Path",
|
| 36 |
+
"model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
|
| 37 |
+
"device_label": "Device",
|
| 38 |
+
"device_info": "Processing device (auto-detect recommended)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM Model Path",
|
| 40 |
+
"lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
|
| 41 |
+
"backend_label": "5Hz LM Backend",
|
| 42 |
+
"backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
|
| 43 |
+
"init_llm_label": "Initialize 5Hz LM",
|
| 44 |
+
"init_llm_info": "Check to initialize 5Hz LM during service initialization",
|
| 45 |
+
"flash_attention_label": "Use Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
|
| 48 |
+
"offload_cpu_label": "Offload to CPU",
|
| 49 |
+
"offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
|
| 50 |
+
"offload_dit_cpu_label": "Offload DiT to CPU",
|
| 51 |
+
"offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
|
| 52 |
+
"init_btn": "Initialize Service",
|
| 53 |
+
"status_label": "Status",
|
| 54 |
+
"language_label": "UI Language",
|
| 55 |
+
"language_info": "Select interface language"
|
| 56 |
+
},
|
| 57 |
+
"generation": {
|
| 58 |
+
"required_inputs": "📝 Required Inputs",
|
| 59 |
+
"task_type_label": "Task Type",
|
| 60 |
+
"task_type_info": "Select the task type for generation",
|
| 61 |
+
"instruction_label": "Instruction",
|
| 62 |
+
"instruction_info": "Instruction is automatically generated based on task type",
|
| 63 |
+
"load_btn": "Load",
|
| 64 |
+
"track_name_label": "Track Name",
|
| 65 |
+
"track_name_info": "Select track name for lego/extract tasks",
|
| 66 |
+
"track_classes_label": "Track Names",
|
| 67 |
+
"track_classes_info": "Select multiple track classes for complete task",
|
| 68 |
+
"audio_uploads": "🎵 Audio Uploads",
|
| 69 |
+
"reference_audio": "Reference Audio (optional)",
|
| 70 |
+
"source_audio": "Source Audio (optional)",
|
| 71 |
+
"convert_codes_btn": "Convert to Codes",
|
| 72 |
+
"lm_codes_hints": "🎼 LM Codes Hints",
|
| 73 |
+
"lm_codes_label": "LM Codes Hints",
|
| 74 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 75 |
+
"lm_codes_info": "Paste LM codes hints for text2music generation",
|
| 76 |
+
"lm_codes_sample": "LM Codes Hints (Sample {n})",
|
| 77 |
+
"lm_codes_sample_info": "Codes for sample {n}",
|
| 78 |
+
"transcribe_btn": "Transcribe",
|
| 79 |
+
"repainting_controls": "🎨 Repainting Controls (seconds)",
|
| 80 |
+
"repainting_start": "Repainting Start",
|
| 81 |
+
"repainting_end": "Repainting End",
|
| 82 |
+
"mode_label": "Generation Mode",
|
| 83 |
+
"mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
|
| 84 |
+
"mode_simple": "Simple",
|
| 85 |
+
"mode_custom": "Custom",
|
| 86 |
+
"simple_query_label": "Song Description",
|
| 87 |
+
"simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
|
| 88 |
+
"simple_query_info": "Enter a natural language description of the music you want to generate",
|
| 89 |
+
"simple_vocal_language_label": "Vocal Language (optional)",
|
| 90 |
+
"simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
|
| 91 |
+
"create_sample_btn": "Create Sample",
|
| 92 |
+
"caption_title": "📝 Music Caption",
|
| 93 |
+
"caption_label": "Music Caption (optional)",
|
| 94 |
+
"caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
|
| 95 |
+
"caption_info": "Describe the style, genre, instruments, and mood",
|
| 96 |
+
"lyrics_title": "📝 Lyrics",
|
| 97 |
+
"lyrics_label": "Lyrics (optional)",
|
| 98 |
+
"lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
|
| 99 |
+
"lyrics_info": "Song lyrics with structure",
|
| 100 |
+
"instrumental_label": "Instrumental",
|
| 101 |
+
"format_btn": "Format",
|
| 102 |
+
"optional_params": "⚙️ Optional Parameters",
|
| 103 |
+
"vocal_language_label": "Vocal Language (optional)",
|
| 104 |
+
"vocal_language_info": "use `unknown` for inst",
|
| 105 |
+
"bpm_label": "BPM (optional)",
|
| 106 |
+
"bpm_info": "leave empty for N/A",
|
| 107 |
+
"keyscale_label": "KeyScale (optional)",
|
| 108 |
+
"keyscale_placeholder": "Leave empty for N/A",
|
| 109 |
+
"keyscale_info": "A-G, #/♭, major/minor",
|
| 110 |
+
"timesig_label": "Time Signature (optional)",
|
| 111 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 112 |
+
"duration_label": "Audio Duration (seconds)",
|
| 113 |
+
"duration_info": "Use -1 for random",
|
| 114 |
+
"batch_size_label": "Batch Size",
|
| 115 |
+
"batch_size_info": "Number of audio to generate (max 8)",
|
| 116 |
+
"advanced_settings": "🔧 Advanced Settings",
|
| 117 |
+
"inference_steps_label": "DiT Inference Steps",
|
| 118 |
+
"inference_steps_info": "Turbo: max 8, Base: max 200",
|
| 119 |
+
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
|
| 120 |
+
"guidance_scale_info": "Higher values follow text more closely",
|
| 121 |
+
"seed_label": "Seed",
|
| 122 |
+
"seed_info": "Use comma-separated values for batches",
|
| 123 |
+
"random_seed_label": "Random Seed",
|
| 124 |
+
"random_seed_info": "Enable to auto-generate seeds",
|
| 125 |
+
"audio_format_label": "Audio Format",
|
| 126 |
+
"audio_format_info": "Audio format for saved files",
|
| 127 |
+
"use_adg_label": "Use ADG",
|
| 128 |
+
"use_adg_info": "Enable Angle Domain Guidance",
|
| 129 |
+
"shift_label": "Shift",
|
| 130 |
+
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
| 131 |
+
"infer_method_label": "Inference Method",
|
| 132 |
+
"infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 133 |
+
"custom_timesteps_label": "Custom Timesteps",
|
| 134 |
+
"custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 135 |
+
"cfg_interval_start": "CFG Interval Start",
|
| 136 |
+
"cfg_interval_end": "CFG Interval End",
|
| 137 |
+
"lm_params_title": "🤖 LM Generation Parameters",
|
| 138 |
+
"lm_temperature_label": "LM Temperature",
|
| 139 |
+
"lm_temperature_info": "5Hz LM temperature (higher = more random)",
|
| 140 |
+
"lm_cfg_scale_label": "LM CFG Scale",
|
| 141 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
|
| 142 |
+
"lm_top_k_label": "LM Top-K",
|
| 143 |
+
"lm_top_k_info": "Top-K (0 = disabled)",
|
| 144 |
+
"lm_top_p_label": "LM Top-P",
|
| 145 |
+
"lm_top_p_info": "Top-P (1.0 = disabled)",
|
| 146 |
+
"lm_negative_prompt_label": "LM Negative Prompt",
|
| 147 |
+
"lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 148 |
+
"lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
|
| 149 |
+
"cot_metas_label": "CoT Metas",
|
| 150 |
+
"cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
|
| 151 |
+
"cot_language_label": "CoT Language",
|
| 152 |
+
"cot_language_info": "Generate language in CoT (chain-of-thought)",
|
| 153 |
+
"constrained_debug_label": "Constrained Decoding Debug",
|
| 154 |
+
"constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
|
| 155 |
+
"auto_score_label": "Auto Score",
|
| 156 |
+
"auto_score_info": "Automatically calculate quality scores for all generated audios",
|
| 157 |
+
"auto_lrc_label": "Auto LRC",
|
| 158 |
+
"auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
|
| 159 |
+
"lm_batch_chunk_label": "LM Batch Chunk Size",
|
| 160 |
+
"lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
|
| 161 |
+
"codes_strength_label": "LM Codes Strength",
|
| 162 |
+
"codes_strength_info": "Control how many denoising steps use LM-generated codes",
|
| 163 |
+
"cover_strength_label": "Audio Cover Strength",
|
| 164 |
+
"cover_strength_info": "Control how many denoising steps use cover mode",
|
| 165 |
+
"score_sensitivity_label": "Quality Score Sensitivity",
|
| 166 |
+
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
|
| 167 |
+
"think_label": "Think",
|
| 168 |
+
"parallel_thinking_label": "ParallelThinking",
|
| 169 |
+
"generate_btn": "🎵 Generate Music",
|
| 170 |
+
"autogen_label": "AutoGen",
|
| 171 |
+
"caption_rewrite_label": "CaptionRewrite"
|
| 172 |
+
},
|
| 173 |
+
"results": {
|
| 174 |
+
"title": "🎵 Results",
|
| 175 |
+
"generated_music": "🎵 Generated Music (Sample {n})",
|
| 176 |
+
"send_to_src_btn": "🔗 Send To Src Audio",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
+
"save_btn": "💾 Save",
|
| 180 |
+
"score_btn": "📊 Score",
|
| 181 |
+
"lrc_btn": "🎵 LRC",
|
| 182 |
+
"quality_score_label": "Quality Score (Sample {n})",
|
| 183 |
+
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
|
| 184 |
+
"codes_label": "LM Codes (Sample {n})",
|
| 185 |
+
"lrc_label": "Lyrics Timestamps (Sample {n})",
|
| 186 |
+
"lrc_placeholder": "Click 'LRC' to generate timestamps",
|
| 187 |
+
"details_accordion": "📊 Score & LRC & LM Codes",
|
| 188 |
+
"generation_status": "Generation Status",
|
| 189 |
+
"current_batch": "Current Batch",
|
| 190 |
+
"batch_indicator": "Batch {current} / {total}",
|
| 191 |
+
"next_batch_status": "Next Batch Status",
|
| 192 |
+
"prev_btn": "◀ Previous",
|
| 193 |
+
"next_btn": "Next ▶",
|
| 194 |
+
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
|
| 195 |
+
"batch_results_title": "📁 Batch Results & Generation Details",
|
| 196 |
+
"all_files_label": "📁 All Generated Files (Download)",
|
| 197 |
+
"generation_details": "Generation Details"
|
| 198 |
+
},
|
| 199 |
+
"messages": {
|
| 200 |
+
"no_audio_to_save": "❌ No audio to save",
|
| 201 |
+
"save_success": "✅ Saved audio and metadata to {filename}",
|
| 202 |
+
"save_failed": "❌ Failed to save: {error}",
|
| 203 |
+
"no_file_selected": "⚠️ No file selected",
|
| 204 |
+
"params_loaded": "✅ Parameters loaded from {filename}",
|
| 205 |
+
"invalid_json": "❌ Invalid JSON file: {error}",
|
| 206 |
+
"load_error": "❌ Error loading file: {error}",
|
| 207 |
+
"example_loaded": "📁 Loaded example from {filename}",
|
| 208 |
+
"example_failed": "Failed to parse JSON file {filename}: {error}",
|
| 209 |
+
"example_error": "Error loading example: {error}",
|
| 210 |
+
"lm_generated": "🤖 Generated example using LM",
|
| 211 |
+
"lm_fallback": "Failed to generate example using LM, falling back to examples directory",
|
| 212 |
+
"lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
|
| 213 |
+
"autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
|
| 214 |
+
"batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
|
| 215 |
+
"batch_generating": "🔄 Starting background generation for Batch {n}...",
|
| 216 |
+
"batch_failed": "❌ Background generation failed: {error}",
|
| 217 |
+
"viewing_batch": "✅ Viewing Batch {n}",
|
| 218 |
+
"at_first_batch": "Already at first batch",
|
| 219 |
+
"at_last_batch": "No next batch available",
|
| 220 |
+
"batch_not_found": "Batch {n} not found in queue",
|
| 221 |
+
"no_batch_data": "No batch data found to restore.",
|
| 222 |
+
"params_restored": "✅ UI Parameters restored from Batch {n}",
|
| 223 |
+
"scoring_failed": "❌ Error: Batch data not found",
|
| 224 |
+
"no_codes": "❌ No audio codes available. Please generate music first.",
|
| 225 |
+
"score_failed": "❌ Scoring failed: {error}",
|
| 226 |
+
"score_error": "❌ Error calculating score: {error}",
|
| 227 |
+
"lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
|
| 228 |
+
"lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
|
| 229 |
+
"lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
|
| 230 |
+
"lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
|
| 231 |
+
"lrc_empty_result": "⚠️ LRC generation produced empty result.",
|
| 232 |
+
"empty_query": "⚠️ Please enter a music description.",
|
| 233 |
+
"sample_creation_failed": "❌ Failed to create sample. Please try again.",
|
| 234 |
+
"sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
|
| 235 |
+
"simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
|
| 236 |
+
"simple_examples_empty": "⚠️ No example files found in simple mode examples.",
|
| 237 |
+
"simple_example_loaded": "🎲 Loaded random example from {filename}",
|
| 238 |
+
"format_success": "✅ Caption and lyrics formatted successfully",
|
| 239 |
+
"format_failed": "❌ Format failed: {error}",
|
| 240 |
+
"skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
|
| 241 |
+
"invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
|
| 242 |
+
"timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
|
| 243 |
+
"timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
|
| 244 |
+
}
|
| 245 |
+
}
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/ja.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
|
| 4 |
+
"subtitle": "オープンソース音楽生成の限界を押し広げる"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 データセットエクスプローラー",
|
| 8 |
+
"dataset_label": "データセット",
|
| 9 |
+
"dataset_info": "探索するデータセットを選択",
|
| 10 |
+
"import_btn": "📥 データセットをインポート",
|
| 11 |
+
"search_type_label": "検索タイプ",
|
| 12 |
+
"search_type_info": "アイテムの検索方法",
|
| 13 |
+
"search_value_label": "検索値",
|
| 14 |
+
"search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
|
| 15 |
+
"search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
|
| 16 |
+
"instruction_label": "📝 指示",
|
| 17 |
+
"instruction_placeholder": "利用可能な指示がありません",
|
| 18 |
+
"metadata_title": "📋 アイテムメタデータ (JSON)",
|
| 19 |
+
"metadata_label": "完全なアイテム情報",
|
| 20 |
+
"source_audio": "ソースオーディオ",
|
| 21 |
+
"target_audio": "ターゲットオーディオ",
|
| 22 |
+
"reference_audio": "リファレンスオーディオ",
|
| 23 |
+
"get_item_btn": "🔍 アイテムを取得",
|
| 24 |
+
"use_src_checkbox": "データセットのソースオーディオを使用",
|
| 25 |
+
"use_src_info": "データセットのソースオーディオを使用する場合はチェック",
|
| 26 |
+
"data_status_label": "📊 データステータス",
|
| 27 |
+
"data_status_default": "❌ データセットがインポートされていません",
|
| 28 |
+
"autofill_btn": "📋 生成フォームを自動入力"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 サービス設定",
|
| 32 |
+
"checkpoint_label": "チェックポイントファイル",
|
| 33 |
+
"checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
|
| 34 |
+
"refresh_btn": "🔄 更新",
|
| 35 |
+
"model_path_label": "メインモデルパス",
|
| 36 |
+
"model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
|
| 37 |
+
"device_label": "デバイス",
|
| 38 |
+
"device_info": "処理デバイス(自動検出を推奨)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM モデルパス",
|
| 40 |
+
"lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
|
| 41 |
+
"backend_label": "5Hz LM バックエンド",
|
| 42 |
+
"backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
|
| 43 |
+
"init_llm_label": "5Hz LM を初期化",
|
| 44 |
+
"init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
|
| 45 |
+
"flash_attention_label": "Flash Attention を使用",
|
| 46 |
+
"flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
|
| 48 |
+
"offload_cpu_label": "CPUにオフロード",
|
| 49 |
+
"offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
|
| 50 |
+
"offload_dit_cpu_label": "DiTをCPUにオフロード",
|
| 51 |
+
"offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
|
| 52 |
+
"init_btn": "サービスを初期化",
|
| 53 |
+
"status_label": "ステータス",
|
| 54 |
+
"language_label": "UI言語",
|
| 55 |
+
"language_info": "インターフェース言語を選択"
|
| 56 |
+
},
|
| 57 |
+
"generation": {
|
| 58 |
+
"required_inputs": "📝 必須入力",
|
| 59 |
+
"task_type_label": "タスクタイプ",
|
| 60 |
+
"task_type_info": "生成のタスクタイプを選択",
|
| 61 |
+
"instruction_label": "指示",
|
| 62 |
+
"instruction_info": "指示はタスクタイプに基づいて自動生成されます",
|
| 63 |
+
"load_btn": "読み込む",
|
| 64 |
+
"track_name_label": "トラック名",
|
| 65 |
+
"track_name_info": "lego/extractタスクのトラック名を選択",
|
| 66 |
+
"track_classes_label": "トラック名",
|
| 67 |
+
"track_classes_info": "completeタスクの複数のトラッククラスを選択",
|
| 68 |
+
"audio_uploads": "🎵 オーディオアップロード",
|
| 69 |
+
"reference_audio": "リファレンスオーディオ(オプション)",
|
| 70 |
+
"source_audio": "ソースオーディオ(オプション)",
|
| 71 |
+
"convert_codes_btn": "コードに変換",
|
| 72 |
+
"lm_codes_hints": "🎼 LM コードヒント",
|
| 73 |
+
"lm_codes_label": "LM コードヒント",
|
| 74 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 75 |
+
"lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
|
| 76 |
+
"lm_codes_sample": "LM コードヒント(サンプル {n})",
|
| 77 |
+
"lm_codes_sample_info": "サンプル{n}のコード",
|
| 78 |
+
"transcribe_btn": "転写",
|
| 79 |
+
"repainting_controls": "🎨 再描画コントロール(秒)",
|
| 80 |
+
"repainting_start": "再描画開始",
|
| 81 |
+
"repainting_end": "再描画終了",
|
| 82 |
+
"mode_label": "生成モード",
|
| 83 |
+
"mode_info": "シンプル:自然言語で音楽を説明��カスタム:キャプションと歌詞を完全にコントロール。",
|
| 84 |
+
"mode_simple": "シンプル",
|
| 85 |
+
"mode_custom": "カスタム",
|
| 86 |
+
"simple_query_label": "曲の説明",
|
| 87 |
+
"simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
|
| 88 |
+
"simple_query_info": "生成したい音楽の自然言語の説明を入力",
|
| 89 |
+
"simple_vocal_language_label": "ボーカル言語(オプション)",
|
| 90 |
+
"simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
|
| 91 |
+
"create_sample_btn": "サンプル作成",
|
| 92 |
+
"caption_title": "📝 音楽キャプション",
|
| 93 |
+
"caption_label": "音楽キャプション(オプション)",
|
| 94 |
+
"caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
|
| 95 |
+
"caption_info": "スタイル、ジャンル、楽器、ムードを説明",
|
| 96 |
+
"lyrics_title": "📝 歌詞",
|
| 97 |
+
"lyrics_label": "歌詞(オプション)",
|
| 98 |
+
"lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
|
| 99 |
+
"lyrics_info": "構造を持つ曲の歌詞",
|
| 100 |
+
"instrumental_label": "インストゥルメンタル",
|
| 101 |
+
"format_btn": "フォーマット",
|
| 102 |
+
"optional_params": "⚙️ オプションパラメータ",
|
| 103 |
+
"vocal_language_label": "ボーカル言語(オプション)",
|
| 104 |
+
"vocal_language_info": "インストには`unknown`を使用",
|
| 105 |
+
"bpm_label": "BPM(オプション)",
|
| 106 |
+
"bpm_info": "空白の場合はN/A",
|
| 107 |
+
"keyscale_label": "キースケール(オプション)",
|
| 108 |
+
"keyscale_placeholder": "空白の場合はN/A",
|
| 109 |
+
"keyscale_info": "A-G, #/♭, メジャー/マイナー",
|
| 110 |
+
"timesig_label": "拍子記号(オプション)",
|
| 111 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 112 |
+
"duration_label": "オーディオ長(秒)",
|
| 113 |
+
"duration_info": "ランダムの場合は-1を使用",
|
| 114 |
+
"batch_size_label": "バッチサイズ",
|
| 115 |
+
"batch_size_info": "生成するオーディオの数(最大8)",
|
| 116 |
+
"advanced_settings": "🔧 詳細設定",
|
| 117 |
+
"inference_steps_label": "DiT 推論ステップ",
|
| 118 |
+
"inference_steps_info": "Turbo: 最大8、Base: 最大200",
|
| 119 |
+
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
|
| 120 |
+
"guidance_scale_info": "値が高いほどテキストに忠実に従う",
|
| 121 |
+
"seed_label": "シード",
|
| 122 |
+
"seed_info": "バッチにはカンマ区切りの値を使用",
|
| 123 |
+
"random_seed_label": "ランダムシード",
|
| 124 |
+
"random_seed_info": "有効にすると自動的にシードを生成",
|
| 125 |
+
"audio_format_label": "オーディオフォーマット",
|
| 126 |
+
"audio_format_info": "保存ファイルのオーディオフォーマット",
|
| 127 |
+
"use_adg_label": "ADG を使用",
|
| 128 |
+
"use_adg_info": "角度ドメインガイダンスを有効化",
|
| 129 |
+
"shift_label": "シフト",
|
| 130 |
+
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
| 131 |
+
"infer_method_label": "推論方法",
|
| 132 |
+
"infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
|
| 133 |
+
"custom_timesteps_label": "カスタムタイムステップ",
|
| 134 |
+
"custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
|
| 135 |
+
"cfg_interval_start": "CFG 間隔開始",
|
| 136 |
+
"cfg_interval_end": "CFG 間隔終了",
|
| 137 |
+
"lm_params_title": "🤖 LM 生成パラメータ",
|
| 138 |
+
"lm_temperature_label": "LM 温度",
|
| 139 |
+
"lm_temperature_info": "5Hz LM温度(高いほどランダム)",
|
| 140 |
+
"lm_cfg_scale_label": "LM CFG スケール",
|
| 141 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
|
| 142 |
+
"lm_top_k_label": "LM Top-K",
|
| 143 |
+
"lm_top_k_info": "Top-K (0 = 無効)",
|
| 144 |
+
"lm_top_p_label": "LM Top-P",
|
| 145 |
+
"lm_top_p_info": "Top-P (1.0 = 無効)",
|
| 146 |
+
"lm_negative_prompt_label": "LM ネガティブプロンプト",
|
| 147 |
+
"lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
|
| 148 |
+
"lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
|
| 149 |
+
"cot_metas_label": "CoT メタデータ",
|
| 150 |
+
"cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
|
| 151 |
+
"cot_language_label": "CoT 言語",
|
| 152 |
+
"cot_language_info": "CoTで言語を生成(思考の連鎖)",
|
| 153 |
+
"constrained_debug_label": "制約付きデコーディングデバッグ",
|
| 154 |
+
"constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
|
| 155 |
+
"auto_score_label": "自動スコアリング",
|
| 156 |
+
"auto_score_info": "生成���れたすべてのオーディオの品質スコアを自動計算",
|
| 157 |
+
"auto_lrc_label": "自動 LRC",
|
| 158 |
+
"auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
|
| 159 |
+
"lm_batch_chunk_label": "LM バッチチャンクサイズ",
|
| 160 |
+
"lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
|
| 161 |
+
"codes_strength_label": "LM コード強度",
|
| 162 |
+
"codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
|
| 163 |
+
"cover_strength_label": "オーディオカバー強度",
|
| 164 |
+
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
|
| 165 |
+
"score_sensitivity_label": "品質スコア感度",
|
| 166 |
+
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
|
| 167 |
+
"think_label": "思考",
|
| 168 |
+
"parallel_thinking_label": "並列思考",
|
| 169 |
+
"generate_btn": "🎵 音楽を生成",
|
| 170 |
+
"autogen_label": "自動生成",
|
| 171 |
+
"caption_rewrite_label": "キャプション書き換え"
|
| 172 |
+
},
|
| 173 |
+
"results": {
|
| 174 |
+
"title": "🎵 結果",
|
| 175 |
+
"generated_music": "🎵 生成された音楽(サンプル {n})",
|
| 176 |
+
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
+
"save_btn": "💾 保存",
|
| 180 |
+
"score_btn": "📊 スコア",
|
| 181 |
+
"lrc_btn": "🎵 LRC",
|
| 182 |
+
"quality_score_label": "品質スコア(サンプル {n})",
|
| 183 |
+
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
|
| 184 |
+
"codes_label": "LM コード(サンプル {n})",
|
| 185 |
+
"lrc_label": "歌詞タイムスタンプ(サンプル {n})",
|
| 186 |
+
"lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
|
| 187 |
+
"details_accordion": "📊 スコア & LRC & LM コード",
|
| 188 |
+
"generation_status": "生成ステータス",
|
| 189 |
+
"current_batch": "現在のバッチ",
|
| 190 |
+
"batch_indicator": "バッチ {current} / {total}",
|
| 191 |
+
"next_batch_status": "次のバッチステータス",
|
| 192 |
+
"prev_btn": "◀ 前へ",
|
| 193 |
+
"next_btn": "次へ ▶",
|
| 194 |
+
"restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
|
| 195 |
+
"batch_results_title": "📁 バッチ結果と生成詳細",
|
| 196 |
+
"all_files_label": "📁 すべての生成ファイル(ダウンロード)",
|
| 197 |
+
"generation_details": "生成詳細"
|
| 198 |
+
},
|
| 199 |
+
"messages": {
|
| 200 |
+
"no_audio_to_save": "❌ 保存するオーディオがありません",
|
| 201 |
+
"save_success": "✅ オーディオとメタデータを {filename} に保存しました",
|
| 202 |
+
"save_failed": "❌ 保存に失敗しました: {error}",
|
| 203 |
+
"no_file_selected": "⚠️ ファイルが選択されていません",
|
| 204 |
+
"params_loaded": "✅ {filename} からパラメータを読み込みました",
|
| 205 |
+
"invalid_json": "❌ 無効なJSONファイル: {error}",
|
| 206 |
+
"load_error": "❌ ファイルの読み込みエラー: {error}",
|
| 207 |
+
"example_loaded": "📁 {filename} からサンプルを読み込みました",
|
| 208 |
+
"example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
|
| 209 |
+
"example_error": "サンプル読み込みエラー: {error}",
|
| 210 |
+
"lm_generated": "🤖 LMを使用してサンプルを生成しました",
|
| 211 |
+
"lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
|
| 212 |
+
"lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
|
| 213 |
+
"autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
|
| 214 |
+
"batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
|
| 215 |
+
"batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
|
| 216 |
+
"batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
|
| 217 |
+
"viewing_batch": "✅ バッチ {n} を表示中",
|
| 218 |
+
"at_first_batch": "すでに最初のバッチです",
|
| 219 |
+
"at_last_batch": "次のバッチはありません",
|
| 220 |
+
"batch_not_found": "キューにバッチ {n} が見つかりません",
|
| 221 |
+
"no_batch_data": "復元するバッチデータがありません。",
|
| 222 |
+
"params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
|
| 223 |
+
"scoring_failed": "❌ エラー: バッチデータが見つかりません",
|
| 224 |
+
"no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
|
| 225 |
+
"score_failed": "❌ スコアリングに失敗しました: {error}",
|
| 226 |
+
"score_error": "❌ スコア計算エラー: {error}",
|
| 227 |
+
"lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
|
| 228 |
+
"lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
|
| 229 |
+
"lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
|
| 230 |
+
"lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
|
| 231 |
+
"lrc_empty_result": "⚠️ LRC生成の結果が空です。",
|
| 232 |
+
"empty_query": "⚠️ 音楽の説明を入力してください。",
|
| 233 |
+
"sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
|
| 234 |
+
"sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
|
| 235 |
+
"simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
|
| 236 |
+
"simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
|
| 237 |
+
"simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
|
| 238 |
+
"format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
|
| 239 |
+
"format_failed": "❌ フォーマットに失敗しました: {error}",
|
| 240 |
+
"skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
|
| 241 |
+
"invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
|
| 242 |
+
"timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
|
| 243 |
+
"timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
|
| 244 |
+
}
|
| 245 |
+
}
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/i18n/zh.json
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 演练场💡",
|
| 4 |
+
"subtitle": "推动开源音乐生成的边界"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 数据集浏览器",
|
| 8 |
+
"dataset_label": "数据集",
|
| 9 |
+
"dataset_info": "选择要浏览的数据集",
|
| 10 |
+
"import_btn": "📥 导入数据集",
|
| 11 |
+
"search_type_label": "搜索类型",
|
| 12 |
+
"search_type_info": "如何查找项目",
|
| 13 |
+
"search_value_label": "搜索值",
|
| 14 |
+
"search_value_placeholder": "输入键或索引(留空表示随机)",
|
| 15 |
+
"search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
|
| 16 |
+
"instruction_label": "📝 指令",
|
| 17 |
+
"instruction_placeholder": "无可用指令",
|
| 18 |
+
"metadata_title": "📋 项目元数据 (JSON)",
|
| 19 |
+
"metadata_label": "完整项目信息",
|
| 20 |
+
"source_audio": "源音频",
|
| 21 |
+
"target_audio": "目标音频",
|
| 22 |
+
"reference_audio": "参考音频",
|
| 23 |
+
"get_item_btn": "🔍 获取项目",
|
| 24 |
+
"use_src_checkbox": "使用数据集中的源音频",
|
| 25 |
+
"use_src_info": "勾选以使用数据集中的源音频",
|
| 26 |
+
"data_status_label": "📊 数据状态",
|
| 27 |
+
"data_status_default": "❌ 未导入数据集",
|
| 28 |
+
"autofill_btn": "📋 自动填充生成表单"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 服务配置",
|
| 32 |
+
"checkpoint_label": "检查点文件",
|
| 33 |
+
"checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
|
| 34 |
+
"refresh_btn": "🔄 刷新",
|
| 35 |
+
"model_path_label": "主模型路径",
|
| 36 |
+
"model_path_info": "选择模型配置目录(从检查点自动扫描)",
|
| 37 |
+
"device_label": "设备",
|
| 38 |
+
"device_info": "处理设备(建议自动检测)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM 模型路径",
|
| 40 |
+
"lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
|
| 41 |
+
"backend_label": "5Hz LM 后端",
|
| 42 |
+
"backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
|
| 43 |
+
"init_llm_label": "初始化 5Hz LM",
|
| 44 |
+
"init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
|
| 45 |
+
"flash_attention_label": "使用Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
|
| 48 |
+
"offload_cpu_label": "卸载到CPU",
|
| 49 |
+
"offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
|
| 50 |
+
"offload_dit_cpu_label": "将DiT卸载到CPU",
|
| 51 |
+
"offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
|
| 52 |
+
"init_btn": "初始化服务",
|
| 53 |
+
"status_label": "状态",
|
| 54 |
+
"language_label": "界面语言",
|
| 55 |
+
"language_info": "选择界面语言"
|
| 56 |
+
},
|
| 57 |
+
"generation": {
|
| 58 |
+
"required_inputs": "📝 必需输入",
|
| 59 |
+
"task_type_label": "任务类型",
|
| 60 |
+
"task_type_info": "选择生成的任务类型",
|
| 61 |
+
"instruction_label": "指令",
|
| 62 |
+
"instruction_info": "指令根据任务类型自动生成",
|
| 63 |
+
"load_btn": "加载",
|
| 64 |
+
"track_name_label": "音轨名称",
|
| 65 |
+
"track_name_info": "为lego/extract任务选择音轨名称",
|
| 66 |
+
"track_classes_label": "音轨名称",
|
| 67 |
+
"track_classes_info": "为complete任务选择多个音轨类别",
|
| 68 |
+
"audio_uploads": "🎵 音频上传",
|
| 69 |
+
"reference_audio": "参考音频(可选)",
|
| 70 |
+
"source_audio": "源音频(可选)",
|
| 71 |
+
"convert_codes_btn": "转换为代码",
|
| 72 |
+
"lm_codes_hints": "🎼 LM 代码提示",
|
| 73 |
+
"lm_codes_label": "LM 代码提示",
|
| 74 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 75 |
+
"lm_codes_info": "粘贴用于text2music生成的LM代码提示",
|
| 76 |
+
"lm_codes_sample": "LM 代码提示(样本 {n})",
|
| 77 |
+
"lm_codes_sample_info": "样本{n}的代码",
|
| 78 |
+
"transcribe_btn": "转录",
|
| 79 |
+
"repainting_controls": "🎨 重绘控制(秒)",
|
| 80 |
+
"repainting_start": "重绘开始",
|
| 81 |
+
"repainting_end": "重绘结束",
|
| 82 |
+
"mode_label": "生成模式",
|
| 83 |
+
"mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
|
| 84 |
+
"mode_simple": "简单",
|
| 85 |
+
"mode_custom": "自定义",
|
| 86 |
+
"simple_query_label": "歌曲描述",
|
| 87 |
+
"simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
|
| 88 |
+
"simple_query_info": "输入你想生成的音乐的自然语言描述",
|
| 89 |
+
"simple_vocal_language_label": "人声语言(可选)",
|
| 90 |
+
"simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
|
| 91 |
+
"create_sample_btn": "创建样本",
|
| 92 |
+
"caption_title": "📝 音乐描述",
|
| 93 |
+
"caption_label": "音乐描述(可选)",
|
| 94 |
+
"caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
|
| 95 |
+
"caption_info": "描述风格、流派、乐器和情绪",
|
| 96 |
+
"lyrics_title": "📝 歌词",
|
| 97 |
+
"lyrics_label": "歌词(可选)",
|
| 98 |
+
"lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
|
| 99 |
+
"lyrics_info": "带有结构的歌曲歌词",
|
| 100 |
+
"instrumental_label": "纯音乐",
|
| 101 |
+
"format_btn": "格式化",
|
| 102 |
+
"optional_params": "⚙️ 可选参数",
|
| 103 |
+
"vocal_language_label": "人声语言(可选)",
|
| 104 |
+
"vocal_language_info": "纯音乐使用 `unknown`",
|
| 105 |
+
"bpm_label": "BPM(可选)",
|
| 106 |
+
"bpm_info": "留空表示N/A",
|
| 107 |
+
"keyscale_label": "调性(可选)",
|
| 108 |
+
"keyscale_placeholder": "留空表示N/A",
|
| 109 |
+
"keyscale_info": "A-G, #/♭, 大调/小调",
|
| 110 |
+
"timesig_label": "拍号(可选)",
|
| 111 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 112 |
+
"duration_label": "音频时长(秒)",
|
| 113 |
+
"duration_info": "使用-1表示随机",
|
| 114 |
+
"batch_size_label": "批量大小",
|
| 115 |
+
"batch_size_info": "要生成的音频数量(最多8个)",
|
| 116 |
+
"advanced_settings": "🔧 高级设置",
|
| 117 |
+
"inference_steps_label": "DiT 推理步数",
|
| 118 |
+
"inference_steps_info": "Turbo: 最多8, Base: 最多200",
|
| 119 |
+
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
|
| 120 |
+
"guidance_scale_info": "更高的值更紧密地遵循文本",
|
| 121 |
+
"seed_label": "种子",
|
| 122 |
+
"seed_info": "批量使用逗号分隔的值",
|
| 123 |
+
"random_seed_label": "随机种子",
|
| 124 |
+
"random_seed_info": "启用以自动生成种子",
|
| 125 |
+
"audio_format_label": "音频格式",
|
| 126 |
+
"audio_format_info": "保存文件的音频格式",
|
| 127 |
+
"use_adg_label": "使用 ADG",
|
| 128 |
+
"use_adg_info": "启用角域引导",
|
| 129 |
+
"shift_label": "Shift",
|
| 130 |
+
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
| 131 |
+
"infer_method_label": "推理方法",
|
| 132 |
+
"infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
|
| 133 |
+
"custom_timesteps_label": "自定义时间步",
|
| 134 |
+
"custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
|
| 135 |
+
"cfg_interval_start": "CFG 间隔开始",
|
| 136 |
+
"cfg_interval_end": "CFG 间隔结束",
|
| 137 |
+
"lm_params_title": "🤖 LM 生成参数",
|
| 138 |
+
"lm_temperature_label": "LM 温度",
|
| 139 |
+
"lm_temperature_info": "5Hz LM温度(越高越随机)",
|
| 140 |
+
"lm_cfg_scale_label": "LM CFG 比例",
|
| 141 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
|
| 142 |
+
"lm_top_k_label": "LM Top-K",
|
| 143 |
+
"lm_top_k_info": "Top-K (0 = 禁用)",
|
| 144 |
+
"lm_top_p_label": "LM Top-P",
|
| 145 |
+
"lm_top_p_info": "Top-P (1.0 = 禁用)",
|
| 146 |
+
"lm_negative_prompt_label": "LM 负面提示",
|
| 147 |
+
"lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
|
| 148 |
+
"lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
|
| 149 |
+
"cot_metas_label": "CoT 元数据",
|
| 150 |
+
"cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
|
| 151 |
+
"cot_language_label": "CoT 语言",
|
| 152 |
+
"cot_language_info": "在CoT中生成语言(思维链)",
|
| 153 |
+
"constrained_debug_label": "约束解码调试",
|
| 154 |
+
"constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
|
| 155 |
+
"auto_score_label": "自动评分",
|
| 156 |
+
"auto_score_info": "自动计算所有生成音频的质量分数",
|
| 157 |
+
"auto_lrc_label": "自动 LRC",
|
| 158 |
+
"auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
|
| 159 |
+
"lm_batch_chunk_label": "LM 批量块大小",
|
| 160 |
+
"lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
|
| 161 |
+
"codes_strength_label": "LM 代码强度",
|
| 162 |
+
"codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
|
| 163 |
+
"cover_strength_label": "音频覆盖强度",
|
| 164 |
+
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
|
| 165 |
+
"score_sensitivity_label": "质量评分敏感度",
|
| 166 |
+
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
|
| 167 |
+
"think_label": "思考",
|
| 168 |
+
"parallel_thinking_label": "并行思考",
|
| 169 |
+
"generate_btn": "🎵 生成音乐",
|
| 170 |
+
"autogen_label": "自动生成",
|
| 171 |
+
"caption_rewrite_label": "描述重写"
|
| 172 |
+
},
|
| 173 |
+
"results": {
|
| 174 |
+
"title": "🎵 结果",
|
| 175 |
+
"generated_music": "🎵 生成的音乐(样本 {n})",
|
| 176 |
+
"send_to_src_btn": "🔗 发送到源音频",
|
| 177 |
+
"send_to_cover_btn": "🔗 Send To Cover",
|
| 178 |
+
"send_to_repaint_btn": "🔗 Send To Repaint",
|
| 179 |
+
"save_btn": "💾 保存",
|
| 180 |
+
"score_btn": "📊 评分",
|
| 181 |
+
"lrc_btn": "🎵 LRC",
|
| 182 |
+
"quality_score_label": "质量分数(样本 {n})",
|
| 183 |
+
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
|
| 184 |
+
"codes_label": "LM 代码(样本 {n})",
|
| 185 |
+
"lrc_label": "歌词时间戳(样本 {n})",
|
| 186 |
+
"lrc_placeholder": "点击'LRC'生成时间戳",
|
| 187 |
+
"details_accordion": "📊 评分与LRC与LM代码",
|
| 188 |
+
"generation_status": "生成状态",
|
| 189 |
+
"current_batch": "当前批次",
|
| 190 |
+
"batch_indicator": "批次 {current} / {total}",
|
| 191 |
+
"next_batch_status": "下一批次状态",
|
| 192 |
+
"prev_btn": "◀ 上一个",
|
| 193 |
+
"next_btn": "下一个 ▶",
|
| 194 |
+
"restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
|
| 195 |
+
"batch_results_title": "📁 批量结果和生成详情",
|
| 196 |
+
"all_files_label": "📁 所有生成的文件(��载)",
|
| 197 |
+
"generation_details": "生成详情"
|
| 198 |
+
},
|
| 199 |
+
"messages": {
|
| 200 |
+
"no_audio_to_save": "❌ 没有要保存的音频",
|
| 201 |
+
"save_success": "✅ 已将音频和元数据保存到 {filename}",
|
| 202 |
+
"save_failed": "❌ 保存失败: {error}",
|
| 203 |
+
"no_file_selected": "⚠️ 未选择文件",
|
| 204 |
+
"params_loaded": "✅ 已从 {filename} 加载参数",
|
| 205 |
+
"invalid_json": "❌ 无效的JSON文件: {error}",
|
| 206 |
+
"load_error": "❌ 加载文件时出错: {error}",
|
| 207 |
+
"example_loaded": "📁 已从 {filename} 加载示例",
|
| 208 |
+
"example_failed": "解析JSON文件 {filename} 失败: {error}",
|
| 209 |
+
"example_error": "加载示例时出错: {error}",
|
| 210 |
+
"lm_generated": "🤖 使用LM生成的示例",
|
| 211 |
+
"lm_fallback": "使用LM生成示例失败,回退到示例目录",
|
| 212 |
+
"lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
|
| 213 |
+
"autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
|
| 214 |
+
"batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
|
| 215 |
+
"batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
|
| 216 |
+
"batch_failed": "❌ 后台生成失败: {error}",
|
| 217 |
+
"viewing_batch": "✅ 查看批次 {n}",
|
| 218 |
+
"at_first_batch": "已在第一批次",
|
| 219 |
+
"at_last_batch": "没有下一批次可用",
|
| 220 |
+
"batch_not_found": "在队列中未找到批次 {n}",
|
| 221 |
+
"no_batch_data": "没有要恢复的批次数据。",
|
| 222 |
+
"params_restored": "✅ 已从批次 {n} 恢复UI参数",
|
| 223 |
+
"scoring_failed": "❌ 错误: 未找到批次数据",
|
| 224 |
+
"no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
|
| 225 |
+
"score_failed": "❌ 评分失败: {error}",
|
| 226 |
+
"score_error": "❌ 计算分数时出错: {error}",
|
| 227 |
+
"lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
|
| 228 |
+
"lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
|
| 229 |
+
"lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
|
| 230 |
+
"lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
|
| 231 |
+
"lrc_empty_result": "⚠️ LRC生成结果为空。",
|
| 232 |
+
"empty_query": "⚠️ 请输入音乐描述。",
|
| 233 |
+
"sample_creation_failed": "❌ 创建样本失败。请重试。",
|
| 234 |
+
"sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
|
| 235 |
+
"simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
|
| 236 |
+
"simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
|
| 237 |
+
"simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
|
| 238 |
+
"format_success": "✅ 描述和歌词格式化成功",
|
| 239 |
+
"format_failed": "❌ 格式化失败: {error}",
|
| 240 |
+
"skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
|
| 241 |
+
"invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
|
| 242 |
+
"timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
|
| 243 |
+
"timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
|
| 244 |
+
}
|
| 245 |
+
}
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/__init__.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Components Module
|
| 3 |
+
Contains all Gradio interface component definitions and layouts
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.gradio_ui.i18n import get_i18n, t
|
| 7 |
+
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
|
| 8 |
+
from acestep.gradio_ui.interfaces.generation import create_generation_section
|
| 9 |
+
from acestep.gradio_ui.interfaces.result import create_results_section
|
| 10 |
+
from acestep.gradio_ui.interfaces.training import create_training_section
|
| 11 |
+
from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
|
| 15 |
+
"""
|
| 16 |
+
Create Gradio interface
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
dit_handler: DiT handler instance
|
| 20 |
+
llm_handler: LM handler instance
|
| 21 |
+
dataset_handler: Dataset handler instance
|
| 22 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 23 |
+
If None, service will not be pre-initialized.
|
| 24 |
+
language: UI language code ('en', 'zh', 'ja', default: 'en')
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Gradio Blocks instance
|
| 28 |
+
"""
|
| 29 |
+
# Initialize i18n with selected language
|
| 30 |
+
i18n = get_i18n(language)
|
| 31 |
+
|
| 32 |
+
with gr.Blocks(
|
| 33 |
+
title=t("app.title"),
|
| 34 |
+
theme=gr.themes.Soft(),
|
| 35 |
+
css="""
|
| 36 |
+
.main-header {
|
| 37 |
+
text-align: center;
|
| 38 |
+
margin-bottom: 2rem;
|
| 39 |
+
}
|
| 40 |
+
.section-header {
|
| 41 |
+
background: linear-gradient(90deg, #4CAF50, #45a049);
|
| 42 |
+
color: white;
|
| 43 |
+
padding: 10px;
|
| 44 |
+
border-radius: 5px;
|
| 45 |
+
margin: 10px 0;
|
| 46 |
+
}
|
| 47 |
+
.lm-hints-row {
|
| 48 |
+
align-items: stretch;
|
| 49 |
+
}
|
| 50 |
+
.lm-hints-col {
|
| 51 |
+
display: flex;
|
| 52 |
+
}
|
| 53 |
+
.lm-hints-col > div {
|
| 54 |
+
flex: 1;
|
| 55 |
+
display: flex;
|
| 56 |
+
}
|
| 57 |
+
.lm-hints-btn button {
|
| 58 |
+
height: 100%;
|
| 59 |
+
width: 100%;
|
| 60 |
+
}
|
| 61 |
+
"""
|
| 62 |
+
) as demo:
|
| 63 |
+
|
| 64 |
+
gr.HTML(f"""
|
| 65 |
+
<div class="main-header">
|
| 66 |
+
<h1>{t("app.title")}</h1>
|
| 67 |
+
<p>{t("app.subtitle")}</p>
|
| 68 |
+
<p style="margin-top: 0.5rem;">
|
| 69 |
+
<a href="https://ace-step.github.io/ace-step-v1.5.github.io/" target="_blank">Project</a> |
|
| 70 |
+
<a href="https://huggingface.co/collections/ACE-Step/ace-step-15" target="_blank">Hugging Face</a> |
|
| 71 |
+
<a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5" target="_blank">ModelScope</a> |
|
| 72 |
+
<a href="https://huggingface.co/spaces/ACE-Step/Ace-Step-v1.5" target="_blank">Space Demo</a> |
|
| 73 |
+
<a href="https://discord.gg/PeWDxrkdj7" target="_blank">Discord</a> |
|
| 74 |
+
<a href="https://arxiv.org/abs/2506.00045" target="_blank">Technical Report</a>
|
| 75 |
+
</p>
|
| 76 |
+
</div>
|
| 77 |
+
""")
|
| 78 |
+
|
| 79 |
+
# Dataset Explorer Section
|
| 80 |
+
dataset_section = create_dataset_section(dataset_handler)
|
| 81 |
+
|
| 82 |
+
# Generation Section (pass init_params and language to support pre-initialization)
|
| 83 |
+
generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
|
| 84 |
+
|
| 85 |
+
# Results Section
|
| 86 |
+
results_section = create_results_section(dit_handler)
|
| 87 |
+
|
| 88 |
+
# Training Section (LoRA training and dataset builder)
|
| 89 |
+
# Pass init_params to support hiding in service mode
|
| 90 |
+
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
|
| 91 |
+
|
| 92 |
+
# Connect event handlers (pass init_params for multi-model support)
|
| 93 |
+
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=init_params)
|
| 94 |
+
|
| 95 |
+
# Connect training event handlers
|
| 96 |
+
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
| 97 |
+
|
| 98 |
+
return demo
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/dataset.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Dataset Section Module
|
| 3 |
+
Contains dataset explorer section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_dataset_section(dataset_handler) -> dict:
|
| 9 |
+
"""Create dataset explorer section"""
|
| 10 |
+
with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
|
| 11 |
+
with gr.Row(equal_height=True):
|
| 12 |
+
dataset_type = gr.Dropdown(
|
| 13 |
+
choices=["train", "test"],
|
| 14 |
+
value="train",
|
| 15 |
+
label="Dataset",
|
| 16 |
+
info="Choose dataset to explore",
|
| 17 |
+
scale=2
|
| 18 |
+
)
|
| 19 |
+
import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
|
| 20 |
+
|
| 21 |
+
search_type = gr.Dropdown(
|
| 22 |
+
choices=["keys", "idx", "random"],
|
| 23 |
+
value="random",
|
| 24 |
+
label="Search Type",
|
| 25 |
+
info="How to find items",
|
| 26 |
+
scale=1
|
| 27 |
+
)
|
| 28 |
+
search_value = gr.Textbox(
|
| 29 |
+
label="Search Value",
|
| 30 |
+
placeholder="Enter keys or index (leave empty for random)",
|
| 31 |
+
info="Keys: exact match, Index: 0 to dataset size-1",
|
| 32 |
+
scale=2
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
instruction_display = gr.Textbox(
|
| 36 |
+
label="📝 Instruction",
|
| 37 |
+
interactive=False,
|
| 38 |
+
placeholder="No instruction available",
|
| 39 |
+
lines=1
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
repaint_viz_plot = gr.Plot()
|
| 43 |
+
|
| 44 |
+
with gr.Accordion("📋 Item Metadata (JSON)", open=False):
|
| 45 |
+
item_info_json = gr.Code(
|
| 46 |
+
label="Complete Item Information",
|
| 47 |
+
language="json",
|
| 48 |
+
interactive=False,
|
| 49 |
+
lines=15
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with gr.Row(equal_height=True):
|
| 53 |
+
item_src_audio = gr.Audio(
|
| 54 |
+
label="Source Audio",
|
| 55 |
+
type="filepath",
|
| 56 |
+
interactive=False,
|
| 57 |
+
scale=8
|
| 58 |
+
)
|
| 59 |
+
get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
|
| 60 |
+
|
| 61 |
+
with gr.Row(equal_height=True):
|
| 62 |
+
item_target_audio = gr.Audio(
|
| 63 |
+
label="Target Audio",
|
| 64 |
+
type="filepath",
|
| 65 |
+
interactive=False,
|
| 66 |
+
scale=8
|
| 67 |
+
)
|
| 68 |
+
item_refer_audio = gr.Audio(
|
| 69 |
+
label="Reference Audio",
|
| 70 |
+
type="filepath",
|
| 71 |
+
interactive=False,
|
| 72 |
+
scale=2
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
with gr.Row():
|
| 76 |
+
use_src_checkbox = gr.Checkbox(
|
| 77 |
+
label="Use Source Audio from Dataset",
|
| 78 |
+
value=True,
|
| 79 |
+
info="Check to use the source audio from dataset"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
|
| 83 |
+
auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"dataset_type": dataset_type,
|
| 87 |
+
"import_dataset_btn": import_dataset_btn,
|
| 88 |
+
"search_type": search_type,
|
| 89 |
+
"search_value": search_value,
|
| 90 |
+
"instruction_display": instruction_display,
|
| 91 |
+
"repaint_viz_plot": repaint_viz_plot,
|
| 92 |
+
"item_info_json": item_info_json,
|
| 93 |
+
"item_src_audio": item_src_audio,
|
| 94 |
+
"get_item_btn": get_item_btn,
|
| 95 |
+
"item_target_audio": item_target_audio,
|
| 96 |
+
"item_refer_audio": item_refer_audio,
|
| 97 |
+
"use_src_checkbox": use_src_checkbox,
|
| 98 |
+
"data_status": data_status,
|
| 99 |
+
"auto_fill_btn": auto_fill_btn,
|
| 100 |
+
}
|
| 101 |
+
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/generation.py
ADDED
|
@@ -0,0 +1,693 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Generation Section Module
|
| 3 |
+
Contains generation section component definitions - Simplified UI
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.constants import (
|
| 7 |
+
VALID_LANGUAGES,
|
| 8 |
+
TRACK_NAMES,
|
| 9 |
+
TASK_TYPES_TURBO,
|
| 10 |
+
TASK_TYPES_BASE,
|
| 11 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 12 |
+
)
|
| 13 |
+
from acestep.gradio_ui.i18n import t
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
|
| 17 |
+
"""Create generation section with simplified UI
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
dit_handler: DiT handler instance
|
| 21 |
+
llm_handler: LM handler instance
|
| 22 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 23 |
+
If None, service will not be pre-initialized.
|
| 24 |
+
language: UI language code ('en', 'zh', 'ja')
|
| 25 |
+
"""
|
| 26 |
+
# Check if service is pre-initialized
|
| 27 |
+
service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
|
| 28 |
+
|
| 29 |
+
# Check if running in service mode (restricted UI)
|
| 30 |
+
service_mode = init_params is not None and init_params.get('service_mode', False)
|
| 31 |
+
|
| 32 |
+
# Get current language from init_params if available
|
| 33 |
+
current_language = init_params.get('language', language) if init_params else language
|
| 34 |
+
|
| 35 |
+
# Get available models
|
| 36 |
+
available_dit_models = init_params.get('available_dit_models', []) if init_params else []
|
| 37 |
+
current_model_value = init_params.get('config_path', '') if init_params else ''
|
| 38 |
+
show_model_selector = len(available_dit_models) > 1
|
| 39 |
+
|
| 40 |
+
with gr.Group():
|
| 41 |
+
# ==================== Service Configuration (Hidden in service mode) ====================
|
| 42 |
+
accordion_open = not service_pre_initialized
|
| 43 |
+
accordion_visible = not service_pre_initialized
|
| 44 |
+
with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
|
| 45 |
+
# Language selector at the top
|
| 46 |
+
with gr.Row():
|
| 47 |
+
language_dropdown = gr.Dropdown(
|
| 48 |
+
choices=[
|
| 49 |
+
("English", "en"),
|
| 50 |
+
("中文", "zh"),
|
| 51 |
+
("日本語", "ja"),
|
| 52 |
+
],
|
| 53 |
+
value=current_language,
|
| 54 |
+
label=t("service.language_label"),
|
| 55 |
+
info=t("service.language_info"),
|
| 56 |
+
scale=1,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
with gr.Row(equal_height=True):
|
| 60 |
+
with gr.Column(scale=4):
|
| 61 |
+
checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
|
| 62 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 63 |
+
label=t("service.checkpoint_label"),
|
| 64 |
+
choices=dit_handler.get_available_checkpoints(),
|
| 65 |
+
value=checkpoint_value,
|
| 66 |
+
info=t("service.checkpoint_info")
|
| 67 |
+
)
|
| 68 |
+
with gr.Column(scale=1, min_width=90):
|
| 69 |
+
refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
|
| 70 |
+
|
| 71 |
+
with gr.Row():
|
| 72 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 73 |
+
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 74 |
+
config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 75 |
+
config_path = gr.Dropdown(
|
| 76 |
+
label=t("service.model_path_label"),
|
| 77 |
+
choices=available_models,
|
| 78 |
+
value=config_path_value,
|
| 79 |
+
info=t("service.model_path_info")
|
| 80 |
+
)
|
| 81 |
+
device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
|
| 82 |
+
device = gr.Dropdown(
|
| 83 |
+
choices=["auto", "cuda", "cpu"],
|
| 84 |
+
value=device_value,
|
| 85 |
+
label=t("service.device_label"),
|
| 86 |
+
info=t("service.device_info")
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 91 |
+
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 92 |
+
lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
|
| 93 |
+
lm_model_path = gr.Dropdown(
|
| 94 |
+
label=t("service.lm_model_path_label"),
|
| 95 |
+
choices=available_lm_models,
|
| 96 |
+
value=lm_model_path_value,
|
| 97 |
+
info=t("service.lm_model_path_info")
|
| 98 |
+
)
|
| 99 |
+
backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
|
| 100 |
+
backend_dropdown = gr.Dropdown(
|
| 101 |
+
choices=["vllm", "pt"],
|
| 102 |
+
value=backend_value,
|
| 103 |
+
label=t("service.backend_label"),
|
| 104 |
+
info=t("service.backend_info")
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
with gr.Row():
|
| 108 |
+
init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
|
| 109 |
+
init_llm_checkbox = gr.Checkbox(
|
| 110 |
+
label=t("service.init_llm_label"),
|
| 111 |
+
value=init_llm_value,
|
| 112 |
+
info=t("service.init_llm_info"),
|
| 113 |
+
)
|
| 114 |
+
flash_attn_available = dit_handler.is_flash_attention_available()
|
| 115 |
+
use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
|
| 116 |
+
use_flash_attention_checkbox = gr.Checkbox(
|
| 117 |
+
label=t("service.flash_attention_label"),
|
| 118 |
+
value=use_flash_attention_value,
|
| 119 |
+
interactive=flash_attn_available,
|
| 120 |
+
info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
|
| 121 |
+
)
|
| 122 |
+
offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
|
| 123 |
+
offload_to_cpu_checkbox = gr.Checkbox(
|
| 124 |
+
label=t("service.offload_cpu_label"),
|
| 125 |
+
value=offload_to_cpu_value,
|
| 126 |
+
info=t("service.offload_cpu_info")
|
| 127 |
+
)
|
| 128 |
+
offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
|
| 129 |
+
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 130 |
+
label=t("service.offload_dit_cpu_label"),
|
| 131 |
+
value=offload_dit_to_cpu_value,
|
| 132 |
+
info=t("service.offload_dit_cpu_info")
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
|
| 136 |
+
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 137 |
+
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
| 138 |
+
|
| 139 |
+
# LoRA Configuration Section
|
| 140 |
+
gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
|
| 141 |
+
with gr.Row():
|
| 142 |
+
lora_path = gr.Textbox(
|
| 143 |
+
label="LoRA Path",
|
| 144 |
+
placeholder="./lora_output/final/adapter",
|
| 145 |
+
info="Path to trained LoRA adapter directory",
|
| 146 |
+
scale=3,
|
| 147 |
+
)
|
| 148 |
+
load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
|
| 149 |
+
unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
|
| 150 |
+
with gr.Row():
|
| 151 |
+
use_lora_checkbox = gr.Checkbox(
|
| 152 |
+
label="Use LoRA",
|
| 153 |
+
value=False,
|
| 154 |
+
info="Enable LoRA adapter for inference",
|
| 155 |
+
scale=1,
|
| 156 |
+
)
|
| 157 |
+
lora_status = gr.Textbox(
|
| 158 |
+
label="LoRA Status",
|
| 159 |
+
value="No LoRA loaded",
|
| 160 |
+
interactive=False,
|
| 161 |
+
scale=2,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# ==================== Model Selector (Top, only when multiple models) ====================
|
| 165 |
+
with gr.Row(visible=show_model_selector):
|
| 166 |
+
dit_model_selector = gr.Dropdown(
|
| 167 |
+
choices=available_dit_models,
|
| 168 |
+
value=current_model_value,
|
| 169 |
+
label="models",
|
| 170 |
+
scale=1,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# Hidden dropdown when only one model (for event handler compatibility)
|
| 174 |
+
if not show_model_selector:
|
| 175 |
+
dit_model_selector = gr.Dropdown(
|
| 176 |
+
choices=available_dit_models if available_dit_models else [current_model_value],
|
| 177 |
+
value=current_model_value,
|
| 178 |
+
visible=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# ==================== Generation Mode (4 modes) ====================
|
| 182 |
+
gr.HTML("<div style='background: #4a5568; color: white; padding: 8px 16px; border-radius: 4px; font-weight: bold;'>Generation Mode</div>")
|
| 183 |
+
with gr.Row():
|
| 184 |
+
generation_mode = gr.Radio(
|
| 185 |
+
choices=[
|
| 186 |
+
("Simple", "simple"),
|
| 187 |
+
("Custom", "custom"),
|
| 188 |
+
("Cover", "cover"),
|
| 189 |
+
("Repaint", "repaint"),
|
| 190 |
+
],
|
| 191 |
+
value="custom",
|
| 192 |
+
label="",
|
| 193 |
+
show_label=False,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# ==================== Simple Mode Group ====================
|
| 197 |
+
with gr.Column(visible=False) as simple_mode_group:
|
| 198 |
+
# Row: Song Description + Vocal Language + Random button
|
| 199 |
+
with gr.Row(equal_height=True):
|
| 200 |
+
simple_query_input = gr.Textbox(
|
| 201 |
+
label=t("generation.simple_query_label"),
|
| 202 |
+
placeholder=t("generation.simple_query_placeholder"),
|
| 203 |
+
lines=2,
|
| 204 |
+
info=t("generation.simple_query_info"),
|
| 205 |
+
scale=10,
|
| 206 |
+
)
|
| 207 |
+
simple_vocal_language = gr.Dropdown(
|
| 208 |
+
choices=VALID_LANGUAGES,
|
| 209 |
+
value="unknown",
|
| 210 |
+
allow_custom_value=True,
|
| 211 |
+
label=t("generation.simple_vocal_language_label"),
|
| 212 |
+
interactive=True,
|
| 213 |
+
info="use unknown for instrumental",
|
| 214 |
+
scale=2,
|
| 215 |
+
)
|
| 216 |
+
with gr.Column(scale=1, min_width=60):
|
| 217 |
+
random_desc_btn = gr.Button(
|
| 218 |
+
"🎲",
|
| 219 |
+
variant="secondary",
|
| 220 |
+
size="lg",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Hidden components (kept for compatibility but not shown)
|
| 224 |
+
simple_instrumental_checkbox = gr.Checkbox(
|
| 225 |
+
label=t("generation.instrumental_label"),
|
| 226 |
+
value=False,
|
| 227 |
+
visible=False,
|
| 228 |
+
)
|
| 229 |
+
create_sample_btn = gr.Button(
|
| 230 |
+
t("generation.create_sample_btn"),
|
| 231 |
+
variant="primary",
|
| 232 |
+
size="lg",
|
| 233 |
+
visible=False,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# State to track if sample has been created in Simple mode
|
| 237 |
+
simple_sample_created = gr.State(value=False)
|
| 238 |
+
|
| 239 |
+
# ==================== Source Audio (for Cover/Repaint) ====================
|
| 240 |
+
# This is shown above the main content for Cover and Repaint modes
|
| 241 |
+
with gr.Column(visible=False) as src_audio_group:
|
| 242 |
+
with gr.Row(equal_height=True):
|
| 243 |
+
# Source Audio - scale=10 to match (refer_audio=2 + prompt/lyrics=8)
|
| 244 |
+
src_audio = gr.Audio(
|
| 245 |
+
label="Source Audio",
|
| 246 |
+
type="filepath",
|
| 247 |
+
scale=10,
|
| 248 |
+
)
|
| 249 |
+
# Process button - scale=1 to align with random button
|
| 250 |
+
with gr.Column(scale=1, min_width=80):
|
| 251 |
+
process_src_btn = gr.Button(
|
| 252 |
+
"Analyze",
|
| 253 |
+
variant="secondary",
|
| 254 |
+
size="lg",
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Hidden Audio Codes storage (needed internally but not displayed)
|
| 258 |
+
text2music_audio_code_string = gr.Textbox(
|
| 259 |
+
label="Audio Codes",
|
| 260 |
+
visible=False,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# ==================== Custom/Cover/Repaint Mode Content ====================
|
| 264 |
+
with gr.Column() as custom_mode_content:
|
| 265 |
+
with gr.Row(equal_height=True):
|
| 266 |
+
# Left: Reference Audio
|
| 267 |
+
with gr.Column(scale=2, min_width=200):
|
| 268 |
+
reference_audio = gr.Audio(
|
| 269 |
+
label="Reference Audio (optional)",
|
| 270 |
+
type="filepath",
|
| 271 |
+
show_label=True,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Middle: Prompt + Lyrics + Format button
|
| 275 |
+
with gr.Column(scale=8):
|
| 276 |
+
# Row 1: Prompt and Lyrics
|
| 277 |
+
with gr.Row(equal_height=True):
|
| 278 |
+
captions = gr.Textbox(
|
| 279 |
+
label="Prompt",
|
| 280 |
+
placeholder="Describe the music style, mood, instruments...",
|
| 281 |
+
lines=12,
|
| 282 |
+
max_lines=12,
|
| 283 |
+
scale=1,
|
| 284 |
+
)
|
| 285 |
+
lyrics = gr.Textbox(
|
| 286 |
+
label="Lyrics",
|
| 287 |
+
placeholder="Enter lyrics here... Use [Verse], [Chorus] etc. for structure",
|
| 288 |
+
lines=12,
|
| 289 |
+
max_lines=12,
|
| 290 |
+
scale=1,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Row 2: Format button (only below Prompt and Lyrics)
|
| 294 |
+
format_btn = gr.Button(
|
| 295 |
+
"Format",
|
| 296 |
+
variant="secondary",
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Right: Random button
|
| 300 |
+
with gr.Column(scale=1, min_width=60):
|
| 301 |
+
sample_btn = gr.Button(
|
| 302 |
+
"🎲",
|
| 303 |
+
variant="secondary",
|
| 304 |
+
size="lg",
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
# Placeholder for removed audio_uploads_accordion (for compatibility)
|
| 308 |
+
audio_uploads_accordion = gr.Column(visible=False)
|
| 309 |
+
|
| 310 |
+
# Legacy cover_mode_group (hidden, for backward compatibility)
|
| 311 |
+
cover_mode_group = gr.Column(visible=False)
|
| 312 |
+
# Legacy convert button (hidden, for backward compatibility)
|
| 313 |
+
convert_src_to_codes_btn = gr.Button("Convert to Codes", visible=False)
|
| 314 |
+
|
| 315 |
+
# ==================== Repaint Mode: Source + Time Range ====================
|
| 316 |
+
with gr.Column(visible=False) as repainting_group:
|
| 317 |
+
with gr.Row():
|
| 318 |
+
repainting_start = gr.Number(
|
| 319 |
+
label="Start (seconds)",
|
| 320 |
+
value=0.0,
|
| 321 |
+
step=0.1,
|
| 322 |
+
scale=1,
|
| 323 |
+
)
|
| 324 |
+
repainting_end = gr.Number(
|
| 325 |
+
label="End (seconds, -1 for end)",
|
| 326 |
+
value=-1,
|
| 327 |
+
minimum=-1,
|
| 328 |
+
step=0.1,
|
| 329 |
+
scale=1,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# ==================== Optional Parameters ====================
|
| 333 |
+
with gr.Accordion("⚙️ Optional Parameters", open=False, visible=False) as optional_params_accordion:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
# ==================== Advanced Settings ====================
|
| 337 |
+
with gr.Accordion("🔧 Advanced Settings", open=False) as advanced_options_accordion:
|
| 338 |
+
with gr.Row():
|
| 339 |
+
bpm = gr.Number(
|
| 340 |
+
label="BPM (optional)",
|
| 341 |
+
value=0,
|
| 342 |
+
step=1,
|
| 343 |
+
info="leave empty for N/A",
|
| 344 |
+
scale=1,
|
| 345 |
+
)
|
| 346 |
+
key_scale = gr.Textbox(
|
| 347 |
+
label="Key Signature (optional)",
|
| 348 |
+
placeholder="Leave empty for N/A",
|
| 349 |
+
value="",
|
| 350 |
+
info="A-G, #/♭, major/minor",
|
| 351 |
+
scale=1,
|
| 352 |
+
)
|
| 353 |
+
time_signature = gr.Dropdown(
|
| 354 |
+
choices=["", "2", "3", "4"],
|
| 355 |
+
value="",
|
| 356 |
+
label="Time Signature (optional)",
|
| 357 |
+
allow_custom_value=True,
|
| 358 |
+
info="2/4, 3/4, 4/4...",
|
| 359 |
+
scale=1,
|
| 360 |
+
)
|
| 361 |
+
audio_duration = gr.Number(
|
| 362 |
+
label="Audio Duration (seconds)",
|
| 363 |
+
value=-1,
|
| 364 |
+
minimum=-1,
|
| 365 |
+
maximum=600.0,
|
| 366 |
+
step=1,
|
| 367 |
+
info="Use -1 for random",
|
| 368 |
+
scale=1,
|
| 369 |
+
)
|
| 370 |
+
vocal_language = gr.Dropdown(
|
| 371 |
+
choices=VALID_LANGUAGES,
|
| 372 |
+
value="unknown",
|
| 373 |
+
label="Vocal Language",
|
| 374 |
+
allow_custom_value=True,
|
| 375 |
+
info="use `unknown` for instrumental",
|
| 376 |
+
scale=1,
|
| 377 |
+
)
|
| 378 |
+
batch_size_input = gr.Number(
|
| 379 |
+
label="batch size",
|
| 380 |
+
info="max 8",
|
| 381 |
+
value=2,
|
| 382 |
+
minimum=1,
|
| 383 |
+
maximum=8,
|
| 384 |
+
step=1,
|
| 385 |
+
scale=1,
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Row 1: DiT Inference Steps, Seed, Audio Format
|
| 389 |
+
with gr.Row():
|
| 390 |
+
inference_steps = gr.Slider(
|
| 391 |
+
minimum=1,
|
| 392 |
+
maximum=20,
|
| 393 |
+
value=8,
|
| 394 |
+
step=1,
|
| 395 |
+
label="DiT Inference Steps",
|
| 396 |
+
info="Turbo: max 8, Base: max 200",
|
| 397 |
+
)
|
| 398 |
+
seed = gr.Textbox(
|
| 399 |
+
label="Seed",
|
| 400 |
+
value="-1",
|
| 401 |
+
info="Use comma-separated values for batches",
|
| 402 |
+
)
|
| 403 |
+
audio_format = gr.Dropdown(
|
| 404 |
+
choices=["mp3", "flac"],
|
| 405 |
+
value="mp3",
|
| 406 |
+
label="Audio Format",
|
| 407 |
+
info="Audio format for saved files",
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# Row 2: Shift, Random Seed, Inference Method
|
| 411 |
+
with gr.Row():
|
| 412 |
+
shift = gr.Slider(
|
| 413 |
+
minimum=1.0,
|
| 414 |
+
maximum=5.0,
|
| 415 |
+
value=3.0,
|
| 416 |
+
step=0.1,
|
| 417 |
+
label="Shift",
|
| 418 |
+
info="Timestep shift factor for base models (range 1.0-5.0, default 3.0). Not effective for turbo models.",
|
| 419 |
+
)
|
| 420 |
+
random_seed_checkbox = gr.Checkbox(
|
| 421 |
+
label="Random Seed",
|
| 422 |
+
value=True,
|
| 423 |
+
info="Enable to auto-generate seeds",
|
| 424 |
+
)
|
| 425 |
+
infer_method = gr.Dropdown(
|
| 426 |
+
choices=["ode", "sde"],
|
| 427 |
+
value="ode",
|
| 428 |
+
label="Inference Method",
|
| 429 |
+
info="Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Row 3: Custom Timesteps (full width)
|
| 433 |
+
custom_timesteps = gr.Textbox(
|
| 434 |
+
label="Custom Timesteps",
|
| 435 |
+
placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
|
| 436 |
+
value="",
|
| 437 |
+
info="Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Section: LM Generation Parameters
|
| 441 |
+
gr.HTML("<h4>🎵 LM Generation Parameters</h4>")
|
| 442 |
+
|
| 443 |
+
# Row 4: LM Temperature, LM CFG Scale, LM Top-K, LM Top-P
|
| 444 |
+
with gr.Row():
|
| 445 |
+
lm_temperature = gr.Slider(
|
| 446 |
+
minimum=0.0,
|
| 447 |
+
maximum=2.0,
|
| 448 |
+
value=0.85,
|
| 449 |
+
step=0.05,
|
| 450 |
+
label="LM Temperature",
|
| 451 |
+
info="5Hz LM temperature (higher = more random)",
|
| 452 |
+
)
|
| 453 |
+
lm_cfg_scale = gr.Slider(
|
| 454 |
+
minimum=1.0,
|
| 455 |
+
maximum=3.0,
|
| 456 |
+
value=2.0,
|
| 457 |
+
step=0.1,
|
| 458 |
+
label="LM CFG Scale",
|
| 459 |
+
info="5Hz LM CFG (1.0 = no CFG)",
|
| 460 |
+
)
|
| 461 |
+
lm_top_k = gr.Slider(
|
| 462 |
+
minimum=0,
|
| 463 |
+
maximum=100,
|
| 464 |
+
value=0,
|
| 465 |
+
step=1,
|
| 466 |
+
label="LM Top-K",
|
| 467 |
+
info="Top-k (0 = disabled)",
|
| 468 |
+
)
|
| 469 |
+
lm_top_p = gr.Slider(
|
| 470 |
+
minimum=0.0,
|
| 471 |
+
maximum=1.0,
|
| 472 |
+
value=0.9,
|
| 473 |
+
step=0.01,
|
| 474 |
+
label="LM Top-P",
|
| 475 |
+
info="Top-p (1.0 = disabled)",
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# Row 5: LM Negative Prompt (full width)
|
| 479 |
+
lm_negative_prompt = gr.Textbox(
|
| 480 |
+
label="LM Negative Prompt",
|
| 481 |
+
value="NO USER INPUT",
|
| 482 |
+
placeholder="Things to avoid in generation...",
|
| 483 |
+
lines=2,
|
| 484 |
+
info="Negative prompt (use when LM CFG Scale > 1.0)",
|
| 485 |
+
)
|
| 486 |
+
# audio_cover_strength remains hidden for now
|
| 487 |
+
audio_cover_strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, visible=False)
|
| 488 |
+
|
| 489 |
+
# Note: audio_duration, bpm, key_scale, time_signature are now visible in Optional Parameters
|
| 490 |
+
# ==================== Generate Button Row ====================
|
| 491 |
+
generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
|
| 492 |
+
with gr.Row(equal_height=True):
|
| 493 |
+
# Left: Thinking and Instrumental checkboxes
|
| 494 |
+
with gr.Column(scale=1, min_width=120):
|
| 495 |
+
think_checkbox = gr.Checkbox(
|
| 496 |
+
label="Thinking",
|
| 497 |
+
value=True,
|
| 498 |
+
)
|
| 499 |
+
instrumental_checkbox = gr.Checkbox(
|
| 500 |
+
label="Instrumental",
|
| 501 |
+
value=False,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Center: Generate button
|
| 505 |
+
with gr.Column(scale=4):
|
| 506 |
+
generate_btn = gr.Button(
|
| 507 |
+
"🎵 Generate Music",
|
| 508 |
+
variant="primary",
|
| 509 |
+
size="lg",
|
| 510 |
+
interactive=generate_btn_interactive,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# Right: auto_score, auto_lrc
|
| 514 |
+
with gr.Column(scale=1, min_width=120):
|
| 515 |
+
auto_score = gr.Checkbox(
|
| 516 |
+
label="Get Scores",
|
| 517 |
+
value=False,
|
| 518 |
+
)
|
| 519 |
+
auto_lrc = gr.Checkbox(
|
| 520 |
+
label="Get LRC",
|
| 521 |
+
value=False,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# ==================== Hidden Components (for internal use) ====================
|
| 525 |
+
# These are needed for event handlers but not shown in UI
|
| 526 |
+
|
| 527 |
+
# Task type (set automatically based on generation_mode)
|
| 528 |
+
actual_model = init_params.get('config_path', 'acestep-v15-turbo') if service_pre_initialized else 'acestep-v15-turbo'
|
| 529 |
+
actual_model_lower = (actual_model or "").lower()
|
| 530 |
+
if "turbo" in actual_model_lower:
|
| 531 |
+
initial_task_choices = TASK_TYPES_TURBO
|
| 532 |
+
else:
|
| 533 |
+
initial_task_choices = TASK_TYPES_BASE
|
| 534 |
+
|
| 535 |
+
task_type = gr.Dropdown(
|
| 536 |
+
choices=initial_task_choices,
|
| 537 |
+
value="text2music",
|
| 538 |
+
visible=False,
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
instruction_display_gen = gr.Textbox(
|
| 542 |
+
value=DEFAULT_DIT_INSTRUCTION,
|
| 543 |
+
visible=False,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
track_name = gr.Dropdown(
|
| 547 |
+
choices=TRACK_NAMES,
|
| 548 |
+
value=None,
|
| 549 |
+
visible=False,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
complete_track_classes = gr.CheckboxGroup(
|
| 553 |
+
choices=TRACK_NAMES,
|
| 554 |
+
visible=False,
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Note: lyrics, vocal_language, instrumental_checkbox, format_btn are now visible in custom_mode_content
|
| 558 |
+
|
| 559 |
+
# Hidden advanced settings (keep defaults)
|
| 560 |
+
# Note: Most parameters are now visible in Advanced Settings section above
|
| 561 |
+
guidance_scale = gr.Slider(value=7.0, visible=False)
|
| 562 |
+
use_adg = gr.Checkbox(value=False, visible=False)
|
| 563 |
+
cfg_interval_start = gr.Slider(value=0.0, visible=False)
|
| 564 |
+
cfg_interval_end = gr.Slider(value=1.0, visible=False)
|
| 565 |
+
|
| 566 |
+
# LM parameters (remaining hidden ones)
|
| 567 |
+
use_cot_metas = gr.Checkbox(value=True, visible=False)
|
| 568 |
+
use_cot_caption = gr.Checkbox(value=True, visible=False)
|
| 569 |
+
use_cot_language = gr.Checkbox(value=True, visible=False)
|
| 570 |
+
constrained_decoding_debug = gr.Checkbox(value=False, visible=False)
|
| 571 |
+
allow_lm_batch = gr.Checkbox(value=True, visible=False)
|
| 572 |
+
lm_batch_chunk_size = gr.Number(value=8, visible=False)
|
| 573 |
+
score_scale = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, visible=False)
|
| 574 |
+
autogen_checkbox = gr.Checkbox(value=False, visible=False)
|
| 575 |
+
|
| 576 |
+
# Transcribe button (hidden)
|
| 577 |
+
transcribe_btn = gr.Button(value="Transcribe", visible=False)
|
| 578 |
+
text2music_audio_codes_group = gr.Group(visible=False)
|
| 579 |
+
|
| 580 |
+
# Note: format_btn is now visible in custom_mode_content
|
| 581 |
+
|
| 582 |
+
# Load file button (hidden for now)
|
| 583 |
+
load_file = gr.UploadButton(
|
| 584 |
+
label="Load",
|
| 585 |
+
file_types=[".json"],
|
| 586 |
+
file_count="single",
|
| 587 |
+
visible=False,
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Caption/Lyrics accordions (not used in new UI but needed for compatibility)
|
| 591 |
+
caption_accordion = gr.Accordion("Caption", visible=False)
|
| 592 |
+
lyrics_accordion = gr.Accordion("Lyrics", visible=False)
|
| 593 |
+
# Note: optional_params_accordion is now visible above
|
| 594 |
+
|
| 595 |
+
return {
|
| 596 |
+
"service_config_accordion": service_config_accordion,
|
| 597 |
+
"language_dropdown": language_dropdown,
|
| 598 |
+
"checkpoint_dropdown": checkpoint_dropdown,
|
| 599 |
+
"refresh_btn": refresh_btn,
|
| 600 |
+
"config_path": config_path,
|
| 601 |
+
"device": device,
|
| 602 |
+
"init_btn": init_btn,
|
| 603 |
+
"init_status": init_status,
|
| 604 |
+
"lm_model_path": lm_model_path,
|
| 605 |
+
"init_llm_checkbox": init_llm_checkbox,
|
| 606 |
+
"backend_dropdown": backend_dropdown,
|
| 607 |
+
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 608 |
+
"offload_to_cpu_checkbox": offload_to_cpu_checkbox,
|
| 609 |
+
"offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
|
| 610 |
+
# LoRA components
|
| 611 |
+
"lora_path": lora_path,
|
| 612 |
+
"load_lora_btn": load_lora_btn,
|
| 613 |
+
"unload_lora_btn": unload_lora_btn,
|
| 614 |
+
"use_lora_checkbox": use_lora_checkbox,
|
| 615 |
+
"lora_status": lora_status,
|
| 616 |
+
# DiT model selector
|
| 617 |
+
"dit_model_selector": dit_model_selector,
|
| 618 |
+
"task_type": task_type,
|
| 619 |
+
"instruction_display_gen": instruction_display_gen,
|
| 620 |
+
"track_name": track_name,
|
| 621 |
+
"complete_track_classes": complete_track_classes,
|
| 622 |
+
"audio_uploads_accordion": audio_uploads_accordion,
|
| 623 |
+
"reference_audio": reference_audio,
|
| 624 |
+
"src_audio": src_audio,
|
| 625 |
+
"convert_src_to_codes_btn": convert_src_to_codes_btn,
|
| 626 |
+
"text2music_audio_code_string": text2music_audio_code_string,
|
| 627 |
+
"transcribe_btn": transcribe_btn,
|
| 628 |
+
"text2music_audio_codes_group": text2music_audio_codes_group,
|
| 629 |
+
"lm_temperature": lm_temperature,
|
| 630 |
+
"lm_cfg_scale": lm_cfg_scale,
|
| 631 |
+
"lm_top_k": lm_top_k,
|
| 632 |
+
"lm_top_p": lm_top_p,
|
| 633 |
+
"lm_negative_prompt": lm_negative_prompt,
|
| 634 |
+
"use_cot_metas": use_cot_metas,
|
| 635 |
+
"use_cot_caption": use_cot_caption,
|
| 636 |
+
"use_cot_language": use_cot_language,
|
| 637 |
+
"repainting_group": repainting_group,
|
| 638 |
+
"repainting_start": repainting_start,
|
| 639 |
+
"repainting_end": repainting_end,
|
| 640 |
+
"audio_cover_strength": audio_cover_strength,
|
| 641 |
+
# Generation mode components
|
| 642 |
+
"generation_mode": generation_mode,
|
| 643 |
+
"simple_mode_group": simple_mode_group,
|
| 644 |
+
"simple_query_input": simple_query_input,
|
| 645 |
+
"random_desc_btn": random_desc_btn,
|
| 646 |
+
"simple_instrumental_checkbox": simple_instrumental_checkbox,
|
| 647 |
+
"simple_vocal_language": simple_vocal_language,
|
| 648 |
+
"create_sample_btn": create_sample_btn,
|
| 649 |
+
"simple_sample_created": simple_sample_created,
|
| 650 |
+
"caption_accordion": caption_accordion,
|
| 651 |
+
"lyrics_accordion": lyrics_accordion,
|
| 652 |
+
"optional_params_accordion": optional_params_accordion,
|
| 653 |
+
# Custom mode components
|
| 654 |
+
"custom_mode_content": custom_mode_content,
|
| 655 |
+
"cover_mode_group": cover_mode_group,
|
| 656 |
+
# Source audio group for Cover/Repaint
|
| 657 |
+
"src_audio_group": src_audio_group,
|
| 658 |
+
"process_src_btn": process_src_btn,
|
| 659 |
+
"advanced_options_accordion": advanced_options_accordion,
|
| 660 |
+
# Existing components
|
| 661 |
+
"captions": captions,
|
| 662 |
+
"sample_btn": sample_btn,
|
| 663 |
+
"load_file": load_file,
|
| 664 |
+
"lyrics": lyrics,
|
| 665 |
+
"vocal_language": vocal_language,
|
| 666 |
+
"bpm": bpm,
|
| 667 |
+
"key_scale": key_scale,
|
| 668 |
+
"time_signature": time_signature,
|
| 669 |
+
"audio_duration": audio_duration,
|
| 670 |
+
"batch_size_input": batch_size_input,
|
| 671 |
+
"inference_steps": inference_steps,
|
| 672 |
+
"guidance_scale": guidance_scale,
|
| 673 |
+
"seed": seed,
|
| 674 |
+
"random_seed_checkbox": random_seed_checkbox,
|
| 675 |
+
"use_adg": use_adg,
|
| 676 |
+
"cfg_interval_start": cfg_interval_start,
|
| 677 |
+
"cfg_interval_end": cfg_interval_end,
|
| 678 |
+
"shift": shift,
|
| 679 |
+
"infer_method": infer_method,
|
| 680 |
+
"custom_timesteps": custom_timesteps,
|
| 681 |
+
"audio_format": audio_format,
|
| 682 |
+
"think_checkbox": think_checkbox,
|
| 683 |
+
"autogen_checkbox": autogen_checkbox,
|
| 684 |
+
"generate_btn": generate_btn,
|
| 685 |
+
"instrumental_checkbox": instrumental_checkbox,
|
| 686 |
+
"format_btn": format_btn,
|
| 687 |
+
"constrained_decoding_debug": constrained_decoding_debug,
|
| 688 |
+
"score_scale": score_scale,
|
| 689 |
+
"allow_lm_batch": allow_lm_batch,
|
| 690 |
+
"auto_score": auto_score,
|
| 691 |
+
"auto_lrc": auto_lrc,
|
| 692 |
+
"lm_batch_chunk_size": lm_batch_chunk_size,
|
| 693 |
+
}
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/result.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Results Section Module
|
| 3 |
+
Contains results display section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.gradio_ui.i18n import t
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_results_section(dit_handler) -> dict:
|
| 10 |
+
"""Create results display section"""
|
| 11 |
+
with gr.Accordion(t("results.title"), open=True):
|
| 12 |
+
# Hidden state to store LM-generated metadata
|
| 13 |
+
lm_metadata_state = gr.State(value=None)
|
| 14 |
+
|
| 15 |
+
# Hidden state to track if caption/metadata is from formatted source (LM/transcription)
|
| 16 |
+
is_format_caption_state = gr.State(value=False)
|
| 17 |
+
|
| 18 |
+
# Batch management states
|
| 19 |
+
current_batch_index = gr.State(value=0) # Currently displayed batch index
|
| 20 |
+
total_batches = gr.State(value=1) # Total number of batches generated
|
| 21 |
+
batch_queue = gr.State(value={}) # Dictionary storing all batch data
|
| 22 |
+
generation_params_state = gr.State(value={}) # Store generation parameters for next batches
|
| 23 |
+
is_generating_background = gr.State(value=False) # Background generation flag
|
| 24 |
+
|
| 25 |
+
# All audio components in one row with dynamic visibility
|
| 26 |
+
with gr.Row():
|
| 27 |
+
with gr.Column(visible=True) as audio_col_1:
|
| 28 |
+
generated_audio_1 = gr.Audio(
|
| 29 |
+
label=t("results.generated_music", n=1),
|
| 30 |
+
type="filepath",
|
| 31 |
+
interactive=False,
|
| 32 |
+
buttons=[]
|
| 33 |
+
)
|
| 34 |
+
with gr.Row(equal_height=True):
|
| 35 |
+
send_to_cover_btn_1 = gr.Button(
|
| 36 |
+
t("results.send_to_cover_btn"),
|
| 37 |
+
variant="secondary",
|
| 38 |
+
size="sm",
|
| 39 |
+
scale=1
|
| 40 |
+
)
|
| 41 |
+
send_to_repaint_btn_1 = gr.Button(
|
| 42 |
+
t("results.send_to_repaint_btn"),
|
| 43 |
+
variant="secondary",
|
| 44 |
+
size="sm",
|
| 45 |
+
scale=1
|
| 46 |
+
)
|
| 47 |
+
save_btn_1 = gr.Button(
|
| 48 |
+
t("results.save_btn"),
|
| 49 |
+
variant="primary",
|
| 50 |
+
size="sm",
|
| 51 |
+
scale=1
|
| 52 |
+
)
|
| 53 |
+
score_btn_1 = gr.Button(
|
| 54 |
+
t("results.score_btn"),
|
| 55 |
+
variant="secondary",
|
| 56 |
+
size="sm",
|
| 57 |
+
scale=1,
|
| 58 |
+
visible=False
|
| 59 |
+
)
|
| 60 |
+
lrc_btn_1 = gr.Button(
|
| 61 |
+
t("results.lrc_btn"),
|
| 62 |
+
variant="secondary",
|
| 63 |
+
size="sm",
|
| 64 |
+
scale=1,
|
| 65 |
+
visible=False
|
| 66 |
+
)
|
| 67 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
|
| 68 |
+
score_display_1 = gr.Textbox(
|
| 69 |
+
label=t("results.quality_score_label", n=1),
|
| 70 |
+
interactive=False,
|
| 71 |
+
buttons=["copy"],
|
| 72 |
+
lines=6,
|
| 73 |
+
max_lines=6,
|
| 74 |
+
visible=True
|
| 75 |
+
)
|
| 76 |
+
lrc_display_1 = gr.Textbox(
|
| 77 |
+
label=t("results.lrc_label", n=1),
|
| 78 |
+
interactive=True,
|
| 79 |
+
buttons=["copy"],
|
| 80 |
+
lines=8,
|
| 81 |
+
max_lines=8,
|
| 82 |
+
visible=True
|
| 83 |
+
)
|
| 84 |
+
codes_display_1 = gr.Textbox(
|
| 85 |
+
label=t("results.codes_label", n=1),
|
| 86 |
+
interactive=False,
|
| 87 |
+
buttons=["copy"],
|
| 88 |
+
lines=4,
|
| 89 |
+
max_lines=4,
|
| 90 |
+
visible=True
|
| 91 |
+
)
|
| 92 |
+
with gr.Column(visible=True) as audio_col_2:
|
| 93 |
+
generated_audio_2 = gr.Audio(
|
| 94 |
+
label=t("results.generated_music", n=2),
|
| 95 |
+
type="filepath",
|
| 96 |
+
interactive=False,
|
| 97 |
+
buttons=[]
|
| 98 |
+
)
|
| 99 |
+
with gr.Row(equal_height=True):
|
| 100 |
+
send_to_cover_btn_2 = gr.Button(
|
| 101 |
+
t("results.send_to_cover_btn"),
|
| 102 |
+
variant="secondary",
|
| 103 |
+
size="sm",
|
| 104 |
+
scale=1
|
| 105 |
+
)
|
| 106 |
+
send_to_repaint_btn_2 = gr.Button(
|
| 107 |
+
t("results.send_to_repaint_btn"),
|
| 108 |
+
variant="secondary",
|
| 109 |
+
size="sm",
|
| 110 |
+
scale=1
|
| 111 |
+
)
|
| 112 |
+
save_btn_2 = gr.Button(
|
| 113 |
+
t("results.save_btn"),
|
| 114 |
+
variant="primary",
|
| 115 |
+
size="sm",
|
| 116 |
+
scale=1
|
| 117 |
+
)
|
| 118 |
+
score_btn_2 = gr.Button(
|
| 119 |
+
t("results.score_btn"),
|
| 120 |
+
variant="secondary",
|
| 121 |
+
size="sm",
|
| 122 |
+
scale=1,
|
| 123 |
+
visible=False
|
| 124 |
+
)
|
| 125 |
+
lrc_btn_2 = gr.Button(
|
| 126 |
+
t("results.lrc_btn"),
|
| 127 |
+
variant="secondary",
|
| 128 |
+
size="sm",
|
| 129 |
+
scale=1,
|
| 130 |
+
visible=False
|
| 131 |
+
)
|
| 132 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
|
| 133 |
+
score_display_2 = gr.Textbox(
|
| 134 |
+
label=t("results.quality_score_label", n=2),
|
| 135 |
+
interactive=False,
|
| 136 |
+
buttons=["copy"],
|
| 137 |
+
lines=6,
|
| 138 |
+
max_lines=6,
|
| 139 |
+
visible=True
|
| 140 |
+
)
|
| 141 |
+
lrc_display_2 = gr.Textbox(
|
| 142 |
+
label=t("results.lrc_label", n=2),
|
| 143 |
+
interactive=True,
|
| 144 |
+
buttons=["copy"],
|
| 145 |
+
lines=8,
|
| 146 |
+
max_lines=8,
|
| 147 |
+
visible=True
|
| 148 |
+
)
|
| 149 |
+
codes_display_2 = gr.Textbox(
|
| 150 |
+
label=t("results.codes_label", n=2),
|
| 151 |
+
interactive=False,
|
| 152 |
+
buttons=["copy"],
|
| 153 |
+
lines=4,
|
| 154 |
+
max_lines=4,
|
| 155 |
+
visible=True
|
| 156 |
+
)
|
| 157 |
+
with gr.Column(visible=False) as audio_col_3:
|
| 158 |
+
generated_audio_3 = gr.Audio(
|
| 159 |
+
label=t("results.generated_music", n=3),
|
| 160 |
+
type="filepath",
|
| 161 |
+
interactive=False,
|
| 162 |
+
buttons=[]
|
| 163 |
+
)
|
| 164 |
+
with gr.Row(equal_height=True):
|
| 165 |
+
send_to_cover_btn_3 = gr.Button(
|
| 166 |
+
t("results.send_to_cover_btn"),
|
| 167 |
+
variant="secondary",
|
| 168 |
+
size="sm",
|
| 169 |
+
scale=1
|
| 170 |
+
)
|
| 171 |
+
send_to_repaint_btn_3 = gr.Button(
|
| 172 |
+
t("results.send_to_repaint_btn"),
|
| 173 |
+
variant="secondary",
|
| 174 |
+
size="sm",
|
| 175 |
+
scale=1
|
| 176 |
+
)
|
| 177 |
+
save_btn_3 = gr.Button(
|
| 178 |
+
t("results.save_btn"),
|
| 179 |
+
variant="primary",
|
| 180 |
+
size="sm",
|
| 181 |
+
scale=1
|
| 182 |
+
)
|
| 183 |
+
score_btn_3 = gr.Button(
|
| 184 |
+
t("results.score_btn"),
|
| 185 |
+
variant="secondary",
|
| 186 |
+
size="sm",
|
| 187 |
+
scale=1,
|
| 188 |
+
visible=False
|
| 189 |
+
)
|
| 190 |
+
lrc_btn_3 = gr.Button(
|
| 191 |
+
t("results.lrc_btn"),
|
| 192 |
+
variant="secondary",
|
| 193 |
+
size="sm",
|
| 194 |
+
scale=1,
|
| 195 |
+
visible=False
|
| 196 |
+
)
|
| 197 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
|
| 198 |
+
score_display_3 = gr.Textbox(
|
| 199 |
+
label=t("results.quality_score_label", n=3),
|
| 200 |
+
interactive=False,
|
| 201 |
+
buttons=["copy"],
|
| 202 |
+
lines=6,
|
| 203 |
+
max_lines=6,
|
| 204 |
+
visible=True
|
| 205 |
+
)
|
| 206 |
+
lrc_display_3 = gr.Textbox(
|
| 207 |
+
label=t("results.lrc_label", n=3),
|
| 208 |
+
interactive=True,
|
| 209 |
+
buttons=["copy"],
|
| 210 |
+
lines=8,
|
| 211 |
+
max_lines=8,
|
| 212 |
+
visible=True
|
| 213 |
+
)
|
| 214 |
+
codes_display_3 = gr.Textbox(
|
| 215 |
+
label=t("results.codes_label", n=3),
|
| 216 |
+
interactive=False,
|
| 217 |
+
buttons=["copy"],
|
| 218 |
+
lines=4,
|
| 219 |
+
max_lines=4,
|
| 220 |
+
visible=True
|
| 221 |
+
)
|
| 222 |
+
with gr.Column(visible=False) as audio_col_4:
|
| 223 |
+
generated_audio_4 = gr.Audio(
|
| 224 |
+
label=t("results.generated_music", n=4),
|
| 225 |
+
type="filepath",
|
| 226 |
+
interactive=False,
|
| 227 |
+
buttons=[]
|
| 228 |
+
)
|
| 229 |
+
with gr.Row(equal_height=True):
|
| 230 |
+
send_to_cover_btn_4 = gr.Button(
|
| 231 |
+
t("results.send_to_cover_btn"),
|
| 232 |
+
variant="secondary",
|
| 233 |
+
size="sm",
|
| 234 |
+
scale=1
|
| 235 |
+
)
|
| 236 |
+
send_to_repaint_btn_4 = gr.Button(
|
| 237 |
+
t("results.send_to_repaint_btn"),
|
| 238 |
+
variant="secondary",
|
| 239 |
+
size="sm",
|
| 240 |
+
scale=1
|
| 241 |
+
)
|
| 242 |
+
save_btn_4 = gr.Button(
|
| 243 |
+
t("results.save_btn"),
|
| 244 |
+
variant="primary",
|
| 245 |
+
size="sm",
|
| 246 |
+
scale=1
|
| 247 |
+
)
|
| 248 |
+
score_btn_4 = gr.Button(
|
| 249 |
+
t("results.score_btn"),
|
| 250 |
+
variant="secondary",
|
| 251 |
+
size="sm",
|
| 252 |
+
scale=1,
|
| 253 |
+
visible=False
|
| 254 |
+
)
|
| 255 |
+
lrc_btn_4 = gr.Button(
|
| 256 |
+
t("results.lrc_btn"),
|
| 257 |
+
variant="secondary",
|
| 258 |
+
size="sm",
|
| 259 |
+
scale=1,
|
| 260 |
+
visible=False
|
| 261 |
+
)
|
| 262 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
|
| 263 |
+
score_display_4 = gr.Textbox(
|
| 264 |
+
label=t("results.quality_score_label", n=4),
|
| 265 |
+
interactive=False,
|
| 266 |
+
buttons=["copy"],
|
| 267 |
+
lines=6,
|
| 268 |
+
max_lines=6,
|
| 269 |
+
visible=True
|
| 270 |
+
)
|
| 271 |
+
lrc_display_4 = gr.Textbox(
|
| 272 |
+
label=t("results.lrc_label", n=4),
|
| 273 |
+
interactive=True,
|
| 274 |
+
buttons=["copy"],
|
| 275 |
+
lines=8,
|
| 276 |
+
max_lines=8,
|
| 277 |
+
visible=True
|
| 278 |
+
)
|
| 279 |
+
codes_display_4 = gr.Textbox(
|
| 280 |
+
label=t("results.codes_label", n=4),
|
| 281 |
+
interactive=False,
|
| 282 |
+
buttons=["copy"],
|
| 283 |
+
lines=4,
|
| 284 |
+
max_lines=4,
|
| 285 |
+
visible=True
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Second row for batch size 5-8 (initially hidden)
|
| 289 |
+
with gr.Row(visible=False) as audio_row_5_8:
|
| 290 |
+
with gr.Column() as audio_col_5:
|
| 291 |
+
generated_audio_5 = gr.Audio(
|
| 292 |
+
label=t("results.generated_music", n=5),
|
| 293 |
+
type="filepath",
|
| 294 |
+
interactive=False,
|
| 295 |
+
buttons=[]
|
| 296 |
+
)
|
| 297 |
+
with gr.Row(equal_height=True):
|
| 298 |
+
send_to_cover_btn_5 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 299 |
+
send_to_repaint_btn_5 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 300 |
+
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 301 |
+
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 302 |
+
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 303 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
|
| 304 |
+
score_display_5 = gr.Textbox(
|
| 305 |
+
label=t("results.quality_score_label", n=5),
|
| 306 |
+
interactive=False,
|
| 307 |
+
buttons=["copy"],
|
| 308 |
+
lines=6,
|
| 309 |
+
max_lines=6,
|
| 310 |
+
visible=True
|
| 311 |
+
)
|
| 312 |
+
lrc_display_5 = gr.Textbox(
|
| 313 |
+
label=t("results.lrc_label", n=5),
|
| 314 |
+
interactive=True,
|
| 315 |
+
buttons=["copy"],
|
| 316 |
+
lines=8,
|
| 317 |
+
max_lines=8,
|
| 318 |
+
visible=True
|
| 319 |
+
)
|
| 320 |
+
codes_display_5 = gr.Textbox(
|
| 321 |
+
label=t("results.codes_label", n=5),
|
| 322 |
+
interactive=False,
|
| 323 |
+
buttons=["copy"],
|
| 324 |
+
lines=4,
|
| 325 |
+
max_lines=4,
|
| 326 |
+
visible=True
|
| 327 |
+
)
|
| 328 |
+
with gr.Column() as audio_col_6:
|
| 329 |
+
generated_audio_6 = gr.Audio(
|
| 330 |
+
label=t("results.generated_music", n=6),
|
| 331 |
+
type="filepath",
|
| 332 |
+
interactive=False,
|
| 333 |
+
buttons=[]
|
| 334 |
+
)
|
| 335 |
+
with gr.Row(equal_height=True):
|
| 336 |
+
send_to_cover_btn_6 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 337 |
+
send_to_repaint_btn_6 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 338 |
+
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 339 |
+
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 340 |
+
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 341 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
|
| 342 |
+
score_display_6 = gr.Textbox(
|
| 343 |
+
label=t("results.quality_score_label", n=6),
|
| 344 |
+
interactive=False,
|
| 345 |
+
buttons=["copy"],
|
| 346 |
+
lines=6,
|
| 347 |
+
max_lines=6,
|
| 348 |
+
visible=True
|
| 349 |
+
)
|
| 350 |
+
lrc_display_6 = gr.Textbox(
|
| 351 |
+
label=t("results.lrc_label", n=6),
|
| 352 |
+
interactive=True,
|
| 353 |
+
buttons=["copy"],
|
| 354 |
+
lines=8,
|
| 355 |
+
max_lines=8,
|
| 356 |
+
visible=True
|
| 357 |
+
)
|
| 358 |
+
codes_display_6 = gr.Textbox(
|
| 359 |
+
label=t("results.codes_label", n=6),
|
| 360 |
+
interactive=False,
|
| 361 |
+
buttons=["copy"],
|
| 362 |
+
lines=4,
|
| 363 |
+
max_lines=4,
|
| 364 |
+
visible=True
|
| 365 |
+
)
|
| 366 |
+
with gr.Column() as audio_col_7:
|
| 367 |
+
generated_audio_7 = gr.Audio(
|
| 368 |
+
label=t("results.generated_music", n=7),
|
| 369 |
+
type="filepath",
|
| 370 |
+
interactive=False,
|
| 371 |
+
buttons=[]
|
| 372 |
+
)
|
| 373 |
+
with gr.Row(equal_height=True):
|
| 374 |
+
send_to_cover_btn_7 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 375 |
+
send_to_repaint_btn_7 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 376 |
+
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 377 |
+
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 378 |
+
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 379 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
|
| 380 |
+
score_display_7 = gr.Textbox(
|
| 381 |
+
label=t("results.quality_score_label", n=7),
|
| 382 |
+
interactive=False,
|
| 383 |
+
buttons=["copy"],
|
| 384 |
+
lines=6,
|
| 385 |
+
max_lines=6,
|
| 386 |
+
visible=True
|
| 387 |
+
)
|
| 388 |
+
lrc_display_7 = gr.Textbox(
|
| 389 |
+
label=t("results.lrc_label", n=7),
|
| 390 |
+
interactive=True,
|
| 391 |
+
buttons=["copy"],
|
| 392 |
+
lines=8,
|
| 393 |
+
max_lines=8,
|
| 394 |
+
visible=True
|
| 395 |
+
)
|
| 396 |
+
codes_display_7 = gr.Textbox(
|
| 397 |
+
label=t("results.codes_label", n=7),
|
| 398 |
+
interactive=False,
|
| 399 |
+
buttons=["copy"],
|
| 400 |
+
lines=4,
|
| 401 |
+
max_lines=4,
|
| 402 |
+
visible=True
|
| 403 |
+
)
|
| 404 |
+
with gr.Column() as audio_col_8:
|
| 405 |
+
generated_audio_8 = gr.Audio(
|
| 406 |
+
label=t("results.generated_music", n=8),
|
| 407 |
+
type="filepath",
|
| 408 |
+
interactive=False,
|
| 409 |
+
buttons=[]
|
| 410 |
+
)
|
| 411 |
+
with gr.Row(equal_height=True):
|
| 412 |
+
send_to_cover_btn_8 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
|
| 413 |
+
send_to_repaint_btn_8 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
|
| 414 |
+
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 415 |
+
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 416 |
+
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
|
| 417 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
|
| 418 |
+
score_display_8 = gr.Textbox(
|
| 419 |
+
label=t("results.quality_score_label", n=8),
|
| 420 |
+
interactive=False,
|
| 421 |
+
buttons=["copy"],
|
| 422 |
+
lines=6,
|
| 423 |
+
max_lines=6,
|
| 424 |
+
visible=True
|
| 425 |
+
)
|
| 426 |
+
lrc_display_8 = gr.Textbox(
|
| 427 |
+
label=t("results.lrc_label", n=8),
|
| 428 |
+
interactive=True,
|
| 429 |
+
buttons=["copy"],
|
| 430 |
+
lines=8,
|
| 431 |
+
max_lines=8,
|
| 432 |
+
visible=True
|
| 433 |
+
)
|
| 434 |
+
codes_display_8 = gr.Textbox(
|
| 435 |
+
label=t("results.codes_label", n=8),
|
| 436 |
+
interactive=False,
|
| 437 |
+
buttons=["copy"],
|
| 438 |
+
lines=4,
|
| 439 |
+
max_lines=4,
|
| 440 |
+
visible=True
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 444 |
+
|
| 445 |
+
# Batch navigation controls (hidden for simplified UI)
|
| 446 |
+
with gr.Row(equal_height=True, visible=False):
|
| 447 |
+
prev_batch_btn = gr.Button(
|
| 448 |
+
t("results.prev_btn"),
|
| 449 |
+
variant="secondary",
|
| 450 |
+
interactive=False,
|
| 451 |
+
scale=1,
|
| 452 |
+
size="sm"
|
| 453 |
+
)
|
| 454 |
+
batch_indicator = gr.Textbox(
|
| 455 |
+
label=t("results.current_batch"),
|
| 456 |
+
value=t("results.batch_indicator", current=1, total=1),
|
| 457 |
+
interactive=False,
|
| 458 |
+
scale=3
|
| 459 |
+
)
|
| 460 |
+
next_batch_status = gr.Textbox(
|
| 461 |
+
label=t("results.next_batch_status"),
|
| 462 |
+
value="",
|
| 463 |
+
interactive=False,
|
| 464 |
+
scale=3
|
| 465 |
+
)
|
| 466 |
+
next_batch_btn = gr.Button(
|
| 467 |
+
t("results.next_btn"),
|
| 468 |
+
variant="primary",
|
| 469 |
+
interactive=False,
|
| 470 |
+
scale=1,
|
| 471 |
+
size="sm"
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# One-click restore parameters button (hidden for simplified UI)
|
| 475 |
+
restore_params_btn = gr.Button(
|
| 476 |
+
t("results.restore_params_btn"),
|
| 477 |
+
variant="secondary",
|
| 478 |
+
interactive=False,
|
| 479 |
+
size="sm",
|
| 480 |
+
visible=False
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
with gr.Accordion(t("results.batch_results_title"), open=True):
|
| 484 |
+
generated_audio_batch = gr.File(
|
| 485 |
+
label=t("results.all_files_label"),
|
| 486 |
+
file_count="multiple",
|
| 487 |
+
interactive=False,
|
| 488 |
+
visible=False
|
| 489 |
+
)
|
| 490 |
+
generation_info = gr.Markdown(label=t("results.generation_details"))
|
| 491 |
+
|
| 492 |
+
return {
|
| 493 |
+
"lm_metadata_state": lm_metadata_state,
|
| 494 |
+
"is_format_caption_state": is_format_caption_state,
|
| 495 |
+
"current_batch_index": current_batch_index,
|
| 496 |
+
"total_batches": total_batches,
|
| 497 |
+
"batch_queue": batch_queue,
|
| 498 |
+
"generation_params_state": generation_params_state,
|
| 499 |
+
"is_generating_background": is_generating_background,
|
| 500 |
+
"status_output": status_output,
|
| 501 |
+
"prev_batch_btn": prev_batch_btn,
|
| 502 |
+
"batch_indicator": batch_indicator,
|
| 503 |
+
"next_batch_btn": next_batch_btn,
|
| 504 |
+
"next_batch_status": next_batch_status,
|
| 505 |
+
"restore_params_btn": restore_params_btn,
|
| 506 |
+
"generated_audio_1": generated_audio_1,
|
| 507 |
+
"generated_audio_2": generated_audio_2,
|
| 508 |
+
"generated_audio_3": generated_audio_3,
|
| 509 |
+
"generated_audio_4": generated_audio_4,
|
| 510 |
+
"generated_audio_5": generated_audio_5,
|
| 511 |
+
"generated_audio_6": generated_audio_6,
|
| 512 |
+
"generated_audio_7": generated_audio_7,
|
| 513 |
+
"generated_audio_8": generated_audio_8,
|
| 514 |
+
"audio_row_5_8": audio_row_5_8,
|
| 515 |
+
"audio_col_1": audio_col_1,
|
| 516 |
+
"audio_col_2": audio_col_2,
|
| 517 |
+
"audio_col_3": audio_col_3,
|
| 518 |
+
"audio_col_4": audio_col_4,
|
| 519 |
+
"audio_col_5": audio_col_5,
|
| 520 |
+
"audio_col_6": audio_col_6,
|
| 521 |
+
"audio_col_7": audio_col_7,
|
| 522 |
+
"audio_col_8": audio_col_8,
|
| 523 |
+
"send_to_cover_btn_1": send_to_cover_btn_1,
|
| 524 |
+
"send_to_cover_btn_2": send_to_cover_btn_2,
|
| 525 |
+
"send_to_cover_btn_3": send_to_cover_btn_3,
|
| 526 |
+
"send_to_cover_btn_4": send_to_cover_btn_4,
|
| 527 |
+
"send_to_cover_btn_5": send_to_cover_btn_5,
|
| 528 |
+
"send_to_cover_btn_6": send_to_cover_btn_6,
|
| 529 |
+
"send_to_cover_btn_7": send_to_cover_btn_7,
|
| 530 |
+
"send_to_cover_btn_8": send_to_cover_btn_8,
|
| 531 |
+
"send_to_repaint_btn_1": send_to_repaint_btn_1,
|
| 532 |
+
"send_to_repaint_btn_2": send_to_repaint_btn_2,
|
| 533 |
+
"send_to_repaint_btn_3": send_to_repaint_btn_3,
|
| 534 |
+
"send_to_repaint_btn_4": send_to_repaint_btn_4,
|
| 535 |
+
"send_to_repaint_btn_5": send_to_repaint_btn_5,
|
| 536 |
+
"send_to_repaint_btn_6": send_to_repaint_btn_6,
|
| 537 |
+
"send_to_repaint_btn_7": send_to_repaint_btn_7,
|
| 538 |
+
"send_to_repaint_btn_8": send_to_repaint_btn_8,
|
| 539 |
+
"save_btn_1": save_btn_1,
|
| 540 |
+
"save_btn_2": save_btn_2,
|
| 541 |
+
"save_btn_3": save_btn_3,
|
| 542 |
+
"save_btn_4": save_btn_4,
|
| 543 |
+
"save_btn_5": save_btn_5,
|
| 544 |
+
"save_btn_6": save_btn_6,
|
| 545 |
+
"save_btn_7": save_btn_7,
|
| 546 |
+
"save_btn_8": save_btn_8,
|
| 547 |
+
"score_btn_1": score_btn_1,
|
| 548 |
+
"score_btn_2": score_btn_2,
|
| 549 |
+
"score_btn_3": score_btn_3,
|
| 550 |
+
"score_btn_4": score_btn_4,
|
| 551 |
+
"score_btn_5": score_btn_5,
|
| 552 |
+
"score_btn_6": score_btn_6,
|
| 553 |
+
"score_btn_7": score_btn_7,
|
| 554 |
+
"score_btn_8": score_btn_8,
|
| 555 |
+
"score_display_1": score_display_1,
|
| 556 |
+
"score_display_2": score_display_2,
|
| 557 |
+
"score_display_3": score_display_3,
|
| 558 |
+
"score_display_4": score_display_4,
|
| 559 |
+
"score_display_5": score_display_5,
|
| 560 |
+
"score_display_6": score_display_6,
|
| 561 |
+
"score_display_7": score_display_7,
|
| 562 |
+
"score_display_8": score_display_8,
|
| 563 |
+
"codes_display_1": codes_display_1,
|
| 564 |
+
"codes_display_2": codes_display_2,
|
| 565 |
+
"codes_display_3": codes_display_3,
|
| 566 |
+
"codes_display_4": codes_display_4,
|
| 567 |
+
"codes_display_5": codes_display_5,
|
| 568 |
+
"codes_display_6": codes_display_6,
|
| 569 |
+
"codes_display_7": codes_display_7,
|
| 570 |
+
"codes_display_8": codes_display_8,
|
| 571 |
+
"lrc_btn_1": lrc_btn_1,
|
| 572 |
+
"lrc_btn_2": lrc_btn_2,
|
| 573 |
+
"lrc_btn_3": lrc_btn_3,
|
| 574 |
+
"lrc_btn_4": lrc_btn_4,
|
| 575 |
+
"lrc_btn_5": lrc_btn_5,
|
| 576 |
+
"lrc_btn_6": lrc_btn_6,
|
| 577 |
+
"lrc_btn_7": lrc_btn_7,
|
| 578 |
+
"lrc_btn_8": lrc_btn_8,
|
| 579 |
+
"lrc_display_1": lrc_display_1,
|
| 580 |
+
"lrc_display_2": lrc_display_2,
|
| 581 |
+
"lrc_display_3": lrc_display_3,
|
| 582 |
+
"lrc_display_4": lrc_display_4,
|
| 583 |
+
"lrc_display_5": lrc_display_5,
|
| 584 |
+
"lrc_display_6": lrc_display_6,
|
| 585 |
+
"lrc_display_7": lrc_display_7,
|
| 586 |
+
"lrc_display_8": lrc_display_8,
|
| 587 |
+
"details_accordion_1": details_accordion_1,
|
| 588 |
+
"details_accordion_2": details_accordion_2,
|
| 589 |
+
"details_accordion_3": details_accordion_3,
|
| 590 |
+
"details_accordion_4": details_accordion_4,
|
| 591 |
+
"details_accordion_5": details_accordion_5,
|
| 592 |
+
"details_accordion_6": details_accordion_6,
|
| 593 |
+
"details_accordion_7": details_accordion_7,
|
| 594 |
+
"details_accordion_8": details_accordion_8,
|
| 595 |
+
"generated_audio_batch": generated_audio_batch,
|
| 596 |
+
"generation_info": generation_info,
|
| 597 |
+
}
|
| 598 |
+
|
spaces/Ace-Step-v1.5/acestep/gradio_ui/interfaces/training.py
ADDED
|
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Training Tab Module
|
| 3 |
+
|
| 4 |
+
Contains the dataset builder and LoRA training interface components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from acestep.gradio_ui.i18n import t
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
|
| 13 |
+
"""Create the training tab section with dataset builder and training controls.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
dit_handler: DiT handler instance
|
| 17 |
+
llm_handler: LLM handler instance
|
| 18 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 19 |
+
If None, service will not be pre-initialized.
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
Dictionary of Gradio components for event handling
|
| 23 |
+
"""
|
| 24 |
+
# Check if running in service mode (hide training tab)
|
| 25 |
+
service_mode = init_params is not None and init_params.get('service_mode', False)
|
| 26 |
+
|
| 27 |
+
with gr.Tab("🎓 LoRA Training", visible=not service_mode):
|
| 28 |
+
gr.HTML("""
|
| 29 |
+
<div style="text-align: center; padding: 10px; margin-bottom: 15px;">
|
| 30 |
+
<h2>🎵 LoRA Training for ACE-Step</h2>
|
| 31 |
+
<p>Build datasets from your audio files and train custom LoRA adapters</p>
|
| 32 |
+
</div>
|
| 33 |
+
""")
|
| 34 |
+
|
| 35 |
+
with gr.Tabs():
|
| 36 |
+
# ==================== Dataset Builder Tab ====================
|
| 37 |
+
with gr.Tab("📁 Dataset Builder"):
|
| 38 |
+
# ========== Load Existing OR Scan New ==========
|
| 39 |
+
gr.HTML("""
|
| 40 |
+
<div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
|
| 41 |
+
<h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
|
| 42 |
+
<p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
|
| 43 |
+
</div>
|
| 44 |
+
""")
|
| 45 |
+
|
| 46 |
+
with gr.Row():
|
| 47 |
+
with gr.Column(scale=1):
|
| 48 |
+
gr.HTML("<h4>📂 Load Existing Dataset</h4>")
|
| 49 |
+
with gr.Row():
|
| 50 |
+
load_json_path = gr.Textbox(
|
| 51 |
+
label="Dataset JSON Path",
|
| 52 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 53 |
+
info="Load a previously saved dataset",
|
| 54 |
+
scale=3,
|
| 55 |
+
)
|
| 56 |
+
load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
|
| 57 |
+
load_json_status = gr.Textbox(
|
| 58 |
+
label="Load Status",
|
| 59 |
+
interactive=False,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
with gr.Column(scale=1):
|
| 63 |
+
gr.HTML("<h4>🔍 Scan New Directory</h4>")
|
| 64 |
+
with gr.Row():
|
| 65 |
+
audio_directory = gr.Textbox(
|
| 66 |
+
label="Audio Directory Path",
|
| 67 |
+
placeholder="/path/to/your/audio/folder",
|
| 68 |
+
info="Scan for audio files (wav, mp3, flac, ogg, opus)",
|
| 69 |
+
scale=3,
|
| 70 |
+
)
|
| 71 |
+
scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
|
| 72 |
+
scan_status = gr.Textbox(
|
| 73 |
+
label="Scan Status",
|
| 74 |
+
interactive=False,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
gr.HTML("<hr>")
|
| 78 |
+
|
| 79 |
+
with gr.Row():
|
| 80 |
+
with gr.Column(scale=2):
|
| 81 |
+
|
| 82 |
+
# Audio files table
|
| 83 |
+
audio_files_table = gr.Dataframe(
|
| 84 |
+
headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
|
| 85 |
+
datatype=["number", "str", "str", "str", "str", "str", "str"],
|
| 86 |
+
label="Found Audio Files",
|
| 87 |
+
interactive=False,
|
| 88 |
+
wrap=True,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
with gr.Column(scale=1):
|
| 92 |
+
gr.HTML("<h3>⚙️ Dataset Settings</h3>")
|
| 93 |
+
|
| 94 |
+
dataset_name = gr.Textbox(
|
| 95 |
+
label="Dataset Name",
|
| 96 |
+
value="my_lora_dataset",
|
| 97 |
+
placeholder="Enter dataset name",
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
all_instrumental = gr.Checkbox(
|
| 101 |
+
label="All Instrumental",
|
| 102 |
+
value=True,
|
| 103 |
+
info="Check if all tracks are instrumental (no vocals)",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
need_lyrics = gr.Checkbox(
|
| 107 |
+
label="Transcribe Lyrics",
|
| 108 |
+
value=False,
|
| 109 |
+
info="Attempt to transcribe lyrics (slower)",
|
| 110 |
+
interactive=False, # Disabled for now
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
custom_tag = gr.Textbox(
|
| 114 |
+
label="Custom Activation Tag",
|
| 115 |
+
placeholder="e.g., 8bit_retro, my_style",
|
| 116 |
+
info="Unique tag to activate this LoRA's style",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
tag_position = gr.Radio(
|
| 120 |
+
choices=[
|
| 121 |
+
("Prepend (tag, caption)", "prepend"),
|
| 122 |
+
("Append (caption, tag)", "append"),
|
| 123 |
+
("Replace caption", "replace"),
|
| 124 |
+
],
|
| 125 |
+
value="replace",
|
| 126 |
+
label="Tag Position",
|
| 127 |
+
info="Where to place the custom tag in the caption",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
|
| 131 |
+
|
| 132 |
+
with gr.Row():
|
| 133 |
+
with gr.Column(scale=3):
|
| 134 |
+
gr.Markdown("""
|
| 135 |
+
Click the button below to automatically generate metadata for all audio files using AI:
|
| 136 |
+
- **Caption**: Music style, genre, mood description
|
| 137 |
+
- **BPM**: Beats per minute
|
| 138 |
+
- **Key**: Musical key (e.g., C Major, Am)
|
| 139 |
+
- **Time Signature**: 4/4, 3/4, etc.
|
| 140 |
+
""")
|
| 141 |
+
skip_metas = gr.Checkbox(
|
| 142 |
+
label="Skip Metas (No LLM)",
|
| 143 |
+
value=False,
|
| 144 |
+
info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
|
| 145 |
+
)
|
| 146 |
+
with gr.Column(scale=1):
|
| 147 |
+
auto_label_btn = gr.Button(
|
| 148 |
+
"🏷️ Auto-Label All",
|
| 149 |
+
variant="primary",
|
| 150 |
+
size="lg",
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
label_progress = gr.Textbox(
|
| 154 |
+
label="Labeling Progress",
|
| 155 |
+
interactive=False,
|
| 156 |
+
lines=2,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
|
| 160 |
+
|
| 161 |
+
with gr.Row():
|
| 162 |
+
with gr.Column(scale=1):
|
| 163 |
+
sample_selector = gr.Slider(
|
| 164 |
+
minimum=0,
|
| 165 |
+
maximum=0,
|
| 166 |
+
step=1,
|
| 167 |
+
value=0,
|
| 168 |
+
label="Select Sample #",
|
| 169 |
+
info="Choose a sample to preview and edit",
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
preview_audio = gr.Audio(
|
| 173 |
+
label="Audio Preview",
|
| 174 |
+
type="filepath",
|
| 175 |
+
interactive=False,
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
preview_filename = gr.Textbox(
|
| 179 |
+
label="Filename",
|
| 180 |
+
interactive=False,
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
with gr.Column(scale=2):
|
| 184 |
+
with gr.Row():
|
| 185 |
+
edit_caption = gr.Textbox(
|
| 186 |
+
label="Caption",
|
| 187 |
+
lines=3,
|
| 188 |
+
placeholder="Music description...",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
with gr.Row():
|
| 192 |
+
edit_lyrics = gr.Textbox(
|
| 193 |
+
label="Lyrics",
|
| 194 |
+
lines=4,
|
| 195 |
+
placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
with gr.Row():
|
| 199 |
+
edit_bpm = gr.Number(
|
| 200 |
+
label="BPM",
|
| 201 |
+
precision=0,
|
| 202 |
+
)
|
| 203 |
+
edit_keyscale = gr.Textbox(
|
| 204 |
+
label="Key",
|
| 205 |
+
placeholder="C Major",
|
| 206 |
+
)
|
| 207 |
+
edit_timesig = gr.Dropdown(
|
| 208 |
+
choices=["", "2", "3", "4", "6"],
|
| 209 |
+
label="Time Signature",
|
| 210 |
+
)
|
| 211 |
+
edit_duration = gr.Number(
|
| 212 |
+
label="Duration (s)",
|
| 213 |
+
precision=1,
|
| 214 |
+
interactive=False,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
with gr.Row():
|
| 218 |
+
edit_language = gr.Dropdown(
|
| 219 |
+
choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
|
| 220 |
+
value="instrumental",
|
| 221 |
+
label="Language",
|
| 222 |
+
)
|
| 223 |
+
edit_instrumental = gr.Checkbox(
|
| 224 |
+
label="Instrumental",
|
| 225 |
+
value=True,
|
| 226 |
+
)
|
| 227 |
+
save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
|
| 228 |
+
|
| 229 |
+
edit_status = gr.Textbox(
|
| 230 |
+
label="Edit Status",
|
| 231 |
+
interactive=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
|
| 235 |
+
|
| 236 |
+
with gr.Row():
|
| 237 |
+
with gr.Column(scale=3):
|
| 238 |
+
save_path = gr.Textbox(
|
| 239 |
+
label="Save Path",
|
| 240 |
+
value="./datasets/my_lora_dataset.json",
|
| 241 |
+
placeholder="./datasets/dataset_name.json",
|
| 242 |
+
info="Path where the dataset JSON will be saved",
|
| 243 |
+
)
|
| 244 |
+
with gr.Column(scale=1):
|
| 245 |
+
save_dataset_btn = gr.Button(
|
| 246 |
+
"💾 Save Dataset",
|
| 247 |
+
variant="primary",
|
| 248 |
+
size="lg",
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
save_status = gr.Textbox(
|
| 252 |
+
label="Save Status",
|
| 253 |
+
interactive=False,
|
| 254 |
+
lines=2,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
|
| 258 |
+
|
| 259 |
+
gr.Markdown("""
|
| 260 |
+
**Preprocessing converts your dataset to pre-computed tensors for fast training.**
|
| 261 |
+
|
| 262 |
+
You can either:
|
| 263 |
+
- Use the dataset from Steps 1-4 above, **OR**
|
| 264 |
+
- Load an existing dataset JSON file (if you've already saved one)
|
| 265 |
+
""")
|
| 266 |
+
|
| 267 |
+
with gr.Row():
|
| 268 |
+
with gr.Column(scale=3):
|
| 269 |
+
load_existing_dataset_path = gr.Textbox(
|
| 270 |
+
label="Load Existing Dataset (Optional)",
|
| 271 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 272 |
+
info="Path to a previously saved dataset JSON file",
|
| 273 |
+
)
|
| 274 |
+
with gr.Column(scale=1):
|
| 275 |
+
load_existing_dataset_btn = gr.Button(
|
| 276 |
+
"📂 Load Dataset",
|
| 277 |
+
variant="secondary",
|
| 278 |
+
size="lg",
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
load_existing_status = gr.Textbox(
|
| 282 |
+
label="Load Status",
|
| 283 |
+
interactive=False,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
gr.Markdown("""
|
| 287 |
+
This step:
|
| 288 |
+
- Encodes audio to VAE latents
|
| 289 |
+
- Encodes captions and lyrics to text embeddings
|
| 290 |
+
- Runs the condition encoder
|
| 291 |
+
- Saves all tensors to `.pt` files
|
| 292 |
+
|
| 293 |
+
⚠️ **This requires the model to be loaded and may take a few minutes.**
|
| 294 |
+
""")
|
| 295 |
+
|
| 296 |
+
with gr.Row():
|
| 297 |
+
with gr.Column(scale=3):
|
| 298 |
+
preprocess_output_dir = gr.Textbox(
|
| 299 |
+
label="Tensor Output Directory",
|
| 300 |
+
value="./datasets/preprocessed_tensors",
|
| 301 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 302 |
+
info="Directory to save preprocessed tensor files",
|
| 303 |
+
)
|
| 304 |
+
with gr.Column(scale=1):
|
| 305 |
+
preprocess_btn = gr.Button(
|
| 306 |
+
"⚡ Preprocess",
|
| 307 |
+
variant="primary",
|
| 308 |
+
size="lg",
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
preprocess_progress = gr.Textbox(
|
| 312 |
+
label="Preprocessing Progress",
|
| 313 |
+
interactive=False,
|
| 314 |
+
lines=3,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# ==================== Training Tab ====================
|
| 318 |
+
with gr.Tab("🚀 Train LoRA"):
|
| 319 |
+
with gr.Row():
|
| 320 |
+
with gr.Column(scale=2):
|
| 321 |
+
gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
|
| 322 |
+
|
| 323 |
+
gr.Markdown("""
|
| 324 |
+
Select the directory containing preprocessed tensor files (`.pt` files).
|
| 325 |
+
These are created in the "Dataset Builder" tab using the "Preprocess" button.
|
| 326 |
+
""")
|
| 327 |
+
|
| 328 |
+
training_tensor_dir = gr.Textbox(
|
| 329 |
+
label="Preprocessed Tensors Directory",
|
| 330 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 331 |
+
value="./datasets/preprocessed_tensors",
|
| 332 |
+
info="Directory containing preprocessed .pt tensor files",
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
|
| 336 |
+
|
| 337 |
+
training_dataset_info = gr.Textbox(
|
| 338 |
+
label="Dataset Info",
|
| 339 |
+
interactive=False,
|
| 340 |
+
lines=3,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
with gr.Column(scale=1):
|
| 344 |
+
gr.HTML("<h3>⚙️ LoRA Settings</h3>")
|
| 345 |
+
|
| 346 |
+
lora_rank = gr.Slider(
|
| 347 |
+
minimum=4,
|
| 348 |
+
maximum=256,
|
| 349 |
+
step=4,
|
| 350 |
+
value=64,
|
| 351 |
+
label="LoRA Rank (r)",
|
| 352 |
+
info="Higher = more capacity, more memory",
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
lora_alpha = gr.Slider(
|
| 356 |
+
minimum=4,
|
| 357 |
+
maximum=512,
|
| 358 |
+
step=4,
|
| 359 |
+
value=128,
|
| 360 |
+
label="LoRA Alpha",
|
| 361 |
+
info="Scaling factor (typically 2x rank)",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
lora_dropout = gr.Slider(
|
| 365 |
+
minimum=0.0,
|
| 366 |
+
maximum=0.5,
|
| 367 |
+
step=0.05,
|
| 368 |
+
value=0.1,
|
| 369 |
+
label="LoRA Dropout",
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
|
| 373 |
+
|
| 374 |
+
with gr.Row():
|
| 375 |
+
learning_rate = gr.Number(
|
| 376 |
+
label="Learning Rate",
|
| 377 |
+
value=1e-4,
|
| 378 |
+
info="Start with 1e-4, adjust if needed",
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
train_epochs = gr.Slider(
|
| 382 |
+
minimum=100,
|
| 383 |
+
maximum=4000,
|
| 384 |
+
step=100,
|
| 385 |
+
value=500,
|
| 386 |
+
label="Max Epochs",
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
train_batch_size = gr.Slider(
|
| 390 |
+
minimum=1,
|
| 391 |
+
maximum=8,
|
| 392 |
+
step=1,
|
| 393 |
+
value=1,
|
| 394 |
+
label="Batch Size",
|
| 395 |
+
info="Increase if you have enough VRAM",
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
gradient_accumulation = gr.Slider(
|
| 399 |
+
minimum=1,
|
| 400 |
+
maximum=16,
|
| 401 |
+
step=1,
|
| 402 |
+
value=1,
|
| 403 |
+
label="Gradient Accumulation",
|
| 404 |
+
info="Effective batch = batch_size × accumulation",
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
with gr.Row():
|
| 408 |
+
save_every_n_epochs = gr.Slider(
|
| 409 |
+
minimum=50,
|
| 410 |
+
maximum=1000,
|
| 411 |
+
step=50,
|
| 412 |
+
value=200,
|
| 413 |
+
label="Save Every N Epochs",
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
training_shift = gr.Slider(
|
| 417 |
+
minimum=1.0,
|
| 418 |
+
maximum=5.0,
|
| 419 |
+
step=0.5,
|
| 420 |
+
value=3.0,
|
| 421 |
+
label="Shift",
|
| 422 |
+
info="Timestep shift for turbo model",
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
training_seed = gr.Number(
|
| 426 |
+
label="Seed",
|
| 427 |
+
value=42,
|
| 428 |
+
precision=0,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
with gr.Row():
|
| 432 |
+
lora_output_dir = gr.Textbox(
|
| 433 |
+
label="Output Directory",
|
| 434 |
+
value="./lora_output",
|
| 435 |
+
placeholder="./lora_output",
|
| 436 |
+
info="Directory to save trained LoRA weights",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
gr.HTML("<hr>")
|
| 440 |
+
|
| 441 |
+
with gr.Row():
|
| 442 |
+
with gr.Column(scale=1):
|
| 443 |
+
start_training_btn = gr.Button(
|
| 444 |
+
"🚀 Start Training",
|
| 445 |
+
variant="primary",
|
| 446 |
+
size="lg",
|
| 447 |
+
)
|
| 448 |
+
with gr.Column(scale=1):
|
| 449 |
+
stop_training_btn = gr.Button(
|
| 450 |
+
"⏹️ Stop Training",
|
| 451 |
+
variant="stop",
|
| 452 |
+
size="lg",
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
training_progress = gr.Textbox(
|
| 456 |
+
label="Training Progress",
|
| 457 |
+
interactive=False,
|
| 458 |
+
lines=2,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
with gr.Row():
|
| 462 |
+
training_log = gr.Textbox(
|
| 463 |
+
label="Training Log",
|
| 464 |
+
interactive=False,
|
| 465 |
+
lines=10,
|
| 466 |
+
max_lines=15,
|
| 467 |
+
scale=1,
|
| 468 |
+
)
|
| 469 |
+
training_loss_plot = gr.LinePlot(
|
| 470 |
+
x="step",
|
| 471 |
+
y="loss",
|
| 472 |
+
title="Training Loss",
|
| 473 |
+
x_title="Step",
|
| 474 |
+
y_title="Loss",
|
| 475 |
+
scale=1,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
gr.HTML("<hr><h3>📦 Export LoRA</h3>")
|
| 479 |
+
|
| 480 |
+
with gr.Row():
|
| 481 |
+
export_path = gr.Textbox(
|
| 482 |
+
label="Export Path",
|
| 483 |
+
value="./lora_output/final_lora",
|
| 484 |
+
placeholder="./lora_output/my_lora",
|
| 485 |
+
)
|
| 486 |
+
export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
|
| 487 |
+
|
| 488 |
+
export_status = gr.Textbox(
|
| 489 |
+
label="Export Status",
|
| 490 |
+
interactive=False,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# Store dataset builder state
|
| 494 |
+
dataset_builder_state = gr.State(None)
|
| 495 |
+
training_state = gr.State({"is_training": False, "should_stop": False})
|
| 496 |
+
|
| 497 |
+
return {
|
| 498 |
+
# Dataset Builder - Load or Scan
|
| 499 |
+
"load_json_path": load_json_path,
|
| 500 |
+
"load_json_btn": load_json_btn,
|
| 501 |
+
"load_json_status": load_json_status,
|
| 502 |
+
"audio_directory": audio_directory,
|
| 503 |
+
"scan_btn": scan_btn,
|
| 504 |
+
"scan_status": scan_status,
|
| 505 |
+
"audio_files_table": audio_files_table,
|
| 506 |
+
"dataset_name": dataset_name,
|
| 507 |
+
"all_instrumental": all_instrumental,
|
| 508 |
+
"need_lyrics": need_lyrics,
|
| 509 |
+
"custom_tag": custom_tag,
|
| 510 |
+
"tag_position": tag_position,
|
| 511 |
+
"skip_metas": skip_metas,
|
| 512 |
+
"auto_label_btn": auto_label_btn,
|
| 513 |
+
"label_progress": label_progress,
|
| 514 |
+
"sample_selector": sample_selector,
|
| 515 |
+
"preview_audio": preview_audio,
|
| 516 |
+
"preview_filename": preview_filename,
|
| 517 |
+
"edit_caption": edit_caption,
|
| 518 |
+
"edit_lyrics": edit_lyrics,
|
| 519 |
+
"edit_bpm": edit_bpm,
|
| 520 |
+
"edit_keyscale": edit_keyscale,
|
| 521 |
+
"edit_timesig": edit_timesig,
|
| 522 |
+
"edit_duration": edit_duration,
|
| 523 |
+
"edit_language": edit_language,
|
| 524 |
+
"edit_instrumental": edit_instrumental,
|
| 525 |
+
"save_edit_btn": save_edit_btn,
|
| 526 |
+
"edit_status": edit_status,
|
| 527 |
+
"save_path": save_path,
|
| 528 |
+
"save_dataset_btn": save_dataset_btn,
|
| 529 |
+
"save_status": save_status,
|
| 530 |
+
# Preprocessing
|
| 531 |
+
"load_existing_dataset_path": load_existing_dataset_path,
|
| 532 |
+
"load_existing_dataset_btn": load_existing_dataset_btn,
|
| 533 |
+
"load_existing_status": load_existing_status,
|
| 534 |
+
"preprocess_output_dir": preprocess_output_dir,
|
| 535 |
+
"preprocess_btn": preprocess_btn,
|
| 536 |
+
"preprocess_progress": preprocess_progress,
|
| 537 |
+
"dataset_builder_state": dataset_builder_state,
|
| 538 |
+
# Training
|
| 539 |
+
"training_tensor_dir": training_tensor_dir,
|
| 540 |
+
"load_dataset_btn": load_dataset_btn,
|
| 541 |
+
"training_dataset_info": training_dataset_info,
|
| 542 |
+
"lora_rank": lora_rank,
|
| 543 |
+
"lora_alpha": lora_alpha,
|
| 544 |
+
"lora_dropout": lora_dropout,
|
| 545 |
+
"learning_rate": learning_rate,
|
| 546 |
+
"train_epochs": train_epochs,
|
| 547 |
+
"train_batch_size": train_batch_size,
|
| 548 |
+
"gradient_accumulation": gradient_accumulation,
|
| 549 |
+
"save_every_n_epochs": save_every_n_epochs,
|
| 550 |
+
"training_shift": training_shift,
|
| 551 |
+
"training_seed": training_seed,
|
| 552 |
+
"lora_output_dir": lora_output_dir,
|
| 553 |
+
"start_training_btn": start_training_btn,
|
| 554 |
+
"stop_training_btn": stop_training_btn,
|
| 555 |
+
"training_progress": training_progress,
|
| 556 |
+
"training_log": training_log,
|
| 557 |
+
"training_loss_plot": training_loss_plot,
|
| 558 |
+
"export_path": export_path,
|
| 559 |
+
"export_lora_btn": export_lora_btn,
|
| 560 |
+
"export_status": export_status,
|
| 561 |
+
"training_state": training_state,
|
| 562 |
+
}
|
spaces/Ace-Step-v1.5/acestep/handler.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
spaces/Ace-Step-v1.5/acestep/inference.py
ADDED
|
@@ -0,0 +1,1182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Inference API Module
|
| 3 |
+
|
| 4 |
+
This module provides a standardized inference interface for music generation,
|
| 5 |
+
designed for third-party integration. It offers both a simplified API and
|
| 6 |
+
backward-compatible Gradio UI support.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 13 |
+
from dataclasses import dataclass, field, asdict
|
| 14 |
+
from loguru import logger
|
| 15 |
+
|
| 16 |
+
from acestep.audio_utils import AudioSaver, generate_uuid_from_params
|
| 17 |
+
|
| 18 |
+
# HuggingFace Space environment detection
|
| 19 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 20 |
+
|
| 21 |
+
def _get_spaces_gpu_decorator(duration=180):
|
| 22 |
+
"""
|
| 23 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 24 |
+
Returns identity decorator if not in Space environment.
|
| 25 |
+
"""
|
| 26 |
+
if IS_HUGGINGFACE_SPACE:
|
| 27 |
+
try:
|
| 28 |
+
import spaces
|
| 29 |
+
return spaces.GPU(duration=duration)
|
| 30 |
+
except ImportError:
|
| 31 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 32 |
+
return lambda func: func
|
| 33 |
+
return lambda func: func
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class GenerationParams:
|
| 38 |
+
"""Configuration for music generation parameters.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
# Text Inputs
|
| 42 |
+
caption: A short text prompt describing the desired music (main prompt). < 512 characters
|
| 43 |
+
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
|
| 44 |
+
instrumental: If True, generate instrumental music regardless of lyrics.
|
| 45 |
+
|
| 46 |
+
# Music Metadata
|
| 47 |
+
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
|
| 48 |
+
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
|
| 49 |
+
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
|
| 50 |
+
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
|
| 51 |
+
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
|
| 52 |
+
|
| 53 |
+
# Generation Parameters
|
| 54 |
+
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
|
| 55 |
+
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
|
| 56 |
+
seed: Integer seed for reproducibility. -1 means use random seed each time.
|
| 57 |
+
|
| 58 |
+
# Advanced DiT Parameters
|
| 59 |
+
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
|
| 60 |
+
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
|
| 61 |
+
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
|
| 62 |
+
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
|
| 63 |
+
|
| 64 |
+
# Task-Specific Parameters
|
| 65 |
+
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
|
| 66 |
+
reference_audio: Path to a reference audio file for style transfer or cover tasks.
|
| 67 |
+
src_audio: Path to a source audio file for audio-to-audio tasks.
|
| 68 |
+
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
|
| 69 |
+
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
|
| 70 |
+
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
|
| 71 |
+
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
|
| 72 |
+
instruction: Optional task instruction prompt. If empty, auto-generated by system.
|
| 73 |
+
|
| 74 |
+
# 5Hz Language Model Parameters for CoT reasoning
|
| 75 |
+
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
|
| 76 |
+
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
|
| 77 |
+
lm_cfg_scale: Classifier-free guidance scale for the LLM.
|
| 78 |
+
lm_top_k: LLM top-k sampling (0 = disabled).
|
| 79 |
+
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
|
| 80 |
+
lm_negative_prompt: Negative prompt to use for LLM (for control).
|
| 81 |
+
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
|
| 82 |
+
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
|
| 83 |
+
use_cot_language: Whether to let LLM detect vocal language via CoT.
|
| 84 |
+
"""
|
| 85 |
+
# Required Inputs
|
| 86 |
+
task_type: str = "text2music"
|
| 87 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 88 |
+
|
| 89 |
+
# Audio Uploads
|
| 90 |
+
reference_audio: Optional[str] = None
|
| 91 |
+
src_audio: Optional[str] = None
|
| 92 |
+
|
| 93 |
+
# LM Codes Hints
|
| 94 |
+
audio_codes: str = ""
|
| 95 |
+
|
| 96 |
+
# Text Inputs
|
| 97 |
+
caption: str = ""
|
| 98 |
+
lyrics: str = ""
|
| 99 |
+
instrumental: bool = False
|
| 100 |
+
|
| 101 |
+
# Metadata
|
| 102 |
+
vocal_language: str = "unknown"
|
| 103 |
+
bpm: Optional[int] = None
|
| 104 |
+
keyscale: str = ""
|
| 105 |
+
timesignature: str = ""
|
| 106 |
+
duration: float = -1.0
|
| 107 |
+
|
| 108 |
+
# Advanced Settings
|
| 109 |
+
inference_steps: int = 8
|
| 110 |
+
seed: int = -1
|
| 111 |
+
guidance_scale: float = 7.0
|
| 112 |
+
use_adg: bool = False
|
| 113 |
+
cfg_interval_start: float = 0.0
|
| 114 |
+
cfg_interval_end: float = 1.0
|
| 115 |
+
shift: float = 1.0
|
| 116 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 117 |
+
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 118 |
+
# If provided, overrides inference_steps and shift
|
| 119 |
+
timesteps: Optional[List[float]] = None
|
| 120 |
+
|
| 121 |
+
repainting_start: float = 0.0
|
| 122 |
+
repainting_end: float = -1
|
| 123 |
+
audio_cover_strength: float = 1.0
|
| 124 |
+
|
| 125 |
+
# 5Hz Language Model Parameters
|
| 126 |
+
thinking: bool = True
|
| 127 |
+
lm_temperature: float = 0.85
|
| 128 |
+
lm_cfg_scale: float = 2.0
|
| 129 |
+
lm_top_k: int = 0
|
| 130 |
+
lm_top_p: float = 0.9
|
| 131 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 132 |
+
use_cot_metas: bool = True
|
| 133 |
+
use_cot_caption: bool = True
|
| 134 |
+
use_cot_lyrics: bool = False # TODO: not used yet
|
| 135 |
+
use_cot_language: bool = True
|
| 136 |
+
use_constrained_decoding: bool = True
|
| 137 |
+
|
| 138 |
+
cot_bpm: Optional[int] = None
|
| 139 |
+
cot_keyscale: str = ""
|
| 140 |
+
cot_timesignature: str = ""
|
| 141 |
+
cot_duration: Optional[float] = None
|
| 142 |
+
cot_vocal_language: str = "unknown"
|
| 143 |
+
cot_caption: str = ""
|
| 144 |
+
cot_lyrics: str = ""
|
| 145 |
+
|
| 146 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 147 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 148 |
+
return asdict(self)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
@dataclass
|
| 152 |
+
class GenerationConfig:
|
| 153 |
+
"""Configuration for music generation.
|
| 154 |
+
|
| 155 |
+
Attributes:
|
| 156 |
+
batch_size: Number of audio samples to generate
|
| 157 |
+
allow_lm_batch: Whether to allow batch processing in LM
|
| 158 |
+
use_random_seed: Whether to use random seed
|
| 159 |
+
seeds: Seed(s) for batch generation. Can be:
|
| 160 |
+
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 161 |
+
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 162 |
+
- int: Single seed value (will be converted to list and padded)
|
| 163 |
+
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 164 |
+
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 165 |
+
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 166 |
+
"""
|
| 167 |
+
batch_size: int = 2
|
| 168 |
+
allow_lm_batch: bool = False
|
| 169 |
+
use_random_seed: bool = True
|
| 170 |
+
seeds: Optional[List[int]] = None
|
| 171 |
+
lm_batch_chunk_size: int = 8
|
| 172 |
+
constrained_decoding_debug: bool = False
|
| 173 |
+
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 174 |
+
|
| 175 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 176 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 177 |
+
return asdict(self)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
@dataclass
|
| 181 |
+
class GenerationResult:
|
| 182 |
+
"""Result of music generation.
|
| 183 |
+
|
| 184 |
+
Attributes:
|
| 185 |
+
# Audio Outputs
|
| 186 |
+
audios: List of audio dictionaries with paths, keys, params
|
| 187 |
+
status_message: Status message from generation
|
| 188 |
+
extra_outputs: Extra outputs from generation
|
| 189 |
+
success: Whether generation completed successfully
|
| 190 |
+
error: Error message if generation failed
|
| 191 |
+
"""
|
| 192 |
+
|
| 193 |
+
# Audio Outputs
|
| 194 |
+
audios: List[Dict[str, Any]] = field(default_factory=list)
|
| 195 |
+
# Generation Information
|
| 196 |
+
status_message: str = ""
|
| 197 |
+
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
| 198 |
+
# Success Status
|
| 199 |
+
success: bool = True
|
| 200 |
+
error: Optional[str] = None
|
| 201 |
+
|
| 202 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 203 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 204 |
+
return asdict(self)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@dataclass
|
| 208 |
+
class UnderstandResult:
|
| 209 |
+
"""Result of music understanding from audio codes.
|
| 210 |
+
|
| 211 |
+
Attributes:
|
| 212 |
+
# Metadata Fields
|
| 213 |
+
caption: Generated caption describing the music
|
| 214 |
+
lyrics: Generated or extracted lyrics
|
| 215 |
+
bpm: Beats per minute (None if not detected)
|
| 216 |
+
duration: Duration in seconds (None if not detected)
|
| 217 |
+
keyscale: Musical key (e.g., "C Major")
|
| 218 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 219 |
+
timesignature: Time signature (e.g., "4/4")
|
| 220 |
+
|
| 221 |
+
# Status
|
| 222 |
+
status_message: Status message from understanding
|
| 223 |
+
success: Whether understanding completed successfully
|
| 224 |
+
error: Error message if understanding failed
|
| 225 |
+
"""
|
| 226 |
+
# Metadata Fields
|
| 227 |
+
caption: str = ""
|
| 228 |
+
lyrics: str = ""
|
| 229 |
+
bpm: Optional[int] = None
|
| 230 |
+
duration: Optional[float] = None
|
| 231 |
+
keyscale: str = ""
|
| 232 |
+
language: str = ""
|
| 233 |
+
timesignature: str = ""
|
| 234 |
+
|
| 235 |
+
# Status
|
| 236 |
+
status_message: str = ""
|
| 237 |
+
success: bool = True
|
| 238 |
+
error: Optional[str] = None
|
| 239 |
+
|
| 240 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 241 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 242 |
+
return asdict(self)
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
def _update_metadata_from_lm(
|
| 246 |
+
metadata: Dict[str, Any],
|
| 247 |
+
bpm: Optional[int],
|
| 248 |
+
key_scale: str,
|
| 249 |
+
time_signature: str,
|
| 250 |
+
audio_duration: Optional[float],
|
| 251 |
+
vocal_language: str,
|
| 252 |
+
caption: str,
|
| 253 |
+
lyrics: str,
|
| 254 |
+
) -> Tuple[Optional[int], str, str, Optional[float]]:
|
| 255 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 256 |
+
|
| 257 |
+
if bpm is None and metadata.get('bpm'):
|
| 258 |
+
bpm_value = metadata.get('bpm')
|
| 259 |
+
if bpm_value not in ["N/A", ""]:
|
| 260 |
+
try:
|
| 261 |
+
bpm = int(bpm_value)
|
| 262 |
+
except (ValueError, TypeError):
|
| 263 |
+
pass
|
| 264 |
+
|
| 265 |
+
if not key_scale and metadata.get('keyscale'):
|
| 266 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 267 |
+
if key_scale_value != "N/A":
|
| 268 |
+
key_scale = key_scale_value
|
| 269 |
+
|
| 270 |
+
if not time_signature and metadata.get('timesignature'):
|
| 271 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 272 |
+
if time_signature_value != "N/A":
|
| 273 |
+
time_signature = time_signature_value
|
| 274 |
+
|
| 275 |
+
if audio_duration is None or audio_duration <= 0:
|
| 276 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 277 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 278 |
+
try:
|
| 279 |
+
audio_duration = float(audio_duration_value)
|
| 280 |
+
except (ValueError, TypeError):
|
| 281 |
+
pass
|
| 282 |
+
|
| 283 |
+
if not vocal_language and metadata.get('vocal_language'):
|
| 284 |
+
vocal_language = metadata.get('vocal_language')
|
| 285 |
+
if not caption and metadata.get('caption'):
|
| 286 |
+
caption = metadata.get('caption')
|
| 287 |
+
if not lyrics and metadata.get('lyrics'):
|
| 288 |
+
lyrics = metadata.get('lyrics')
|
| 289 |
+
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
@_get_spaces_gpu_decorator(duration=180)
|
| 293 |
+
def generate_music(
|
| 294 |
+
dit_handler,
|
| 295 |
+
llm_handler,
|
| 296 |
+
params: GenerationParams,
|
| 297 |
+
config: GenerationConfig,
|
| 298 |
+
save_dir: Optional[str] = None,
|
| 299 |
+
progress=None,
|
| 300 |
+
) -> GenerationResult:
|
| 301 |
+
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 305 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 306 |
+
params: Generation parameters (GenerationParams instance)
|
| 307 |
+
config: Generation configuration (GenerationConfig instance)
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
GenerationResult with generated audio files and metadata
|
| 311 |
+
"""
|
| 312 |
+
try:
|
| 313 |
+
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 314 |
+
audio_code_string_to_use = params.audio_codes
|
| 315 |
+
lm_generated_metadata = None
|
| 316 |
+
lm_generated_audio_codes_list = []
|
| 317 |
+
lm_total_time_costs = {
|
| 318 |
+
"phase1_time": 0.0,
|
| 319 |
+
"phase2_time": 0.0,
|
| 320 |
+
"total_time": 0.0,
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 324 |
+
bpm = params.bpm
|
| 325 |
+
key_scale = params.keyscale
|
| 326 |
+
time_signature = params.timesignature
|
| 327 |
+
audio_duration = params.duration
|
| 328 |
+
dit_input_caption = params.caption
|
| 329 |
+
dit_input_vocal_language = params.vocal_language
|
| 330 |
+
dit_input_lyrics = params.lyrics
|
| 331 |
+
# Determine if we need to generate audio codes
|
| 332 |
+
# If user has provided audio_codes, we don't need to generate them
|
| 333 |
+
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 334 |
+
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 335 |
+
|
| 336 |
+
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 337 |
+
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 338 |
+
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 339 |
+
# Note: This logic can be refined based on specific requirements
|
| 340 |
+
need_audio_codes = not user_provided_audio_codes
|
| 341 |
+
|
| 342 |
+
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 343 |
+
# Determine actual batch size for chunk processing
|
| 344 |
+
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
| 345 |
+
|
| 346 |
+
# Prepare seeds for batch generation
|
| 347 |
+
# Use config.seed if provided, otherwise fallback to params.seed
|
| 348 |
+
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 349 |
+
seed_for_generation = ""
|
| 350 |
+
if config.seeds is not None and len(config.seeds) > 0:
|
| 351 |
+
if isinstance(config.seeds, list):
|
| 352 |
+
# Convert List[int] to comma-separated string
|
| 353 |
+
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 354 |
+
|
| 355 |
+
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 356 |
+
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 357 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 358 |
+
|
| 359 |
+
# LM-based Chain-of-Thought reasoning
|
| 360 |
+
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
|
| 361 |
+
# and don't need LM to generate audio codes
|
| 362 |
+
skip_lm_tasks = {"cover", "repaint"}
|
| 363 |
+
|
| 364 |
+
# Determine if we should use LLM
|
| 365 |
+
# LLM is needed for:
|
| 366 |
+
# 1. thinking=True: generate audio codes via LM
|
| 367 |
+
# 2. use_cot_caption=True: enhance/generate caption via CoT
|
| 368 |
+
# 3. use_cot_language=True: detect vocal language via CoT
|
| 369 |
+
# 4. use_cot_metas=True: fill missing metadata via CoT
|
| 370 |
+
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
|
| 371 |
+
use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
|
| 372 |
+
lm_status = []
|
| 373 |
+
|
| 374 |
+
if params.task_type in skip_lm_tasks:
|
| 375 |
+
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
|
| 376 |
+
|
| 377 |
+
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
|
| 378 |
+
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
|
| 379 |
+
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
|
| 380 |
+
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
|
| 381 |
+
|
| 382 |
+
if use_lm:
|
| 383 |
+
# Convert sampling parameters - handle None values safely
|
| 384 |
+
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
| 385 |
+
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
| 386 |
+
|
| 387 |
+
# Build user_metadata from user-provided values
|
| 388 |
+
user_metadata = {}
|
| 389 |
+
if bpm is not None:
|
| 390 |
+
try:
|
| 391 |
+
bpm_value = float(bpm)
|
| 392 |
+
if bpm_value > 0:
|
| 393 |
+
user_metadata['bpm'] = int(bpm_value)
|
| 394 |
+
except (ValueError, TypeError):
|
| 395 |
+
pass
|
| 396 |
+
|
| 397 |
+
if key_scale and key_scale.strip():
|
| 398 |
+
key_scale_clean = key_scale.strip()
|
| 399 |
+
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 400 |
+
user_metadata['keyscale'] = key_scale_clean
|
| 401 |
+
|
| 402 |
+
if time_signature and time_signature.strip():
|
| 403 |
+
time_sig_clean = time_signature.strip()
|
| 404 |
+
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 405 |
+
user_metadata['timesignature'] = time_sig_clean
|
| 406 |
+
|
| 407 |
+
if audio_duration is not None:
|
| 408 |
+
try:
|
| 409 |
+
duration_value = float(audio_duration)
|
| 410 |
+
if duration_value > 0:
|
| 411 |
+
user_metadata['duration'] = int(duration_value)
|
| 412 |
+
except (ValueError, TypeError):
|
| 413 |
+
pass
|
| 414 |
+
|
| 415 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 416 |
+
|
| 417 |
+
# Determine infer_type based on whether we need audio codes
|
| 418 |
+
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 419 |
+
# - "dit": generates only metas (single phase)
|
| 420 |
+
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
|
| 421 |
+
|
| 422 |
+
# Use chunk size from config, or default to batch_size if not set
|
| 423 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 424 |
+
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 425 |
+
|
| 426 |
+
all_metadata_list = []
|
| 427 |
+
all_audio_codes_list = []
|
| 428 |
+
|
| 429 |
+
for chunk_idx in range(num_chunks):
|
| 430 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 431 |
+
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 432 |
+
chunk_size = chunk_end - chunk_start
|
| 433 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 434 |
+
|
| 435 |
+
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 436 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})")
|
| 437 |
+
|
| 438 |
+
# Use the determined infer_type
|
| 439 |
+
# - "llm_dit" will internally run two phases (metas + codes)
|
| 440 |
+
# - "dit" will only run phase 1 (metas only)
|
| 441 |
+
result = llm_handler.generate_with_stop_condition(
|
| 442 |
+
caption=params.caption or "",
|
| 443 |
+
lyrics=params.lyrics or "",
|
| 444 |
+
infer_type=infer_type,
|
| 445 |
+
temperature=params.lm_temperature,
|
| 446 |
+
cfg_scale=params.lm_cfg_scale,
|
| 447 |
+
negative_prompt=params.lm_negative_prompt,
|
| 448 |
+
top_k=top_k_value,
|
| 449 |
+
top_p=top_p_value,
|
| 450 |
+
user_metadata=user_metadata_to_pass,
|
| 451 |
+
use_cot_caption=params.use_cot_caption,
|
| 452 |
+
use_cot_language=params.use_cot_language,
|
| 453 |
+
use_cot_metas=params.use_cot_metas,
|
| 454 |
+
use_constrained_decoding=params.use_constrained_decoding,
|
| 455 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 456 |
+
batch_size=chunk_size,
|
| 457 |
+
seeds=chunk_seeds,
|
| 458 |
+
progress=progress,
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# Check if LM generation failed
|
| 462 |
+
if not result.get("success", False):
|
| 463 |
+
error_msg = result.get("error", "Unknown LM error")
|
| 464 |
+
lm_status.append(f"❌ LM Error: {error_msg}")
|
| 465 |
+
# Return early with error
|
| 466 |
+
return GenerationResult(
|
| 467 |
+
audios=[],
|
| 468 |
+
status_message=f"❌ LM generation failed: {error_msg}",
|
| 469 |
+
extra_outputs={},
|
| 470 |
+
success=False,
|
| 471 |
+
error=error_msg,
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
# Extract metadata and audio_codes from result dict
|
| 475 |
+
if chunk_size > 1:
|
| 476 |
+
metadata_list = result.get("metadata", [])
|
| 477 |
+
audio_codes_list = result.get("audio_codes", [])
|
| 478 |
+
all_metadata_list.extend(metadata_list)
|
| 479 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 480 |
+
else:
|
| 481 |
+
metadata = result.get("metadata", {})
|
| 482 |
+
audio_codes = result.get("audio_codes", "")
|
| 483 |
+
all_metadata_list.append(metadata)
|
| 484 |
+
all_audio_codes_list.append(audio_codes)
|
| 485 |
+
|
| 486 |
+
# Collect time costs from LM extra_outputs
|
| 487 |
+
lm_extra = result.get("extra_outputs", {})
|
| 488 |
+
lm_chunk_time_costs = lm_extra.get("time_costs", {})
|
| 489 |
+
if lm_chunk_time_costs:
|
| 490 |
+
# Accumulate time costs from all chunks
|
| 491 |
+
for key in ["phase1_time", "phase2_time", "total_time"]:
|
| 492 |
+
if key in lm_chunk_time_costs:
|
| 493 |
+
lm_total_time_costs[key] += lm_chunk_time_costs[key]
|
| 494 |
+
|
| 495 |
+
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
|
| 496 |
+
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
|
| 497 |
+
|
| 498 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 499 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 500 |
+
|
| 501 |
+
# Set audio_code_string_to_use based on infer_type
|
| 502 |
+
if infer_type == "llm_dit":
|
| 503 |
+
# If batch mode, use list; otherwise use single string
|
| 504 |
+
if actual_batch_size > 1:
|
| 505 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 506 |
+
else:
|
| 507 |
+
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
|
| 508 |
+
else:
|
| 509 |
+
# For "dit" mode, keep user-provided codes or empty
|
| 510 |
+
audio_code_string_to_use = params.audio_codes
|
| 511 |
+
|
| 512 |
+
# Update metadata from LM if not provided by user
|
| 513 |
+
if lm_generated_metadata:
|
| 514 |
+
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
|
| 515 |
+
metadata=lm_generated_metadata,
|
| 516 |
+
bpm=bpm,
|
| 517 |
+
key_scale=key_scale,
|
| 518 |
+
time_signature=time_signature,
|
| 519 |
+
audio_duration=audio_duration,
|
| 520 |
+
vocal_language=dit_input_vocal_language,
|
| 521 |
+
caption=dit_input_caption,
|
| 522 |
+
lyrics=dit_input_lyrics)
|
| 523 |
+
if not params.bpm:
|
| 524 |
+
params.cot_bpm = bpm
|
| 525 |
+
if not params.keyscale:
|
| 526 |
+
params.cot_keyscale = key_scale
|
| 527 |
+
if not params.timesignature:
|
| 528 |
+
params.cot_timesignature = time_signature
|
| 529 |
+
if not params.duration:
|
| 530 |
+
params.cot_duration = audio_duration
|
| 531 |
+
if not params.vocal_language:
|
| 532 |
+
params.cot_vocal_language = vocal_language
|
| 533 |
+
if not params.caption:
|
| 534 |
+
params.cot_caption = caption
|
| 535 |
+
if not params.lyrics:
|
| 536 |
+
params.cot_lyrics = lyrics
|
| 537 |
+
|
| 538 |
+
# set cot caption and language if needed
|
| 539 |
+
if params.use_cot_caption:
|
| 540 |
+
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
|
| 541 |
+
if params.use_cot_language:
|
| 542 |
+
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
|
| 543 |
+
|
| 544 |
+
# Phase 2: DiT music generation
|
| 545 |
+
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 546 |
+
result = dit_handler.generate_music(
|
| 547 |
+
captions=dit_input_caption,
|
| 548 |
+
lyrics=dit_input_lyrics,
|
| 549 |
+
bpm=bpm,
|
| 550 |
+
key_scale=key_scale,
|
| 551 |
+
time_signature=time_signature,
|
| 552 |
+
vocal_language=dit_input_vocal_language,
|
| 553 |
+
inference_steps=params.inference_steps,
|
| 554 |
+
guidance_scale=params.guidance_scale,
|
| 555 |
+
use_random_seed=config.use_random_seed,
|
| 556 |
+
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
|
| 557 |
+
reference_audio=params.reference_audio,
|
| 558 |
+
audio_duration=audio_duration,
|
| 559 |
+
batch_size=config.batch_size if config.batch_size is not None else 1,
|
| 560 |
+
src_audio=params.src_audio,
|
| 561 |
+
audio_code_string=audio_code_string_to_use,
|
| 562 |
+
repainting_start=params.repainting_start,
|
| 563 |
+
repainting_end=params.repainting_end,
|
| 564 |
+
instruction=params.instruction,
|
| 565 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 566 |
+
task_type=params.task_type,
|
| 567 |
+
use_adg=params.use_adg,
|
| 568 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 569 |
+
cfg_interval_end=params.cfg_interval_end,
|
| 570 |
+
shift=params.shift,
|
| 571 |
+
infer_method=params.infer_method,
|
| 572 |
+
timesteps=params.timesteps,
|
| 573 |
+
progress=progress,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Check if generation failed
|
| 577 |
+
if not result.get("success", False):
|
| 578 |
+
return GenerationResult(
|
| 579 |
+
audios=[],
|
| 580 |
+
status_message=result.get("status_message", ""),
|
| 581 |
+
extra_outputs={},
|
| 582 |
+
success=False,
|
| 583 |
+
error=result.get("error"),
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
# Extract results from dit_handler.generate_music dict
|
| 587 |
+
dit_audios = result.get("audios", [])
|
| 588 |
+
status_message = result.get("status_message", "")
|
| 589 |
+
dit_extra_outputs = result.get("extra_outputs", {})
|
| 590 |
+
|
| 591 |
+
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 592 |
+
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 593 |
+
seed_list = actual_seed_list
|
| 594 |
+
|
| 595 |
+
# Get base params dictionary
|
| 596 |
+
base_params_dict = params.to_dict()
|
| 597 |
+
|
| 598 |
+
# Save audio files using AudioSaver (format from config)
|
| 599 |
+
audio_format = config.audio_format if config.audio_format else "flac"
|
| 600 |
+
audio_saver = AudioSaver(default_format=audio_format)
|
| 601 |
+
|
| 602 |
+
# Use handler's temp_dir for saving files
|
| 603 |
+
if save_dir is not None:
|
| 604 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 605 |
+
|
| 606 |
+
# Build audios list for GenerationResult with params and save files
|
| 607 |
+
# Audio saving and UUID generation handled here, outside of handler
|
| 608 |
+
audios = []
|
| 609 |
+
for idx, dit_audio in enumerate(dit_audios):
|
| 610 |
+
# Create a copy of params dict for this audio
|
| 611 |
+
audio_params = base_params_dict.copy()
|
| 612 |
+
|
| 613 |
+
# Update audio-specific values
|
| 614 |
+
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 615 |
+
|
| 616 |
+
# Add audio codes if batch mode
|
| 617 |
+
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 618 |
+
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 619 |
+
|
| 620 |
+
# Get audio tensor and metadata
|
| 621 |
+
audio_tensor = dit_audio.get("tensor")
|
| 622 |
+
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 623 |
+
|
| 624 |
+
# Generate UUID for this audio (moved from handler)
|
| 625 |
+
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 626 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
| 627 |
+
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 628 |
+
if isinstance(audio_code_str, list):
|
| 629 |
+
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 630 |
+
|
| 631 |
+
audio_key = generate_uuid_from_params(audio_params)
|
| 632 |
+
|
| 633 |
+
# Save audio file (handled outside handler)
|
| 634 |
+
audio_path = None
|
| 635 |
+
if audio_tensor is not None and save_dir is not None:
|
| 636 |
+
try:
|
| 637 |
+
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 638 |
+
audio_path = audio_saver.save_audio(audio_tensor,
|
| 639 |
+
audio_file,
|
| 640 |
+
sample_rate=sample_rate,
|
| 641 |
+
format=audio_format,
|
| 642 |
+
channels_first=True)
|
| 643 |
+
except Exception as e:
|
| 644 |
+
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 645 |
+
audio_path = "" # Fallback to empty path
|
| 646 |
+
|
| 647 |
+
audio_dict = {
|
| 648 |
+
"path": audio_path or "", # File path (saved here, not in handler)
|
| 649 |
+
"tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
|
| 650 |
+
"key": audio_key,
|
| 651 |
+
"sample_rate": sample_rate,
|
| 652 |
+
"params": audio_params,
|
| 653 |
+
}
|
| 654 |
+
|
| 655 |
+
audios.append(audio_dict)
|
| 656 |
+
|
| 657 |
+
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 658 |
+
extra_outputs = dit_extra_outputs.copy()
|
| 659 |
+
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 660 |
+
|
| 661 |
+
# Merge time_costs from both LM and DiT into a unified dictionary
|
| 662 |
+
unified_time_costs = {}
|
| 663 |
+
|
| 664 |
+
# Add LM time costs (if LM was used)
|
| 665 |
+
if use_lm and lm_total_time_costs:
|
| 666 |
+
for key, value in lm_total_time_costs.items():
|
| 667 |
+
unified_time_costs[f"lm_{key}"] = value
|
| 668 |
+
|
| 669 |
+
# Add DiT time costs (if available)
|
| 670 |
+
dit_time_costs = dit_extra_outputs.get("time_costs", {})
|
| 671 |
+
if dit_time_costs:
|
| 672 |
+
for key, value in dit_time_costs.items():
|
| 673 |
+
unified_time_costs[f"dit_{key}"] = value
|
| 674 |
+
|
| 675 |
+
# Calculate total pipeline time
|
| 676 |
+
if unified_time_costs:
|
| 677 |
+
lm_total = unified_time_costs.get("lm_total_time", 0.0)
|
| 678 |
+
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
|
| 679 |
+
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
|
| 680 |
+
|
| 681 |
+
# Update extra_outputs with unified time_costs
|
| 682 |
+
extra_outputs["time_costs"] = unified_time_costs
|
| 683 |
+
|
| 684 |
+
if lm_status:
|
| 685 |
+
status_message = "\n".join(lm_status) + "\n" + status_message
|
| 686 |
+
else:
|
| 687 |
+
status_message = status_message
|
| 688 |
+
# Create and return GenerationResult
|
| 689 |
+
return GenerationResult(
|
| 690 |
+
audios=audios,
|
| 691 |
+
status_message=status_message,
|
| 692 |
+
extra_outputs=extra_outputs,
|
| 693 |
+
success=True,
|
| 694 |
+
error=None,
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
except Exception as e:
|
| 698 |
+
logger.exception("Music generation failed")
|
| 699 |
+
return GenerationResult(
|
| 700 |
+
audios=[],
|
| 701 |
+
status_message=f"Error: {str(e)}",
|
| 702 |
+
extra_outputs={},
|
| 703 |
+
success=False,
|
| 704 |
+
error=str(e),
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
def understand_music(
|
| 709 |
+
llm_handler,
|
| 710 |
+
audio_codes: str,
|
| 711 |
+
temperature: float = 0.85,
|
| 712 |
+
top_k: Optional[int] = None,
|
| 713 |
+
top_p: Optional[float] = None,
|
| 714 |
+
repetition_penalty: float = 1.0,
|
| 715 |
+
use_constrained_decoding: bool = True,
|
| 716 |
+
constrained_decoding_debug: bool = False,
|
| 717 |
+
) -> UnderstandResult:
|
| 718 |
+
"""Understand music from audio codes using the 5Hz Language Model.
|
| 719 |
+
|
| 720 |
+
This function analyzes audio semantic codes and generates metadata about the music,
|
| 721 |
+
including caption, lyrics, BPM, duration, key scale, language, and time signature.
|
| 722 |
+
|
| 723 |
+
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
|
| 724 |
+
instead of analyzing existing codes.
|
| 725 |
+
|
| 726 |
+
Note: cfg_scale and negative_prompt are not supported in understand mode.
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 730 |
+
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
|
| 731 |
+
Use empty string or "NO USER INPUT" to generate a sample example.
|
| 732 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 733 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 734 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 735 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 736 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 737 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 738 |
+
|
| 739 |
+
Returns:
|
| 740 |
+
UnderstandResult with parsed metadata fields and status
|
| 741 |
+
|
| 742 |
+
Example:
|
| 743 |
+
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
|
| 744 |
+
>>> if result.success:
|
| 745 |
+
... print(f"Caption: {result.caption}")
|
| 746 |
+
... print(f"BPM: {result.bpm}")
|
| 747 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 748 |
+
"""
|
| 749 |
+
# Check if LLM is initialized
|
| 750 |
+
if not llm_handler.llm_initialized:
|
| 751 |
+
return UnderstandResult(
|
| 752 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 753 |
+
success=False,
|
| 754 |
+
error="LLM not initialized",
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# If codes are empty, use "NO USER INPUT" to generate a sample example
|
| 758 |
+
if not audio_codes or not audio_codes.strip():
|
| 759 |
+
audio_codes = "NO USER INPUT"
|
| 760 |
+
|
| 761 |
+
try:
|
| 762 |
+
# Call LLM understanding
|
| 763 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 764 |
+
audio_codes=audio_codes,
|
| 765 |
+
temperature=temperature,
|
| 766 |
+
top_k=top_k,
|
| 767 |
+
top_p=top_p,
|
| 768 |
+
repetition_penalty=repetition_penalty,
|
| 769 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 770 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
# Check if LLM returned empty metadata (error case)
|
| 774 |
+
if not metadata:
|
| 775 |
+
return UnderstandResult(
|
| 776 |
+
status_message=status or "Failed to understand audio codes",
|
| 777 |
+
success=False,
|
| 778 |
+
error=status or "Empty metadata returned",
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
# Extract and convert fields
|
| 782 |
+
caption = metadata.get('caption', '')
|
| 783 |
+
lyrics = metadata.get('lyrics', '')
|
| 784 |
+
keyscale = metadata.get('keyscale', '')
|
| 785 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 786 |
+
timesignature = metadata.get('timesignature', '')
|
| 787 |
+
|
| 788 |
+
# Convert BPM to int
|
| 789 |
+
bpm = None
|
| 790 |
+
bpm_value = metadata.get('bpm')
|
| 791 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 792 |
+
try:
|
| 793 |
+
bpm = int(bpm_value)
|
| 794 |
+
except (ValueError, TypeError):
|
| 795 |
+
pass
|
| 796 |
+
|
| 797 |
+
# Convert duration to float
|
| 798 |
+
duration = None
|
| 799 |
+
duration_value = metadata.get('duration')
|
| 800 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 801 |
+
try:
|
| 802 |
+
duration = float(duration_value)
|
| 803 |
+
except (ValueError, TypeError):
|
| 804 |
+
pass
|
| 805 |
+
|
| 806 |
+
# Clean up N/A values
|
| 807 |
+
if keyscale == 'N/A':
|
| 808 |
+
keyscale = ''
|
| 809 |
+
if language == 'N/A':
|
| 810 |
+
language = ''
|
| 811 |
+
if timesignature == 'N/A':
|
| 812 |
+
timesignature = ''
|
| 813 |
+
|
| 814 |
+
return UnderstandResult(
|
| 815 |
+
caption=caption,
|
| 816 |
+
lyrics=lyrics,
|
| 817 |
+
bpm=bpm,
|
| 818 |
+
duration=duration,
|
| 819 |
+
keyscale=keyscale,
|
| 820 |
+
language=language,
|
| 821 |
+
timesignature=timesignature,
|
| 822 |
+
status_message=status,
|
| 823 |
+
success=True,
|
| 824 |
+
error=None,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
except Exception as e:
|
| 828 |
+
logger.exception("Music understanding failed")
|
| 829 |
+
return UnderstandResult(
|
| 830 |
+
status_message=f"Error: {str(e)}",
|
| 831 |
+
success=False,
|
| 832 |
+
error=str(e),
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
@dataclass
|
| 837 |
+
class CreateSampleResult:
|
| 838 |
+
"""Result of creating a music sample from a natural language query.
|
| 839 |
+
|
| 840 |
+
This is used by the "Simple Mode" / "Inspiration Mode" feature where users
|
| 841 |
+
provide a natural language description and the LLM generates a complete
|
| 842 |
+
sample with caption, lyrics, and metadata.
|
| 843 |
+
|
| 844 |
+
Attributes:
|
| 845 |
+
# Metadata Fields
|
| 846 |
+
caption: Generated detailed music description/caption
|
| 847 |
+
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
|
| 848 |
+
bpm: Beats per minute (None if not generated)
|
| 849 |
+
duration: Duration in seconds (None if not generated)
|
| 850 |
+
keyscale: Musical key (e.g., "C Major")
|
| 851 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 852 |
+
timesignature: Time signature (e.g., "4")
|
| 853 |
+
instrumental: Whether this is an instrumental piece
|
| 854 |
+
|
| 855 |
+
# Status
|
| 856 |
+
status_message: Status message from sample creation
|
| 857 |
+
success: Whether sample creation completed successfully
|
| 858 |
+
error: Error message if sample creation failed
|
| 859 |
+
"""
|
| 860 |
+
# Metadata Fields
|
| 861 |
+
caption: str = ""
|
| 862 |
+
lyrics: str = ""
|
| 863 |
+
bpm: Optional[int] = None
|
| 864 |
+
duration: Optional[float] = None
|
| 865 |
+
keyscale: str = ""
|
| 866 |
+
language: str = ""
|
| 867 |
+
timesignature: str = ""
|
| 868 |
+
instrumental: bool = False
|
| 869 |
+
|
| 870 |
+
# Status
|
| 871 |
+
status_message: str = ""
|
| 872 |
+
success: bool = True
|
| 873 |
+
error: Optional[str] = None
|
| 874 |
+
|
| 875 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 876 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 877 |
+
return asdict(self)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
def create_sample(
|
| 881 |
+
llm_handler,
|
| 882 |
+
query: str,
|
| 883 |
+
instrumental: bool = False,
|
| 884 |
+
vocal_language: Optional[str] = None,
|
| 885 |
+
temperature: float = 0.85,
|
| 886 |
+
top_k: Optional[int] = None,
|
| 887 |
+
top_p: Optional[float] = None,
|
| 888 |
+
repetition_penalty: float = 1.0,
|
| 889 |
+
use_constrained_decoding: bool = True,
|
| 890 |
+
constrained_decoding_debug: bool = False,
|
| 891 |
+
) -> CreateSampleResult:
|
| 892 |
+
"""Create a music sample from a natural language query using the 5Hz Language Model.
|
| 893 |
+
|
| 894 |
+
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
|
| 895 |
+
language description of music and generates a complete sample including:
|
| 896 |
+
- Detailed caption/description
|
| 897 |
+
- Lyrics (unless instrumental)
|
| 898 |
+
- Metadata (BPM, duration, key, language, time signature)
|
| 899 |
+
|
| 900 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 904 |
+
query: User's natural language music description (e.g., "a soft Bengali love song")
|
| 905 |
+
instrumental: Whether to generate instrumental music (no vocals)
|
| 906 |
+
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
|
| 907 |
+
If provided, the model will be constrained to generate lyrics in this language.
|
| 908 |
+
If None or "unknown", no language constraint is applied.
|
| 909 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 910 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 911 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 912 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 913 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 914 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 915 |
+
|
| 916 |
+
Returns:
|
| 917 |
+
CreateSampleResult with generated sample fields and status
|
| 918 |
+
|
| 919 |
+
Example:
|
| 920 |
+
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
|
| 921 |
+
>>> if result.success:
|
| 922 |
+
... print(f"Caption: {result.caption}")
|
| 923 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 924 |
+
... print(f"BPM: {result.bpm}")
|
| 925 |
+
"""
|
| 926 |
+
# Check if LLM is initialized
|
| 927 |
+
if not llm_handler.llm_initialized:
|
| 928 |
+
return CreateSampleResult(
|
| 929 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 930 |
+
success=False,
|
| 931 |
+
error="LLM not initialized",
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
try:
|
| 935 |
+
# Call LLM to create sample
|
| 936 |
+
metadata, status = llm_handler.create_sample_from_query(
|
| 937 |
+
query=query,
|
| 938 |
+
instrumental=instrumental,
|
| 939 |
+
vocal_language=vocal_language,
|
| 940 |
+
temperature=temperature,
|
| 941 |
+
top_k=top_k,
|
| 942 |
+
top_p=top_p,
|
| 943 |
+
repetition_penalty=repetition_penalty,
|
| 944 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 945 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
# Check if LLM returned empty metadata (error case)
|
| 949 |
+
if not metadata:
|
| 950 |
+
return CreateSampleResult(
|
| 951 |
+
status_message=status or "Failed to create sample",
|
| 952 |
+
success=False,
|
| 953 |
+
error=status or "Empty metadata returned",
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
# Extract and convert fields
|
| 957 |
+
caption = metadata.get('caption', '')
|
| 958 |
+
lyrics = metadata.get('lyrics', '')
|
| 959 |
+
keyscale = metadata.get('keyscale', '')
|
| 960 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 961 |
+
timesignature = metadata.get('timesignature', '')
|
| 962 |
+
is_instrumental = metadata.get('instrumental', instrumental)
|
| 963 |
+
|
| 964 |
+
# Convert BPM to int
|
| 965 |
+
bpm = None
|
| 966 |
+
bpm_value = metadata.get('bpm')
|
| 967 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 968 |
+
try:
|
| 969 |
+
bpm = int(bpm_value)
|
| 970 |
+
except (ValueError, TypeError):
|
| 971 |
+
pass
|
| 972 |
+
|
| 973 |
+
# Convert duration to float
|
| 974 |
+
duration = None
|
| 975 |
+
duration_value = metadata.get('duration')
|
| 976 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 977 |
+
try:
|
| 978 |
+
duration = float(duration_value)
|
| 979 |
+
except (ValueError, TypeError):
|
| 980 |
+
pass
|
| 981 |
+
|
| 982 |
+
# Clean up N/A values
|
| 983 |
+
if keyscale == 'N/A':
|
| 984 |
+
keyscale = ''
|
| 985 |
+
if language == 'N/A':
|
| 986 |
+
language = ''
|
| 987 |
+
if timesignature == 'N/A':
|
| 988 |
+
timesignature = ''
|
| 989 |
+
|
| 990 |
+
return CreateSampleResult(
|
| 991 |
+
caption=caption,
|
| 992 |
+
lyrics=lyrics,
|
| 993 |
+
bpm=bpm,
|
| 994 |
+
duration=duration,
|
| 995 |
+
keyscale=keyscale,
|
| 996 |
+
language=language,
|
| 997 |
+
timesignature=timesignature,
|
| 998 |
+
instrumental=is_instrumental,
|
| 999 |
+
status_message=status,
|
| 1000 |
+
success=True,
|
| 1001 |
+
error=None,
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
except Exception as e:
|
| 1005 |
+
logger.exception("Sample creation failed")
|
| 1006 |
+
return CreateSampleResult(
|
| 1007 |
+
status_message=f"Error: {str(e)}",
|
| 1008 |
+
success=False,
|
| 1009 |
+
error=str(e),
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
|
| 1013 |
+
@dataclass
|
| 1014 |
+
class FormatSampleResult:
|
| 1015 |
+
"""Result of formatting user-provided caption and lyrics.
|
| 1016 |
+
|
| 1017 |
+
This is used by the "Format" feature where users provide caption and lyrics,
|
| 1018 |
+
and the LLM formats them into structured music metadata and an enhanced description.
|
| 1019 |
+
|
| 1020 |
+
Attributes:
|
| 1021 |
+
# Metadata Fields
|
| 1022 |
+
caption: Enhanced/formatted music description/caption
|
| 1023 |
+
lyrics: Formatted lyrics (may be same as input or reformatted)
|
| 1024 |
+
bpm: Beats per minute (None if not detected)
|
| 1025 |
+
duration: Duration in seconds (None if not detected)
|
| 1026 |
+
keyscale: Musical key (e.g., "C Major")
|
| 1027 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 1028 |
+
timesignature: Time signature (e.g., "4")
|
| 1029 |
+
|
| 1030 |
+
# Status
|
| 1031 |
+
status_message: Status message from formatting
|
| 1032 |
+
success: Whether formatting completed successfully
|
| 1033 |
+
error: Error message if formatting failed
|
| 1034 |
+
"""
|
| 1035 |
+
# Metadata Fields
|
| 1036 |
+
caption: str = ""
|
| 1037 |
+
lyrics: str = ""
|
| 1038 |
+
bpm: Optional[int] = None
|
| 1039 |
+
duration: Optional[float] = None
|
| 1040 |
+
keyscale: str = ""
|
| 1041 |
+
language: str = ""
|
| 1042 |
+
timesignature: str = ""
|
| 1043 |
+
|
| 1044 |
+
# Status
|
| 1045 |
+
status_message: str = ""
|
| 1046 |
+
success: bool = True
|
| 1047 |
+
error: Optional[str] = None
|
| 1048 |
+
|
| 1049 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1050 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 1051 |
+
return asdict(self)
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def format_sample(
|
| 1055 |
+
llm_handler,
|
| 1056 |
+
caption: str,
|
| 1057 |
+
lyrics: str,
|
| 1058 |
+
user_metadata: Optional[Dict[str, Any]] = None,
|
| 1059 |
+
temperature: float = 0.85,
|
| 1060 |
+
top_k: Optional[int] = None,
|
| 1061 |
+
top_p: Optional[float] = None,
|
| 1062 |
+
repetition_penalty: float = 1.0,
|
| 1063 |
+
use_constrained_decoding: bool = True,
|
| 1064 |
+
constrained_decoding_debug: bool = False,
|
| 1065 |
+
) -> FormatSampleResult:
|
| 1066 |
+
"""Format user-provided caption and lyrics using the 5Hz Language Model.
|
| 1067 |
+
|
| 1068 |
+
This function takes user input (caption and lyrics) and generates structured
|
| 1069 |
+
music metadata including an enhanced caption, BPM, duration, key, language,
|
| 1070 |
+
and time signature.
|
| 1071 |
+
|
| 1072 |
+
If user_metadata is provided, those values will be used to constrain the
|
| 1073 |
+
decoding, ensuring the output matches user-specified values.
|
| 1074 |
+
|
| 1075 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 1076 |
+
|
| 1077 |
+
Args:
|
| 1078 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 1079 |
+
caption: User's caption/description (e.g., "Latin pop, reggaeton")
|
| 1080 |
+
lyrics: User's lyrics with structure tags
|
| 1081 |
+
user_metadata: Optional dict with user-provided metadata to constrain decoding.
|
| 1082 |
+
Supported keys: bpm, duration, keyscale, timesignature, language
|
| 1083 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 1084 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 1085 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 1086 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 1087 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 1088 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1089 |
+
|
| 1090 |
+
Returns:
|
| 1091 |
+
FormatSampleResult with formatted metadata fields and status
|
| 1092 |
+
|
| 1093 |
+
Example:
|
| 1094 |
+
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
|
| 1095 |
+
>>> if result.success:
|
| 1096 |
+
... print(f"Caption: {result.caption}")
|
| 1097 |
+
... print(f"BPM: {result.bpm}")
|
| 1098 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 1099 |
+
"""
|
| 1100 |
+
# Check if LLM is initialized
|
| 1101 |
+
if not llm_handler.llm_initialized:
|
| 1102 |
+
return FormatSampleResult(
|
| 1103 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 1104 |
+
success=False,
|
| 1105 |
+
error="LLM not initialized",
|
| 1106 |
+
)
|
| 1107 |
+
|
| 1108 |
+
try:
|
| 1109 |
+
# Call LLM formatting
|
| 1110 |
+
metadata, status = llm_handler.format_sample_from_input(
|
| 1111 |
+
caption=caption,
|
| 1112 |
+
lyrics=lyrics,
|
| 1113 |
+
user_metadata=user_metadata,
|
| 1114 |
+
temperature=temperature,
|
| 1115 |
+
top_k=top_k,
|
| 1116 |
+
top_p=top_p,
|
| 1117 |
+
repetition_penalty=repetition_penalty,
|
| 1118 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1119 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
# Check if LLM returned empty metadata (error case)
|
| 1123 |
+
if not metadata:
|
| 1124 |
+
return FormatSampleResult(
|
| 1125 |
+
status_message=status or "Failed to format input",
|
| 1126 |
+
success=False,
|
| 1127 |
+
error=status or "Empty metadata returned",
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
# Extract and convert fields
|
| 1131 |
+
result_caption = metadata.get('caption', '')
|
| 1132 |
+
result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
|
| 1133 |
+
keyscale = metadata.get('keyscale', '')
|
| 1134 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 1135 |
+
timesignature = metadata.get('timesignature', '')
|
| 1136 |
+
|
| 1137 |
+
# Convert BPM to int
|
| 1138 |
+
bpm = None
|
| 1139 |
+
bpm_value = metadata.get('bpm')
|
| 1140 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 1141 |
+
try:
|
| 1142 |
+
bpm = int(bpm_value)
|
| 1143 |
+
except (ValueError, TypeError):
|
| 1144 |
+
pass
|
| 1145 |
+
|
| 1146 |
+
# Convert duration to float
|
| 1147 |
+
duration = None
|
| 1148 |
+
duration_value = metadata.get('duration')
|
| 1149 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 1150 |
+
try:
|
| 1151 |
+
duration = float(duration_value)
|
| 1152 |
+
except (ValueError, TypeError):
|
| 1153 |
+
pass
|
| 1154 |
+
|
| 1155 |
+
# Clean up N/A values
|
| 1156 |
+
if keyscale == 'N/A':
|
| 1157 |
+
keyscale = ''
|
| 1158 |
+
if language == 'N/A':
|
| 1159 |
+
language = ''
|
| 1160 |
+
if timesignature == 'N/A':
|
| 1161 |
+
timesignature = ''
|
| 1162 |
+
|
| 1163 |
+
return FormatSampleResult(
|
| 1164 |
+
caption=result_caption,
|
| 1165 |
+
lyrics=result_lyrics,
|
| 1166 |
+
bpm=bpm,
|
| 1167 |
+
duration=duration,
|
| 1168 |
+
keyscale=keyscale,
|
| 1169 |
+
language=language,
|
| 1170 |
+
timesignature=timesignature,
|
| 1171 |
+
status_message=status,
|
| 1172 |
+
success=True,
|
| 1173 |
+
error=None,
|
| 1174 |
+
)
|
| 1175 |
+
|
| 1176 |
+
except Exception as e:
|
| 1177 |
+
logger.exception("Format sample failed")
|
| 1178 |
+
return FormatSampleResult(
|
| 1179 |
+
status_message=f"Error: {str(e)}",
|
| 1180 |
+
success=False,
|
| 1181 |
+
error=str(e),
|
| 1182 |
+
)
|
spaces/Ace-Step-v1.5/acestep/llm_inference.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
spaces/Ace-Step-v1.5/acestep/local_cache.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local cache module to replace Redis
|
| 2 |
+
|
| 3 |
+
Uses diskcache as backend, provides Redis-compatible API.
|
| 4 |
+
Supports persistent storage and TTL expiration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Optional
|
| 10 |
+
from threading import Lock
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from diskcache import Cache
|
| 14 |
+
HAS_DISKCACHE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
HAS_DISKCACHE = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LocalCache:
|
| 20 |
+
"""
|
| 21 |
+
Local cache implementation with Redis-compatible API.
|
| 22 |
+
Uses diskcache as backend, supports persistence and TTL.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
_instance = None
|
| 26 |
+
_lock = Lock()
|
| 27 |
+
|
| 28 |
+
def __new__(cls, cache_dir: Optional[str] = None):
|
| 29 |
+
"""Singleton pattern"""
|
| 30 |
+
if cls._instance is None:
|
| 31 |
+
with cls._lock:
|
| 32 |
+
if cls._instance is None:
|
| 33 |
+
cls._instance = super().__new__(cls)
|
| 34 |
+
cls._instance._initialized = False
|
| 35 |
+
return cls._instance
|
| 36 |
+
|
| 37 |
+
def __init__(self, cache_dir: Optional[str] = None):
|
| 38 |
+
if getattr(self, '_initialized', False):
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
if not HAS_DISKCACHE:
|
| 42 |
+
raise ImportError(
|
| 43 |
+
"diskcache not installed. Run: pip install diskcache"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if cache_dir is None:
|
| 47 |
+
cache_dir = os.path.join(
|
| 48 |
+
os.path.dirname(os.path.dirname(__file__)),
|
| 49 |
+
".cache",
|
| 50 |
+
"local_redis"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 54 |
+
self._cache = Cache(cache_dir)
|
| 55 |
+
self._initialized = True
|
| 56 |
+
|
| 57 |
+
def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
|
| 58 |
+
"""
|
| 59 |
+
Set key-value pair
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
name: Key name
|
| 63 |
+
value: Value (auto-serialize dict/list)
|
| 64 |
+
ex: Expiration time (seconds)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
bool: Success status
|
| 68 |
+
"""
|
| 69 |
+
if isinstance(value, (dict, list)):
|
| 70 |
+
value = json.dumps(value, ensure_ascii=False)
|
| 71 |
+
self._cache.set(name, value, expire=ex)
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
def get(self, name: str) -> Optional[str]:
|
| 75 |
+
"""Get value"""
|
| 76 |
+
return self._cache.get(name)
|
| 77 |
+
|
| 78 |
+
def delete(self, name: str) -> int:
|
| 79 |
+
"""Delete key, returns number of deleted items"""
|
| 80 |
+
return 1 if self._cache.delete(name) else 0
|
| 81 |
+
|
| 82 |
+
def exists(self, name: str) -> bool:
|
| 83 |
+
"""Check if key exists"""
|
| 84 |
+
return name in self._cache
|
| 85 |
+
|
| 86 |
+
def keys(self, pattern: str = "*") -> list:
|
| 87 |
+
"""
|
| 88 |
+
Get list of matching keys
|
| 89 |
+
Note: Simplified implementation, only supports prefix and full matching
|
| 90 |
+
"""
|
| 91 |
+
if pattern == "*":
|
| 92 |
+
return list(self._cache.iterkeys())
|
| 93 |
+
|
| 94 |
+
prefix = pattern.rstrip("*")
|
| 95 |
+
return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
|
| 96 |
+
|
| 97 |
+
def expire(self, name: str, seconds: int) -> bool:
|
| 98 |
+
"""Set key expiration time"""
|
| 99 |
+
value = self._cache.get(name)
|
| 100 |
+
if value is not None:
|
| 101 |
+
self._cache.set(name, value, expire=seconds)
|
| 102 |
+
return True
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def ttl(self, name: str) -> int:
|
| 106 |
+
"""
|
| 107 |
+
Get remaining time to live (seconds)
|
| 108 |
+
Note: diskcache does not directly support TTL queries
|
| 109 |
+
"""
|
| 110 |
+
if name in self._cache:
|
| 111 |
+
return -1 # Exists but TTL unknown
|
| 112 |
+
return -2 # Key does not exist
|
| 113 |
+
|
| 114 |
+
def close(self):
|
| 115 |
+
"""Close cache connection"""
|
| 116 |
+
if hasattr(self, '_cache'):
|
| 117 |
+
self._cache.close()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Lazily initialized global instance
|
| 121 |
+
_local_cache: Optional[LocalCache] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
|
| 125 |
+
"""Get local cache instance"""
|
| 126 |
+
global _local_cache
|
| 127 |
+
if _local_cache is None:
|
| 128 |
+
_local_cache = LocalCache(cache_dir)
|
| 129 |
+
return _local_cache
|
spaces/Ace-Step-v1.5/acestep/test_time_scaling.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test-Time Scaling Module
|
| 3 |
+
Implements perplexity-based scoring for generated audio codes
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Tuple, Optional, Dict, Any, List
|
| 8 |
+
from loguru import logger
|
| 9 |
+
import yaml
|
| 10 |
+
import math
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
|
| 15 |
+
"""
|
| 16 |
+
Calculate Pointwise Mutual Information (PMI) score.
|
| 17 |
+
|
| 18 |
+
PMI = log P(condition|codes) - log P(condition)
|
| 19 |
+
= log [P(codes|condition) / P(codes)]
|
| 20 |
+
|
| 21 |
+
This removes the bias from P(condition) and measures how much the codes
|
| 22 |
+
improve our ability to predict the condition.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
log_prob_conditional: Average log probability of condition given codes
|
| 26 |
+
log_prob_unconditional: Average log probability of condition without codes
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
PMI score (higher is better, can be positive or negative)
|
| 30 |
+
- Positive: codes improve prediction → good match
|
| 31 |
+
- Zero: codes don't help → no correlation
|
| 32 |
+
- Negative: codes hurt prediction → poor match
|
| 33 |
+
"""
|
| 34 |
+
return log_prob_conditional - log_prob_unconditional
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Convert PMI score to normalized [0, 1] range using sigmoid function.
|
| 40 |
+
|
| 41 |
+
score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
pmi: PMI score (can be positive or negative)
|
| 45 |
+
scale: Scale parameter to control sensitivity (default 0.1)
|
| 46 |
+
- Smaller scale: more sensitive to PMI changes
|
| 47 |
+
- Larger scale: less sensitive to PMI changes
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Normalized score in [0, 1] range, where:
|
| 51 |
+
- PMI > 0 → score > 0.5 (good match)
|
| 52 |
+
- PMI = 0 → score = 0.5 (neutral)
|
| 53 |
+
- PMI < 0 → score < 0.5 (poor match)
|
| 54 |
+
|
| 55 |
+
Examples (scale=1.0):
|
| 56 |
+
PMI=2.0 → score≈0.88 (excellent)
|
| 57 |
+
PMI=1.0 → score≈0.73 (good)
|
| 58 |
+
PMI=0.0 → score=0.50 (neutral)
|
| 59 |
+
PMI=-1.0 → score≈0.27 (poor)
|
| 60 |
+
PMI=-2.0 → score≈0.12 (bad)
|
| 61 |
+
"""
|
| 62 |
+
return 1.0 / (1.0 + math.exp(-pmi / scale))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
|
| 66 |
+
target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
llm_handler: The handler containing the model and tokenizer.
|
| 70 |
+
formatted_prompt: The input context.
|
| 71 |
+
target_text: The text we want to calculate probability/recall for.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple of (target_logits, target_ids)
|
| 75 |
+
- target_logits: Logits used to predict the target tokens.
|
| 76 |
+
- target_ids: The ground truth token IDs of the target.
|
| 77 |
+
"""
|
| 78 |
+
model = llm_handler.get_hf_model_for_scoring()
|
| 79 |
+
tokenizer = llm_handler.llm_tokenizer
|
| 80 |
+
device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
|
| 81 |
+
|
| 82 |
+
# 1. Tokenize prompt ONLY to get its length (used for slicing later).
|
| 83 |
+
# We must ensure special tokens are added to count the offset correctly.
|
| 84 |
+
prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
|
| 85 |
+
prompt_len = prompt_tokens_temp['input_ids'].shape[1]
|
| 86 |
+
|
| 87 |
+
# 2. Tokenize the FULL text (Prompt + Target).
|
| 88 |
+
# This ensures subword merging at boundaries is handled correctly by the tokenizer.
|
| 89 |
+
full_text = formatted_prompt + target_text
|
| 90 |
+
full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
|
| 91 |
+
|
| 92 |
+
input_ids = full_tokens['input_ids']
|
| 93 |
+
|
| 94 |
+
# Safety check: if target was empty or truncated entirely
|
| 95 |
+
if input_ids.shape[1] <= prompt_len:
|
| 96 |
+
return torch.empty(0, device=device), torch.empty(0, device=device)
|
| 97 |
+
|
| 98 |
+
# 3. Forward Pass (Teacher Forcing)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
with llm_handler._load_model_context():
|
| 101 |
+
outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
|
| 102 |
+
all_logits = outputs.logits # [1, seq_len, vocab_size]
|
| 103 |
+
|
| 104 |
+
# 4. Extract Logits and Labels
|
| 105 |
+
# We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
|
| 106 |
+
# Target starts at index `prompt_len`.
|
| 107 |
+
# So we need logits from `prompt_len - 1` up to the second to last position.
|
| 108 |
+
|
| 109 |
+
target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
|
| 110 |
+
target_ids = input_ids[0, prompt_len:] # [target_len]
|
| 111 |
+
|
| 112 |
+
return target_logits, target_ids
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ==============================================================================
|
| 116 |
+
# Scoring Logic
|
| 117 |
+
# ==============================================================================
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _calculate_topk_recall(llm_handler,
|
| 121 |
+
formatted_prompt: str,
|
| 122 |
+
target_text: str,
|
| 123 |
+
topk: int = 10) -> Tuple[float, Dict[int, float]]:
|
| 124 |
+
"""
|
| 125 |
+
Calculate top-k recall for target text given prompt.
|
| 126 |
+
Checks if the ground truth token is within the top-k probabilities at each step.
|
| 127 |
+
"""
|
| 128 |
+
# Use the fixed helper to get aligned logits/labels
|
| 129 |
+
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
|
| 130 |
+
|
| 131 |
+
if target_ids.shape[0] == 0:
|
| 132 |
+
return 0.0, {}
|
| 133 |
+
|
| 134 |
+
target_len = target_ids.shape[0]
|
| 135 |
+
|
| 136 |
+
# Get top-k indices for all positions at once
|
| 137 |
+
# topk_indices: [target_len, topk]
|
| 138 |
+
_, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
|
| 139 |
+
|
| 140 |
+
recall_per_k = {}
|
| 141 |
+
position_scores = []
|
| 142 |
+
|
| 143 |
+
# Convert to list for faster CPU iteration
|
| 144 |
+
target_ids_list = target_ids.tolist()
|
| 145 |
+
topk_indices_list = topk_indices.tolist()
|
| 146 |
+
|
| 147 |
+
for k in range(1, topk + 1):
|
| 148 |
+
hits = 0
|
| 149 |
+
for pos in range(target_len):
|
| 150 |
+
gt_token = target_ids_list[pos]
|
| 151 |
+
# Check the top-k slice
|
| 152 |
+
topk_at_pos = topk_indices_list[pos][:k]
|
| 153 |
+
|
| 154 |
+
if gt_token in topk_at_pos:
|
| 155 |
+
hits += 1
|
| 156 |
+
# Calculate position-weighted score only once (when k=topk)
|
| 157 |
+
if k == topk:
|
| 158 |
+
rank = topk_at_pos.index(gt_token) + 1
|
| 159 |
+
# Rank 1 = 1.0, Rank k = small positive
|
| 160 |
+
position_weight = 1.0 - (rank - 1) / topk
|
| 161 |
+
position_scores.append(position_weight)
|
| 162 |
+
|
| 163 |
+
recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
|
| 164 |
+
|
| 165 |
+
# Fill scores for positions where GT was NOT in top-k
|
| 166 |
+
while len(position_scores) < target_len:
|
| 167 |
+
position_scores.append(0.0)
|
| 168 |
+
|
| 169 |
+
average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
|
| 170 |
+
|
| 171 |
+
return average_recall, recall_per_k
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _calculate_metadata_recall(llm_handler,
|
| 175 |
+
formatted_prompt: str,
|
| 176 |
+
fields_dict: Dict[str, Any],
|
| 177 |
+
topk: int = 10) -> Dict[str, float]:
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
fields_dict: Dictionary of {field_name: field_value}
|
| 181 |
+
"""
|
| 182 |
+
if not fields_dict:
|
| 183 |
+
return {}
|
| 184 |
+
|
| 185 |
+
field_scores = {}
|
| 186 |
+
|
| 187 |
+
for field_name in sorted(fields_dict.keys()):
|
| 188 |
+
# Construct target text for this specific field
|
| 189 |
+
# e.g. <think>\nbpm: 120\n</think>\n
|
| 190 |
+
field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
|
| 191 |
+
field_target_text = f"<think>\n{field_yaml}\n</think>\n"
|
| 192 |
+
|
| 193 |
+
# Calculate recall using the robust logic
|
| 194 |
+
avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
|
| 195 |
+
|
| 196 |
+
field_scores[field_name] = avg_score
|
| 197 |
+
logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
|
| 198 |
+
|
| 199 |
+
return field_scores
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _calculate_log_prob(
|
| 203 |
+
llm_handler,
|
| 204 |
+
formatted_prompt: str,
|
| 205 |
+
target_text: str,
|
| 206 |
+
temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
|
| 207 |
+
) -> float:
|
| 208 |
+
"""
|
| 209 |
+
Calculate average log probability of target text given prompt.
|
| 210 |
+
"""
|
| 211 |
+
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
|
| 212 |
+
|
| 213 |
+
if target_ids.shape[0] == 0:
|
| 214 |
+
return float('-inf')
|
| 215 |
+
|
| 216 |
+
# FIX: Do not divide by temperature.
|
| 217 |
+
# Log-probability for PMI/Perplexity should be exact.
|
| 218 |
+
|
| 219 |
+
# Calculate log probabilities (log_softmax)
|
| 220 |
+
log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
|
| 221 |
+
|
| 222 |
+
# Gather log probabilities of the ground truth tokens
|
| 223 |
+
target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
|
| 224 |
+
|
| 225 |
+
# Return average log probability
|
| 226 |
+
mean_log_prob = target_log_probs.mean().item()
|
| 227 |
+
|
| 228 |
+
return mean_log_prob
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def calculate_reward_score(
|
| 232 |
+
scores: Dict[str, float],
|
| 233 |
+
weights_config: Optional[Dict[str, float]] = None
|
| 234 |
+
) -> Tuple[float, str]:
|
| 235 |
+
"""
|
| 236 |
+
Reward Model Calculator: Computes a final reward based on user priorities.
|
| 237 |
+
|
| 238 |
+
Priority Logic:
|
| 239 |
+
1. Caption (Highest): The overall vibe/style must match.
|
| 240 |
+
2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
|
| 241 |
+
3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
|
| 242 |
+
|
| 243 |
+
Strategy: Dynamic Weighted Sum
|
| 244 |
+
- Metadata fields are aggregated into a single 'metadata' score first.
|
| 245 |
+
- Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
|
| 249 |
+
weights_config: Optional custom weights. Defaults to:
|
| 250 |
+
Caption (50%), Lyrics (30%), Metadata (20%).
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
final_reward: The calculated reward score (0.0 - 1.0).
|
| 254 |
+
explanation: A formatted string explaining how the score was derived.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
# 1. Default Preference Configuration
|
| 258 |
+
# These weights determine the relative importance of each component.
|
| 259 |
+
if weights_config is None:
|
| 260 |
+
weights_config = {
|
| 261 |
+
'caption': 0.50, # High priority: Style/Vibe
|
| 262 |
+
'lyrics': 0.30, # Medium priority: Content
|
| 263 |
+
'metadata': 0.20 # Low priority: Technical details
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# 2. Extract and Group Scores
|
| 267 |
+
# Caption and Lyrics are standalone high-level features.
|
| 268 |
+
caption_score = scores.get('caption')
|
| 269 |
+
lyrics_score = scores.get('lyrics')
|
| 270 |
+
|
| 271 |
+
# Metadata fields (bpm, key, duration, etc.) are aggregated.
|
| 272 |
+
# We treat them as a single "Technical Score" to prevent them from
|
| 273 |
+
# diluting the weight of Caption/Lyrics simply by having many fields.
|
| 274 |
+
meta_scores_list = [
|
| 275 |
+
val for key, val in scores.items()
|
| 276 |
+
if key not in ['caption', 'lyrics']
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
# Calculate average of all metadata fields (if any exist)
|
| 280 |
+
meta_aggregate_score = None
|
| 281 |
+
if meta_scores_list:
|
| 282 |
+
meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
|
| 283 |
+
|
| 284 |
+
# 3. specific Active Components & Dynamic Weighting
|
| 285 |
+
# We only include components that actually exist in this generation.
|
| 286 |
+
active_components = {}
|
| 287 |
+
|
| 288 |
+
if caption_score is not None:
|
| 289 |
+
active_components['caption'] = (caption_score, weights_config['caption'])
|
| 290 |
+
|
| 291 |
+
if lyrics_score is not None:
|
| 292 |
+
active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
|
| 293 |
+
|
| 294 |
+
if meta_aggregate_score is not None:
|
| 295 |
+
active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
|
| 296 |
+
|
| 297 |
+
# 4. Calculate Final Weighted Score
|
| 298 |
+
total_base_weight = sum(w for _, w in active_components.values())
|
| 299 |
+
total_score = 0.0
|
| 300 |
+
|
| 301 |
+
breakdown_lines = []
|
| 302 |
+
|
| 303 |
+
if total_base_weight == 0:
|
| 304 |
+
return 0.0, "❌ No valid scores available to calculate reward."
|
| 305 |
+
|
| 306 |
+
# Sort by weight (importance) for display
|
| 307 |
+
sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
|
| 308 |
+
|
| 309 |
+
for name, (score, base_weight) in sorted_components:
|
| 310 |
+
# Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
|
| 311 |
+
normalized_weight = base_weight / total_base_weight
|
| 312 |
+
weighted_contribution = score * normalized_weight
|
| 313 |
+
total_score += weighted_contribution
|
| 314 |
+
|
| 315 |
+
breakdown_lines.append(
|
| 316 |
+
f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
|
| 317 |
+
f"-> Contrib: +{weighted_contribution:.4f}"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return total_score, "\n".join(breakdown_lines)
|
| 321 |
+
|
| 322 |
+
# ==============================================================================
|
| 323 |
+
# Main Public API
|
| 324 |
+
# ==============================================================================
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def calculate_pmi_score_per_condition(
|
| 328 |
+
llm_handler,
|
| 329 |
+
audio_codes: str,
|
| 330 |
+
caption: str = "",
|
| 331 |
+
lyrics: str = "",
|
| 332 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 333 |
+
temperature: float = 1.0,
|
| 334 |
+
topk: int = 10,
|
| 335 |
+
score_scale: float = 0.1,
|
| 336 |
+
) -> Tuple[Dict[str, float], float, str]:
|
| 337 |
+
"""
|
| 338 |
+
Calculate quality score separately for each condition.
|
| 339 |
+
- Metadata: Uses Top-k Recall.
|
| 340 |
+
- Caption/Lyrics: Uses PMI (Normalized).
|
| 341 |
+
"""
|
| 342 |
+
if not llm_handler.llm_initialized:
|
| 343 |
+
return {}, 0.0, "❌ LLM not initialized"
|
| 344 |
+
|
| 345 |
+
if not audio_codes or not audio_codes.strip():
|
| 346 |
+
return {}, 0.0, "❌ No audio codes provided"
|
| 347 |
+
|
| 348 |
+
if "caption" not in metadata:
|
| 349 |
+
metadata['caption'] = caption
|
| 350 |
+
|
| 351 |
+
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
|
| 352 |
+
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
|
| 353 |
+
try:
|
| 354 |
+
# 1. Calculate Recall for Metadata Fields
|
| 355 |
+
if metadata and isinstance(metadata, dict):
|
| 356 |
+
scores = {}
|
| 357 |
+
# Define which fields use which metric
|
| 358 |
+
metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
|
| 359 |
+
metadata_pmi_keys = ['caption']
|
| 360 |
+
for key in metadata_recall_keys:
|
| 361 |
+
if key in metadata and metadata[key] is not None:
|
| 362 |
+
recall_metadata = {key: metadata[key]}
|
| 363 |
+
field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
|
| 364 |
+
scores.update(field_scores)
|
| 365 |
+
|
| 366 |
+
# 2. Calculate PMI for Caption
|
| 367 |
+
for key in metadata_pmi_keys:
|
| 368 |
+
if key in metadata and metadata[key] is not None:
|
| 369 |
+
cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
|
| 370 |
+
target_text = f"<think>\n{cot_yaml}\n</think>\n"
|
| 371 |
+
|
| 372 |
+
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
|
| 373 |
+
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
|
| 374 |
+
|
| 375 |
+
pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
|
| 376 |
+
scores[key] = pmi_normalized
|
| 377 |
+
|
| 378 |
+
# 3. Calculate PMI for Lyrics
|
| 379 |
+
if lyrics:
|
| 380 |
+
target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
|
| 381 |
+
|
| 382 |
+
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
|
| 383 |
+
|
| 384 |
+
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
|
| 385 |
+
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
|
| 386 |
+
|
| 387 |
+
scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
|
| 388 |
+
|
| 389 |
+
if not scores:
|
| 390 |
+
return {}, 0.0, "❌ No conditions to evaluate"
|
| 391 |
+
|
| 392 |
+
# 4. Global Score
|
| 393 |
+
global_score = sum(scores.values()) / len(scores)
|
| 394 |
+
global_score, breakdown_lines = calculate_reward_score(scores)
|
| 395 |
+
|
| 396 |
+
# Status Message
|
| 397 |
+
status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
|
| 398 |
+
for key, score in sorted(scores.items()):
|
| 399 |
+
metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
|
| 400 |
+
status_lines.append(f" {key}: {score:.4f} ({metric})")
|
| 401 |
+
status = "\n".join(status_lines)
|
| 402 |
+
logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
|
| 403 |
+
return scores, global_score, status
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
import traceback
|
| 407 |
+
error_msg = f"❌ Error: {str(e)}"
|
| 408 |
+
logger.error(error_msg)
|
| 409 |
+
logger.error(traceback.format_exc())
|
| 410 |
+
return {}, float('-inf'), error_msg
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Xingkai Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img width="300" src="assets/logo.png">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
# Nano-vLLM
|
| 10 |
+
|
| 11 |
+
A lightweight vLLM implementation built from scratch.
|
| 12 |
+
|
| 13 |
+
## Key Features
|
| 14 |
+
|
| 15 |
+
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
| 16 |
+
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
| 17 |
+
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
| 18 |
+
|
| 19 |
+
## Installation
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Model Download
|
| 26 |
+
|
| 27 |
+
To download the model weights manually, use the following command:
|
| 28 |
+
```bash
|
| 29 |
+
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
| 30 |
+
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
| 31 |
+
--local-dir-use-symlinks False
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Quick Start
|
| 35 |
+
|
| 36 |
+
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
|
| 37 |
+
```python
|
| 38 |
+
from nanovllm import LLM, SamplingParams
|
| 39 |
+
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
| 40 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 41 |
+
prompts = ["Hello, Nano-vLLM."]
|
| 42 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 43 |
+
outputs[0]["text"]
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Benchmark
|
| 47 |
+
|
| 48 |
+
See `bench.py` for benchmark.
|
| 49 |
+
|
| 50 |
+
**Test Configuration:**
|
| 51 |
+
- Hardware: RTX 4070 Laptop (8GB)
|
| 52 |
+
- Model: Qwen3-0.6B
|
| 53 |
+
- Total Requests: 256 sequences
|
| 54 |
+
- Input Length: Randomly sampled between 100–1024 tokens
|
| 55 |
+
- Output Length: Randomly sampled between 100–1024 tokens
|
| 56 |
+
|
| 57 |
+
**Performance Results:**
|
| 58 |
+
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|
| 59 |
+
|----------------|-------------|----------|-----------------------|
|
| 60 |
+
| vLLM | 133,966 | 98.37 | 1361.84 |
|
| 61 |
+
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## Star History
|
| 65 |
+
|
| 66 |
+
[](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/assets/logo.png
ADDED
|
Git LFS Details
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/bench.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from random import randint, seed
|
| 4 |
+
from nanovllm import LLM, SamplingParams
|
| 5 |
+
# from vllm import LLM, SamplingParams
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
seed(0)
|
| 10 |
+
num_seqs = 256
|
| 11 |
+
max_input_len = 1024
|
| 12 |
+
max_ouput_len = 1024
|
| 13 |
+
|
| 14 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 15 |
+
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
| 16 |
+
|
| 17 |
+
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
| 18 |
+
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
|
| 19 |
+
# uncomment the following line for vllm
|
| 20 |
+
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
| 21 |
+
|
| 22 |
+
llm.generate(["Benchmark: "], SamplingParams())
|
| 23 |
+
t = time.time()
|
| 24 |
+
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
| 25 |
+
t = (time.time() - t)
|
| 26 |
+
total_tokens = sum(sp.max_tokens for sp in sampling_params)
|
| 27 |
+
throughput = total_tokens / t
|
| 28 |
+
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/example.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from nanovllm import LLM, SamplingParams
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 8 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
| 9 |
+
llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
|
| 10 |
+
|
| 11 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 12 |
+
prompts = [
|
| 13 |
+
"introduce yourself",
|
| 14 |
+
"list all prime numbers within 100",
|
| 15 |
+
]
|
| 16 |
+
prompts = [
|
| 17 |
+
tokenizer.apply_chat_template(
|
| 18 |
+
[{"role": "user", "content": prompt}],
|
| 19 |
+
tokenize=False,
|
| 20 |
+
add_generation_prompt=True,
|
| 21 |
+
)
|
| 22 |
+
for prompt in prompts
|
| 23 |
+
]
|
| 24 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 25 |
+
|
| 26 |
+
for prompt, output in zip(prompts, outputs):
|
| 27 |
+
print("\n")
|
| 28 |
+
print(f"Prompt: {prompt!r}")
|
| 29 |
+
print(f"Completion: {output['text']!r}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from nanovllm.llm import LLM
|
| 2 |
+
from nanovllm.sampling_params import SamplingParams
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/config.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from transformers import AutoConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class Config:
|
| 8 |
+
model: str
|
| 9 |
+
max_num_batched_tokens: int = 16384
|
| 10 |
+
max_num_seqs: int = 512
|
| 11 |
+
max_model_len: int = 4096
|
| 12 |
+
gpu_memory_utilization: float = 0.9
|
| 13 |
+
tensor_parallel_size: int = 1
|
| 14 |
+
enforce_eager: bool = False
|
| 15 |
+
hf_config: AutoConfig | None = None
|
| 16 |
+
eos: int = -1
|
| 17 |
+
kvcache_block_size: int = 256
|
| 18 |
+
num_kvcache_blocks: int = -1
|
| 19 |
+
|
| 20 |
+
def __post_init__(self):
|
| 21 |
+
assert os.path.isdir(self.model)
|
| 22 |
+
assert self.kvcache_block_size % 256 == 0
|
| 23 |
+
assert 1 <= self.tensor_parallel_size <= 8
|
| 24 |
+
self.hf_config = AutoConfig.from_pretrained(self.model)
|
| 25 |
+
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
| 26 |
+
assert self.max_num_batched_tokens >= self.max_model_len
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
import xxhash
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from nanovllm.engine.sequence import Sequence
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Block:
|
| 9 |
+
|
| 10 |
+
def __init__(self, block_id):
|
| 11 |
+
self.block_id = block_id
|
| 12 |
+
self.ref_count = 0
|
| 13 |
+
self.hash = -1
|
| 14 |
+
self.token_ids = []
|
| 15 |
+
|
| 16 |
+
def update(self, hash: int, token_ids: list[int]):
|
| 17 |
+
self.hash = hash
|
| 18 |
+
self.token_ids = token_ids
|
| 19 |
+
|
| 20 |
+
def reset(self):
|
| 21 |
+
self.ref_count = 1
|
| 22 |
+
self.hash = -1
|
| 23 |
+
self.token_ids = []
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BlockManager:
|
| 27 |
+
|
| 28 |
+
def __init__(self, num_blocks: int, block_size: int):
|
| 29 |
+
self.block_size = block_size
|
| 30 |
+
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
| 31 |
+
self.hash_to_block_id: dict[int, int] = dict()
|
| 32 |
+
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
| 33 |
+
self.used_block_ids: set[int] = set()
|
| 34 |
+
|
| 35 |
+
@classmethod
|
| 36 |
+
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
| 37 |
+
h = xxhash.xxh64()
|
| 38 |
+
if prefix != -1:
|
| 39 |
+
h.update(prefix.to_bytes(8, "little"))
|
| 40 |
+
h.update(np.array(token_ids).tobytes())
|
| 41 |
+
return h.intdigest()
|
| 42 |
+
|
| 43 |
+
def _allocate_block(self, block_id: int) -> Block:
|
| 44 |
+
block = self.blocks[block_id]
|
| 45 |
+
assert block.ref_count == 0
|
| 46 |
+
block.reset()
|
| 47 |
+
self.free_block_ids.remove(block_id)
|
| 48 |
+
self.used_block_ids.add(block_id)
|
| 49 |
+
return self.blocks[block_id]
|
| 50 |
+
|
| 51 |
+
def _deallocate_block(self, block_id: int) -> Block:
|
| 52 |
+
assert self.blocks[block_id].ref_count == 0
|
| 53 |
+
self.used_block_ids.remove(block_id)
|
| 54 |
+
self.free_block_ids.append(block_id)
|
| 55 |
+
|
| 56 |
+
def can_allocate(self, seq: Sequence) -> bool:
|
| 57 |
+
return len(self.free_block_ids) >= seq.num_blocks
|
| 58 |
+
|
| 59 |
+
def allocate(self, seq: Sequence):
|
| 60 |
+
assert not seq.block_table
|
| 61 |
+
h = -1
|
| 62 |
+
cache_miss = False
|
| 63 |
+
for i in range(seq.num_blocks):
|
| 64 |
+
token_ids = seq.block(i)
|
| 65 |
+
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
| 66 |
+
block_id = self.hash_to_block_id.get(h, -1)
|
| 67 |
+
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
| 68 |
+
cache_miss = True
|
| 69 |
+
if cache_miss:
|
| 70 |
+
block_id = self.free_block_ids[0]
|
| 71 |
+
block = self._allocate_block(block_id)
|
| 72 |
+
else:
|
| 73 |
+
seq.num_cached_tokens += self.block_size
|
| 74 |
+
if block_id in self.used_block_ids:
|
| 75 |
+
block = self.blocks[block_id]
|
| 76 |
+
block.ref_count += 1
|
| 77 |
+
else:
|
| 78 |
+
block = self._allocate_block(block_id)
|
| 79 |
+
if h != -1:
|
| 80 |
+
block.update(h, token_ids)
|
| 81 |
+
self.hash_to_block_id[h] = block_id
|
| 82 |
+
seq.block_table.append(block_id)
|
| 83 |
+
|
| 84 |
+
def deallocate(self, seq: Sequence):
|
| 85 |
+
for block_id in reversed(seq.block_table):
|
| 86 |
+
block = self.blocks[block_id]
|
| 87 |
+
block.ref_count -= 1
|
| 88 |
+
if block.ref_count == 0:
|
| 89 |
+
self._deallocate_block(block_id)
|
| 90 |
+
seq.num_cached_tokens = 0
|
| 91 |
+
seq.block_table.clear()
|
| 92 |
+
|
| 93 |
+
def can_append(self, seq: Sequence) -> bool:
|
| 94 |
+
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
| 95 |
+
|
| 96 |
+
def may_append(self, seq: Sequence):
|
| 97 |
+
block_table = seq.block_table
|
| 98 |
+
last_block = self.blocks[block_table[-1]]
|
| 99 |
+
if len(seq) % self.block_size == 1:
|
| 100 |
+
assert last_block.hash != -1
|
| 101 |
+
block_id = self.free_block_ids[0]
|
| 102 |
+
self._allocate_block(block_id)
|
| 103 |
+
block_table.append(block_id)
|
| 104 |
+
elif len(seq) % self.block_size == 0:
|
| 105 |
+
assert last_block.hash == -1
|
| 106 |
+
token_ids = seq.block(seq.num_blocks-1)
|
| 107 |
+
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
| 108 |
+
h = self.compute_hash(token_ids, prefix)
|
| 109 |
+
last_block.update(h, token_ids)
|
| 110 |
+
self.hash_to_block_id[h] = last_block.block_id
|
| 111 |
+
else:
|
| 112 |
+
assert last_block.hash == -1
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import atexit
|
| 2 |
+
from dataclasses import fields
|
| 3 |
+
from time import perf_counter
|
| 4 |
+
from tqdm.auto import tqdm
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
import torch.multiprocessing as mp
|
| 7 |
+
|
| 8 |
+
from nanovllm.config import Config
|
| 9 |
+
from nanovllm.sampling_params import SamplingParams
|
| 10 |
+
from nanovllm.engine.sequence import Sequence
|
| 11 |
+
from nanovllm.engine.scheduler import Scheduler
|
| 12 |
+
from nanovllm.engine.model_runner import ModelRunner
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LLMEngine:
|
| 16 |
+
|
| 17 |
+
def __init__(self, model, **kwargs):
|
| 18 |
+
config_fields = {field.name for field in fields(Config)}
|
| 19 |
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
| 20 |
+
config = Config(model, **config_kwargs)
|
| 21 |
+
self.ps = []
|
| 22 |
+
self.events = []
|
| 23 |
+
ctx = mp.get_context("spawn")
|
| 24 |
+
for i in range(1, config.tensor_parallel_size):
|
| 25 |
+
event = ctx.Event()
|
| 26 |
+
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
| 27 |
+
process.start()
|
| 28 |
+
self.ps.append(process)
|
| 29 |
+
self.events.append(event)
|
| 30 |
+
self.model_runner = ModelRunner(config, 0, self.events)
|
| 31 |
+
tokenizer = kwargs.get("tokenizer", None)
|
| 32 |
+
if tokenizer is not None:
|
| 33 |
+
self.tokenizer = tokenizer
|
| 34 |
+
else:
|
| 35 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
| 36 |
+
config.eos = self.tokenizer.eos_token_id
|
| 37 |
+
self.scheduler = Scheduler(config)
|
| 38 |
+
atexit.register(self.exit)
|
| 39 |
+
|
| 40 |
+
def exit(self):
|
| 41 |
+
self.model_runner.call("exit")
|
| 42 |
+
del self.model_runner
|
| 43 |
+
for p in self.ps:
|
| 44 |
+
p.join()
|
| 45 |
+
|
| 46 |
+
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
|
| 47 |
+
if isinstance(prompt, str):
|
| 48 |
+
prompt = self.tokenizer.encode(prompt)
|
| 49 |
+
# For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
|
| 50 |
+
if sampling_params.cfg_scale > 1.0:
|
| 51 |
+
if unconditional_prompt is None:
|
| 52 |
+
# Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
|
| 53 |
+
# This is a fallback - ideally users should provide unconditional_prompt
|
| 54 |
+
if isinstance(prompt, list):
|
| 55 |
+
# For now, just use the same prompt (user should provide unconditional_prompt)
|
| 56 |
+
# TODO: Implement automatic "NO USER INPUT" replacement if possible
|
| 57 |
+
unconditional_prompt = prompt
|
| 58 |
+
else:
|
| 59 |
+
unconditional_prompt = prompt
|
| 60 |
+
if isinstance(unconditional_prompt, str):
|
| 61 |
+
unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
|
| 62 |
+
# Create unconditional sequence first (so we can reference it from conditional)
|
| 63 |
+
uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
|
| 64 |
+
# Create conditional sequence with reference to unconditional
|
| 65 |
+
cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
|
| 66 |
+
uncond_seq.paired_seq = cond_seq # Link them bidirectionally
|
| 67 |
+
# Add both sequences to scheduler
|
| 68 |
+
self.scheduler.add(cond_seq)
|
| 69 |
+
self.scheduler.add(uncond_seq)
|
| 70 |
+
else:
|
| 71 |
+
seq = Sequence(prompt, sampling_params)
|
| 72 |
+
self.scheduler.add(seq)
|
| 73 |
+
|
| 74 |
+
def step(self):
|
| 75 |
+
seqs, is_prefill = self.scheduler.schedule()
|
| 76 |
+
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
| 77 |
+
self.scheduler.postprocess(seqs, token_ids)
|
| 78 |
+
# Only output conditional sequences (unconditional sequences are just for CFG computation)
|
| 79 |
+
output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
|
| 80 |
+
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
|
| 81 |
+
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
|
| 82 |
+
return outputs, num_tokens
|
| 83 |
+
|
| 84 |
+
def is_finished(self):
|
| 85 |
+
return self.scheduler.is_finished()
|
| 86 |
+
|
| 87 |
+
def generate(
|
| 88 |
+
self,
|
| 89 |
+
prompts: list[str] | list[list[int]],
|
| 90 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
| 91 |
+
use_tqdm: bool = True,
|
| 92 |
+
unconditional_prompts: list[str] | list[list[int]] | None = None,
|
| 93 |
+
) -> list[str]:
|
| 94 |
+
if use_tqdm:
|
| 95 |
+
pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
|
| 96 |
+
if not isinstance(sampling_params, list):
|
| 97 |
+
sampling_params = [sampling_params] * len(prompts)
|
| 98 |
+
if unconditional_prompts is None:
|
| 99 |
+
unconditional_prompts = [None] * len(prompts)
|
| 100 |
+
for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
|
| 101 |
+
self.add_request(prompt, sp, uncond_prompt)
|
| 102 |
+
outputs = {}
|
| 103 |
+
prefill_throughput = decode_throughput = 0.
|
| 104 |
+
while not self.is_finished():
|
| 105 |
+
t = perf_counter()
|
| 106 |
+
output, num_tokens = self.step()
|
| 107 |
+
if use_tqdm:
|
| 108 |
+
if num_tokens > 0:
|
| 109 |
+
prefill_throughput = num_tokens / (perf_counter() - t)
|
| 110 |
+
else:
|
| 111 |
+
decode_throughput = -num_tokens / (perf_counter() - t)
|
| 112 |
+
pbar.set_postfix({
|
| 113 |
+
"Prefill": f"{int(prefill_throughput)}tok/s",
|
| 114 |
+
"Decode": f"{int(decode_throughput)}tok/s",
|
| 115 |
+
})
|
| 116 |
+
for seq_id, token_ids in output:
|
| 117 |
+
outputs[seq_id] = token_ids
|
| 118 |
+
if use_tqdm:
|
| 119 |
+
pbar.update(1)
|
| 120 |
+
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
| 121 |
+
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
| 122 |
+
if use_tqdm:
|
| 123 |
+
pbar.close()
|
| 124 |
+
return outputs
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
from multiprocessing.synchronize import Event
|
| 5 |
+
from multiprocessing.shared_memory import SharedMemory
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from nanovllm.config import Config
|
| 9 |
+
from nanovllm.engine.sequence import Sequence
|
| 10 |
+
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
| 11 |
+
from nanovllm.layers.sampler import Sampler
|
| 12 |
+
from nanovllm.utils.context import set_context, get_context, reset_context
|
| 13 |
+
from nanovllm.utils.loader import load_model
|
| 14 |
+
|
| 15 |
+
import socket
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
|
| 19 |
+
"""Find an available port starting from start_port.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
start_port: The starting port number to check
|
| 23 |
+
max_attempts: Maximum number of ports to try
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
An available port number
|
| 27 |
+
|
| 28 |
+
Raises:
|
| 29 |
+
RuntimeError: If no available port is found within max_attempts
|
| 30 |
+
"""
|
| 31 |
+
for i in range(max_attempts):
|
| 32 |
+
port = start_port + i
|
| 33 |
+
try:
|
| 34 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 35 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 36 |
+
s.bind(('localhost', port))
|
| 37 |
+
return port
|
| 38 |
+
except OSError:
|
| 39 |
+
# Port is in use, try next one
|
| 40 |
+
continue
|
| 41 |
+
raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ModelRunner:
|
| 45 |
+
|
| 46 |
+
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
| 47 |
+
# Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
|
| 48 |
+
torch._dynamo.config.capture_scalar_outputs = True
|
| 49 |
+
|
| 50 |
+
self.config = config
|
| 51 |
+
hf_config = config.hf_config
|
| 52 |
+
self.block_size = config.kvcache_block_size
|
| 53 |
+
self.enforce_eager = config.enforce_eager
|
| 54 |
+
self.world_size = config.tensor_parallel_size
|
| 55 |
+
self.rank = rank
|
| 56 |
+
self.event = event
|
| 57 |
+
dist_port = find_available_port()
|
| 58 |
+
print(f"[debug]dist_port: {dist_port}")
|
| 59 |
+
# Use gloo backend on Windows, nccl on Linux/other platforms
|
| 60 |
+
backend = "gloo" if sys.platform == "win32" else "nccl"
|
| 61 |
+
dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
|
| 62 |
+
torch.cuda.set_device(rank)
|
| 63 |
+
default_dtype = torch.get_default_dtype()
|
| 64 |
+
# Use dtype instead of deprecated torch_dtype
|
| 65 |
+
config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
|
| 66 |
+
torch.set_default_dtype(config_dtype)
|
| 67 |
+
torch.set_default_device("cuda")
|
| 68 |
+
self.model = Qwen3ForCausalLM(hf_config)
|
| 69 |
+
load_model(self.model, config.model)
|
| 70 |
+
self.sampler = Sampler()
|
| 71 |
+
|
| 72 |
+
# Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
|
| 73 |
+
# Must be called before warmup_model() since it uses these buffers
|
| 74 |
+
self._allocate_sample_buffers()
|
| 75 |
+
|
| 76 |
+
self.warmup_model()
|
| 77 |
+
self.allocate_kv_cache()
|
| 78 |
+
if not self.enforce_eager:
|
| 79 |
+
self.capture_cudagraph()
|
| 80 |
+
|
| 81 |
+
torch.set_default_device("cpu")
|
| 82 |
+
torch.set_default_dtype(default_dtype)
|
| 83 |
+
|
| 84 |
+
if self.world_size > 1:
|
| 85 |
+
if rank == 0:
|
| 86 |
+
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
|
| 87 |
+
dist.barrier()
|
| 88 |
+
else:
|
| 89 |
+
dist.barrier()
|
| 90 |
+
self.shm = SharedMemory(name="nanovllm")
|
| 91 |
+
self.loop()
|
| 92 |
+
|
| 93 |
+
def _allocate_sample_buffers(self):
|
| 94 |
+
"""Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
|
| 95 |
+
max_bs = self.config.max_num_seqs
|
| 96 |
+
max_tokens = self.config.max_num_batched_tokens
|
| 97 |
+
max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
|
| 98 |
+
|
| 99 |
+
# Pre-allocate pinned memory buffers on CPU for fast transfer
|
| 100 |
+
# Must explicitly specify device="cpu" since default device may be "cuda"
|
| 101 |
+
self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 102 |
+
self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 103 |
+
self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 104 |
+
self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 105 |
+
self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
|
| 106 |
+
|
| 107 |
+
# Pre-allocate decode buffers on CPU with pinned memory
|
| 108 |
+
self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 109 |
+
self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 110 |
+
self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 111 |
+
self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 112 |
+
|
| 113 |
+
# Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
|
| 114 |
+
self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 115 |
+
self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 116 |
+
self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 117 |
+
self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 118 |
+
|
| 119 |
+
# Pre-allocate block tables buffer (shared by both decode and prefill)
|
| 120 |
+
self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
|
| 121 |
+
|
| 122 |
+
# Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
|
| 123 |
+
# Max length is max_model_len since sequences can be that long
|
| 124 |
+
self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
|
| 125 |
+
|
| 126 |
+
def exit(self):
|
| 127 |
+
if self.world_size > 1:
|
| 128 |
+
self.shm.close()
|
| 129 |
+
dist.barrier()
|
| 130 |
+
if self.rank == 0:
|
| 131 |
+
self.shm.unlink()
|
| 132 |
+
if not self.enforce_eager:
|
| 133 |
+
del self.graphs, self.graph_pool
|
| 134 |
+
torch.cuda.synchronize()
|
| 135 |
+
dist.destroy_process_group()
|
| 136 |
+
|
| 137 |
+
def loop(self):
|
| 138 |
+
while True:
|
| 139 |
+
method_name, args = self.read_shm()
|
| 140 |
+
self.call(method_name, *args)
|
| 141 |
+
if method_name == "exit":
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
def read_shm(self):
|
| 145 |
+
assert self.world_size > 1 and self.rank > 0
|
| 146 |
+
self.event.wait()
|
| 147 |
+
n = int.from_bytes(self.shm.buf[0:4], "little")
|
| 148 |
+
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
| 149 |
+
self.event.clear()
|
| 150 |
+
return method_name, args
|
| 151 |
+
|
| 152 |
+
def write_shm(self, method_name, *args):
|
| 153 |
+
assert self.world_size > 1 and self.rank == 0
|
| 154 |
+
data = pickle.dumps([method_name, *args])
|
| 155 |
+
n = len(data)
|
| 156 |
+
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
| 157 |
+
self.shm.buf[4:n+4] = data
|
| 158 |
+
for event in self.event:
|
| 159 |
+
event.set()
|
| 160 |
+
|
| 161 |
+
def call(self, method_name, *args):
|
| 162 |
+
if self.world_size > 1 and self.rank == 0:
|
| 163 |
+
self.write_shm(method_name, *args)
|
| 164 |
+
method = getattr(self, method_name, None)
|
| 165 |
+
return method(*args)
|
| 166 |
+
|
| 167 |
+
def warmup_model(self):
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
torch.cuda.reset_peak_memory_stats()
|
| 170 |
+
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
| 171 |
+
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
| 172 |
+
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
| 173 |
+
self.run(seqs, True)
|
| 174 |
+
torch.cuda.empty_cache()
|
| 175 |
+
|
| 176 |
+
def allocate_kv_cache(self):
|
| 177 |
+
config = self.config
|
| 178 |
+
hf_config = config.hf_config
|
| 179 |
+
free, total = torch.cuda.mem_get_info()
|
| 180 |
+
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
| 181 |
+
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
| 182 |
+
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
| 183 |
+
# Use dtype instead of deprecated torch_dtype
|
| 184 |
+
config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
|
| 185 |
+
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * config_dtype.itemsize
|
| 186 |
+
|
| 187 |
+
# Calculate available memory for KV cache
|
| 188 |
+
# After warmup_model, empty_cache has been called, so current represents model memory only
|
| 189 |
+
# Use free memory but respect the gpu_memory_utilization limit
|
| 190 |
+
target_total_usage = total * config.gpu_memory_utilization
|
| 191 |
+
available_for_kv_cache = min(free * 0.9, target_total_usage - current)
|
| 192 |
+
|
| 193 |
+
# Ensure we have positive memory available
|
| 194 |
+
if available_for_kv_cache <= 0:
|
| 195 |
+
available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
|
| 196 |
+
|
| 197 |
+
config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
|
| 198 |
+
if config.num_kvcache_blocks <= 0:
|
| 199 |
+
raise RuntimeError(
|
| 200 |
+
f"Insufficient GPU memory for KV cache. "
|
| 201 |
+
f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
|
| 202 |
+
f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
|
| 203 |
+
f"Block size: {block_bytes / 1024**2:.2f} MB"
|
| 204 |
+
)
|
| 205 |
+
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
| 206 |
+
layer_id = 0
|
| 207 |
+
for module in self.model.modules():
|
| 208 |
+
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
| 209 |
+
module.k_cache = self.kv_cache[0, layer_id]
|
| 210 |
+
module.v_cache = self.kv_cache[1, layer_id]
|
| 211 |
+
layer_id += 1
|
| 212 |
+
|
| 213 |
+
def prepare_block_tables(self, seqs: list[Sequence]):
|
| 214 |
+
max_len = max(len(seq.block_table) for seq in seqs)
|
| 215 |
+
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
| 216 |
+
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 217 |
+
return block_tables
|
| 218 |
+
|
| 219 |
+
def prepare_prefill(self, seqs: list[Sequence]):
|
| 220 |
+
input_ids = []
|
| 221 |
+
positions = []
|
| 222 |
+
cu_seqlens_q = [0]
|
| 223 |
+
cu_seqlens_k = [0]
|
| 224 |
+
max_seqlen_q = 0
|
| 225 |
+
max_seqlen_k = 0
|
| 226 |
+
slot_mapping = []
|
| 227 |
+
block_tables = None
|
| 228 |
+
for seq in seqs:
|
| 229 |
+
seqlen = len(seq)
|
| 230 |
+
input_ids.extend(seq[seq.num_cached_tokens:])
|
| 231 |
+
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
| 232 |
+
seqlen_q = seqlen - seq.num_cached_tokens
|
| 233 |
+
seqlen_k = seqlen
|
| 234 |
+
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
| 235 |
+
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
| 236 |
+
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
| 237 |
+
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
| 238 |
+
if not seq.block_table: # warmup
|
| 239 |
+
continue
|
| 240 |
+
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
| 241 |
+
start = seq.block_table[i] * self.block_size
|
| 242 |
+
if i != seq.num_blocks - 1:
|
| 243 |
+
end = start + self.block_size
|
| 244 |
+
else:
|
| 245 |
+
end = start + seq.last_block_num_tokens
|
| 246 |
+
slot_mapping.extend(list(range(start, end)))
|
| 247 |
+
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
| 248 |
+
block_tables = self.prepare_block_tables(seqs)
|
| 249 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 250 |
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
| 251 |
+
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 252 |
+
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 253 |
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
| 254 |
+
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
| 255 |
+
return input_ids, positions
|
| 256 |
+
|
| 257 |
+
def prepare_decode(self, seqs: list[Sequence]):
|
| 258 |
+
"""Optimized decode preparation using pre-allocated buffers."""
|
| 259 |
+
bs = len(seqs)
|
| 260 |
+
|
| 261 |
+
# Use pre-allocated CPU buffers
|
| 262 |
+
for i, seq in enumerate(seqs):
|
| 263 |
+
self._cpu_input_ids[i] = seq.last_token
|
| 264 |
+
self._cpu_positions[i] = len(seq) - 1
|
| 265 |
+
self._cpu_context_lens[i] = len(seq)
|
| 266 |
+
self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
|
| 267 |
+
|
| 268 |
+
# Transfer to GPU using sliced views
|
| 269 |
+
input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
|
| 270 |
+
positions = self._cpu_positions[:bs].cuda(non_blocking=True)
|
| 271 |
+
slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
|
| 272 |
+
context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
|
| 273 |
+
block_tables = self.prepare_block_tables(seqs)
|
| 274 |
+
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
| 275 |
+
return input_ids, positions
|
| 276 |
+
|
| 277 |
+
def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
|
| 278 |
+
"""Optimized sample preparation using pre-allocated buffers."""
|
| 279 |
+
if is_cfg_batch:
|
| 280 |
+
num_seqs = len(seqs) // 2
|
| 281 |
+
target_seqs = seqs[:num_seqs]
|
| 282 |
+
else:
|
| 283 |
+
num_seqs = len(seqs)
|
| 284 |
+
target_seqs = seqs
|
| 285 |
+
|
| 286 |
+
# Fill pre-allocated CPU buffers
|
| 287 |
+
top_ks_is_zero = True
|
| 288 |
+
top_ps_is_one = True
|
| 289 |
+
repetition_penalties_is_one = True
|
| 290 |
+
for i, seq in enumerate(target_seqs):
|
| 291 |
+
self._cpu_temperatures[i] = seq.temperature
|
| 292 |
+
self._cpu_cfg_scales[i] = seq.cfg_scale
|
| 293 |
+
self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
|
| 294 |
+
if seq.top_k is not None and seq.top_k > 0:
|
| 295 |
+
top_ks_is_zero = False
|
| 296 |
+
self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
|
| 297 |
+
if seq.top_p is not None and seq.top_p == 1.0:
|
| 298 |
+
top_ps_is_one = False
|
| 299 |
+
self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
|
| 300 |
+
if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
|
| 301 |
+
repetition_penalties_is_one = False
|
| 302 |
+
|
| 303 |
+
# Transfer to GPU using sliced views (single batched transfer)
|
| 304 |
+
temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
|
| 305 |
+
cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
|
| 306 |
+
top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
|
| 307 |
+
top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
|
| 308 |
+
repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
|
| 309 |
+
|
| 310 |
+
return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
|
| 311 |
+
|
| 312 |
+
@torch.inference_mode()
|
| 313 |
+
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
| 314 |
+
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
| 315 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 316 |
+
else:
|
| 317 |
+
bs = input_ids.size(0)
|
| 318 |
+
context = get_context()
|
| 319 |
+
|
| 320 |
+
# Check if block_tables size exceeds pre-allocated buffer size
|
| 321 |
+
# This can happen when conditional and unconditional sequences have different lengths
|
| 322 |
+
# in CFG mode, causing block_tables to have more columns than expected
|
| 323 |
+
max_num_blocks = self.graph_vars["block_tables"].size(1)
|
| 324 |
+
if context.block_tables.size(1) > max_num_blocks:
|
| 325 |
+
# Fall back to eager mode when block_tables is too large for CUDA graph
|
| 326 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
| 327 |
+
|
| 328 |
+
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
| 329 |
+
graph_vars = self.graph_vars
|
| 330 |
+
graph_vars["input_ids"][:bs] = input_ids
|
| 331 |
+
graph_vars["positions"][:bs] = positions
|
| 332 |
+
graph_vars["slot_mapping"].fill_(-1)
|
| 333 |
+
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
| 334 |
+
graph_vars["context_lens"].zero_()
|
| 335 |
+
graph_vars["context_lens"][:bs] = context.context_lens
|
| 336 |
+
# Clear block_tables first to ensure no stale data from previous runs
|
| 337 |
+
graph_vars["block_tables"][:bs].fill_(-1)
|
| 338 |
+
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
| 339 |
+
graph.replay()
|
| 340 |
+
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
| 341 |
+
|
| 342 |
+
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
| 343 |
+
"""Run model forward and sampling. For CFG sequences, batch is structured as:
|
| 344 |
+
[cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 345 |
+
where uncond_seqi is the paired unconditional sequence of cond_seqi."""
|
| 346 |
+
# Check if this is a CFG batch (contains paired conditional and unconditional sequences)
|
| 347 |
+
is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
|
| 348 |
+
if is_cfg_batch:
|
| 349 |
+
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 350 |
+
num_cond = len(seqs) // 2
|
| 351 |
+
cond_seqs = seqs[:num_cond]
|
| 352 |
+
# uncond_seqs = seqs[num_cond:]
|
| 353 |
+
|
| 354 |
+
# Prepare inputs for both conditional and unconditional (they're already in the batch)
|
| 355 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
|
| 356 |
+
sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
|
| 357 |
+
if sample_params is not None:
|
| 358 |
+
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
| 359 |
+
else:
|
| 360 |
+
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
| 361 |
+
|
| 362 |
+
# Run model forward (processes entire batch: cond + uncond)
|
| 363 |
+
logits_all = self.run_model(input_ids, positions, is_prefill)
|
| 364 |
+
reset_context()
|
| 365 |
+
|
| 366 |
+
if self.rank == 0:
|
| 367 |
+
# Split logits: first half is conditional, second half is unconditional
|
| 368 |
+
logits_cond = logits_all[:num_cond]
|
| 369 |
+
logits_uncond = logits_all[num_cond:]
|
| 370 |
+
|
| 371 |
+
# Apply repetition penalty to conditional logits (before CFG)
|
| 372 |
+
if repetition_penalties is not None:
|
| 373 |
+
for i, seq in enumerate(cond_seqs):
|
| 374 |
+
penalty = repetition_penalties[i].item()
|
| 375 |
+
if penalty != 1.0:
|
| 376 |
+
# Only penalize completion tokens (not prompt tokens)
|
| 377 |
+
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
|
| 378 |
+
if len(completion_tokens) > 0:
|
| 379 |
+
# Create token mask: mark tokens that appeared in completion
|
| 380 |
+
token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
|
| 381 |
+
token_mask[completion_tokens] = True
|
| 382 |
+
|
| 383 |
+
# Apply standard repetition penalty formula (matching transformers implementation):
|
| 384 |
+
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
| 385 |
+
penalty_scores = torch.where(
|
| 386 |
+
logits_cond[i] < 0,
|
| 387 |
+
logits_cond[i] * penalty,
|
| 388 |
+
logits_cond[i] / penalty
|
| 389 |
+
)
|
| 390 |
+
# Only apply penalty to tokens that appeared in completion
|
| 391 |
+
logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
|
| 392 |
+
|
| 393 |
+
# Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
|
| 394 |
+
cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
|
| 395 |
+
logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
|
| 396 |
+
|
| 397 |
+
# Apply logits processor for constrained decoding (if any sequence has one)
|
| 398 |
+
for i, seq in enumerate(cond_seqs):
|
| 399 |
+
if seq.logits_processor is not None:
|
| 400 |
+
# Create input_ids tensor for this sequence
|
| 401 |
+
seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
|
| 402 |
+
# Apply processor to this sequence's logits
|
| 403 |
+
logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
|
| 404 |
+
|
| 405 |
+
# Prepare input_ids for sampler (for repetition penalty, though we already applied it)
|
| 406 |
+
# cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
|
| 407 |
+
|
| 408 |
+
# Sample from CFG logits
|
| 409 |
+
token_ids_cfg = self.sampler(
|
| 410 |
+
logits_cfg,
|
| 411 |
+
temperatures,
|
| 412 |
+
top_ks=top_ks if top_ks is not None else None,
|
| 413 |
+
top_ps=top_ps if top_ps is not None else None,
|
| 414 |
+
repetition_penalties=None, # Already applied above
|
| 415 |
+
# input_ids=cond_input_ids,
|
| 416 |
+
).tolist()
|
| 417 |
+
|
| 418 |
+
# Update logits processor state after sampling
|
| 419 |
+
for i, seq in enumerate(cond_seqs):
|
| 420 |
+
if seq.logits_processor_update_state is not None:
|
| 421 |
+
seq.logits_processor_update_state(token_ids_cfg[i])
|
| 422 |
+
|
| 423 |
+
# Return token_ids (will be applied to both conditional and unconditional sequences)
|
| 424 |
+
return token_ids_cfg
|
| 425 |
+
else:
|
| 426 |
+
return None
|
| 427 |
+
else:
|
| 428 |
+
# Normal batch (non-CFG)
|
| 429 |
+
input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
|
| 430 |
+
else self.prepare_decode(seqs))
|
| 431 |
+
sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
|
| 432 |
+
if sample_params is not None:
|
| 433 |
+
temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
|
| 434 |
+
else:
|
| 435 |
+
temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
|
| 436 |
+
logits = self.run_model(input_ids, positions, is_prefill)
|
| 437 |
+
reset_context()
|
| 438 |
+
|
| 439 |
+
if self.rank == 0:
|
| 440 |
+
# Apply repetition penalty to logits
|
| 441 |
+
if repetition_penalties is not None:
|
| 442 |
+
for i, seq in enumerate(seqs):
|
| 443 |
+
penalty = repetition_penalties[i].item()
|
| 444 |
+
if penalty != 1.0:
|
| 445 |
+
# Only penalize completion tokens (not prompt tokens)
|
| 446 |
+
completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
|
| 447 |
+
if len(completion_tokens) > 0:
|
| 448 |
+
# Create token mask: mark tokens that appeared in completion
|
| 449 |
+
token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
|
| 450 |
+
token_mask[completion_tokens] = True
|
| 451 |
+
|
| 452 |
+
# Apply standard repetition penalty formula (matching transformers implementation):
|
| 453 |
+
# For tokens in completion: if score < 0 then score * penalty, else score / penalty
|
| 454 |
+
penalty_scores = torch.where(
|
| 455 |
+
logits[i] < 0,
|
| 456 |
+
logits[i] * penalty,
|
| 457 |
+
logits[i] / penalty
|
| 458 |
+
)
|
| 459 |
+
# Only apply penalty to tokens that appeared in completion
|
| 460 |
+
logits[i] = torch.where(token_mask, penalty_scores, logits[i])
|
| 461 |
+
|
| 462 |
+
# Apply logits processor for constrained decoding (if any sequence has one)
|
| 463 |
+
# Clone logits to avoid in-place update issues in inference mode
|
| 464 |
+
logits = logits.clone()
|
| 465 |
+
for i, seq in enumerate(seqs):
|
| 466 |
+
if seq.logits_processor is not None:
|
| 467 |
+
# Create input_ids tensor for this sequence
|
| 468 |
+
seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
|
| 469 |
+
# Apply processor to this sequence's logits (clone to avoid inference mode issues)
|
| 470 |
+
processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
|
| 471 |
+
logits[i] = processed[0]
|
| 472 |
+
|
| 473 |
+
# Prepare input_ids for sampler
|
| 474 |
+
# seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
|
| 475 |
+
|
| 476 |
+
token_ids = self.sampler(
|
| 477 |
+
logits,
|
| 478 |
+
temperatures,
|
| 479 |
+
top_ks=top_ks if top_ks is not None else None,
|
| 480 |
+
top_ps=top_ps if top_ps is not None else None,
|
| 481 |
+
repetition_penalties=None, # Already applied above
|
| 482 |
+
# input_ids=seq_input_ids,
|
| 483 |
+
).tolist()
|
| 484 |
+
|
| 485 |
+
# Update logits processor state after sampling
|
| 486 |
+
for i, seq in enumerate(seqs):
|
| 487 |
+
if seq.logits_processor_update_state is not None:
|
| 488 |
+
seq.logits_processor_update_state(token_ids[i])
|
| 489 |
+
|
| 490 |
+
return token_ids
|
| 491 |
+
else:
|
| 492 |
+
return None
|
| 493 |
+
|
| 494 |
+
@torch.inference_mode()
|
| 495 |
+
def capture_cudagraph(self):
|
| 496 |
+
config = self.config
|
| 497 |
+
hf_config = config.hf_config
|
| 498 |
+
max_bs = min(self.config.max_num_seqs, 512)
|
| 499 |
+
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
| 500 |
+
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
| 501 |
+
positions = torch.zeros(max_bs, dtype=torch.int64)
|
| 502 |
+
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
| 503 |
+
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
| 504 |
+
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
| 505 |
+
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
| 506 |
+
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
| 507 |
+
self.graphs = {}
|
| 508 |
+
self.graph_pool = None
|
| 509 |
+
|
| 510 |
+
for bs in reversed(self.graph_bs):
|
| 511 |
+
graph = torch.cuda.CUDAGraph()
|
| 512 |
+
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
| 513 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
| 514 |
+
with torch.cuda.graph(graph, self.graph_pool):
|
| 515 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
| 516 |
+
if self.graph_pool is None:
|
| 517 |
+
self.graph_pool = graph.pool()
|
| 518 |
+
self.graphs[bs] = graph
|
| 519 |
+
torch.cuda.synchronize()
|
| 520 |
+
reset_context()
|
| 521 |
+
|
| 522 |
+
self.graph_vars = dict(
|
| 523 |
+
input_ids=input_ids,
|
| 524 |
+
positions=positions,
|
| 525 |
+
slot_mapping=slot_mapping,
|
| 526 |
+
context_lens=context_lens,
|
| 527 |
+
block_tables=block_tables,
|
| 528 |
+
outputs=outputs,
|
| 529 |
+
)
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import deque
|
| 2 |
+
|
| 3 |
+
from nanovllm.config import Config
|
| 4 |
+
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
| 5 |
+
from nanovllm.engine.block_manager import BlockManager
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Scheduler:
|
| 9 |
+
|
| 10 |
+
def __init__(self, config: Config):
|
| 11 |
+
self.max_num_seqs = config.max_num_seqs
|
| 12 |
+
self.max_num_batched_tokens = config.max_num_batched_tokens
|
| 13 |
+
self.eos = config.eos
|
| 14 |
+
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
| 15 |
+
self.waiting: deque[Sequence] = deque()
|
| 16 |
+
self.running: deque[Sequence] = deque()
|
| 17 |
+
|
| 18 |
+
def is_finished(self):
|
| 19 |
+
return not self.waiting and not self.running
|
| 20 |
+
|
| 21 |
+
def add(self, seq: Sequence):
|
| 22 |
+
self.waiting.append(seq)
|
| 23 |
+
|
| 24 |
+
def schedule(self) -> tuple[list[Sequence], bool]:
|
| 25 |
+
# prefill
|
| 26 |
+
scheduled_seqs = []
|
| 27 |
+
num_seqs = 0
|
| 28 |
+
num_batched_tokens = 0
|
| 29 |
+
processed_seqs = set() # Track processed sequences to handle CFG pairs
|
| 30 |
+
|
| 31 |
+
while self.waiting and num_seqs < self.max_num_seqs:
|
| 32 |
+
seq = self.waiting[0]
|
| 33 |
+
|
| 34 |
+
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
| 35 |
+
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
| 36 |
+
# This is a conditional sequence, need to schedule its paired unconditional sequence too
|
| 37 |
+
paired_seq = seq.paired_seq
|
| 38 |
+
if paired_seq.status != SequenceStatus.WAITING:
|
| 39 |
+
# Paired sequence not in waiting, skip this conditional sequence for now
|
| 40 |
+
break
|
| 41 |
+
|
| 42 |
+
# Calculate tokens for both sequences
|
| 43 |
+
total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
|
| 44 |
+
can_allocate_both = (self.block_manager.can_allocate(seq) and
|
| 45 |
+
self.block_manager.can_allocate(paired_seq))
|
| 46 |
+
|
| 47 |
+
if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
|
| 48 |
+
break
|
| 49 |
+
|
| 50 |
+
# Schedule both sequences: conditional first, then unconditional
|
| 51 |
+
for s in [seq, paired_seq]:
|
| 52 |
+
num_seqs += 1
|
| 53 |
+
self.block_manager.allocate(s)
|
| 54 |
+
num_batched_tokens += len(s) - s.num_cached_tokens
|
| 55 |
+
s.status = SequenceStatus.RUNNING
|
| 56 |
+
self.waiting.remove(s)
|
| 57 |
+
self.running.append(s)
|
| 58 |
+
scheduled_seqs.append(s)
|
| 59 |
+
processed_seqs.add(s.seq_id)
|
| 60 |
+
else:
|
| 61 |
+
# Normal sequence or unconditional sequence (already processed with its conditional)
|
| 62 |
+
if seq.seq_id in processed_seqs:
|
| 63 |
+
# Skip if already processed as part of a CFG pair
|
| 64 |
+
self.waiting.popleft()
|
| 65 |
+
continue
|
| 66 |
+
|
| 67 |
+
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
| 68 |
+
break
|
| 69 |
+
num_seqs += 1
|
| 70 |
+
self.block_manager.allocate(seq)
|
| 71 |
+
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
| 72 |
+
seq.status = SequenceStatus.RUNNING
|
| 73 |
+
self.waiting.popleft()
|
| 74 |
+
self.running.append(seq)
|
| 75 |
+
scheduled_seqs.append(seq)
|
| 76 |
+
|
| 77 |
+
if scheduled_seqs:
|
| 78 |
+
# For CFG batches, ensure conditional sequences come before their unconditional pairs
|
| 79 |
+
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
| 80 |
+
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
| 81 |
+
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
| 82 |
+
|
| 83 |
+
# Reorder: non-CFG, then CFG conditional, then CFG unconditional
|
| 84 |
+
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
| 85 |
+
return scheduled_seqs, True
|
| 86 |
+
|
| 87 |
+
# decode
|
| 88 |
+
processed_seqs = set()
|
| 89 |
+
temp_running = list(self.running) # Work with a copy
|
| 90 |
+
|
| 91 |
+
while temp_running and num_seqs < self.max_num_seqs:
|
| 92 |
+
seq = temp_running.pop(0)
|
| 93 |
+
|
| 94 |
+
# For CFG sequences, ensure conditional and unconditional are scheduled together
|
| 95 |
+
if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
|
| 96 |
+
paired_seq = seq.paired_seq
|
| 97 |
+
if paired_seq not in temp_running:
|
| 98 |
+
# Paired sequence not available, skip for now
|
| 99 |
+
continue
|
| 100 |
+
|
| 101 |
+
# Remove paired_seq from temp_running
|
| 102 |
+
temp_running.remove(paired_seq)
|
| 103 |
+
|
| 104 |
+
# Check if both can append
|
| 105 |
+
can_append_both = (self.block_manager.can_append(seq) and
|
| 106 |
+
self.block_manager.can_append(paired_seq))
|
| 107 |
+
|
| 108 |
+
if not can_append_both:
|
| 109 |
+
# Try preempting other sequences
|
| 110 |
+
preempted = False
|
| 111 |
+
while not can_append_both and temp_running:
|
| 112 |
+
other_seq = temp_running.pop(0)
|
| 113 |
+
if other_seq != seq and other_seq != paired_seq:
|
| 114 |
+
self.preempt(other_seq)
|
| 115 |
+
can_append_both = (self.block_manager.can_append(seq) and
|
| 116 |
+
self.block_manager.can_append(paired_seq))
|
| 117 |
+
preempted = True
|
| 118 |
+
else:
|
| 119 |
+
temp_running.append(other_seq)
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
if not can_append_both:
|
| 123 |
+
# Can't schedule this pair right now
|
| 124 |
+
temp_running.append(seq)
|
| 125 |
+
temp_running.append(paired_seq)
|
| 126 |
+
continue
|
| 127 |
+
|
| 128 |
+
# Schedule both sequences
|
| 129 |
+
for s in [seq, paired_seq]:
|
| 130 |
+
num_seqs += 1
|
| 131 |
+
self.block_manager.may_append(s)
|
| 132 |
+
scheduled_seqs.append(s)
|
| 133 |
+
processed_seqs.add(s.seq_id)
|
| 134 |
+
# Remove from actual running list if scheduled
|
| 135 |
+
if s in self.running:
|
| 136 |
+
self.running.remove(s)
|
| 137 |
+
else:
|
| 138 |
+
# Normal sequence or unconditional (already processed)
|
| 139 |
+
if seq.seq_id in processed_seqs:
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
while not self.block_manager.can_append(seq):
|
| 143 |
+
if temp_running:
|
| 144 |
+
other_seq = temp_running.pop(0)
|
| 145 |
+
if other_seq != seq:
|
| 146 |
+
self.preempt(other_seq)
|
| 147 |
+
else:
|
| 148 |
+
temp_running.append(other_seq)
|
| 149 |
+
break
|
| 150 |
+
else:
|
| 151 |
+
self.preempt(seq)
|
| 152 |
+
if seq in self.running:
|
| 153 |
+
self.running.remove(seq)
|
| 154 |
+
break
|
| 155 |
+
else:
|
| 156 |
+
num_seqs += 1
|
| 157 |
+
self.block_manager.may_append(seq)
|
| 158 |
+
scheduled_seqs.append(seq)
|
| 159 |
+
if seq in self.running:
|
| 160 |
+
self.running.remove(seq)
|
| 161 |
+
|
| 162 |
+
assert scheduled_seqs
|
| 163 |
+
|
| 164 |
+
# For CFG batches in decode, ensure conditional sequences come before unconditional
|
| 165 |
+
cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
|
| 166 |
+
cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
|
| 167 |
+
non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
|
| 168 |
+
scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
|
| 169 |
+
|
| 170 |
+
self.running.extendleft(reversed(scheduled_seqs))
|
| 171 |
+
return scheduled_seqs, False
|
| 172 |
+
|
| 173 |
+
def preempt(self, seq: Sequence):
|
| 174 |
+
seq.status = SequenceStatus.WAITING
|
| 175 |
+
self.block_manager.deallocate(seq)
|
| 176 |
+
self.waiting.appendleft(seq)
|
| 177 |
+
|
| 178 |
+
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
| 179 |
+
# Check if this is a CFG batch
|
| 180 |
+
is_cfg_batch = False
|
| 181 |
+
if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
|
| 182 |
+
num_cond = len(seqs) // 2
|
| 183 |
+
is_cfg_batch = (num_cond > 0 and
|
| 184 |
+
not seqs[0].is_unconditional and
|
| 185 |
+
seqs[num_cond].is_unconditional)
|
| 186 |
+
|
| 187 |
+
if is_cfg_batch:
|
| 188 |
+
# CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
|
| 189 |
+
# token_ids correspond to conditional sequences only (sampled from CFG logits)
|
| 190 |
+
num_cond = len(seqs) // 2
|
| 191 |
+
cond_seqs = seqs[:num_cond]
|
| 192 |
+
uncond_seqs = seqs[num_cond:]
|
| 193 |
+
|
| 194 |
+
# Apply the same sampled token to both conditional and unconditional sequences
|
| 195 |
+
for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
|
| 196 |
+
cond_seq.append_token(token_id)
|
| 197 |
+
uncond_seq.append_token(token_id) # Same token for unconditional
|
| 198 |
+
|
| 199 |
+
# Check if either sequence is finished
|
| 200 |
+
cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
|
| 201 |
+
cond_seq.num_completion_tokens == cond_seq.max_tokens)
|
| 202 |
+
uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
|
| 203 |
+
uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
|
| 204 |
+
|
| 205 |
+
if cond_finished or uncond_finished:
|
| 206 |
+
# Mark both as finished
|
| 207 |
+
cond_seq.status = SequenceStatus.FINISHED
|
| 208 |
+
uncond_seq.status = SequenceStatus.FINISHED
|
| 209 |
+
self.block_manager.deallocate(cond_seq)
|
| 210 |
+
self.block_manager.deallocate(uncond_seq)
|
| 211 |
+
if cond_seq in self.running:
|
| 212 |
+
self.running.remove(cond_seq)
|
| 213 |
+
if uncond_seq in self.running:
|
| 214 |
+
self.running.remove(uncond_seq)
|
| 215 |
+
else:
|
| 216 |
+
# Normal batch
|
| 217 |
+
for seq, token_id in zip(seqs, token_ids):
|
| 218 |
+
seq.append_token(token_id)
|
| 219 |
+
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
| 220 |
+
seq.status = SequenceStatus.FINISHED
|
| 221 |
+
self.block_manager.deallocate(seq)
|
| 222 |
+
self.running.remove(seq)
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import copy
|
| 2 |
+
from enum import Enum, auto
|
| 3 |
+
from itertools import count
|
| 4 |
+
from typing import Optional, Callable, Any
|
| 5 |
+
|
| 6 |
+
from nanovllm.sampling_params import SamplingParams
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SequenceStatus(Enum):
|
| 10 |
+
WAITING = auto()
|
| 11 |
+
RUNNING = auto()
|
| 12 |
+
FINISHED = auto()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Sequence:
|
| 16 |
+
block_size = 256
|
| 17 |
+
counter = count()
|
| 18 |
+
|
| 19 |
+
def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
|
| 20 |
+
self.seq_id = next(Sequence.counter)
|
| 21 |
+
self.status = SequenceStatus.WAITING
|
| 22 |
+
self.token_ids = copy(token_ids)
|
| 23 |
+
self.last_token = token_ids[-1]
|
| 24 |
+
self.num_tokens = len(self.token_ids)
|
| 25 |
+
self.num_prompt_tokens = len(token_ids)
|
| 26 |
+
self.num_cached_tokens = 0
|
| 27 |
+
self.block_table = []
|
| 28 |
+
self.temperature = sampling_params.temperature
|
| 29 |
+
self.max_tokens = sampling_params.max_tokens
|
| 30 |
+
self.ignore_eos = sampling_params.ignore_eos
|
| 31 |
+
self.cfg_scale = sampling_params.cfg_scale
|
| 32 |
+
self.top_k = sampling_params.top_k
|
| 33 |
+
self.top_p = sampling_params.top_p
|
| 34 |
+
self.repetition_penalty = sampling_params.repetition_penalty
|
| 35 |
+
# For CFG: mark if this is an unconditional sequence
|
| 36 |
+
self.is_unconditional = is_unconditional
|
| 37 |
+
# For CFG: reference to the corresponding conditional sequence (if this is unconditional)
|
| 38 |
+
# For conditional sequences, this points to the unconditional sequence
|
| 39 |
+
self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
|
| 40 |
+
# For constrained decoding: logits processor and state update callback
|
| 41 |
+
self.logits_processor: Optional[Any] = sampling_params.logits_processor
|
| 42 |
+
self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
|
| 43 |
+
|
| 44 |
+
def __len__(self):
|
| 45 |
+
return self.num_tokens
|
| 46 |
+
|
| 47 |
+
def __getitem__(self, key):
|
| 48 |
+
return self.token_ids[key]
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def is_finished(self):
|
| 52 |
+
return self.status == SequenceStatus.FINISHED
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def num_completion_tokens(self):
|
| 56 |
+
return self.num_tokens - self.num_prompt_tokens
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def prompt_token_ids(self):
|
| 60 |
+
return self.token_ids[:self.num_prompt_tokens]
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def completion_token_ids(self):
|
| 64 |
+
return self.token_ids[self.num_prompt_tokens:]
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def num_cached_blocks(self):
|
| 68 |
+
return self.num_cached_tokens // self.block_size
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def num_blocks(self):
|
| 72 |
+
return (self.num_tokens + self.block_size - 1) // self.block_size
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def last_block_num_tokens(self):
|
| 76 |
+
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
| 77 |
+
|
| 78 |
+
def block(self, i):
|
| 79 |
+
assert 0 <= i < self.num_blocks
|
| 80 |
+
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
| 81 |
+
|
| 82 |
+
def append_token(self, token_id: int):
|
| 83 |
+
self.token_ids.append(token_id)
|
| 84 |
+
self.last_token = token_id
|
| 85 |
+
self.num_tokens += 1
|
| 86 |
+
|
| 87 |
+
def __getstate__(self):
|
| 88 |
+
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
| 89 |
+
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
| 90 |
+
|
| 91 |
+
def __setstate__(self, state):
|
| 92 |
+
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
| 93 |
+
if self.num_completion_tokens == 0:
|
| 94 |
+
self.token_ids = state[-1]
|
| 95 |
+
else:
|
| 96 |
+
self.last_token = state[-1]
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/activation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SiluAndMul(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self):
|
| 9 |
+
super().__init__()
|
| 10 |
+
|
| 11 |
+
@torch.compile
|
| 12 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 13 |
+
x, y = x.chunk(2, -1)
|
| 14 |
+
return F.silu(x) * y
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/attention.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import triton
|
| 4 |
+
import triton.language as tl
|
| 5 |
+
|
| 6 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
| 7 |
+
from nanovllm.utils.context import get_context
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@triton.jit
|
| 11 |
+
def store_kvcache_kernel(
|
| 12 |
+
key_ptr,
|
| 13 |
+
key_stride,
|
| 14 |
+
value_ptr,
|
| 15 |
+
value_stride,
|
| 16 |
+
k_cache_ptr,
|
| 17 |
+
v_cache_ptr,
|
| 18 |
+
slot_mapping_ptr,
|
| 19 |
+
D: tl.constexpr,
|
| 20 |
+
):
|
| 21 |
+
idx = tl.program_id(0)
|
| 22 |
+
slot = tl.load(slot_mapping_ptr + idx)
|
| 23 |
+
if slot == -1: return
|
| 24 |
+
key_offsets = idx * key_stride + tl.arange(0, D)
|
| 25 |
+
value_offsets = idx * value_stride + tl.arange(0, D)
|
| 26 |
+
key = tl.load(key_ptr + key_offsets)
|
| 27 |
+
value = tl.load(value_ptr + value_offsets)
|
| 28 |
+
cache_offsets = slot * D + tl.arange(0, D)
|
| 29 |
+
tl.store(k_cache_ptr + cache_offsets, key)
|
| 30 |
+
tl.store(v_cache_ptr + cache_offsets, value)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
| 34 |
+
N, num_heads, head_dim = key.shape
|
| 35 |
+
D = num_heads * head_dim
|
| 36 |
+
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
| 37 |
+
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
| 38 |
+
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
| 39 |
+
assert slot_mapping.numel() == N
|
| 40 |
+
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class Attention(nn.Module):
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
num_heads,
|
| 48 |
+
head_dim,
|
| 49 |
+
scale,
|
| 50 |
+
num_kv_heads,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.head_dim = head_dim
|
| 55 |
+
self.scale = scale
|
| 56 |
+
self.num_kv_heads = num_kv_heads
|
| 57 |
+
self.k_cache = self.v_cache = torch.tensor([])
|
| 58 |
+
|
| 59 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
| 60 |
+
context = get_context()
|
| 61 |
+
k_cache, v_cache = self.k_cache, self.v_cache
|
| 62 |
+
if k_cache.numel() and v_cache.numel():
|
| 63 |
+
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
| 64 |
+
if context.is_prefill:
|
| 65 |
+
if context.block_tables is not None: # prefix cache
|
| 66 |
+
k, v = k_cache, v_cache
|
| 67 |
+
o = flash_attn_varlen_func(q, k, v,
|
| 68 |
+
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
| 69 |
+
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
| 70 |
+
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
| 71 |
+
else: # decode
|
| 72 |
+
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
| 73 |
+
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
| 74 |
+
softmax_scale=self.scale, causal=True)
|
| 75 |
+
return o
|
spaces/Ace-Step-v1.5/acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
from nanovllm.utils.context import get_context
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class VocabParallelEmbedding(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
num_embeddings: int,
|
| 14 |
+
embedding_dim: int,
|
| 15 |
+
):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.tp_rank = dist.get_rank()
|
| 18 |
+
self.tp_size = dist.get_world_size()
|
| 19 |
+
assert num_embeddings % self.tp_size == 0
|
| 20 |
+
self.num_embeddings = num_embeddings
|
| 21 |
+
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
| 22 |
+
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
| 23 |
+
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
| 24 |
+
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
| 25 |
+
self.weight.weight_loader = self.weight_loader
|
| 26 |
+
|
| 27 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
| 28 |
+
param_data = param.data
|
| 29 |
+
shard_size = param_data.size(0)
|
| 30 |
+
start_idx = self.tp_rank * shard_size
|
| 31 |
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
| 32 |
+
param_data.copy_(loaded_weight)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: torch.Tensor):
|
| 35 |
+
if self.tp_size > 1:
|
| 36 |
+
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
| 37 |
+
x = mask * (x - self.vocab_start_idx)
|
| 38 |
+
y = F.embedding(x, self.weight)
|
| 39 |
+
if self.tp_size > 1:
|
| 40 |
+
y = mask.unsqueeze(1) * y
|
| 41 |
+
dist.all_reduce(y)
|
| 42 |
+
return y
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class ParallelLMHead(VocabParallelEmbedding):
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
num_embeddings: int,
|
| 50 |
+
embedding_dim: int,
|
| 51 |
+
bias: bool = False,
|
| 52 |
+
):
|
| 53 |
+
assert not bias
|
| 54 |
+
super().__init__(num_embeddings, embedding_dim)
|
| 55 |
+
|
| 56 |
+
def forward(self, x: torch.Tensor):
|
| 57 |
+
context = get_context()
|
| 58 |
+
if context.is_prefill:
|
| 59 |
+
last_indices = context.cu_seqlens_q[1:] - 1
|
| 60 |
+
x = x[last_indices].contiguous()
|
| 61 |
+
logits = F.linear(x, self.weight)
|
| 62 |
+
if self.tp_size > 1:
|
| 63 |
+
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
| 64 |
+
dist.gather(logits, all_logits, 0)
|
| 65 |
+
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
| 66 |
+
return logits
|