Spaces:
Running on Zero
Running on Zero
ACE-Step Custom commited on
Commit ·
a602628
0
Parent(s):
Deploy ACE-Step Custom Edition with bug fixes
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +3 -0
- .gitignore +67 -0
- .python-version +1 -0
- DEPLOYMENT.md +367 -0
- DEPLOYMENT_CHECKLIST.txt +226 -0
- DEPLOY_QUICK.md +126 -0
- Dockerfile +35 -0
- LICENSE +28 -0
- QUICKSTART.md +115 -0
- README.md +73 -0
- README_HF.md +73 -0
- README_PROJECT.md +116 -0
- acestep/__init__.py +1 -0
- acestep/acestep_v15_pipeline.py +411 -0
- acestep/api_server.py +0 -0
- acestep/audio_utils.py +354 -0
- acestep/constants.py +193 -0
- acestep/constrained_logits_processor.py +0 -0
- acestep/dataset_handler.py +83 -0
- acestep/debug_utils.py +122 -0
- acestep/dit_alignment_score.py +877 -0
- acestep/genres_vocab.txt +0 -0
- acestep/gpu_config.py +549 -0
- acestep/gradio_ui/__init__.py +1 -0
- acestep/gradio_ui/api_routes.py +564 -0
- acestep/gradio_ui/events/__init__.py +1254 -0
- acestep/gradio_ui/events/generation_handlers.py +1050 -0
- acestep/gradio_ui/events/results_handlers.py +0 -0
- acestep/gradio_ui/events/training_handlers.py +829 -0
- acestep/gradio_ui/i18n.py +152 -0
- acestep/gradio_ui/i18n/en.json +354 -0
- acestep/gradio_ui/i18n/he.json +352 -0
- acestep/gradio_ui/i18n/ja.json +354 -0
- acestep/gradio_ui/i18n/zh.json +350 -0
- acestep/gradio_ui/interfaces/__init__.py +94 -0
- acestep/gradio_ui/interfaces/dataset.py +101 -0
- acestep/gradio_ui/interfaces/generation.py +824 -0
- acestep/gradio_ui/interfaces/result.py +552 -0
- acestep/gradio_ui/interfaces/training.py +625 -0
- acestep/handler.py +0 -0
- acestep/inference.py +1310 -0
- acestep/llm_inference.py +0 -0
- acestep/local_cache.py +129 -0
- acestep/model_downloader.py +634 -0
- acestep/openrouter_adapter.py +773 -0
- acestep/openrouter_models.py +244 -0
- acestep/test_time_scaling.py +410 -0
- acestep/third_parts/nano-vllm/LICENSE +21 -0
- acestep/third_parts/nano-vllm/README.md +66 -0
- acestep/third_parts/nano-vllm/bench.py +32 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
|
| 28 |
+
# IDE
|
| 29 |
+
.vscode/
|
| 30 |
+
.idea/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
*~
|
| 34 |
+
|
| 35 |
+
# OS
|
| 36 |
+
.DS_Store
|
| 37 |
+
Thumbs.db
|
| 38 |
+
|
| 39 |
+
# Application
|
| 40 |
+
outputs/
|
| 41 |
+
timelines/
|
| 42 |
+
lora_training/prepared_data/
|
| 43 |
+
lora_training/models/
|
| 44 |
+
logs/
|
| 45 |
+
models/
|
| 46 |
+
*.wav
|
| 47 |
+
*.mp3
|
| 48 |
+
*.flac
|
| 49 |
+
*.ogg
|
| 50 |
+
|
| 51 |
+
# Config (keep example)
|
| 52 |
+
config.yaml
|
| 53 |
+
|
| 54 |
+
# Jupyter
|
| 55 |
+
.ipynb_checkpoints/
|
| 56 |
+
*.ipynb
|
| 57 |
+
|
| 58 |
+
# Model cache
|
| 59 |
+
.cache/
|
| 60 |
+
huggingface/
|
| 61 |
+
|
| 62 |
+
# Environment
|
| 63 |
+
.env
|
| 64 |
+
.env.local
|
| 65 |
+
|
| 66 |
+
# Test outputs
|
| 67 |
+
test_outputs/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.11.0
|
DEPLOYMENT.md
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# HuggingFace Spaces Deployment Guide
|
| 2 |
+
|
| 3 |
+
## Quick Deploy to HuggingFace Spaces
|
| 4 |
+
|
| 5 |
+
### Prerequisites
|
| 6 |
+
- HuggingFace account (create at https://huggingface.co/join)
|
| 7 |
+
- Git installed on your machine
|
| 8 |
+
- Git LFS installed (for large files)
|
| 9 |
+
|
| 10 |
+
### Method 1: Web Upload (Easiest)
|
| 11 |
+
|
| 12 |
+
1. **Create New Space**
|
| 13 |
+
- Go to https://huggingface.co/new-space
|
| 14 |
+
- Name: `ace-step-custom` (or your choice)
|
| 15 |
+
- License: MIT
|
| 16 |
+
- SDK: Gradio
|
| 17 |
+
- Hardware: A10G Small (or better)
|
| 18 |
+
- Click "Create Space"
|
| 19 |
+
|
| 20 |
+
2. **Upload Files**
|
| 21 |
+
- Click "Files and versions" tab
|
| 22 |
+
- Click "Add file" → "Upload files"
|
| 23 |
+
- Upload all files from `d:\2025-vibe-coding\ACE-Step-Custom\`:
|
| 24 |
+
- `app.py`
|
| 25 |
+
- `requirements.txt`
|
| 26 |
+
- `config.yaml`
|
| 27 |
+
- `README.md` (with YAML frontmatter)
|
| 28 |
+
- `LICENSE`
|
| 29 |
+
- `.gitignore`
|
| 30 |
+
- Entire `src/` directory
|
| 31 |
+
- Entire `scripts/` directory
|
| 32 |
+
- Commit changes
|
| 33 |
+
|
| 34 |
+
3. **Configure Space**
|
| 35 |
+
- Go to "Settings" tab
|
| 36 |
+
- Set Python version: 3.10
|
| 37 |
+
- Enable GPU: A10G Small (minimum) or A100 (recommended)
|
| 38 |
+
- Set timeout: 30 minutes (for long generations)
|
| 39 |
+
- Save settings
|
| 40 |
+
|
| 41 |
+
4. **Wait for Build**
|
| 42 |
+
- Space will automatically build and deploy
|
| 43 |
+
- First build takes 5-10 minutes
|
| 44 |
+
- Model will download on first run (~7GB)
|
| 45 |
+
|
| 46 |
+
### Method 2: Git Push (For Developers)
|
| 47 |
+
|
| 48 |
+
1. **Create Space on HuggingFace**
|
| 49 |
+
- Go to https://huggingface.co/new-space
|
| 50 |
+
- Create space as above
|
| 51 |
+
|
| 52 |
+
2. **Clone and Push**
|
| 53 |
+
```powershell
|
| 54 |
+
# Navigate to project
|
| 55 |
+
cd d:\2025-vibe-coding\ACE-Step-Custom
|
| 56 |
+
|
| 57 |
+
# Initialize git (if not already)
|
| 58 |
+
git init
|
| 59 |
+
git add .
|
| 60 |
+
git commit -m "Initial commit"
|
| 61 |
+
|
| 62 |
+
# Add HuggingFace remote
|
| 63 |
+
git remote add hf https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
|
| 64 |
+
|
| 65 |
+
# Push to HuggingFace
|
| 66 |
+
git push hf main
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
3. **Configure Git LFS for Large Files**
|
| 70 |
+
```powershell
|
| 71 |
+
git lfs install
|
| 72 |
+
git lfs track "*.wav"
|
| 73 |
+
git lfs track "*.pth"
|
| 74 |
+
git lfs track "*.bin"
|
| 75 |
+
git lfs track "models/**"
|
| 76 |
+
git add .gitattributes
|
| 77 |
+
git commit -m "Add LFS tracking"
|
| 78 |
+
git push hf main
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
### Method 3: HuggingFace CLI (Fastest)
|
| 82 |
+
|
| 83 |
+
1. **Install HuggingFace CLI**
|
| 84 |
+
```powershell
|
| 85 |
+
pip install huggingface_hub
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
2. **Login**
|
| 89 |
+
```powershell
|
| 90 |
+
huggingface-cli login
|
| 91 |
+
# Enter your HuggingFace token
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
3. **Create and Upload**
|
| 95 |
+
```powershell
|
| 96 |
+
cd d:\2025-vibe-coding\ACE-Step-Custom
|
| 97 |
+
|
| 98 |
+
# Create space
|
| 99 |
+
huggingface-cli repo create ace-step-custom --type space --space_sdk gradio
|
| 100 |
+
|
| 101 |
+
# Upload files
|
| 102 |
+
huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Space Configuration
|
| 106 |
+
|
| 107 |
+
### Hardware Recommendations
|
| 108 |
+
|
| 109 |
+
| GPU | VRAM | Cost | Performance | Recommended For |
|
| 110 |
+
|-----|------|------|-------------|-----------------|
|
| 111 |
+
| CPU | - | Free | Very Slow | Testing only |
|
| 112 |
+
| T4 Small | 16GB | ~$0.60/hr | Slow | Light testing |
|
| 113 |
+
| **A10G Small** | **24GB** | **~$1.05/hr** | **Good** | **Recommended** |
|
| 114 |
+
| A10G Large | 24GB | ~$3.15/hr | Good | Production |
|
| 115 |
+
| A100 Large | 40GB | ~$4.13/hr | Excellent | Best quality |
|
| 116 |
+
|
| 117 |
+
**Recommendation:** Start with A10G Small for testing, upgrade to A100 for production.
|
| 118 |
+
|
| 119 |
+
### Environment Variables (Optional)
|
| 120 |
+
|
| 121 |
+
In Space settings, you can add:
|
| 122 |
+
|
| 123 |
+
```
|
| 124 |
+
GRADIO_SERVER_NAME=0.0.0.0
|
| 125 |
+
GRADIO_SERVER_PORT=7860
|
| 126 |
+
HF_HOME=/data/huggingface
|
| 127 |
+
TORCH_HOME=/data/torch
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
### Secrets (If Needed)
|
| 131 |
+
|
| 132 |
+
For API keys or sensitive data:
|
| 133 |
+
- Go to Space Settings → Repository secrets
|
| 134 |
+
- Add secrets like `HF_TOKEN`, `API_KEY`, etc.
|
| 135 |
+
- Access in code: `os.environ.get("SECRET_NAME")`
|
| 136 |
+
|
| 137 |
+
## Post-Deployment Setup
|
| 138 |
+
|
| 139 |
+
### First Launch
|
| 140 |
+
|
| 141 |
+
1. **Wait for Model Download**
|
| 142 |
+
- First launch downloads ACE-Step model (~7GB)
|
| 143 |
+
- Takes 5-10 minutes depending on connection
|
| 144 |
+
- Model cached for subsequent runs
|
| 145 |
+
|
| 146 |
+
2. **Test Basic Generation**
|
| 147 |
+
- Go to Tab 1 (Standard ACE-Step)
|
| 148 |
+
- Enter simple prompt: "Happy pop song"
|
| 149 |
+
- Set duration to 10 seconds
|
| 150 |
+
- Click Generate
|
| 151 |
+
|
| 152 |
+
3. **Test Timeline**
|
| 153 |
+
- Go to Tab 2 (Timeline Workflow)
|
| 154 |
+
- Enter prompt and lyrics
|
| 155 |
+
- Set context length to 30s
|
| 156 |
+
- Generate first clip
|
| 157 |
+
|
| 158 |
+
4. **Test LoRA Training**
|
| 159 |
+
- Go to Tab 3 (LoRA Training)
|
| 160 |
+
- Upload 2-3 test audio files
|
| 161 |
+
- Run quick training (2-3 epochs)
|
| 162 |
+
|
| 163 |
+
### Monitoring
|
| 164 |
+
|
| 165 |
+
**View Logs:**
|
| 166 |
+
- Click "Logs" tab in your Space
|
| 167 |
+
- Monitor for errors or warnings
|
| 168 |
+
- Check GPU usage and memory
|
| 169 |
+
|
| 170 |
+
**Performance Metrics:**
|
| 171 |
+
- Generation time
|
| 172 |
+
- Memory usage
|
| 173 |
+
- Error rate
|
| 174 |
+
- User feedback
|
| 175 |
+
|
| 176 |
+
### Troubleshooting
|
| 177 |
+
|
| 178 |
+
**Space Not Building:**
|
| 179 |
+
- Check requirements.txt for conflicts
|
| 180 |
+
- Verify Python 3.10 compatibility
|
| 181 |
+
- Check logs for specific errors
|
| 182 |
+
|
| 183 |
+
**Out of Memory:**
|
| 184 |
+
- Upgrade to larger GPU
|
| 185 |
+
- Reduce batch size in LoRA training
|
| 186 |
+
- Limit generation duration
|
| 187 |
+
|
| 188 |
+
**Model Not Loading:**
|
| 189 |
+
- Check HuggingFace Hub access
|
| 190 |
+
- Verify model ID in config.yaml
|
| 191 |
+
- Check internet connectivity
|
| 192 |
+
|
| 193 |
+
**Slow Performance:**
|
| 194 |
+
- Upgrade GPU tier
|
| 195 |
+
- Reduce concurrent users
|
| 196 |
+
- Optimize generation parameters
|
| 197 |
+
|
| 198 |
+
## Optimization Tips
|
| 199 |
+
|
| 200 |
+
### Reduce Startup Time
|
| 201 |
+
|
| 202 |
+
1. **Cache Models**
|
| 203 |
+
```python
|
| 204 |
+
# In app.py, add before model loading:
|
| 205 |
+
os.environ["HF_HOME"] = "/data/huggingface"
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
2. **Preload on Startup**
|
| 209 |
+
- Models download on first run
|
| 210 |
+
- Cached for subsequent uses
|
| 211 |
+
- Consider pre-downloading to Space
|
| 212 |
+
|
| 213 |
+
### Improve Response Time
|
| 214 |
+
|
| 215 |
+
1. **Use Queuing**
|
| 216 |
+
- Gradio automatically queues requests
|
| 217 |
+
- Set `max_size` in `app.launch()`
|
| 218 |
+
|
| 219 |
+
2. **Optimize Generation**
|
| 220 |
+
- Lower default duration
|
| 221 |
+
- Reduce sampling steps
|
| 222 |
+
- Use FP16 precision
|
| 223 |
+
|
| 224 |
+
### Cost Optimization
|
| 225 |
+
|
| 226 |
+
1. **Auto-Sleep**
|
| 227 |
+
- Space sleeps after inactivity
|
| 228 |
+
- Wakes on first request
|
| 229 |
+
- Configure in Space settings
|
| 230 |
+
|
| 231 |
+
2. **Usage Limits**
|
| 232 |
+
- Set max concurrent users
|
| 233 |
+
- Limit generation duration
|
| 234 |
+
- Add rate limiting if needed
|
| 235 |
+
|
| 236 |
+
## Going Live
|
| 237 |
+
|
| 238 |
+
### Before Public Release
|
| 239 |
+
|
| 240 |
+
- [ ] Test all three tabs thoroughly
|
| 241 |
+
- [ ] Verify LoRA training works
|
| 242 |
+
- [ ] Test with different prompts and styles
|
| 243 |
+
- [ ] Check error handling
|
| 244 |
+
- [ ] Review logs for issues
|
| 245 |
+
- [ ] Test on mobile devices
|
| 246 |
+
- [ ] Add usage examples
|
| 247 |
+
- [ ] Create demo video
|
| 248 |
+
|
| 249 |
+
### Public Space Settings
|
| 250 |
+
|
| 251 |
+
1. **Enable Discussions**
|
| 252 |
+
- Let users report issues
|
| 253 |
+
- Gather feedback
|
| 254 |
+
|
| 255 |
+
2. **Add Examples**
|
| 256 |
+
- Create example prompts
|
| 257 |
+
- Show best practices
|
| 258 |
+
- Include sample outputs
|
| 259 |
+
|
| 260 |
+
3. **Update README**
|
| 261 |
+
- Clear usage instructions
|
| 262 |
+
- Feature highlights
|
| 263 |
+
- Limitations and known issues
|
| 264 |
+
|
| 265 |
+
4. **Pin Space**
|
| 266 |
+
- Makes it discoverable
|
| 267 |
+
- Shows on your profile
|
| 268 |
+
|
| 269 |
+
## Maintenance
|
| 270 |
+
|
| 271 |
+
### Regular Updates
|
| 272 |
+
|
| 273 |
+
```powershell
|
| 274 |
+
# Update code
|
| 275 |
+
cd d:\2025-vibe-coding\ACE-Step-Custom
|
| 276 |
+
git add .
|
| 277 |
+
git commit -m "Update description"
|
| 278 |
+
git push hf main
|
| 279 |
+
```
|
| 280 |
+
|
| 281 |
+
### Monitor Usage
|
| 282 |
+
|
| 283 |
+
- Check Space analytics
|
| 284 |
+
- Review user feedback
|
| 285 |
+
- Monitor error rates
|
| 286 |
+
- Track popular features
|
| 287 |
+
|
| 288 |
+
### Scaling
|
| 289 |
+
|
| 290 |
+
**If Space Gets Popular:**
|
| 291 |
+
1. Upgrade GPU tier
|
| 292 |
+
2. Add request queuing
|
| 293 |
+
3. Implement caching
|
| 294 |
+
4. Consider duplicate Spaces for load balancing
|
| 295 |
+
|
| 296 |
+
## Support & Community
|
| 297 |
+
|
| 298 |
+
### Get Help
|
| 299 |
+
|
| 300 |
+
- HuggingFace Forums: https://discuss.huggingface.co/
|
| 301 |
+
- Discord: https://discord.gg/huggingface
|
| 302 |
+
- Docs: https://huggingface.co/docs/hub/spaces
|
| 303 |
+
|
| 304 |
+
### Share Your Space
|
| 305 |
+
|
| 306 |
+
- Post on Twitter/X with #HuggingFace #ACEStep
|
| 307 |
+
- Share in AI music communities
|
| 308 |
+
- Add to your portfolio
|
| 309 |
+
- Write blog post about it
|
| 310 |
+
|
| 311 |
+
## Advanced Configuration
|
| 312 |
+
|
| 313 |
+
### Custom Domain (Pro)
|
| 314 |
+
|
| 315 |
+
HuggingFace Pro users can set custom domains:
|
| 316 |
+
1. Go to Space settings
|
| 317 |
+
2. Add custom domain
|
| 318 |
+
3. Configure DNS
|
| 319 |
+
|
| 320 |
+
### Persistent Storage
|
| 321 |
+
|
| 322 |
+
For saving user data:
|
| 323 |
+
```python
|
| 324 |
+
import os
|
| 325 |
+
PERSIST_DIR = os.environ.get("SPACE_ID", "local")
|
| 326 |
+
# Save to /data/{SPACE_ID}/
|
| 327 |
+
```
|
| 328 |
+
|
| 329 |
+
### Analytics Integration
|
| 330 |
+
|
| 331 |
+
Add Google Analytics or similar:
|
| 332 |
+
```python
|
| 333 |
+
# In app.py
|
| 334 |
+
analytics_code = """
|
| 335 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=YOUR-ID"></script>
|
| 336 |
+
<script>
|
| 337 |
+
window.dataLayer = window.dataLayer || [];
|
| 338 |
+
function gtag(){dataLayer.push(arguments);}
|
| 339 |
+
gtag('js', new Date());
|
| 340 |
+
gtag('config', 'YOUR-ID');
|
| 341 |
+
</script>
|
| 342 |
+
"""
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
## Success Checklist
|
| 346 |
+
|
| 347 |
+
Before announcing your Space:
|
| 348 |
+
|
| 349 |
+
- ✅ All features working
|
| 350 |
+
- ✅ Clear documentation
|
| 351 |
+
- ✅ Example outputs included
|
| 352 |
+
- ✅ Error handling robust
|
| 353 |
+
- ✅ Performance optimized
|
| 354 |
+
- ✅ Mobile-friendly UI
|
| 355 |
+
- ✅ Clear limitations stated
|
| 356 |
+
- ✅ License properly attributed
|
| 357 |
+
- ✅ Usage guidelines clear
|
| 358 |
+
- ✅ Contact/support info provided
|
| 359 |
+
|
| 360 |
+
## Your Space URL
|
| 361 |
+
|
| 362 |
+
After deployment, your Space will be available at:
|
| 363 |
+
```
|
| 364 |
+
https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
Share it with the world! 🎵🚀
|
DEPLOYMENT_CHECKLIST.txt
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
⚠️ IMPORTANT: Follow these steps in order ⚠️
|
| 2 |
+
|
| 3 |
+
═══════════════════════════════════════════════════════════════
|
| 4 |
+
🚀 HuggingFace Spaces Deployment - Step by Step
|
| 5 |
+
═══════════════════════════════════════════════════════════════
|
| 6 |
+
|
| 7 |
+
📋 PREREQUISITES
|
| 8 |
+
═══════════════════════════════════════════════════════════════
|
| 9 |
+
☐ HuggingFace account created: https://huggingface.co/join
|
| 10 |
+
☐ HuggingFace token obtained: https://huggingface.co/settings/tokens
|
| 11 |
+
(Create new token with "write" access)
|
| 12 |
+
☐ HuggingFace CLI installed (already done ✓)
|
| 13 |
+
|
| 14 |
+
═══════════════════════════════════════════════════════════════
|
| 15 |
+
🎯 DEPLOYMENT STEPS
|
| 16 |
+
═══════════════════════════════════════════════════════════════
|
| 17 |
+
|
| 18 |
+
Choose ONE method below:
|
| 19 |
+
|
| 20 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 21 |
+
│ METHOD 1: AUTOMATED SCRIPT (EASIEST) ⭐ │
|
| 22 |
+
└─────────────────────────────────────────────────────────────┘
|
| 23 |
+
|
| 24 |
+
1. Open PowerShell in this directory:
|
| 25 |
+
d:\2025-vibe-coding\ACE-Step-Custom
|
| 26 |
+
|
| 27 |
+
2. Run the deployment script:
|
| 28 |
+
.\deploy_hf.bat
|
| 29 |
+
|
| 30 |
+
3. Follow the prompts:
|
| 31 |
+
- Login with your HF token
|
| 32 |
+
- Enter Space name (e.g., "ace-step-custom")
|
| 33 |
+
- Wait for upload
|
| 34 |
+
|
| 35 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 36 |
+
│ METHOD 2: MANUAL CLI (FOR DEVELOPERS) │
|
| 37 |
+
└─────────────────────────────────────────────────────────────┘
|
| 38 |
+
|
| 39 |
+
1. Login to HuggingFace:
|
| 40 |
+
huggingface-cli login
|
| 41 |
+
[Paste your token]
|
| 42 |
+
|
| 43 |
+
2. Create the Space:
|
| 44 |
+
huggingface-cli repo create ace-step-custom --type space --space_sdk gradio
|
| 45 |
+
|
| 46 |
+
3. Upload files:
|
| 47 |
+
huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space
|
| 48 |
+
|
| 49 |
+
┌─────────────────────────────────────────────────────────────┐
|
| 50 |
+
│ METHOD 3: WEB INTERFACE (NO CLI NEEDED) │
|
| 51 |
+
└─────────────────────────────────────────────────────────────┘
|
| 52 |
+
|
| 53 |
+
1. Go to: https://huggingface.co/new-space
|
| 54 |
+
|
| 55 |
+
2. Fill in Space details:
|
| 56 |
+
Name: ace-step-custom
|
| 57 |
+
License: MIT
|
| 58 |
+
SDK: Gradio
|
| 59 |
+
Hardware: A10G Small
|
| 60 |
+
|
| 61 |
+
3. Click "Create Space"
|
| 62 |
+
|
| 63 |
+
4. Click "Files and versions" → "Add file" → "Upload files"
|
| 64 |
+
|
| 65 |
+
5. Upload these files/folders:
|
| 66 |
+
✓ app.py
|
| 67 |
+
✓ requirements.txt
|
| 68 |
+
✓ config.yaml
|
| 69 |
+
✓ README.md
|
| 70 |
+
✓ LICENSE
|
| 71 |
+
✓ .gitignore
|
| 72 |
+
✓ src/ (entire folder)
|
| 73 |
+
✓ scripts/ (entire folder)
|
| 74 |
+
|
| 75 |
+
6. Commit changes
|
| 76 |
+
|
| 77 |
+
═══════════════════════════════════════════════════════════════
|
| 78 |
+
⚙️ POST-DEPLOYMENT CONFIGURATION
|
| 79 |
+
═══════════════════════════════════════════════════════════════
|
| 80 |
+
|
| 81 |
+
After upload, configure your Space:
|
| 82 |
+
|
| 83 |
+
1. ☐ Go to your Space URL:
|
| 84 |
+
https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
|
| 85 |
+
|
| 86 |
+
2. ☐ Click "Settings" tab
|
| 87 |
+
|
| 88 |
+
3. ☐ Configure Hardware:
|
| 89 |
+
- Select: "A10G Small" (24GB VRAM) - MINIMUM
|
| 90 |
+
- Or: "A100 Large" (40GB VRAM) - RECOMMENDED
|
| 91 |
+
- Click "Save"
|
| 92 |
+
|
| 93 |
+
4. ☐ Set Python version: 3.10 (should be automatic)
|
| 94 |
+
|
| 95 |
+
5. ☐ Set timeout: 30 minutes (optional, for long generations)
|
| 96 |
+
|
| 97 |
+
6. ☐ Enable Discussions (optional, for user feedback)
|
| 98 |
+
|
| 99 |
+
═══════════════════════════════════════════════════════════════
|
| 100 |
+
⏱️ BUILD & TESTING
|
| 101 |
+
═══════════════════════════════════════════════════════════════
|
| 102 |
+
|
| 103 |
+
1. ☐ Wait for build to complete:
|
| 104 |
+
- Click "Logs" tab to monitor
|
| 105 |
+
- First build: 5-10 minutes
|
| 106 |
+
- Model download: ~7GB (first run only)
|
| 107 |
+
|
| 108 |
+
2. ☐ Space will show "Running" when ready
|
| 109 |
+
|
| 110 |
+
3. ☐ Test Tab 1 (Standard ACE-Step):
|
| 111 |
+
- Enter prompt: "Happy pop song with piano"
|
| 112 |
+
- Set duration: 10 seconds
|
| 113 |
+
- Click "Generate"
|
| 114 |
+
- Verify audio plays
|
| 115 |
+
|
| 116 |
+
4. ☐ Test Tab 2 (Timeline Workflow):
|
| 117 |
+
- Enter prompt and lyrics
|
| 118 |
+
- Set context length: 30 seconds
|
| 119 |
+
- Click "Generate Clip"
|
| 120 |
+
- Verify timeline updates
|
| 121 |
+
|
| 122 |
+
5. ☐ Test Tab 3 (LoRA Training):
|
| 123 |
+
- Upload 2-3 test audio files
|
| 124 |
+
- Set epochs to 2
|
| 125 |
+
- Click "Start Training"
|
| 126 |
+
- Verify progress updates
|
| 127 |
+
|
| 128 |
+
═══════════════════════════════════════════════════════════════
|
| 129 |
+
💰 COST MANAGEMENT
|
| 130 |
+
═══════════════════════════════════════════════════════════════
|
| 131 |
+
|
| 132 |
+
GPU Costs:
|
| 133 |
+
- A10G Small (24GB): ~$1.05/hour ⭐ RECOMMENDED
|
| 134 |
+
- A100 Large (40GB): ~$4.13/hour
|
| 135 |
+
|
| 136 |
+
Auto-Sleep:
|
| 137 |
+
✓ Space sleeps automatically after 48 hours of inactivity
|
| 138 |
+
✓ Wakes up on first request (30-60 second startup)
|
| 139 |
+
✓ No charges while sleeping
|
| 140 |
+
|
| 141 |
+
Testing Budget:
|
| 142 |
+
- Initial testing: ~$5-10
|
| 143 |
+
- Active use: ~$10-50/month
|
| 144 |
+
- Production: Scale as needed
|
| 145 |
+
|
| 146 |
+
═══════════════════════════════════════════════════════════════
|
| 147 |
+
🐛 TROUBLESHOOTING
|
| 148 |
+
═══════════════════════════════════════════════════════════════
|
| 149 |
+
|
| 150 |
+
Problem: Space won't start
|
| 151 |
+
Solution:
|
| 152 |
+
- Check "Logs" tab for errors
|
| 153 |
+
- Verify all files uploaded
|
| 154 |
+
- Ensure README.md has YAML frontmatter
|
| 155 |
+
|
| 156 |
+
Problem: Out of memory error
|
| 157 |
+
Solution:
|
| 158 |
+
- Upgrade to A100 Large
|
| 159 |
+
- Reduce generation duration in UI
|
| 160 |
+
- Lower batch size in LoRA training
|
| 161 |
+
|
| 162 |
+
Problem: Slow generation
|
| 163 |
+
Solution:
|
| 164 |
+
- Verify GPU is enabled (not CPU)
|
| 165 |
+
- Check Space isn't sleeping
|
| 166 |
+
- Reduce sampling steps in config
|
| 167 |
+
|
| 168 |
+
Problem: Model download fails
|
| 169 |
+
Solution:
|
| 170 |
+
- Check HuggingFace Hub status
|
| 171 |
+
- Verify internet connectivity
|
| 172 |
+
- Wait and retry
|
| 173 |
+
|
| 174 |
+
═══════════════════════════════════════════════════════════════
|
| 175 |
+
✅ SUCCESS CHECKLIST
|
| 176 |
+
═══════════════════════════════════════════════════════════════
|
| 177 |
+
|
| 178 |
+
Before announcing your Space:
|
| 179 |
+
|
| 180 |
+
☐ All three tabs tested and working
|
| 181 |
+
☐ Example generations added to README
|
| 182 |
+
☐ Clear usage instructions visible
|
| 183 |
+
☐ GPU enabled (A10G Small minimum)
|
| 184 |
+
☐ Error handling tested
|
| 185 |
+
☐ Mobile view checked
|
| 186 |
+
☐ Discussions enabled
|
| 187 |
+
☐ License properly displayed
|
| 188 |
+
☐ Contact/support info added
|
| 189 |
+
☐ Share link works
|
| 190 |
+
|
| 191 |
+
═══════════════════════════════════════════════════════════════
|
| 192 |
+
🎉 GO LIVE!
|
| 193 |
+
═══════════════════════════════════════════════════════════════
|
| 194 |
+
|
| 195 |
+
Your Space URL:
|
| 196 |
+
https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
|
| 197 |
+
|
| 198 |
+
Share it:
|
| 199 |
+
□ Twitter/X: "Just deployed ACE-Step 1.5 Custom on @huggingface! 🎵
|
| 200 |
+
Check it out: [your-url] #AIMusic #HuggingFace #ACEStep"
|
| 201 |
+
□ LinkedIn post
|
| 202 |
+
□ Reddit (r/MachineLearning, r/artificial, r/WeAreTheMusicMakers)
|
| 203 |
+
□ Discord communities
|
| 204 |
+
□ Personal blog/portfolio
|
| 205 |
+
|
| 206 |
+
═══════════════════════════════════════════════════════════════
|
| 207 |
+
📚 ADDITIONAL RESOURCES
|
| 208 |
+
═══════════════════════════════════════════════════════════════
|
| 209 |
+
|
| 210 |
+
Documentation:
|
| 211 |
+
- DEPLOY_QUICK.md - Quick reference
|
| 212 |
+
- DEPLOYMENT.md - Complete guide
|
| 213 |
+
- README.md - Project documentation
|
| 214 |
+
|
| 215 |
+
Support:
|
| 216 |
+
- HuggingFace Docs: https://huggingface.co/docs/hub/spaces
|
| 217 |
+
- HuggingFace Discord: https://discord.gg/huggingface
|
| 218 |
+
- GitHub Issues: [your-repo-url]
|
| 219 |
+
|
| 220 |
+
═══════════════════════════════════════════════════════════════
|
| 221 |
+
|
| 222 |
+
Ready to deploy? 🚀
|
| 223 |
+
|
| 224 |
+
Run: .\deploy_hf.bat
|
| 225 |
+
|
| 226 |
+
═══════════════════════════════════════════════════════════════
|
DEPLOY_QUICK.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Quick Deployment to HuggingFace Spaces
|
| 2 |
+
|
| 3 |
+
## Prerequisites
|
| 4 |
+
✅ HuggingFace account: https://huggingface.co/join
|
| 5 |
+
✅ HuggingFace token: https://huggingface.co/settings/tokens
|
| 6 |
+
|
| 7 |
+
## Fastest Method (Windows)
|
| 8 |
+
|
| 9 |
+
Run the deployment script:
|
| 10 |
+
|
| 11 |
+
```powershell
|
| 12 |
+
cd d:\2025-vibe-coding\ACE-Step-Custom
|
| 13 |
+
.\deploy_hf.bat
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
The script will:
|
| 17 |
+
1. Install HuggingFace CLI (if needed)
|
| 18 |
+
2. Login to your account
|
| 19 |
+
3. Create new Space
|
| 20 |
+
4. Upload all files
|
| 21 |
+
5. Provide your Space URL
|
| 22 |
+
|
| 23 |
+
## Fastest Method (Linux/Mac)
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
cd /path/to/ACE-Step-Custom
|
| 27 |
+
chmod +x deploy_hf.sh
|
| 28 |
+
./deploy_hf.sh
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Manual Deployment (If Script Fails)
|
| 32 |
+
|
| 33 |
+
### 1. Install HuggingFace CLI
|
| 34 |
+
```powershell
|
| 35 |
+
pip install huggingface_hub
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
### 2. Login
|
| 39 |
+
```powershell
|
| 40 |
+
huggingface-cli login
|
| 41 |
+
```
|
| 42 |
+
Enter your token from: https://huggingface.co/settings/tokens
|
| 43 |
+
|
| 44 |
+
### 3. Create Space
|
| 45 |
+
```powershell
|
| 46 |
+
huggingface-cli repo create ace-step-custom --type space --space_sdk gradio
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### 4. Upload Files
|
| 50 |
+
```powershell
|
| 51 |
+
cd d:\2025-vibe-coding\ACE-Step-Custom
|
| 52 |
+
huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Replace `YOUR_USERNAME` with your HuggingFace username.
|
| 56 |
+
|
| 57 |
+
## After Upload
|
| 58 |
+
|
| 59 |
+
### 1. Configure GPU
|
| 60 |
+
- Go to your Space: https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
|
| 61 |
+
- Click "Settings" tab
|
| 62 |
+
- Under "Hardware", select: **A10G Small** (recommended)
|
| 63 |
+
- Click "Save"
|
| 64 |
+
|
| 65 |
+
### 2. Wait for Build
|
| 66 |
+
- Space will build automatically (5-10 minutes)
|
| 67 |
+
- Check "Logs" tab for progress
|
| 68 |
+
- Model downloads on first run (~7GB)
|
| 69 |
+
|
| 70 |
+
### 3. Test Your Space
|
| 71 |
+
1. Open Space URL
|
| 72 |
+
2. Test Tab 1: Generate 10-second clip
|
| 73 |
+
3. Test Tab 2: Generate timeline clip
|
| 74 |
+
4. Test Tab 3: Upload test audio
|
| 75 |
+
|
| 76 |
+
## Troubleshooting
|
| 77 |
+
|
| 78 |
+
**Login Failed:**
|
| 79 |
+
```powershell
|
| 80 |
+
# Make sure you copied the full token
|
| 81 |
+
huggingface-cli whoami # Check if logged in
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
**Upload Failed:**
|
| 85 |
+
```powershell
|
| 86 |
+
# Try with explicit exclusions
|
| 87 |
+
huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space --exclude "*.pyc" --exclude "outputs/*" --exclude "__pycache__/*"
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
**Space Not Starting:**
|
| 91 |
+
- Check "Logs" tab for errors
|
| 92 |
+
- Verify requirements.txt is uploaded
|
| 93 |
+
- Ensure README.md has correct YAML frontmatter
|
| 94 |
+
|
| 95 |
+
**Out of Memory:**
|
| 96 |
+
- Upgrade GPU in Settings
|
| 97 |
+
- Start with A10G Small minimum
|
| 98 |
+
|
| 99 |
+
## Your Space URL
|
| 100 |
+
|
| 101 |
+
After deployment:
|
| 102 |
+
```
|
| 103 |
+
https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## Cost Estimate
|
| 107 |
+
|
| 108 |
+
- **A10G Small (24GB):** ~$1.05/hour
|
| 109 |
+
- **Auto-sleep:** Space sleeps when inactive (no charge)
|
| 110 |
+
- **Testing:** Budget ~$5-10 for initial testing
|
| 111 |
+
|
| 112 |
+
## Need Help?
|
| 113 |
+
|
| 114 |
+
See full guide: [DEPLOYMENT.md](DEPLOYMENT.md)
|
| 115 |
+
|
| 116 |
+
## Next Steps
|
| 117 |
+
|
| 118 |
+
1. ✅ Deploy Space
|
| 119 |
+
2. ✅ Test all features
|
| 120 |
+
3. ✅ Enable Discussions in Settings
|
| 121 |
+
4. ✅ Add example outputs to README
|
| 122 |
+
5. ✅ Share your Space!
|
| 123 |
+
|
| 124 |
+
---
|
| 125 |
+
|
| 126 |
+
🎵 Happy testing! Your Space will be live in minutes! 🚀
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
# Set working directory
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Install system dependencies
|
| 7 |
+
RUN apt-get update && apt-get install -y \
|
| 8 |
+
python3.10 \
|
| 9 |
+
python3-pip \
|
| 10 |
+
git \
|
| 11 |
+
ffmpeg \
|
| 12 |
+
libsndfile1 \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Copy requirements
|
| 16 |
+
COPY requirements.txt .
|
| 17 |
+
|
| 18 |
+
# Install Python dependencies
|
| 19 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy application files
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# Create necessary directories
|
| 25 |
+
RUN mkdir -p outputs timelines lora_training logs models
|
| 26 |
+
|
| 27 |
+
# Expose Gradio port
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Set environment variables
|
| 31 |
+
ENV GRADIO_SERVER_NAME="0.0.0.0"
|
| 32 |
+
ENV GRADIO_SERVER_PORT=7860
|
| 33 |
+
|
| 34 |
+
# Run the application
|
| 35 |
+
CMD ["python3", "app.py"]
|
LICENSE
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Gamahea Development Team
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
This project uses ACE-Step, which is subject to its own license:
|
| 26 |
+
https://github.com/ace-step/ACE-Step
|
| 27 |
+
|
| 28 |
+
Please refer to the original ACE-Step repository for their licensing terms.
|
QUICKSTART.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ACE-Step 1.5 Custom Edition - Quick Start Guide
|
| 2 |
+
|
| 3 |
+
## Installation
|
| 4 |
+
|
| 5 |
+
### Option 1: Local Setup
|
| 6 |
+
|
| 7 |
+
1. **Clone the repository**
|
| 8 |
+
```bash
|
| 9 |
+
git clone https://github.com/yourusername/ace-step-custom.git
|
| 10 |
+
cd ace-step-custom
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
2. **Create virtual environment**
|
| 14 |
+
```bash
|
| 15 |
+
python -m venv venv
|
| 16 |
+
|
| 17 |
+
# On Windows:
|
| 18 |
+
venv\Scripts\activate
|
| 19 |
+
|
| 20 |
+
# On Linux/Mac:
|
| 21 |
+
source venv/bin/activate
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
3. **Run setup**
|
| 25 |
+
```bash
|
| 26 |
+
python scripts/setup.py
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
4. **Download model**
|
| 30 |
+
```bash
|
| 31 |
+
python scripts/download_model.py
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
5. **Launch application**
|
| 35 |
+
```bash
|
| 36 |
+
python app.py
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
6. **Open browser to** `http://localhost:7860`
|
| 40 |
+
|
| 41 |
+
### Option 2: HuggingFace Spaces
|
| 42 |
+
|
| 43 |
+
1. Create new Space on HuggingFace
|
| 44 |
+
2. Upload all project files
|
| 45 |
+
3. Set Space configuration:
|
| 46 |
+
- SDK: `gradio`
|
| 47 |
+
- Python: `3.10`
|
| 48 |
+
- GPU: `A10G` (or better)
|
| 49 |
+
4. Space will auto-deploy
|
| 50 |
+
|
| 51 |
+
## Usage
|
| 52 |
+
|
| 53 |
+
### Tab 1: Standard ACE-Step
|
| 54 |
+
|
| 55 |
+
Standard interface with all original ACE-Step features:
|
| 56 |
+
- Text-to-music generation
|
| 57 |
+
- Variation generation
|
| 58 |
+
- Repainting sections
|
| 59 |
+
- Lyric editing
|
| 60 |
+
|
| 61 |
+
### Tab 2: Timeline Workflow
|
| 62 |
+
|
| 63 |
+
Advanced timeline-based generation:
|
| 64 |
+
1. Enter prompt and lyrics
|
| 65 |
+
2. Set context length (0-120s)
|
| 66 |
+
3. Click "Generate" for 32s clips
|
| 67 |
+
4. Clips auto-blend into timeline
|
| 68 |
+
5. Use "Extend" to continue
|
| 69 |
+
6. Use "Inpaint" to edit regions
|
| 70 |
+
|
| 71 |
+
### Tab 3: LoRA Training
|
| 72 |
+
|
| 73 |
+
Train custom models:
|
| 74 |
+
1. Upload audio files (10+ recommended)
|
| 75 |
+
2. Set training parameters
|
| 76 |
+
3. Click "Start Training"
|
| 77 |
+
4. Download trained model
|
| 78 |
+
5. Use in Tab 1 or Tab 2
|
| 79 |
+
|
| 80 |
+
## Tips
|
| 81 |
+
|
| 82 |
+
- **First time:** Start with Standard tab to understand basics
|
| 83 |
+
- **For longer songs:** Use Timeline tab with context length 30-60s
|
| 84 |
+
- **For custom styles:** Train LoRA with 20+ similar audio files
|
| 85 |
+
- **GPU recommended:** 8GB+ VRAM for best performance
|
| 86 |
+
- **CPU mode:** Works but slower, use shorter durations
|
| 87 |
+
|
| 88 |
+
## Troubleshooting
|
| 89 |
+
|
| 90 |
+
### Out of Memory
|
| 91 |
+
- Reduce batch size in LoRA training
|
| 92 |
+
- Use shorter audio durations
|
| 93 |
+
- Close other GPU applications
|
| 94 |
+
|
| 95 |
+
### Poor Quality
|
| 96 |
+
- Increase context length
|
| 97 |
+
- Try different seeds
|
| 98 |
+
- Adjust temperature (0.6-0.8 is usually good)
|
| 99 |
+
|
| 100 |
+
### Blend Artifacts
|
| 101 |
+
- Reduce lead-in/lead-out durations
|
| 102 |
+
- Ensure consistent style across clips
|
| 103 |
+
- Use lower context length for more variety
|
| 104 |
+
|
| 105 |
+
## Support
|
| 106 |
+
|
| 107 |
+
- GitHub Issues: [Report bugs here]
|
| 108 |
+
- Documentation: See `docs/` directory
|
| 109 |
+
- Examples: See `examples/` directory
|
| 110 |
+
|
| 111 |
+
## Credits
|
| 112 |
+
|
| 113 |
+
Based on ACE-Step by ACE Studio and Step Fun
|
| 114 |
+
- Website: https://ace-step.github.io/
|
| 115 |
+
- Paper: https://arxiv.org/abs/2506.00045
|
README.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ACE-Step 1.5 Custom Edition
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: 3.11
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# ACE-Step 1.5 Custom Edition
|
| 15 |
+
|
| 16 |
+
A comprehensive music generation system built on ACE-Step 1.5, featuring:
|
| 17 |
+
|
| 18 |
+
## 🌟 Features
|
| 19 |
+
|
| 20 |
+
### 1. Standard ACE-Step Interface
|
| 21 |
+
Full-featured standard ACE-Step 1.5 GUI with all original capabilities including:
|
| 22 |
+
- Text-to-music generation with style control
|
| 23 |
+
- Variation generation
|
| 24 |
+
- Section repainting
|
| 25 |
+
- Lyric editing
|
| 26 |
+
|
| 27 |
+
### 2. Custom Timeline Workflow
|
| 28 |
+
Advanced timeline-based generation system:
|
| 29 |
+
- Generate 32-second clips with seamless blending
|
| 30 |
+
- Adjustable context length (0-120 seconds) for style consistency
|
| 31 |
+
- Master timeline with visual representation
|
| 32 |
+
- Extend, inpaint, and remix capabilities
|
| 33 |
+
- Automatic crossfading between clips
|
| 34 |
+
|
| 35 |
+
### 3. LoRA Training Studio
|
| 36 |
+
Complete training interface for custom models:
|
| 37 |
+
- Upload and preprocess audio files
|
| 38 |
+
- Configure training parameters
|
| 39 |
+
- Train specialized models for voices, instruments, or styles
|
| 40 |
+
- Download and reuse trained models
|
| 41 |
+
- Continue training from existing LoRAs
|
| 42 |
+
|
| 43 |
+
## 🚀 Quick Start
|
| 44 |
+
|
| 45 |
+
1. **Standard Generation**: Use Tab 1 for traditional text-to-music
|
| 46 |
+
2. **Timeline Creation**: Use Tab 2 to build longer songs with consistent style
|
| 47 |
+
3. **Custom Training**: Use Tab 3 to create specialized models
|
| 48 |
+
|
| 49 |
+
## 💡 Tips
|
| 50 |
+
|
| 51 |
+
- Start with context length of 30-60s for best results
|
| 52 |
+
- For custom voices, train LoRA with 20+ audio samples
|
| 53 |
+
- Adjust temperature between 0.6-0.8 for quality vs creativity
|
| 54 |
+
- Use "Extend" in Timeline mode to continue your song
|
| 55 |
+
|
| 56 |
+
## 🎯 Use Cases
|
| 57 |
+
|
| 58 |
+
- **Musicians**: Create backing tracks and song ideas
|
| 59 |
+
- **Content Creators**: Generate royalty-free music for videos
|
| 60 |
+
- **Game Developers**: Create adaptive game soundtracks
|
| 61 |
+
- **AI Researchers**: Experiment with music generation and LoRA training
|
| 62 |
+
|
| 63 |
+
## 📚 Documentation
|
| 64 |
+
|
| 65 |
+
See the repository for full documentation and examples.
|
| 66 |
+
|
| 67 |
+
## 🙏 Credits
|
| 68 |
+
|
| 69 |
+
Built on top of [ACE-Step](https://ace-step.github.io/) by ACE Studio and Step Fun.
|
| 70 |
+
|
| 71 |
+
## ⚠️ Note
|
| 72 |
+
|
| 73 |
+
This is a custom implementation focusing on enhanced workflows and training capabilities. Generation quality depends on the base ACE-Step model and your usage patterns.
|
README_HF.md
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ACE-Step 1.5 Custom Edition
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: 3.11
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# ACE-Step 1.5 Custom Edition
|
| 15 |
+
|
| 16 |
+
A comprehensive music generation system built on ACE-Step 1.5, featuring:
|
| 17 |
+
|
| 18 |
+
## 🌟 Features
|
| 19 |
+
|
| 20 |
+
### 1. Standard ACE-Step Interface
|
| 21 |
+
Full-featured standard ACE-Step 1.5 GUI with all original capabilities including:
|
| 22 |
+
- Text-to-music generation with style control
|
| 23 |
+
- Variation generation
|
| 24 |
+
- Section repainting
|
| 25 |
+
- Lyric editing
|
| 26 |
+
|
| 27 |
+
### 2. Custom Timeline Workflow
|
| 28 |
+
Advanced timeline-based generation system:
|
| 29 |
+
- Generate 32-second clips with seamless blending
|
| 30 |
+
- Adjustable context length (0-120 seconds) for style consistency
|
| 31 |
+
- Master timeline with visual representation
|
| 32 |
+
- Extend, inpaint, and remix capabilities
|
| 33 |
+
- Automatic crossfading between clips
|
| 34 |
+
|
| 35 |
+
### 3. LoRA Training Studio
|
| 36 |
+
Complete training interface for custom models:
|
| 37 |
+
- Upload and preprocess audio files
|
| 38 |
+
- Configure training parameters
|
| 39 |
+
- Train specialized models for voices, instruments, or styles
|
| 40 |
+
- Download and reuse trained models
|
| 41 |
+
- Continue training from existing LoRAs
|
| 42 |
+
|
| 43 |
+
## 🚀 Quick Start
|
| 44 |
+
|
| 45 |
+
1. **Standard Generation**: Use Tab 1 for traditional text-to-music
|
| 46 |
+
2. **Timeline Creation**: Use Tab 2 to build longer songs with consistent style
|
| 47 |
+
3. **Custom Training**: Use Tab 3 to create specialized models
|
| 48 |
+
|
| 49 |
+
## 💡 Tips
|
| 50 |
+
|
| 51 |
+
- Start with context length of 30-60s for best results
|
| 52 |
+
- For custom voices, train LoRA with 20+ audio samples
|
| 53 |
+
- Adjust temperature between 0.6-0.8 for quality vs creativity
|
| 54 |
+
- Use "Extend" in Timeline mode to continue your song
|
| 55 |
+
|
| 56 |
+
## 🎯 Use Cases
|
| 57 |
+
|
| 58 |
+
- **Musicians**: Create backing tracks and song ideas
|
| 59 |
+
- **Content Creators**: Generate royalty-free music for videos
|
| 60 |
+
- **Game Developers**: Create adaptive game soundtracks
|
| 61 |
+
- **AI Researchers**: Experiment with music generation and LoRA training
|
| 62 |
+
|
| 63 |
+
## 📚 Documentation
|
| 64 |
+
|
| 65 |
+
See the repository for full documentation and examples.
|
| 66 |
+
|
| 67 |
+
## 🙏 Credits
|
| 68 |
+
|
| 69 |
+
Built on top of [ACE-Step](https://ace-step.github.io/) by ACE Studio and Step Fun.
|
| 70 |
+
|
| 71 |
+
## ⚠️ Note
|
| 72 |
+
|
| 73 |
+
This is a custom implementation focusing on enhanced workflows and training capabilities. Generation quality depends on the base ACE-Step model and your usage patterns.
|
README_PROJECT.md
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: ACE-Step 1.5 Custom Edition
|
| 3 |
+
emoji: 🎵
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.9.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
python_version: "3.11"
|
| 12 |
+
hardware: zero-gpu-medium
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
# ACE-Step 1.5 Custom Edition
|
| 16 |
+
|
| 17 |
+
A fully-featured implementation of ACE-Step 1.5 with custom GUI and workflow capabilities for local use and HuggingFace Space deployment.
|
| 18 |
+
|
| 19 |
+
## Features
|
| 20 |
+
|
| 21 |
+
### 🎵 Three Main Interfaces
|
| 22 |
+
|
| 23 |
+
1. **Standard ACE-Step GUI**: Full-featured standard ACE-Step 1.5 interface with all original capabilities
|
| 24 |
+
2. **Custom Timeline Workflow**: Advanced timeline-based generation with:
|
| 25 |
+
- 32-second clip generation (2s lead-in + 28s main + 2s lead-out)
|
| 26 |
+
- Seamless clip blending for continuous music
|
| 27 |
+
- Context Length slider (0-120 seconds) for style guidance
|
| 28 |
+
- Master timeline with extend, inpaint, and remix capabilities
|
| 29 |
+
3. **LoRA Training Studio**: Complete LoRA training interface with:
|
| 30 |
+
- Audio file upload and preprocessing
|
| 31 |
+
- Custom training configuration
|
| 32 |
+
- Model download/upload for continued training
|
| 33 |
+
|
| 34 |
+
## Architecture
|
| 35 |
+
|
| 36 |
+
- **Base Model**: ACE-Step v1.5 Turbo
|
| 37 |
+
- **Framework**: Gradio 5.9.1, PyTorch
|
| 38 |
+
- **Deployment**: Local execution + HuggingFace Spaces
|
| 39 |
+
- **Audio Processing**: DiT + VAE + 5Hz Language Model
|
| 40 |
+
|
| 41 |
+
## Installation
|
| 42 |
+
|
| 43 |
+
### Local Setup
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
# Clone the repository
|
| 47 |
+
git clone https://github.com/yourusername/ace-step-custom.git
|
| 48 |
+
cd ace-step-custom
|
| 49 |
+
|
| 50 |
+
# Create virtual environment
|
| 51 |
+
python -m venv venv
|
| 52 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 53 |
+
|
| 54 |
+
# Install dependencies
|
| 55 |
+
pip install -r requirements.txt
|
| 56 |
+
|
| 57 |
+
# Download ACE-Step model
|
| 58 |
+
python scripts/download_model.py
|
| 59 |
+
|
| 60 |
+
# Run the application
|
| 61 |
+
python app.py
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
### HuggingFace Space Deployment
|
| 65 |
+
|
| 66 |
+
1. Create a new Space on HuggingFace
|
| 67 |
+
2. Upload all files to the Space
|
| 68 |
+
3. Set Space to use GPU (recommended: H200 or A100)
|
| 69 |
+
4. The app will automatically download models and start
|
| 70 |
+
|
| 71 |
+
## Usage
|
| 72 |
+
|
| 73 |
+
### Standard Mode
|
| 74 |
+
Use the first tab for standard ACE-Step generation with all original features.
|
| 75 |
+
|
| 76 |
+
### Timeline Mode
|
| 77 |
+
1. Enter your prompt/lyrics
|
| 78 |
+
2. Adjust Context Length (how far back to reference previous clips)
|
| 79 |
+
3. Click "Generate" to create 32-second clips
|
| 80 |
+
4. Clips automatically blend and add to timeline
|
| 81 |
+
5. Use "Extend" to continue the song or other options for variations
|
| 82 |
+
|
| 83 |
+
### LoRA Training
|
| 84 |
+
1. Upload audio files for training
|
| 85 |
+
2. Configure training parameters
|
| 86 |
+
3. Train custom LoRA models
|
| 87 |
+
4. Download and reuse for continued training
|
| 88 |
+
|
| 89 |
+
## System Requirements
|
| 90 |
+
|
| 91 |
+
### Minimum
|
| 92 |
+
- GPU: 8GB VRAM (with optimizations)
|
| 93 |
+
- RAM: 16GB
|
| 94 |
+
- Storage: 20GB
|
| 95 |
+
|
| 96 |
+
### Recommended
|
| 97 |
+
- GPU: 16GB+ VRAM (A100, H200, or consumer GPUs)
|
| 98 |
+
- RAM: 32GB
|
| 99 |
+
- Storage: 50GB
|
| 100 |
+
|
| 101 |
+
## Technical Details
|
| 102 |
+
|
| 103 |
+
- **Audio Format**: 48kHz, stereo
|
| 104 |
+
- **Generation Speed**: ~8 inference steps (turbo model)
|
| 105 |
+
- **Context Window**: Up to 120 seconds for style guidance
|
| 106 |
+
- **Blend Regions**: 2-second crossfade between clips
|
| 107 |
+
|
| 108 |
+
## Credits
|
| 109 |
+
|
| 110 |
+
Based on ACE-Step 1.5 by ACE Studio
|
| 111 |
+
- GitHub: https://github.com/ace-step/ACE-Step-1.5
|
| 112 |
+
- Original Demo: https://huggingface.co/spaces/ACE-Step/ACE-Step
|
| 113 |
+
|
| 114 |
+
## License
|
| 115 |
+
|
| 116 |
+
MIT License (see LICENSE file)
|
acestep/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""ACE-Step package."""
|
acestep/acestep_v15_pipeline.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step V1.5 Pipeline
|
| 3 |
+
Handler wrapper connecting model and UI
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
# Load environment variables from .env file at most once per process to avoid
|
| 9 |
+
# epoch-boundary stalls (e.g. on Windows when Gradio yields during training)
|
| 10 |
+
_env_loaded = False # module-level so we never reload .env in the same process
|
| 11 |
+
try:
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
if not _env_loaded:
|
| 14 |
+
_current_file = os.path.abspath(__file__)
|
| 15 |
+
_project_root = os.path.dirname(os.path.dirname(_current_file))
|
| 16 |
+
_env_path = os.path.join(_project_root, '.env')
|
| 17 |
+
_env_example_path = os.path.join(_project_root, '.env.example')
|
| 18 |
+
if os.path.exists(_env_path):
|
| 19 |
+
load_dotenv(_env_path)
|
| 20 |
+
print(f"Loaded configuration from {_env_path}")
|
| 21 |
+
elif os.path.exists(_env_example_path):
|
| 22 |
+
load_dotenv(_env_example_path)
|
| 23 |
+
print(f"Loaded configuration from {_env_example_path} (fallback)")
|
| 24 |
+
_env_loaded = True
|
| 25 |
+
except ImportError:
|
| 26 |
+
# python-dotenv not installed, skip loading .env
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
# Clear proxy settings that may affect Gradio
|
| 30 |
+
for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
|
| 31 |
+
os.environ.pop(proxy_var, None)
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
# When executed as a module: `python -m acestep.acestep_v15_pipeline`
|
| 35 |
+
from .handler import AceStepHandler
|
| 36 |
+
from .llm_inference import LLMHandler
|
| 37 |
+
from .dataset_handler import DatasetHandler
|
| 38 |
+
from .gradio_ui import create_gradio_interface
|
| 39 |
+
from .gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB
|
| 40 |
+
except ImportError:
|
| 41 |
+
# When executed as a script: `python acestep/acestep_v15_pipeline.py`
|
| 42 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 43 |
+
if project_root not in sys.path:
|
| 44 |
+
sys.path.insert(0, project_root)
|
| 45 |
+
from acestep.handler import AceStepHandler
|
| 46 |
+
from acestep.llm_inference import LLMHandler
|
| 47 |
+
from acestep.dataset_handler import DatasetHandler
|
| 48 |
+
from acestep.gradio_ui import create_gradio_interface
|
| 49 |
+
from acestep.gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def create_demo(init_params=None, language='en'):
|
| 53 |
+
"""
|
| 54 |
+
Create Gradio demo interface
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 58 |
+
If None, service will not be pre-initialized.
|
| 59 |
+
Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
|
| 60 |
+
'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
|
| 61 |
+
'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
|
| 62 |
+
'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
|
| 63 |
+
'language' (UI language code)
|
| 64 |
+
language: UI language code ('en', 'zh', 'ja', default: 'en')
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
Gradio Blocks instance
|
| 68 |
+
"""
|
| 69 |
+
# Use pre-initialized handlers if available, otherwise create new ones
|
| 70 |
+
if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
|
| 71 |
+
dit_handler = init_params['dit_handler']
|
| 72 |
+
llm_handler = init_params['llm_handler']
|
| 73 |
+
else:
|
| 74 |
+
dit_handler = AceStepHandler() # DiT handler
|
| 75 |
+
llm_handler = LLMHandler() # LM handler
|
| 76 |
+
|
| 77 |
+
dataset_handler = DatasetHandler() # Dataset handler
|
| 78 |
+
|
| 79 |
+
# Create Gradio interface with all handlers and initialization parameters
|
| 80 |
+
demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
|
| 81 |
+
|
| 82 |
+
return demo
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def main():
|
| 86 |
+
"""Main entry function"""
|
| 87 |
+
import argparse
|
| 88 |
+
|
| 89 |
+
# Detect GPU memory and get configuration
|
| 90 |
+
gpu_config = get_gpu_config()
|
| 91 |
+
set_global_gpu_config(gpu_config) # Set global config for use across modules
|
| 92 |
+
|
| 93 |
+
gpu_memory_gb = gpu_config.gpu_memory_gb
|
| 94 |
+
auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < VRAM_16GB_MIN_GB
|
| 95 |
+
|
| 96 |
+
# Print GPU configuration info
|
| 97 |
+
print(f"\n{'='*60}")
|
| 98 |
+
print("GPU Configuration Detected:")
|
| 99 |
+
print(f"{'='*60}")
|
| 100 |
+
print(f" GPU Memory: {gpu_memory_gb:.2f} GB")
|
| 101 |
+
print(f" Configuration Tier: {gpu_config.tier}")
|
| 102 |
+
print(f" Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
|
| 103 |
+
print(f" Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
|
| 104 |
+
print(f" Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
|
| 105 |
+
print(f" Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
|
| 106 |
+
print(f" Default LM Init: {gpu_config.init_lm_default}")
|
| 107 |
+
print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}")
|
| 108 |
+
print(f"{'='*60}\n")
|
| 109 |
+
|
| 110 |
+
if auto_offload:
|
| 111 |
+
print(f"Auto-enabling CPU offload (GPU < 16GB)")
|
| 112 |
+
elif gpu_memory_gb > 0:
|
| 113 |
+
print(f"CPU offload disabled by default (GPU >= 16GB)")
|
| 114 |
+
else:
|
| 115 |
+
print("No GPU detected, running on CPU")
|
| 116 |
+
|
| 117 |
+
# Define local outputs directory
|
| 118 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 119 |
+
output_dir = os.path.join(project_root, "gradio_outputs")
|
| 120 |
+
# Normalize path to use forward slashes for Gradio 6 compatibility on Windows
|
| 121 |
+
output_dir = output_dir.replace("\\", "/")
|
| 122 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 123 |
+
print(f"Output directory: {output_dir}")
|
| 124 |
+
|
| 125 |
+
parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
|
| 126 |
+
parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
|
| 127 |
+
parser.add_argument("--share", action="store_true", help="Create a public link")
|
| 128 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 129 |
+
parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
|
| 130 |
+
parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "he", "ja"], help="UI language: en (English), zh (中文), he (עברית), ja (日本語)")
|
| 131 |
+
parser.add_argument(
|
| 132 |
+
"--allowed-path",
|
| 133 |
+
action="append",
|
| 134 |
+
default=[],
|
| 135 |
+
help="Additional allowed file paths for Gradio (repeatable).",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Service mode argument
|
| 139 |
+
parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
|
| 140 |
+
help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
|
| 141 |
+
|
| 142 |
+
# Service initialization arguments
|
| 143 |
+
parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
|
| 144 |
+
parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
|
| 145 |
+
parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
|
| 146 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"], help="Processing device (default: auto)")
|
| 147 |
+
parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Initialize 5Hz LM (default: auto based on GPU memory)")
|
| 148 |
+
parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
|
| 149 |
+
parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt", "mlx"], help="5Hz LM backend (default: vllm, use 'mlx' for native Apple Silicon acceleration)")
|
| 150 |
+
parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
|
| 151 |
+
parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
|
| 152 |
+
parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
|
| 153 |
+
parser.add_argument("--download-source", type=str, default=None, choices=["huggingface", "modelscope", "auto"], help="Preferred model download source (default: auto-detect based on network)")
|
| 154 |
+
|
| 155 |
+
# API mode argument
|
| 156 |
+
parser.add_argument("--enable-api", action="store_true", help="Enable API endpoints (default: False)")
|
| 157 |
+
|
| 158 |
+
# Authentication arguments
|
| 159 |
+
parser.add_argument("--auth-username", type=str, default=None, help="Username for Gradio authentication")
|
| 160 |
+
parser.add_argument("--auth-password", type=str, default=None, help="Password for Gradio authentication")
|
| 161 |
+
parser.add_argument("--api-key", type=str, default=None, help="API key for API endpoints authentication")
|
| 162 |
+
|
| 163 |
+
args = parser.parse_args()
|
| 164 |
+
|
| 165 |
+
# Enable API requires init_service
|
| 166 |
+
if args.enable_api:
|
| 167 |
+
args.init_service = True
|
| 168 |
+
# Load config from .env if not specified
|
| 169 |
+
if args.config_path is None:
|
| 170 |
+
args.config_path = os.environ.get("ACESTEP_CONFIG_PATH")
|
| 171 |
+
if args.lm_model_path is None:
|
| 172 |
+
args.lm_model_path = os.environ.get("ACESTEP_LM_MODEL_PATH")
|
| 173 |
+
if os.environ.get("ACESTEP_LM_BACKEND"):
|
| 174 |
+
args.backend = os.environ.get("ACESTEP_LM_BACKEND")
|
| 175 |
+
|
| 176 |
+
# Service mode defaults (can be configured via .env file)
|
| 177 |
+
if args.service_mode:
|
| 178 |
+
print("Service mode enabled - applying preset configurations...")
|
| 179 |
+
# Force init_service in service mode
|
| 180 |
+
args.init_service = True
|
| 181 |
+
# Default DiT model for service mode (from env or fallback)
|
| 182 |
+
if args.config_path is None:
|
| 183 |
+
args.config_path = os.environ.get(
|
| 184 |
+
"SERVICE_MODE_DIT_MODEL",
|
| 185 |
+
"acestep-v15-turbo-fix-inst-shift-dynamic"
|
| 186 |
+
)
|
| 187 |
+
# Default LM model for service mode (from env or fallback)
|
| 188 |
+
if args.lm_model_path is None:
|
| 189 |
+
args.lm_model_path = os.environ.get(
|
| 190 |
+
"SERVICE_MODE_LM_MODEL",
|
| 191 |
+
"acestep-5Hz-lm-1.7B-v4-fix"
|
| 192 |
+
)
|
| 193 |
+
# Backend for service mode (from env or fallback to vllm)
|
| 194 |
+
args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
|
| 195 |
+
print(f" DiT model: {args.config_path}")
|
| 196 |
+
print(f" LM model: {args.lm_model_path}")
|
| 197 |
+
print(f" Backend: {args.backend}")
|
| 198 |
+
|
| 199 |
+
# Auto-enable CPU offload for tier6 GPUs (16-24GB) when using the 4B LM model
|
| 200 |
+
# The 4B LM (~8GB) + DiT (~4.7GB) + VAE + text encoder exceeds 16-20GB with activations
|
| 201 |
+
if not args.offload_to_cpu and args.lm_model_path and "4B" in args.lm_model_path:
|
| 202 |
+
if 0 < gpu_memory_gb <= 24:
|
| 203 |
+
args.offload_to_cpu = True
|
| 204 |
+
print(f"Auto-enabling CPU offload (4B LM model requires offloading on {gpu_memory_gb:.0f}GB GPU)")
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
init_params = None
|
| 208 |
+
dit_handler = None
|
| 209 |
+
llm_handler = None
|
| 210 |
+
|
| 211 |
+
# If init_service is True, perform initialization before creating UI
|
| 212 |
+
if args.init_service:
|
| 213 |
+
print("Initializing service from command line...")
|
| 214 |
+
|
| 215 |
+
# Create handler instances for initialization
|
| 216 |
+
dit_handler = AceStepHandler()
|
| 217 |
+
llm_handler = LLMHandler()
|
| 218 |
+
|
| 219 |
+
# Auto-select config_path if not provided
|
| 220 |
+
if args.config_path is None:
|
| 221 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 222 |
+
if available_models:
|
| 223 |
+
args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
|
| 224 |
+
print(f"Auto-selected config_path: {args.config_path}")
|
| 225 |
+
else:
|
| 226 |
+
print("Error: No available models found. Please specify --config_path", file=sys.stderr)
|
| 227 |
+
sys.exit(1)
|
| 228 |
+
|
| 229 |
+
# Get project root (same logic as in handler)
|
| 230 |
+
current_file = os.path.abspath(__file__)
|
| 231 |
+
project_root = os.path.dirname(os.path.dirname(current_file))
|
| 232 |
+
|
| 233 |
+
# Determine flash attention setting
|
| 234 |
+
use_flash_attention = args.use_flash_attention
|
| 235 |
+
if use_flash_attention is None:
|
| 236 |
+
use_flash_attention = dit_handler.is_flash_attention_available(args.device)
|
| 237 |
+
|
| 238 |
+
# Determine download source preference
|
| 239 |
+
prefer_source = None
|
| 240 |
+
if args.download_source and args.download_source != "auto":
|
| 241 |
+
prefer_source = args.download_source
|
| 242 |
+
print(f"Using preferred download source: {prefer_source}")
|
| 243 |
+
|
| 244 |
+
# Initialize DiT handler
|
| 245 |
+
print(f"Initializing DiT model: {args.config_path} on {args.device}...")
|
| 246 |
+
init_status, enable_generate = dit_handler.initialize_service(
|
| 247 |
+
project_root=project_root,
|
| 248 |
+
config_path=args.config_path,
|
| 249 |
+
device=args.device,
|
| 250 |
+
use_flash_attention=use_flash_attention,
|
| 251 |
+
compile_model=False,
|
| 252 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 253 |
+
offload_dit_to_cpu=args.offload_dit_to_cpu,
|
| 254 |
+
prefer_source=prefer_source
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
if not enable_generate:
|
| 258 |
+
print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
|
| 259 |
+
sys.exit(1)
|
| 260 |
+
|
| 261 |
+
print(f"DiT model initialized successfully")
|
| 262 |
+
|
| 263 |
+
# Initialize LM handler if requested
|
| 264 |
+
# Auto-determine init_llm based on GPU config if not explicitly set
|
| 265 |
+
if args.init_llm is None:
|
| 266 |
+
args.init_llm = gpu_config.init_lm_default
|
| 267 |
+
print(f"Auto-setting init_llm to {args.init_llm} based on GPU configuration")
|
| 268 |
+
|
| 269 |
+
lm_status = ""
|
| 270 |
+
if args.init_llm:
|
| 271 |
+
if args.lm_model_path is None:
|
| 272 |
+
# Try to get default LM model
|
| 273 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 274 |
+
if available_lm_models:
|
| 275 |
+
args.lm_model_path = available_lm_models[0]
|
| 276 |
+
print(f"Using default LM model: {args.lm_model_path}")
|
| 277 |
+
else:
|
| 278 |
+
print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
|
| 279 |
+
args.init_llm = False
|
| 280 |
+
|
| 281 |
+
if args.init_llm and args.lm_model_path:
|
| 282 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 283 |
+
print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
|
| 284 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 285 |
+
checkpoint_dir=checkpoint_dir,
|
| 286 |
+
lm_model_path=args.lm_model_path,
|
| 287 |
+
backend=args.backend,
|
| 288 |
+
device=args.device,
|
| 289 |
+
offload_to_cpu=args.offload_to_cpu,
|
| 290 |
+
dtype=None,
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
if lm_success:
|
| 294 |
+
print(f"5Hz LM initialized successfully")
|
| 295 |
+
init_status += f"\n{lm_status}"
|
| 296 |
+
else:
|
| 297 |
+
print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
|
| 298 |
+
init_status += f"\n{lm_status}"
|
| 299 |
+
|
| 300 |
+
# Prepare initialization parameters for UI
|
| 301 |
+
init_params = {
|
| 302 |
+
'pre_initialized': True,
|
| 303 |
+
'service_mode': args.service_mode,
|
| 304 |
+
'checkpoint': args.checkpoint,
|
| 305 |
+
'config_path': args.config_path,
|
| 306 |
+
'device': args.device,
|
| 307 |
+
'init_llm': args.init_llm,
|
| 308 |
+
'lm_model_path': args.lm_model_path,
|
| 309 |
+
'backend': args.backend,
|
| 310 |
+
'use_flash_attention': use_flash_attention,
|
| 311 |
+
'offload_to_cpu': args.offload_to_cpu,
|
| 312 |
+
'offload_dit_to_cpu': args.offload_dit_to_cpu,
|
| 313 |
+
'init_status': init_status,
|
| 314 |
+
'enable_generate': enable_generate,
|
| 315 |
+
'dit_handler': dit_handler,
|
| 316 |
+
'llm_handler': llm_handler,
|
| 317 |
+
'language': args.language,
|
| 318 |
+
'gpu_config': gpu_config, # Pass GPU config to UI
|
| 319 |
+
'output_dir': output_dir, # Pass output dir to UI
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
print("Service initialization completed successfully!")
|
| 323 |
+
|
| 324 |
+
# Create and launch demo
|
| 325 |
+
print(f"Creating Gradio interface with language: {args.language}...")
|
| 326 |
+
|
| 327 |
+
# If not using init_service, still pass gpu_config to init_params
|
| 328 |
+
if init_params is None:
|
| 329 |
+
init_params = {
|
| 330 |
+
'gpu_config': gpu_config,
|
| 331 |
+
'language': args.language,
|
| 332 |
+
'output_dir': output_dir, # Pass output dir to UI
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
demo = create_demo(init_params=init_params, language=args.language)
|
| 336 |
+
|
| 337 |
+
# Enable queue for multi-user support
|
| 338 |
+
# This ensures proper request queuing and prevents concurrent generation conflicts
|
| 339 |
+
print("Enabling queue for multi-user support...")
|
| 340 |
+
demo.queue(
|
| 341 |
+
max_size=20, # Maximum queue size (adjust based on your needs)
|
| 342 |
+
status_update_rate="auto", # Update rate for queue status
|
| 343 |
+
default_concurrency_limit=1, # Prevents VRAM saturation
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
print(f"Launching server on {args.server_name}:{args.port}...")
|
| 347 |
+
|
| 348 |
+
# Setup authentication if provided
|
| 349 |
+
auth = None
|
| 350 |
+
if args.auth_username and args.auth_password:
|
| 351 |
+
auth = (args.auth_username, args.auth_password)
|
| 352 |
+
print("Authentication enabled")
|
| 353 |
+
|
| 354 |
+
allowed_paths = [output_dir]
|
| 355 |
+
for p in args.allowed_path:
|
| 356 |
+
if p and p not in allowed_paths:
|
| 357 |
+
allowed_paths.append(p)
|
| 358 |
+
|
| 359 |
+
# Enable API endpoints if requested
|
| 360 |
+
if args.enable_api:
|
| 361 |
+
print("Enabling API endpoints...")
|
| 362 |
+
from acestep.gradio_ui.api_routes import setup_api_routes
|
| 363 |
+
|
| 364 |
+
# Launch Gradio first with prevent_thread_lock=True
|
| 365 |
+
demo.launch(
|
| 366 |
+
server_name=args.server_name,
|
| 367 |
+
server_port=args.port,
|
| 368 |
+
share=args.share,
|
| 369 |
+
debug=args.debug,
|
| 370 |
+
show_error=True,
|
| 371 |
+
prevent_thread_lock=True, # Don't block, so we can add routes
|
| 372 |
+
inbrowser=False,
|
| 373 |
+
auth=auth,
|
| 374 |
+
allowed_paths=allowed_paths, # include output_dir + user-provided
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Now add API routes to Gradio's FastAPI app (app is available after launch)
|
| 378 |
+
setup_api_routes(demo, dit_handler, llm_handler, api_key=args.api_key)
|
| 379 |
+
|
| 380 |
+
if args.api_key:
|
| 381 |
+
print("API authentication enabled")
|
| 382 |
+
print("API endpoints enabled: /health, /v1/models, /release_task, /query_result, /create_random_sample, /format_lyrics")
|
| 383 |
+
|
| 384 |
+
# Keep the main thread alive
|
| 385 |
+
try:
|
| 386 |
+
while True:
|
| 387 |
+
import time
|
| 388 |
+
time.sleep(1)
|
| 389 |
+
except KeyboardInterrupt:
|
| 390 |
+
print("\nShutting down...")
|
| 391 |
+
else:
|
| 392 |
+
demo.launch(
|
| 393 |
+
server_name=args.server_name,
|
| 394 |
+
server_port=args.port,
|
| 395 |
+
share=args.share,
|
| 396 |
+
debug=args.debug,
|
| 397 |
+
show_error=True,
|
| 398 |
+
prevent_thread_lock=False,
|
| 399 |
+
inbrowser=False,
|
| 400 |
+
auth=auth,
|
| 401 |
+
allowed_paths=allowed_paths, # include output_dir + user-provided
|
| 402 |
+
)
|
| 403 |
+
except Exception as e:
|
| 404 |
+
print(f"Error launching Gradio: {e}", file=sys.stderr)
|
| 405 |
+
import traceback
|
| 406 |
+
traceback.print_exc()
|
| 407 |
+
sys.exit(1)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
if __name__ == "__main__":
|
| 411 |
+
main()
|
acestep/api_server.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/audio_utils.py
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio saving and transcoding utility module
|
| 3 |
+
|
| 4 |
+
Independent audio file operations outside of handler, supporting:
|
| 5 |
+
- Save audio tensor/numpy to files (default FLAC format, fast)
|
| 6 |
+
- Format conversion (FLAC/WAV/MP3)
|
| 7 |
+
- Batch processing
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
import hashlib
|
| 12 |
+
import json
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Union, Optional, List, Tuple
|
| 15 |
+
import torch
|
| 16 |
+
import numpy as np
|
| 17 |
+
import torchaudio
|
| 18 |
+
from loguru import logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AudioSaver:
|
| 22 |
+
"""Audio saving and transcoding utility class"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, default_format: str = "flac"):
|
| 25 |
+
"""
|
| 26 |
+
Initialize audio saver
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
default_format: Default save format ('flac', 'wav', 'mp3')
|
| 30 |
+
"""
|
| 31 |
+
self.default_format = default_format.lower()
|
| 32 |
+
if self.default_format not in ["flac", "wav", "mp3"]:
|
| 33 |
+
logger.warning(f"Unsupported format {default_format}, using 'flac'")
|
| 34 |
+
self.default_format = "flac"
|
| 35 |
+
|
| 36 |
+
def save_audio(
|
| 37 |
+
self,
|
| 38 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 39 |
+
output_path: Union[str, Path],
|
| 40 |
+
sample_rate: int = 48000,
|
| 41 |
+
format: Optional[str] = None,
|
| 42 |
+
channels_first: bool = True,
|
| 43 |
+
) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Save audio data to file
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
|
| 49 |
+
output_path: Output file path (extension can be omitted)
|
| 50 |
+
sample_rate: Sample rate
|
| 51 |
+
format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
|
| 52 |
+
channels_first: If True, tensor format is [channels, samples], else [samples, channels]
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Actual saved file path
|
| 56 |
+
"""
|
| 57 |
+
format = (format or self.default_format).lower()
|
| 58 |
+
if format not in ["flac", "wav", "mp3"]:
|
| 59 |
+
logger.warning(f"Unsupported format {format}, using {self.default_format}")
|
| 60 |
+
format = self.default_format
|
| 61 |
+
|
| 62 |
+
# Ensure output path has correct extension
|
| 63 |
+
output_path = Path(output_path)
|
| 64 |
+
if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
|
| 65 |
+
output_path = output_path.with_suffix(f'.{format}')
|
| 66 |
+
|
| 67 |
+
# Convert to torch tensor
|
| 68 |
+
if isinstance(audio_data, np.ndarray):
|
| 69 |
+
if channels_first:
|
| 70 |
+
# numpy [samples, channels] -> tensor [channels, samples]
|
| 71 |
+
audio_tensor = torch.from_numpy(audio_data.T).float()
|
| 72 |
+
else:
|
| 73 |
+
# numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
|
| 74 |
+
audio_tensor = torch.from_numpy(audio_data).float()
|
| 75 |
+
if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
|
| 76 |
+
audio_tensor = audio_tensor.T
|
| 77 |
+
else:
|
| 78 |
+
# torch tensor
|
| 79 |
+
audio_tensor = audio_data.cpu().float()
|
| 80 |
+
if not channels_first and audio_tensor.dim() == 2:
|
| 81 |
+
# [samples, channels] -> [channels, samples]
|
| 82 |
+
if audio_tensor.shape[0] > audio_tensor.shape[1]:
|
| 83 |
+
audio_tensor = audio_tensor.T
|
| 84 |
+
|
| 85 |
+
# Ensure memory is contiguous
|
| 86 |
+
audio_tensor = audio_tensor.contiguous()
|
| 87 |
+
|
| 88 |
+
# Select backend and save
|
| 89 |
+
try:
|
| 90 |
+
if format == "mp3":
|
| 91 |
+
# MP3 uses ffmpeg backend
|
| 92 |
+
torchaudio.save(
|
| 93 |
+
str(output_path),
|
| 94 |
+
audio_tensor,
|
| 95 |
+
sample_rate,
|
| 96 |
+
channels_first=True,
|
| 97 |
+
backend='ffmpeg',
|
| 98 |
+
)
|
| 99 |
+
elif format in ["flac", "wav"]:
|
| 100 |
+
# FLAC and WAV use soundfile backend (fastest)
|
| 101 |
+
torchaudio.save(
|
| 102 |
+
str(output_path),
|
| 103 |
+
audio_tensor,
|
| 104 |
+
sample_rate,
|
| 105 |
+
channels_first=True,
|
| 106 |
+
backend='soundfile',
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
# Other formats use default backend
|
| 110 |
+
torchaudio.save(
|
| 111 |
+
str(output_path),
|
| 112 |
+
audio_tensor,
|
| 113 |
+
sample_rate,
|
| 114 |
+
channels_first=True,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 118 |
+
return str(output_path)
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
try:
|
| 122 |
+
import soundfile as sf
|
| 123 |
+
audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
|
| 124 |
+
sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
|
| 125 |
+
logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
|
| 126 |
+
return str(output_path)
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"[AudioSaver] Failed to save audio: {e}")
|
| 129 |
+
raise
|
| 130 |
+
|
| 131 |
+
def convert_audio(
|
| 132 |
+
self,
|
| 133 |
+
input_path: Union[str, Path],
|
| 134 |
+
output_path: Union[str, Path],
|
| 135 |
+
output_format: str,
|
| 136 |
+
remove_input: bool = False,
|
| 137 |
+
) -> str:
|
| 138 |
+
"""
|
| 139 |
+
Convert audio format
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
input_path: Input audio file path
|
| 143 |
+
output_path: Output audio file path
|
| 144 |
+
output_format: Target format ('flac', 'wav', 'mp3')
|
| 145 |
+
remove_input: Whether to delete input file
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
Output file path
|
| 149 |
+
"""
|
| 150 |
+
input_path = Path(input_path)
|
| 151 |
+
output_path = Path(output_path)
|
| 152 |
+
|
| 153 |
+
if not input_path.exists():
|
| 154 |
+
raise FileNotFoundError(f"Input file not found: {input_path}")
|
| 155 |
+
|
| 156 |
+
# Load audio
|
| 157 |
+
audio_tensor, sample_rate = torchaudio.load(str(input_path))
|
| 158 |
+
|
| 159 |
+
# Save as new format
|
| 160 |
+
output_path = self.save_audio(
|
| 161 |
+
audio_tensor,
|
| 162 |
+
output_path,
|
| 163 |
+
sample_rate=sample_rate,
|
| 164 |
+
format=output_format,
|
| 165 |
+
channels_first=True
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Delete input file if needed
|
| 169 |
+
if remove_input:
|
| 170 |
+
input_path.unlink()
|
| 171 |
+
logger.debug(f"[AudioSaver] Removed input file: {input_path}")
|
| 172 |
+
|
| 173 |
+
return output_path
|
| 174 |
+
|
| 175 |
+
def save_batch(
|
| 176 |
+
self,
|
| 177 |
+
audio_batch: Union[List[torch.Tensor], torch.Tensor],
|
| 178 |
+
output_dir: Union[str, Path],
|
| 179 |
+
file_prefix: str = "audio",
|
| 180 |
+
sample_rate: int = 48000,
|
| 181 |
+
format: Optional[str] = None,
|
| 182 |
+
channels_first: bool = True,
|
| 183 |
+
) -> List[str]:
|
| 184 |
+
"""
|
| 185 |
+
Save audio batch
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
|
| 189 |
+
output_dir: Output directory
|
| 190 |
+
file_prefix: File prefix
|
| 191 |
+
sample_rate: Sample rate
|
| 192 |
+
format: Audio format
|
| 193 |
+
channels_first: Tensor format flag
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
List of saved file paths
|
| 197 |
+
"""
|
| 198 |
+
output_dir = Path(output_dir)
|
| 199 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 200 |
+
|
| 201 |
+
# Process batch
|
| 202 |
+
if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
|
| 203 |
+
# [batch, channels, samples]
|
| 204 |
+
audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
|
| 205 |
+
elif isinstance(audio_batch, list):
|
| 206 |
+
audio_list = audio_batch
|
| 207 |
+
else:
|
| 208 |
+
audio_list = [audio_batch]
|
| 209 |
+
|
| 210 |
+
saved_paths = []
|
| 211 |
+
for i, audio in enumerate(audio_list):
|
| 212 |
+
output_path = output_dir / f"{file_prefix}_{i:04d}"
|
| 213 |
+
saved_path = self.save_audio(
|
| 214 |
+
audio,
|
| 215 |
+
output_path,
|
| 216 |
+
sample_rate=sample_rate,
|
| 217 |
+
format=format,
|
| 218 |
+
channels_first=channels_first
|
| 219 |
+
)
|
| 220 |
+
saved_paths.append(saved_path)
|
| 221 |
+
|
| 222 |
+
return saved_paths
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_audio_file_hash(audio_file) -> str:
|
| 226 |
+
"""
|
| 227 |
+
Get hash identifier for an audio file.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
audio_file: Path to audio file (str) or file-like object
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
Hash string or empty string
|
| 234 |
+
"""
|
| 235 |
+
if audio_file is None:
|
| 236 |
+
return ""
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
if isinstance(audio_file, str):
|
| 240 |
+
if os.path.exists(audio_file):
|
| 241 |
+
with open(audio_file, 'rb') as f:
|
| 242 |
+
return hashlib.md5(f.read()).hexdigest()
|
| 243 |
+
return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
|
| 244 |
+
elif hasattr(audio_file, 'name'):
|
| 245 |
+
return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
|
| 246 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 247 |
+
except Exception:
|
| 248 |
+
return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def generate_uuid_from_params(params_dict) -> str:
|
| 252 |
+
"""
|
| 253 |
+
Generate deterministic UUID from generation parameters.
|
| 254 |
+
Same parameters will always generate the same UUID.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
params_dict: Dictionary of parameters
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
UUID string
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
|
| 264 |
+
hash_obj = hashlib.sha256(params_json.encode('utf-8'))
|
| 265 |
+
hash_hex = hash_obj.hexdigest()
|
| 266 |
+
uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
|
| 267 |
+
return uuid_str
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def generate_uuid_from_audio_data(
|
| 271 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 272 |
+
seed: Optional[int] = None
|
| 273 |
+
) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Generate UUID from audio data (for caching/deduplication)
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
audio_data: Audio data
|
| 279 |
+
seed: Optional seed value
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
UUID string
|
| 283 |
+
"""
|
| 284 |
+
if isinstance(audio_data, torch.Tensor):
|
| 285 |
+
# Convert to numpy and calculate hash
|
| 286 |
+
audio_np = audio_data.cpu().numpy()
|
| 287 |
+
else:
|
| 288 |
+
audio_np = audio_data
|
| 289 |
+
|
| 290 |
+
# Calculate data hash
|
| 291 |
+
data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
|
| 292 |
+
|
| 293 |
+
if seed is not None:
|
| 294 |
+
combined = f"{data_hash}_{seed}"
|
| 295 |
+
return hashlib.md5(combined.encode()).hexdigest()
|
| 296 |
+
|
| 297 |
+
return data_hash
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
# Global default instance
|
| 301 |
+
_default_saver = AudioSaver(default_format="flac")
|
| 302 |
+
|
| 303 |
+
SILENT_RMS_THRESHOLD = 1e-5
|
| 304 |
+
SILENT_PEAK_THRESHOLD = 1e-5
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def is_audio_silent(
|
| 308 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 309 |
+
rms_threshold: float = SILENT_RMS_THRESHOLD,
|
| 310 |
+
peak_threshold: float = SILENT_PEAK_THRESHOLD,
|
| 311 |
+
channels_first: bool = True,
|
| 312 |
+
) -> Tuple[bool, float, float]:
|
| 313 |
+
"""
|
| 314 |
+
Check if audio is silent or near-silent (e.g. zeroed conditioning output).
|
| 315 |
+
Returns (is_silent, rms, peak) where rms/peak are computed over the full signal.
|
| 316 |
+
"""
|
| 317 |
+
if audio_data is None:
|
| 318 |
+
return True, 0.0, 0.0
|
| 319 |
+
if isinstance(audio_data, np.ndarray):
|
| 320 |
+
x = np.asarray(audio_data, dtype=np.float64).ravel()
|
| 321 |
+
else:
|
| 322 |
+
x = audio_data.cpu().float().numpy().ravel()
|
| 323 |
+
if x.size == 0:
|
| 324 |
+
return True, 0.0, 0.0
|
| 325 |
+
rms = float(np.sqrt(np.mean(x * x)))
|
| 326 |
+
peak = float(np.max(np.abs(x)))
|
| 327 |
+
is_silent = rms <= rms_threshold and peak <= peak_threshold
|
| 328 |
+
return is_silent, rms, peak
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def save_audio(
|
| 332 |
+
audio_data: Union[torch.Tensor, np.ndarray],
|
| 333 |
+
output_path: Union[str, Path],
|
| 334 |
+
sample_rate: int = 48000,
|
| 335 |
+
format: Optional[str] = None,
|
| 336 |
+
channels_first: bool = True,
|
| 337 |
+
) -> str:
|
| 338 |
+
"""
|
| 339 |
+
Convenience function: save audio (using default configuration)
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
audio_data: Audio data
|
| 343 |
+
output_path: Output path
|
| 344 |
+
sample_rate: Sample rate
|
| 345 |
+
format: Format (default flac)
|
| 346 |
+
channels_first: Tensor format flag
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
Saved file path
|
| 350 |
+
"""
|
| 351 |
+
return _default_saver.save_audio(
|
| 352 |
+
audio_data, output_path, sample_rate, format, channels_first
|
| 353 |
+
)
|
| 354 |
+
|
acestep/constants.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Constants for ACE-Step
|
| 3 |
+
Centralized constants used across the codebase
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
# ==============================================================================
|
| 7 |
+
# Language Constants
|
| 8 |
+
# ==============================================================================
|
| 9 |
+
|
| 10 |
+
# Supported languages for vocal generation and language detection
|
| 11 |
+
# Covers major world languages with good TTS support in the underlying model
|
| 12 |
+
# 'unknown' is used when language cannot be determined automatically
|
| 13 |
+
VALID_LANGUAGES = [
|
| 14 |
+
'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
|
| 15 |
+
'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
|
| 16 |
+
'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
|
| 17 |
+
'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
|
| 18 |
+
'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
|
| 19 |
+
'unknown'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ==============================================================================
|
| 24 |
+
# Keyscale Constants
|
| 25 |
+
# ==============================================================================
|
| 26 |
+
|
| 27 |
+
# Musical note names using standard Western notation
|
| 28 |
+
KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
|
| 29 |
+
|
| 30 |
+
# Supported accidentals: natural, ASCII sharp/flat, Unicode sharp/flat
|
| 31 |
+
KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
|
| 32 |
+
|
| 33 |
+
# Major and minor scale modes
|
| 34 |
+
KEYSCALE_MODES = ['major', 'minor']
|
| 35 |
+
|
| 36 |
+
# Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
|
| 37 |
+
# Examples: "C major", "F# minor", "B♭ major"
|
| 38 |
+
VALID_KEYSCALES = set()
|
| 39 |
+
for note in KEYSCALE_NOTES:
|
| 40 |
+
for acc in KEYSCALE_ACCIDENTALS:
|
| 41 |
+
for mode in KEYSCALE_MODES:
|
| 42 |
+
VALID_KEYSCALES.add(f"{note}{acc} {mode}")
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ==============================================================================
|
| 46 |
+
# Metadata Range Constants
|
| 47 |
+
# ==============================================================================
|
| 48 |
+
|
| 49 |
+
# BPM (Beats Per Minute) range - covers most musical styles
|
| 50 |
+
# 30 BPM: Very slow ballads, ambient music
|
| 51 |
+
# 300 BPM: Fast electronic dance music, extreme metal
|
| 52 |
+
BPM_MIN = 30
|
| 53 |
+
BPM_MAX = 300
|
| 54 |
+
|
| 55 |
+
# Duration range (in seconds) - balances quality vs. computational cost
|
| 56 |
+
# 10s: Short loops, musical excerpts
|
| 57 |
+
# 600s: Full songs, extended compositions (10 minutes)
|
| 58 |
+
DURATION_MIN = 10
|
| 59 |
+
DURATION_MAX = 600
|
| 60 |
+
|
| 61 |
+
# Valid time signatures - common musical meter patterns
|
| 62 |
+
# 2: 2/4 time (marches, polka)
|
| 63 |
+
# 3: 3/4 time (waltzes, ballads)
|
| 64 |
+
# 4: 4/4 time (most pop, rock, hip-hop)
|
| 65 |
+
# 6: 6/8 time (compound time, folk dances)
|
| 66 |
+
VALID_TIME_SIGNATURES = [2, 3, 4, 6]
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ==============================================================================
|
| 70 |
+
# Task Type Constants
|
| 71 |
+
# ==============================================================================
|
| 72 |
+
|
| 73 |
+
# All supported generation tasks across different model variants
|
| 74 |
+
TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 75 |
+
|
| 76 |
+
# Task types available for turbo models (optimized subset for speed)
|
| 77 |
+
# - text2music: Generate from text descriptions
|
| 78 |
+
# - repaint: Selective audio editing/regeneration
|
| 79 |
+
# - cover: Style transfer using reference audio
|
| 80 |
+
TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
|
| 81 |
+
|
| 82 |
+
# Task types available for base models (full feature set)
|
| 83 |
+
# Additional tasks requiring more computational resources:
|
| 84 |
+
# - extract: Separate individual tracks/stems from audio
|
| 85 |
+
# - lego: Multi-track generation (add layers)
|
| 86 |
+
# - complete: Automatic completion of partial audio
|
| 87 |
+
TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ==============================================================================
|
| 91 |
+
# Instruction Constants
|
| 92 |
+
# ==============================================================================
|
| 93 |
+
|
| 94 |
+
# Default instructions
|
| 95 |
+
DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
|
| 96 |
+
DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
|
| 97 |
+
DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 98 |
+
DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
|
| 99 |
+
DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
|
| 100 |
+
|
| 101 |
+
# Instruction templates for each task type
|
| 102 |
+
# Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
|
| 103 |
+
# These should be formatted using .format() or f-strings when used
|
| 104 |
+
TASK_INSTRUCTIONS = {
|
| 105 |
+
"text2music": "Fill the audio semantic mask based on the given conditions:",
|
| 106 |
+
"repaint": "Repaint the mask area based on the given conditions:",
|
| 107 |
+
"cover": "Generate audio semantic tokens based on the given conditions:",
|
| 108 |
+
"extract": "Extract the {TRACK_NAME} track from the audio:",
|
| 109 |
+
"extract_default": "Extract the track from the audio:",
|
| 110 |
+
"lego": "Generate the {TRACK_NAME} track based on the audio context:",
|
| 111 |
+
"lego_default": "Generate the track based on the audio context:",
|
| 112 |
+
"complete": "Complete the input track with {TRACK_CLASSES}:",
|
| 113 |
+
"complete_default": "Complete the input track:",
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ==============================================================================
|
| 118 |
+
# Track/Instrument Constants
|
| 119 |
+
# ==============================================================================
|
| 120 |
+
|
| 121 |
+
# Supported instrumental track types for multi-track generation and extraction
|
| 122 |
+
# Organized by instrument families for logical grouping:
|
| 123 |
+
# - Wind instruments: woodwinds, brass
|
| 124 |
+
# - Electronic: fx (effects), synth (synthesizer)
|
| 125 |
+
# - String instruments: strings, guitar, bass
|
| 126 |
+
# - Rhythm section: percussion, drums, keyboard
|
| 127 |
+
# - Vocals: backing_vocals, vocals (lead vocals)
|
| 128 |
+
TRACK_NAMES = [
|
| 129 |
+
"woodwinds", "brass", "fx", "synth", "strings", "percussion",
|
| 130 |
+
"keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
# Template for SFT (Supervised Fine-Tuning) model prompts
|
| 134 |
+
# Used to format inputs for the language model with instruction, caption, and metadata
|
| 135 |
+
SFT_GEN_PROMPT = """# Instruction
|
| 136 |
+
{}
|
| 137 |
+
|
| 138 |
+
# Caption
|
| 139 |
+
{}
|
| 140 |
+
|
| 141 |
+
# Metas
|
| 142 |
+
{}<|endoftext|>
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ==============================================================================
|
| 147 |
+
# GPU Memory Configuration Constants
|
| 148 |
+
# ==============================================================================
|
| 149 |
+
|
| 150 |
+
# GPU tier thresholds (in GB)
|
| 151 |
+
GPU_TIER_THRESHOLDS = {
|
| 152 |
+
"tier1": 4, # <= 4GB
|
| 153 |
+
"tier2": 6, # 4-6GB
|
| 154 |
+
"tier3": 8, # 6-8GB
|
| 155 |
+
"tier4": 12, # 8-12GB
|
| 156 |
+
"tier5": 16, # 12-16GB
|
| 157 |
+
"tier6": 24, # 16-24GB
|
| 158 |
+
# "unlimited" for >= 24GB
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
# LM model memory requirements (in GB)
|
| 162 |
+
LM_MODEL_MEMORY_GB = {
|
| 163 |
+
"0.6B": 3.0,
|
| 164 |
+
"1.7B": 8.0,
|
| 165 |
+
"4B": 12.0,
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
# LM model names mapping
|
| 169 |
+
LM_MODEL_NAMES = {
|
| 170 |
+
"0.6B": "acestep-5Hz-lm-0.6B",
|
| 171 |
+
"1.7B": "acestep-5Hz-lm-1.7B",
|
| 172 |
+
"4B": "acestep-5Hz-lm-4B",
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# ==============================================================================
|
| 177 |
+
# Debug Constants
|
| 178 |
+
# ==============================================================================
|
| 179 |
+
|
| 180 |
+
# Tensor debug mode (values: "OFF" | "ON" | "VERBOSE")
|
| 181 |
+
TENSOR_DEBUG_MODE = "OFF"
|
| 182 |
+
|
| 183 |
+
# Placeholder debug switches for other main functionality (default "OFF")
|
| 184 |
+
# Update names/usage as features adopt them.
|
| 185 |
+
DEBUG_API_SERVER = "OFF"
|
| 186 |
+
DEBUG_INFERENCE = "OFF"
|
| 187 |
+
DEBUG_TRAINING = "OFF"
|
| 188 |
+
DEBUG_DATASET = "OFF"
|
| 189 |
+
DEBUG_AUDIO = "OFF"
|
| 190 |
+
DEBUG_LLM = "OFF"
|
| 191 |
+
DEBUG_UI = "OFF"
|
| 192 |
+
DEBUG_MODEL_LOADING = "OFF"
|
| 193 |
+
DEBUG_GPU = "OFF"
|
acestep/constrained_logits_processor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/dataset_handler.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Dataset Handler Module
|
| 3 |
+
|
| 4 |
+
Handles dataset import and exploration functionality for ACE-Step training.
|
| 5 |
+
This module provides a placeholder implementation for dataset operations
|
| 6 |
+
when the full training dataset dependencies are not available.
|
| 7 |
+
|
| 8 |
+
Note: Full dataset functionality requires Text2MusicDataset which may not be
|
| 9 |
+
included in the basic installation to reduce dependencies.
|
| 10 |
+
"""
|
| 11 |
+
from typing import Optional, Tuple, Any, Dict
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DatasetHandler:
|
| 15 |
+
"""
|
| 16 |
+
Dataset Handler for Dataset Explorer functionality.
|
| 17 |
+
|
| 18 |
+
Provides interface for dataset import and exploration features in the Gradio UI.
|
| 19 |
+
When training dependencies are not available, returns appropriate fallback responses.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
"""Initialize dataset handler with empty state"""
|
| 24 |
+
self.dataset = None
|
| 25 |
+
self.dataset_imported = False
|
| 26 |
+
|
| 27 |
+
def import_dataset(self, dataset_type: str) -> str:
|
| 28 |
+
"""
|
| 29 |
+
Import dataset (currently disabled in base installation)
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
dataset_type: Type of dataset to import (e.g., "train", "test", "validation")
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Status message indicating dataset import is disabled
|
| 36 |
+
|
| 37 |
+
Note:
|
| 38 |
+
This is a placeholder implementation. Full dataset support requires:
|
| 39 |
+
- Text2MusicDataset dependency
|
| 40 |
+
- Training data files
|
| 41 |
+
- Additional configuration
|
| 42 |
+
"""
|
| 43 |
+
self.dataset_imported = False
|
| 44 |
+
return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
|
| 45 |
+
|
| 46 |
+
def get_item_data(self, *args, **kwargs) -> Tuple:
|
| 47 |
+
"""
|
| 48 |
+
Get dataset item data (placeholder implementation)
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
*args: Variable arguments (ignored in placeholder)
|
| 52 |
+
**kwargs: Keyword arguments (ignored in placeholder)
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
Tuple of placeholder values matching the expected return format:
|
| 56 |
+
(caption, lyrics, language, bpm, keyscale, ref_audio, src_audio, codes,
|
| 57 |
+
status_msg, instruction, duration, timesig, audio1, audio2, audio3,
|
| 58 |
+
metadata, task_type)
|
| 59 |
+
|
| 60 |
+
Note:
|
| 61 |
+
Returns empty/default values since dataset is not available.
|
| 62 |
+
Real implementation would return actual dataset samples.
|
| 63 |
+
"""
|
| 64 |
+
return (
|
| 65 |
+
"", # caption: empty string
|
| 66 |
+
"", # lyrics: empty string
|
| 67 |
+
"", # language: empty string
|
| 68 |
+
"", # bpm: empty string
|
| 69 |
+
"", # keyscale: empty string
|
| 70 |
+
None, # ref_audio: no audio file
|
| 71 |
+
None, # src_audio: no audio file
|
| 72 |
+
None, # codes: no audio codes
|
| 73 |
+
"❌ Dataset not available", # status_msg: error indicator
|
| 74 |
+
"", # instruction: empty string
|
| 75 |
+
0, # duration: zero
|
| 76 |
+
"", # timesig: empty string
|
| 77 |
+
None, # audio1: no audio
|
| 78 |
+
None, # audio2: no audio
|
| 79 |
+
None, # audio3: no audio
|
| 80 |
+
{}, # metadata: empty dict
|
| 81 |
+
"text2music" # task_type: default task
|
| 82 |
+
)
|
| 83 |
+
|
acestep/debug_utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Debug helpers (global).
|
| 3 |
+
"""
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Optional, Callable, Union
|
| 8 |
+
|
| 9 |
+
from acestep.constants import (
|
| 10 |
+
TENSOR_DEBUG_MODE,
|
| 11 |
+
DEBUG_API_SERVER,
|
| 12 |
+
DEBUG_INFERENCE,
|
| 13 |
+
DEBUG_TRAINING,
|
| 14 |
+
DEBUG_DATASET,
|
| 15 |
+
DEBUG_AUDIO,
|
| 16 |
+
DEBUG_LLM,
|
| 17 |
+
DEBUG_UI,
|
| 18 |
+
DEBUG_MODEL_LOADING,
|
| 19 |
+
DEBUG_GPU,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _normalize_mode(mode: str) -> str:
|
| 24 |
+
return (mode or "").strip().upper()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def is_debug_enabled(mode: str) -> bool:
|
| 28 |
+
return _normalize_mode(mode) != "OFF"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def is_debug_verbose(mode: str) -> bool:
|
| 32 |
+
return _normalize_mode(mode) == "VERBOSE"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def debug_log(message: Union[str, Callable[[], str]], *, mode: str = TENSOR_DEBUG_MODE, prefix: str = "debug") -> None:
|
| 36 |
+
"""Emit a timestamped debug log line if the mode is enabled."""
|
| 37 |
+
if not is_debug_enabled(mode):
|
| 38 |
+
return
|
| 39 |
+
if callable(message):
|
| 40 |
+
message = message()
|
| 41 |
+
ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
| 42 |
+
print(f"[{prefix}] {ts} {message}", flush=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Placeholder debug switches registry (for centralized access)
|
| 46 |
+
DEBUG_SWITCHES = {
|
| 47 |
+
"tensor": TENSOR_DEBUG_MODE,
|
| 48 |
+
"api_server": DEBUG_API_SERVER,
|
| 49 |
+
"inference": DEBUG_INFERENCE,
|
| 50 |
+
"training": DEBUG_TRAINING,
|
| 51 |
+
"dataset": DEBUG_DATASET,
|
| 52 |
+
"audio": DEBUG_AUDIO,
|
| 53 |
+
"llm": DEBUG_LLM,
|
| 54 |
+
"ui": DEBUG_UI,
|
| 55 |
+
"model_loading": DEBUG_MODEL_LOADING,
|
| 56 |
+
"gpu": DEBUG_GPU,
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_debug_mode(name: str, default: str = "OFF") -> str:
|
| 61 |
+
"""Fetch a placeholder debug mode by name."""
|
| 62 |
+
return DEBUG_SWITCHES.get((name or "").strip().lower(), default)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def debug_log_for(name: str, message: Union[str, Callable[[], str]], *, prefix: str | None = None) -> None:
|
| 66 |
+
"""Emit a timestamped debug log for a named subsystem."""
|
| 67 |
+
mode = get_debug_mode(name)
|
| 68 |
+
debug_log(message, mode=mode, prefix=prefix or name)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def debug_start_for(name: str, label: str) -> Optional[float]:
|
| 72 |
+
"""Start timing for a named subsystem."""
|
| 73 |
+
mode = get_debug_mode(name)
|
| 74 |
+
return debug_start(label, mode=mode, prefix=name)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def debug_end_for(name: str, label: str, start_ts: Optional[float]) -> None:
|
| 78 |
+
"""End timing for a named subsystem."""
|
| 79 |
+
mode = get_debug_mode(name)
|
| 80 |
+
debug_end(label, start_ts, mode=mode, prefix=name)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def debug_log_verbose_for(name: str, message: Union[str, Callable[[], str]], *, prefix: str | None = None) -> None:
|
| 84 |
+
"""Emit a timestamped debug log only in VERBOSE mode for a named subsystem."""
|
| 85 |
+
mode = get_debug_mode(name)
|
| 86 |
+
if not is_debug_verbose(mode):
|
| 87 |
+
return
|
| 88 |
+
debug_log(message, mode=mode, prefix=prefix or name)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def debug_start_verbose_for(name: str, label: str) -> Optional[float]:
|
| 92 |
+
"""Start timing only in VERBOSE mode for a named subsystem."""
|
| 93 |
+
mode = get_debug_mode(name)
|
| 94 |
+
if not is_debug_verbose(mode):
|
| 95 |
+
return None
|
| 96 |
+
return debug_start(label, mode=mode, prefix=name)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def debug_end_verbose_for(name: str, label: str, start_ts: Optional[float]) -> None:
|
| 100 |
+
"""End timing only in VERBOSE mode for a named subsystem."""
|
| 101 |
+
mode = get_debug_mode(name)
|
| 102 |
+
if not is_debug_verbose(mode):
|
| 103 |
+
return
|
| 104 |
+
debug_end(label, start_ts, mode=mode, prefix=name)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def debug_start(name: str, *, mode: str = TENSOR_DEBUG_MODE, prefix: str = "debug") -> Optional[float]:
|
| 108 |
+
"""Return a start timestamp (perf counter) if enabled, otherwise None."""
|
| 109 |
+
if not is_debug_enabled(mode):
|
| 110 |
+
return None
|
| 111 |
+
debug_log(f"START {name}", mode=mode, prefix=prefix)
|
| 112 |
+
from time import perf_counter
|
| 113 |
+
return perf_counter()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def debug_end(name: str, start_ts: Optional[float], *, mode: str = TENSOR_DEBUG_MODE, prefix: str = "debug") -> None:
|
| 117 |
+
"""Emit an END log with elapsed ms if enabled and start_ts is present."""
|
| 118 |
+
if start_ts is None or not is_debug_enabled(mode):
|
| 119 |
+
return
|
| 120 |
+
from time import perf_counter
|
| 121 |
+
elapsed_ms = (perf_counter() - start_ts) * 1000.0
|
| 122 |
+
debug_log(f"END {name} ({elapsed_ms:.1f} ms)", mode=mode, prefix=prefix)
|
acestep/dit_alignment_score.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
DiT Alignment Score Module
|
| 3 |
+
|
| 4 |
+
This module provides lyrics-to-audio alignment using cross-attention matrices
|
| 5 |
+
from DiT model for generating LRC timestamps.
|
| 6 |
+
|
| 7 |
+
Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
|
| 8 |
+
"""
|
| 9 |
+
import numba
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from dataclasses import dataclass, asdict
|
| 14 |
+
from typing import List, Dict, Any, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ================= Data Classes =================
|
| 18 |
+
@dataclass
|
| 19 |
+
class TokenTimestamp:
|
| 20 |
+
"""Stores per-token timing information."""
|
| 21 |
+
token_id: int
|
| 22 |
+
text: str
|
| 23 |
+
start: float
|
| 24 |
+
end: float
|
| 25 |
+
probability: float
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class SentenceTimestamp:
|
| 30 |
+
"""Stores per-sentence timing information with token list."""
|
| 31 |
+
text: str
|
| 32 |
+
start: float
|
| 33 |
+
end: float
|
| 34 |
+
tokens: List[TokenTimestamp]
|
| 35 |
+
confidence: float
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ================= DTW Algorithm (Numba Optimized) =================
|
| 39 |
+
@numba.jit(nopython=True)
|
| 40 |
+
def dtw_cpu(x: np.ndarray):
|
| 41 |
+
"""
|
| 42 |
+
Dynamic Time Warping algorithm optimized with Numba.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
x: Cost matrix of shape [N, M]
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
Tuple of (text_indices, time_indices) arrays
|
| 49 |
+
"""
|
| 50 |
+
N, M = x.shape
|
| 51 |
+
# Use float32 for memory efficiency
|
| 52 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
| 53 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
| 54 |
+
cost[0, 0] = 0
|
| 55 |
+
|
| 56 |
+
for j in range(1, M + 1):
|
| 57 |
+
for i in range(1, N + 1):
|
| 58 |
+
c0 = cost[i - 1, j - 1]
|
| 59 |
+
c1 = cost[i - 1, j]
|
| 60 |
+
c2 = cost[i, j - 1]
|
| 61 |
+
|
| 62 |
+
if c0 < c1 and c0 < c2:
|
| 63 |
+
c, t = c0, 0
|
| 64 |
+
elif c1 < c0 and c1 < c2:
|
| 65 |
+
c, t = c1, 1
|
| 66 |
+
else:
|
| 67 |
+
c, t = c2, 2
|
| 68 |
+
|
| 69 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
| 70 |
+
trace[i, j] = t
|
| 71 |
+
|
| 72 |
+
return _backtrace(trace, N, M)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@numba.jit(nopython=True)
|
| 76 |
+
def _backtrace(trace: np.ndarray, N: int, M: int):
|
| 77 |
+
"""
|
| 78 |
+
Optimized backtrace function for DTW.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
trace: Trace matrix of shape (N+1, M+1)
|
| 82 |
+
N, M: Original matrix dimensions
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Path array of shape (2, path_len) - first row is text indices, second is time indices
|
| 86 |
+
"""
|
| 87 |
+
# Boundary handling
|
| 88 |
+
trace[0, :] = 2
|
| 89 |
+
trace[:, 0] = 1
|
| 90 |
+
|
| 91 |
+
# Pre-allocate array, max path length is N+M
|
| 92 |
+
max_path_len = N + M
|
| 93 |
+
path = np.zeros((2, max_path_len), dtype=np.int32)
|
| 94 |
+
|
| 95 |
+
i, j = N, M
|
| 96 |
+
path_idx = max_path_len - 1
|
| 97 |
+
|
| 98 |
+
while i > 0 or j > 0:
|
| 99 |
+
path[0, path_idx] = i - 1 # text index
|
| 100 |
+
path[1, path_idx] = j - 1 # time index
|
| 101 |
+
path_idx -= 1
|
| 102 |
+
|
| 103 |
+
t = trace[i, j]
|
| 104 |
+
if t == 0:
|
| 105 |
+
i -= 1
|
| 106 |
+
j -= 1
|
| 107 |
+
elif t == 1:
|
| 108 |
+
i -= 1
|
| 109 |
+
elif t == 2:
|
| 110 |
+
j -= 1
|
| 111 |
+
else:
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
actual_len = max_path_len - path_idx - 1
|
| 115 |
+
return path[:, path_idx + 1:max_path_len]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# ================= Utility Functions =================
|
| 119 |
+
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
|
| 120 |
+
"""
|
| 121 |
+
Apply median filter to tensor.
|
| 122 |
+
|
| 123 |
+
Args:
|
| 124 |
+
x: Input tensor
|
| 125 |
+
filter_width: Width of median filter
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Filtered tensor
|
| 129 |
+
"""
|
| 130 |
+
pad_width = filter_width // 2
|
| 131 |
+
if x.shape[-1] <= pad_width:
|
| 132 |
+
return x
|
| 133 |
+
if x.ndim == 2:
|
| 134 |
+
x = x[None, :]
|
| 135 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
| 136 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
| 137 |
+
if result.ndim > 2:
|
| 138 |
+
result = result.squeeze(0)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# ================= Main Aligner Class =================
|
| 143 |
+
class MusicStampsAligner:
|
| 144 |
+
"""
|
| 145 |
+
Aligner class for generating lyrics timestamps from cross-attention matrices.
|
| 146 |
+
|
| 147 |
+
Uses bidirectional consensus denoising and DTW for alignment.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, tokenizer):
|
| 151 |
+
"""
|
| 152 |
+
Initialize the aligner.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
tokenizer: Text tokenizer for decoding tokens
|
| 156 |
+
"""
|
| 157 |
+
self.tokenizer = tokenizer
|
| 158 |
+
|
| 159 |
+
def _apply_bidirectional_consensus(
|
| 160 |
+
self,
|
| 161 |
+
weights_stack: torch.Tensor,
|
| 162 |
+
violence_level: float,
|
| 163 |
+
medfilt_width: int
|
| 164 |
+
) -> tuple:
|
| 165 |
+
"""
|
| 166 |
+
Core denoising logic using bidirectional consensus.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
weights_stack: Attention weights [Heads, Tokens, Frames]
|
| 170 |
+
violence_level: Denoising strength coefficient
|
| 171 |
+
medfilt_width: Median filter width
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Tuple of (calc_matrix, energy_matrix) as numpy arrays
|
| 175 |
+
"""
|
| 176 |
+
# A. Bidirectional Consensus
|
| 177 |
+
row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
|
| 178 |
+
col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
|
| 179 |
+
processed = row_prob * col_prob
|
| 180 |
+
|
| 181 |
+
# 1. Row suppression (kill horizontal crossing lines)
|
| 182 |
+
row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
|
| 183 |
+
processed = processed - (violence_level * row_medians)
|
| 184 |
+
processed = torch.relu(processed)
|
| 185 |
+
|
| 186 |
+
# 2. Column suppression (kill vertical crossing lines)
|
| 187 |
+
col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
|
| 188 |
+
processed = processed - (violence_level * col_medians)
|
| 189 |
+
processed = torch.relu(processed)
|
| 190 |
+
|
| 191 |
+
# C. Power sharpening
|
| 192 |
+
processed = processed ** 2
|
| 193 |
+
|
| 194 |
+
# Energy matrix for confidence
|
| 195 |
+
energy_matrix = processed.mean(dim=0).cpu().numpy()
|
| 196 |
+
|
| 197 |
+
# D. Z-Score normalization
|
| 198 |
+
std, mean = torch.std_mean(processed, unbiased=False)
|
| 199 |
+
weights_processed = (processed - mean) / (std + 1e-9)
|
| 200 |
+
|
| 201 |
+
# E. Median filtering
|
| 202 |
+
weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
|
| 203 |
+
calc_matrix = weights_processed.mean(dim=0).numpy()
|
| 204 |
+
|
| 205 |
+
return calc_matrix, energy_matrix
|
| 206 |
+
|
| 207 |
+
def _preprocess_attention(
|
| 208 |
+
self,
|
| 209 |
+
attention_matrix: torch.Tensor,
|
| 210 |
+
custom_config: Dict[int, List[int]],
|
| 211 |
+
violence_level: float,
|
| 212 |
+
medfilt_width: int = 7
|
| 213 |
+
) -> tuple:
|
| 214 |
+
"""
|
| 215 |
+
Preprocess attention matrix for alignment.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
|
| 219 |
+
custom_config: Dict mapping layer indices to head indices
|
| 220 |
+
violence_level: Denoising strength
|
| 221 |
+
medfilt_width: Median filter width
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
Tuple of (calc_matrix, energy_matrix, visual_matrix)
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 227 |
+
weights = torch.tensor(attention_matrix)
|
| 228 |
+
else:
|
| 229 |
+
weights = attention_matrix.clone()
|
| 230 |
+
|
| 231 |
+
weights = weights.cpu().float()
|
| 232 |
+
|
| 233 |
+
selected_tensors = []
|
| 234 |
+
for layer_idx, head_indices in custom_config.items():
|
| 235 |
+
for head_idx in head_indices:
|
| 236 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 237 |
+
head_matrix = weights[layer_idx, head_idx]
|
| 238 |
+
selected_tensors.append(head_matrix)
|
| 239 |
+
|
| 240 |
+
if not selected_tensors:
|
| 241 |
+
return None, None, None
|
| 242 |
+
|
| 243 |
+
# Stack selected heads: [Heads, Tokens, Frames]
|
| 244 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 245 |
+
visual_matrix = weights_stack.mean(dim=0).numpy()
|
| 246 |
+
|
| 247 |
+
calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
|
| 248 |
+
weights_stack, violence_level, medfilt_width
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
return calc_matrix, energy_matrix, visual_matrix
|
| 252 |
+
|
| 253 |
+
def stamps_align_info(
|
| 254 |
+
self,
|
| 255 |
+
attention_matrix: torch.Tensor,
|
| 256 |
+
lyrics_tokens: List[int],
|
| 257 |
+
total_duration_seconds: float,
|
| 258 |
+
custom_config: Dict[int, List[int]],
|
| 259 |
+
return_matrices: bool = False,
|
| 260 |
+
violence_level: float = 2.0,
|
| 261 |
+
medfilt_width: int = 1
|
| 262 |
+
) -> Dict[str, Any]:
|
| 263 |
+
"""
|
| 264 |
+
Get alignment information from attention matrix.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
|
| 268 |
+
lyrics_tokens: List of lyrics token IDs
|
| 269 |
+
total_duration_seconds: Total audio duration in seconds
|
| 270 |
+
custom_config: Dict mapping layer indices to head indices
|
| 271 |
+
return_matrices: Whether to return intermediate matrices
|
| 272 |
+
violence_level: Denoising strength
|
| 273 |
+
medfilt_width: Median filter width
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
|
| 277 |
+
and optionally energy_matrix and vis_matrix
|
| 278 |
+
"""
|
| 279 |
+
calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
|
| 280 |
+
attention_matrix, custom_config, violence_level, medfilt_width
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
if calc_matrix is None:
|
| 284 |
+
return {
|
| 285 |
+
"calc_matrix": None,
|
| 286 |
+
"lyrics_tokens": lyrics_tokens,
|
| 287 |
+
"total_duration_seconds": total_duration_seconds,
|
| 288 |
+
"error": "No valid attention heads found"
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
return_dict = {
|
| 292 |
+
"calc_matrix": calc_matrix,
|
| 293 |
+
"lyrics_tokens": lyrics_tokens,
|
| 294 |
+
"total_duration_seconds": total_duration_seconds
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
if return_matrices:
|
| 298 |
+
return_dict['energy_matrix'] = energy_matrix
|
| 299 |
+
return_dict['vis_matrix'] = visual_matrix
|
| 300 |
+
|
| 301 |
+
return return_dict
|
| 302 |
+
|
| 303 |
+
def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
|
| 304 |
+
"""
|
| 305 |
+
Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
|
| 306 |
+
|
| 307 |
+
For Chinese and other multi-byte characters, the tokenizer may split them
|
| 308 |
+
into multiple byte-level tokens. Decoding each token individually produces
|
| 309 |
+
invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
|
| 310 |
+
to correctly track which characters each token contributes.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
token_ids: List of token IDs
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
List of decoded text for each token position
|
| 317 |
+
"""
|
| 318 |
+
decoded_tokens = []
|
| 319 |
+
prev_bytes = b""
|
| 320 |
+
|
| 321 |
+
for i in range(len(token_ids)):
|
| 322 |
+
# Decode tokens from start to current position
|
| 323 |
+
current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
|
| 324 |
+
current_bytes = current_text.encode('utf-8', errors='surrogatepass')
|
| 325 |
+
|
| 326 |
+
# The contribution of current token is the new bytes added
|
| 327 |
+
if len(current_bytes) >= len(prev_bytes):
|
| 328 |
+
new_bytes = current_bytes[len(prev_bytes):]
|
| 329 |
+
# Try to decode the new bytes; if incomplete, use empty string
|
| 330 |
+
try:
|
| 331 |
+
token_text = new_bytes.decode('utf-8')
|
| 332 |
+
except UnicodeDecodeError:
|
| 333 |
+
# Incomplete UTF-8 sequence, this token doesn't complete a character
|
| 334 |
+
token_text = ""
|
| 335 |
+
else:
|
| 336 |
+
# Edge case: current decode is shorter (shouldn't happen normally)
|
| 337 |
+
token_text = ""
|
| 338 |
+
|
| 339 |
+
decoded_tokens.append(token_text)
|
| 340 |
+
prev_bytes = current_bytes
|
| 341 |
+
|
| 342 |
+
return decoded_tokens
|
| 343 |
+
|
| 344 |
+
def token_timestamps(
|
| 345 |
+
self,
|
| 346 |
+
calc_matrix: np.ndarray,
|
| 347 |
+
lyrics_tokens: List[int],
|
| 348 |
+
total_duration_seconds: float
|
| 349 |
+
) -> List[TokenTimestamp]:
|
| 350 |
+
"""
|
| 351 |
+
Generate per-token timestamps using DTW.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
calc_matrix: Processed attention matrix [Tokens, Frames]
|
| 355 |
+
lyrics_tokens: List of token IDs
|
| 356 |
+
total_duration_seconds: Total audio duration
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
List of TokenTimestamp objects
|
| 360 |
+
"""
|
| 361 |
+
n_frames = calc_matrix.shape[-1]
|
| 362 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
|
| 363 |
+
|
| 364 |
+
seconds_per_frame = total_duration_seconds / n_frames
|
| 365 |
+
alignment_results = []
|
| 366 |
+
|
| 367 |
+
# Use incremental decoding to properly handle multi-byte UTF-8 characters
|
| 368 |
+
decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
|
| 369 |
+
|
| 370 |
+
for i in range(len(lyrics_tokens)):
|
| 371 |
+
mask = (text_indices == i)
|
| 372 |
+
|
| 373 |
+
if not np.any(mask):
|
| 374 |
+
start = alignment_results[-1].end if alignment_results else 0.0
|
| 375 |
+
end = start
|
| 376 |
+
token_conf = 0.0
|
| 377 |
+
else:
|
| 378 |
+
times = time_indices[mask] * seconds_per_frame
|
| 379 |
+
start = times[0]
|
| 380 |
+
end = times[-1]
|
| 381 |
+
token_conf = 0.0
|
| 382 |
+
|
| 383 |
+
if end < start:
|
| 384 |
+
end = start
|
| 385 |
+
|
| 386 |
+
alignment_results.append(TokenTimestamp(
|
| 387 |
+
token_id=lyrics_tokens[i],
|
| 388 |
+
text=decoded_tokens[i],
|
| 389 |
+
start=float(start),
|
| 390 |
+
end=float(end),
|
| 391 |
+
probability=token_conf
|
| 392 |
+
))
|
| 393 |
+
|
| 394 |
+
return alignment_results
|
| 395 |
+
|
| 396 |
+
def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
|
| 397 |
+
"""
|
| 398 |
+
Decode a sentence by decoding all token IDs together.
|
| 399 |
+
This avoids UTF-8 encoding issues from joining individual token texts.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
tokens: List of TokenTimestamp objects
|
| 403 |
+
|
| 404 |
+
Returns:
|
| 405 |
+
Properly decoded sentence text
|
| 406 |
+
"""
|
| 407 |
+
token_ids = [t.token_id for t in tokens]
|
| 408 |
+
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
|
| 409 |
+
|
| 410 |
+
def sentence_timestamps(
|
| 411 |
+
self,
|
| 412 |
+
token_alignment: List[TokenTimestamp]
|
| 413 |
+
) -> List[SentenceTimestamp]:
|
| 414 |
+
"""
|
| 415 |
+
Group token timestamps into sentence timestamps.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
token_alignment: List of TokenTimestamp objects
|
| 419 |
+
|
| 420 |
+
Returns:
|
| 421 |
+
List of SentenceTimestamp objects
|
| 422 |
+
"""
|
| 423 |
+
results = []
|
| 424 |
+
current_tokens = []
|
| 425 |
+
|
| 426 |
+
for token in token_alignment:
|
| 427 |
+
current_tokens.append(token)
|
| 428 |
+
|
| 429 |
+
if '\n' in token.text:
|
| 430 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 431 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 432 |
+
|
| 433 |
+
if full_text.strip():
|
| 434 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 435 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 436 |
+
|
| 437 |
+
results.append(SentenceTimestamp(
|
| 438 |
+
text=full_text.strip(),
|
| 439 |
+
start=round(current_tokens[0].start, 3),
|
| 440 |
+
end=round(current_tokens[-1].end, 3),
|
| 441 |
+
tokens=list(current_tokens),
|
| 442 |
+
confidence=sent_conf
|
| 443 |
+
))
|
| 444 |
+
|
| 445 |
+
current_tokens = []
|
| 446 |
+
|
| 447 |
+
# Handle last sentence
|
| 448 |
+
if current_tokens:
|
| 449 |
+
# Decode all token IDs together to avoid UTF-8 issues
|
| 450 |
+
full_text = self._decode_sentence_from_tokens(current_tokens)
|
| 451 |
+
if full_text.strip():
|
| 452 |
+
valid_scores = [t.probability for t in current_tokens if t.probability > 0]
|
| 453 |
+
sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
|
| 454 |
+
|
| 455 |
+
results.append(SentenceTimestamp(
|
| 456 |
+
text=full_text.strip(),
|
| 457 |
+
start=round(current_tokens[0].start, 3),
|
| 458 |
+
end=round(current_tokens[-1].end, 3),
|
| 459 |
+
tokens=list(current_tokens),
|
| 460 |
+
confidence=sent_conf
|
| 461 |
+
))
|
| 462 |
+
|
| 463 |
+
# Normalize confidence scores
|
| 464 |
+
if results:
|
| 465 |
+
all_scores = [s.confidence for s in results]
|
| 466 |
+
min_score = min(all_scores)
|
| 467 |
+
max_score = max(all_scores)
|
| 468 |
+
score_range = max_score - min_score
|
| 469 |
+
|
| 470 |
+
if score_range > 1e-9:
|
| 471 |
+
for s in results:
|
| 472 |
+
normalized_score = (s.confidence - min_score) / score_range
|
| 473 |
+
s.confidence = round(normalized_score, 2)
|
| 474 |
+
else:
|
| 475 |
+
for s in results:
|
| 476 |
+
s.confidence = round(s.confidence, 2)
|
| 477 |
+
|
| 478 |
+
return results
|
| 479 |
+
|
| 480 |
+
def format_lrc(
|
| 481 |
+
self,
|
| 482 |
+
sentence_timestamps: List[SentenceTimestamp],
|
| 483 |
+
include_end_time: bool = False
|
| 484 |
+
) -> str:
|
| 485 |
+
"""
|
| 486 |
+
Format sentence timestamps as LRC lyrics format.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
sentence_timestamps: List of SentenceTimestamp objects
|
| 490 |
+
include_end_time: Whether to include end time (enhanced LRC format)
|
| 491 |
+
|
| 492 |
+
Returns:
|
| 493 |
+
LRC formatted string
|
| 494 |
+
"""
|
| 495 |
+
lines = []
|
| 496 |
+
|
| 497 |
+
for sentence in sentence_timestamps:
|
| 498 |
+
# Convert seconds to mm:ss.xx format
|
| 499 |
+
start_minutes = int(sentence.start // 60)
|
| 500 |
+
start_seconds = sentence.start % 60
|
| 501 |
+
|
| 502 |
+
if include_end_time:
|
| 503 |
+
end_minutes = int(sentence.end // 60)
|
| 504 |
+
end_seconds = sentence.end % 60
|
| 505 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
|
| 506 |
+
else:
|
| 507 |
+
timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
|
| 508 |
+
|
| 509 |
+
# Clean the text (remove structural tags like [verse], [chorus])
|
| 510 |
+
text = sentence.text
|
| 511 |
+
|
| 512 |
+
lines.append(f"{timestamp}{text}")
|
| 513 |
+
|
| 514 |
+
return "\n".join(lines)
|
| 515 |
+
|
| 516 |
+
def get_timestamps_and_lrc(
|
| 517 |
+
self,
|
| 518 |
+
calc_matrix: np.ndarray,
|
| 519 |
+
lyrics_tokens: List[int],
|
| 520 |
+
total_duration_seconds: float
|
| 521 |
+
) -> Dict[str, Any]:
|
| 522 |
+
"""
|
| 523 |
+
Convenience method to get both timestamps and LRC in one call.
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
calc_matrix: Processed attention matrix
|
| 527 |
+
lyrics_tokens: List of token IDs
|
| 528 |
+
total_duration_seconds: Total audio duration
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Dict containing token_timestamps, sentence_timestamps, and lrc_text
|
| 532 |
+
"""
|
| 533 |
+
token_stamps = self.token_timestamps(
|
| 534 |
+
calc_matrix=calc_matrix,
|
| 535 |
+
lyrics_tokens=lyrics_tokens,
|
| 536 |
+
total_duration_seconds=total_duration_seconds
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
sentence_stamps = self.sentence_timestamps(token_stamps)
|
| 540 |
+
lrc_text = self.format_lrc(sentence_stamps)
|
| 541 |
+
|
| 542 |
+
return {
|
| 543 |
+
"token_timestamps": token_stamps,
|
| 544 |
+
"sentence_timestamps": sentence_stamps,
|
| 545 |
+
"lrc_text": lrc_text
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
class MusicLyricScorer:
|
| 550 |
+
"""
|
| 551 |
+
Scorer class for evaluating lyrics-to-audio alignment quality.
|
| 552 |
+
|
| 553 |
+
Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
|
| 554 |
+
using tensor operations for potential differentiability or GPU acceleration.
|
| 555 |
+
"""
|
| 556 |
+
|
| 557 |
+
def __init__(self, tokenizer: Any):
|
| 558 |
+
"""
|
| 559 |
+
Initialize the aligner.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
tokenizer: Tokenizer instance (must implement .decode()).
|
| 563 |
+
"""
|
| 564 |
+
self.tokenizer = tokenizer
|
| 565 |
+
|
| 566 |
+
def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
|
| 567 |
+
"""
|
| 568 |
+
Generate a mask distinguishing lyrics (1) from structural tags (0).
|
| 569 |
+
Uses self.tokenizer to decode tokens.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
token_ids: List of token IDs.
|
| 573 |
+
|
| 574 |
+
Returns:
|
| 575 |
+
Numpy array of shape [len(token_ids)] with 1 or 0.
|
| 576 |
+
"""
|
| 577 |
+
decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
|
| 578 |
+
mask = np.ones(len(token_ids), dtype=np.int32)
|
| 579 |
+
in_bracket = False
|
| 580 |
+
|
| 581 |
+
for i, token_str in enumerate(decoded_tokens):
|
| 582 |
+
if '[' in token_str:
|
| 583 |
+
in_bracket = True
|
| 584 |
+
if in_bracket:
|
| 585 |
+
mask[i] = 0
|
| 586 |
+
if ']' in token_str:
|
| 587 |
+
in_bracket = False
|
| 588 |
+
mask[i] = 0
|
| 589 |
+
return mask
|
| 590 |
+
|
| 591 |
+
def _preprocess_attention(
|
| 592 |
+
self,
|
| 593 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 594 |
+
custom_config: Dict[int, List[int]],
|
| 595 |
+
medfilt_width: int = 1
|
| 596 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
|
| 597 |
+
"""
|
| 598 |
+
Extracts and normalizes the attention matrix.
|
| 599 |
+
|
| 600 |
+
Logic V4: Uses Min-Max normalization to highlight energy differences.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
|
| 604 |
+
custom_config: Config mapping layers to heads.
|
| 605 |
+
medfilt_width: Width for median filtering.
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
|
| 609 |
+
"""
|
| 610 |
+
# 1. Prepare Tensor
|
| 611 |
+
if not isinstance(attention_matrix, torch.Tensor):
|
| 612 |
+
weights = torch.tensor(attention_matrix)
|
| 613 |
+
else:
|
| 614 |
+
weights = attention_matrix.clone()
|
| 615 |
+
weights = weights.cpu().float()
|
| 616 |
+
|
| 617 |
+
# 2. Select Heads based on config
|
| 618 |
+
selected_tensors = []
|
| 619 |
+
for layer_idx, head_indices in custom_config.items():
|
| 620 |
+
for head_idx in head_indices:
|
| 621 |
+
if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
|
| 622 |
+
selected_tensors.append(weights[layer_idx, head_idx])
|
| 623 |
+
|
| 624 |
+
if not selected_tensors:
|
| 625 |
+
return None, None, None
|
| 626 |
+
|
| 627 |
+
weights_stack = torch.stack(selected_tensors, dim=0)
|
| 628 |
+
|
| 629 |
+
# 3. Average Heads
|
| 630 |
+
avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
|
| 631 |
+
|
| 632 |
+
# 4. Preprocessing Logic
|
| 633 |
+
# Min-Max normalization preserving energy distribution
|
| 634 |
+
# Median filter is applied to the energy matrix
|
| 635 |
+
energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
|
| 636 |
+
energy_matrix = energy_tensor.numpy()
|
| 637 |
+
|
| 638 |
+
e_min, e_max = energy_matrix.min(), energy_matrix.max()
|
| 639 |
+
|
| 640 |
+
if e_max - e_min > 1e-9:
|
| 641 |
+
energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
|
| 642 |
+
else:
|
| 643 |
+
energy_matrix = np.zeros_like(energy_matrix)
|
| 644 |
+
|
| 645 |
+
# Contrast enhancement for DTW pathfinding
|
| 646 |
+
# calc_matrix is used for pathfinding, energy_matrix for scoring
|
| 647 |
+
calc_matrix = energy_matrix ** 2
|
| 648 |
+
|
| 649 |
+
return calc_matrix, energy_matrix, avg_weights
|
| 650 |
+
|
| 651 |
+
def _compute_alignment_metrics(
|
| 652 |
+
self,
|
| 653 |
+
energy_matrix: torch.Tensor,
|
| 654 |
+
path_coords: torch.Tensor,
|
| 655 |
+
type_mask: torch.Tensor,
|
| 656 |
+
time_weight: float = 0.01,
|
| 657 |
+
overlap_frames: float = 9.0,
|
| 658 |
+
instrumental_weight: float = 1.0
|
| 659 |
+
) -> Tuple[float, float, float]:
|
| 660 |
+
"""
|
| 661 |
+
Core metric calculation logic using high-precision Tensor operations.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
energy_matrix: Normalized energy [Rows, Cols].
|
| 665 |
+
path_coords: DTW path coordinates [Steps, 2].
|
| 666 |
+
type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
|
| 667 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 668 |
+
overlap_frames: Allowed overlap for monotonicity check.
|
| 669 |
+
instrumental_weight: Weight for non-lyric tokens in confidence calc.
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
Tuple of (coverage, monotonicity, confidence).
|
| 673 |
+
"""
|
| 674 |
+
# Ensure high precision for internal calculation
|
| 675 |
+
energy_matrix = energy_matrix.to(dtype=torch.float64)
|
| 676 |
+
path_coords = path_coords.long()
|
| 677 |
+
type_mask = type_mask.long()
|
| 678 |
+
|
| 679 |
+
device = energy_matrix.device
|
| 680 |
+
rows, cols = energy_matrix.shape
|
| 681 |
+
|
| 682 |
+
is_lyrics_row = (type_mask == 1)
|
| 683 |
+
|
| 684 |
+
# ================= A. Coverage Score =================
|
| 685 |
+
# Ratio of lyric lines that have significant energy peak
|
| 686 |
+
row_max_energies = energy_matrix.max(dim=1).values
|
| 687 |
+
total_sung_rows = is_lyrics_row.sum().double()
|
| 688 |
+
|
| 689 |
+
coverage_threshold = 0.1
|
| 690 |
+
valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
|
| 691 |
+
valid_sung_rows = valid_sung_mask.sum().double()
|
| 692 |
+
|
| 693 |
+
if total_sung_rows > 0:
|
| 694 |
+
coverage_score = valid_sung_rows / total_sung_rows
|
| 695 |
+
else:
|
| 696 |
+
coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 697 |
+
|
| 698 |
+
# ================= B. Monotonicity Score =================
|
| 699 |
+
# Check if the "center of mass" of lyric lines moves forward in time
|
| 700 |
+
col_indices = torch.arange(cols, device=device, dtype=torch.float64)
|
| 701 |
+
|
| 702 |
+
# Zero out low energy noise
|
| 703 |
+
weights = torch.where(
|
| 704 |
+
energy_matrix > time_weight,
|
| 705 |
+
energy_matrix,
|
| 706 |
+
torch.zeros_like(energy_matrix)
|
| 707 |
+
)
|
| 708 |
+
|
| 709 |
+
sum_w = weights.sum(dim=1)
|
| 710 |
+
sum_t = (weights * col_indices).sum(dim=1)
|
| 711 |
+
|
| 712 |
+
# Calculate centroids
|
| 713 |
+
centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
|
| 714 |
+
valid_w_mask = sum_w > 1e-9
|
| 715 |
+
centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
|
| 716 |
+
|
| 717 |
+
# Extract sequence of valid lyrics centroids
|
| 718 |
+
valid_sequence_mask = is_lyrics_row & (centroids >= 0)
|
| 719 |
+
sung_centroids = centroids[valid_sequence_mask]
|
| 720 |
+
|
| 721 |
+
cnt = sung_centroids.shape[0]
|
| 722 |
+
if cnt > 1:
|
| 723 |
+
curr_c = sung_centroids[:-1]
|
| 724 |
+
next_c = sung_centroids[1:]
|
| 725 |
+
|
| 726 |
+
# Check non-decreasing order with overlap tolerance
|
| 727 |
+
non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
|
| 728 |
+
pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
|
| 729 |
+
monotonicity_score = non_decreasing / pairs
|
| 730 |
+
else:
|
| 731 |
+
monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
|
| 732 |
+
|
| 733 |
+
# ================= C. Path Confidence =================
|
| 734 |
+
# Average energy along the optimal path
|
| 735 |
+
if path_coords.shape[0] > 0:
|
| 736 |
+
p_rows = path_coords[:, 0]
|
| 737 |
+
p_cols = path_coords[:, 1]
|
| 738 |
+
|
| 739 |
+
path_energies = energy_matrix[p_rows, p_cols]
|
| 740 |
+
step_weights = torch.ones_like(path_energies)
|
| 741 |
+
|
| 742 |
+
# Lower weight for instrumental/tag steps
|
| 743 |
+
is_inst_step = (type_mask[p_rows] == 0)
|
| 744 |
+
step_weights[is_inst_step] = instrumental_weight
|
| 745 |
+
|
| 746 |
+
total_energy = (path_energies * step_weights).sum()
|
| 747 |
+
total_steps = step_weights.sum()
|
| 748 |
+
|
| 749 |
+
if total_steps > 0:
|
| 750 |
+
path_confidence = total_energy / total_steps
|
| 751 |
+
else:
|
| 752 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 753 |
+
else:
|
| 754 |
+
path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
|
| 755 |
+
|
| 756 |
+
return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
|
| 757 |
+
|
| 758 |
+
def lyrics_alignment_info(
|
| 759 |
+
self,
|
| 760 |
+
attention_matrix: Union[torch.Tensor, np.ndarray],
|
| 761 |
+
token_ids: List[int],
|
| 762 |
+
custom_config: Dict[int, List[int]],
|
| 763 |
+
return_matrices: bool = False,
|
| 764 |
+
medfilt_width: int = 1
|
| 765 |
+
) -> Dict[str, Any]:
|
| 766 |
+
"""
|
| 767 |
+
Generates alignment path and processed matrices.
|
| 768 |
+
|
| 769 |
+
Args:
|
| 770 |
+
attention_matrix: Input attention tensor.
|
| 771 |
+
token_ids: Corresponding token IDs.
|
| 772 |
+
custom_config: Layer/Head configuration.
|
| 773 |
+
return_matrices: If True, returns matrices in the output.
|
| 774 |
+
medfilt_width: Median filter width.
|
| 775 |
+
|
| 776 |
+
Returns:
|
| 777 |
+
Dict or AlignmentInfo object containing path and masks.
|
| 778 |
+
"""
|
| 779 |
+
calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
|
| 780 |
+
attention_matrix, custom_config, medfilt_width
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if calc_matrix is None:
|
| 784 |
+
return {
|
| 785 |
+
"calc_matrix": None,
|
| 786 |
+
"error": "No valid attention heads found"
|
| 787 |
+
}
|
| 788 |
+
|
| 789 |
+
# 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
|
| 790 |
+
# Uses self.tokenizer internally
|
| 791 |
+
type_mask = self._generate_token_type_mask(token_ids)
|
| 792 |
+
|
| 793 |
+
# Safety check for shape mismatch
|
| 794 |
+
if len(type_mask) != energy_matrix.shape[0]:
|
| 795 |
+
# Fallback to all lyrics if shapes don't align
|
| 796 |
+
type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
|
| 797 |
+
|
| 798 |
+
# 2. DTW Pathfinding
|
| 799 |
+
# Using negative calc_matrix because DTW minimizes cost
|
| 800 |
+
text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
|
| 801 |
+
path_coords = np.stack([text_indices, time_indices], axis=1)
|
| 802 |
+
|
| 803 |
+
return_dict = {
|
| 804 |
+
"path_coords": path_coords,
|
| 805 |
+
"type_mask": type_mask,
|
| 806 |
+
"energy_matrix": energy_matrix
|
| 807 |
+
}
|
| 808 |
+
if return_matrices:
|
| 809 |
+
return_dict['calc_matrix'] = calc_matrix
|
| 810 |
+
return_dict['vis_matrix'] = vis_matrix
|
| 811 |
+
|
| 812 |
+
return return_dict
|
| 813 |
+
|
| 814 |
+
def calculate_score(
|
| 815 |
+
self,
|
| 816 |
+
energy_matrix: Union[torch.Tensor, np.ndarray],
|
| 817 |
+
type_mask: Union[torch.Tensor, np.ndarray],
|
| 818 |
+
path_coords: Union[torch.Tensor, np.ndarray],
|
| 819 |
+
time_weight: float = 0.01,
|
| 820 |
+
overlap_frames: float = 9.0,
|
| 821 |
+
instrumental_weight: float = 1.0
|
| 822 |
+
) -> Dict[str, Any]:
|
| 823 |
+
"""
|
| 824 |
+
Calculates the final alignment score based on pre-computed components.
|
| 825 |
+
|
| 826 |
+
Args:
|
| 827 |
+
energy_matrix: Processed energy matrix.
|
| 828 |
+
type_mask: Token type mask.
|
| 829 |
+
path_coords: DTW path coordinates.
|
| 830 |
+
time_weight: Minimum energy threshold for monotonicity.
|
| 831 |
+
overlap_frames: Allowed backward movement frames.
|
| 832 |
+
instrumental_weight: Weight for non-lyric path steps.
|
| 833 |
+
|
| 834 |
+
Returns:
|
| 835 |
+
AlignmentScore object containing individual metrics and final score.
|
| 836 |
+
"""
|
| 837 |
+
# Ensure Inputs are Tensors on the correct device
|
| 838 |
+
if not isinstance(energy_matrix, torch.Tensor):
|
| 839 |
+
# Use available accelerator device; fallback to CPU if none
|
| 840 |
+
if torch.cuda.is_available():
|
| 841 |
+
_score_device = "cuda"
|
| 842 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 843 |
+
_score_device = "mps"
|
| 844 |
+
else:
|
| 845 |
+
_score_device = "cpu"
|
| 846 |
+
energy_matrix = torch.tensor(energy_matrix, device=_score_device, dtype=torch.float32)
|
| 847 |
+
|
| 848 |
+
device = energy_matrix.device
|
| 849 |
+
|
| 850 |
+
if not isinstance(type_mask, torch.Tensor):
|
| 851 |
+
type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
|
| 852 |
+
else:
|
| 853 |
+
type_mask = type_mask.to(device=device, dtype=torch.long)
|
| 854 |
+
|
| 855 |
+
if not isinstance(path_coords, torch.Tensor):
|
| 856 |
+
path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
|
| 857 |
+
else:
|
| 858 |
+
path_coords = path_coords.to(device=device, dtype=torch.long)
|
| 859 |
+
|
| 860 |
+
# Compute Metrics
|
| 861 |
+
coverage, monotonicity, confidence = self._compute_alignment_metrics(
|
| 862 |
+
energy_matrix=energy_matrix,
|
| 863 |
+
path_coords=path_coords,
|
| 864 |
+
type_mask=type_mask,
|
| 865 |
+
time_weight=time_weight,
|
| 866 |
+
overlap_frames=overlap_frames,
|
| 867 |
+
instrumental_weight=instrumental_weight
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
# Final Score Calculation
|
| 871 |
+
# (Cov^2 * Mono^2 * Conf)
|
| 872 |
+
final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
|
| 873 |
+
final_score = float(np.clip(final_score, 0.0, 1.0))
|
| 874 |
+
|
| 875 |
+
return {
|
| 876 |
+
"lyrics_score": round(final_score, 4)
|
| 877 |
+
}
|
acestep/genres_vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gpu_config.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPU Configuration Module
|
| 3 |
+
Centralized GPU memory detection and adaptive configuration management
|
| 4 |
+
|
| 5 |
+
Debug Mode:
|
| 6 |
+
Set environment variable MAX_CUDA_VRAM to simulate different GPU memory sizes.
|
| 7 |
+
Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
|
| 8 |
+
|
| 9 |
+
For MPS testing, use MAX_MPS_VRAM to simulate MPS memory.
|
| 10 |
+
Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
|
| 11 |
+
|
| 12 |
+
This is useful for testing GPU tier configurations on high-end hardware.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Optional, List, Dict, Tuple
|
| 19 |
+
from loguru import logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Environment variable for debugging/testing different GPU memory configurations
|
| 23 |
+
DEBUG_MAX_CUDA_VRAM_ENV = "MAX_CUDA_VRAM"
|
| 24 |
+
DEBUG_MAX_MPS_VRAM_ENV = "MAX_MPS_VRAM"
|
| 25 |
+
|
| 26 |
+
# Tolerance for 16GB detection: reported VRAM like 15.5GB is effectively 16GB hardware
|
| 27 |
+
# Real-world 16GB GPUs often report 15.7-15.9GB due to system/driver reservations
|
| 28 |
+
VRAM_16GB_TOLERANCE_GB = 0.5
|
| 29 |
+
VRAM_16GB_MIN_GB = 16.0 - VRAM_16GB_TOLERANCE_GB # treat as 16GB class if >= this
|
| 30 |
+
|
| 31 |
+
# PyTorch installation URLs for diagnostics
|
| 32 |
+
PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
|
| 33 |
+
PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class GPUConfig:
|
| 38 |
+
"""GPU configuration based on available memory"""
|
| 39 |
+
tier: str # "tier1", "tier2", etc. or "unlimited"
|
| 40 |
+
gpu_memory_gb: float
|
| 41 |
+
|
| 42 |
+
# Duration limits (in seconds)
|
| 43 |
+
max_duration_with_lm: int # When LM is initialized
|
| 44 |
+
max_duration_without_lm: int # When LM is not initialized
|
| 45 |
+
|
| 46 |
+
# Batch size limits
|
| 47 |
+
max_batch_size_with_lm: int
|
| 48 |
+
max_batch_size_without_lm: int
|
| 49 |
+
|
| 50 |
+
# LM configuration
|
| 51 |
+
init_lm_default: bool # Whether to initialize LM by default
|
| 52 |
+
available_lm_models: List[str] # Available LM models for this tier
|
| 53 |
+
|
| 54 |
+
# LM memory allocation (GB) for each model size
|
| 55 |
+
lm_memory_gb: Dict[str, float] # e.g., {"0.6B": 3, "1.7B": 8, "4B": 12}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# GPU tier configurations
|
| 59 |
+
GPU_TIER_CONFIGS = {
|
| 60 |
+
"tier1": { # <= 4GB
|
| 61 |
+
"max_duration_with_lm": 180, # 3 minutes
|
| 62 |
+
"max_duration_without_lm": 180, # 3 minutes
|
| 63 |
+
"max_batch_size_with_lm": 1,
|
| 64 |
+
"max_batch_size_without_lm": 1,
|
| 65 |
+
"init_lm_default": False,
|
| 66 |
+
"available_lm_models": [],
|
| 67 |
+
"lm_memory_gb": {},
|
| 68 |
+
},
|
| 69 |
+
"tier2": { # 4-6GB
|
| 70 |
+
"max_duration_with_lm": 360, # 6 minutes
|
| 71 |
+
"max_duration_without_lm": 360, # 6 minutes
|
| 72 |
+
"max_batch_size_with_lm": 1,
|
| 73 |
+
"max_batch_size_without_lm": 1,
|
| 74 |
+
"init_lm_default": False,
|
| 75 |
+
"available_lm_models": [],
|
| 76 |
+
"lm_memory_gb": {},
|
| 77 |
+
},
|
| 78 |
+
"tier3": { # 6-8GB
|
| 79 |
+
"max_duration_with_lm": 240, # 4 minutes with LM
|
| 80 |
+
"max_duration_without_lm": 360, # 6 minutes without LM
|
| 81 |
+
"max_batch_size_with_lm": 1,
|
| 82 |
+
"max_batch_size_without_lm": 2,
|
| 83 |
+
"init_lm_default": False, # Don't init by default due to limited memory
|
| 84 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
|
| 85 |
+
"lm_memory_gb": {"0.6B": 3},
|
| 86 |
+
},
|
| 87 |
+
"tier4": { # 8-12GB
|
| 88 |
+
"max_duration_with_lm": 240, # 4 minutes with LM
|
| 89 |
+
"max_duration_without_lm": 360, # 6 minutes without LM
|
| 90 |
+
"max_batch_size_with_lm": 2,
|
| 91 |
+
"max_batch_size_without_lm": 4,
|
| 92 |
+
"init_lm_default": False, # Don't init by default
|
| 93 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B"],
|
| 94 |
+
"lm_memory_gb": {"0.6B": 3},
|
| 95 |
+
},
|
| 96 |
+
"tier5": { # 12-16GB
|
| 97 |
+
"max_duration_with_lm": 240, # 4 minutes with LM
|
| 98 |
+
"max_duration_without_lm": 360, # 6 minutes without LM
|
| 99 |
+
"max_batch_size_with_lm": 2,
|
| 100 |
+
"max_batch_size_without_lm": 4,
|
| 101 |
+
"init_lm_default": True,
|
| 102 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B"],
|
| 103 |
+
"lm_memory_gb": {"0.6B": 3, "1.7B": 8},
|
| 104 |
+
},
|
| 105 |
+
"tier6": { # 16-24GB
|
| 106 |
+
"max_duration_with_lm": 480, # 8 minutes
|
| 107 |
+
"max_duration_without_lm": 480, # 8 minutes
|
| 108 |
+
"max_batch_size_with_lm": 4,
|
| 109 |
+
"max_batch_size_without_lm": 8,
|
| 110 |
+
"init_lm_default": True,
|
| 111 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
|
| 112 |
+
"lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
|
| 113 |
+
},
|
| 114 |
+
"unlimited": { # >= 24GB
|
| 115 |
+
"max_duration_with_lm": 600, # 10 minutes (max supported)
|
| 116 |
+
"max_duration_without_lm": 600, # 10 minutes
|
| 117 |
+
"max_batch_size_with_lm": 8,
|
| 118 |
+
"max_batch_size_without_lm": 8,
|
| 119 |
+
"init_lm_default": True,
|
| 120 |
+
"available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
|
| 121 |
+
"lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
|
| 122 |
+
},
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_gpu_memory_gb() -> float:
|
| 127 |
+
"""
|
| 128 |
+
Get GPU memory in GB. Returns 0 if no GPU is available.
|
| 129 |
+
|
| 130 |
+
Debug Mode:
|
| 131 |
+
Set environment variable MAX_CUDA_VRAM to override the detected GPU memory.
|
| 132 |
+
Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
|
| 133 |
+
|
| 134 |
+
For MPS testing, set MAX_MPS_VRAM to override MPS memory detection.
|
| 135 |
+
Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
|
| 136 |
+
|
| 137 |
+
This allows testing different GPU tier configurations on high-end hardware.
|
| 138 |
+
"""
|
| 139 |
+
# Check for debug override first
|
| 140 |
+
debug_vram = os.environ.get(DEBUG_MAX_CUDA_VRAM_ENV)
|
| 141 |
+
if debug_vram is not None:
|
| 142 |
+
try:
|
| 143 |
+
simulated_gb = float(debug_vram)
|
| 144 |
+
logger.warning(f"⚠️ DEBUG MODE: Simulating GPU memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_CUDA_VRAM_ENV} environment variable)")
|
| 145 |
+
return simulated_gb
|
| 146 |
+
except ValueError:
|
| 147 |
+
logger.warning(f"Invalid {DEBUG_MAX_CUDA_VRAM_ENV} value: {debug_vram}, ignoring")
|
| 148 |
+
debug_mps_vram = os.environ.get(DEBUG_MAX_MPS_VRAM_ENV)
|
| 149 |
+
if debug_mps_vram is not None:
|
| 150 |
+
try:
|
| 151 |
+
simulated_gb = float(debug_mps_vram)
|
| 152 |
+
logger.warning(f"⚠️ DEBUG MODE: Simulating MPS memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_MPS_VRAM_ENV} environment variable)")
|
| 153 |
+
return simulated_gb
|
| 154 |
+
except ValueError:
|
| 155 |
+
logger.warning(f"Invalid {DEBUG_MAX_MPS_VRAM_ENV} value: {debug_mps_vram}, ignoring")
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
import torch
|
| 159 |
+
if torch.cuda.is_available():
|
| 160 |
+
# Get total memory of the first GPU in GB
|
| 161 |
+
total_memory = torch.cuda.get_device_properties(0).total_memory
|
| 162 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 163 |
+
device_name = torch.cuda.get_device_name(0)
|
| 164 |
+
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
|
| 165 |
+
if is_rocm:
|
| 166 |
+
logger.info(f"ROCm GPU detected: {device_name} ({memory_gb:.1f} GB, HIP {torch.version.hip})")
|
| 167 |
+
else:
|
| 168 |
+
logger.info(f"CUDA GPU detected: {device_name} ({memory_gb:.1f} GB)")
|
| 169 |
+
return memory_gb
|
| 170 |
+
elif hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 171 |
+
# Get total memory of the first XPU in GB
|
| 172 |
+
total_memory = torch.xpu.get_device_properties(0).total_memory
|
| 173 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 174 |
+
return memory_gb
|
| 175 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 176 |
+
mps_module = getattr(torch, "mps", None)
|
| 177 |
+
try:
|
| 178 |
+
if mps_module is not None and hasattr(mps_module, "recommended_max_memory"):
|
| 179 |
+
total_memory = mps_module.recommended_max_memory()
|
| 180 |
+
memory_gb = total_memory / (1024**3) # Convert bytes to GB
|
| 181 |
+
return memory_gb
|
| 182 |
+
if mps_module is not None and hasattr(mps_module, "get_device_properties"):
|
| 183 |
+
props = mps_module.get_device_properties(0)
|
| 184 |
+
total_memory = getattr(props, "total_memory", None)
|
| 185 |
+
if total_memory:
|
| 186 |
+
memory_gb = total_memory / (1024**3)
|
| 187 |
+
return memory_gb
|
| 188 |
+
except Exception as e:
|
| 189 |
+
logger.warning(f"Failed to detect MPS memory: {e}")
|
| 190 |
+
|
| 191 |
+
# Fallback: estimate from system unified memory (Apple Silicon shares CPU/GPU RAM)
|
| 192 |
+
try:
|
| 193 |
+
import subprocess
|
| 194 |
+
result = subprocess.run(
|
| 195 |
+
["sysctl", "-n", "hw.memsize"],
|
| 196 |
+
capture_output=True, text=True, timeout=5
|
| 197 |
+
)
|
| 198 |
+
total_system_bytes = int(result.stdout.strip())
|
| 199 |
+
# MPS can use up to ~75% of unified memory for GPU workloads
|
| 200 |
+
memory_gb = (total_system_bytes / (1024**3)) * 0.75
|
| 201 |
+
return memory_gb
|
| 202 |
+
except Exception:
|
| 203 |
+
logger.warning(f"MPS available but total memory not exposed. Set {DEBUG_MAX_MPS_VRAM_ENV} to enable tiering.")
|
| 204 |
+
# Conservative fallback for M1/M2
|
| 205 |
+
return 8.0
|
| 206 |
+
else:
|
| 207 |
+
# No GPU detected - provide diagnostic information
|
| 208 |
+
_log_gpu_diagnostic_info(torch)
|
| 209 |
+
return 0
|
| 210 |
+
except Exception as e:
|
| 211 |
+
logger.warning(f"Failed to detect GPU memory: {e}")
|
| 212 |
+
return 0
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def _log_gpu_diagnostic_info(torch_module):
|
| 216 |
+
"""
|
| 217 |
+
Log diagnostic information when GPU is not detected to help users troubleshoot.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
torch_module: The torch module to inspect for build information
|
| 221 |
+
"""
|
| 222 |
+
logger.warning("=" * 80)
|
| 223 |
+
logger.warning("⚠️ GPU NOT DETECTED - DIAGNOSTIC INFORMATION")
|
| 224 |
+
logger.warning("=" * 80)
|
| 225 |
+
|
| 226 |
+
# Check PyTorch build type
|
| 227 |
+
is_rocm_build = hasattr(torch_module.version, 'hip') and torch_module.version.hip is not None
|
| 228 |
+
is_cuda_build = hasattr(torch_module.version, 'cuda') and torch_module.version.cuda is not None
|
| 229 |
+
|
| 230 |
+
if is_rocm_build:
|
| 231 |
+
logger.warning("✓ PyTorch ROCm build detected")
|
| 232 |
+
logger.warning(f" HIP version: {torch_module.version.hip}")
|
| 233 |
+
logger.warning("")
|
| 234 |
+
logger.warning("❌ torch.cuda.is_available() returned False")
|
| 235 |
+
logger.warning("")
|
| 236 |
+
logger.warning("Common causes for AMD/ROCm GPUs:")
|
| 237 |
+
logger.warning(" 1. ROCm drivers not installed or not properly configured")
|
| 238 |
+
logger.warning(" 2. GPU not supported by installed ROCm version")
|
| 239 |
+
logger.warning(" 3. Missing or incorrect HSA_OVERRIDE_GFX_VERSION environment variable")
|
| 240 |
+
logger.warning(" 4. ROCm runtime libraries not in system path")
|
| 241 |
+
logger.warning("")
|
| 242 |
+
|
| 243 |
+
# Check for common environment variables
|
| 244 |
+
hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
|
| 245 |
+
if hsa_override:
|
| 246 |
+
logger.warning(f" HSA_OVERRIDE_GFX_VERSION is set to: {hsa_override}")
|
| 247 |
+
else:
|
| 248 |
+
logger.warning(" ⚠️ HSA_OVERRIDE_GFX_VERSION is not set")
|
| 249 |
+
logger.warning(" For RDNA3 GPUs (RX 7000 series, RX 9000 series):")
|
| 250 |
+
logger.warning(" - RX 7900 XT/XTX, RX 9070 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
|
| 251 |
+
logger.warning(" - RX 7800 XT, RX 7700 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.1")
|
| 252 |
+
logger.warning(" - RX 7600: set HSA_OVERRIDE_GFX_VERSION=11.0.2")
|
| 253 |
+
|
| 254 |
+
logger.warning("")
|
| 255 |
+
logger.warning("Troubleshooting steps:")
|
| 256 |
+
logger.warning(" 1. Verify ROCm installation:")
|
| 257 |
+
logger.warning(" rocm-smi # Should list your GPU")
|
| 258 |
+
logger.warning(" 2. Check PyTorch ROCm build:")
|
| 259 |
+
logger.warning(" python -c \"import torch; print(f'ROCm: {torch.version.hip}')\"")
|
| 260 |
+
logger.warning(" 3. Set HSA_OVERRIDE_GFX_VERSION for your GPU (see above)")
|
| 261 |
+
logger.warning(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets required env vars")
|
| 262 |
+
logger.warning(" 5. See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md for Linux setup")
|
| 263 |
+
logger.warning(" 6. See requirements-rocm.txt for Windows ROCm setup instructions")
|
| 264 |
+
|
| 265 |
+
elif is_cuda_build:
|
| 266 |
+
logger.warning("✓ PyTorch CUDA build detected")
|
| 267 |
+
logger.warning(f" CUDA version: {torch_module.version.cuda}")
|
| 268 |
+
logger.warning("")
|
| 269 |
+
logger.warning("❌ torch.cuda.is_available() returned False")
|
| 270 |
+
logger.warning("")
|
| 271 |
+
logger.warning("Common causes for NVIDIA GPUs:")
|
| 272 |
+
logger.warning(" 1. NVIDIA drivers not installed")
|
| 273 |
+
logger.warning(" 2. CUDA runtime not installed or version mismatch")
|
| 274 |
+
logger.warning(" 3. GPU not supported by installed CUDA version")
|
| 275 |
+
logger.warning("")
|
| 276 |
+
logger.warning("Troubleshooting steps:")
|
| 277 |
+
logger.warning(" 1. Verify NVIDIA driver installation:")
|
| 278 |
+
logger.warning(" nvidia-smi # Should list your GPU")
|
| 279 |
+
logger.warning(" 2. Check CUDA version compatibility")
|
| 280 |
+
logger.warning(" 3. Reinstall PyTorch with CUDA support:")
|
| 281 |
+
logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
|
| 282 |
+
|
| 283 |
+
else:
|
| 284 |
+
logger.warning("⚠️ PyTorch build type: CPU-only")
|
| 285 |
+
logger.warning("")
|
| 286 |
+
logger.warning("You have installed a CPU-only version of PyTorch!")
|
| 287 |
+
logger.warning("")
|
| 288 |
+
logger.warning("For NVIDIA GPUs:")
|
| 289 |
+
logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
|
| 290 |
+
logger.warning("")
|
| 291 |
+
logger.warning("For AMD GPUs with ROCm:")
|
| 292 |
+
logger.warning(" Windows: See requirements-rocm.txt for detailed instructions")
|
| 293 |
+
logger.warning(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
|
| 294 |
+
logger.warning("")
|
| 295 |
+
logger.warning("For more information, see README.md section 'AMD / ROCm GPUs'")
|
| 296 |
+
|
| 297 |
+
logger.warning("=" * 80)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def get_gpu_tier(gpu_memory_gb: float) -> str:
|
| 301 |
+
"""
|
| 302 |
+
Determine GPU tier based on available memory.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
gpu_memory_gb: GPU memory in GB
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
Tier string: "tier1", "tier2", "tier3", "tier4", "tier5", "tier6", or "unlimited"
|
| 309 |
+
"""
|
| 310 |
+
if gpu_memory_gb <= 0:
|
| 311 |
+
# CPU mode - use tier1 limits
|
| 312 |
+
return "tier1"
|
| 313 |
+
elif gpu_memory_gb <= 4:
|
| 314 |
+
return "tier1"
|
| 315 |
+
elif gpu_memory_gb <= 6:
|
| 316 |
+
return "tier2"
|
| 317 |
+
elif gpu_memory_gb <= 8:
|
| 318 |
+
return "tier3"
|
| 319 |
+
elif gpu_memory_gb <= 12:
|
| 320 |
+
return "tier4"
|
| 321 |
+
elif gpu_memory_gb < VRAM_16GB_MIN_GB:
|
| 322 |
+
return "tier5"
|
| 323 |
+
elif gpu_memory_gb <= 24:
|
| 324 |
+
if gpu_memory_gb < 16.0:
|
| 325 |
+
logger.info(f"Detected {gpu_memory_gb:.2f}GB VRAM — treating as 16GB class GPU")
|
| 326 |
+
return "tier6"
|
| 327 |
+
else:
|
| 328 |
+
return "unlimited"
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
|
| 332 |
+
"""
|
| 333 |
+
Get GPU configuration based on detected or provided GPU memory.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
GPUConfig object with all configuration parameters
|
| 340 |
+
"""
|
| 341 |
+
if gpu_memory_gb is None:
|
| 342 |
+
gpu_memory_gb = get_gpu_memory_gb()
|
| 343 |
+
|
| 344 |
+
tier = get_gpu_tier(gpu_memory_gb)
|
| 345 |
+
config = GPU_TIER_CONFIGS[tier]
|
| 346 |
+
|
| 347 |
+
return GPUConfig(
|
| 348 |
+
tier=tier,
|
| 349 |
+
gpu_memory_gb=gpu_memory_gb,
|
| 350 |
+
max_duration_with_lm=config["max_duration_with_lm"],
|
| 351 |
+
max_duration_without_lm=config["max_duration_without_lm"],
|
| 352 |
+
max_batch_size_with_lm=config["max_batch_size_with_lm"],
|
| 353 |
+
max_batch_size_without_lm=config["max_batch_size_without_lm"],
|
| 354 |
+
init_lm_default=config["init_lm_default"],
|
| 355 |
+
available_lm_models=config["available_lm_models"],
|
| 356 |
+
lm_memory_gb=config["lm_memory_gb"],
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_lm_model_size(model_path: str) -> str:
|
| 361 |
+
"""
|
| 362 |
+
Extract LM model size from model path.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
model_path: Model path string (e.g., "acestep-5Hz-lm-0.6B")
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
Model size string: "0.6B", "1.7B", or "4B"
|
| 369 |
+
"""
|
| 370 |
+
if "0.6B" in model_path:
|
| 371 |
+
return "0.6B"
|
| 372 |
+
elif "1.7B" in model_path:
|
| 373 |
+
return "1.7B"
|
| 374 |
+
elif "4B" in model_path:
|
| 375 |
+
return "4B"
|
| 376 |
+
else:
|
| 377 |
+
# Default to smallest model assumption
|
| 378 |
+
return "0.6B"
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def get_lm_gpu_memory_ratio(model_path: str, total_gpu_memory_gb: float) -> Tuple[float, float]:
|
| 382 |
+
"""
|
| 383 |
+
Calculate GPU memory utilization ratio for LM model.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
model_path: LM model path (e.g., "acestep-5Hz-lm-0.6B")
|
| 387 |
+
total_gpu_memory_gb: Total GPU memory in GB
|
| 388 |
+
|
| 389 |
+
Returns:
|
| 390 |
+
Tuple of (gpu_memory_utilization_ratio, target_memory_gb)
|
| 391 |
+
"""
|
| 392 |
+
model_size = get_lm_model_size(model_path)
|
| 393 |
+
|
| 394 |
+
# Target memory allocation for each model size
|
| 395 |
+
target_memory = {
|
| 396 |
+
"0.6B": 3.0,
|
| 397 |
+
"1.7B": 8.0,
|
| 398 |
+
"4B": 12.0,
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
target_gb = target_memory.get(model_size, 3.0)
|
| 402 |
+
|
| 403 |
+
# For large GPUs (>=24GB), don't restrict memory too much
|
| 404 |
+
if total_gpu_memory_gb >= 24:
|
| 405 |
+
# Use a reasonable ratio that allows the model to run efficiently
|
| 406 |
+
ratio = min(0.9, max(0.2, target_gb / total_gpu_memory_gb))
|
| 407 |
+
else:
|
| 408 |
+
# For smaller GPUs, strictly limit memory usage
|
| 409 |
+
ratio = min(0.9, max(0.1, target_gb / total_gpu_memory_gb))
|
| 410 |
+
|
| 411 |
+
return ratio, target_gb
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def check_duration_limit(
|
| 415 |
+
duration: float,
|
| 416 |
+
gpu_config: GPUConfig,
|
| 417 |
+
lm_initialized: bool
|
| 418 |
+
) -> Tuple[bool, str]:
|
| 419 |
+
"""
|
| 420 |
+
Check if requested duration is within limits for current GPU configuration.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
duration: Requested duration in seconds
|
| 424 |
+
gpu_config: Current GPU configuration
|
| 425 |
+
lm_initialized: Whether LM is initialized
|
| 426 |
+
|
| 427 |
+
Returns:
|
| 428 |
+
Tuple of (is_valid, warning_message)
|
| 429 |
+
"""
|
| 430 |
+
max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
|
| 431 |
+
|
| 432 |
+
if duration > max_duration:
|
| 433 |
+
warning_msg = (
|
| 434 |
+
f"⚠️ Requested duration ({duration:.0f}s) exceeds the limit for your GPU "
|
| 435 |
+
f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_duration}s "
|
| 436 |
+
f"({'with' if lm_initialized else 'without'} LM). "
|
| 437 |
+
f"Duration will be clamped to {max_duration}s."
|
| 438 |
+
)
|
| 439 |
+
return False, warning_msg
|
| 440 |
+
|
| 441 |
+
return True, ""
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
def check_batch_size_limit(
|
| 445 |
+
batch_size: int,
|
| 446 |
+
gpu_config: GPUConfig,
|
| 447 |
+
lm_initialized: bool
|
| 448 |
+
) -> Tuple[bool, str]:
|
| 449 |
+
"""
|
| 450 |
+
Check if requested batch size is within limits for current GPU configuration.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
batch_size: Requested batch size
|
| 454 |
+
gpu_config: Current GPU configuration
|
| 455 |
+
lm_initialized: Whether LM is initialized
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
Tuple of (is_valid, warning_message)
|
| 459 |
+
"""
|
| 460 |
+
max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
|
| 461 |
+
|
| 462 |
+
if batch_size > max_batch_size:
|
| 463 |
+
warning_msg = (
|
| 464 |
+
f"⚠️ Requested batch size ({batch_size}) exceeds the limit for your GPU "
|
| 465 |
+
f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_batch_size} "
|
| 466 |
+
f"({'with' if lm_initialized else 'without'} LM). "
|
| 467 |
+
f"Batch size will be clamped to {max_batch_size}."
|
| 468 |
+
)
|
| 469 |
+
return False, warning_msg
|
| 470 |
+
|
| 471 |
+
return True, ""
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def is_lm_model_supported(model_path: str, gpu_config: GPUConfig) -> Tuple[bool, str]:
|
| 475 |
+
"""
|
| 476 |
+
Check if the specified LM model is supported for current GPU configuration.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
model_path: LM model path
|
| 480 |
+
gpu_config: Current GPU configuration
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
Tuple of (is_supported, warning_message)
|
| 484 |
+
"""
|
| 485 |
+
if not gpu_config.available_lm_models:
|
| 486 |
+
return False, (
|
| 487 |
+
f"⚠️ Your GPU ({gpu_config.gpu_memory_gb:.1f}GB) does not have enough memory "
|
| 488 |
+
f"to run any LM model. Please disable LM initialization."
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
model_size = get_lm_model_size(model_path)
|
| 492 |
+
|
| 493 |
+
# Check if model size is in available models
|
| 494 |
+
for available_model in gpu_config.available_lm_models:
|
| 495 |
+
if model_size in available_model:
|
| 496 |
+
return True, ""
|
| 497 |
+
|
| 498 |
+
return False, (
|
| 499 |
+
f"⚠️ LM model {model_path} ({model_size}) is not supported for your GPU "
|
| 500 |
+
f"({gpu_config.gpu_memory_gb:.1f}GB). Available models: {', '.join(gpu_config.available_lm_models)}"
|
| 501 |
+
)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def get_recommended_lm_model(gpu_config: GPUConfig) -> Optional[str]:
|
| 505 |
+
"""
|
| 506 |
+
Get recommended LM model for current GPU configuration.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
gpu_config: Current GPU configuration
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
Recommended LM model path, or None if LM is not supported
|
| 513 |
+
"""
|
| 514 |
+
if not gpu_config.available_lm_models:
|
| 515 |
+
return None
|
| 516 |
+
|
| 517 |
+
# Return the largest available model (last in the list)
|
| 518 |
+
return gpu_config.available_lm_models[-1]
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def print_gpu_config_info(gpu_config: GPUConfig):
|
| 522 |
+
"""Print GPU configuration information for debugging."""
|
| 523 |
+
logger.info(f"GPU Configuration:")
|
| 524 |
+
logger.info(f" - GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB")
|
| 525 |
+
logger.info(f" - Tier: {gpu_config.tier}")
|
| 526 |
+
logger.info(f" - Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
|
| 527 |
+
logger.info(f" - Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
|
| 528 |
+
logger.info(f" - Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
|
| 529 |
+
logger.info(f" - Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
|
| 530 |
+
logger.info(f" - Init LM by Default: {gpu_config.init_lm_default}")
|
| 531 |
+
logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# Global GPU config instance (initialized lazily)
|
| 535 |
+
_global_gpu_config: Optional[GPUConfig] = None
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_global_gpu_config() -> GPUConfig:
|
| 539 |
+
"""Get the global GPU configuration, initializing if necessary."""
|
| 540 |
+
global _global_gpu_config
|
| 541 |
+
if _global_gpu_config is None:
|
| 542 |
+
_global_gpu_config = get_gpu_config()
|
| 543 |
+
return _global_gpu_config
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def set_global_gpu_config(config: GPUConfig):
|
| 547 |
+
"""Set the global GPU configuration."""
|
| 548 |
+
global _global_gpu_config
|
| 549 |
+
_global_gpu_config = config
|
acestep/gradio_ui/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from acestep.gradio_ui.interfaces import create_gradio_interface
|
acestep/gradio_ui/api_routes.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio API Routes Module
|
| 3 |
+
Add API endpoints compatible with api_server.py and CustomAceStep to Gradio application
|
| 4 |
+
"""
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import random
|
| 8 |
+
import time
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
from uuid import uuid4
|
| 11 |
+
|
| 12 |
+
from fastapi import APIRouter, HTTPException, Request, Depends, Header
|
| 13 |
+
from fastapi.responses import FileResponse
|
| 14 |
+
|
| 15 |
+
# Global results directory inside project root
|
| 16 |
+
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 17 |
+
DEFAULT_RESULTS_DIR = os.path.join(PROJECT_ROOT, "gradio_outputs").replace("\\", "/")
|
| 18 |
+
os.makedirs(DEFAULT_RESULTS_DIR, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# API Key storage (set via setup_api_routes)
|
| 21 |
+
_api_key: Optional[str] = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_api_key(key: Optional[str]):
|
| 25 |
+
"""Set the API key for authentication"""
|
| 26 |
+
global _api_key
|
| 27 |
+
_api_key = key
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _wrap_response(data: Any, code: int = 200, error: Optional[str] = None) -> Dict[str, Any]:
|
| 31 |
+
"""Wrap response data in standard format compatible with CustomAceStep."""
|
| 32 |
+
return {
|
| 33 |
+
"data": data,
|
| 34 |
+
"code": code,
|
| 35 |
+
"error": error,
|
| 36 |
+
"timestamp": int(time.time() * 1000),
|
| 37 |
+
"extra": None,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def verify_token_from_request(body: dict, authorization: Optional[str] = None) -> Optional[str]:
|
| 42 |
+
"""
|
| 43 |
+
Verify API key from request body (ai_token) or Authorization header.
|
| 44 |
+
Returns the token if valid, None if no auth required.
|
| 45 |
+
"""
|
| 46 |
+
if _api_key is None:
|
| 47 |
+
return None # No auth required
|
| 48 |
+
|
| 49 |
+
# Try ai_token from body first
|
| 50 |
+
ai_token = body.get("ai_token") if body else None
|
| 51 |
+
if ai_token:
|
| 52 |
+
if ai_token == _api_key:
|
| 53 |
+
return ai_token
|
| 54 |
+
raise HTTPException(status_code=401, detail="Invalid ai_token")
|
| 55 |
+
|
| 56 |
+
# Fallback to Authorization header
|
| 57 |
+
if authorization:
|
| 58 |
+
if authorization.startswith("Bearer "):
|
| 59 |
+
token = authorization[7:]
|
| 60 |
+
else:
|
| 61 |
+
token = authorization
|
| 62 |
+
if token == _api_key:
|
| 63 |
+
return token
|
| 64 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 65 |
+
|
| 66 |
+
# No token provided but auth is required
|
| 67 |
+
raise HTTPException(status_code=401, detail="Missing ai_token or Authorization header")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
async def verify_api_key(authorization: Optional[str] = Header(None)):
|
| 71 |
+
"""Verify API key from Authorization header (legacy, for non-body endpoints)"""
|
| 72 |
+
if _api_key is None:
|
| 73 |
+
return # No auth required
|
| 74 |
+
|
| 75 |
+
if not authorization:
|
| 76 |
+
raise HTTPException(status_code=401, detail="Missing Authorization header")
|
| 77 |
+
|
| 78 |
+
# Support "Bearer <key>" format
|
| 79 |
+
if authorization.startswith("Bearer "):
|
| 80 |
+
token = authorization[7:]
|
| 81 |
+
else:
|
| 82 |
+
token = authorization
|
| 83 |
+
|
| 84 |
+
if token != _api_key:
|
| 85 |
+
raise HTTPException(status_code=401, detail="Invalid API key")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Use diskcache to store results
|
| 89 |
+
try:
|
| 90 |
+
import diskcache
|
| 91 |
+
_cache_dir = os.path.join(os.path.dirname(__file__), ".cache", "api_results")
|
| 92 |
+
os.makedirs(_cache_dir, exist_ok=True)
|
| 93 |
+
_result_cache = diskcache.Cache(_cache_dir)
|
| 94 |
+
DISKCACHE_AVAILABLE = True
|
| 95 |
+
except ImportError:
|
| 96 |
+
_result_cache = {}
|
| 97 |
+
DISKCACHE_AVAILABLE = False
|
| 98 |
+
|
| 99 |
+
RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days expiration
|
| 100 |
+
RESULT_KEY_PREFIX = "ace_step_v1.5_"
|
| 101 |
+
|
| 102 |
+
# =============================================================================
|
| 103 |
+
# Example Data for Random Sample
|
| 104 |
+
# =============================================================================
|
| 105 |
+
|
| 106 |
+
def _get_project_root() -> str:
|
| 107 |
+
"""Get project root directory"""
|
| 108 |
+
return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _load_all_examples(sample_mode: str = "simple_mode") -> List[Dict[str, Any]]:
|
| 112 |
+
"""Load all example JSON files from examples directory"""
|
| 113 |
+
project_root = _get_project_root()
|
| 114 |
+
if sample_mode == "simple_mode":
|
| 115 |
+
examples_dir = os.path.join(project_root, "examples", "simple_mode")
|
| 116 |
+
else:
|
| 117 |
+
examples_dir = os.path.join(project_root, "examples", "text2music")
|
| 118 |
+
|
| 119 |
+
if not os.path.isdir(examples_dir):
|
| 120 |
+
return []
|
| 121 |
+
|
| 122 |
+
all_examples = []
|
| 123 |
+
for filename in os.listdir(examples_dir):
|
| 124 |
+
if filename.endswith(".json"):
|
| 125 |
+
filepath = os.path.join(examples_dir, filename)
|
| 126 |
+
try:
|
| 127 |
+
with open(filepath, "r", encoding="utf-8") as f:
|
| 128 |
+
data = json.load(f)
|
| 129 |
+
if isinstance(data, list):
|
| 130 |
+
all_examples.extend(data)
|
| 131 |
+
elif isinstance(data, dict):
|
| 132 |
+
all_examples.append(data)
|
| 133 |
+
except Exception:
|
| 134 |
+
pass
|
| 135 |
+
return all_examples
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Pre-load example data
|
| 139 |
+
SIMPLE_EXAMPLE_DATA = _load_all_examples("simple_mode")
|
| 140 |
+
CUSTOM_EXAMPLE_DATA = _load_all_examples("custom_mode")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def store_result(task_id: str, result: dict, status: str = "succeeded"):
|
| 144 |
+
"""Store result to diskcache"""
|
| 145 |
+
data = {
|
| 146 |
+
"result": result,
|
| 147 |
+
"created_at": time.time(),
|
| 148 |
+
"status": status
|
| 149 |
+
}
|
| 150 |
+
key = f"{RESULT_KEY_PREFIX}{task_id}"
|
| 151 |
+
if DISKCACHE_AVAILABLE:
|
| 152 |
+
_result_cache.set(key, data, expire=RESULT_EXPIRE_SECONDS)
|
| 153 |
+
else:
|
| 154 |
+
_result_cache[key] = data
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_result(task_id: str) -> Optional[dict]:
|
| 158 |
+
"""Get result from diskcache"""
|
| 159 |
+
key = f"{RESULT_KEY_PREFIX}{task_id}"
|
| 160 |
+
if DISKCACHE_AVAILABLE:
|
| 161 |
+
return _result_cache.get(key)
|
| 162 |
+
else:
|
| 163 |
+
return _result_cache.get(key)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
router = APIRouter()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
@router.get("/health")
|
| 170 |
+
async def health_check():
|
| 171 |
+
"""Health check endpoint"""
|
| 172 |
+
return _wrap_response({
|
| 173 |
+
"status": "ok",
|
| 174 |
+
"service": "ACE-Step Gradio API",
|
| 175 |
+
"version": "1.0",
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@router.get("/v1/models")
|
| 180 |
+
async def list_models(request: Request, _: None = Depends(verify_api_key)):
|
| 181 |
+
"""List available DiT models"""
|
| 182 |
+
dit_handler = request.app.state.dit_handler
|
| 183 |
+
|
| 184 |
+
models = []
|
| 185 |
+
if dit_handler and dit_handler.model is not None:
|
| 186 |
+
# Get current loaded model name
|
| 187 |
+
config_path = getattr(dit_handler, 'config_path', '') or ''
|
| 188 |
+
model_name = os.path.basename(config_path.rstrip("/\\")) if config_path else "unknown"
|
| 189 |
+
models.append({
|
| 190 |
+
"name": model_name,
|
| 191 |
+
"is_default": True,
|
| 192 |
+
})
|
| 193 |
+
|
| 194 |
+
return _wrap_response({
|
| 195 |
+
"models": models,
|
| 196 |
+
"default_model": models[0]["name"] if models else None,
|
| 197 |
+
})
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@router.get("/v1/audio")
|
| 201 |
+
async def get_audio(path: str, _: None = Depends(verify_api_key)):
|
| 202 |
+
"""Download audio file"""
|
| 203 |
+
# Security: Validate path is within allowed directory to prevent path traversal
|
| 204 |
+
resolved_path = os.path.realpath(path)
|
| 205 |
+
allowed_dir = os.path.realpath(DEFAULT_RESULTS_DIR)
|
| 206 |
+
if not resolved_path.startswith(allowed_dir + os.sep) and resolved_path != allowed_dir:
|
| 207 |
+
raise HTTPException(status_code=403, detail="Access denied: path outside allowed directory")
|
| 208 |
+
if not os.path.exists(resolved_path):
|
| 209 |
+
raise HTTPException(status_code=404, detail="Audio file not found")
|
| 210 |
+
|
| 211 |
+
ext = os.path.splitext(resolved_path)[1].lower()
|
| 212 |
+
media_types = {
|
| 213 |
+
".mp3": "audio/mpeg",
|
| 214 |
+
".wav": "audio/wav",
|
| 215 |
+
".flac": "audio/flac",
|
| 216 |
+
".ogg": "audio/ogg",
|
| 217 |
+
}
|
| 218 |
+
media_type = media_types.get(ext, "audio/mpeg")
|
| 219 |
+
|
| 220 |
+
return FileResponse(resolved_path, media_type=media_type)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@router.post("/create_random_sample")
|
| 224 |
+
async def create_random_sample(request: Request, authorization: Optional[str] = Header(None)):
|
| 225 |
+
"""Get random sample parameters from pre-loaded example data"""
|
| 226 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 227 |
+
|
| 228 |
+
if "json" in content_type:
|
| 229 |
+
body = await request.json()
|
| 230 |
+
else:
|
| 231 |
+
form = await request.form()
|
| 232 |
+
body = {k: v for k, v in form.items()}
|
| 233 |
+
|
| 234 |
+
verify_token_from_request(body, authorization)
|
| 235 |
+
sample_type = body.get("sample_type", "simple_mode") or "simple_mode"
|
| 236 |
+
|
| 237 |
+
if sample_type == "simple_mode":
|
| 238 |
+
example_data = SIMPLE_EXAMPLE_DATA
|
| 239 |
+
else:
|
| 240 |
+
example_data = CUSTOM_EXAMPLE_DATA
|
| 241 |
+
|
| 242 |
+
if not example_data:
|
| 243 |
+
return _wrap_response(None, code=500, error="No example data available")
|
| 244 |
+
|
| 245 |
+
random_example = random.choice(example_data)
|
| 246 |
+
return _wrap_response(random_example)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
@router.post("/query_result")
|
| 250 |
+
async def query_result(request: Request, authorization: Optional[str] = Header(None)):
|
| 251 |
+
"""Batch query task results"""
|
| 252 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 253 |
+
|
| 254 |
+
if "json" in content_type:
|
| 255 |
+
body = await request.json()
|
| 256 |
+
else:
|
| 257 |
+
form = await request.form()
|
| 258 |
+
body = {k: v for k, v in form.items()}
|
| 259 |
+
|
| 260 |
+
verify_token_from_request(body, authorization)
|
| 261 |
+
task_ids = body.get("task_id_list", [])
|
| 262 |
+
|
| 263 |
+
if isinstance(task_ids, str):
|
| 264 |
+
try:
|
| 265 |
+
task_ids = json.loads(task_ids)
|
| 266 |
+
except Exception:
|
| 267 |
+
task_ids = []
|
| 268 |
+
|
| 269 |
+
results = []
|
| 270 |
+
for task_id in task_ids:
|
| 271 |
+
data = get_result(task_id)
|
| 272 |
+
if data and data.get("status") == "succeeded":
|
| 273 |
+
results.append({
|
| 274 |
+
"task_id": task_id,
|
| 275 |
+
"status": 1,
|
| 276 |
+
"result": json.dumps(data["result"], ensure_ascii=False)
|
| 277 |
+
})
|
| 278 |
+
else:
|
| 279 |
+
results.append({
|
| 280 |
+
"task_id": task_id,
|
| 281 |
+
"status": 0,
|
| 282 |
+
"result": "[]"
|
| 283 |
+
})
|
| 284 |
+
|
| 285 |
+
return _wrap_response(results)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@router.post("/format_input")
|
| 289 |
+
async def format_input(request: Request, authorization: Optional[str] = Header(None)):
|
| 290 |
+
"""Format and enhance lyrics/caption via LLM"""
|
| 291 |
+
llm_handler = request.app.state.llm_handler
|
| 292 |
+
|
| 293 |
+
if not llm_handler or not llm_handler.llm_initialized:
|
| 294 |
+
return _wrap_response(None, code=500, error="LLM not initialized")
|
| 295 |
+
|
| 296 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 297 |
+
if "json" in content_type:
|
| 298 |
+
body = await request.json()
|
| 299 |
+
else:
|
| 300 |
+
form = await request.form()
|
| 301 |
+
body = {k: v for k, v in form.items()}
|
| 302 |
+
|
| 303 |
+
verify_token_from_request(body, authorization)
|
| 304 |
+
|
| 305 |
+
caption = body.get("prompt", "") or ""
|
| 306 |
+
lyrics = body.get("lyrics", "") or ""
|
| 307 |
+
temperature = float(body.get("temperature", 0.85))
|
| 308 |
+
|
| 309 |
+
from acestep.inference import format_sample
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
result = format_sample(
|
| 313 |
+
llm_handler=llm_handler,
|
| 314 |
+
caption=caption,
|
| 315 |
+
lyrics=lyrics,
|
| 316 |
+
temperature=temperature,
|
| 317 |
+
use_constrained_decoding=True,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
if not result.success:
|
| 321 |
+
return _wrap_response(None, code=500, error=result.status_message)
|
| 322 |
+
|
| 323 |
+
return _wrap_response({
|
| 324 |
+
"caption": result.caption or caption,
|
| 325 |
+
"lyrics": result.lyrics or lyrics,
|
| 326 |
+
"bpm": result.bpm,
|
| 327 |
+
"key_scale": result.keyscale,
|
| 328 |
+
"time_signature": result.timesignature,
|
| 329 |
+
"duration": result.duration,
|
| 330 |
+
"vocal_language": result.language or "unknown",
|
| 331 |
+
})
|
| 332 |
+
except Exception as e:
|
| 333 |
+
return _wrap_response(None, code=500, error=str(e))
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
@router.post("/release_task")
|
| 337 |
+
async def release_task(request: Request, authorization: Optional[str] = Header(None)):
|
| 338 |
+
"""Create music generation task"""
|
| 339 |
+
dit_handler = request.app.state.dit_handler
|
| 340 |
+
llm_handler = request.app.state.llm_handler
|
| 341 |
+
|
| 342 |
+
if not dit_handler or dit_handler.model is None:
|
| 343 |
+
raise HTTPException(status_code=500, detail="DiT model not initialized")
|
| 344 |
+
|
| 345 |
+
content_type = (request.headers.get("content-type") or "").lower()
|
| 346 |
+
if "json" in content_type:
|
| 347 |
+
body = await request.json()
|
| 348 |
+
else:
|
| 349 |
+
form = await request.form()
|
| 350 |
+
body = {k: v for k, v in form.items()}
|
| 351 |
+
|
| 352 |
+
verify_token_from_request(body, authorization)
|
| 353 |
+
task_id = str(uuid4())
|
| 354 |
+
|
| 355 |
+
from acestep.inference import generate_music, GenerationParams, GenerationConfig, create_sample, format_sample
|
| 356 |
+
|
| 357 |
+
# Parse param_obj if provided
|
| 358 |
+
param_obj = body.get("param_obj", {})
|
| 359 |
+
if isinstance(param_obj, str):
|
| 360 |
+
try:
|
| 361 |
+
param_obj = json.loads(param_obj)
|
| 362 |
+
except Exception:
|
| 363 |
+
param_obj = {}
|
| 364 |
+
|
| 365 |
+
# Helper to get param with aliases
|
| 366 |
+
def get_param(key, *aliases, default=None):
|
| 367 |
+
for k in [key] + list(aliases):
|
| 368 |
+
if k in body and body[k] is not None:
|
| 369 |
+
return body[k]
|
| 370 |
+
if k in param_obj and param_obj[k] is not None:
|
| 371 |
+
return param_obj[k]
|
| 372 |
+
return default
|
| 373 |
+
|
| 374 |
+
def to_bool(val, default=False):
|
| 375 |
+
if val is None:
|
| 376 |
+
return default
|
| 377 |
+
if isinstance(val, bool):
|
| 378 |
+
return val
|
| 379 |
+
if isinstance(val, str):
|
| 380 |
+
return val.lower() in ("true", "1", "yes")
|
| 381 |
+
return bool(val)
|
| 382 |
+
|
| 383 |
+
try:
|
| 384 |
+
# Get sample_mode and sample_query parameters
|
| 385 |
+
sample_mode = to_bool(get_param("sample_mode", "sampleMode"), False)
|
| 386 |
+
sample_query = get_param("sample_query", "sampleQuery", "description", "desc", default="") or ""
|
| 387 |
+
use_format = to_bool(get_param("use_format", "useFormat"), False)
|
| 388 |
+
has_sample_query = bool(sample_query and sample_query.strip())
|
| 389 |
+
|
| 390 |
+
# Get base parameters
|
| 391 |
+
caption = get_param("prompt", "caption", default="") or ""
|
| 392 |
+
lyrics = get_param("lyrics", default="") or ""
|
| 393 |
+
vocal_language = get_param("vocal_language", "language", default="en") or "en"
|
| 394 |
+
lm_temperature = float(get_param("lm_temperature", "temperature", default=0.85) or 0.85)
|
| 395 |
+
|
| 396 |
+
# Process sample_mode: use LLM to auto-generate caption/lyrics/metas
|
| 397 |
+
if sample_mode or has_sample_query:
|
| 398 |
+
if not llm_handler or not llm_handler.llm_initialized:
|
| 399 |
+
raise HTTPException(status_code=500, detail="sample_mode requires LLM to be initialized")
|
| 400 |
+
|
| 401 |
+
query = sample_query if has_sample_query else "NO USER INPUT"
|
| 402 |
+
sample_result = create_sample(
|
| 403 |
+
llm_handler=llm_handler,
|
| 404 |
+
query=query,
|
| 405 |
+
vocal_language=vocal_language if vocal_language not in ("en", "unknown", "") else None,
|
| 406 |
+
temperature=lm_temperature,
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
if not sample_result.success:
|
| 410 |
+
raise HTTPException(status_code=500, detail=sample_result.error or sample_result.status_message)
|
| 411 |
+
|
| 412 |
+
# Use generated values
|
| 413 |
+
caption = sample_result.caption or caption
|
| 414 |
+
lyrics = sample_result.lyrics or lyrics
|
| 415 |
+
# Override metas from sample result if available
|
| 416 |
+
sample_bpm = sample_result.bpm
|
| 417 |
+
sample_duration = sample_result.duration
|
| 418 |
+
sample_keyscale = sample_result.keyscale
|
| 419 |
+
sample_timesignature = sample_result.timesignature
|
| 420 |
+
sample_language = sample_result.language or vocal_language
|
| 421 |
+
else:
|
| 422 |
+
sample_bpm = None
|
| 423 |
+
sample_duration = None
|
| 424 |
+
sample_keyscale = None
|
| 425 |
+
sample_timesignature = None
|
| 426 |
+
sample_language = vocal_language
|
| 427 |
+
|
| 428 |
+
# Process use_format: enhance caption/lyrics via LLM
|
| 429 |
+
if use_format and not sample_mode and not has_sample_query:
|
| 430 |
+
if llm_handler and llm_handler.llm_initialized:
|
| 431 |
+
format_result = format_sample(
|
| 432 |
+
llm_handler=llm_handler,
|
| 433 |
+
caption=caption,
|
| 434 |
+
lyrics=lyrics,
|
| 435 |
+
temperature=lm_temperature,
|
| 436 |
+
)
|
| 437 |
+
if format_result.success:
|
| 438 |
+
caption = format_result.caption or caption
|
| 439 |
+
lyrics = format_result.lyrics or lyrics
|
| 440 |
+
if format_result.bpm:
|
| 441 |
+
sample_bpm = format_result.bpm
|
| 442 |
+
if format_result.duration:
|
| 443 |
+
sample_duration = format_result.duration
|
| 444 |
+
if format_result.keyscale:
|
| 445 |
+
sample_keyscale = format_result.keyscale
|
| 446 |
+
if format_result.timesignature:
|
| 447 |
+
sample_timesignature = format_result.timesignature
|
| 448 |
+
if format_result.language:
|
| 449 |
+
sample_language = format_result.language
|
| 450 |
+
|
| 451 |
+
# Build generation params with alias support
|
| 452 |
+
params = GenerationParams(
|
| 453 |
+
task_type=get_param("task_type", default="text2music"),
|
| 454 |
+
caption=caption,
|
| 455 |
+
lyrics=lyrics,
|
| 456 |
+
bpm=sample_bpm or get_param("bpm"),
|
| 457 |
+
keyscale=sample_keyscale or get_param("key_scale", "keyscale", "key", default=""),
|
| 458 |
+
timesignature=sample_timesignature or get_param("time_signature", "timesignature", default=""),
|
| 459 |
+
duration=sample_duration or get_param("audio_duration", "duration", default=-1),
|
| 460 |
+
vocal_language=sample_language,
|
| 461 |
+
inference_steps=get_param("inference_steps", default=8),
|
| 462 |
+
guidance_scale=float(get_param("guidance_scale", default=7.0) or 7.0),
|
| 463 |
+
seed=int(get_param("seed", default=-1) or -1),
|
| 464 |
+
thinking=to_bool(get_param("thinking"), False),
|
| 465 |
+
lm_temperature=lm_temperature,
|
| 466 |
+
lm_cfg_scale=float(get_param("lm_cfg_scale", default=2.0) or 2.0),
|
| 467 |
+
lm_negative_prompt=get_param("lm_negative_prompt", default="NO USER INPUT") or "NO USER INPUT",
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
config = GenerationConfig(
|
| 471 |
+
batch_size=get_param("batch_size", default=2),
|
| 472 |
+
use_random_seed=get_param("use_random_seed", default=True),
|
| 473 |
+
audio_format=get_param("audio_format", default="mp3"),
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
# Get output directory
|
| 477 |
+
save_dir = os.path.join(DEFAULT_RESULTS_DIR, f"api_{int(time.time())}").replace("\\", "/")
|
| 478 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 479 |
+
|
| 480 |
+
# Call generation function
|
| 481 |
+
result = generate_music(
|
| 482 |
+
dit_handler=dit_handler,
|
| 483 |
+
llm_handler=llm_handler if llm_handler and llm_handler.llm_initialized else None,
|
| 484 |
+
params=params,
|
| 485 |
+
config=config,
|
| 486 |
+
save_dir=save_dir,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
if not result.success:
|
| 490 |
+
raise HTTPException(status_code=500, detail=result.error or result.status_message)
|
| 491 |
+
|
| 492 |
+
# Extract audio paths
|
| 493 |
+
audio_paths = [a["path"] for a in result.audios if a.get("path")]
|
| 494 |
+
|
| 495 |
+
# Build result data with download URLs
|
| 496 |
+
from urllib.parse import urlencode
|
| 497 |
+
result_data = [{
|
| 498 |
+
"file": p,
|
| 499 |
+
"url": f"/v1/audio?{urlencode({'path': p})}",
|
| 500 |
+
"status": 1,
|
| 501 |
+
"create_time": int(time.time()),
|
| 502 |
+
} for p in audio_paths]
|
| 503 |
+
|
| 504 |
+
# Store result
|
| 505 |
+
store_result(task_id, result_data)
|
| 506 |
+
|
| 507 |
+
return _wrap_response({"task_id": task_id, "status": "succeeded"})
|
| 508 |
+
|
| 509 |
+
except HTTPException:
|
| 510 |
+
raise
|
| 511 |
+
except Exception as e:
|
| 512 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str] = None):
|
| 516 |
+
"""
|
| 517 |
+
Mount API routes to a FastAPI application (for use with gr.mount_gradio_app)
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
app: FastAPI application instance
|
| 521 |
+
dit_handler: DiT handler
|
| 522 |
+
llm_handler: LLM handler
|
| 523 |
+
api_key: Optional API key for authentication
|
| 524 |
+
"""
|
| 525 |
+
set_api_key(api_key)
|
| 526 |
+
app.state.dit_handler = dit_handler
|
| 527 |
+
app.state.llm_handler = llm_handler
|
| 528 |
+
app.include_router(router)
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def setup_api_routes(demo, dit_handler, llm_handler, api_key: Optional[str] = None):
|
| 532 |
+
"""
|
| 533 |
+
Mount API routes to Gradio application
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
demo: Gradio Blocks instance
|
| 537 |
+
dit_handler: DiT handler
|
| 538 |
+
llm_handler: LLM handler
|
| 539 |
+
api_key: Optional API key for authentication
|
| 540 |
+
"""
|
| 541 |
+
set_api_key(api_key)
|
| 542 |
+
app = demo.app
|
| 543 |
+
app.state.dit_handler = dit_handler
|
| 544 |
+
app.state.llm_handler = llm_handler
|
| 545 |
+
app.include_router(router)
|
| 546 |
+
|
| 547 |
+
# Override the /info endpoint to handle schema generation errors gracefully
|
| 548 |
+
from fastapi.responses import JSONResponse
|
| 549 |
+
|
| 550 |
+
@app.get("/info")
|
| 551 |
+
async def custom_api_info():
|
| 552 |
+
"""Custom API info endpoint with error handling for schema generation issues"""
|
| 553 |
+
try:
|
| 554 |
+
# Try to get the original API info
|
| 555 |
+
from gradio import utils
|
| 556 |
+
api_info = utils.safe_deepcopy(demo.get_api_info())
|
| 557 |
+
return JSONResponse(content=api_info)
|
| 558 |
+
except (TypeError, AttributeError, KeyError) as e:
|
| 559 |
+
# If schema generation fails, return a minimal response
|
| 560 |
+
return JSONResponse(content={
|
| 561 |
+
"error": "API schema generation not available",
|
| 562 |
+
"message": "Custom API routes are available at /health, /v1/models, /release_task, /query_result, /create_random_sample, /format_input",
|
| 563 |
+
"detail": str(e)
|
| 564 |
+
})
|
acestep/gradio_ui/events/__init__.py
ADDED
|
@@ -0,0 +1,1254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Event Handlers Module
|
| 3 |
+
Main entry point for setting up all event handlers
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
# Import handler modules
|
| 10 |
+
from . import generation_handlers as gen_h
|
| 11 |
+
from . import results_handlers as res_h
|
| 12 |
+
from . import training_handlers as train_h
|
| 13 |
+
from acestep.gradio_ui.i18n import t
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
|
| 17 |
+
"""Setup event handlers connecting UI components and business logic"""
|
| 18 |
+
|
| 19 |
+
# ========== Dataset Handlers ==========
|
| 20 |
+
dataset_section["import_dataset_btn"].click(
|
| 21 |
+
fn=dataset_handler.import_dataset,
|
| 22 |
+
inputs=[dataset_section["dataset_type"]],
|
| 23 |
+
outputs=[dataset_section["data_status"]]
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# ========== Service Initialization ==========
|
| 27 |
+
generation_section["refresh_btn"].click(
|
| 28 |
+
fn=lambda: gen_h.refresh_checkpoints(dit_handler),
|
| 29 |
+
outputs=[generation_section["checkpoint_dropdown"]]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
generation_section["config_path"].change(
|
| 33 |
+
fn=gen_h.update_model_type_settings,
|
| 34 |
+
inputs=[generation_section["config_path"]],
|
| 35 |
+
outputs=[
|
| 36 |
+
generation_section["inference_steps"],
|
| 37 |
+
generation_section["guidance_scale"],
|
| 38 |
+
generation_section["use_adg"],
|
| 39 |
+
generation_section["shift"],
|
| 40 |
+
generation_section["cfg_interval_start"],
|
| 41 |
+
generation_section["cfg_interval_end"],
|
| 42 |
+
generation_section["task_type"],
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
generation_section["init_btn"].click(
|
| 47 |
+
fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
|
| 48 |
+
inputs=[
|
| 49 |
+
generation_section["checkpoint_dropdown"],
|
| 50 |
+
generation_section["config_path"],
|
| 51 |
+
generation_section["device"],
|
| 52 |
+
generation_section["init_llm_checkbox"],
|
| 53 |
+
generation_section["lm_model_path"],
|
| 54 |
+
generation_section["backend_dropdown"],
|
| 55 |
+
generation_section["use_flash_attention_checkbox"],
|
| 56 |
+
generation_section["offload_to_cpu_checkbox"],
|
| 57 |
+
generation_section["offload_dit_to_cpu_checkbox"],
|
| 58 |
+
generation_section["compile_model_checkbox"],
|
| 59 |
+
generation_section["quantization_checkbox"],
|
| 60 |
+
],
|
| 61 |
+
outputs=[
|
| 62 |
+
generation_section["init_status"],
|
| 63 |
+
generation_section["generate_btn"],
|
| 64 |
+
generation_section["service_config_accordion"],
|
| 65 |
+
# Model type settings (updated based on actual loaded model)
|
| 66 |
+
generation_section["inference_steps"],
|
| 67 |
+
generation_section["guidance_scale"],
|
| 68 |
+
generation_section["use_adg"],
|
| 69 |
+
generation_section["shift"],
|
| 70 |
+
generation_section["cfg_interval_start"],
|
| 71 |
+
generation_section["cfg_interval_end"],
|
| 72 |
+
generation_section["task_type"],
|
| 73 |
+
]
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# ========== LoRA Handlers ==========
|
| 77 |
+
generation_section["load_lora_btn"].click(
|
| 78 |
+
fn=dit_handler.load_lora,
|
| 79 |
+
inputs=[generation_section["lora_path"]],
|
| 80 |
+
outputs=[generation_section["lora_status"]]
|
| 81 |
+
).then(
|
| 82 |
+
# Update checkbox to enabled state after loading
|
| 83 |
+
fn=lambda: gr.update(value=True),
|
| 84 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
generation_section["unload_lora_btn"].click(
|
| 88 |
+
fn=dit_handler.unload_lora,
|
| 89 |
+
outputs=[generation_section["lora_status"]]
|
| 90 |
+
).then(
|
| 91 |
+
# Update checkbox to disabled state after unloading
|
| 92 |
+
fn=lambda: gr.update(value=False),
|
| 93 |
+
outputs=[generation_section["use_lora_checkbox"]]
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
generation_section["use_lora_checkbox"].change(
|
| 97 |
+
fn=dit_handler.set_use_lora,
|
| 98 |
+
inputs=[generation_section["use_lora_checkbox"]],
|
| 99 |
+
outputs=[generation_section["lora_status"]]
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
generation_section["lora_scale_slider"].change(
|
| 103 |
+
fn=dit_handler.set_lora_scale,
|
| 104 |
+
inputs=[generation_section["lora_scale_slider"]],
|
| 105 |
+
outputs=[generation_section["lora_status"]]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# ========== UI Visibility Updates ==========
|
| 109 |
+
generation_section["init_llm_checkbox"].change(
|
| 110 |
+
fn=gen_h.update_negative_prompt_visibility,
|
| 111 |
+
inputs=[generation_section["init_llm_checkbox"]],
|
| 112 |
+
outputs=[generation_section["lm_negative_prompt"]]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
generation_section["init_llm_checkbox"].change(
|
| 116 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 117 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
|
| 118 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
generation_section["task_type"].change(
|
| 122 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 123 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
|
| 124 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
generation_section["reference_audio"].change(
|
| 128 |
+
fn=gen_h.update_audio_cover_strength_visibility,
|
| 129 |
+
inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
|
| 130 |
+
outputs=[generation_section["audio_cover_strength"]]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
generation_section["batch_size_input"].change(
|
| 134 |
+
fn=gen_h.update_audio_components_visibility,
|
| 135 |
+
inputs=[generation_section["batch_size_input"]],
|
| 136 |
+
outputs=[
|
| 137 |
+
results_section["audio_col_1"],
|
| 138 |
+
results_section["audio_col_2"],
|
| 139 |
+
results_section["audio_col_3"],
|
| 140 |
+
results_section["audio_col_4"],
|
| 141 |
+
results_section["audio_row_5_8"],
|
| 142 |
+
results_section["audio_col_5"],
|
| 143 |
+
results_section["audio_col_6"],
|
| 144 |
+
results_section["audio_col_7"],
|
| 145 |
+
results_section["audio_col_8"],
|
| 146 |
+
]
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# ========== Audio Conversion ==========
|
| 150 |
+
generation_section["convert_src_to_codes_btn"].click(
|
| 151 |
+
fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
|
| 152 |
+
inputs=[generation_section["src_audio"]],
|
| 153 |
+
outputs=[generation_section["text2music_audio_code_string"]]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# ========== Instruction UI Updates ==========
|
| 157 |
+
for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"], generation_section["reference_audio"]]:
|
| 158 |
+
trigger.change(
|
| 159 |
+
fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
|
| 160 |
+
inputs=[
|
| 161 |
+
generation_section["task_type"],
|
| 162 |
+
generation_section["track_name"],
|
| 163 |
+
generation_section["complete_track_classes"],
|
| 164 |
+
generation_section["text2music_audio_code_string"],
|
| 165 |
+
generation_section["init_llm_checkbox"],
|
| 166 |
+
generation_section["reference_audio"],
|
| 167 |
+
],
|
| 168 |
+
outputs=[
|
| 169 |
+
generation_section["instruction_display_gen"],
|
| 170 |
+
generation_section["track_name"],
|
| 171 |
+
generation_section["complete_track_classes"],
|
| 172 |
+
generation_section["audio_cover_strength"],
|
| 173 |
+
generation_section["repainting_group"],
|
| 174 |
+
generation_section["text2music_audio_codes_group"],
|
| 175 |
+
]
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# ========== Sample/Transcribe Handlers ==========
|
| 179 |
+
# Load random example from ./examples/text2music directory
|
| 180 |
+
generation_section["sample_btn"].click(
|
| 181 |
+
fn=lambda task: gen_h.load_random_example(task, llm_handler) + (True,),
|
| 182 |
+
inputs=[
|
| 183 |
+
generation_section["task_type"],
|
| 184 |
+
],
|
| 185 |
+
outputs=[
|
| 186 |
+
generation_section["captions"],
|
| 187 |
+
generation_section["lyrics"],
|
| 188 |
+
generation_section["think_checkbox"],
|
| 189 |
+
generation_section["bpm"],
|
| 190 |
+
generation_section["audio_duration"],
|
| 191 |
+
generation_section["key_scale"],
|
| 192 |
+
generation_section["vocal_language"],
|
| 193 |
+
generation_section["time_signature"],
|
| 194 |
+
results_section["is_format_caption_state"]
|
| 195 |
+
]
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
generation_section["text2music_audio_code_string"].change(
|
| 199 |
+
fn=gen_h.update_transcribe_button_text,
|
| 200 |
+
inputs=[generation_section["text2music_audio_code_string"]],
|
| 201 |
+
outputs=[generation_section["transcribe_btn"]]
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
generation_section["transcribe_btn"].click(
|
| 205 |
+
fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
|
| 206 |
+
inputs=[
|
| 207 |
+
generation_section["text2music_audio_code_string"],
|
| 208 |
+
generation_section["constrained_decoding_debug"]
|
| 209 |
+
],
|
| 210 |
+
outputs=[
|
| 211 |
+
results_section["status_output"],
|
| 212 |
+
generation_section["captions"],
|
| 213 |
+
generation_section["lyrics"],
|
| 214 |
+
generation_section["bpm"],
|
| 215 |
+
generation_section["audio_duration"],
|
| 216 |
+
generation_section["key_scale"],
|
| 217 |
+
generation_section["vocal_language"],
|
| 218 |
+
generation_section["time_signature"],
|
| 219 |
+
results_section["is_format_caption_state"]
|
| 220 |
+
]
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# ========== Reset Format Caption Flag ==========
|
| 224 |
+
for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
|
| 225 |
+
generation_section["key_scale"], generation_section["time_signature"],
|
| 226 |
+
generation_section["vocal_language"], generation_section["audio_duration"]]:
|
| 227 |
+
trigger.change(
|
| 228 |
+
fn=gen_h.reset_format_caption_flag,
|
| 229 |
+
inputs=[],
|
| 230 |
+
outputs=[results_section["is_format_caption_state"]]
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# ========== Audio Uploads Accordion ==========
|
| 234 |
+
for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
|
| 235 |
+
trigger.change(
|
| 236 |
+
fn=gen_h.update_audio_uploads_accordion,
|
| 237 |
+
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 238 |
+
outputs=[generation_section["audio_uploads_accordion"]]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# ========== Instrumental Checkbox ==========
|
| 242 |
+
generation_section["instrumental_checkbox"].change(
|
| 243 |
+
fn=gen_h.handle_instrumental_checkbox,
|
| 244 |
+
inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
|
| 245 |
+
outputs=[generation_section["lyrics"]]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# ========== Format Button ==========
|
| 249 |
+
# Note: cfg_scale and negative_prompt are not supported in format mode
|
| 250 |
+
generation_section["format_btn"].click(
|
| 251 |
+
fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
|
| 252 |
+
llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
|
| 253 |
+
),
|
| 254 |
+
inputs=[
|
| 255 |
+
generation_section["captions"],
|
| 256 |
+
generation_section["lyrics"],
|
| 257 |
+
generation_section["bpm"],
|
| 258 |
+
generation_section["audio_duration"],
|
| 259 |
+
generation_section["key_scale"],
|
| 260 |
+
generation_section["time_signature"],
|
| 261 |
+
generation_section["lm_temperature"],
|
| 262 |
+
generation_section["lm_top_k"],
|
| 263 |
+
generation_section["lm_top_p"],
|
| 264 |
+
generation_section["constrained_decoding_debug"],
|
| 265 |
+
],
|
| 266 |
+
outputs=[
|
| 267 |
+
generation_section["captions"],
|
| 268 |
+
generation_section["lyrics"],
|
| 269 |
+
generation_section["bpm"],
|
| 270 |
+
generation_section["audio_duration"],
|
| 271 |
+
generation_section["key_scale"],
|
| 272 |
+
generation_section["vocal_language"],
|
| 273 |
+
generation_section["time_signature"],
|
| 274 |
+
results_section["is_format_caption_state"],
|
| 275 |
+
results_section["status_output"],
|
| 276 |
+
]
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# ========== Simple/Custom Mode Toggle ==========
|
| 280 |
+
generation_section["generation_mode"].change(
|
| 281 |
+
fn=gen_h.handle_generation_mode_change,
|
| 282 |
+
inputs=[generation_section["generation_mode"]],
|
| 283 |
+
outputs=[
|
| 284 |
+
generation_section["simple_mode_group"],
|
| 285 |
+
generation_section["caption_accordion"],
|
| 286 |
+
generation_section["lyrics_accordion"],
|
| 287 |
+
generation_section["generate_btn"],
|
| 288 |
+
generation_section["simple_sample_created"],
|
| 289 |
+
generation_section["optional_params_accordion"],
|
| 290 |
+
]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# ========== Simple Mode Instrumental Checkbox ==========
|
| 294 |
+
# When instrumental is checked, disable vocal language and set to ["unknown"]
|
| 295 |
+
generation_section["simple_instrumental_checkbox"].change(
|
| 296 |
+
fn=gen_h.handle_simple_instrumental_change,
|
| 297 |
+
inputs=[generation_section["simple_instrumental_checkbox"]],
|
| 298 |
+
outputs=[generation_section["simple_vocal_language"]]
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# ========== Random Description Button ==========
|
| 302 |
+
generation_section["random_desc_btn"].click(
|
| 303 |
+
fn=gen_h.load_random_simple_description,
|
| 304 |
+
inputs=[],
|
| 305 |
+
outputs=[
|
| 306 |
+
generation_section["simple_query_input"],
|
| 307 |
+
generation_section["simple_instrumental_checkbox"],
|
| 308 |
+
generation_section["simple_vocal_language"],
|
| 309 |
+
]
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# ========== Create Sample Button (Simple Mode) ==========
|
| 313 |
+
# Note: cfg_scale and negative_prompt are not supported in create_sample mode
|
| 314 |
+
generation_section["create_sample_btn"].click(
|
| 315 |
+
fn=lambda query, instrumental, vocal_lang, temp, top_k, top_p, debug: gen_h.handle_create_sample(
|
| 316 |
+
llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
|
| 317 |
+
),
|
| 318 |
+
inputs=[
|
| 319 |
+
generation_section["simple_query_input"],
|
| 320 |
+
generation_section["simple_instrumental_checkbox"],
|
| 321 |
+
generation_section["simple_vocal_language"],
|
| 322 |
+
generation_section["lm_temperature"],
|
| 323 |
+
generation_section["lm_top_k"],
|
| 324 |
+
generation_section["lm_top_p"],
|
| 325 |
+
generation_section["constrained_decoding_debug"],
|
| 326 |
+
],
|
| 327 |
+
outputs=[
|
| 328 |
+
generation_section["captions"],
|
| 329 |
+
generation_section["lyrics"],
|
| 330 |
+
generation_section["bpm"],
|
| 331 |
+
generation_section["audio_duration"],
|
| 332 |
+
generation_section["key_scale"],
|
| 333 |
+
generation_section["vocal_language"],
|
| 334 |
+
generation_section["simple_vocal_language"],
|
| 335 |
+
generation_section["time_signature"],
|
| 336 |
+
generation_section["instrumental_checkbox"],
|
| 337 |
+
generation_section["caption_accordion"],
|
| 338 |
+
generation_section["lyrics_accordion"],
|
| 339 |
+
generation_section["generate_btn"],
|
| 340 |
+
generation_section["simple_sample_created"],
|
| 341 |
+
generation_section["think_checkbox"],
|
| 342 |
+
results_section["is_format_caption_state"],
|
| 343 |
+
results_section["status_output"],
|
| 344 |
+
]
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
# ========== Load/Save Metadata ==========
|
| 348 |
+
generation_section["load_file"].upload(
|
| 349 |
+
fn=lambda file_obj: gen_h.load_metadata(file_obj, llm_handler),
|
| 350 |
+
inputs=[generation_section["load_file"]],
|
| 351 |
+
outputs=[
|
| 352 |
+
generation_section["task_type"],
|
| 353 |
+
generation_section["captions"],
|
| 354 |
+
generation_section["lyrics"],
|
| 355 |
+
generation_section["vocal_language"],
|
| 356 |
+
generation_section["bpm"],
|
| 357 |
+
generation_section["key_scale"],
|
| 358 |
+
generation_section["time_signature"],
|
| 359 |
+
generation_section["audio_duration"],
|
| 360 |
+
generation_section["batch_size_input"],
|
| 361 |
+
generation_section["inference_steps"],
|
| 362 |
+
generation_section["guidance_scale"],
|
| 363 |
+
generation_section["seed"],
|
| 364 |
+
generation_section["random_seed_checkbox"],
|
| 365 |
+
generation_section["use_adg"],
|
| 366 |
+
generation_section["cfg_interval_start"],
|
| 367 |
+
generation_section["cfg_interval_end"],
|
| 368 |
+
generation_section["shift"],
|
| 369 |
+
generation_section["infer_method"],
|
| 370 |
+
generation_section["custom_timesteps"],
|
| 371 |
+
generation_section["audio_format"],
|
| 372 |
+
generation_section["lm_temperature"],
|
| 373 |
+
generation_section["lm_cfg_scale"],
|
| 374 |
+
generation_section["lm_top_k"],
|
| 375 |
+
generation_section["lm_top_p"],
|
| 376 |
+
generation_section["lm_negative_prompt"],
|
| 377 |
+
generation_section["use_cot_metas"], # Added: use_cot_metas
|
| 378 |
+
generation_section["use_cot_caption"],
|
| 379 |
+
generation_section["use_cot_language"],
|
| 380 |
+
generation_section["audio_cover_strength"],
|
| 381 |
+
generation_section["think_checkbox"],
|
| 382 |
+
generation_section["text2music_audio_code_string"],
|
| 383 |
+
generation_section["repainting_start"],
|
| 384 |
+
generation_section["repainting_end"],
|
| 385 |
+
generation_section["track_name"],
|
| 386 |
+
generation_section["complete_track_classes"],
|
| 387 |
+
generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
|
| 388 |
+
results_section["is_format_caption_state"]
|
| 389 |
+
]
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Save buttons for all 8 audio outputs
|
| 393 |
+
download_existing_js = """(current_audio, batch_files) => {
|
| 394 |
+
// Debug: print what the input actually is
|
| 395 |
+
console.log("👉 [Debug] Current Audio Input:", current_audio);
|
| 396 |
+
|
| 397 |
+
// 1. Safety check
|
| 398 |
+
if (!current_audio) {
|
| 399 |
+
console.warn("⚠️ No audio selected or audio is empty.");
|
| 400 |
+
return;
|
| 401 |
+
}
|
| 402 |
+
if (!batch_files || !Array.isArray(batch_files)) {
|
| 403 |
+
console.warn("⚠️ Batch file list is empty/not ready.");
|
| 404 |
+
return;
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
// 2. Smartly extract path string
|
| 408 |
+
let pathString = "";
|
| 409 |
+
|
| 410 |
+
if (typeof current_audio === "string") {
|
| 411 |
+
// Case A: direct path string received
|
| 412 |
+
pathString = current_audio;
|
| 413 |
+
} else if (typeof current_audio === "object") {
|
| 414 |
+
// Case B: an object is received, try common properties
|
| 415 |
+
// Gradio file objects usually have path, url, or name
|
| 416 |
+
pathString = current_audio.path || current_audio.name || current_audio.url || "";
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
if (!pathString) {
|
| 420 |
+
console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
|
| 421 |
+
return;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
// 3. Extract Key (UUID)
|
| 425 |
+
// Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
|
| 426 |
+
let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
|
| 427 |
+
let key = filename.split('.')[0]; // get UUID without extension
|
| 428 |
+
|
| 429 |
+
console.log(`🔑 Key extracted: ${key}`);
|
| 430 |
+
|
| 431 |
+
// 4. Find matching file(s) in the list
|
| 432 |
+
let targets = batch_files.filter(f => {
|
| 433 |
+
// Also extract names from batch_files objects
|
| 434 |
+
// f usually contains name (backend path) and orig_name (download name)
|
| 435 |
+
const fPath = f.name || f.path || "";
|
| 436 |
+
return fPath.includes(key);
|
| 437 |
+
});
|
| 438 |
+
|
| 439 |
+
if (targets.length === 0) {
|
| 440 |
+
console.warn("❌ No matching files found in batch list for key:", key);
|
| 441 |
+
alert("Batch list does not contain this file yet. Please wait for generation to finish.");
|
| 442 |
+
return;
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
// 5. Trigger download(s)
|
| 446 |
+
console.log(`🎯 Found ${targets.length} files to download.`);
|
| 447 |
+
targets.forEach((f, index) => {
|
| 448 |
+
setTimeout(() => {
|
| 449 |
+
const a = document.createElement('a');
|
| 450 |
+
// Prefer url (frontend-accessible link), otherwise try data
|
| 451 |
+
a.href = f.url || f.data;
|
| 452 |
+
a.download = f.orig_name || "download";
|
| 453 |
+
a.style.display = 'none';
|
| 454 |
+
document.body.appendChild(a);
|
| 455 |
+
a.click();
|
| 456 |
+
document.body.removeChild(a);
|
| 457 |
+
}, index * 1000); // 300ms interval to avoid browser blocking
|
| 458 |
+
});
|
| 459 |
+
}
|
| 460 |
+
"""
|
| 461 |
+
for btn_idx in range(1, 9):
|
| 462 |
+
results_section[f"save_btn_{btn_idx}"].click(
|
| 463 |
+
fn=None,
|
| 464 |
+
inputs=[
|
| 465 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 466 |
+
results_section["generated_audio_batch"],
|
| 467 |
+
],
|
| 468 |
+
js=download_existing_js # Run the above JS
|
| 469 |
+
)
|
| 470 |
+
# ========== Send to SRC Handlers ==========
|
| 471 |
+
for btn_idx in range(1, 9):
|
| 472 |
+
results_section[f"send_to_src_btn_{btn_idx}"].click(
|
| 473 |
+
fn=res_h.send_audio_to_src_with_metadata,
|
| 474 |
+
inputs=[
|
| 475 |
+
results_section[f"generated_audio_{btn_idx}"],
|
| 476 |
+
results_section["lm_metadata_state"]
|
| 477 |
+
],
|
| 478 |
+
outputs=[
|
| 479 |
+
generation_section["src_audio"],
|
| 480 |
+
generation_section["bpm"],
|
| 481 |
+
generation_section["captions"],
|
| 482 |
+
generation_section["lyrics"],
|
| 483 |
+
generation_section["audio_duration"],
|
| 484 |
+
generation_section["key_scale"],
|
| 485 |
+
generation_section["vocal_language"],
|
| 486 |
+
generation_section["time_signature"],
|
| 487 |
+
results_section["is_format_caption_state"]
|
| 488 |
+
]
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# ========== Score Calculation Handlers ==========
|
| 492 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 493 |
+
def make_score_handler(idx):
|
| 494 |
+
return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
|
| 495 |
+
dit_handler, llm_handler, idx, scale, batch_idx, queue
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
for btn_idx in range(1, 9):
|
| 499 |
+
results_section[f"score_btn_{btn_idx}"].click(
|
| 500 |
+
fn=make_score_handler(btn_idx),
|
| 501 |
+
inputs=[
|
| 502 |
+
generation_section["score_scale"],
|
| 503 |
+
results_section["current_batch_index"],
|
| 504 |
+
results_section["batch_queue"],
|
| 505 |
+
],
|
| 506 |
+
outputs=[
|
| 507 |
+
results_section[f"score_display_{btn_idx}"],
|
| 508 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 509 |
+
results_section["batch_queue"]
|
| 510 |
+
]
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
# ========== LRC Timestamp Handlers ==========
|
| 514 |
+
# Use default argument to capture btn_idx value at definition time (Python closure fix)
|
| 515 |
+
def make_lrc_handler(idx):
|
| 516 |
+
return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
|
| 517 |
+
dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
for btn_idx in range(1, 9):
|
| 521 |
+
results_section[f"lrc_btn_{btn_idx}"].click(
|
| 522 |
+
fn=make_lrc_handler(btn_idx),
|
| 523 |
+
inputs=[
|
| 524 |
+
results_section["current_batch_index"],
|
| 525 |
+
results_section["batch_queue"],
|
| 526 |
+
generation_section["vocal_language"],
|
| 527 |
+
generation_section["inference_steps"],
|
| 528 |
+
],
|
| 529 |
+
outputs=[
|
| 530 |
+
results_section[f"lrc_display_{btn_idx}"],
|
| 531 |
+
results_section[f"details_accordion_{btn_idx}"],
|
| 532 |
+
# NOTE: Removed generated_audio output!
|
| 533 |
+
# Audio subtitles are now updated via lrc_display.change() event.
|
| 534 |
+
results_section["batch_queue"]
|
| 535 |
+
]
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
def generation_wrapper(*args):
|
| 539 |
+
yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
|
| 540 |
+
# ========== Generation Handler ==========
|
| 541 |
+
generation_section["generate_btn"].click(
|
| 542 |
+
fn=res_h.clear_audio_outputs_for_new_generation,
|
| 543 |
+
outputs=[
|
| 544 |
+
results_section["generated_audio_1"],
|
| 545 |
+
results_section["generated_audio_2"],
|
| 546 |
+
results_section["generated_audio_3"],
|
| 547 |
+
results_section["generated_audio_4"],
|
| 548 |
+
results_section["generated_audio_5"],
|
| 549 |
+
results_section["generated_audio_6"],
|
| 550 |
+
results_section["generated_audio_7"],
|
| 551 |
+
results_section["generated_audio_8"],
|
| 552 |
+
results_section["generated_audio_batch"],
|
| 553 |
+
],
|
| 554 |
+
).then(
|
| 555 |
+
fn=generation_wrapper,
|
| 556 |
+
inputs=[
|
| 557 |
+
generation_section["captions"],
|
| 558 |
+
generation_section["lyrics"],
|
| 559 |
+
generation_section["bpm"],
|
| 560 |
+
generation_section["key_scale"],
|
| 561 |
+
generation_section["time_signature"],
|
| 562 |
+
generation_section["vocal_language"],
|
| 563 |
+
generation_section["inference_steps"],
|
| 564 |
+
generation_section["guidance_scale"],
|
| 565 |
+
generation_section["random_seed_checkbox"],
|
| 566 |
+
generation_section["seed"],
|
| 567 |
+
generation_section["reference_audio"],
|
| 568 |
+
generation_section["audio_duration"],
|
| 569 |
+
generation_section["batch_size_input"],
|
| 570 |
+
generation_section["src_audio"],
|
| 571 |
+
generation_section["text2music_audio_code_string"],
|
| 572 |
+
generation_section["repainting_start"],
|
| 573 |
+
generation_section["repainting_end"],
|
| 574 |
+
generation_section["instruction_display_gen"],
|
| 575 |
+
generation_section["audio_cover_strength"],
|
| 576 |
+
generation_section["task_type"],
|
| 577 |
+
generation_section["use_adg"],
|
| 578 |
+
generation_section["cfg_interval_start"],
|
| 579 |
+
generation_section["cfg_interval_end"],
|
| 580 |
+
generation_section["shift"],
|
| 581 |
+
generation_section["infer_method"],
|
| 582 |
+
generation_section["custom_timesteps"],
|
| 583 |
+
generation_section["audio_format"],
|
| 584 |
+
generation_section["lm_temperature"],
|
| 585 |
+
generation_section["think_checkbox"],
|
| 586 |
+
generation_section["lm_cfg_scale"],
|
| 587 |
+
generation_section["lm_top_k"],
|
| 588 |
+
generation_section["lm_top_p"],
|
| 589 |
+
generation_section["lm_negative_prompt"],
|
| 590 |
+
generation_section["use_cot_metas"],
|
| 591 |
+
generation_section["use_cot_caption"],
|
| 592 |
+
generation_section["use_cot_language"],
|
| 593 |
+
results_section["is_format_caption_state"],
|
| 594 |
+
generation_section["constrained_decoding_debug"],
|
| 595 |
+
generation_section["allow_lm_batch"],
|
| 596 |
+
generation_section["auto_score"],
|
| 597 |
+
generation_section["auto_lrc"],
|
| 598 |
+
generation_section["score_scale"],
|
| 599 |
+
generation_section["lm_batch_chunk_size"],
|
| 600 |
+
generation_section["track_name"],
|
| 601 |
+
generation_section["complete_track_classes"],
|
| 602 |
+
generation_section["autogen_checkbox"],
|
| 603 |
+
results_section["current_batch_index"],
|
| 604 |
+
results_section["total_batches"],
|
| 605 |
+
results_section["batch_queue"],
|
| 606 |
+
results_section["generation_params_state"],
|
| 607 |
+
],
|
| 608 |
+
outputs=[
|
| 609 |
+
results_section["generated_audio_1"],
|
| 610 |
+
results_section["generated_audio_2"],
|
| 611 |
+
results_section["generated_audio_3"],
|
| 612 |
+
results_section["generated_audio_4"],
|
| 613 |
+
results_section["generated_audio_5"],
|
| 614 |
+
results_section["generated_audio_6"],
|
| 615 |
+
results_section["generated_audio_7"],
|
| 616 |
+
results_section["generated_audio_8"],
|
| 617 |
+
results_section["generated_audio_batch"],
|
| 618 |
+
results_section["generation_info"],
|
| 619 |
+
results_section["status_output"],
|
| 620 |
+
generation_section["seed"],
|
| 621 |
+
results_section["score_display_1"],
|
| 622 |
+
results_section["score_display_2"],
|
| 623 |
+
results_section["score_display_3"],
|
| 624 |
+
results_section["score_display_4"],
|
| 625 |
+
results_section["score_display_5"],
|
| 626 |
+
results_section["score_display_6"],
|
| 627 |
+
results_section["score_display_7"],
|
| 628 |
+
results_section["score_display_8"],
|
| 629 |
+
results_section["codes_display_1"],
|
| 630 |
+
results_section["codes_display_2"],
|
| 631 |
+
results_section["codes_display_3"],
|
| 632 |
+
results_section["codes_display_4"],
|
| 633 |
+
results_section["codes_display_5"],
|
| 634 |
+
results_section["codes_display_6"],
|
| 635 |
+
results_section["codes_display_7"],
|
| 636 |
+
results_section["codes_display_8"],
|
| 637 |
+
results_section["details_accordion_1"],
|
| 638 |
+
results_section["details_accordion_2"],
|
| 639 |
+
results_section["details_accordion_3"],
|
| 640 |
+
results_section["details_accordion_4"],
|
| 641 |
+
results_section["details_accordion_5"],
|
| 642 |
+
results_section["details_accordion_6"],
|
| 643 |
+
results_section["details_accordion_7"],
|
| 644 |
+
results_section["details_accordion_8"],
|
| 645 |
+
results_section["lrc_display_1"],
|
| 646 |
+
results_section["lrc_display_2"],
|
| 647 |
+
results_section["lrc_display_3"],
|
| 648 |
+
results_section["lrc_display_4"],
|
| 649 |
+
results_section["lrc_display_5"],
|
| 650 |
+
results_section["lrc_display_6"],
|
| 651 |
+
results_section["lrc_display_7"],
|
| 652 |
+
results_section["lrc_display_8"],
|
| 653 |
+
results_section["lm_metadata_state"],
|
| 654 |
+
results_section["is_format_caption_state"],
|
| 655 |
+
results_section["current_batch_index"],
|
| 656 |
+
results_section["total_batches"],
|
| 657 |
+
results_section["batch_queue"],
|
| 658 |
+
results_section["generation_params_state"],
|
| 659 |
+
results_section["batch_indicator"],
|
| 660 |
+
results_section["prev_batch_btn"],
|
| 661 |
+
results_section["next_batch_btn"],
|
| 662 |
+
results_section["next_batch_status"],
|
| 663 |
+
results_section["restore_params_btn"],
|
| 664 |
+
],
|
| 665 |
+
).then(
|
| 666 |
+
fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
|
| 667 |
+
inputs=[
|
| 668 |
+
generation_section["autogen_checkbox"],
|
| 669 |
+
results_section["generation_params_state"],
|
| 670 |
+
results_section["current_batch_index"],
|
| 671 |
+
results_section["total_batches"],
|
| 672 |
+
results_section["batch_queue"],
|
| 673 |
+
results_section["is_format_caption_state"],
|
| 674 |
+
],
|
| 675 |
+
outputs=[
|
| 676 |
+
results_section["batch_queue"],
|
| 677 |
+
results_section["total_batches"],
|
| 678 |
+
results_section["next_batch_status"],
|
| 679 |
+
results_section["next_batch_btn"],
|
| 680 |
+
]
|
| 681 |
+
)
|
| 682 |
+
|
| 683 |
+
# ========== Batch Navigation Handlers ==========
|
| 684 |
+
results_section["prev_batch_btn"].click(
|
| 685 |
+
fn=res_h.navigate_to_previous_batch,
|
| 686 |
+
inputs=[
|
| 687 |
+
results_section["current_batch_index"],
|
| 688 |
+
results_section["batch_queue"],
|
| 689 |
+
],
|
| 690 |
+
outputs=[
|
| 691 |
+
results_section["generated_audio_1"],
|
| 692 |
+
results_section["generated_audio_2"],
|
| 693 |
+
results_section["generated_audio_3"],
|
| 694 |
+
results_section["generated_audio_4"],
|
| 695 |
+
results_section["generated_audio_5"],
|
| 696 |
+
results_section["generated_audio_6"],
|
| 697 |
+
results_section["generated_audio_7"],
|
| 698 |
+
results_section["generated_audio_8"],
|
| 699 |
+
results_section["generated_audio_batch"],
|
| 700 |
+
results_section["generation_info"],
|
| 701 |
+
results_section["current_batch_index"],
|
| 702 |
+
results_section["batch_indicator"],
|
| 703 |
+
results_section["prev_batch_btn"],
|
| 704 |
+
results_section["next_batch_btn"],
|
| 705 |
+
results_section["status_output"],
|
| 706 |
+
results_section["score_display_1"],
|
| 707 |
+
results_section["score_display_2"],
|
| 708 |
+
results_section["score_display_3"],
|
| 709 |
+
results_section["score_display_4"],
|
| 710 |
+
results_section["score_display_5"],
|
| 711 |
+
results_section["score_display_6"],
|
| 712 |
+
results_section["score_display_7"],
|
| 713 |
+
results_section["score_display_8"],
|
| 714 |
+
results_section["codes_display_1"],
|
| 715 |
+
results_section["codes_display_2"],
|
| 716 |
+
results_section["codes_display_3"],
|
| 717 |
+
results_section["codes_display_4"],
|
| 718 |
+
results_section["codes_display_5"],
|
| 719 |
+
results_section["codes_display_6"],
|
| 720 |
+
results_section["codes_display_7"],
|
| 721 |
+
results_section["codes_display_8"],
|
| 722 |
+
results_section["lrc_display_1"],
|
| 723 |
+
results_section["lrc_display_2"],
|
| 724 |
+
results_section["lrc_display_3"],
|
| 725 |
+
results_section["lrc_display_4"],
|
| 726 |
+
results_section["lrc_display_5"],
|
| 727 |
+
results_section["lrc_display_6"],
|
| 728 |
+
results_section["lrc_display_7"],
|
| 729 |
+
results_section["lrc_display_8"],
|
| 730 |
+
results_section["details_accordion_1"],
|
| 731 |
+
results_section["details_accordion_2"],
|
| 732 |
+
results_section["details_accordion_3"],
|
| 733 |
+
results_section["details_accordion_4"],
|
| 734 |
+
results_section["details_accordion_5"],
|
| 735 |
+
results_section["details_accordion_6"],
|
| 736 |
+
results_section["details_accordion_7"],
|
| 737 |
+
results_section["details_accordion_8"],
|
| 738 |
+
results_section["restore_params_btn"],
|
| 739 |
+
]
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
results_section["next_batch_btn"].click(
|
| 743 |
+
fn=res_h.capture_current_params,
|
| 744 |
+
inputs=[
|
| 745 |
+
generation_section["captions"],
|
| 746 |
+
generation_section["lyrics"],
|
| 747 |
+
generation_section["bpm"],
|
| 748 |
+
generation_section["key_scale"],
|
| 749 |
+
generation_section["time_signature"],
|
| 750 |
+
generation_section["vocal_language"],
|
| 751 |
+
generation_section["inference_steps"],
|
| 752 |
+
generation_section["guidance_scale"],
|
| 753 |
+
generation_section["random_seed_checkbox"],
|
| 754 |
+
generation_section["seed"],
|
| 755 |
+
generation_section["reference_audio"],
|
| 756 |
+
generation_section["audio_duration"],
|
| 757 |
+
generation_section["batch_size_input"],
|
| 758 |
+
generation_section["src_audio"],
|
| 759 |
+
generation_section["text2music_audio_code_string"],
|
| 760 |
+
generation_section["repainting_start"],
|
| 761 |
+
generation_section["repainting_end"],
|
| 762 |
+
generation_section["instruction_display_gen"],
|
| 763 |
+
generation_section["audio_cover_strength"],
|
| 764 |
+
generation_section["task_type"],
|
| 765 |
+
generation_section["use_adg"],
|
| 766 |
+
generation_section["cfg_interval_start"],
|
| 767 |
+
generation_section["cfg_interval_end"],
|
| 768 |
+
generation_section["shift"],
|
| 769 |
+
generation_section["infer_method"],
|
| 770 |
+
generation_section["custom_timesteps"],
|
| 771 |
+
generation_section["audio_format"],
|
| 772 |
+
generation_section["lm_temperature"],
|
| 773 |
+
generation_section["think_checkbox"],
|
| 774 |
+
generation_section["lm_cfg_scale"],
|
| 775 |
+
generation_section["lm_top_k"],
|
| 776 |
+
generation_section["lm_top_p"],
|
| 777 |
+
generation_section["lm_negative_prompt"],
|
| 778 |
+
generation_section["use_cot_metas"],
|
| 779 |
+
generation_section["use_cot_caption"],
|
| 780 |
+
generation_section["use_cot_language"],
|
| 781 |
+
generation_section["constrained_decoding_debug"],
|
| 782 |
+
generation_section["allow_lm_batch"],
|
| 783 |
+
generation_section["auto_score"],
|
| 784 |
+
generation_section["auto_lrc"],
|
| 785 |
+
generation_section["score_scale"],
|
| 786 |
+
generation_section["lm_batch_chunk_size"],
|
| 787 |
+
generation_section["track_name"],
|
| 788 |
+
generation_section["complete_track_classes"],
|
| 789 |
+
],
|
| 790 |
+
outputs=[results_section["generation_params_state"]]
|
| 791 |
+
).then(
|
| 792 |
+
fn=res_h.navigate_to_next_batch,
|
| 793 |
+
inputs=[
|
| 794 |
+
generation_section["autogen_checkbox"],
|
| 795 |
+
results_section["current_batch_index"],
|
| 796 |
+
results_section["total_batches"],
|
| 797 |
+
results_section["batch_queue"],
|
| 798 |
+
],
|
| 799 |
+
outputs=[
|
| 800 |
+
results_section["generated_audio_1"],
|
| 801 |
+
results_section["generated_audio_2"],
|
| 802 |
+
results_section["generated_audio_3"],
|
| 803 |
+
results_section["generated_audio_4"],
|
| 804 |
+
results_section["generated_audio_5"],
|
| 805 |
+
results_section["generated_audio_6"],
|
| 806 |
+
results_section["generated_audio_7"],
|
| 807 |
+
results_section["generated_audio_8"],
|
| 808 |
+
results_section["generated_audio_batch"],
|
| 809 |
+
results_section["generation_info"],
|
| 810 |
+
results_section["current_batch_index"],
|
| 811 |
+
results_section["batch_indicator"],
|
| 812 |
+
results_section["prev_batch_btn"],
|
| 813 |
+
results_section["next_batch_btn"],
|
| 814 |
+
results_section["status_output"],
|
| 815 |
+
results_section["next_batch_status"],
|
| 816 |
+
results_section["score_display_1"],
|
| 817 |
+
results_section["score_display_2"],
|
| 818 |
+
results_section["score_display_3"],
|
| 819 |
+
results_section["score_display_4"],
|
| 820 |
+
results_section["score_display_5"],
|
| 821 |
+
results_section["score_display_6"],
|
| 822 |
+
results_section["score_display_7"],
|
| 823 |
+
results_section["score_display_8"],
|
| 824 |
+
results_section["codes_display_1"],
|
| 825 |
+
results_section["codes_display_2"],
|
| 826 |
+
results_section["codes_display_3"],
|
| 827 |
+
results_section["codes_display_4"],
|
| 828 |
+
results_section["codes_display_5"],
|
| 829 |
+
results_section["codes_display_6"],
|
| 830 |
+
results_section["codes_display_7"],
|
| 831 |
+
results_section["codes_display_8"],
|
| 832 |
+
results_section["lrc_display_1"],
|
| 833 |
+
results_section["lrc_display_2"],
|
| 834 |
+
results_section["lrc_display_3"],
|
| 835 |
+
results_section["lrc_display_4"],
|
| 836 |
+
results_section["lrc_display_5"],
|
| 837 |
+
results_section["lrc_display_6"],
|
| 838 |
+
results_section["lrc_display_7"],
|
| 839 |
+
results_section["lrc_display_8"],
|
| 840 |
+
results_section["details_accordion_1"],
|
| 841 |
+
results_section["details_accordion_2"],
|
| 842 |
+
results_section["details_accordion_3"],
|
| 843 |
+
results_section["details_accordion_4"],
|
| 844 |
+
results_section["details_accordion_5"],
|
| 845 |
+
results_section["details_accordion_6"],
|
| 846 |
+
results_section["details_accordion_7"],
|
| 847 |
+
results_section["details_accordion_8"],
|
| 848 |
+
results_section["restore_params_btn"],
|
| 849 |
+
]
|
| 850 |
+
).then(
|
| 851 |
+
fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
|
| 852 |
+
inputs=[
|
| 853 |
+
generation_section["autogen_checkbox"],
|
| 854 |
+
results_section["generation_params_state"],
|
| 855 |
+
results_section["current_batch_index"],
|
| 856 |
+
results_section["total_batches"],
|
| 857 |
+
results_section["batch_queue"],
|
| 858 |
+
results_section["is_format_caption_state"],
|
| 859 |
+
],
|
| 860 |
+
outputs=[
|
| 861 |
+
results_section["batch_queue"],
|
| 862 |
+
results_section["total_batches"],
|
| 863 |
+
results_section["next_batch_status"],
|
| 864 |
+
results_section["next_batch_btn"],
|
| 865 |
+
]
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# ========== Restore Parameters Handler ==========
|
| 869 |
+
results_section["restore_params_btn"].click(
|
| 870 |
+
fn=res_h.restore_batch_parameters,
|
| 871 |
+
inputs=[
|
| 872 |
+
results_section["current_batch_index"],
|
| 873 |
+
results_section["batch_queue"]
|
| 874 |
+
],
|
| 875 |
+
outputs=[
|
| 876 |
+
generation_section["text2music_audio_code_string"],
|
| 877 |
+
generation_section["captions"],
|
| 878 |
+
generation_section["lyrics"],
|
| 879 |
+
generation_section["bpm"],
|
| 880 |
+
generation_section["key_scale"],
|
| 881 |
+
generation_section["time_signature"],
|
| 882 |
+
generation_section["vocal_language"],
|
| 883 |
+
generation_section["audio_duration"],
|
| 884 |
+
generation_section["batch_size_input"],
|
| 885 |
+
generation_section["inference_steps"],
|
| 886 |
+
generation_section["lm_temperature"],
|
| 887 |
+
generation_section["lm_cfg_scale"],
|
| 888 |
+
generation_section["lm_top_k"],
|
| 889 |
+
generation_section["lm_top_p"],
|
| 890 |
+
generation_section["think_checkbox"],
|
| 891 |
+
generation_section["use_cot_caption"],
|
| 892 |
+
generation_section["use_cot_language"],
|
| 893 |
+
generation_section["allow_lm_batch"],
|
| 894 |
+
generation_section["track_name"],
|
| 895 |
+
generation_section["complete_track_classes"],
|
| 896 |
+
]
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# ========== LRC Display Change Handlers ==========
|
| 900 |
+
# NEW APPROACH: Use lrc_display.change() to update audio subtitles
|
| 901 |
+
# This decouples audio value updates from subtitle updates, avoiding flickering.
|
| 902 |
+
#
|
| 903 |
+
# When lrc_display text changes (from generate, LRC button, or manual edit):
|
| 904 |
+
# 1. lrc_display.change() is triggered
|
| 905 |
+
# 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
|
| 906 |
+
# 3. Audio value is NEVER updated here - only subtitles
|
| 907 |
+
for lrc_idx in range(1, 9):
|
| 908 |
+
results_section[f"lrc_display_{lrc_idx}"].change(
|
| 909 |
+
fn=res_h.update_audio_subtitles_from_lrc,
|
| 910 |
+
inputs=[
|
| 911 |
+
results_section[f"lrc_display_{lrc_idx}"],
|
| 912 |
+
# audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
|
| 913 |
+
],
|
| 914 |
+
outputs=[
|
| 915 |
+
results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
|
| 916 |
+
]
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
|
| 921 |
+
"""Setup event handlers for the training tab (dataset builder and LoRA training)"""
|
| 922 |
+
|
| 923 |
+
# ========== Load Existing Dataset (Top Section) ==========
|
| 924 |
+
|
| 925 |
+
# Load existing dataset JSON at the top of Dataset Builder
|
| 926 |
+
training_section["load_json_btn"].click(
|
| 927 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 928 |
+
inputs=[
|
| 929 |
+
training_section["load_json_path"],
|
| 930 |
+
training_section["dataset_builder_state"],
|
| 931 |
+
],
|
| 932 |
+
outputs=[
|
| 933 |
+
training_section["load_json_status"],
|
| 934 |
+
training_section["audio_files_table"],
|
| 935 |
+
training_section["sample_selector"],
|
| 936 |
+
training_section["dataset_builder_state"],
|
| 937 |
+
# Also update preview fields with first sample
|
| 938 |
+
training_section["preview_audio"],
|
| 939 |
+
training_section["preview_filename"],
|
| 940 |
+
training_section["edit_caption"],
|
| 941 |
+
training_section["edit_genre"],
|
| 942 |
+
training_section["prompt_override"],
|
| 943 |
+
training_section["edit_lyrics"],
|
| 944 |
+
training_section["edit_bpm"],
|
| 945 |
+
training_section["edit_keyscale"],
|
| 946 |
+
training_section["edit_timesig"],
|
| 947 |
+
training_section["edit_duration"],
|
| 948 |
+
training_section["edit_language"],
|
| 949 |
+
training_section["edit_instrumental"],
|
| 950 |
+
training_section["raw_lyrics_display"],
|
| 951 |
+
training_section["has_raw_lyrics_state"],
|
| 952 |
+
# Update dataset-level settings
|
| 953 |
+
training_section["dataset_name"],
|
| 954 |
+
training_section["custom_tag"],
|
| 955 |
+
training_section["tag_position"],
|
| 956 |
+
training_section["all_instrumental"],
|
| 957 |
+
training_section["genre_ratio"],
|
| 958 |
+
]
|
| 959 |
+
).then(
|
| 960 |
+
fn=lambda has_raw: gr.update(visible=has_raw),
|
| 961 |
+
inputs=[training_section["has_raw_lyrics_state"]],
|
| 962 |
+
outputs=[training_section["raw_lyrics_display"]],
|
| 963 |
+
)
|
| 964 |
+
|
| 965 |
+
# ========== Dataset Builder Handlers ==========
|
| 966 |
+
|
| 967 |
+
# Scan directory for audio files
|
| 968 |
+
training_section["scan_btn"].click(
|
| 969 |
+
fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
|
| 970 |
+
dir, name, tag, pos, instr, state
|
| 971 |
+
),
|
| 972 |
+
inputs=[
|
| 973 |
+
training_section["audio_directory"],
|
| 974 |
+
training_section["dataset_name"],
|
| 975 |
+
training_section["custom_tag"],
|
| 976 |
+
training_section["tag_position"],
|
| 977 |
+
training_section["all_instrumental"],
|
| 978 |
+
training_section["dataset_builder_state"],
|
| 979 |
+
],
|
| 980 |
+
outputs=[
|
| 981 |
+
training_section["audio_files_table"],
|
| 982 |
+
training_section["scan_status"],
|
| 983 |
+
training_section["sample_selector"],
|
| 984 |
+
training_section["dataset_builder_state"],
|
| 985 |
+
]
|
| 986 |
+
)
|
| 987 |
+
|
| 988 |
+
# Auto-label all samples
|
| 989 |
+
training_section["auto_label_btn"].click(
|
| 990 |
+
fn=lambda state, skip, fmt_lyrics, trans_lyrics, only_unlab: train_h.auto_label_all(
|
| 991 |
+
dit_handler, llm_handler, state, skip, fmt_lyrics, trans_lyrics, only_unlab
|
| 992 |
+
),
|
| 993 |
+
inputs=[
|
| 994 |
+
training_section["dataset_builder_state"],
|
| 995 |
+
training_section["skip_metas"],
|
| 996 |
+
training_section["format_lyrics"],
|
| 997 |
+
training_section["transcribe_lyrics"],
|
| 998 |
+
training_section["only_unlabeled"],
|
| 999 |
+
],
|
| 1000 |
+
outputs=[
|
| 1001 |
+
training_section["audio_files_table"],
|
| 1002 |
+
training_section["label_progress"],
|
| 1003 |
+
training_section["dataset_builder_state"],
|
| 1004 |
+
]
|
| 1005 |
+
).then(
|
| 1006 |
+
# Refresh preview/edit fields after labeling completes
|
| 1007 |
+
fn=train_h.get_sample_preview,
|
| 1008 |
+
inputs=[
|
| 1009 |
+
training_section["sample_selector"],
|
| 1010 |
+
training_section["dataset_builder_state"],
|
| 1011 |
+
],
|
| 1012 |
+
outputs=[
|
| 1013 |
+
training_section["preview_audio"],
|
| 1014 |
+
training_section["preview_filename"],
|
| 1015 |
+
training_section["edit_caption"],
|
| 1016 |
+
training_section["edit_genre"],
|
| 1017 |
+
training_section["prompt_override"],
|
| 1018 |
+
training_section["edit_lyrics"],
|
| 1019 |
+
training_section["edit_bpm"],
|
| 1020 |
+
training_section["edit_keyscale"],
|
| 1021 |
+
training_section["edit_timesig"],
|
| 1022 |
+
training_section["edit_duration"],
|
| 1023 |
+
training_section["edit_language"],
|
| 1024 |
+
training_section["edit_instrumental"],
|
| 1025 |
+
training_section["raw_lyrics_display"],
|
| 1026 |
+
training_section["has_raw_lyrics_state"],
|
| 1027 |
+
]
|
| 1028 |
+
).then(
|
| 1029 |
+
fn=lambda status: f"{status or '✅ Auto-label complete.'}\n✅ Preview refreshed.",
|
| 1030 |
+
inputs=[training_section["label_progress"]],
|
| 1031 |
+
outputs=[training_section["label_progress"]],
|
| 1032 |
+
).then(
|
| 1033 |
+
fn=lambda has_raw: gr.update(visible=bool(has_raw)),
|
| 1034 |
+
inputs=[training_section["has_raw_lyrics_state"]],
|
| 1035 |
+
outputs=[training_section["raw_lyrics_display"]],
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
# Mutual exclusion: format_lyrics and transcribe_lyrics cannot both be True
|
| 1039 |
+
training_section["format_lyrics"].change(
|
| 1040 |
+
fn=lambda fmt: gr.update(value=False) if fmt else gr.update(),
|
| 1041 |
+
inputs=[training_section["format_lyrics"]],
|
| 1042 |
+
outputs=[training_section["transcribe_lyrics"]]
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
+
training_section["transcribe_lyrics"].change(
|
| 1046 |
+
fn=lambda trans: gr.update(value=False) if trans else gr.update(),
|
| 1047 |
+
inputs=[training_section["transcribe_lyrics"]],
|
| 1048 |
+
outputs=[training_section["format_lyrics"]]
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
# Sample selector change - update preview
|
| 1052 |
+
training_section["sample_selector"].change(
|
| 1053 |
+
fn=train_h.get_sample_preview,
|
| 1054 |
+
inputs=[
|
| 1055 |
+
training_section["sample_selector"],
|
| 1056 |
+
training_section["dataset_builder_state"],
|
| 1057 |
+
],
|
| 1058 |
+
outputs=[
|
| 1059 |
+
training_section["preview_audio"],
|
| 1060 |
+
training_section["preview_filename"],
|
| 1061 |
+
training_section["edit_caption"],
|
| 1062 |
+
training_section["edit_genre"],
|
| 1063 |
+
training_section["prompt_override"],
|
| 1064 |
+
training_section["edit_lyrics"],
|
| 1065 |
+
training_section["edit_bpm"],
|
| 1066 |
+
training_section["edit_keyscale"],
|
| 1067 |
+
training_section["edit_timesig"],
|
| 1068 |
+
training_section["edit_duration"],
|
| 1069 |
+
training_section["edit_language"],
|
| 1070 |
+
training_section["edit_instrumental"],
|
| 1071 |
+
training_section["raw_lyrics_display"],
|
| 1072 |
+
training_section["has_raw_lyrics_state"],
|
| 1073 |
+
]
|
| 1074 |
+
).then(
|
| 1075 |
+
# Show/hide raw lyrics panel based on whether raw lyrics exist
|
| 1076 |
+
fn=lambda has_raw: gr.update(visible=has_raw),
|
| 1077 |
+
inputs=[training_section["has_raw_lyrics_state"]],
|
| 1078 |
+
outputs=[training_section["raw_lyrics_display"]],
|
| 1079 |
+
)
|
| 1080 |
+
|
| 1081 |
+
# Save sample edit
|
| 1082 |
+
training_section["save_edit_btn"].click(
|
| 1083 |
+
fn=train_h.save_sample_edit,
|
| 1084 |
+
inputs=[
|
| 1085 |
+
training_section["sample_selector"],
|
| 1086 |
+
training_section["edit_caption"],
|
| 1087 |
+
training_section["edit_genre"],
|
| 1088 |
+
training_section["prompt_override"],
|
| 1089 |
+
training_section["edit_lyrics"],
|
| 1090 |
+
training_section["edit_bpm"],
|
| 1091 |
+
training_section["edit_keyscale"],
|
| 1092 |
+
training_section["edit_timesig"],
|
| 1093 |
+
training_section["edit_language"],
|
| 1094 |
+
training_section["edit_instrumental"],
|
| 1095 |
+
training_section["dataset_builder_state"],
|
| 1096 |
+
],
|
| 1097 |
+
outputs=[
|
| 1098 |
+
training_section["audio_files_table"],
|
| 1099 |
+
training_section["edit_status"],
|
| 1100 |
+
training_section["dataset_builder_state"],
|
| 1101 |
+
]
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
# Update settings when changed (including genre_ratio)
|
| 1105 |
+
for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"], training_section["genre_ratio"]]:
|
| 1106 |
+
trigger.change(
|
| 1107 |
+
fn=train_h.update_settings,
|
| 1108 |
+
inputs=[
|
| 1109 |
+
training_section["custom_tag"],
|
| 1110 |
+
training_section["tag_position"],
|
| 1111 |
+
training_section["all_instrumental"],
|
| 1112 |
+
training_section["genre_ratio"],
|
| 1113 |
+
training_section["dataset_builder_state"],
|
| 1114 |
+
],
|
| 1115 |
+
outputs=[training_section["dataset_builder_state"]]
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
# Save dataset
|
| 1119 |
+
training_section["save_dataset_btn"].click(
|
| 1120 |
+
fn=train_h.save_dataset,
|
| 1121 |
+
inputs=[
|
| 1122 |
+
training_section["save_path"],
|
| 1123 |
+
training_section["dataset_name"],
|
| 1124 |
+
training_section["dataset_builder_state"],
|
| 1125 |
+
],
|
| 1126 |
+
outputs=[
|
| 1127 |
+
training_section["save_status"],
|
| 1128 |
+
training_section["save_path"],
|
| 1129 |
+
]
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
# ========== Preprocess Handlers ==========
|
| 1133 |
+
|
| 1134 |
+
# Load existing dataset JSON for preprocessing
|
| 1135 |
+
# This also updates the preview section so users can view/edit samples
|
| 1136 |
+
training_section["load_existing_dataset_btn"].click(
|
| 1137 |
+
fn=train_h.load_existing_dataset_for_preprocess,
|
| 1138 |
+
inputs=[
|
| 1139 |
+
training_section["load_existing_dataset_path"],
|
| 1140 |
+
training_section["dataset_builder_state"],
|
| 1141 |
+
],
|
| 1142 |
+
outputs=[
|
| 1143 |
+
training_section["load_existing_status"],
|
| 1144 |
+
training_section["audio_files_table"],
|
| 1145 |
+
training_section["sample_selector"],
|
| 1146 |
+
training_section["dataset_builder_state"],
|
| 1147 |
+
# Also update preview fields with first sample
|
| 1148 |
+
training_section["preview_audio"],
|
| 1149 |
+
training_section["preview_filename"],
|
| 1150 |
+
training_section["edit_caption"],
|
| 1151 |
+
training_section["edit_genre"],
|
| 1152 |
+
training_section["prompt_override"],
|
| 1153 |
+
training_section["edit_lyrics"],
|
| 1154 |
+
training_section["edit_bpm"],
|
| 1155 |
+
training_section["edit_keyscale"],
|
| 1156 |
+
training_section["edit_timesig"],
|
| 1157 |
+
training_section["edit_duration"],
|
| 1158 |
+
training_section["edit_language"],
|
| 1159 |
+
training_section["edit_instrumental"],
|
| 1160 |
+
training_section["raw_lyrics_display"],
|
| 1161 |
+
training_section["has_raw_lyrics_state"],
|
| 1162 |
+
# Update dataset-level settings
|
| 1163 |
+
training_section["dataset_name"],
|
| 1164 |
+
training_section["custom_tag"],
|
| 1165 |
+
training_section["tag_position"],
|
| 1166 |
+
training_section["all_instrumental"],
|
| 1167 |
+
training_section["genre_ratio"],
|
| 1168 |
+
]
|
| 1169 |
+
).then(
|
| 1170 |
+
fn=lambda has_raw: gr.update(visible=has_raw),
|
| 1171 |
+
inputs=[training_section["has_raw_lyrics_state"]],
|
| 1172 |
+
outputs=[training_section["raw_lyrics_display"]],
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
# Preprocess dataset to tensor files
|
| 1176 |
+
training_section["preprocess_btn"].click(
|
| 1177 |
+
fn=lambda output_dir, state: train_h.preprocess_dataset(
|
| 1178 |
+
output_dir, dit_handler, state
|
| 1179 |
+
),
|
| 1180 |
+
inputs=[
|
| 1181 |
+
training_section["preprocess_output_dir"],
|
| 1182 |
+
training_section["dataset_builder_state"],
|
| 1183 |
+
],
|
| 1184 |
+
outputs=[training_section["preprocess_progress"]]
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
# ========== Training Tab Handlers ==========
|
| 1188 |
+
|
| 1189 |
+
# Load preprocessed tensor dataset
|
| 1190 |
+
training_section["load_dataset_btn"].click(
|
| 1191 |
+
fn=train_h.load_training_dataset,
|
| 1192 |
+
inputs=[training_section["training_tensor_dir"]],
|
| 1193 |
+
outputs=[training_section["training_dataset_info"]]
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
# Start training from preprocessed tensors
|
| 1197 |
+
def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts):
|
| 1198 |
+
from loguru import logger
|
| 1199 |
+
if not isinstance(ts, dict):
|
| 1200 |
+
ts = {"is_training": False, "should_stop": False}
|
| 1201 |
+
try:
|
| 1202 |
+
for progress, log_msg, plot, state in train_h.start_training(
|
| 1203 |
+
tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts
|
| 1204 |
+
):
|
| 1205 |
+
yield progress, log_msg, plot, state
|
| 1206 |
+
except Exception as e:
|
| 1207 |
+
logger.exception("Training wrapper error")
|
| 1208 |
+
yield f"❌ Error: {str(e)}", str(e), None, ts
|
| 1209 |
+
|
| 1210 |
+
training_section["start_training_btn"].click(
|
| 1211 |
+
fn=training_wrapper,
|
| 1212 |
+
inputs=[
|
| 1213 |
+
training_section["training_tensor_dir"],
|
| 1214 |
+
training_section["lora_rank"],
|
| 1215 |
+
training_section["lora_alpha"],
|
| 1216 |
+
training_section["lora_dropout"],
|
| 1217 |
+
training_section["learning_rate"],
|
| 1218 |
+
training_section["train_epochs"],
|
| 1219 |
+
training_section["train_batch_size"],
|
| 1220 |
+
training_section["gradient_accumulation"],
|
| 1221 |
+
training_section["save_every_n_epochs"],
|
| 1222 |
+
training_section["training_shift"],
|
| 1223 |
+
training_section["training_seed"],
|
| 1224 |
+
training_section["lora_output_dir"],
|
| 1225 |
+
training_section["resume_checkpoint_dir"],
|
| 1226 |
+
training_section["training_state"],
|
| 1227 |
+
],
|
| 1228 |
+
outputs=[
|
| 1229 |
+
training_section["training_progress"],
|
| 1230 |
+
training_section["training_log"],
|
| 1231 |
+
training_section["training_loss_plot"],
|
| 1232 |
+
training_section["training_state"],
|
| 1233 |
+
]
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
# Stop training
|
| 1237 |
+
training_section["stop_training_btn"].click(
|
| 1238 |
+
fn=train_h.stop_training,
|
| 1239 |
+
inputs=[training_section["training_state"]],
|
| 1240 |
+
outputs=[
|
| 1241 |
+
training_section["training_progress"],
|
| 1242 |
+
training_section["training_state"],
|
| 1243 |
+
]
|
| 1244 |
+
)
|
| 1245 |
+
|
| 1246 |
+
# Export LoRA
|
| 1247 |
+
training_section["export_lora_btn"].click(
|
| 1248 |
+
fn=train_h.export_lora,
|
| 1249 |
+
inputs=[
|
| 1250 |
+
training_section["export_path"],
|
| 1251 |
+
training_section["lora_output_dir"],
|
| 1252 |
+
],
|
| 1253 |
+
outputs=[training_section["export_status"]]
|
| 1254 |
+
)
|
acestep/gradio_ui/events/generation_handlers.py
ADDED
|
@@ -0,0 +1,1050 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generation Input Handlers Module
|
| 3 |
+
Contains event handlers and helper functions related to generation inputs
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import random
|
| 8 |
+
import glob
|
| 9 |
+
import gradio as gr
|
| 10 |
+
from typing import Optional, List, Tuple
|
| 11 |
+
from acestep.constants import (
|
| 12 |
+
TASK_TYPES_TURBO,
|
| 13 |
+
TASK_TYPES_BASE,
|
| 14 |
+
)
|
| 15 |
+
from acestep.gradio_ui.i18n import t
|
| 16 |
+
from acestep.inference import understand_music, create_sample, format_sample
|
| 17 |
+
from acestep.gpu_config import get_global_gpu_config
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def clamp_duration_to_gpu_limit(duration_value: Optional[float], llm_handler=None) -> Optional[float]:
|
| 21 |
+
"""
|
| 22 |
+
Clamp duration value to GPU memory limit.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
duration_value: Duration in seconds (can be None or -1 for no limit)
|
| 26 |
+
llm_handler: LLM handler instance (to check if LM is initialized)
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Clamped duration value, or original value if within limits
|
| 30 |
+
"""
|
| 31 |
+
if duration_value is None or duration_value <= 0:
|
| 32 |
+
return duration_value
|
| 33 |
+
|
| 34 |
+
gpu_config = get_global_gpu_config()
|
| 35 |
+
lm_initialized = llm_handler.llm_initialized if llm_handler else False
|
| 36 |
+
max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
|
| 37 |
+
|
| 38 |
+
if duration_value > max_duration:
|
| 39 |
+
return float(max_duration)
|
| 40 |
+
|
| 41 |
+
return duration_value
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def parse_and_validate_timesteps(
|
| 45 |
+
timesteps_str: str,
|
| 46 |
+
inference_steps: int
|
| 47 |
+
) -> Tuple[Optional[List[float]], bool, str]:
|
| 48 |
+
"""
|
| 49 |
+
Parse timesteps string and validate.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 53 |
+
inference_steps: Expected number of inference steps
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
Tuple of (parsed_timesteps, has_warning, warning_message)
|
| 57 |
+
- parsed_timesteps: List of float timesteps, or None if invalid/empty
|
| 58 |
+
- has_warning: Whether a warning was shown
|
| 59 |
+
- warning_message: Description of the warning
|
| 60 |
+
"""
|
| 61 |
+
if not timesteps_str or not timesteps_str.strip():
|
| 62 |
+
return None, False, ""
|
| 63 |
+
|
| 64 |
+
# Parse comma-separated values
|
| 65 |
+
values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
|
| 66 |
+
|
| 67 |
+
if not values:
|
| 68 |
+
return None, False, ""
|
| 69 |
+
|
| 70 |
+
# Handle optional trailing 0
|
| 71 |
+
if values[-1] != "0":
|
| 72 |
+
values.append("0")
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
timesteps = [float(v) for v in values]
|
| 76 |
+
except ValueError:
|
| 77 |
+
gr.Warning(t("messages.invalid_timesteps_format"))
|
| 78 |
+
return None, True, "Invalid format"
|
| 79 |
+
|
| 80 |
+
# Validate range [0, 1]
|
| 81 |
+
if any(ts < 0 or ts > 1 for ts in timesteps):
|
| 82 |
+
gr.Warning(t("messages.timesteps_out_of_range"))
|
| 83 |
+
return None, True, "Out of range"
|
| 84 |
+
|
| 85 |
+
# Check if count matches inference_steps
|
| 86 |
+
actual_steps = len(timesteps) - 1
|
| 87 |
+
if actual_steps != inference_steps:
|
| 88 |
+
gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
|
| 89 |
+
return timesteps, True, f"Using {actual_steps} steps from timesteps"
|
| 90 |
+
|
| 91 |
+
return timesteps, False, ""
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_metadata(file_obj, llm_handler=None):
|
| 95 |
+
"""Load generation parameters from a JSON file
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
file_obj: Uploaded file object
|
| 99 |
+
llm_handler: LLM handler instance (optional, for GPU duration limit check)
|
| 100 |
+
"""
|
| 101 |
+
if file_obj is None:
|
| 102 |
+
gr.Warning(t("messages.no_file_selected"))
|
| 103 |
+
return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
# Read the uploaded file
|
| 107 |
+
if hasattr(file_obj, 'name'):
|
| 108 |
+
filepath = file_obj.name
|
| 109 |
+
else:
|
| 110 |
+
filepath = file_obj
|
| 111 |
+
|
| 112 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 113 |
+
metadata = json.load(f)
|
| 114 |
+
|
| 115 |
+
# Extract all fields
|
| 116 |
+
task_type = metadata.get('task_type', 'text2music')
|
| 117 |
+
captions = metadata.get('caption', '')
|
| 118 |
+
lyrics = metadata.get('lyrics', '')
|
| 119 |
+
vocal_language = metadata.get('vocal_language', 'unknown')
|
| 120 |
+
|
| 121 |
+
# Convert bpm
|
| 122 |
+
bpm_value = metadata.get('bpm')
|
| 123 |
+
if bpm_value is not None and bpm_value != "N/A":
|
| 124 |
+
try:
|
| 125 |
+
bpm = int(bpm_value) if bpm_value else None
|
| 126 |
+
except:
|
| 127 |
+
bpm = None
|
| 128 |
+
else:
|
| 129 |
+
bpm = None
|
| 130 |
+
|
| 131 |
+
key_scale = metadata.get('keyscale', '')
|
| 132 |
+
time_signature = metadata.get('timesignature', '')
|
| 133 |
+
|
| 134 |
+
# Convert duration
|
| 135 |
+
duration_value = metadata.get('duration', -1)
|
| 136 |
+
if duration_value is not None and duration_value != "N/A":
|
| 137 |
+
try:
|
| 138 |
+
audio_duration = float(duration_value)
|
| 139 |
+
# Clamp duration to GPU memory limit
|
| 140 |
+
audio_duration = clamp_duration_to_gpu_limit(audio_duration, llm_handler)
|
| 141 |
+
except:
|
| 142 |
+
audio_duration = -1
|
| 143 |
+
else:
|
| 144 |
+
audio_duration = -1
|
| 145 |
+
|
| 146 |
+
batch_size = metadata.get('batch_size', 2)
|
| 147 |
+
inference_steps = metadata.get('inference_steps', 8)
|
| 148 |
+
guidance_scale = metadata.get('guidance_scale', 7.0)
|
| 149 |
+
seed = metadata.get('seed', '-1')
|
| 150 |
+
random_seed = False # Always set to False when loading to enable reproducibility with saved seed
|
| 151 |
+
use_adg = metadata.get('use_adg', False)
|
| 152 |
+
cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
|
| 153 |
+
cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
|
| 154 |
+
audio_format = metadata.get('audio_format', 'mp3')
|
| 155 |
+
lm_temperature = metadata.get('lm_temperature', 0.85)
|
| 156 |
+
lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
|
| 157 |
+
lm_top_k = metadata.get('lm_top_k', 0)
|
| 158 |
+
lm_top_p = metadata.get('lm_top_p', 0.9)
|
| 159 |
+
lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
|
| 160 |
+
use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
|
| 161 |
+
use_cot_caption = metadata.get('use_cot_caption', True)
|
| 162 |
+
use_cot_language = metadata.get('use_cot_language', True)
|
| 163 |
+
audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
|
| 164 |
+
think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
|
| 165 |
+
audio_codes = metadata.get('audio_codes', '')
|
| 166 |
+
repainting_start = metadata.get('repainting_start', 0.0)
|
| 167 |
+
repainting_end = metadata.get('repainting_end', -1)
|
| 168 |
+
track_name = metadata.get('track_name')
|
| 169 |
+
complete_track_classes = metadata.get('complete_track_classes', [])
|
| 170 |
+
shift = metadata.get('shift', 3.0) # Default 3.0 for base models
|
| 171 |
+
infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
|
| 172 |
+
custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
|
| 173 |
+
if custom_timesteps is None:
|
| 174 |
+
custom_timesteps = ''
|
| 175 |
+
instrumental = metadata.get('instrumental', False) # Added: read instrumental
|
| 176 |
+
|
| 177 |
+
gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
|
| 178 |
+
|
| 179 |
+
return (
|
| 180 |
+
task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
|
| 181 |
+
audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
|
| 182 |
+
use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
|
| 183 |
+
custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
|
| 184 |
+
audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
|
| 185 |
+
use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
|
| 186 |
+
think, audio_codes, repainting_start, repainting_end,
|
| 187 |
+
track_name, complete_track_classes, instrumental,
|
| 188 |
+
True # Set is_format_caption to True when loading from file
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
except json.JSONDecodeError as e:
|
| 192 |
+
gr.Warning(t("messages.invalid_json", error=str(e)))
|
| 193 |
+
return [None] * 36 + [False]
|
| 194 |
+
except Exception as e:
|
| 195 |
+
gr.Warning(t("messages.load_error", error=str(e)))
|
| 196 |
+
return [None] * 36 + [False]
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def load_random_example(task_type: str, llm_handler=None):
|
| 200 |
+
"""Load a random example from the task-specific examples directory
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
task_type: The task type (e.g., "text2music")
|
| 204 |
+
llm_handler: LLM handler instance (optional, for GPU duration limit check)
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
| 208 |
+
"""
|
| 209 |
+
try:
|
| 210 |
+
# Get the project root directory
|
| 211 |
+
current_file = os.path.abspath(__file__)
|
| 212 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 213 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 214 |
+
|
| 215 |
+
# Construct the examples directory path
|
| 216 |
+
examples_dir = os.path.join(project_root, "examples", task_type)
|
| 217 |
+
|
| 218 |
+
# Check if directory exists
|
| 219 |
+
if not os.path.exists(examples_dir):
|
| 220 |
+
gr.Warning(f"Examples directory not found: examples/{task_type}/")
|
| 221 |
+
return "", "", True, None, None, "", "", ""
|
| 222 |
+
|
| 223 |
+
# Find all JSON files in the directory
|
| 224 |
+
json_files = glob.glob(os.path.join(examples_dir, "*.json"))
|
| 225 |
+
|
| 226 |
+
if not json_files:
|
| 227 |
+
gr.Warning(f"No JSON files found in examples/{task_type}/")
|
| 228 |
+
return "", "", True, None, None, "", "", ""
|
| 229 |
+
|
| 230 |
+
# Randomly select one file
|
| 231 |
+
selected_file = random.choice(json_files)
|
| 232 |
+
|
| 233 |
+
# Read and parse JSON
|
| 234 |
+
try:
|
| 235 |
+
with open(selected_file, 'r', encoding='utf-8') as f:
|
| 236 |
+
data = json.load(f)
|
| 237 |
+
|
| 238 |
+
# Extract caption (prefer 'caption', fallback to 'prompt')
|
| 239 |
+
caption_value = data.get('caption', data.get('prompt', ''))
|
| 240 |
+
if not isinstance(caption_value, str):
|
| 241 |
+
caption_value = str(caption_value) if caption_value else ''
|
| 242 |
+
|
| 243 |
+
# Extract lyrics
|
| 244 |
+
lyrics_value = data.get('lyrics', '')
|
| 245 |
+
if not isinstance(lyrics_value, str):
|
| 246 |
+
lyrics_value = str(lyrics_value) if lyrics_value else ''
|
| 247 |
+
|
| 248 |
+
# Extract think (default to True if not present)
|
| 249 |
+
think_value = data.get('think', True)
|
| 250 |
+
if not isinstance(think_value, bool):
|
| 251 |
+
think_value = True
|
| 252 |
+
|
| 253 |
+
# Extract optional metadata fields
|
| 254 |
+
bpm_value = None
|
| 255 |
+
if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
|
| 256 |
+
try:
|
| 257 |
+
bpm_value = int(data['bpm'])
|
| 258 |
+
except (ValueError, TypeError):
|
| 259 |
+
pass
|
| 260 |
+
|
| 261 |
+
duration_value = None
|
| 262 |
+
if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
|
| 263 |
+
try:
|
| 264 |
+
duration_value = float(data['duration'])
|
| 265 |
+
# Clamp duration to GPU memory limit
|
| 266 |
+
duration_value = clamp_duration_to_gpu_limit(duration_value, llm_handler)
|
| 267 |
+
except (ValueError, TypeError):
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
keyscale_value = data.get('keyscale', '')
|
| 271 |
+
if keyscale_value in [None, "N/A"]:
|
| 272 |
+
keyscale_value = ''
|
| 273 |
+
|
| 274 |
+
language_value = data.get('language', '')
|
| 275 |
+
if language_value in [None, "N/A"]:
|
| 276 |
+
language_value = ''
|
| 277 |
+
|
| 278 |
+
timesignature_value = data.get('timesignature', '')
|
| 279 |
+
if timesignature_value in [None, "N/A"]:
|
| 280 |
+
timesignature_value = ''
|
| 281 |
+
|
| 282 |
+
gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
|
| 283 |
+
return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
|
| 284 |
+
|
| 285 |
+
except json.JSONDecodeError as e:
|
| 286 |
+
gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
|
| 287 |
+
return "", "", True, None, None, "", "", ""
|
| 288 |
+
except Exception as e:
|
| 289 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 290 |
+
return "", "", True, None, None, "", "", ""
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 294 |
+
return "", "", True, None, None, "", "", ""
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
|
| 298 |
+
"""Smart sample function that uses LM if initialized, otherwise falls back to examples
|
| 299 |
+
|
| 300 |
+
This is a Gradio wrapper that uses the understand_music API from acestep.inference
|
| 301 |
+
to generate examples when LM is available.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
llm_handler: LLM handler instance
|
| 305 |
+
task_type: The task type (e.g., "text2music")
|
| 306 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
|
| 310 |
+
"""
|
| 311 |
+
# Check if LM is initialized
|
| 312 |
+
if llm_handler.llm_initialized:
|
| 313 |
+
# Use LM to generate example via understand_music API
|
| 314 |
+
try:
|
| 315 |
+
result = understand_music(
|
| 316 |
+
llm_handler=llm_handler,
|
| 317 |
+
audio_codes="NO USER INPUT", # Empty input triggers example generation
|
| 318 |
+
temperature=0.85,
|
| 319 |
+
use_constrained_decoding=True,
|
| 320 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
if result.success:
|
| 324 |
+
gr.Info(t("messages.lm_generated"))
|
| 325 |
+
# Clamp duration to GPU memory limit
|
| 326 |
+
clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
|
| 327 |
+
return (
|
| 328 |
+
result.caption,
|
| 329 |
+
result.lyrics,
|
| 330 |
+
True, # Always enable think when using LM-generated examples
|
| 331 |
+
result.bpm,
|
| 332 |
+
clamped_duration,
|
| 333 |
+
result.keyscale,
|
| 334 |
+
result.language,
|
| 335 |
+
result.timesignature,
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
gr.Warning(t("messages.lm_fallback"))
|
| 339 |
+
return load_random_example(task_type)
|
| 340 |
+
|
| 341 |
+
except Exception as e:
|
| 342 |
+
gr.Warning(t("messages.lm_fallback"))
|
| 343 |
+
return load_random_example(task_type)
|
| 344 |
+
else:
|
| 345 |
+
# LM not initialized, use examples directory
|
| 346 |
+
return load_random_example(task_type)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def load_random_simple_description():
|
| 350 |
+
"""Load a random description from the simple_mode examples directory.
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
Tuple of (description, instrumental, vocal_language) for updating UI components
|
| 354 |
+
"""
|
| 355 |
+
try:
|
| 356 |
+
# Get the project root directory
|
| 357 |
+
current_file = os.path.abspath(__file__)
|
| 358 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 359 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 360 |
+
|
| 361 |
+
# Construct the examples directory path
|
| 362 |
+
examples_dir = os.path.join(project_root, "examples", "simple_mode")
|
| 363 |
+
|
| 364 |
+
# Check if directory exists
|
| 365 |
+
if not os.path.exists(examples_dir):
|
| 366 |
+
gr.Warning(t("messages.simple_examples_not_found"))
|
| 367 |
+
return gr.update(), gr.update(), gr.update()
|
| 368 |
+
|
| 369 |
+
# Find all JSON files in the directory
|
| 370 |
+
json_files = glob.glob(os.path.join(examples_dir, "*.json"))
|
| 371 |
+
|
| 372 |
+
if not json_files:
|
| 373 |
+
gr.Warning(t("messages.simple_examples_empty"))
|
| 374 |
+
return gr.update(), gr.update(), gr.update()
|
| 375 |
+
|
| 376 |
+
# Randomly select one file
|
| 377 |
+
selected_file = random.choice(json_files)
|
| 378 |
+
|
| 379 |
+
# Read and parse JSON
|
| 380 |
+
try:
|
| 381 |
+
with open(selected_file, 'r', encoding='utf-8') as f:
|
| 382 |
+
data = json.load(f)
|
| 383 |
+
|
| 384 |
+
# Extract fields
|
| 385 |
+
description = data.get('description', '')
|
| 386 |
+
instrumental = data.get('instrumental', False)
|
| 387 |
+
vocal_language = data.get('vocal_language', 'unknown')
|
| 388 |
+
|
| 389 |
+
# Ensure vocal_language is a string
|
| 390 |
+
if isinstance(vocal_language, list):
|
| 391 |
+
vocal_language = vocal_language[0] if vocal_language else 'unknown'
|
| 392 |
+
|
| 393 |
+
gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
|
| 394 |
+
return description, instrumental, vocal_language
|
| 395 |
+
|
| 396 |
+
except json.JSONDecodeError as e:
|
| 397 |
+
gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
|
| 398 |
+
return gr.update(), gr.update(), gr.update()
|
| 399 |
+
except Exception as e:
|
| 400 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 401 |
+
return gr.update(), gr.update(), gr.update()
|
| 402 |
+
|
| 403 |
+
except Exception as e:
|
| 404 |
+
gr.Warning(t("messages.example_error", error=str(e)))
|
| 405 |
+
return gr.update(), gr.update(), gr.update()
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def refresh_checkpoints(dit_handler):
|
| 409 |
+
"""Refresh available checkpoints"""
|
| 410 |
+
choices = dit_handler.get_available_checkpoints()
|
| 411 |
+
return gr.update(choices=choices)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
def update_model_type_settings(config_path):
|
| 415 |
+
"""Update UI settings based on model type (fallback when handler not initialized yet)
|
| 416 |
+
|
| 417 |
+
Note: This is used as a fallback when the user changes config_path dropdown
|
| 418 |
+
before initializing the model. The actual settings are determined by the
|
| 419 |
+
handler's is_turbo_model() method after initialization.
|
| 420 |
+
"""
|
| 421 |
+
if config_path is None:
|
| 422 |
+
config_path = ""
|
| 423 |
+
config_path_lower = config_path.lower()
|
| 424 |
+
|
| 425 |
+
# Determine is_turbo based on config_path string
|
| 426 |
+
# This is a heuristic fallback - actual model type is determined after loading
|
| 427 |
+
if "turbo" in config_path_lower:
|
| 428 |
+
is_turbo = True
|
| 429 |
+
elif "base" in config_path_lower:
|
| 430 |
+
is_turbo = False
|
| 431 |
+
else:
|
| 432 |
+
# Default to turbo settings for unknown model types
|
| 433 |
+
is_turbo = True
|
| 434 |
+
|
| 435 |
+
return get_model_type_ui_settings(is_turbo)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu, compile_model, quantization):
|
| 439 |
+
"""Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
|
| 440 |
+
# Convert quantization checkbox to value (int8_weight_only if checked, None if not)
|
| 441 |
+
quant_value = "int8_weight_only" if quantization else None
|
| 442 |
+
|
| 443 |
+
# Initialize DiT handler
|
| 444 |
+
status, enable = dit_handler.initialize_service(
|
| 445 |
+
checkpoint, config_path, device,
|
| 446 |
+
use_flash_attention=use_flash_attention, compile_model=compile_model,
|
| 447 |
+
offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu,
|
| 448 |
+
quantization=quant_value
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Initialize LM handler if requested
|
| 452 |
+
if init_llm:
|
| 453 |
+
# Get checkpoint directory
|
| 454 |
+
current_file = os.path.abspath(__file__)
|
| 455 |
+
# This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
|
| 456 |
+
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
|
| 457 |
+
checkpoint_dir = os.path.join(project_root, "checkpoints")
|
| 458 |
+
|
| 459 |
+
lm_status, lm_success = llm_handler.initialize(
|
| 460 |
+
checkpoint_dir=checkpoint_dir,
|
| 461 |
+
lm_model_path=lm_model_path,
|
| 462 |
+
backend=backend,
|
| 463 |
+
device=device,
|
| 464 |
+
offload_to_cpu=offload_to_cpu,
|
| 465 |
+
dtype=None,
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if lm_success:
|
| 469 |
+
status += f"\n{lm_status}"
|
| 470 |
+
else:
|
| 471 |
+
status += f"\n{lm_status}"
|
| 472 |
+
# Don't fail the entire initialization if LM fails, but log it
|
| 473 |
+
# Keep enable as is (DiT initialization result) even if LM fails
|
| 474 |
+
|
| 475 |
+
# Check if model is initialized - if so, collapse the accordion
|
| 476 |
+
is_model_initialized = dit_handler.model is not None
|
| 477 |
+
accordion_state = gr.Accordion(open=not is_model_initialized)
|
| 478 |
+
|
| 479 |
+
# Get model type settings based on actual loaded model
|
| 480 |
+
is_turbo = dit_handler.is_turbo_model()
|
| 481 |
+
model_type_settings = get_model_type_ui_settings(is_turbo)
|
| 482 |
+
|
| 483 |
+
return (
|
| 484 |
+
status,
|
| 485 |
+
gr.update(interactive=enable),
|
| 486 |
+
accordion_state,
|
| 487 |
+
*model_type_settings
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def get_model_type_ui_settings(is_turbo: bool):
|
| 492 |
+
"""Get UI settings based on whether the model is turbo or base"""
|
| 493 |
+
if is_turbo:
|
| 494 |
+
# Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
|
| 495 |
+
return (
|
| 496 |
+
gr.update(value=8, maximum=20, minimum=1), # inference_steps
|
| 497 |
+
gr.update(visible=False), # guidance_scale
|
| 498 |
+
gr.update(visible=False), # use_adg
|
| 499 |
+
gr.update(value=3.0, visible=True), # shift (show with default 3.0)
|
| 500 |
+
gr.update(visible=False), # cfg_interval_start
|
| 501 |
+
gr.update(visible=False), # cfg_interval_end
|
| 502 |
+
gr.update(choices=TASK_TYPES_TURBO), # task_type
|
| 503 |
+
)
|
| 504 |
+
else:
|
| 505 |
+
# Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
|
| 506 |
+
return (
|
| 507 |
+
gr.update(value=32, maximum=200, minimum=1), # inference_steps
|
| 508 |
+
gr.update(visible=True), # guidance_scale
|
| 509 |
+
gr.update(visible=True), # use_adg
|
| 510 |
+
gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
|
| 511 |
+
gr.update(visible=True), # cfg_interval_start
|
| 512 |
+
gr.update(visible=True), # cfg_interval_end
|
| 513 |
+
gr.update(choices=TASK_TYPES_BASE), # task_type
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def update_negative_prompt_visibility(init_llm_checked):
|
| 518 |
+
"""Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
|
| 519 |
+
return gr.update(visible=init_llm_checked)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def _has_reference_audio(reference_audio) -> bool:
|
| 523 |
+
"""True if reference_audio has a usable value (Gradio Audio returns path string or (path, sr))."""
|
| 524 |
+
if reference_audio is None:
|
| 525 |
+
return False
|
| 526 |
+
if isinstance(reference_audio, str):
|
| 527 |
+
return bool(reference_audio.strip())
|
| 528 |
+
if isinstance(reference_audio, (list, tuple)) and reference_audio:
|
| 529 |
+
return bool(reference_audio[0])
|
| 530 |
+
return False
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def update_audio_cover_strength_visibility(task_type_value, init_llm_checked, reference_audio=None):
|
| 534 |
+
"""Update audio_cover_strength visibility and label. Show Similarity/Denoise when reference audio is present."""
|
| 535 |
+
has_reference = _has_reference_audio(reference_audio)
|
| 536 |
+
# Show if task is cover, LM is initialized, or reference audio is present (audio-conditioned generation)
|
| 537 |
+
is_visible = (task_type_value == "cover") or init_llm_checked or has_reference
|
| 538 |
+
# Label priority: cover -> LM codes -> Similarity/Denoise (reference audio)
|
| 539 |
+
if task_type_value == "cover":
|
| 540 |
+
label = t("generation.cover_strength_label")
|
| 541 |
+
info = t("generation.cover_strength_info")
|
| 542 |
+
elif init_llm_checked:
|
| 543 |
+
label = t("generation.codes_strength_label")
|
| 544 |
+
info = t("generation.codes_strength_info")
|
| 545 |
+
elif has_reference:
|
| 546 |
+
label = t("generation.similarity_denoise_label")
|
| 547 |
+
info = t("generation.similarity_denoise_info")
|
| 548 |
+
else:
|
| 549 |
+
label = t("generation.cover_strength_label")
|
| 550 |
+
info = t("generation.cover_strength_info")
|
| 551 |
+
return gr.update(visible=is_visible, label=label, info=info)
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
|
| 555 |
+
"""Wrapper for converting src audio to codes"""
|
| 556 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 557 |
+
return codes_string
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def update_instruction_ui(
|
| 561 |
+
dit_handler,
|
| 562 |
+
task_type_value: str,
|
| 563 |
+
track_name_value: Optional[str],
|
| 564 |
+
complete_track_classes_value: list,
|
| 565 |
+
audio_codes_content: str = "",
|
| 566 |
+
init_llm_checked: bool = False,
|
| 567 |
+
reference_audio=None,
|
| 568 |
+
) -> tuple:
|
| 569 |
+
"""Update instruction and UI visibility based on task type."""
|
| 570 |
+
instruction = dit_handler.generate_instruction(
|
| 571 |
+
task_type=task_type_value,
|
| 572 |
+
track_name=track_name_value,
|
| 573 |
+
complete_track_classes=complete_track_classes_value
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
# Show track_name for lego and extract
|
| 577 |
+
track_name_visible = task_type_value in ["lego", "extract"]
|
| 578 |
+
# Show complete_track_classes for complete
|
| 579 |
+
complete_visible = task_type_value == "complete"
|
| 580 |
+
# Show audio_cover_strength for cover, LM initialized, or reference audio present
|
| 581 |
+
has_reference = _has_reference_audio(reference_audio)
|
| 582 |
+
audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked or has_reference
|
| 583 |
+
# Label priority: cover -> LM codes -> Similarity/Denoise (reference audio)
|
| 584 |
+
if task_type_value == "cover":
|
| 585 |
+
audio_cover_strength_label = t("generation.cover_strength_label")
|
| 586 |
+
audio_cover_strength_info = t("generation.cover_strength_info")
|
| 587 |
+
elif init_llm_checked:
|
| 588 |
+
audio_cover_strength_label = t("generation.codes_strength_label")
|
| 589 |
+
audio_cover_strength_info = t("generation.codes_strength_info")
|
| 590 |
+
elif has_reference:
|
| 591 |
+
audio_cover_strength_label = t("generation.similarity_denoise_label")
|
| 592 |
+
audio_cover_strength_info = t("generation.similarity_denoise_info")
|
| 593 |
+
else:
|
| 594 |
+
audio_cover_strength_label = t("generation.cover_strength_label")
|
| 595 |
+
audio_cover_strength_info = t("generation.cover_strength_info")
|
| 596 |
+
# Show repainting controls for repaint and lego
|
| 597 |
+
repainting_visible = task_type_value in ["repaint", "lego"]
|
| 598 |
+
# Show text2music_audio_codes if task is text2music OR if it has content
|
| 599 |
+
# This allows it to stay visible even if user switches task type but has codes
|
| 600 |
+
has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
|
| 601 |
+
text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
|
| 602 |
+
|
| 603 |
+
return (
|
| 604 |
+
instruction, # instruction_display_gen
|
| 605 |
+
gr.update(visible=track_name_visible), # track_name
|
| 606 |
+
gr.update(visible=complete_visible), # complete_track_classes
|
| 607 |
+
gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
|
| 608 |
+
gr.update(visible=repainting_visible), # repainting_group
|
| 609 |
+
gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
|
| 614 |
+
"""
|
| 615 |
+
Transcribe audio codes to metadata using LLM understanding.
|
| 616 |
+
If audio_code_string is empty, generate a sample example instead.
|
| 617 |
+
|
| 618 |
+
This is a Gradio wrapper around the understand_music API in acestep.inference.
|
| 619 |
+
|
| 620 |
+
Args:
|
| 621 |
+
llm_handler: LLM handler instance
|
| 622 |
+
audio_code_string: String containing audio codes (or empty for example generation)
|
| 623 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 624 |
+
|
| 625 |
+
Returns:
|
| 626 |
+
Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
|
| 627 |
+
"""
|
| 628 |
+
# Call the inference API
|
| 629 |
+
result = understand_music(
|
| 630 |
+
llm_handler=llm_handler,
|
| 631 |
+
audio_codes=audio_code_string,
|
| 632 |
+
use_constrained_decoding=True,
|
| 633 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
# Handle error case with localized message
|
| 637 |
+
if not result.success:
|
| 638 |
+
# Use localized error message for LLM not initialized
|
| 639 |
+
if result.error == "LLM not initialized":
|
| 640 |
+
return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
|
| 641 |
+
return result.status_message, "", "", None, None, "", "", "", False
|
| 642 |
+
|
| 643 |
+
# Clamp duration to GPU memory limit
|
| 644 |
+
clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
|
| 645 |
+
|
| 646 |
+
return (
|
| 647 |
+
result.status_message,
|
| 648 |
+
result.caption,
|
| 649 |
+
result.lyrics,
|
| 650 |
+
result.bpm,
|
| 651 |
+
clamped_duration,
|
| 652 |
+
result.keyscale,
|
| 653 |
+
result.language,
|
| 654 |
+
result.timesignature,
|
| 655 |
+
True # Set is_format_caption to True (from Transcribe/LM understanding)
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
def update_transcribe_button_text(audio_code_string):
|
| 660 |
+
"""
|
| 661 |
+
Update the transcribe button text based on input content.
|
| 662 |
+
If empty: "Generate Example"
|
| 663 |
+
If has content: "Transcribe"
|
| 664 |
+
"""
|
| 665 |
+
if not audio_code_string or not audio_code_string.strip():
|
| 666 |
+
return gr.update(value="Generate Example")
|
| 667 |
+
else:
|
| 668 |
+
return gr.update(value="Transcribe")
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
def reset_format_caption_flag():
|
| 672 |
+
"""Reset is_format_caption to False when user manually edits caption/metadata"""
|
| 673 |
+
return False
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 677 |
+
"""Update Audio Uploads accordion open state based on whether audio files are present"""
|
| 678 |
+
has_audio = (reference_audio is not None) or (src_audio is not None)
|
| 679 |
+
return gr.Accordion(open=has_audio)
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
|
| 683 |
+
"""
|
| 684 |
+
Handle instrumental checkbox changes.
|
| 685 |
+
When checked: if no lyrics, fill with [Instrumental]
|
| 686 |
+
When unchecked: if lyrics is [Instrumental], clear it
|
| 687 |
+
"""
|
| 688 |
+
if instrumental_checked:
|
| 689 |
+
# If checked and no lyrics, fill with [Instrumental]
|
| 690 |
+
if not current_lyrics or not current_lyrics.strip():
|
| 691 |
+
return "[Instrumental]"
|
| 692 |
+
else:
|
| 693 |
+
# Has lyrics, don't change
|
| 694 |
+
return current_lyrics
|
| 695 |
+
else:
|
| 696 |
+
# If unchecked and lyrics is exactly [Instrumental], clear it
|
| 697 |
+
if current_lyrics and current_lyrics.strip() == "[Instrumental]":
|
| 698 |
+
return ""
|
| 699 |
+
else:
|
| 700 |
+
# Has other lyrics, don't change
|
| 701 |
+
return current_lyrics
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def handle_simple_instrumental_change(is_instrumental: bool):
|
| 705 |
+
"""
|
| 706 |
+
Handle simple mode instrumental checkbox changes.
|
| 707 |
+
When checked: set vocal_language to "unknown" and disable editing.
|
| 708 |
+
When unchecked: enable vocal_language editing.
|
| 709 |
+
|
| 710 |
+
Args:
|
| 711 |
+
is_instrumental: Whether instrumental checkbox is checked
|
| 712 |
+
|
| 713 |
+
Returns:
|
| 714 |
+
gr.update for simple_vocal_language dropdown
|
| 715 |
+
"""
|
| 716 |
+
if is_instrumental:
|
| 717 |
+
return gr.update(value="unknown", interactive=False)
|
| 718 |
+
else:
|
| 719 |
+
return gr.update(interactive=True)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
def update_audio_components_visibility(batch_size):
|
| 723 |
+
"""Show/hide individual audio components based on batch size (1-8)
|
| 724 |
+
|
| 725 |
+
Row 1: Components 1-4 (batch_size 1-4)
|
| 726 |
+
Row 2: Components 5-8 (batch_size 5-8)
|
| 727 |
+
"""
|
| 728 |
+
# Clamp batch size to 1-8 range for UI
|
| 729 |
+
batch_size = min(max(int(batch_size), 1), 8)
|
| 730 |
+
|
| 731 |
+
# Row 1 columns (1-4)
|
| 732 |
+
updates_row1 = (
|
| 733 |
+
gr.update(visible=True), # audio_col_1: always visible
|
| 734 |
+
gr.update(visible=batch_size >= 2), # audio_col_2
|
| 735 |
+
gr.update(visible=batch_size >= 3), # audio_col_3
|
| 736 |
+
gr.update(visible=batch_size >= 4), # audio_col_4
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Row 2 container and columns (5-8)
|
| 740 |
+
show_row_5_8 = batch_size >= 5
|
| 741 |
+
updates_row2 = (
|
| 742 |
+
gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
|
| 743 |
+
gr.update(visible=batch_size >= 5), # audio_col_5
|
| 744 |
+
gr.update(visible=batch_size >= 6), # audio_col_6
|
| 745 |
+
gr.update(visible=batch_size >= 7), # audio_col_7
|
| 746 |
+
gr.update(visible=batch_size >= 8), # audio_col_8
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
return updates_row1 + updates_row2
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
def handle_generation_mode_change(mode: str):
|
| 753 |
+
"""
|
| 754 |
+
Handle generation mode change between Simple and Custom modes.
|
| 755 |
+
|
| 756 |
+
In Simple mode:
|
| 757 |
+
- Show simple mode group (query input, instrumental checkbox, create button)
|
| 758 |
+
- Collapse caption and lyrics accordions
|
| 759 |
+
- Hide optional parameters accordion
|
| 760 |
+
- Disable generate button until sample is created
|
| 761 |
+
|
| 762 |
+
In Custom mode:
|
| 763 |
+
- Hide simple mode group
|
| 764 |
+
- Expand caption and lyrics accordions
|
| 765 |
+
- Show optional parameters accordion
|
| 766 |
+
- Enable generate button
|
| 767 |
+
|
| 768 |
+
Args:
|
| 769 |
+
mode: "simple" or "custom"
|
| 770 |
+
|
| 771 |
+
Returns:
|
| 772 |
+
Tuple of updates for:
|
| 773 |
+
- simple_mode_group (visibility)
|
| 774 |
+
- caption_accordion (open state)
|
| 775 |
+
- lyrics_accordion (open state)
|
| 776 |
+
- generate_btn (interactive state)
|
| 777 |
+
- simple_sample_created (reset state)
|
| 778 |
+
- optional_params_accordion (visibility)
|
| 779 |
+
"""
|
| 780 |
+
is_simple = mode == "simple"
|
| 781 |
+
|
| 782 |
+
return (
|
| 783 |
+
gr.update(visible=is_simple), # simple_mode_group
|
| 784 |
+
gr.Accordion(open=not is_simple), # caption_accordion - collapsed in simple, open in custom
|
| 785 |
+
gr.Accordion(open=not is_simple), # lyrics_accordion - collapsed in simple, open in custom
|
| 786 |
+
gr.update(interactive=not is_simple), # generate_btn - disabled in simple until sample created
|
| 787 |
+
False, # simple_sample_created - reset to False on mode change
|
| 788 |
+
gr.Accordion(open=not is_simple), # optional_params_accordion - hidden in simple mode
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def handle_create_sample(
|
| 793 |
+
llm_handler,
|
| 794 |
+
query: str,
|
| 795 |
+
instrumental: bool,
|
| 796 |
+
vocal_language: str,
|
| 797 |
+
lm_temperature: float,
|
| 798 |
+
lm_top_k: int,
|
| 799 |
+
lm_top_p: float,
|
| 800 |
+
constrained_decoding_debug: bool = False,
|
| 801 |
+
):
|
| 802 |
+
"""
|
| 803 |
+
Handle the Create Sample button click in Simple mode.
|
| 804 |
+
|
| 805 |
+
Creates a sample from the user's query using the LLM, then populates
|
| 806 |
+
the caption, lyrics, and metadata fields.
|
| 807 |
+
|
| 808 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 809 |
+
|
| 810 |
+
Args:
|
| 811 |
+
llm_handler: LLM handler instance
|
| 812 |
+
query: User's natural language music description
|
| 813 |
+
instrumental: Whether to generate instrumental music
|
| 814 |
+
vocal_language: Preferred vocal language for constrained decoding
|
| 815 |
+
lm_temperature: LLM temperature for generation
|
| 816 |
+
lm_top_k: LLM top-k sampling
|
| 817 |
+
lm_top_p: LLM top-p sampling
|
| 818 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 819 |
+
|
| 820 |
+
Returns:
|
| 821 |
+
Tuple of updates for:
|
| 822 |
+
- captions
|
| 823 |
+
- lyrics
|
| 824 |
+
- bpm
|
| 825 |
+
- audio_duration
|
| 826 |
+
- key_scale
|
| 827 |
+
- vocal_language
|
| 828 |
+
- time_signature
|
| 829 |
+
- instrumental_checkbox
|
| 830 |
+
- caption_accordion (open)
|
| 831 |
+
- lyrics_accordion (open)
|
| 832 |
+
- generate_btn (interactive)
|
| 833 |
+
- simple_sample_created (True)
|
| 834 |
+
- think_checkbox (True)
|
| 835 |
+
- is_format_caption_state (True)
|
| 836 |
+
- status_output
|
| 837 |
+
"""
|
| 838 |
+
# Check if LLM is initialized
|
| 839 |
+
if not llm_handler.llm_initialized:
|
| 840 |
+
gr.Warning(t("messages.lm_not_initialized"))
|
| 841 |
+
return (
|
| 842 |
+
gr.update(), # captions - no change
|
| 843 |
+
gr.update(), # lyrics - no change
|
| 844 |
+
gr.update(), # bpm - no change
|
| 845 |
+
gr.update(), # audio_duration - no change
|
| 846 |
+
gr.update(), # key_scale - no change
|
| 847 |
+
gr.update(), # vocal_language - no change
|
| 848 |
+
gr.update(), # time_signature - no change
|
| 849 |
+
gr.update(), # instrumental_checkbox - no change
|
| 850 |
+
gr.update(), # caption_accordion - no change
|
| 851 |
+
gr.update(), # lyrics_accordion - no change
|
| 852 |
+
gr.update(interactive=False), # generate_btn - keep disabled
|
| 853 |
+
False, # simple_sample_created - still False
|
| 854 |
+
gr.update(), # think_checkbox - no change
|
| 855 |
+
gr.update(), # is_format_caption_state - no change
|
| 856 |
+
t("messages.lm_not_initialized"), # status_output
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Convert LM parameters
|
| 860 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 861 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 862 |
+
|
| 863 |
+
# Call create_sample API
|
| 864 |
+
# Note: cfg_scale and negative_prompt are not supported in create_sample mode
|
| 865 |
+
result = create_sample(
|
| 866 |
+
llm_handler=llm_handler,
|
| 867 |
+
query=query,
|
| 868 |
+
instrumental=instrumental,
|
| 869 |
+
vocal_language=vocal_language,
|
| 870 |
+
temperature=lm_temperature,
|
| 871 |
+
top_k=top_k_value,
|
| 872 |
+
top_p=top_p_value,
|
| 873 |
+
use_constrained_decoding=True,
|
| 874 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
# Handle error
|
| 878 |
+
if not result.success:
|
| 879 |
+
gr.Warning(result.status_message or t("messages.sample_creation_failed"))
|
| 880 |
+
return (
|
| 881 |
+
gr.update(), # captions - no change
|
| 882 |
+
gr.update(), # lyrics - no change
|
| 883 |
+
gr.update(), # bpm - no change
|
| 884 |
+
gr.update(), # audio_duration - no change
|
| 885 |
+
gr.update(), # key_scale - no change
|
| 886 |
+
gr.update(), # vocal_language - no change
|
| 887 |
+
gr.update(), # simple vocal_language - no change
|
| 888 |
+
gr.update(), # time_signature - no change
|
| 889 |
+
gr.update(), # instrumental_checkbox - no change
|
| 890 |
+
gr.update(), # caption_accordion - no change
|
| 891 |
+
gr.update(), # lyrics_accordion - no change
|
| 892 |
+
gr.update(interactive=False), # generate_btn - keep disabled
|
| 893 |
+
False, # simple_sample_created - still False
|
| 894 |
+
gr.update(), # think_checkbox - no change
|
| 895 |
+
gr.update(), # is_format_caption_state - no change
|
| 896 |
+
result.status_message or t("messages.sample_creation_failed"), # status_output
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
# Success - populate fields
|
| 900 |
+
gr.Info(t("messages.sample_created"))
|
| 901 |
+
|
| 902 |
+
# Clamp duration to GPU memory limit
|
| 903 |
+
clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
|
| 904 |
+
audio_duration_value = clamped_duration if clamped_duration and clamped_duration > 0 else -1
|
| 905 |
+
|
| 906 |
+
return (
|
| 907 |
+
result.caption, # captions
|
| 908 |
+
result.lyrics, # lyrics
|
| 909 |
+
result.bpm, # bpm
|
| 910 |
+
audio_duration_value, # audio_duration
|
| 911 |
+
result.keyscale, # key_scale
|
| 912 |
+
result.language, # vocal_language
|
| 913 |
+
result.language, # simple vocal_language
|
| 914 |
+
result.timesignature, # time_signature
|
| 915 |
+
result.instrumental, # instrumental_checkbox
|
| 916 |
+
gr.Accordion(open=True), # caption_accordion - expand
|
| 917 |
+
gr.Accordion(open=True), # lyrics_accordion - expand
|
| 918 |
+
gr.update(interactive=True), # generate_btn - enable
|
| 919 |
+
True, # simple_sample_created - True
|
| 920 |
+
True, # think_checkbox - enable thinking
|
| 921 |
+
True, # is_format_caption_state - True (LM-generated)
|
| 922 |
+
result.status_message, # status_output
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def handle_format_sample(
|
| 927 |
+
llm_handler,
|
| 928 |
+
caption: str,
|
| 929 |
+
lyrics: str,
|
| 930 |
+
bpm,
|
| 931 |
+
audio_duration,
|
| 932 |
+
key_scale: str,
|
| 933 |
+
time_signature: str,
|
| 934 |
+
lm_temperature: float,
|
| 935 |
+
lm_top_k: int,
|
| 936 |
+
lm_top_p: float,
|
| 937 |
+
constrained_decoding_debug: bool = False,
|
| 938 |
+
):
|
| 939 |
+
"""
|
| 940 |
+
Handle the Format button click to format caption and lyrics.
|
| 941 |
+
|
| 942 |
+
Takes user-provided caption and lyrics, and uses the LLM to generate
|
| 943 |
+
structured music metadata and an enhanced description.
|
| 944 |
+
|
| 945 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 946 |
+
|
| 947 |
+
Args:
|
| 948 |
+
llm_handler: LLM handler instance
|
| 949 |
+
caption: User's caption/description
|
| 950 |
+
lyrics: User's lyrics
|
| 951 |
+
bpm: User-provided BPM (optional, for constrained decoding)
|
| 952 |
+
audio_duration: User-provided duration (optional, for constrained decoding)
|
| 953 |
+
key_scale: User-provided key scale (optional, for constrained decoding)
|
| 954 |
+
time_signature: User-provided time signature (optional, for constrained decoding)
|
| 955 |
+
lm_temperature: LLM temperature for generation
|
| 956 |
+
lm_top_k: LLM top-k sampling
|
| 957 |
+
lm_top_p: LLM top-p sampling
|
| 958 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 959 |
+
|
| 960 |
+
Returns:
|
| 961 |
+
Tuple of updates for:
|
| 962 |
+
- captions
|
| 963 |
+
- lyrics
|
| 964 |
+
- bpm
|
| 965 |
+
- audio_duration
|
| 966 |
+
- key_scale
|
| 967 |
+
- vocal_language
|
| 968 |
+
- time_signature
|
| 969 |
+
- is_format_caption_state
|
| 970 |
+
- status_output
|
| 971 |
+
"""
|
| 972 |
+
# Check if LLM is initialized
|
| 973 |
+
if not llm_handler.llm_initialized:
|
| 974 |
+
gr.Warning(t("messages.lm_not_initialized"))
|
| 975 |
+
return (
|
| 976 |
+
gr.update(), # captions - no change
|
| 977 |
+
gr.update(), # lyrics - no change
|
| 978 |
+
gr.update(), # bpm - no change
|
| 979 |
+
gr.update(), # audio_duration - no change
|
| 980 |
+
gr.update(), # key_scale - no change
|
| 981 |
+
gr.update(), # vocal_language - no change
|
| 982 |
+
gr.update(), # time_signature - no change
|
| 983 |
+
gr.update(), # is_format_caption_state - no change
|
| 984 |
+
t("messages.lm_not_initialized"), # status_output
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# Build user_metadata from provided values for constrained decoding
|
| 988 |
+
user_metadata = {}
|
| 989 |
+
if bpm is not None and bpm > 0:
|
| 990 |
+
user_metadata['bpm'] = int(bpm)
|
| 991 |
+
if audio_duration is not None and float(audio_duration) > 0:
|
| 992 |
+
user_metadata['duration'] = int(audio_duration)
|
| 993 |
+
if key_scale and key_scale.strip():
|
| 994 |
+
user_metadata['keyscale'] = key_scale.strip()
|
| 995 |
+
if time_signature and time_signature.strip():
|
| 996 |
+
user_metadata['timesignature'] = time_signature.strip()
|
| 997 |
+
|
| 998 |
+
# Only pass user_metadata if we have at least one field
|
| 999 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 1000 |
+
|
| 1001 |
+
# Convert LM parameters
|
| 1002 |
+
top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
|
| 1003 |
+
top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
|
| 1004 |
+
|
| 1005 |
+
# Call format_sample API
|
| 1006 |
+
result = format_sample(
|
| 1007 |
+
llm_handler=llm_handler,
|
| 1008 |
+
caption=caption,
|
| 1009 |
+
lyrics=lyrics,
|
| 1010 |
+
user_metadata=user_metadata_to_pass,
|
| 1011 |
+
temperature=lm_temperature,
|
| 1012 |
+
top_k=top_k_value,
|
| 1013 |
+
top_p=top_p_value,
|
| 1014 |
+
use_constrained_decoding=True,
|
| 1015 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
# Handle error
|
| 1019 |
+
if not result.success:
|
| 1020 |
+
gr.Warning(result.status_message or t("messages.format_failed"))
|
| 1021 |
+
return (
|
| 1022 |
+
gr.update(), # captions - no change
|
| 1023 |
+
gr.update(), # lyrics - no change
|
| 1024 |
+
gr.update(), # bpm - no change
|
| 1025 |
+
gr.update(), # audio_duration - no change
|
| 1026 |
+
gr.update(), # key_scale - no change
|
| 1027 |
+
gr.update(), # vocal_language - no change
|
| 1028 |
+
gr.update(), # time_signature - no change
|
| 1029 |
+
gr.update(), # is_format_caption_state - no change
|
| 1030 |
+
result.status_message or t("messages.format_failed"), # status_output
|
| 1031 |
+
)
|
| 1032 |
+
|
| 1033 |
+
# Success - populate fields
|
| 1034 |
+
gr.Info(t("messages.format_success"))
|
| 1035 |
+
|
| 1036 |
+
# Clamp duration to GPU memory limit
|
| 1037 |
+
clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
|
| 1038 |
+
audio_duration_value = clamped_duration if clamped_duration and clamped_duration > 0 else -1
|
| 1039 |
+
|
| 1040 |
+
return (
|
| 1041 |
+
result.caption, # captions
|
| 1042 |
+
result.lyrics, # lyrics
|
| 1043 |
+
result.bpm, # bpm
|
| 1044 |
+
audio_duration_value, # audio_duration
|
| 1045 |
+
result.keyscale, # key_scale
|
| 1046 |
+
result.language, # vocal_language
|
| 1047 |
+
result.timesignature, # time_signature
|
| 1048 |
+
True, # is_format_caption_state - True (LM-formatted)
|
| 1049 |
+
result.status_message, # status_output
|
| 1050 |
+
)
|
acestep/gradio_ui/events/results_handlers.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/gradio_ui/events/training_handlers.py
ADDED
|
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Event Handlers for Training Tab
|
| 3 |
+
|
| 4 |
+
Contains all event handler functions for the dataset builder and training UI.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
from typing import Any, Dict, List, Tuple, Optional
|
| 10 |
+
from loguru import logger
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
from acestep.training.dataset_builder import DatasetBuilder, AudioSample
|
| 14 |
+
from acestep.debug_utils import debug_log_for, debug_start_for, debug_end_for
|
| 15 |
+
from acestep.gpu_config import get_global_gpu_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_dataset_builder() -> DatasetBuilder:
|
| 19 |
+
"""Create a new DatasetBuilder instance."""
|
| 20 |
+
return DatasetBuilder()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _safe_slider(max_value: int, value: int = 0, visible: Optional[bool] = None) -> gr.Slider:
|
| 24 |
+
"""Create a slider with a non-zero range to avoid Gradio math errors."""
|
| 25 |
+
max_value = max(1, int(max_value))
|
| 26 |
+
kwargs = {"maximum": max_value, "value": min(int(value), max_value)}
|
| 27 |
+
if visible is not None:
|
| 28 |
+
kwargs["visible"] = visible
|
| 29 |
+
return gr.Slider(**kwargs)
|
| 30 |
+
|
| 31 |
+
def scan_directory(
|
| 32 |
+
audio_dir: str,
|
| 33 |
+
dataset_name: str,
|
| 34 |
+
custom_tag: str,
|
| 35 |
+
tag_position: str,
|
| 36 |
+
all_instrumental: bool,
|
| 37 |
+
builder_state: Optional[DatasetBuilder],
|
| 38 |
+
) -> Tuple[Any, str, Any, DatasetBuilder]:
|
| 39 |
+
"""Scan a directory for audio files.
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Tuple of (table_data, status, slider_update, builder_state)
|
| 43 |
+
"""
|
| 44 |
+
if not audio_dir or not audio_dir.strip():
|
| 45 |
+
return [], "� Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
|
| 46 |
+
|
| 47 |
+
# Create or use existing builder
|
| 48 |
+
builder = builder_state if builder_state else DatasetBuilder()
|
| 49 |
+
|
| 50 |
+
# Set metadata before scanning
|
| 51 |
+
builder.metadata.name = dataset_name
|
| 52 |
+
builder.metadata.custom_tag = custom_tag
|
| 53 |
+
builder.metadata.tag_position = tag_position
|
| 54 |
+
builder.metadata.all_instrumental = all_instrumental
|
| 55 |
+
|
| 56 |
+
# Scan directory
|
| 57 |
+
samples, status = builder.scan_directory(audio_dir.strip())
|
| 58 |
+
|
| 59 |
+
if not samples:
|
| 60 |
+
return [], status, _safe_slider(0, value=0, visible=False), builder
|
| 61 |
+
|
| 62 |
+
# Set instrumental and tag for all samples
|
| 63 |
+
builder.set_all_instrumental(all_instrumental)
|
| 64 |
+
if custom_tag:
|
| 65 |
+
builder.set_custom_tag(custom_tag, tag_position)
|
| 66 |
+
|
| 67 |
+
# Get table data
|
| 68 |
+
table_data = builder.get_samples_dataframe_data()
|
| 69 |
+
|
| 70 |
+
# Calculate slider max and return as Slider update
|
| 71 |
+
slider_max = max(0, len(samples) - 1)
|
| 72 |
+
|
| 73 |
+
return table_data, status, _safe_slider(slider_max, value=0, visible=len(samples) > 1), builder
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def auto_label_all(
|
| 77 |
+
dit_handler,
|
| 78 |
+
llm_handler,
|
| 79 |
+
builder_state: Optional[DatasetBuilder],
|
| 80 |
+
skip_metas: bool = False,
|
| 81 |
+
format_lyrics: bool = False,
|
| 82 |
+
transcribe_lyrics: bool = False,
|
| 83 |
+
only_unlabeled: bool = False,
|
| 84 |
+
progress=None,
|
| 85 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 86 |
+
"""Auto-label all samples in the dataset.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
dit_handler: DiT handler for audio processing
|
| 90 |
+
llm_handler: LLM handler for caption generation
|
| 91 |
+
builder_state: Dataset builder state
|
| 92 |
+
skip_metas: If True, skip generating BPM/Key/TimeSig but still generate caption/genre
|
| 93 |
+
format_lyrics: If True, use LLM to format user-provided lyrics from .txt files
|
| 94 |
+
transcribe_lyrics: If True, use LLM to transcribe lyrics from audio (ignores .txt files)
|
| 95 |
+
only_unlabeled: If True, only label samples without caption
|
| 96 |
+
progress: Progress callback
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Tuple of (table_data, status, builder_state)
|
| 100 |
+
"""
|
| 101 |
+
if builder_state is None:
|
| 102 |
+
return [], "� Please scan a directory first", builder_state
|
| 103 |
+
|
| 104 |
+
if not builder_state.samples:
|
| 105 |
+
return [], "� No samples to label. Please scan a directory first.", builder_state
|
| 106 |
+
|
| 107 |
+
# Check if handlers are initialized
|
| 108 |
+
if dit_handler is None or dit_handler.model is None:
|
| 109 |
+
return builder_state.get_samples_dataframe_data(), "� Model not initialized. Please initialize the service first.", builder_state
|
| 110 |
+
|
| 111 |
+
if llm_handler is None or not llm_handler.llm_initialized:
|
| 112 |
+
return builder_state.get_samples_dataframe_data(), "� LLM not initialized. Please initialize the service with LLM enabled.", builder_state
|
| 113 |
+
|
| 114 |
+
def progress_callback(msg):
|
| 115 |
+
if progress:
|
| 116 |
+
try:
|
| 117 |
+
progress(msg)
|
| 118 |
+
except:
|
| 119 |
+
pass
|
| 120 |
+
|
| 121 |
+
# Label all samples (skip_metas only skips BPM/Key/TimeSig, still generates caption/genre)
|
| 122 |
+
samples, status = builder_state.label_all_samples(
|
| 123 |
+
dit_handler=dit_handler,
|
| 124 |
+
llm_handler=llm_handler,
|
| 125 |
+
format_lyrics=format_lyrics,
|
| 126 |
+
transcribe_lyrics=transcribe_lyrics,
|
| 127 |
+
skip_metas=skip_metas,
|
| 128 |
+
only_unlabeled=only_unlabeled,
|
| 129 |
+
progress_callback=progress_callback,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Get updated table data
|
| 133 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 134 |
+
|
| 135 |
+
# Force UI refresh for table and status
|
| 136 |
+
return gr.update(value=table_data), gr.update(value=status), builder_state
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def get_sample_preview(
|
| 140 |
+
sample_idx: int,
|
| 141 |
+
builder_state: Optional[DatasetBuilder],
|
| 142 |
+
):
|
| 143 |
+
"""Get preview data for a specific sample.
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tuple of (audio_path, filename, caption, genre, prompt_override, lyrics, bpm, keyscale, timesig,
|
| 147 |
+
duration, language, instrumental, raw_lyrics, raw_lyrics_visible)
|
| 148 |
+
"""
|
| 149 |
+
empty = (None, "", "", "", "Use Global Ratio", "", None, "", "", 0.0, "instrumental", True, "", False)
|
| 150 |
+
|
| 151 |
+
if builder_state is None or not builder_state.samples:
|
| 152 |
+
return empty
|
| 153 |
+
|
| 154 |
+
if sample_idx is None:
|
| 155 |
+
return empty
|
| 156 |
+
|
| 157 |
+
idx = int(sample_idx)
|
| 158 |
+
if idx < 0 or idx >= len(builder_state.samples):
|
| 159 |
+
return empty
|
| 160 |
+
|
| 161 |
+
sample = builder_state.samples[idx]
|
| 162 |
+
|
| 163 |
+
# Show raw lyrics panel only when raw lyrics exist
|
| 164 |
+
has_raw = sample.has_raw_lyrics()
|
| 165 |
+
|
| 166 |
+
# Convert prompt_override to dropdown choice
|
| 167 |
+
if sample.prompt_override == "genre":
|
| 168 |
+
override_choice = "Genre"
|
| 169 |
+
elif sample.prompt_override == "caption":
|
| 170 |
+
override_choice = "Caption"
|
| 171 |
+
else:
|
| 172 |
+
override_choice = "Use Global Ratio"
|
| 173 |
+
|
| 174 |
+
display_lyrics = sample.lyrics if sample.lyrics else sample.formatted_lyrics
|
| 175 |
+
|
| 176 |
+
return (
|
| 177 |
+
sample.audio_path,
|
| 178 |
+
sample.filename,
|
| 179 |
+
sample.caption,
|
| 180 |
+
sample.genre,
|
| 181 |
+
override_choice,
|
| 182 |
+
display_lyrics,
|
| 183 |
+
sample.bpm,
|
| 184 |
+
sample.keyscale,
|
| 185 |
+
sample.timesignature,
|
| 186 |
+
sample.duration,
|
| 187 |
+
sample.language,
|
| 188 |
+
sample.is_instrumental,
|
| 189 |
+
sample.raw_lyrics if has_raw else "",
|
| 190 |
+
has_raw,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def save_sample_edit(
|
| 195 |
+
sample_idx: int,
|
| 196 |
+
caption: str,
|
| 197 |
+
genre: str,
|
| 198 |
+
prompt_override: str,
|
| 199 |
+
lyrics: str,
|
| 200 |
+
bpm: Optional[int],
|
| 201 |
+
keyscale: str,
|
| 202 |
+
timesig: str,
|
| 203 |
+
language: str,
|
| 204 |
+
is_instrumental: bool,
|
| 205 |
+
builder_state: Optional[DatasetBuilder],
|
| 206 |
+
) -> Tuple[List[List[Any]], str, DatasetBuilder]:
|
| 207 |
+
"""Save edits to a sample.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Tuple of (table_data, status, builder_state)
|
| 211 |
+
"""
|
| 212 |
+
if builder_state is None:
|
| 213 |
+
return [], "� No dataset loaded", builder_state
|
| 214 |
+
|
| 215 |
+
idx = int(sample_idx)
|
| 216 |
+
|
| 217 |
+
# Convert dropdown choice to prompt_override value
|
| 218 |
+
if prompt_override == "Genre":
|
| 219 |
+
override_value = "genre"
|
| 220 |
+
elif prompt_override == "Caption":
|
| 221 |
+
override_value = "caption"
|
| 222 |
+
else:
|
| 223 |
+
override_value = None # Use Global Ratio
|
| 224 |
+
|
| 225 |
+
# Update sample
|
| 226 |
+
updated_lyrics = lyrics if not is_instrumental else "[Instrumental]"
|
| 227 |
+
updated_formatted = updated_lyrics if updated_lyrics and updated_lyrics != "[Instrumental]" else ""
|
| 228 |
+
sample, status = builder_state.update_sample(
|
| 229 |
+
idx,
|
| 230 |
+
caption=caption,
|
| 231 |
+
genre=genre,
|
| 232 |
+
prompt_override=override_value,
|
| 233 |
+
lyrics=updated_lyrics,
|
| 234 |
+
formatted_lyrics=updated_formatted,
|
| 235 |
+
bpm=int(bpm) if bpm else None,
|
| 236 |
+
keyscale=keyscale,
|
| 237 |
+
timesignature=timesig,
|
| 238 |
+
language="unknown" if is_instrumental else language,
|
| 239 |
+
is_instrumental=is_instrumental,
|
| 240 |
+
labeled=True,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Get updated table data
|
| 244 |
+
table_data = builder_state.get_samples_dataframe_data()
|
| 245 |
+
|
| 246 |
+
return table_data, status, builder_state
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def update_settings(
|
| 250 |
+
custom_tag: str,
|
| 251 |
+
tag_position: str,
|
| 252 |
+
all_instrumental: bool,
|
| 253 |
+
genre_ratio: int,
|
| 254 |
+
builder_state: Optional[DatasetBuilder],
|
| 255 |
+
) -> DatasetBuilder:
|
| 256 |
+
"""Update dataset settings.
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
Updated builder_state
|
| 260 |
+
"""
|
| 261 |
+
if builder_state is None:
|
| 262 |
+
return builder_state
|
| 263 |
+
|
| 264 |
+
if custom_tag:
|
| 265 |
+
builder_state.set_custom_tag(custom_tag, tag_position)
|
| 266 |
+
|
| 267 |
+
builder_state.set_all_instrumental(all_instrumental)
|
| 268 |
+
builder_state.metadata.genre_ratio = int(genre_ratio)
|
| 269 |
+
|
| 270 |
+
return builder_state
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def save_dataset(
|
| 274 |
+
save_path: str,
|
| 275 |
+
dataset_name: str,
|
| 276 |
+
builder_state: Optional[DatasetBuilder],
|
| 277 |
+
) -> Tuple[str, Any]:
|
| 278 |
+
"""Save the dataset to a JSON file.
|
| 279 |
+
|
| 280 |
+
Returns:
|
| 281 |
+
Status message
|
| 282 |
+
"""
|
| 283 |
+
if builder_state is None:
|
| 284 |
+
return "� No dataset to save. Please scan a directory first.", gr.update()
|
| 285 |
+
|
| 286 |
+
if not builder_state.samples:
|
| 287 |
+
return "� No samples in dataset.", gr.update()
|
| 288 |
+
|
| 289 |
+
if not save_path or not save_path.strip():
|
| 290 |
+
return "� Please enter a save path.", gr.update()
|
| 291 |
+
|
| 292 |
+
save_path = save_path.strip()
|
| 293 |
+
if not save_path.lower().endswith(".json"):
|
| 294 |
+
save_path = save_path + ".json"
|
| 295 |
+
|
| 296 |
+
# Check if any samples are labeled
|
| 297 |
+
labeled_count = builder_state.get_labeled_count()
|
| 298 |
+
if labeled_count == 0:
|
| 299 |
+
return "�️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
|
| 300 |
+
|
| 301 |
+
return builder_state.save_dataset(save_path, dataset_name), gr.update(value=save_path)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def load_existing_dataset_for_preprocess(
|
| 305 |
+
dataset_path: str,
|
| 306 |
+
builder_state: Optional[DatasetBuilder],
|
| 307 |
+
):
|
| 308 |
+
"""Load an existing dataset JSON file for preprocessing.
|
| 309 |
+
|
| 310 |
+
This allows users to load a previously saved dataset and proceed to preprocessing
|
| 311 |
+
without having to re-scan and re-label.
|
| 312 |
+
|
| 313 |
+
Returns:
|
| 314 |
+
Tuple of (status, table_data, slider_update, builder_state,
|
| 315 |
+
audio_path, filename, caption, genre, prompt_override,
|
| 316 |
+
lyrics, bpm, keyscale, timesig, duration, language, instrumental,
|
| 317 |
+
raw_lyrics, has_raw)
|
| 318 |
+
"""
|
| 319 |
+
# Empty preview: (audio_path, filename, caption, genre, prompt_override, lyrics, bpm, keyscale, timesig, duration, language, instrumental, raw_lyrics, has_raw)
|
| 320 |
+
empty_preview = (None, "", "", "", "Use Global Ratio", "", None, "", "", 0.0, "instrumental", True, "", False)
|
| 321 |
+
|
| 322 |
+
if not dataset_path or not dataset_path.strip():
|
| 323 |
+
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
| 324 |
+
return ("� Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
| 325 |
+
|
| 326 |
+
dataset_path = dataset_path.strip()
|
| 327 |
+
debug_log_for("dataset", f"UI load_existing_dataset_for_preprocess: path='{dataset_path}'")
|
| 328 |
+
|
| 329 |
+
if not os.path.exists(dataset_path):
|
| 330 |
+
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
| 331 |
+
return (f"� Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
|
| 332 |
+
|
| 333 |
+
# Create new builder (don't reuse old state when loading a file)
|
| 334 |
+
builder = DatasetBuilder()
|
| 335 |
+
|
| 336 |
+
# Load the dataset
|
| 337 |
+
t0 = debug_start_for("dataset", "load_dataset")
|
| 338 |
+
samples, status = builder.load_dataset(dataset_path)
|
| 339 |
+
debug_end_for("dataset", "load_dataset", t0)
|
| 340 |
+
|
| 341 |
+
if not samples:
|
| 342 |
+
updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
|
| 343 |
+
return (status, [], _safe_slider(0, value=0, visible=False), builder) + empty_preview + updates
|
| 344 |
+
|
| 345 |
+
# Get table data
|
| 346 |
+
table_data = builder.get_samples_dataframe_data()
|
| 347 |
+
|
| 348 |
+
# Calculate slider max
|
| 349 |
+
slider_max = max(0, len(samples) - 1)
|
| 350 |
+
|
| 351 |
+
# Create info text
|
| 352 |
+
labeled_count = builder.get_labeled_count()
|
| 353 |
+
info = f"� Loaded dataset: {builder.metadata.name}\n"
|
| 354 |
+
info += f"� Samples: {len(samples)} ({labeled_count} labeled)\n"
|
| 355 |
+
info += f"���️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
|
| 356 |
+
info += "� Ready for preprocessing! You can also edit samples below."
|
| 357 |
+
if any((s.formatted_lyrics and not s.lyrics) for s in builder.samples):
|
| 358 |
+
info += "\n�️ Showing formatted lyrics where lyrics are empty."
|
| 359 |
+
|
| 360 |
+
# Get first sample preview
|
| 361 |
+
first_sample = builder.samples[0]
|
| 362 |
+
has_raw = first_sample.has_raw_lyrics()
|
| 363 |
+
|
| 364 |
+
# Convert prompt_override to dropdown choice
|
| 365 |
+
if first_sample.prompt_override == "genre":
|
| 366 |
+
override_choice = "Genre"
|
| 367 |
+
elif first_sample.prompt_override == "caption":
|
| 368 |
+
override_choice = "Caption"
|
| 369 |
+
else:
|
| 370 |
+
override_choice = "Use Global Ratio"
|
| 371 |
+
|
| 372 |
+
display_lyrics = first_sample.lyrics if first_sample.lyrics else first_sample.formatted_lyrics
|
| 373 |
+
|
| 374 |
+
preview = (
|
| 375 |
+
first_sample.audio_path,
|
| 376 |
+
first_sample.filename,
|
| 377 |
+
first_sample.caption,
|
| 378 |
+
first_sample.genre,
|
| 379 |
+
override_choice,
|
| 380 |
+
display_lyrics,
|
| 381 |
+
first_sample.bpm,
|
| 382 |
+
first_sample.keyscale,
|
| 383 |
+
first_sample.timesignature,
|
| 384 |
+
first_sample.duration,
|
| 385 |
+
first_sample.language,
|
| 386 |
+
first_sample.is_instrumental,
|
| 387 |
+
first_sample.raw_lyrics if has_raw else "",
|
| 388 |
+
has_raw,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
updates = (
|
| 392 |
+
gr.update(value=builder.metadata.name),
|
| 393 |
+
gr.update(value=builder.metadata.custom_tag),
|
| 394 |
+
gr.update(value=builder.metadata.tag_position),
|
| 395 |
+
gr.update(value=builder.metadata.all_instrumental),
|
| 396 |
+
gr.update(value=builder.metadata.genre_ratio),
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
return (info, table_data, _safe_slider(slider_max, value=0, visible=len(samples) > 1), builder) + preview + updates
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def preprocess_dataset(
|
| 403 |
+
output_dir: str,
|
| 404 |
+
dit_handler,
|
| 405 |
+
builder_state: Optional[DatasetBuilder],
|
| 406 |
+
progress=None,
|
| 407 |
+
) -> str:
|
| 408 |
+
"""Preprocess dataset to tensor files for fast training.
|
| 409 |
+
|
| 410 |
+
This converts audio files to VAE latents and text to embeddings.
|
| 411 |
+
|
| 412 |
+
Returns:
|
| 413 |
+
Status message
|
| 414 |
+
"""
|
| 415 |
+
if builder_state is None:
|
| 416 |
+
return "� No dataset loaded. Please scan a directory first."
|
| 417 |
+
|
| 418 |
+
if not builder_state.samples:
|
| 419 |
+
return "� No samples in dataset."
|
| 420 |
+
|
| 421 |
+
labeled_count = builder_state.get_labeled_count()
|
| 422 |
+
if labeled_count == 0:
|
| 423 |
+
return "� No labeled samples. Please auto-label or manually label samples first."
|
| 424 |
+
|
| 425 |
+
if not output_dir or not output_dir.strip():
|
| 426 |
+
return "� Please enter an output directory."
|
| 427 |
+
|
| 428 |
+
if dit_handler is None or dit_handler.model is None:
|
| 429 |
+
return "� Model not initialized. Please initialize the service first."
|
| 430 |
+
|
| 431 |
+
def progress_callback(msg):
|
| 432 |
+
if progress:
|
| 433 |
+
try:
|
| 434 |
+
progress(msg)
|
| 435 |
+
except:
|
| 436 |
+
pass
|
| 437 |
+
|
| 438 |
+
# Run preprocessing
|
| 439 |
+
t0 = debug_start_for("dataset", "preprocess_to_tensors")
|
| 440 |
+
output_paths, status = builder_state.preprocess_to_tensors(
|
| 441 |
+
dit_handler=dit_handler,
|
| 442 |
+
output_dir=output_dir.strip(),
|
| 443 |
+
progress_callback=progress_callback,
|
| 444 |
+
)
|
| 445 |
+
debug_end_for("dataset", "preprocess_to_tensors", t0)
|
| 446 |
+
|
| 447 |
+
return status
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def load_training_dataset(
|
| 451 |
+
tensor_dir: str,
|
| 452 |
+
) -> str:
|
| 453 |
+
"""Load a preprocessed tensor dataset for training.
|
| 454 |
+
|
| 455 |
+
Returns:
|
| 456 |
+
Info text about the dataset
|
| 457 |
+
"""
|
| 458 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 459 |
+
return "� Please enter a tensor directory path"
|
| 460 |
+
|
| 461 |
+
tensor_dir = tensor_dir.strip()
|
| 462 |
+
|
| 463 |
+
if not os.path.exists(tensor_dir):
|
| 464 |
+
return f"� Directory not found: {tensor_dir}"
|
| 465 |
+
|
| 466 |
+
if not os.path.isdir(tensor_dir):
|
| 467 |
+
return f"� Not a directory: {tensor_dir}"
|
| 468 |
+
|
| 469 |
+
# Check for manifest
|
| 470 |
+
manifest_path = os.path.join(tensor_dir, "manifest.json")
|
| 471 |
+
if os.path.exists(manifest_path):
|
| 472 |
+
try:
|
| 473 |
+
with open(manifest_path, 'r') as f:
|
| 474 |
+
manifest = json.load(f)
|
| 475 |
+
|
| 476 |
+
num_samples = manifest.get("num_samples", 0)
|
| 477 |
+
metadata = manifest.get("metadata", {})
|
| 478 |
+
name = metadata.get("name", "Unknown")
|
| 479 |
+
custom_tag = metadata.get("custom_tag", "")
|
| 480 |
+
|
| 481 |
+
info = f"� Loaded preprocessed dataset: {name}\n"
|
| 482 |
+
info += f"� Samples: {num_samples} preprocessed tensors\n"
|
| 483 |
+
info += f"���️ Custom Tag: {custom_tag or '(none)'}"
|
| 484 |
+
|
| 485 |
+
return info
|
| 486 |
+
except Exception as e:
|
| 487 |
+
logger.warning(f"Failed to read manifest: {e}")
|
| 488 |
+
|
| 489 |
+
# Fallback: count .pt files
|
| 490 |
+
pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
|
| 491 |
+
|
| 492 |
+
if not pt_files:
|
| 493 |
+
return f"� No .pt tensor files found in {tensor_dir}"
|
| 494 |
+
|
| 495 |
+
info = f"� Found {len(pt_files)} tensor files in {tensor_dir}\n"
|
| 496 |
+
info += "�️ No manifest.json found - using all .pt files"
|
| 497 |
+
|
| 498 |
+
return info
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# Training handlers
|
| 502 |
+
|
| 503 |
+
import time
|
| 504 |
+
import re
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def _format_duration(seconds):
|
| 508 |
+
"""Format seconds to human readable string."""
|
| 509 |
+
seconds = int(seconds)
|
| 510 |
+
if seconds < 60:
|
| 511 |
+
return f"{seconds}s"
|
| 512 |
+
elif seconds < 3600:
|
| 513 |
+
return f"{seconds // 60}m {seconds % 60}s"
|
| 514 |
+
else:
|
| 515 |
+
return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def start_training(
|
| 519 |
+
tensor_dir: str,
|
| 520 |
+
dit_handler,
|
| 521 |
+
lora_rank: int,
|
| 522 |
+
lora_alpha: int,
|
| 523 |
+
lora_dropout: float,
|
| 524 |
+
learning_rate: float,
|
| 525 |
+
train_epochs: int,
|
| 526 |
+
train_batch_size: int,
|
| 527 |
+
gradient_accumulation: int,
|
| 528 |
+
save_every_n_epochs: int,
|
| 529 |
+
training_shift: float,
|
| 530 |
+
training_seed: int,
|
| 531 |
+
lora_output_dir: str,
|
| 532 |
+
resume_checkpoint_dir: str,
|
| 533 |
+
training_state: Dict,
|
| 534 |
+
progress=None,
|
| 535 |
+
):
|
| 536 |
+
"""Start LoRA training from preprocessed tensors.
|
| 537 |
+
|
| 538 |
+
This is a generator function that yields progress updates.
|
| 539 |
+
"""
|
| 540 |
+
if not tensor_dir or not tensor_dir.strip():
|
| 541 |
+
yield "� Please enter a tensor directory path", "", None, training_state
|
| 542 |
+
return
|
| 543 |
+
|
| 544 |
+
tensor_dir = tensor_dir.strip()
|
| 545 |
+
|
| 546 |
+
if not os.path.exists(tensor_dir):
|
| 547 |
+
yield f"� Tensor directory not found: {tensor_dir}", "", None, training_state
|
| 548 |
+
return
|
| 549 |
+
|
| 550 |
+
if dit_handler is None or dit_handler.model is None:
|
| 551 |
+
yield "� Model not initialized. Please initialize the service first.", "", None, training_state
|
| 552 |
+
return
|
| 553 |
+
|
| 554 |
+
# Training preset: LoRA training must run on non-quantized DiT.
|
| 555 |
+
if getattr(dit_handler, "quantization", None) is not None:
|
| 556 |
+
gpu_config = get_global_gpu_config()
|
| 557 |
+
if gpu_config.gpu_memory_gb <= 0:
|
| 558 |
+
yield (
|
| 559 |
+
"WARNING: CPU-only training detected. Using best-effort training path "
|
| 560 |
+
"(non-quantized DiT). Performance will be sub-optimal.",
|
| 561 |
+
"",
|
| 562 |
+
None,
|
| 563 |
+
training_state,
|
| 564 |
+
)
|
| 565 |
+
elif gpu_config.tier in {"tier1", "tier2", "tier3", "tier4"}:
|
| 566 |
+
yield (
|
| 567 |
+
f"WARNING: Low VRAM tier detected ({gpu_config.gpu_memory_gb:.1f} GB, {gpu_config.tier}). "
|
| 568 |
+
"Using best-effort training path (non-quantized DiT). Performance may be sub-optimal.",
|
| 569 |
+
"",
|
| 570 |
+
None,
|
| 571 |
+
training_state,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
yield "Switching model to training preset (disable quantization)...", "", None, training_state
|
| 575 |
+
if hasattr(dit_handler, "switch_to_training_preset"):
|
| 576 |
+
switch_status, switched = dit_handler.switch_to_training_preset()
|
| 577 |
+
if not switched:
|
| 578 |
+
yield f"� {switch_status}", "", None, training_state
|
| 579 |
+
return
|
| 580 |
+
yield f"� {switch_status}", "", None, training_state
|
| 581 |
+
else:
|
| 582 |
+
yield "� Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
|
| 583 |
+
return
|
| 584 |
+
|
| 585 |
+
# Check for required training dependencies
|
| 586 |
+
try:
|
| 587 |
+
from lightning.fabric import Fabric
|
| 588 |
+
from peft import get_peft_model, LoraConfig
|
| 589 |
+
except ImportError as e:
|
| 590 |
+
yield f"� Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
|
| 591 |
+
return
|
| 592 |
+
|
| 593 |
+
training_state["is_training"] = True
|
| 594 |
+
training_state["should_stop"] = False
|
| 595 |
+
|
| 596 |
+
try:
|
| 597 |
+
from acestep.training.trainer import LoRATrainer
|
| 598 |
+
from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
|
| 599 |
+
|
| 600 |
+
# Create configs
|
| 601 |
+
lora_config = LoRAConfigClass(
|
| 602 |
+
r=lora_rank,
|
| 603 |
+
alpha=lora_alpha,
|
| 604 |
+
dropout=lora_dropout,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
device_attr = getattr(dit_handler, "device", "")
|
| 608 |
+
if hasattr(device_attr, "type"):
|
| 609 |
+
device_type = str(device_attr.type).lower()
|
| 610 |
+
else:
|
| 611 |
+
device_type = str(device_attr).split(":", 1)[0].lower()
|
| 612 |
+
|
| 613 |
+
# Use device-tuned dataloader defaults while preserving CUDA acceleration.
|
| 614 |
+
if device_type == "cuda":
|
| 615 |
+
num_workers = 4
|
| 616 |
+
pin_memory = True
|
| 617 |
+
prefetch_factor = 2
|
| 618 |
+
persistent_workers = True
|
| 619 |
+
pin_memory_device = "cuda"
|
| 620 |
+
mixed_precision = "bf16"
|
| 621 |
+
elif device_type == "xpu":
|
| 622 |
+
num_workers = 4
|
| 623 |
+
pin_memory = True
|
| 624 |
+
prefetch_factor = 2
|
| 625 |
+
persistent_workers = True
|
| 626 |
+
pin_memory_device = None
|
| 627 |
+
mixed_precision = "bf16"
|
| 628 |
+
elif device_type == "mps":
|
| 629 |
+
num_workers = 0
|
| 630 |
+
pin_memory = False
|
| 631 |
+
prefetch_factor = 2
|
| 632 |
+
persistent_workers = False
|
| 633 |
+
pin_memory_device = None
|
| 634 |
+
mixed_precision = "fp16"
|
| 635 |
+
else:
|
| 636 |
+
cpu_count = os.cpu_count() or 2
|
| 637 |
+
num_workers = min(4, max(1, cpu_count // 2))
|
| 638 |
+
pin_memory = False
|
| 639 |
+
prefetch_factor = 2
|
| 640 |
+
persistent_workers = num_workers > 0
|
| 641 |
+
pin_memory_device = None
|
| 642 |
+
mixed_precision = "fp32"
|
| 643 |
+
|
| 644 |
+
logger.info(
|
| 645 |
+
f"Training loader config: device={device_type}, workers={num_workers}, "
|
| 646 |
+
f"pin_memory={pin_memory}, pin_memory_device={pin_memory_device}, "
|
| 647 |
+
f"persistent_workers={persistent_workers}"
|
| 648 |
+
)
|
| 649 |
+
training_config = TrainingConfig(
|
| 650 |
+
shift=training_shift,
|
| 651 |
+
learning_rate=learning_rate,
|
| 652 |
+
batch_size=train_batch_size,
|
| 653 |
+
gradient_accumulation_steps=gradient_accumulation,
|
| 654 |
+
max_epochs=train_epochs,
|
| 655 |
+
save_every_n_epochs=save_every_n_epochs,
|
| 656 |
+
seed=training_seed,
|
| 657 |
+
output_dir=lora_output_dir,
|
| 658 |
+
num_workers=num_workers,
|
| 659 |
+
pin_memory=pin_memory,
|
| 660 |
+
prefetch_factor=prefetch_factor,
|
| 661 |
+
persistent_workers=persistent_workers,
|
| 662 |
+
pin_memory_device=pin_memory_device,
|
| 663 |
+
mixed_precision=mixed_precision,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
import pandas as pd
|
| 667 |
+
|
| 668 |
+
# Initialize training log and loss history
|
| 669 |
+
log_lines = []
|
| 670 |
+
loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
|
| 671 |
+
|
| 672 |
+
# Start timer
|
| 673 |
+
start_time = time.time()
|
| 674 |
+
|
| 675 |
+
yield f"� Starting training from {tensor_dir}...", "", loss_data, training_state
|
| 676 |
+
|
| 677 |
+
# Create trainer
|
| 678 |
+
trainer = LoRATrainer(
|
| 679 |
+
dit_handler=dit_handler,
|
| 680 |
+
lora_config=lora_config,
|
| 681 |
+
training_config=training_config,
|
| 682 |
+
)
|
| 683 |
+
|
| 684 |
+
# Collect loss history
|
| 685 |
+
step_list = []
|
| 686 |
+
loss_list = []
|
| 687 |
+
training_failed = False
|
| 688 |
+
failure_message = ""
|
| 689 |
+
|
| 690 |
+
# Train with progress updates using preprocessed tensors
|
| 691 |
+
resume_from = resume_checkpoint_dir.strip() if resume_checkpoint_dir and resume_checkpoint_dir.strip() else None
|
| 692 |
+
for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state, resume_from=resume_from):
|
| 693 |
+
status_text = str(status)
|
| 694 |
+
status_lower = status_text.lower()
|
| 695 |
+
if (
|
| 696 |
+
status_text.startswith("âŒ")
|
| 697 |
+
or status_text.startswith("❌")
|
| 698 |
+
or "training failed" in status_lower
|
| 699 |
+
or "error:" in status_lower
|
| 700 |
+
or "module not found" in status_lower
|
| 701 |
+
):
|
| 702 |
+
training_failed = True
|
| 703 |
+
failure_message = status_text
|
| 704 |
+
# Calculate elapsed time and ETA
|
| 705 |
+
elapsed_seconds = time.time() - start_time
|
| 706 |
+
time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
|
| 707 |
+
|
| 708 |
+
# Parse "Epoch x/y" from status to calculate ETA
|
| 709 |
+
match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
|
| 710 |
+
if match:
|
| 711 |
+
current_ep = int(match.group(1))
|
| 712 |
+
total_ep = int(match.group(2))
|
| 713 |
+
if current_ep > 0:
|
| 714 |
+
eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
|
| 715 |
+
time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
|
| 716 |
+
|
| 717 |
+
# Display status with time info
|
| 718 |
+
display_status = f"{status}\n{time_info}"
|
| 719 |
+
|
| 720 |
+
# Terminal log
|
| 721 |
+
log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
|
| 722 |
+
logger.info(log_msg)
|
| 723 |
+
|
| 724 |
+
# Add to UI log
|
| 725 |
+
log_lines.append(status)
|
| 726 |
+
if len(log_lines) > 15:
|
| 727 |
+
log_lines = log_lines[-15:]
|
| 728 |
+
log_text = "\n".join(log_lines)
|
| 729 |
+
|
| 730 |
+
# Track loss for plot (only valid values)
|
| 731 |
+
if step > 0 and loss is not None and loss == loss: # Check for NaN
|
| 732 |
+
step_list.append(step)
|
| 733 |
+
loss_list.append(float(loss))
|
| 734 |
+
loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
|
| 735 |
+
|
| 736 |
+
yield display_status, log_text, loss_data, training_state
|
| 737 |
+
|
| 738 |
+
if training_state.get("should_stop", False):
|
| 739 |
+
logger.info("⏹️ Training stopped by user")
|
| 740 |
+
log_lines.append("⏹️ Training stopped by user")
|
| 741 |
+
yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
|
| 742 |
+
break
|
| 743 |
+
|
| 744 |
+
total_time = time.time() - start_time
|
| 745 |
+
training_state["is_training"] = False
|
| 746 |
+
if training_failed:
|
| 747 |
+
final_msg = f"{failure_message}\nElapsed: {_format_duration(total_time)}"
|
| 748 |
+
logger.warning(final_msg)
|
| 749 |
+
log_lines.append(failure_message)
|
| 750 |
+
yield final_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
| 751 |
+
return
|
| 752 |
+
completion_msg = f"� Training completed! Total time: {_format_duration(total_time)}"
|
| 753 |
+
|
| 754 |
+
logger.info(completion_msg)
|
| 755 |
+
log_lines.append(completion_msg)
|
| 756 |
+
|
| 757 |
+
yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
|
| 758 |
+
|
| 759 |
+
except Exception as e:
|
| 760 |
+
logger.exception("Training error")
|
| 761 |
+
training_state["is_training"] = False
|
| 762 |
+
import pandas as pd
|
| 763 |
+
empty_df = pd.DataFrame({"step": [], "loss": []})
|
| 764 |
+
yield f"� Error: {str(e)}", str(e), empty_df, training_state
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def stop_training(training_state: Dict) -> Tuple[str, Dict]:
|
| 768 |
+
"""Stop the current training process.
|
| 769 |
+
|
| 770 |
+
Returns:
|
| 771 |
+
Tuple of (status, training_state)
|
| 772 |
+
"""
|
| 773 |
+
if not training_state.get("is_training", False):
|
| 774 |
+
return "�️ No training in progress", training_state
|
| 775 |
+
|
| 776 |
+
training_state["should_stop"] = True
|
| 777 |
+
return "⏹️ Stopping training...", training_state
|
| 778 |
+
|
| 779 |
+
|
| 780 |
+
def export_lora(
|
| 781 |
+
export_path: str,
|
| 782 |
+
lora_output_dir: str,
|
| 783 |
+
) -> str:
|
| 784 |
+
"""Export the trained LoRA weights.
|
| 785 |
+
|
| 786 |
+
Returns:
|
| 787 |
+
Status message
|
| 788 |
+
"""
|
| 789 |
+
if not export_path or not export_path.strip():
|
| 790 |
+
return "� Please enter an export path"
|
| 791 |
+
|
| 792 |
+
# Check if there's a trained model to export
|
| 793 |
+
final_dir = os.path.join(lora_output_dir, "final")
|
| 794 |
+
checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
|
| 795 |
+
|
| 796 |
+
# Prefer final, fallback to checkpoints
|
| 797 |
+
if os.path.exists(final_dir):
|
| 798 |
+
source_path = final_dir
|
| 799 |
+
elif os.path.exists(checkpoint_dir):
|
| 800 |
+
# Find the latest checkpoint
|
| 801 |
+
checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
|
| 802 |
+
if not checkpoints:
|
| 803 |
+
return "� No checkpoints found"
|
| 804 |
+
|
| 805 |
+
checkpoints.sort(key=lambda x: int(x.split("_")[1]))
|
| 806 |
+
latest = checkpoints[-1]
|
| 807 |
+
source_path = os.path.join(checkpoint_dir, latest)
|
| 808 |
+
else:
|
| 809 |
+
return f"� No trained model found in {lora_output_dir}"
|
| 810 |
+
|
| 811 |
+
try:
|
| 812 |
+
import shutil
|
| 813 |
+
|
| 814 |
+
export_path = export_path.strip()
|
| 815 |
+
os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
|
| 816 |
+
|
| 817 |
+
if os.path.exists(export_path):
|
| 818 |
+
shutil.rmtree(export_path)
|
| 819 |
+
|
| 820 |
+
shutil.copytree(source_path, export_path)
|
| 821 |
+
|
| 822 |
+
return f"� LoRA exported to {export_path}"
|
| 823 |
+
|
| 824 |
+
except Exception as e:
|
| 825 |
+
logger.exception("Export error")
|
| 826 |
+
return f"� Export failed: {str(e)}"
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
|
acestep/gradio_ui/i18n.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Internationalization (i18n) module for Gradio UI
|
| 3 |
+
Supports multiple languages with easy translation management
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class I18n:
|
| 11 |
+
"""Internationalization handler"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, default_language: str = "en"):
|
| 14 |
+
"""
|
| 15 |
+
Initialize i18n handler
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
default_language: Default language code (en, zh, ja, etc.)
|
| 19 |
+
"""
|
| 20 |
+
self.current_language = default_language
|
| 21 |
+
self.translations: Dict[str, Dict[str, str]] = {}
|
| 22 |
+
self._load_all_translations()
|
| 23 |
+
|
| 24 |
+
def _load_all_translations(self):
|
| 25 |
+
"""Load all translation files from i18n directory"""
|
| 26 |
+
current_file = os.path.abspath(__file__)
|
| 27 |
+
module_dir = os.path.dirname(current_file)
|
| 28 |
+
i18n_dir = os.path.join(module_dir, "i18n")
|
| 29 |
+
|
| 30 |
+
if not os.path.exists(i18n_dir):
|
| 31 |
+
# Create i18n directory if it doesn't exist
|
| 32 |
+
os.makedirs(i18n_dir)
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# Load all JSON files in i18n directory
|
| 36 |
+
for filename in os.listdir(i18n_dir):
|
| 37 |
+
if filename.endswith(".json"):
|
| 38 |
+
lang_code = filename[:-5] # Remove .json extension
|
| 39 |
+
filepath = os.path.join(i18n_dir, filename)
|
| 40 |
+
try:
|
| 41 |
+
with open(filepath, 'r', encoding='utf-8') as f:
|
| 42 |
+
self.translations[lang_code] = json.load(f)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Error loading translation file {filename}: {e}")
|
| 45 |
+
|
| 46 |
+
def set_language(self, language: str):
|
| 47 |
+
"""Set current language"""
|
| 48 |
+
if language in self.translations:
|
| 49 |
+
self.current_language = language
|
| 50 |
+
else:
|
| 51 |
+
print(f"Warning: Language '{language}' not found, using default")
|
| 52 |
+
|
| 53 |
+
def t(self, key: str, **kwargs) -> str:
|
| 54 |
+
"""
|
| 55 |
+
Translate a key to current language
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
key: Translation key (dot-separated for nested keys)
|
| 59 |
+
**kwargs: Optional format parameters
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Translated string
|
| 63 |
+
"""
|
| 64 |
+
# Get translation from current language
|
| 65 |
+
translation = self._get_nested_value(
|
| 66 |
+
self.translations.get(self.current_language, {}),
|
| 67 |
+
key
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Fallback to English if not found
|
| 71 |
+
if translation is None:
|
| 72 |
+
translation = self._get_nested_value(
|
| 73 |
+
self.translations.get('en', {}),
|
| 74 |
+
key
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Final fallback to key itself
|
| 78 |
+
if translation is None:
|
| 79 |
+
translation = key
|
| 80 |
+
|
| 81 |
+
# Apply formatting if kwargs provided
|
| 82 |
+
if kwargs:
|
| 83 |
+
try:
|
| 84 |
+
translation = translation.format(**kwargs)
|
| 85 |
+
except KeyError:
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
return translation
|
| 89 |
+
|
| 90 |
+
def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
|
| 91 |
+
"""
|
| 92 |
+
Get nested dictionary value using dot notation
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
data: Dictionary to search
|
| 96 |
+
key: Dot-separated key (e.g., "section.subsection.key")
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Value if found, None otherwise
|
| 100 |
+
"""
|
| 101 |
+
keys = key.split('.')
|
| 102 |
+
current = data
|
| 103 |
+
|
| 104 |
+
for k in keys:
|
| 105 |
+
if isinstance(current, dict) and k in current:
|
| 106 |
+
current = current[k]
|
| 107 |
+
else:
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
return current if isinstance(current, str) else None
|
| 111 |
+
|
| 112 |
+
def get_available_languages(self) -> list:
|
| 113 |
+
"""Get list of available language codes"""
|
| 114 |
+
return list(self.translations.keys())
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# Global i18n instance
|
| 118 |
+
_i18n_instance: Optional[I18n] = None
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def get_i18n(language: Optional[str] = None) -> I18n:
|
| 122 |
+
"""
|
| 123 |
+
Get global i18n instance
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
language: Optional language to set
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
I18n instance
|
| 130 |
+
"""
|
| 131 |
+
global _i18n_instance
|
| 132 |
+
|
| 133 |
+
if _i18n_instance is None:
|
| 134 |
+
_i18n_instance = I18n(default_language=language or "en")
|
| 135 |
+
elif language is not None:
|
| 136 |
+
_i18n_instance.set_language(language)
|
| 137 |
+
|
| 138 |
+
return _i18n_instance
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def t(key: str, **kwargs) -> str:
|
| 142 |
+
"""
|
| 143 |
+
Convenience function for translation
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
key: Translation key
|
| 147 |
+
**kwargs: Optional format parameters
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Translated string
|
| 151 |
+
"""
|
| 152 |
+
return get_i18n().t(key, **kwargs)
|
acestep/gradio_ui/i18n/en.json
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 Playground💡",
|
| 4 |
+
"subtitle": "Pushing the Boundaries of Open-Source Music Generation"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 Dataset Explorer",
|
| 8 |
+
"dataset_label": "Dataset",
|
| 9 |
+
"dataset_info": "Choose dataset to explore",
|
| 10 |
+
"import_btn": "📥 Import Dataset",
|
| 11 |
+
"search_type_label": "Search Type",
|
| 12 |
+
"search_type_info": "How to find items",
|
| 13 |
+
"search_value_label": "Search Value",
|
| 14 |
+
"search_value_placeholder": "Enter keys or index (leave empty for random)",
|
| 15 |
+
"search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
|
| 16 |
+
"instruction_label": "📝 Instruction",
|
| 17 |
+
"instruction_placeholder": "No instruction available",
|
| 18 |
+
"metadata_title": "📋 Item Metadata (JSON)",
|
| 19 |
+
"metadata_label": "Complete Item Information",
|
| 20 |
+
"source_audio": "Source Audio",
|
| 21 |
+
"target_audio": "Target Audio",
|
| 22 |
+
"reference_audio": "Reference Audio",
|
| 23 |
+
"get_item_btn": "🔍 Get Item",
|
| 24 |
+
"use_src_checkbox": "Use Source Audio from Dataset",
|
| 25 |
+
"use_src_info": "Check to use the source audio from dataset",
|
| 26 |
+
"data_status_label": "📊 Data Status",
|
| 27 |
+
"data_status_default": "❌ No dataset imported",
|
| 28 |
+
"autofill_btn": "📋 Auto-fill Generation Form"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 Service Configuration",
|
| 32 |
+
"checkpoint_label": "Checkpoint File",
|
| 33 |
+
"checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
|
| 34 |
+
"refresh_btn": "🔄 Refresh",
|
| 35 |
+
"model_path_label": "Main Model Path",
|
| 36 |
+
"model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
|
| 37 |
+
"device_label": "Device",
|
| 38 |
+
"device_info": "Processing device (auto-detect recommended)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM Model Path",
|
| 40 |
+
"lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
|
| 41 |
+
"backend_label": "5Hz LM Backend",
|
| 42 |
+
"backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
|
| 43 |
+
"init_llm_label": "Initialize 5Hz LM",
|
| 44 |
+
"init_llm_info": "Check to initialize 5Hz LM during service initialization",
|
| 45 |
+
"flash_attention_label": "Use Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
|
| 48 |
+
"offload_cpu_label": "Offload to CPU",
|
| 49 |
+
"offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
|
| 50 |
+
"offload_dit_cpu_label": "Offload DiT to CPU",
|
| 51 |
+
"offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
|
| 52 |
+
"compile_model_label": "Compile Model",
|
| 53 |
+
"compile_model_info": "Use torch.compile to optimize model (required for quantization)",
|
| 54 |
+
"quantization_label": "INT8 Quantization",
|
| 55 |
+
"quantization_info": "Enable INT8 weight-only quantization to reduce VRAM usage (requires Compile Model)",
|
| 56 |
+
"init_btn": "Initialize Service",
|
| 57 |
+
"status_label": "Status",
|
| 58 |
+
"language_label": "UI Language",
|
| 59 |
+
"language_info": "Select interface language"
|
| 60 |
+
},
|
| 61 |
+
"generation": {
|
| 62 |
+
"required_inputs": "📝 Required Inputs",
|
| 63 |
+
"task_type_label": "Task Type",
|
| 64 |
+
"task_type_info": "Select the task type for generation",
|
| 65 |
+
"instruction_label": "Instruction",
|
| 66 |
+
"instruction_info": "Instruction is automatically generated based on task type",
|
| 67 |
+
"load_btn": "Load",
|
| 68 |
+
"track_name_label": "Track Name",
|
| 69 |
+
"track_name_info": "Select track name for lego/extract tasks",
|
| 70 |
+
"track_classes_label": "Track Names",
|
| 71 |
+
"track_classes_info": "Select multiple track classes for complete task",
|
| 72 |
+
"audio_uploads": "🎵 Audio Uploads",
|
| 73 |
+
"reference_audio": "Reference Audio (optional)",
|
| 74 |
+
"source_audio": "Source Audio (optional)",
|
| 75 |
+
"convert_codes_btn": "Convert to Codes",
|
| 76 |
+
"lm_codes_hints": "🎼 LM Codes Hints",
|
| 77 |
+
"lm_codes_label": "LM Codes Hints",
|
| 78 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 79 |
+
"lm_codes_info": "Paste LM codes hints for text2music generation",
|
| 80 |
+
"lm_codes_sample": "LM Codes Hints (Sample {n})",
|
| 81 |
+
"lm_codes_sample_info": "Codes for sample {n}",
|
| 82 |
+
"transcribe_btn": "Transcribe",
|
| 83 |
+
"repainting_controls": "🎨 Repainting Controls (seconds)",
|
| 84 |
+
"repainting_start": "Repainting Start",
|
| 85 |
+
"repainting_end": "Repainting End",
|
| 86 |
+
"mode_label": "Generation Mode",
|
| 87 |
+
"mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
|
| 88 |
+
"mode_simple": "Simple",
|
| 89 |
+
"mode_custom": "Custom",
|
| 90 |
+
"simple_query_label": "Song Description",
|
| 91 |
+
"simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
|
| 92 |
+
"simple_query_info": "Enter a natural language description of the music you want to generate",
|
| 93 |
+
"simple_vocal_language_label": "Vocal Language (optional)",
|
| 94 |
+
"simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
|
| 95 |
+
"create_sample_btn": "Create Sample",
|
| 96 |
+
"caption_title": "📝 Music Caption",
|
| 97 |
+
"caption_label": "Music Caption (optional)",
|
| 98 |
+
"caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
|
| 99 |
+
"caption_info": "Describe the style, genre, instruments, and mood",
|
| 100 |
+
"lyrics_title": "📝 Lyrics",
|
| 101 |
+
"lyrics_label": "Lyrics (optional)",
|
| 102 |
+
"lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
|
| 103 |
+
"lyrics_info": "Song lyrics with structure",
|
| 104 |
+
"instrumental_label": "Instrumental",
|
| 105 |
+
"format_btn": "Format",
|
| 106 |
+
"optional_params": "⚙️ Optional Parameters",
|
| 107 |
+
"vocal_language_label": "Vocal Language (optional)",
|
| 108 |
+
"vocal_language_info": "use `unknown` for inst",
|
| 109 |
+
"bpm_label": "BPM (optional)",
|
| 110 |
+
"bpm_info": "leave empty for N/A",
|
| 111 |
+
"keyscale_label": "KeyScale (optional)",
|
| 112 |
+
"keyscale_placeholder": "Leave empty for N/A",
|
| 113 |
+
"keyscale_info": "A-G, #/♭, major/minor",
|
| 114 |
+
"timesig_label": "Time Signature (optional)",
|
| 115 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 116 |
+
"duration_label": "Audio Duration (seconds)",
|
| 117 |
+
"duration_info": "Use -1 for random",
|
| 118 |
+
"batch_size_label": "Batch Size",
|
| 119 |
+
"batch_size_info": "Number of audio to generate (max 8)",
|
| 120 |
+
"advanced_settings": "🔧 Advanced Settings",
|
| 121 |
+
"inference_steps_label": "DiT Inference Steps",
|
| 122 |
+
"inference_steps_info": "Turbo: max 8, Base: max 200",
|
| 123 |
+
"guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
|
| 124 |
+
"guidance_scale_info": "Higher values follow text more closely",
|
| 125 |
+
"seed_label": "Seed",
|
| 126 |
+
"seed_info": "Use comma-separated values for batches",
|
| 127 |
+
"random_seed_label": "Random Seed",
|
| 128 |
+
"random_seed_info": "Enable to auto-generate seeds",
|
| 129 |
+
"audio_format_label": "Audio Format",
|
| 130 |
+
"audio_format_info": "Audio format for saved files",
|
| 131 |
+
"use_adg_label": "Use ADG",
|
| 132 |
+
"use_adg_info": "Enable Angle Domain Guidance",
|
| 133 |
+
"shift_label": "Shift",
|
| 134 |
+
"shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
|
| 135 |
+
"infer_method_label": "Inference Method",
|
| 136 |
+
"infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
|
| 137 |
+
"custom_timesteps_label": "Custom Timesteps",
|
| 138 |
+
"custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
|
| 139 |
+
"cfg_interval_start": "CFG Interval Start",
|
| 140 |
+
"cfg_interval_end": "CFG Interval End",
|
| 141 |
+
"lm_params_title": "🤖 LM Generation Parameters",
|
| 142 |
+
"lm_temperature_label": "LM Temperature",
|
| 143 |
+
"lm_temperature_info": "5Hz LM temperature (higher = more random)",
|
| 144 |
+
"lm_cfg_scale_label": "LM CFG Scale",
|
| 145 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
|
| 146 |
+
"lm_top_k_label": "LM Top-K",
|
| 147 |
+
"lm_top_k_info": "Top-K (0 = disabled)",
|
| 148 |
+
"lm_top_p_label": "LM Top-P",
|
| 149 |
+
"lm_top_p_info": "Top-P (1.0 = disabled)",
|
| 150 |
+
"lm_negative_prompt_label": "LM Negative Prompt",
|
| 151 |
+
"lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 152 |
+
"lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
|
| 153 |
+
"cot_metas_label": "CoT Metas",
|
| 154 |
+
"cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
|
| 155 |
+
"cot_language_label": "CoT Language",
|
| 156 |
+
"cot_language_info": "Generate language in CoT (chain-of-thought)",
|
| 157 |
+
"constrained_debug_label": "Constrained Decoding Debug",
|
| 158 |
+
"constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
|
| 159 |
+
"auto_score_label": "Auto Score",
|
| 160 |
+
"auto_score_info": "Automatically calculate quality scores for all generated audios",
|
| 161 |
+
"auto_lrc_label": "Auto LRC",
|
| 162 |
+
"auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
|
| 163 |
+
"lm_batch_chunk_label": "LM Batch Chunk Size",
|
| 164 |
+
"lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
|
| 165 |
+
"codes_strength_label": "LM Codes Strength",
|
| 166 |
+
"codes_strength_info": "Control how many denoising steps use LM-generated codes",
|
| 167 |
+
"similarity_denoise_label": "Similarity / Denoise",
|
| 168 |
+
"similarity_denoise_info": "Controls how closely the output follows the reference audio. Higher values preserve more structure.",
|
| 169 |
+
"cover_strength_label": "Audio Cover Strength",
|
| 170 |
+
"cover_strength_info": "Control how many denoising steps use cover mode",
|
| 171 |
+
"score_sensitivity_label": "Quality Score Sensitivity",
|
| 172 |
+
"score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
|
| 173 |
+
"think_label": "Think",
|
| 174 |
+
"parallel_thinking_label": "ParallelThinking",
|
| 175 |
+
"generate_btn": "🎵 Generate Music",
|
| 176 |
+
"autogen_label": "AutoGen",
|
| 177 |
+
"caption_rewrite_label": "CaptionRewrite"
|
| 178 |
+
},
|
| 179 |
+
"results": {
|
| 180 |
+
"title": "🎵 Results",
|
| 181 |
+
"generated_music": "🎵 Generated Music (Sample {n})",
|
| 182 |
+
"send_to_src_btn": "🔗 Send To Src Audio",
|
| 183 |
+
"save_btn": "💾 Save",
|
| 184 |
+
"score_btn": "📊 Score",
|
| 185 |
+
"lrc_btn": "🎵 LRC",
|
| 186 |
+
"quality_score_label": "Quality Score (Sample {n})",
|
| 187 |
+
"quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
|
| 188 |
+
"codes_label": "LM Codes (Sample {n})",
|
| 189 |
+
"lrc_label": "Lyrics Timestamps (Sample {n})",
|
| 190 |
+
"lrc_placeholder": "Click 'LRC' to generate timestamps",
|
| 191 |
+
"details_accordion": "📊 Score & LRC & LM Codes",
|
| 192 |
+
"generation_status": "Generation Status",
|
| 193 |
+
"current_batch": "Current Batch",
|
| 194 |
+
"batch_indicator": "Batch {current} / {total}",
|
| 195 |
+
"next_batch_status": "Next Batch Status",
|
| 196 |
+
"prev_btn": "◀ Previous",
|
| 197 |
+
"next_btn": "Next ▶",
|
| 198 |
+
"restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
|
| 199 |
+
"batch_results_title": "👇 Click here to view batch results & generation details",
|
| 200 |
+
"all_files_label": "📁 All Generated Files (Download)",
|
| 201 |
+
"generation_details": "Generation Details"
|
| 202 |
+
},
|
| 203 |
+
"messages": {
|
| 204 |
+
"no_audio_to_save": "❌ No audio to save",
|
| 205 |
+
"save_success": "✅ Saved audio and metadata to {filename}",
|
| 206 |
+
"save_failed": "❌ Failed to save: {error}",
|
| 207 |
+
"no_file_selected": "⚠️ No file selected",
|
| 208 |
+
"params_loaded": "✅ Parameters loaded from {filename}",
|
| 209 |
+
"invalid_json": "❌ Invalid JSON file: {error}",
|
| 210 |
+
"load_error": "❌ Error loading file: {error}",
|
| 211 |
+
"example_loaded": "📁 Loaded example from {filename}",
|
| 212 |
+
"example_failed": "Failed to parse JSON file {filename}: {error}",
|
| 213 |
+
"example_error": "Error loading example: {error}",
|
| 214 |
+
"lm_generated": "🤖 Generated example using LM",
|
| 215 |
+
"lm_fallback": "Failed to generate example using LM, falling back to examples directory",
|
| 216 |
+
"lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
|
| 217 |
+
"autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
|
| 218 |
+
"batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
|
| 219 |
+
"batch_generating": "🔄 Starting background generation for Batch {n}...",
|
| 220 |
+
"batch_failed": "❌ Background generation failed: {error}",
|
| 221 |
+
"viewing_batch": "✅ Viewing Batch {n}",
|
| 222 |
+
"at_first_batch": "Already at first batch",
|
| 223 |
+
"at_last_batch": "No next batch available",
|
| 224 |
+
"batch_not_found": "Batch {n} not found in queue",
|
| 225 |
+
"no_batch_data": "No batch data found to restore.",
|
| 226 |
+
"params_restored": "✅ UI Parameters restored from Batch {n}",
|
| 227 |
+
"scoring_failed": "❌ Error: Batch data not found",
|
| 228 |
+
"no_codes": "❌ No audio codes available. Please generate music first.",
|
| 229 |
+
"score_failed": "❌ Scoring failed: {error}",
|
| 230 |
+
"score_error": "❌ Error calculating score: {error}",
|
| 231 |
+
"lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
|
| 232 |
+
"lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
|
| 233 |
+
"lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
|
| 234 |
+
"lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
|
| 235 |
+
"lrc_empty_result": "⚠️ LRC generation produced empty result.",
|
| 236 |
+
"empty_query": "⚠️ Please enter a music description.",
|
| 237 |
+
"sample_creation_failed": "❌ Failed to create sample. Please try again.",
|
| 238 |
+
"sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
|
| 239 |
+
"simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
|
| 240 |
+
"simple_examples_empty": "⚠️ No example files found in simple mode examples.",
|
| 241 |
+
"simple_example_loaded": "🎲 Loaded random example from {filename}",
|
| 242 |
+
"format_success": "✅ Caption and lyrics formatted successfully",
|
| 243 |
+
"format_failed": "❌ Format failed: {error}",
|
| 244 |
+
"skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
|
| 245 |
+
"invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
|
| 246 |
+
"timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
|
| 247 |
+
"timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
|
| 248 |
+
},
|
| 249 |
+
"training": {
|
| 250 |
+
"tab_title": "🎓 LoRA Training",
|
| 251 |
+
"tab_dataset_builder": "📁 Dataset Builder",
|
| 252 |
+
"tab_train_lora": "🚀 Train LoRA",
|
| 253 |
+
"quick_start_title": "🚀 Quick Start",
|
| 254 |
+
"load_dataset_label": "Dataset JSON Path",
|
| 255 |
+
"load_dataset_info": "Load a previously saved dataset",
|
| 256 |
+
"load_btn": "📂 Load",
|
| 257 |
+
"load_status": "Load Status",
|
| 258 |
+
"scan_label": "Audio Directory Path",
|
| 259 |
+
"scan_info": "Scan for audio files (wav, mp3, flac, ogg, opus)",
|
| 260 |
+
"scan_btn": "🔍 Scan",
|
| 261 |
+
"scan_status": "Scan Status",
|
| 262 |
+
"found_audio_files": "Found Audio Files",
|
| 263 |
+
"dataset_name": "Dataset Name",
|
| 264 |
+
"dataset_name_placeholder": "Enter dataset name",
|
| 265 |
+
"dataset_settings_header": "Dataset Settings",
|
| 266 |
+
"tag_prepend": "Prepend (tag, caption)",
|
| 267 |
+
"tag_append": "Append (caption, tag)",
|
| 268 |
+
"tag_replace": "Replace caption",
|
| 269 |
+
"step2_title": "Step 2: Auto-Label with AI",
|
| 270 |
+
"step3_title": "Step 3: Preview & Edit",
|
| 271 |
+
"step4_title": "Step 4: Save Dataset",
|
| 272 |
+
"step5_title": "Step 5: Preprocess to Tensors",
|
| 273 |
+
"all_instrumental": "All Instrumental",
|
| 274 |
+
"all_instrumental_info": "Check if all tracks are instrumental (no vocals)",
|
| 275 |
+
"custom_tag": "Custom Activation Tag",
|
| 276 |
+
"custom_tag_info": "Unique tag to activate this LoRA's style",
|
| 277 |
+
"tag_position": "Tag Position",
|
| 278 |
+
"tag_position_info": "Where to place the custom tag in the caption",
|
| 279 |
+
"genre_ratio": "Genre Ratio (%)",
|
| 280 |
+
"genre_ratio_info": "0%=all Caption, 100%=all Genre. Per-sample override takes priority.",
|
| 281 |
+
"skip_metas": "Skip BPM/Key/Time Signature",
|
| 282 |
+
"skip_metas_info": "Skip BPM/Key/Time Signature generation. Caption and Genre are still generated by LLM.",
|
| 283 |
+
"only_unlabeled": "Only Unlabeled",
|
| 284 |
+
"only_unlabeled_info": "Only label samples without caption (useful for resuming failed labeling)",
|
| 285 |
+
"auto_label_btn": "🏷️ Auto-Label All",
|
| 286 |
+
"label_progress": "Labeling Progress",
|
| 287 |
+
"select_sample": "Select Sample #",
|
| 288 |
+
"select_sample_info": "Choose a sample to preview and edit",
|
| 289 |
+
"audio_preview": "Audio Preview",
|
| 290 |
+
"filename": "Filename",
|
| 291 |
+
"caption": "Caption",
|
| 292 |
+
"genre": "Genre",
|
| 293 |
+
"prompt_override_label": "Prompt Override (this sample)",
|
| 294 |
+
"prompt_override_info": "Override global ratio for this sample",
|
| 295 |
+
"lyrics_editable_label": "Lyrics (editable, used for training)",
|
| 296 |
+
"raw_lyrics_label": "Raw Lyrics (from .txt file)",
|
| 297 |
+
"no_lyrics_placeholder": "(no .txt lyrics file)",
|
| 298 |
+
"bpm": "BPM",
|
| 299 |
+
"key_label": "Key",
|
| 300 |
+
"key_placeholder": "C Major",
|
| 301 |
+
"time_sig": "Time Signature",
|
| 302 |
+
"duration_s": "Duration (s)",
|
| 303 |
+
"language": "Language",
|
| 304 |
+
"instrumental": "Instrumental",
|
| 305 |
+
"save_changes_btn": "💾 Save Changes",
|
| 306 |
+
"edit_status": "Edit Status",
|
| 307 |
+
"save_path": "Save Path",
|
| 308 |
+
"save_path_info": "Path where the dataset JSON will be saved",
|
| 309 |
+
"save_dataset_btn": "💾 Save Dataset",
|
| 310 |
+
"save_status": "Save Status",
|
| 311 |
+
"load_existing_label": "Load Existing Dataset (Optional)",
|
| 312 |
+
"load_existing_info": "Path to a previously saved dataset JSON file",
|
| 313 |
+
"load_dataset_btn": "📂 Load Dataset",
|
| 314 |
+
"tensor_output_dir": "Tensor Output Directory",
|
| 315 |
+
"tensor_output_info": "Directory to save preprocessed tensor files",
|
| 316 |
+
"preprocess_btn": "⚡ Preprocess",
|
| 317 |
+
"preprocess_progress": "Preprocessing Progress",
|
| 318 |
+
"preprocessed_tensors_dir": "Preprocessed Tensors Directory",
|
| 319 |
+
"preprocessed_tensors_info": "Directory containing preprocessed .pt tensor files",
|
| 320 |
+
"train_section_tensors": "Preprocessed Dataset Selection",
|
| 321 |
+
"train_section_lora": "LoRA Settings",
|
| 322 |
+
"train_section_params": "Training Parameters",
|
| 323 |
+
"dataset_info": "Dataset Info",
|
| 324 |
+
"lora_rank": "LoRA Rank (r)",
|
| 325 |
+
"lora_rank_info": "Higher = more capacity, more memory",
|
| 326 |
+
"lora_alpha": "LoRA Alpha",
|
| 327 |
+
"lora_alpha_info": "Scaling factor (typically 2x rank)",
|
| 328 |
+
"lora_dropout": "LoRA Dropout",
|
| 329 |
+
"learning_rate": "Learning Rate",
|
| 330 |
+
"learning_rate_info": "Start with 3e-4, adjust if needed",
|
| 331 |
+
"max_epochs": "Max Epochs",
|
| 332 |
+
"batch_size": "Batch Size",
|
| 333 |
+
"batch_size_info": "Increase if you have enough VRAM",
|
| 334 |
+
"gradient_accumulation": "Gradient Accumulation",
|
| 335 |
+
"gradient_accumulation_info": "Effective batch = batch_size × accumulation",
|
| 336 |
+
"save_every_n_epochs": "Save Every N Epochs",
|
| 337 |
+
"shift": "Shift",
|
| 338 |
+
"shift_info": "Timestep shift for turbo model",
|
| 339 |
+
"seed": "Seed",
|
| 340 |
+
"output_dir": "Output Directory",
|
| 341 |
+
"output_dir_info": "Directory to save trained LoRA weights",
|
| 342 |
+
"start_training_btn": "🚀 Start Training",
|
| 343 |
+
"stop_training_btn": "⏹️ Stop Training",
|
| 344 |
+
"training_progress": "Training Progress",
|
| 345 |
+
"training_log": "Training Log",
|
| 346 |
+
"training_loss_title": "Training Loss",
|
| 347 |
+
"step": "Step",
|
| 348 |
+
"loss": "Loss",
|
| 349 |
+
"export_header": "Export LoRA",
|
| 350 |
+
"export_path": "Export Path",
|
| 351 |
+
"export_lora_btn": "📦 Export LoRA",
|
| 352 |
+
"export_status": "Export Status"
|
| 353 |
+
}
|
| 354 |
+
}
|
acestep/gradio_ui/i18n/he.json
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ סביבת העבודה ACE-Step V1.5 Playground💡",
|
| 4 |
+
"subtitle": "פורצים את גבולות יצירת המוזיקה בקוד פתוח"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 סייר מערכי נתונים (Dataset Explorer)",
|
| 8 |
+
"dataset_label": "מערך נתונים",
|
| 9 |
+
"dataset_info": "בחר מערך נתונים לחקירה",
|
| 10 |
+
"import_btn": "📥 ייבוא מערך נתונים",
|
| 11 |
+
"search_type_label": "סוג חיפוש",
|
| 12 |
+
"search_type_info": "כיצד למצוא פריטים",
|
| 13 |
+
"search_value_label": "ערך חיפוש",
|
| 14 |
+
"search_value_placeholder": "הזן מפתחות או אינדקס (השאר ריק לבחירה אקראית)",
|
| 15 |
+
"search_value_info": "מפתחות: התאמה מדויקת, אינדקס: 0 עד גודל המערך פחות 1",
|
| 16 |
+
"instruction_label": "📝 הנחיה (Instruction)",
|
| 17 |
+
"instruction_placeholder": "אין הנחיה זמינה",
|
| 18 |
+
"metadata_title": "📋 מטא-דאטה של הפריט (JSON)",
|
| 19 |
+
"metadata_label": "מידע מלא על הפריט",
|
| 20 |
+
"source_audio": "אודיו מקור",
|
| 21 |
+
"target_audio": "אודיו יעד",
|
| 22 |
+
"reference_audio": "אודיו ייחוס",
|
| 23 |
+
"get_item_btn": "🔍 קבל פריט",
|
| 24 |
+
"use_src_checkbox": "השתמש באודיו מקור ממערך הנתונים",
|
| 25 |
+
"use_src_info": "סמן כדי להשתמש באודיו המקור מתוך מערך הנתונים",
|
| 26 |
+
"data_status_label": "📊 מצב נתונים",
|
| 27 |
+
"data_status_default": "❌ לא יובא מערך נתונים",
|
| 28 |
+
"autofill_btn": "📋 מילוי אוטומטי של טופס היצירה"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 הגדרות שירות",
|
| 32 |
+
"checkpoint_label": "קובץ נקודת ביקורת (Checkpoint)",
|
| 33 |
+
"checkpoint_info": "בחר קובץ נקודת ביקורת של מודל מאומן (נתיב מלא או שם קובץ)",
|
| 34 |
+
"refresh_btn": "🔄 רענון",
|
| 35 |
+
"model_path_label": "נתיב מודל ראשי",
|
| 36 |
+
"model_path_info": "בחר את ספריית הגדרות המודל (נסרק אוטומטית מנקודות הביקורת)",
|
| 37 |
+
"device_label": "מכשיר (Device)",
|
| 38 |
+
"device_info": "מכשיר עיבוד (מומלץ זיהוי אוטומטי)",
|
| 39 |
+
"lm_model_path_label": "נתיב מודל 5Hz LM",
|
| 40 |
+
"lm_model_path_info": "בחר את קובץ נקודת הביקורת של מודל ה-5Hz LM",
|
| 41 |
+
"backend_label": "מנוע (Backend) 5Hz LM",
|
| 42 |
+
"backend_info": "בחר מנוע עבור 5Hz LM: vllm (מהיר יותר) או pt (PyTorch, תואם יותר)",
|
| 43 |
+
"init_llm_label": "אתחול 5Hz LM",
|
| 44 |
+
"init_llm_info": "סמן כדי לאתחל את ה-5Hz LM במהלך אתחול השירות",
|
| 45 |
+
"flash_attention_label": "השתמש ב-Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "הפעל Flash Attention להסקה מהירה יותר (דורש חבילת flash_attn)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash Attention אינו זמין (חבילת flash_attn לא מותקנת)",
|
| 48 |
+
"offload_cpu_label": "העברה ל-CPU (Offload)",
|
| 49 |
+
"offload_cpu_info": "העבר מודלים ל-CPU כשאינם בשימוש כדי לחסוך בזיכרון גרפי (VRAM)",
|
| 50 |
+
"offload_dit_cpu_label": "העברת DiT ל-CPU",
|
| 51 |
+
"offload_dit_cpu_info": "העבר DiT ל-CPU (דורש 'העברה ל-CPU')",
|
| 52 |
+
"compile_model_label": "הידור מודל (Compile)",
|
| 53 |
+
"compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
|
| 54 |
+
"quantization_label": "קוונטיזציה INT8",
|
| 55 |
+
"quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
|
| 56 |
+
"init_btn": "אתחול שירות",
|
| 57 |
+
"status_label": "מצב",
|
| 58 |
+
"language_label": "שפת ממשק",
|
| 59 |
+
"language_info": "בחר את שפת הממשק"
|
| 60 |
+
},
|
| 61 |
+
"generation": {
|
| 62 |
+
"required_inputs": "📝 קלטים נדרשים",
|
| 63 |
+
"task_type_label": "סוג משימה",
|
| 64 |
+
"task_type_info": "בחר את סוג המשימה ליצירה",
|
| 65 |
+
"instruction_label": "הנחיה",
|
| 66 |
+
"instruction_info": "ההנחיה נוצרת אוטומטית בהתאם לסוג המשימה",
|
| 67 |
+
"load_btn": "טעינה",
|
| 68 |
+
"track_name_label": "שם רצועה",
|
| 69 |
+
"track_name_info": "בחר שם רצועה עבור משימות lego/extract",
|
| 70 |
+
"track_classes_label": "שמות רצועות",
|
| 71 |
+
"track_classes_info": "בחר מספר מחלקות רצועה עבור משימה מלאה",
|
| 72 |
+
"audio_uploads": "🎵 העלאות אודיו",
|
| 73 |
+
"reference_audio": "אודיו ייחוס (אופציונלי)",
|
| 74 |
+
"source_audio": "אודיו מקור (אופציונלי)",
|
| 75 |
+
"convert_codes_btn": "המר לקודים",
|
| 76 |
+
"lm_codes_hints": "🎼 רמזי קודי LM",
|
| 77 |
+
"lm_codes_label": "רמזי קודי LM",
|
| 78 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 79 |
+
"lm_codes_info": "הדבק רמזי קודי LM עבור יצירת טקסט למוזיקה (text2music)",
|
| 80 |
+
"lm_codes_sample": "רמזי קודי LM (דגימה {n})",
|
| 81 |
+
"lm_codes_sample_info": "קודים עבור דגימה {n}",
|
| 82 |
+
"transcribe_btn": "תמלול",
|
| 83 |
+
"repainting_controls": "🎨 בקרת צביעה מחדש (בשניות)",
|
| 84 |
+
"repainting_start": "תחילת צביעה מחדש",
|
| 85 |
+
"repainting_end": "סיום צביעה מחדש",
|
| 86 |
+
"mode_label": "מצב יצירה",
|
| 87 |
+
"mode_info": "פשוט: תאר מוזיקה בשפה טבעית. מותאם אישית: שליטה מלאה בתיאור ומילים.",
|
| 88 |
+
"mode_simple": "פשוט",
|
| 89 |
+
"mode_custom": "מותאם אישית",
|
| 90 |
+
"simple_query_label": "תיאור השיר",
|
| 91 |
+
"simple_query_placeholder": "תאר את המוזיקה שברצונך ליצור, למשל: 'שיר אהבה אקוסטי שקט לערב רגוע'. השאר ריק לדגימה אקראית.",
|
| 92 |
+
"simple_query_info": "הזן תיאור בשפה טבעית של המוזיקה שברצונך ליצור",
|
| 93 |
+
"simple_vocal_language_label": "שפת שירה (אופציונלי)",
|
| 94 |
+
"simple_vocal_language_info": "בחר שפות מועדפות למילים. השתמש ב-'unknown' לכל שפה.",
|
| 95 |
+
"create_sample_btn": "צור דגימה",
|
| 96 |
+
"caption_title": "📝 תיאור מוזיקלי (Caption)",
|
| 97 |
+
"caption_label": "תיאור מוזיקלי (אופציונלי)",
|
| 98 |
+
"caption_placeholder": "מנגינת גיטרה אקוסטית שלווה עם שירה רכה...",
|
| 99 |
+
"caption_info": "תאר את הסגנון, הז'אנר, הכלים והאווירה",
|
| 100 |
+
"lyrics_title": "📝 מילים",
|
| 101 |
+
"lyrics_label": "מילים (אופציונלי)",
|
| 102 |
+
"lyrics_placeholder": "[בית 1]\\nתחת שמי הלילה...\\nאני מרגיש חי...",
|
| 103 |
+
"lyrics_info": "מילות השיר עם מבנה",
|
| 104 |
+
"instrumental_label": "אינסטרומנטלי (ללא שירה)",
|
| 105 |
+
"format_btn": "פרמוט",
|
| 106 |
+
"optional_params": "⚙️ פרמטרים אופציונליים",
|
| 107 |
+
"vocal_language_label": "שפת שירה (אופציונלי)",
|
| 108 |
+
"vocal_language_info": "השתמש ב-`unknown` לקטעים כליים",
|
| 109 |
+
"bpm_label": "קצב (BPM) (אופציונלי)",
|
| 110 |
+
"bpm_info": "השאר ריק אם לא ידוע",
|
| 111 |
+
"keyscale_label": "סולם (KeyScale) (אופציונלי)",
|
| 112 |
+
"keyscale_placeholder": "השאר ריק אם לא ידוע",
|
| 113 |
+
"keyscale_info": "A-G, #/♭, מז'ור/מינור",
|
| 114 |
+
"timesig_label": "משקל מוזיקלי (אופציונלי)",
|
| 115 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 116 |
+
"duration_label": "אורך אודיו (שניות)",
|
| 117 |
+
"duration_info": "השתמש ב-1- לאקראי",
|
| 118 |
+
"batch_size_label": "גודל מנה (Batch Size)",
|
| 119 |
+
"batch_size_info": "מספר קטעי אודיו ליצירה (מקסימום 8)",
|
| 120 |
+
"advanced_settings": "🔧 הגדרות מתקדמות",
|
| 121 |
+
"inference_steps_label": "צעדי הסקה של DiT",
|
| 122 |
+
"inference_steps_info": "Turbo: מקסימום 8, Base: מקסימום 200",
|
| 123 |
+
"guidance_scale_label": "קנה מידה להנחיה (רק למודל base)",
|
| 124 |
+
"guidance_scale_info": "ערכים גבוהים יותר נצמדים יותר לטקסט",
|
| 125 |
+
"seed_label": "גרעין (Seed)",
|
| 126 |
+
"seed_info": "השתמש בערכים מופרדים בפסיקים עבור מנות",
|
| 127 |
+
"random_seed_label": "גרעין אקראי",
|
| 128 |
+
"random_seed_info": "אפשר ליצירה אוטומטית של גרעינים",
|
| 129 |
+
"audio_format_label": "פורמט אודיו",
|
| 130 |
+
"audio_format_info": "פורמט האודיו עבור הקבצים שיישמרו",
|
| 131 |
+
"use_adg_label": "השתמש ב-ADG",
|
| 132 |
+
"use_adg_info": "הפעל Angle Domain Guidance",
|
| 133 |
+
"shift_label": "Shift",
|
| 134 |
+
"shift_info": "פקטור הסטת צעדי זמן למודלי base (טווח 1.0~5.0, ברירת מחדל 3.0). לא משפיע על מודלי turbo.",
|
| 135 |
+
"infer_method_label": "שיטת הסקה",
|
| 136 |
+
"infer_method_info": "שיטת הסקת הדיפוזיה. ODE (Euler) מהירה יותר, SDE (stochastic) עשויה להפיק תוצאות שונות.",
|
| 137 |
+
"custom_timesteps_label": "צעדי זמן מותאמים אישית",
|
| 138 |
+
"custom_timesteps_info": "אופציונלי: ערכים מופרדים בפסיקים מ-1.0 עד 0.0. דורס את צעדי ההסקה וה-shift.",
|
| 139 |
+
"cfg_interval_start": "תחילת מרווח CFG",
|
| 140 |
+
"cfg_interval_end": "סיום מרווח CFG",
|
| 141 |
+
"lm_params_title": "🤖 פרמטרי יצירת LM",
|
| 142 |
+
"lm_temperature_label": "טמפרטורת LM",
|
| 143 |
+
"lm_temperature_info": "טמפרטורת 5Hz LM (גבוה יותר = אקראי יותר)",
|
| 144 |
+
"lm_cfg_scale_label": "קנה מידה LM CFG",
|
| 145 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = ללא CFG)",
|
| 146 |
+
"lm_top_k_label": "LM Top-K",
|
| 147 |
+
"lm_top_k_info": "Top-K (0 = מושבת)",
|
| 148 |
+
"lm_top_p_label": "LM Top-P",
|
| 149 |
+
"lm_top_p_info": "Top-P (1.0 = מושבת)",
|
| 150 |
+
"lm_negative_prompt_label": "הנחיה שלילית ל-LM",
|
| 151 |
+
"lm_negative_prompt_placeholder": "הזן הנחיה שלילית עבור CFG",
|
| 152 |
+
"lm_negative_prompt_info": "הנחיה שלילית (בשימוש כאשר LM CFG Scale > 1.0)",
|
| 153 |
+
"cot_metas_label": "CoT Metas",
|
| 154 |
+
"cot_metas_info": "השתמש ב-LM ליצירת מטא-דאטה CoT (בטל סימון כדי לדלג)",
|
| 155 |
+
"cot_language_label": "שפת CoT",
|
| 156 |
+
"cot_language_info": "יצירת שפה ב-CoT (שרשרת מחשבה)",
|
| 157 |
+
"constrained_debug_label": "ניקוי באגים של פענוח מוגבל",
|
| 158 |
+
"constrained_debug_info": "הפעל לוגים של ניקוי באגים עבור פענוח מוגבל",
|
| 159 |
+
"auto_score_label": "דירוג אוטומטי",
|
| 160 |
+
"auto_score_info": "חשב אוטומטית ציוני איכות לכל קטעי האודיו שנוצרו",
|
| 161 |
+
"auto_lrc_label": "LRC אוטומטי",
|
| 162 |
+
"auto_lrc_info": "צור אוטומטית חותמות זמן למילים (LRC) לכל קטעי האודיו",
|
| 163 |
+
"lm_batch_chunk_label": "גודל מקטע מנת LM",
|
| 164 |
+
"lm_batch_chunk_info": "מקסימום פריטים למקטע מנת LM (ברירת מחדל: 8, מוגבל ע\"י זיכרון ה-GPU)",
|
| 165 |
+
"codes_strength_label": "חוזק קודי LM",
|
| 166 |
+
"codes_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים בקודים שנוצרו ע\"י ה-LM",
|
| 167 |
+
"cover_strength_label": "חוזק כיסוי אודיו (Audio Cover)",
|
| 168 |
+
"cover_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים במצב כיסוי",
|
| 169 |
+
"score_sensitivity_label": "רגישות ציון איכות",
|
| 170 |
+
"score_sensitivity_info": "נמוך יותר = רגיש יותר (ברירת מחדל: 1.0)",
|
| 171 |
+
"think_label": "חשיבה (Think)",
|
| 172 |
+
"parallel_thinking_label": "חשיבה מקבילית",
|
| 173 |
+
"generate_btn": "🎵 צור מוזיקה",
|
| 174 |
+
"autogen_label": "יצירה אוטומטית",
|
| 175 |
+
"caption_rewrite_label": "שכתוב תיאור"
|
| 176 |
+
},
|
| 177 |
+
"results": {
|
| 178 |
+
"title": "🎵 תוצאות",
|
| 179 |
+
"generated_music": "🎵 מוזיקה שנוצרה (דגימה {n})",
|
| 180 |
+
"send_to_src_btn": "🔗 שלח לאודיו מקור",
|
| 181 |
+
"save_btn": "💾 שמירה",
|
| 182 |
+
"score_btn": "📊 דירוג",
|
| 183 |
+
"lrc_btn": "🎵 LRC",
|
| 184 |
+
"quality_score_label": "ציון איכות (דגימה {n})",
|
| 185 |
+
"quality_score_placeholder": "לחץ על 'דירוג' לחישוב ציון איכות מבוסס מורכבות (Perplexity)",
|
| 186 |
+
"codes_label": "קודי LM (דגימה {n})",
|
| 187 |
+
"lrc_label": "חותמות זמן למילים (דגימה {n})",
|
| 188 |
+
"lrc_placeholder": "לחץ על 'LRC' ליצירת חותמות זמן",
|
| 189 |
+
"details_accordion": "📊 דירוג, LRC וקודי LM",
|
| 190 |
+
"generation_status": "מצב יצירה",
|
| 191 |
+
"current_batch": "מנה נוכחית",
|
| 192 |
+
"batch_indicator": "מנה {current} / {total}",
|
| 193 |
+
"next_batch_status": "מצב המנה הבאה",
|
| 194 |
+
"prev_btn": "◀ הקודם",
|
| 195 |
+
"next_btn": "הבא ▶",
|
| 196 |
+
"restore_params_btn": "↙️ החל הגדרות אלו על הממשק (שחזור פרמטרי מנה)",
|
| 197 |
+
"batch_results_title": "📁 תוצאות המנה ופרטי יצירה",
|
| 198 |
+
"all_files_label": "📁 כל הקבצים שנוצרו (הורדה)",
|
| 199 |
+
"generation_details": "פרטי יצירה"
|
| 200 |
+
},
|
| 201 |
+
"messages": {
|
| 202 |
+
"no_audio_to_save": "❌ אין אודיו לשמירה",
|
| 203 |
+
"save_success": "✅ האודיו והמטא-דאטה נשמרו ב-{filename}",
|
| 204 |
+
"save_failed": "❌ השמירה נכשלה: {error}",
|
| 205 |
+
"no_file_selected": "⚠️ לא נבחר קובץ",
|
| 206 |
+
"params_loaded": "✅ הפרמטרים נטענו מ-{filename}",
|
| 207 |
+
"invalid_json": "❌ קובץ JSON לא תקין: {error}",
|
| 208 |
+
"load_error": "❌ שגיאה בטעינת הקובץ: {error}",
|
| 209 |
+
"example_loaded": "📁 נטען דגם מ-{filename}",
|
| 210 |
+
"example_failed": "נכשל ניתוח קובץ ה-JSON ב-{filename}: {error}",
|
| 211 |
+
"example_error": "שגיאה בטעינת הדגם: {error}",
|
| 212 |
+
"lm_generated": "🤖 נוצר דגם באמצעות ה-LM",
|
| 213 |
+
"lm_fallback": "יצירת דגם באמצעות ה-LM נכשלה, חוזר לשימוש בספריית הדגמים",
|
| 214 |
+
"lm_not_initialized": "❌ 5Hz LM לא מאותחל. נא לאתחל אותו תחילה.",
|
| 215 |
+
"autogen_enabled": "🔄 יצירה אוטומטית הופעלה - המנה הבאה תיווצר לאחר מכן",
|
| 216 |
+
"batch_ready": "✅ מנה {n} מוכנה! לחץ על 'הבא' לצפייה.",
|
| 217 |
+
"batch_generating": "🔄 מתחיל יצירת רקע עבור מנה {n}...",
|
| 218 |
+
"batch_failed": "❌ יצירת הרקע נכשלה: {error}",
|
| 219 |
+
"viewing_batch": "✅ צופה במנה {n}",
|
| 220 |
+
"at_first_batch": "נמצא כבר במנה הראשונה",
|
| 221 |
+
"at_last_batch": "אין מנה באה זמינה",
|
| 222 |
+
"batch_not_found": "מנה {n} לא נמצאה בתור",
|
| 223 |
+
"no_batch_data": "לא נמצאו נתוני מנה לשחזור.",
|
| 224 |
+
"params_restored": "✅ פרמטרי הממשק שוחזרו ממנה {n}",
|
| 225 |
+
"scoring_failed": "❌ שגיאה: נתוני המנה לא נמצאו",
|
| 226 |
+
"no_codes": "❌ אין קודי אודיו זמינים. נא ליצור מוזיקה תחילה.",
|
| 227 |
+
"score_failed": "❌ הדירוג נכשל: {error}",
|
| 228 |
+
"score_error": "❌ שגיאה בחישוב הציון: {error}",
|
| 229 |
+
"lrc_no_batch_data": "❌ לא נמצאו נתוני מנה. נא ליצור מוזיקה תחילה.",
|
| 230 |
+
"lrc_no_extra_outputs": "❌ לא נמצאו פלטים נוספים. טנזורי התניה אינם זמינים.",
|
| 231 |
+
"lrc_missing_tensors": "❌ חסרים טנזורים נדרשים ליצירת LRC.",
|
| 232 |
+
"lrc_sample_not_exist": "❌ הדגימה אינה קיימת במנה הנוכחית.",
|
| 233 |
+
"lrc_empty_result": "⚠️ יצירת ה-LRC הפיקה תוצאה ריקה.",
|
| 234 |
+
"empty_query": "⚠️ נא להזין תיאור מוזיקלי.",
|
| 235 |
+
"sample_creation_failed": "❌ יצירת הדגימה נכשלה. נא לנסות שוב.",
|
| 236 |
+
"sample_created": "✅ הדגימה נוצרה! בדוק את התיאור והמילים, ולאחר מכן לחץ על 'צור מוזיקה'.",
|
| 237 |
+
"simple_examples_not_found": "⚠️ ספריית הדגמים של המצב הפשוט לא נמצאה.",
|
| 238 |
+
"simple_examples_empty": "⚠️ לא נמצאו קבצי דוגמה במצב פשוט.",
|
| 239 |
+
"simple_example_loaded": "🎲 נטענה דוגמה אקראית מ-{filename}",
|
| 240 |
+
"format_success": "✅ התיאור והמילים פורמטו בהצלחה",
|
| 241 |
+
"format_failed": "❌ הפירמוט נכשל: {error}",
|
| 242 |
+
"skipping_metas_cot": "⚡ מדלג על שלב 1 של מטא-דאטה COT (הדגימה כבר מפורמטת)",
|
| 243 |
+
"invalid_timesteps_format": "⚠️ פורמט צעדי זמן לא תקין. משתמש בלוח זמנים כברירת מחדל.",
|
| 244 |
+
"timesteps_out_of_range": "⚠️ צעדי הזמן חייבים להיות בטווח [0, 1]. משתמש בלוח זמנים כברירת מחדל.",
|
| 245 |
+
"timesteps_count_mismatch": "⚠️ מספר צעדי הזמן ({actual}) שונה מצעדי ההסקה ({expected}). משתמש במספר צעדי הזמן."
|
| 246 |
+
},
|
| 247 |
+
"training": {
|
| 248 |
+
"tab_title": "🎓 אימון LoRA",
|
| 249 |
+
"tab_dataset_builder": "📁 בונה מערך נתונים",
|
| 250 |
+
"tab_train_lora": "🚀 אימון LoRA",
|
| 251 |
+
"quick_start_title": "🚀 התחלה מהירה",
|
| 252 |
+
"load_dataset_label": "נתיב קובץ JSON של מערך הנתונים",
|
| 253 |
+
"load_dataset_info": "טעינת מערך נתונים שנשמר בעבר",
|
| 254 |
+
"load_btn": "📂 טעינה",
|
| 255 |
+
"load_status": "מצב טעינה",
|
| 256 |
+
"scan_label": "נתיב ספריית אודיו",
|
| 257 |
+
"scan_info": "סריקה אחר קבצי אודיו (wav, mp3, flac, ogg, opus)",
|
| 258 |
+
"scan_btn": "🔍 סריקה",
|
| 259 |
+
"scan_status": "מצב סריקה",
|
| 260 |
+
"found_audio_files": "קבצי אודיו שנמצאו",
|
| 261 |
+
"dataset_name": "שם מערך הנתונים",
|
| 262 |
+
"dataset_name_placeholder": "הזן שם למערך הנתונים",
|
| 263 |
+
"dataset_settings_header": "הגדרות מערך נתונים",
|
| 264 |
+
"tag_prepend": "הוספה בהתחלה (תגית, תיאור)",
|
| 265 |
+
"tag_append": "הוספה בסוף (תיאור, תגית)",
|
| 266 |
+
"tag_replace": "החלפת התיאור",
|
| 267 |
+
"step2_title": "שלב 2: תיוג אוטומטי באמצעות AI",
|
| 268 |
+
"step3_title": "שלב 3: תצוגה מקדימה ועריכה",
|
| 269 |
+
"step4_title": "שלב 4: שמירת מערך הנתונים",
|
| 270 |
+
"step5_title": "שלב 5: עיבוד מקדים לטנזורים (Tensors)",
|
| 271 |
+
"all_instrumental": "הכל אינסטרומנטלי",
|
| 272 |
+
"all_instrumental_info": "סמן אם כל הרצועות הן כליות (ללא שירה)",
|
| 273 |
+
"custom_tag": "תגית הפעלה מותאמת אישית",
|
| 274 |
+
"custom_tag_info": "תגית ייחודית להפעלת הסגנון של LoRA זו",
|
| 275 |
+
"tag_position": "מיקום התגית",
|
| 276 |
+
"tag_position_info": "היכן למקם את התגית המותאמת אישית בתוך התיאור",
|
| 277 |
+
"genre_ratio": "יחס ז'אנר (%)",
|
| 278 |
+
"genre_ratio_info": "0% = הכל תיאור, 100% = הכל ז'אנר. הגדרה פר-דגימה קודמת להגדרת הכלל.",
|
| 279 |
+
"skip_metas": "דלג על BPM/סולם/משקל",
|
| 280 |
+
"skip_metas_info": "דלג על יצירת BPM/סולם/משקל. התיאור והז'אנר עדיין ייווצרו על ידי ה-LLM.",
|
| 281 |
+
"only_unlabeled": "רק כאלו ללא תיוג",
|
| 282 |
+
"only_unlabeled_info": "תייג רק דגימות ללא תיאור (שימושי להמשך תיוג שנכשל)",
|
| 283 |
+
"auto_label_btn": "🏷️ תיוג אוטומטי של הכל",
|
| 284 |
+
"label_progress": "התקדמות התיוג",
|
| 285 |
+
"select_sample": "בחר דגימה #",
|
| 286 |
+
"select_sample_info": "בחר דגימה לצפייה ועריכה",
|
| 287 |
+
"audio_preview": "תצוגה מקדימה של אודיו",
|
| 288 |
+
"filename": "שם קובץ",
|
| 289 |
+
"caption": "תיאור",
|
| 290 |
+
"genre": "ז'אנר",
|
| 291 |
+
"prompt_override_label": "דריסת פרומפט (לדגימה זו)",
|
| 292 |
+
"prompt_override_info": "דריסת היחס הכללי עבור דגימה ז��",
|
| 293 |
+
"lyrics_editable_label": "מילים (ניתן לעריכה, משמש לאימון)",
|
| 294 |
+
"raw_lyrics_label": "מילים גולמיות (מתוך קובץ .txt)",
|
| 295 |
+
"no_lyrics_placeholder": "(אין קובץ מילים .txt)",
|
| 296 |
+
"bpm": "BPM",
|
| 297 |
+
"key_label": "סולם (Key)",
|
| 298 |
+
"key_placeholder": "C Major",
|
| 299 |
+
"time_sig": "משקל מוזיקלי",
|
| 300 |
+
"duration_s": "משך (שניות)",
|
| 301 |
+
"language": "שפה",
|
| 302 |
+
"instrumental": "אינסטרומנטלי",
|
| 303 |
+
"save_changes_btn": "💾 שמירת שינויים",
|
| 304 |
+
"edit_status": "מצב עריכה",
|
| 305 |
+
"save_path": "נתיב שמירה",
|
| 306 |
+
"save_path_info": "הנתיב שבו יישמר קובץ ה-JSON של מערך הנתונים",
|
| 307 |
+
"save_dataset_btn": "💾 שמירת מערך נתונים",
|
| 308 |
+
"save_status": "מצב שמירה",
|
| 309 |
+
"load_existing_label": "טעינת מערך נתונים קיים (אופציונלי)",
|
| 310 |
+
"load_existing_info": "נתיב לקובץ JSON של מערך נתונים שנשמר בעבר",
|
| 311 |
+
"load_dataset_btn": "📂 טעינת מערך נתונים",
|
| 312 |
+
"tensor_output_dir": "ספריית פלט של טנזורים",
|
| 313 |
+
"tensor_output_info": "הספרייה לשמירת קבצי טנזור שעברו עיבוד מקדים",
|
| 314 |
+
"preprocess_btn": "⚡ עיבוד מקדים",
|
| 315 |
+
"preprocess_progress": "התקדמות עיבוד מקדים",
|
| 316 |
+
"preprocessed_tensors_dir": "ספריית טנזורים מעובדים",
|
| 317 |
+
"preprocessed_tensors_info": "ספרייה המכילה קבצי .pt של טנזורים מעובדים",
|
| 318 |
+
"train_section_tensors": "בחירת מערך נתונים מעובד",
|
| 319 |
+
"train_section_lora": "הגדרות LoRA",
|
| 320 |
+
"train_section_params": "פרמטרי אימון",
|
| 321 |
+
"dataset_info": "מידע על מערך הנתונים",
|
| 322 |
+
"lora_rank": "דרגת LoRA (Rank)",
|
| 323 |
+
"lora_rank_info": "גבוה יותר = יותר קיבולת, יותר זיכרון",
|
| 324 |
+
"lora_alpha": "LoRA Alpha",
|
| 325 |
+
"lora_alpha_info": "פקטור קנה מידה (בדרך כלל פי 2 מה-Rank)",
|
| 326 |
+
"lora_dropout": "LoRA Dropout",
|
| 327 |
+
"learning_rate": "קצב למידה (Learning Rate)",
|
| 328 |
+
"learning_rate_info": "התחל עם 3e-4, שנה במידת הצורך",
|
| 329 |
+
"max_epochs": "מקסימום תקופות (Epochs)",
|
| 330 |
+
"batch_size": "גודל מנה (Batch Size)",
|
| 331 |
+
"batch_size_info": "הגדל אם יש לך מספיק זיכרון גרפי (VRAM)",
|
| 332 |
+
"gradient_accumulation": "צבירת גרדיאנטים (Accumulation)",
|
| 333 |
+
"gradient_accumulation_info": "גודל מנה אפקטיבי = גודל מנה × צבירה",
|
| 334 |
+
"save_every_n_epochs": "שמור כל N תקופות (Epochs)",
|
| 335 |
+
"shift": "Shift (הסטה)",
|
| 336 |
+
"shift_info": "הסטת צעדי זמן עבור מודל turbo",
|
| 337 |
+
"seed": "גרעין (Seed)",
|
| 338 |
+
"output_dir": "ספריית פלט",
|
| 339 |
+
"output_dir_info": "ספרייה לשמירת משקולות ה-LoRA המאומנות",
|
| 340 |
+
"start_training_btn": "🚀 התחלת אימון",
|
| 341 |
+
"stop_training_btn": "⏹️ עצירת אימון",
|
| 342 |
+
"training_progress": "התקדמות האימון",
|
| 343 |
+
"training_log": "יומן אימון",
|
| 344 |
+
"training_loss_title": "הפסד אימון (Training Loss)",
|
| 345 |
+
"step": "צעד",
|
| 346 |
+
"loss": "הפסד (Loss)",
|
| 347 |
+
"export_header": "ייצוא LoRA",
|
| 348 |
+
"export_path": "נתיב ייצוא",
|
| 349 |
+
"export_lora_btn": "📦 ייצוא LoRA",
|
| 350 |
+
"export_status": "מצב ייצוא"
|
| 351 |
+
}
|
| 352 |
+
}
|
acestep/gradio_ui/i18n/ja.json
ADDED
|
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
|
| 4 |
+
"subtitle": "オープンソース音楽生成の限界を押し広げる"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 データセットエクスプローラー",
|
| 8 |
+
"dataset_label": "データセット",
|
| 9 |
+
"dataset_info": "探索するデータセットを選択",
|
| 10 |
+
"import_btn": "📥 データセットをインポート",
|
| 11 |
+
"search_type_label": "検索タイプ",
|
| 12 |
+
"search_type_info": "アイテムの検索方法",
|
| 13 |
+
"search_value_label": "検索値",
|
| 14 |
+
"search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
|
| 15 |
+
"search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
|
| 16 |
+
"instruction_label": "📝 指示",
|
| 17 |
+
"instruction_placeholder": "利用可能な指示がありません",
|
| 18 |
+
"metadata_title": "📋 アイテムメタデータ (JSON)",
|
| 19 |
+
"metadata_label": "完全なアイテム情報",
|
| 20 |
+
"source_audio": "ソースオーディオ",
|
| 21 |
+
"target_audio": "ターゲットオーディオ",
|
| 22 |
+
"reference_audio": "リファレンスオーディオ",
|
| 23 |
+
"get_item_btn": "🔍 アイテムを取得",
|
| 24 |
+
"use_src_checkbox": "データセットのソースオーディオを使用",
|
| 25 |
+
"use_src_info": "データセットのソースオーディオを使用する場合はチェック",
|
| 26 |
+
"data_status_label": "📊 データステータス",
|
| 27 |
+
"data_status_default": "❌ データセットがインポートされていません",
|
| 28 |
+
"autofill_btn": "📋 生成フォームを自動入力"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 サービス設定",
|
| 32 |
+
"checkpoint_label": "チェックポイントファイル",
|
| 33 |
+
"checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
|
| 34 |
+
"refresh_btn": "🔄 更新",
|
| 35 |
+
"model_path_label": "メインモデルパス",
|
| 36 |
+
"model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
|
| 37 |
+
"device_label": "デバイス",
|
| 38 |
+
"device_info": "処理デバイス(自動検出を推奨)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM モデルパス",
|
| 40 |
+
"lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
|
| 41 |
+
"backend_label": "5Hz LM バックエンド",
|
| 42 |
+
"backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
|
| 43 |
+
"init_llm_label": "5Hz LM を初期化",
|
| 44 |
+
"init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
|
| 45 |
+
"flash_attention_label": "Flash Attention を使用",
|
| 46 |
+
"flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
|
| 48 |
+
"offload_cpu_label": "CPUにオフロード",
|
| 49 |
+
"offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
|
| 50 |
+
"offload_dit_cpu_label": "DiTをCPUにオフロード",
|
| 51 |
+
"offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
|
| 52 |
+
"compile_model_label": "モデルをコンパイル",
|
| 53 |
+
"compile_model_info": "torch.compileでモデルを最適化(量子化に必要)",
|
| 54 |
+
"quantization_label": "INT8 量子化",
|
| 55 |
+
"quantization_info": "INT8重み量子化を有効にしてVRAMを節約(モデルのコンパイルが必要)",
|
| 56 |
+
"init_btn": "サービスを初期化",
|
| 57 |
+
"status_label": "ステータス",
|
| 58 |
+
"language_label": "UI言語",
|
| 59 |
+
"language_info": "インターフェース言語を選択"
|
| 60 |
+
},
|
| 61 |
+
"generation": {
|
| 62 |
+
"required_inputs": "📝 必須入力",
|
| 63 |
+
"task_type_label": "タスクタイプ",
|
| 64 |
+
"task_type_info": "生成のタスクタイプを選択",
|
| 65 |
+
"instruction_label": "指示",
|
| 66 |
+
"instruction_info": "指示はタスクタイプに基づいて自動生成されます",
|
| 67 |
+
"load_btn": "読み込む",
|
| 68 |
+
"track_name_label": "トラック名",
|
| 69 |
+
"track_name_info": "lego/extractタスクのトラック名を選択",
|
| 70 |
+
"track_classes_label": "トラック名",
|
| 71 |
+
"track_classes_info": "completeタスクの複数のトラッククラスを選択",
|
| 72 |
+
"audio_uploads": "🎵 オーディオアップロード",
|
| 73 |
+
"reference_audio": "リファレンスオーディオ(オプション)",
|
| 74 |
+
"source_audio": "ソースオーディオ(オプション)",
|
| 75 |
+
"convert_codes_btn": "コードに変換",
|
| 76 |
+
"lm_codes_hints": "🎼 LM コードヒント",
|
| 77 |
+
"lm_codes_label": "LM コードヒント",
|
| 78 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 79 |
+
"lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
|
| 80 |
+
"lm_codes_sample": "LM コードヒント(サンプル {n})",
|
| 81 |
+
"lm_codes_sample_info": "サ���プル{n}のコード",
|
| 82 |
+
"transcribe_btn": "転写",
|
| 83 |
+
"repainting_controls": "🎨 再描画コントロール(秒)",
|
| 84 |
+
"repainting_start": "再描画開始",
|
| 85 |
+
"repainting_end": "再描画終了",
|
| 86 |
+
"mode_label": "生成モード",
|
| 87 |
+
"mode_info": "シンプル:自然言語で音楽を説明。カスタム:キャプションと歌詞を完全にコントロール。",
|
| 88 |
+
"mode_simple": "シンプル",
|
| 89 |
+
"mode_custom": "カスタム",
|
| 90 |
+
"simple_query_label": "曲の説明",
|
| 91 |
+
"simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
|
| 92 |
+
"simple_query_info": "生成したい音楽の自然言語の説明を入力",
|
| 93 |
+
"simple_vocal_language_label": "ボーカル言語(オプション)",
|
| 94 |
+
"simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
|
| 95 |
+
"create_sample_btn": "サンプル作成",
|
| 96 |
+
"caption_title": "📝 音楽キャプション",
|
| 97 |
+
"caption_label": "音楽キャプション(オプション)",
|
| 98 |
+
"caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
|
| 99 |
+
"caption_info": "スタイル、ジャンル、楽器、ムードを説明",
|
| 100 |
+
"lyrics_title": "📝 歌詞",
|
| 101 |
+
"lyrics_label": "歌詞(オプション)",
|
| 102 |
+
"lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
|
| 103 |
+
"lyrics_info": "構造を持つ曲の歌詞",
|
| 104 |
+
"instrumental_label": "インストゥルメンタル",
|
| 105 |
+
"format_btn": "フォーマット",
|
| 106 |
+
"optional_params": "⚙️ オプションパラメータ",
|
| 107 |
+
"vocal_language_label": "ボーカル言語(オプション)",
|
| 108 |
+
"vocal_language_info": "インストには`unknown`を使用",
|
| 109 |
+
"bpm_label": "BPM(オプション)",
|
| 110 |
+
"bpm_info": "空白の場合はN/A",
|
| 111 |
+
"keyscale_label": "キースケール(オプション)",
|
| 112 |
+
"keyscale_placeholder": "空白の場合はN/A",
|
| 113 |
+
"keyscale_info": "A-G, #/♭, メジャー/マイナー",
|
| 114 |
+
"timesig_label": "拍子記号(オプション)",
|
| 115 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 116 |
+
"duration_label": "オーディオ長(秒)",
|
| 117 |
+
"duration_info": "ランダムの場合は-1を使用",
|
| 118 |
+
"batch_size_label": "バッチサイズ",
|
| 119 |
+
"batch_size_info": "生成するオーディオの数(最大8)",
|
| 120 |
+
"advanced_settings": "🔧 詳細設定",
|
| 121 |
+
"inference_steps_label": "DiT 推論ステップ",
|
| 122 |
+
"inference_steps_info": "Turbo: 最大8、Base: 最大200",
|
| 123 |
+
"guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
|
| 124 |
+
"guidance_scale_info": "値が高いほどテキストに忠実に従う",
|
| 125 |
+
"seed_label": "シード",
|
| 126 |
+
"seed_info": "バッチにはカンマ区切りの値を使用",
|
| 127 |
+
"random_seed_label": "ランダムシード",
|
| 128 |
+
"random_seed_info": "有効にすると自動的にシードを生成",
|
| 129 |
+
"audio_format_label": "オーディオフォーマット",
|
| 130 |
+
"audio_format_info": "保存ファイルのオーディオフォーマット",
|
| 131 |
+
"use_adg_label": "ADG を使用",
|
| 132 |
+
"use_adg_info": "角度ドメインガイダンスを有効化",
|
| 133 |
+
"shift_label": "シフト",
|
| 134 |
+
"shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
|
| 135 |
+
"infer_method_label": "推論方法",
|
| 136 |
+
"infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
|
| 137 |
+
"custom_timesteps_label": "カスタムタイムステップ",
|
| 138 |
+
"custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
|
| 139 |
+
"cfg_interval_start": "CFG 間隔開始",
|
| 140 |
+
"cfg_interval_end": "CFG 間隔終了",
|
| 141 |
+
"lm_params_title": "🤖 LM 生成パラメータ",
|
| 142 |
+
"lm_temperature_label": "LM 温度",
|
| 143 |
+
"lm_temperature_info": "5Hz LM温度(高いほどランダム)",
|
| 144 |
+
"lm_cfg_scale_label": "LM CFG スケール",
|
| 145 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
|
| 146 |
+
"lm_top_k_label": "LM Top-K",
|
| 147 |
+
"lm_top_k_info": "Top-K (0 = 無効)",
|
| 148 |
+
"lm_top_p_label": "LM Top-P",
|
| 149 |
+
"lm_top_p_info": "Top-P (1.0 = 無効)",
|
| 150 |
+
"lm_negative_prompt_label": "LM ネガティブプロンプト",
|
| 151 |
+
"lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
|
| 152 |
+
"lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
|
| 153 |
+
"cot_metas_label": "CoT メタデータ",
|
| 154 |
+
"cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
|
| 155 |
+
"cot_language_label": "CoT 言語",
|
| 156 |
+
"cot_language_info": "CoTで言語を生成(思考の連鎖)",
|
| 157 |
+
"constrained_debug_label": "制約付きデコーディングデバッグ",
|
| 158 |
+
"constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
|
| 159 |
+
"auto_score_label": "自動スコアリング",
|
| 160 |
+
"auto_score_info": "生成されたすべてのオーディオの品質スコアを自動計算",
|
| 161 |
+
"auto_lrc_label": "自動 LRC",
|
| 162 |
+
"auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
|
| 163 |
+
"lm_batch_chunk_label": "LM バッチチャンクサイズ",
|
| 164 |
+
"lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
|
| 165 |
+
"codes_strength_label": "LM コード強度",
|
| 166 |
+
"codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
|
| 167 |
+
"similarity_denoise_label": "類似度 / ノイズ除去",
|
| 168 |
+
"similarity_denoise_info": "出力が参照オーディオにどれだけ忠実かを制御します。高い値ほど構造を保持します。",
|
| 169 |
+
"cover_strength_label": "オーディオカバー強度",
|
| 170 |
+
"cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
|
| 171 |
+
"score_sensitivity_label": "品質スコア感度",
|
| 172 |
+
"score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
|
| 173 |
+
"think_label": "思考",
|
| 174 |
+
"parallel_thinking_label": "並列思考",
|
| 175 |
+
"generate_btn": "🎵 音楽を生成",
|
| 176 |
+
"autogen_label": "自動生成",
|
| 177 |
+
"caption_rewrite_label": "キャプション書き換え"
|
| 178 |
+
},
|
| 179 |
+
"results": {
|
| 180 |
+
"title": "🎵 結果",
|
| 181 |
+
"generated_music": "🎵 生成された音楽(サンプル {n})",
|
| 182 |
+
"send_to_src_btn": "🔗 ソースオーディオに送信",
|
| 183 |
+
"save_btn": "💾 保存",
|
| 184 |
+
"score_btn": "📊 スコア",
|
| 185 |
+
"lrc_btn": "🎵 LRC",
|
| 186 |
+
"quality_score_label": "品質スコア(サンプル {n})",
|
| 187 |
+
"quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
|
| 188 |
+
"codes_label": "LM コード(サンプル {n})",
|
| 189 |
+
"lrc_label": "歌詞タイムスタンプ(サンプル {n})",
|
| 190 |
+
"lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
|
| 191 |
+
"details_accordion": "📊 スコア & LRC & LM コード",
|
| 192 |
+
"generation_status": "生成ステータス",
|
| 193 |
+
"current_batch": "現在のバッチ",
|
| 194 |
+
"batch_indicator": "バッチ {current} / {total}",
|
| 195 |
+
"next_batch_status": "次のバッチステータス",
|
| 196 |
+
"prev_btn": "◀ 前へ",
|
| 197 |
+
"next_btn": "次へ ▶",
|
| 198 |
+
"restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
|
| 199 |
+
"batch_results_title": "👇 クリックしてバッチ結果と生成詳細を表示",
|
| 200 |
+
"all_files_label": "📁 すべての生成ファイル(ダウンロード)",
|
| 201 |
+
"generation_details": "生成詳細"
|
| 202 |
+
},
|
| 203 |
+
"messages": {
|
| 204 |
+
"no_audio_to_save": "❌ 保存するオーディオがありません",
|
| 205 |
+
"save_success": "✅ オーディオとメタデータを {filename} に保存しました",
|
| 206 |
+
"save_failed": "❌ 保存に失敗しました: {error}",
|
| 207 |
+
"no_file_selected": "⚠️ ファイルが選択されていません",
|
| 208 |
+
"params_loaded": "✅ {filename} からパラメータを読み込みました",
|
| 209 |
+
"invalid_json": "❌ 無効なJSONファイル: {error}",
|
| 210 |
+
"load_error": "❌ ファイルの読み込みエラー: {error}",
|
| 211 |
+
"example_loaded": "📁 {filename} からサンプルを読み込みました",
|
| 212 |
+
"example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
|
| 213 |
+
"example_error": "サンプル読み込みエラー: {error}",
|
| 214 |
+
"lm_generated": "🤖 LMを使用してサンプルを生成しました",
|
| 215 |
+
"lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
|
| 216 |
+
"lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
|
| 217 |
+
"autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
|
| 218 |
+
"batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
|
| 219 |
+
"batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
|
| 220 |
+
"batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
|
| 221 |
+
"viewing_batch": "✅ バッチ {n} を表示中",
|
| 222 |
+
"at_first_batch": "すでに最初のバッチです",
|
| 223 |
+
"at_last_batch": "次のバッチはありません",
|
| 224 |
+
"batch_not_found": "キューにバッチ {n} が見つかりません",
|
| 225 |
+
"no_batch_data": "復元するバッチデータがありません。",
|
| 226 |
+
"params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
|
| 227 |
+
"scoring_failed": "❌ エラー: バッチデータが見つかりません",
|
| 228 |
+
"no_codes": "❌ 利用可能なオーディオ��ードがありません。最初に音楽を生成してください。",
|
| 229 |
+
"score_failed": "❌ スコアリングに失敗しました: {error}",
|
| 230 |
+
"score_error": "❌ スコア計算エラー: {error}",
|
| 231 |
+
"lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
|
| 232 |
+
"lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
|
| 233 |
+
"lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
|
| 234 |
+
"lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
|
| 235 |
+
"lrc_empty_result": "⚠️ LRC生成の結果が空です。",
|
| 236 |
+
"empty_query": "⚠️ 音楽の説明を入力してください。",
|
| 237 |
+
"sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
|
| 238 |
+
"sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
|
| 239 |
+
"simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
|
| 240 |
+
"simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
|
| 241 |
+
"simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
|
| 242 |
+
"format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
|
| 243 |
+
"format_failed": "❌ フォーマットに失敗しました: {error}",
|
| 244 |
+
"skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
|
| 245 |
+
"invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
|
| 246 |
+
"timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
|
| 247 |
+
"timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
|
| 248 |
+
},
|
| 249 |
+
"training": {
|
| 250 |
+
"tab_title": "🎓 LoRA トレーニング",
|
| 251 |
+
"tab_dataset_builder": "📁 データセットビルダー",
|
| 252 |
+
"tab_train_lora": "🚀 LoRA をトレーニング",
|
| 253 |
+
"quick_start_title": "🚀 クイックスタート",
|
| 254 |
+
"load_dataset_label": "データセット JSON パス",
|
| 255 |
+
"load_dataset_info": "以前保存したデータセットを読み込む",
|
| 256 |
+
"load_btn": "📂 読み込み",
|
| 257 |
+
"load_status": "読み込み状態",
|
| 258 |
+
"scan_label": "オーディオディレクトリパス",
|
| 259 |
+
"scan_info": "オーディオファイルをスキャン(wav、mp3、flac、ogg、opus)",
|
| 260 |
+
"scan_btn": "🔍 スキャン",
|
| 261 |
+
"scan_status": "スキャン状態",
|
| 262 |
+
"found_audio_files": "見つかったオーディオファイル",
|
| 263 |
+
"dataset_name": "データセット名",
|
| 264 |
+
"dataset_name_placeholder": "データセット名を入力",
|
| 265 |
+
"dataset_settings_header": "データセット設定",
|
| 266 |
+
"tag_prepend": "前置(タグ、キャプション)",
|
| 267 |
+
"tag_append": "後置(キャプション、タグ)",
|
| 268 |
+
"tag_replace": "キャプションを置換",
|
| 269 |
+
"step2_title": "ステップ 2: AI で自動ラベル",
|
| 270 |
+
"step3_title": "ステップ 3: プレビューと編集",
|
| 271 |
+
"step4_title": "ステップ 4: データセットを保存",
|
| 272 |
+
"step5_title": "ステップ 5: テンソルに前処理",
|
| 273 |
+
"all_instrumental": "すべてインストゥルメンタル",
|
| 274 |
+
"all_instrumental_info": "すべてのトラックがインストゥルメンタル(ボーカルなし)の場合にチェック",
|
| 275 |
+
"custom_tag": "カスタムアクティベーションタグ",
|
| 276 |
+
"custom_tag_info": "この LoRA のスタイルを有効にする一意のタグ",
|
| 277 |
+
"tag_position": "タグの位置",
|
| 278 |
+
"tag_position_info": "キャプション内でカスタムタグを配置する位置",
|
| 279 |
+
"genre_ratio": "ジャンル比率 (%)",
|
| 280 |
+
"genre_ratio_info": "0%=すべてキャプション、100%=すべてジャンル。サンプル単位の上書きが優先。",
|
| 281 |
+
"skip_metas": "BPM/キー/拍子をスキップ",
|
| 282 |
+
"skip_metas_info": "BPM/キー/拍子の生成をスキップ。キャプションとジャンルは LM が生成。",
|
| 283 |
+
"only_unlabeled": "未ラベルのみ",
|
| 284 |
+
"only_unlabeled_info": "キャプションのないサンプルのみラベル付け(失敗したラベル付けの再開に便利)",
|
| 285 |
+
"auto_label_btn": "🏷️ 一括自動ラベル",
|
| 286 |
+
"label_progress": "ラベル付け進捗",
|
| 287 |
+
"select_sample": "サンプル # を選択",
|
| 288 |
+
"select_sample_info": "プレビューと編集するサンプルを選択",
|
| 289 |
+
"audio_preview": "オーディオプレビュー",
|
| 290 |
+
"filename": "ファイル名",
|
| 291 |
+
"caption": "キャプション",
|
| 292 |
+
"genre": "ジャンル",
|
| 293 |
+
"prompt_override_label": "プロンプト上書き(このサ���プル)",
|
| 294 |
+
"prompt_override_info": "このサンプルのグローバル比率を上書き",
|
| 295 |
+
"lyrics_editable_label": "歌詞(編集可、トレーニング用)",
|
| 296 |
+
"raw_lyrics_label": "生歌詞(.txt ファイルから)",
|
| 297 |
+
"no_lyrics_placeholder": "(.txt 歌詞ファイルなし)",
|
| 298 |
+
"bpm": "BPM",
|
| 299 |
+
"key_label": "キー",
|
| 300 |
+
"key_placeholder": "C Major",
|
| 301 |
+
"time_sig": "拍子",
|
| 302 |
+
"duration_s": "長さ (秒)",
|
| 303 |
+
"language": "言語",
|
| 304 |
+
"instrumental": "インストゥルメンタル",
|
| 305 |
+
"save_changes_btn": "💾 変更を保存",
|
| 306 |
+
"edit_status": "編集状態",
|
| 307 |
+
"save_path": "保存パス",
|
| 308 |
+
"save_path_info": "データセット JSON の保存先パス",
|
| 309 |
+
"save_dataset_btn": "💾 データセットを保存",
|
| 310 |
+
"save_status": "保存状態",
|
| 311 |
+
"load_existing_label": "既存データセットを読み込み(任意)",
|
| 312 |
+
"load_existing_info": "以前保存したデータセット JSON ファイルのパス",
|
| 313 |
+
"load_dataset_btn": "📂 データセットを読み込み",
|
| 314 |
+
"tensor_output_dir": "テンソル出力ディレクトリ",
|
| 315 |
+
"tensor_output_info": "前処理済みテンソルファイルの保存先ディレクトリ",
|
| 316 |
+
"preprocess_btn": "⚡ 前処理",
|
| 317 |
+
"preprocess_progress": "前処理進捗",
|
| 318 |
+
"preprocessed_tensors_dir": "前処理済みテンソルディレクトリ",
|
| 319 |
+
"preprocessed_tensors_info": "前処理済み .pt テンソルファイルを含むディレクトリ",
|
| 320 |
+
"train_section_tensors": "前処理済みデータセット選択",
|
| 321 |
+
"train_section_lora": "LoRA 設定",
|
| 322 |
+
"train_section_params": "トレーニングパラメータ",
|
| 323 |
+
"dataset_info": "データセット情報",
|
| 324 |
+
"lora_rank": "LoRA ランク (r)",
|
| 325 |
+
"lora_rank_info": "高いほど容量は増えるがメモリ使用量も増加",
|
| 326 |
+
"lora_alpha": "LoRA Alpha",
|
| 327 |
+
"lora_alpha_info": "スケーリング係数(通常はランクの2倍)",
|
| 328 |
+
"lora_dropout": "LoRA Dropout",
|
| 329 |
+
"learning_rate": "学習率",
|
| 330 |
+
"learning_rate_info": "3e-4 から始め、必要に応じて調整",
|
| 331 |
+
"max_epochs": "最大エポック数",
|
| 332 |
+
"batch_size": "バッチサイズ",
|
| 333 |
+
"batch_size_info": "VRAM に余裕があれば増やせます",
|
| 334 |
+
"gradient_accumulation": "勾配累積",
|
| 335 |
+
"gradient_accumulation_info": "実効バッチ = batch_size × 累積",
|
| 336 |
+
"save_every_n_epochs": "N エポックごとに保存",
|
| 337 |
+
"shift": "Shift",
|
| 338 |
+
"shift_info": "ターボモデル用タイムステップシフト",
|
| 339 |
+
"seed": "シード",
|
| 340 |
+
"output_dir": "出力ディレクトリ",
|
| 341 |
+
"output_dir_info": "トレーニング済み LoRA 重みの保存先ディレクトリ",
|
| 342 |
+
"start_training_btn": "🚀 トレーニング開始",
|
| 343 |
+
"stop_training_btn": "⏹️ トレーニング停止",
|
| 344 |
+
"training_progress": "トレーニング進捗",
|
| 345 |
+
"training_log": "トレーニングログ",
|
| 346 |
+
"training_loss_title": "トレーニング損失",
|
| 347 |
+
"step": "ステップ",
|
| 348 |
+
"loss": "損失",
|
| 349 |
+
"export_header": "LoRA をエクスポート",
|
| 350 |
+
"export_path": "エクスポートパス",
|
| 351 |
+
"export_lora_btn": "📦 LoRA をエクスポート",
|
| 352 |
+
"export_status": "エクスポート状態"
|
| 353 |
+
}
|
| 354 |
+
}
|
acestep/gradio_ui/i18n/zh.json
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"app": {
|
| 3 |
+
"title": "🎛️ ACE-Step V1.5 演练场💡",
|
| 4 |
+
"subtitle": "推动开源音乐生成的边界"
|
| 5 |
+
},
|
| 6 |
+
"dataset": {
|
| 7 |
+
"title": "📊 数据集浏览器",
|
| 8 |
+
"dataset_label": "数据集",
|
| 9 |
+
"dataset_info": "选择要浏览的数据集",
|
| 10 |
+
"import_btn": "📥 导入数据集",
|
| 11 |
+
"search_type_label": "搜索类型",
|
| 12 |
+
"search_type_info": "如何查找项目",
|
| 13 |
+
"search_value_label": "搜索值",
|
| 14 |
+
"search_value_placeholder": "输入键或索引(留空表示随机)",
|
| 15 |
+
"search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
|
| 16 |
+
"instruction_label": "📝 指令",
|
| 17 |
+
"instruction_placeholder": "无可用指令",
|
| 18 |
+
"metadata_title": "📋 项目元数据 (JSON)",
|
| 19 |
+
"metadata_label": "完整项目信息",
|
| 20 |
+
"source_audio": "源音频",
|
| 21 |
+
"target_audio": "目标音频",
|
| 22 |
+
"reference_audio": "参考音频",
|
| 23 |
+
"get_item_btn": "🔍 获取项目",
|
| 24 |
+
"use_src_checkbox": "使用数据集中的源音频",
|
| 25 |
+
"use_src_info": "勾选以使用数据集中的源音频",
|
| 26 |
+
"data_status_label": "📊 数据状态",
|
| 27 |
+
"data_status_default": "❌ 未导入数据集",
|
| 28 |
+
"autofill_btn": "📋 自动填充生成表单"
|
| 29 |
+
},
|
| 30 |
+
"service": {
|
| 31 |
+
"title": "🔧 服务配置",
|
| 32 |
+
"checkpoint_label": "检查点文件",
|
| 33 |
+
"checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
|
| 34 |
+
"refresh_btn": "🔄 刷新",
|
| 35 |
+
"model_path_label": "主模型路径",
|
| 36 |
+
"model_path_info": "选择模型配置目录(从检查点自动扫描)",
|
| 37 |
+
"device_label": "设备",
|
| 38 |
+
"device_info": "处理设备(建议自动检测)",
|
| 39 |
+
"lm_model_path_label": "5Hz LM 模型路径",
|
| 40 |
+
"lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
|
| 41 |
+
"backend_label": "5Hz LM 后端",
|
| 42 |
+
"backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
|
| 43 |
+
"init_llm_label": "初始化 5Hz LM",
|
| 44 |
+
"init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
|
| 45 |
+
"flash_attention_label": "使用Flash Attention",
|
| 46 |
+
"flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
|
| 47 |
+
"flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
|
| 48 |
+
"offload_cpu_label": "卸载到CPU",
|
| 49 |
+
"offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
|
| 50 |
+
"offload_dit_cpu_label": "将DiT卸载到CPU",
|
| 51 |
+
"offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
|
| 52 |
+
"compile_model_label": "编译模型",
|
| 53 |
+
"compile_model_info": "使用 torch.compile 优化模型(量化必需)",
|
| 54 |
+
"quantization_label": "INT8 量化",
|
| 55 |
+
"quantization_info": "启用 INT8 仅权重量化以减少显存占用(需要启用编译模型)",
|
| 56 |
+
"init_btn": "初始化服务",
|
| 57 |
+
"status_label": "状态",
|
| 58 |
+
"language_label": "界面语言",
|
| 59 |
+
"language_info": "选择界面语言"
|
| 60 |
+
},
|
| 61 |
+
"generation": {
|
| 62 |
+
"required_inputs": "📝 必需输入",
|
| 63 |
+
"task_type_label": "任务类型",
|
| 64 |
+
"task_type_info": "选择生成的任务类型",
|
| 65 |
+
"instruction_label": "指令",
|
| 66 |
+
"instruction_info": "指令根据任务类型自动生成",
|
| 67 |
+
"load_btn": "加载",
|
| 68 |
+
"track_name_label": "音轨名称",
|
| 69 |
+
"track_name_info": "为lego/extract任务选择音轨名称",
|
| 70 |
+
"track_classes_label": "音轨名称",
|
| 71 |
+
"track_classes_info": "为complete任务选择多个音轨类别",
|
| 72 |
+
"audio_uploads": "🎵 音频上传",
|
| 73 |
+
"reference_audio": "参考音频(可选)",
|
| 74 |
+
"source_audio": "源音频(可选)",
|
| 75 |
+
"convert_codes_btn": "转换为代码",
|
| 76 |
+
"lm_codes_hints": "🎼 LM 代码提示",
|
| 77 |
+
"lm_codes_label": "LM 代码提示",
|
| 78 |
+
"lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
|
| 79 |
+
"lm_codes_info": "粘贴用于text2music生成的LM代码提示",
|
| 80 |
+
"lm_codes_sample": "LM 代码提示(样本 {n})",
|
| 81 |
+
"lm_codes_sample_info": "样本{n}的代码",
|
| 82 |
+
"transcribe_btn": "转录",
|
| 83 |
+
"repainting_controls": "🎨 重绘控制(秒)",
|
| 84 |
+
"repainting_start": "重绘开始",
|
| 85 |
+
"repainting_end": "重绘结束",
|
| 86 |
+
"mode_label": "生成模式",
|
| 87 |
+
"mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
|
| 88 |
+
"mode_simple": "简单",
|
| 89 |
+
"mode_custom": "自定义",
|
| 90 |
+
"simple_query_label": "歌曲描述",
|
| 91 |
+
"simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
|
| 92 |
+
"simple_query_info": "输入你想生成的音乐的自然语言描述",
|
| 93 |
+
"simple_vocal_language_label": "人声语言(可选)",
|
| 94 |
+
"simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
|
| 95 |
+
"create_sample_btn": "创建样本",
|
| 96 |
+
"caption_title": "📝 音乐描述",
|
| 97 |
+
"caption_label": "音乐描述(可选)",
|
| 98 |
+
"caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
|
| 99 |
+
"caption_info": "描述风格、流派、乐器���情绪",
|
| 100 |
+
"lyrics_title": "📝 歌词",
|
| 101 |
+
"lyrics_label": "歌词(可选)",
|
| 102 |
+
"lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
|
| 103 |
+
"lyrics_info": "带有结构的歌曲歌词",
|
| 104 |
+
"instrumental_label": "纯音乐",
|
| 105 |
+
"format_btn": "格式化",
|
| 106 |
+
"optional_params": "⚙️ 可选参数",
|
| 107 |
+
"vocal_language_label": "人声语言(可选)",
|
| 108 |
+
"vocal_language_info": "纯音乐使用 `unknown`",
|
| 109 |
+
"bpm_label": "BPM(可选)",
|
| 110 |
+
"bpm_info": "留空表示N/A",
|
| 111 |
+
"keyscale_label": "调性(可选)",
|
| 112 |
+
"keyscale_placeholder": "留空表示N/A",
|
| 113 |
+
"keyscale_info": "A-G, #/♭, 大调/小调",
|
| 114 |
+
"timesig_label": "拍号(可选)",
|
| 115 |
+
"timesig_info": "2/4, 3/4, 4/4...",
|
| 116 |
+
"duration_label": "音频时长(秒)",
|
| 117 |
+
"duration_info": "使用-1表示随机",
|
| 118 |
+
"batch_size_label": "批量大小",
|
| 119 |
+
"batch_size_info": "要生成的音频数量(最多8个)",
|
| 120 |
+
"advanced_settings": "🔧 高级设置",
|
| 121 |
+
"inference_steps_label": "DiT 推理步数",
|
| 122 |
+
"inference_steps_info": "Turbo: 最多8, Base: 最多200",
|
| 123 |
+
"guidance_scale_label": "DiT 引导比例(仅支持base模型)",
|
| 124 |
+
"guidance_scale_info": "更高的值更紧密地遵循文本",
|
| 125 |
+
"seed_label": "种子",
|
| 126 |
+
"seed_info": "批量使用逗号分隔的值",
|
| 127 |
+
"random_seed_label": "随机种子",
|
| 128 |
+
"random_seed_info": "启用以自动生成种子",
|
| 129 |
+
"audio_format_label": "音频格式",
|
| 130 |
+
"audio_format_info": "保存文件的音频格式",
|
| 131 |
+
"use_adg_label": "使用 ADG",
|
| 132 |
+
"use_adg_info": "启用角域引导",
|
| 133 |
+
"shift_label": "Shift",
|
| 134 |
+
"shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
|
| 135 |
+
"infer_method_label": "推理方法",
|
| 136 |
+
"infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
|
| 137 |
+
"custom_timesteps_label": "自定义时间步",
|
| 138 |
+
"custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
|
| 139 |
+
"cfg_interval_start": "CFG 间隔开始",
|
| 140 |
+
"cfg_interval_end": "CFG 间隔结束",
|
| 141 |
+
"lm_params_title": "🤖 LM 生成参数",
|
| 142 |
+
"lm_temperature_label": "LM 温度",
|
| 143 |
+
"lm_temperature_info": "5Hz LM温度(越高越随机)",
|
| 144 |
+
"lm_cfg_scale_label": "LM CFG 比例",
|
| 145 |
+
"lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
|
| 146 |
+
"lm_top_k_label": "LM Top-K",
|
| 147 |
+
"lm_top_k_info": "Top-K (0 = 禁用)",
|
| 148 |
+
"lm_top_p_label": "LM Top-P",
|
| 149 |
+
"lm_top_p_info": "Top-P (1.0 = 禁用)",
|
| 150 |
+
"lm_negative_prompt_label": "LM 负面提示",
|
| 151 |
+
"lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
|
| 152 |
+
"lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
|
| 153 |
+
"cot_metas_label": "CoT 元数据",
|
| 154 |
+
"cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
|
| 155 |
+
"cot_language_label": "CoT 语言",
|
| 156 |
+
"cot_language_info": "在CoT中生成语言(思维链)",
|
| 157 |
+
"constrained_debug_label": "约束解码调试",
|
| 158 |
+
"constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
|
| 159 |
+
"auto_score_label": "自动评分",
|
| 160 |
+
"auto_score_info": "自动计算所有生成音频的质量分数",
|
| 161 |
+
"auto_lrc_label": "自动 LRC",
|
| 162 |
+
"auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
|
| 163 |
+
"lm_batch_chunk_label": "LM 批量块大小",
|
| 164 |
+
"lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
|
| 165 |
+
"codes_strength_label": "LM 代码强度",
|
| 166 |
+
"codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
|
| 167 |
+
"similarity_denoise_label": "相似度 / 降噪",
|
| 168 |
+
"similarity_denoise_info": "控制输出与参考音频的贴合程度。数值越高保留越多结构。",
|
| 169 |
+
"cover_strength_label": "音频覆盖强度",
|
| 170 |
+
"cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
|
| 171 |
+
"score_sensitivity_label": "质量评分敏感度",
|
| 172 |
+
"score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
|
| 173 |
+
"think_label": "思考",
|
| 174 |
+
"parallel_thinking_label": "并行思考",
|
| 175 |
+
"generate_btn": "🎵 生成音乐",
|
| 176 |
+
"autogen_label": "自动生成",
|
| 177 |
+
"caption_rewrite_label": "描述重写"
|
| 178 |
+
},
|
| 179 |
+
"results": {
|
| 180 |
+
"title": "🎵 结果",
|
| 181 |
+
"generated_music": "🎵 生成的音乐(样本 {n})",
|
| 182 |
+
"send_to_src_btn": "🔗 发送到源音频",
|
| 183 |
+
"save_btn": "💾 保存",
|
| 184 |
+
"score_btn": "📊 评分",
|
| 185 |
+
"lrc_btn": "🎵 LRC",
|
| 186 |
+
"quality_score_label": "质量分数(样本 {n})",
|
| 187 |
+
"quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
|
| 188 |
+
"codes_label": "LM 代码(样本 {n})",
|
| 189 |
+
"lrc_label": "歌词时间戳(样本 {n})",
|
| 190 |
+
"lrc_placeholder": "点击'LRC'生成时间戳",
|
| 191 |
+
"details_accordion": "📊 评分与LRC与LM代码",
|
| 192 |
+
"generation_status": "生成状态",
|
| 193 |
+
"current_batch": "当前批次",
|
| 194 |
+
"batch_indicator": "��次 {current} / {total}",
|
| 195 |
+
"next_batch_status": "下一批次状态",
|
| 196 |
+
"prev_btn": "◀ 上一个",
|
| 197 |
+
"next_btn": "下一个 ▶",
|
| 198 |
+
"restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
|
| 199 |
+
"batch_results_title": "👇 点击查看批量结果和生成详情",
|
| 200 |
+
"all_files_label": "📁 所有生成的文件(下载)",
|
| 201 |
+
"generation_details": "生成详情"
|
| 202 |
+
},
|
| 203 |
+
"messages": {
|
| 204 |
+
"no_audio_to_save": "❌ 没有要保存的音频",
|
| 205 |
+
"save_success": "✅ 已将音频和元数据保存到 {filename}",
|
| 206 |
+
"save_failed": "❌ 保存失败: {error}",
|
| 207 |
+
"no_file_selected": "⚠️ 未选择文件",
|
| 208 |
+
"params_loaded": "✅ 已从 {filename} 加载参数",
|
| 209 |
+
"invalid_json": "❌ 无效的JSON文件: {error}",
|
| 210 |
+
"load_error": "❌ 加载文件时出错: {error}",
|
| 211 |
+
"example_loaded": "📁 已从 {filename} 加载示例",
|
| 212 |
+
"example_failed": "解析JSON文件 {filename} 失败: {error}",
|
| 213 |
+
"example_error": "加载示例时出错: {error}",
|
| 214 |
+
"lm_generated": "🤖 使用LM生成的示例",
|
| 215 |
+
"lm_fallback": "使用LM生成示例失败,回退到示例目录",
|
| 216 |
+
"lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
|
| 217 |
+
"autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
|
| 218 |
+
"batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
|
| 219 |
+
"batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
|
| 220 |
+
"batch_failed": "❌ 后台生成失败: {error}",
|
| 221 |
+
"viewing_batch": "✅ 查看批次 {n}",
|
| 222 |
+
"at_first_batch": "已在第一批次",
|
| 223 |
+
"at_last_batch": "没有下一批次可用",
|
| 224 |
+
"batch_not_found": "在队列中未找到批次 {n}",
|
| 225 |
+
"no_batch_data": "没有要恢复的批次数据。",
|
| 226 |
+
"params_restored": "✅ 已从批次 {n} 恢复UI参数",
|
| 227 |
+
"scoring_failed": "❌ 错误: 未找到批次数据",
|
| 228 |
+
"no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
|
| 229 |
+
"score_failed": "❌ 评分失败: {error}",
|
| 230 |
+
"score_error": "❌ 计算分数时出错: {error}",
|
| 231 |
+
"lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
|
| 232 |
+
"lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
|
| 233 |
+
"lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
|
| 234 |
+
"lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
|
| 235 |
+
"lrc_empty_result": "⚠️ LRC生成结果为空。",
|
| 236 |
+
"empty_query": "⚠️ 请输入音乐描述。",
|
| 237 |
+
"sample_creation_failed": "❌ 创建样本失败。请重试。",
|
| 238 |
+
"sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
|
| 239 |
+
"simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
|
| 240 |
+
"simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
|
| 241 |
+
"simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
|
| 242 |
+
"format_success": "✅ 描述和歌词格式化成功",
|
| 243 |
+
"format_failed": "❌ 格式化失败: {error}",
|
| 244 |
+
"skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
|
| 245 |
+
"invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
|
| 246 |
+
"timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
|
| 247 |
+
"timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
|
| 248 |
+
},
|
| 249 |
+
"training": {
|
| 250 |
+
"tab_title": "🎓 LoRA 训练",
|
| 251 |
+
"tab_dataset_builder": "📁 数据集构建",
|
| 252 |
+
"tab_train_lora": "🚀 训练 LoRA",
|
| 253 |
+
"quick_start_title": "🚀 快速开始",
|
| 254 |
+
"load_dataset_label": "数据集 JSON 路径",
|
| 255 |
+
"load_btn": "📂 加载",
|
| 256 |
+
"load_status": "加载状态",
|
| 257 |
+
"scan_label": "音频目录路径",
|
| 258 |
+
"scan_info": "扫描音频文件(wav、mp3、flac、ogg、opus)",
|
| 259 |
+
"scan_btn": "🔍 扫描",
|
| 260 |
+
"scan_status": "扫描状态",
|
| 261 |
+
"found_audio_files": "已找到的音频文件",
|
| 262 |
+
"dataset_name": "数据集名称",
|
| 263 |
+
"dataset_name_placeholder": "输入数据集名称",
|
| 264 |
+
"dataset_settings_header": "数据集设置",
|
| 265 |
+
"tag_prepend": "前置(标签,描述)",
|
| 266 |
+
"tag_append": "后置(描述,标签)",
|
| 267 |
+
"tag_replace": "替换描述",
|
| 268 |
+
"step2_title": "步骤 2:AI 自动标注",
|
| 269 |
+
"step3_title": "步骤 3:预览与编辑",
|
| 270 |
+
"step4_title": "步骤 4:保存数据集",
|
| 271 |
+
"step5_title": "步骤 5:预处理为张量",
|
| 272 |
+
"all_instrumental": "全部为纯音乐",
|
| 273 |
+
"all_instrumental_info": "勾选表示所有曲目均为纯音乐(无人声)",
|
| 274 |
+
"custom_tag": "自定义激活标签",
|
| 275 |
+
"custom_tag_info": "用于激活此 LoRA 风格的唯一标签",
|
| 276 |
+
"tag_position": "标签位置",
|
| 277 |
+
"tag_position_info": "在描述中放置自定义标签的位置",
|
| 278 |
+
"genre_ratio": "风格比例 (%)",
|
| 279 |
+
"genre_ratio_info": "0%=全部描述,100%=全部风格。单样本覆盖优先。",
|
| 280 |
+
"skip_metas": "跳过 BPM/调性/拍号",
|
| 281 |
+
"skip_metas_info": "跳过 BPM/调性/拍号生成。描述和风格仍由 LM 生成。",
|
| 282 |
+
"only_unlabeled": "仅未标注",
|
| 283 |
+
"only_unlabeled_info": "仅标注无描述的样本(用于继续失败的标注)",
|
| 284 |
+
"auto_label_btn": "🏷️ 自动标注全部",
|
| 285 |
+
"label_progress": "标注进度",
|
| 286 |
+
"select_sample": "选择样本 #",
|
| 287 |
+
"select_sample_info": "选择要预览和编辑的样本",
|
| 288 |
+
"audio_preview": "音频预览",
|
| 289 |
+
"filename": "文件名",
|
| 290 |
+
"caption": "描述",
|
| 291 |
+
"genre": "风格",
|
| 292 |
+
"prompt_override_label": "提示覆盖(本样本)",
|
| 293 |
+
"prompt_override_info": "覆盖本样本的全局比例",
|
| 294 |
+
"lyrics_editable_label": "歌词(可编辑,用于训练)",
|
| 295 |
+
"raw_lyrics_label": "原始歌词(来自 .txt 文件)",
|
| 296 |
+
"no_lyrics_placeholder": "(无 .txt 歌词文件)",
|
| 297 |
+
"bpm": "BPM",
|
| 298 |
+
"key_label": "调性",
|
| 299 |
+
"key_placeholder": "C 大调",
|
| 300 |
+
"time_sig": "拍号",
|
| 301 |
+
"duration_s": "时长 (秒)",
|
| 302 |
+
"language": "语言",
|
| 303 |
+
"instrumental": "纯音乐",
|
| 304 |
+
"save_changes_btn": "💾 保存更改",
|
| 305 |
+
"edit_status": "编辑状态",
|
| 306 |
+
"save_path": "保存路径",
|
| 307 |
+
"save_path_info": "数据集 JSON 的保存路径",
|
| 308 |
+
"save_dataset_btn": "💾 保存数据集",
|
| 309 |
+
"save_status": "保存状态",
|
| 310 |
+
"load_existing_label": "加载已有数据集(可选)",
|
| 311 |
+
"load_existing_info": "之前保存的数据集 JSON 文件路径",
|
| 312 |
+
"load_dataset_btn": "📂 加载数据集",
|
| 313 |
+
"tensor_output_dir": "张量输出目录",
|
| 314 |
+
"tensor_output_info": "保存预处理张量文件的目录",
|
| 315 |
+
"preprocess_btn": "⚡ 预处理",
|
| 316 |
+
"preprocess_progress": "预处理进度",
|
| 317 |
+
"preprocessed_tensors_dir": "预处理张量目录",
|
| 318 |
+
"preprocessed_tensors_info": "包含预处理 .pt 张量文件的目录",
|
| 319 |
+
"dataset_info": "数据集信息",
|
| 320 |
+
"lora_rank": "LoRA 秩 (r)",
|
| 321 |
+
"lora_rank_info": "越高容量越大,显存占用越多",
|
| 322 |
+
"lora_alpha": "LoRA Alpha",
|
| 323 |
+
"lora_alpha_info": "缩放因子(通常为 2× 秩)",
|
| 324 |
+
"lora_dropout": "LoRA Dropout",
|
| 325 |
+
"learning_rate": "学习率",
|
| 326 |
+
"learning_rate_info": "建议从 3e-4 开始,按需调整",
|
| 327 |
+
"max_epochs": "最大轮数",
|
| 328 |
+
"batch_size": "批大小",
|
| 329 |
+
"batch_size_info": "显存充足时可增大",
|
| 330 |
+
"gradient_accumulation": "梯度累积",
|
| 331 |
+
"gradient_accumulation_info": "有效批大小 = batch_size × 累积步数",
|
| 332 |
+
"save_every_n_epochs": "每 N 轮保存",
|
| 333 |
+
"shift": "Shift",
|
| 334 |
+
"shift_info": "Turbo 模型时间步偏移",
|
| 335 |
+
"seed": "随机种子",
|
| 336 |
+
"output_dir": "输出目录",
|
| 337 |
+
"output_dir_info": "保存训练后 LoRA 权重的目录",
|
| 338 |
+
"start_training_btn": "🚀 开始训练",
|
| 339 |
+
"stop_training_btn": "⏹️ 停止训练",
|
| 340 |
+
"training_progress": "训练进度",
|
| 341 |
+
"training_log": "训练日志",
|
| 342 |
+
"training_loss_title": "训练损失",
|
| 343 |
+
"step": "步数",
|
| 344 |
+
"loss": "损失",
|
| 345 |
+
"export_header": "导出 LoRA",
|
| 346 |
+
"export_path": "导出路径",
|
| 347 |
+
"export_lora_btn": "📦 导出 LoRA",
|
| 348 |
+
"export_status": "导出状态"
|
| 349 |
+
}
|
| 350 |
+
}
|
acestep/gradio_ui/interfaces/__init__.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Components Module
|
| 3 |
+
Contains all Gradio interface component definitions and layouts
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.gradio_ui.i18n import get_i18n, t
|
| 7 |
+
from acestep.gradio_ui.interfaces.dataset import create_dataset_section
|
| 8 |
+
from acestep.gradio_ui.interfaces.generation import create_generation_section
|
| 9 |
+
from acestep.gradio_ui.interfaces.result import create_results_section
|
| 10 |
+
from acestep.gradio_ui.interfaces.training import create_training_section
|
| 11 |
+
from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
|
| 15 |
+
"""
|
| 16 |
+
Create Gradio interface
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
dit_handler: DiT handler instance
|
| 20 |
+
llm_handler: LM handler instance
|
| 21 |
+
dataset_handler: Dataset handler instance
|
| 22 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 23 |
+
If None, service will not be pre-initialized.
|
| 24 |
+
language: UI language code ('en', 'zh', 'ja', default: 'en')
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
Gradio Blocks instance
|
| 28 |
+
"""
|
| 29 |
+
# Initialize i18n with selected language
|
| 30 |
+
i18n = get_i18n(language)
|
| 31 |
+
|
| 32 |
+
with gr.Blocks(
|
| 33 |
+
title=t("app.title"),
|
| 34 |
+
theme=gr.themes.Soft(),
|
| 35 |
+
css="""
|
| 36 |
+
.main-header {
|
| 37 |
+
text-align: center;
|
| 38 |
+
margin-bottom: 2rem;
|
| 39 |
+
}
|
| 40 |
+
.section-header {
|
| 41 |
+
background: linear-gradient(90deg, #4CAF50, #45a049);
|
| 42 |
+
color: white;
|
| 43 |
+
padding: 10px;
|
| 44 |
+
border-radius: 5px;
|
| 45 |
+
margin: 10px 0;
|
| 46 |
+
}
|
| 47 |
+
.lm-hints-row {
|
| 48 |
+
align-items: stretch;
|
| 49 |
+
}
|
| 50 |
+
.lm-hints-col {
|
| 51 |
+
display: flex;
|
| 52 |
+
}
|
| 53 |
+
.lm-hints-col > div {
|
| 54 |
+
flex: 1;
|
| 55 |
+
display: flex;
|
| 56 |
+
}
|
| 57 |
+
.lm-hints-btn button {
|
| 58 |
+
height: 100%;
|
| 59 |
+
width: 100%;
|
| 60 |
+
}
|
| 61 |
+
/* Position Audio time labels lower to avoid scrollbar overlap */
|
| 62 |
+
.component-wrapper > .timestamps {
|
| 63 |
+
transform: translateY(15px);
|
| 64 |
+
}
|
| 65 |
+
""",
|
| 66 |
+
) as demo:
|
| 67 |
+
|
| 68 |
+
gr.HTML(f"""
|
| 69 |
+
<div class="main-header">
|
| 70 |
+
<h1>{t("app.title")}</h1>
|
| 71 |
+
<p>{t("app.subtitle")}</p>
|
| 72 |
+
</div>
|
| 73 |
+
""")
|
| 74 |
+
|
| 75 |
+
# Dataset Explorer Section
|
| 76 |
+
dataset_section = create_dataset_section(dataset_handler)
|
| 77 |
+
|
| 78 |
+
# Generation Section (pass init_params and language to support pre-initialization)
|
| 79 |
+
generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
|
| 80 |
+
|
| 81 |
+
# Results Section
|
| 82 |
+
results_section = create_results_section(dit_handler)
|
| 83 |
+
|
| 84 |
+
# Training Section (LoRA training and dataset builder)
|
| 85 |
+
# Pass init_params to support hiding in service mode
|
| 86 |
+
training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
|
| 87 |
+
|
| 88 |
+
# Connect event handlers
|
| 89 |
+
setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
|
| 90 |
+
|
| 91 |
+
# Connect training event handlers
|
| 92 |
+
setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
|
| 93 |
+
|
| 94 |
+
return demo
|
acestep/gradio_ui/interfaces/dataset.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Dataset Section Module
|
| 3 |
+
Contains dataset explorer section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_dataset_section(dataset_handler) -> dict:
|
| 9 |
+
"""Create dataset explorer section"""
|
| 10 |
+
with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
|
| 11 |
+
with gr.Row(equal_height=True):
|
| 12 |
+
dataset_type = gr.Dropdown(
|
| 13 |
+
choices=["train", "test"],
|
| 14 |
+
value="train",
|
| 15 |
+
label="Dataset",
|
| 16 |
+
info="Choose dataset to explore",
|
| 17 |
+
scale=2
|
| 18 |
+
)
|
| 19 |
+
import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
|
| 20 |
+
|
| 21 |
+
search_type = gr.Dropdown(
|
| 22 |
+
choices=["keys", "idx", "random"],
|
| 23 |
+
value="random",
|
| 24 |
+
label="Search Type",
|
| 25 |
+
info="How to find items",
|
| 26 |
+
scale=1
|
| 27 |
+
)
|
| 28 |
+
search_value = gr.Textbox(
|
| 29 |
+
label="Search Value",
|
| 30 |
+
placeholder="Enter keys or index (leave empty for random)",
|
| 31 |
+
info="Keys: exact match, Index: 0 to dataset size-1",
|
| 32 |
+
scale=2
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
instruction_display = gr.Textbox(
|
| 36 |
+
label="📝 Instruction",
|
| 37 |
+
interactive=False,
|
| 38 |
+
placeholder="No instruction available",
|
| 39 |
+
lines=1
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
repaint_viz_plot = gr.Plot()
|
| 43 |
+
|
| 44 |
+
with gr.Accordion("📋 Item Metadata (JSON)", open=False):
|
| 45 |
+
item_info_json = gr.Code(
|
| 46 |
+
label="Complete Item Information",
|
| 47 |
+
language="json",
|
| 48 |
+
interactive=False,
|
| 49 |
+
lines=15
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
with gr.Row(equal_height=True):
|
| 53 |
+
item_src_audio = gr.Audio(
|
| 54 |
+
label="Source Audio",
|
| 55 |
+
type="filepath",
|
| 56 |
+
interactive=False,
|
| 57 |
+
scale=8
|
| 58 |
+
)
|
| 59 |
+
get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
|
| 60 |
+
|
| 61 |
+
with gr.Row(equal_height=True):
|
| 62 |
+
item_target_audio = gr.Audio(
|
| 63 |
+
label="Target Audio",
|
| 64 |
+
type="filepath",
|
| 65 |
+
interactive=False,
|
| 66 |
+
scale=8
|
| 67 |
+
)
|
| 68 |
+
item_refer_audio = gr.Audio(
|
| 69 |
+
label="Reference Audio",
|
| 70 |
+
type="filepath",
|
| 71 |
+
interactive=False,
|
| 72 |
+
scale=2
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
with gr.Row():
|
| 76 |
+
use_src_checkbox = gr.Checkbox(
|
| 77 |
+
label="Use Source Audio from Dataset",
|
| 78 |
+
value=True,
|
| 79 |
+
info="Check to use the source audio from dataset"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
|
| 83 |
+
auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
|
| 84 |
+
|
| 85 |
+
return {
|
| 86 |
+
"dataset_type": dataset_type,
|
| 87 |
+
"import_dataset_btn": import_dataset_btn,
|
| 88 |
+
"search_type": search_type,
|
| 89 |
+
"search_value": search_value,
|
| 90 |
+
"instruction_display": instruction_display,
|
| 91 |
+
"repaint_viz_plot": repaint_viz_plot,
|
| 92 |
+
"item_info_json": item_info_json,
|
| 93 |
+
"item_src_audio": item_src_audio,
|
| 94 |
+
"get_item_btn": get_item_btn,
|
| 95 |
+
"item_target_audio": item_target_audio,
|
| 96 |
+
"item_refer_audio": item_refer_audio,
|
| 97 |
+
"use_src_checkbox": use_src_checkbox,
|
| 98 |
+
"data_status": data_status,
|
| 99 |
+
"auto_fill_btn": auto_fill_btn,
|
| 100 |
+
}
|
| 101 |
+
|
acestep/gradio_ui/interfaces/generation.py
ADDED
|
@@ -0,0 +1,824 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Generation Section Module
|
| 3 |
+
Contains generation section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import sys
|
| 6 |
+
import gradio as gr
|
| 7 |
+
from acestep.constants import (
|
| 8 |
+
VALID_LANGUAGES,
|
| 9 |
+
TRACK_NAMES,
|
| 10 |
+
TASK_TYPES_TURBO,
|
| 11 |
+
TASK_TYPES_BASE,
|
| 12 |
+
DEFAULT_DIT_INSTRUCTION,
|
| 13 |
+
)
|
| 14 |
+
from acestep.gradio_ui.i18n import t
|
| 15 |
+
from acestep.gpu_config import get_global_gpu_config, GPUConfig
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
|
| 19 |
+
"""Create generation section
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
dit_handler: DiT handler instance
|
| 23 |
+
llm_handler: LM handler instance
|
| 24 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 25 |
+
If None, service will not be pre-initialized.
|
| 26 |
+
language: UI language code ('en', 'zh', 'ja')
|
| 27 |
+
"""
|
| 28 |
+
# Check if service is pre-initialized
|
| 29 |
+
service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
|
| 30 |
+
|
| 31 |
+
# Check if running in service mode (restricted UI)
|
| 32 |
+
service_mode = init_params is not None and init_params.get('service_mode', False)
|
| 33 |
+
|
| 34 |
+
# Get current language from init_params if available
|
| 35 |
+
current_language = init_params.get('language', language) if init_params else language
|
| 36 |
+
|
| 37 |
+
# Get GPU configuration
|
| 38 |
+
gpu_config: GPUConfig = init_params.get('gpu_config') if init_params else None
|
| 39 |
+
if gpu_config is None:
|
| 40 |
+
gpu_config = get_global_gpu_config()
|
| 41 |
+
|
| 42 |
+
# Determine if LM is initialized (for setting appropriate limits)
|
| 43 |
+
lm_initialized = init_params.get('init_llm', False) if init_params else False
|
| 44 |
+
|
| 45 |
+
# Calculate UI limits based on GPU config and LM state
|
| 46 |
+
max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
|
| 47 |
+
max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
|
| 48 |
+
default_batch_size = min(2, max_batch_size) # Default to 2 or max if lower
|
| 49 |
+
init_lm_default = gpu_config.init_lm_default
|
| 50 |
+
|
| 51 |
+
# Determine default offload setting
|
| 52 |
+
# If XPU is detected, default offload to False (keep models on device)
|
| 53 |
+
# Otherwise default to True (offload to CPU to save VRAM)
|
| 54 |
+
default_offload = True
|
| 55 |
+
try:
|
| 56 |
+
import torch
|
| 57 |
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 58 |
+
default_offload = False
|
| 59 |
+
except ImportError:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
with gr.Group():
|
| 63 |
+
# Service Configuration - collapse if pre-initialized, hide if in service mode
|
| 64 |
+
accordion_open = not service_pre_initialized
|
| 65 |
+
accordion_visible = not service_pre_initialized # Hide when running in service mode
|
| 66 |
+
with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
|
| 67 |
+
# Language selector at the top
|
| 68 |
+
with gr.Row():
|
| 69 |
+
language_dropdown = gr.Dropdown(
|
| 70 |
+
choices=[
|
| 71 |
+
("English", "en"),
|
| 72 |
+
("中文", "zh"),
|
| 73 |
+
("日本語", "ja"),
|
| 74 |
+
],
|
| 75 |
+
value=current_language,
|
| 76 |
+
label=t("service.language_label"),
|
| 77 |
+
info=t("service.language_info"),
|
| 78 |
+
scale=1,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Dropdown options section - all dropdowns grouped together
|
| 82 |
+
with gr.Row(equal_height=True):
|
| 83 |
+
with gr.Column(scale=4):
|
| 84 |
+
# Set checkpoint value from init_params if pre-initialized
|
| 85 |
+
checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
|
| 86 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 87 |
+
label=t("service.checkpoint_label"),
|
| 88 |
+
choices=dit_handler.get_available_checkpoints(),
|
| 89 |
+
value=checkpoint_value,
|
| 90 |
+
info=t("service.checkpoint_info")
|
| 91 |
+
)
|
| 92 |
+
with gr.Column(scale=1, min_width=90):
|
| 93 |
+
refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
|
| 94 |
+
|
| 95 |
+
with gr.Row():
|
| 96 |
+
# Get available acestep-v15- model list
|
| 97 |
+
available_models = dit_handler.get_available_acestep_v15_models()
|
| 98 |
+
default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
|
| 99 |
+
|
| 100 |
+
# Set config_path value from init_params if pre-initialized
|
| 101 |
+
config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 102 |
+
config_path = gr.Dropdown(
|
| 103 |
+
label=t("service.model_path_label"),
|
| 104 |
+
choices=available_models,
|
| 105 |
+
value=config_path_value,
|
| 106 |
+
info=t("service.model_path_info")
|
| 107 |
+
)
|
| 108 |
+
# Set device value from init_params if pre-initialized
|
| 109 |
+
device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
|
| 110 |
+
device = gr.Dropdown(
|
| 111 |
+
choices=["auto", "cuda", "mps", "xpu", "cpu"],
|
| 112 |
+
value=device_value,
|
| 113 |
+
label=t("service.device_label"),
|
| 114 |
+
info=t("service.device_info")
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
with gr.Row():
|
| 118 |
+
# Get available 5Hz LM model list
|
| 119 |
+
available_lm_models = llm_handler.get_available_5hz_lm_models()
|
| 120 |
+
default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
|
| 121 |
+
|
| 122 |
+
# Set lm_model_path value from init_params if pre-initialized
|
| 123 |
+
lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
|
| 124 |
+
lm_model_path = gr.Dropdown(
|
| 125 |
+
label=t("service.lm_model_path_label"),
|
| 126 |
+
choices=available_lm_models,
|
| 127 |
+
value=lm_model_path_value,
|
| 128 |
+
info=t("service.lm_model_path_info")
|
| 129 |
+
)
|
| 130 |
+
# Set backend value from init_params if pre-initialized
|
| 131 |
+
backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
|
| 132 |
+
backend_dropdown = gr.Dropdown(
|
| 133 |
+
choices=["vllm", "pt", "mlx"],
|
| 134 |
+
value=backend_value,
|
| 135 |
+
label=t("service.backend_label"),
|
| 136 |
+
info=t("service.backend_info")
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Checkbox options section - all checkboxes grouped together
|
| 140 |
+
with gr.Row():
|
| 141 |
+
# Set init_llm value from init_params if pre-initialized, otherwise use GPU config default
|
| 142 |
+
init_llm_value = init_params.get('init_llm', init_lm_default) if service_pre_initialized else init_lm_default
|
| 143 |
+
init_llm_checkbox = gr.Checkbox(
|
| 144 |
+
label=t("service.init_llm_label"),
|
| 145 |
+
value=init_llm_value,
|
| 146 |
+
info=t("service.init_llm_info"),
|
| 147 |
+
)
|
| 148 |
+
# Auto-detect flash attention availability
|
| 149 |
+
flash_attn_available = dit_handler.is_flash_attention_available(device_value)
|
| 150 |
+
# Set use_flash_attention value from init_params if pre-initialized
|
| 151 |
+
use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
|
| 152 |
+
use_flash_attention_checkbox = gr.Checkbox(
|
| 153 |
+
label=t("service.flash_attention_label"),
|
| 154 |
+
value=use_flash_attention_value,
|
| 155 |
+
interactive=flash_attn_available,
|
| 156 |
+
info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
|
| 157 |
+
)
|
| 158 |
+
# Set offload_to_cpu value from init_params if pre-initialized (default True)
|
| 159 |
+
offload_to_cpu_value = init_params.get('offload_to_cpu', default_offload) if service_pre_initialized else default_offload
|
| 160 |
+
offload_to_cpu_checkbox = gr.Checkbox(
|
| 161 |
+
label=t("service.offload_cpu_label"),
|
| 162 |
+
value=offload_to_cpu_value,
|
| 163 |
+
info=t("service.offload_cpu_info")
|
| 164 |
+
)
|
| 165 |
+
# Set offload_dit_to_cpu value from init_params if pre-initialized (default True)
|
| 166 |
+
offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', default_offload) if service_pre_initialized else default_offload
|
| 167 |
+
offload_dit_to_cpu_checkbox = gr.Checkbox(
|
| 168 |
+
label=t("service.offload_dit_cpu_label"),
|
| 169 |
+
value=offload_dit_to_cpu_value,
|
| 170 |
+
info=t("service.offload_dit_cpu_info")
|
| 171 |
+
)
|
| 172 |
+
# Set compile_model value from init_params if pre-initialized (default True)
|
| 173 |
+
compile_model_value = init_params.get('compile_model', True) if service_pre_initialized else True
|
| 174 |
+
compile_model_checkbox = gr.Checkbox(
|
| 175 |
+
label=t("service.compile_model_label"),
|
| 176 |
+
value=compile_model_value,
|
| 177 |
+
info=t("service.compile_model_info")
|
| 178 |
+
)
|
| 179 |
+
# Set quantization value from init_params if pre-initialized.
|
| 180 |
+
# Default to False on macOS to avoid torchao incompatibilities.
|
| 181 |
+
default_quantization = False if sys.platform == "darwin" else True
|
| 182 |
+
quantization_value = init_params.get('quantization', default_quantization) if service_pre_initialized else default_quantization
|
| 183 |
+
quantization_checkbox = gr.Checkbox(
|
| 184 |
+
label=t("service.quantization_label"),
|
| 185 |
+
value=quantization_value,
|
| 186 |
+
info=t("service.quantization_info")
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
|
| 190 |
+
# Set init_status value from init_params if pre-initialized
|
| 191 |
+
init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
|
| 192 |
+
init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
|
| 193 |
+
|
| 194 |
+
# LoRA Configuration Section
|
| 195 |
+
gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
|
| 196 |
+
with gr.Row():
|
| 197 |
+
lora_path = gr.Textbox(
|
| 198 |
+
label="LoRA Path",
|
| 199 |
+
placeholder="./lora_output/final/adapter",
|
| 200 |
+
info="Path to trained LoRA adapter directory",
|
| 201 |
+
scale=3,
|
| 202 |
+
)
|
| 203 |
+
load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
|
| 204 |
+
unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
|
| 205 |
+
with gr.Row():
|
| 206 |
+
use_lora_checkbox = gr.Checkbox(
|
| 207 |
+
label="Use LoRA",
|
| 208 |
+
value=False,
|
| 209 |
+
info="Enable LoRA adapter for inference",
|
| 210 |
+
scale=1,
|
| 211 |
+
)
|
| 212 |
+
lora_scale_slider = gr.Slider(
|
| 213 |
+
minimum=0.0,
|
| 214 |
+
maximum=1.0,
|
| 215 |
+
value=1.0,
|
| 216 |
+
step=0.05,
|
| 217 |
+
label="LoRA Scale",
|
| 218 |
+
info="LoRA influence strength (0=disabled, 1=full)",
|
| 219 |
+
scale=2,
|
| 220 |
+
)
|
| 221 |
+
lora_status = gr.Textbox(
|
| 222 |
+
label="LoRA Status",
|
| 223 |
+
value="No LoRA loaded",
|
| 224 |
+
interactive=False,
|
| 225 |
+
scale=2,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# Inputs
|
| 229 |
+
with gr.Row():
|
| 230 |
+
with gr.Column(scale=2):
|
| 231 |
+
with gr.Accordion(t("generation.required_inputs"), open=True):
|
| 232 |
+
# Task type
|
| 233 |
+
# Determine initial task_type choices based on actual model in use
|
| 234 |
+
# When service is pre-initialized, use config_path from init_params
|
| 235 |
+
actual_model = init_params.get('config_path', default_model) if service_pre_initialized else default_model
|
| 236 |
+
actual_model_lower = (actual_model or "").lower()
|
| 237 |
+
if "turbo" in actual_model_lower:
|
| 238 |
+
initial_task_choices = TASK_TYPES_TURBO
|
| 239 |
+
else:
|
| 240 |
+
initial_task_choices = TASK_TYPES_BASE
|
| 241 |
+
|
| 242 |
+
with gr.Row(equal_height=True):
|
| 243 |
+
with gr.Column(scale=2):
|
| 244 |
+
task_type = gr.Dropdown(
|
| 245 |
+
choices=initial_task_choices,
|
| 246 |
+
value="text2music",
|
| 247 |
+
label=t("generation.task_type_label"),
|
| 248 |
+
info=t("generation.task_type_info"),
|
| 249 |
+
)
|
| 250 |
+
with gr.Column(scale=7):
|
| 251 |
+
instruction_display_gen = gr.Textbox(
|
| 252 |
+
label=t("generation.instruction_label"),
|
| 253 |
+
value=DEFAULT_DIT_INSTRUCTION,
|
| 254 |
+
interactive=False,
|
| 255 |
+
lines=1,
|
| 256 |
+
info=t("generation.instruction_info"),
|
| 257 |
+
)
|
| 258 |
+
with gr.Column(scale=1, min_width=100):
|
| 259 |
+
load_file = gr.UploadButton(
|
| 260 |
+
t("generation.load_btn"),
|
| 261 |
+
file_types=[".json"],
|
| 262 |
+
file_count="single",
|
| 263 |
+
variant="secondary",
|
| 264 |
+
size="sm",
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
track_name = gr.Dropdown(
|
| 268 |
+
choices=TRACK_NAMES,
|
| 269 |
+
value=None,
|
| 270 |
+
label=t("generation.track_name_label"),
|
| 271 |
+
info=t("generation.track_name_info"),
|
| 272 |
+
visible=False
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
complete_track_classes = gr.CheckboxGroup(
|
| 276 |
+
choices=TRACK_NAMES,
|
| 277 |
+
label=t("generation.track_classes_label"),
|
| 278 |
+
info=t("generation.track_classes_info"),
|
| 279 |
+
visible=False
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Audio uploads
|
| 283 |
+
audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False)
|
| 284 |
+
with audio_uploads_accordion:
|
| 285 |
+
with gr.Row(equal_height=True):
|
| 286 |
+
with gr.Column(scale=2):
|
| 287 |
+
reference_audio = gr.Audio(
|
| 288 |
+
label=t("generation.reference_audio"),
|
| 289 |
+
type="filepath",
|
| 290 |
+
)
|
| 291 |
+
with gr.Column(scale=7):
|
| 292 |
+
src_audio = gr.Audio(
|
| 293 |
+
label=t("generation.source_audio"),
|
| 294 |
+
type="filepath",
|
| 295 |
+
)
|
| 296 |
+
with gr.Column(scale=1, min_width=80):
|
| 297 |
+
convert_src_to_codes_btn = gr.Button(
|
| 298 |
+
t("generation.convert_codes_btn"),
|
| 299 |
+
variant="secondary",
|
| 300 |
+
size="sm"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Audio Codes for text2music - single input for transcription or cover task
|
| 304 |
+
with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group:
|
| 305 |
+
with gr.Row(equal_height=True):
|
| 306 |
+
text2music_audio_code_string = gr.Textbox(
|
| 307 |
+
label=t("generation.lm_codes_label"),
|
| 308 |
+
placeholder=t("generation.lm_codes_placeholder"),
|
| 309 |
+
lines=6,
|
| 310 |
+
info=t("generation.lm_codes_info"),
|
| 311 |
+
scale=9,
|
| 312 |
+
)
|
| 313 |
+
transcribe_btn = gr.Button(
|
| 314 |
+
t("generation.transcribe_btn"),
|
| 315 |
+
variant="secondary",
|
| 316 |
+
size="sm",
|
| 317 |
+
scale=1,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Repainting controls
|
| 321 |
+
with gr.Group(visible=False) as repainting_group:
|
| 322 |
+
gr.HTML(f"<h5>{t('generation.repainting_controls')}</h5>")
|
| 323 |
+
with gr.Row():
|
| 324 |
+
repainting_start = gr.Number(
|
| 325 |
+
label=t("generation.repainting_start"),
|
| 326 |
+
value=0.0,
|
| 327 |
+
step=0.1,
|
| 328 |
+
)
|
| 329 |
+
repainting_end = gr.Number(
|
| 330 |
+
label=t("generation.repainting_end"),
|
| 331 |
+
value=-1,
|
| 332 |
+
minimum=-1,
|
| 333 |
+
step=0.1,
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
# Simple/Custom Mode Toggle
|
| 337 |
+
# In service mode: only Custom mode, hide the toggle
|
| 338 |
+
with gr.Row(visible=not service_mode):
|
| 339 |
+
generation_mode = gr.Radio(
|
| 340 |
+
choices=[
|
| 341 |
+
(t("generation.mode_simple"), "simple"),
|
| 342 |
+
(t("generation.mode_custom"), "custom"),
|
| 343 |
+
],
|
| 344 |
+
value="custom" if service_mode else "simple",
|
| 345 |
+
label=t("generation.mode_label"),
|
| 346 |
+
info=t("generation.mode_info"),
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Simple Mode Components - hidden in service mode
|
| 350 |
+
with gr.Group(visible=not service_mode) as simple_mode_group:
|
| 351 |
+
with gr.Row(equal_height=True):
|
| 352 |
+
simple_query_input = gr.Textbox(
|
| 353 |
+
label=t("generation.simple_query_label"),
|
| 354 |
+
placeholder=t("generation.simple_query_placeholder"),
|
| 355 |
+
lines=2,
|
| 356 |
+
info=t("generation.simple_query_info"),
|
| 357 |
+
scale=12,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
with gr.Column(scale=1, min_width=100):
|
| 361 |
+
random_desc_btn = gr.Button(
|
| 362 |
+
"🎲",
|
| 363 |
+
variant="secondary",
|
| 364 |
+
size="sm",
|
| 365 |
+
scale=2
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
with gr.Row(equal_height=True):
|
| 369 |
+
with gr.Column(scale=1, variant="compact"):
|
| 370 |
+
simple_instrumental_checkbox = gr.Checkbox(
|
| 371 |
+
label=t("generation.instrumental_label"),
|
| 372 |
+
value=False,
|
| 373 |
+
)
|
| 374 |
+
with gr.Column(scale=18):
|
| 375 |
+
create_sample_btn = gr.Button(
|
| 376 |
+
t("generation.create_sample_btn"),
|
| 377 |
+
variant="primary",
|
| 378 |
+
size="lg",
|
| 379 |
+
)
|
| 380 |
+
with gr.Column(scale=1, variant="compact"):
|
| 381 |
+
simple_vocal_language = gr.Dropdown(
|
| 382 |
+
choices=VALID_LANGUAGES,
|
| 383 |
+
value="unknown",
|
| 384 |
+
allow_custom_value=True,
|
| 385 |
+
label=t("generation.simple_vocal_language_label"),
|
| 386 |
+
interactive=True,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# State to track if sample has been created in Simple mode
|
| 390 |
+
simple_sample_created = gr.State(value=False)
|
| 391 |
+
|
| 392 |
+
# Music Caption - wrapped in accordion that can be collapsed in Simple mode
|
| 393 |
+
# Default to expanded for better UX
|
| 394 |
+
with gr.Accordion(t("generation.caption_title"), open=True) as caption_accordion:
|
| 395 |
+
with gr.Row(equal_height=True):
|
| 396 |
+
captions = gr.Textbox(
|
| 397 |
+
label=t("generation.caption_label"),
|
| 398 |
+
placeholder=t("generation.caption_placeholder"),
|
| 399 |
+
lines=3,
|
| 400 |
+
info=t("generation.caption_info"),
|
| 401 |
+
scale=12,
|
| 402 |
+
)
|
| 403 |
+
with gr.Column(scale=1, min_width=100):
|
| 404 |
+
sample_btn = gr.Button(
|
| 405 |
+
"🎲",
|
| 406 |
+
variant="secondary",
|
| 407 |
+
size="sm",
|
| 408 |
+
scale=2,
|
| 409 |
+
)
|
| 410 |
+
# Lyrics - wrapped in accordion that can be collapsed in Simple mode
|
| 411 |
+
# Default to expanded for better UX
|
| 412 |
+
with gr.Accordion(t("generation.lyrics_title"), open=True) as lyrics_accordion:
|
| 413 |
+
lyrics = gr.Textbox(
|
| 414 |
+
label=t("generation.lyrics_label"),
|
| 415 |
+
placeholder=t("generation.lyrics_placeholder"),
|
| 416 |
+
lines=8,
|
| 417 |
+
info=t("generation.lyrics_info")
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
with gr.Row(variant="compact", equal_height=True):
|
| 421 |
+
instrumental_checkbox = gr.Checkbox(
|
| 422 |
+
label=t("generation.instrumental_label"),
|
| 423 |
+
value=False,
|
| 424 |
+
scale=1,
|
| 425 |
+
min_width=120,
|
| 426 |
+
container=True,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# 中间:语言选择 (Dropdown)
|
| 430 |
+
# 移除 gr.HTML hack,直接使用 label 参数,Gradio 会自动处理对齐
|
| 431 |
+
vocal_language = gr.Dropdown(
|
| 432 |
+
choices=VALID_LANGUAGES,
|
| 433 |
+
value="unknown",
|
| 434 |
+
label=t("generation.vocal_language_label"),
|
| 435 |
+
show_label=False,
|
| 436 |
+
container=True,
|
| 437 |
+
allow_custom_value=True,
|
| 438 |
+
scale=3,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# 右侧:格式化按钮 (Button)
|
| 442 |
+
# 放在同一行最右侧,操作更顺手
|
| 443 |
+
format_btn = gr.Button(
|
| 444 |
+
t("generation.format_btn"),
|
| 445 |
+
variant="secondary",
|
| 446 |
+
scale=1,
|
| 447 |
+
min_width=80,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Optional Parameters
|
| 451 |
+
# In service mode: auto-expand
|
| 452 |
+
with gr.Accordion(t("generation.optional_params"), open=service_mode) as optional_params_accordion:
|
| 453 |
+
with gr.Row():
|
| 454 |
+
bpm = gr.Number(
|
| 455 |
+
label=t("generation.bpm_label"),
|
| 456 |
+
value=None,
|
| 457 |
+
step=1,
|
| 458 |
+
info=t("generation.bpm_info")
|
| 459 |
+
)
|
| 460 |
+
key_scale = gr.Textbox(
|
| 461 |
+
label=t("generation.keyscale_label"),
|
| 462 |
+
placeholder=t("generation.keyscale_placeholder"),
|
| 463 |
+
value="",
|
| 464 |
+
info=t("generation.keyscale_info")
|
| 465 |
+
)
|
| 466 |
+
time_signature = gr.Dropdown(
|
| 467 |
+
choices=["", "2", "3", "4", "6", "N/A"],
|
| 468 |
+
value="",
|
| 469 |
+
label=t("generation.timesig_label"),
|
| 470 |
+
allow_custom_value=True,
|
| 471 |
+
info=t("generation.timesig_info")
|
| 472 |
+
)
|
| 473 |
+
audio_duration = gr.Number(
|
| 474 |
+
label=t("generation.duration_label"),
|
| 475 |
+
value=-1,
|
| 476 |
+
minimum=-1,
|
| 477 |
+
maximum=float(max_duration),
|
| 478 |
+
step=0.1,
|
| 479 |
+
info=t("generation.duration_info") + f" (Max: {max_duration}s / {max_duration // 60} min)"
|
| 480 |
+
)
|
| 481 |
+
batch_size_input = gr.Number(
|
| 482 |
+
label=t("generation.batch_size_label"),
|
| 483 |
+
value=default_batch_size,
|
| 484 |
+
minimum=1,
|
| 485 |
+
maximum=max_batch_size,
|
| 486 |
+
step=1,
|
| 487 |
+
info=t("generation.batch_size_info") + f" (Max: {max_batch_size})",
|
| 488 |
+
interactive=not service_mode # Fixed in service mode
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Advanced Settings
|
| 492 |
+
# Default UI settings use turbo mode (max 20 steps, default 8, show shift with default 3)
|
| 493 |
+
# These will be updated after model initialization based on handler.is_turbo_model()
|
| 494 |
+
with gr.Accordion(t("generation.advanced_settings"), open=False):
|
| 495 |
+
with gr.Row():
|
| 496 |
+
inference_steps = gr.Slider(
|
| 497 |
+
minimum=1,
|
| 498 |
+
maximum=20,
|
| 499 |
+
value=8,
|
| 500 |
+
step=1,
|
| 501 |
+
label=t("generation.inference_steps_label"),
|
| 502 |
+
info=t("generation.inference_steps_info")
|
| 503 |
+
)
|
| 504 |
+
guidance_scale = gr.Slider(
|
| 505 |
+
minimum=1.0,
|
| 506 |
+
maximum=15.0,
|
| 507 |
+
value=7.0,
|
| 508 |
+
step=0.1,
|
| 509 |
+
label=t("generation.guidance_scale_label"),
|
| 510 |
+
info=t("generation.guidance_scale_info"),
|
| 511 |
+
visible=False
|
| 512 |
+
)
|
| 513 |
+
with gr.Column():
|
| 514 |
+
seed = gr.Textbox(
|
| 515 |
+
label=t("generation.seed_label"),
|
| 516 |
+
value="-1",
|
| 517 |
+
info=t("generation.seed_info")
|
| 518 |
+
)
|
| 519 |
+
random_seed_checkbox = gr.Checkbox(
|
| 520 |
+
label=t("generation.random_seed_label"),
|
| 521 |
+
value=True,
|
| 522 |
+
info=t("generation.random_seed_info")
|
| 523 |
+
)
|
| 524 |
+
audio_format = gr.Dropdown(
|
| 525 |
+
choices=["mp3", "flac"],
|
| 526 |
+
value="mp3",
|
| 527 |
+
label=t("generation.audio_format_label"),
|
| 528 |
+
info=t("generation.audio_format_info"),
|
| 529 |
+
interactive=not service_mode # Fixed in service mode
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
with gr.Row():
|
| 533 |
+
use_adg = gr.Checkbox(
|
| 534 |
+
label=t("generation.use_adg_label"),
|
| 535 |
+
value=False,
|
| 536 |
+
info=t("generation.use_adg_info"),
|
| 537 |
+
visible=False
|
| 538 |
+
)
|
| 539 |
+
shift = gr.Slider(
|
| 540 |
+
minimum=1.0,
|
| 541 |
+
maximum=5.0,
|
| 542 |
+
value=3.0,
|
| 543 |
+
step=0.1,
|
| 544 |
+
label=t("generation.shift_label"),
|
| 545 |
+
info=t("generation.shift_info"),
|
| 546 |
+
visible=True
|
| 547 |
+
)
|
| 548 |
+
infer_method = gr.Dropdown(
|
| 549 |
+
choices=["ode", "sde"],
|
| 550 |
+
value="ode",
|
| 551 |
+
label=t("generation.infer_method_label"),
|
| 552 |
+
info=t("generation.infer_method_info"),
|
| 553 |
+
)
|
| 554 |
+
|
| 555 |
+
with gr.Row():
|
| 556 |
+
custom_timesteps = gr.Textbox(
|
| 557 |
+
label=t("generation.custom_timesteps_label"),
|
| 558 |
+
placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
|
| 559 |
+
value="",
|
| 560 |
+
info=t("generation.custom_timesteps_info"),
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
with gr.Row():
|
| 564 |
+
cfg_interval_start = gr.Slider(
|
| 565 |
+
minimum=0.0,
|
| 566 |
+
maximum=1.0,
|
| 567 |
+
value=0.0,
|
| 568 |
+
step=0.01,
|
| 569 |
+
label=t("generation.cfg_interval_start"),
|
| 570 |
+
visible=False
|
| 571 |
+
)
|
| 572 |
+
cfg_interval_end = gr.Slider(
|
| 573 |
+
minimum=0.0,
|
| 574 |
+
maximum=1.0,
|
| 575 |
+
value=1.0,
|
| 576 |
+
step=0.01,
|
| 577 |
+
label=t("generation.cfg_interval_end"),
|
| 578 |
+
visible=False
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# LM (Language Model) Parameters
|
| 582 |
+
gr.HTML(f"<h4>{t('generation.lm_params_title')}</h4>")
|
| 583 |
+
with gr.Row():
|
| 584 |
+
lm_temperature = gr.Slider(
|
| 585 |
+
label=t("generation.lm_temperature_label"),
|
| 586 |
+
minimum=0.0,
|
| 587 |
+
maximum=2.0,
|
| 588 |
+
value=0.85,
|
| 589 |
+
step=0.1,
|
| 590 |
+
scale=1,
|
| 591 |
+
info=t("generation.lm_temperature_info")
|
| 592 |
+
)
|
| 593 |
+
lm_cfg_scale = gr.Slider(
|
| 594 |
+
label=t("generation.lm_cfg_scale_label"),
|
| 595 |
+
minimum=1.0,
|
| 596 |
+
maximum=3.0,
|
| 597 |
+
value=2.0,
|
| 598 |
+
step=0.1,
|
| 599 |
+
scale=1,
|
| 600 |
+
info=t("generation.lm_cfg_scale_info")
|
| 601 |
+
)
|
| 602 |
+
lm_top_k = gr.Slider(
|
| 603 |
+
label=t("generation.lm_top_k_label"),
|
| 604 |
+
minimum=0,
|
| 605 |
+
maximum=100,
|
| 606 |
+
value=0,
|
| 607 |
+
step=1,
|
| 608 |
+
scale=1,
|
| 609 |
+
info=t("generation.lm_top_k_info")
|
| 610 |
+
)
|
| 611 |
+
lm_top_p = gr.Slider(
|
| 612 |
+
label=t("generation.lm_top_p_label"),
|
| 613 |
+
minimum=0.0,
|
| 614 |
+
maximum=1.0,
|
| 615 |
+
value=0.9,
|
| 616 |
+
step=0.01,
|
| 617 |
+
scale=1,
|
| 618 |
+
info=t("generation.lm_top_p_info")
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
with gr.Row():
|
| 622 |
+
lm_negative_prompt = gr.Textbox(
|
| 623 |
+
label=t("generation.lm_negative_prompt_label"),
|
| 624 |
+
value="NO USER INPUT",
|
| 625 |
+
placeholder=t("generation.lm_negative_prompt_placeholder"),
|
| 626 |
+
info=t("generation.lm_negative_prompt_info"),
|
| 627 |
+
lines=2,
|
| 628 |
+
scale=2,
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
with gr.Row():
|
| 632 |
+
use_cot_metas = gr.Checkbox(
|
| 633 |
+
label=t("generation.cot_metas_label"),
|
| 634 |
+
value=True,
|
| 635 |
+
info=t("generation.cot_metas_info"),
|
| 636 |
+
scale=1,
|
| 637 |
+
)
|
| 638 |
+
use_cot_language = gr.Checkbox(
|
| 639 |
+
label=t("generation.cot_language_label"),
|
| 640 |
+
value=True,
|
| 641 |
+
info=t("generation.cot_language_info"),
|
| 642 |
+
scale=1,
|
| 643 |
+
)
|
| 644 |
+
constrained_decoding_debug = gr.Checkbox(
|
| 645 |
+
label=t("generation.constrained_debug_label"),
|
| 646 |
+
value=False,
|
| 647 |
+
info=t("generation.constrained_debug_info"),
|
| 648 |
+
scale=1,
|
| 649 |
+
interactive=not service_mode # Fixed in service mode
|
| 650 |
+
)
|
| 651 |
+
|
| 652 |
+
with gr.Row():
|
| 653 |
+
auto_score = gr.Checkbox(
|
| 654 |
+
label=t("generation.auto_score_label"),
|
| 655 |
+
value=False,
|
| 656 |
+
info=t("generation.auto_score_info"),
|
| 657 |
+
scale=1,
|
| 658 |
+
interactive=not service_mode # Fixed in service mode
|
| 659 |
+
)
|
| 660 |
+
auto_lrc = gr.Checkbox(
|
| 661 |
+
label=t("generation.auto_lrc_label"),
|
| 662 |
+
value=False,
|
| 663 |
+
info=t("generation.auto_lrc_info"),
|
| 664 |
+
scale=1,
|
| 665 |
+
interactive=not service_mode # Fixed in service mode
|
| 666 |
+
)
|
| 667 |
+
lm_batch_chunk_size = gr.Number(
|
| 668 |
+
label=t("generation.lm_batch_chunk_label"),
|
| 669 |
+
value=8,
|
| 670 |
+
minimum=1,
|
| 671 |
+
maximum=32,
|
| 672 |
+
step=1,
|
| 673 |
+
info=t("generation.lm_batch_chunk_info"),
|
| 674 |
+
scale=1,
|
| 675 |
+
interactive=not service_mode # Fixed in service mode
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
with gr.Row():
|
| 679 |
+
audio_cover_strength = gr.Slider(
|
| 680 |
+
minimum=0.0,
|
| 681 |
+
maximum=1.0,
|
| 682 |
+
value=1.0,
|
| 683 |
+
step=0.01,
|
| 684 |
+
label=t("generation.codes_strength_label"),
|
| 685 |
+
info=t("generation.codes_strength_info"),
|
| 686 |
+
scale=1,
|
| 687 |
+
)
|
| 688 |
+
score_scale = gr.Slider(
|
| 689 |
+
minimum=0.01,
|
| 690 |
+
maximum=1.0,
|
| 691 |
+
value=0.5,
|
| 692 |
+
step=0.01,
|
| 693 |
+
label=t("generation.score_sensitivity_label"),
|
| 694 |
+
info=t("generation.score_sensitivity_info"),
|
| 695 |
+
scale=1,
|
| 696 |
+
visible=not service_mode # Hidden in service mode
|
| 697 |
+
)
|
| 698 |
+
|
| 699 |
+
# Set generate_btn to interactive if service is pre-initialized
|
| 700 |
+
generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
|
| 701 |
+
with gr.Row(equal_height=True):
|
| 702 |
+
with gr.Column(scale=1, variant="compact"):
|
| 703 |
+
think_checkbox = gr.Checkbox(
|
| 704 |
+
label=t("generation.think_label"),
|
| 705 |
+
value=True,
|
| 706 |
+
scale=1,
|
| 707 |
+
)
|
| 708 |
+
allow_lm_batch = gr.Checkbox(
|
| 709 |
+
label=t("generation.parallel_thinking_label"),
|
| 710 |
+
value=True,
|
| 711 |
+
scale=1,
|
| 712 |
+
)
|
| 713 |
+
with gr.Column(scale=18):
|
| 714 |
+
generate_btn = gr.Button(t("generation.generate_btn"), variant="primary", size="lg", interactive=generate_btn_interactive)
|
| 715 |
+
with gr.Column(scale=1, variant="compact"):
|
| 716 |
+
autogen_checkbox = gr.Checkbox(
|
| 717 |
+
label=t("generation.autogen_label"),
|
| 718 |
+
value=False, # Default to False for both service and local modes
|
| 719 |
+
scale=1,
|
| 720 |
+
interactive=not service_mode # Not selectable in service mode
|
| 721 |
+
)
|
| 722 |
+
use_cot_caption = gr.Checkbox(
|
| 723 |
+
label=t("generation.caption_rewrite_label"),
|
| 724 |
+
value=True,
|
| 725 |
+
scale=1,
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
return {
|
| 729 |
+
"service_config_accordion": service_config_accordion,
|
| 730 |
+
"language_dropdown": language_dropdown,
|
| 731 |
+
"checkpoint_dropdown": checkpoint_dropdown,
|
| 732 |
+
"refresh_btn": refresh_btn,
|
| 733 |
+
"config_path": config_path,
|
| 734 |
+
"device": device,
|
| 735 |
+
"init_btn": init_btn,
|
| 736 |
+
"init_status": init_status,
|
| 737 |
+
"lm_model_path": lm_model_path,
|
| 738 |
+
"init_llm_checkbox": init_llm_checkbox,
|
| 739 |
+
"backend_dropdown": backend_dropdown,
|
| 740 |
+
"use_flash_attention_checkbox": use_flash_attention_checkbox,
|
| 741 |
+
"offload_to_cpu_checkbox": offload_to_cpu_checkbox,
|
| 742 |
+
"offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
|
| 743 |
+
"compile_model_checkbox": compile_model_checkbox,
|
| 744 |
+
"quantization_checkbox": quantization_checkbox,
|
| 745 |
+
# LoRA components
|
| 746 |
+
"lora_path": lora_path,
|
| 747 |
+
"load_lora_btn": load_lora_btn,
|
| 748 |
+
"unload_lora_btn": unload_lora_btn,
|
| 749 |
+
"use_lora_checkbox": use_lora_checkbox,
|
| 750 |
+
"lora_scale_slider": lora_scale_slider,
|
| 751 |
+
"lora_status": lora_status,
|
| 752 |
+
"task_type": task_type,
|
| 753 |
+
"instruction_display_gen": instruction_display_gen,
|
| 754 |
+
"track_name": track_name,
|
| 755 |
+
"complete_track_classes": complete_track_classes,
|
| 756 |
+
"audio_uploads_accordion": audio_uploads_accordion,
|
| 757 |
+
"reference_audio": reference_audio,
|
| 758 |
+
"src_audio": src_audio,
|
| 759 |
+
"convert_src_to_codes_btn": convert_src_to_codes_btn,
|
| 760 |
+
"text2music_audio_code_string": text2music_audio_code_string,
|
| 761 |
+
"transcribe_btn": transcribe_btn,
|
| 762 |
+
"text2music_audio_codes_group": text2music_audio_codes_group,
|
| 763 |
+
"lm_temperature": lm_temperature,
|
| 764 |
+
"lm_cfg_scale": lm_cfg_scale,
|
| 765 |
+
"lm_top_k": lm_top_k,
|
| 766 |
+
"lm_top_p": lm_top_p,
|
| 767 |
+
"lm_negative_prompt": lm_negative_prompt,
|
| 768 |
+
"use_cot_metas": use_cot_metas,
|
| 769 |
+
"use_cot_caption": use_cot_caption,
|
| 770 |
+
"use_cot_language": use_cot_language,
|
| 771 |
+
"repainting_group": repainting_group,
|
| 772 |
+
"repainting_start": repainting_start,
|
| 773 |
+
"repainting_end": repainting_end,
|
| 774 |
+
"audio_cover_strength": audio_cover_strength,
|
| 775 |
+
# Simple/Custom Mode Components
|
| 776 |
+
"generation_mode": generation_mode,
|
| 777 |
+
"simple_mode_group": simple_mode_group,
|
| 778 |
+
"simple_query_input": simple_query_input,
|
| 779 |
+
"random_desc_btn": random_desc_btn,
|
| 780 |
+
"simple_instrumental_checkbox": simple_instrumental_checkbox,
|
| 781 |
+
"simple_vocal_language": simple_vocal_language,
|
| 782 |
+
"create_sample_btn": create_sample_btn,
|
| 783 |
+
"simple_sample_created": simple_sample_created,
|
| 784 |
+
"caption_accordion": caption_accordion,
|
| 785 |
+
"lyrics_accordion": lyrics_accordion,
|
| 786 |
+
"optional_params_accordion": optional_params_accordion,
|
| 787 |
+
# Existing components
|
| 788 |
+
"captions": captions,
|
| 789 |
+
"sample_btn": sample_btn,
|
| 790 |
+
"load_file": load_file,
|
| 791 |
+
"lyrics": lyrics,
|
| 792 |
+
"vocal_language": vocal_language,
|
| 793 |
+
"bpm": bpm,
|
| 794 |
+
"key_scale": key_scale,
|
| 795 |
+
"time_signature": time_signature,
|
| 796 |
+
"audio_duration": audio_duration,
|
| 797 |
+
"batch_size_input": batch_size_input,
|
| 798 |
+
"inference_steps": inference_steps,
|
| 799 |
+
"guidance_scale": guidance_scale,
|
| 800 |
+
"seed": seed,
|
| 801 |
+
"random_seed_checkbox": random_seed_checkbox,
|
| 802 |
+
"use_adg": use_adg,
|
| 803 |
+
"cfg_interval_start": cfg_interval_start,
|
| 804 |
+
"cfg_interval_end": cfg_interval_end,
|
| 805 |
+
"shift": shift,
|
| 806 |
+
"infer_method": infer_method,
|
| 807 |
+
"custom_timesteps": custom_timesteps,
|
| 808 |
+
"audio_format": audio_format,
|
| 809 |
+
"think_checkbox": think_checkbox,
|
| 810 |
+
"autogen_checkbox": autogen_checkbox,
|
| 811 |
+
"generate_btn": generate_btn,
|
| 812 |
+
"instrumental_checkbox": instrumental_checkbox,
|
| 813 |
+
"format_btn": format_btn,
|
| 814 |
+
"constrained_decoding_debug": constrained_decoding_debug,
|
| 815 |
+
"score_scale": score_scale,
|
| 816 |
+
"allow_lm_batch": allow_lm_batch,
|
| 817 |
+
"auto_score": auto_score,
|
| 818 |
+
"auto_lrc": auto_lrc,
|
| 819 |
+
"lm_batch_chunk_size": lm_batch_chunk_size,
|
| 820 |
+
# GPU config values for validation
|
| 821 |
+
"gpu_config": gpu_config,
|
| 822 |
+
"max_duration": max_duration,
|
| 823 |
+
"max_batch_size": max_batch_size,
|
| 824 |
+
}
|
acestep/gradio_ui/interfaces/result.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Results Section Module
|
| 3 |
+
Contains results display section component definitions
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from acestep.gradio_ui.i18n import t
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_results_section(dit_handler) -> dict:
|
| 10 |
+
"""Create results display section"""
|
| 11 |
+
with gr.Accordion(t("results.title"), open=True):
|
| 12 |
+
# Hidden state to store LM-generated metadata
|
| 13 |
+
lm_metadata_state = gr.State(value=None)
|
| 14 |
+
|
| 15 |
+
# Hidden state to track if caption/metadata is from formatted source (LM/transcription)
|
| 16 |
+
is_format_caption_state = gr.State(value=False)
|
| 17 |
+
|
| 18 |
+
# Batch management states
|
| 19 |
+
current_batch_index = gr.State(value=0) # Currently displayed batch index
|
| 20 |
+
total_batches = gr.State(value=1) # Total number of batches generated
|
| 21 |
+
batch_queue = gr.State(value={}) # Dictionary storing all batch data
|
| 22 |
+
generation_params_state = gr.State(value={}) # Store generation parameters for next batches
|
| 23 |
+
is_generating_background = gr.State(value=False) # Background generation flag
|
| 24 |
+
|
| 25 |
+
# All audio components in one row with dynamic visibility
|
| 26 |
+
with gr.Row():
|
| 27 |
+
with gr.Column(visible=True) as audio_col_1:
|
| 28 |
+
generated_audio_1 = gr.Audio(
|
| 29 |
+
label=t("results.generated_music", n=1),
|
| 30 |
+
type="filepath",
|
| 31 |
+
interactive=False,
|
| 32 |
+
buttons=[]
|
| 33 |
+
)
|
| 34 |
+
with gr.Row(equal_height=True):
|
| 35 |
+
send_to_src_btn_1 = gr.Button(
|
| 36 |
+
t("results.send_to_src_btn"),
|
| 37 |
+
variant="secondary",
|
| 38 |
+
size="sm",
|
| 39 |
+
scale=1
|
| 40 |
+
)
|
| 41 |
+
save_btn_1 = gr.Button(
|
| 42 |
+
t("results.save_btn"),
|
| 43 |
+
variant="primary",
|
| 44 |
+
size="sm",
|
| 45 |
+
scale=1
|
| 46 |
+
)
|
| 47 |
+
score_btn_1 = gr.Button(
|
| 48 |
+
t("results.score_btn"),
|
| 49 |
+
variant="secondary",
|
| 50 |
+
size="sm",
|
| 51 |
+
scale=1
|
| 52 |
+
)
|
| 53 |
+
lrc_btn_1 = gr.Button(
|
| 54 |
+
t("results.lrc_btn"),
|
| 55 |
+
variant="secondary",
|
| 56 |
+
size="sm",
|
| 57 |
+
scale=1
|
| 58 |
+
)
|
| 59 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
|
| 60 |
+
codes_display_1 = gr.Textbox(
|
| 61 |
+
label=t("results.codes_label", n=1),
|
| 62 |
+
interactive=False,
|
| 63 |
+
buttons=["copy"],
|
| 64 |
+
lines=4,
|
| 65 |
+
max_lines=4,
|
| 66 |
+
visible=True
|
| 67 |
+
)
|
| 68 |
+
score_display_1 = gr.Textbox(
|
| 69 |
+
label=t("results.quality_score_label", n=1),
|
| 70 |
+
interactive=False,
|
| 71 |
+
buttons=["copy"],
|
| 72 |
+
lines=6,
|
| 73 |
+
max_lines=6,
|
| 74 |
+
visible=True
|
| 75 |
+
)
|
| 76 |
+
lrc_display_1 = gr.Textbox(
|
| 77 |
+
label=t("results.lrc_label", n=1),
|
| 78 |
+
interactive=True,
|
| 79 |
+
buttons=["copy"],
|
| 80 |
+
lines=8,
|
| 81 |
+
max_lines=8,
|
| 82 |
+
visible=True
|
| 83 |
+
)
|
| 84 |
+
with gr.Column(visible=True) as audio_col_2:
|
| 85 |
+
generated_audio_2 = gr.Audio(
|
| 86 |
+
label=t("results.generated_music", n=2),
|
| 87 |
+
type="filepath",
|
| 88 |
+
interactive=False,
|
| 89 |
+
buttons=[]
|
| 90 |
+
)
|
| 91 |
+
with gr.Row(equal_height=True):
|
| 92 |
+
send_to_src_btn_2 = gr.Button(
|
| 93 |
+
t("results.send_to_src_btn"),
|
| 94 |
+
variant="secondary",
|
| 95 |
+
size="sm",
|
| 96 |
+
scale=1
|
| 97 |
+
)
|
| 98 |
+
save_btn_2 = gr.Button(
|
| 99 |
+
t("results.save_btn"),
|
| 100 |
+
variant="primary",
|
| 101 |
+
size="sm",
|
| 102 |
+
scale=1
|
| 103 |
+
)
|
| 104 |
+
score_btn_2 = gr.Button(
|
| 105 |
+
t("results.score_btn"),
|
| 106 |
+
variant="secondary",
|
| 107 |
+
size="sm",
|
| 108 |
+
scale=1
|
| 109 |
+
)
|
| 110 |
+
lrc_btn_2 = gr.Button(
|
| 111 |
+
t("results.lrc_btn"),
|
| 112 |
+
variant="secondary",
|
| 113 |
+
size="sm",
|
| 114 |
+
scale=1
|
| 115 |
+
)
|
| 116 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
|
| 117 |
+
codes_display_2 = gr.Textbox(
|
| 118 |
+
label=t("results.codes_label", n=2),
|
| 119 |
+
interactive=False,
|
| 120 |
+
buttons=["copy"],
|
| 121 |
+
lines=4,
|
| 122 |
+
max_lines=4,
|
| 123 |
+
visible=True
|
| 124 |
+
)
|
| 125 |
+
score_display_2 = gr.Textbox(
|
| 126 |
+
label=t("results.quality_score_label", n=2),
|
| 127 |
+
interactive=False,
|
| 128 |
+
buttons=["copy"],
|
| 129 |
+
lines=6,
|
| 130 |
+
max_lines=6,
|
| 131 |
+
visible=True
|
| 132 |
+
)
|
| 133 |
+
lrc_display_2 = gr.Textbox(
|
| 134 |
+
label=t("results.lrc_label", n=2),
|
| 135 |
+
interactive=True,
|
| 136 |
+
buttons=["copy"],
|
| 137 |
+
lines=8,
|
| 138 |
+
max_lines=8,
|
| 139 |
+
visible=True
|
| 140 |
+
)
|
| 141 |
+
with gr.Column(visible=False) as audio_col_3:
|
| 142 |
+
generated_audio_3 = gr.Audio(
|
| 143 |
+
label=t("results.generated_music", n=3),
|
| 144 |
+
type="filepath",
|
| 145 |
+
interactive=False,
|
| 146 |
+
buttons=[]
|
| 147 |
+
)
|
| 148 |
+
with gr.Row(equal_height=True):
|
| 149 |
+
send_to_src_btn_3 = gr.Button(
|
| 150 |
+
t("results.send_to_src_btn"),
|
| 151 |
+
variant="secondary",
|
| 152 |
+
size="sm",
|
| 153 |
+
scale=1
|
| 154 |
+
)
|
| 155 |
+
save_btn_3 = gr.Button(
|
| 156 |
+
t("results.save_btn"),
|
| 157 |
+
variant="primary",
|
| 158 |
+
size="sm",
|
| 159 |
+
scale=1
|
| 160 |
+
)
|
| 161 |
+
score_btn_3 = gr.Button(
|
| 162 |
+
t("results.score_btn"),
|
| 163 |
+
variant="secondary",
|
| 164 |
+
size="sm",
|
| 165 |
+
scale=1
|
| 166 |
+
)
|
| 167 |
+
lrc_btn_3 = gr.Button(
|
| 168 |
+
t("results.lrc_btn"),
|
| 169 |
+
variant="secondary",
|
| 170 |
+
size="sm",
|
| 171 |
+
scale=1
|
| 172 |
+
)
|
| 173 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
|
| 174 |
+
codes_display_3 = gr.Textbox(
|
| 175 |
+
label=t("results.codes_label", n=3),
|
| 176 |
+
interactive=False,
|
| 177 |
+
buttons=["copy"],
|
| 178 |
+
lines=4,
|
| 179 |
+
max_lines=4,
|
| 180 |
+
visible=True
|
| 181 |
+
)
|
| 182 |
+
score_display_3 = gr.Textbox(
|
| 183 |
+
label=t("results.quality_score_label", n=3),
|
| 184 |
+
interactive=False,
|
| 185 |
+
buttons=["copy"],
|
| 186 |
+
lines=6,
|
| 187 |
+
max_lines=6,
|
| 188 |
+
visible=True
|
| 189 |
+
)
|
| 190 |
+
lrc_display_3 = gr.Textbox(
|
| 191 |
+
label=t("results.lrc_label", n=3),
|
| 192 |
+
interactive=True,
|
| 193 |
+
buttons=["copy"],
|
| 194 |
+
lines=8,
|
| 195 |
+
max_lines=8,
|
| 196 |
+
visible=True
|
| 197 |
+
)
|
| 198 |
+
with gr.Column(visible=False) as audio_col_4:
|
| 199 |
+
generated_audio_4 = gr.Audio(
|
| 200 |
+
label=t("results.generated_music", n=4),
|
| 201 |
+
type="filepath",
|
| 202 |
+
interactive=False,
|
| 203 |
+
buttons=[]
|
| 204 |
+
)
|
| 205 |
+
with gr.Row(equal_height=True):
|
| 206 |
+
send_to_src_btn_4 = gr.Button(
|
| 207 |
+
t("results.send_to_src_btn"),
|
| 208 |
+
variant="secondary",
|
| 209 |
+
size="sm",
|
| 210 |
+
scale=1
|
| 211 |
+
)
|
| 212 |
+
save_btn_4 = gr.Button(
|
| 213 |
+
t("results.save_btn"),
|
| 214 |
+
variant="primary",
|
| 215 |
+
size="sm",
|
| 216 |
+
scale=1
|
| 217 |
+
)
|
| 218 |
+
score_btn_4 = gr.Button(
|
| 219 |
+
t("results.score_btn"),
|
| 220 |
+
variant="secondary",
|
| 221 |
+
size="sm",
|
| 222 |
+
scale=1
|
| 223 |
+
)
|
| 224 |
+
lrc_btn_4 = gr.Button(
|
| 225 |
+
t("results.lrc_btn"),
|
| 226 |
+
variant="secondary",
|
| 227 |
+
size="sm",
|
| 228 |
+
scale=1
|
| 229 |
+
)
|
| 230 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
|
| 231 |
+
codes_display_4 = gr.Textbox(
|
| 232 |
+
label=t("results.codes_label", n=4),
|
| 233 |
+
interactive=False,
|
| 234 |
+
buttons=["copy"],
|
| 235 |
+
lines=4,
|
| 236 |
+
max_lines=4,
|
| 237 |
+
visible=True
|
| 238 |
+
)
|
| 239 |
+
score_display_4 = gr.Textbox(
|
| 240 |
+
label=t("results.quality_score_label", n=4),
|
| 241 |
+
interactive=False,
|
| 242 |
+
buttons=["copy"],
|
| 243 |
+
lines=6,
|
| 244 |
+
max_lines=6,
|
| 245 |
+
visible=True
|
| 246 |
+
)
|
| 247 |
+
lrc_display_4 = gr.Textbox(
|
| 248 |
+
label=t("results.lrc_label", n=4),
|
| 249 |
+
interactive=True,
|
| 250 |
+
buttons=["copy"],
|
| 251 |
+
lines=8,
|
| 252 |
+
max_lines=8,
|
| 253 |
+
visible=True
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Second row for batch size 5-8 (initially hidden)
|
| 257 |
+
with gr.Row(visible=False) as audio_row_5_8:
|
| 258 |
+
with gr.Column() as audio_col_5:
|
| 259 |
+
generated_audio_5 = gr.Audio(
|
| 260 |
+
label=t("results.generated_music", n=5),
|
| 261 |
+
type="filepath",
|
| 262 |
+
interactive=False,
|
| 263 |
+
buttons=[]
|
| 264 |
+
)
|
| 265 |
+
with gr.Row(equal_height=True):
|
| 266 |
+
send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 267 |
+
save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 268 |
+
score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 269 |
+
lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 270 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
|
| 271 |
+
codes_display_5 = gr.Textbox(
|
| 272 |
+
label=t("results.codes_label", n=5),
|
| 273 |
+
interactive=False,
|
| 274 |
+
buttons=["copy"],
|
| 275 |
+
lines=4,
|
| 276 |
+
max_lines=4,
|
| 277 |
+
visible=True
|
| 278 |
+
)
|
| 279 |
+
score_display_5 = gr.Textbox(
|
| 280 |
+
label=t("results.quality_score_label", n=5),
|
| 281 |
+
interactive=False,
|
| 282 |
+
buttons=["copy"],
|
| 283 |
+
lines=6,
|
| 284 |
+
max_lines=6,
|
| 285 |
+
visible=True
|
| 286 |
+
)
|
| 287 |
+
lrc_display_5 = gr.Textbox(
|
| 288 |
+
label=t("results.lrc_label", n=5),
|
| 289 |
+
interactive=True,
|
| 290 |
+
buttons=["copy"],
|
| 291 |
+
lines=8,
|
| 292 |
+
max_lines=8,
|
| 293 |
+
visible=True
|
| 294 |
+
)
|
| 295 |
+
with gr.Column() as audio_col_6:
|
| 296 |
+
generated_audio_6 = gr.Audio(
|
| 297 |
+
label=t("results.generated_music", n=6),
|
| 298 |
+
type="filepath",
|
| 299 |
+
interactive=False,
|
| 300 |
+
buttons=[]
|
| 301 |
+
)
|
| 302 |
+
with gr.Row(equal_height=True):
|
| 303 |
+
send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 304 |
+
save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 305 |
+
score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 306 |
+
lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 307 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
|
| 308 |
+
codes_display_6 = gr.Textbox(
|
| 309 |
+
label=t("results.codes_label", n=6),
|
| 310 |
+
interactive=False,
|
| 311 |
+
buttons=["copy"],
|
| 312 |
+
lines=4,
|
| 313 |
+
max_lines=4,
|
| 314 |
+
visible=True
|
| 315 |
+
)
|
| 316 |
+
score_display_6 = gr.Textbox(
|
| 317 |
+
label=t("results.quality_score_label", n=6),
|
| 318 |
+
interactive=False,
|
| 319 |
+
buttons=["copy"],
|
| 320 |
+
lines=6,
|
| 321 |
+
max_lines=6,
|
| 322 |
+
visible=True
|
| 323 |
+
)
|
| 324 |
+
lrc_display_6 = gr.Textbox(
|
| 325 |
+
label=t("results.lrc_label", n=6),
|
| 326 |
+
interactive=True,
|
| 327 |
+
buttons=["copy"],
|
| 328 |
+
lines=8,
|
| 329 |
+
max_lines=8,
|
| 330 |
+
visible=True
|
| 331 |
+
)
|
| 332 |
+
with gr.Column() as audio_col_7:
|
| 333 |
+
generated_audio_7 = gr.Audio(
|
| 334 |
+
label=t("results.generated_music", n=7),
|
| 335 |
+
type="filepath",
|
| 336 |
+
interactive=False,
|
| 337 |
+
buttons=[]
|
| 338 |
+
)
|
| 339 |
+
with gr.Row(equal_height=True):
|
| 340 |
+
send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 341 |
+
save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 342 |
+
score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 343 |
+
lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 344 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
|
| 345 |
+
codes_display_7 = gr.Textbox(
|
| 346 |
+
label=t("results.codes_label", n=7),
|
| 347 |
+
interactive=False,
|
| 348 |
+
buttons=["copy"],
|
| 349 |
+
lines=4,
|
| 350 |
+
max_lines=4,
|
| 351 |
+
visible=True
|
| 352 |
+
)
|
| 353 |
+
score_display_7 = gr.Textbox(
|
| 354 |
+
label=t("results.quality_score_label", n=7),
|
| 355 |
+
interactive=False,
|
| 356 |
+
buttons=["copy"],
|
| 357 |
+
lines=6,
|
| 358 |
+
max_lines=6,
|
| 359 |
+
visible=True
|
| 360 |
+
)
|
| 361 |
+
lrc_display_7 = gr.Textbox(
|
| 362 |
+
label=t("results.lrc_label", n=7),
|
| 363 |
+
interactive=True,
|
| 364 |
+
buttons=["copy"],
|
| 365 |
+
lines=8,
|
| 366 |
+
max_lines=8,
|
| 367 |
+
visible=True
|
| 368 |
+
)
|
| 369 |
+
with gr.Column() as audio_col_8:
|
| 370 |
+
generated_audio_8 = gr.Audio(
|
| 371 |
+
label=t("results.generated_music", n=8),
|
| 372 |
+
type="filepath",
|
| 373 |
+
interactive=False,
|
| 374 |
+
buttons=[]
|
| 375 |
+
)
|
| 376 |
+
with gr.Row(equal_height=True):
|
| 377 |
+
send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
|
| 378 |
+
save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
|
| 379 |
+
score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
|
| 380 |
+
lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
|
| 381 |
+
with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
|
| 382 |
+
codes_display_8 = gr.Textbox(
|
| 383 |
+
label=t("results.codes_label", n=8),
|
| 384 |
+
interactive=False,
|
| 385 |
+
buttons=["copy"],
|
| 386 |
+
lines=4,
|
| 387 |
+
max_lines=4,
|
| 388 |
+
visible=True
|
| 389 |
+
)
|
| 390 |
+
score_display_8 = gr.Textbox(
|
| 391 |
+
label=t("results.quality_score_label", n=8),
|
| 392 |
+
interactive=False,
|
| 393 |
+
buttons=["copy"],
|
| 394 |
+
lines=6,
|
| 395 |
+
max_lines=6,
|
| 396 |
+
visible=True
|
| 397 |
+
)
|
| 398 |
+
lrc_display_8 = gr.Textbox(
|
| 399 |
+
label=t("results.lrc_label", n=8),
|
| 400 |
+
interactive=True,
|
| 401 |
+
buttons=["copy"],
|
| 402 |
+
lines=8,
|
| 403 |
+
max_lines=8,
|
| 404 |
+
visible=True
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
|
| 408 |
+
|
| 409 |
+
# Batch navigation controls
|
| 410 |
+
with gr.Row(equal_height=True):
|
| 411 |
+
prev_batch_btn = gr.Button(
|
| 412 |
+
t("results.prev_btn"),
|
| 413 |
+
variant="secondary",
|
| 414 |
+
interactive=False,
|
| 415 |
+
scale=1,
|
| 416 |
+
size="sm"
|
| 417 |
+
)
|
| 418 |
+
batch_indicator = gr.Textbox(
|
| 419 |
+
label=t("results.current_batch"),
|
| 420 |
+
value=t("results.batch_indicator", current=1, total=1),
|
| 421 |
+
interactive=False,
|
| 422 |
+
scale=3
|
| 423 |
+
)
|
| 424 |
+
next_batch_status = gr.Textbox(
|
| 425 |
+
label=t("results.next_batch_status"),
|
| 426 |
+
value="",
|
| 427 |
+
interactive=False,
|
| 428 |
+
scale=3
|
| 429 |
+
)
|
| 430 |
+
next_batch_btn = gr.Button(
|
| 431 |
+
t("results.next_btn"),
|
| 432 |
+
variant="primary",
|
| 433 |
+
interactive=False,
|
| 434 |
+
scale=1,
|
| 435 |
+
size="sm"
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# One-click restore parameters button
|
| 439 |
+
restore_params_btn = gr.Button(
|
| 440 |
+
t("results.restore_params_btn"),
|
| 441 |
+
variant="secondary",
|
| 442 |
+
interactive=False, # Initially disabled, enabled after generation
|
| 443 |
+
size="sm"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
with gr.Accordion(t("results.batch_results_title"), open=False):
|
| 447 |
+
generated_audio_batch = gr.File(
|
| 448 |
+
label=t("results.all_files_label"),
|
| 449 |
+
file_count="multiple",
|
| 450 |
+
interactive=False
|
| 451 |
+
)
|
| 452 |
+
generation_info = gr.Markdown(label=t("results.generation_details"))
|
| 453 |
+
|
| 454 |
+
return {
|
| 455 |
+
"lm_metadata_state": lm_metadata_state,
|
| 456 |
+
"is_format_caption_state": is_format_caption_state,
|
| 457 |
+
"current_batch_index": current_batch_index,
|
| 458 |
+
"total_batches": total_batches,
|
| 459 |
+
"batch_queue": batch_queue,
|
| 460 |
+
"generation_params_state": generation_params_state,
|
| 461 |
+
"is_generating_background": is_generating_background,
|
| 462 |
+
"status_output": status_output,
|
| 463 |
+
"prev_batch_btn": prev_batch_btn,
|
| 464 |
+
"batch_indicator": batch_indicator,
|
| 465 |
+
"next_batch_btn": next_batch_btn,
|
| 466 |
+
"next_batch_status": next_batch_status,
|
| 467 |
+
"restore_params_btn": restore_params_btn,
|
| 468 |
+
"generated_audio_1": generated_audio_1,
|
| 469 |
+
"generated_audio_2": generated_audio_2,
|
| 470 |
+
"generated_audio_3": generated_audio_3,
|
| 471 |
+
"generated_audio_4": generated_audio_4,
|
| 472 |
+
"generated_audio_5": generated_audio_5,
|
| 473 |
+
"generated_audio_6": generated_audio_6,
|
| 474 |
+
"generated_audio_7": generated_audio_7,
|
| 475 |
+
"generated_audio_8": generated_audio_8,
|
| 476 |
+
"audio_row_5_8": audio_row_5_8,
|
| 477 |
+
"audio_col_1": audio_col_1,
|
| 478 |
+
"audio_col_2": audio_col_2,
|
| 479 |
+
"audio_col_3": audio_col_3,
|
| 480 |
+
"audio_col_4": audio_col_4,
|
| 481 |
+
"audio_col_5": audio_col_5,
|
| 482 |
+
"audio_col_6": audio_col_6,
|
| 483 |
+
"audio_col_7": audio_col_7,
|
| 484 |
+
"audio_col_8": audio_col_8,
|
| 485 |
+
"send_to_src_btn_1": send_to_src_btn_1,
|
| 486 |
+
"send_to_src_btn_2": send_to_src_btn_2,
|
| 487 |
+
"send_to_src_btn_3": send_to_src_btn_3,
|
| 488 |
+
"send_to_src_btn_4": send_to_src_btn_4,
|
| 489 |
+
"send_to_src_btn_5": send_to_src_btn_5,
|
| 490 |
+
"send_to_src_btn_6": send_to_src_btn_6,
|
| 491 |
+
"send_to_src_btn_7": send_to_src_btn_7,
|
| 492 |
+
"send_to_src_btn_8": send_to_src_btn_8,
|
| 493 |
+
"save_btn_1": save_btn_1,
|
| 494 |
+
"save_btn_2": save_btn_2,
|
| 495 |
+
"save_btn_3": save_btn_3,
|
| 496 |
+
"save_btn_4": save_btn_4,
|
| 497 |
+
"save_btn_5": save_btn_5,
|
| 498 |
+
"save_btn_6": save_btn_6,
|
| 499 |
+
"save_btn_7": save_btn_7,
|
| 500 |
+
"save_btn_8": save_btn_8,
|
| 501 |
+
"score_btn_1": score_btn_1,
|
| 502 |
+
"score_btn_2": score_btn_2,
|
| 503 |
+
"score_btn_3": score_btn_3,
|
| 504 |
+
"score_btn_4": score_btn_4,
|
| 505 |
+
"score_btn_5": score_btn_5,
|
| 506 |
+
"score_btn_6": score_btn_6,
|
| 507 |
+
"score_btn_7": score_btn_7,
|
| 508 |
+
"score_btn_8": score_btn_8,
|
| 509 |
+
"score_display_1": score_display_1,
|
| 510 |
+
"score_display_2": score_display_2,
|
| 511 |
+
"score_display_3": score_display_3,
|
| 512 |
+
"score_display_4": score_display_4,
|
| 513 |
+
"score_display_5": score_display_5,
|
| 514 |
+
"score_display_6": score_display_6,
|
| 515 |
+
"score_display_7": score_display_7,
|
| 516 |
+
"score_display_8": score_display_8,
|
| 517 |
+
"codes_display_1": codes_display_1,
|
| 518 |
+
"codes_display_2": codes_display_2,
|
| 519 |
+
"codes_display_3": codes_display_3,
|
| 520 |
+
"codes_display_4": codes_display_4,
|
| 521 |
+
"codes_display_5": codes_display_5,
|
| 522 |
+
"codes_display_6": codes_display_6,
|
| 523 |
+
"codes_display_7": codes_display_7,
|
| 524 |
+
"codes_display_8": codes_display_8,
|
| 525 |
+
"lrc_btn_1": lrc_btn_1,
|
| 526 |
+
"lrc_btn_2": lrc_btn_2,
|
| 527 |
+
"lrc_btn_3": lrc_btn_3,
|
| 528 |
+
"lrc_btn_4": lrc_btn_4,
|
| 529 |
+
"lrc_btn_5": lrc_btn_5,
|
| 530 |
+
"lrc_btn_6": lrc_btn_6,
|
| 531 |
+
"lrc_btn_7": lrc_btn_7,
|
| 532 |
+
"lrc_btn_8": lrc_btn_8,
|
| 533 |
+
"lrc_display_1": lrc_display_1,
|
| 534 |
+
"lrc_display_2": lrc_display_2,
|
| 535 |
+
"lrc_display_3": lrc_display_3,
|
| 536 |
+
"lrc_display_4": lrc_display_4,
|
| 537 |
+
"lrc_display_5": lrc_display_5,
|
| 538 |
+
"lrc_display_6": lrc_display_6,
|
| 539 |
+
"lrc_display_7": lrc_display_7,
|
| 540 |
+
"lrc_display_8": lrc_display_8,
|
| 541 |
+
"details_accordion_1": details_accordion_1,
|
| 542 |
+
"details_accordion_2": details_accordion_2,
|
| 543 |
+
"details_accordion_3": details_accordion_3,
|
| 544 |
+
"details_accordion_4": details_accordion_4,
|
| 545 |
+
"details_accordion_5": details_accordion_5,
|
| 546 |
+
"details_accordion_6": details_accordion_6,
|
| 547 |
+
"details_accordion_7": details_accordion_7,
|
| 548 |
+
"details_accordion_8": details_accordion_8,
|
| 549 |
+
"generated_audio_batch": generated_audio_batch,
|
| 550 |
+
"generation_info": generation_info,
|
| 551 |
+
}
|
| 552 |
+
|
acestep/gradio_ui/interfaces/training.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI Training Tab Module
|
| 3 |
+
|
| 4 |
+
Contains the dataset builder and LoRA training interface components.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from acestep.gradio_ui.i18n import t
|
| 10 |
+
from acestep.constants import DEBUG_TRAINING
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
|
| 14 |
+
"""Create the training tab section with dataset builder and training controls.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
dit_handler: DiT handler instance
|
| 18 |
+
llm_handler: LLM handler instance
|
| 19 |
+
init_params: Dictionary containing initialization parameters and state.
|
| 20 |
+
If None, service will not be pre-initialized.
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
Dictionary of Gradio components for event handling
|
| 24 |
+
"""
|
| 25 |
+
# Check if running in service mode (hide training tab)
|
| 26 |
+
service_mode = init_params is not None and init_params.get('service_mode', False)
|
| 27 |
+
|
| 28 |
+
debug_training_enabled = str(DEBUG_TRAINING).strip().upper() != "OFF"
|
| 29 |
+
epoch_min = 1 if debug_training_enabled else 100
|
| 30 |
+
epoch_step = 1 if debug_training_enabled else 100
|
| 31 |
+
epoch_default = 1 if debug_training_enabled else 1000
|
| 32 |
+
|
| 33 |
+
with gr.Tab(t("training.tab_title"), visible=not service_mode):
|
| 34 |
+
gr.HTML("""
|
| 35 |
+
<div style="text-align: center; padding: 10px; margin-bottom: 15px;">
|
| 36 |
+
<h2>🎵 LoRA Training for ACE-Step</h2>
|
| 37 |
+
<p>Build datasets from your audio files and train custom LoRA adapters</p>
|
| 38 |
+
</div>
|
| 39 |
+
""")
|
| 40 |
+
|
| 41 |
+
with gr.Tabs():
|
| 42 |
+
# ==================== Dataset Builder Tab ====================
|
| 43 |
+
with gr.Tab(t("training.tab_dataset_builder")):
|
| 44 |
+
# ========== Load Existing OR Scan New ==========
|
| 45 |
+
gr.HTML(f"""
|
| 46 |
+
<div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
|
| 47 |
+
<h3 style="margin: 0 0 5px 0;">{t("training.quick_start_title")}</h3>
|
| 48 |
+
<p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
|
| 49 |
+
</div>
|
| 50 |
+
""")
|
| 51 |
+
|
| 52 |
+
with gr.Row():
|
| 53 |
+
with gr.Column(scale=1):
|
| 54 |
+
gr.HTML("<h4>📂 Load Existing Dataset</h4>")
|
| 55 |
+
with gr.Row():
|
| 56 |
+
load_json_path = gr.Textbox(
|
| 57 |
+
label=t("training.load_dataset_label"),
|
| 58 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 59 |
+
info=t("training.load_dataset_info"),
|
| 60 |
+
scale=3,
|
| 61 |
+
)
|
| 62 |
+
load_json_btn = gr.Button(t("training.load_btn"), variant="primary", scale=1)
|
| 63 |
+
load_json_status = gr.Textbox(
|
| 64 |
+
label=t("training.load_status"),
|
| 65 |
+
interactive=False,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
with gr.Column(scale=1):
|
| 69 |
+
gr.HTML("<h4>🔍 Scan New Directory</h4>")
|
| 70 |
+
with gr.Row():
|
| 71 |
+
audio_directory = gr.Textbox(
|
| 72 |
+
label=t("training.scan_label"),
|
| 73 |
+
placeholder="/path/to/your/audio/folder",
|
| 74 |
+
info=t("training.scan_info"),
|
| 75 |
+
scale=3,
|
| 76 |
+
)
|
| 77 |
+
scan_btn = gr.Button(t("training.scan_btn"), variant="secondary", scale=1)
|
| 78 |
+
scan_status = gr.Textbox(
|
| 79 |
+
label=t("training.scan_status"),
|
| 80 |
+
interactive=False,
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
gr.HTML("<hr>")
|
| 84 |
+
|
| 85 |
+
with gr.Row():
|
| 86 |
+
with gr.Column(scale=2):
|
| 87 |
+
|
| 88 |
+
# Audio files table
|
| 89 |
+
audio_files_table = gr.Dataframe(
|
| 90 |
+
headers=["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"],
|
| 91 |
+
datatype=["number", "str", "str", "str", "str", "str", "str", "str"],
|
| 92 |
+
label=t("training.found_audio_files"),
|
| 93 |
+
interactive=False,
|
| 94 |
+
wrap=True,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
with gr.Column(scale=1):
|
| 98 |
+
gr.HTML(f"<h3>⚙️ {t('training.dataset_settings_header')}</h3>")
|
| 99 |
+
|
| 100 |
+
dataset_name = gr.Textbox(
|
| 101 |
+
label=t("training.dataset_name"),
|
| 102 |
+
value="my_lora_dataset",
|
| 103 |
+
placeholder=t("training.dataset_name_placeholder"),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
all_instrumental = gr.Checkbox(
|
| 107 |
+
label=t("training.all_instrumental"),
|
| 108 |
+
value=True,
|
| 109 |
+
info=t("training.all_instrumental_info"),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
format_lyrics = gr.Checkbox(
|
| 113 |
+
label="Format Lyrics (LM)",
|
| 114 |
+
value=False,
|
| 115 |
+
info="Use LM to format/structure user-provided lyrics from .txt files (coming soon)",
|
| 116 |
+
interactive=False, # Disabled for now - model update needed
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
transcribe_lyrics = gr.Checkbox(
|
| 120 |
+
label="Transcribe Lyrics (LM)",
|
| 121 |
+
value=False,
|
| 122 |
+
info="Use LM to transcribe lyrics from audio (coming soon)",
|
| 123 |
+
interactive=False, # Disabled for now - model update needed
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
custom_tag = gr.Textbox(
|
| 127 |
+
label=t("training.custom_tag"),
|
| 128 |
+
placeholder="e.g., 8bit_retro, my_style",
|
| 129 |
+
info=t("training.custom_tag_info"),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
tag_position = gr.Radio(
|
| 133 |
+
choices=[
|
| 134 |
+
(t("training.tag_prepend"), "prepend"),
|
| 135 |
+
(t("training.tag_append"), "append"),
|
| 136 |
+
(t("training.tag_replace"), "replace"),
|
| 137 |
+
],
|
| 138 |
+
value="replace",
|
| 139 |
+
label=t("training.tag_position"),
|
| 140 |
+
info=t("training.tag_position_info"),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
genre_ratio = gr.Slider(
|
| 144 |
+
minimum=0,
|
| 145 |
+
maximum=100,
|
| 146 |
+
step=10,
|
| 147 |
+
value=0,
|
| 148 |
+
label=t("training.genre_ratio"),
|
| 149 |
+
info=t("training.genre_ratio_info"),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
gr.HTML(f"<hr><h3>🤖 {t('training.step2_title')}</h3>")
|
| 153 |
+
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column(scale=3):
|
| 156 |
+
gr.Markdown("""
|
| 157 |
+
Click the button below to automatically generate metadata for all audio files using AI:
|
| 158 |
+
- **Caption**: Music style, genre, mood description
|
| 159 |
+
- **BPM**: Beats per minute
|
| 160 |
+
- **Key**: Musical key (e.g., C Major, Am)
|
| 161 |
+
- **Time Signature**: 4/4, 3/4, etc.
|
| 162 |
+
""")
|
| 163 |
+
skip_metas = gr.Checkbox(
|
| 164 |
+
label=t("training.skip_metas"),
|
| 165 |
+
value=False,
|
| 166 |
+
info=t("training.skip_metas_info"),
|
| 167 |
+
)
|
| 168 |
+
only_unlabeled = gr.Checkbox(
|
| 169 |
+
label=t("training.only_unlabeled"),
|
| 170 |
+
value=False,
|
| 171 |
+
info=t("training.only_unlabeled_info"),
|
| 172 |
+
)
|
| 173 |
+
with gr.Column(scale=1):
|
| 174 |
+
auto_label_btn = gr.Button(
|
| 175 |
+
t("training.auto_label_btn"),
|
| 176 |
+
variant="primary",
|
| 177 |
+
size="lg",
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
label_progress = gr.Textbox(
|
| 181 |
+
label=t("training.label_progress"),
|
| 182 |
+
interactive=False,
|
| 183 |
+
lines=2,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
gr.HTML(f"<hr><h3>👀 {t('training.step3_title')}</h3>")
|
| 187 |
+
|
| 188 |
+
with gr.Row():
|
| 189 |
+
with gr.Column(scale=1):
|
| 190 |
+
sample_selector = gr.Slider(
|
| 191 |
+
minimum=0,
|
| 192 |
+
maximum=0,
|
| 193 |
+
step=1,
|
| 194 |
+
value=0,
|
| 195 |
+
label=t("training.select_sample"),
|
| 196 |
+
info=t("training.select_sample_info"),
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
preview_audio = gr.Audio(
|
| 200 |
+
label=t("training.audio_preview"),
|
| 201 |
+
type="filepath",
|
| 202 |
+
interactive=False,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
preview_filename = gr.Textbox(
|
| 206 |
+
label=t("training.filename"),
|
| 207 |
+
interactive=False,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
with gr.Column(scale=2):
|
| 211 |
+
with gr.Row():
|
| 212 |
+
edit_caption = gr.Textbox(
|
| 213 |
+
label=t("training.caption"),
|
| 214 |
+
lines=3,
|
| 215 |
+
placeholder="Music description...",
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
with gr.Row():
|
| 219 |
+
edit_genre = gr.Textbox(
|
| 220 |
+
label=t("training.genre"),
|
| 221 |
+
lines=1,
|
| 222 |
+
placeholder="pop, electronic, dance...",
|
| 223 |
+
)
|
| 224 |
+
prompt_override = gr.Dropdown(
|
| 225 |
+
choices=["Use Global Ratio", "Caption", "Genre"],
|
| 226 |
+
value="Use Global Ratio",
|
| 227 |
+
label=t("training.prompt_override_label"),
|
| 228 |
+
info=t("training.prompt_override_info"),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
with gr.Row():
|
| 232 |
+
edit_lyrics = gr.Textbox(
|
| 233 |
+
label=t("training.lyrics_editable_label"),
|
| 234 |
+
lines=6,
|
| 235 |
+
placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
|
| 236 |
+
)
|
| 237 |
+
raw_lyrics_display = gr.Textbox(
|
| 238 |
+
label=t("training.raw_lyrics_label"),
|
| 239 |
+
lines=6,
|
| 240 |
+
placeholder=t("training.no_lyrics_placeholder"),
|
| 241 |
+
interactive=False, # Read-only, can copy but not edit
|
| 242 |
+
visible=False, # Hidden when no raw lyrics
|
| 243 |
+
)
|
| 244 |
+
has_raw_lyrics_state = gr.State(False) # Track visibility
|
| 245 |
+
|
| 246 |
+
with gr.Row():
|
| 247 |
+
edit_bpm = gr.Number(
|
| 248 |
+
label=t("training.bpm"),
|
| 249 |
+
precision=0,
|
| 250 |
+
)
|
| 251 |
+
edit_keyscale = gr.Textbox(
|
| 252 |
+
label=t("training.key_label"),
|
| 253 |
+
placeholder=t("training.key_placeholder"),
|
| 254 |
+
)
|
| 255 |
+
edit_timesig = gr.Dropdown(
|
| 256 |
+
choices=["", "2", "3", "4", "6", "N/A"],
|
| 257 |
+
label=t("training.time_sig"),
|
| 258 |
+
)
|
| 259 |
+
edit_duration = gr.Number(
|
| 260 |
+
label=t("training.duration_s"),
|
| 261 |
+
precision=1,
|
| 262 |
+
interactive=False,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
with gr.Row():
|
| 266 |
+
edit_language = gr.Dropdown(
|
| 267 |
+
choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
|
| 268 |
+
value="instrumental",
|
| 269 |
+
label=t("training.language"),
|
| 270 |
+
)
|
| 271 |
+
edit_instrumental = gr.Checkbox(
|
| 272 |
+
label=t("training.instrumental"),
|
| 273 |
+
value=True,
|
| 274 |
+
)
|
| 275 |
+
save_edit_btn = gr.Button(t("training.save_changes_btn"), variant="secondary")
|
| 276 |
+
|
| 277 |
+
edit_status = gr.Textbox(
|
| 278 |
+
label=t("training.edit_status"),
|
| 279 |
+
interactive=False,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
gr.HTML(f"<hr><h3>💾 {t('training.step4_title')}</h3>")
|
| 283 |
+
|
| 284 |
+
with gr.Row():
|
| 285 |
+
with gr.Column(scale=3):
|
| 286 |
+
save_path = gr.Textbox(
|
| 287 |
+
label=t("training.save_path"),
|
| 288 |
+
value="./datasets/my_lora_dataset.json",
|
| 289 |
+
placeholder="./datasets/dataset_name.json",
|
| 290 |
+
info=t("training.save_path_info"),
|
| 291 |
+
)
|
| 292 |
+
with gr.Column(scale=1):
|
| 293 |
+
save_dataset_btn = gr.Button(
|
| 294 |
+
t("training.save_dataset_btn"),
|
| 295 |
+
variant="primary",
|
| 296 |
+
size="lg",
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
save_status = gr.Textbox(
|
| 300 |
+
label=t("training.save_status"),
|
| 301 |
+
interactive=False,
|
| 302 |
+
lines=2,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
gr.HTML(f"<hr><h3>⚡ {t('training.step5_title')}</h3>")
|
| 306 |
+
|
| 307 |
+
gr.Markdown("""
|
| 308 |
+
**Preprocessing converts your dataset to pre-computed tensors for fast training.**
|
| 309 |
+
|
| 310 |
+
You can either:
|
| 311 |
+
- Use the dataset from Steps 1-4 above, **OR**
|
| 312 |
+
- Load an existing dataset JSON file (if you've already saved one)
|
| 313 |
+
""")
|
| 314 |
+
|
| 315 |
+
with gr.Row():
|
| 316 |
+
with gr.Column(scale=3):
|
| 317 |
+
load_existing_dataset_path = gr.Textbox(
|
| 318 |
+
label=t("training.load_existing_label"),
|
| 319 |
+
placeholder="./datasets/my_lora_dataset.json",
|
| 320 |
+
info=t("training.load_existing_info"),
|
| 321 |
+
)
|
| 322 |
+
with gr.Column(scale=1):
|
| 323 |
+
load_existing_dataset_btn = gr.Button(
|
| 324 |
+
t("training.load_dataset_btn"),
|
| 325 |
+
variant="secondary",
|
| 326 |
+
size="lg",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
load_existing_status = gr.Textbox(
|
| 330 |
+
label=t("training.load_status"),
|
| 331 |
+
interactive=False,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
gr.Markdown("""
|
| 335 |
+
This step:
|
| 336 |
+
- Encodes audio to VAE latents
|
| 337 |
+
- Encodes captions and lyrics to text embeddings
|
| 338 |
+
- Runs the condition encoder
|
| 339 |
+
- Saves all tensors to `.pt` files
|
| 340 |
+
|
| 341 |
+
⚠️ **This requires the model to be loaded and may take a few minutes.**
|
| 342 |
+
""")
|
| 343 |
+
|
| 344 |
+
with gr.Row():
|
| 345 |
+
with gr.Column(scale=3):
|
| 346 |
+
preprocess_output_dir = gr.Textbox(
|
| 347 |
+
label=t("training.tensor_output_dir"),
|
| 348 |
+
value="./datasets/preprocessed_tensors",
|
| 349 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 350 |
+
info=t("training.tensor_output_info"),
|
| 351 |
+
)
|
| 352 |
+
with gr.Column(scale=1):
|
| 353 |
+
preprocess_btn = gr.Button(
|
| 354 |
+
t("training.preprocess_btn"),
|
| 355 |
+
variant="primary",
|
| 356 |
+
size="lg",
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
preprocess_progress = gr.Textbox(
|
| 360 |
+
label=t("training.preprocess_progress"),
|
| 361 |
+
interactive=False,
|
| 362 |
+
lines=3,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# ==================== Training Tab ====================
|
| 366 |
+
with gr.Tab(t("training.tab_train_lora")):
|
| 367 |
+
with gr.Row():
|
| 368 |
+
with gr.Column(scale=2):
|
| 369 |
+
gr.HTML(f"<h3>📊 {t('training.train_section_tensors')}</h3>")
|
| 370 |
+
|
| 371 |
+
gr.Markdown("""
|
| 372 |
+
Select the directory containing preprocessed tensor files (`.pt` files).
|
| 373 |
+
These are created in the "Dataset Builder" tab using the "Preprocess" button.
|
| 374 |
+
""")
|
| 375 |
+
|
| 376 |
+
training_tensor_dir = gr.Textbox(
|
| 377 |
+
label=t("training.preprocessed_tensors_dir"),
|
| 378 |
+
placeholder="./datasets/preprocessed_tensors",
|
| 379 |
+
value="./datasets/preprocessed_tensors",
|
| 380 |
+
info=t("training.preprocessed_tensors_info"),
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
load_dataset_btn = gr.Button(t("training.load_dataset_btn"), variant="secondary")
|
| 384 |
+
|
| 385 |
+
training_dataset_info = gr.Textbox(
|
| 386 |
+
label=t("training.dataset_info"),
|
| 387 |
+
interactive=False,
|
| 388 |
+
lines=3,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
with gr.Column(scale=1):
|
| 392 |
+
gr.HTML(f"<h3>⚙️ {t('training.train_section_lora')}</h3>")
|
| 393 |
+
|
| 394 |
+
lora_rank = gr.Slider(
|
| 395 |
+
minimum=4,
|
| 396 |
+
maximum=256,
|
| 397 |
+
step=4,
|
| 398 |
+
value=64,
|
| 399 |
+
label=t("training.lora_rank"),
|
| 400 |
+
info=t("training.lora_rank_info"),
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
lora_alpha = gr.Slider(
|
| 404 |
+
minimum=4,
|
| 405 |
+
maximum=512,
|
| 406 |
+
step=4,
|
| 407 |
+
value=128,
|
| 408 |
+
label=t("training.lora_alpha"),
|
| 409 |
+
info=t("training.lora_alpha_info"),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
lora_dropout = gr.Slider(
|
| 413 |
+
minimum=0.0,
|
| 414 |
+
maximum=0.5,
|
| 415 |
+
step=0.05,
|
| 416 |
+
value=0.1,
|
| 417 |
+
label=t("training.lora_dropout"),
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
gr.HTML(f"<hr><h3>🎛️ {t('training.train_section_params')}</h3>")
|
| 421 |
+
|
| 422 |
+
with gr.Row():
|
| 423 |
+
learning_rate = gr.Number(
|
| 424 |
+
label=t("training.learning_rate"),
|
| 425 |
+
value=3e-4,
|
| 426 |
+
info=t("training.learning_rate_info"),
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
train_epochs = gr.Slider(
|
| 430 |
+
minimum=epoch_min,
|
| 431 |
+
maximum=4000,
|
| 432 |
+
step=epoch_step,
|
| 433 |
+
value=epoch_default,
|
| 434 |
+
label=t("training.max_epochs"),
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
train_batch_size = gr.Slider(
|
| 438 |
+
minimum=1,
|
| 439 |
+
maximum=8,
|
| 440 |
+
step=1,
|
| 441 |
+
value=1,
|
| 442 |
+
label=t("training.batch_size"),
|
| 443 |
+
info=t("training.batch_size_info"),
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
gradient_accumulation = gr.Slider(
|
| 447 |
+
minimum=1,
|
| 448 |
+
maximum=16,
|
| 449 |
+
step=1,
|
| 450 |
+
value=1,
|
| 451 |
+
label=t("training.gradient_accumulation"),
|
| 452 |
+
info=t("training.gradient_accumulation_info"),
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
with gr.Row():
|
| 456 |
+
save_every_n_epochs = gr.Slider(
|
| 457 |
+
minimum=50,
|
| 458 |
+
maximum=1000,
|
| 459 |
+
step=50,
|
| 460 |
+
value=200,
|
| 461 |
+
label=t("training.save_every_n_epochs"),
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
training_shift = gr.Slider(
|
| 465 |
+
minimum=1.0,
|
| 466 |
+
maximum=5.0,
|
| 467 |
+
step=0.5,
|
| 468 |
+
value=3.0,
|
| 469 |
+
label=t("training.shift"),
|
| 470 |
+
info=t("training.shift_info"),
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
training_seed = gr.Number(
|
| 474 |
+
label=t("training.seed"),
|
| 475 |
+
value=42,
|
| 476 |
+
precision=0,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
with gr.Row():
|
| 480 |
+
lora_output_dir = gr.Textbox(
|
| 481 |
+
label=t("training.output_dir"),
|
| 482 |
+
value="./lora_output",
|
| 483 |
+
placeholder="./lora_output",
|
| 484 |
+
info=t("training.output_dir_info"),
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
with gr.Row():
|
| 488 |
+
resume_checkpoint_dir = gr.Textbox(
|
| 489 |
+
label="Resume Checkpoint (optional)",
|
| 490 |
+
placeholder="./lora_output/checkpoints/epoch_200",
|
| 491 |
+
info="Directory of a saved LoRA checkpoint to resume from",
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
gr.HTML("<hr>")
|
| 495 |
+
|
| 496 |
+
with gr.Row():
|
| 497 |
+
with gr.Column(scale=1):
|
| 498 |
+
start_training_btn = gr.Button(
|
| 499 |
+
t("training.start_training_btn"),
|
| 500 |
+
variant="primary",
|
| 501 |
+
size="lg",
|
| 502 |
+
)
|
| 503 |
+
with gr.Column(scale=1):
|
| 504 |
+
stop_training_btn = gr.Button(
|
| 505 |
+
t("training.stop_training_btn"),
|
| 506 |
+
variant="stop",
|
| 507 |
+
size="lg",
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
training_progress = gr.Textbox(
|
| 511 |
+
label=t("training.training_progress"),
|
| 512 |
+
interactive=False,
|
| 513 |
+
lines=2,
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
with gr.Row():
|
| 517 |
+
training_log = gr.Textbox(
|
| 518 |
+
label=t("training.training_log"),
|
| 519 |
+
interactive=False,
|
| 520 |
+
lines=10,
|
| 521 |
+
max_lines=15,
|
| 522 |
+
scale=1,
|
| 523 |
+
)
|
| 524 |
+
training_loss_plot = gr.LinePlot(
|
| 525 |
+
x="step",
|
| 526 |
+
y="loss",
|
| 527 |
+
title=t("training.training_loss_title"),
|
| 528 |
+
x_title=t("training.step"),
|
| 529 |
+
y_title=t("training.loss"),
|
| 530 |
+
scale=1,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
gr.HTML(f"<hr><h3>📦 {t('training.export_header')}</h3>")
|
| 534 |
+
|
| 535 |
+
with gr.Row():
|
| 536 |
+
export_path = gr.Textbox(
|
| 537 |
+
label=t("training.export_path"),
|
| 538 |
+
value="./lora_output/final_lora",
|
| 539 |
+
placeholder="./lora_output/my_lora",
|
| 540 |
+
)
|
| 541 |
+
export_lora_btn = gr.Button(t("training.export_lora_btn"), variant="secondary")
|
| 542 |
+
|
| 543 |
+
export_status = gr.Textbox(
|
| 544 |
+
label=t("training.export_status"),
|
| 545 |
+
interactive=False,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Store dataset builder state
|
| 549 |
+
dataset_builder_state = gr.State(None)
|
| 550 |
+
training_state = gr.State({"is_training": False, "should_stop": False})
|
| 551 |
+
|
| 552 |
+
return {
|
| 553 |
+
# Dataset Builder - Load or Scan
|
| 554 |
+
"load_json_path": load_json_path,
|
| 555 |
+
"load_json_btn": load_json_btn,
|
| 556 |
+
"load_json_status": load_json_status,
|
| 557 |
+
"audio_directory": audio_directory,
|
| 558 |
+
"scan_btn": scan_btn,
|
| 559 |
+
"scan_status": scan_status,
|
| 560 |
+
"audio_files_table": audio_files_table,
|
| 561 |
+
"dataset_name": dataset_name,
|
| 562 |
+
"all_instrumental": all_instrumental,
|
| 563 |
+
"format_lyrics": format_lyrics,
|
| 564 |
+
"transcribe_lyrics": transcribe_lyrics,
|
| 565 |
+
"custom_tag": custom_tag,
|
| 566 |
+
"tag_position": tag_position,
|
| 567 |
+
"skip_metas": skip_metas,
|
| 568 |
+
"only_unlabeled": only_unlabeled,
|
| 569 |
+
"auto_label_btn": auto_label_btn,
|
| 570 |
+
"label_progress": label_progress,
|
| 571 |
+
"sample_selector": sample_selector,
|
| 572 |
+
"preview_audio": preview_audio,
|
| 573 |
+
"preview_filename": preview_filename,
|
| 574 |
+
"edit_caption": edit_caption,
|
| 575 |
+
"edit_genre": edit_genre,
|
| 576 |
+
"prompt_override": prompt_override,
|
| 577 |
+
"genre_ratio": genre_ratio,
|
| 578 |
+
"edit_lyrics": edit_lyrics,
|
| 579 |
+
"raw_lyrics_display": raw_lyrics_display,
|
| 580 |
+
"has_raw_lyrics_state": has_raw_lyrics_state,
|
| 581 |
+
"edit_bpm": edit_bpm,
|
| 582 |
+
"edit_keyscale": edit_keyscale,
|
| 583 |
+
"edit_timesig": edit_timesig,
|
| 584 |
+
"edit_duration": edit_duration,
|
| 585 |
+
"edit_language": edit_language,
|
| 586 |
+
"edit_instrumental": edit_instrumental,
|
| 587 |
+
"save_edit_btn": save_edit_btn,
|
| 588 |
+
"edit_status": edit_status,
|
| 589 |
+
"save_path": save_path,
|
| 590 |
+
"save_dataset_btn": save_dataset_btn,
|
| 591 |
+
"save_status": save_status,
|
| 592 |
+
# Preprocessing
|
| 593 |
+
"load_existing_dataset_path": load_existing_dataset_path,
|
| 594 |
+
"load_existing_dataset_btn": load_existing_dataset_btn,
|
| 595 |
+
"load_existing_status": load_existing_status,
|
| 596 |
+
"preprocess_output_dir": preprocess_output_dir,
|
| 597 |
+
"preprocess_btn": preprocess_btn,
|
| 598 |
+
"preprocess_progress": preprocess_progress,
|
| 599 |
+
"dataset_builder_state": dataset_builder_state,
|
| 600 |
+
# Training
|
| 601 |
+
"training_tensor_dir": training_tensor_dir,
|
| 602 |
+
"load_dataset_btn": load_dataset_btn,
|
| 603 |
+
"training_dataset_info": training_dataset_info,
|
| 604 |
+
"lora_rank": lora_rank,
|
| 605 |
+
"lora_alpha": lora_alpha,
|
| 606 |
+
"lora_dropout": lora_dropout,
|
| 607 |
+
"learning_rate": learning_rate,
|
| 608 |
+
"train_epochs": train_epochs,
|
| 609 |
+
"train_batch_size": train_batch_size,
|
| 610 |
+
"gradient_accumulation": gradient_accumulation,
|
| 611 |
+
"save_every_n_epochs": save_every_n_epochs,
|
| 612 |
+
"training_shift": training_shift,
|
| 613 |
+
"training_seed": training_seed,
|
| 614 |
+
"lora_output_dir": lora_output_dir,
|
| 615 |
+
"resume_checkpoint_dir": resume_checkpoint_dir,
|
| 616 |
+
"start_training_btn": start_training_btn,
|
| 617 |
+
"stop_training_btn": stop_training_btn,
|
| 618 |
+
"training_progress": training_progress,
|
| 619 |
+
"training_log": training_log,
|
| 620 |
+
"training_loss_plot": training_loss_plot,
|
| 621 |
+
"export_path": export_path,
|
| 622 |
+
"export_lora_btn": export_lora_btn,
|
| 623 |
+
"export_status": export_status,
|
| 624 |
+
"training_state": training_state,
|
| 625 |
+
}
|
acestep/handler.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/inference.py
ADDED
|
@@ -0,0 +1,1310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Inference API Module
|
| 3 |
+
|
| 4 |
+
This module provides a standardized inference interface for music generation,
|
| 5 |
+
designed for third-party integration. It offers both a simplified API and
|
| 6 |
+
backward-compatible Gradio UI support.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import math
|
| 10 |
+
import os
|
| 11 |
+
import tempfile
|
| 12 |
+
import shutil
|
| 13 |
+
import subprocess
|
| 14 |
+
import sys
|
| 15 |
+
from typing import Optional, Union, List, Dict, Any, Tuple
|
| 16 |
+
from dataclasses import dataclass, field, asdict
|
| 17 |
+
from loguru import logger
|
| 18 |
+
|
| 19 |
+
from acestep.audio_utils import AudioSaver, generate_uuid_from_params, is_audio_silent
|
| 20 |
+
from acestep.constants import TASK_INSTRUCTIONS
|
| 21 |
+
from acestep.gpu_config import get_gpu_config
|
| 22 |
+
|
| 23 |
+
# HuggingFace Space environment detection
|
| 24 |
+
IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
|
| 25 |
+
|
| 26 |
+
def _get_spaces_gpu_decorator(duration=180):
|
| 27 |
+
"""
|
| 28 |
+
Get the @spaces.GPU decorator if running in HuggingFace Space environment.
|
| 29 |
+
Returns identity decorator if not in Space environment.
|
| 30 |
+
"""
|
| 31 |
+
if IS_HUGGINGFACE_SPACE:
|
| 32 |
+
try:
|
| 33 |
+
import spaces
|
| 34 |
+
return spaces.GPU(duration=duration)
|
| 35 |
+
except ImportError:
|
| 36 |
+
logger.warning("spaces package not found, GPU decorator disabled")
|
| 37 |
+
return lambda func: func
|
| 38 |
+
return lambda func: func
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class GenerationParams:
|
| 43 |
+
"""Configuration for music generation parameters.
|
| 44 |
+
|
| 45 |
+
Attributes:
|
| 46 |
+
# Text Inputs
|
| 47 |
+
caption: A short text prompt describing the desired music (main prompt). < 512 characters
|
| 48 |
+
lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
|
| 49 |
+
instrumental: If True, generate instrumental music regardless of lyrics.
|
| 50 |
+
|
| 51 |
+
# Music Metadata
|
| 52 |
+
bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
|
| 53 |
+
keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
|
| 54 |
+
timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
|
| 55 |
+
vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
|
| 56 |
+
duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
|
| 57 |
+
|
| 58 |
+
# Generation Parameters
|
| 59 |
+
inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
|
| 60 |
+
guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
|
| 61 |
+
seed: Integer seed for reproducibility. -1 means use random seed each time.
|
| 62 |
+
|
| 63 |
+
# Advanced DiT Parameters
|
| 64 |
+
use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
|
| 65 |
+
cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
|
| 66 |
+
cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
|
| 67 |
+
shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
|
| 68 |
+
|
| 69 |
+
# Task-Specific Parameters
|
| 70 |
+
task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
|
| 71 |
+
reference_audio: Path to a reference audio file for style transfer or cover tasks.
|
| 72 |
+
src_audio: Path to a source audio file for audio-to-audio tasks.
|
| 73 |
+
audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
|
| 74 |
+
repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
|
| 75 |
+
repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
|
| 76 |
+
audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
|
| 77 |
+
instruction: Optional task instruction prompt. If empty, auto-generated by system.
|
| 78 |
+
|
| 79 |
+
# 5Hz Language Model Parameters for CoT reasoning
|
| 80 |
+
thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
|
| 81 |
+
lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
|
| 82 |
+
lm_cfg_scale: Classifier-free guidance scale for the LLM.
|
| 83 |
+
lm_top_k: LLM top-k sampling (0 = disabled).
|
| 84 |
+
lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
|
| 85 |
+
lm_negative_prompt: Negative prompt to use for LLM (for control).
|
| 86 |
+
use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
|
| 87 |
+
use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
|
| 88 |
+
use_cot_language: Whether to let LLM detect vocal language via CoT.
|
| 89 |
+
"""
|
| 90 |
+
# Required Inputs
|
| 91 |
+
task_type: str = "text2music"
|
| 92 |
+
instruction: str = "Fill the audio semantic mask based on the given conditions:"
|
| 93 |
+
|
| 94 |
+
# Audio Uploads
|
| 95 |
+
reference_audio: Optional[str] = None
|
| 96 |
+
src_audio: Optional[str] = None
|
| 97 |
+
|
| 98 |
+
# LM Codes Hints
|
| 99 |
+
audio_codes: str = ""
|
| 100 |
+
|
| 101 |
+
# Text Inputs
|
| 102 |
+
caption: str = ""
|
| 103 |
+
lyrics: str = ""
|
| 104 |
+
instrumental: bool = False
|
| 105 |
+
|
| 106 |
+
# Metadata
|
| 107 |
+
vocal_language: str = "unknown"
|
| 108 |
+
bpm: Optional[int] = None
|
| 109 |
+
keyscale: str = ""
|
| 110 |
+
timesignature: str = ""
|
| 111 |
+
duration: float = -1.0
|
| 112 |
+
|
| 113 |
+
# Advanced Settings
|
| 114 |
+
inference_steps: int = 8
|
| 115 |
+
seed: int = -1
|
| 116 |
+
guidance_scale: float = 7.0
|
| 117 |
+
use_adg: bool = False
|
| 118 |
+
cfg_interval_start: float = 0.0
|
| 119 |
+
cfg_interval_end: float = 1.0
|
| 120 |
+
shift: float = 1.0
|
| 121 |
+
infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
|
| 122 |
+
# Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
|
| 123 |
+
# If provided, overrides inference_steps and shift
|
| 124 |
+
timesteps: Optional[List[float]] = None
|
| 125 |
+
|
| 126 |
+
repainting_start: float = 0.0
|
| 127 |
+
repainting_end: float = -1
|
| 128 |
+
audio_cover_strength: float = 1.0
|
| 129 |
+
|
| 130 |
+
# 5Hz Language Model Parameters
|
| 131 |
+
thinking: bool = True
|
| 132 |
+
lm_temperature: float = 0.85
|
| 133 |
+
lm_cfg_scale: float = 2.0
|
| 134 |
+
lm_top_k: int = 0
|
| 135 |
+
lm_top_p: float = 0.9
|
| 136 |
+
lm_negative_prompt: str = "NO USER INPUT"
|
| 137 |
+
use_cot_metas: bool = True
|
| 138 |
+
use_cot_caption: bool = True
|
| 139 |
+
use_cot_lyrics: bool = False # TODO: not used yet
|
| 140 |
+
use_cot_language: bool = True
|
| 141 |
+
use_constrained_decoding: bool = True
|
| 142 |
+
|
| 143 |
+
cot_bpm: Optional[int] = None
|
| 144 |
+
cot_keyscale: str = ""
|
| 145 |
+
cot_timesignature: str = ""
|
| 146 |
+
cot_duration: Optional[float] = None
|
| 147 |
+
cot_vocal_language: str = "unknown"
|
| 148 |
+
cot_caption: str = ""
|
| 149 |
+
cot_lyrics: str = ""
|
| 150 |
+
|
| 151 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 152 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 153 |
+
return asdict(self)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
@dataclass
|
| 157 |
+
class GenerationConfig:
|
| 158 |
+
"""Configuration for music generation.
|
| 159 |
+
|
| 160 |
+
Attributes:
|
| 161 |
+
batch_size: Number of audio samples to generate
|
| 162 |
+
allow_lm_batch: Whether to allow batch processing in LM
|
| 163 |
+
use_random_seed: Whether to use random seed
|
| 164 |
+
seeds: Seed(s) for batch generation. Can be:
|
| 165 |
+
- None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
|
| 166 |
+
- List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
|
| 167 |
+
- int: Single seed value (will be converted to list and padded)
|
| 168 |
+
lm_batch_chunk_size: Batch chunk size for LM processing
|
| 169 |
+
constrained_decoding_debug: Whether to enable constrained decoding debug
|
| 170 |
+
audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
|
| 171 |
+
"""
|
| 172 |
+
batch_size: int = 2
|
| 173 |
+
allow_lm_batch: bool = False
|
| 174 |
+
use_random_seed: bool = True
|
| 175 |
+
seeds: Optional[List[int]] = None
|
| 176 |
+
lm_batch_chunk_size: int = 8
|
| 177 |
+
constrained_decoding_debug: bool = False
|
| 178 |
+
audio_format: str = "flac" # Default to FLAC for fast saving
|
| 179 |
+
|
| 180 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 181 |
+
"""Convert config to dictionary for JSON serialization."""
|
| 182 |
+
return asdict(self)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@dataclass
|
| 186 |
+
class GenerationResult:
|
| 187 |
+
"""Result of music generation.
|
| 188 |
+
|
| 189 |
+
Attributes:
|
| 190 |
+
# Audio Outputs
|
| 191 |
+
audios: List of audio dictionaries with paths, keys, params
|
| 192 |
+
status_message: Status message from generation
|
| 193 |
+
extra_outputs: Extra outputs from generation
|
| 194 |
+
success: Whether generation completed successfully
|
| 195 |
+
error: Error message if generation failed
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
# Audio Outputs
|
| 199 |
+
audios: List[Dict[str, Any]] = field(default_factory=list)
|
| 200 |
+
# Generation Information
|
| 201 |
+
status_message: str = ""
|
| 202 |
+
extra_outputs: Dict[str, Any] = field(default_factory=dict)
|
| 203 |
+
# Success Status
|
| 204 |
+
success: bool = True
|
| 205 |
+
error: Optional[str] = None
|
| 206 |
+
|
| 207 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 208 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 209 |
+
return asdict(self)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
@dataclass
|
| 213 |
+
class UnderstandResult:
|
| 214 |
+
"""Result of music understanding from audio codes.
|
| 215 |
+
|
| 216 |
+
Attributes:
|
| 217 |
+
# Metadata Fields
|
| 218 |
+
caption: Generated caption describing the music
|
| 219 |
+
lyrics: Generated or extracted lyrics
|
| 220 |
+
bpm: Beats per minute (None if not detected)
|
| 221 |
+
duration: Duration in seconds (None if not detected)
|
| 222 |
+
keyscale: Musical key (e.g., "C Major")
|
| 223 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 224 |
+
timesignature: Time signature (e.g., "4/4")
|
| 225 |
+
|
| 226 |
+
# Status
|
| 227 |
+
status_message: Status message from understanding
|
| 228 |
+
success: Whether understanding completed successfully
|
| 229 |
+
error: Error message if understanding failed
|
| 230 |
+
"""
|
| 231 |
+
# Metadata Fields
|
| 232 |
+
caption: str = ""
|
| 233 |
+
lyrics: str = ""
|
| 234 |
+
bpm: Optional[int] = None
|
| 235 |
+
duration: Optional[float] = None
|
| 236 |
+
keyscale: str = ""
|
| 237 |
+
language: str = ""
|
| 238 |
+
timesignature: str = ""
|
| 239 |
+
|
| 240 |
+
# Status
|
| 241 |
+
status_message: str = ""
|
| 242 |
+
success: bool = True
|
| 243 |
+
error: Optional[str] = None
|
| 244 |
+
|
| 245 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 246 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 247 |
+
return asdict(self)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _update_metadata_from_lm(
|
| 251 |
+
metadata: Dict[str, Any],
|
| 252 |
+
bpm: Optional[int],
|
| 253 |
+
key_scale: str,
|
| 254 |
+
time_signature: str,
|
| 255 |
+
audio_duration: Optional[float],
|
| 256 |
+
vocal_language: str,
|
| 257 |
+
caption: str,
|
| 258 |
+
lyrics: str,
|
| 259 |
+
) -> Tuple[Optional[int], str, str, Optional[float], str, str, str]:
|
| 260 |
+
"""Update metadata fields from LM output if not provided by user."""
|
| 261 |
+
|
| 262 |
+
if bpm is None and metadata.get('bpm'):
|
| 263 |
+
bpm_value = metadata.get('bpm')
|
| 264 |
+
if bpm_value not in ["N/A", ""]:
|
| 265 |
+
try:
|
| 266 |
+
bpm = int(bpm_value)
|
| 267 |
+
except (ValueError, TypeError):
|
| 268 |
+
pass
|
| 269 |
+
|
| 270 |
+
if not key_scale and metadata.get('keyscale'):
|
| 271 |
+
key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
|
| 272 |
+
if key_scale_value != "N/A":
|
| 273 |
+
key_scale = key_scale_value
|
| 274 |
+
|
| 275 |
+
if not time_signature and metadata.get('timesignature'):
|
| 276 |
+
time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
|
| 277 |
+
if time_signature_value != "N/A":
|
| 278 |
+
time_signature = time_signature_value
|
| 279 |
+
|
| 280 |
+
if audio_duration is None or audio_duration <= 0:
|
| 281 |
+
audio_duration_value = metadata.get('duration', -1)
|
| 282 |
+
if audio_duration_value not in ["N/A", ""]:
|
| 283 |
+
try:
|
| 284 |
+
audio_duration = float(audio_duration_value)
|
| 285 |
+
except (ValueError, TypeError):
|
| 286 |
+
pass
|
| 287 |
+
|
| 288 |
+
if not vocal_language and metadata.get('vocal_language'):
|
| 289 |
+
vocal_language = metadata.get('vocal_language')
|
| 290 |
+
if not caption and metadata.get('caption'):
|
| 291 |
+
caption = metadata.get('caption')
|
| 292 |
+
if not lyrics and metadata.get('lyrics'):
|
| 293 |
+
lyrics = metadata.get('lyrics')
|
| 294 |
+
return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@_get_spaces_gpu_decorator(duration=180)
|
| 298 |
+
def generate_music(
|
| 299 |
+
dit_handler,
|
| 300 |
+
llm_handler,
|
| 301 |
+
params: GenerationParams,
|
| 302 |
+
config: GenerationConfig,
|
| 303 |
+
save_dir: Optional[str] = None,
|
| 304 |
+
progress=None,
|
| 305 |
+
) -> GenerationResult:
|
| 306 |
+
"""Generate music using ACE-Step model with optional LM reasoning.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
dit_handler: Initialized DiT model handler (AceStepHandler instance)
|
| 310 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 311 |
+
params: Generation parameters (GenerationParams instance)
|
| 312 |
+
config: Generation configuration (GenerationConfig instance)
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
GenerationResult with generated audio files and metadata
|
| 316 |
+
"""
|
| 317 |
+
try:
|
| 318 |
+
# Phase 1: LM-based metadata and code generation (if enabled)
|
| 319 |
+
audio_code_string_to_use = params.audio_codes
|
| 320 |
+
lm_generated_metadata = None
|
| 321 |
+
lm_generated_audio_codes_list = []
|
| 322 |
+
lm_total_time_costs = {
|
| 323 |
+
"phase1_time": 0.0,
|
| 324 |
+
"phase2_time": 0.0,
|
| 325 |
+
"total_time": 0.0,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
# Extract mutable copies of metadata (will be updated by LM if needed)
|
| 329 |
+
bpm = params.bpm
|
| 330 |
+
key_scale = params.keyscale
|
| 331 |
+
time_signature = params.timesignature
|
| 332 |
+
audio_duration = params.duration
|
| 333 |
+
dit_input_caption = params.caption
|
| 334 |
+
dit_input_vocal_language = params.vocal_language
|
| 335 |
+
dit_input_lyrics = params.lyrics
|
| 336 |
+
# Determine if we need to generate audio codes
|
| 337 |
+
# If user has provided audio_codes, we don't need to generate them
|
| 338 |
+
# Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
|
| 339 |
+
user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
|
| 340 |
+
|
| 341 |
+
# Safety: cover task without any source audio or codes produces silence.
|
| 342 |
+
if params.task_type == "cover":
|
| 343 |
+
no_src_audio = not (params.reference_audio or params.src_audio)
|
| 344 |
+
if no_src_audio and not user_provided_audio_codes:
|
| 345 |
+
logger.warning("Cover task requested without source audio or audio codes. Falling back to text2music.")
|
| 346 |
+
params.task_type = "text2music"
|
| 347 |
+
if params.instruction == TASK_INSTRUCTIONS.get("cover"):
|
| 348 |
+
params.instruction = TASK_INSTRUCTIONS.get("text2music", params.instruction)
|
| 349 |
+
|
| 350 |
+
# Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
|
| 351 |
+
# For now, we use "llm_dit" if batch mode or if user hasn't provided codes
|
| 352 |
+
# Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
|
| 353 |
+
# Note: This logic can be refined based on specific requirements
|
| 354 |
+
need_audio_codes = not user_provided_audio_codes
|
| 355 |
+
|
| 356 |
+
# Determine if we should use chunk-based LM generation (always use chunks for consistency)
|
| 357 |
+
# Determine actual batch size for chunk processing
|
| 358 |
+
actual_batch_size = config.batch_size if config.batch_size is not None else 1
|
| 359 |
+
|
| 360 |
+
# Prepare seeds for batch generation
|
| 361 |
+
# Use config.seed if provided, otherwise fallback to params.seed
|
| 362 |
+
# Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
|
| 363 |
+
seed_for_generation = ""
|
| 364 |
+
# Original code (commented out because it crashes on int seeds):
|
| 365 |
+
# if config.seeds is not None and len(config.seeds) > 0:
|
| 366 |
+
# if isinstance(config.seeds, list):
|
| 367 |
+
# # Convert List[int] to comma-separated string
|
| 368 |
+
# seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 369 |
+
|
| 370 |
+
if config.seeds is not None:
|
| 371 |
+
if isinstance(config.seeds, list) and len(config.seeds) > 0:
|
| 372 |
+
# Convert List[int] to comma-separated string
|
| 373 |
+
seed_for_generation = ",".join(str(s) for s in config.seeds)
|
| 374 |
+
elif isinstance(config.seeds, int):
|
| 375 |
+
# Fix: Explicitly handle single integer seeds by converting to string.
|
| 376 |
+
# Previously, this would crash because 'len()' was called on an int.
|
| 377 |
+
seed_for_generation = str(config.seeds)
|
| 378 |
+
|
| 379 |
+
# Use dit_handler.prepare_seeds to handle seed list generation and padding
|
| 380 |
+
# This will handle all the logic: padding with random seeds if needed, etc.
|
| 381 |
+
actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
|
| 382 |
+
|
| 383 |
+
# LM-based Chain-of-Thought reasoning
|
| 384 |
+
# Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
|
| 385 |
+
# and don't need LM to generate audio codes
|
| 386 |
+
skip_lm_tasks = {"cover", "repaint"}
|
| 387 |
+
|
| 388 |
+
# Determine if we should use LLM
|
| 389 |
+
# LLM is needed for:
|
| 390 |
+
# 1. thinking=True: generate audio codes via LM
|
| 391 |
+
# 2. use_cot_caption=True: enhance/generate caption via CoT
|
| 392 |
+
# 3. use_cot_language=True: detect vocal language via CoT
|
| 393 |
+
# 4. use_cot_metas=True: fill missing metadata via CoT
|
| 394 |
+
need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
|
| 395 |
+
use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
|
| 396 |
+
lm_status = []
|
| 397 |
+
|
| 398 |
+
if params.task_type in skip_lm_tasks:
|
| 399 |
+
logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
|
| 400 |
+
|
| 401 |
+
logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
|
| 402 |
+
f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
|
| 403 |
+
f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
|
| 404 |
+
f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
|
| 405 |
+
|
| 406 |
+
def _infer_audio_duration_seconds(audio_path: str) -> Optional[float]:
|
| 407 |
+
"""Best-effort duration inference for common audio formats."""
|
| 408 |
+
if not audio_path:
|
| 409 |
+
return None
|
| 410 |
+
# Try torchaudio (supports more formats when ffmpeg backend is available)
|
| 411 |
+
try:
|
| 412 |
+
import torchaudio
|
| 413 |
+
info = torchaudio.info(audio_path)
|
| 414 |
+
if info and info.num_frames and info.sample_rate:
|
| 415 |
+
return float(info.num_frames) / float(info.sample_rate)
|
| 416 |
+
except Exception:
|
| 417 |
+
pass
|
| 418 |
+
# Try soundfile (fast for wav/flac)
|
| 419 |
+
try:
|
| 420 |
+
import soundfile as sf
|
| 421 |
+
info = sf.info(audio_path)
|
| 422 |
+
if info and info.frames and info.samplerate:
|
| 423 |
+
return float(info.frames) / float(info.samplerate)
|
| 424 |
+
except Exception:
|
| 425 |
+
pass
|
| 426 |
+
# macOS fallback: use afinfo for m4a/aac
|
| 427 |
+
if sys.platform == "darwin" and shutil.which("afinfo"):
|
| 428 |
+
try:
|
| 429 |
+
result = subprocess.run(
|
| 430 |
+
["afinfo", audio_path],
|
| 431 |
+
check=False,
|
| 432 |
+
capture_output=True,
|
| 433 |
+
text=True,
|
| 434 |
+
)
|
| 435 |
+
if result.stdout:
|
| 436 |
+
for line in result.stdout.splitlines():
|
| 437 |
+
if "duration:" in line:
|
| 438 |
+
# Example: "duration: 183.165s"
|
| 439 |
+
parts = line.strip().split()
|
| 440 |
+
for p in parts:
|
| 441 |
+
if p.endswith("s"):
|
| 442 |
+
try:
|
| 443 |
+
return float(p.rstrip("s"))
|
| 444 |
+
except ValueError:
|
| 445 |
+
continue
|
| 446 |
+
except Exception:
|
| 447 |
+
pass
|
| 448 |
+
return None
|
| 449 |
+
|
| 450 |
+
# Clamp duration and batch size to GPU limits (applies to non-Gradio callers too)
|
| 451 |
+
try:
|
| 452 |
+
# If duration not provided, try to infer from source audio to enable safe clamping.
|
| 453 |
+
if (audio_duration is None or float(audio_duration) <= 0) and (params.src_audio or params.reference_audio):
|
| 454 |
+
audio_path = params.src_audio or params.reference_audio
|
| 455 |
+
try:
|
| 456 |
+
inferred = _infer_audio_duration_seconds(audio_path)
|
| 457 |
+
if inferred and inferred > 0:
|
| 458 |
+
audio_duration = inferred
|
| 459 |
+
params.duration = inferred
|
| 460 |
+
logger.info(f"[generate_music] Inferred duration from audio file: {inferred:.2f}s")
|
| 461 |
+
except Exception as e:
|
| 462 |
+
logger.warning(f"[generate_music] Failed to infer duration from audio file: {e}")
|
| 463 |
+
|
| 464 |
+
gpu_config = get_gpu_config()
|
| 465 |
+
max_duration = gpu_config.max_duration_with_lm if use_lm else gpu_config.max_duration_without_lm
|
| 466 |
+
if audio_duration is not None and float(audio_duration) > 0 and float(audio_duration) > max_duration:
|
| 467 |
+
logger.warning(f"[generate_music] Duration {audio_duration}s exceeds GPU limit {max_duration}s. Clamping.")
|
| 468 |
+
audio_duration = float(max_duration)
|
| 469 |
+
params.duration = float(max_duration)
|
| 470 |
+
|
| 471 |
+
max_batch = gpu_config.max_batch_size_with_lm if use_lm else gpu_config.max_batch_size_without_lm
|
| 472 |
+
if config.batch_size is not None and config.batch_size > max_batch:
|
| 473 |
+
logger.warning(f"[generate_music] Batch size {config.batch_size} exceeds GPU limit {max_batch}. Clamping.")
|
| 474 |
+
config.batch_size = max_batch
|
| 475 |
+
|
| 476 |
+
# Extra safety for MPS: large durations can OOM with batch > 1
|
| 477 |
+
if (
|
| 478 |
+
hasattr(dit_handler, "device")
|
| 479 |
+
and dit_handler.device == "mps"
|
| 480 |
+
and audio_duration is not None
|
| 481 |
+
and float(audio_duration) > 180
|
| 482 |
+
and config.batch_size is not None
|
| 483 |
+
and config.batch_size > 1
|
| 484 |
+
):
|
| 485 |
+
logger.warning("[generate_music] MPS with long duration detected; reducing batch size to 1 to avoid OOM.")
|
| 486 |
+
config.batch_size = 1
|
| 487 |
+
except Exception as e:
|
| 488 |
+
logger.warning(f"[generate_music] Failed to clamp duration/batch to GPU limits: {e}")
|
| 489 |
+
|
| 490 |
+
if use_lm:
|
| 491 |
+
# Convert sampling parameters - handle None values safely
|
| 492 |
+
top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
|
| 493 |
+
top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
|
| 494 |
+
|
| 495 |
+
# Build user_metadata from user-provided values
|
| 496 |
+
user_metadata = {}
|
| 497 |
+
if bpm is not None:
|
| 498 |
+
try:
|
| 499 |
+
bpm_value = float(bpm)
|
| 500 |
+
if bpm_value > 0:
|
| 501 |
+
user_metadata['bpm'] = int(bpm_value)
|
| 502 |
+
except (ValueError, TypeError):
|
| 503 |
+
pass
|
| 504 |
+
|
| 505 |
+
if key_scale and key_scale.strip():
|
| 506 |
+
key_scale_clean = key_scale.strip()
|
| 507 |
+
if key_scale_clean.lower() not in ["n/a", ""]:
|
| 508 |
+
user_metadata['keyscale'] = key_scale_clean
|
| 509 |
+
|
| 510 |
+
if time_signature and time_signature.strip():
|
| 511 |
+
time_sig_clean = time_signature.strip()
|
| 512 |
+
if time_sig_clean.lower() not in ["n/a", ""]:
|
| 513 |
+
user_metadata['timesignature'] = time_sig_clean
|
| 514 |
+
|
| 515 |
+
if audio_duration is not None:
|
| 516 |
+
try:
|
| 517 |
+
duration_value = float(audio_duration)
|
| 518 |
+
if duration_value > 0:
|
| 519 |
+
user_metadata['duration'] = int(duration_value)
|
| 520 |
+
except (ValueError, TypeError):
|
| 521 |
+
pass
|
| 522 |
+
|
| 523 |
+
user_metadata_to_pass = user_metadata if user_metadata else None
|
| 524 |
+
|
| 525 |
+
# Determine infer_type based on whether we need audio codes
|
| 526 |
+
# - "llm_dit": generates both metas and audio codes (two-phase internally)
|
| 527 |
+
# - "dit": generates only metas (single phase)
|
| 528 |
+
infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
|
| 529 |
+
|
| 530 |
+
# Use chunk size from config, or default to batch_size if not set
|
| 531 |
+
max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
|
| 532 |
+
num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
|
| 533 |
+
|
| 534 |
+
all_metadata_list = []
|
| 535 |
+
all_audio_codes_list = []
|
| 536 |
+
|
| 537 |
+
for chunk_idx in range(num_chunks):
|
| 538 |
+
chunk_start = chunk_idx * max_inference_batch_size
|
| 539 |
+
chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
|
| 540 |
+
chunk_size = chunk_end - chunk_start
|
| 541 |
+
chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
|
| 542 |
+
|
| 543 |
+
logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
|
| 544 |
+
f"(size: {chunk_size}, seeds: {chunk_seeds})")
|
| 545 |
+
|
| 546 |
+
# Use the determined infer_type
|
| 547 |
+
# - "llm_dit" will internally run two phases (metas + codes)
|
| 548 |
+
# - "dit" will only run phase 1 (metas only)
|
| 549 |
+
result = llm_handler.generate_with_stop_condition(
|
| 550 |
+
caption=params.caption or "",
|
| 551 |
+
lyrics=params.lyrics or "",
|
| 552 |
+
infer_type=infer_type,
|
| 553 |
+
temperature=params.lm_temperature,
|
| 554 |
+
cfg_scale=params.lm_cfg_scale,
|
| 555 |
+
negative_prompt=params.lm_negative_prompt,
|
| 556 |
+
top_k=top_k_value,
|
| 557 |
+
top_p=top_p_value,
|
| 558 |
+
target_duration=audio_duration, # Pass duration to limit audio codes generation
|
| 559 |
+
user_metadata=user_metadata_to_pass,
|
| 560 |
+
use_cot_caption=params.use_cot_caption,
|
| 561 |
+
use_cot_language=params.use_cot_language,
|
| 562 |
+
use_cot_metas=params.use_cot_metas,
|
| 563 |
+
use_constrained_decoding=params.use_constrained_decoding,
|
| 564 |
+
constrained_decoding_debug=config.constrained_decoding_debug,
|
| 565 |
+
batch_size=chunk_size,
|
| 566 |
+
seeds=chunk_seeds,
|
| 567 |
+
progress=progress,
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
# Check if LM generation failed
|
| 571 |
+
if not result.get("success", False):
|
| 572 |
+
error_msg = result.get("error", "Unknown LM error")
|
| 573 |
+
lm_status.append(f"❌ LM Error: {error_msg}")
|
| 574 |
+
# Return early with error
|
| 575 |
+
return GenerationResult(
|
| 576 |
+
audios=[],
|
| 577 |
+
status_message=f"❌ LM generation failed: {error_msg}",
|
| 578 |
+
extra_outputs={},
|
| 579 |
+
success=False,
|
| 580 |
+
error=error_msg,
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
# Extract metadata and audio_codes from result dict
|
| 584 |
+
if chunk_size > 1:
|
| 585 |
+
metadata_list = result.get("metadata", [])
|
| 586 |
+
audio_codes_list = result.get("audio_codes", [])
|
| 587 |
+
all_metadata_list.extend(metadata_list)
|
| 588 |
+
all_audio_codes_list.extend(audio_codes_list)
|
| 589 |
+
else:
|
| 590 |
+
metadata = result.get("metadata", {})
|
| 591 |
+
audio_codes = result.get("audio_codes", "")
|
| 592 |
+
all_metadata_list.append(metadata)
|
| 593 |
+
all_audio_codes_list.append(audio_codes)
|
| 594 |
+
|
| 595 |
+
# Collect time costs from LM extra_outputs
|
| 596 |
+
lm_extra = result.get("extra_outputs", {})
|
| 597 |
+
lm_chunk_time_costs = lm_extra.get("time_costs", {})
|
| 598 |
+
if lm_chunk_time_costs:
|
| 599 |
+
# Accumulate time costs from all chunks
|
| 600 |
+
for key in ["phase1_time", "phase2_time", "total_time"]:
|
| 601 |
+
if key in lm_chunk_time_costs:
|
| 602 |
+
lm_total_time_costs[key] += lm_chunk_time_costs[key]
|
| 603 |
+
|
| 604 |
+
time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
|
| 605 |
+
lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
|
| 606 |
+
|
| 607 |
+
lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
|
| 608 |
+
lm_generated_audio_codes_list = all_audio_codes_list
|
| 609 |
+
|
| 610 |
+
# Set audio_code_string_to_use based on infer_type
|
| 611 |
+
if infer_type == "llm_dit":
|
| 612 |
+
# If batch mode, use list; otherwise use single string
|
| 613 |
+
if actual_batch_size > 1:
|
| 614 |
+
audio_code_string_to_use = all_audio_codes_list
|
| 615 |
+
else:
|
| 616 |
+
audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
|
| 617 |
+
else:
|
| 618 |
+
# For "dit" mode, keep user-provided codes or empty
|
| 619 |
+
audio_code_string_to_use = params.audio_codes
|
| 620 |
+
|
| 621 |
+
# Update metadata from LM if not provided by user
|
| 622 |
+
if lm_generated_metadata:
|
| 623 |
+
bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
|
| 624 |
+
metadata=lm_generated_metadata,
|
| 625 |
+
bpm=bpm,
|
| 626 |
+
key_scale=key_scale,
|
| 627 |
+
time_signature=time_signature,
|
| 628 |
+
audio_duration=audio_duration,
|
| 629 |
+
vocal_language=dit_input_vocal_language,
|
| 630 |
+
caption=dit_input_caption,
|
| 631 |
+
lyrics=dit_input_lyrics)
|
| 632 |
+
if not params.bpm:
|
| 633 |
+
params.cot_bpm = bpm
|
| 634 |
+
if not params.keyscale:
|
| 635 |
+
params.cot_keyscale = key_scale
|
| 636 |
+
if not params.timesignature:
|
| 637 |
+
params.cot_timesignature = time_signature
|
| 638 |
+
if not params.duration:
|
| 639 |
+
params.cot_duration = audio_duration
|
| 640 |
+
if not params.vocal_language:
|
| 641 |
+
params.cot_vocal_language = vocal_language
|
| 642 |
+
if not params.caption:
|
| 643 |
+
params.cot_caption = caption
|
| 644 |
+
if not params.lyrics:
|
| 645 |
+
params.cot_lyrics = lyrics
|
| 646 |
+
dit_input_lyrics = lyrics
|
| 647 |
+
|
| 648 |
+
# set cot caption and language if needed
|
| 649 |
+
if params.use_cot_caption:
|
| 650 |
+
dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
|
| 651 |
+
if params.use_cot_language:
|
| 652 |
+
dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
|
| 653 |
+
|
| 654 |
+
# Phase 2: DiT music generation
|
| 655 |
+
# Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
|
| 656 |
+
result = dit_handler.generate_music(
|
| 657 |
+
captions=dit_input_caption,
|
| 658 |
+
lyrics=dit_input_lyrics,
|
| 659 |
+
bpm=bpm,
|
| 660 |
+
key_scale=key_scale,
|
| 661 |
+
time_signature=time_signature,
|
| 662 |
+
vocal_language=dit_input_vocal_language,
|
| 663 |
+
inference_steps=params.inference_steps,
|
| 664 |
+
guidance_scale=params.guidance_scale,
|
| 665 |
+
use_random_seed=config.use_random_seed,
|
| 666 |
+
seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
|
| 667 |
+
reference_audio=params.reference_audio,
|
| 668 |
+
audio_duration=audio_duration,
|
| 669 |
+
batch_size=config.batch_size if config.batch_size is not None else 1,
|
| 670 |
+
src_audio=params.src_audio,
|
| 671 |
+
audio_code_string=audio_code_string_to_use,
|
| 672 |
+
repainting_start=params.repainting_start,
|
| 673 |
+
repainting_end=params.repainting_end,
|
| 674 |
+
instruction=params.instruction,
|
| 675 |
+
audio_cover_strength=params.audio_cover_strength,
|
| 676 |
+
task_type=params.task_type,
|
| 677 |
+
use_adg=params.use_adg,
|
| 678 |
+
cfg_interval_start=params.cfg_interval_start,
|
| 679 |
+
cfg_interval_end=params.cfg_interval_end,
|
| 680 |
+
shift=params.shift,
|
| 681 |
+
infer_method=params.infer_method,
|
| 682 |
+
timesteps=params.timesteps,
|
| 683 |
+
progress=progress,
|
| 684 |
+
)
|
| 685 |
+
|
| 686 |
+
# Check if generation failed
|
| 687 |
+
if not result.get("success", False):
|
| 688 |
+
return GenerationResult(
|
| 689 |
+
audios=[],
|
| 690 |
+
status_message=result.get("status_message", ""),
|
| 691 |
+
extra_outputs={},
|
| 692 |
+
success=False,
|
| 693 |
+
error=result.get("error"),
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
# Extract results from dit_handler.generate_music dict
|
| 697 |
+
dit_audios = result.get("audios", [])
|
| 698 |
+
status_message = result.get("status_message", "")
|
| 699 |
+
dit_extra_outputs = result.get("extra_outputs", {})
|
| 700 |
+
|
| 701 |
+
# Use the seed list already prepared above (from config.seed or params.seed fallback)
|
| 702 |
+
# actual_seed_list was computed earlier using dit_handler.prepare_seeds
|
| 703 |
+
seed_list = actual_seed_list
|
| 704 |
+
|
| 705 |
+
# Get base params dictionary
|
| 706 |
+
base_params_dict = params.to_dict()
|
| 707 |
+
|
| 708 |
+
# Save audio files using AudioSaver (format from config)
|
| 709 |
+
audio_format = config.audio_format if config.audio_format else "flac"
|
| 710 |
+
audio_saver = AudioSaver(default_format=audio_format)
|
| 711 |
+
|
| 712 |
+
# Use handler's temp_dir for saving files
|
| 713 |
+
if save_dir is not None:
|
| 714 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 715 |
+
|
| 716 |
+
# Build audios list for GenerationResult with params and save files
|
| 717 |
+
# Audio saving and UUID generation handled here, outside of handler
|
| 718 |
+
audios = []
|
| 719 |
+
silent_warnings = []
|
| 720 |
+
for idx, dit_audio in enumerate(dit_audios):
|
| 721 |
+
# Create a copy of params dict for this audio
|
| 722 |
+
audio_params = base_params_dict.copy()
|
| 723 |
+
|
| 724 |
+
# Update audio-specific values
|
| 725 |
+
audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
|
| 726 |
+
|
| 727 |
+
# Add audio codes if batch mode
|
| 728 |
+
if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
|
| 729 |
+
audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
|
| 730 |
+
|
| 731 |
+
# Get audio tensor and metadata
|
| 732 |
+
audio_tensor = dit_audio.get("tensor")
|
| 733 |
+
sample_rate = dit_audio.get("sample_rate", 48000)
|
| 734 |
+
|
| 735 |
+
# Generate UUID for this audio (moved from handler)
|
| 736 |
+
batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
|
| 737 |
+
audio_code_str = lm_generated_audio_codes_list[idx] if (
|
| 738 |
+
lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
|
| 739 |
+
if isinstance(audio_code_str, list):
|
| 740 |
+
audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
|
| 741 |
+
|
| 742 |
+
audio_key = generate_uuid_from_params(audio_params)
|
| 743 |
+
|
| 744 |
+
silent_check = False
|
| 745 |
+
if audio_tensor is not None:
|
| 746 |
+
silent_check, rms_val, peak_val = is_audio_silent(audio_tensor, channels_first=True)
|
| 747 |
+
if silent_check:
|
| 748 |
+
logger.warning(
|
| 749 |
+
f"[generate_music] Silent output detected (idx={idx}, RMS={rms_val:.2e}, peak={peak_val:.2e}). "
|
| 750 |
+
"Likely cause: LLM backend returned empty conditioning, or incompatible torch/triton/flash-attn. "
|
| 751 |
+
"Suggest running with --backend pt."
|
| 752 |
+
)
|
| 753 |
+
silent_warnings.append(
|
| 754 |
+
f"Output {idx + 1}: silent or near-silent (RMS≈{rms_val:.2e}). "
|
| 755 |
+
"Likely causes: LLM backend failure, incompatible torch/triton/flash-attn, or CPU/fallback path. "
|
| 756 |
+
"Try running with --backend pt."
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
audio_path = None
|
| 760 |
+
if audio_tensor is not None and save_dir is not None and not silent_check:
|
| 761 |
+
try:
|
| 762 |
+
audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
|
| 763 |
+
audio_path = audio_saver.save_audio(audio_tensor,
|
| 764 |
+
audio_file,
|
| 765 |
+
sample_rate=sample_rate,
|
| 766 |
+
format=audio_format,
|
| 767 |
+
channels_first=True)
|
| 768 |
+
except Exception as e:
|
| 769 |
+
logger.error(f"[generate_music] Failed to save audio file: {e}")
|
| 770 |
+
audio_path = ""
|
| 771 |
+
|
| 772 |
+
audio_dict = {
|
| 773 |
+
"path": audio_path or "",
|
| 774 |
+
"tensor": audio_tensor,
|
| 775 |
+
"key": audio_key,
|
| 776 |
+
"sample_rate": sample_rate,
|
| 777 |
+
"params": audio_params,
|
| 778 |
+
"silent": silent_check,
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
audios.append(audio_dict)
|
| 782 |
+
|
| 783 |
+
# Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
|
| 784 |
+
extra_outputs = dit_extra_outputs.copy()
|
| 785 |
+
extra_outputs["lm_metadata"] = lm_generated_metadata
|
| 786 |
+
|
| 787 |
+
# Merge time_costs from both LM and DiT into a unified dictionary
|
| 788 |
+
unified_time_costs = {}
|
| 789 |
+
|
| 790 |
+
# Add LM time costs (if LM was used)
|
| 791 |
+
if use_lm and lm_total_time_costs:
|
| 792 |
+
for key, value in lm_total_time_costs.items():
|
| 793 |
+
unified_time_costs[f"lm_{key}"] = value
|
| 794 |
+
|
| 795 |
+
# Add DiT time costs (if available)
|
| 796 |
+
dit_time_costs = dit_extra_outputs.get("time_costs", {})
|
| 797 |
+
if dit_time_costs:
|
| 798 |
+
for key, value in dit_time_costs.items():
|
| 799 |
+
unified_time_costs[f"dit_{key}"] = value
|
| 800 |
+
|
| 801 |
+
# Calculate total pipeline time
|
| 802 |
+
if unified_time_costs:
|
| 803 |
+
lm_total = unified_time_costs.get("lm_total_time", 0.0)
|
| 804 |
+
dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
|
| 805 |
+
unified_time_costs["pipeline_total_time"] = lm_total + dit_total
|
| 806 |
+
|
| 807 |
+
# Update extra_outputs with unified time_costs
|
| 808 |
+
extra_outputs["time_costs"] = unified_time_costs
|
| 809 |
+
|
| 810 |
+
if lm_status:
|
| 811 |
+
status_message = "\n".join(lm_status) + "\n" + status_message
|
| 812 |
+
else:
|
| 813 |
+
status_message = status_message
|
| 814 |
+
if silent_warnings:
|
| 815 |
+
status_message = "⚠️ Silent output detected:\n" + "\n".join(silent_warnings) + "\n\nSuggested fix: try running with --backend pt\n\n" + (status_message or "")
|
| 816 |
+
# Create and return GenerationResult
|
| 817 |
+
return GenerationResult(
|
| 818 |
+
audios=audios,
|
| 819 |
+
status_message=status_message,
|
| 820 |
+
extra_outputs=extra_outputs,
|
| 821 |
+
success=True,
|
| 822 |
+
error=None,
|
| 823 |
+
)
|
| 824 |
+
|
| 825 |
+
except Exception as e:
|
| 826 |
+
logger.exception("Music generation failed")
|
| 827 |
+
return GenerationResult(
|
| 828 |
+
audios=[],
|
| 829 |
+
status_message=f"Error: {str(e)}",
|
| 830 |
+
extra_outputs={},
|
| 831 |
+
success=False,
|
| 832 |
+
error=str(e),
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
def understand_music(
|
| 837 |
+
llm_handler,
|
| 838 |
+
audio_codes: str,
|
| 839 |
+
temperature: float = 0.85,
|
| 840 |
+
top_k: Optional[int] = None,
|
| 841 |
+
top_p: Optional[float] = None,
|
| 842 |
+
repetition_penalty: float = 1.0,
|
| 843 |
+
use_constrained_decoding: bool = True,
|
| 844 |
+
constrained_decoding_debug: bool = False,
|
| 845 |
+
) -> UnderstandResult:
|
| 846 |
+
"""Understand music from audio codes using the 5Hz Language Model.
|
| 847 |
+
|
| 848 |
+
This function analyzes audio semantic codes and generates metadata about the music,
|
| 849 |
+
including caption, lyrics, BPM, duration, key scale, language, and time signature.
|
| 850 |
+
|
| 851 |
+
If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
|
| 852 |
+
instead of analyzing existing codes.
|
| 853 |
+
|
| 854 |
+
Note: cfg_scale and negative_prompt are not supported in understand mode.
|
| 855 |
+
|
| 856 |
+
Args:
|
| 857 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 858 |
+
audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
|
| 859 |
+
Use empty string or "NO USER INPUT" to generate a sample example.
|
| 860 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 861 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 862 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 863 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 864 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 865 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 866 |
+
|
| 867 |
+
Returns:
|
| 868 |
+
UnderstandResult with parsed metadata fields and status
|
| 869 |
+
|
| 870 |
+
Example:
|
| 871 |
+
>>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
|
| 872 |
+
>>> if result.success:
|
| 873 |
+
... print(f"Caption: {result.caption}")
|
| 874 |
+
... print(f"BPM: {result.bpm}")
|
| 875 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 876 |
+
"""
|
| 877 |
+
# Check if LLM is initialized
|
| 878 |
+
if not llm_handler.llm_initialized:
|
| 879 |
+
return UnderstandResult(
|
| 880 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 881 |
+
success=False,
|
| 882 |
+
error="LLM not initialized",
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
# If codes are empty, use "NO USER INPUT" to generate a sample example
|
| 886 |
+
if not audio_codes or not audio_codes.strip():
|
| 887 |
+
audio_codes = "NO USER INPUT"
|
| 888 |
+
|
| 889 |
+
try:
|
| 890 |
+
# Call LLM understanding
|
| 891 |
+
metadata, status = llm_handler.understand_audio_from_codes(
|
| 892 |
+
audio_codes=audio_codes,
|
| 893 |
+
temperature=temperature,
|
| 894 |
+
top_k=top_k,
|
| 895 |
+
top_p=top_p,
|
| 896 |
+
repetition_penalty=repetition_penalty,
|
| 897 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 898 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
# Check if LLM returned empty metadata (error case)
|
| 902 |
+
if not metadata:
|
| 903 |
+
return UnderstandResult(
|
| 904 |
+
status_message=status or "Failed to understand audio codes",
|
| 905 |
+
success=False,
|
| 906 |
+
error=status or "Empty metadata returned",
|
| 907 |
+
)
|
| 908 |
+
|
| 909 |
+
# Extract and convert fields
|
| 910 |
+
caption = metadata.get('caption', '')
|
| 911 |
+
lyrics = metadata.get('lyrics', '')
|
| 912 |
+
keyscale = metadata.get('keyscale', '')
|
| 913 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 914 |
+
timesignature = metadata.get('timesignature', '')
|
| 915 |
+
|
| 916 |
+
# Convert BPM to int
|
| 917 |
+
bpm = None
|
| 918 |
+
bpm_value = metadata.get('bpm')
|
| 919 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 920 |
+
try:
|
| 921 |
+
bpm = int(bpm_value)
|
| 922 |
+
except (ValueError, TypeError):
|
| 923 |
+
pass
|
| 924 |
+
|
| 925 |
+
# Convert duration to float
|
| 926 |
+
duration = None
|
| 927 |
+
duration_value = metadata.get('duration')
|
| 928 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 929 |
+
try:
|
| 930 |
+
duration = float(duration_value)
|
| 931 |
+
except (ValueError, TypeError):
|
| 932 |
+
pass
|
| 933 |
+
|
| 934 |
+
# Clean up N/A values
|
| 935 |
+
if keyscale == 'N/A':
|
| 936 |
+
keyscale = ''
|
| 937 |
+
if language == 'N/A':
|
| 938 |
+
language = ''
|
| 939 |
+
if timesignature == 'N/A':
|
| 940 |
+
timesignature = ''
|
| 941 |
+
|
| 942 |
+
return UnderstandResult(
|
| 943 |
+
caption=caption,
|
| 944 |
+
lyrics=lyrics,
|
| 945 |
+
bpm=bpm,
|
| 946 |
+
duration=duration,
|
| 947 |
+
keyscale=keyscale,
|
| 948 |
+
language=language,
|
| 949 |
+
timesignature=timesignature,
|
| 950 |
+
status_message=status,
|
| 951 |
+
success=True,
|
| 952 |
+
error=None,
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
except Exception as e:
|
| 956 |
+
logger.exception("Music understanding failed")
|
| 957 |
+
return UnderstandResult(
|
| 958 |
+
status_message=f"Error: {str(e)}",
|
| 959 |
+
success=False,
|
| 960 |
+
error=str(e),
|
| 961 |
+
)
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
@dataclass
|
| 965 |
+
class CreateSampleResult:
|
| 966 |
+
"""Result of creating a music sample from a natural language query.
|
| 967 |
+
|
| 968 |
+
This is used by the "Simple Mode" / "Inspiration Mode" feature where users
|
| 969 |
+
provide a natural language description and the LLM generates a complete
|
| 970 |
+
sample with caption, lyrics, and metadata.
|
| 971 |
+
|
| 972 |
+
Attributes:
|
| 973 |
+
# Metadata Fields
|
| 974 |
+
caption: Generated detailed music description/caption
|
| 975 |
+
lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
|
| 976 |
+
bpm: Beats per minute (None if not generated)
|
| 977 |
+
duration: Duration in seconds (None if not generated)
|
| 978 |
+
keyscale: Musical key (e.g., "C Major")
|
| 979 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 980 |
+
timesignature: Time signature (e.g., "4")
|
| 981 |
+
instrumental: Whether this is an instrumental piece
|
| 982 |
+
|
| 983 |
+
# Status
|
| 984 |
+
status_message: Status message from sample creation
|
| 985 |
+
success: Whether sample creation completed successfully
|
| 986 |
+
error: Error message if sample creation failed
|
| 987 |
+
"""
|
| 988 |
+
# Metadata Fields
|
| 989 |
+
caption: str = ""
|
| 990 |
+
lyrics: str = ""
|
| 991 |
+
bpm: Optional[int] = None
|
| 992 |
+
duration: Optional[float] = None
|
| 993 |
+
keyscale: str = ""
|
| 994 |
+
language: str = ""
|
| 995 |
+
timesignature: str = ""
|
| 996 |
+
instrumental: bool = False
|
| 997 |
+
|
| 998 |
+
# Status
|
| 999 |
+
status_message: str = ""
|
| 1000 |
+
success: bool = True
|
| 1001 |
+
error: Optional[str] = None
|
| 1002 |
+
|
| 1003 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1004 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 1005 |
+
return asdict(self)
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def create_sample(
|
| 1009 |
+
llm_handler,
|
| 1010 |
+
query: str,
|
| 1011 |
+
instrumental: bool = False,
|
| 1012 |
+
vocal_language: Optional[str] = None,
|
| 1013 |
+
temperature: float = 0.85,
|
| 1014 |
+
top_k: Optional[int] = None,
|
| 1015 |
+
top_p: Optional[float] = None,
|
| 1016 |
+
repetition_penalty: float = 1.0,
|
| 1017 |
+
use_constrained_decoding: bool = True,
|
| 1018 |
+
constrained_decoding_debug: bool = False,
|
| 1019 |
+
) -> CreateSampleResult:
|
| 1020 |
+
"""Create a music sample from a natural language query using the 5Hz Language Model.
|
| 1021 |
+
|
| 1022 |
+
This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
|
| 1023 |
+
language description of music and generates a complete sample including:
|
| 1024 |
+
- Detailed caption/description
|
| 1025 |
+
- Lyrics (unless instrumental)
|
| 1026 |
+
- Metadata (BPM, duration, key, language, time signature)
|
| 1027 |
+
|
| 1028 |
+
Note: cfg_scale and negative_prompt are not supported in create_sample mode.
|
| 1029 |
+
|
| 1030 |
+
Args:
|
| 1031 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 1032 |
+
query: User's natural language music description (e.g., "a soft Bengali love song")
|
| 1033 |
+
instrumental: Whether to generate instrumental music (no vocals)
|
| 1034 |
+
vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
|
| 1035 |
+
If provided, the model will be constrained to generate lyrics in this language.
|
| 1036 |
+
If None or "unknown", no language constraint is applied.
|
| 1037 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 1038 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 1039 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 1040 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 1041 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding
|
| 1042 |
+
constrained_decoding_debug: Whether to enable debug logging
|
| 1043 |
+
|
| 1044 |
+
Returns:
|
| 1045 |
+
CreateSampleResult with generated sample fields and status
|
| 1046 |
+
|
| 1047 |
+
Example:
|
| 1048 |
+
>>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
|
| 1049 |
+
>>> if result.success:
|
| 1050 |
+
... print(f"Caption: {result.caption}")
|
| 1051 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 1052 |
+
... print(f"BPM: {result.bpm}")
|
| 1053 |
+
"""
|
| 1054 |
+
# Check if LLM is initialized
|
| 1055 |
+
if not llm_handler.llm_initialized:
|
| 1056 |
+
return CreateSampleResult(
|
| 1057 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 1058 |
+
success=False,
|
| 1059 |
+
error="LLM not initialized",
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
try:
|
| 1063 |
+
# Call LLM to create sample
|
| 1064 |
+
metadata, status = llm_handler.create_sample_from_query(
|
| 1065 |
+
query=query,
|
| 1066 |
+
instrumental=instrumental,
|
| 1067 |
+
vocal_language=vocal_language,
|
| 1068 |
+
temperature=temperature,
|
| 1069 |
+
top_k=top_k,
|
| 1070 |
+
top_p=top_p,
|
| 1071 |
+
repetition_penalty=repetition_penalty,
|
| 1072 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1073 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1074 |
+
)
|
| 1075 |
+
|
| 1076 |
+
# Check if LLM returned empty metadata (error case)
|
| 1077 |
+
if not metadata:
|
| 1078 |
+
return CreateSampleResult(
|
| 1079 |
+
status_message=status or "Failed to create sample",
|
| 1080 |
+
success=False,
|
| 1081 |
+
error=status or "Empty metadata returned",
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
# Extract and convert fields
|
| 1085 |
+
caption = metadata.get('caption', '')
|
| 1086 |
+
lyrics = metadata.get('lyrics', '')
|
| 1087 |
+
keyscale = metadata.get('keyscale', '')
|
| 1088 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 1089 |
+
timesignature = metadata.get('timesignature', '')
|
| 1090 |
+
is_instrumental = metadata.get('instrumental', instrumental)
|
| 1091 |
+
|
| 1092 |
+
# Convert BPM to int
|
| 1093 |
+
bpm = None
|
| 1094 |
+
bpm_value = metadata.get('bpm')
|
| 1095 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 1096 |
+
try:
|
| 1097 |
+
bpm = int(bpm_value)
|
| 1098 |
+
except (ValueError, TypeError):
|
| 1099 |
+
pass
|
| 1100 |
+
|
| 1101 |
+
# Convert duration to float
|
| 1102 |
+
duration = None
|
| 1103 |
+
duration_value = metadata.get('duration')
|
| 1104 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 1105 |
+
try:
|
| 1106 |
+
duration = float(duration_value)
|
| 1107 |
+
except (ValueError, TypeError):
|
| 1108 |
+
pass
|
| 1109 |
+
|
| 1110 |
+
# Clean up N/A values
|
| 1111 |
+
if keyscale == 'N/A':
|
| 1112 |
+
keyscale = ''
|
| 1113 |
+
if language == 'N/A':
|
| 1114 |
+
language = ''
|
| 1115 |
+
if timesignature == 'N/A':
|
| 1116 |
+
timesignature = ''
|
| 1117 |
+
|
| 1118 |
+
return CreateSampleResult(
|
| 1119 |
+
caption=caption,
|
| 1120 |
+
lyrics=lyrics,
|
| 1121 |
+
bpm=bpm,
|
| 1122 |
+
duration=duration,
|
| 1123 |
+
keyscale=keyscale,
|
| 1124 |
+
language=language,
|
| 1125 |
+
timesignature=timesignature,
|
| 1126 |
+
instrumental=is_instrumental,
|
| 1127 |
+
status_message=status,
|
| 1128 |
+
success=True,
|
| 1129 |
+
error=None,
|
| 1130 |
+
)
|
| 1131 |
+
|
| 1132 |
+
except Exception as e:
|
| 1133 |
+
logger.exception("Sample creation failed")
|
| 1134 |
+
return CreateSampleResult(
|
| 1135 |
+
status_message=f"Error: {str(e)}",
|
| 1136 |
+
success=False,
|
| 1137 |
+
error=str(e),
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
@dataclass
|
| 1142 |
+
class FormatSampleResult:
|
| 1143 |
+
"""Result of formatting user-provided caption and lyrics.
|
| 1144 |
+
|
| 1145 |
+
This is used by the "Format" feature where users provide caption and lyrics,
|
| 1146 |
+
and the LLM formats them into structured music metadata and an enhanced description.
|
| 1147 |
+
|
| 1148 |
+
Attributes:
|
| 1149 |
+
# Metadata Fields
|
| 1150 |
+
caption: Enhanced/formatted music description/caption
|
| 1151 |
+
lyrics: Formatted lyrics (may be same as input or reformatted)
|
| 1152 |
+
bpm: Beats per minute (None if not detected)
|
| 1153 |
+
duration: Duration in seconds (None if not detected)
|
| 1154 |
+
keyscale: Musical key (e.g., "C Major")
|
| 1155 |
+
language: Vocal language code (e.g., "en", "zh")
|
| 1156 |
+
timesignature: Time signature (e.g., "4")
|
| 1157 |
+
|
| 1158 |
+
# Status
|
| 1159 |
+
status_message: Status message from formatting
|
| 1160 |
+
success: Whether formatting completed successfully
|
| 1161 |
+
error: Error message if formatting failed
|
| 1162 |
+
"""
|
| 1163 |
+
# Metadata Fields
|
| 1164 |
+
caption: str = ""
|
| 1165 |
+
lyrics: str = ""
|
| 1166 |
+
bpm: Optional[int] = None
|
| 1167 |
+
duration: Optional[float] = None
|
| 1168 |
+
keyscale: str = ""
|
| 1169 |
+
language: str = ""
|
| 1170 |
+
timesignature: str = ""
|
| 1171 |
+
|
| 1172 |
+
# Status
|
| 1173 |
+
status_message: str = ""
|
| 1174 |
+
success: bool = True
|
| 1175 |
+
error: Optional[str] = None
|
| 1176 |
+
|
| 1177 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 1178 |
+
"""Convert result to dictionary for JSON serialization."""
|
| 1179 |
+
return asdict(self)
|
| 1180 |
+
|
| 1181 |
+
|
| 1182 |
+
def format_sample(
|
| 1183 |
+
llm_handler,
|
| 1184 |
+
caption: str,
|
| 1185 |
+
lyrics: str,
|
| 1186 |
+
user_metadata: Optional[Dict[str, Any]] = None,
|
| 1187 |
+
temperature: float = 0.85,
|
| 1188 |
+
top_k: Optional[int] = None,
|
| 1189 |
+
top_p: Optional[float] = None,
|
| 1190 |
+
repetition_penalty: float = 1.0,
|
| 1191 |
+
use_constrained_decoding: bool = True,
|
| 1192 |
+
constrained_decoding_debug: bool = False,
|
| 1193 |
+
) -> FormatSampleResult:
|
| 1194 |
+
"""Format user-provided caption and lyrics using the 5Hz Language Model.
|
| 1195 |
+
|
| 1196 |
+
This function takes user input (caption and lyrics) and generates structured
|
| 1197 |
+
music metadata including an enhanced caption, BPM, duration, key, language,
|
| 1198 |
+
and time signature.
|
| 1199 |
+
|
| 1200 |
+
If user_metadata is provided, those values will be used to constrain the
|
| 1201 |
+
decoding, ensuring the output matches user-specified values.
|
| 1202 |
+
|
| 1203 |
+
Note: cfg_scale and negative_prompt are not supported in format mode.
|
| 1204 |
+
|
| 1205 |
+
Args:
|
| 1206 |
+
llm_handler: Initialized LLM handler (LLMHandler instance)
|
| 1207 |
+
caption: User's caption/description (e.g., "Latin pop, reggaeton")
|
| 1208 |
+
lyrics: User's lyrics with structure tags
|
| 1209 |
+
user_metadata: Optional dict with user-provided metadata to constrain decoding.
|
| 1210 |
+
Supported keys: bpm, duration, keyscale, timesignature, language
|
| 1211 |
+
temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
|
| 1212 |
+
top_k: Top-K sampling (None or 0 = disabled)
|
| 1213 |
+
top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
|
| 1214 |
+
repetition_penalty: Repetition penalty (1.0 = no penalty)
|
| 1215 |
+
use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
|
| 1216 |
+
constrained_decoding_debug: Whether to enable debug logging for constrained decoding
|
| 1217 |
+
|
| 1218 |
+
Returns:
|
| 1219 |
+
FormatSampleResult with formatted metadata fields and status
|
| 1220 |
+
|
| 1221 |
+
Example:
|
| 1222 |
+
>>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
|
| 1223 |
+
>>> if result.success:
|
| 1224 |
+
... print(f"Caption: {result.caption}")
|
| 1225 |
+
... print(f"BPM: {result.bpm}")
|
| 1226 |
+
... print(f"Lyrics: {result.lyrics}")
|
| 1227 |
+
"""
|
| 1228 |
+
# Check if LLM is initialized
|
| 1229 |
+
if not llm_handler.llm_initialized:
|
| 1230 |
+
return FormatSampleResult(
|
| 1231 |
+
status_message="5Hz LM not initialized. Please initialize it first.",
|
| 1232 |
+
success=False,
|
| 1233 |
+
error="LLM not initialized",
|
| 1234 |
+
)
|
| 1235 |
+
|
| 1236 |
+
try:
|
| 1237 |
+
# Call LLM formatting
|
| 1238 |
+
metadata, status = llm_handler.format_sample_from_input(
|
| 1239 |
+
caption=caption,
|
| 1240 |
+
lyrics=lyrics,
|
| 1241 |
+
user_metadata=user_metadata,
|
| 1242 |
+
temperature=temperature,
|
| 1243 |
+
top_k=top_k,
|
| 1244 |
+
top_p=top_p,
|
| 1245 |
+
repetition_penalty=repetition_penalty,
|
| 1246 |
+
use_constrained_decoding=use_constrained_decoding,
|
| 1247 |
+
constrained_decoding_debug=constrained_decoding_debug,
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
# Check if LLM returned empty metadata (error case)
|
| 1251 |
+
if not metadata:
|
| 1252 |
+
return FormatSampleResult(
|
| 1253 |
+
status_message=status or "Failed to format input",
|
| 1254 |
+
success=False,
|
| 1255 |
+
error=status or "Empty metadata returned",
|
| 1256 |
+
)
|
| 1257 |
+
|
| 1258 |
+
# Extract and convert fields
|
| 1259 |
+
result_caption = metadata.get('caption', '')
|
| 1260 |
+
result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
|
| 1261 |
+
keyscale = metadata.get('keyscale', '')
|
| 1262 |
+
language = metadata.get('language', metadata.get('vocal_language', ''))
|
| 1263 |
+
timesignature = metadata.get('timesignature', '')
|
| 1264 |
+
|
| 1265 |
+
# Convert BPM to int
|
| 1266 |
+
bpm = None
|
| 1267 |
+
bpm_value = metadata.get('bpm')
|
| 1268 |
+
if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
|
| 1269 |
+
try:
|
| 1270 |
+
bpm = int(bpm_value)
|
| 1271 |
+
except (ValueError, TypeError):
|
| 1272 |
+
pass
|
| 1273 |
+
|
| 1274 |
+
# Convert duration to float
|
| 1275 |
+
duration = None
|
| 1276 |
+
duration_value = metadata.get('duration')
|
| 1277 |
+
if duration_value is not None and duration_value != 'N/A' and duration_value != '':
|
| 1278 |
+
try:
|
| 1279 |
+
duration = float(duration_value)
|
| 1280 |
+
except (ValueError, TypeError):
|
| 1281 |
+
pass
|
| 1282 |
+
|
| 1283 |
+
# Clean up N/A values
|
| 1284 |
+
if keyscale == 'N/A':
|
| 1285 |
+
keyscale = ''
|
| 1286 |
+
if language == 'N/A':
|
| 1287 |
+
language = ''
|
| 1288 |
+
if timesignature == 'N/A':
|
| 1289 |
+
timesignature = ''
|
| 1290 |
+
|
| 1291 |
+
return FormatSampleResult(
|
| 1292 |
+
caption=result_caption,
|
| 1293 |
+
lyrics=result_lyrics,
|
| 1294 |
+
bpm=bpm,
|
| 1295 |
+
duration=duration,
|
| 1296 |
+
keyscale=keyscale,
|
| 1297 |
+
language=language,
|
| 1298 |
+
timesignature=timesignature,
|
| 1299 |
+
status_message=status,
|
| 1300 |
+
success=True,
|
| 1301 |
+
error=None,
|
| 1302 |
+
)
|
| 1303 |
+
|
| 1304 |
+
except Exception as e:
|
| 1305 |
+
logger.exception("Format sample failed")
|
| 1306 |
+
return FormatSampleResult(
|
| 1307 |
+
status_message=f"Error: {str(e)}",
|
| 1308 |
+
success=False,
|
| 1309 |
+
error=str(e),
|
| 1310 |
+
)
|
acestep/llm_inference.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
acestep/local_cache.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Local cache module to replace Redis
|
| 2 |
+
|
| 3 |
+
Uses diskcache as backend, provides Redis-compatible API.
|
| 4 |
+
Supports persistent storage and TTL expiration.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from typing import Any, Optional
|
| 10 |
+
from threading import Lock
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
from diskcache import Cache
|
| 14 |
+
HAS_DISKCACHE = True
|
| 15 |
+
except ImportError:
|
| 16 |
+
HAS_DISKCACHE = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class LocalCache:
|
| 20 |
+
"""
|
| 21 |
+
Local cache implementation with Redis-compatible API.
|
| 22 |
+
Uses diskcache as backend, supports persistence and TTL.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
_instance = None
|
| 26 |
+
_lock = Lock()
|
| 27 |
+
|
| 28 |
+
def __new__(cls, cache_dir: Optional[str] = None):
|
| 29 |
+
"""Singleton pattern"""
|
| 30 |
+
if cls._instance is None:
|
| 31 |
+
with cls._lock:
|
| 32 |
+
if cls._instance is None:
|
| 33 |
+
cls._instance = super().__new__(cls)
|
| 34 |
+
cls._instance._initialized = False
|
| 35 |
+
return cls._instance
|
| 36 |
+
|
| 37 |
+
def __init__(self, cache_dir: Optional[str] = None):
|
| 38 |
+
if getattr(self, '_initialized', False):
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
if not HAS_DISKCACHE:
|
| 42 |
+
raise ImportError(
|
| 43 |
+
"diskcache not installed. Run: pip install diskcache"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if cache_dir is None:
|
| 47 |
+
cache_dir = os.path.join(
|
| 48 |
+
os.path.dirname(os.path.dirname(__file__)),
|
| 49 |
+
".cache",
|
| 50 |
+
"local_redis"
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 54 |
+
self._cache = Cache(cache_dir)
|
| 55 |
+
self._initialized = True
|
| 56 |
+
|
| 57 |
+
def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
|
| 58 |
+
"""
|
| 59 |
+
Set key-value pair
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
name: Key name
|
| 63 |
+
value: Value (auto-serialize dict/list)
|
| 64 |
+
ex: Expiration time (seconds)
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
bool: Success status
|
| 68 |
+
"""
|
| 69 |
+
if isinstance(value, (dict, list)):
|
| 70 |
+
value = json.dumps(value, ensure_ascii=False)
|
| 71 |
+
self._cache.set(name, value, expire=ex)
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
def get(self, name: str) -> Optional[str]:
|
| 75 |
+
"""Get value"""
|
| 76 |
+
return self._cache.get(name)
|
| 77 |
+
|
| 78 |
+
def delete(self, name: str) -> int:
|
| 79 |
+
"""Delete key, returns number of deleted items"""
|
| 80 |
+
return 1 if self._cache.delete(name) else 0
|
| 81 |
+
|
| 82 |
+
def exists(self, name: str) -> bool:
|
| 83 |
+
"""Check if key exists"""
|
| 84 |
+
return name in self._cache
|
| 85 |
+
|
| 86 |
+
def keys(self, pattern: str = "*") -> list:
|
| 87 |
+
"""
|
| 88 |
+
Get list of matching keys
|
| 89 |
+
Note: Simplified implementation, only supports prefix and full matching
|
| 90 |
+
"""
|
| 91 |
+
if pattern == "*":
|
| 92 |
+
return list(self._cache.iterkeys())
|
| 93 |
+
|
| 94 |
+
prefix = pattern.rstrip("*")
|
| 95 |
+
return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
|
| 96 |
+
|
| 97 |
+
def expire(self, name: str, seconds: int) -> bool:
|
| 98 |
+
"""Set key expiration time"""
|
| 99 |
+
value = self._cache.get(name)
|
| 100 |
+
if value is not None:
|
| 101 |
+
self._cache.set(name, value, expire=seconds)
|
| 102 |
+
return True
|
| 103 |
+
return False
|
| 104 |
+
|
| 105 |
+
def ttl(self, name: str) -> int:
|
| 106 |
+
"""
|
| 107 |
+
Get remaining time to live (seconds)
|
| 108 |
+
Note: diskcache does not directly support TTL queries
|
| 109 |
+
"""
|
| 110 |
+
if name in self._cache:
|
| 111 |
+
return -1 # Exists but TTL unknown
|
| 112 |
+
return -2 # Key does not exist
|
| 113 |
+
|
| 114 |
+
def close(self):
|
| 115 |
+
"""Close cache connection"""
|
| 116 |
+
if hasattr(self, '_cache'):
|
| 117 |
+
self._cache.close()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Lazily initialized global instance
|
| 121 |
+
_local_cache: Optional[LocalCache] = None
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
|
| 125 |
+
"""Get local cache instance"""
|
| 126 |
+
global _local_cache
|
| 127 |
+
if _local_cache is None:
|
| 128 |
+
_local_cache = LocalCache(cache_dir)
|
| 129 |
+
return _local_cache
|
acestep/model_downloader.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ACE-Step Model Downloader
|
| 3 |
+
|
| 4 |
+
This module provides functionality to download models from HuggingFace Hub or ModelScope.
|
| 5 |
+
It supports automatic downloading when models are not found locally,
|
| 6 |
+
with intelligent fallback between download sources.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
import argparse
|
| 12 |
+
from typing import Optional, List, Dict, Tuple
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from loguru import logger
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# Network Detection & Smart Download
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
def _can_access_google(timeout: float = 3.0) -> bool:
|
| 23 |
+
"""
|
| 24 |
+
Check if Google is accessible (to determine HuggingFace vs ModelScope).
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
timeout: Connection timeout in seconds
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
True if Google is accessible, False otherwise
|
| 31 |
+
"""
|
| 32 |
+
import socket
|
| 33 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 34 |
+
try:
|
| 35 |
+
sock.settimeout(timeout)
|
| 36 |
+
sock.connect(("www.google.com", 443))
|
| 37 |
+
return True
|
| 38 |
+
except (socket.timeout, socket.error, OSError):
|
| 39 |
+
return False
|
| 40 |
+
finally:
|
| 41 |
+
sock.close()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _download_from_huggingface_internal(
|
| 45 |
+
repo_id: str,
|
| 46 |
+
local_dir: Path,
|
| 47 |
+
token: Optional[str] = None,
|
| 48 |
+
) -> None:
|
| 49 |
+
"""
|
| 50 |
+
Internal function to download from HuggingFace Hub.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
repo_id: HuggingFace repository ID (e.g., "ACE-Step/Ace-Step1.5")
|
| 54 |
+
local_dir: Local directory to save the model
|
| 55 |
+
token: HuggingFace token for private repos (optional)
|
| 56 |
+
|
| 57 |
+
Raises:
|
| 58 |
+
Exception: If download fails
|
| 59 |
+
"""
|
| 60 |
+
from huggingface_hub import snapshot_download
|
| 61 |
+
|
| 62 |
+
logger.info(f"[Model Download] Downloading from HuggingFace: {repo_id} -> {local_dir}")
|
| 63 |
+
|
| 64 |
+
snapshot_download(
|
| 65 |
+
repo_id=repo_id,
|
| 66 |
+
local_dir=str(local_dir),
|
| 67 |
+
local_dir_use_symlinks=False,
|
| 68 |
+
token=token,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _download_from_modelscope_internal(
|
| 73 |
+
repo_id: str,
|
| 74 |
+
local_dir: Path,
|
| 75 |
+
) -> None:
|
| 76 |
+
"""
|
| 77 |
+
Internal function to download from ModelScope.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
repo_id: ModelScope repository ID (e.g., "ACE-Step/Ace-Step1.5")
|
| 81 |
+
local_dir: Local directory to save the model
|
| 82 |
+
|
| 83 |
+
Raises:
|
| 84 |
+
Exception: If download fails
|
| 85 |
+
"""
|
| 86 |
+
from modelscope import snapshot_download
|
| 87 |
+
|
| 88 |
+
logger.info(f"[Model Download] Downloading from ModelScope: {repo_id} -> {local_dir}")
|
| 89 |
+
|
| 90 |
+
snapshot_download(
|
| 91 |
+
model_id=repo_id,
|
| 92 |
+
local_dir=str(local_dir),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _smart_download(
|
| 97 |
+
repo_id: str,
|
| 98 |
+
local_dir: Path,
|
| 99 |
+
token: Optional[str] = None,
|
| 100 |
+
prefer_source: Optional[str] = None,
|
| 101 |
+
) -> Tuple[bool, str]:
|
| 102 |
+
"""
|
| 103 |
+
Smart download with automatic fallback between HuggingFace and ModelScope.
|
| 104 |
+
|
| 105 |
+
Automatically detects network environment and chooses the best download source.
|
| 106 |
+
If the primary source fails, automatically falls back to the alternative.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
repo_id: Repository ID (same format for both HF and ModelScope)
|
| 110 |
+
local_dir: Local directory to save the model
|
| 111 |
+
token: HuggingFace token for private repos (optional)
|
| 112 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Tuple of (success, message)
|
| 116 |
+
"""
|
| 117 |
+
# Ensure directory exists
|
| 118 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
| 119 |
+
|
| 120 |
+
# Determine primary source
|
| 121 |
+
if prefer_source == "huggingface":
|
| 122 |
+
use_huggingface_first = True
|
| 123 |
+
logger.info("[Model Download] User preference: HuggingFace Hub")
|
| 124 |
+
elif prefer_source == "modelscope":
|
| 125 |
+
use_huggingface_first = False
|
| 126 |
+
logger.info("[Model Download] User preference: ModelScope")
|
| 127 |
+
else:
|
| 128 |
+
# Auto-detect network environment
|
| 129 |
+
can_access_google = _can_access_google()
|
| 130 |
+
use_huggingface_first = can_access_google
|
| 131 |
+
logger.info(f"[Model Download] Auto-detected: {'HuggingFace Hub' if can_access_google else 'ModelScope'}")
|
| 132 |
+
|
| 133 |
+
if use_huggingface_first:
|
| 134 |
+
logger.info("[Model Download] Using HuggingFace Hub...")
|
| 135 |
+
try:
|
| 136 |
+
_download_from_huggingface_internal(repo_id, local_dir, token)
|
| 137 |
+
return True, f"Successfully downloaded from HuggingFace: {repo_id}"
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logger.warning(f"[Model Download] HuggingFace download failed: {e}")
|
| 140 |
+
logger.info("[Model Download] Falling back to ModelScope...")
|
| 141 |
+
try:
|
| 142 |
+
_download_from_modelscope_internal(repo_id, local_dir)
|
| 143 |
+
return True, f"Successfully downloaded from ModelScope: {repo_id}"
|
| 144 |
+
except Exception as e2:
|
| 145 |
+
error_msg = f"Both HuggingFace and ModelScope downloads failed. HF: {e}, MS: {e2}"
|
| 146 |
+
logger.error(error_msg)
|
| 147 |
+
return False, error_msg
|
| 148 |
+
else:
|
| 149 |
+
logger.info("[Model Download] Using ModelScope...")
|
| 150 |
+
try:
|
| 151 |
+
_download_from_modelscope_internal(repo_id, local_dir)
|
| 152 |
+
return True, f"Successfully downloaded from ModelScope: {repo_id}"
|
| 153 |
+
except Exception as e:
|
| 154 |
+
logger.warning(f"[Model Download] ModelScope download failed: {e}")
|
| 155 |
+
logger.info("[Model Download] Falling back to HuggingFace Hub...")
|
| 156 |
+
try:
|
| 157 |
+
_download_from_huggingface_internal(repo_id, local_dir, token)
|
| 158 |
+
return True, f"Successfully downloaded from HuggingFace: {repo_id}"
|
| 159 |
+
except Exception as e2:
|
| 160 |
+
error_msg = f"Both ModelScope and HuggingFace downloads failed. MS: {e}, HF: {e2}"
|
| 161 |
+
logger.error(error_msg)
|
| 162 |
+
return False, error_msg
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# =============================================================================
|
| 166 |
+
# Model Registry
|
| 167 |
+
# =============================================================================
|
| 168 |
+
# Main model contains core components (vae, text_encoder, default DiT)
|
| 169 |
+
MAIN_MODEL_REPO = "ACE-Step/Ace-Step1.5"
|
| 170 |
+
|
| 171 |
+
# Sub-models that can be downloaded separately into the checkpoints directory
|
| 172 |
+
SUBMODEL_REGISTRY: Dict[str, str] = {
|
| 173 |
+
# LM models
|
| 174 |
+
"acestep-5Hz-lm-0.6B": "ACE-Step/acestep-5Hz-lm-0.6B",
|
| 175 |
+
"acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B",
|
| 176 |
+
# DiT models
|
| 177 |
+
"acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
|
| 178 |
+
"acestep-v15-sft": "ACE-Step/acestep-v15-sft",
|
| 179 |
+
"acestep-v15-base": "ACE-Step/acestep-v15-base",
|
| 180 |
+
"acestep-v15-turbo-shift1": "ACE-Step/acestep-v15-turbo-shift1",
|
| 181 |
+
"acestep-v15-turbo-continuous": "ACE-Step/acestep-v15-turbo-continuous",
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
# Components that come from the main model repo (ACE-Step/Ace-Step1.5)
|
| 185 |
+
MAIN_MODEL_COMPONENTS = [
|
| 186 |
+
"acestep-v15-turbo", # Default DiT model
|
| 187 |
+
"vae", # VAE for audio encoding/decoding
|
| 188 |
+
"Qwen3-Embedding-0.6B", # Text encoder
|
| 189 |
+
"acestep-5Hz-lm-1.7B", # Default LM model (1.7B)
|
| 190 |
+
]
|
| 191 |
+
|
| 192 |
+
# Default LM model (included in main model)
|
| 193 |
+
DEFAULT_LM_MODEL = "acestep-5Hz-lm-1.7B"
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def get_project_root() -> Path:
|
| 197 |
+
"""Get the project root directory."""
|
| 198 |
+
current_file = Path(__file__).resolve()
|
| 199 |
+
return current_file.parent.parent
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def get_checkpoints_dir(custom_dir: Optional[str] = None) -> Path:
|
| 203 |
+
"""Get the checkpoints directory path."""
|
| 204 |
+
if custom_dir:
|
| 205 |
+
return Path(custom_dir)
|
| 206 |
+
return get_project_root() / "checkpoints"
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def check_main_model_exists(checkpoints_dir: Optional[Path] = None) -> bool:
|
| 210 |
+
"""
|
| 211 |
+
Check if the main model components exist in the checkpoints directory.
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
True if all main model components exist, False otherwise.
|
| 215 |
+
"""
|
| 216 |
+
if checkpoints_dir is None:
|
| 217 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 218 |
+
|
| 219 |
+
for component in MAIN_MODEL_COMPONENTS:
|
| 220 |
+
component_path = checkpoints_dir / component
|
| 221 |
+
if not component_path.exists():
|
| 222 |
+
return False
|
| 223 |
+
return True
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def check_model_exists(model_name: str, checkpoints_dir: Optional[Path] = None) -> bool:
|
| 227 |
+
"""
|
| 228 |
+
Check if a specific model exists in the checkpoints directory.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
model_name: Name of the model to check
|
| 232 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
True if the model exists, False otherwise.
|
| 236 |
+
"""
|
| 237 |
+
if checkpoints_dir is None:
|
| 238 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 239 |
+
|
| 240 |
+
model_path = checkpoints_dir / model_name
|
| 241 |
+
return model_path.exists()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def list_available_models() -> Dict[str, str]:
|
| 245 |
+
"""
|
| 246 |
+
List all available models for download.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Dictionary mapping local names to HuggingFace repo IDs.
|
| 250 |
+
"""
|
| 251 |
+
models = {
|
| 252 |
+
"main": MAIN_MODEL_REPO,
|
| 253 |
+
**SUBMODEL_REGISTRY
|
| 254 |
+
}
|
| 255 |
+
return models
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def download_main_model(
|
| 259 |
+
checkpoints_dir: Optional[Path] = None,
|
| 260 |
+
force: bool = False,
|
| 261 |
+
token: Optional[str] = None,
|
| 262 |
+
prefer_source: Optional[str] = None,
|
| 263 |
+
) -> Tuple[bool, str]:
|
| 264 |
+
"""
|
| 265 |
+
Download the main ACE-Step model from HuggingFace or ModelScope.
|
| 266 |
+
|
| 267 |
+
The main model includes:
|
| 268 |
+
- acestep-v15-turbo (default DiT model)
|
| 269 |
+
- vae (audio encoder/decoder)
|
| 270 |
+
- Qwen3-Embedding-0.6B (text encoder)
|
| 271 |
+
- acestep-5Hz-lm-1.7B (default LM model)
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 275 |
+
force: Force re-download even if model exists
|
| 276 |
+
token: HuggingFace token for private repos (optional)
|
| 277 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Tuple of (success, message)
|
| 281 |
+
"""
|
| 282 |
+
if checkpoints_dir is None:
|
| 283 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 284 |
+
|
| 285 |
+
# Ensure checkpoints directory exists
|
| 286 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 287 |
+
|
| 288 |
+
if not force and check_main_model_exists(checkpoints_dir):
|
| 289 |
+
return True, f"Main model already exists at {checkpoints_dir}"
|
| 290 |
+
|
| 291 |
+
print(f"Downloading main model from {MAIN_MODEL_REPO}...")
|
| 292 |
+
print(f"Destination: {checkpoints_dir}")
|
| 293 |
+
print("This may take a while depending on your internet connection...")
|
| 294 |
+
|
| 295 |
+
# Use smart download with automatic fallback
|
| 296 |
+
return _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def download_submodel(
|
| 300 |
+
model_name: str,
|
| 301 |
+
checkpoints_dir: Optional[Path] = None,
|
| 302 |
+
force: bool = False,
|
| 303 |
+
token: Optional[str] = None,
|
| 304 |
+
prefer_source: Optional[str] = None,
|
| 305 |
+
) -> Tuple[bool, str]:
|
| 306 |
+
"""
|
| 307 |
+
Download a specific sub-model from HuggingFace or ModelScope.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
model_name: Name of the model to download (must be in SUBMODEL_REGISTRY)
|
| 311 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 312 |
+
force: Force re-download even if model exists
|
| 313 |
+
token: HuggingFace token for private repos (optional)
|
| 314 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Tuple of (success, message)
|
| 318 |
+
"""
|
| 319 |
+
if model_name not in SUBMODEL_REGISTRY:
|
| 320 |
+
available = ", ".join(SUBMODEL_REGISTRY.keys())
|
| 321 |
+
return False, f"Unknown model '{model_name}'. Available models: {available}"
|
| 322 |
+
|
| 323 |
+
if checkpoints_dir is None:
|
| 324 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 325 |
+
|
| 326 |
+
# Ensure checkpoints directory exists
|
| 327 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 328 |
+
|
| 329 |
+
model_path = checkpoints_dir / model_name
|
| 330 |
+
|
| 331 |
+
if not force and model_path.exists():
|
| 332 |
+
return True, f"Model '{model_name}' already exists at {model_path}"
|
| 333 |
+
|
| 334 |
+
repo_id = SUBMODEL_REGISTRY[model_name]
|
| 335 |
+
|
| 336 |
+
print(f"Downloading {model_name} from {repo_id}...")
|
| 337 |
+
print(f"Destination: {model_path}")
|
| 338 |
+
|
| 339 |
+
# Use smart download with automatic fallback
|
| 340 |
+
return _smart_download(repo_id, model_path, token, prefer_source)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def download_all_models(
|
| 344 |
+
checkpoints_dir: Optional[Path] = None,
|
| 345 |
+
force: bool = False,
|
| 346 |
+
token: Optional[str] = None,
|
| 347 |
+
) -> Tuple[bool, List[str]]:
|
| 348 |
+
"""
|
| 349 |
+
Download all available models.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 353 |
+
force: Force re-download even if models exist
|
| 354 |
+
token: HuggingFace token for private repos (optional)
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Tuple of (all_success, list of messages)
|
| 358 |
+
"""
|
| 359 |
+
if checkpoints_dir is None:
|
| 360 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 361 |
+
|
| 362 |
+
messages = []
|
| 363 |
+
all_success = True
|
| 364 |
+
|
| 365 |
+
# Download main model first
|
| 366 |
+
success, msg = download_main_model(checkpoints_dir, force, token)
|
| 367 |
+
messages.append(msg)
|
| 368 |
+
if not success:
|
| 369 |
+
all_success = False
|
| 370 |
+
|
| 371 |
+
# Download all sub-models
|
| 372 |
+
for model_name in SUBMODEL_REGISTRY:
|
| 373 |
+
success, msg = download_submodel(model_name, checkpoints_dir, force, token)
|
| 374 |
+
messages.append(msg)
|
| 375 |
+
if not success:
|
| 376 |
+
all_success = False
|
| 377 |
+
|
| 378 |
+
return all_success, messages
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def ensure_main_model(
|
| 382 |
+
checkpoints_dir: Optional[Path] = None,
|
| 383 |
+
token: Optional[str] = None,
|
| 384 |
+
prefer_source: Optional[str] = None,
|
| 385 |
+
) -> Tuple[bool, str]:
|
| 386 |
+
"""
|
| 387 |
+
Ensure the main model is available, downloading if necessary.
|
| 388 |
+
|
| 389 |
+
This function is designed to be called during initialization.
|
| 390 |
+
It will only download if the model doesn't exist.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 394 |
+
token: HuggingFace token for private repos (optional)
|
| 395 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
Tuple of (success, message)
|
| 399 |
+
"""
|
| 400 |
+
if checkpoints_dir is None:
|
| 401 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 402 |
+
|
| 403 |
+
if check_main_model_exists(checkpoints_dir):
|
| 404 |
+
return True, "Main model is available"
|
| 405 |
+
|
| 406 |
+
print("\n" + "=" * 60)
|
| 407 |
+
print("Main model not found. Starting automatic download...")
|
| 408 |
+
print("=" * 60 + "\n")
|
| 409 |
+
|
| 410 |
+
return download_main_model(checkpoints_dir, token=token, prefer_source=prefer_source)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
def ensure_lm_model(
|
| 414 |
+
model_name: Optional[str] = None,
|
| 415 |
+
checkpoints_dir: Optional[Path] = None,
|
| 416 |
+
token: Optional[str] = None,
|
| 417 |
+
prefer_source: Optional[str] = None,
|
| 418 |
+
) -> Tuple[bool, str]:
|
| 419 |
+
"""
|
| 420 |
+
Ensure an LM model is available, downloading if necessary.
|
| 421 |
+
|
| 422 |
+
Args:
|
| 423 |
+
model_name: Name of the LM model (defaults to DEFAULT_LM_MODEL)
|
| 424 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 425 |
+
token: HuggingFace token for private repos (optional)
|
| 426 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
Tuple of (success, message)
|
| 430 |
+
"""
|
| 431 |
+
if model_name is None:
|
| 432 |
+
model_name = DEFAULT_LM_MODEL
|
| 433 |
+
|
| 434 |
+
if checkpoints_dir is None:
|
| 435 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 436 |
+
|
| 437 |
+
if check_model_exists(model_name, checkpoints_dir):
|
| 438 |
+
return True, f"LM model '{model_name}' is available"
|
| 439 |
+
|
| 440 |
+
# Check if this is a known LM model
|
| 441 |
+
if model_name not in SUBMODEL_REGISTRY:
|
| 442 |
+
# Check if it might be a variant name
|
| 443 |
+
for known_model in SUBMODEL_REGISTRY:
|
| 444 |
+
if "lm" in known_model.lower() and model_name.lower() in known_model.lower():
|
| 445 |
+
model_name = known_model
|
| 446 |
+
break
|
| 447 |
+
else:
|
| 448 |
+
return False, f"Unknown LM model: {model_name}"
|
| 449 |
+
|
| 450 |
+
print("\n" + "=" * 60)
|
| 451 |
+
print(f"LM model '{model_name}' not found. Starting automatic download...")
|
| 452 |
+
print("=" * 60 + "\n")
|
| 453 |
+
|
| 454 |
+
return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def ensure_dit_model(
|
| 458 |
+
model_name: str,
|
| 459 |
+
checkpoints_dir: Optional[Path] = None,
|
| 460 |
+
token: Optional[str] = None,
|
| 461 |
+
prefer_source: Optional[str] = None,
|
| 462 |
+
) -> Tuple[bool, str]:
|
| 463 |
+
"""
|
| 464 |
+
Ensure a DiT model is available, downloading if necessary.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
model_name: Name of the DiT model
|
| 468 |
+
checkpoints_dir: Custom checkpoints directory (optional)
|
| 469 |
+
token: HuggingFace token for private repos (optional)
|
| 470 |
+
prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
Tuple of (success, message)
|
| 474 |
+
"""
|
| 475 |
+
if checkpoints_dir is None:
|
| 476 |
+
checkpoints_dir = get_checkpoints_dir()
|
| 477 |
+
|
| 478 |
+
if check_model_exists(model_name, checkpoints_dir):
|
| 479 |
+
return True, f"DiT model '{model_name}' is available"
|
| 480 |
+
|
| 481 |
+
# Check if this is the default turbo model (part of main)
|
| 482 |
+
if model_name == "acestep-v15-turbo":
|
| 483 |
+
return ensure_main_model(checkpoints_dir, token, prefer_source)
|
| 484 |
+
|
| 485 |
+
# Check if it's a known sub-model
|
| 486 |
+
if model_name in SUBMODEL_REGISTRY:
|
| 487 |
+
print("\n" + "=" * 60)
|
| 488 |
+
print(f"DiT model '{model_name}' not found. Starting automatic download...")
|
| 489 |
+
print("=" * 60 + "\n")
|
| 490 |
+
return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
|
| 491 |
+
|
| 492 |
+
return False, f"Unknown DiT model: {model_name}"
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def print_model_list():
|
| 496 |
+
"""Print formatted list of available models."""
|
| 497 |
+
print("\nAvailable Models for Download:")
|
| 498 |
+
print("=" * 60)
|
| 499 |
+
print("\nSupported Sources: HuggingFace Hub <-> ModelScope (auto-fallback)")
|
| 500 |
+
|
| 501 |
+
print("\n[Main Model]")
|
| 502 |
+
print(f" main -> {MAIN_MODEL_REPO}")
|
| 503 |
+
print(" Contains: vae, Qwen3-Embedding-0.6B, acestep-v15-turbo, acestep-5Hz-lm-1.7B")
|
| 504 |
+
|
| 505 |
+
print("\n[Optional LM Models]")
|
| 506 |
+
for name, repo in SUBMODEL_REGISTRY.items():
|
| 507 |
+
if "lm" in name.lower():
|
| 508 |
+
print(f" {name} -> {repo}")
|
| 509 |
+
|
| 510 |
+
print("\n[Optional DiT Models]")
|
| 511 |
+
for name, repo in SUBMODEL_REGISTRY.items():
|
| 512 |
+
if "lm" not in name.lower():
|
| 513 |
+
print(f" {name} -> {repo}")
|
| 514 |
+
|
| 515 |
+
print("\n" + "=" * 60)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def main():
|
| 519 |
+
"""CLI entry point for model downloading."""
|
| 520 |
+
parser = argparse.ArgumentParser(
|
| 521 |
+
description="Download ACE-Step models with automatic fallback (HuggingFace <-> ModelScope)",
|
| 522 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 523 |
+
epilog="""
|
| 524 |
+
Examples:
|
| 525 |
+
acestep-download # Download main model (includes LM 1.7B)
|
| 526 |
+
acestep-download --all # Download all available models
|
| 527 |
+
acestep-download --model acestep-v15-sft # Download a specific model
|
| 528 |
+
acestep-download --list # List all available models
|
| 529 |
+
|
| 530 |
+
Network Detection:
|
| 531 |
+
Automatically detects network environment and chooses the best download source:
|
| 532 |
+
- Google accessible -> HuggingFace (fallback to ModelScope)
|
| 533 |
+
- Google blocked -> ModelScope (fallback to HuggingFace)
|
| 534 |
+
|
| 535 |
+
Alternative using huggingface-cli:
|
| 536 |
+
huggingface-cli download ACE-Step/Ace-Step1.5 --local-dir ./checkpoints
|
| 537 |
+
huggingface-cli download ACE-Step/acestep-5Hz-lm-0.6B --local-dir ./checkpoints/acestep-5Hz-lm-0.6B
|
| 538 |
+
"""
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
parser.add_argument(
|
| 542 |
+
"--model", "-m",
|
| 543 |
+
type=str,
|
| 544 |
+
help="Specific model to download (use --list to see available models)"
|
| 545 |
+
)
|
| 546 |
+
parser.add_argument(
|
| 547 |
+
"--all", "-a",
|
| 548 |
+
action="store_true",
|
| 549 |
+
help="Download all available models"
|
| 550 |
+
)
|
| 551 |
+
parser.add_argument(
|
| 552 |
+
"--list", "-l",
|
| 553 |
+
action="store_true",
|
| 554 |
+
help="List all available models"
|
| 555 |
+
)
|
| 556 |
+
parser.add_argument(
|
| 557 |
+
"--dir", "-d",
|
| 558 |
+
type=str,
|
| 559 |
+
default=None,
|
| 560 |
+
help="Custom checkpoints directory (default: ./checkpoints)"
|
| 561 |
+
)
|
| 562 |
+
parser.add_argument(
|
| 563 |
+
"--force", "-f",
|
| 564 |
+
action="store_true",
|
| 565 |
+
help="Force re-download even if model exists"
|
| 566 |
+
)
|
| 567 |
+
parser.add_argument(
|
| 568 |
+
"--token", "-t",
|
| 569 |
+
type=str,
|
| 570 |
+
default=None,
|
| 571 |
+
help="HuggingFace token for private repos"
|
| 572 |
+
)
|
| 573 |
+
parser.add_argument(
|
| 574 |
+
"--skip-main",
|
| 575 |
+
action="store_true",
|
| 576 |
+
help="Skip downloading the main model (only download specified sub-model)"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
args = parser.parse_args()
|
| 580 |
+
|
| 581 |
+
# Handle --list
|
| 582 |
+
if args.list:
|
| 583 |
+
print_model_list()
|
| 584 |
+
return 0
|
| 585 |
+
|
| 586 |
+
# Get checkpoints directory
|
| 587 |
+
checkpoints_dir = get_checkpoints_dir(args.dir) if args.dir else get_checkpoints_dir()
|
| 588 |
+
print(f"Checkpoints directory: {checkpoints_dir}")
|
| 589 |
+
|
| 590 |
+
# Handle --all
|
| 591 |
+
if args.all:
|
| 592 |
+
success, messages = download_all_models(checkpoints_dir, args.force, args.token)
|
| 593 |
+
for msg in messages:
|
| 594 |
+
print(msg)
|
| 595 |
+
return 0 if success else 1
|
| 596 |
+
|
| 597 |
+
# Handle --model
|
| 598 |
+
if args.model:
|
| 599 |
+
if args.model == "main":
|
| 600 |
+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
|
| 601 |
+
elif args.model in SUBMODEL_REGISTRY:
|
| 602 |
+
# Download main model first if needed (unless --skip-main)
|
| 603 |
+
if not args.skip_main and not check_main_model_exists(checkpoints_dir):
|
| 604 |
+
print("Main model not found. Downloading main model first...")
|
| 605 |
+
main_success, main_msg = download_main_model(checkpoints_dir, args.force, args.token)
|
| 606 |
+
print(main_msg)
|
| 607 |
+
if not main_success:
|
| 608 |
+
return 1
|
| 609 |
+
|
| 610 |
+
success, msg = download_submodel(args.model, checkpoints_dir, args.force, args.token)
|
| 611 |
+
else:
|
| 612 |
+
print(f"Unknown model: {args.model}")
|
| 613 |
+
print("Use --list to see available models")
|
| 614 |
+
return 1
|
| 615 |
+
|
| 616 |
+
print(msg)
|
| 617 |
+
return 0 if success else 1
|
| 618 |
+
|
| 619 |
+
# Default: download main model (includes default LM 1.7B)
|
| 620 |
+
print("Downloading main model (includes vae, text encoder, DiT, and LM 1.7B)...")
|
| 621 |
+
|
| 622 |
+
# Download main model
|
| 623 |
+
success, msg = download_main_model(checkpoints_dir, args.force, args.token)
|
| 624 |
+
print(msg)
|
| 625 |
+
|
| 626 |
+
if success:
|
| 627 |
+
print("\nDownload complete!")
|
| 628 |
+
print(f"Models are available at: {checkpoints_dir}")
|
| 629 |
+
|
| 630 |
+
return 0 if success else 1
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if __name__ == "__main__":
|
| 634 |
+
sys.exit(main())
|
acestep/openrouter_adapter.py
ADDED
|
@@ -0,0 +1,773 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenRouter API adapter for ACE-Step music generation.
|
| 2 |
+
|
| 3 |
+
This module provides OpenRouter-compatible endpoints that wrap the ACE-Step
|
| 4 |
+
music generation API, mounted as a sub-router on the main api_server.
|
| 5 |
+
|
| 6 |
+
All generation requests go through the shared asyncio.Queue, ensuring unified
|
| 7 |
+
GPU scheduling with release_task.
|
| 8 |
+
|
| 9 |
+
Endpoints:
|
| 10 |
+
- POST /v1/chat/completions - Generate music via chat completion format
|
| 11 |
+
- GET /v1/models - List available models (OpenRouter format)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import asyncio
|
| 17 |
+
import base64
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import re
|
| 21 |
+
import tempfile
|
| 22 |
+
import time
|
| 23 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 24 |
+
from uuid import uuid4
|
| 25 |
+
|
| 26 |
+
from fastapi import APIRouter, HTTPException, Request
|
| 27 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 28 |
+
|
| 29 |
+
from acestep.openrouter_models import (
|
| 30 |
+
AudioConfig,
|
| 31 |
+
ChatCompletionRequest,
|
| 32 |
+
ModelInfo,
|
| 33 |
+
ModelPricing,
|
| 34 |
+
ModelsResponse,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# =============================================================================
|
| 39 |
+
# Constants
|
| 40 |
+
# =============================================================================
|
| 41 |
+
|
| 42 |
+
MODEL_PREFIX = "acestep"
|
| 43 |
+
DEFAULT_AUDIO_FORMAT = "mp3"
|
| 44 |
+
|
| 45 |
+
# Generation timeout for non-streaming requests (seconds)
|
| 46 |
+
GENERATION_TIMEOUT = int(os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600"))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# =============================================================================
|
| 50 |
+
# Helper Functions
|
| 51 |
+
# =============================================================================
|
| 52 |
+
|
| 53 |
+
def _generate_completion_id() -> str:
|
| 54 |
+
"""Generate a unique completion ID."""
|
| 55 |
+
return f"chatcmpl-{uuid4().hex[:24]}"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _get_model_id(model_name: str) -> str:
|
| 59 |
+
"""Convert internal model name to OpenRouter model ID."""
|
| 60 |
+
return f"{MODEL_PREFIX}/{model_name}"
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _parse_model_name(model_id: str) -> str:
|
| 64 |
+
"""Extract internal model name from OpenRouter model ID."""
|
| 65 |
+
if "/" in model_id:
|
| 66 |
+
return model_id.split("/", 1)[1]
|
| 67 |
+
return model_id
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _audio_to_base64_url(audio_path: str, audio_format: str = "mp3") -> str:
|
| 71 |
+
"""Convert audio file to base64 data URL."""
|
| 72 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 73 |
+
return ""
|
| 74 |
+
|
| 75 |
+
mime_types = {
|
| 76 |
+
"mp3": "audio/mpeg",
|
| 77 |
+
"wav": "audio/wav",
|
| 78 |
+
"flac": "audio/flac",
|
| 79 |
+
"ogg": "audio/ogg",
|
| 80 |
+
"m4a": "audio/mp4",
|
| 81 |
+
"aac": "audio/aac",
|
| 82 |
+
}
|
| 83 |
+
mime_type = mime_types.get(audio_format.lower(), "audio/mpeg")
|
| 84 |
+
|
| 85 |
+
with open(audio_path, "rb") as f:
|
| 86 |
+
audio_data = f.read()
|
| 87 |
+
|
| 88 |
+
b64_data = base64.b64encode(audio_data).decode("utf-8")
|
| 89 |
+
return f"data:{mime_type};base64,{b64_data}"
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _format_lm_content(result: Dict[str, Any]) -> str:
|
| 93 |
+
"""Format generation result as content string with metadata and lyrics."""
|
| 94 |
+
metas = result.get("metas", {})
|
| 95 |
+
lyrics = result.get("lyrics", "")
|
| 96 |
+
|
| 97 |
+
parts = []
|
| 98 |
+
|
| 99 |
+
# Add metadata section
|
| 100 |
+
meta_lines = []
|
| 101 |
+
caption = metas.get("prompt") or metas.get("caption") or result.get("prompt", "")
|
| 102 |
+
if caption:
|
| 103 |
+
meta_lines.append(f"**Caption:** {caption}")
|
| 104 |
+
if metas.get("bpm") and metas["bpm"] != "N/A":
|
| 105 |
+
meta_lines.append(f"**BPM:** {metas['bpm']}")
|
| 106 |
+
if metas.get("duration") and metas["duration"] != "N/A":
|
| 107 |
+
meta_lines.append(f"**Duration:** {metas['duration']}s")
|
| 108 |
+
if metas.get("keyscale") and metas["keyscale"] != "N/A":
|
| 109 |
+
meta_lines.append(f"**Key:** {metas['keyscale']}")
|
| 110 |
+
if metas.get("timesignature") and metas["timesignature"] != "N/A":
|
| 111 |
+
meta_lines.append(f"**Time Signature:** {metas['timesignature']}")
|
| 112 |
+
|
| 113 |
+
if meta_lines:
|
| 114 |
+
parts.append("## Metadata\n" + "\n".join(meta_lines))
|
| 115 |
+
|
| 116 |
+
# Add lyrics section
|
| 117 |
+
if lyrics and lyrics.strip() and lyrics.strip().lower() not in ("[inst]", "[instrumental]"):
|
| 118 |
+
parts.append(f"## Lyrics\n{lyrics}")
|
| 119 |
+
|
| 120 |
+
if parts:
|
| 121 |
+
return "\n\n".join(parts)
|
| 122 |
+
else:
|
| 123 |
+
return "Music generated successfully."
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _base64_to_temp_file(b64_data: str, audio_format: str = "mp3") -> str:
|
| 127 |
+
"""Save base64 audio data to temporary file."""
|
| 128 |
+
if "," in b64_data:
|
| 129 |
+
b64_data = b64_data.split(",", 1)[1]
|
| 130 |
+
|
| 131 |
+
audio_bytes = base64.b64decode(b64_data)
|
| 132 |
+
suffix = f".{audio_format}" if not audio_format.startswith(".") else audio_format
|
| 133 |
+
fd, path = tempfile.mkstemp(suffix=suffix, prefix="openrouter_audio_")
|
| 134 |
+
os.close(fd)
|
| 135 |
+
|
| 136 |
+
with open(path, "wb") as f:
|
| 137 |
+
f.write(audio_bytes)
|
| 138 |
+
|
| 139 |
+
return path
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _extract_tagged_content(text: str) -> Tuple[Optional[str], Optional[str], str]:
|
| 143 |
+
"""
|
| 144 |
+
Extract content from <prompt> and <lyrics> tags.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
(prompt, lyrics, remaining_text)
|
| 148 |
+
"""
|
| 149 |
+
prompt = None
|
| 150 |
+
lyrics = None
|
| 151 |
+
remaining = text
|
| 152 |
+
|
| 153 |
+
prompt_match = re.search(r'<prompt>(.*?)</prompt>', text, re.DOTALL | re.IGNORECASE)
|
| 154 |
+
if prompt_match:
|
| 155 |
+
prompt = prompt_match.group(1).strip()
|
| 156 |
+
remaining = remaining.replace(prompt_match.group(0), '').strip()
|
| 157 |
+
|
| 158 |
+
lyrics_match = re.search(r'<lyrics>(.*?)</lyrics>', text, re.DOTALL | re.IGNORECASE)
|
| 159 |
+
if lyrics_match:
|
| 160 |
+
lyrics = lyrics_match.group(1).strip()
|
| 161 |
+
remaining = remaining.replace(lyrics_match.group(0), '').strip()
|
| 162 |
+
|
| 163 |
+
return prompt, lyrics, remaining
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _looks_like_lyrics(text: str) -> bool:
|
| 167 |
+
"""Heuristic to detect if text looks like song lyrics."""
|
| 168 |
+
if not text:
|
| 169 |
+
return False
|
| 170 |
+
|
| 171 |
+
lyrics_markers = [
|
| 172 |
+
"[verse", "[chorus", "[bridge", "[intro", "[outro",
|
| 173 |
+
"[hook", "[pre-chorus", "[refrain", "[inst",
|
| 174 |
+
]
|
| 175 |
+
text_lower = text.lower()
|
| 176 |
+
for marker in lyrics_markers:
|
| 177 |
+
if marker in text_lower:
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
| 181 |
+
if len(lines) >= 4:
|
| 182 |
+
avg_line_length = sum(len(line) for line in lines) / len(lines)
|
| 183 |
+
if avg_line_length < 60:
|
| 184 |
+
return True
|
| 185 |
+
|
| 186 |
+
return False
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def _is_instrumental(lyrics: str) -> bool:
|
| 190 |
+
"""Check if the music should be instrumental based on lyrics."""
|
| 191 |
+
if not lyrics:
|
| 192 |
+
return True
|
| 193 |
+
lyrics_clean = lyrics.strip().lower()
|
| 194 |
+
if not lyrics_clean:
|
| 195 |
+
return True
|
| 196 |
+
return lyrics_clean in ("[inst]", "[instrumental]")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[str], Optional[str]]:
|
| 200 |
+
"""
|
| 201 |
+
Parse chat messages to extract prompt, lyrics, sample_query and audio references.
|
| 202 |
+
|
| 203 |
+
Supports two modes:
|
| 204 |
+
1. Tagged mode: Use <prompt>...</prompt> and <lyrics>...</lyrics> tags
|
| 205 |
+
2. Heuristic mode: Auto-detect based on content structure
|
| 206 |
+
|
| 207 |
+
Multiple input_audio blocks are collected in order (like multiple images).
|
| 208 |
+
The caller routes them to src_audio / reference_audio based on task_type.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
(prompt, lyrics, audio_paths, system_instruction, sample_query)
|
| 212 |
+
"""
|
| 213 |
+
prompt_parts = []
|
| 214 |
+
lyrics = ""
|
| 215 |
+
sample_query = None
|
| 216 |
+
audio_paths: List[str] = []
|
| 217 |
+
system_instruction = None
|
| 218 |
+
has_tags = False
|
| 219 |
+
|
| 220 |
+
for msg in messages:
|
| 221 |
+
role = msg.role
|
| 222 |
+
content = msg.content
|
| 223 |
+
|
| 224 |
+
if role == "system":
|
| 225 |
+
if isinstance(content, str):
|
| 226 |
+
system_instruction = content
|
| 227 |
+
continue
|
| 228 |
+
|
| 229 |
+
if role != "user":
|
| 230 |
+
continue
|
| 231 |
+
|
| 232 |
+
if isinstance(content, str):
|
| 233 |
+
text = content.strip()
|
| 234 |
+
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
|
| 235 |
+
if tagged_prompt is not None or tagged_lyrics is not None:
|
| 236 |
+
has_tags = True
|
| 237 |
+
if tagged_prompt:
|
| 238 |
+
prompt_parts.append(tagged_prompt)
|
| 239 |
+
if tagged_lyrics:
|
| 240 |
+
lyrics = tagged_lyrics
|
| 241 |
+
if remaining:
|
| 242 |
+
prompt_parts.append(remaining)
|
| 243 |
+
else:
|
| 244 |
+
if _looks_like_lyrics(text):
|
| 245 |
+
lyrics = text
|
| 246 |
+
else:
|
| 247 |
+
prompt_parts.append(text)
|
| 248 |
+
|
| 249 |
+
elif isinstance(content, list):
|
| 250 |
+
for part in content:
|
| 251 |
+
if isinstance(part, dict):
|
| 252 |
+
part_type = part.get("type", "")
|
| 253 |
+
|
| 254 |
+
if part_type == "text":
|
| 255 |
+
text = part.get("text", "").strip()
|
| 256 |
+
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
|
| 257 |
+
if tagged_prompt is not None or tagged_lyrics is not None:
|
| 258 |
+
has_tags = True
|
| 259 |
+
if tagged_prompt:
|
| 260 |
+
prompt_parts.append(tagged_prompt)
|
| 261 |
+
if tagged_lyrics:
|
| 262 |
+
lyrics = tagged_lyrics
|
| 263 |
+
if remaining:
|
| 264 |
+
prompt_parts.append(remaining)
|
| 265 |
+
elif _looks_like_lyrics(text):
|
| 266 |
+
lyrics = text
|
| 267 |
+
else:
|
| 268 |
+
prompt_parts.append(text)
|
| 269 |
+
|
| 270 |
+
elif part_type == "input_audio":
|
| 271 |
+
audio_data = part.get("input_audio", {})
|
| 272 |
+
if isinstance(audio_data, dict):
|
| 273 |
+
b64_data = audio_data.get("data", "")
|
| 274 |
+
audio_format = audio_data.get("format", "mp3")
|
| 275 |
+
if b64_data:
|
| 276 |
+
try:
|
| 277 |
+
path = _base64_to_temp_file(b64_data, audio_format)
|
| 278 |
+
audio_paths.append(path)
|
| 279 |
+
except Exception:
|
| 280 |
+
pass
|
| 281 |
+
|
| 282 |
+
elif hasattr(part, "type"):
|
| 283 |
+
if part.type == "text":
|
| 284 |
+
text = getattr(part, "text", "").strip()
|
| 285 |
+
tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
|
| 286 |
+
if tagged_prompt is not None or tagged_lyrics is not None:
|
| 287 |
+
has_tags = True
|
| 288 |
+
if tagged_prompt:
|
| 289 |
+
prompt_parts.append(tagged_prompt)
|
| 290 |
+
if tagged_lyrics:
|
| 291 |
+
lyrics = tagged_lyrics
|
| 292 |
+
if remaining:
|
| 293 |
+
prompt_parts.append(remaining)
|
| 294 |
+
elif _looks_like_lyrics(text):
|
| 295 |
+
lyrics = text
|
| 296 |
+
else:
|
| 297 |
+
prompt_parts.append(text)
|
| 298 |
+
|
| 299 |
+
elif part.type == "input_audio":
|
| 300 |
+
audio_data = getattr(part, "input_audio", None)
|
| 301 |
+
if audio_data:
|
| 302 |
+
b64_data = getattr(audio_data, "data", "")
|
| 303 |
+
audio_format = getattr(audio_data, "format", "mp3")
|
| 304 |
+
if b64_data:
|
| 305 |
+
try:
|
| 306 |
+
path = _base64_to_temp_file(b64_data, audio_format)
|
| 307 |
+
audio_paths.append(path)
|
| 308 |
+
except Exception:
|
| 309 |
+
pass
|
| 310 |
+
|
| 311 |
+
prompt = " ".join(prompt_parts).strip()
|
| 312 |
+
|
| 313 |
+
# Use sample mode when: no tags, no lyrics detected, and we have text input
|
| 314 |
+
if not has_tags and not lyrics and prompt:
|
| 315 |
+
sample_query = prompt
|
| 316 |
+
prompt = ""
|
| 317 |
+
|
| 318 |
+
return prompt, lyrics, audio_paths, system_instruction, sample_query
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def _to_generate_music_request(
|
| 322 |
+
req: ChatCompletionRequest,
|
| 323 |
+
prompt: str,
|
| 324 |
+
lyrics: str,
|
| 325 |
+
sample_query: Optional[str],
|
| 326 |
+
reference_audio_path: Optional[str],
|
| 327 |
+
src_audio_path: Optional[str],
|
| 328 |
+
):
|
| 329 |
+
"""
|
| 330 |
+
Convert OpenRouter ChatCompletionRequest to api_server's GenerateMusicRequest.
|
| 331 |
+
|
| 332 |
+
Audio routing depends on task_type:
|
| 333 |
+
text2music: audio[0] → reference_audio
|
| 334 |
+
cover/repaint/lego/…: audio[0] → src_audio, audio[1] → reference_audio
|
| 335 |
+
|
| 336 |
+
task_type auto-detection:
|
| 337 |
+
text2music + reference_audio → music_continuation
|
| 338 |
+
|
| 339 |
+
Uses late import to avoid circular dependency with api_server.
|
| 340 |
+
"""
|
| 341 |
+
from acestep.api_server import GenerateMusicRequest
|
| 342 |
+
|
| 343 |
+
audio_config = req.audio_config or AudioConfig()
|
| 344 |
+
|
| 345 |
+
# Resolve parameters from audio_config only
|
| 346 |
+
resolved_instrumental = audio_config.instrumental if audio_config.instrumental is not None else False
|
| 347 |
+
|
| 348 |
+
# If instrumental, set lyrics to [inst]
|
| 349 |
+
resolved_lyrics = lyrics
|
| 350 |
+
if req.lyrics:
|
| 351 |
+
resolved_lyrics = req.lyrics
|
| 352 |
+
if resolved_instrumental and not resolved_lyrics:
|
| 353 |
+
resolved_lyrics = "[inst]"
|
| 354 |
+
|
| 355 |
+
# Resolve sample_mode: explicit field takes priority, then auto-detect from messages
|
| 356 |
+
resolved_sample_mode = req.sample_mode or bool(sample_query)
|
| 357 |
+
resolved_sample_query = sample_query or ""
|
| 358 |
+
|
| 359 |
+
# Resolve seed: pass through as-is (int or comma-separated string)
|
| 360 |
+
# handler.prepare_seeds() handles both formats
|
| 361 |
+
resolved_seed = req.seed if req.seed is not None else -1
|
| 362 |
+
use_random_seed = req.seed is None
|
| 363 |
+
|
| 364 |
+
# Resolve task_type
|
| 365 |
+
# Explicit task_type from request takes priority.
|
| 366 |
+
# For text2music: auto-detect based on reference_audio.
|
| 367 |
+
resolved_task_type = req.task_type
|
| 368 |
+
if resolved_task_type == "text2music" and reference_audio_path:
|
| 369 |
+
resolved_task_type = "music_continuation"
|
| 370 |
+
|
| 371 |
+
return GenerateMusicRequest(
|
| 372 |
+
# Text input
|
| 373 |
+
prompt=prompt,
|
| 374 |
+
lyrics=resolved_lyrics,
|
| 375 |
+
sample_query=resolved_sample_query,
|
| 376 |
+
sample_mode=resolved_sample_mode,
|
| 377 |
+
|
| 378 |
+
# Music metadata
|
| 379 |
+
bpm=audio_config.bpm,
|
| 380 |
+
key_scale=audio_config.key_scale or "",
|
| 381 |
+
time_signature=audio_config.time_signature or "",
|
| 382 |
+
audio_duration=audio_config.duration if audio_config.duration else None,
|
| 383 |
+
vocal_language=audio_config.vocal_language or "en",
|
| 384 |
+
|
| 385 |
+
# LM parameters
|
| 386 |
+
lm_temperature=req.temperature if req.temperature is not None else 0.85,
|
| 387 |
+
lm_top_p=req.top_p if req.top_p is not None else 0.9,
|
| 388 |
+
lm_top_k=req.top_k if req.top_k is not None else 0,
|
| 389 |
+
thinking=req.thinking if req.thinking is not None else False,
|
| 390 |
+
|
| 391 |
+
# Generation parameters
|
| 392 |
+
inference_steps=8,
|
| 393 |
+
guidance_scale=req.guidance_scale if req.guidance_scale is not None else 7.0,
|
| 394 |
+
seed=resolved_seed,
|
| 395 |
+
use_random_seed=use_random_seed,
|
| 396 |
+
batch_size=req.batch_size if req.batch_size is not None else 1,
|
| 397 |
+
|
| 398 |
+
# Task type
|
| 399 |
+
task_type=resolved_task_type,
|
| 400 |
+
|
| 401 |
+
# Audio paths
|
| 402 |
+
reference_audio_path=reference_audio_path or None,
|
| 403 |
+
src_audio_path=src_audio_path or None,
|
| 404 |
+
|
| 405 |
+
# Audio editing
|
| 406 |
+
repainting_start=req.repainting_start,
|
| 407 |
+
repainting_end=req.repainting_end,
|
| 408 |
+
audio_cover_strength=req.audio_cover_strength,
|
| 409 |
+
|
| 410 |
+
# Format / CoT control
|
| 411 |
+
use_format=req.use_format,
|
| 412 |
+
use_cot_caption=req.use_cot_caption,
|
| 413 |
+
use_cot_language=req.use_cot_language,
|
| 414 |
+
|
| 415 |
+
# Model selection
|
| 416 |
+
model=_parse_model_name(req.model),
|
| 417 |
+
|
| 418 |
+
# Audio format
|
| 419 |
+
audio_format=(audio_config.format or DEFAULT_AUDIO_FORMAT),
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _build_openrouter_response(
|
| 424 |
+
rec: Any,
|
| 425 |
+
model_id: str,
|
| 426 |
+
audio_format: str,
|
| 427 |
+
) -> JSONResponse:
|
| 428 |
+
"""Build OpenRouter non-streaming response from a completed JobRecord."""
|
| 429 |
+
if rec.status != "succeeded" or not rec.result:
|
| 430 |
+
error_msg = rec.error or "Generation failed"
|
| 431 |
+
raise HTTPException(status_code=500, detail=error_msg)
|
| 432 |
+
|
| 433 |
+
result = rec.result
|
| 434 |
+
completion_id = _generate_completion_id()
|
| 435 |
+
created_timestamp = int(time.time())
|
| 436 |
+
|
| 437 |
+
text_content = _format_lm_content(result)
|
| 438 |
+
|
| 439 |
+
# Encode audio
|
| 440 |
+
audio_obj = None
|
| 441 |
+
raw_audio_paths = result.get("raw_audio_paths", [])
|
| 442 |
+
if raw_audio_paths:
|
| 443 |
+
audio_path = raw_audio_paths[0]
|
| 444 |
+
if audio_path and os.path.exists(audio_path):
|
| 445 |
+
b64_url = _audio_to_base64_url(audio_path, audio_format)
|
| 446 |
+
if b64_url:
|
| 447 |
+
audio_obj = [{
|
| 448 |
+
"type": "audio_url",
|
| 449 |
+
"audio_url": {"url": b64_url},
|
| 450 |
+
}]
|
| 451 |
+
|
| 452 |
+
response_data = {
|
| 453 |
+
"id": completion_id,
|
| 454 |
+
"object": "chat.completion",
|
| 455 |
+
"created": created_timestamp,
|
| 456 |
+
"model": model_id,
|
| 457 |
+
"choices": [{
|
| 458 |
+
"index": 0,
|
| 459 |
+
"message": {
|
| 460 |
+
"role": "assistant",
|
| 461 |
+
"content": text_content,
|
| 462 |
+
"audio": audio_obj,
|
| 463 |
+
},
|
| 464 |
+
"finish_reason": "stop",
|
| 465 |
+
}],
|
| 466 |
+
"usage": {
|
| 467 |
+
"prompt_tokens": 0,
|
| 468 |
+
"completion_tokens": 0,
|
| 469 |
+
"total_tokens": 0,
|
| 470 |
+
},
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
return JSONResponse(content=response_data)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
async def _openrouter_stream_generator(
|
| 477 |
+
rec: Any,
|
| 478 |
+
model_id: str,
|
| 479 |
+
audio_format: str,
|
| 480 |
+
):
|
| 481 |
+
"""
|
| 482 |
+
SSE stream generator that reads from rec.progress_queue.
|
| 483 |
+
|
| 484 |
+
Yields heartbeat chunks every 2 seconds while waiting for the
|
| 485 |
+
queue worker to push the generation result.
|
| 486 |
+
"""
|
| 487 |
+
completion_id = _generate_completion_id()
|
| 488 |
+
created_timestamp = int(time.time())
|
| 489 |
+
|
| 490 |
+
def _make_chunk(
|
| 491 |
+
content: Optional[str] = None,
|
| 492 |
+
role: Optional[str] = None,
|
| 493 |
+
audio: Optional[Any] = None,
|
| 494 |
+
finish_reason: Optional[str] = None,
|
| 495 |
+
) -> str:
|
| 496 |
+
delta = {}
|
| 497 |
+
if role:
|
| 498 |
+
delta["role"] = role
|
| 499 |
+
if content is not None:
|
| 500 |
+
delta["content"] = content
|
| 501 |
+
if audio is not None:
|
| 502 |
+
delta["audio"] = audio
|
| 503 |
+
|
| 504 |
+
chunk = {
|
| 505 |
+
"id": completion_id,
|
| 506 |
+
"object": "chat.completion.chunk",
|
| 507 |
+
"created": created_timestamp,
|
| 508 |
+
"model": model_id,
|
| 509 |
+
"choices": [{
|
| 510 |
+
"index": 0,
|
| 511 |
+
"delta": delta,
|
| 512 |
+
"finish_reason": finish_reason,
|
| 513 |
+
}],
|
| 514 |
+
}
|
| 515 |
+
return f"data: {json.dumps(chunk)}\n\n"
|
| 516 |
+
|
| 517 |
+
# Initial role chunk
|
| 518 |
+
yield _make_chunk(role="assistant", content="Generating music")
|
| 519 |
+
await asyncio.sleep(0)
|
| 520 |
+
|
| 521 |
+
# Wait for result with periodic heartbeats
|
| 522 |
+
while True:
|
| 523 |
+
try:
|
| 524 |
+
msg = await asyncio.wait_for(rec.progress_queue.get(), timeout=2.0)
|
| 525 |
+
except asyncio.TimeoutError:
|
| 526 |
+
yield _make_chunk(content=".")
|
| 527 |
+
await asyncio.sleep(0)
|
| 528 |
+
continue
|
| 529 |
+
|
| 530 |
+
msg_type = msg.get("type")
|
| 531 |
+
|
| 532 |
+
if msg_type == "done":
|
| 533 |
+
break
|
| 534 |
+
|
| 535 |
+
elif msg_type == "error":
|
| 536 |
+
yield _make_chunk(content=f"\n\nError: {msg.get('content', 'Unknown error')}")
|
| 537 |
+
yield _make_chunk(finish_reason="error")
|
| 538 |
+
yield "data: [DONE]\n\n"
|
| 539 |
+
return
|
| 540 |
+
|
| 541 |
+
elif msg_type == "result":
|
| 542 |
+
result = msg.get("result", {})
|
| 543 |
+
|
| 544 |
+
# Send LM content
|
| 545 |
+
lm_content = _format_lm_content(result)
|
| 546 |
+
yield _make_chunk(content=f"\n\n{lm_content}")
|
| 547 |
+
await asyncio.sleep(0)
|
| 548 |
+
|
| 549 |
+
# Send audio
|
| 550 |
+
raw_audio_paths = result.get("raw_audio_paths", [])
|
| 551 |
+
if raw_audio_paths:
|
| 552 |
+
audio_path = raw_audio_paths[0]
|
| 553 |
+
if audio_path and os.path.exists(audio_path):
|
| 554 |
+
b64_url = _audio_to_base64_url(audio_path, audio_format)
|
| 555 |
+
if b64_url:
|
| 556 |
+
audio_list = [{
|
| 557 |
+
"type": "audio_url",
|
| 558 |
+
"audio_url": {"url": b64_url},
|
| 559 |
+
}]
|
| 560 |
+
yield _make_chunk(audio=audio_list)
|
| 561 |
+
await asyncio.sleep(0)
|
| 562 |
+
|
| 563 |
+
# Finish
|
| 564 |
+
yield _make_chunk(finish_reason="stop")
|
| 565 |
+
yield "data: [DONE]\n\n"
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# =============================================================================
|
| 569 |
+
# Router Factory
|
| 570 |
+
# =============================================================================
|
| 571 |
+
|
| 572 |
+
def create_openrouter_router(app_state_getter) -> APIRouter:
|
| 573 |
+
"""
|
| 574 |
+
Create OpenRouter-compatible API router.
|
| 575 |
+
|
| 576 |
+
Args:
|
| 577 |
+
app_state_getter: Callable that returns the FastAPI app.state object
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
APIRouter with OpenRouter-compatible endpoints
|
| 581 |
+
"""
|
| 582 |
+
router = APIRouter(tags=["OpenRouter Compatible"])
|
| 583 |
+
|
| 584 |
+
def _get_model_name_from_path(config_path: str) -> str:
|
| 585 |
+
"""Extract model name from config path."""
|
| 586 |
+
if not config_path:
|
| 587 |
+
return ""
|
| 588 |
+
normalized = config_path.rstrip("/\\")
|
| 589 |
+
return os.path.basename(normalized)
|
| 590 |
+
|
| 591 |
+
@router.get("/v1/models", response_model=ModelsResponse)
|
| 592 |
+
async def list_models():
|
| 593 |
+
"""List available models in OpenRouter format."""
|
| 594 |
+
state = app_state_getter()
|
| 595 |
+
models = []
|
| 596 |
+
created_timestamp = int(time.time()) - 86400 * 30
|
| 597 |
+
|
| 598 |
+
# Primary model
|
| 599 |
+
if getattr(state, "_initialized", False):
|
| 600 |
+
model_name = _get_model_name_from_path(state._config_path)
|
| 601 |
+
if model_name:
|
| 602 |
+
models.append(ModelInfo(
|
| 603 |
+
id=_get_model_id(model_name),
|
| 604 |
+
name=f"ACE-Step {model_name}",
|
| 605 |
+
created=created_timestamp,
|
| 606 |
+
input_modalities=["text", "audio"],
|
| 607 |
+
output_modalities=["audio", "text"],
|
| 608 |
+
context_length=4096,
|
| 609 |
+
max_output_length=300,
|
| 610 |
+
pricing=ModelPricing(
|
| 611 |
+
prompt="0", completion="0", request="0",
|
| 612 |
+
),
|
| 613 |
+
description="AI music generation model",
|
| 614 |
+
))
|
| 615 |
+
|
| 616 |
+
# Secondary model
|
| 617 |
+
if getattr(state, "_initialized2", False) and getattr(state, "_config_path2", ""):
|
| 618 |
+
model_name = _get_model_name_from_path(state._config_path2)
|
| 619 |
+
if model_name:
|
| 620 |
+
models.append(ModelInfo(
|
| 621 |
+
id=_get_model_id(model_name),
|
| 622 |
+
name=f"ACE-Step {model_name}",
|
| 623 |
+
created=created_timestamp,
|
| 624 |
+
input_modalities=["text", "audio"],
|
| 625 |
+
output_modalities=["audio", "text"],
|
| 626 |
+
context_length=4096,
|
| 627 |
+
max_output_length=300,
|
| 628 |
+
pricing=ModelPricing(),
|
| 629 |
+
description="AI music generation model",
|
| 630 |
+
))
|
| 631 |
+
|
| 632 |
+
# Third model
|
| 633 |
+
if getattr(state, "_initialized3", False) and getattr(state, "_config_path3", ""):
|
| 634 |
+
model_name = _get_model_name_from_path(state._config_path3)
|
| 635 |
+
if model_name:
|
| 636 |
+
models.append(ModelInfo(
|
| 637 |
+
id=_get_model_id(model_name),
|
| 638 |
+
name=f"ACE-Step {model_name}",
|
| 639 |
+
created=created_timestamp,
|
| 640 |
+
input_modalities=["text", "audio"],
|
| 641 |
+
output_modalities=["audio", "text"],
|
| 642 |
+
context_length=4096,
|
| 643 |
+
max_output_length=300,
|
| 644 |
+
pricing=ModelPricing(),
|
| 645 |
+
description="AI music generation model",
|
| 646 |
+
))
|
| 647 |
+
|
| 648 |
+
return ModelsResponse(data=models)
|
| 649 |
+
|
| 650 |
+
@router.post("/v1/chat/completions")
|
| 651 |
+
async def chat_completions(request: Request):
|
| 652 |
+
"""
|
| 653 |
+
OpenRouter-compatible chat completions endpoint for music generation.
|
| 654 |
+
|
| 655 |
+
Submits the request to the shared asyncio.Queue and waits for completion.
|
| 656 |
+
Supports both streaming (SSE) and non-streaming responses.
|
| 657 |
+
"""
|
| 658 |
+
state = app_state_getter()
|
| 659 |
+
|
| 660 |
+
# Check initialization
|
| 661 |
+
if not getattr(state, "_initialized", False):
|
| 662 |
+
raise HTTPException(
|
| 663 |
+
status_code=503,
|
| 664 |
+
detail=f"Model not initialized. init_error={getattr(state, '_init_error', None)}"
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# Parse request
|
| 668 |
+
try:
|
| 669 |
+
body = await request.json()
|
| 670 |
+
req = ChatCompletionRequest(**body)
|
| 671 |
+
except Exception as e:
|
| 672 |
+
raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
|
| 673 |
+
|
| 674 |
+
# Parse messages for text, audio, and system instruction
|
| 675 |
+
prompt, lyrics, audio_paths, system_instruction, sample_query = _parse_messages(req.messages)
|
| 676 |
+
|
| 677 |
+
# When lyrics or sample_mode is explicitly provided, the message text role
|
| 678 |
+
# is already known — skip auto-detection results.
|
| 679 |
+
# _parse_messages may have put raw text into prompt or sample_query;
|
| 680 |
+
# recover it as raw_text for re-assignment.
|
| 681 |
+
if req.lyrics or req.sample_mode:
|
| 682 |
+
raw_text = prompt or sample_query or ""
|
| 683 |
+
if req.lyrics:
|
| 684 |
+
# lyrics provided → message text is the prompt
|
| 685 |
+
prompt = raw_text
|
| 686 |
+
lyrics = req.lyrics
|
| 687 |
+
sample_query = None
|
| 688 |
+
else:
|
| 689 |
+
# sample_mode → message text is the sample_query
|
| 690 |
+
prompt = ""
|
| 691 |
+
lyrics = ""
|
| 692 |
+
sample_query = raw_text
|
| 693 |
+
|
| 694 |
+
if not prompt and not lyrics and not sample_query and not req.sample_mode and not audio_paths:
|
| 695 |
+
raise HTTPException(
|
| 696 |
+
status_code=400,
|
| 697 |
+
detail="No valid prompt, lyrics, sample query, or input audio found in request"
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Route audio paths based on task_type.
|
| 701 |
+
# Multiple input_audio blocks are supported (like multiple images).
|
| 702 |
+
#
|
| 703 |
+
# For cover / repaint / lego / extract / complete:
|
| 704 |
+
# audio[0] → src_audio (primary: the audio to edit / cover)
|
| 705 |
+
# audio[1] → reference_audio (optional: style conditioning)
|
| 706 |
+
#
|
| 707 |
+
# For text2music (default):
|
| 708 |
+
# audio[0] → reference_audio (style conditioning → music_continuation)
|
| 709 |
+
reference_audio_path = None
|
| 710 |
+
src_audio_path = None
|
| 711 |
+
_SRC_AUDIO_TASK_TYPES = {"cover", "repaint", "lego", "extract", "complete"}
|
| 712 |
+
if audio_paths:
|
| 713 |
+
if req.task_type in _SRC_AUDIO_TASK_TYPES:
|
| 714 |
+
src_audio_path = audio_paths[0]
|
| 715 |
+
if len(audio_paths) > 1:
|
| 716 |
+
reference_audio_path = audio_paths[1]
|
| 717 |
+
else:
|
| 718 |
+
reference_audio_path = audio_paths[0]
|
| 719 |
+
|
| 720 |
+
# Convert to GenerateMusicRequest
|
| 721 |
+
gen_request = _to_generate_music_request(
|
| 722 |
+
req, prompt, lyrics, sample_query, reference_audio_path, src_audio_path
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
# Check queue capacity
|
| 726 |
+
job_queue = state.job_queue
|
| 727 |
+
if job_queue.full():
|
| 728 |
+
raise HTTPException(status_code=429, detail="Server busy: queue is full")
|
| 729 |
+
|
| 730 |
+
# Get audio format
|
| 731 |
+
audio_config = req.audio_config or AudioConfig()
|
| 732 |
+
audio_format = audio_config.format or DEFAULT_AUDIO_FORMAT
|
| 733 |
+
|
| 734 |
+
# Create job record and submit to queue
|
| 735 |
+
job_store = state.job_store
|
| 736 |
+
rec = job_store.create()
|
| 737 |
+
|
| 738 |
+
# Track temp files from base64 audio uploads
|
| 739 |
+
if audio_paths:
|
| 740 |
+
async with state.job_temp_files_lock:
|
| 741 |
+
state.job_temp_files.setdefault(rec.job_id, []).extend(audio_paths)
|
| 742 |
+
|
| 743 |
+
if req.stream:
|
| 744 |
+
# Streaming: use progress_queue
|
| 745 |
+
rec.progress_queue = asyncio.Queue()
|
| 746 |
+
|
| 747 |
+
async with state.pending_lock:
|
| 748 |
+
state.pending_ids.append(rec.job_id)
|
| 749 |
+
|
| 750 |
+
await job_queue.put((rec.job_id, gen_request))
|
| 751 |
+
|
| 752 |
+
return StreamingResponse(
|
| 753 |
+
_openrouter_stream_generator(rec, req.model, audio_format),
|
| 754 |
+
media_type="text/event-stream",
|
| 755 |
+
)
|
| 756 |
+
else:
|
| 757 |
+
# Non-streaming: use done_event
|
| 758 |
+
rec.done_event = asyncio.Event()
|
| 759 |
+
|
| 760 |
+
async with state.pending_lock:
|
| 761 |
+
state.pending_ids.append(rec.job_id)
|
| 762 |
+
|
| 763 |
+
await job_queue.put((rec.job_id, gen_request))
|
| 764 |
+
|
| 765 |
+
# Wait for completion with timeout
|
| 766 |
+
try:
|
| 767 |
+
await asyncio.wait_for(rec.done_event.wait(), timeout=GENERATION_TIMEOUT)
|
| 768 |
+
except asyncio.TimeoutError:
|
| 769 |
+
raise HTTPException(status_code=504, detail="Generation timeout")
|
| 770 |
+
|
| 771 |
+
return _build_openrouter_response(rec, req.model, audio_format)
|
| 772 |
+
|
| 773 |
+
return router
|
acestep/openrouter_models.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenRouter API compatible Pydantic models for ACE-Step.
|
| 2 |
+
|
| 3 |
+
This module defines request/response models that conform to OpenRouter's
|
| 4 |
+
chat completions API specification for audio generation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict, List, Literal, Optional, Union
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# =============================================================================
|
| 14 |
+
# Request Models
|
| 15 |
+
# =============================================================================
|
| 16 |
+
|
| 17 |
+
class AudioInputContent(BaseModel):
|
| 18 |
+
"""Audio input content in base64 format."""
|
| 19 |
+
data: str = Field(..., description="Base64-encoded audio data")
|
| 20 |
+
format: str = Field(default="mp3", description="Audio format (mp3, wav, flac, etc.)")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TextContent(BaseModel):
|
| 24 |
+
"""Text content block."""
|
| 25 |
+
type: Literal["text"] = "text"
|
| 26 |
+
text: str = Field(..., description="Text content")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class AudioContent(BaseModel):
|
| 30 |
+
"""Audio input content block."""
|
| 31 |
+
type: Literal["input_audio"] = "input_audio"
|
| 32 |
+
input_audio: AudioInputContent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Union type for message content
|
| 36 |
+
ContentPart = Union[TextContent, AudioContent, Dict[str, Any]]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class ChatMessage(BaseModel):
|
| 40 |
+
"""A single message in the chat conversation."""
|
| 41 |
+
role: Literal["system", "user", "assistant"] = Field(..., description="Message role")
|
| 42 |
+
content: Union[str, List[ContentPart]] = Field(..., description="Message content")
|
| 43 |
+
name: Optional[str] = Field(default=None, description="Optional name for the message author")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class AudioConfig(BaseModel):
|
| 47 |
+
"""Audio generation configuration."""
|
| 48 |
+
duration: Optional[float] = Field(default=None, description="Target audio duration in seconds")
|
| 49 |
+
format: str = Field(default="mp3", description="Output audio format")
|
| 50 |
+
# ACE-Step specific parameters
|
| 51 |
+
bpm: Optional[int] = Field(default=None, description="Beats per minute")
|
| 52 |
+
key_scale: Optional[str] = Field(default=None, description="Musical key and scale")
|
| 53 |
+
time_signature: Optional[str] = Field(default=None, description="Time signature (e.g., 4/4)")
|
| 54 |
+
vocal_language: Optional[str] = Field(default=None, description="Vocal language code")
|
| 55 |
+
instrumental: Optional[bool] = Field(default=None, description="Generate instrumental only")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ChatCompletionRequest(BaseModel):
|
| 59 |
+
"""OpenRouter-compatible chat completion request."""
|
| 60 |
+
model: str = Field(..., description="Model ID to use")
|
| 61 |
+
messages: List[ChatMessage] = Field(..., description="List of messages")
|
| 62 |
+
|
| 63 |
+
# Modalities
|
| 64 |
+
modalities: Optional[List[str]] = Field(
|
| 65 |
+
default=None,
|
| 66 |
+
description="Output modalities (e.g., ['audio', 'text'])"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Audio configuration
|
| 70 |
+
audio_config: Optional[AudioConfig] = Field(
|
| 71 |
+
default=None,
|
| 72 |
+
description="Audio generation configuration"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Standard OpenAI parameters
|
| 76 |
+
temperature: Optional[float] = Field(default=None, ge=0, le=2)
|
| 77 |
+
top_p: Optional[float] = Field(default=None, ge=0, le=1)
|
| 78 |
+
top_k: Optional[int] = Field(default=None, ge=0)
|
| 79 |
+
max_tokens: Optional[int] = Field(default=None, ge=1)
|
| 80 |
+
stream: bool = Field(default=False, description="Enable streaming response")
|
| 81 |
+
stop: Optional[Union[str, List[str]]] = Field(default=None)
|
| 82 |
+
seed: Optional[Union[int, str]] = Field(default=None, description="Seed(s) for reproducibility. Comma-separated for batch (e.g. '42,123,456')")
|
| 83 |
+
|
| 84 |
+
# ACE-Step specific parameters (extended)
|
| 85 |
+
thinking: Optional[bool] = Field(default=None, description="Use LM for audio code generation")
|
| 86 |
+
guidance_scale: Optional[float] = Field(default=None, description="Classifier-free guidance scale")
|
| 87 |
+
batch_size: Optional[int] = Field(default=None, description="Number of audio samples to generate")
|
| 88 |
+
|
| 89 |
+
# ACE-Step direct fields (bypass message parsing / audio_config)
|
| 90 |
+
lyrics: str = Field(default="", description="Direct lyrics input (bypass message parsing)")
|
| 91 |
+
sample_mode: bool = Field(default=False, description="Auto-generate caption/lyrics/metas via LM; user message becomes the query")
|
| 92 |
+
use_format: bool = Field(default=False, description="Use format_sample to enhance caption/lyrics")
|
| 93 |
+
use_cot_caption: bool = Field(default=True, description="Use CoT for caption rewriting")
|
| 94 |
+
use_cot_language: bool = Field(default=True, description="Use CoT for language detection")
|
| 95 |
+
|
| 96 |
+
# Task type
|
| 97 |
+
task_type: str = Field(default="text2music", description="Task type: text2music, cover, repaint, extract, lego, complete")
|
| 98 |
+
|
| 99 |
+
# Audio editing parameters
|
| 100 |
+
repainting_start: float = Field(default=0.0, description="Repainting region start (seconds)")
|
| 101 |
+
repainting_end: Optional[float] = Field(default=None, description="Repainting region end (seconds)")
|
| 102 |
+
audio_cover_strength: float = Field(default=1.0, description="Audio cover strength (0.0~1.0)")
|
| 103 |
+
|
| 104 |
+
class Config:
|
| 105 |
+
extra = "allow" # Allow additional fields for forward compatibility
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# =============================================================================
|
| 109 |
+
# Response Models
|
| 110 |
+
# =============================================================================
|
| 111 |
+
|
| 112 |
+
class AudioOutputUrl(BaseModel):
|
| 113 |
+
"""Audio output URL (base64 data URL)."""
|
| 114 |
+
url: str = Field(..., description="Base64 data URL of the audio")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class AudioOutput(BaseModel):
|
| 118 |
+
"""Audio output content block."""
|
| 119 |
+
type: Literal["audio_url"] = "audio_url"
|
| 120 |
+
audio_url: AudioOutputUrl
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class AssistantMessage(BaseModel):
|
| 124 |
+
"""Assistant response message."""
|
| 125 |
+
role: Literal["assistant"] = "assistant"
|
| 126 |
+
content: Optional[str] = Field(default=None, description="Text content")
|
| 127 |
+
audio: Optional[List[AudioOutput]] = Field(default=None, description="Generated audio files")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Choice(BaseModel):
|
| 131 |
+
"""A single completion choice."""
|
| 132 |
+
index: int = Field(default=0)
|
| 133 |
+
message: AssistantMessage
|
| 134 |
+
finish_reason: Literal["stop", "length", "content_filter", "error"] = "stop"
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class Usage(BaseModel):
|
| 138 |
+
"""Token usage statistics."""
|
| 139 |
+
prompt_tokens: int = 0
|
| 140 |
+
completion_tokens: int = 0
|
| 141 |
+
total_tokens: int = 0
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class ChatCompletionResponse(BaseModel):
|
| 145 |
+
"""OpenRouter-compatible chat completion response."""
|
| 146 |
+
id: str = Field(..., description="Unique completion ID")
|
| 147 |
+
object: Literal["chat.completion"] = "chat.completion"
|
| 148 |
+
created: int = Field(..., description="Unix timestamp")
|
| 149 |
+
model: str = Field(..., description="Model ID used")
|
| 150 |
+
choices: List[Choice] = Field(..., description="Completion choices")
|
| 151 |
+
usage: Usage = Field(default_factory=Usage)
|
| 152 |
+
|
| 153 |
+
# Extended metadata
|
| 154 |
+
system_fingerprint: Optional[str] = Field(default=None)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# =============================================================================
|
| 158 |
+
# Streaming Response Models
|
| 159 |
+
# =============================================================================
|
| 160 |
+
|
| 161 |
+
class DeltaContent(BaseModel):
|
| 162 |
+
"""Delta content for streaming."""
|
| 163 |
+
role: Optional[Literal["assistant"]] = None
|
| 164 |
+
content: Optional[str] = None
|
| 165 |
+
audio: Optional[List[AudioOutput]] = None
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class StreamChoice(BaseModel):
|
| 169 |
+
"""Streaming choice."""
|
| 170 |
+
index: int = 0
|
| 171 |
+
delta: DeltaContent
|
| 172 |
+
finish_reason: Optional[Literal["stop", "length", "content_filter", "error"]] = None
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class ChatCompletionChunk(BaseModel):
|
| 176 |
+
"""Streaming chunk response."""
|
| 177 |
+
id: str
|
| 178 |
+
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
| 179 |
+
created: int
|
| 180 |
+
model: str
|
| 181 |
+
choices: List[StreamChoice]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# =============================================================================
|
| 185 |
+
# Models Endpoint Response
|
| 186 |
+
# =============================================================================
|
| 187 |
+
|
| 188 |
+
class ModelPricing(BaseModel):
|
| 189 |
+
"""Model pricing information."""
|
| 190 |
+
prompt: str = Field(default="0", description="Price per prompt token in USD")
|
| 191 |
+
completion: str = Field(default="0", description="Price per completion token in USD")
|
| 192 |
+
request: str = Field(default="0", description="Price per request in USD")
|
| 193 |
+
image: str = Field(default="0", description="Price per image in USD")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class ModelInfo(BaseModel):
|
| 197 |
+
"""OpenRouter-compatible model information."""
|
| 198 |
+
id: str = Field(..., description="Model identifier")
|
| 199 |
+
name: str = Field(..., description="Display name")
|
| 200 |
+
created: int = Field(..., description="Unix timestamp of creation")
|
| 201 |
+
|
| 202 |
+
# Modalities
|
| 203 |
+
input_modalities: List[str] = Field(
|
| 204 |
+
default_factory=lambda: ["text"],
|
| 205 |
+
description="Supported input modalities"
|
| 206 |
+
)
|
| 207 |
+
output_modalities: List[str] = Field(
|
| 208 |
+
default_factory=lambda: ["audio", "text"],
|
| 209 |
+
description="Supported output modalities"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Limits
|
| 213 |
+
context_length: int = Field(default=4096, description="Maximum context length")
|
| 214 |
+
max_output_length: int = Field(default=300, description="Maximum output length in seconds")
|
| 215 |
+
|
| 216 |
+
# Pricing
|
| 217 |
+
pricing: ModelPricing = Field(default_factory=ModelPricing)
|
| 218 |
+
|
| 219 |
+
# Metadata
|
| 220 |
+
description: Optional[str] = Field(default=None)
|
| 221 |
+
architecture: Optional[Dict[str, Any]] = Field(default=None)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class ModelsResponse(BaseModel):
|
| 225 |
+
"""Response for /v1/models endpoint."""
|
| 226 |
+
object: Literal["list"] = "list"
|
| 227 |
+
data: List[ModelInfo] = Field(default_factory=list)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# =============================================================================
|
| 231 |
+
# Error Response
|
| 232 |
+
# =============================================================================
|
| 233 |
+
|
| 234 |
+
class ErrorDetail(BaseModel):
|
| 235 |
+
"""Error detail information."""
|
| 236 |
+
message: str
|
| 237 |
+
type: str = "invalid_request_error"
|
| 238 |
+
param: Optional[str] = None
|
| 239 |
+
code: Optional[str] = None
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class ErrorResponse(BaseModel):
|
| 243 |
+
"""OpenRouter-compatible error response."""
|
| 244 |
+
error: ErrorDetail
|
acestep/test_time_scaling.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test-Time Scaling Module
|
| 3 |
+
Implements perplexity-based scoring for generated audio codes
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from typing import Tuple, Optional, Dict, Any, List
|
| 8 |
+
from loguru import logger
|
| 9 |
+
import yaml
|
| 10 |
+
import math
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
|
| 15 |
+
"""
|
| 16 |
+
Calculate Pointwise Mutual Information (PMI) score.
|
| 17 |
+
|
| 18 |
+
PMI = log P(condition|codes) - log P(condition)
|
| 19 |
+
= log [P(codes|condition) / P(codes)]
|
| 20 |
+
|
| 21 |
+
This removes the bias from P(condition) and measures how much the codes
|
| 22 |
+
improve our ability to predict the condition.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
log_prob_conditional: Average log probability of condition given codes
|
| 26 |
+
log_prob_unconditional: Average log probability of condition without codes
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
PMI score (higher is better, can be positive or negative)
|
| 30 |
+
- Positive: codes improve prediction → good match
|
| 31 |
+
- Zero: codes don't help → no correlation
|
| 32 |
+
- Negative: codes hurt prediction → poor match
|
| 33 |
+
"""
|
| 34 |
+
return log_prob_conditional - log_prob_unconditional
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
|
| 38 |
+
"""
|
| 39 |
+
Convert PMI score to normalized [0, 1] range using sigmoid function.
|
| 40 |
+
|
| 41 |
+
score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
pmi: PMI score (can be positive or negative)
|
| 45 |
+
scale: Scale parameter to control sensitivity (default 0.1)
|
| 46 |
+
- Smaller scale: more sensitive to PMI changes
|
| 47 |
+
- Larger scale: less sensitive to PMI changes
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Normalized score in [0, 1] range, where:
|
| 51 |
+
- PMI > 0 → score > 0.5 (good match)
|
| 52 |
+
- PMI = 0 → score = 0.5 (neutral)
|
| 53 |
+
- PMI < 0 → score < 0.5 (poor match)
|
| 54 |
+
|
| 55 |
+
Examples (scale=1.0):
|
| 56 |
+
PMI=2.0 → score≈0.88 (excellent)
|
| 57 |
+
PMI=1.0 → score≈0.73 (good)
|
| 58 |
+
PMI=0.0 → score=0.50 (neutral)
|
| 59 |
+
PMI=-1.0 → score≈0.27 (poor)
|
| 60 |
+
PMI=-2.0 → score≈0.12 (bad)
|
| 61 |
+
"""
|
| 62 |
+
return 1.0 / (1.0 + math.exp(-pmi / scale))
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
|
| 66 |
+
target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
llm_handler: The handler containing the model and tokenizer.
|
| 70 |
+
formatted_prompt: The input context.
|
| 71 |
+
target_text: The text we want to calculate probability/recall for.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
Tuple of (target_logits, target_ids)
|
| 75 |
+
- target_logits: Logits used to predict the target tokens.
|
| 76 |
+
- target_ids: The ground truth token IDs of the target.
|
| 77 |
+
"""
|
| 78 |
+
model = llm_handler.get_hf_model_for_scoring()
|
| 79 |
+
tokenizer = llm_handler.llm_tokenizer
|
| 80 |
+
device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
|
| 81 |
+
|
| 82 |
+
# 1. Tokenize prompt ONLY to get its length (used for slicing later).
|
| 83 |
+
# We must ensure special tokens are added to count the offset correctly.
|
| 84 |
+
prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
|
| 85 |
+
prompt_len = prompt_tokens_temp['input_ids'].shape[1]
|
| 86 |
+
|
| 87 |
+
# 2. Tokenize the FULL text (Prompt + Target).
|
| 88 |
+
# This ensures subword merging at boundaries is handled correctly by the tokenizer.
|
| 89 |
+
full_text = formatted_prompt + target_text
|
| 90 |
+
full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
|
| 91 |
+
|
| 92 |
+
input_ids = full_tokens['input_ids']
|
| 93 |
+
|
| 94 |
+
# Safety check: if target was empty or truncated entirely
|
| 95 |
+
if input_ids.shape[1] <= prompt_len:
|
| 96 |
+
return torch.empty(0, device=device), torch.empty(0, device=device)
|
| 97 |
+
|
| 98 |
+
# 3. Forward Pass (Teacher Forcing)
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
with llm_handler._load_model_context():
|
| 101 |
+
outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
|
| 102 |
+
all_logits = outputs.logits # [1, seq_len, vocab_size]
|
| 103 |
+
|
| 104 |
+
# 4. Extract Logits and Labels
|
| 105 |
+
# We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
|
| 106 |
+
# Target starts at index `prompt_len`.
|
| 107 |
+
# So we need logits from `prompt_len - 1` up to the second to last position.
|
| 108 |
+
|
| 109 |
+
target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
|
| 110 |
+
target_ids = input_ids[0, prompt_len:] # [target_len]
|
| 111 |
+
|
| 112 |
+
return target_logits, target_ids
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ==============================================================================
|
| 116 |
+
# Scoring Logic
|
| 117 |
+
# ==============================================================================
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _calculate_topk_recall(llm_handler,
|
| 121 |
+
formatted_prompt: str,
|
| 122 |
+
target_text: str,
|
| 123 |
+
topk: int = 10) -> Tuple[float, Dict[int, float]]:
|
| 124 |
+
"""
|
| 125 |
+
Calculate top-k recall for target text given prompt.
|
| 126 |
+
Checks if the ground truth token is within the top-k probabilities at each step.
|
| 127 |
+
"""
|
| 128 |
+
# Use the fixed helper to get aligned logits/labels
|
| 129 |
+
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
|
| 130 |
+
|
| 131 |
+
if target_ids.shape[0] == 0:
|
| 132 |
+
return 0.0, {}
|
| 133 |
+
|
| 134 |
+
target_len = target_ids.shape[0]
|
| 135 |
+
|
| 136 |
+
# Get top-k indices for all positions at once
|
| 137 |
+
# topk_indices: [target_len, topk]
|
| 138 |
+
_, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
|
| 139 |
+
|
| 140 |
+
recall_per_k = {}
|
| 141 |
+
position_scores = []
|
| 142 |
+
|
| 143 |
+
# Convert to list for faster CPU iteration
|
| 144 |
+
target_ids_list = target_ids.tolist()
|
| 145 |
+
topk_indices_list = topk_indices.tolist()
|
| 146 |
+
|
| 147 |
+
for k in range(1, topk + 1):
|
| 148 |
+
hits = 0
|
| 149 |
+
for pos in range(target_len):
|
| 150 |
+
gt_token = target_ids_list[pos]
|
| 151 |
+
# Check the top-k slice
|
| 152 |
+
topk_at_pos = topk_indices_list[pos][:k]
|
| 153 |
+
|
| 154 |
+
if gt_token in topk_at_pos:
|
| 155 |
+
hits += 1
|
| 156 |
+
# Calculate position-weighted score only once (when k=topk)
|
| 157 |
+
if k == topk:
|
| 158 |
+
rank = topk_at_pos.index(gt_token) + 1
|
| 159 |
+
# Rank 1 = 1.0, Rank k = small positive
|
| 160 |
+
position_weight = 1.0 - (rank - 1) / topk
|
| 161 |
+
position_scores.append(position_weight)
|
| 162 |
+
|
| 163 |
+
recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
|
| 164 |
+
|
| 165 |
+
# Fill scores for positions where GT was NOT in top-k
|
| 166 |
+
while len(position_scores) < target_len:
|
| 167 |
+
position_scores.append(0.0)
|
| 168 |
+
|
| 169 |
+
average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
|
| 170 |
+
|
| 171 |
+
return average_recall, recall_per_k
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _calculate_metadata_recall(llm_handler,
|
| 175 |
+
formatted_prompt: str,
|
| 176 |
+
fields_dict: Dict[str, Any],
|
| 177 |
+
topk: int = 10) -> Dict[str, float]:
|
| 178 |
+
"""
|
| 179 |
+
Args:
|
| 180 |
+
fields_dict: Dictionary of {field_name: field_value}
|
| 181 |
+
"""
|
| 182 |
+
if not fields_dict:
|
| 183 |
+
return {}
|
| 184 |
+
|
| 185 |
+
field_scores = {}
|
| 186 |
+
|
| 187 |
+
for field_name in sorted(fields_dict.keys()):
|
| 188 |
+
# Construct target text for this specific field
|
| 189 |
+
# e.g. <think>\nbpm: 120\n</think>\n
|
| 190 |
+
field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
|
| 191 |
+
field_target_text = f"<think>\n{field_yaml}\n</think>\n"
|
| 192 |
+
|
| 193 |
+
# Calculate recall using the robust logic
|
| 194 |
+
avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
|
| 195 |
+
|
| 196 |
+
field_scores[field_name] = avg_score
|
| 197 |
+
logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
|
| 198 |
+
|
| 199 |
+
return field_scores
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def _calculate_log_prob(
|
| 203 |
+
llm_handler,
|
| 204 |
+
formatted_prompt: str,
|
| 205 |
+
target_text: str,
|
| 206 |
+
temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
|
| 207 |
+
) -> float:
|
| 208 |
+
"""
|
| 209 |
+
Calculate average log probability of target text given prompt.
|
| 210 |
+
"""
|
| 211 |
+
pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
|
| 212 |
+
|
| 213 |
+
if target_ids.shape[0] == 0:
|
| 214 |
+
return float('-inf')
|
| 215 |
+
|
| 216 |
+
# FIX: Do not divide by temperature.
|
| 217 |
+
# Log-probability for PMI/Perplexity should be exact.
|
| 218 |
+
|
| 219 |
+
# Calculate log probabilities (log_softmax)
|
| 220 |
+
log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
|
| 221 |
+
|
| 222 |
+
# Gather log probabilities of the ground truth tokens
|
| 223 |
+
target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
|
| 224 |
+
|
| 225 |
+
# Return average log probability
|
| 226 |
+
mean_log_prob = target_log_probs.mean().item()
|
| 227 |
+
|
| 228 |
+
return mean_log_prob
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def calculate_reward_score(
|
| 232 |
+
scores: Dict[str, float],
|
| 233 |
+
weights_config: Optional[Dict[str, float]] = None
|
| 234 |
+
) -> Tuple[float, str]:
|
| 235 |
+
"""
|
| 236 |
+
Reward Model Calculator: Computes a final reward based on user priorities.
|
| 237 |
+
|
| 238 |
+
Priority Logic:
|
| 239 |
+
1. Caption (Highest): The overall vibe/style must match.
|
| 240 |
+
2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
|
| 241 |
+
3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
|
| 242 |
+
|
| 243 |
+
Strategy: Dynamic Weighted Sum
|
| 244 |
+
- Metadata fields are aggregated into a single 'metadata' score first.
|
| 245 |
+
- Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
|
| 249 |
+
weights_config: Optional custom weights. Defaults to:
|
| 250 |
+
Caption (50%), Lyrics (30%), Metadata (20%).
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
final_reward: The calculated reward score (0.0 - 1.0).
|
| 254 |
+
explanation: A formatted string explaining how the score was derived.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
# 1. Default Preference Configuration
|
| 258 |
+
# These weights determine the relative importance of each component.
|
| 259 |
+
if weights_config is None:
|
| 260 |
+
weights_config = {
|
| 261 |
+
'caption': 0.50, # High priority: Style/Vibe
|
| 262 |
+
'lyrics': 0.30, # Medium priority: Content
|
| 263 |
+
'metadata': 0.20 # Low priority: Technical details
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# 2. Extract and Group Scores
|
| 267 |
+
# Caption and Lyrics are standalone high-level features.
|
| 268 |
+
caption_score = scores.get('caption')
|
| 269 |
+
lyrics_score = scores.get('lyrics')
|
| 270 |
+
|
| 271 |
+
# Metadata fields (bpm, key, duration, etc.) are aggregated.
|
| 272 |
+
# We treat them as a single "Technical Score" to prevent them from
|
| 273 |
+
# diluting the weight of Caption/Lyrics simply by having many fields.
|
| 274 |
+
meta_scores_list = [
|
| 275 |
+
val for key, val in scores.items()
|
| 276 |
+
if key not in ['caption', 'lyrics']
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
# Calculate average of all metadata fields (if any exist)
|
| 280 |
+
meta_aggregate_score = None
|
| 281 |
+
if meta_scores_list:
|
| 282 |
+
meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
|
| 283 |
+
|
| 284 |
+
# 3. specific Active Components & Dynamic Weighting
|
| 285 |
+
# We only include components that actually exist in this generation.
|
| 286 |
+
active_components = {}
|
| 287 |
+
|
| 288 |
+
if caption_score is not None:
|
| 289 |
+
active_components['caption'] = (caption_score, weights_config['caption'])
|
| 290 |
+
|
| 291 |
+
if lyrics_score is not None:
|
| 292 |
+
active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
|
| 293 |
+
|
| 294 |
+
if meta_aggregate_score is not None:
|
| 295 |
+
active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
|
| 296 |
+
|
| 297 |
+
# 4. Calculate Final Weighted Score
|
| 298 |
+
total_base_weight = sum(w for _, w in active_components.values())
|
| 299 |
+
total_score = 0.0
|
| 300 |
+
|
| 301 |
+
breakdown_lines = []
|
| 302 |
+
|
| 303 |
+
if total_base_weight == 0:
|
| 304 |
+
return 0.0, "❌ No valid scores available to calculate reward."
|
| 305 |
+
|
| 306 |
+
# Sort by weight (importance) for display
|
| 307 |
+
sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
|
| 308 |
+
|
| 309 |
+
for name, (score, base_weight) in sorted_components:
|
| 310 |
+
# Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
|
| 311 |
+
normalized_weight = base_weight / total_base_weight
|
| 312 |
+
weighted_contribution = score * normalized_weight
|
| 313 |
+
total_score += weighted_contribution
|
| 314 |
+
|
| 315 |
+
breakdown_lines.append(
|
| 316 |
+
f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
|
| 317 |
+
f"-> Contrib: +{weighted_contribution:.4f}"
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return total_score, "\n".join(breakdown_lines)
|
| 321 |
+
|
| 322 |
+
# ==============================================================================
|
| 323 |
+
# Main Public API
|
| 324 |
+
# ==============================================================================
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def calculate_pmi_score_per_condition(
|
| 328 |
+
llm_handler,
|
| 329 |
+
audio_codes: str,
|
| 330 |
+
caption: str = "",
|
| 331 |
+
lyrics: str = "",
|
| 332 |
+
metadata: Optional[Dict[str, Any]] = None,
|
| 333 |
+
temperature: float = 1.0,
|
| 334 |
+
topk: int = 10,
|
| 335 |
+
score_scale: float = 0.1,
|
| 336 |
+
) -> Tuple[Dict[str, float], float, str]:
|
| 337 |
+
"""
|
| 338 |
+
Calculate quality score separately for each condition.
|
| 339 |
+
- Metadata: Uses Top-k Recall.
|
| 340 |
+
- Caption/Lyrics: Uses PMI (Normalized).
|
| 341 |
+
"""
|
| 342 |
+
if not llm_handler.llm_initialized:
|
| 343 |
+
return {}, 0.0, "❌ LLM not initialized"
|
| 344 |
+
|
| 345 |
+
if not audio_codes or not audio_codes.strip():
|
| 346 |
+
return {}, 0.0, "❌ No audio codes provided"
|
| 347 |
+
|
| 348 |
+
if "caption" not in metadata:
|
| 349 |
+
metadata['caption'] = caption
|
| 350 |
+
|
| 351 |
+
formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
|
| 352 |
+
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
|
| 353 |
+
try:
|
| 354 |
+
# 1. Calculate Recall for Metadata Fields
|
| 355 |
+
if metadata and isinstance(metadata, dict):
|
| 356 |
+
scores = {}
|
| 357 |
+
# Define which fields use which metric
|
| 358 |
+
metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
|
| 359 |
+
metadata_pmi_keys = ['caption']
|
| 360 |
+
for key in metadata_recall_keys:
|
| 361 |
+
if key in metadata and metadata[key] is not None:
|
| 362 |
+
recall_metadata = {key: metadata[key]}
|
| 363 |
+
field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
|
| 364 |
+
scores.update(field_scores)
|
| 365 |
+
|
| 366 |
+
# 2. Calculate PMI for Caption
|
| 367 |
+
for key in metadata_pmi_keys:
|
| 368 |
+
if key in metadata and metadata[key] is not None:
|
| 369 |
+
cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
|
| 370 |
+
target_text = f"<think>\n{cot_yaml}\n</think>\n"
|
| 371 |
+
|
| 372 |
+
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
|
| 373 |
+
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
|
| 374 |
+
|
| 375 |
+
pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
|
| 376 |
+
scores[key] = pmi_normalized
|
| 377 |
+
|
| 378 |
+
# 3. Calculate PMI for Lyrics
|
| 379 |
+
if lyrics:
|
| 380 |
+
target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
|
| 381 |
+
|
| 382 |
+
log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
|
| 383 |
+
|
| 384 |
+
prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
|
| 385 |
+
log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
|
| 386 |
+
|
| 387 |
+
scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
|
| 388 |
+
|
| 389 |
+
if not scores:
|
| 390 |
+
return {}, 0.0, "❌ No conditions to evaluate"
|
| 391 |
+
|
| 392 |
+
# 4. Global Score
|
| 393 |
+
global_score = sum(scores.values()) / len(scores)
|
| 394 |
+
global_score, breakdown_lines = calculate_reward_score(scores)
|
| 395 |
+
|
| 396 |
+
# Status Message
|
| 397 |
+
status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
|
| 398 |
+
for key, score in sorted(scores.items()):
|
| 399 |
+
metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
|
| 400 |
+
status_lines.append(f" {key}: {score:.4f} ({metric})")
|
| 401 |
+
status = "\n".join(status_lines)
|
| 402 |
+
logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
|
| 403 |
+
return scores, global_score, status
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
import traceback
|
| 407 |
+
error_msg = f"❌ Error: {str(e)}"
|
| 408 |
+
logger.error(error_msg)
|
| 409 |
+
logger.error(traceback.format_exc())
|
| 410 |
+
return {}, float('-inf'), error_msg
|
acestep/third_parts/nano-vllm/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Xingkai Yu
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
acestep/third_parts/nano-vllm/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img width="300" src="assets/logo.png">
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
+
# Nano-vLLM
|
| 10 |
+
|
| 11 |
+
A lightweight vLLM implementation built from scratch.
|
| 12 |
+
|
| 13 |
+
## Key Features
|
| 14 |
+
|
| 15 |
+
* 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
|
| 16 |
+
* 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
|
| 17 |
+
* ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
|
| 18 |
+
|
| 19 |
+
## Installation
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Model Download
|
| 26 |
+
|
| 27 |
+
To download the model weights manually, use the following command:
|
| 28 |
+
```bash
|
| 29 |
+
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
|
| 30 |
+
--local-dir ~/huggingface/Qwen3-0.6B/ \
|
| 31 |
+
--local-dir-use-symlinks False
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
## Quick Start
|
| 35 |
+
|
| 36 |
+
See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
|
| 37 |
+
```python
|
| 38 |
+
from nanovllm import LLM, SamplingParams
|
| 39 |
+
llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
|
| 40 |
+
sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
|
| 41 |
+
prompts = ["Hello, Nano-vLLM."]
|
| 42 |
+
outputs = llm.generate(prompts, sampling_params)
|
| 43 |
+
outputs[0]["text"]
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Benchmark
|
| 47 |
+
|
| 48 |
+
See `bench.py` for benchmark.
|
| 49 |
+
|
| 50 |
+
**Test Configuration:**
|
| 51 |
+
- Hardware: RTX 4070 Laptop (8GB)
|
| 52 |
+
- Model: Qwen3-0.6B
|
| 53 |
+
- Total Requests: 256 sequences
|
| 54 |
+
- Input Length: Randomly sampled between 100–1024 tokens
|
| 55 |
+
- Output Length: Randomly sampled between 100–1024 tokens
|
| 56 |
+
|
| 57 |
+
**Performance Results:**
|
| 58 |
+
| Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
|
| 59 |
+
|----------------|-------------|----------|-----------------------|
|
| 60 |
+
| vLLM | 133,966 | 98.37 | 1361.84 |
|
| 61 |
+
| Nano-vLLM | 133,966 | 93.41 | 1434.13 |
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## Star History
|
| 65 |
+
|
| 66 |
+
[](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
|
acestep/third_parts/nano-vllm/bench.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from random import randint, seed
|
| 4 |
+
from nanovllm import LLM, SamplingParams
|
| 5 |
+
# from vllm import LLM, SamplingParams
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
seed(0)
|
| 10 |
+
num_seqs = 256
|
| 11 |
+
max_input_len = 1024
|
| 12 |
+
max_ouput_len = 1024
|
| 13 |
+
|
| 14 |
+
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
| 15 |
+
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
| 16 |
+
|
| 17 |
+
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
| 18 |
+
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
|
| 19 |
+
# uncomment the following line for vllm
|
| 20 |
+
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
| 21 |
+
|
| 22 |
+
llm.generate(["Benchmark: "], SamplingParams())
|
| 23 |
+
t = time.time()
|
| 24 |
+
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
| 25 |
+
t = (time.time() - t)
|
| 26 |
+
total_tokens = sum(sp.max_tokens for sp in sampling_params)
|
| 27 |
+
throughput = total_tokens / t
|
| 28 |
+
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|