ACE-Step Custom commited on
Commit
a602628
·
0 Parent(s):

Deploy ACE-Step Custom Edition with bug fixes

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +67 -0
  3. .python-version +1 -0
  4. DEPLOYMENT.md +367 -0
  5. DEPLOYMENT_CHECKLIST.txt +226 -0
  6. DEPLOY_QUICK.md +126 -0
  7. Dockerfile +35 -0
  8. LICENSE +28 -0
  9. QUICKSTART.md +115 -0
  10. README.md +73 -0
  11. README_HF.md +73 -0
  12. README_PROJECT.md +116 -0
  13. acestep/__init__.py +1 -0
  14. acestep/acestep_v15_pipeline.py +411 -0
  15. acestep/api_server.py +0 -0
  16. acestep/audio_utils.py +354 -0
  17. acestep/constants.py +193 -0
  18. acestep/constrained_logits_processor.py +0 -0
  19. acestep/dataset_handler.py +83 -0
  20. acestep/debug_utils.py +122 -0
  21. acestep/dit_alignment_score.py +877 -0
  22. acestep/genres_vocab.txt +0 -0
  23. acestep/gpu_config.py +549 -0
  24. acestep/gradio_ui/__init__.py +1 -0
  25. acestep/gradio_ui/api_routes.py +564 -0
  26. acestep/gradio_ui/events/__init__.py +1254 -0
  27. acestep/gradio_ui/events/generation_handlers.py +1050 -0
  28. acestep/gradio_ui/events/results_handlers.py +0 -0
  29. acestep/gradio_ui/events/training_handlers.py +829 -0
  30. acestep/gradio_ui/i18n.py +152 -0
  31. acestep/gradio_ui/i18n/en.json +354 -0
  32. acestep/gradio_ui/i18n/he.json +352 -0
  33. acestep/gradio_ui/i18n/ja.json +354 -0
  34. acestep/gradio_ui/i18n/zh.json +350 -0
  35. acestep/gradio_ui/interfaces/__init__.py +94 -0
  36. acestep/gradio_ui/interfaces/dataset.py +101 -0
  37. acestep/gradio_ui/interfaces/generation.py +824 -0
  38. acestep/gradio_ui/interfaces/result.py +552 -0
  39. acestep/gradio_ui/interfaces/training.py +625 -0
  40. acestep/handler.py +0 -0
  41. acestep/inference.py +1310 -0
  42. acestep/llm_inference.py +0 -0
  43. acestep/local_cache.py +129 -0
  44. acestep/model_downloader.py +634 -0
  45. acestep/openrouter_adapter.py +773 -0
  46. acestep/openrouter_models.py +244 -0
  47. acestep/test_time_scaling.py +410 -0
  48. acestep/third_parts/nano-vllm/LICENSE +21 -0
  49. acestep/third_parts/nano-vllm/README.md +66 -0
  50. acestep/third_parts/nano-vllm/bench.py +32 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+
28
+ # IDE
29
+ .vscode/
30
+ .idea/
31
+ *.swp
32
+ *.swo
33
+ *~
34
+
35
+ # OS
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Application
40
+ outputs/
41
+ timelines/
42
+ lora_training/prepared_data/
43
+ lora_training/models/
44
+ logs/
45
+ models/
46
+ *.wav
47
+ *.mp3
48
+ *.flac
49
+ *.ogg
50
+
51
+ # Config (keep example)
52
+ config.yaml
53
+
54
+ # Jupyter
55
+ .ipynb_checkpoints/
56
+ *.ipynb
57
+
58
+ # Model cache
59
+ .cache/
60
+ huggingface/
61
+
62
+ # Environment
63
+ .env
64
+ .env.local
65
+
66
+ # Test outputs
67
+ test_outputs/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.11.0
DEPLOYMENT.md ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HuggingFace Spaces Deployment Guide
2
+
3
+ ## Quick Deploy to HuggingFace Spaces
4
+
5
+ ### Prerequisites
6
+ - HuggingFace account (create at https://huggingface.co/join)
7
+ - Git installed on your machine
8
+ - Git LFS installed (for large files)
9
+
10
+ ### Method 1: Web Upload (Easiest)
11
+
12
+ 1. **Create New Space**
13
+ - Go to https://huggingface.co/new-space
14
+ - Name: `ace-step-custom` (or your choice)
15
+ - License: MIT
16
+ - SDK: Gradio
17
+ - Hardware: A10G Small (or better)
18
+ - Click "Create Space"
19
+
20
+ 2. **Upload Files**
21
+ - Click "Files and versions" tab
22
+ - Click "Add file" → "Upload files"
23
+ - Upload all files from `d:\2025-vibe-coding\ACE-Step-Custom\`:
24
+ - `app.py`
25
+ - `requirements.txt`
26
+ - `config.yaml`
27
+ - `README.md` (with YAML frontmatter)
28
+ - `LICENSE`
29
+ - `.gitignore`
30
+ - Entire `src/` directory
31
+ - Entire `scripts/` directory
32
+ - Commit changes
33
+
34
+ 3. **Configure Space**
35
+ - Go to "Settings" tab
36
+ - Set Python version: 3.10
37
+ - Enable GPU: A10G Small (minimum) or A100 (recommended)
38
+ - Set timeout: 30 minutes (for long generations)
39
+ - Save settings
40
+
41
+ 4. **Wait for Build**
42
+ - Space will automatically build and deploy
43
+ - First build takes 5-10 minutes
44
+ - Model will download on first run (~7GB)
45
+
46
+ ### Method 2: Git Push (For Developers)
47
+
48
+ 1. **Create Space on HuggingFace**
49
+ - Go to https://huggingface.co/new-space
50
+ - Create space as above
51
+
52
+ 2. **Clone and Push**
53
+ ```powershell
54
+ # Navigate to project
55
+ cd d:\2025-vibe-coding\ACE-Step-Custom
56
+
57
+ # Initialize git (if not already)
58
+ git init
59
+ git add .
60
+ git commit -m "Initial commit"
61
+
62
+ # Add HuggingFace remote
63
+ git remote add hf https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
64
+
65
+ # Push to HuggingFace
66
+ git push hf main
67
+ ```
68
+
69
+ 3. **Configure Git LFS for Large Files**
70
+ ```powershell
71
+ git lfs install
72
+ git lfs track "*.wav"
73
+ git lfs track "*.pth"
74
+ git lfs track "*.bin"
75
+ git lfs track "models/**"
76
+ git add .gitattributes
77
+ git commit -m "Add LFS tracking"
78
+ git push hf main
79
+ ```
80
+
81
+ ### Method 3: HuggingFace CLI (Fastest)
82
+
83
+ 1. **Install HuggingFace CLI**
84
+ ```powershell
85
+ pip install huggingface_hub
86
+ ```
87
+
88
+ 2. **Login**
89
+ ```powershell
90
+ huggingface-cli login
91
+ # Enter your HuggingFace token
92
+ ```
93
+
94
+ 3. **Create and Upload**
95
+ ```powershell
96
+ cd d:\2025-vibe-coding\ACE-Step-Custom
97
+
98
+ # Create space
99
+ huggingface-cli repo create ace-step-custom --type space --space_sdk gradio
100
+
101
+ # Upload files
102
+ huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space
103
+ ```
104
+
105
+ ## Space Configuration
106
+
107
+ ### Hardware Recommendations
108
+
109
+ | GPU | VRAM | Cost | Performance | Recommended For |
110
+ |-----|------|------|-------------|-----------------|
111
+ | CPU | - | Free | Very Slow | Testing only |
112
+ | T4 Small | 16GB | ~$0.60/hr | Slow | Light testing |
113
+ | **A10G Small** | **24GB** | **~$1.05/hr** | **Good** | **Recommended** |
114
+ | A10G Large | 24GB | ~$3.15/hr | Good | Production |
115
+ | A100 Large | 40GB | ~$4.13/hr | Excellent | Best quality |
116
+
117
+ **Recommendation:** Start with A10G Small for testing, upgrade to A100 for production.
118
+
119
+ ### Environment Variables (Optional)
120
+
121
+ In Space settings, you can add:
122
+
123
+ ```
124
+ GRADIO_SERVER_NAME=0.0.0.0
125
+ GRADIO_SERVER_PORT=7860
126
+ HF_HOME=/data/huggingface
127
+ TORCH_HOME=/data/torch
128
+ ```
129
+
130
+ ### Secrets (If Needed)
131
+
132
+ For API keys or sensitive data:
133
+ - Go to Space Settings → Repository secrets
134
+ - Add secrets like `HF_TOKEN`, `API_KEY`, etc.
135
+ - Access in code: `os.environ.get("SECRET_NAME")`
136
+
137
+ ## Post-Deployment Setup
138
+
139
+ ### First Launch
140
+
141
+ 1. **Wait for Model Download**
142
+ - First launch downloads ACE-Step model (~7GB)
143
+ - Takes 5-10 minutes depending on connection
144
+ - Model cached for subsequent runs
145
+
146
+ 2. **Test Basic Generation**
147
+ - Go to Tab 1 (Standard ACE-Step)
148
+ - Enter simple prompt: "Happy pop song"
149
+ - Set duration to 10 seconds
150
+ - Click Generate
151
+
152
+ 3. **Test Timeline**
153
+ - Go to Tab 2 (Timeline Workflow)
154
+ - Enter prompt and lyrics
155
+ - Set context length to 30s
156
+ - Generate first clip
157
+
158
+ 4. **Test LoRA Training**
159
+ - Go to Tab 3 (LoRA Training)
160
+ - Upload 2-3 test audio files
161
+ - Run quick training (2-3 epochs)
162
+
163
+ ### Monitoring
164
+
165
+ **View Logs:**
166
+ - Click "Logs" tab in your Space
167
+ - Monitor for errors or warnings
168
+ - Check GPU usage and memory
169
+
170
+ **Performance Metrics:**
171
+ - Generation time
172
+ - Memory usage
173
+ - Error rate
174
+ - User feedback
175
+
176
+ ### Troubleshooting
177
+
178
+ **Space Not Building:**
179
+ - Check requirements.txt for conflicts
180
+ - Verify Python 3.10 compatibility
181
+ - Check logs for specific errors
182
+
183
+ **Out of Memory:**
184
+ - Upgrade to larger GPU
185
+ - Reduce batch size in LoRA training
186
+ - Limit generation duration
187
+
188
+ **Model Not Loading:**
189
+ - Check HuggingFace Hub access
190
+ - Verify model ID in config.yaml
191
+ - Check internet connectivity
192
+
193
+ **Slow Performance:**
194
+ - Upgrade GPU tier
195
+ - Reduce concurrent users
196
+ - Optimize generation parameters
197
+
198
+ ## Optimization Tips
199
+
200
+ ### Reduce Startup Time
201
+
202
+ 1. **Cache Models**
203
+ ```python
204
+ # In app.py, add before model loading:
205
+ os.environ["HF_HOME"] = "/data/huggingface"
206
+ ```
207
+
208
+ 2. **Preload on Startup**
209
+ - Models download on first run
210
+ - Cached for subsequent uses
211
+ - Consider pre-downloading to Space
212
+
213
+ ### Improve Response Time
214
+
215
+ 1. **Use Queuing**
216
+ - Gradio automatically queues requests
217
+ - Set `max_size` in `app.launch()`
218
+
219
+ 2. **Optimize Generation**
220
+ - Lower default duration
221
+ - Reduce sampling steps
222
+ - Use FP16 precision
223
+
224
+ ### Cost Optimization
225
+
226
+ 1. **Auto-Sleep**
227
+ - Space sleeps after inactivity
228
+ - Wakes on first request
229
+ - Configure in Space settings
230
+
231
+ 2. **Usage Limits**
232
+ - Set max concurrent users
233
+ - Limit generation duration
234
+ - Add rate limiting if needed
235
+
236
+ ## Going Live
237
+
238
+ ### Before Public Release
239
+
240
+ - [ ] Test all three tabs thoroughly
241
+ - [ ] Verify LoRA training works
242
+ - [ ] Test with different prompts and styles
243
+ - [ ] Check error handling
244
+ - [ ] Review logs for issues
245
+ - [ ] Test on mobile devices
246
+ - [ ] Add usage examples
247
+ - [ ] Create demo video
248
+
249
+ ### Public Space Settings
250
+
251
+ 1. **Enable Discussions**
252
+ - Let users report issues
253
+ - Gather feedback
254
+
255
+ 2. **Add Examples**
256
+ - Create example prompts
257
+ - Show best practices
258
+ - Include sample outputs
259
+
260
+ 3. **Update README**
261
+ - Clear usage instructions
262
+ - Feature highlights
263
+ - Limitations and known issues
264
+
265
+ 4. **Pin Space**
266
+ - Makes it discoverable
267
+ - Shows on your profile
268
+
269
+ ## Maintenance
270
+
271
+ ### Regular Updates
272
+
273
+ ```powershell
274
+ # Update code
275
+ cd d:\2025-vibe-coding\ACE-Step-Custom
276
+ git add .
277
+ git commit -m "Update description"
278
+ git push hf main
279
+ ```
280
+
281
+ ### Monitor Usage
282
+
283
+ - Check Space analytics
284
+ - Review user feedback
285
+ - Monitor error rates
286
+ - Track popular features
287
+
288
+ ### Scaling
289
+
290
+ **If Space Gets Popular:**
291
+ 1. Upgrade GPU tier
292
+ 2. Add request queuing
293
+ 3. Implement caching
294
+ 4. Consider duplicate Spaces for load balancing
295
+
296
+ ## Support & Community
297
+
298
+ ### Get Help
299
+
300
+ - HuggingFace Forums: https://discuss.huggingface.co/
301
+ - Discord: https://discord.gg/huggingface
302
+ - Docs: https://huggingface.co/docs/hub/spaces
303
+
304
+ ### Share Your Space
305
+
306
+ - Post on Twitter/X with #HuggingFace #ACEStep
307
+ - Share in AI music communities
308
+ - Add to your portfolio
309
+ - Write blog post about it
310
+
311
+ ## Advanced Configuration
312
+
313
+ ### Custom Domain (Pro)
314
+
315
+ HuggingFace Pro users can set custom domains:
316
+ 1. Go to Space settings
317
+ 2. Add custom domain
318
+ 3. Configure DNS
319
+
320
+ ### Persistent Storage
321
+
322
+ For saving user data:
323
+ ```python
324
+ import os
325
+ PERSIST_DIR = os.environ.get("SPACE_ID", "local")
326
+ # Save to /data/{SPACE_ID}/
327
+ ```
328
+
329
+ ### Analytics Integration
330
+
331
+ Add Google Analytics or similar:
332
+ ```python
333
+ # In app.py
334
+ analytics_code = """
335
+ <script async src="https://www.googletagmanager.com/gtag/js?id=YOUR-ID"></script>
336
+ <script>
337
+ window.dataLayer = window.dataLayer || [];
338
+ function gtag(){dataLayer.push(arguments);}
339
+ gtag('js', new Date());
340
+ gtag('config', 'YOUR-ID');
341
+ </script>
342
+ """
343
+ ```
344
+
345
+ ## Success Checklist
346
+
347
+ Before announcing your Space:
348
+
349
+ - ✅ All features working
350
+ - ✅ Clear documentation
351
+ - ✅ Example outputs included
352
+ - ✅ Error handling robust
353
+ - ✅ Performance optimized
354
+ - ✅ Mobile-friendly UI
355
+ - ✅ Clear limitations stated
356
+ - ✅ License properly attributed
357
+ - ✅ Usage guidelines clear
358
+ - ✅ Contact/support info provided
359
+
360
+ ## Your Space URL
361
+
362
+ After deployment, your Space will be available at:
363
+ ```
364
+ https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
365
+ ```
366
+
367
+ Share it with the world! 🎵🚀
DEPLOYMENT_CHECKLIST.txt ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ⚠️ IMPORTANT: Follow these steps in order ⚠️
2
+
3
+ ═══════════════════════════════════════════════════════════════
4
+ 🚀 HuggingFace Spaces Deployment - Step by Step
5
+ ═══════════════════════════════════════════════════════════════
6
+
7
+ 📋 PREREQUISITES
8
+ ═══════════════════════════════════════════════════════════════
9
+ ☐ HuggingFace account created: https://huggingface.co/join
10
+ ☐ HuggingFace token obtained: https://huggingface.co/settings/tokens
11
+ (Create new token with "write" access)
12
+ ☐ HuggingFace CLI installed (already done ✓)
13
+
14
+ ═══════════════════════════════════════════════════════════════
15
+ 🎯 DEPLOYMENT STEPS
16
+ ═══════════════════════════════════════════════════════════════
17
+
18
+ Choose ONE method below:
19
+
20
+ ┌─────────────────────────────────────────────────────────────┐
21
+ │ METHOD 1: AUTOMATED SCRIPT (EASIEST) ⭐ │
22
+ └─────────────────────────────────────────────────────────────┘
23
+
24
+ 1. Open PowerShell in this directory:
25
+ d:\2025-vibe-coding\ACE-Step-Custom
26
+
27
+ 2. Run the deployment script:
28
+ .\deploy_hf.bat
29
+
30
+ 3. Follow the prompts:
31
+ - Login with your HF token
32
+ - Enter Space name (e.g., "ace-step-custom")
33
+ - Wait for upload
34
+
35
+ ┌─────────────────────────────────────────────────────────────┐
36
+ │ METHOD 2: MANUAL CLI (FOR DEVELOPERS) │
37
+ └─────────────────────────────────────────────────────────────┘
38
+
39
+ 1. Login to HuggingFace:
40
+ huggingface-cli login
41
+ [Paste your token]
42
+
43
+ 2. Create the Space:
44
+ huggingface-cli repo create ace-step-custom --type space --space_sdk gradio
45
+
46
+ 3. Upload files:
47
+ huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space
48
+
49
+ ┌─────────────────────────────────────────────────────────────┐
50
+ │ METHOD 3: WEB INTERFACE (NO CLI NEEDED) │
51
+ └─────────────────────────────────────────────────────────────┘
52
+
53
+ 1. Go to: https://huggingface.co/new-space
54
+
55
+ 2. Fill in Space details:
56
+ Name: ace-step-custom
57
+ License: MIT
58
+ SDK: Gradio
59
+ Hardware: A10G Small
60
+
61
+ 3. Click "Create Space"
62
+
63
+ 4. Click "Files and versions" → "Add file" → "Upload files"
64
+
65
+ 5. Upload these files/folders:
66
+ ✓ app.py
67
+ ✓ requirements.txt
68
+ ✓ config.yaml
69
+ ✓ README.md
70
+ ✓ LICENSE
71
+ ✓ .gitignore
72
+ ✓ src/ (entire folder)
73
+ ✓ scripts/ (entire folder)
74
+
75
+ 6. Commit changes
76
+
77
+ ═══════════════════════════════════════════════════════════════
78
+ ⚙️ POST-DEPLOYMENT CONFIGURATION
79
+ ═══════════════════════════════════════════════════════════════
80
+
81
+ After upload, configure your Space:
82
+
83
+ 1. ☐ Go to your Space URL:
84
+ https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
85
+
86
+ 2. ☐ Click "Settings" tab
87
+
88
+ 3. ☐ Configure Hardware:
89
+ - Select: "A10G Small" (24GB VRAM) - MINIMUM
90
+ - Or: "A100 Large" (40GB VRAM) - RECOMMENDED
91
+ - Click "Save"
92
+
93
+ 4. ☐ Set Python version: 3.10 (should be automatic)
94
+
95
+ 5. ☐ Set timeout: 30 minutes (optional, for long generations)
96
+
97
+ 6. ☐ Enable Discussions (optional, for user feedback)
98
+
99
+ ═══════════════════════════════════════════════════════════════
100
+ ⏱️ BUILD & TESTING
101
+ ═══════════════════════════════════════════════════════════════
102
+
103
+ 1. ☐ Wait for build to complete:
104
+ - Click "Logs" tab to monitor
105
+ - First build: 5-10 minutes
106
+ - Model download: ~7GB (first run only)
107
+
108
+ 2. ☐ Space will show "Running" when ready
109
+
110
+ 3. ☐ Test Tab 1 (Standard ACE-Step):
111
+ - Enter prompt: "Happy pop song with piano"
112
+ - Set duration: 10 seconds
113
+ - Click "Generate"
114
+ - Verify audio plays
115
+
116
+ 4. ☐ Test Tab 2 (Timeline Workflow):
117
+ - Enter prompt and lyrics
118
+ - Set context length: 30 seconds
119
+ - Click "Generate Clip"
120
+ - Verify timeline updates
121
+
122
+ 5. ☐ Test Tab 3 (LoRA Training):
123
+ - Upload 2-3 test audio files
124
+ - Set epochs to 2
125
+ - Click "Start Training"
126
+ - Verify progress updates
127
+
128
+ ═══════════════════════════════════════════════════════════════
129
+ 💰 COST MANAGEMENT
130
+ ═══════════════════════════════════════════════════════════════
131
+
132
+ GPU Costs:
133
+ - A10G Small (24GB): ~$1.05/hour ⭐ RECOMMENDED
134
+ - A100 Large (40GB): ~$4.13/hour
135
+
136
+ Auto-Sleep:
137
+ ✓ Space sleeps automatically after 48 hours of inactivity
138
+ ✓ Wakes up on first request (30-60 second startup)
139
+ ✓ No charges while sleeping
140
+
141
+ Testing Budget:
142
+ - Initial testing: ~$5-10
143
+ - Active use: ~$10-50/month
144
+ - Production: Scale as needed
145
+
146
+ ═══════════════════════════════════════════════════════════════
147
+ 🐛 TROUBLESHOOTING
148
+ ═══════════════════════════════════════════════════════════════
149
+
150
+ Problem: Space won't start
151
+ Solution:
152
+ - Check "Logs" tab for errors
153
+ - Verify all files uploaded
154
+ - Ensure README.md has YAML frontmatter
155
+
156
+ Problem: Out of memory error
157
+ Solution:
158
+ - Upgrade to A100 Large
159
+ - Reduce generation duration in UI
160
+ - Lower batch size in LoRA training
161
+
162
+ Problem: Slow generation
163
+ Solution:
164
+ - Verify GPU is enabled (not CPU)
165
+ - Check Space isn't sleeping
166
+ - Reduce sampling steps in config
167
+
168
+ Problem: Model download fails
169
+ Solution:
170
+ - Check HuggingFace Hub status
171
+ - Verify internet connectivity
172
+ - Wait and retry
173
+
174
+ ═══════════════════════════════════════════════════════════════
175
+ ✅ SUCCESS CHECKLIST
176
+ ═══════════════════════════════════════════════════════════════
177
+
178
+ Before announcing your Space:
179
+
180
+ ☐ All three tabs tested and working
181
+ ☐ Example generations added to README
182
+ ☐ Clear usage instructions visible
183
+ ☐ GPU enabled (A10G Small minimum)
184
+ ☐ Error handling tested
185
+ ☐ Mobile view checked
186
+ ☐ Discussions enabled
187
+ ☐ License properly displayed
188
+ ☐ Contact/support info added
189
+ ☐ Share link works
190
+
191
+ ═══════════════════════════════════════════════════════════════
192
+ 🎉 GO LIVE!
193
+ ═══════════════════════════════════════════════════════════════
194
+
195
+ Your Space URL:
196
+ https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
197
+
198
+ Share it:
199
+ □ Twitter/X: "Just deployed ACE-Step 1.5 Custom on @huggingface! 🎵
200
+ Check it out: [your-url] #AIMusic #HuggingFace #ACEStep"
201
+ □ LinkedIn post
202
+ □ Reddit (r/MachineLearning, r/artificial, r/WeAreTheMusicMakers)
203
+ □ Discord communities
204
+ □ Personal blog/portfolio
205
+
206
+ ═══════════════════════════════════════════════════════════════
207
+ 📚 ADDITIONAL RESOURCES
208
+ ═══════════════════════════════════════════════════════════════
209
+
210
+ Documentation:
211
+ - DEPLOY_QUICK.md - Quick reference
212
+ - DEPLOYMENT.md - Complete guide
213
+ - README.md - Project documentation
214
+
215
+ Support:
216
+ - HuggingFace Docs: https://huggingface.co/docs/hub/spaces
217
+ - HuggingFace Discord: https://discord.gg/huggingface
218
+ - GitHub Issues: [your-repo-url]
219
+
220
+ ═══════════════════════════════════════════════════════════════
221
+
222
+ Ready to deploy? 🚀
223
+
224
+ Run: .\deploy_hf.bat
225
+
226
+ ═══════════════════════════════════════════════════════════════
DEPLOY_QUICK.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Quick Deployment to HuggingFace Spaces
2
+
3
+ ## Prerequisites
4
+ ✅ HuggingFace account: https://huggingface.co/join
5
+ ✅ HuggingFace token: https://huggingface.co/settings/tokens
6
+
7
+ ## Fastest Method (Windows)
8
+
9
+ Run the deployment script:
10
+
11
+ ```powershell
12
+ cd d:\2025-vibe-coding\ACE-Step-Custom
13
+ .\deploy_hf.bat
14
+ ```
15
+
16
+ The script will:
17
+ 1. Install HuggingFace CLI (if needed)
18
+ 2. Login to your account
19
+ 3. Create new Space
20
+ 4. Upload all files
21
+ 5. Provide your Space URL
22
+
23
+ ## Fastest Method (Linux/Mac)
24
+
25
+ ```bash
26
+ cd /path/to/ACE-Step-Custom
27
+ chmod +x deploy_hf.sh
28
+ ./deploy_hf.sh
29
+ ```
30
+
31
+ ## Manual Deployment (If Script Fails)
32
+
33
+ ### 1. Install HuggingFace CLI
34
+ ```powershell
35
+ pip install huggingface_hub
36
+ ```
37
+
38
+ ### 2. Login
39
+ ```powershell
40
+ huggingface-cli login
41
+ ```
42
+ Enter your token from: https://huggingface.co/settings/tokens
43
+
44
+ ### 3. Create Space
45
+ ```powershell
46
+ huggingface-cli repo create ace-step-custom --type space --space_sdk gradio
47
+ ```
48
+
49
+ ### 4. Upload Files
50
+ ```powershell
51
+ cd d:\2025-vibe-coding\ACE-Step-Custom
52
+ huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space
53
+ ```
54
+
55
+ Replace `YOUR_USERNAME` with your HuggingFace username.
56
+
57
+ ## After Upload
58
+
59
+ ### 1. Configure GPU
60
+ - Go to your Space: https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
61
+ - Click "Settings" tab
62
+ - Under "Hardware", select: **A10G Small** (recommended)
63
+ - Click "Save"
64
+
65
+ ### 2. Wait for Build
66
+ - Space will build automatically (5-10 minutes)
67
+ - Check "Logs" tab for progress
68
+ - Model downloads on first run (~7GB)
69
+
70
+ ### 3. Test Your Space
71
+ 1. Open Space URL
72
+ 2. Test Tab 1: Generate 10-second clip
73
+ 3. Test Tab 2: Generate timeline clip
74
+ 4. Test Tab 3: Upload test audio
75
+
76
+ ## Troubleshooting
77
+
78
+ **Login Failed:**
79
+ ```powershell
80
+ # Make sure you copied the full token
81
+ huggingface-cli whoami # Check if logged in
82
+ ```
83
+
84
+ **Upload Failed:**
85
+ ```powershell
86
+ # Try with explicit exclusions
87
+ huggingface-cli upload YOUR_USERNAME/ace-step-custom . --repo-type space --exclude "*.pyc" --exclude "outputs/*" --exclude "__pycache__/*"
88
+ ```
89
+
90
+ **Space Not Starting:**
91
+ - Check "Logs" tab for errors
92
+ - Verify requirements.txt is uploaded
93
+ - Ensure README.md has correct YAML frontmatter
94
+
95
+ **Out of Memory:**
96
+ - Upgrade GPU in Settings
97
+ - Start with A10G Small minimum
98
+
99
+ ## Your Space URL
100
+
101
+ After deployment:
102
+ ```
103
+ https://huggingface.co/spaces/YOUR_USERNAME/ace-step-custom
104
+ ```
105
+
106
+ ## Cost Estimate
107
+
108
+ - **A10G Small (24GB):** ~$1.05/hour
109
+ - **Auto-sleep:** Space sleeps when inactive (no charge)
110
+ - **Testing:** Budget ~$5-10 for initial testing
111
+
112
+ ## Need Help?
113
+
114
+ See full guide: [DEPLOYMENT.md](DEPLOYMENT.md)
115
+
116
+ ## Next Steps
117
+
118
+ 1. ✅ Deploy Space
119
+ 2. ✅ Test all features
120
+ 3. ✅ Enable Discussions in Settings
121
+ 4. ✅ Add example outputs to README
122
+ 5. ✅ Share your Space!
123
+
124
+ ---
125
+
126
+ 🎵 Happy testing! Your Space will be live in minutes! 🚀
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
2
+
3
+ # Set working directory
4
+ WORKDIR /app
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ python3.10 \
9
+ python3-pip \
10
+ git \
11
+ ffmpeg \
12
+ libsndfile1 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Copy requirements
16
+ COPY requirements.txt .
17
+
18
+ # Install Python dependencies
19
+ RUN pip3 install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy application files
22
+ COPY . .
23
+
24
+ # Create necessary directories
25
+ RUN mkdir -p outputs timelines lora_training logs models
26
+
27
+ # Expose Gradio port
28
+ EXPOSE 7860
29
+
30
+ # Set environment variables
31
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
32
+ ENV GRADIO_SERVER_PORT=7860
33
+
34
+ # Run the application
35
+ CMD ["python3", "app.py"]
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Gamahea Development Team
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ ---
24
+
25
+ This project uses ACE-Step, which is subject to its own license:
26
+ https://github.com/ace-step/ACE-Step
27
+
28
+ Please refer to the original ACE-Step repository for their licensing terms.
QUICKSTART.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ACE-Step 1.5 Custom Edition - Quick Start Guide
2
+
3
+ ## Installation
4
+
5
+ ### Option 1: Local Setup
6
+
7
+ 1. **Clone the repository**
8
+ ```bash
9
+ git clone https://github.com/yourusername/ace-step-custom.git
10
+ cd ace-step-custom
11
+ ```
12
+
13
+ 2. **Create virtual environment**
14
+ ```bash
15
+ python -m venv venv
16
+
17
+ # On Windows:
18
+ venv\Scripts\activate
19
+
20
+ # On Linux/Mac:
21
+ source venv/bin/activate
22
+ ```
23
+
24
+ 3. **Run setup**
25
+ ```bash
26
+ python scripts/setup.py
27
+ ```
28
+
29
+ 4. **Download model**
30
+ ```bash
31
+ python scripts/download_model.py
32
+ ```
33
+
34
+ 5. **Launch application**
35
+ ```bash
36
+ python app.py
37
+ ```
38
+
39
+ 6. **Open browser to** `http://localhost:7860`
40
+
41
+ ### Option 2: HuggingFace Spaces
42
+
43
+ 1. Create new Space on HuggingFace
44
+ 2. Upload all project files
45
+ 3. Set Space configuration:
46
+ - SDK: `gradio`
47
+ - Python: `3.10`
48
+ - GPU: `A10G` (or better)
49
+ 4. Space will auto-deploy
50
+
51
+ ## Usage
52
+
53
+ ### Tab 1: Standard ACE-Step
54
+
55
+ Standard interface with all original ACE-Step features:
56
+ - Text-to-music generation
57
+ - Variation generation
58
+ - Repainting sections
59
+ - Lyric editing
60
+
61
+ ### Tab 2: Timeline Workflow
62
+
63
+ Advanced timeline-based generation:
64
+ 1. Enter prompt and lyrics
65
+ 2. Set context length (0-120s)
66
+ 3. Click "Generate" for 32s clips
67
+ 4. Clips auto-blend into timeline
68
+ 5. Use "Extend" to continue
69
+ 6. Use "Inpaint" to edit regions
70
+
71
+ ### Tab 3: LoRA Training
72
+
73
+ Train custom models:
74
+ 1. Upload audio files (10+ recommended)
75
+ 2. Set training parameters
76
+ 3. Click "Start Training"
77
+ 4. Download trained model
78
+ 5. Use in Tab 1 or Tab 2
79
+
80
+ ## Tips
81
+
82
+ - **First time:** Start with Standard tab to understand basics
83
+ - **For longer songs:** Use Timeline tab with context length 30-60s
84
+ - **For custom styles:** Train LoRA with 20+ similar audio files
85
+ - **GPU recommended:** 8GB+ VRAM for best performance
86
+ - **CPU mode:** Works but slower, use shorter durations
87
+
88
+ ## Troubleshooting
89
+
90
+ ### Out of Memory
91
+ - Reduce batch size in LoRA training
92
+ - Use shorter audio durations
93
+ - Close other GPU applications
94
+
95
+ ### Poor Quality
96
+ - Increase context length
97
+ - Try different seeds
98
+ - Adjust temperature (0.6-0.8 is usually good)
99
+
100
+ ### Blend Artifacts
101
+ - Reduce lead-in/lead-out durations
102
+ - Ensure consistent style across clips
103
+ - Use lower context length for more variety
104
+
105
+ ## Support
106
+
107
+ - GitHub Issues: [Report bugs here]
108
+ - Documentation: See `docs/` directory
109
+ - Examples: See `examples/` directory
110
+
111
+ ## Credits
112
+
113
+ Based on ACE-Step by ACE Studio and Step Fun
114
+ - Website: https://ace-step.github.io/
115
+ - Paper: https://arxiv.org/abs/2506.00045
README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ACE-Step 1.5 Custom Edition
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.11
12
+ ---
13
+
14
+ # ACE-Step 1.5 Custom Edition
15
+
16
+ A comprehensive music generation system built on ACE-Step 1.5, featuring:
17
+
18
+ ## 🌟 Features
19
+
20
+ ### 1. Standard ACE-Step Interface
21
+ Full-featured standard ACE-Step 1.5 GUI with all original capabilities including:
22
+ - Text-to-music generation with style control
23
+ - Variation generation
24
+ - Section repainting
25
+ - Lyric editing
26
+
27
+ ### 2. Custom Timeline Workflow
28
+ Advanced timeline-based generation system:
29
+ - Generate 32-second clips with seamless blending
30
+ - Adjustable context length (0-120 seconds) for style consistency
31
+ - Master timeline with visual representation
32
+ - Extend, inpaint, and remix capabilities
33
+ - Automatic crossfading between clips
34
+
35
+ ### 3. LoRA Training Studio
36
+ Complete training interface for custom models:
37
+ - Upload and preprocess audio files
38
+ - Configure training parameters
39
+ - Train specialized models for voices, instruments, or styles
40
+ - Download and reuse trained models
41
+ - Continue training from existing LoRAs
42
+
43
+ ## 🚀 Quick Start
44
+
45
+ 1. **Standard Generation**: Use Tab 1 for traditional text-to-music
46
+ 2. **Timeline Creation**: Use Tab 2 to build longer songs with consistent style
47
+ 3. **Custom Training**: Use Tab 3 to create specialized models
48
+
49
+ ## 💡 Tips
50
+
51
+ - Start with context length of 30-60s for best results
52
+ - For custom voices, train LoRA with 20+ audio samples
53
+ - Adjust temperature between 0.6-0.8 for quality vs creativity
54
+ - Use "Extend" in Timeline mode to continue your song
55
+
56
+ ## 🎯 Use Cases
57
+
58
+ - **Musicians**: Create backing tracks and song ideas
59
+ - **Content Creators**: Generate royalty-free music for videos
60
+ - **Game Developers**: Create adaptive game soundtracks
61
+ - **AI Researchers**: Experiment with music generation and LoRA training
62
+
63
+ ## 📚 Documentation
64
+
65
+ See the repository for full documentation and examples.
66
+
67
+ ## 🙏 Credits
68
+
69
+ Built on top of [ACE-Step](https://ace-step.github.io/) by ACE Studio and Step Fun.
70
+
71
+ ## ⚠️ Note
72
+
73
+ This is a custom implementation focusing on enhanced workflows and training capabilities. Generation quality depends on the base ACE-Step model and your usage patterns.
README_HF.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ACE-Step 1.5 Custom Edition
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: 3.11
12
+ ---
13
+
14
+ # ACE-Step 1.5 Custom Edition
15
+
16
+ A comprehensive music generation system built on ACE-Step 1.5, featuring:
17
+
18
+ ## 🌟 Features
19
+
20
+ ### 1. Standard ACE-Step Interface
21
+ Full-featured standard ACE-Step 1.5 GUI with all original capabilities including:
22
+ - Text-to-music generation with style control
23
+ - Variation generation
24
+ - Section repainting
25
+ - Lyric editing
26
+
27
+ ### 2. Custom Timeline Workflow
28
+ Advanced timeline-based generation system:
29
+ - Generate 32-second clips with seamless blending
30
+ - Adjustable context length (0-120 seconds) for style consistency
31
+ - Master timeline with visual representation
32
+ - Extend, inpaint, and remix capabilities
33
+ - Automatic crossfading between clips
34
+
35
+ ### 3. LoRA Training Studio
36
+ Complete training interface for custom models:
37
+ - Upload and preprocess audio files
38
+ - Configure training parameters
39
+ - Train specialized models for voices, instruments, or styles
40
+ - Download and reuse trained models
41
+ - Continue training from existing LoRAs
42
+
43
+ ## 🚀 Quick Start
44
+
45
+ 1. **Standard Generation**: Use Tab 1 for traditional text-to-music
46
+ 2. **Timeline Creation**: Use Tab 2 to build longer songs with consistent style
47
+ 3. **Custom Training**: Use Tab 3 to create specialized models
48
+
49
+ ## 💡 Tips
50
+
51
+ - Start with context length of 30-60s for best results
52
+ - For custom voices, train LoRA with 20+ audio samples
53
+ - Adjust temperature between 0.6-0.8 for quality vs creativity
54
+ - Use "Extend" in Timeline mode to continue your song
55
+
56
+ ## 🎯 Use Cases
57
+
58
+ - **Musicians**: Create backing tracks and song ideas
59
+ - **Content Creators**: Generate royalty-free music for videos
60
+ - **Game Developers**: Create adaptive game soundtracks
61
+ - **AI Researchers**: Experiment with music generation and LoRA training
62
+
63
+ ## 📚 Documentation
64
+
65
+ See the repository for full documentation and examples.
66
+
67
+ ## 🙏 Credits
68
+
69
+ Built on top of [ACE-Step](https://ace-step.github.io/) by ACE Studio and Step Fun.
70
+
71
+ ## ⚠️ Note
72
+
73
+ This is a custom implementation focusing on enhanced workflows and training capabilities. Generation quality depends on the base ACE-Step model and your usage patterns.
README_PROJECT.md ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ACE-Step 1.5 Custom Edition
3
+ emoji: 🎵
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.9.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ python_version: "3.11"
12
+ hardware: zero-gpu-medium
13
+ ---
14
+
15
+ # ACE-Step 1.5 Custom Edition
16
+
17
+ A fully-featured implementation of ACE-Step 1.5 with custom GUI and workflow capabilities for local use and HuggingFace Space deployment.
18
+
19
+ ## Features
20
+
21
+ ### 🎵 Three Main Interfaces
22
+
23
+ 1. **Standard ACE-Step GUI**: Full-featured standard ACE-Step 1.5 interface with all original capabilities
24
+ 2. **Custom Timeline Workflow**: Advanced timeline-based generation with:
25
+ - 32-second clip generation (2s lead-in + 28s main + 2s lead-out)
26
+ - Seamless clip blending for continuous music
27
+ - Context Length slider (0-120 seconds) for style guidance
28
+ - Master timeline with extend, inpaint, and remix capabilities
29
+ 3. **LoRA Training Studio**: Complete LoRA training interface with:
30
+ - Audio file upload and preprocessing
31
+ - Custom training configuration
32
+ - Model download/upload for continued training
33
+
34
+ ## Architecture
35
+
36
+ - **Base Model**: ACE-Step v1.5 Turbo
37
+ - **Framework**: Gradio 5.9.1, PyTorch
38
+ - **Deployment**: Local execution + HuggingFace Spaces
39
+ - **Audio Processing**: DiT + VAE + 5Hz Language Model
40
+
41
+ ## Installation
42
+
43
+ ### Local Setup
44
+
45
+ ```bash
46
+ # Clone the repository
47
+ git clone https://github.com/yourusername/ace-step-custom.git
48
+ cd ace-step-custom
49
+
50
+ # Create virtual environment
51
+ python -m venv venv
52
+ source venv/bin/activate # On Windows: venv\Scripts\activate
53
+
54
+ # Install dependencies
55
+ pip install -r requirements.txt
56
+
57
+ # Download ACE-Step model
58
+ python scripts/download_model.py
59
+
60
+ # Run the application
61
+ python app.py
62
+ ```
63
+
64
+ ### HuggingFace Space Deployment
65
+
66
+ 1. Create a new Space on HuggingFace
67
+ 2. Upload all files to the Space
68
+ 3. Set Space to use GPU (recommended: H200 or A100)
69
+ 4. The app will automatically download models and start
70
+
71
+ ## Usage
72
+
73
+ ### Standard Mode
74
+ Use the first tab for standard ACE-Step generation with all original features.
75
+
76
+ ### Timeline Mode
77
+ 1. Enter your prompt/lyrics
78
+ 2. Adjust Context Length (how far back to reference previous clips)
79
+ 3. Click "Generate" to create 32-second clips
80
+ 4. Clips automatically blend and add to timeline
81
+ 5. Use "Extend" to continue the song or other options for variations
82
+
83
+ ### LoRA Training
84
+ 1. Upload audio files for training
85
+ 2. Configure training parameters
86
+ 3. Train custom LoRA models
87
+ 4. Download and reuse for continued training
88
+
89
+ ## System Requirements
90
+
91
+ ### Minimum
92
+ - GPU: 8GB VRAM (with optimizations)
93
+ - RAM: 16GB
94
+ - Storage: 20GB
95
+
96
+ ### Recommended
97
+ - GPU: 16GB+ VRAM (A100, H200, or consumer GPUs)
98
+ - RAM: 32GB
99
+ - Storage: 50GB
100
+
101
+ ## Technical Details
102
+
103
+ - **Audio Format**: 48kHz, stereo
104
+ - **Generation Speed**: ~8 inference steps (turbo model)
105
+ - **Context Window**: Up to 120 seconds for style guidance
106
+ - **Blend Regions**: 2-second crossfade between clips
107
+
108
+ ## Credits
109
+
110
+ Based on ACE-Step 1.5 by ACE Studio
111
+ - GitHub: https://github.com/ace-step/ACE-Step-1.5
112
+ - Original Demo: https://huggingface.co/spaces/ACE-Step/ACE-Step
113
+
114
+ ## License
115
+
116
+ MIT License (see LICENSE file)
acestep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ACE-Step package."""
acestep/acestep_v15_pipeline.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step V1.5 Pipeline
3
+ Handler wrapper connecting model and UI
4
+ """
5
+ import os
6
+ import sys
7
+
8
+ # Load environment variables from .env file at most once per process to avoid
9
+ # epoch-boundary stalls (e.g. on Windows when Gradio yields during training)
10
+ _env_loaded = False # module-level so we never reload .env in the same process
11
+ try:
12
+ from dotenv import load_dotenv
13
+ if not _env_loaded:
14
+ _current_file = os.path.abspath(__file__)
15
+ _project_root = os.path.dirname(os.path.dirname(_current_file))
16
+ _env_path = os.path.join(_project_root, '.env')
17
+ _env_example_path = os.path.join(_project_root, '.env.example')
18
+ if os.path.exists(_env_path):
19
+ load_dotenv(_env_path)
20
+ print(f"Loaded configuration from {_env_path}")
21
+ elif os.path.exists(_env_example_path):
22
+ load_dotenv(_env_example_path)
23
+ print(f"Loaded configuration from {_env_example_path} (fallback)")
24
+ _env_loaded = True
25
+ except ImportError:
26
+ # python-dotenv not installed, skip loading .env
27
+ pass
28
+
29
+ # Clear proxy settings that may affect Gradio
30
+ for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
31
+ os.environ.pop(proxy_var, None)
32
+
33
+ try:
34
+ # When executed as a module: `python -m acestep.acestep_v15_pipeline`
35
+ from .handler import AceStepHandler
36
+ from .llm_inference import LLMHandler
37
+ from .dataset_handler import DatasetHandler
38
+ from .gradio_ui import create_gradio_interface
39
+ from .gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB
40
+ except ImportError:
41
+ # When executed as a script: `python acestep/acestep_v15_pipeline.py`
42
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
43
+ if project_root not in sys.path:
44
+ sys.path.insert(0, project_root)
45
+ from acestep.handler import AceStepHandler
46
+ from acestep.llm_inference import LLMHandler
47
+ from acestep.dataset_handler import DatasetHandler
48
+ from acestep.gradio_ui import create_gradio_interface
49
+ from acestep.gpu_config import get_gpu_config, get_gpu_memory_gb, print_gpu_config_info, set_global_gpu_config, VRAM_16GB_MIN_GB
50
+
51
+
52
+ def create_demo(init_params=None, language='en'):
53
+ """
54
+ Create Gradio demo interface
55
+
56
+ Args:
57
+ init_params: Dictionary containing initialization parameters and state.
58
+ If None, service will not be pre-initialized.
59
+ Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
60
+ 'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
61
+ 'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
62
+ 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
63
+ 'language' (UI language code)
64
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
65
+
66
+ Returns:
67
+ Gradio Blocks instance
68
+ """
69
+ # Use pre-initialized handlers if available, otherwise create new ones
70
+ if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
71
+ dit_handler = init_params['dit_handler']
72
+ llm_handler = init_params['llm_handler']
73
+ else:
74
+ dit_handler = AceStepHandler() # DiT handler
75
+ llm_handler = LLMHandler() # LM handler
76
+
77
+ dataset_handler = DatasetHandler() # Dataset handler
78
+
79
+ # Create Gradio interface with all handlers and initialization parameters
80
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
81
+
82
+ return demo
83
+
84
+
85
+ def main():
86
+ """Main entry function"""
87
+ import argparse
88
+
89
+ # Detect GPU memory and get configuration
90
+ gpu_config = get_gpu_config()
91
+ set_global_gpu_config(gpu_config) # Set global config for use across modules
92
+
93
+ gpu_memory_gb = gpu_config.gpu_memory_gb
94
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < VRAM_16GB_MIN_GB
95
+
96
+ # Print GPU configuration info
97
+ print(f"\n{'='*60}")
98
+ print("GPU Configuration Detected:")
99
+ print(f"{'='*60}")
100
+ print(f" GPU Memory: {gpu_memory_gb:.2f} GB")
101
+ print(f" Configuration Tier: {gpu_config.tier}")
102
+ print(f" Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
103
+ print(f" Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
104
+ print(f" Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
105
+ print(f" Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
106
+ print(f" Default LM Init: {gpu_config.init_lm_default}")
107
+ print(f" Available LM Models: {gpu_config.available_lm_models or 'None'}")
108
+ print(f"{'='*60}\n")
109
+
110
+ if auto_offload:
111
+ print(f"Auto-enabling CPU offload (GPU < 16GB)")
112
+ elif gpu_memory_gb > 0:
113
+ print(f"CPU offload disabled by default (GPU >= 16GB)")
114
+ else:
115
+ print("No GPU detected, running on CPU")
116
+
117
+ # Define local outputs directory
118
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
119
+ output_dir = os.path.join(project_root, "gradio_outputs")
120
+ # Normalize path to use forward slashes for Gradio 6 compatibility on Windows
121
+ output_dir = output_dir.replace("\\", "/")
122
+ os.makedirs(output_dir, exist_ok=True)
123
+ print(f"Output directory: {output_dir}")
124
+
125
+ parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
126
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
127
+ parser.add_argument("--share", action="store_true", help="Create a public link")
128
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
129
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
130
+ parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "he", "ja"], help="UI language: en (English), zh (中文), he (עברית), ja (日本語)")
131
+ parser.add_argument(
132
+ "--allowed-path",
133
+ action="append",
134
+ default=[],
135
+ help="Additional allowed file paths for Gradio (repeatable).",
136
+ )
137
+
138
+ # Service mode argument
139
+ parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
140
+ help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
141
+
142
+ # Service initialization arguments
143
+ parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
144
+ parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
145
+ parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
146
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "xpu", "cpu"], help="Processing device (default: auto)")
147
+ parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Initialize 5Hz LM (default: auto based on GPU memory)")
148
+ parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
149
+ parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt", "mlx"], help="5Hz LM backend (default: vllm, use 'mlx' for native Apple Silicon acceleration)")
150
+ parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
151
+ parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
152
+ parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
153
+ parser.add_argument("--download-source", type=str, default=None, choices=["huggingface", "modelscope", "auto"], help="Preferred model download source (default: auto-detect based on network)")
154
+
155
+ # API mode argument
156
+ parser.add_argument("--enable-api", action="store_true", help="Enable API endpoints (default: False)")
157
+
158
+ # Authentication arguments
159
+ parser.add_argument("--auth-username", type=str, default=None, help="Username for Gradio authentication")
160
+ parser.add_argument("--auth-password", type=str, default=None, help="Password for Gradio authentication")
161
+ parser.add_argument("--api-key", type=str, default=None, help="API key for API endpoints authentication")
162
+
163
+ args = parser.parse_args()
164
+
165
+ # Enable API requires init_service
166
+ if args.enable_api:
167
+ args.init_service = True
168
+ # Load config from .env if not specified
169
+ if args.config_path is None:
170
+ args.config_path = os.environ.get("ACESTEP_CONFIG_PATH")
171
+ if args.lm_model_path is None:
172
+ args.lm_model_path = os.environ.get("ACESTEP_LM_MODEL_PATH")
173
+ if os.environ.get("ACESTEP_LM_BACKEND"):
174
+ args.backend = os.environ.get("ACESTEP_LM_BACKEND")
175
+
176
+ # Service mode defaults (can be configured via .env file)
177
+ if args.service_mode:
178
+ print("Service mode enabled - applying preset configurations...")
179
+ # Force init_service in service mode
180
+ args.init_service = True
181
+ # Default DiT model for service mode (from env or fallback)
182
+ if args.config_path is None:
183
+ args.config_path = os.environ.get(
184
+ "SERVICE_MODE_DIT_MODEL",
185
+ "acestep-v15-turbo-fix-inst-shift-dynamic"
186
+ )
187
+ # Default LM model for service mode (from env or fallback)
188
+ if args.lm_model_path is None:
189
+ args.lm_model_path = os.environ.get(
190
+ "SERVICE_MODE_LM_MODEL",
191
+ "acestep-5Hz-lm-1.7B-v4-fix"
192
+ )
193
+ # Backend for service mode (from env or fallback to vllm)
194
+ args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
195
+ print(f" DiT model: {args.config_path}")
196
+ print(f" LM model: {args.lm_model_path}")
197
+ print(f" Backend: {args.backend}")
198
+
199
+ # Auto-enable CPU offload for tier6 GPUs (16-24GB) when using the 4B LM model
200
+ # The 4B LM (~8GB) + DiT (~4.7GB) + VAE + text encoder exceeds 16-20GB with activations
201
+ if not args.offload_to_cpu and args.lm_model_path and "4B" in args.lm_model_path:
202
+ if 0 < gpu_memory_gb <= 24:
203
+ args.offload_to_cpu = True
204
+ print(f"Auto-enabling CPU offload (4B LM model requires offloading on {gpu_memory_gb:.0f}GB GPU)")
205
+
206
+ try:
207
+ init_params = None
208
+ dit_handler = None
209
+ llm_handler = None
210
+
211
+ # If init_service is True, perform initialization before creating UI
212
+ if args.init_service:
213
+ print("Initializing service from command line...")
214
+
215
+ # Create handler instances for initialization
216
+ dit_handler = AceStepHandler()
217
+ llm_handler = LLMHandler()
218
+
219
+ # Auto-select config_path if not provided
220
+ if args.config_path is None:
221
+ available_models = dit_handler.get_available_acestep_v15_models()
222
+ if available_models:
223
+ args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
224
+ print(f"Auto-selected config_path: {args.config_path}")
225
+ else:
226
+ print("Error: No available models found. Please specify --config_path", file=sys.stderr)
227
+ sys.exit(1)
228
+
229
+ # Get project root (same logic as in handler)
230
+ current_file = os.path.abspath(__file__)
231
+ project_root = os.path.dirname(os.path.dirname(current_file))
232
+
233
+ # Determine flash attention setting
234
+ use_flash_attention = args.use_flash_attention
235
+ if use_flash_attention is None:
236
+ use_flash_attention = dit_handler.is_flash_attention_available(args.device)
237
+
238
+ # Determine download source preference
239
+ prefer_source = None
240
+ if args.download_source and args.download_source != "auto":
241
+ prefer_source = args.download_source
242
+ print(f"Using preferred download source: {prefer_source}")
243
+
244
+ # Initialize DiT handler
245
+ print(f"Initializing DiT model: {args.config_path} on {args.device}...")
246
+ init_status, enable_generate = dit_handler.initialize_service(
247
+ project_root=project_root,
248
+ config_path=args.config_path,
249
+ device=args.device,
250
+ use_flash_attention=use_flash_attention,
251
+ compile_model=False,
252
+ offload_to_cpu=args.offload_to_cpu,
253
+ offload_dit_to_cpu=args.offload_dit_to_cpu,
254
+ prefer_source=prefer_source
255
+ )
256
+
257
+ if not enable_generate:
258
+ print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
259
+ sys.exit(1)
260
+
261
+ print(f"DiT model initialized successfully")
262
+
263
+ # Initialize LM handler if requested
264
+ # Auto-determine init_llm based on GPU config if not explicitly set
265
+ if args.init_llm is None:
266
+ args.init_llm = gpu_config.init_lm_default
267
+ print(f"Auto-setting init_llm to {args.init_llm} based on GPU configuration")
268
+
269
+ lm_status = ""
270
+ if args.init_llm:
271
+ if args.lm_model_path is None:
272
+ # Try to get default LM model
273
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
274
+ if available_lm_models:
275
+ args.lm_model_path = available_lm_models[0]
276
+ print(f"Using default LM model: {args.lm_model_path}")
277
+ else:
278
+ print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
279
+ args.init_llm = False
280
+
281
+ if args.init_llm and args.lm_model_path:
282
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
283
+ print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
284
+ lm_status, lm_success = llm_handler.initialize(
285
+ checkpoint_dir=checkpoint_dir,
286
+ lm_model_path=args.lm_model_path,
287
+ backend=args.backend,
288
+ device=args.device,
289
+ offload_to_cpu=args.offload_to_cpu,
290
+ dtype=None,
291
+ )
292
+
293
+ if lm_success:
294
+ print(f"5Hz LM initialized successfully")
295
+ init_status += f"\n{lm_status}"
296
+ else:
297
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
298
+ init_status += f"\n{lm_status}"
299
+
300
+ # Prepare initialization parameters for UI
301
+ init_params = {
302
+ 'pre_initialized': True,
303
+ 'service_mode': args.service_mode,
304
+ 'checkpoint': args.checkpoint,
305
+ 'config_path': args.config_path,
306
+ 'device': args.device,
307
+ 'init_llm': args.init_llm,
308
+ 'lm_model_path': args.lm_model_path,
309
+ 'backend': args.backend,
310
+ 'use_flash_attention': use_flash_attention,
311
+ 'offload_to_cpu': args.offload_to_cpu,
312
+ 'offload_dit_to_cpu': args.offload_dit_to_cpu,
313
+ 'init_status': init_status,
314
+ 'enable_generate': enable_generate,
315
+ 'dit_handler': dit_handler,
316
+ 'llm_handler': llm_handler,
317
+ 'language': args.language,
318
+ 'gpu_config': gpu_config, # Pass GPU config to UI
319
+ 'output_dir': output_dir, # Pass output dir to UI
320
+ }
321
+
322
+ print("Service initialization completed successfully!")
323
+
324
+ # Create and launch demo
325
+ print(f"Creating Gradio interface with language: {args.language}...")
326
+
327
+ # If not using init_service, still pass gpu_config to init_params
328
+ if init_params is None:
329
+ init_params = {
330
+ 'gpu_config': gpu_config,
331
+ 'language': args.language,
332
+ 'output_dir': output_dir, # Pass output dir to UI
333
+ }
334
+
335
+ demo = create_demo(init_params=init_params, language=args.language)
336
+
337
+ # Enable queue for multi-user support
338
+ # This ensures proper request queuing and prevents concurrent generation conflicts
339
+ print("Enabling queue for multi-user support...")
340
+ demo.queue(
341
+ max_size=20, # Maximum queue size (adjust based on your needs)
342
+ status_update_rate="auto", # Update rate for queue status
343
+ default_concurrency_limit=1, # Prevents VRAM saturation
344
+ )
345
+
346
+ print(f"Launching server on {args.server_name}:{args.port}...")
347
+
348
+ # Setup authentication if provided
349
+ auth = None
350
+ if args.auth_username and args.auth_password:
351
+ auth = (args.auth_username, args.auth_password)
352
+ print("Authentication enabled")
353
+
354
+ allowed_paths = [output_dir]
355
+ for p in args.allowed_path:
356
+ if p and p not in allowed_paths:
357
+ allowed_paths.append(p)
358
+
359
+ # Enable API endpoints if requested
360
+ if args.enable_api:
361
+ print("Enabling API endpoints...")
362
+ from acestep.gradio_ui.api_routes import setup_api_routes
363
+
364
+ # Launch Gradio first with prevent_thread_lock=True
365
+ demo.launch(
366
+ server_name=args.server_name,
367
+ server_port=args.port,
368
+ share=args.share,
369
+ debug=args.debug,
370
+ show_error=True,
371
+ prevent_thread_lock=True, # Don't block, so we can add routes
372
+ inbrowser=False,
373
+ auth=auth,
374
+ allowed_paths=allowed_paths, # include output_dir + user-provided
375
+ )
376
+
377
+ # Now add API routes to Gradio's FastAPI app (app is available after launch)
378
+ setup_api_routes(demo, dit_handler, llm_handler, api_key=args.api_key)
379
+
380
+ if args.api_key:
381
+ print("API authentication enabled")
382
+ print("API endpoints enabled: /health, /v1/models, /release_task, /query_result, /create_random_sample, /format_lyrics")
383
+
384
+ # Keep the main thread alive
385
+ try:
386
+ while True:
387
+ import time
388
+ time.sleep(1)
389
+ except KeyboardInterrupt:
390
+ print("\nShutting down...")
391
+ else:
392
+ demo.launch(
393
+ server_name=args.server_name,
394
+ server_port=args.port,
395
+ share=args.share,
396
+ debug=args.debug,
397
+ show_error=True,
398
+ prevent_thread_lock=False,
399
+ inbrowser=False,
400
+ auth=auth,
401
+ allowed_paths=allowed_paths, # include output_dir + user-provided
402
+ )
403
+ except Exception as e:
404
+ print(f"Error launching Gradio: {e}", file=sys.stderr)
405
+ import traceback
406
+ traceback.print_exc()
407
+ sys.exit(1)
408
+
409
+
410
+ if __name__ == "__main__":
411
+ main()
acestep/api_server.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/audio_utils.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+ import hashlib
12
+ import json
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Tuple
15
+ import torch
16
+ import numpy as np
17
+ import torchaudio
18
+ from loguru import logger
19
+
20
+
21
+ class AudioSaver:
22
+ """Audio saving and transcoding utility class"""
23
+
24
+ def __init__(self, default_format: str = "flac"):
25
+ """
26
+ Initialize audio saver
27
+
28
+ Args:
29
+ default_format: Default save format ('flac', 'wav', 'mp3')
30
+ """
31
+ self.default_format = default_format.lower()
32
+ if self.default_format not in ["flac", "wav", "mp3"]:
33
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
34
+ self.default_format = "flac"
35
+
36
+ def save_audio(
37
+ self,
38
+ audio_data: Union[torch.Tensor, np.ndarray],
39
+ output_path: Union[str, Path],
40
+ sample_rate: int = 48000,
41
+ format: Optional[str] = None,
42
+ channels_first: bool = True,
43
+ ) -> str:
44
+ """
45
+ Save audio data to file
46
+
47
+ Args:
48
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
49
+ output_path: Output file path (extension can be omitted)
50
+ sample_rate: Sample rate
51
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
52
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
53
+
54
+ Returns:
55
+ Actual saved file path
56
+ """
57
+ format = (format or self.default_format).lower()
58
+ if format not in ["flac", "wav", "mp3"]:
59
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
60
+ format = self.default_format
61
+
62
+ # Ensure output path has correct extension
63
+ output_path = Path(output_path)
64
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
65
+ output_path = output_path.with_suffix(f'.{format}')
66
+
67
+ # Convert to torch tensor
68
+ if isinstance(audio_data, np.ndarray):
69
+ if channels_first:
70
+ # numpy [samples, channels] -> tensor [channels, samples]
71
+ audio_tensor = torch.from_numpy(audio_data.T).float()
72
+ else:
73
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
74
+ audio_tensor = torch.from_numpy(audio_data).float()
75
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
76
+ audio_tensor = audio_tensor.T
77
+ else:
78
+ # torch tensor
79
+ audio_tensor = audio_data.cpu().float()
80
+ if not channels_first and audio_tensor.dim() == 2:
81
+ # [samples, channels] -> [channels, samples]
82
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
83
+ audio_tensor = audio_tensor.T
84
+
85
+ # Ensure memory is contiguous
86
+ audio_tensor = audio_tensor.contiguous()
87
+
88
+ # Select backend and save
89
+ try:
90
+ if format == "mp3":
91
+ # MP3 uses ffmpeg backend
92
+ torchaudio.save(
93
+ str(output_path),
94
+ audio_tensor,
95
+ sample_rate,
96
+ channels_first=True,
97
+ backend='ffmpeg',
98
+ )
99
+ elif format in ["flac", "wav"]:
100
+ # FLAC and WAV use soundfile backend (fastest)
101
+ torchaudio.save(
102
+ str(output_path),
103
+ audio_tensor,
104
+ sample_rate,
105
+ channels_first=True,
106
+ backend='soundfile',
107
+ )
108
+ else:
109
+ # Other formats use default backend
110
+ torchaudio.save(
111
+ str(output_path),
112
+ audio_tensor,
113
+ sample_rate,
114
+ channels_first=True,
115
+ )
116
+
117
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
118
+ return str(output_path)
119
+
120
+ except Exception as e:
121
+ try:
122
+ import soundfile as sf
123
+ audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
124
+ sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
125
+ logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
126
+ return str(output_path)
127
+ except Exception as e:
128
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
129
+ raise
130
+
131
+ def convert_audio(
132
+ self,
133
+ input_path: Union[str, Path],
134
+ output_path: Union[str, Path],
135
+ output_format: str,
136
+ remove_input: bool = False,
137
+ ) -> str:
138
+ """
139
+ Convert audio format
140
+
141
+ Args:
142
+ input_path: Input audio file path
143
+ output_path: Output audio file path
144
+ output_format: Target format ('flac', 'wav', 'mp3')
145
+ remove_input: Whether to delete input file
146
+
147
+ Returns:
148
+ Output file path
149
+ """
150
+ input_path = Path(input_path)
151
+ output_path = Path(output_path)
152
+
153
+ if not input_path.exists():
154
+ raise FileNotFoundError(f"Input file not found: {input_path}")
155
+
156
+ # Load audio
157
+ audio_tensor, sample_rate = torchaudio.load(str(input_path))
158
+
159
+ # Save as new format
160
+ output_path = self.save_audio(
161
+ audio_tensor,
162
+ output_path,
163
+ sample_rate=sample_rate,
164
+ format=output_format,
165
+ channels_first=True
166
+ )
167
+
168
+ # Delete input file if needed
169
+ if remove_input:
170
+ input_path.unlink()
171
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
172
+
173
+ return output_path
174
+
175
+ def save_batch(
176
+ self,
177
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
178
+ output_dir: Union[str, Path],
179
+ file_prefix: str = "audio",
180
+ sample_rate: int = 48000,
181
+ format: Optional[str] = None,
182
+ channels_first: bool = True,
183
+ ) -> List[str]:
184
+ """
185
+ Save audio batch
186
+
187
+ Args:
188
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
189
+ output_dir: Output directory
190
+ file_prefix: File prefix
191
+ sample_rate: Sample rate
192
+ format: Audio format
193
+ channels_first: Tensor format flag
194
+
195
+ Returns:
196
+ List of saved file paths
197
+ """
198
+ output_dir = Path(output_dir)
199
+ output_dir.mkdir(parents=True, exist_ok=True)
200
+
201
+ # Process batch
202
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
203
+ # [batch, channels, samples]
204
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
205
+ elif isinstance(audio_batch, list):
206
+ audio_list = audio_batch
207
+ else:
208
+ audio_list = [audio_batch]
209
+
210
+ saved_paths = []
211
+ for i, audio in enumerate(audio_list):
212
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
213
+ saved_path = self.save_audio(
214
+ audio,
215
+ output_path,
216
+ sample_rate=sample_rate,
217
+ format=format,
218
+ channels_first=channels_first
219
+ )
220
+ saved_paths.append(saved_path)
221
+
222
+ return saved_paths
223
+
224
+
225
+ def get_audio_file_hash(audio_file) -> str:
226
+ """
227
+ Get hash identifier for an audio file.
228
+
229
+ Args:
230
+ audio_file: Path to audio file (str) or file-like object
231
+
232
+ Returns:
233
+ Hash string or empty string
234
+ """
235
+ if audio_file is None:
236
+ return ""
237
+
238
+ try:
239
+ if isinstance(audio_file, str):
240
+ if os.path.exists(audio_file):
241
+ with open(audio_file, 'rb') as f:
242
+ return hashlib.md5(f.read()).hexdigest()
243
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
244
+ elif hasattr(audio_file, 'name'):
245
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
246
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
247
+ except Exception:
248
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
249
+
250
+
251
+ def generate_uuid_from_params(params_dict) -> str:
252
+ """
253
+ Generate deterministic UUID from generation parameters.
254
+ Same parameters will always generate the same UUID.
255
+
256
+ Args:
257
+ params_dict: Dictionary of parameters
258
+
259
+ Returns:
260
+ UUID string
261
+ """
262
+
263
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
264
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
265
+ hash_hex = hash_obj.hexdigest()
266
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
267
+ return uuid_str
268
+
269
+
270
+ def generate_uuid_from_audio_data(
271
+ audio_data: Union[torch.Tensor, np.ndarray],
272
+ seed: Optional[int] = None
273
+ ) -> str:
274
+ """
275
+ Generate UUID from audio data (for caching/deduplication)
276
+
277
+ Args:
278
+ audio_data: Audio data
279
+ seed: Optional seed value
280
+
281
+ Returns:
282
+ UUID string
283
+ """
284
+ if isinstance(audio_data, torch.Tensor):
285
+ # Convert to numpy and calculate hash
286
+ audio_np = audio_data.cpu().numpy()
287
+ else:
288
+ audio_np = audio_data
289
+
290
+ # Calculate data hash
291
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
292
+
293
+ if seed is not None:
294
+ combined = f"{data_hash}_{seed}"
295
+ return hashlib.md5(combined.encode()).hexdigest()
296
+
297
+ return data_hash
298
+
299
+
300
+ # Global default instance
301
+ _default_saver = AudioSaver(default_format="flac")
302
+
303
+ SILENT_RMS_THRESHOLD = 1e-5
304
+ SILENT_PEAK_THRESHOLD = 1e-5
305
+
306
+
307
+ def is_audio_silent(
308
+ audio_data: Union[torch.Tensor, np.ndarray],
309
+ rms_threshold: float = SILENT_RMS_THRESHOLD,
310
+ peak_threshold: float = SILENT_PEAK_THRESHOLD,
311
+ channels_first: bool = True,
312
+ ) -> Tuple[bool, float, float]:
313
+ """
314
+ Check if audio is silent or near-silent (e.g. zeroed conditioning output).
315
+ Returns (is_silent, rms, peak) where rms/peak are computed over the full signal.
316
+ """
317
+ if audio_data is None:
318
+ return True, 0.0, 0.0
319
+ if isinstance(audio_data, np.ndarray):
320
+ x = np.asarray(audio_data, dtype=np.float64).ravel()
321
+ else:
322
+ x = audio_data.cpu().float().numpy().ravel()
323
+ if x.size == 0:
324
+ return True, 0.0, 0.0
325
+ rms = float(np.sqrt(np.mean(x * x)))
326
+ peak = float(np.max(np.abs(x)))
327
+ is_silent = rms <= rms_threshold and peak <= peak_threshold
328
+ return is_silent, rms, peak
329
+
330
+
331
+ def save_audio(
332
+ audio_data: Union[torch.Tensor, np.ndarray],
333
+ output_path: Union[str, Path],
334
+ sample_rate: int = 48000,
335
+ format: Optional[str] = None,
336
+ channels_first: bool = True,
337
+ ) -> str:
338
+ """
339
+ Convenience function: save audio (using default configuration)
340
+
341
+ Args:
342
+ audio_data: Audio data
343
+ output_path: Output path
344
+ sample_rate: Sample rate
345
+ format: Format (default flac)
346
+ channels_first: Tensor format flag
347
+
348
+ Returns:
349
+ Saved file path
350
+ """
351
+ return _default_saver.save_audio(
352
+ audio_data, output_path, sample_rate, format, channels_first
353
+ )
354
+
acestep/constants.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for ACE-Step
3
+ Centralized constants used across the codebase
4
+ """
5
+
6
+ # ==============================================================================
7
+ # Language Constants
8
+ # ==============================================================================
9
+
10
+ # Supported languages for vocal generation and language detection
11
+ # Covers major world languages with good TTS support in the underlying model
12
+ # 'unknown' is used when language cannot be determined automatically
13
+ VALID_LANGUAGES = [
14
+ 'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
15
+ 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
16
+ 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
17
+ 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
18
+ 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
19
+ 'unknown'
20
+ ]
21
+
22
+
23
+ # ==============================================================================
24
+ # Keyscale Constants
25
+ # ==============================================================================
26
+
27
+ # Musical note names using standard Western notation
28
+ KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
29
+
30
+ # Supported accidentals: natural, ASCII sharp/flat, Unicode sharp/flat
31
+ KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
32
+
33
+ # Major and minor scale modes
34
+ KEYSCALE_MODES = ['major', 'minor']
35
+
36
+ # Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
37
+ # Examples: "C major", "F# minor", "B♭ major"
38
+ VALID_KEYSCALES = set()
39
+ for note in KEYSCALE_NOTES:
40
+ for acc in KEYSCALE_ACCIDENTALS:
41
+ for mode in KEYSCALE_MODES:
42
+ VALID_KEYSCALES.add(f"{note}{acc} {mode}")
43
+
44
+
45
+ # ==============================================================================
46
+ # Metadata Range Constants
47
+ # ==============================================================================
48
+
49
+ # BPM (Beats Per Minute) range - covers most musical styles
50
+ # 30 BPM: Very slow ballads, ambient music
51
+ # 300 BPM: Fast electronic dance music, extreme metal
52
+ BPM_MIN = 30
53
+ BPM_MAX = 300
54
+
55
+ # Duration range (in seconds) - balances quality vs. computational cost
56
+ # 10s: Short loops, musical excerpts
57
+ # 600s: Full songs, extended compositions (10 minutes)
58
+ DURATION_MIN = 10
59
+ DURATION_MAX = 600
60
+
61
+ # Valid time signatures - common musical meter patterns
62
+ # 2: 2/4 time (marches, polka)
63
+ # 3: 3/4 time (waltzes, ballads)
64
+ # 4: 4/4 time (most pop, rock, hip-hop)
65
+ # 6: 6/8 time (compound time, folk dances)
66
+ VALID_TIME_SIGNATURES = [2, 3, 4, 6]
67
+
68
+
69
+ # ==============================================================================
70
+ # Task Type Constants
71
+ # ==============================================================================
72
+
73
+ # All supported generation tasks across different model variants
74
+ TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
75
+
76
+ # Task types available for turbo models (optimized subset for speed)
77
+ # - text2music: Generate from text descriptions
78
+ # - repaint: Selective audio editing/regeneration
79
+ # - cover: Style transfer using reference audio
80
+ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
81
+
82
+ # Task types available for base models (full feature set)
83
+ # Additional tasks requiring more computational resources:
84
+ # - extract: Separate individual tracks/stems from audio
85
+ # - lego: Multi-track generation (add layers)
86
+ # - complete: Automatic completion of partial audio
87
+ TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
88
+
89
+
90
+ # ==============================================================================
91
+ # Instruction Constants
92
+ # ==============================================================================
93
+
94
+ # Default instructions
95
+ DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
96
+ DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
97
+ DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
98
+ DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
99
+ DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
100
+
101
+ # Instruction templates for each task type
102
+ # Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
103
+ # These should be formatted using .format() or f-strings when used
104
+ TASK_INSTRUCTIONS = {
105
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
106
+ "repaint": "Repaint the mask area based on the given conditions:",
107
+ "cover": "Generate audio semantic tokens based on the given conditions:",
108
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
109
+ "extract_default": "Extract the track from the audio:",
110
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
111
+ "lego_default": "Generate the track based on the audio context:",
112
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
113
+ "complete_default": "Complete the input track:",
114
+ }
115
+
116
+
117
+ # ==============================================================================
118
+ # Track/Instrument Constants
119
+ # ==============================================================================
120
+
121
+ # Supported instrumental track types for multi-track generation and extraction
122
+ # Organized by instrument families for logical grouping:
123
+ # - Wind instruments: woodwinds, brass
124
+ # - Electronic: fx (effects), synth (synthesizer)
125
+ # - String instruments: strings, guitar, bass
126
+ # - Rhythm section: percussion, drums, keyboard
127
+ # - Vocals: backing_vocals, vocals (lead vocals)
128
+ TRACK_NAMES = [
129
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
130
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
131
+ ]
132
+
133
+ # Template for SFT (Supervised Fine-Tuning) model prompts
134
+ # Used to format inputs for the language model with instruction, caption, and metadata
135
+ SFT_GEN_PROMPT = """# Instruction
136
+ {}
137
+
138
+ # Caption
139
+ {}
140
+
141
+ # Metas
142
+ {}<|endoftext|>
143
+ """
144
+
145
+
146
+ # ==============================================================================
147
+ # GPU Memory Configuration Constants
148
+ # ==============================================================================
149
+
150
+ # GPU tier thresholds (in GB)
151
+ GPU_TIER_THRESHOLDS = {
152
+ "tier1": 4, # <= 4GB
153
+ "tier2": 6, # 4-6GB
154
+ "tier3": 8, # 6-8GB
155
+ "tier4": 12, # 8-12GB
156
+ "tier5": 16, # 12-16GB
157
+ "tier6": 24, # 16-24GB
158
+ # "unlimited" for >= 24GB
159
+ }
160
+
161
+ # LM model memory requirements (in GB)
162
+ LM_MODEL_MEMORY_GB = {
163
+ "0.6B": 3.0,
164
+ "1.7B": 8.0,
165
+ "4B": 12.0,
166
+ }
167
+
168
+ # LM model names mapping
169
+ LM_MODEL_NAMES = {
170
+ "0.6B": "acestep-5Hz-lm-0.6B",
171
+ "1.7B": "acestep-5Hz-lm-1.7B",
172
+ "4B": "acestep-5Hz-lm-4B",
173
+ }
174
+
175
+
176
+ # ==============================================================================
177
+ # Debug Constants
178
+ # ==============================================================================
179
+
180
+ # Tensor debug mode (values: "OFF" | "ON" | "VERBOSE")
181
+ TENSOR_DEBUG_MODE = "OFF"
182
+
183
+ # Placeholder debug switches for other main functionality (default "OFF")
184
+ # Update names/usage as features adopt them.
185
+ DEBUG_API_SERVER = "OFF"
186
+ DEBUG_INFERENCE = "OFF"
187
+ DEBUG_TRAINING = "OFF"
188
+ DEBUG_DATASET = "OFF"
189
+ DEBUG_AUDIO = "OFF"
190
+ DEBUG_LLM = "OFF"
191
+ DEBUG_UI = "OFF"
192
+ DEBUG_MODEL_LOADING = "OFF"
193
+ DEBUG_GPU = "OFF"
acestep/constrained_logits_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/dataset_handler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Handler Module
3
+
4
+ Handles dataset import and exploration functionality for ACE-Step training.
5
+ This module provides a placeholder implementation for dataset operations
6
+ when the full training dataset dependencies are not available.
7
+
8
+ Note: Full dataset functionality requires Text2MusicDataset which may not be
9
+ included in the basic installation to reduce dependencies.
10
+ """
11
+ from typing import Optional, Tuple, Any, Dict
12
+
13
+
14
+ class DatasetHandler:
15
+ """
16
+ Dataset Handler for Dataset Explorer functionality.
17
+
18
+ Provides interface for dataset import and exploration features in the Gradio UI.
19
+ When training dependencies are not available, returns appropriate fallback responses.
20
+ """
21
+
22
+ def __init__(self):
23
+ """Initialize dataset handler with empty state"""
24
+ self.dataset = None
25
+ self.dataset_imported = False
26
+
27
+ def import_dataset(self, dataset_type: str) -> str:
28
+ """
29
+ Import dataset (currently disabled in base installation)
30
+
31
+ Args:
32
+ dataset_type: Type of dataset to import (e.g., "train", "test", "validation")
33
+
34
+ Returns:
35
+ Status message indicating dataset import is disabled
36
+
37
+ Note:
38
+ This is a placeholder implementation. Full dataset support requires:
39
+ - Text2MusicDataset dependency
40
+ - Training data files
41
+ - Additional configuration
42
+ """
43
+ self.dataset_imported = False
44
+ return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
45
+
46
+ def get_item_data(self, *args, **kwargs) -> Tuple:
47
+ """
48
+ Get dataset item data (placeholder implementation)
49
+
50
+ Args:
51
+ *args: Variable arguments (ignored in placeholder)
52
+ **kwargs: Keyword arguments (ignored in placeholder)
53
+
54
+ Returns:
55
+ Tuple of placeholder values matching the expected return format:
56
+ (caption, lyrics, language, bpm, keyscale, ref_audio, src_audio, codes,
57
+ status_msg, instruction, duration, timesig, audio1, audio2, audio3,
58
+ metadata, task_type)
59
+
60
+ Note:
61
+ Returns empty/default values since dataset is not available.
62
+ Real implementation would return actual dataset samples.
63
+ """
64
+ return (
65
+ "", # caption: empty string
66
+ "", # lyrics: empty string
67
+ "", # language: empty string
68
+ "", # bpm: empty string
69
+ "", # keyscale: empty string
70
+ None, # ref_audio: no audio file
71
+ None, # src_audio: no audio file
72
+ None, # codes: no audio codes
73
+ "❌ Dataset not available", # status_msg: error indicator
74
+ "", # instruction: empty string
75
+ 0, # duration: zero
76
+ "", # timesig: empty string
77
+ None, # audio1: no audio
78
+ None, # audio2: no audio
79
+ None, # audio3: no audio
80
+ {}, # metadata: empty dict
81
+ "text2music" # task_type: default task
82
+ )
83
+
acestep/debug_utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Debug helpers (global).
3
+ """
4
+ from __future__ import annotations
5
+
6
+ from datetime import datetime
7
+ from typing import Optional, Callable, Union
8
+
9
+ from acestep.constants import (
10
+ TENSOR_DEBUG_MODE,
11
+ DEBUG_API_SERVER,
12
+ DEBUG_INFERENCE,
13
+ DEBUG_TRAINING,
14
+ DEBUG_DATASET,
15
+ DEBUG_AUDIO,
16
+ DEBUG_LLM,
17
+ DEBUG_UI,
18
+ DEBUG_MODEL_LOADING,
19
+ DEBUG_GPU,
20
+ )
21
+
22
+
23
+ def _normalize_mode(mode: str) -> str:
24
+ return (mode or "").strip().upper()
25
+
26
+
27
+ def is_debug_enabled(mode: str) -> bool:
28
+ return _normalize_mode(mode) != "OFF"
29
+
30
+
31
+ def is_debug_verbose(mode: str) -> bool:
32
+ return _normalize_mode(mode) == "VERBOSE"
33
+
34
+
35
+ def debug_log(message: Union[str, Callable[[], str]], *, mode: str = TENSOR_DEBUG_MODE, prefix: str = "debug") -> None:
36
+ """Emit a timestamped debug log line if the mode is enabled."""
37
+ if not is_debug_enabled(mode):
38
+ return
39
+ if callable(message):
40
+ message = message()
41
+ ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
42
+ print(f"[{prefix}] {ts} {message}", flush=True)
43
+
44
+
45
+ # Placeholder debug switches registry (for centralized access)
46
+ DEBUG_SWITCHES = {
47
+ "tensor": TENSOR_DEBUG_MODE,
48
+ "api_server": DEBUG_API_SERVER,
49
+ "inference": DEBUG_INFERENCE,
50
+ "training": DEBUG_TRAINING,
51
+ "dataset": DEBUG_DATASET,
52
+ "audio": DEBUG_AUDIO,
53
+ "llm": DEBUG_LLM,
54
+ "ui": DEBUG_UI,
55
+ "model_loading": DEBUG_MODEL_LOADING,
56
+ "gpu": DEBUG_GPU,
57
+ }
58
+
59
+
60
+ def get_debug_mode(name: str, default: str = "OFF") -> str:
61
+ """Fetch a placeholder debug mode by name."""
62
+ return DEBUG_SWITCHES.get((name or "").strip().lower(), default)
63
+
64
+
65
+ def debug_log_for(name: str, message: Union[str, Callable[[], str]], *, prefix: str | None = None) -> None:
66
+ """Emit a timestamped debug log for a named subsystem."""
67
+ mode = get_debug_mode(name)
68
+ debug_log(message, mode=mode, prefix=prefix or name)
69
+
70
+
71
+ def debug_start_for(name: str, label: str) -> Optional[float]:
72
+ """Start timing for a named subsystem."""
73
+ mode = get_debug_mode(name)
74
+ return debug_start(label, mode=mode, prefix=name)
75
+
76
+
77
+ def debug_end_for(name: str, label: str, start_ts: Optional[float]) -> None:
78
+ """End timing for a named subsystem."""
79
+ mode = get_debug_mode(name)
80
+ debug_end(label, start_ts, mode=mode, prefix=name)
81
+
82
+
83
+ def debug_log_verbose_for(name: str, message: Union[str, Callable[[], str]], *, prefix: str | None = None) -> None:
84
+ """Emit a timestamped debug log only in VERBOSE mode for a named subsystem."""
85
+ mode = get_debug_mode(name)
86
+ if not is_debug_verbose(mode):
87
+ return
88
+ debug_log(message, mode=mode, prefix=prefix or name)
89
+
90
+
91
+ def debug_start_verbose_for(name: str, label: str) -> Optional[float]:
92
+ """Start timing only in VERBOSE mode for a named subsystem."""
93
+ mode = get_debug_mode(name)
94
+ if not is_debug_verbose(mode):
95
+ return None
96
+ return debug_start(label, mode=mode, prefix=name)
97
+
98
+
99
+ def debug_end_verbose_for(name: str, label: str, start_ts: Optional[float]) -> None:
100
+ """End timing only in VERBOSE mode for a named subsystem."""
101
+ mode = get_debug_mode(name)
102
+ if not is_debug_verbose(mode):
103
+ return
104
+ debug_end(label, start_ts, mode=mode, prefix=name)
105
+
106
+
107
+ def debug_start(name: str, *, mode: str = TENSOR_DEBUG_MODE, prefix: str = "debug") -> Optional[float]:
108
+ """Return a start timestamp (perf counter) if enabled, otherwise None."""
109
+ if not is_debug_enabled(mode):
110
+ return None
111
+ debug_log(f"START {name}", mode=mode, prefix=prefix)
112
+ from time import perf_counter
113
+ return perf_counter()
114
+
115
+
116
+ def debug_end(name: str, start_ts: Optional[float], *, mode: str = TENSOR_DEBUG_MODE, prefix: str = "debug") -> None:
117
+ """Emit an END log with elapsed ms if enabled and start_ts is present."""
118
+ if start_ts is None or not is_debug_enabled(mode):
119
+ return
120
+ from time import perf_counter
121
+ elapsed_ms = (perf_counter() - start_ts) * 1000.0
122
+ debug_log(f"END {name} ({elapsed_ms:.1f} ms)", mode=mode, prefix=prefix)
acestep/dit_alignment_score.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT Alignment Score Module
3
+
4
+ This module provides lyrics-to-audio alignment using cross-attention matrices
5
+ from DiT model for generating LRC timestamps.
6
+
7
+ Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
8
+ """
9
+ import numba
10
+ import torch
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
+
16
+
17
+ # ================= Data Classes =================
18
+ @dataclass
19
+ class TokenTimestamp:
20
+ """Stores per-token timing information."""
21
+ token_id: int
22
+ text: str
23
+ start: float
24
+ end: float
25
+ probability: float
26
+
27
+
28
+ @dataclass
29
+ class SentenceTimestamp:
30
+ """Stores per-sentence timing information with token list."""
31
+ text: str
32
+ start: float
33
+ end: float
34
+ tokens: List[TokenTimestamp]
35
+ confidence: float
36
+
37
+
38
+ # ================= DTW Algorithm (Numba Optimized) =================
39
+ @numba.jit(nopython=True)
40
+ def dtw_cpu(x: np.ndarray):
41
+ """
42
+ Dynamic Time Warping algorithm optimized with Numba.
43
+
44
+ Args:
45
+ x: Cost matrix of shape [N, M]
46
+
47
+ Returns:
48
+ Tuple of (text_indices, time_indices) arrays
49
+ """
50
+ N, M = x.shape
51
+ # Use float32 for memory efficiency
52
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
53
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
54
+ cost[0, 0] = 0
55
+
56
+ for j in range(1, M + 1):
57
+ for i in range(1, N + 1):
58
+ c0 = cost[i - 1, j - 1]
59
+ c1 = cost[i - 1, j]
60
+ c2 = cost[i, j - 1]
61
+
62
+ if c0 < c1 and c0 < c2:
63
+ c, t = c0, 0
64
+ elif c1 < c0 and c1 < c2:
65
+ c, t = c1, 1
66
+ else:
67
+ c, t = c2, 2
68
+
69
+ cost[i, j] = x[i - 1, j - 1] + c
70
+ trace[i, j] = t
71
+
72
+ return _backtrace(trace, N, M)
73
+
74
+
75
+ @numba.jit(nopython=True)
76
+ def _backtrace(trace: np.ndarray, N: int, M: int):
77
+ """
78
+ Optimized backtrace function for DTW.
79
+
80
+ Args:
81
+ trace: Trace matrix of shape (N+1, M+1)
82
+ N, M: Original matrix dimensions
83
+
84
+ Returns:
85
+ Path array of shape (2, path_len) - first row is text indices, second is time indices
86
+ """
87
+ # Boundary handling
88
+ trace[0, :] = 2
89
+ trace[:, 0] = 1
90
+
91
+ # Pre-allocate array, max path length is N+M
92
+ max_path_len = N + M
93
+ path = np.zeros((2, max_path_len), dtype=np.int32)
94
+
95
+ i, j = N, M
96
+ path_idx = max_path_len - 1
97
+
98
+ while i > 0 or j > 0:
99
+ path[0, path_idx] = i - 1 # text index
100
+ path[1, path_idx] = j - 1 # time index
101
+ path_idx -= 1
102
+
103
+ t = trace[i, j]
104
+ if t == 0:
105
+ i -= 1
106
+ j -= 1
107
+ elif t == 1:
108
+ i -= 1
109
+ elif t == 2:
110
+ j -= 1
111
+ else:
112
+ break
113
+
114
+ actual_len = max_path_len - path_idx - 1
115
+ return path[:, path_idx + 1:max_path_len]
116
+
117
+
118
+ # ================= Utility Functions =================
119
+ def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
120
+ """
121
+ Apply median filter to tensor.
122
+
123
+ Args:
124
+ x: Input tensor
125
+ filter_width: Width of median filter
126
+
127
+ Returns:
128
+ Filtered tensor
129
+ """
130
+ pad_width = filter_width // 2
131
+ if x.shape[-1] <= pad_width:
132
+ return x
133
+ if x.ndim == 2:
134
+ x = x[None, :]
135
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
136
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
137
+ if result.ndim > 2:
138
+ result = result.squeeze(0)
139
+ return result
140
+
141
+
142
+ # ================= Main Aligner Class =================
143
+ class MusicStampsAligner:
144
+ """
145
+ Aligner class for generating lyrics timestamps from cross-attention matrices.
146
+
147
+ Uses bidirectional consensus denoising and DTW for alignment.
148
+ """
149
+
150
+ def __init__(self, tokenizer):
151
+ """
152
+ Initialize the aligner.
153
+
154
+ Args:
155
+ tokenizer: Text tokenizer for decoding tokens
156
+ """
157
+ self.tokenizer = tokenizer
158
+
159
+ def _apply_bidirectional_consensus(
160
+ self,
161
+ weights_stack: torch.Tensor,
162
+ violence_level: float,
163
+ medfilt_width: int
164
+ ) -> tuple:
165
+ """
166
+ Core denoising logic using bidirectional consensus.
167
+
168
+ Args:
169
+ weights_stack: Attention weights [Heads, Tokens, Frames]
170
+ violence_level: Denoising strength coefficient
171
+ medfilt_width: Median filter width
172
+
173
+ Returns:
174
+ Tuple of (calc_matrix, energy_matrix) as numpy arrays
175
+ """
176
+ # A. Bidirectional Consensus
177
+ row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
178
+ col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
179
+ processed = row_prob * col_prob
180
+
181
+ # 1. Row suppression (kill horizontal crossing lines)
182
+ row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
183
+ processed = processed - (violence_level * row_medians)
184
+ processed = torch.relu(processed)
185
+
186
+ # 2. Column suppression (kill vertical crossing lines)
187
+ col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
188
+ processed = processed - (violence_level * col_medians)
189
+ processed = torch.relu(processed)
190
+
191
+ # C. Power sharpening
192
+ processed = processed ** 2
193
+
194
+ # Energy matrix for confidence
195
+ energy_matrix = processed.mean(dim=0).cpu().numpy()
196
+
197
+ # D. Z-Score normalization
198
+ std, mean = torch.std_mean(processed, unbiased=False)
199
+ weights_processed = (processed - mean) / (std + 1e-9)
200
+
201
+ # E. Median filtering
202
+ weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
203
+ calc_matrix = weights_processed.mean(dim=0).numpy()
204
+
205
+ return calc_matrix, energy_matrix
206
+
207
+ def _preprocess_attention(
208
+ self,
209
+ attention_matrix: torch.Tensor,
210
+ custom_config: Dict[int, List[int]],
211
+ violence_level: float,
212
+ medfilt_width: int = 7
213
+ ) -> tuple:
214
+ """
215
+ Preprocess attention matrix for alignment.
216
+
217
+ Args:
218
+ attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
219
+ custom_config: Dict mapping layer indices to head indices
220
+ violence_level: Denoising strength
221
+ medfilt_width: Median filter width
222
+
223
+ Returns:
224
+ Tuple of (calc_matrix, energy_matrix, visual_matrix)
225
+ """
226
+ if not isinstance(attention_matrix, torch.Tensor):
227
+ weights = torch.tensor(attention_matrix)
228
+ else:
229
+ weights = attention_matrix.clone()
230
+
231
+ weights = weights.cpu().float()
232
+
233
+ selected_tensors = []
234
+ for layer_idx, head_indices in custom_config.items():
235
+ for head_idx in head_indices:
236
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
237
+ head_matrix = weights[layer_idx, head_idx]
238
+ selected_tensors.append(head_matrix)
239
+
240
+ if not selected_tensors:
241
+ return None, None, None
242
+
243
+ # Stack selected heads: [Heads, Tokens, Frames]
244
+ weights_stack = torch.stack(selected_tensors, dim=0)
245
+ visual_matrix = weights_stack.mean(dim=0).numpy()
246
+
247
+ calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
248
+ weights_stack, violence_level, medfilt_width
249
+ )
250
+
251
+ return calc_matrix, energy_matrix, visual_matrix
252
+
253
+ def stamps_align_info(
254
+ self,
255
+ attention_matrix: torch.Tensor,
256
+ lyrics_tokens: List[int],
257
+ total_duration_seconds: float,
258
+ custom_config: Dict[int, List[int]],
259
+ return_matrices: bool = False,
260
+ violence_level: float = 2.0,
261
+ medfilt_width: int = 1
262
+ ) -> Dict[str, Any]:
263
+ """
264
+ Get alignment information from attention matrix.
265
+
266
+ Args:
267
+ attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
268
+ lyrics_tokens: List of lyrics token IDs
269
+ total_duration_seconds: Total audio duration in seconds
270
+ custom_config: Dict mapping layer indices to head indices
271
+ return_matrices: Whether to return intermediate matrices
272
+ violence_level: Denoising strength
273
+ medfilt_width: Median filter width
274
+
275
+ Returns:
276
+ Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
277
+ and optionally energy_matrix and vis_matrix
278
+ """
279
+ calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
280
+ attention_matrix, custom_config, violence_level, medfilt_width
281
+ )
282
+
283
+ if calc_matrix is None:
284
+ return {
285
+ "calc_matrix": None,
286
+ "lyrics_tokens": lyrics_tokens,
287
+ "total_duration_seconds": total_duration_seconds,
288
+ "error": "No valid attention heads found"
289
+ }
290
+
291
+ return_dict = {
292
+ "calc_matrix": calc_matrix,
293
+ "lyrics_tokens": lyrics_tokens,
294
+ "total_duration_seconds": total_duration_seconds
295
+ }
296
+
297
+ if return_matrices:
298
+ return_dict['energy_matrix'] = energy_matrix
299
+ return_dict['vis_matrix'] = visual_matrix
300
+
301
+ return return_dict
302
+
303
+ def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
304
+ """
305
+ Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
306
+
307
+ For Chinese and other multi-byte characters, the tokenizer may split them
308
+ into multiple byte-level tokens. Decoding each token individually produces
309
+ invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
310
+ to correctly track which characters each token contributes.
311
+
312
+ Args:
313
+ token_ids: List of token IDs
314
+
315
+ Returns:
316
+ List of decoded text for each token position
317
+ """
318
+ decoded_tokens = []
319
+ prev_bytes = b""
320
+
321
+ for i in range(len(token_ids)):
322
+ # Decode tokens from start to current position
323
+ current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
324
+ current_bytes = current_text.encode('utf-8', errors='surrogatepass')
325
+
326
+ # The contribution of current token is the new bytes added
327
+ if len(current_bytes) >= len(prev_bytes):
328
+ new_bytes = current_bytes[len(prev_bytes):]
329
+ # Try to decode the new bytes; if incomplete, use empty string
330
+ try:
331
+ token_text = new_bytes.decode('utf-8')
332
+ except UnicodeDecodeError:
333
+ # Incomplete UTF-8 sequence, this token doesn't complete a character
334
+ token_text = ""
335
+ else:
336
+ # Edge case: current decode is shorter (shouldn't happen normally)
337
+ token_text = ""
338
+
339
+ decoded_tokens.append(token_text)
340
+ prev_bytes = current_bytes
341
+
342
+ return decoded_tokens
343
+
344
+ def token_timestamps(
345
+ self,
346
+ calc_matrix: np.ndarray,
347
+ lyrics_tokens: List[int],
348
+ total_duration_seconds: float
349
+ ) -> List[TokenTimestamp]:
350
+ """
351
+ Generate per-token timestamps using DTW.
352
+
353
+ Args:
354
+ calc_matrix: Processed attention matrix [Tokens, Frames]
355
+ lyrics_tokens: List of token IDs
356
+ total_duration_seconds: Total audio duration
357
+
358
+ Returns:
359
+ List of TokenTimestamp objects
360
+ """
361
+ n_frames = calc_matrix.shape[-1]
362
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
363
+
364
+ seconds_per_frame = total_duration_seconds / n_frames
365
+ alignment_results = []
366
+
367
+ # Use incremental decoding to properly handle multi-byte UTF-8 characters
368
+ decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
369
+
370
+ for i in range(len(lyrics_tokens)):
371
+ mask = (text_indices == i)
372
+
373
+ if not np.any(mask):
374
+ start = alignment_results[-1].end if alignment_results else 0.0
375
+ end = start
376
+ token_conf = 0.0
377
+ else:
378
+ times = time_indices[mask] * seconds_per_frame
379
+ start = times[0]
380
+ end = times[-1]
381
+ token_conf = 0.0
382
+
383
+ if end < start:
384
+ end = start
385
+
386
+ alignment_results.append(TokenTimestamp(
387
+ token_id=lyrics_tokens[i],
388
+ text=decoded_tokens[i],
389
+ start=float(start),
390
+ end=float(end),
391
+ probability=token_conf
392
+ ))
393
+
394
+ return alignment_results
395
+
396
+ def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
397
+ """
398
+ Decode a sentence by decoding all token IDs together.
399
+ This avoids UTF-8 encoding issues from joining individual token texts.
400
+
401
+ Args:
402
+ tokens: List of TokenTimestamp objects
403
+
404
+ Returns:
405
+ Properly decoded sentence text
406
+ """
407
+ token_ids = [t.token_id for t in tokens]
408
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
409
+
410
+ def sentence_timestamps(
411
+ self,
412
+ token_alignment: List[TokenTimestamp]
413
+ ) -> List[SentenceTimestamp]:
414
+ """
415
+ Group token timestamps into sentence timestamps.
416
+
417
+ Args:
418
+ token_alignment: List of TokenTimestamp objects
419
+
420
+ Returns:
421
+ List of SentenceTimestamp objects
422
+ """
423
+ results = []
424
+ current_tokens = []
425
+
426
+ for token in token_alignment:
427
+ current_tokens.append(token)
428
+
429
+ if '\n' in token.text:
430
+ # Decode all token IDs together to avoid UTF-8 issues
431
+ full_text = self._decode_sentence_from_tokens(current_tokens)
432
+
433
+ if full_text.strip():
434
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
435
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
436
+
437
+ results.append(SentenceTimestamp(
438
+ text=full_text.strip(),
439
+ start=round(current_tokens[0].start, 3),
440
+ end=round(current_tokens[-1].end, 3),
441
+ tokens=list(current_tokens),
442
+ confidence=sent_conf
443
+ ))
444
+
445
+ current_tokens = []
446
+
447
+ # Handle last sentence
448
+ if current_tokens:
449
+ # Decode all token IDs together to avoid UTF-8 issues
450
+ full_text = self._decode_sentence_from_tokens(current_tokens)
451
+ if full_text.strip():
452
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
453
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
454
+
455
+ results.append(SentenceTimestamp(
456
+ text=full_text.strip(),
457
+ start=round(current_tokens[0].start, 3),
458
+ end=round(current_tokens[-1].end, 3),
459
+ tokens=list(current_tokens),
460
+ confidence=sent_conf
461
+ ))
462
+
463
+ # Normalize confidence scores
464
+ if results:
465
+ all_scores = [s.confidence for s in results]
466
+ min_score = min(all_scores)
467
+ max_score = max(all_scores)
468
+ score_range = max_score - min_score
469
+
470
+ if score_range > 1e-9:
471
+ for s in results:
472
+ normalized_score = (s.confidence - min_score) / score_range
473
+ s.confidence = round(normalized_score, 2)
474
+ else:
475
+ for s in results:
476
+ s.confidence = round(s.confidence, 2)
477
+
478
+ return results
479
+
480
+ def format_lrc(
481
+ self,
482
+ sentence_timestamps: List[SentenceTimestamp],
483
+ include_end_time: bool = False
484
+ ) -> str:
485
+ """
486
+ Format sentence timestamps as LRC lyrics format.
487
+
488
+ Args:
489
+ sentence_timestamps: List of SentenceTimestamp objects
490
+ include_end_time: Whether to include end time (enhanced LRC format)
491
+
492
+ Returns:
493
+ LRC formatted string
494
+ """
495
+ lines = []
496
+
497
+ for sentence in sentence_timestamps:
498
+ # Convert seconds to mm:ss.xx format
499
+ start_minutes = int(sentence.start // 60)
500
+ start_seconds = sentence.start % 60
501
+
502
+ if include_end_time:
503
+ end_minutes = int(sentence.end // 60)
504
+ end_seconds = sentence.end % 60
505
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
506
+ else:
507
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
508
+
509
+ # Clean the text (remove structural tags like [verse], [chorus])
510
+ text = sentence.text
511
+
512
+ lines.append(f"{timestamp}{text}")
513
+
514
+ return "\n".join(lines)
515
+
516
+ def get_timestamps_and_lrc(
517
+ self,
518
+ calc_matrix: np.ndarray,
519
+ lyrics_tokens: List[int],
520
+ total_duration_seconds: float
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Convenience method to get both timestamps and LRC in one call.
524
+
525
+ Args:
526
+ calc_matrix: Processed attention matrix
527
+ lyrics_tokens: List of token IDs
528
+ total_duration_seconds: Total audio duration
529
+
530
+ Returns:
531
+ Dict containing token_timestamps, sentence_timestamps, and lrc_text
532
+ """
533
+ token_stamps = self.token_timestamps(
534
+ calc_matrix=calc_matrix,
535
+ lyrics_tokens=lyrics_tokens,
536
+ total_duration_seconds=total_duration_seconds
537
+ )
538
+
539
+ sentence_stamps = self.sentence_timestamps(token_stamps)
540
+ lrc_text = self.format_lrc(sentence_stamps)
541
+
542
+ return {
543
+ "token_timestamps": token_stamps,
544
+ "sentence_timestamps": sentence_stamps,
545
+ "lrc_text": lrc_text
546
+ }
547
+
548
+
549
+ class MusicLyricScorer:
550
+ """
551
+ Scorer class for evaluating lyrics-to-audio alignment quality.
552
+
553
+ Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
554
+ using tensor operations for potential differentiability or GPU acceleration.
555
+ """
556
+
557
+ def __init__(self, tokenizer: Any):
558
+ """
559
+ Initialize the aligner.
560
+
561
+ Args:
562
+ tokenizer: Tokenizer instance (must implement .decode()).
563
+ """
564
+ self.tokenizer = tokenizer
565
+
566
+ def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
567
+ """
568
+ Generate a mask distinguishing lyrics (1) from structural tags (0).
569
+ Uses self.tokenizer to decode tokens.
570
+
571
+ Args:
572
+ token_ids: List of token IDs.
573
+
574
+ Returns:
575
+ Numpy array of shape [len(token_ids)] with 1 or 0.
576
+ """
577
+ decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
578
+ mask = np.ones(len(token_ids), dtype=np.int32)
579
+ in_bracket = False
580
+
581
+ for i, token_str in enumerate(decoded_tokens):
582
+ if '[' in token_str:
583
+ in_bracket = True
584
+ if in_bracket:
585
+ mask[i] = 0
586
+ if ']' in token_str:
587
+ in_bracket = False
588
+ mask[i] = 0
589
+ return mask
590
+
591
+ def _preprocess_attention(
592
+ self,
593
+ attention_matrix: Union[torch.Tensor, np.ndarray],
594
+ custom_config: Dict[int, List[int]],
595
+ medfilt_width: int = 1
596
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
597
+ """
598
+ Extracts and normalizes the attention matrix.
599
+
600
+ Logic V4: Uses Min-Max normalization to highlight energy differences.
601
+
602
+ Args:
603
+ attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
604
+ custom_config: Config mapping layers to heads.
605
+ medfilt_width: Width for median filtering.
606
+
607
+ Returns:
608
+ Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
609
+ """
610
+ # 1. Prepare Tensor
611
+ if not isinstance(attention_matrix, torch.Tensor):
612
+ weights = torch.tensor(attention_matrix)
613
+ else:
614
+ weights = attention_matrix.clone()
615
+ weights = weights.cpu().float()
616
+
617
+ # 2. Select Heads based on config
618
+ selected_tensors = []
619
+ for layer_idx, head_indices in custom_config.items():
620
+ for head_idx in head_indices:
621
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
622
+ selected_tensors.append(weights[layer_idx, head_idx])
623
+
624
+ if not selected_tensors:
625
+ return None, None, None
626
+
627
+ weights_stack = torch.stack(selected_tensors, dim=0)
628
+
629
+ # 3. Average Heads
630
+ avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
631
+
632
+ # 4. Preprocessing Logic
633
+ # Min-Max normalization preserving energy distribution
634
+ # Median filter is applied to the energy matrix
635
+ energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
636
+ energy_matrix = energy_tensor.numpy()
637
+
638
+ e_min, e_max = energy_matrix.min(), energy_matrix.max()
639
+
640
+ if e_max - e_min > 1e-9:
641
+ energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
642
+ else:
643
+ energy_matrix = np.zeros_like(energy_matrix)
644
+
645
+ # Contrast enhancement for DTW pathfinding
646
+ # calc_matrix is used for pathfinding, energy_matrix for scoring
647
+ calc_matrix = energy_matrix ** 2
648
+
649
+ return calc_matrix, energy_matrix, avg_weights
650
+
651
+ def _compute_alignment_metrics(
652
+ self,
653
+ energy_matrix: torch.Tensor,
654
+ path_coords: torch.Tensor,
655
+ type_mask: torch.Tensor,
656
+ time_weight: float = 0.01,
657
+ overlap_frames: float = 9.0,
658
+ instrumental_weight: float = 1.0
659
+ ) -> Tuple[float, float, float]:
660
+ """
661
+ Core metric calculation logic using high-precision Tensor operations.
662
+
663
+ Args:
664
+ energy_matrix: Normalized energy [Rows, Cols].
665
+ path_coords: DTW path coordinates [Steps, 2].
666
+ type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
667
+ time_weight: Minimum energy threshold for monotonicity.
668
+ overlap_frames: Allowed overlap for monotonicity check.
669
+ instrumental_weight: Weight for non-lyric tokens in confidence calc.
670
+
671
+ Returns:
672
+ Tuple of (coverage, monotonicity, confidence).
673
+ """
674
+ # Ensure high precision for internal calculation
675
+ energy_matrix = energy_matrix.to(dtype=torch.float64)
676
+ path_coords = path_coords.long()
677
+ type_mask = type_mask.long()
678
+
679
+ device = energy_matrix.device
680
+ rows, cols = energy_matrix.shape
681
+
682
+ is_lyrics_row = (type_mask == 1)
683
+
684
+ # ================= A. Coverage Score =================
685
+ # Ratio of lyric lines that have significant energy peak
686
+ row_max_energies = energy_matrix.max(dim=1).values
687
+ total_sung_rows = is_lyrics_row.sum().double()
688
+
689
+ coverage_threshold = 0.1
690
+ valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
691
+ valid_sung_rows = valid_sung_mask.sum().double()
692
+
693
+ if total_sung_rows > 0:
694
+ coverage_score = valid_sung_rows / total_sung_rows
695
+ else:
696
+ coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
697
+
698
+ # ================= B. Monotonicity Score =================
699
+ # Check if the "center of mass" of lyric lines moves forward in time
700
+ col_indices = torch.arange(cols, device=device, dtype=torch.float64)
701
+
702
+ # Zero out low energy noise
703
+ weights = torch.where(
704
+ energy_matrix > time_weight,
705
+ energy_matrix,
706
+ torch.zeros_like(energy_matrix)
707
+ )
708
+
709
+ sum_w = weights.sum(dim=1)
710
+ sum_t = (weights * col_indices).sum(dim=1)
711
+
712
+ # Calculate centroids
713
+ centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
714
+ valid_w_mask = sum_w > 1e-9
715
+ centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
716
+
717
+ # Extract sequence of valid lyrics centroids
718
+ valid_sequence_mask = is_lyrics_row & (centroids >= 0)
719
+ sung_centroids = centroids[valid_sequence_mask]
720
+
721
+ cnt = sung_centroids.shape[0]
722
+ if cnt > 1:
723
+ curr_c = sung_centroids[:-1]
724
+ next_c = sung_centroids[1:]
725
+
726
+ # Check non-decreasing order with overlap tolerance
727
+ non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
728
+ pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
729
+ monotonicity_score = non_decreasing / pairs
730
+ else:
731
+ monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
732
+
733
+ # ================= C. Path Confidence =================
734
+ # Average energy along the optimal path
735
+ if path_coords.shape[0] > 0:
736
+ p_rows = path_coords[:, 0]
737
+ p_cols = path_coords[:, 1]
738
+
739
+ path_energies = energy_matrix[p_rows, p_cols]
740
+ step_weights = torch.ones_like(path_energies)
741
+
742
+ # Lower weight for instrumental/tag steps
743
+ is_inst_step = (type_mask[p_rows] == 0)
744
+ step_weights[is_inst_step] = instrumental_weight
745
+
746
+ total_energy = (path_energies * step_weights).sum()
747
+ total_steps = step_weights.sum()
748
+
749
+ if total_steps > 0:
750
+ path_confidence = total_energy / total_steps
751
+ else:
752
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
753
+ else:
754
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
755
+
756
+ return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
757
+
758
+ def lyrics_alignment_info(
759
+ self,
760
+ attention_matrix: Union[torch.Tensor, np.ndarray],
761
+ token_ids: List[int],
762
+ custom_config: Dict[int, List[int]],
763
+ return_matrices: bool = False,
764
+ medfilt_width: int = 1
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Generates alignment path and processed matrices.
768
+
769
+ Args:
770
+ attention_matrix: Input attention tensor.
771
+ token_ids: Corresponding token IDs.
772
+ custom_config: Layer/Head configuration.
773
+ return_matrices: If True, returns matrices in the output.
774
+ medfilt_width: Median filter width.
775
+
776
+ Returns:
777
+ Dict or AlignmentInfo object containing path and masks.
778
+ """
779
+ calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
780
+ attention_matrix, custom_config, medfilt_width
781
+ )
782
+
783
+ if calc_matrix is None:
784
+ return {
785
+ "calc_matrix": None,
786
+ "error": "No valid attention heads found"
787
+ }
788
+
789
+ # 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
790
+ # Uses self.tokenizer internally
791
+ type_mask = self._generate_token_type_mask(token_ids)
792
+
793
+ # Safety check for shape mismatch
794
+ if len(type_mask) != energy_matrix.shape[0]:
795
+ # Fallback to all lyrics if shapes don't align
796
+ type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
797
+
798
+ # 2. DTW Pathfinding
799
+ # Using negative calc_matrix because DTW minimizes cost
800
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
801
+ path_coords = np.stack([text_indices, time_indices], axis=1)
802
+
803
+ return_dict = {
804
+ "path_coords": path_coords,
805
+ "type_mask": type_mask,
806
+ "energy_matrix": energy_matrix
807
+ }
808
+ if return_matrices:
809
+ return_dict['calc_matrix'] = calc_matrix
810
+ return_dict['vis_matrix'] = vis_matrix
811
+
812
+ return return_dict
813
+
814
+ def calculate_score(
815
+ self,
816
+ energy_matrix: Union[torch.Tensor, np.ndarray],
817
+ type_mask: Union[torch.Tensor, np.ndarray],
818
+ path_coords: Union[torch.Tensor, np.ndarray],
819
+ time_weight: float = 0.01,
820
+ overlap_frames: float = 9.0,
821
+ instrumental_weight: float = 1.0
822
+ ) -> Dict[str, Any]:
823
+ """
824
+ Calculates the final alignment score based on pre-computed components.
825
+
826
+ Args:
827
+ energy_matrix: Processed energy matrix.
828
+ type_mask: Token type mask.
829
+ path_coords: DTW path coordinates.
830
+ time_weight: Minimum energy threshold for monotonicity.
831
+ overlap_frames: Allowed backward movement frames.
832
+ instrumental_weight: Weight for non-lyric path steps.
833
+
834
+ Returns:
835
+ AlignmentScore object containing individual metrics and final score.
836
+ """
837
+ # Ensure Inputs are Tensors on the correct device
838
+ if not isinstance(energy_matrix, torch.Tensor):
839
+ # Use available accelerator device; fallback to CPU if none
840
+ if torch.cuda.is_available():
841
+ _score_device = "cuda"
842
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
843
+ _score_device = "mps"
844
+ else:
845
+ _score_device = "cpu"
846
+ energy_matrix = torch.tensor(energy_matrix, device=_score_device, dtype=torch.float32)
847
+
848
+ device = energy_matrix.device
849
+
850
+ if not isinstance(type_mask, torch.Tensor):
851
+ type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
852
+ else:
853
+ type_mask = type_mask.to(device=device, dtype=torch.long)
854
+
855
+ if not isinstance(path_coords, torch.Tensor):
856
+ path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
857
+ else:
858
+ path_coords = path_coords.to(device=device, dtype=torch.long)
859
+
860
+ # Compute Metrics
861
+ coverage, monotonicity, confidence = self._compute_alignment_metrics(
862
+ energy_matrix=energy_matrix,
863
+ path_coords=path_coords,
864
+ type_mask=type_mask,
865
+ time_weight=time_weight,
866
+ overlap_frames=overlap_frames,
867
+ instrumental_weight=instrumental_weight
868
+ )
869
+
870
+ # Final Score Calculation
871
+ # (Cov^2 * Mono^2 * Conf)
872
+ final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
873
+ final_score = float(np.clip(final_score, 0.0, 1.0))
874
+
875
+ return {
876
+ "lyrics_score": round(final_score, 4)
877
+ }
acestep/genres_vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
acestep/gpu_config.py ADDED
@@ -0,0 +1,549 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU Configuration Module
3
+ Centralized GPU memory detection and adaptive configuration management
4
+
5
+ Debug Mode:
6
+ Set environment variable MAX_CUDA_VRAM to simulate different GPU memory sizes.
7
+ Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
8
+
9
+ For MPS testing, use MAX_MPS_VRAM to simulate MPS memory.
10
+ Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
11
+
12
+ This is useful for testing GPU tier configurations on high-end hardware.
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ from dataclasses import dataclass
18
+ from typing import Optional, List, Dict, Tuple
19
+ from loguru import logger
20
+
21
+
22
+ # Environment variable for debugging/testing different GPU memory configurations
23
+ DEBUG_MAX_CUDA_VRAM_ENV = "MAX_CUDA_VRAM"
24
+ DEBUG_MAX_MPS_VRAM_ENV = "MAX_MPS_VRAM"
25
+
26
+ # Tolerance for 16GB detection: reported VRAM like 15.5GB is effectively 16GB hardware
27
+ # Real-world 16GB GPUs often report 15.7-15.9GB due to system/driver reservations
28
+ VRAM_16GB_TOLERANCE_GB = 0.5
29
+ VRAM_16GB_MIN_GB = 16.0 - VRAM_16GB_TOLERANCE_GB # treat as 16GB class if >= this
30
+
31
+ # PyTorch installation URLs for diagnostics
32
+ PYTORCH_CUDA_INSTALL_URL = "https://download.pytorch.org/whl/cu121"
33
+ PYTORCH_ROCM_INSTALL_URL = "https://download.pytorch.org/whl/rocm6.0"
34
+
35
+
36
+ @dataclass
37
+ class GPUConfig:
38
+ """GPU configuration based on available memory"""
39
+ tier: str # "tier1", "tier2", etc. or "unlimited"
40
+ gpu_memory_gb: float
41
+
42
+ # Duration limits (in seconds)
43
+ max_duration_with_lm: int # When LM is initialized
44
+ max_duration_without_lm: int # When LM is not initialized
45
+
46
+ # Batch size limits
47
+ max_batch_size_with_lm: int
48
+ max_batch_size_without_lm: int
49
+
50
+ # LM configuration
51
+ init_lm_default: bool # Whether to initialize LM by default
52
+ available_lm_models: List[str] # Available LM models for this tier
53
+
54
+ # LM memory allocation (GB) for each model size
55
+ lm_memory_gb: Dict[str, float] # e.g., {"0.6B": 3, "1.7B": 8, "4B": 12}
56
+
57
+
58
+ # GPU tier configurations
59
+ GPU_TIER_CONFIGS = {
60
+ "tier1": { # <= 4GB
61
+ "max_duration_with_lm": 180, # 3 minutes
62
+ "max_duration_without_lm": 180, # 3 minutes
63
+ "max_batch_size_with_lm": 1,
64
+ "max_batch_size_without_lm": 1,
65
+ "init_lm_default": False,
66
+ "available_lm_models": [],
67
+ "lm_memory_gb": {},
68
+ },
69
+ "tier2": { # 4-6GB
70
+ "max_duration_with_lm": 360, # 6 minutes
71
+ "max_duration_without_lm": 360, # 6 minutes
72
+ "max_batch_size_with_lm": 1,
73
+ "max_batch_size_without_lm": 1,
74
+ "init_lm_default": False,
75
+ "available_lm_models": [],
76
+ "lm_memory_gb": {},
77
+ },
78
+ "tier3": { # 6-8GB
79
+ "max_duration_with_lm": 240, # 4 minutes with LM
80
+ "max_duration_without_lm": 360, # 6 minutes without LM
81
+ "max_batch_size_with_lm": 1,
82
+ "max_batch_size_without_lm": 2,
83
+ "init_lm_default": False, # Don't init by default due to limited memory
84
+ "available_lm_models": ["acestep-5Hz-lm-0.6B"],
85
+ "lm_memory_gb": {"0.6B": 3},
86
+ },
87
+ "tier4": { # 8-12GB
88
+ "max_duration_with_lm": 240, # 4 minutes with LM
89
+ "max_duration_without_lm": 360, # 6 minutes without LM
90
+ "max_batch_size_with_lm": 2,
91
+ "max_batch_size_without_lm": 4,
92
+ "init_lm_default": False, # Don't init by default
93
+ "available_lm_models": ["acestep-5Hz-lm-0.6B"],
94
+ "lm_memory_gb": {"0.6B": 3},
95
+ },
96
+ "tier5": { # 12-16GB
97
+ "max_duration_with_lm": 240, # 4 minutes with LM
98
+ "max_duration_without_lm": 360, # 6 minutes without LM
99
+ "max_batch_size_with_lm": 2,
100
+ "max_batch_size_without_lm": 4,
101
+ "init_lm_default": True,
102
+ "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B"],
103
+ "lm_memory_gb": {"0.6B": 3, "1.7B": 8},
104
+ },
105
+ "tier6": { # 16-24GB
106
+ "max_duration_with_lm": 480, # 8 minutes
107
+ "max_duration_without_lm": 480, # 8 minutes
108
+ "max_batch_size_with_lm": 4,
109
+ "max_batch_size_without_lm": 8,
110
+ "init_lm_default": True,
111
+ "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
112
+ "lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
113
+ },
114
+ "unlimited": { # >= 24GB
115
+ "max_duration_with_lm": 600, # 10 minutes (max supported)
116
+ "max_duration_without_lm": 600, # 10 minutes
117
+ "max_batch_size_with_lm": 8,
118
+ "max_batch_size_without_lm": 8,
119
+ "init_lm_default": True,
120
+ "available_lm_models": ["acestep-5Hz-lm-0.6B", "acestep-5Hz-lm-1.7B", "acestep-5Hz-lm-4B"],
121
+ "lm_memory_gb": {"0.6B": 3, "1.7B": 8, "4B": 12},
122
+ },
123
+ }
124
+
125
+
126
+ def get_gpu_memory_gb() -> float:
127
+ """
128
+ Get GPU memory in GB. Returns 0 if no GPU is available.
129
+
130
+ Debug Mode:
131
+ Set environment variable MAX_CUDA_VRAM to override the detected GPU memory.
132
+ Example: MAX_CUDA_VRAM=8 python acestep # Simulates 8GB GPU
133
+
134
+ For MPS testing, set MAX_MPS_VRAM to override MPS memory detection.
135
+ Example: MAX_MPS_VRAM=16 python acestep # Simulates 16GB MPS
136
+
137
+ This allows testing different GPU tier configurations on high-end hardware.
138
+ """
139
+ # Check for debug override first
140
+ debug_vram = os.environ.get(DEBUG_MAX_CUDA_VRAM_ENV)
141
+ if debug_vram is not None:
142
+ try:
143
+ simulated_gb = float(debug_vram)
144
+ logger.warning(f"⚠️ DEBUG MODE: Simulating GPU memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_CUDA_VRAM_ENV} environment variable)")
145
+ return simulated_gb
146
+ except ValueError:
147
+ logger.warning(f"Invalid {DEBUG_MAX_CUDA_VRAM_ENV} value: {debug_vram}, ignoring")
148
+ debug_mps_vram = os.environ.get(DEBUG_MAX_MPS_VRAM_ENV)
149
+ if debug_mps_vram is not None:
150
+ try:
151
+ simulated_gb = float(debug_mps_vram)
152
+ logger.warning(f"⚠️ DEBUG MODE: Simulating MPS memory as {simulated_gb:.1f}GB (set via {DEBUG_MAX_MPS_VRAM_ENV} environment variable)")
153
+ return simulated_gb
154
+ except ValueError:
155
+ logger.warning(f"Invalid {DEBUG_MAX_MPS_VRAM_ENV} value: {debug_mps_vram}, ignoring")
156
+
157
+ try:
158
+ import torch
159
+ if torch.cuda.is_available():
160
+ # Get total memory of the first GPU in GB
161
+ total_memory = torch.cuda.get_device_properties(0).total_memory
162
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
163
+ device_name = torch.cuda.get_device_name(0)
164
+ is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
165
+ if is_rocm:
166
+ logger.info(f"ROCm GPU detected: {device_name} ({memory_gb:.1f} GB, HIP {torch.version.hip})")
167
+ else:
168
+ logger.info(f"CUDA GPU detected: {device_name} ({memory_gb:.1f} GB)")
169
+ return memory_gb
170
+ elif hasattr(torch, 'xpu') and torch.xpu.is_available():
171
+ # Get total memory of the first XPU in GB
172
+ total_memory = torch.xpu.get_device_properties(0).total_memory
173
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
174
+ return memory_gb
175
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
176
+ mps_module = getattr(torch, "mps", None)
177
+ try:
178
+ if mps_module is not None and hasattr(mps_module, "recommended_max_memory"):
179
+ total_memory = mps_module.recommended_max_memory()
180
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
181
+ return memory_gb
182
+ if mps_module is not None and hasattr(mps_module, "get_device_properties"):
183
+ props = mps_module.get_device_properties(0)
184
+ total_memory = getattr(props, "total_memory", None)
185
+ if total_memory:
186
+ memory_gb = total_memory / (1024**3)
187
+ return memory_gb
188
+ except Exception as e:
189
+ logger.warning(f"Failed to detect MPS memory: {e}")
190
+
191
+ # Fallback: estimate from system unified memory (Apple Silicon shares CPU/GPU RAM)
192
+ try:
193
+ import subprocess
194
+ result = subprocess.run(
195
+ ["sysctl", "-n", "hw.memsize"],
196
+ capture_output=True, text=True, timeout=5
197
+ )
198
+ total_system_bytes = int(result.stdout.strip())
199
+ # MPS can use up to ~75% of unified memory for GPU workloads
200
+ memory_gb = (total_system_bytes / (1024**3)) * 0.75
201
+ return memory_gb
202
+ except Exception:
203
+ logger.warning(f"MPS available but total memory not exposed. Set {DEBUG_MAX_MPS_VRAM_ENV} to enable tiering.")
204
+ # Conservative fallback for M1/M2
205
+ return 8.0
206
+ else:
207
+ # No GPU detected - provide diagnostic information
208
+ _log_gpu_diagnostic_info(torch)
209
+ return 0
210
+ except Exception as e:
211
+ logger.warning(f"Failed to detect GPU memory: {e}")
212
+ return 0
213
+
214
+
215
+ def _log_gpu_diagnostic_info(torch_module):
216
+ """
217
+ Log diagnostic information when GPU is not detected to help users troubleshoot.
218
+
219
+ Args:
220
+ torch_module: The torch module to inspect for build information
221
+ """
222
+ logger.warning("=" * 80)
223
+ logger.warning("⚠️ GPU NOT DETECTED - DIAGNOSTIC INFORMATION")
224
+ logger.warning("=" * 80)
225
+
226
+ # Check PyTorch build type
227
+ is_rocm_build = hasattr(torch_module.version, 'hip') and torch_module.version.hip is not None
228
+ is_cuda_build = hasattr(torch_module.version, 'cuda') and torch_module.version.cuda is not None
229
+
230
+ if is_rocm_build:
231
+ logger.warning("✓ PyTorch ROCm build detected")
232
+ logger.warning(f" HIP version: {torch_module.version.hip}")
233
+ logger.warning("")
234
+ logger.warning("❌ torch.cuda.is_available() returned False")
235
+ logger.warning("")
236
+ logger.warning("Common causes for AMD/ROCm GPUs:")
237
+ logger.warning(" 1. ROCm drivers not installed or not properly configured")
238
+ logger.warning(" 2. GPU not supported by installed ROCm version")
239
+ logger.warning(" 3. Missing or incorrect HSA_OVERRIDE_GFX_VERSION environment variable")
240
+ logger.warning(" 4. ROCm runtime libraries not in system path")
241
+ logger.warning("")
242
+
243
+ # Check for common environment variables
244
+ hsa_override = os.environ.get('HSA_OVERRIDE_GFX_VERSION')
245
+ if hsa_override:
246
+ logger.warning(f" HSA_OVERRIDE_GFX_VERSION is set to: {hsa_override}")
247
+ else:
248
+ logger.warning(" ⚠️ HSA_OVERRIDE_GFX_VERSION is not set")
249
+ logger.warning(" For RDNA3 GPUs (RX 7000 series, RX 9000 series):")
250
+ logger.warning(" - RX 7900 XT/XTX, RX 9070 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.0")
251
+ logger.warning(" - RX 7800 XT, RX 7700 XT: set HSA_OVERRIDE_GFX_VERSION=11.0.1")
252
+ logger.warning(" - RX 7600: set HSA_OVERRIDE_GFX_VERSION=11.0.2")
253
+
254
+ logger.warning("")
255
+ logger.warning("Troubleshooting steps:")
256
+ logger.warning(" 1. Verify ROCm installation:")
257
+ logger.warning(" rocm-smi # Should list your GPU")
258
+ logger.warning(" 2. Check PyTorch ROCm build:")
259
+ logger.warning(" python -c \"import torch; print(f'ROCm: {torch.version.hip}')\"")
260
+ logger.warning(" 3. Set HSA_OVERRIDE_GFX_VERSION for your GPU (see above)")
261
+ logger.warning(" 4. On Windows: Use start_gradio_ui_rocm.bat which sets required env vars")
262
+ logger.warning(" 5. See docs/en/ACE-Step1.5-Rocm-Manual-Linux.md for Linux setup")
263
+ logger.warning(" 6. See requirements-rocm.txt for Windows ROCm setup instructions")
264
+
265
+ elif is_cuda_build:
266
+ logger.warning("✓ PyTorch CUDA build detected")
267
+ logger.warning(f" CUDA version: {torch_module.version.cuda}")
268
+ logger.warning("")
269
+ logger.warning("❌ torch.cuda.is_available() returned False")
270
+ logger.warning("")
271
+ logger.warning("Common causes for NVIDIA GPUs:")
272
+ logger.warning(" 1. NVIDIA drivers not installed")
273
+ logger.warning(" 2. CUDA runtime not installed or version mismatch")
274
+ logger.warning(" 3. GPU not supported by installed CUDA version")
275
+ logger.warning("")
276
+ logger.warning("Troubleshooting steps:")
277
+ logger.warning(" 1. Verify NVIDIA driver installation:")
278
+ logger.warning(" nvidia-smi # Should list your GPU")
279
+ logger.warning(" 2. Check CUDA version compatibility")
280
+ logger.warning(" 3. Reinstall PyTorch with CUDA support:")
281
+ logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
282
+
283
+ else:
284
+ logger.warning("⚠️ PyTorch build type: CPU-only")
285
+ logger.warning("")
286
+ logger.warning("You have installed a CPU-only version of PyTorch!")
287
+ logger.warning("")
288
+ logger.warning("For NVIDIA GPUs:")
289
+ logger.warning(f" pip install torch --index-url {PYTORCH_CUDA_INSTALL_URL}")
290
+ logger.warning("")
291
+ logger.warning("For AMD GPUs with ROCm:")
292
+ logger.warning(" Windows: See requirements-rocm.txt for detailed instructions")
293
+ logger.warning(f" Linux: pip install torch --index-url {PYTORCH_ROCM_INSTALL_URL}")
294
+ logger.warning("")
295
+ logger.warning("For more information, see README.md section 'AMD / ROCm GPUs'")
296
+
297
+ logger.warning("=" * 80)
298
+
299
+
300
+ def get_gpu_tier(gpu_memory_gb: float) -> str:
301
+ """
302
+ Determine GPU tier based on available memory.
303
+
304
+ Args:
305
+ gpu_memory_gb: GPU memory in GB
306
+
307
+ Returns:
308
+ Tier string: "tier1", "tier2", "tier3", "tier4", "tier5", "tier6", or "unlimited"
309
+ """
310
+ if gpu_memory_gb <= 0:
311
+ # CPU mode - use tier1 limits
312
+ return "tier1"
313
+ elif gpu_memory_gb <= 4:
314
+ return "tier1"
315
+ elif gpu_memory_gb <= 6:
316
+ return "tier2"
317
+ elif gpu_memory_gb <= 8:
318
+ return "tier3"
319
+ elif gpu_memory_gb <= 12:
320
+ return "tier4"
321
+ elif gpu_memory_gb < VRAM_16GB_MIN_GB:
322
+ return "tier5"
323
+ elif gpu_memory_gb <= 24:
324
+ if gpu_memory_gb < 16.0:
325
+ logger.info(f"Detected {gpu_memory_gb:.2f}GB VRAM — treating as 16GB class GPU")
326
+ return "tier6"
327
+ else:
328
+ return "unlimited"
329
+
330
+
331
+ def get_gpu_config(gpu_memory_gb: Optional[float] = None) -> GPUConfig:
332
+ """
333
+ Get GPU configuration based on detected or provided GPU memory.
334
+
335
+ Args:
336
+ gpu_memory_gb: GPU memory in GB. If None, will be auto-detected.
337
+
338
+ Returns:
339
+ GPUConfig object with all configuration parameters
340
+ """
341
+ if gpu_memory_gb is None:
342
+ gpu_memory_gb = get_gpu_memory_gb()
343
+
344
+ tier = get_gpu_tier(gpu_memory_gb)
345
+ config = GPU_TIER_CONFIGS[tier]
346
+
347
+ return GPUConfig(
348
+ tier=tier,
349
+ gpu_memory_gb=gpu_memory_gb,
350
+ max_duration_with_lm=config["max_duration_with_lm"],
351
+ max_duration_without_lm=config["max_duration_without_lm"],
352
+ max_batch_size_with_lm=config["max_batch_size_with_lm"],
353
+ max_batch_size_without_lm=config["max_batch_size_without_lm"],
354
+ init_lm_default=config["init_lm_default"],
355
+ available_lm_models=config["available_lm_models"],
356
+ lm_memory_gb=config["lm_memory_gb"],
357
+ )
358
+
359
+
360
+ def get_lm_model_size(model_path: str) -> str:
361
+ """
362
+ Extract LM model size from model path.
363
+
364
+ Args:
365
+ model_path: Model path string (e.g., "acestep-5Hz-lm-0.6B")
366
+
367
+ Returns:
368
+ Model size string: "0.6B", "1.7B", or "4B"
369
+ """
370
+ if "0.6B" in model_path:
371
+ return "0.6B"
372
+ elif "1.7B" in model_path:
373
+ return "1.7B"
374
+ elif "4B" in model_path:
375
+ return "4B"
376
+ else:
377
+ # Default to smallest model assumption
378
+ return "0.6B"
379
+
380
+
381
+ def get_lm_gpu_memory_ratio(model_path: str, total_gpu_memory_gb: float) -> Tuple[float, float]:
382
+ """
383
+ Calculate GPU memory utilization ratio for LM model.
384
+
385
+ Args:
386
+ model_path: LM model path (e.g., "acestep-5Hz-lm-0.6B")
387
+ total_gpu_memory_gb: Total GPU memory in GB
388
+
389
+ Returns:
390
+ Tuple of (gpu_memory_utilization_ratio, target_memory_gb)
391
+ """
392
+ model_size = get_lm_model_size(model_path)
393
+
394
+ # Target memory allocation for each model size
395
+ target_memory = {
396
+ "0.6B": 3.0,
397
+ "1.7B": 8.0,
398
+ "4B": 12.0,
399
+ }
400
+
401
+ target_gb = target_memory.get(model_size, 3.0)
402
+
403
+ # For large GPUs (>=24GB), don't restrict memory too much
404
+ if total_gpu_memory_gb >= 24:
405
+ # Use a reasonable ratio that allows the model to run efficiently
406
+ ratio = min(0.9, max(0.2, target_gb / total_gpu_memory_gb))
407
+ else:
408
+ # For smaller GPUs, strictly limit memory usage
409
+ ratio = min(0.9, max(0.1, target_gb / total_gpu_memory_gb))
410
+
411
+ return ratio, target_gb
412
+
413
+
414
+ def check_duration_limit(
415
+ duration: float,
416
+ gpu_config: GPUConfig,
417
+ lm_initialized: bool
418
+ ) -> Tuple[bool, str]:
419
+ """
420
+ Check if requested duration is within limits for current GPU configuration.
421
+
422
+ Args:
423
+ duration: Requested duration in seconds
424
+ gpu_config: Current GPU configuration
425
+ lm_initialized: Whether LM is initialized
426
+
427
+ Returns:
428
+ Tuple of (is_valid, warning_message)
429
+ """
430
+ max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
431
+
432
+ if duration > max_duration:
433
+ warning_msg = (
434
+ f"⚠️ Requested duration ({duration:.0f}s) exceeds the limit for your GPU "
435
+ f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_duration}s "
436
+ f"({'with' if lm_initialized else 'without'} LM). "
437
+ f"Duration will be clamped to {max_duration}s."
438
+ )
439
+ return False, warning_msg
440
+
441
+ return True, ""
442
+
443
+
444
+ def check_batch_size_limit(
445
+ batch_size: int,
446
+ gpu_config: GPUConfig,
447
+ lm_initialized: bool
448
+ ) -> Tuple[bool, str]:
449
+ """
450
+ Check if requested batch size is within limits for current GPU configuration.
451
+
452
+ Args:
453
+ batch_size: Requested batch size
454
+ gpu_config: Current GPU configuration
455
+ lm_initialized: Whether LM is initialized
456
+
457
+ Returns:
458
+ Tuple of (is_valid, warning_message)
459
+ """
460
+ max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
461
+
462
+ if batch_size > max_batch_size:
463
+ warning_msg = (
464
+ f"⚠️ Requested batch size ({batch_size}) exceeds the limit for your GPU "
465
+ f"({gpu_config.gpu_memory_gb:.1f}GB). Maximum allowed: {max_batch_size} "
466
+ f"({'with' if lm_initialized else 'without'} LM). "
467
+ f"Batch size will be clamped to {max_batch_size}."
468
+ )
469
+ return False, warning_msg
470
+
471
+ return True, ""
472
+
473
+
474
+ def is_lm_model_supported(model_path: str, gpu_config: GPUConfig) -> Tuple[bool, str]:
475
+ """
476
+ Check if the specified LM model is supported for current GPU configuration.
477
+
478
+ Args:
479
+ model_path: LM model path
480
+ gpu_config: Current GPU configuration
481
+
482
+ Returns:
483
+ Tuple of (is_supported, warning_message)
484
+ """
485
+ if not gpu_config.available_lm_models:
486
+ return False, (
487
+ f"⚠️ Your GPU ({gpu_config.gpu_memory_gb:.1f}GB) does not have enough memory "
488
+ f"to run any LM model. Please disable LM initialization."
489
+ )
490
+
491
+ model_size = get_lm_model_size(model_path)
492
+
493
+ # Check if model size is in available models
494
+ for available_model in gpu_config.available_lm_models:
495
+ if model_size in available_model:
496
+ return True, ""
497
+
498
+ return False, (
499
+ f"⚠️ LM model {model_path} ({model_size}) is not supported for your GPU "
500
+ f"({gpu_config.gpu_memory_gb:.1f}GB). Available models: {', '.join(gpu_config.available_lm_models)}"
501
+ )
502
+
503
+
504
+ def get_recommended_lm_model(gpu_config: GPUConfig) -> Optional[str]:
505
+ """
506
+ Get recommended LM model for current GPU configuration.
507
+
508
+ Args:
509
+ gpu_config: Current GPU configuration
510
+
511
+ Returns:
512
+ Recommended LM model path, or None if LM is not supported
513
+ """
514
+ if not gpu_config.available_lm_models:
515
+ return None
516
+
517
+ # Return the largest available model (last in the list)
518
+ return gpu_config.available_lm_models[-1]
519
+
520
+
521
+ def print_gpu_config_info(gpu_config: GPUConfig):
522
+ """Print GPU configuration information for debugging."""
523
+ logger.info(f"GPU Configuration:")
524
+ logger.info(f" - GPU Memory: {gpu_config.gpu_memory_gb:.1f} GB")
525
+ logger.info(f" - Tier: {gpu_config.tier}")
526
+ logger.info(f" - Max Duration (with LM): {gpu_config.max_duration_with_lm}s ({gpu_config.max_duration_with_lm // 60} min)")
527
+ logger.info(f" - Max Duration (without LM): {gpu_config.max_duration_without_lm}s ({gpu_config.max_duration_without_lm // 60} min)")
528
+ logger.info(f" - Max Batch Size (with LM): {gpu_config.max_batch_size_with_lm}")
529
+ logger.info(f" - Max Batch Size (without LM): {gpu_config.max_batch_size_without_lm}")
530
+ logger.info(f" - Init LM by Default: {gpu_config.init_lm_default}")
531
+ logger.info(f" - Available LM Models: {gpu_config.available_lm_models or 'None'}")
532
+
533
+
534
+ # Global GPU config instance (initialized lazily)
535
+ _global_gpu_config: Optional[GPUConfig] = None
536
+
537
+
538
+ def get_global_gpu_config() -> GPUConfig:
539
+ """Get the global GPU configuration, initializing if necessary."""
540
+ global _global_gpu_config
541
+ if _global_gpu_config is None:
542
+ _global_gpu_config = get_gpu_config()
543
+ return _global_gpu_config
544
+
545
+
546
+ def set_global_gpu_config(config: GPUConfig):
547
+ """Set the global GPU configuration."""
548
+ global _global_gpu_config
549
+ _global_gpu_config = config
acestep/gradio_ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from acestep.gradio_ui.interfaces import create_gradio_interface
acestep/gradio_ui/api_routes.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio API Routes Module
3
+ Add API endpoints compatible with api_server.py and CustomAceStep to Gradio application
4
+ """
5
+ import json
6
+ import os
7
+ import random
8
+ import time
9
+ from typing import Any, Dict, List, Optional
10
+ from uuid import uuid4
11
+
12
+ from fastapi import APIRouter, HTTPException, Request, Depends, Header
13
+ from fastapi.responses import FileResponse
14
+
15
+ # Global results directory inside project root
16
+ PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+ DEFAULT_RESULTS_DIR = os.path.join(PROJECT_ROOT, "gradio_outputs").replace("\\", "/")
18
+ os.makedirs(DEFAULT_RESULTS_DIR, exist_ok=True)
19
+
20
+ # API Key storage (set via setup_api_routes)
21
+ _api_key: Optional[str] = None
22
+
23
+
24
+ def set_api_key(key: Optional[str]):
25
+ """Set the API key for authentication"""
26
+ global _api_key
27
+ _api_key = key
28
+
29
+
30
+ def _wrap_response(data: Any, code: int = 200, error: Optional[str] = None) -> Dict[str, Any]:
31
+ """Wrap response data in standard format compatible with CustomAceStep."""
32
+ return {
33
+ "data": data,
34
+ "code": code,
35
+ "error": error,
36
+ "timestamp": int(time.time() * 1000),
37
+ "extra": None,
38
+ }
39
+
40
+
41
+ def verify_token_from_request(body: dict, authorization: Optional[str] = None) -> Optional[str]:
42
+ """
43
+ Verify API key from request body (ai_token) or Authorization header.
44
+ Returns the token if valid, None if no auth required.
45
+ """
46
+ if _api_key is None:
47
+ return None # No auth required
48
+
49
+ # Try ai_token from body first
50
+ ai_token = body.get("ai_token") if body else None
51
+ if ai_token:
52
+ if ai_token == _api_key:
53
+ return ai_token
54
+ raise HTTPException(status_code=401, detail="Invalid ai_token")
55
+
56
+ # Fallback to Authorization header
57
+ if authorization:
58
+ if authorization.startswith("Bearer "):
59
+ token = authorization[7:]
60
+ else:
61
+ token = authorization
62
+ if token == _api_key:
63
+ return token
64
+ raise HTTPException(status_code=401, detail="Invalid API key")
65
+
66
+ # No token provided but auth is required
67
+ raise HTTPException(status_code=401, detail="Missing ai_token or Authorization header")
68
+
69
+
70
+ async def verify_api_key(authorization: Optional[str] = Header(None)):
71
+ """Verify API key from Authorization header (legacy, for non-body endpoints)"""
72
+ if _api_key is None:
73
+ return # No auth required
74
+
75
+ if not authorization:
76
+ raise HTTPException(status_code=401, detail="Missing Authorization header")
77
+
78
+ # Support "Bearer <key>" format
79
+ if authorization.startswith("Bearer "):
80
+ token = authorization[7:]
81
+ else:
82
+ token = authorization
83
+
84
+ if token != _api_key:
85
+ raise HTTPException(status_code=401, detail="Invalid API key")
86
+
87
+
88
+ # Use diskcache to store results
89
+ try:
90
+ import diskcache
91
+ _cache_dir = os.path.join(os.path.dirname(__file__), ".cache", "api_results")
92
+ os.makedirs(_cache_dir, exist_ok=True)
93
+ _result_cache = diskcache.Cache(_cache_dir)
94
+ DISKCACHE_AVAILABLE = True
95
+ except ImportError:
96
+ _result_cache = {}
97
+ DISKCACHE_AVAILABLE = False
98
+
99
+ RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days expiration
100
+ RESULT_KEY_PREFIX = "ace_step_v1.5_"
101
+
102
+ # =============================================================================
103
+ # Example Data for Random Sample
104
+ # =============================================================================
105
+
106
+ def _get_project_root() -> str:
107
+ """Get project root directory"""
108
+ return os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
109
+
110
+
111
+ def _load_all_examples(sample_mode: str = "simple_mode") -> List[Dict[str, Any]]:
112
+ """Load all example JSON files from examples directory"""
113
+ project_root = _get_project_root()
114
+ if sample_mode == "simple_mode":
115
+ examples_dir = os.path.join(project_root, "examples", "simple_mode")
116
+ else:
117
+ examples_dir = os.path.join(project_root, "examples", "text2music")
118
+
119
+ if not os.path.isdir(examples_dir):
120
+ return []
121
+
122
+ all_examples = []
123
+ for filename in os.listdir(examples_dir):
124
+ if filename.endswith(".json"):
125
+ filepath = os.path.join(examples_dir, filename)
126
+ try:
127
+ with open(filepath, "r", encoding="utf-8") as f:
128
+ data = json.load(f)
129
+ if isinstance(data, list):
130
+ all_examples.extend(data)
131
+ elif isinstance(data, dict):
132
+ all_examples.append(data)
133
+ except Exception:
134
+ pass
135
+ return all_examples
136
+
137
+
138
+ # Pre-load example data
139
+ SIMPLE_EXAMPLE_DATA = _load_all_examples("simple_mode")
140
+ CUSTOM_EXAMPLE_DATA = _load_all_examples("custom_mode")
141
+
142
+
143
+ def store_result(task_id: str, result: dict, status: str = "succeeded"):
144
+ """Store result to diskcache"""
145
+ data = {
146
+ "result": result,
147
+ "created_at": time.time(),
148
+ "status": status
149
+ }
150
+ key = f"{RESULT_KEY_PREFIX}{task_id}"
151
+ if DISKCACHE_AVAILABLE:
152
+ _result_cache.set(key, data, expire=RESULT_EXPIRE_SECONDS)
153
+ else:
154
+ _result_cache[key] = data
155
+
156
+
157
+ def get_result(task_id: str) -> Optional[dict]:
158
+ """Get result from diskcache"""
159
+ key = f"{RESULT_KEY_PREFIX}{task_id}"
160
+ if DISKCACHE_AVAILABLE:
161
+ return _result_cache.get(key)
162
+ else:
163
+ return _result_cache.get(key)
164
+
165
+
166
+ router = APIRouter()
167
+
168
+
169
+ @router.get("/health")
170
+ async def health_check():
171
+ """Health check endpoint"""
172
+ return _wrap_response({
173
+ "status": "ok",
174
+ "service": "ACE-Step Gradio API",
175
+ "version": "1.0",
176
+ })
177
+
178
+
179
+ @router.get("/v1/models")
180
+ async def list_models(request: Request, _: None = Depends(verify_api_key)):
181
+ """List available DiT models"""
182
+ dit_handler = request.app.state.dit_handler
183
+
184
+ models = []
185
+ if dit_handler and dit_handler.model is not None:
186
+ # Get current loaded model name
187
+ config_path = getattr(dit_handler, 'config_path', '') or ''
188
+ model_name = os.path.basename(config_path.rstrip("/\\")) if config_path else "unknown"
189
+ models.append({
190
+ "name": model_name,
191
+ "is_default": True,
192
+ })
193
+
194
+ return _wrap_response({
195
+ "models": models,
196
+ "default_model": models[0]["name"] if models else None,
197
+ })
198
+
199
+
200
+ @router.get("/v1/audio")
201
+ async def get_audio(path: str, _: None = Depends(verify_api_key)):
202
+ """Download audio file"""
203
+ # Security: Validate path is within allowed directory to prevent path traversal
204
+ resolved_path = os.path.realpath(path)
205
+ allowed_dir = os.path.realpath(DEFAULT_RESULTS_DIR)
206
+ if not resolved_path.startswith(allowed_dir + os.sep) and resolved_path != allowed_dir:
207
+ raise HTTPException(status_code=403, detail="Access denied: path outside allowed directory")
208
+ if not os.path.exists(resolved_path):
209
+ raise HTTPException(status_code=404, detail="Audio file not found")
210
+
211
+ ext = os.path.splitext(resolved_path)[1].lower()
212
+ media_types = {
213
+ ".mp3": "audio/mpeg",
214
+ ".wav": "audio/wav",
215
+ ".flac": "audio/flac",
216
+ ".ogg": "audio/ogg",
217
+ }
218
+ media_type = media_types.get(ext, "audio/mpeg")
219
+
220
+ return FileResponse(resolved_path, media_type=media_type)
221
+
222
+
223
+ @router.post("/create_random_sample")
224
+ async def create_random_sample(request: Request, authorization: Optional[str] = Header(None)):
225
+ """Get random sample parameters from pre-loaded example data"""
226
+ content_type = (request.headers.get("content-type") or "").lower()
227
+
228
+ if "json" in content_type:
229
+ body = await request.json()
230
+ else:
231
+ form = await request.form()
232
+ body = {k: v for k, v in form.items()}
233
+
234
+ verify_token_from_request(body, authorization)
235
+ sample_type = body.get("sample_type", "simple_mode") or "simple_mode"
236
+
237
+ if sample_type == "simple_mode":
238
+ example_data = SIMPLE_EXAMPLE_DATA
239
+ else:
240
+ example_data = CUSTOM_EXAMPLE_DATA
241
+
242
+ if not example_data:
243
+ return _wrap_response(None, code=500, error="No example data available")
244
+
245
+ random_example = random.choice(example_data)
246
+ return _wrap_response(random_example)
247
+
248
+
249
+ @router.post("/query_result")
250
+ async def query_result(request: Request, authorization: Optional[str] = Header(None)):
251
+ """Batch query task results"""
252
+ content_type = (request.headers.get("content-type") or "").lower()
253
+
254
+ if "json" in content_type:
255
+ body = await request.json()
256
+ else:
257
+ form = await request.form()
258
+ body = {k: v for k, v in form.items()}
259
+
260
+ verify_token_from_request(body, authorization)
261
+ task_ids = body.get("task_id_list", [])
262
+
263
+ if isinstance(task_ids, str):
264
+ try:
265
+ task_ids = json.loads(task_ids)
266
+ except Exception:
267
+ task_ids = []
268
+
269
+ results = []
270
+ for task_id in task_ids:
271
+ data = get_result(task_id)
272
+ if data and data.get("status") == "succeeded":
273
+ results.append({
274
+ "task_id": task_id,
275
+ "status": 1,
276
+ "result": json.dumps(data["result"], ensure_ascii=False)
277
+ })
278
+ else:
279
+ results.append({
280
+ "task_id": task_id,
281
+ "status": 0,
282
+ "result": "[]"
283
+ })
284
+
285
+ return _wrap_response(results)
286
+
287
+
288
+ @router.post("/format_input")
289
+ async def format_input(request: Request, authorization: Optional[str] = Header(None)):
290
+ """Format and enhance lyrics/caption via LLM"""
291
+ llm_handler = request.app.state.llm_handler
292
+
293
+ if not llm_handler or not llm_handler.llm_initialized:
294
+ return _wrap_response(None, code=500, error="LLM not initialized")
295
+
296
+ content_type = (request.headers.get("content-type") or "").lower()
297
+ if "json" in content_type:
298
+ body = await request.json()
299
+ else:
300
+ form = await request.form()
301
+ body = {k: v for k, v in form.items()}
302
+
303
+ verify_token_from_request(body, authorization)
304
+
305
+ caption = body.get("prompt", "") or ""
306
+ lyrics = body.get("lyrics", "") or ""
307
+ temperature = float(body.get("temperature", 0.85))
308
+
309
+ from acestep.inference import format_sample
310
+
311
+ try:
312
+ result = format_sample(
313
+ llm_handler=llm_handler,
314
+ caption=caption,
315
+ lyrics=lyrics,
316
+ temperature=temperature,
317
+ use_constrained_decoding=True,
318
+ )
319
+
320
+ if not result.success:
321
+ return _wrap_response(None, code=500, error=result.status_message)
322
+
323
+ return _wrap_response({
324
+ "caption": result.caption or caption,
325
+ "lyrics": result.lyrics or lyrics,
326
+ "bpm": result.bpm,
327
+ "key_scale": result.keyscale,
328
+ "time_signature": result.timesignature,
329
+ "duration": result.duration,
330
+ "vocal_language": result.language or "unknown",
331
+ })
332
+ except Exception as e:
333
+ return _wrap_response(None, code=500, error=str(e))
334
+
335
+
336
+ @router.post("/release_task")
337
+ async def release_task(request: Request, authorization: Optional[str] = Header(None)):
338
+ """Create music generation task"""
339
+ dit_handler = request.app.state.dit_handler
340
+ llm_handler = request.app.state.llm_handler
341
+
342
+ if not dit_handler or dit_handler.model is None:
343
+ raise HTTPException(status_code=500, detail="DiT model not initialized")
344
+
345
+ content_type = (request.headers.get("content-type") or "").lower()
346
+ if "json" in content_type:
347
+ body = await request.json()
348
+ else:
349
+ form = await request.form()
350
+ body = {k: v for k, v in form.items()}
351
+
352
+ verify_token_from_request(body, authorization)
353
+ task_id = str(uuid4())
354
+
355
+ from acestep.inference import generate_music, GenerationParams, GenerationConfig, create_sample, format_sample
356
+
357
+ # Parse param_obj if provided
358
+ param_obj = body.get("param_obj", {})
359
+ if isinstance(param_obj, str):
360
+ try:
361
+ param_obj = json.loads(param_obj)
362
+ except Exception:
363
+ param_obj = {}
364
+
365
+ # Helper to get param with aliases
366
+ def get_param(key, *aliases, default=None):
367
+ for k in [key] + list(aliases):
368
+ if k in body and body[k] is not None:
369
+ return body[k]
370
+ if k in param_obj and param_obj[k] is not None:
371
+ return param_obj[k]
372
+ return default
373
+
374
+ def to_bool(val, default=False):
375
+ if val is None:
376
+ return default
377
+ if isinstance(val, bool):
378
+ return val
379
+ if isinstance(val, str):
380
+ return val.lower() in ("true", "1", "yes")
381
+ return bool(val)
382
+
383
+ try:
384
+ # Get sample_mode and sample_query parameters
385
+ sample_mode = to_bool(get_param("sample_mode", "sampleMode"), False)
386
+ sample_query = get_param("sample_query", "sampleQuery", "description", "desc", default="") or ""
387
+ use_format = to_bool(get_param("use_format", "useFormat"), False)
388
+ has_sample_query = bool(sample_query and sample_query.strip())
389
+
390
+ # Get base parameters
391
+ caption = get_param("prompt", "caption", default="") or ""
392
+ lyrics = get_param("lyrics", default="") or ""
393
+ vocal_language = get_param("vocal_language", "language", default="en") or "en"
394
+ lm_temperature = float(get_param("lm_temperature", "temperature", default=0.85) or 0.85)
395
+
396
+ # Process sample_mode: use LLM to auto-generate caption/lyrics/metas
397
+ if sample_mode or has_sample_query:
398
+ if not llm_handler or not llm_handler.llm_initialized:
399
+ raise HTTPException(status_code=500, detail="sample_mode requires LLM to be initialized")
400
+
401
+ query = sample_query if has_sample_query else "NO USER INPUT"
402
+ sample_result = create_sample(
403
+ llm_handler=llm_handler,
404
+ query=query,
405
+ vocal_language=vocal_language if vocal_language not in ("en", "unknown", "") else None,
406
+ temperature=lm_temperature,
407
+ )
408
+
409
+ if not sample_result.success:
410
+ raise HTTPException(status_code=500, detail=sample_result.error or sample_result.status_message)
411
+
412
+ # Use generated values
413
+ caption = sample_result.caption or caption
414
+ lyrics = sample_result.lyrics or lyrics
415
+ # Override metas from sample result if available
416
+ sample_bpm = sample_result.bpm
417
+ sample_duration = sample_result.duration
418
+ sample_keyscale = sample_result.keyscale
419
+ sample_timesignature = sample_result.timesignature
420
+ sample_language = sample_result.language or vocal_language
421
+ else:
422
+ sample_bpm = None
423
+ sample_duration = None
424
+ sample_keyscale = None
425
+ sample_timesignature = None
426
+ sample_language = vocal_language
427
+
428
+ # Process use_format: enhance caption/lyrics via LLM
429
+ if use_format and not sample_mode and not has_sample_query:
430
+ if llm_handler and llm_handler.llm_initialized:
431
+ format_result = format_sample(
432
+ llm_handler=llm_handler,
433
+ caption=caption,
434
+ lyrics=lyrics,
435
+ temperature=lm_temperature,
436
+ )
437
+ if format_result.success:
438
+ caption = format_result.caption or caption
439
+ lyrics = format_result.lyrics or lyrics
440
+ if format_result.bpm:
441
+ sample_bpm = format_result.bpm
442
+ if format_result.duration:
443
+ sample_duration = format_result.duration
444
+ if format_result.keyscale:
445
+ sample_keyscale = format_result.keyscale
446
+ if format_result.timesignature:
447
+ sample_timesignature = format_result.timesignature
448
+ if format_result.language:
449
+ sample_language = format_result.language
450
+
451
+ # Build generation params with alias support
452
+ params = GenerationParams(
453
+ task_type=get_param("task_type", default="text2music"),
454
+ caption=caption,
455
+ lyrics=lyrics,
456
+ bpm=sample_bpm or get_param("bpm"),
457
+ keyscale=sample_keyscale or get_param("key_scale", "keyscale", "key", default=""),
458
+ timesignature=sample_timesignature or get_param("time_signature", "timesignature", default=""),
459
+ duration=sample_duration or get_param("audio_duration", "duration", default=-1),
460
+ vocal_language=sample_language,
461
+ inference_steps=get_param("inference_steps", default=8),
462
+ guidance_scale=float(get_param("guidance_scale", default=7.0) or 7.0),
463
+ seed=int(get_param("seed", default=-1) or -1),
464
+ thinking=to_bool(get_param("thinking"), False),
465
+ lm_temperature=lm_temperature,
466
+ lm_cfg_scale=float(get_param("lm_cfg_scale", default=2.0) or 2.0),
467
+ lm_negative_prompt=get_param("lm_negative_prompt", default="NO USER INPUT") or "NO USER INPUT",
468
+ )
469
+
470
+ config = GenerationConfig(
471
+ batch_size=get_param("batch_size", default=2),
472
+ use_random_seed=get_param("use_random_seed", default=True),
473
+ audio_format=get_param("audio_format", default="mp3"),
474
+ )
475
+
476
+ # Get output directory
477
+ save_dir = os.path.join(DEFAULT_RESULTS_DIR, f"api_{int(time.time())}").replace("\\", "/")
478
+ os.makedirs(save_dir, exist_ok=True)
479
+
480
+ # Call generation function
481
+ result = generate_music(
482
+ dit_handler=dit_handler,
483
+ llm_handler=llm_handler if llm_handler and llm_handler.llm_initialized else None,
484
+ params=params,
485
+ config=config,
486
+ save_dir=save_dir,
487
+ )
488
+
489
+ if not result.success:
490
+ raise HTTPException(status_code=500, detail=result.error or result.status_message)
491
+
492
+ # Extract audio paths
493
+ audio_paths = [a["path"] for a in result.audios if a.get("path")]
494
+
495
+ # Build result data with download URLs
496
+ from urllib.parse import urlencode
497
+ result_data = [{
498
+ "file": p,
499
+ "url": f"/v1/audio?{urlencode({'path': p})}",
500
+ "status": 1,
501
+ "create_time": int(time.time()),
502
+ } for p in audio_paths]
503
+
504
+ # Store result
505
+ store_result(task_id, result_data)
506
+
507
+ return _wrap_response({"task_id": task_id, "status": "succeeded"})
508
+
509
+ except HTTPException:
510
+ raise
511
+ except Exception as e:
512
+ raise HTTPException(status_code=500, detail=str(e))
513
+
514
+
515
+ def setup_api_routes_to_app(app, dit_handler, llm_handler, api_key: Optional[str] = None):
516
+ """
517
+ Mount API routes to a FastAPI application (for use with gr.mount_gradio_app)
518
+
519
+ Args:
520
+ app: FastAPI application instance
521
+ dit_handler: DiT handler
522
+ llm_handler: LLM handler
523
+ api_key: Optional API key for authentication
524
+ """
525
+ set_api_key(api_key)
526
+ app.state.dit_handler = dit_handler
527
+ app.state.llm_handler = llm_handler
528
+ app.include_router(router)
529
+
530
+
531
+ def setup_api_routes(demo, dit_handler, llm_handler, api_key: Optional[str] = None):
532
+ """
533
+ Mount API routes to Gradio application
534
+
535
+ Args:
536
+ demo: Gradio Blocks instance
537
+ dit_handler: DiT handler
538
+ llm_handler: LLM handler
539
+ api_key: Optional API key for authentication
540
+ """
541
+ set_api_key(api_key)
542
+ app = demo.app
543
+ app.state.dit_handler = dit_handler
544
+ app.state.llm_handler = llm_handler
545
+ app.include_router(router)
546
+
547
+ # Override the /info endpoint to handle schema generation errors gracefully
548
+ from fastapi.responses import JSONResponse
549
+
550
+ @app.get("/info")
551
+ async def custom_api_info():
552
+ """Custom API info endpoint with error handling for schema generation issues"""
553
+ try:
554
+ # Try to get the original API info
555
+ from gradio import utils
556
+ api_info = utils.safe_deepcopy(demo.get_api_info())
557
+ return JSONResponse(content=api_info)
558
+ except (TypeError, AttributeError, KeyError) as e:
559
+ # If schema generation fails, return a minimal response
560
+ return JSONResponse(content={
561
+ "error": "API schema generation not available",
562
+ "message": "Custom API routes are available at /health, /v1/models, /release_task, /query_result, /create_random_sample, /format_input",
563
+ "detail": str(e)
564
+ })
acestep/gradio_ui/events/__init__.py ADDED
@@ -0,0 +1,1254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Event Handlers Module
3
+ Main entry point for setting up all event handlers
4
+ """
5
+ import gradio as gr
6
+ from typing import Optional
7
+ from loguru import logger
8
+
9
+ # Import handler modules
10
+ from . import generation_handlers as gen_h
11
+ from . import results_handlers as res_h
12
+ from . import training_handlers as train_h
13
+ from acestep.gradio_ui.i18n import t
14
+
15
+
16
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
17
+ """Setup event handlers connecting UI components and business logic"""
18
+
19
+ # ========== Dataset Handlers ==========
20
+ dataset_section["import_dataset_btn"].click(
21
+ fn=dataset_handler.import_dataset,
22
+ inputs=[dataset_section["dataset_type"]],
23
+ outputs=[dataset_section["data_status"]]
24
+ )
25
+
26
+ # ========== Service Initialization ==========
27
+ generation_section["refresh_btn"].click(
28
+ fn=lambda: gen_h.refresh_checkpoints(dit_handler),
29
+ outputs=[generation_section["checkpoint_dropdown"]]
30
+ )
31
+
32
+ generation_section["config_path"].change(
33
+ fn=gen_h.update_model_type_settings,
34
+ inputs=[generation_section["config_path"]],
35
+ outputs=[
36
+ generation_section["inference_steps"],
37
+ generation_section["guidance_scale"],
38
+ generation_section["use_adg"],
39
+ generation_section["shift"],
40
+ generation_section["cfg_interval_start"],
41
+ generation_section["cfg_interval_end"],
42
+ generation_section["task_type"],
43
+ ]
44
+ )
45
+
46
+ generation_section["init_btn"].click(
47
+ fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
48
+ inputs=[
49
+ generation_section["checkpoint_dropdown"],
50
+ generation_section["config_path"],
51
+ generation_section["device"],
52
+ generation_section["init_llm_checkbox"],
53
+ generation_section["lm_model_path"],
54
+ generation_section["backend_dropdown"],
55
+ generation_section["use_flash_attention_checkbox"],
56
+ generation_section["offload_to_cpu_checkbox"],
57
+ generation_section["offload_dit_to_cpu_checkbox"],
58
+ generation_section["compile_model_checkbox"],
59
+ generation_section["quantization_checkbox"],
60
+ ],
61
+ outputs=[
62
+ generation_section["init_status"],
63
+ generation_section["generate_btn"],
64
+ generation_section["service_config_accordion"],
65
+ # Model type settings (updated based on actual loaded model)
66
+ generation_section["inference_steps"],
67
+ generation_section["guidance_scale"],
68
+ generation_section["use_adg"],
69
+ generation_section["shift"],
70
+ generation_section["cfg_interval_start"],
71
+ generation_section["cfg_interval_end"],
72
+ generation_section["task_type"],
73
+ ]
74
+ )
75
+
76
+ # ========== LoRA Handlers ==========
77
+ generation_section["load_lora_btn"].click(
78
+ fn=dit_handler.load_lora,
79
+ inputs=[generation_section["lora_path"]],
80
+ outputs=[generation_section["lora_status"]]
81
+ ).then(
82
+ # Update checkbox to enabled state after loading
83
+ fn=lambda: gr.update(value=True),
84
+ outputs=[generation_section["use_lora_checkbox"]]
85
+ )
86
+
87
+ generation_section["unload_lora_btn"].click(
88
+ fn=dit_handler.unload_lora,
89
+ outputs=[generation_section["lora_status"]]
90
+ ).then(
91
+ # Update checkbox to disabled state after unloading
92
+ fn=lambda: gr.update(value=False),
93
+ outputs=[generation_section["use_lora_checkbox"]]
94
+ )
95
+
96
+ generation_section["use_lora_checkbox"].change(
97
+ fn=dit_handler.set_use_lora,
98
+ inputs=[generation_section["use_lora_checkbox"]],
99
+ outputs=[generation_section["lora_status"]]
100
+ )
101
+
102
+ generation_section["lora_scale_slider"].change(
103
+ fn=dit_handler.set_lora_scale,
104
+ inputs=[generation_section["lora_scale_slider"]],
105
+ outputs=[generation_section["lora_status"]]
106
+ )
107
+
108
+ # ========== UI Visibility Updates ==========
109
+ generation_section["init_llm_checkbox"].change(
110
+ fn=gen_h.update_negative_prompt_visibility,
111
+ inputs=[generation_section["init_llm_checkbox"]],
112
+ outputs=[generation_section["lm_negative_prompt"]]
113
+ )
114
+
115
+ generation_section["init_llm_checkbox"].change(
116
+ fn=gen_h.update_audio_cover_strength_visibility,
117
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
118
+ outputs=[generation_section["audio_cover_strength"]]
119
+ )
120
+
121
+ generation_section["task_type"].change(
122
+ fn=gen_h.update_audio_cover_strength_visibility,
123
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
124
+ outputs=[generation_section["audio_cover_strength"]]
125
+ )
126
+
127
+ generation_section["reference_audio"].change(
128
+ fn=gen_h.update_audio_cover_strength_visibility,
129
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"], generation_section["reference_audio"]],
130
+ outputs=[generation_section["audio_cover_strength"]]
131
+ )
132
+
133
+ generation_section["batch_size_input"].change(
134
+ fn=gen_h.update_audio_components_visibility,
135
+ inputs=[generation_section["batch_size_input"]],
136
+ outputs=[
137
+ results_section["audio_col_1"],
138
+ results_section["audio_col_2"],
139
+ results_section["audio_col_3"],
140
+ results_section["audio_col_4"],
141
+ results_section["audio_row_5_8"],
142
+ results_section["audio_col_5"],
143
+ results_section["audio_col_6"],
144
+ results_section["audio_col_7"],
145
+ results_section["audio_col_8"],
146
+ ]
147
+ )
148
+
149
+ # ========== Audio Conversion ==========
150
+ generation_section["convert_src_to_codes_btn"].click(
151
+ fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
152
+ inputs=[generation_section["src_audio"]],
153
+ outputs=[generation_section["text2music_audio_code_string"]]
154
+ )
155
+
156
+ # ========== Instruction UI Updates ==========
157
+ for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"], generation_section["reference_audio"]]:
158
+ trigger.change(
159
+ fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
160
+ inputs=[
161
+ generation_section["task_type"],
162
+ generation_section["track_name"],
163
+ generation_section["complete_track_classes"],
164
+ generation_section["text2music_audio_code_string"],
165
+ generation_section["init_llm_checkbox"],
166
+ generation_section["reference_audio"],
167
+ ],
168
+ outputs=[
169
+ generation_section["instruction_display_gen"],
170
+ generation_section["track_name"],
171
+ generation_section["complete_track_classes"],
172
+ generation_section["audio_cover_strength"],
173
+ generation_section["repainting_group"],
174
+ generation_section["text2music_audio_codes_group"],
175
+ ]
176
+ )
177
+
178
+ # ========== Sample/Transcribe Handlers ==========
179
+ # Load random example from ./examples/text2music directory
180
+ generation_section["sample_btn"].click(
181
+ fn=lambda task: gen_h.load_random_example(task, llm_handler) + (True,),
182
+ inputs=[
183
+ generation_section["task_type"],
184
+ ],
185
+ outputs=[
186
+ generation_section["captions"],
187
+ generation_section["lyrics"],
188
+ generation_section["think_checkbox"],
189
+ generation_section["bpm"],
190
+ generation_section["audio_duration"],
191
+ generation_section["key_scale"],
192
+ generation_section["vocal_language"],
193
+ generation_section["time_signature"],
194
+ results_section["is_format_caption_state"]
195
+ ]
196
+ )
197
+
198
+ generation_section["text2music_audio_code_string"].change(
199
+ fn=gen_h.update_transcribe_button_text,
200
+ inputs=[generation_section["text2music_audio_code_string"]],
201
+ outputs=[generation_section["transcribe_btn"]]
202
+ )
203
+
204
+ generation_section["transcribe_btn"].click(
205
+ fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
206
+ inputs=[
207
+ generation_section["text2music_audio_code_string"],
208
+ generation_section["constrained_decoding_debug"]
209
+ ],
210
+ outputs=[
211
+ results_section["status_output"],
212
+ generation_section["captions"],
213
+ generation_section["lyrics"],
214
+ generation_section["bpm"],
215
+ generation_section["audio_duration"],
216
+ generation_section["key_scale"],
217
+ generation_section["vocal_language"],
218
+ generation_section["time_signature"],
219
+ results_section["is_format_caption_state"]
220
+ ]
221
+ )
222
+
223
+ # ========== Reset Format Caption Flag ==========
224
+ for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
225
+ generation_section["key_scale"], generation_section["time_signature"],
226
+ generation_section["vocal_language"], generation_section["audio_duration"]]:
227
+ trigger.change(
228
+ fn=gen_h.reset_format_caption_flag,
229
+ inputs=[],
230
+ outputs=[results_section["is_format_caption_state"]]
231
+ )
232
+
233
+ # ========== Audio Uploads Accordion ==========
234
+ for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
235
+ trigger.change(
236
+ fn=gen_h.update_audio_uploads_accordion,
237
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
238
+ outputs=[generation_section["audio_uploads_accordion"]]
239
+ )
240
+
241
+ # ========== Instrumental Checkbox ==========
242
+ generation_section["instrumental_checkbox"].change(
243
+ fn=gen_h.handle_instrumental_checkbox,
244
+ inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
245
+ outputs=[generation_section["lyrics"]]
246
+ )
247
+
248
+ # ========== Format Button ==========
249
+ # Note: cfg_scale and negative_prompt are not supported in format mode
250
+ generation_section["format_btn"].click(
251
+ fn=lambda caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug: gen_h.handle_format_sample(
252
+ llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
253
+ ),
254
+ inputs=[
255
+ generation_section["captions"],
256
+ generation_section["lyrics"],
257
+ generation_section["bpm"],
258
+ generation_section["audio_duration"],
259
+ generation_section["key_scale"],
260
+ generation_section["time_signature"],
261
+ generation_section["lm_temperature"],
262
+ generation_section["lm_top_k"],
263
+ generation_section["lm_top_p"],
264
+ generation_section["constrained_decoding_debug"],
265
+ ],
266
+ outputs=[
267
+ generation_section["captions"],
268
+ generation_section["lyrics"],
269
+ generation_section["bpm"],
270
+ generation_section["audio_duration"],
271
+ generation_section["key_scale"],
272
+ generation_section["vocal_language"],
273
+ generation_section["time_signature"],
274
+ results_section["is_format_caption_state"],
275
+ results_section["status_output"],
276
+ ]
277
+ )
278
+
279
+ # ========== Simple/Custom Mode Toggle ==========
280
+ generation_section["generation_mode"].change(
281
+ fn=gen_h.handle_generation_mode_change,
282
+ inputs=[generation_section["generation_mode"]],
283
+ outputs=[
284
+ generation_section["simple_mode_group"],
285
+ generation_section["caption_accordion"],
286
+ generation_section["lyrics_accordion"],
287
+ generation_section["generate_btn"],
288
+ generation_section["simple_sample_created"],
289
+ generation_section["optional_params_accordion"],
290
+ ]
291
+ )
292
+
293
+ # ========== Simple Mode Instrumental Checkbox ==========
294
+ # When instrumental is checked, disable vocal language and set to ["unknown"]
295
+ generation_section["simple_instrumental_checkbox"].change(
296
+ fn=gen_h.handle_simple_instrumental_change,
297
+ inputs=[generation_section["simple_instrumental_checkbox"]],
298
+ outputs=[generation_section["simple_vocal_language"]]
299
+ )
300
+
301
+ # ========== Random Description Button ==========
302
+ generation_section["random_desc_btn"].click(
303
+ fn=gen_h.load_random_simple_description,
304
+ inputs=[],
305
+ outputs=[
306
+ generation_section["simple_query_input"],
307
+ generation_section["simple_instrumental_checkbox"],
308
+ generation_section["simple_vocal_language"],
309
+ ]
310
+ )
311
+
312
+ # ========== Create Sample Button (Simple Mode) ==========
313
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
314
+ generation_section["create_sample_btn"].click(
315
+ fn=lambda query, instrumental, vocal_lang, temp, top_k, top_p, debug: gen_h.handle_create_sample(
316
+ llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
317
+ ),
318
+ inputs=[
319
+ generation_section["simple_query_input"],
320
+ generation_section["simple_instrumental_checkbox"],
321
+ generation_section["simple_vocal_language"],
322
+ generation_section["lm_temperature"],
323
+ generation_section["lm_top_k"],
324
+ generation_section["lm_top_p"],
325
+ generation_section["constrained_decoding_debug"],
326
+ ],
327
+ outputs=[
328
+ generation_section["captions"],
329
+ generation_section["lyrics"],
330
+ generation_section["bpm"],
331
+ generation_section["audio_duration"],
332
+ generation_section["key_scale"],
333
+ generation_section["vocal_language"],
334
+ generation_section["simple_vocal_language"],
335
+ generation_section["time_signature"],
336
+ generation_section["instrumental_checkbox"],
337
+ generation_section["caption_accordion"],
338
+ generation_section["lyrics_accordion"],
339
+ generation_section["generate_btn"],
340
+ generation_section["simple_sample_created"],
341
+ generation_section["think_checkbox"],
342
+ results_section["is_format_caption_state"],
343
+ results_section["status_output"],
344
+ ]
345
+ )
346
+
347
+ # ========== Load/Save Metadata ==========
348
+ generation_section["load_file"].upload(
349
+ fn=lambda file_obj: gen_h.load_metadata(file_obj, llm_handler),
350
+ inputs=[generation_section["load_file"]],
351
+ outputs=[
352
+ generation_section["task_type"],
353
+ generation_section["captions"],
354
+ generation_section["lyrics"],
355
+ generation_section["vocal_language"],
356
+ generation_section["bpm"],
357
+ generation_section["key_scale"],
358
+ generation_section["time_signature"],
359
+ generation_section["audio_duration"],
360
+ generation_section["batch_size_input"],
361
+ generation_section["inference_steps"],
362
+ generation_section["guidance_scale"],
363
+ generation_section["seed"],
364
+ generation_section["random_seed_checkbox"],
365
+ generation_section["use_adg"],
366
+ generation_section["cfg_interval_start"],
367
+ generation_section["cfg_interval_end"],
368
+ generation_section["shift"],
369
+ generation_section["infer_method"],
370
+ generation_section["custom_timesteps"],
371
+ generation_section["audio_format"],
372
+ generation_section["lm_temperature"],
373
+ generation_section["lm_cfg_scale"],
374
+ generation_section["lm_top_k"],
375
+ generation_section["lm_top_p"],
376
+ generation_section["lm_negative_prompt"],
377
+ generation_section["use_cot_metas"], # Added: use_cot_metas
378
+ generation_section["use_cot_caption"],
379
+ generation_section["use_cot_language"],
380
+ generation_section["audio_cover_strength"],
381
+ generation_section["think_checkbox"],
382
+ generation_section["text2music_audio_code_string"],
383
+ generation_section["repainting_start"],
384
+ generation_section["repainting_end"],
385
+ generation_section["track_name"],
386
+ generation_section["complete_track_classes"],
387
+ generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
388
+ results_section["is_format_caption_state"]
389
+ ]
390
+ )
391
+
392
+ # Save buttons for all 8 audio outputs
393
+ download_existing_js = """(current_audio, batch_files) => {
394
+ // Debug: print what the input actually is
395
+ console.log("👉 [Debug] Current Audio Input:", current_audio);
396
+
397
+ // 1. Safety check
398
+ if (!current_audio) {
399
+ console.warn("⚠️ No audio selected or audio is empty.");
400
+ return;
401
+ }
402
+ if (!batch_files || !Array.isArray(batch_files)) {
403
+ console.warn("⚠️ Batch file list is empty/not ready.");
404
+ return;
405
+ }
406
+
407
+ // 2. Smartly extract path string
408
+ let pathString = "";
409
+
410
+ if (typeof current_audio === "string") {
411
+ // Case A: direct path string received
412
+ pathString = current_audio;
413
+ } else if (typeof current_audio === "object") {
414
+ // Case B: an object is received, try common properties
415
+ // Gradio file objects usually have path, url, or name
416
+ pathString = current_audio.path || current_audio.name || current_audio.url || "";
417
+ }
418
+
419
+ if (!pathString) {
420
+ console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
421
+ return;
422
+ }
423
+
424
+ // 3. Extract Key (UUID)
425
+ // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
426
+ let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
427
+ let key = filename.split('.')[0]; // get UUID without extension
428
+
429
+ console.log(`🔑 Key extracted: ${key}`);
430
+
431
+ // 4. Find matching file(s) in the list
432
+ let targets = batch_files.filter(f => {
433
+ // Also extract names from batch_files objects
434
+ // f usually contains name (backend path) and orig_name (download name)
435
+ const fPath = f.name || f.path || "";
436
+ return fPath.includes(key);
437
+ });
438
+
439
+ if (targets.length === 0) {
440
+ console.warn("❌ No matching files found in batch list for key:", key);
441
+ alert("Batch list does not contain this file yet. Please wait for generation to finish.");
442
+ return;
443
+ }
444
+
445
+ // 5. Trigger download(s)
446
+ console.log(`🎯 Found ${targets.length} files to download.`);
447
+ targets.forEach((f, index) => {
448
+ setTimeout(() => {
449
+ const a = document.createElement('a');
450
+ // Prefer url (frontend-accessible link), otherwise try data
451
+ a.href = f.url || f.data;
452
+ a.download = f.orig_name || "download";
453
+ a.style.display = 'none';
454
+ document.body.appendChild(a);
455
+ a.click();
456
+ document.body.removeChild(a);
457
+ }, index * 1000); // 300ms interval to avoid browser blocking
458
+ });
459
+ }
460
+ """
461
+ for btn_idx in range(1, 9):
462
+ results_section[f"save_btn_{btn_idx}"].click(
463
+ fn=None,
464
+ inputs=[
465
+ results_section[f"generated_audio_{btn_idx}"],
466
+ results_section["generated_audio_batch"],
467
+ ],
468
+ js=download_existing_js # Run the above JS
469
+ )
470
+ # ========== Send to SRC Handlers ==========
471
+ for btn_idx in range(1, 9):
472
+ results_section[f"send_to_src_btn_{btn_idx}"].click(
473
+ fn=res_h.send_audio_to_src_with_metadata,
474
+ inputs=[
475
+ results_section[f"generated_audio_{btn_idx}"],
476
+ results_section["lm_metadata_state"]
477
+ ],
478
+ outputs=[
479
+ generation_section["src_audio"],
480
+ generation_section["bpm"],
481
+ generation_section["captions"],
482
+ generation_section["lyrics"],
483
+ generation_section["audio_duration"],
484
+ generation_section["key_scale"],
485
+ generation_section["vocal_language"],
486
+ generation_section["time_signature"],
487
+ results_section["is_format_caption_state"]
488
+ ]
489
+ )
490
+
491
+ # ========== Score Calculation Handlers ==========
492
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
493
+ def make_score_handler(idx):
494
+ return lambda scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
495
+ dit_handler, llm_handler, idx, scale, batch_idx, queue
496
+ )
497
+
498
+ for btn_idx in range(1, 9):
499
+ results_section[f"score_btn_{btn_idx}"].click(
500
+ fn=make_score_handler(btn_idx),
501
+ inputs=[
502
+ generation_section["score_scale"],
503
+ results_section["current_batch_index"],
504
+ results_section["batch_queue"],
505
+ ],
506
+ outputs=[
507
+ results_section[f"score_display_{btn_idx}"],
508
+ results_section[f"details_accordion_{btn_idx}"],
509
+ results_section["batch_queue"]
510
+ ]
511
+ )
512
+
513
+ # ========== LRC Timestamp Handlers ==========
514
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
515
+ def make_lrc_handler(idx):
516
+ return lambda batch_idx, queue, vocal_lang, infer_steps: res_h.generate_lrc_handler(
517
+ dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
518
+ )
519
+
520
+ for btn_idx in range(1, 9):
521
+ results_section[f"lrc_btn_{btn_idx}"].click(
522
+ fn=make_lrc_handler(btn_idx),
523
+ inputs=[
524
+ results_section["current_batch_index"],
525
+ results_section["batch_queue"],
526
+ generation_section["vocal_language"],
527
+ generation_section["inference_steps"],
528
+ ],
529
+ outputs=[
530
+ results_section[f"lrc_display_{btn_idx}"],
531
+ results_section[f"details_accordion_{btn_idx}"],
532
+ # NOTE: Removed generated_audio output!
533
+ # Audio subtitles are now updated via lrc_display.change() event.
534
+ results_section["batch_queue"]
535
+ ]
536
+ )
537
+
538
+ def generation_wrapper(*args):
539
+ yield from res_h.generate_with_batch_management(dit_handler, llm_handler, *args)
540
+ # ========== Generation Handler ==========
541
+ generation_section["generate_btn"].click(
542
+ fn=res_h.clear_audio_outputs_for_new_generation,
543
+ outputs=[
544
+ results_section["generated_audio_1"],
545
+ results_section["generated_audio_2"],
546
+ results_section["generated_audio_3"],
547
+ results_section["generated_audio_4"],
548
+ results_section["generated_audio_5"],
549
+ results_section["generated_audio_6"],
550
+ results_section["generated_audio_7"],
551
+ results_section["generated_audio_8"],
552
+ results_section["generated_audio_batch"],
553
+ ],
554
+ ).then(
555
+ fn=generation_wrapper,
556
+ inputs=[
557
+ generation_section["captions"],
558
+ generation_section["lyrics"],
559
+ generation_section["bpm"],
560
+ generation_section["key_scale"],
561
+ generation_section["time_signature"],
562
+ generation_section["vocal_language"],
563
+ generation_section["inference_steps"],
564
+ generation_section["guidance_scale"],
565
+ generation_section["random_seed_checkbox"],
566
+ generation_section["seed"],
567
+ generation_section["reference_audio"],
568
+ generation_section["audio_duration"],
569
+ generation_section["batch_size_input"],
570
+ generation_section["src_audio"],
571
+ generation_section["text2music_audio_code_string"],
572
+ generation_section["repainting_start"],
573
+ generation_section["repainting_end"],
574
+ generation_section["instruction_display_gen"],
575
+ generation_section["audio_cover_strength"],
576
+ generation_section["task_type"],
577
+ generation_section["use_adg"],
578
+ generation_section["cfg_interval_start"],
579
+ generation_section["cfg_interval_end"],
580
+ generation_section["shift"],
581
+ generation_section["infer_method"],
582
+ generation_section["custom_timesteps"],
583
+ generation_section["audio_format"],
584
+ generation_section["lm_temperature"],
585
+ generation_section["think_checkbox"],
586
+ generation_section["lm_cfg_scale"],
587
+ generation_section["lm_top_k"],
588
+ generation_section["lm_top_p"],
589
+ generation_section["lm_negative_prompt"],
590
+ generation_section["use_cot_metas"],
591
+ generation_section["use_cot_caption"],
592
+ generation_section["use_cot_language"],
593
+ results_section["is_format_caption_state"],
594
+ generation_section["constrained_decoding_debug"],
595
+ generation_section["allow_lm_batch"],
596
+ generation_section["auto_score"],
597
+ generation_section["auto_lrc"],
598
+ generation_section["score_scale"],
599
+ generation_section["lm_batch_chunk_size"],
600
+ generation_section["track_name"],
601
+ generation_section["complete_track_classes"],
602
+ generation_section["autogen_checkbox"],
603
+ results_section["current_batch_index"],
604
+ results_section["total_batches"],
605
+ results_section["batch_queue"],
606
+ results_section["generation_params_state"],
607
+ ],
608
+ outputs=[
609
+ results_section["generated_audio_1"],
610
+ results_section["generated_audio_2"],
611
+ results_section["generated_audio_3"],
612
+ results_section["generated_audio_4"],
613
+ results_section["generated_audio_5"],
614
+ results_section["generated_audio_6"],
615
+ results_section["generated_audio_7"],
616
+ results_section["generated_audio_8"],
617
+ results_section["generated_audio_batch"],
618
+ results_section["generation_info"],
619
+ results_section["status_output"],
620
+ generation_section["seed"],
621
+ results_section["score_display_1"],
622
+ results_section["score_display_2"],
623
+ results_section["score_display_3"],
624
+ results_section["score_display_4"],
625
+ results_section["score_display_5"],
626
+ results_section["score_display_6"],
627
+ results_section["score_display_7"],
628
+ results_section["score_display_8"],
629
+ results_section["codes_display_1"],
630
+ results_section["codes_display_2"],
631
+ results_section["codes_display_3"],
632
+ results_section["codes_display_4"],
633
+ results_section["codes_display_5"],
634
+ results_section["codes_display_6"],
635
+ results_section["codes_display_7"],
636
+ results_section["codes_display_8"],
637
+ results_section["details_accordion_1"],
638
+ results_section["details_accordion_2"],
639
+ results_section["details_accordion_3"],
640
+ results_section["details_accordion_4"],
641
+ results_section["details_accordion_5"],
642
+ results_section["details_accordion_6"],
643
+ results_section["details_accordion_7"],
644
+ results_section["details_accordion_8"],
645
+ results_section["lrc_display_1"],
646
+ results_section["lrc_display_2"],
647
+ results_section["lrc_display_3"],
648
+ results_section["lrc_display_4"],
649
+ results_section["lrc_display_5"],
650
+ results_section["lrc_display_6"],
651
+ results_section["lrc_display_7"],
652
+ results_section["lrc_display_8"],
653
+ results_section["lm_metadata_state"],
654
+ results_section["is_format_caption_state"],
655
+ results_section["current_batch_index"],
656
+ results_section["total_batches"],
657
+ results_section["batch_queue"],
658
+ results_section["generation_params_state"],
659
+ results_section["batch_indicator"],
660
+ results_section["prev_batch_btn"],
661
+ results_section["next_batch_btn"],
662
+ results_section["next_batch_status"],
663
+ results_section["restore_params_btn"],
664
+ ],
665
+ ).then(
666
+ fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
667
+ inputs=[
668
+ generation_section["autogen_checkbox"],
669
+ results_section["generation_params_state"],
670
+ results_section["current_batch_index"],
671
+ results_section["total_batches"],
672
+ results_section["batch_queue"],
673
+ results_section["is_format_caption_state"],
674
+ ],
675
+ outputs=[
676
+ results_section["batch_queue"],
677
+ results_section["total_batches"],
678
+ results_section["next_batch_status"],
679
+ results_section["next_batch_btn"],
680
+ ]
681
+ )
682
+
683
+ # ========== Batch Navigation Handlers ==========
684
+ results_section["prev_batch_btn"].click(
685
+ fn=res_h.navigate_to_previous_batch,
686
+ inputs=[
687
+ results_section["current_batch_index"],
688
+ results_section["batch_queue"],
689
+ ],
690
+ outputs=[
691
+ results_section["generated_audio_1"],
692
+ results_section["generated_audio_2"],
693
+ results_section["generated_audio_3"],
694
+ results_section["generated_audio_4"],
695
+ results_section["generated_audio_5"],
696
+ results_section["generated_audio_6"],
697
+ results_section["generated_audio_7"],
698
+ results_section["generated_audio_8"],
699
+ results_section["generated_audio_batch"],
700
+ results_section["generation_info"],
701
+ results_section["current_batch_index"],
702
+ results_section["batch_indicator"],
703
+ results_section["prev_batch_btn"],
704
+ results_section["next_batch_btn"],
705
+ results_section["status_output"],
706
+ results_section["score_display_1"],
707
+ results_section["score_display_2"],
708
+ results_section["score_display_3"],
709
+ results_section["score_display_4"],
710
+ results_section["score_display_5"],
711
+ results_section["score_display_6"],
712
+ results_section["score_display_7"],
713
+ results_section["score_display_8"],
714
+ results_section["codes_display_1"],
715
+ results_section["codes_display_2"],
716
+ results_section["codes_display_3"],
717
+ results_section["codes_display_4"],
718
+ results_section["codes_display_5"],
719
+ results_section["codes_display_6"],
720
+ results_section["codes_display_7"],
721
+ results_section["codes_display_8"],
722
+ results_section["lrc_display_1"],
723
+ results_section["lrc_display_2"],
724
+ results_section["lrc_display_3"],
725
+ results_section["lrc_display_4"],
726
+ results_section["lrc_display_5"],
727
+ results_section["lrc_display_6"],
728
+ results_section["lrc_display_7"],
729
+ results_section["lrc_display_8"],
730
+ results_section["details_accordion_1"],
731
+ results_section["details_accordion_2"],
732
+ results_section["details_accordion_3"],
733
+ results_section["details_accordion_4"],
734
+ results_section["details_accordion_5"],
735
+ results_section["details_accordion_6"],
736
+ results_section["details_accordion_7"],
737
+ results_section["details_accordion_8"],
738
+ results_section["restore_params_btn"],
739
+ ]
740
+ )
741
+
742
+ results_section["next_batch_btn"].click(
743
+ fn=res_h.capture_current_params,
744
+ inputs=[
745
+ generation_section["captions"],
746
+ generation_section["lyrics"],
747
+ generation_section["bpm"],
748
+ generation_section["key_scale"],
749
+ generation_section["time_signature"],
750
+ generation_section["vocal_language"],
751
+ generation_section["inference_steps"],
752
+ generation_section["guidance_scale"],
753
+ generation_section["random_seed_checkbox"],
754
+ generation_section["seed"],
755
+ generation_section["reference_audio"],
756
+ generation_section["audio_duration"],
757
+ generation_section["batch_size_input"],
758
+ generation_section["src_audio"],
759
+ generation_section["text2music_audio_code_string"],
760
+ generation_section["repainting_start"],
761
+ generation_section["repainting_end"],
762
+ generation_section["instruction_display_gen"],
763
+ generation_section["audio_cover_strength"],
764
+ generation_section["task_type"],
765
+ generation_section["use_adg"],
766
+ generation_section["cfg_interval_start"],
767
+ generation_section["cfg_interval_end"],
768
+ generation_section["shift"],
769
+ generation_section["infer_method"],
770
+ generation_section["custom_timesteps"],
771
+ generation_section["audio_format"],
772
+ generation_section["lm_temperature"],
773
+ generation_section["think_checkbox"],
774
+ generation_section["lm_cfg_scale"],
775
+ generation_section["lm_top_k"],
776
+ generation_section["lm_top_p"],
777
+ generation_section["lm_negative_prompt"],
778
+ generation_section["use_cot_metas"],
779
+ generation_section["use_cot_caption"],
780
+ generation_section["use_cot_language"],
781
+ generation_section["constrained_decoding_debug"],
782
+ generation_section["allow_lm_batch"],
783
+ generation_section["auto_score"],
784
+ generation_section["auto_lrc"],
785
+ generation_section["score_scale"],
786
+ generation_section["lm_batch_chunk_size"],
787
+ generation_section["track_name"],
788
+ generation_section["complete_track_classes"],
789
+ ],
790
+ outputs=[results_section["generation_params_state"]]
791
+ ).then(
792
+ fn=res_h.navigate_to_next_batch,
793
+ inputs=[
794
+ generation_section["autogen_checkbox"],
795
+ results_section["current_batch_index"],
796
+ results_section["total_batches"],
797
+ results_section["batch_queue"],
798
+ ],
799
+ outputs=[
800
+ results_section["generated_audio_1"],
801
+ results_section["generated_audio_2"],
802
+ results_section["generated_audio_3"],
803
+ results_section["generated_audio_4"],
804
+ results_section["generated_audio_5"],
805
+ results_section["generated_audio_6"],
806
+ results_section["generated_audio_7"],
807
+ results_section["generated_audio_8"],
808
+ results_section["generated_audio_batch"],
809
+ results_section["generation_info"],
810
+ results_section["current_batch_index"],
811
+ results_section["batch_indicator"],
812
+ results_section["prev_batch_btn"],
813
+ results_section["next_batch_btn"],
814
+ results_section["status_output"],
815
+ results_section["next_batch_status"],
816
+ results_section["score_display_1"],
817
+ results_section["score_display_2"],
818
+ results_section["score_display_3"],
819
+ results_section["score_display_4"],
820
+ results_section["score_display_5"],
821
+ results_section["score_display_6"],
822
+ results_section["score_display_7"],
823
+ results_section["score_display_8"],
824
+ results_section["codes_display_1"],
825
+ results_section["codes_display_2"],
826
+ results_section["codes_display_3"],
827
+ results_section["codes_display_4"],
828
+ results_section["codes_display_5"],
829
+ results_section["codes_display_6"],
830
+ results_section["codes_display_7"],
831
+ results_section["codes_display_8"],
832
+ results_section["lrc_display_1"],
833
+ results_section["lrc_display_2"],
834
+ results_section["lrc_display_3"],
835
+ results_section["lrc_display_4"],
836
+ results_section["lrc_display_5"],
837
+ results_section["lrc_display_6"],
838
+ results_section["lrc_display_7"],
839
+ results_section["lrc_display_8"],
840
+ results_section["details_accordion_1"],
841
+ results_section["details_accordion_2"],
842
+ results_section["details_accordion_3"],
843
+ results_section["details_accordion_4"],
844
+ results_section["details_accordion_5"],
845
+ results_section["details_accordion_6"],
846
+ results_section["details_accordion_7"],
847
+ results_section["details_accordion_8"],
848
+ results_section["restore_params_btn"],
849
+ ]
850
+ ).then(
851
+ fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
852
+ inputs=[
853
+ generation_section["autogen_checkbox"],
854
+ results_section["generation_params_state"],
855
+ results_section["current_batch_index"],
856
+ results_section["total_batches"],
857
+ results_section["batch_queue"],
858
+ results_section["is_format_caption_state"],
859
+ ],
860
+ outputs=[
861
+ results_section["batch_queue"],
862
+ results_section["total_batches"],
863
+ results_section["next_batch_status"],
864
+ results_section["next_batch_btn"],
865
+ ]
866
+ )
867
+
868
+ # ========== Restore Parameters Handler ==========
869
+ results_section["restore_params_btn"].click(
870
+ fn=res_h.restore_batch_parameters,
871
+ inputs=[
872
+ results_section["current_batch_index"],
873
+ results_section["batch_queue"]
874
+ ],
875
+ outputs=[
876
+ generation_section["text2music_audio_code_string"],
877
+ generation_section["captions"],
878
+ generation_section["lyrics"],
879
+ generation_section["bpm"],
880
+ generation_section["key_scale"],
881
+ generation_section["time_signature"],
882
+ generation_section["vocal_language"],
883
+ generation_section["audio_duration"],
884
+ generation_section["batch_size_input"],
885
+ generation_section["inference_steps"],
886
+ generation_section["lm_temperature"],
887
+ generation_section["lm_cfg_scale"],
888
+ generation_section["lm_top_k"],
889
+ generation_section["lm_top_p"],
890
+ generation_section["think_checkbox"],
891
+ generation_section["use_cot_caption"],
892
+ generation_section["use_cot_language"],
893
+ generation_section["allow_lm_batch"],
894
+ generation_section["track_name"],
895
+ generation_section["complete_track_classes"],
896
+ ]
897
+ )
898
+
899
+ # ========== LRC Display Change Handlers ==========
900
+ # NEW APPROACH: Use lrc_display.change() to update audio subtitles
901
+ # This decouples audio value updates from subtitle updates, avoiding flickering.
902
+ #
903
+ # When lrc_display text changes (from generate, LRC button, or manual edit):
904
+ # 1. lrc_display.change() is triggered
905
+ # 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
906
+ # 3. Audio value is NEVER updated here - only subtitles
907
+ for lrc_idx in range(1, 9):
908
+ results_section[f"lrc_display_{lrc_idx}"].change(
909
+ fn=res_h.update_audio_subtitles_from_lrc,
910
+ inputs=[
911
+ results_section[f"lrc_display_{lrc_idx}"],
912
+ # audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
913
+ ],
914
+ outputs=[
915
+ results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
916
+ ]
917
+ )
918
+
919
+
920
+ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
921
+ """Setup event handlers for the training tab (dataset builder and LoRA training)"""
922
+
923
+ # ========== Load Existing Dataset (Top Section) ==========
924
+
925
+ # Load existing dataset JSON at the top of Dataset Builder
926
+ training_section["load_json_btn"].click(
927
+ fn=train_h.load_existing_dataset_for_preprocess,
928
+ inputs=[
929
+ training_section["load_json_path"],
930
+ training_section["dataset_builder_state"],
931
+ ],
932
+ outputs=[
933
+ training_section["load_json_status"],
934
+ training_section["audio_files_table"],
935
+ training_section["sample_selector"],
936
+ training_section["dataset_builder_state"],
937
+ # Also update preview fields with first sample
938
+ training_section["preview_audio"],
939
+ training_section["preview_filename"],
940
+ training_section["edit_caption"],
941
+ training_section["edit_genre"],
942
+ training_section["prompt_override"],
943
+ training_section["edit_lyrics"],
944
+ training_section["edit_bpm"],
945
+ training_section["edit_keyscale"],
946
+ training_section["edit_timesig"],
947
+ training_section["edit_duration"],
948
+ training_section["edit_language"],
949
+ training_section["edit_instrumental"],
950
+ training_section["raw_lyrics_display"],
951
+ training_section["has_raw_lyrics_state"],
952
+ # Update dataset-level settings
953
+ training_section["dataset_name"],
954
+ training_section["custom_tag"],
955
+ training_section["tag_position"],
956
+ training_section["all_instrumental"],
957
+ training_section["genre_ratio"],
958
+ ]
959
+ ).then(
960
+ fn=lambda has_raw: gr.update(visible=has_raw),
961
+ inputs=[training_section["has_raw_lyrics_state"]],
962
+ outputs=[training_section["raw_lyrics_display"]],
963
+ )
964
+
965
+ # ========== Dataset Builder Handlers ==========
966
+
967
+ # Scan directory for audio files
968
+ training_section["scan_btn"].click(
969
+ fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
970
+ dir, name, tag, pos, instr, state
971
+ ),
972
+ inputs=[
973
+ training_section["audio_directory"],
974
+ training_section["dataset_name"],
975
+ training_section["custom_tag"],
976
+ training_section["tag_position"],
977
+ training_section["all_instrumental"],
978
+ training_section["dataset_builder_state"],
979
+ ],
980
+ outputs=[
981
+ training_section["audio_files_table"],
982
+ training_section["scan_status"],
983
+ training_section["sample_selector"],
984
+ training_section["dataset_builder_state"],
985
+ ]
986
+ )
987
+
988
+ # Auto-label all samples
989
+ training_section["auto_label_btn"].click(
990
+ fn=lambda state, skip, fmt_lyrics, trans_lyrics, only_unlab: train_h.auto_label_all(
991
+ dit_handler, llm_handler, state, skip, fmt_lyrics, trans_lyrics, only_unlab
992
+ ),
993
+ inputs=[
994
+ training_section["dataset_builder_state"],
995
+ training_section["skip_metas"],
996
+ training_section["format_lyrics"],
997
+ training_section["transcribe_lyrics"],
998
+ training_section["only_unlabeled"],
999
+ ],
1000
+ outputs=[
1001
+ training_section["audio_files_table"],
1002
+ training_section["label_progress"],
1003
+ training_section["dataset_builder_state"],
1004
+ ]
1005
+ ).then(
1006
+ # Refresh preview/edit fields after labeling completes
1007
+ fn=train_h.get_sample_preview,
1008
+ inputs=[
1009
+ training_section["sample_selector"],
1010
+ training_section["dataset_builder_state"],
1011
+ ],
1012
+ outputs=[
1013
+ training_section["preview_audio"],
1014
+ training_section["preview_filename"],
1015
+ training_section["edit_caption"],
1016
+ training_section["edit_genre"],
1017
+ training_section["prompt_override"],
1018
+ training_section["edit_lyrics"],
1019
+ training_section["edit_bpm"],
1020
+ training_section["edit_keyscale"],
1021
+ training_section["edit_timesig"],
1022
+ training_section["edit_duration"],
1023
+ training_section["edit_language"],
1024
+ training_section["edit_instrumental"],
1025
+ training_section["raw_lyrics_display"],
1026
+ training_section["has_raw_lyrics_state"],
1027
+ ]
1028
+ ).then(
1029
+ fn=lambda status: f"{status or '✅ Auto-label complete.'}\n✅ Preview refreshed.",
1030
+ inputs=[training_section["label_progress"]],
1031
+ outputs=[training_section["label_progress"]],
1032
+ ).then(
1033
+ fn=lambda has_raw: gr.update(visible=bool(has_raw)),
1034
+ inputs=[training_section["has_raw_lyrics_state"]],
1035
+ outputs=[training_section["raw_lyrics_display"]],
1036
+ )
1037
+
1038
+ # Mutual exclusion: format_lyrics and transcribe_lyrics cannot both be True
1039
+ training_section["format_lyrics"].change(
1040
+ fn=lambda fmt: gr.update(value=False) if fmt else gr.update(),
1041
+ inputs=[training_section["format_lyrics"]],
1042
+ outputs=[training_section["transcribe_lyrics"]]
1043
+ )
1044
+
1045
+ training_section["transcribe_lyrics"].change(
1046
+ fn=lambda trans: gr.update(value=False) if trans else gr.update(),
1047
+ inputs=[training_section["transcribe_lyrics"]],
1048
+ outputs=[training_section["format_lyrics"]]
1049
+ )
1050
+
1051
+ # Sample selector change - update preview
1052
+ training_section["sample_selector"].change(
1053
+ fn=train_h.get_sample_preview,
1054
+ inputs=[
1055
+ training_section["sample_selector"],
1056
+ training_section["dataset_builder_state"],
1057
+ ],
1058
+ outputs=[
1059
+ training_section["preview_audio"],
1060
+ training_section["preview_filename"],
1061
+ training_section["edit_caption"],
1062
+ training_section["edit_genre"],
1063
+ training_section["prompt_override"],
1064
+ training_section["edit_lyrics"],
1065
+ training_section["edit_bpm"],
1066
+ training_section["edit_keyscale"],
1067
+ training_section["edit_timesig"],
1068
+ training_section["edit_duration"],
1069
+ training_section["edit_language"],
1070
+ training_section["edit_instrumental"],
1071
+ training_section["raw_lyrics_display"],
1072
+ training_section["has_raw_lyrics_state"],
1073
+ ]
1074
+ ).then(
1075
+ # Show/hide raw lyrics panel based on whether raw lyrics exist
1076
+ fn=lambda has_raw: gr.update(visible=has_raw),
1077
+ inputs=[training_section["has_raw_lyrics_state"]],
1078
+ outputs=[training_section["raw_lyrics_display"]],
1079
+ )
1080
+
1081
+ # Save sample edit
1082
+ training_section["save_edit_btn"].click(
1083
+ fn=train_h.save_sample_edit,
1084
+ inputs=[
1085
+ training_section["sample_selector"],
1086
+ training_section["edit_caption"],
1087
+ training_section["edit_genre"],
1088
+ training_section["prompt_override"],
1089
+ training_section["edit_lyrics"],
1090
+ training_section["edit_bpm"],
1091
+ training_section["edit_keyscale"],
1092
+ training_section["edit_timesig"],
1093
+ training_section["edit_language"],
1094
+ training_section["edit_instrumental"],
1095
+ training_section["dataset_builder_state"],
1096
+ ],
1097
+ outputs=[
1098
+ training_section["audio_files_table"],
1099
+ training_section["edit_status"],
1100
+ training_section["dataset_builder_state"],
1101
+ ]
1102
+ )
1103
+
1104
+ # Update settings when changed (including genre_ratio)
1105
+ for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"], training_section["genre_ratio"]]:
1106
+ trigger.change(
1107
+ fn=train_h.update_settings,
1108
+ inputs=[
1109
+ training_section["custom_tag"],
1110
+ training_section["tag_position"],
1111
+ training_section["all_instrumental"],
1112
+ training_section["genre_ratio"],
1113
+ training_section["dataset_builder_state"],
1114
+ ],
1115
+ outputs=[training_section["dataset_builder_state"]]
1116
+ )
1117
+
1118
+ # Save dataset
1119
+ training_section["save_dataset_btn"].click(
1120
+ fn=train_h.save_dataset,
1121
+ inputs=[
1122
+ training_section["save_path"],
1123
+ training_section["dataset_name"],
1124
+ training_section["dataset_builder_state"],
1125
+ ],
1126
+ outputs=[
1127
+ training_section["save_status"],
1128
+ training_section["save_path"],
1129
+ ]
1130
+ )
1131
+
1132
+ # ========== Preprocess Handlers ==========
1133
+
1134
+ # Load existing dataset JSON for preprocessing
1135
+ # This also updates the preview section so users can view/edit samples
1136
+ training_section["load_existing_dataset_btn"].click(
1137
+ fn=train_h.load_existing_dataset_for_preprocess,
1138
+ inputs=[
1139
+ training_section["load_existing_dataset_path"],
1140
+ training_section["dataset_builder_state"],
1141
+ ],
1142
+ outputs=[
1143
+ training_section["load_existing_status"],
1144
+ training_section["audio_files_table"],
1145
+ training_section["sample_selector"],
1146
+ training_section["dataset_builder_state"],
1147
+ # Also update preview fields with first sample
1148
+ training_section["preview_audio"],
1149
+ training_section["preview_filename"],
1150
+ training_section["edit_caption"],
1151
+ training_section["edit_genre"],
1152
+ training_section["prompt_override"],
1153
+ training_section["edit_lyrics"],
1154
+ training_section["edit_bpm"],
1155
+ training_section["edit_keyscale"],
1156
+ training_section["edit_timesig"],
1157
+ training_section["edit_duration"],
1158
+ training_section["edit_language"],
1159
+ training_section["edit_instrumental"],
1160
+ training_section["raw_lyrics_display"],
1161
+ training_section["has_raw_lyrics_state"],
1162
+ # Update dataset-level settings
1163
+ training_section["dataset_name"],
1164
+ training_section["custom_tag"],
1165
+ training_section["tag_position"],
1166
+ training_section["all_instrumental"],
1167
+ training_section["genre_ratio"],
1168
+ ]
1169
+ ).then(
1170
+ fn=lambda has_raw: gr.update(visible=has_raw),
1171
+ inputs=[training_section["has_raw_lyrics_state"]],
1172
+ outputs=[training_section["raw_lyrics_display"]],
1173
+ )
1174
+
1175
+ # Preprocess dataset to tensor files
1176
+ training_section["preprocess_btn"].click(
1177
+ fn=lambda output_dir, state: train_h.preprocess_dataset(
1178
+ output_dir, dit_handler, state
1179
+ ),
1180
+ inputs=[
1181
+ training_section["preprocess_output_dir"],
1182
+ training_section["dataset_builder_state"],
1183
+ ],
1184
+ outputs=[training_section["preprocess_progress"]]
1185
+ )
1186
+
1187
+ # ========== Training Tab Handlers ==========
1188
+
1189
+ # Load preprocessed tensor dataset
1190
+ training_section["load_dataset_btn"].click(
1191
+ fn=train_h.load_training_dataset,
1192
+ inputs=[training_section["training_tensor_dir"]],
1193
+ outputs=[training_section["training_dataset_info"]]
1194
+ )
1195
+
1196
+ # Start training from preprocessed tensors
1197
+ def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts):
1198
+ from loguru import logger
1199
+ if not isinstance(ts, dict):
1200
+ ts = {"is_training": False, "should_stop": False}
1201
+ try:
1202
+ for progress, log_msg, plot, state in train_h.start_training(
1203
+ tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, rc, ts
1204
+ ):
1205
+ yield progress, log_msg, plot, state
1206
+ except Exception as e:
1207
+ logger.exception("Training wrapper error")
1208
+ yield f"❌ Error: {str(e)}", str(e), None, ts
1209
+
1210
+ training_section["start_training_btn"].click(
1211
+ fn=training_wrapper,
1212
+ inputs=[
1213
+ training_section["training_tensor_dir"],
1214
+ training_section["lora_rank"],
1215
+ training_section["lora_alpha"],
1216
+ training_section["lora_dropout"],
1217
+ training_section["learning_rate"],
1218
+ training_section["train_epochs"],
1219
+ training_section["train_batch_size"],
1220
+ training_section["gradient_accumulation"],
1221
+ training_section["save_every_n_epochs"],
1222
+ training_section["training_shift"],
1223
+ training_section["training_seed"],
1224
+ training_section["lora_output_dir"],
1225
+ training_section["resume_checkpoint_dir"],
1226
+ training_section["training_state"],
1227
+ ],
1228
+ outputs=[
1229
+ training_section["training_progress"],
1230
+ training_section["training_log"],
1231
+ training_section["training_loss_plot"],
1232
+ training_section["training_state"],
1233
+ ]
1234
+ )
1235
+
1236
+ # Stop training
1237
+ training_section["stop_training_btn"].click(
1238
+ fn=train_h.stop_training,
1239
+ inputs=[training_section["training_state"]],
1240
+ outputs=[
1241
+ training_section["training_progress"],
1242
+ training_section["training_state"],
1243
+ ]
1244
+ )
1245
+
1246
+ # Export LoRA
1247
+ training_section["export_lora_btn"].click(
1248
+ fn=train_h.export_lora,
1249
+ inputs=[
1250
+ training_section["export_path"],
1251
+ training_section["lora_output_dir"],
1252
+ ],
1253
+ outputs=[training_section["export_status"]]
1254
+ )
acestep/gradio_ui/events/generation_handlers.py ADDED
@@ -0,0 +1,1050 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation Input Handlers Module
3
+ Contains event handlers and helper functions related to generation inputs
4
+ """
5
+ import os
6
+ import json
7
+ import random
8
+ import glob
9
+ import gradio as gr
10
+ from typing import Optional, List, Tuple
11
+ from acestep.constants import (
12
+ TASK_TYPES_TURBO,
13
+ TASK_TYPES_BASE,
14
+ )
15
+ from acestep.gradio_ui.i18n import t
16
+ from acestep.inference import understand_music, create_sample, format_sample
17
+ from acestep.gpu_config import get_global_gpu_config
18
+
19
+
20
+ def clamp_duration_to_gpu_limit(duration_value: Optional[float], llm_handler=None) -> Optional[float]:
21
+ """
22
+ Clamp duration value to GPU memory limit.
23
+
24
+ Args:
25
+ duration_value: Duration in seconds (can be None or -1 for no limit)
26
+ llm_handler: LLM handler instance (to check if LM is initialized)
27
+
28
+ Returns:
29
+ Clamped duration value, or original value if within limits
30
+ """
31
+ if duration_value is None or duration_value <= 0:
32
+ return duration_value
33
+
34
+ gpu_config = get_global_gpu_config()
35
+ lm_initialized = llm_handler.llm_initialized if llm_handler else False
36
+ max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
37
+
38
+ if duration_value > max_duration:
39
+ return float(max_duration)
40
+
41
+ return duration_value
42
+
43
+
44
+ def parse_and_validate_timesteps(
45
+ timesteps_str: str,
46
+ inference_steps: int
47
+ ) -> Tuple[Optional[List[float]], bool, str]:
48
+ """
49
+ Parse timesteps string and validate.
50
+
51
+ Args:
52
+ timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
53
+ inference_steps: Expected number of inference steps
54
+
55
+ Returns:
56
+ Tuple of (parsed_timesteps, has_warning, warning_message)
57
+ - parsed_timesteps: List of float timesteps, or None if invalid/empty
58
+ - has_warning: Whether a warning was shown
59
+ - warning_message: Description of the warning
60
+ """
61
+ if not timesteps_str or not timesteps_str.strip():
62
+ return None, False, ""
63
+
64
+ # Parse comma-separated values
65
+ values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
66
+
67
+ if not values:
68
+ return None, False, ""
69
+
70
+ # Handle optional trailing 0
71
+ if values[-1] != "0":
72
+ values.append("0")
73
+
74
+ try:
75
+ timesteps = [float(v) for v in values]
76
+ except ValueError:
77
+ gr.Warning(t("messages.invalid_timesteps_format"))
78
+ return None, True, "Invalid format"
79
+
80
+ # Validate range [0, 1]
81
+ if any(ts < 0 or ts > 1 for ts in timesteps):
82
+ gr.Warning(t("messages.timesteps_out_of_range"))
83
+ return None, True, "Out of range"
84
+
85
+ # Check if count matches inference_steps
86
+ actual_steps = len(timesteps) - 1
87
+ if actual_steps != inference_steps:
88
+ gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
89
+ return timesteps, True, f"Using {actual_steps} steps from timesteps"
90
+
91
+ return timesteps, False, ""
92
+
93
+
94
+ def load_metadata(file_obj, llm_handler=None):
95
+ """Load generation parameters from a JSON file
96
+
97
+ Args:
98
+ file_obj: Uploaded file object
99
+ llm_handler: LLM handler instance (optional, for GPU duration limit check)
100
+ """
101
+ if file_obj is None:
102
+ gr.Warning(t("messages.no_file_selected"))
103
+ return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
104
+
105
+ try:
106
+ # Read the uploaded file
107
+ if hasattr(file_obj, 'name'):
108
+ filepath = file_obj.name
109
+ else:
110
+ filepath = file_obj
111
+
112
+ with open(filepath, 'r', encoding='utf-8') as f:
113
+ metadata = json.load(f)
114
+
115
+ # Extract all fields
116
+ task_type = metadata.get('task_type', 'text2music')
117
+ captions = metadata.get('caption', '')
118
+ lyrics = metadata.get('lyrics', '')
119
+ vocal_language = metadata.get('vocal_language', 'unknown')
120
+
121
+ # Convert bpm
122
+ bpm_value = metadata.get('bpm')
123
+ if bpm_value is not None and bpm_value != "N/A":
124
+ try:
125
+ bpm = int(bpm_value) if bpm_value else None
126
+ except:
127
+ bpm = None
128
+ else:
129
+ bpm = None
130
+
131
+ key_scale = metadata.get('keyscale', '')
132
+ time_signature = metadata.get('timesignature', '')
133
+
134
+ # Convert duration
135
+ duration_value = metadata.get('duration', -1)
136
+ if duration_value is not None and duration_value != "N/A":
137
+ try:
138
+ audio_duration = float(duration_value)
139
+ # Clamp duration to GPU memory limit
140
+ audio_duration = clamp_duration_to_gpu_limit(audio_duration, llm_handler)
141
+ except:
142
+ audio_duration = -1
143
+ else:
144
+ audio_duration = -1
145
+
146
+ batch_size = metadata.get('batch_size', 2)
147
+ inference_steps = metadata.get('inference_steps', 8)
148
+ guidance_scale = metadata.get('guidance_scale', 7.0)
149
+ seed = metadata.get('seed', '-1')
150
+ random_seed = False # Always set to False when loading to enable reproducibility with saved seed
151
+ use_adg = metadata.get('use_adg', False)
152
+ cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
153
+ cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
154
+ audio_format = metadata.get('audio_format', 'mp3')
155
+ lm_temperature = metadata.get('lm_temperature', 0.85)
156
+ lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
157
+ lm_top_k = metadata.get('lm_top_k', 0)
158
+ lm_top_p = metadata.get('lm_top_p', 0.9)
159
+ lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
160
+ use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
161
+ use_cot_caption = metadata.get('use_cot_caption', True)
162
+ use_cot_language = metadata.get('use_cot_language', True)
163
+ audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
164
+ think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
165
+ audio_codes = metadata.get('audio_codes', '')
166
+ repainting_start = metadata.get('repainting_start', 0.0)
167
+ repainting_end = metadata.get('repainting_end', -1)
168
+ track_name = metadata.get('track_name')
169
+ complete_track_classes = metadata.get('complete_track_classes', [])
170
+ shift = metadata.get('shift', 3.0) # Default 3.0 for base models
171
+ infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
172
+ custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
173
+ if custom_timesteps is None:
174
+ custom_timesteps = ''
175
+ instrumental = metadata.get('instrumental', False) # Added: read instrumental
176
+
177
+ gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
178
+
179
+ return (
180
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
181
+ audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
182
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
183
+ custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
184
+ audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
185
+ use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
186
+ think, audio_codes, repainting_start, repainting_end,
187
+ track_name, complete_track_classes, instrumental,
188
+ True # Set is_format_caption to True when loading from file
189
+ )
190
+
191
+ except json.JSONDecodeError as e:
192
+ gr.Warning(t("messages.invalid_json", error=str(e)))
193
+ return [None] * 36 + [False]
194
+ except Exception as e:
195
+ gr.Warning(t("messages.load_error", error=str(e)))
196
+ return [None] * 36 + [False]
197
+
198
+
199
+ def load_random_example(task_type: str, llm_handler=None):
200
+ """Load a random example from the task-specific examples directory
201
+
202
+ Args:
203
+ task_type: The task type (e.g., "text2music")
204
+ llm_handler: LLM handler instance (optional, for GPU duration limit check)
205
+
206
+ Returns:
207
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
208
+ """
209
+ try:
210
+ # Get the project root directory
211
+ current_file = os.path.abspath(__file__)
212
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
213
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
214
+
215
+ # Construct the examples directory path
216
+ examples_dir = os.path.join(project_root, "examples", task_type)
217
+
218
+ # Check if directory exists
219
+ if not os.path.exists(examples_dir):
220
+ gr.Warning(f"Examples directory not found: examples/{task_type}/")
221
+ return "", "", True, None, None, "", "", ""
222
+
223
+ # Find all JSON files in the directory
224
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
225
+
226
+ if not json_files:
227
+ gr.Warning(f"No JSON files found in examples/{task_type}/")
228
+ return "", "", True, None, None, "", "", ""
229
+
230
+ # Randomly select one file
231
+ selected_file = random.choice(json_files)
232
+
233
+ # Read and parse JSON
234
+ try:
235
+ with open(selected_file, 'r', encoding='utf-8') as f:
236
+ data = json.load(f)
237
+
238
+ # Extract caption (prefer 'caption', fallback to 'prompt')
239
+ caption_value = data.get('caption', data.get('prompt', ''))
240
+ if not isinstance(caption_value, str):
241
+ caption_value = str(caption_value) if caption_value else ''
242
+
243
+ # Extract lyrics
244
+ lyrics_value = data.get('lyrics', '')
245
+ if not isinstance(lyrics_value, str):
246
+ lyrics_value = str(lyrics_value) if lyrics_value else ''
247
+
248
+ # Extract think (default to True if not present)
249
+ think_value = data.get('think', True)
250
+ if not isinstance(think_value, bool):
251
+ think_value = True
252
+
253
+ # Extract optional metadata fields
254
+ bpm_value = None
255
+ if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
256
+ try:
257
+ bpm_value = int(data['bpm'])
258
+ except (ValueError, TypeError):
259
+ pass
260
+
261
+ duration_value = None
262
+ if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
263
+ try:
264
+ duration_value = float(data['duration'])
265
+ # Clamp duration to GPU memory limit
266
+ duration_value = clamp_duration_to_gpu_limit(duration_value, llm_handler)
267
+ except (ValueError, TypeError):
268
+ pass
269
+
270
+ keyscale_value = data.get('keyscale', '')
271
+ if keyscale_value in [None, "N/A"]:
272
+ keyscale_value = ''
273
+
274
+ language_value = data.get('language', '')
275
+ if language_value in [None, "N/A"]:
276
+ language_value = ''
277
+
278
+ timesignature_value = data.get('timesignature', '')
279
+ if timesignature_value in [None, "N/A"]:
280
+ timesignature_value = ''
281
+
282
+ gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
283
+ return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
284
+
285
+ except json.JSONDecodeError as e:
286
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
287
+ return "", "", True, None, None, "", "", ""
288
+ except Exception as e:
289
+ gr.Warning(t("messages.example_error", error=str(e)))
290
+ return "", "", True, None, None, "", "", ""
291
+
292
+ except Exception as e:
293
+ gr.Warning(t("messages.example_error", error=str(e)))
294
+ return "", "", True, None, None, "", "", ""
295
+
296
+
297
+ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
298
+ """Smart sample function that uses LM if initialized, otherwise falls back to examples
299
+
300
+ This is a Gradio wrapper that uses the understand_music API from acestep.inference
301
+ to generate examples when LM is available.
302
+
303
+ Args:
304
+ llm_handler: LLM handler instance
305
+ task_type: The task type (e.g., "text2music")
306
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
307
+
308
+ Returns:
309
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
310
+ """
311
+ # Check if LM is initialized
312
+ if llm_handler.llm_initialized:
313
+ # Use LM to generate example via understand_music API
314
+ try:
315
+ result = understand_music(
316
+ llm_handler=llm_handler,
317
+ audio_codes="NO USER INPUT", # Empty input triggers example generation
318
+ temperature=0.85,
319
+ use_constrained_decoding=True,
320
+ constrained_decoding_debug=constrained_decoding_debug,
321
+ )
322
+
323
+ if result.success:
324
+ gr.Info(t("messages.lm_generated"))
325
+ # Clamp duration to GPU memory limit
326
+ clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
327
+ return (
328
+ result.caption,
329
+ result.lyrics,
330
+ True, # Always enable think when using LM-generated examples
331
+ result.bpm,
332
+ clamped_duration,
333
+ result.keyscale,
334
+ result.language,
335
+ result.timesignature,
336
+ )
337
+ else:
338
+ gr.Warning(t("messages.lm_fallback"))
339
+ return load_random_example(task_type)
340
+
341
+ except Exception as e:
342
+ gr.Warning(t("messages.lm_fallback"))
343
+ return load_random_example(task_type)
344
+ else:
345
+ # LM not initialized, use examples directory
346
+ return load_random_example(task_type)
347
+
348
+
349
+ def load_random_simple_description():
350
+ """Load a random description from the simple_mode examples directory.
351
+
352
+ Returns:
353
+ Tuple of (description, instrumental, vocal_language) for updating UI components
354
+ """
355
+ try:
356
+ # Get the project root directory
357
+ current_file = os.path.abspath(__file__)
358
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
359
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
360
+
361
+ # Construct the examples directory path
362
+ examples_dir = os.path.join(project_root, "examples", "simple_mode")
363
+
364
+ # Check if directory exists
365
+ if not os.path.exists(examples_dir):
366
+ gr.Warning(t("messages.simple_examples_not_found"))
367
+ return gr.update(), gr.update(), gr.update()
368
+
369
+ # Find all JSON files in the directory
370
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
371
+
372
+ if not json_files:
373
+ gr.Warning(t("messages.simple_examples_empty"))
374
+ return gr.update(), gr.update(), gr.update()
375
+
376
+ # Randomly select one file
377
+ selected_file = random.choice(json_files)
378
+
379
+ # Read and parse JSON
380
+ try:
381
+ with open(selected_file, 'r', encoding='utf-8') as f:
382
+ data = json.load(f)
383
+
384
+ # Extract fields
385
+ description = data.get('description', '')
386
+ instrumental = data.get('instrumental', False)
387
+ vocal_language = data.get('vocal_language', 'unknown')
388
+
389
+ # Ensure vocal_language is a string
390
+ if isinstance(vocal_language, list):
391
+ vocal_language = vocal_language[0] if vocal_language else 'unknown'
392
+
393
+ gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
394
+ return description, instrumental, vocal_language
395
+
396
+ except json.JSONDecodeError as e:
397
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
398
+ return gr.update(), gr.update(), gr.update()
399
+ except Exception as e:
400
+ gr.Warning(t("messages.example_error", error=str(e)))
401
+ return gr.update(), gr.update(), gr.update()
402
+
403
+ except Exception as e:
404
+ gr.Warning(t("messages.example_error", error=str(e)))
405
+ return gr.update(), gr.update(), gr.update()
406
+
407
+
408
+ def refresh_checkpoints(dit_handler):
409
+ """Refresh available checkpoints"""
410
+ choices = dit_handler.get_available_checkpoints()
411
+ return gr.update(choices=choices)
412
+
413
+
414
+ def update_model_type_settings(config_path):
415
+ """Update UI settings based on model type (fallback when handler not initialized yet)
416
+
417
+ Note: This is used as a fallback when the user changes config_path dropdown
418
+ before initializing the model. The actual settings are determined by the
419
+ handler's is_turbo_model() method after initialization.
420
+ """
421
+ if config_path is None:
422
+ config_path = ""
423
+ config_path_lower = config_path.lower()
424
+
425
+ # Determine is_turbo based on config_path string
426
+ # This is a heuristic fallback - actual model type is determined after loading
427
+ if "turbo" in config_path_lower:
428
+ is_turbo = True
429
+ elif "base" in config_path_lower:
430
+ is_turbo = False
431
+ else:
432
+ # Default to turbo settings for unknown model types
433
+ is_turbo = True
434
+
435
+ return get_model_type_ui_settings(is_turbo)
436
+
437
+
438
+ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu, compile_model, quantization):
439
+ """Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
440
+ # Convert quantization checkbox to value (int8_weight_only if checked, None if not)
441
+ quant_value = "int8_weight_only" if quantization else None
442
+
443
+ # Initialize DiT handler
444
+ status, enable = dit_handler.initialize_service(
445
+ checkpoint, config_path, device,
446
+ use_flash_attention=use_flash_attention, compile_model=compile_model,
447
+ offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu,
448
+ quantization=quant_value
449
+ )
450
+
451
+ # Initialize LM handler if requested
452
+ if init_llm:
453
+ # Get checkpoint directory
454
+ current_file = os.path.abspath(__file__)
455
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
456
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
457
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
458
+
459
+ lm_status, lm_success = llm_handler.initialize(
460
+ checkpoint_dir=checkpoint_dir,
461
+ lm_model_path=lm_model_path,
462
+ backend=backend,
463
+ device=device,
464
+ offload_to_cpu=offload_to_cpu,
465
+ dtype=None,
466
+ )
467
+
468
+ if lm_success:
469
+ status += f"\n{lm_status}"
470
+ else:
471
+ status += f"\n{lm_status}"
472
+ # Don't fail the entire initialization if LM fails, but log it
473
+ # Keep enable as is (DiT initialization result) even if LM fails
474
+
475
+ # Check if model is initialized - if so, collapse the accordion
476
+ is_model_initialized = dit_handler.model is not None
477
+ accordion_state = gr.Accordion(open=not is_model_initialized)
478
+
479
+ # Get model type settings based on actual loaded model
480
+ is_turbo = dit_handler.is_turbo_model()
481
+ model_type_settings = get_model_type_ui_settings(is_turbo)
482
+
483
+ return (
484
+ status,
485
+ gr.update(interactive=enable),
486
+ accordion_state,
487
+ *model_type_settings
488
+ )
489
+
490
+
491
+ def get_model_type_ui_settings(is_turbo: bool):
492
+ """Get UI settings based on whether the model is turbo or base"""
493
+ if is_turbo:
494
+ # Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
495
+ return (
496
+ gr.update(value=8, maximum=20, minimum=1), # inference_steps
497
+ gr.update(visible=False), # guidance_scale
498
+ gr.update(visible=False), # use_adg
499
+ gr.update(value=3.0, visible=True), # shift (show with default 3.0)
500
+ gr.update(visible=False), # cfg_interval_start
501
+ gr.update(visible=False), # cfg_interval_end
502
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
503
+ )
504
+ else:
505
+ # Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
506
+ return (
507
+ gr.update(value=32, maximum=200, minimum=1), # inference_steps
508
+ gr.update(visible=True), # guidance_scale
509
+ gr.update(visible=True), # use_adg
510
+ gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
511
+ gr.update(visible=True), # cfg_interval_start
512
+ gr.update(visible=True), # cfg_interval_end
513
+ gr.update(choices=TASK_TYPES_BASE), # task_type
514
+ )
515
+
516
+
517
+ def update_negative_prompt_visibility(init_llm_checked):
518
+ """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
519
+ return gr.update(visible=init_llm_checked)
520
+
521
+
522
+ def _has_reference_audio(reference_audio) -> bool:
523
+ """True if reference_audio has a usable value (Gradio Audio returns path string or (path, sr))."""
524
+ if reference_audio is None:
525
+ return False
526
+ if isinstance(reference_audio, str):
527
+ return bool(reference_audio.strip())
528
+ if isinstance(reference_audio, (list, tuple)) and reference_audio:
529
+ return bool(reference_audio[0])
530
+ return False
531
+
532
+
533
+ def update_audio_cover_strength_visibility(task_type_value, init_llm_checked, reference_audio=None):
534
+ """Update audio_cover_strength visibility and label. Show Similarity/Denoise when reference audio is present."""
535
+ has_reference = _has_reference_audio(reference_audio)
536
+ # Show if task is cover, LM is initialized, or reference audio is present (audio-conditioned generation)
537
+ is_visible = (task_type_value == "cover") or init_llm_checked or has_reference
538
+ # Label priority: cover -> LM codes -> Similarity/Denoise (reference audio)
539
+ if task_type_value == "cover":
540
+ label = t("generation.cover_strength_label")
541
+ info = t("generation.cover_strength_info")
542
+ elif init_llm_checked:
543
+ label = t("generation.codes_strength_label")
544
+ info = t("generation.codes_strength_info")
545
+ elif has_reference:
546
+ label = t("generation.similarity_denoise_label")
547
+ info = t("generation.similarity_denoise_info")
548
+ else:
549
+ label = t("generation.cover_strength_label")
550
+ info = t("generation.cover_strength_info")
551
+ return gr.update(visible=is_visible, label=label, info=info)
552
+
553
+
554
+ def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
555
+ """Wrapper for converting src audio to codes"""
556
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
557
+ return codes_string
558
+
559
+
560
+ def update_instruction_ui(
561
+ dit_handler,
562
+ task_type_value: str,
563
+ track_name_value: Optional[str],
564
+ complete_track_classes_value: list,
565
+ audio_codes_content: str = "",
566
+ init_llm_checked: bool = False,
567
+ reference_audio=None,
568
+ ) -> tuple:
569
+ """Update instruction and UI visibility based on task type."""
570
+ instruction = dit_handler.generate_instruction(
571
+ task_type=task_type_value,
572
+ track_name=track_name_value,
573
+ complete_track_classes=complete_track_classes_value
574
+ )
575
+
576
+ # Show track_name for lego and extract
577
+ track_name_visible = task_type_value in ["lego", "extract"]
578
+ # Show complete_track_classes for complete
579
+ complete_visible = task_type_value == "complete"
580
+ # Show audio_cover_strength for cover, LM initialized, or reference audio present
581
+ has_reference = _has_reference_audio(reference_audio)
582
+ audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked or has_reference
583
+ # Label priority: cover -> LM codes -> Similarity/Denoise (reference audio)
584
+ if task_type_value == "cover":
585
+ audio_cover_strength_label = t("generation.cover_strength_label")
586
+ audio_cover_strength_info = t("generation.cover_strength_info")
587
+ elif init_llm_checked:
588
+ audio_cover_strength_label = t("generation.codes_strength_label")
589
+ audio_cover_strength_info = t("generation.codes_strength_info")
590
+ elif has_reference:
591
+ audio_cover_strength_label = t("generation.similarity_denoise_label")
592
+ audio_cover_strength_info = t("generation.similarity_denoise_info")
593
+ else:
594
+ audio_cover_strength_label = t("generation.cover_strength_label")
595
+ audio_cover_strength_info = t("generation.cover_strength_info")
596
+ # Show repainting controls for repaint and lego
597
+ repainting_visible = task_type_value in ["repaint", "lego"]
598
+ # Show text2music_audio_codes if task is text2music OR if it has content
599
+ # This allows it to stay visible even if user switches task type but has codes
600
+ has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
601
+ text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
602
+
603
+ return (
604
+ instruction, # instruction_display_gen
605
+ gr.update(visible=track_name_visible), # track_name
606
+ gr.update(visible=complete_visible), # complete_track_classes
607
+ gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
608
+ gr.update(visible=repainting_visible), # repainting_group
609
+ gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
610
+ )
611
+
612
+
613
+ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
614
+ """
615
+ Transcribe audio codes to metadata using LLM understanding.
616
+ If audio_code_string is empty, generate a sample example instead.
617
+
618
+ This is a Gradio wrapper around the understand_music API in acestep.inference.
619
+
620
+ Args:
621
+ llm_handler: LLM handler instance
622
+ audio_code_string: String containing audio codes (or empty for example generation)
623
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
624
+
625
+ Returns:
626
+ Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
627
+ """
628
+ # Call the inference API
629
+ result = understand_music(
630
+ llm_handler=llm_handler,
631
+ audio_codes=audio_code_string,
632
+ use_constrained_decoding=True,
633
+ constrained_decoding_debug=constrained_decoding_debug,
634
+ )
635
+
636
+ # Handle error case with localized message
637
+ if not result.success:
638
+ # Use localized error message for LLM not initialized
639
+ if result.error == "LLM not initialized":
640
+ return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
641
+ return result.status_message, "", "", None, None, "", "", "", False
642
+
643
+ # Clamp duration to GPU memory limit
644
+ clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
645
+
646
+ return (
647
+ result.status_message,
648
+ result.caption,
649
+ result.lyrics,
650
+ result.bpm,
651
+ clamped_duration,
652
+ result.keyscale,
653
+ result.language,
654
+ result.timesignature,
655
+ True # Set is_format_caption to True (from Transcribe/LM understanding)
656
+ )
657
+
658
+
659
+ def update_transcribe_button_text(audio_code_string):
660
+ """
661
+ Update the transcribe button text based on input content.
662
+ If empty: "Generate Example"
663
+ If has content: "Transcribe"
664
+ """
665
+ if not audio_code_string or not audio_code_string.strip():
666
+ return gr.update(value="Generate Example")
667
+ else:
668
+ return gr.update(value="Transcribe")
669
+
670
+
671
+ def reset_format_caption_flag():
672
+ """Reset is_format_caption to False when user manually edits caption/metadata"""
673
+ return False
674
+
675
+
676
+ def update_audio_uploads_accordion(reference_audio, src_audio):
677
+ """Update Audio Uploads accordion open state based on whether audio files are present"""
678
+ has_audio = (reference_audio is not None) or (src_audio is not None)
679
+ return gr.Accordion(open=has_audio)
680
+
681
+
682
+ def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
683
+ """
684
+ Handle instrumental checkbox changes.
685
+ When checked: if no lyrics, fill with [Instrumental]
686
+ When unchecked: if lyrics is [Instrumental], clear it
687
+ """
688
+ if instrumental_checked:
689
+ # If checked and no lyrics, fill with [Instrumental]
690
+ if not current_lyrics or not current_lyrics.strip():
691
+ return "[Instrumental]"
692
+ else:
693
+ # Has lyrics, don't change
694
+ return current_lyrics
695
+ else:
696
+ # If unchecked and lyrics is exactly [Instrumental], clear it
697
+ if current_lyrics and current_lyrics.strip() == "[Instrumental]":
698
+ return ""
699
+ else:
700
+ # Has other lyrics, don't change
701
+ return current_lyrics
702
+
703
+
704
+ def handle_simple_instrumental_change(is_instrumental: bool):
705
+ """
706
+ Handle simple mode instrumental checkbox changes.
707
+ When checked: set vocal_language to "unknown" and disable editing.
708
+ When unchecked: enable vocal_language editing.
709
+
710
+ Args:
711
+ is_instrumental: Whether instrumental checkbox is checked
712
+
713
+ Returns:
714
+ gr.update for simple_vocal_language dropdown
715
+ """
716
+ if is_instrumental:
717
+ return gr.update(value="unknown", interactive=False)
718
+ else:
719
+ return gr.update(interactive=True)
720
+
721
+
722
+ def update_audio_components_visibility(batch_size):
723
+ """Show/hide individual audio components based on batch size (1-8)
724
+
725
+ Row 1: Components 1-4 (batch_size 1-4)
726
+ Row 2: Components 5-8 (batch_size 5-8)
727
+ """
728
+ # Clamp batch size to 1-8 range for UI
729
+ batch_size = min(max(int(batch_size), 1), 8)
730
+
731
+ # Row 1 columns (1-4)
732
+ updates_row1 = (
733
+ gr.update(visible=True), # audio_col_1: always visible
734
+ gr.update(visible=batch_size >= 2), # audio_col_2
735
+ gr.update(visible=batch_size >= 3), # audio_col_3
736
+ gr.update(visible=batch_size >= 4), # audio_col_4
737
+ )
738
+
739
+ # Row 2 container and columns (5-8)
740
+ show_row_5_8 = batch_size >= 5
741
+ updates_row2 = (
742
+ gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
743
+ gr.update(visible=batch_size >= 5), # audio_col_5
744
+ gr.update(visible=batch_size >= 6), # audio_col_6
745
+ gr.update(visible=batch_size >= 7), # audio_col_7
746
+ gr.update(visible=batch_size >= 8), # audio_col_8
747
+ )
748
+
749
+ return updates_row1 + updates_row2
750
+
751
+
752
+ def handle_generation_mode_change(mode: str):
753
+ """
754
+ Handle generation mode change between Simple and Custom modes.
755
+
756
+ In Simple mode:
757
+ - Show simple mode group (query input, instrumental checkbox, create button)
758
+ - Collapse caption and lyrics accordions
759
+ - Hide optional parameters accordion
760
+ - Disable generate button until sample is created
761
+
762
+ In Custom mode:
763
+ - Hide simple mode group
764
+ - Expand caption and lyrics accordions
765
+ - Show optional parameters accordion
766
+ - Enable generate button
767
+
768
+ Args:
769
+ mode: "simple" or "custom"
770
+
771
+ Returns:
772
+ Tuple of updates for:
773
+ - simple_mode_group (visibility)
774
+ - caption_accordion (open state)
775
+ - lyrics_accordion (open state)
776
+ - generate_btn (interactive state)
777
+ - simple_sample_created (reset state)
778
+ - optional_params_accordion (visibility)
779
+ """
780
+ is_simple = mode == "simple"
781
+
782
+ return (
783
+ gr.update(visible=is_simple), # simple_mode_group
784
+ gr.Accordion(open=not is_simple), # caption_accordion - collapsed in simple, open in custom
785
+ gr.Accordion(open=not is_simple), # lyrics_accordion - collapsed in simple, open in custom
786
+ gr.update(interactive=not is_simple), # generate_btn - disabled in simple until sample created
787
+ False, # simple_sample_created - reset to False on mode change
788
+ gr.Accordion(open=not is_simple), # optional_params_accordion - hidden in simple mode
789
+ )
790
+
791
+
792
+ def handle_create_sample(
793
+ llm_handler,
794
+ query: str,
795
+ instrumental: bool,
796
+ vocal_language: str,
797
+ lm_temperature: float,
798
+ lm_top_k: int,
799
+ lm_top_p: float,
800
+ constrained_decoding_debug: bool = False,
801
+ ):
802
+ """
803
+ Handle the Create Sample button click in Simple mode.
804
+
805
+ Creates a sample from the user's query using the LLM, then populates
806
+ the caption, lyrics, and metadata fields.
807
+
808
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
809
+
810
+ Args:
811
+ llm_handler: LLM handler instance
812
+ query: User's natural language music description
813
+ instrumental: Whether to generate instrumental music
814
+ vocal_language: Preferred vocal language for constrained decoding
815
+ lm_temperature: LLM temperature for generation
816
+ lm_top_k: LLM top-k sampling
817
+ lm_top_p: LLM top-p sampling
818
+ constrained_decoding_debug: Whether to enable debug logging
819
+
820
+ Returns:
821
+ Tuple of updates for:
822
+ - captions
823
+ - lyrics
824
+ - bpm
825
+ - audio_duration
826
+ - key_scale
827
+ - vocal_language
828
+ - time_signature
829
+ - instrumental_checkbox
830
+ - caption_accordion (open)
831
+ - lyrics_accordion (open)
832
+ - generate_btn (interactive)
833
+ - simple_sample_created (True)
834
+ - think_checkbox (True)
835
+ - is_format_caption_state (True)
836
+ - status_output
837
+ """
838
+ # Check if LLM is initialized
839
+ if not llm_handler.llm_initialized:
840
+ gr.Warning(t("messages.lm_not_initialized"))
841
+ return (
842
+ gr.update(), # captions - no change
843
+ gr.update(), # lyrics - no change
844
+ gr.update(), # bpm - no change
845
+ gr.update(), # audio_duration - no change
846
+ gr.update(), # key_scale - no change
847
+ gr.update(), # vocal_language - no change
848
+ gr.update(), # time_signature - no change
849
+ gr.update(), # instrumental_checkbox - no change
850
+ gr.update(), # caption_accordion - no change
851
+ gr.update(), # lyrics_accordion - no change
852
+ gr.update(interactive=False), # generate_btn - keep disabled
853
+ False, # simple_sample_created - still False
854
+ gr.update(), # think_checkbox - no change
855
+ gr.update(), # is_format_caption_state - no change
856
+ t("messages.lm_not_initialized"), # status_output
857
+ )
858
+
859
+ # Convert LM parameters
860
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
861
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
862
+
863
+ # Call create_sample API
864
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
865
+ result = create_sample(
866
+ llm_handler=llm_handler,
867
+ query=query,
868
+ instrumental=instrumental,
869
+ vocal_language=vocal_language,
870
+ temperature=lm_temperature,
871
+ top_k=top_k_value,
872
+ top_p=top_p_value,
873
+ use_constrained_decoding=True,
874
+ constrained_decoding_debug=constrained_decoding_debug,
875
+ )
876
+
877
+ # Handle error
878
+ if not result.success:
879
+ gr.Warning(result.status_message or t("messages.sample_creation_failed"))
880
+ return (
881
+ gr.update(), # captions - no change
882
+ gr.update(), # lyrics - no change
883
+ gr.update(), # bpm - no change
884
+ gr.update(), # audio_duration - no change
885
+ gr.update(), # key_scale - no change
886
+ gr.update(), # vocal_language - no change
887
+ gr.update(), # simple vocal_language - no change
888
+ gr.update(), # time_signature - no change
889
+ gr.update(), # instrumental_checkbox - no change
890
+ gr.update(), # caption_accordion - no change
891
+ gr.update(), # lyrics_accordion - no change
892
+ gr.update(interactive=False), # generate_btn - keep disabled
893
+ False, # simple_sample_created - still False
894
+ gr.update(), # think_checkbox - no change
895
+ gr.update(), # is_format_caption_state - no change
896
+ result.status_message or t("messages.sample_creation_failed"), # status_output
897
+ )
898
+
899
+ # Success - populate fields
900
+ gr.Info(t("messages.sample_created"))
901
+
902
+ # Clamp duration to GPU memory limit
903
+ clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
904
+ audio_duration_value = clamped_duration if clamped_duration and clamped_duration > 0 else -1
905
+
906
+ return (
907
+ result.caption, # captions
908
+ result.lyrics, # lyrics
909
+ result.bpm, # bpm
910
+ audio_duration_value, # audio_duration
911
+ result.keyscale, # key_scale
912
+ result.language, # vocal_language
913
+ result.language, # simple vocal_language
914
+ result.timesignature, # time_signature
915
+ result.instrumental, # instrumental_checkbox
916
+ gr.Accordion(open=True), # caption_accordion - expand
917
+ gr.Accordion(open=True), # lyrics_accordion - expand
918
+ gr.update(interactive=True), # generate_btn - enable
919
+ True, # simple_sample_created - True
920
+ True, # think_checkbox - enable thinking
921
+ True, # is_format_caption_state - True (LM-generated)
922
+ result.status_message, # status_output
923
+ )
924
+
925
+
926
+ def handle_format_sample(
927
+ llm_handler,
928
+ caption: str,
929
+ lyrics: str,
930
+ bpm,
931
+ audio_duration,
932
+ key_scale: str,
933
+ time_signature: str,
934
+ lm_temperature: float,
935
+ lm_top_k: int,
936
+ lm_top_p: float,
937
+ constrained_decoding_debug: bool = False,
938
+ ):
939
+ """
940
+ Handle the Format button click to format caption and lyrics.
941
+
942
+ Takes user-provided caption and lyrics, and uses the LLM to generate
943
+ structured music metadata and an enhanced description.
944
+
945
+ Note: cfg_scale and negative_prompt are not supported in format mode.
946
+
947
+ Args:
948
+ llm_handler: LLM handler instance
949
+ caption: User's caption/description
950
+ lyrics: User's lyrics
951
+ bpm: User-provided BPM (optional, for constrained decoding)
952
+ audio_duration: User-provided duration (optional, for constrained decoding)
953
+ key_scale: User-provided key scale (optional, for constrained decoding)
954
+ time_signature: User-provided time signature (optional, for constrained decoding)
955
+ lm_temperature: LLM temperature for generation
956
+ lm_top_k: LLM top-k sampling
957
+ lm_top_p: LLM top-p sampling
958
+ constrained_decoding_debug: Whether to enable debug logging
959
+
960
+ Returns:
961
+ Tuple of updates for:
962
+ - captions
963
+ - lyrics
964
+ - bpm
965
+ - audio_duration
966
+ - key_scale
967
+ - vocal_language
968
+ - time_signature
969
+ - is_format_caption_state
970
+ - status_output
971
+ """
972
+ # Check if LLM is initialized
973
+ if not llm_handler.llm_initialized:
974
+ gr.Warning(t("messages.lm_not_initialized"))
975
+ return (
976
+ gr.update(), # captions - no change
977
+ gr.update(), # lyrics - no change
978
+ gr.update(), # bpm - no change
979
+ gr.update(), # audio_duration - no change
980
+ gr.update(), # key_scale - no change
981
+ gr.update(), # vocal_language - no change
982
+ gr.update(), # time_signature - no change
983
+ gr.update(), # is_format_caption_state - no change
984
+ t("messages.lm_not_initialized"), # status_output
985
+ )
986
+
987
+ # Build user_metadata from provided values for constrained decoding
988
+ user_metadata = {}
989
+ if bpm is not None and bpm > 0:
990
+ user_metadata['bpm'] = int(bpm)
991
+ if audio_duration is not None and float(audio_duration) > 0:
992
+ user_metadata['duration'] = int(audio_duration)
993
+ if key_scale and key_scale.strip():
994
+ user_metadata['keyscale'] = key_scale.strip()
995
+ if time_signature and time_signature.strip():
996
+ user_metadata['timesignature'] = time_signature.strip()
997
+
998
+ # Only pass user_metadata if we have at least one field
999
+ user_metadata_to_pass = user_metadata if user_metadata else None
1000
+
1001
+ # Convert LM parameters
1002
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
1003
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
1004
+
1005
+ # Call format_sample API
1006
+ result = format_sample(
1007
+ llm_handler=llm_handler,
1008
+ caption=caption,
1009
+ lyrics=lyrics,
1010
+ user_metadata=user_metadata_to_pass,
1011
+ temperature=lm_temperature,
1012
+ top_k=top_k_value,
1013
+ top_p=top_p_value,
1014
+ use_constrained_decoding=True,
1015
+ constrained_decoding_debug=constrained_decoding_debug,
1016
+ )
1017
+
1018
+ # Handle error
1019
+ if not result.success:
1020
+ gr.Warning(result.status_message or t("messages.format_failed"))
1021
+ return (
1022
+ gr.update(), # captions - no change
1023
+ gr.update(), # lyrics - no change
1024
+ gr.update(), # bpm - no change
1025
+ gr.update(), # audio_duration - no change
1026
+ gr.update(), # key_scale - no change
1027
+ gr.update(), # vocal_language - no change
1028
+ gr.update(), # time_signature - no change
1029
+ gr.update(), # is_format_caption_state - no change
1030
+ result.status_message or t("messages.format_failed"), # status_output
1031
+ )
1032
+
1033
+ # Success - populate fields
1034
+ gr.Info(t("messages.format_success"))
1035
+
1036
+ # Clamp duration to GPU memory limit
1037
+ clamped_duration = clamp_duration_to_gpu_limit(result.duration, llm_handler)
1038
+ audio_duration_value = clamped_duration if clamped_duration and clamped_duration > 0 else -1
1039
+
1040
+ return (
1041
+ result.caption, # captions
1042
+ result.lyrics, # lyrics
1043
+ result.bpm, # bpm
1044
+ audio_duration_value, # audio_duration
1045
+ result.keyscale, # key_scale
1046
+ result.language, # vocal_language
1047
+ result.timesignature, # time_signature
1048
+ True, # is_format_caption_state - True (LM-formatted)
1049
+ result.status_message, # status_output
1050
+ )
acestep/gradio_ui/events/results_handlers.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/gradio_ui/events/training_handlers.py ADDED
@@ -0,0 +1,829 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Handlers for Training Tab
3
+
4
+ Contains all event handler functions for the dataset builder and training UI.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Any, Dict, List, Tuple, Optional
10
+ from loguru import logger
11
+ import gradio as gr
12
+
13
+ from acestep.training.dataset_builder import DatasetBuilder, AudioSample
14
+ from acestep.debug_utils import debug_log_for, debug_start_for, debug_end_for
15
+ from acestep.gpu_config import get_global_gpu_config
16
+
17
+
18
+ def create_dataset_builder() -> DatasetBuilder:
19
+ """Create a new DatasetBuilder instance."""
20
+ return DatasetBuilder()
21
+
22
+
23
+ def _safe_slider(max_value: int, value: int = 0, visible: Optional[bool] = None) -> gr.Slider:
24
+ """Create a slider with a non-zero range to avoid Gradio math errors."""
25
+ max_value = max(1, int(max_value))
26
+ kwargs = {"maximum": max_value, "value": min(int(value), max_value)}
27
+ if visible is not None:
28
+ kwargs["visible"] = visible
29
+ return gr.Slider(**kwargs)
30
+
31
+ def scan_directory(
32
+ audio_dir: str,
33
+ dataset_name: str,
34
+ custom_tag: str,
35
+ tag_position: str,
36
+ all_instrumental: bool,
37
+ builder_state: Optional[DatasetBuilder],
38
+ ) -> Tuple[Any, str, Any, DatasetBuilder]:
39
+ """Scan a directory for audio files.
40
+
41
+ Returns:
42
+ Tuple of (table_data, status, slider_update, builder_state)
43
+ """
44
+ if not audio_dir or not audio_dir.strip():
45
+ return [], "� Please enter a directory path", _safe_slider(0, value=0, visible=False), builder_state
46
+
47
+ # Create or use existing builder
48
+ builder = builder_state if builder_state else DatasetBuilder()
49
+
50
+ # Set metadata before scanning
51
+ builder.metadata.name = dataset_name
52
+ builder.metadata.custom_tag = custom_tag
53
+ builder.metadata.tag_position = tag_position
54
+ builder.metadata.all_instrumental = all_instrumental
55
+
56
+ # Scan directory
57
+ samples, status = builder.scan_directory(audio_dir.strip())
58
+
59
+ if not samples:
60
+ return [], status, _safe_slider(0, value=0, visible=False), builder
61
+
62
+ # Set instrumental and tag for all samples
63
+ builder.set_all_instrumental(all_instrumental)
64
+ if custom_tag:
65
+ builder.set_custom_tag(custom_tag, tag_position)
66
+
67
+ # Get table data
68
+ table_data = builder.get_samples_dataframe_data()
69
+
70
+ # Calculate slider max and return as Slider update
71
+ slider_max = max(0, len(samples) - 1)
72
+
73
+ return table_data, status, _safe_slider(slider_max, value=0, visible=len(samples) > 1), builder
74
+
75
+
76
+ def auto_label_all(
77
+ dit_handler,
78
+ llm_handler,
79
+ builder_state: Optional[DatasetBuilder],
80
+ skip_metas: bool = False,
81
+ format_lyrics: bool = False,
82
+ transcribe_lyrics: bool = False,
83
+ only_unlabeled: bool = False,
84
+ progress=None,
85
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
86
+ """Auto-label all samples in the dataset.
87
+
88
+ Args:
89
+ dit_handler: DiT handler for audio processing
90
+ llm_handler: LLM handler for caption generation
91
+ builder_state: Dataset builder state
92
+ skip_metas: If True, skip generating BPM/Key/TimeSig but still generate caption/genre
93
+ format_lyrics: If True, use LLM to format user-provided lyrics from .txt files
94
+ transcribe_lyrics: If True, use LLM to transcribe lyrics from audio (ignores .txt files)
95
+ only_unlabeled: If True, only label samples without caption
96
+ progress: Progress callback
97
+
98
+ Returns:
99
+ Tuple of (table_data, status, builder_state)
100
+ """
101
+ if builder_state is None:
102
+ return [], "� Please scan a directory first", builder_state
103
+
104
+ if not builder_state.samples:
105
+ return [], "� No samples to label. Please scan a directory first.", builder_state
106
+
107
+ # Check if handlers are initialized
108
+ if dit_handler is None or dit_handler.model is None:
109
+ return builder_state.get_samples_dataframe_data(), "� Model not initialized. Please initialize the service first.", builder_state
110
+
111
+ if llm_handler is None or not llm_handler.llm_initialized:
112
+ return builder_state.get_samples_dataframe_data(), "� LLM not initialized. Please initialize the service with LLM enabled.", builder_state
113
+
114
+ def progress_callback(msg):
115
+ if progress:
116
+ try:
117
+ progress(msg)
118
+ except:
119
+ pass
120
+
121
+ # Label all samples (skip_metas only skips BPM/Key/TimeSig, still generates caption/genre)
122
+ samples, status = builder_state.label_all_samples(
123
+ dit_handler=dit_handler,
124
+ llm_handler=llm_handler,
125
+ format_lyrics=format_lyrics,
126
+ transcribe_lyrics=transcribe_lyrics,
127
+ skip_metas=skip_metas,
128
+ only_unlabeled=only_unlabeled,
129
+ progress_callback=progress_callback,
130
+ )
131
+
132
+ # Get updated table data
133
+ table_data = builder_state.get_samples_dataframe_data()
134
+
135
+ # Force UI refresh for table and status
136
+ return gr.update(value=table_data), gr.update(value=status), builder_state
137
+
138
+
139
+ def get_sample_preview(
140
+ sample_idx: int,
141
+ builder_state: Optional[DatasetBuilder],
142
+ ):
143
+ """Get preview data for a specific sample.
144
+
145
+ Returns:
146
+ Tuple of (audio_path, filename, caption, genre, prompt_override, lyrics, bpm, keyscale, timesig,
147
+ duration, language, instrumental, raw_lyrics, raw_lyrics_visible)
148
+ """
149
+ empty = (None, "", "", "", "Use Global Ratio", "", None, "", "", 0.0, "instrumental", True, "", False)
150
+
151
+ if builder_state is None or not builder_state.samples:
152
+ return empty
153
+
154
+ if sample_idx is None:
155
+ return empty
156
+
157
+ idx = int(sample_idx)
158
+ if idx < 0 or idx >= len(builder_state.samples):
159
+ return empty
160
+
161
+ sample = builder_state.samples[idx]
162
+
163
+ # Show raw lyrics panel only when raw lyrics exist
164
+ has_raw = sample.has_raw_lyrics()
165
+
166
+ # Convert prompt_override to dropdown choice
167
+ if sample.prompt_override == "genre":
168
+ override_choice = "Genre"
169
+ elif sample.prompt_override == "caption":
170
+ override_choice = "Caption"
171
+ else:
172
+ override_choice = "Use Global Ratio"
173
+
174
+ display_lyrics = sample.lyrics if sample.lyrics else sample.formatted_lyrics
175
+
176
+ return (
177
+ sample.audio_path,
178
+ sample.filename,
179
+ sample.caption,
180
+ sample.genre,
181
+ override_choice,
182
+ display_lyrics,
183
+ sample.bpm,
184
+ sample.keyscale,
185
+ sample.timesignature,
186
+ sample.duration,
187
+ sample.language,
188
+ sample.is_instrumental,
189
+ sample.raw_lyrics if has_raw else "",
190
+ has_raw,
191
+ )
192
+
193
+
194
+ def save_sample_edit(
195
+ sample_idx: int,
196
+ caption: str,
197
+ genre: str,
198
+ prompt_override: str,
199
+ lyrics: str,
200
+ bpm: Optional[int],
201
+ keyscale: str,
202
+ timesig: str,
203
+ language: str,
204
+ is_instrumental: bool,
205
+ builder_state: Optional[DatasetBuilder],
206
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
207
+ """Save edits to a sample.
208
+
209
+ Returns:
210
+ Tuple of (table_data, status, builder_state)
211
+ """
212
+ if builder_state is None:
213
+ return [], "� No dataset loaded", builder_state
214
+
215
+ idx = int(sample_idx)
216
+
217
+ # Convert dropdown choice to prompt_override value
218
+ if prompt_override == "Genre":
219
+ override_value = "genre"
220
+ elif prompt_override == "Caption":
221
+ override_value = "caption"
222
+ else:
223
+ override_value = None # Use Global Ratio
224
+
225
+ # Update sample
226
+ updated_lyrics = lyrics if not is_instrumental else "[Instrumental]"
227
+ updated_formatted = updated_lyrics if updated_lyrics and updated_lyrics != "[Instrumental]" else ""
228
+ sample, status = builder_state.update_sample(
229
+ idx,
230
+ caption=caption,
231
+ genre=genre,
232
+ prompt_override=override_value,
233
+ lyrics=updated_lyrics,
234
+ formatted_lyrics=updated_formatted,
235
+ bpm=int(bpm) if bpm else None,
236
+ keyscale=keyscale,
237
+ timesignature=timesig,
238
+ language="unknown" if is_instrumental else language,
239
+ is_instrumental=is_instrumental,
240
+ labeled=True,
241
+ )
242
+
243
+ # Get updated table data
244
+ table_data = builder_state.get_samples_dataframe_data()
245
+
246
+ return table_data, status, builder_state
247
+
248
+
249
+ def update_settings(
250
+ custom_tag: str,
251
+ tag_position: str,
252
+ all_instrumental: bool,
253
+ genre_ratio: int,
254
+ builder_state: Optional[DatasetBuilder],
255
+ ) -> DatasetBuilder:
256
+ """Update dataset settings.
257
+
258
+ Returns:
259
+ Updated builder_state
260
+ """
261
+ if builder_state is None:
262
+ return builder_state
263
+
264
+ if custom_tag:
265
+ builder_state.set_custom_tag(custom_tag, tag_position)
266
+
267
+ builder_state.set_all_instrumental(all_instrumental)
268
+ builder_state.metadata.genre_ratio = int(genre_ratio)
269
+
270
+ return builder_state
271
+
272
+
273
+ def save_dataset(
274
+ save_path: str,
275
+ dataset_name: str,
276
+ builder_state: Optional[DatasetBuilder],
277
+ ) -> Tuple[str, Any]:
278
+ """Save the dataset to a JSON file.
279
+
280
+ Returns:
281
+ Status message
282
+ """
283
+ if builder_state is None:
284
+ return "� No dataset to save. Please scan a directory first.", gr.update()
285
+
286
+ if not builder_state.samples:
287
+ return "� No samples in dataset.", gr.update()
288
+
289
+ if not save_path or not save_path.strip():
290
+ return "� Please enter a save path.", gr.update()
291
+
292
+ save_path = save_path.strip()
293
+ if not save_path.lower().endswith(".json"):
294
+ save_path = save_path + ".json"
295
+
296
+ # Check if any samples are labeled
297
+ labeled_count = builder_state.get_labeled_count()
298
+ if labeled_count == 0:
299
+ return "�️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway...", gr.update(value=save_path)
300
+
301
+ return builder_state.save_dataset(save_path, dataset_name), gr.update(value=save_path)
302
+
303
+
304
+ def load_existing_dataset_for_preprocess(
305
+ dataset_path: str,
306
+ builder_state: Optional[DatasetBuilder],
307
+ ):
308
+ """Load an existing dataset JSON file for preprocessing.
309
+
310
+ This allows users to load a previously saved dataset and proceed to preprocessing
311
+ without having to re-scan and re-label.
312
+
313
+ Returns:
314
+ Tuple of (status, table_data, slider_update, builder_state,
315
+ audio_path, filename, caption, genre, prompt_override,
316
+ lyrics, bpm, keyscale, timesig, duration, language, instrumental,
317
+ raw_lyrics, has_raw)
318
+ """
319
+ # Empty preview: (audio_path, filename, caption, genre, prompt_override, lyrics, bpm, keyscale, timesig, duration, language, instrumental, raw_lyrics, has_raw)
320
+ empty_preview = (None, "", "", "", "Use Global Ratio", "", None, "", "", 0.0, "instrumental", True, "", False)
321
+
322
+ if not dataset_path or not dataset_path.strip():
323
+ updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
324
+ return ("� Please enter a dataset path", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
325
+
326
+ dataset_path = dataset_path.strip()
327
+ debug_log_for("dataset", f"UI load_existing_dataset_for_preprocess: path='{dataset_path}'")
328
+
329
+ if not os.path.exists(dataset_path):
330
+ updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
331
+ return (f"� Dataset not found: {dataset_path}", [], _safe_slider(0, value=0, visible=False), builder_state) + empty_preview + updates
332
+
333
+ # Create new builder (don't reuse old state when loading a file)
334
+ builder = DatasetBuilder()
335
+
336
+ # Load the dataset
337
+ t0 = debug_start_for("dataset", "load_dataset")
338
+ samples, status = builder.load_dataset(dataset_path)
339
+ debug_end_for("dataset", "load_dataset", t0)
340
+
341
+ if not samples:
342
+ updates = (gr.update(), gr.update(), gr.update(), gr.update(), gr.update())
343
+ return (status, [], _safe_slider(0, value=0, visible=False), builder) + empty_preview + updates
344
+
345
+ # Get table data
346
+ table_data = builder.get_samples_dataframe_data()
347
+
348
+ # Calculate slider max
349
+ slider_max = max(0, len(samples) - 1)
350
+
351
+ # Create info text
352
+ labeled_count = builder.get_labeled_count()
353
+ info = f"� Loaded dataset: {builder.metadata.name}\n"
354
+ info += f"� Samples: {len(samples)} ({labeled_count} labeled)\n"
355
+ info += f"���️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
356
+ info += "� Ready for preprocessing! You can also edit samples below."
357
+ if any((s.formatted_lyrics and not s.lyrics) for s in builder.samples):
358
+ info += "\n�️ Showing formatted lyrics where lyrics are empty."
359
+
360
+ # Get first sample preview
361
+ first_sample = builder.samples[0]
362
+ has_raw = first_sample.has_raw_lyrics()
363
+
364
+ # Convert prompt_override to dropdown choice
365
+ if first_sample.prompt_override == "genre":
366
+ override_choice = "Genre"
367
+ elif first_sample.prompt_override == "caption":
368
+ override_choice = "Caption"
369
+ else:
370
+ override_choice = "Use Global Ratio"
371
+
372
+ display_lyrics = first_sample.lyrics if first_sample.lyrics else first_sample.formatted_lyrics
373
+
374
+ preview = (
375
+ first_sample.audio_path,
376
+ first_sample.filename,
377
+ first_sample.caption,
378
+ first_sample.genre,
379
+ override_choice,
380
+ display_lyrics,
381
+ first_sample.bpm,
382
+ first_sample.keyscale,
383
+ first_sample.timesignature,
384
+ first_sample.duration,
385
+ first_sample.language,
386
+ first_sample.is_instrumental,
387
+ first_sample.raw_lyrics if has_raw else "",
388
+ has_raw,
389
+ )
390
+
391
+ updates = (
392
+ gr.update(value=builder.metadata.name),
393
+ gr.update(value=builder.metadata.custom_tag),
394
+ gr.update(value=builder.metadata.tag_position),
395
+ gr.update(value=builder.metadata.all_instrumental),
396
+ gr.update(value=builder.metadata.genre_ratio),
397
+ )
398
+
399
+ return (info, table_data, _safe_slider(slider_max, value=0, visible=len(samples) > 1), builder) + preview + updates
400
+
401
+
402
+ def preprocess_dataset(
403
+ output_dir: str,
404
+ dit_handler,
405
+ builder_state: Optional[DatasetBuilder],
406
+ progress=None,
407
+ ) -> str:
408
+ """Preprocess dataset to tensor files for fast training.
409
+
410
+ This converts audio files to VAE latents and text to embeddings.
411
+
412
+ Returns:
413
+ Status message
414
+ """
415
+ if builder_state is None:
416
+ return "� No dataset loaded. Please scan a directory first."
417
+
418
+ if not builder_state.samples:
419
+ return "� No samples in dataset."
420
+
421
+ labeled_count = builder_state.get_labeled_count()
422
+ if labeled_count == 0:
423
+ return "� No labeled samples. Please auto-label or manually label samples first."
424
+
425
+ if not output_dir or not output_dir.strip():
426
+ return "� Please enter an output directory."
427
+
428
+ if dit_handler is None or dit_handler.model is None:
429
+ return "� Model not initialized. Please initialize the service first."
430
+
431
+ def progress_callback(msg):
432
+ if progress:
433
+ try:
434
+ progress(msg)
435
+ except:
436
+ pass
437
+
438
+ # Run preprocessing
439
+ t0 = debug_start_for("dataset", "preprocess_to_tensors")
440
+ output_paths, status = builder_state.preprocess_to_tensors(
441
+ dit_handler=dit_handler,
442
+ output_dir=output_dir.strip(),
443
+ progress_callback=progress_callback,
444
+ )
445
+ debug_end_for("dataset", "preprocess_to_tensors", t0)
446
+
447
+ return status
448
+
449
+
450
+ def load_training_dataset(
451
+ tensor_dir: str,
452
+ ) -> str:
453
+ """Load a preprocessed tensor dataset for training.
454
+
455
+ Returns:
456
+ Info text about the dataset
457
+ """
458
+ if not tensor_dir or not tensor_dir.strip():
459
+ return "� Please enter a tensor directory path"
460
+
461
+ tensor_dir = tensor_dir.strip()
462
+
463
+ if not os.path.exists(tensor_dir):
464
+ return f"� Directory not found: {tensor_dir}"
465
+
466
+ if not os.path.isdir(tensor_dir):
467
+ return f"� Not a directory: {tensor_dir}"
468
+
469
+ # Check for manifest
470
+ manifest_path = os.path.join(tensor_dir, "manifest.json")
471
+ if os.path.exists(manifest_path):
472
+ try:
473
+ with open(manifest_path, 'r') as f:
474
+ manifest = json.load(f)
475
+
476
+ num_samples = manifest.get("num_samples", 0)
477
+ metadata = manifest.get("metadata", {})
478
+ name = metadata.get("name", "Unknown")
479
+ custom_tag = metadata.get("custom_tag", "")
480
+
481
+ info = f"� Loaded preprocessed dataset: {name}\n"
482
+ info += f"� Samples: {num_samples} preprocessed tensors\n"
483
+ info += f"���️ Custom Tag: {custom_tag or '(none)'}"
484
+
485
+ return info
486
+ except Exception as e:
487
+ logger.warning(f"Failed to read manifest: {e}")
488
+
489
+ # Fallback: count .pt files
490
+ pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
491
+
492
+ if not pt_files:
493
+ return f"� No .pt tensor files found in {tensor_dir}"
494
+
495
+ info = f"� Found {len(pt_files)} tensor files in {tensor_dir}\n"
496
+ info += "�️ No manifest.json found - using all .pt files"
497
+
498
+ return info
499
+
500
+
501
+ # Training handlers
502
+
503
+ import time
504
+ import re
505
+
506
+
507
+ def _format_duration(seconds):
508
+ """Format seconds to human readable string."""
509
+ seconds = int(seconds)
510
+ if seconds < 60:
511
+ return f"{seconds}s"
512
+ elif seconds < 3600:
513
+ return f"{seconds // 60}m {seconds % 60}s"
514
+ else:
515
+ return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
516
+
517
+
518
+ def start_training(
519
+ tensor_dir: str,
520
+ dit_handler,
521
+ lora_rank: int,
522
+ lora_alpha: int,
523
+ lora_dropout: float,
524
+ learning_rate: float,
525
+ train_epochs: int,
526
+ train_batch_size: int,
527
+ gradient_accumulation: int,
528
+ save_every_n_epochs: int,
529
+ training_shift: float,
530
+ training_seed: int,
531
+ lora_output_dir: str,
532
+ resume_checkpoint_dir: str,
533
+ training_state: Dict,
534
+ progress=None,
535
+ ):
536
+ """Start LoRA training from preprocessed tensors.
537
+
538
+ This is a generator function that yields progress updates.
539
+ """
540
+ if not tensor_dir or not tensor_dir.strip():
541
+ yield "� Please enter a tensor directory path", "", None, training_state
542
+ return
543
+
544
+ tensor_dir = tensor_dir.strip()
545
+
546
+ if not os.path.exists(tensor_dir):
547
+ yield f"� Tensor directory not found: {tensor_dir}", "", None, training_state
548
+ return
549
+
550
+ if dit_handler is None or dit_handler.model is None:
551
+ yield "� Model not initialized. Please initialize the service first.", "", None, training_state
552
+ return
553
+
554
+ # Training preset: LoRA training must run on non-quantized DiT.
555
+ if getattr(dit_handler, "quantization", None) is not None:
556
+ gpu_config = get_global_gpu_config()
557
+ if gpu_config.gpu_memory_gb <= 0:
558
+ yield (
559
+ "WARNING: CPU-only training detected. Using best-effort training path "
560
+ "(non-quantized DiT). Performance will be sub-optimal.",
561
+ "",
562
+ None,
563
+ training_state,
564
+ )
565
+ elif gpu_config.tier in {"tier1", "tier2", "tier3", "tier4"}:
566
+ yield (
567
+ f"WARNING: Low VRAM tier detected ({gpu_config.gpu_memory_gb:.1f} GB, {gpu_config.tier}). "
568
+ "Using best-effort training path (non-quantized DiT). Performance may be sub-optimal.",
569
+ "",
570
+ None,
571
+ training_state,
572
+ )
573
+
574
+ yield "Switching model to training preset (disable quantization)...", "", None, training_state
575
+ if hasattr(dit_handler, "switch_to_training_preset"):
576
+ switch_status, switched = dit_handler.switch_to_training_preset()
577
+ if not switched:
578
+ yield f"� {switch_status}", "", None, training_state
579
+ return
580
+ yield f"� {switch_status}", "", None, training_state
581
+ else:
582
+ yield "� Training requires non-quantized DiT, and auto-switch is unavailable in this build.", "", None, training_state
583
+ return
584
+
585
+ # Check for required training dependencies
586
+ try:
587
+ from lightning.fabric import Fabric
588
+ from peft import get_peft_model, LoraConfig
589
+ except ImportError as e:
590
+ yield f"� Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
591
+ return
592
+
593
+ training_state["is_training"] = True
594
+ training_state["should_stop"] = False
595
+
596
+ try:
597
+ from acestep.training.trainer import LoRATrainer
598
+ from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
599
+
600
+ # Create configs
601
+ lora_config = LoRAConfigClass(
602
+ r=lora_rank,
603
+ alpha=lora_alpha,
604
+ dropout=lora_dropout,
605
+ )
606
+
607
+ device_attr = getattr(dit_handler, "device", "")
608
+ if hasattr(device_attr, "type"):
609
+ device_type = str(device_attr.type).lower()
610
+ else:
611
+ device_type = str(device_attr).split(":", 1)[0].lower()
612
+
613
+ # Use device-tuned dataloader defaults while preserving CUDA acceleration.
614
+ if device_type == "cuda":
615
+ num_workers = 4
616
+ pin_memory = True
617
+ prefetch_factor = 2
618
+ persistent_workers = True
619
+ pin_memory_device = "cuda"
620
+ mixed_precision = "bf16"
621
+ elif device_type == "xpu":
622
+ num_workers = 4
623
+ pin_memory = True
624
+ prefetch_factor = 2
625
+ persistent_workers = True
626
+ pin_memory_device = None
627
+ mixed_precision = "bf16"
628
+ elif device_type == "mps":
629
+ num_workers = 0
630
+ pin_memory = False
631
+ prefetch_factor = 2
632
+ persistent_workers = False
633
+ pin_memory_device = None
634
+ mixed_precision = "fp16"
635
+ else:
636
+ cpu_count = os.cpu_count() or 2
637
+ num_workers = min(4, max(1, cpu_count // 2))
638
+ pin_memory = False
639
+ prefetch_factor = 2
640
+ persistent_workers = num_workers > 0
641
+ pin_memory_device = None
642
+ mixed_precision = "fp32"
643
+
644
+ logger.info(
645
+ f"Training loader config: device={device_type}, workers={num_workers}, "
646
+ f"pin_memory={pin_memory}, pin_memory_device={pin_memory_device}, "
647
+ f"persistent_workers={persistent_workers}"
648
+ )
649
+ training_config = TrainingConfig(
650
+ shift=training_shift,
651
+ learning_rate=learning_rate,
652
+ batch_size=train_batch_size,
653
+ gradient_accumulation_steps=gradient_accumulation,
654
+ max_epochs=train_epochs,
655
+ save_every_n_epochs=save_every_n_epochs,
656
+ seed=training_seed,
657
+ output_dir=lora_output_dir,
658
+ num_workers=num_workers,
659
+ pin_memory=pin_memory,
660
+ prefetch_factor=prefetch_factor,
661
+ persistent_workers=persistent_workers,
662
+ pin_memory_device=pin_memory_device,
663
+ mixed_precision=mixed_precision,
664
+ )
665
+
666
+ import pandas as pd
667
+
668
+ # Initialize training log and loss history
669
+ log_lines = []
670
+ loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
671
+
672
+ # Start timer
673
+ start_time = time.time()
674
+
675
+ yield f"� Starting training from {tensor_dir}...", "", loss_data, training_state
676
+
677
+ # Create trainer
678
+ trainer = LoRATrainer(
679
+ dit_handler=dit_handler,
680
+ lora_config=lora_config,
681
+ training_config=training_config,
682
+ )
683
+
684
+ # Collect loss history
685
+ step_list = []
686
+ loss_list = []
687
+ training_failed = False
688
+ failure_message = ""
689
+
690
+ # Train with progress updates using preprocessed tensors
691
+ resume_from = resume_checkpoint_dir.strip() if resume_checkpoint_dir and resume_checkpoint_dir.strip() else None
692
+ for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state, resume_from=resume_from):
693
+ status_text = str(status)
694
+ status_lower = status_text.lower()
695
+ if (
696
+ status_text.startswith("❌")
697
+ or status_text.startswith("❌")
698
+ or "training failed" in status_lower
699
+ or "error:" in status_lower
700
+ or "module not found" in status_lower
701
+ ):
702
+ training_failed = True
703
+ failure_message = status_text
704
+ # Calculate elapsed time and ETA
705
+ elapsed_seconds = time.time() - start_time
706
+ time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
707
+
708
+ # Parse "Epoch x/y" from status to calculate ETA
709
+ match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
710
+ if match:
711
+ current_ep = int(match.group(1))
712
+ total_ep = int(match.group(2))
713
+ if current_ep > 0:
714
+ eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
715
+ time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
716
+
717
+ # Display status with time info
718
+ display_status = f"{status}\n{time_info}"
719
+
720
+ # Terminal log
721
+ log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
722
+ logger.info(log_msg)
723
+
724
+ # Add to UI log
725
+ log_lines.append(status)
726
+ if len(log_lines) > 15:
727
+ log_lines = log_lines[-15:]
728
+ log_text = "\n".join(log_lines)
729
+
730
+ # Track loss for plot (only valid values)
731
+ if step > 0 and loss is not None and loss == loss: # Check for NaN
732
+ step_list.append(step)
733
+ loss_list.append(float(loss))
734
+ loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
735
+
736
+ yield display_status, log_text, loss_data, training_state
737
+
738
+ if training_state.get("should_stop", False):
739
+ logger.info("⏹️ Training stopped by user")
740
+ log_lines.append("⏹️ Training stopped by user")
741
+ yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
742
+ break
743
+
744
+ total_time = time.time() - start_time
745
+ training_state["is_training"] = False
746
+ if training_failed:
747
+ final_msg = f"{failure_message}\nElapsed: {_format_duration(total_time)}"
748
+ logger.warning(final_msg)
749
+ log_lines.append(failure_message)
750
+ yield final_msg, "\n".join(log_lines[-15:]), loss_data, training_state
751
+ return
752
+ completion_msg = f"� Training completed! Total time: {_format_duration(total_time)}"
753
+
754
+ logger.info(completion_msg)
755
+ log_lines.append(completion_msg)
756
+
757
+ yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
758
+
759
+ except Exception as e:
760
+ logger.exception("Training error")
761
+ training_state["is_training"] = False
762
+ import pandas as pd
763
+ empty_df = pd.DataFrame({"step": [], "loss": []})
764
+ yield f"� Error: {str(e)}", str(e), empty_df, training_state
765
+
766
+
767
+ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
768
+ """Stop the current training process.
769
+
770
+ Returns:
771
+ Tuple of (status, training_state)
772
+ """
773
+ if not training_state.get("is_training", False):
774
+ return "�️ No training in progress", training_state
775
+
776
+ training_state["should_stop"] = True
777
+ return "⏹️ Stopping training...", training_state
778
+
779
+
780
+ def export_lora(
781
+ export_path: str,
782
+ lora_output_dir: str,
783
+ ) -> str:
784
+ """Export the trained LoRA weights.
785
+
786
+ Returns:
787
+ Status message
788
+ """
789
+ if not export_path or not export_path.strip():
790
+ return "� Please enter an export path"
791
+
792
+ # Check if there's a trained model to export
793
+ final_dir = os.path.join(lora_output_dir, "final")
794
+ checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
795
+
796
+ # Prefer final, fallback to checkpoints
797
+ if os.path.exists(final_dir):
798
+ source_path = final_dir
799
+ elif os.path.exists(checkpoint_dir):
800
+ # Find the latest checkpoint
801
+ checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
802
+ if not checkpoints:
803
+ return "� No checkpoints found"
804
+
805
+ checkpoints.sort(key=lambda x: int(x.split("_")[1]))
806
+ latest = checkpoints[-1]
807
+ source_path = os.path.join(checkpoint_dir, latest)
808
+ else:
809
+ return f"� No trained model found in {lora_output_dir}"
810
+
811
+ try:
812
+ import shutil
813
+
814
+ export_path = export_path.strip()
815
+ os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
816
+
817
+ if os.path.exists(export_path):
818
+ shutil.rmtree(export_path)
819
+
820
+ shutil.copytree(source_path, export_path)
821
+
822
+ return f"� LoRA exported to {export_path}"
823
+
824
+ except Exception as e:
825
+ logger.exception("Export error")
826
+ return f"� Export failed: {str(e)}"
827
+
828
+
829
+
acestep/gradio_ui/i18n.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Internationalization (i18n) module for Gradio UI
3
+ Supports multiple languages with easy translation management
4
+ """
5
+ import os
6
+ import json
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class I18n:
11
+ """Internationalization handler"""
12
+
13
+ def __init__(self, default_language: str = "en"):
14
+ """
15
+ Initialize i18n handler
16
+
17
+ Args:
18
+ default_language: Default language code (en, zh, ja, etc.)
19
+ """
20
+ self.current_language = default_language
21
+ self.translations: Dict[str, Dict[str, str]] = {}
22
+ self._load_all_translations()
23
+
24
+ def _load_all_translations(self):
25
+ """Load all translation files from i18n directory"""
26
+ current_file = os.path.abspath(__file__)
27
+ module_dir = os.path.dirname(current_file)
28
+ i18n_dir = os.path.join(module_dir, "i18n")
29
+
30
+ if not os.path.exists(i18n_dir):
31
+ # Create i18n directory if it doesn't exist
32
+ os.makedirs(i18n_dir)
33
+ return
34
+
35
+ # Load all JSON files in i18n directory
36
+ for filename in os.listdir(i18n_dir):
37
+ if filename.endswith(".json"):
38
+ lang_code = filename[:-5] # Remove .json extension
39
+ filepath = os.path.join(i18n_dir, filename)
40
+ try:
41
+ with open(filepath, 'r', encoding='utf-8') as f:
42
+ self.translations[lang_code] = json.load(f)
43
+ except Exception as e:
44
+ print(f"Error loading translation file {filename}: {e}")
45
+
46
+ def set_language(self, language: str):
47
+ """Set current language"""
48
+ if language in self.translations:
49
+ self.current_language = language
50
+ else:
51
+ print(f"Warning: Language '{language}' not found, using default")
52
+
53
+ def t(self, key: str, **kwargs) -> str:
54
+ """
55
+ Translate a key to current language
56
+
57
+ Args:
58
+ key: Translation key (dot-separated for nested keys)
59
+ **kwargs: Optional format parameters
60
+
61
+ Returns:
62
+ Translated string
63
+ """
64
+ # Get translation from current language
65
+ translation = self._get_nested_value(
66
+ self.translations.get(self.current_language, {}),
67
+ key
68
+ )
69
+
70
+ # Fallback to English if not found
71
+ if translation is None:
72
+ translation = self._get_nested_value(
73
+ self.translations.get('en', {}),
74
+ key
75
+ )
76
+
77
+ # Final fallback to key itself
78
+ if translation is None:
79
+ translation = key
80
+
81
+ # Apply formatting if kwargs provided
82
+ if kwargs:
83
+ try:
84
+ translation = translation.format(**kwargs)
85
+ except KeyError:
86
+ pass
87
+
88
+ return translation
89
+
90
+ def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
91
+ """
92
+ Get nested dictionary value using dot notation
93
+
94
+ Args:
95
+ data: Dictionary to search
96
+ key: Dot-separated key (e.g., "section.subsection.key")
97
+
98
+ Returns:
99
+ Value if found, None otherwise
100
+ """
101
+ keys = key.split('.')
102
+ current = data
103
+
104
+ for k in keys:
105
+ if isinstance(current, dict) and k in current:
106
+ current = current[k]
107
+ else:
108
+ return None
109
+
110
+ return current if isinstance(current, str) else None
111
+
112
+ def get_available_languages(self) -> list:
113
+ """Get list of available language codes"""
114
+ return list(self.translations.keys())
115
+
116
+
117
+ # Global i18n instance
118
+ _i18n_instance: Optional[I18n] = None
119
+
120
+
121
+ def get_i18n(language: Optional[str] = None) -> I18n:
122
+ """
123
+ Get global i18n instance
124
+
125
+ Args:
126
+ language: Optional language to set
127
+
128
+ Returns:
129
+ I18n instance
130
+ """
131
+ global _i18n_instance
132
+
133
+ if _i18n_instance is None:
134
+ _i18n_instance = I18n(default_language=language or "en")
135
+ elif language is not None:
136
+ _i18n_instance.set_language(language)
137
+
138
+ return _i18n_instance
139
+
140
+
141
+ def t(key: str, **kwargs) -> str:
142
+ """
143
+ Convenience function for translation
144
+
145
+ Args:
146
+ key: Translation key
147
+ **kwargs: Optional format parameters
148
+
149
+ Returns:
150
+ Translated string
151
+ """
152
+ return get_i18n().t(key, **kwargs)
acestep/gradio_ui/i18n/en.json ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 Playground💡",
4
+ "subtitle": "Pushing the Boundaries of Open-Source Music Generation"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 Dataset Explorer",
8
+ "dataset_label": "Dataset",
9
+ "dataset_info": "Choose dataset to explore",
10
+ "import_btn": "📥 Import Dataset",
11
+ "search_type_label": "Search Type",
12
+ "search_type_info": "How to find items",
13
+ "search_value_label": "Search Value",
14
+ "search_value_placeholder": "Enter keys or index (leave empty for random)",
15
+ "search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
16
+ "instruction_label": "📝 Instruction",
17
+ "instruction_placeholder": "No instruction available",
18
+ "metadata_title": "📋 Item Metadata (JSON)",
19
+ "metadata_label": "Complete Item Information",
20
+ "source_audio": "Source Audio",
21
+ "target_audio": "Target Audio",
22
+ "reference_audio": "Reference Audio",
23
+ "get_item_btn": "🔍 Get Item",
24
+ "use_src_checkbox": "Use Source Audio from Dataset",
25
+ "use_src_info": "Check to use the source audio from dataset",
26
+ "data_status_label": "📊 Data Status",
27
+ "data_status_default": "❌ No dataset imported",
28
+ "autofill_btn": "📋 Auto-fill Generation Form"
29
+ },
30
+ "service": {
31
+ "title": "🔧 Service Configuration",
32
+ "checkpoint_label": "Checkpoint File",
33
+ "checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
34
+ "refresh_btn": "🔄 Refresh",
35
+ "model_path_label": "Main Model Path",
36
+ "model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
37
+ "device_label": "Device",
38
+ "device_info": "Processing device (auto-detect recommended)",
39
+ "lm_model_path_label": "5Hz LM Model Path",
40
+ "lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
41
+ "backend_label": "5Hz LM Backend",
42
+ "backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
43
+ "init_llm_label": "Initialize 5Hz LM",
44
+ "init_llm_info": "Check to initialize 5Hz LM during service initialization",
45
+ "flash_attention_label": "Use Flash Attention",
46
+ "flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
47
+ "flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
48
+ "offload_cpu_label": "Offload to CPU",
49
+ "offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
50
+ "offload_dit_cpu_label": "Offload DiT to CPU",
51
+ "offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
52
+ "compile_model_label": "Compile Model",
53
+ "compile_model_info": "Use torch.compile to optimize model (required for quantization)",
54
+ "quantization_label": "INT8 Quantization",
55
+ "quantization_info": "Enable INT8 weight-only quantization to reduce VRAM usage (requires Compile Model)",
56
+ "init_btn": "Initialize Service",
57
+ "status_label": "Status",
58
+ "language_label": "UI Language",
59
+ "language_info": "Select interface language"
60
+ },
61
+ "generation": {
62
+ "required_inputs": "📝 Required Inputs",
63
+ "task_type_label": "Task Type",
64
+ "task_type_info": "Select the task type for generation",
65
+ "instruction_label": "Instruction",
66
+ "instruction_info": "Instruction is automatically generated based on task type",
67
+ "load_btn": "Load",
68
+ "track_name_label": "Track Name",
69
+ "track_name_info": "Select track name for lego/extract tasks",
70
+ "track_classes_label": "Track Names",
71
+ "track_classes_info": "Select multiple track classes for complete task",
72
+ "audio_uploads": "🎵 Audio Uploads",
73
+ "reference_audio": "Reference Audio (optional)",
74
+ "source_audio": "Source Audio (optional)",
75
+ "convert_codes_btn": "Convert to Codes",
76
+ "lm_codes_hints": "🎼 LM Codes Hints",
77
+ "lm_codes_label": "LM Codes Hints",
78
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
79
+ "lm_codes_info": "Paste LM codes hints for text2music generation",
80
+ "lm_codes_sample": "LM Codes Hints (Sample {n})",
81
+ "lm_codes_sample_info": "Codes for sample {n}",
82
+ "transcribe_btn": "Transcribe",
83
+ "repainting_controls": "🎨 Repainting Controls (seconds)",
84
+ "repainting_start": "Repainting Start",
85
+ "repainting_end": "Repainting End",
86
+ "mode_label": "Generation Mode",
87
+ "mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
88
+ "mode_simple": "Simple",
89
+ "mode_custom": "Custom",
90
+ "simple_query_label": "Song Description",
91
+ "simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
92
+ "simple_query_info": "Enter a natural language description of the music you want to generate",
93
+ "simple_vocal_language_label": "Vocal Language (optional)",
94
+ "simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
95
+ "create_sample_btn": "Create Sample",
96
+ "caption_title": "📝 Music Caption",
97
+ "caption_label": "Music Caption (optional)",
98
+ "caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
99
+ "caption_info": "Describe the style, genre, instruments, and mood",
100
+ "lyrics_title": "📝 Lyrics",
101
+ "lyrics_label": "Lyrics (optional)",
102
+ "lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
103
+ "lyrics_info": "Song lyrics with structure",
104
+ "instrumental_label": "Instrumental",
105
+ "format_btn": "Format",
106
+ "optional_params": "⚙️ Optional Parameters",
107
+ "vocal_language_label": "Vocal Language (optional)",
108
+ "vocal_language_info": "use `unknown` for inst",
109
+ "bpm_label": "BPM (optional)",
110
+ "bpm_info": "leave empty for N/A",
111
+ "keyscale_label": "KeyScale (optional)",
112
+ "keyscale_placeholder": "Leave empty for N/A",
113
+ "keyscale_info": "A-G, #/♭, major/minor",
114
+ "timesig_label": "Time Signature (optional)",
115
+ "timesig_info": "2/4, 3/4, 4/4...",
116
+ "duration_label": "Audio Duration (seconds)",
117
+ "duration_info": "Use -1 for random",
118
+ "batch_size_label": "Batch Size",
119
+ "batch_size_info": "Number of audio to generate (max 8)",
120
+ "advanced_settings": "🔧 Advanced Settings",
121
+ "inference_steps_label": "DiT Inference Steps",
122
+ "inference_steps_info": "Turbo: max 8, Base: max 200",
123
+ "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
124
+ "guidance_scale_info": "Higher values follow text more closely",
125
+ "seed_label": "Seed",
126
+ "seed_info": "Use comma-separated values for batches",
127
+ "random_seed_label": "Random Seed",
128
+ "random_seed_info": "Enable to auto-generate seeds",
129
+ "audio_format_label": "Audio Format",
130
+ "audio_format_info": "Audio format for saved files",
131
+ "use_adg_label": "Use ADG",
132
+ "use_adg_info": "Enable Angle Domain Guidance",
133
+ "shift_label": "Shift",
134
+ "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
135
+ "infer_method_label": "Inference Method",
136
+ "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
137
+ "custom_timesteps_label": "Custom Timesteps",
138
+ "custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
139
+ "cfg_interval_start": "CFG Interval Start",
140
+ "cfg_interval_end": "CFG Interval End",
141
+ "lm_params_title": "🤖 LM Generation Parameters",
142
+ "lm_temperature_label": "LM Temperature",
143
+ "lm_temperature_info": "5Hz LM temperature (higher = more random)",
144
+ "lm_cfg_scale_label": "LM CFG Scale",
145
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
146
+ "lm_top_k_label": "LM Top-K",
147
+ "lm_top_k_info": "Top-K (0 = disabled)",
148
+ "lm_top_p_label": "LM Top-P",
149
+ "lm_top_p_info": "Top-P (1.0 = disabled)",
150
+ "lm_negative_prompt_label": "LM Negative Prompt",
151
+ "lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
152
+ "lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
153
+ "cot_metas_label": "CoT Metas",
154
+ "cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
155
+ "cot_language_label": "CoT Language",
156
+ "cot_language_info": "Generate language in CoT (chain-of-thought)",
157
+ "constrained_debug_label": "Constrained Decoding Debug",
158
+ "constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
159
+ "auto_score_label": "Auto Score",
160
+ "auto_score_info": "Automatically calculate quality scores for all generated audios",
161
+ "auto_lrc_label": "Auto LRC",
162
+ "auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
163
+ "lm_batch_chunk_label": "LM Batch Chunk Size",
164
+ "lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
165
+ "codes_strength_label": "LM Codes Strength",
166
+ "codes_strength_info": "Control how many denoising steps use LM-generated codes",
167
+ "similarity_denoise_label": "Similarity / Denoise",
168
+ "similarity_denoise_info": "Controls how closely the output follows the reference audio. Higher values preserve more structure.",
169
+ "cover_strength_label": "Audio Cover Strength",
170
+ "cover_strength_info": "Control how many denoising steps use cover mode",
171
+ "score_sensitivity_label": "Quality Score Sensitivity",
172
+ "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
173
+ "think_label": "Think",
174
+ "parallel_thinking_label": "ParallelThinking",
175
+ "generate_btn": "🎵 Generate Music",
176
+ "autogen_label": "AutoGen",
177
+ "caption_rewrite_label": "CaptionRewrite"
178
+ },
179
+ "results": {
180
+ "title": "🎵 Results",
181
+ "generated_music": "🎵 Generated Music (Sample {n})",
182
+ "send_to_src_btn": "🔗 Send To Src Audio",
183
+ "save_btn": "💾 Save",
184
+ "score_btn": "📊 Score",
185
+ "lrc_btn": "🎵 LRC",
186
+ "quality_score_label": "Quality Score (Sample {n})",
187
+ "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
188
+ "codes_label": "LM Codes (Sample {n})",
189
+ "lrc_label": "Lyrics Timestamps (Sample {n})",
190
+ "lrc_placeholder": "Click 'LRC' to generate timestamps",
191
+ "details_accordion": "📊 Score & LRC & LM Codes",
192
+ "generation_status": "Generation Status",
193
+ "current_batch": "Current Batch",
194
+ "batch_indicator": "Batch {current} / {total}",
195
+ "next_batch_status": "Next Batch Status",
196
+ "prev_btn": "◀ Previous",
197
+ "next_btn": "Next ▶",
198
+ "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
199
+ "batch_results_title": "👇 Click here to view batch results & generation details",
200
+ "all_files_label": "📁 All Generated Files (Download)",
201
+ "generation_details": "Generation Details"
202
+ },
203
+ "messages": {
204
+ "no_audio_to_save": "❌ No audio to save",
205
+ "save_success": "✅ Saved audio and metadata to {filename}",
206
+ "save_failed": "❌ Failed to save: {error}",
207
+ "no_file_selected": "⚠️ No file selected",
208
+ "params_loaded": "✅ Parameters loaded from {filename}",
209
+ "invalid_json": "❌ Invalid JSON file: {error}",
210
+ "load_error": "❌ Error loading file: {error}",
211
+ "example_loaded": "📁 Loaded example from {filename}",
212
+ "example_failed": "Failed to parse JSON file {filename}: {error}",
213
+ "example_error": "Error loading example: {error}",
214
+ "lm_generated": "🤖 Generated example using LM",
215
+ "lm_fallback": "Failed to generate example using LM, falling back to examples directory",
216
+ "lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
217
+ "autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
218
+ "batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
219
+ "batch_generating": "🔄 Starting background generation for Batch {n}...",
220
+ "batch_failed": "❌ Background generation failed: {error}",
221
+ "viewing_batch": "✅ Viewing Batch {n}",
222
+ "at_first_batch": "Already at first batch",
223
+ "at_last_batch": "No next batch available",
224
+ "batch_not_found": "Batch {n} not found in queue",
225
+ "no_batch_data": "No batch data found to restore.",
226
+ "params_restored": "✅ UI Parameters restored from Batch {n}",
227
+ "scoring_failed": "❌ Error: Batch data not found",
228
+ "no_codes": "❌ No audio codes available. Please generate music first.",
229
+ "score_failed": "❌ Scoring failed: {error}",
230
+ "score_error": "❌ Error calculating score: {error}",
231
+ "lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
232
+ "lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
233
+ "lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
234
+ "lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
235
+ "lrc_empty_result": "⚠️ LRC generation produced empty result.",
236
+ "empty_query": "⚠️ Please enter a music description.",
237
+ "sample_creation_failed": "❌ Failed to create sample. Please try again.",
238
+ "sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
239
+ "simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
240
+ "simple_examples_empty": "⚠️ No example files found in simple mode examples.",
241
+ "simple_example_loaded": "🎲 Loaded random example from {filename}",
242
+ "format_success": "✅ Caption and lyrics formatted successfully",
243
+ "format_failed": "❌ Format failed: {error}",
244
+ "skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
245
+ "invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
246
+ "timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
247
+ "timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
248
+ },
249
+ "training": {
250
+ "tab_title": "🎓 LoRA Training",
251
+ "tab_dataset_builder": "📁 Dataset Builder",
252
+ "tab_train_lora": "🚀 Train LoRA",
253
+ "quick_start_title": "🚀 Quick Start",
254
+ "load_dataset_label": "Dataset JSON Path",
255
+ "load_dataset_info": "Load a previously saved dataset",
256
+ "load_btn": "📂 Load",
257
+ "load_status": "Load Status",
258
+ "scan_label": "Audio Directory Path",
259
+ "scan_info": "Scan for audio files (wav, mp3, flac, ogg, opus)",
260
+ "scan_btn": "🔍 Scan",
261
+ "scan_status": "Scan Status",
262
+ "found_audio_files": "Found Audio Files",
263
+ "dataset_name": "Dataset Name",
264
+ "dataset_name_placeholder": "Enter dataset name",
265
+ "dataset_settings_header": "Dataset Settings",
266
+ "tag_prepend": "Prepend (tag, caption)",
267
+ "tag_append": "Append (caption, tag)",
268
+ "tag_replace": "Replace caption",
269
+ "step2_title": "Step 2: Auto-Label with AI",
270
+ "step3_title": "Step 3: Preview & Edit",
271
+ "step4_title": "Step 4: Save Dataset",
272
+ "step5_title": "Step 5: Preprocess to Tensors",
273
+ "all_instrumental": "All Instrumental",
274
+ "all_instrumental_info": "Check if all tracks are instrumental (no vocals)",
275
+ "custom_tag": "Custom Activation Tag",
276
+ "custom_tag_info": "Unique tag to activate this LoRA's style",
277
+ "tag_position": "Tag Position",
278
+ "tag_position_info": "Where to place the custom tag in the caption",
279
+ "genre_ratio": "Genre Ratio (%)",
280
+ "genre_ratio_info": "0%=all Caption, 100%=all Genre. Per-sample override takes priority.",
281
+ "skip_metas": "Skip BPM/Key/Time Signature",
282
+ "skip_metas_info": "Skip BPM/Key/Time Signature generation. Caption and Genre are still generated by LLM.",
283
+ "only_unlabeled": "Only Unlabeled",
284
+ "only_unlabeled_info": "Only label samples without caption (useful for resuming failed labeling)",
285
+ "auto_label_btn": "🏷️ Auto-Label All",
286
+ "label_progress": "Labeling Progress",
287
+ "select_sample": "Select Sample #",
288
+ "select_sample_info": "Choose a sample to preview and edit",
289
+ "audio_preview": "Audio Preview",
290
+ "filename": "Filename",
291
+ "caption": "Caption",
292
+ "genre": "Genre",
293
+ "prompt_override_label": "Prompt Override (this sample)",
294
+ "prompt_override_info": "Override global ratio for this sample",
295
+ "lyrics_editable_label": "Lyrics (editable, used for training)",
296
+ "raw_lyrics_label": "Raw Lyrics (from .txt file)",
297
+ "no_lyrics_placeholder": "(no .txt lyrics file)",
298
+ "bpm": "BPM",
299
+ "key_label": "Key",
300
+ "key_placeholder": "C Major",
301
+ "time_sig": "Time Signature",
302
+ "duration_s": "Duration (s)",
303
+ "language": "Language",
304
+ "instrumental": "Instrumental",
305
+ "save_changes_btn": "💾 Save Changes",
306
+ "edit_status": "Edit Status",
307
+ "save_path": "Save Path",
308
+ "save_path_info": "Path where the dataset JSON will be saved",
309
+ "save_dataset_btn": "💾 Save Dataset",
310
+ "save_status": "Save Status",
311
+ "load_existing_label": "Load Existing Dataset (Optional)",
312
+ "load_existing_info": "Path to a previously saved dataset JSON file",
313
+ "load_dataset_btn": "📂 Load Dataset",
314
+ "tensor_output_dir": "Tensor Output Directory",
315
+ "tensor_output_info": "Directory to save preprocessed tensor files",
316
+ "preprocess_btn": "⚡ Preprocess",
317
+ "preprocess_progress": "Preprocessing Progress",
318
+ "preprocessed_tensors_dir": "Preprocessed Tensors Directory",
319
+ "preprocessed_tensors_info": "Directory containing preprocessed .pt tensor files",
320
+ "train_section_tensors": "Preprocessed Dataset Selection",
321
+ "train_section_lora": "LoRA Settings",
322
+ "train_section_params": "Training Parameters",
323
+ "dataset_info": "Dataset Info",
324
+ "lora_rank": "LoRA Rank (r)",
325
+ "lora_rank_info": "Higher = more capacity, more memory",
326
+ "lora_alpha": "LoRA Alpha",
327
+ "lora_alpha_info": "Scaling factor (typically 2x rank)",
328
+ "lora_dropout": "LoRA Dropout",
329
+ "learning_rate": "Learning Rate",
330
+ "learning_rate_info": "Start with 3e-4, adjust if needed",
331
+ "max_epochs": "Max Epochs",
332
+ "batch_size": "Batch Size",
333
+ "batch_size_info": "Increase if you have enough VRAM",
334
+ "gradient_accumulation": "Gradient Accumulation",
335
+ "gradient_accumulation_info": "Effective batch = batch_size × accumulation",
336
+ "save_every_n_epochs": "Save Every N Epochs",
337
+ "shift": "Shift",
338
+ "shift_info": "Timestep shift for turbo model",
339
+ "seed": "Seed",
340
+ "output_dir": "Output Directory",
341
+ "output_dir_info": "Directory to save trained LoRA weights",
342
+ "start_training_btn": "🚀 Start Training",
343
+ "stop_training_btn": "⏹️ Stop Training",
344
+ "training_progress": "Training Progress",
345
+ "training_log": "Training Log",
346
+ "training_loss_title": "Training Loss",
347
+ "step": "Step",
348
+ "loss": "Loss",
349
+ "export_header": "Export LoRA",
350
+ "export_path": "Export Path",
351
+ "export_lora_btn": "📦 Export LoRA",
352
+ "export_status": "Export Status"
353
+ }
354
+ }
acestep/gradio_ui/i18n/he.json ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ סביבת העבודה ACE-Step V1.5 Playground💡",
4
+ "subtitle": "פורצים את גבולות יצירת המוזיקה בקוד פתוח"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 סייר מערכי נתונים (Dataset Explorer)",
8
+ "dataset_label": "מערך נתונים",
9
+ "dataset_info": "בחר מערך נתונים לחקירה",
10
+ "import_btn": "📥 ייבוא מערך נתונים",
11
+ "search_type_label": "סוג חיפוש",
12
+ "search_type_info": "כיצד למצוא פריטים",
13
+ "search_value_label": "ערך חיפוש",
14
+ "search_value_placeholder": "הזן מפתחות או אינדקס (השאר ריק לבחירה אקראית)",
15
+ "search_value_info": "מפתחות: התאמה מדויקת, אינדקס: 0 עד גודל המערך פחות 1",
16
+ "instruction_label": "📝 הנחיה (Instruction)",
17
+ "instruction_placeholder": "אין הנחיה זמינה",
18
+ "metadata_title": "📋 מטא-דאטה של הפריט (JSON)",
19
+ "metadata_label": "מידע מלא על הפריט",
20
+ "source_audio": "אודיו מקור",
21
+ "target_audio": "אודיו יעד",
22
+ "reference_audio": "אודיו ייחוס",
23
+ "get_item_btn": "🔍 קבל פריט",
24
+ "use_src_checkbox": "השתמש באודיו מקור ממערך הנתונים",
25
+ "use_src_info": "סמן כדי להשתמש באודיו המקור מתוך מערך הנתונים",
26
+ "data_status_label": "📊 מצב נתונים",
27
+ "data_status_default": "❌ לא יובא מערך נתונים",
28
+ "autofill_btn": "📋 מילוי אוטומטי של טופס היצירה"
29
+ },
30
+ "service": {
31
+ "title": "🔧 הגדרות שירות",
32
+ "checkpoint_label": "קובץ נקודת ביקורת (Checkpoint)",
33
+ "checkpoint_info": "בחר קובץ נקודת ביקורת של מודל מאומן (נתיב מלא או שם קובץ)",
34
+ "refresh_btn": "🔄 רענון",
35
+ "model_path_label": "נתיב מודל ראשי",
36
+ "model_path_info": "בחר את ספריית הגדרות המודל (נסרק אוטומטית מנקודות הביקורת)",
37
+ "device_label": "מכשיר (Device)",
38
+ "device_info": "מכשיר עיבוד (מומלץ זיהוי אוטומטי)",
39
+ "lm_model_path_label": "נתיב מודל 5Hz LM",
40
+ "lm_model_path_info": "בחר את קובץ נקודת הביקורת של מודל ה-5Hz LM",
41
+ "backend_label": "מנוע (Backend) 5Hz LM",
42
+ "backend_info": "בחר מנוע עבור 5Hz LM: vllm (מהיר יותר) או pt (PyTorch, תואם יותר)",
43
+ "init_llm_label": "אתחול 5Hz LM",
44
+ "init_llm_info": "סמן כדי לאתחל את ה-5Hz LM במהלך אתחול השירות",
45
+ "flash_attention_label": "השתמש ב-Flash Attention",
46
+ "flash_attention_info_enabled": "הפעל Flash Attention להסקה מהירה יותר (דורש חבילת flash_attn)",
47
+ "flash_attention_info_disabled": "Flash Attention אינו זמין (חבילת flash_attn לא מותקנת)",
48
+ "offload_cpu_label": "העברה ל-CPU (Offload)",
49
+ "offload_cpu_info": "העבר מודלים ל-CPU כשאינם בשימוש כדי לחסוך בזיכרון גרפי (VRAM)",
50
+ "offload_dit_cpu_label": "העברת DiT ל-CPU",
51
+ "offload_dit_cpu_info": "העבר DiT ל-CPU (דורש 'העברה ל-CPU')",
52
+ "compile_model_label": "הידור מודל (Compile)",
53
+ "compile_model_info": "השתמש ב-torch.compile לאופטימיזציה של המודל (נדרש עבור קוונטיזציה)",
54
+ "quantization_label": "קוונטיזציה INT8",
55
+ "quantization_info": "הפעל קוונטיזציה של משקולות בלבד (INT8) להפחתת שימוש ב-VRAM (דורש הידור מודל)",
56
+ "init_btn": "אתחול שירות",
57
+ "status_label": "מצב",
58
+ "language_label": "שפת ממשק",
59
+ "language_info": "בחר את שפת הממשק"
60
+ },
61
+ "generation": {
62
+ "required_inputs": "📝 קלטים נדרשים",
63
+ "task_type_label": "סוג משימה",
64
+ "task_type_info": "בחר את סוג המשימה ליצירה",
65
+ "instruction_label": "הנחיה",
66
+ "instruction_info": "ההנחיה נוצרת אוטומטית בהתאם לסוג המשימה",
67
+ "load_btn": "טעינה",
68
+ "track_name_label": "שם רצועה",
69
+ "track_name_info": "בחר שם רצועה עבור משימות lego/extract",
70
+ "track_classes_label": "שמות רצועות",
71
+ "track_classes_info": "בחר מספר מחלקות רצועה עבור משימה מלאה",
72
+ "audio_uploads": "🎵 העלאות אודיו",
73
+ "reference_audio": "אודיו ייחוס (אופציונלי)",
74
+ "source_audio": "אודיו מקור (אופציונלי)",
75
+ "convert_codes_btn": "המר לקודים",
76
+ "lm_codes_hints": "🎼 רמזי קודי LM",
77
+ "lm_codes_label": "רמזי קודי LM",
78
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
79
+ "lm_codes_info": "הדבק רמזי קודי LM עבור יצירת טקסט למוזיקה (text2music)",
80
+ "lm_codes_sample": "רמזי קודי LM (דגימה {n})",
81
+ "lm_codes_sample_info": "קודים עבור דגימה {n}",
82
+ "transcribe_btn": "תמלול",
83
+ "repainting_controls": "🎨 בקרת צביעה מחדש (בשניות)",
84
+ "repainting_start": "תחילת צביעה מחדש",
85
+ "repainting_end": "סיום צביעה מחדש",
86
+ "mode_label": "מצב יצירה",
87
+ "mode_info": "פשוט: תאר מוזיקה בשפה טבעית. מותאם אישית: שליטה מלאה בתיאור ומילים.",
88
+ "mode_simple": "פשוט",
89
+ "mode_custom": "מותאם אישית",
90
+ "simple_query_label": "תיאור השיר",
91
+ "simple_query_placeholder": "תאר את המוזיקה שברצונך ליצור, למשל: 'שיר אהבה אקוסטי שקט לערב רגוע'. השאר ריק לדגימה אקראית.",
92
+ "simple_query_info": "הזן תיאור בשפה טבעית של המוזיקה שברצונך ליצור",
93
+ "simple_vocal_language_label": "שפת שירה (אופציונלי)",
94
+ "simple_vocal_language_info": "בחר שפות מועדפות למילים. השתמש ב-'unknown' לכל שפה.",
95
+ "create_sample_btn": "צור דגימה",
96
+ "caption_title": "📝 תיאור מוזיקלי (Caption)",
97
+ "caption_label": "תיאור מוזיקלי (אופציונלי)",
98
+ "caption_placeholder": "מנגינת גיטרה אקוסטית שלווה עם שירה רכה...",
99
+ "caption_info": "תאר את הסגנון, הז'אנר, הכלים והאווירה",
100
+ "lyrics_title": "📝 מילים",
101
+ "lyrics_label": "מילים (אופציונלי)",
102
+ "lyrics_placeholder": "[בית 1]\\nתחת שמי הלילה...\\nאני מרגיש חי...",
103
+ "lyrics_info": "מילות השיר עם מבנה",
104
+ "instrumental_label": "אינסטרומנטלי (ללא שירה)",
105
+ "format_btn": "פרמוט",
106
+ "optional_params": "⚙️ פרמטרים אופציונליים",
107
+ "vocal_language_label": "שפת שירה (אופציונלי)",
108
+ "vocal_language_info": "השתמש ב-`unknown` לקטעים כליים",
109
+ "bpm_label": "קצב (BPM) (אופציונלי)",
110
+ "bpm_info": "השאר ריק אם לא ידוע",
111
+ "keyscale_label": "סולם (KeyScale) (אופציונלי)",
112
+ "keyscale_placeholder": "השאר ריק אם לא ידוע",
113
+ "keyscale_info": "A-G, #/♭, מז'ור/מינור",
114
+ "timesig_label": "משקל מוזיקלי (אופציונלי)",
115
+ "timesig_info": "2/4, 3/4, 4/4...",
116
+ "duration_label": "אורך אודיו (שניות)",
117
+ "duration_info": "השתמש ב-1- לאקראי",
118
+ "batch_size_label": "גודל מנה (Batch Size)",
119
+ "batch_size_info": "מספר קטעי אודיו ליצירה (מקסימום 8)",
120
+ "advanced_settings": "🔧 הגדרות מתקדמות",
121
+ "inference_steps_label": "צעדי הסקה של DiT",
122
+ "inference_steps_info": "Turbo: מקסימום 8, Base: מקסימום 200",
123
+ "guidance_scale_label": "קנה מידה להנחיה (רק למודל base)",
124
+ "guidance_scale_info": "ערכים גבוהים יותר נצמדים יותר לטקסט",
125
+ "seed_label": "גרעין (Seed)",
126
+ "seed_info": "השתמש בערכים מופרדים בפסיקים עבור מנות",
127
+ "random_seed_label": "גרעין אקראי",
128
+ "random_seed_info": "אפשר ליצירה אוטומטית של גרעינים",
129
+ "audio_format_label": "פורמט אודיו",
130
+ "audio_format_info": "פורמט האודיו עבור הקבצים שיישמרו",
131
+ "use_adg_label": "השתמש ב-ADG",
132
+ "use_adg_info": "הפעל Angle Domain Guidance",
133
+ "shift_label": "Shift",
134
+ "shift_info": "פקטור הסטת צעדי זמן למודלי base (טווח 1.0~5.0, ברירת מחדל 3.0). לא משפיע על מודלי turbo.",
135
+ "infer_method_label": "שיטת הסקה",
136
+ "infer_method_info": "שיטת הסקת הדיפוזיה. ODE (Euler) מהירה יותר, SDE (stochastic) עשויה להפיק תוצאות שונות.",
137
+ "custom_timesteps_label": "צעדי זמן מותאמים אישית",
138
+ "custom_timesteps_info": "אופציונלי: ערכים מופרדים בפסיקים מ-1.0 עד 0.0. דורס את צעדי ההסקה וה-shift.",
139
+ "cfg_interval_start": "תחילת מרווח CFG",
140
+ "cfg_interval_end": "סיום מרווח CFG",
141
+ "lm_params_title": "🤖 פרמטרי יצירת LM",
142
+ "lm_temperature_label": "טמפרטורת LM",
143
+ "lm_temperature_info": "טמפרטורת 5Hz LM (גבוה יותר = אקראי יותר)",
144
+ "lm_cfg_scale_label": "קנה מידה LM CFG",
145
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = ללא CFG)",
146
+ "lm_top_k_label": "LM Top-K",
147
+ "lm_top_k_info": "Top-K (0 = מושבת)",
148
+ "lm_top_p_label": "LM Top-P",
149
+ "lm_top_p_info": "Top-P (1.0 = מושבת)",
150
+ "lm_negative_prompt_label": "הנחיה שלילית ל-LM",
151
+ "lm_negative_prompt_placeholder": "הזן הנחיה שלילית עבור CFG",
152
+ "lm_negative_prompt_info": "הנחיה שלילית (בשימוש כאשר LM CFG Scale > 1.0)",
153
+ "cot_metas_label": "CoT Metas",
154
+ "cot_metas_info": "השתמש ב-LM ליצירת מטא-דאטה CoT (בטל סימון כדי לדלג)",
155
+ "cot_language_label": "שפת CoT",
156
+ "cot_language_info": "יצירת שפה ב-CoT (שרשרת מחשבה)",
157
+ "constrained_debug_label": "ניקוי באגים של פענוח מוגבל",
158
+ "constrained_debug_info": "הפעל לוגים של ניקוי באגים עבור פענוח מוגבל",
159
+ "auto_score_label": "דירוג אוטומטי",
160
+ "auto_score_info": "חשב אוטומטית ציוני איכות לכל קטעי האודיו שנוצרו",
161
+ "auto_lrc_label": "LRC אוטומטי",
162
+ "auto_lrc_info": "צור אוטומטית חותמות זמן למילים (LRC) לכל קטעי האודיו",
163
+ "lm_batch_chunk_label": "גודל מקטע מנת LM",
164
+ "lm_batch_chunk_info": "מקסימום פריטים למקטע מנת LM (ברירת מחדל: 8, מוגבל ע\"י זיכרון ה-GPU)",
165
+ "codes_strength_label": "חוזק קודי LM",
166
+ "codes_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים בקודים שנוצרו ע\"י ה-LM",
167
+ "cover_strength_label": "חוזק כיסוי אודיו (Audio Cover)",
168
+ "cover_strength_info": "שליטה בכמות צעדי הניקוי מרעש המשתמשים במצב כיסוי",
169
+ "score_sensitivity_label": "רגישות ציון איכות",
170
+ "score_sensitivity_info": "נמוך יותר = רגיש יותר (ברירת מחדל: 1.0)",
171
+ "think_label": "חשיבה (Think)",
172
+ "parallel_thinking_label": "חשיבה מקבילית",
173
+ "generate_btn": "🎵 צור מוזיקה",
174
+ "autogen_label": "יצירה אוטומטית",
175
+ "caption_rewrite_label": "שכתוב תיאור"
176
+ },
177
+ "results": {
178
+ "title": "🎵 תוצאות",
179
+ "generated_music": "🎵 מוזיקה שנוצרה (דגימה {n})",
180
+ "send_to_src_btn": "🔗 שלח לאודיו מקור",
181
+ "save_btn": "💾 שמירה",
182
+ "score_btn": "📊 דירוג",
183
+ "lrc_btn": "🎵 LRC",
184
+ "quality_score_label": "ציון איכות (דגימה {n})",
185
+ "quality_score_placeholder": "לחץ על 'דירוג' לחישוב ציון איכות מבוסס מורכבות (Perplexity)",
186
+ "codes_label": "קודי LM (דגימה {n})",
187
+ "lrc_label": "חותמות זמן למילים (דגימה {n})",
188
+ "lrc_placeholder": "לחץ על 'LRC' ליצירת חותמות זמן",
189
+ "details_accordion": "📊 דירוג, LRC וקודי LM",
190
+ "generation_status": "מצב יצירה",
191
+ "current_batch": "מנה נוכחית",
192
+ "batch_indicator": "מנה {current} / {total}",
193
+ "next_batch_status": "מצב המנה הבאה",
194
+ "prev_btn": "◀ הקודם",
195
+ "next_btn": "הבא ▶",
196
+ "restore_params_btn": "↙️ החל הגדרות אלו על הממשק (שחזור פרמטרי מנה)",
197
+ "batch_results_title": "📁 תוצאות המנה ופרטי יצירה",
198
+ "all_files_label": "📁 כל הקבצים שנוצרו (הורדה)",
199
+ "generation_details": "פרטי יצירה"
200
+ },
201
+ "messages": {
202
+ "no_audio_to_save": "❌ אין אודיו לשמירה",
203
+ "save_success": "✅ האודיו והמטא-דאטה נשמרו ב-{filename}",
204
+ "save_failed": "❌ השמירה נכשלה: {error}",
205
+ "no_file_selected": "⚠️ לא נבחר קובץ",
206
+ "params_loaded": "✅ הפרמטרים נטענו מ-{filename}",
207
+ "invalid_json": "❌ קובץ JSON לא תקין: {error}",
208
+ "load_error": "❌ שגיאה בטעינת הקובץ: {error}",
209
+ "example_loaded": "📁 נטען דגם מ-{filename}",
210
+ "example_failed": "נכשל ניתוח קובץ ה-JSON ב-{filename}: {error}",
211
+ "example_error": "שגיאה בטעינת הדגם: {error}",
212
+ "lm_generated": "🤖 נוצר דגם באמצעות ה-LM",
213
+ "lm_fallback": "יצירת דגם באמצעות ה-LM נכשלה, חוזר לשימוש בספריית הדגמים",
214
+ "lm_not_initialized": "❌ 5Hz LM לא מאותחל. נא לאתחל אותו תחילה.",
215
+ "autogen_enabled": "🔄 יצירה אוטומטית הופעלה - המנה הבאה תיווצר לאחר מכן",
216
+ "batch_ready": "✅ מנה {n} מוכנה! לחץ על 'הבא' לצפייה.",
217
+ "batch_generating": "🔄 מתחיל יצירת רקע עבור מנה {n}...",
218
+ "batch_failed": "❌ יצירת הרקע נכשלה: {error}",
219
+ "viewing_batch": "✅ צופה במנה {n}",
220
+ "at_first_batch": "נמצא כבר במנה הראשונה",
221
+ "at_last_batch": "אין מנה באה זמינה",
222
+ "batch_not_found": "מנה {n} לא נמצאה בתור",
223
+ "no_batch_data": "לא נמצאו נתוני מנה לשחזור.",
224
+ "params_restored": "✅ פרמטרי הממשק שוחזרו ממנה {n}",
225
+ "scoring_failed": "❌ שגיאה: נתוני המנה לא נמצאו",
226
+ "no_codes": "❌ אין קודי אודיו זמינים. נא ליצור מוזיקה תחילה.",
227
+ "score_failed": "❌ הדירוג נכשל: {error}",
228
+ "score_error": "❌ שגיאה בחישוב הציון: {error}",
229
+ "lrc_no_batch_data": "❌ לא נמצאו נתוני מנה. נא ליצור מוזיקה תחילה.",
230
+ "lrc_no_extra_outputs": "❌ לא נמצאו פלטים נוספים. טנזורי התניה אינם זמינים.",
231
+ "lrc_missing_tensors": "❌ חסרים טנזורים נדרשים ליצירת LRC.",
232
+ "lrc_sample_not_exist": "❌ הדגימה אינה קיימת במנה הנוכחית.",
233
+ "lrc_empty_result": "⚠️ יצירת ה-LRC הפיקה תוצאה ריקה.",
234
+ "empty_query": "⚠️ נא להזין תיאור מוזיקלי.",
235
+ "sample_creation_failed": "❌ יצירת הדגימה נכשלה. נא לנסות שוב.",
236
+ "sample_created": "✅ הדגימה נוצרה! בדוק את התיאור והמילים, ולאחר מכן לחץ על 'צור מוזיקה'.",
237
+ "simple_examples_not_found": "⚠️ ספריית הדגמים של המצב הפשוט לא נמצאה.",
238
+ "simple_examples_empty": "⚠️ לא נמצאו קבצי דוגמה במצב פשוט.",
239
+ "simple_example_loaded": "🎲 נטענה דוגמה אקראית מ-{filename}",
240
+ "format_success": "✅ התיאור והמילים פורמטו בהצלחה",
241
+ "format_failed": "❌ הפירמוט נכשל: {error}",
242
+ "skipping_metas_cot": "⚡ מדלג על שלב 1 של מטא-דאטה COT (הדגימה כבר מפורמטת)",
243
+ "invalid_timesteps_format": "⚠️ פורמט צעדי זמן לא תקין. משתמש בלוח זמנים כברירת מחדל.",
244
+ "timesteps_out_of_range": "⚠️ צעדי הזמן חייבים להיות בטווח [0, 1]. משתמש בלוח זמנים כברירת מחדל.",
245
+ "timesteps_count_mismatch": "⚠️ מספר צעדי הזמן ({actual}) שונה מצעדי ההסקה ({expected}). משתמש במספר צעדי הזמן."
246
+ },
247
+ "training": {
248
+ "tab_title": "🎓 אימון LoRA",
249
+ "tab_dataset_builder": "📁 בונה מערך נתונים",
250
+ "tab_train_lora": "🚀 אימון LoRA",
251
+ "quick_start_title": "🚀 התחלה מהירה",
252
+ "load_dataset_label": "נתיב קובץ JSON של מערך הנתונים",
253
+ "load_dataset_info": "טעינת מערך נתונים שנשמר בעבר",
254
+ "load_btn": "📂 טעינה",
255
+ "load_status": "מצב טעינה",
256
+ "scan_label": "נתיב ספריית אודיו",
257
+ "scan_info": "סריקה אחר קבצי אודיו (wav, mp3, flac, ogg, opus)",
258
+ "scan_btn": "🔍 סריקה",
259
+ "scan_status": "מצב סריקה",
260
+ "found_audio_files": "קבצי אודיו שנמצאו",
261
+ "dataset_name": "שם מערך הנתונים",
262
+ "dataset_name_placeholder": "הזן שם למערך הנתונים",
263
+ "dataset_settings_header": "הגדרות מערך נתונים",
264
+ "tag_prepend": "הוספה בהתחלה (תגית, תיאור)",
265
+ "tag_append": "הוספה בסוף (תיאור, תגית)",
266
+ "tag_replace": "החלפת התיאור",
267
+ "step2_title": "שלב 2: תיוג אוטומטי באמצעות AI",
268
+ "step3_title": "שלב 3: תצוגה מקדימה ועריכה",
269
+ "step4_title": "שלב 4: שמירת מערך הנתונים",
270
+ "step5_title": "שלב 5: עיבוד מקדים לטנזורים (Tensors)",
271
+ "all_instrumental": "הכל אינסטרומנטלי",
272
+ "all_instrumental_info": "סמן אם כל הרצועות הן כליות (ללא שירה)",
273
+ "custom_tag": "תגית הפעלה מותאמת אישית",
274
+ "custom_tag_info": "תגית ייחודית להפעלת הסגנון של LoRA זו",
275
+ "tag_position": "מיקום התגית",
276
+ "tag_position_info": "היכן למקם את התגית המותאמת אישית בתוך התיאור",
277
+ "genre_ratio": "יחס ז'אנר (%)",
278
+ "genre_ratio_info": "0% = הכל תיאור, 100% = הכל ז'אנר. הגדרה פר-דגימה קודמת להגדרת הכלל.",
279
+ "skip_metas": "דלג על BPM/סולם/משקל",
280
+ "skip_metas_info": "דלג על יצירת BPM/סולם/משקל. התיאור והז'אנר עדיין ייווצרו על ידי ה-LLM.",
281
+ "only_unlabeled": "רק כאלו ללא תיוג",
282
+ "only_unlabeled_info": "תייג רק דגימות ללא תיאור (שימושי להמשך תיוג שנכשל)",
283
+ "auto_label_btn": "🏷️ תיוג אוטומטי של הכל",
284
+ "label_progress": "התקדמות התיוג",
285
+ "select_sample": "בחר דגימה #",
286
+ "select_sample_info": "בחר דגימה לצפייה ועריכה",
287
+ "audio_preview": "תצוגה מקדימה של אודיו",
288
+ "filename": "שם קובץ",
289
+ "caption": "תיאור",
290
+ "genre": "ז'אנר",
291
+ "prompt_override_label": "דריסת פרומפט (לדגימה זו)",
292
+ "prompt_override_info": "דריסת היחס הכללי עבור דגימה ז��",
293
+ "lyrics_editable_label": "מילים (ניתן לעריכה, משמש לאימון)",
294
+ "raw_lyrics_label": "מילים גולמיות (מתוך קובץ .txt)",
295
+ "no_lyrics_placeholder": "(אין קובץ מילים .txt)",
296
+ "bpm": "BPM",
297
+ "key_label": "סולם (Key)",
298
+ "key_placeholder": "C Major",
299
+ "time_sig": "משקל מוזיקלי",
300
+ "duration_s": "משך (שניות)",
301
+ "language": "שפה",
302
+ "instrumental": "אינסטרומנטלי",
303
+ "save_changes_btn": "💾 שמירת שינויים",
304
+ "edit_status": "מצב עריכה",
305
+ "save_path": "נתיב שמירה",
306
+ "save_path_info": "הנתיב שבו יישמר קובץ ה-JSON של מערך הנתונים",
307
+ "save_dataset_btn": "💾 שמירת מערך נתונים",
308
+ "save_status": "מצב שמירה",
309
+ "load_existing_label": "טעינת מערך נתונים קיים (אופציונלי)",
310
+ "load_existing_info": "נתיב לקובץ JSON של מערך נתונים שנשמר בעבר",
311
+ "load_dataset_btn": "📂 טעינת מערך נתונים",
312
+ "tensor_output_dir": "ספריית פלט של טנזורים",
313
+ "tensor_output_info": "הספרייה לשמירת קבצי טנזור שעברו עיבוד מקדים",
314
+ "preprocess_btn": "⚡ עיבוד מקדים",
315
+ "preprocess_progress": "התקדמות עיבוד מקדים",
316
+ "preprocessed_tensors_dir": "ספריית טנזורים מעובדים",
317
+ "preprocessed_tensors_info": "ספרייה המכילה קבצי .pt של טנזורים מעובדים",
318
+ "train_section_tensors": "בחירת מערך נתונים מעובד",
319
+ "train_section_lora": "הגדרות LoRA",
320
+ "train_section_params": "פרמטרי אימון",
321
+ "dataset_info": "מידע על מערך הנתונים",
322
+ "lora_rank": "דרגת LoRA (Rank)",
323
+ "lora_rank_info": "גבוה יותר = יותר קיבולת, יותר זיכרון",
324
+ "lora_alpha": "LoRA Alpha",
325
+ "lora_alpha_info": "פקטור קנה מידה (בדרך כלל פי 2 מה-Rank)",
326
+ "lora_dropout": "LoRA Dropout",
327
+ "learning_rate": "קצב למידה (Learning Rate)",
328
+ "learning_rate_info": "התחל עם 3e-4, שנה במידת הצורך",
329
+ "max_epochs": "מקסימום תקופות (Epochs)",
330
+ "batch_size": "גודל מנה (Batch Size)",
331
+ "batch_size_info": "הגדל אם יש לך מספיק זיכרון גרפי (VRAM)",
332
+ "gradient_accumulation": "צבירת גרדיאנטים (Accumulation)",
333
+ "gradient_accumulation_info": "גודל מנה אפקטיבי = גודל מנה × צבירה",
334
+ "save_every_n_epochs": "שמור כל N תקופות (Epochs)",
335
+ "shift": "Shift (הסטה)",
336
+ "shift_info": "הסטת צעדי זמן עבור מודל turbo",
337
+ "seed": "גרעין (Seed)",
338
+ "output_dir": "ספריית פלט",
339
+ "output_dir_info": "ספרייה לשמירת משקולות ה-LoRA המאומנות",
340
+ "start_training_btn": "🚀 התחלת אימון",
341
+ "stop_training_btn": "⏹️ עצירת אימון",
342
+ "training_progress": "התקדמות האימון",
343
+ "training_log": "יומן אימון",
344
+ "training_loss_title": "הפסד אימון (Training Loss)",
345
+ "step": "צעד",
346
+ "loss": "הפסד (Loss)",
347
+ "export_header": "ייצוא LoRA",
348
+ "export_path": "נתיב ייצוא",
349
+ "export_lora_btn": "📦 ייצוא LoRA",
350
+ "export_status": "מצב ייצוא"
351
+ }
352
+ }
acestep/gradio_ui/i18n/ja.json ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
4
+ "subtitle": "オープンソース音楽生成の限界を押し広げる"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 データセットエクスプローラー",
8
+ "dataset_label": "データセット",
9
+ "dataset_info": "探索するデータセットを選択",
10
+ "import_btn": "📥 データセットをインポート",
11
+ "search_type_label": "検索タイプ",
12
+ "search_type_info": "アイテムの検索方法",
13
+ "search_value_label": "検索値",
14
+ "search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
15
+ "search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
16
+ "instruction_label": "📝 指示",
17
+ "instruction_placeholder": "利用可能な指示がありません",
18
+ "metadata_title": "📋 アイテムメタデータ (JSON)",
19
+ "metadata_label": "完全なアイテム情報",
20
+ "source_audio": "ソースオーディオ",
21
+ "target_audio": "ターゲットオーディオ",
22
+ "reference_audio": "リファレンスオーディオ",
23
+ "get_item_btn": "🔍 アイテムを取得",
24
+ "use_src_checkbox": "データセットのソースオーディオを使用",
25
+ "use_src_info": "データセットのソースオーディオを使用する場合はチェック",
26
+ "data_status_label": "📊 データステータス",
27
+ "data_status_default": "❌ データセットがインポートされていません",
28
+ "autofill_btn": "📋 生成フォームを自動入力"
29
+ },
30
+ "service": {
31
+ "title": "🔧 サービス設定",
32
+ "checkpoint_label": "チェックポイントファイル",
33
+ "checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
34
+ "refresh_btn": "🔄 更新",
35
+ "model_path_label": "メインモデルパス",
36
+ "model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
37
+ "device_label": "デバイス",
38
+ "device_info": "処理デバイス(自動検出を推奨)",
39
+ "lm_model_path_label": "5Hz LM モデルパス",
40
+ "lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
41
+ "backend_label": "5Hz LM バックエンド",
42
+ "backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
43
+ "init_llm_label": "5Hz LM を初期化",
44
+ "init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
45
+ "flash_attention_label": "Flash Attention を使用",
46
+ "flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
47
+ "flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
48
+ "offload_cpu_label": "CPUにオフロード",
49
+ "offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
50
+ "offload_dit_cpu_label": "DiTをCPUにオフロード",
51
+ "offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
52
+ "compile_model_label": "モデルをコンパイル",
53
+ "compile_model_info": "torch.compileでモデルを最適化(量子化に必要)",
54
+ "quantization_label": "INT8 量子化",
55
+ "quantization_info": "INT8重み量子化を有効にしてVRAMを節約(モデルのコンパイルが必要)",
56
+ "init_btn": "サービスを初期化",
57
+ "status_label": "ステータス",
58
+ "language_label": "UI言語",
59
+ "language_info": "インターフェース言語を選択"
60
+ },
61
+ "generation": {
62
+ "required_inputs": "📝 必須入力",
63
+ "task_type_label": "タスクタイプ",
64
+ "task_type_info": "生成のタスクタイプを選択",
65
+ "instruction_label": "指示",
66
+ "instruction_info": "指示はタスクタイプに基づいて自動生成されます",
67
+ "load_btn": "読み込む",
68
+ "track_name_label": "トラック名",
69
+ "track_name_info": "lego/extractタスクのトラック名を選択",
70
+ "track_classes_label": "トラック名",
71
+ "track_classes_info": "completeタスクの複数のトラッククラスを選択",
72
+ "audio_uploads": "🎵 オーディオアップロード",
73
+ "reference_audio": "リファレンスオーディオ(オプション)",
74
+ "source_audio": "ソースオーディオ(オプション)",
75
+ "convert_codes_btn": "コードに変換",
76
+ "lm_codes_hints": "🎼 LM コードヒント",
77
+ "lm_codes_label": "LM コードヒント",
78
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
79
+ "lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
80
+ "lm_codes_sample": "LM コードヒント(サンプル {n})",
81
+ "lm_codes_sample_info": "サ���プル{n}のコード",
82
+ "transcribe_btn": "転写",
83
+ "repainting_controls": "🎨 再描画コントロール(秒)",
84
+ "repainting_start": "再描画開始",
85
+ "repainting_end": "再描画終了",
86
+ "mode_label": "生成モード",
87
+ "mode_info": "シンプル:自然言語で音楽を説明。カスタム:キャプションと歌詞を完全にコントロール。",
88
+ "mode_simple": "シンプル",
89
+ "mode_custom": "カスタム",
90
+ "simple_query_label": "曲の説明",
91
+ "simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
92
+ "simple_query_info": "生成したい音楽の自然言語の説明を入力",
93
+ "simple_vocal_language_label": "ボーカル言語(オプション)",
94
+ "simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
95
+ "create_sample_btn": "サンプル作成",
96
+ "caption_title": "📝 音楽キャプション",
97
+ "caption_label": "音楽キャプション(オプション)",
98
+ "caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
99
+ "caption_info": "スタイル、ジャンル、楽器、ムードを説明",
100
+ "lyrics_title": "📝 歌詞",
101
+ "lyrics_label": "歌詞(オプション)",
102
+ "lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
103
+ "lyrics_info": "構造を持つ曲の歌詞",
104
+ "instrumental_label": "インストゥルメンタル",
105
+ "format_btn": "フォーマット",
106
+ "optional_params": "⚙️ オプションパラメータ",
107
+ "vocal_language_label": "ボーカル言語(オプション)",
108
+ "vocal_language_info": "インストには`unknown`を使用",
109
+ "bpm_label": "BPM(オプション)",
110
+ "bpm_info": "空白の場合はN/A",
111
+ "keyscale_label": "キースケール(オプション)",
112
+ "keyscale_placeholder": "空白の場合はN/A",
113
+ "keyscale_info": "A-G, #/♭, メジャー/マイナー",
114
+ "timesig_label": "拍子記号(オプション)",
115
+ "timesig_info": "2/4, 3/4, 4/4...",
116
+ "duration_label": "オーディオ長(秒)",
117
+ "duration_info": "ランダムの場合は-1を使用",
118
+ "batch_size_label": "バッチサイズ",
119
+ "batch_size_info": "生成するオーディオの数(最大8)",
120
+ "advanced_settings": "🔧 詳細設定",
121
+ "inference_steps_label": "DiT 推論ステップ",
122
+ "inference_steps_info": "Turbo: 最大8、Base: 最大200",
123
+ "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
124
+ "guidance_scale_info": "値が高いほどテキストに忠実に従う",
125
+ "seed_label": "シード",
126
+ "seed_info": "バッチにはカンマ区切りの値を使用",
127
+ "random_seed_label": "ランダムシード",
128
+ "random_seed_info": "有効にすると自動的にシードを生成",
129
+ "audio_format_label": "オーディオフォーマット",
130
+ "audio_format_info": "保存ファイルのオーディオフォーマット",
131
+ "use_adg_label": "ADG を使用",
132
+ "use_adg_info": "角度ドメインガイダンスを有効化",
133
+ "shift_label": "シフト",
134
+ "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
135
+ "infer_method_label": "推論方法",
136
+ "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
137
+ "custom_timesteps_label": "カスタムタイムステップ",
138
+ "custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
139
+ "cfg_interval_start": "CFG 間隔開始",
140
+ "cfg_interval_end": "CFG 間隔終了",
141
+ "lm_params_title": "🤖 LM 生成パラメータ",
142
+ "lm_temperature_label": "LM 温度",
143
+ "lm_temperature_info": "5Hz LM温度(高いほどランダム)",
144
+ "lm_cfg_scale_label": "LM CFG スケール",
145
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
146
+ "lm_top_k_label": "LM Top-K",
147
+ "lm_top_k_info": "Top-K (0 = 無効)",
148
+ "lm_top_p_label": "LM Top-P",
149
+ "lm_top_p_info": "Top-P (1.0 = 無効)",
150
+ "lm_negative_prompt_label": "LM ネガティブプロンプト",
151
+ "lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
152
+ "lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
153
+ "cot_metas_label": "CoT メタデータ",
154
+ "cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
155
+ "cot_language_label": "CoT 言語",
156
+ "cot_language_info": "CoTで言語を生成(思考の連鎖)",
157
+ "constrained_debug_label": "制約付きデコーディングデバッグ",
158
+ "constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
159
+ "auto_score_label": "自動スコアリング",
160
+ "auto_score_info": "生成されたすべてのオーディオの品質スコアを自動計算",
161
+ "auto_lrc_label": "自動 LRC",
162
+ "auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
163
+ "lm_batch_chunk_label": "LM バッチチャンクサイズ",
164
+ "lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
165
+ "codes_strength_label": "LM コード強度",
166
+ "codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
167
+ "similarity_denoise_label": "類似度 / ノイズ除去",
168
+ "similarity_denoise_info": "出力が参照オーディオにどれだけ忠実かを制御します。高い値ほど構造を保持します。",
169
+ "cover_strength_label": "オーディオカバー強度",
170
+ "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
171
+ "score_sensitivity_label": "品質スコア感度",
172
+ "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
173
+ "think_label": "思考",
174
+ "parallel_thinking_label": "並列思考",
175
+ "generate_btn": "🎵 音楽を生成",
176
+ "autogen_label": "自動生成",
177
+ "caption_rewrite_label": "キャプション書き換え"
178
+ },
179
+ "results": {
180
+ "title": "🎵 結果",
181
+ "generated_music": "🎵 生成された音楽(サンプル {n})",
182
+ "send_to_src_btn": "🔗 ソースオーディオに送信",
183
+ "save_btn": "💾 保存",
184
+ "score_btn": "📊 スコア",
185
+ "lrc_btn": "🎵 LRC",
186
+ "quality_score_label": "品質スコア(サンプル {n})",
187
+ "quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
188
+ "codes_label": "LM コード(サンプル {n})",
189
+ "lrc_label": "歌詞タイムスタンプ(サンプル {n})",
190
+ "lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
191
+ "details_accordion": "📊 スコア & LRC & LM コード",
192
+ "generation_status": "生成ステータス",
193
+ "current_batch": "現在のバッチ",
194
+ "batch_indicator": "バッチ {current} / {total}",
195
+ "next_batch_status": "次のバッチステータス",
196
+ "prev_btn": "◀ 前へ",
197
+ "next_btn": "次へ ▶",
198
+ "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
199
+ "batch_results_title": "👇 クリックしてバッチ結果と生成詳細を表示",
200
+ "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
201
+ "generation_details": "生成詳細"
202
+ },
203
+ "messages": {
204
+ "no_audio_to_save": "❌ 保存するオーディオがありません",
205
+ "save_success": "✅ オーディオとメタデータを {filename} に保存しました",
206
+ "save_failed": "❌ 保存に失敗しました: {error}",
207
+ "no_file_selected": "⚠️ ファイルが選択されていません",
208
+ "params_loaded": "✅ {filename} からパラメータを読み込みました",
209
+ "invalid_json": "❌ 無効なJSONファイル: {error}",
210
+ "load_error": "❌ ファイルの読み込みエラー: {error}",
211
+ "example_loaded": "📁 {filename} からサンプルを読み込みました",
212
+ "example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
213
+ "example_error": "サンプル読み込みエラー: {error}",
214
+ "lm_generated": "🤖 LMを使用してサンプルを生成しました",
215
+ "lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
216
+ "lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
217
+ "autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
218
+ "batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
219
+ "batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
220
+ "batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
221
+ "viewing_batch": "✅ バッチ {n} を表示中",
222
+ "at_first_batch": "すでに最初のバッチです",
223
+ "at_last_batch": "次のバッチはありません",
224
+ "batch_not_found": "キューにバッチ {n} が見つかりません",
225
+ "no_batch_data": "復元するバッチデータがありません。",
226
+ "params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
227
+ "scoring_failed": "❌ エラー: バッチデータが見つかりません",
228
+ "no_codes": "❌ 利用可能なオーディオ��ードがありません。最初に音楽を生成してください。",
229
+ "score_failed": "❌ スコアリングに失敗しました: {error}",
230
+ "score_error": "❌ スコア計算エラー: {error}",
231
+ "lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
232
+ "lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
233
+ "lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
234
+ "lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
235
+ "lrc_empty_result": "⚠️ LRC生成の結果が空です。",
236
+ "empty_query": "⚠️ 音楽の説明を入力してください。",
237
+ "sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
238
+ "sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
239
+ "simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
240
+ "simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
241
+ "simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
242
+ "format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
243
+ "format_failed": "❌ フォーマットに失敗しました: {error}",
244
+ "skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
245
+ "invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
246
+ "timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
247
+ "timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
248
+ },
249
+ "training": {
250
+ "tab_title": "🎓 LoRA トレーニング",
251
+ "tab_dataset_builder": "📁 データセットビルダー",
252
+ "tab_train_lora": "🚀 LoRA をトレーニング",
253
+ "quick_start_title": "🚀 クイックスタート",
254
+ "load_dataset_label": "データセット JSON パス",
255
+ "load_dataset_info": "以前保存したデータセットを読み込む",
256
+ "load_btn": "📂 読み込み",
257
+ "load_status": "読み込み状態",
258
+ "scan_label": "オーディオディレクトリパス",
259
+ "scan_info": "オーディオファイルをスキャン(wav、mp3、flac、ogg、opus)",
260
+ "scan_btn": "🔍 スキャン",
261
+ "scan_status": "スキャン状態",
262
+ "found_audio_files": "見つかったオーディオファイル",
263
+ "dataset_name": "データセット名",
264
+ "dataset_name_placeholder": "データセット名を入力",
265
+ "dataset_settings_header": "データセット設定",
266
+ "tag_prepend": "前置(タグ、キャプション)",
267
+ "tag_append": "後置(キャプション、タグ)",
268
+ "tag_replace": "キャプションを置換",
269
+ "step2_title": "ステップ 2: AI で自動ラベル",
270
+ "step3_title": "ステップ 3: プレビューと編集",
271
+ "step4_title": "ステップ 4: データセットを保存",
272
+ "step5_title": "ステップ 5: テンソルに前処理",
273
+ "all_instrumental": "すべてインストゥルメンタル",
274
+ "all_instrumental_info": "すべてのトラックがインストゥルメンタル(ボーカルなし)の場合にチェック",
275
+ "custom_tag": "カスタムアクティベーションタグ",
276
+ "custom_tag_info": "この LoRA のスタイルを有効にする一意のタグ",
277
+ "tag_position": "タグの位置",
278
+ "tag_position_info": "キャプション内でカスタムタグを配置する位置",
279
+ "genre_ratio": "ジャンル比率 (%)",
280
+ "genre_ratio_info": "0%=すべてキャプション、100%=すべてジャンル。サンプル単位の上書きが優先。",
281
+ "skip_metas": "BPM/キー/拍子をスキップ",
282
+ "skip_metas_info": "BPM/キー/拍子の生成をスキップ。キャプションとジャンルは LM が生成。",
283
+ "only_unlabeled": "未ラベルのみ",
284
+ "only_unlabeled_info": "キャプションのないサンプルのみラベル付け(失敗したラベル付けの再開に便利)",
285
+ "auto_label_btn": "🏷️ 一括自動ラベル",
286
+ "label_progress": "ラベル付け進捗",
287
+ "select_sample": "サンプル # を選択",
288
+ "select_sample_info": "プレビューと編集するサンプルを選択",
289
+ "audio_preview": "オーディオプレビュー",
290
+ "filename": "ファイル名",
291
+ "caption": "キャプション",
292
+ "genre": "ジャンル",
293
+ "prompt_override_label": "プロンプト上書き(このサ���プル)",
294
+ "prompt_override_info": "このサンプルのグローバル比率を上書き",
295
+ "lyrics_editable_label": "歌詞(編集可、トレーニング用)",
296
+ "raw_lyrics_label": "生歌詞(.txt ファイルから)",
297
+ "no_lyrics_placeholder": "(.txt 歌詞ファイルなし)",
298
+ "bpm": "BPM",
299
+ "key_label": "キー",
300
+ "key_placeholder": "C Major",
301
+ "time_sig": "拍子",
302
+ "duration_s": "長さ (秒)",
303
+ "language": "言語",
304
+ "instrumental": "インストゥルメンタル",
305
+ "save_changes_btn": "💾 変更を保存",
306
+ "edit_status": "編集状態",
307
+ "save_path": "保存パス",
308
+ "save_path_info": "データセット JSON の保存先パス",
309
+ "save_dataset_btn": "💾 データセットを保存",
310
+ "save_status": "保存状態",
311
+ "load_existing_label": "既存データセットを読み込み(任意)",
312
+ "load_existing_info": "以前保存したデータセット JSON ファイルのパス",
313
+ "load_dataset_btn": "📂 データセットを読み込み",
314
+ "tensor_output_dir": "テンソル出力ディレクトリ",
315
+ "tensor_output_info": "前処理済みテンソルファイルの保存先ディレクトリ",
316
+ "preprocess_btn": "⚡ 前処理",
317
+ "preprocess_progress": "前処理進捗",
318
+ "preprocessed_tensors_dir": "前処理済みテンソルディレクトリ",
319
+ "preprocessed_tensors_info": "前処理済み .pt テンソルファイルを含むディレクトリ",
320
+ "train_section_tensors": "前処理済みデータセット選択",
321
+ "train_section_lora": "LoRA 設定",
322
+ "train_section_params": "トレーニングパラメータ",
323
+ "dataset_info": "データセット情報",
324
+ "lora_rank": "LoRA ランク (r)",
325
+ "lora_rank_info": "高いほど容量は増えるがメモリ使用量も増加",
326
+ "lora_alpha": "LoRA Alpha",
327
+ "lora_alpha_info": "スケーリング係数(通常はランクの2倍)",
328
+ "lora_dropout": "LoRA Dropout",
329
+ "learning_rate": "学習率",
330
+ "learning_rate_info": "3e-4 から始め、必要に応じて調整",
331
+ "max_epochs": "最大エポック数",
332
+ "batch_size": "バッチサイズ",
333
+ "batch_size_info": "VRAM に余裕があれば増やせます",
334
+ "gradient_accumulation": "勾配累積",
335
+ "gradient_accumulation_info": "実効バッチ = batch_size × 累積",
336
+ "save_every_n_epochs": "N エポックごとに保存",
337
+ "shift": "Shift",
338
+ "shift_info": "ターボモデル用タイムステップシフト",
339
+ "seed": "シード",
340
+ "output_dir": "出力ディレクトリ",
341
+ "output_dir_info": "トレーニング済み LoRA 重みの保存先ディレクトリ",
342
+ "start_training_btn": "🚀 トレーニング開始",
343
+ "stop_training_btn": "⏹️ トレーニング停止",
344
+ "training_progress": "トレーニング進捗",
345
+ "training_log": "トレーニングログ",
346
+ "training_loss_title": "トレーニング損失",
347
+ "step": "ステップ",
348
+ "loss": "損失",
349
+ "export_header": "LoRA をエクスポート",
350
+ "export_path": "エクスポートパス",
351
+ "export_lora_btn": "📦 LoRA をエクスポート",
352
+ "export_status": "エクスポート状態"
353
+ }
354
+ }
acestep/gradio_ui/i18n/zh.json ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 演练场💡",
4
+ "subtitle": "推动开源音乐生成的边界"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 数据集浏览器",
8
+ "dataset_label": "数据集",
9
+ "dataset_info": "选择要浏览的数据集",
10
+ "import_btn": "📥 导入数据集",
11
+ "search_type_label": "搜索类型",
12
+ "search_type_info": "如何查找项目",
13
+ "search_value_label": "搜索值",
14
+ "search_value_placeholder": "输入键或索引(留空表示随机)",
15
+ "search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
16
+ "instruction_label": "📝 指令",
17
+ "instruction_placeholder": "无可用指令",
18
+ "metadata_title": "📋 项目元数据 (JSON)",
19
+ "metadata_label": "完整项目信息",
20
+ "source_audio": "源音频",
21
+ "target_audio": "目标音频",
22
+ "reference_audio": "参考音频",
23
+ "get_item_btn": "🔍 获取项目",
24
+ "use_src_checkbox": "使用数据集中的源音频",
25
+ "use_src_info": "勾选以使用数据集中的源音频",
26
+ "data_status_label": "📊 数据状态",
27
+ "data_status_default": "❌ 未导入数据集",
28
+ "autofill_btn": "📋 自动填充生成表单"
29
+ },
30
+ "service": {
31
+ "title": "🔧 服务配置",
32
+ "checkpoint_label": "检查点文件",
33
+ "checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
34
+ "refresh_btn": "🔄 刷新",
35
+ "model_path_label": "主模型路径",
36
+ "model_path_info": "选择模型配置目录(从检查点自动扫描)",
37
+ "device_label": "设备",
38
+ "device_info": "处理设备(建议自动检测)",
39
+ "lm_model_path_label": "5Hz LM 模型路径",
40
+ "lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
41
+ "backend_label": "5Hz LM 后端",
42
+ "backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
43
+ "init_llm_label": "初始化 5Hz LM",
44
+ "init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
45
+ "flash_attention_label": "使用Flash Attention",
46
+ "flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
47
+ "flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
48
+ "offload_cpu_label": "卸载到CPU",
49
+ "offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
50
+ "offload_dit_cpu_label": "将DiT卸载到CPU",
51
+ "offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
52
+ "compile_model_label": "编译模型",
53
+ "compile_model_info": "使用 torch.compile 优化模型(量化必需)",
54
+ "quantization_label": "INT8 量化",
55
+ "quantization_info": "启用 INT8 仅权重量化以减少显存占用(需要启用编译模型)",
56
+ "init_btn": "初始化服务",
57
+ "status_label": "状态",
58
+ "language_label": "界面语言",
59
+ "language_info": "选择界面语言"
60
+ },
61
+ "generation": {
62
+ "required_inputs": "📝 必需输入",
63
+ "task_type_label": "任务类型",
64
+ "task_type_info": "选择生成的任务类型",
65
+ "instruction_label": "指令",
66
+ "instruction_info": "指令根据任务类型自动生成",
67
+ "load_btn": "加载",
68
+ "track_name_label": "音轨名称",
69
+ "track_name_info": "为lego/extract任务选择音轨名称",
70
+ "track_classes_label": "音轨名称",
71
+ "track_classes_info": "为complete任务选择多个音轨类别",
72
+ "audio_uploads": "🎵 音频上传",
73
+ "reference_audio": "参考音频(可选)",
74
+ "source_audio": "源音频(可选)",
75
+ "convert_codes_btn": "转换为代码",
76
+ "lm_codes_hints": "🎼 LM 代码提示",
77
+ "lm_codes_label": "LM 代码提示",
78
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
79
+ "lm_codes_info": "粘贴用于text2music生成的LM代码提示",
80
+ "lm_codes_sample": "LM 代码提示(样本 {n})",
81
+ "lm_codes_sample_info": "样本{n}的代码",
82
+ "transcribe_btn": "转录",
83
+ "repainting_controls": "🎨 重绘控制(秒)",
84
+ "repainting_start": "重绘开始",
85
+ "repainting_end": "重绘结束",
86
+ "mode_label": "生成模式",
87
+ "mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
88
+ "mode_simple": "简单",
89
+ "mode_custom": "自定义",
90
+ "simple_query_label": "歌曲描述",
91
+ "simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
92
+ "simple_query_info": "输入你想生成的音乐的自然语言描述",
93
+ "simple_vocal_language_label": "人声语言(可选)",
94
+ "simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
95
+ "create_sample_btn": "创建样本",
96
+ "caption_title": "📝 音乐描述",
97
+ "caption_label": "音乐描述(可选)",
98
+ "caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
99
+ "caption_info": "描述风格、流派、乐器���情绪",
100
+ "lyrics_title": "📝 歌词",
101
+ "lyrics_label": "歌词(可选)",
102
+ "lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
103
+ "lyrics_info": "带有结构的歌曲歌词",
104
+ "instrumental_label": "纯音乐",
105
+ "format_btn": "格式化",
106
+ "optional_params": "⚙️ 可选参数",
107
+ "vocal_language_label": "人声语言(可选)",
108
+ "vocal_language_info": "纯音乐使用 `unknown`",
109
+ "bpm_label": "BPM(可选)",
110
+ "bpm_info": "留空表示N/A",
111
+ "keyscale_label": "调性(可选)",
112
+ "keyscale_placeholder": "留空表示N/A",
113
+ "keyscale_info": "A-G, #/♭, 大调/小调",
114
+ "timesig_label": "拍号(可选)",
115
+ "timesig_info": "2/4, 3/4, 4/4...",
116
+ "duration_label": "音频时长(秒)",
117
+ "duration_info": "使用-1表示随机",
118
+ "batch_size_label": "批量大小",
119
+ "batch_size_info": "要生成的音频数量(最多8个)",
120
+ "advanced_settings": "🔧 高级设置",
121
+ "inference_steps_label": "DiT 推理步数",
122
+ "inference_steps_info": "Turbo: 最多8, Base: 最多200",
123
+ "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
124
+ "guidance_scale_info": "更高的值更紧密地遵循文本",
125
+ "seed_label": "种子",
126
+ "seed_info": "批量使用逗号分隔的值",
127
+ "random_seed_label": "随机种子",
128
+ "random_seed_info": "启用以自动生成种子",
129
+ "audio_format_label": "音频格式",
130
+ "audio_format_info": "保存文件的音频格式",
131
+ "use_adg_label": "使用 ADG",
132
+ "use_adg_info": "启用角域引导",
133
+ "shift_label": "Shift",
134
+ "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
135
+ "infer_method_label": "推理方法",
136
+ "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
137
+ "custom_timesteps_label": "自定义时间步",
138
+ "custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
139
+ "cfg_interval_start": "CFG 间隔开始",
140
+ "cfg_interval_end": "CFG 间隔结束",
141
+ "lm_params_title": "🤖 LM 生成参数",
142
+ "lm_temperature_label": "LM 温度",
143
+ "lm_temperature_info": "5Hz LM温度(越高越随机)",
144
+ "lm_cfg_scale_label": "LM CFG 比例",
145
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
146
+ "lm_top_k_label": "LM Top-K",
147
+ "lm_top_k_info": "Top-K (0 = 禁用)",
148
+ "lm_top_p_label": "LM Top-P",
149
+ "lm_top_p_info": "Top-P (1.0 = 禁用)",
150
+ "lm_negative_prompt_label": "LM 负面提示",
151
+ "lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
152
+ "lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
153
+ "cot_metas_label": "CoT 元数据",
154
+ "cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
155
+ "cot_language_label": "CoT 语言",
156
+ "cot_language_info": "在CoT中生成语言(思维链)",
157
+ "constrained_debug_label": "约束解码调试",
158
+ "constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
159
+ "auto_score_label": "自动评分",
160
+ "auto_score_info": "自动计算所有生成音频的质量分数",
161
+ "auto_lrc_label": "自动 LRC",
162
+ "auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
163
+ "lm_batch_chunk_label": "LM 批量块大小",
164
+ "lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
165
+ "codes_strength_label": "LM 代码强度",
166
+ "codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
167
+ "similarity_denoise_label": "相似度 / 降噪",
168
+ "similarity_denoise_info": "控制输出与参考音频的贴合程度。数值越高保留越多结构。",
169
+ "cover_strength_label": "音频覆盖强度",
170
+ "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
171
+ "score_sensitivity_label": "质量评分敏感度",
172
+ "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
173
+ "think_label": "思考",
174
+ "parallel_thinking_label": "并行思考",
175
+ "generate_btn": "🎵 生成音乐",
176
+ "autogen_label": "自动生成",
177
+ "caption_rewrite_label": "描述重写"
178
+ },
179
+ "results": {
180
+ "title": "🎵 结果",
181
+ "generated_music": "🎵 生成的音乐(样本 {n})",
182
+ "send_to_src_btn": "🔗 发送到源音频",
183
+ "save_btn": "💾 保存",
184
+ "score_btn": "📊 评分",
185
+ "lrc_btn": "🎵 LRC",
186
+ "quality_score_label": "质量分数(样本 {n})",
187
+ "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
188
+ "codes_label": "LM 代码(样本 {n})",
189
+ "lrc_label": "歌词时间戳(样本 {n})",
190
+ "lrc_placeholder": "点击'LRC'生成时间戳",
191
+ "details_accordion": "📊 评分与LRC与LM代码",
192
+ "generation_status": "生成状态",
193
+ "current_batch": "当前批次",
194
+ "batch_indicator": "��次 {current} / {total}",
195
+ "next_batch_status": "下一批次状态",
196
+ "prev_btn": "◀ 上一个",
197
+ "next_btn": "下一个 ▶",
198
+ "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
199
+ "batch_results_title": "👇 点击查看批量结果和生成详情",
200
+ "all_files_label": "📁 所有生成的文件(下载)",
201
+ "generation_details": "生成详情"
202
+ },
203
+ "messages": {
204
+ "no_audio_to_save": "❌ 没有要保存的音频",
205
+ "save_success": "✅ 已将音频和元数据保存到 {filename}",
206
+ "save_failed": "❌ 保存失败: {error}",
207
+ "no_file_selected": "⚠️ 未选择文件",
208
+ "params_loaded": "✅ 已从 {filename} 加载参数",
209
+ "invalid_json": "❌ 无效的JSON文件: {error}",
210
+ "load_error": "❌ 加载文件时出错: {error}",
211
+ "example_loaded": "📁 已从 {filename} 加载示例",
212
+ "example_failed": "解析JSON文件 {filename} 失败: {error}",
213
+ "example_error": "加载示例时出错: {error}",
214
+ "lm_generated": "🤖 使用LM生成的示例",
215
+ "lm_fallback": "使用LM生成示例失败,回退到示例目录",
216
+ "lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
217
+ "autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
218
+ "batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
219
+ "batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
220
+ "batch_failed": "❌ 后台生成失败: {error}",
221
+ "viewing_batch": "✅ 查看批次 {n}",
222
+ "at_first_batch": "已在第一批次",
223
+ "at_last_batch": "没有下一批次可用",
224
+ "batch_not_found": "在队列中未找到批次 {n}",
225
+ "no_batch_data": "没有要恢复的批次数据。",
226
+ "params_restored": "✅ 已从批次 {n} 恢复UI参数",
227
+ "scoring_failed": "❌ 错误: 未找到批次数据",
228
+ "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
229
+ "score_failed": "❌ 评分失败: {error}",
230
+ "score_error": "❌ 计算分数时出错: {error}",
231
+ "lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
232
+ "lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
233
+ "lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
234
+ "lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
235
+ "lrc_empty_result": "⚠️ LRC生成结果为空。",
236
+ "empty_query": "⚠️ 请输入音乐描述。",
237
+ "sample_creation_failed": "❌ 创建样本失败。请重试。",
238
+ "sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
239
+ "simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
240
+ "simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
241
+ "simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
242
+ "format_success": "✅ 描述和歌词格式化成功",
243
+ "format_failed": "❌ 格式化失败: {error}",
244
+ "skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
245
+ "invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
246
+ "timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
247
+ "timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
248
+ },
249
+ "training": {
250
+ "tab_title": "🎓 LoRA 训练",
251
+ "tab_dataset_builder": "📁 数据集构建",
252
+ "tab_train_lora": "🚀 训练 LoRA",
253
+ "quick_start_title": "🚀 快速开始",
254
+ "load_dataset_label": "数据集 JSON 路径",
255
+ "load_btn": "📂 加载",
256
+ "load_status": "加载状态",
257
+ "scan_label": "音频目录路径",
258
+ "scan_info": "扫描音频文件(wav、mp3、flac、ogg、opus)",
259
+ "scan_btn": "🔍 扫描",
260
+ "scan_status": "扫描状态",
261
+ "found_audio_files": "已找到的音频文件",
262
+ "dataset_name": "数据集名称",
263
+ "dataset_name_placeholder": "输入数据集名称",
264
+ "dataset_settings_header": "数据集设置",
265
+ "tag_prepend": "前置(标签,描述)",
266
+ "tag_append": "后置(描述,标签)",
267
+ "tag_replace": "替换描述",
268
+ "step2_title": "步骤 2:AI 自动标注",
269
+ "step3_title": "步骤 3:预览与编辑",
270
+ "step4_title": "步骤 4:保存数据集",
271
+ "step5_title": "步骤 5:预处理为张量",
272
+ "all_instrumental": "全部为纯音乐",
273
+ "all_instrumental_info": "勾选表示所有曲目均为纯音乐(无人声)",
274
+ "custom_tag": "自定义激活标签",
275
+ "custom_tag_info": "用于激活此 LoRA 风格的唯一标签",
276
+ "tag_position": "标签位置",
277
+ "tag_position_info": "在描述中放置自定义标签的位置",
278
+ "genre_ratio": "风格比例 (%)",
279
+ "genre_ratio_info": "0%=全部描述,100%=全部风格。单样本覆盖优先。",
280
+ "skip_metas": "跳过 BPM/调性/拍号",
281
+ "skip_metas_info": "跳过 BPM/调性/拍号生成。描述和风格仍由 LM 生成。",
282
+ "only_unlabeled": "仅未标注",
283
+ "only_unlabeled_info": "仅标注无描述的样本(用于继续失败的标注)",
284
+ "auto_label_btn": "🏷️ 自动标注全部",
285
+ "label_progress": "标注进度",
286
+ "select_sample": "选择样本 #",
287
+ "select_sample_info": "选择要预览和编辑的样本",
288
+ "audio_preview": "音频预览",
289
+ "filename": "文件名",
290
+ "caption": "描述",
291
+ "genre": "风格",
292
+ "prompt_override_label": "提示覆盖(本样本)",
293
+ "prompt_override_info": "覆盖本样本的全局比例",
294
+ "lyrics_editable_label": "歌词(可编辑,用于训练)",
295
+ "raw_lyrics_label": "原始歌词(来自 .txt 文件)",
296
+ "no_lyrics_placeholder": "(无 .txt 歌词文件)",
297
+ "bpm": "BPM",
298
+ "key_label": "调性",
299
+ "key_placeholder": "C 大调",
300
+ "time_sig": "拍号",
301
+ "duration_s": "时长 (秒)",
302
+ "language": "语言",
303
+ "instrumental": "纯音乐",
304
+ "save_changes_btn": "💾 保存更改",
305
+ "edit_status": "编辑状态",
306
+ "save_path": "保存路径",
307
+ "save_path_info": "数据集 JSON 的保存路径",
308
+ "save_dataset_btn": "💾 保存数据集",
309
+ "save_status": "保存状态",
310
+ "load_existing_label": "加载已有数据集(可选)",
311
+ "load_existing_info": "之前保存的数据集 JSON 文件路径",
312
+ "load_dataset_btn": "📂 加载数据集",
313
+ "tensor_output_dir": "张量输出目录",
314
+ "tensor_output_info": "保存预处理张量文件的目录",
315
+ "preprocess_btn": "⚡ 预处理",
316
+ "preprocess_progress": "预处理进度",
317
+ "preprocessed_tensors_dir": "预处理张量目录",
318
+ "preprocessed_tensors_info": "包含预处理 .pt 张量文件的目录",
319
+ "dataset_info": "数据集信息",
320
+ "lora_rank": "LoRA 秩 (r)",
321
+ "lora_rank_info": "越高容量越大,显存占用越多",
322
+ "lora_alpha": "LoRA Alpha",
323
+ "lora_alpha_info": "缩放因子(通常为 2× 秩)",
324
+ "lora_dropout": "LoRA Dropout",
325
+ "learning_rate": "学习率",
326
+ "learning_rate_info": "建议从 3e-4 开始,按需调整",
327
+ "max_epochs": "最大轮数",
328
+ "batch_size": "批大小",
329
+ "batch_size_info": "显存充足时可增大",
330
+ "gradient_accumulation": "梯度累积",
331
+ "gradient_accumulation_info": "有效批大小 = batch_size × 累积步数",
332
+ "save_every_n_epochs": "每 N 轮保存",
333
+ "shift": "Shift",
334
+ "shift_info": "Turbo 模型时间步偏移",
335
+ "seed": "随机种子",
336
+ "output_dir": "输出目录",
337
+ "output_dir_info": "保存训练后 LoRA 权重的目录",
338
+ "start_training_btn": "🚀 开始训练",
339
+ "stop_training_btn": "⏹️ 停止训练",
340
+ "training_progress": "训练进度",
341
+ "training_log": "训练日志",
342
+ "training_loss_title": "训练损失",
343
+ "step": "步数",
344
+ "loss": "损失",
345
+ "export_header": "导出 LoRA",
346
+ "export_path": "导出路径",
347
+ "export_lora_btn": "📦 导出 LoRA",
348
+ "export_status": "导出状态"
349
+ }
350
+ }
acestep/gradio_ui/interfaces/__init__.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Components Module
3
+ Contains all Gradio interface component definitions and layouts
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import get_i18n, t
7
+ from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
+ from acestep.gradio_ui.interfaces.generation import create_generation_section
9
+ from acestep.gradio_ui.interfaces.result import create_results_section
10
+ from acestep.gradio_ui.interfaces.training import create_training_section
11
+ from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
12
+
13
+
14
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
15
+ """
16
+ Create Gradio interface
17
+
18
+ Args:
19
+ dit_handler: DiT handler instance
20
+ llm_handler: LM handler instance
21
+ dataset_handler: Dataset handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
25
+
26
+ Returns:
27
+ Gradio Blocks instance
28
+ """
29
+ # Initialize i18n with selected language
30
+ i18n = get_i18n(language)
31
+
32
+ with gr.Blocks(
33
+ title=t("app.title"),
34
+ theme=gr.themes.Soft(),
35
+ css="""
36
+ .main-header {
37
+ text-align: center;
38
+ margin-bottom: 2rem;
39
+ }
40
+ .section-header {
41
+ background: linear-gradient(90deg, #4CAF50, #45a049);
42
+ color: white;
43
+ padding: 10px;
44
+ border-radius: 5px;
45
+ margin: 10px 0;
46
+ }
47
+ .lm-hints-row {
48
+ align-items: stretch;
49
+ }
50
+ .lm-hints-col {
51
+ display: flex;
52
+ }
53
+ .lm-hints-col > div {
54
+ flex: 1;
55
+ display: flex;
56
+ }
57
+ .lm-hints-btn button {
58
+ height: 100%;
59
+ width: 100%;
60
+ }
61
+ /* Position Audio time labels lower to avoid scrollbar overlap */
62
+ .component-wrapper > .timestamps {
63
+ transform: translateY(15px);
64
+ }
65
+ """,
66
+ ) as demo:
67
+
68
+ gr.HTML(f"""
69
+ <div class="main-header">
70
+ <h1>{t("app.title")}</h1>
71
+ <p>{t("app.subtitle")}</p>
72
+ </div>
73
+ """)
74
+
75
+ # Dataset Explorer Section
76
+ dataset_section = create_dataset_section(dataset_handler)
77
+
78
+ # Generation Section (pass init_params and language to support pre-initialization)
79
+ generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
80
+
81
+ # Results Section
82
+ results_section = create_results_section(dit_handler)
83
+
84
+ # Training Section (LoRA training and dataset builder)
85
+ # Pass init_params to support hiding in service mode
86
+ training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
87
+
88
+ # Connect event handlers
89
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
90
+
91
+ # Connect training event handlers
92
+ setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
93
+
94
+ return demo
acestep/gradio_ui/interfaces/dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Dataset Section Module
3
+ Contains dataset explorer section component definitions
4
+ """
5
+ import gradio as gr
6
+
7
+
8
+ def create_dataset_section(dataset_handler) -> dict:
9
+ """Create dataset explorer section"""
10
+ with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
11
+ with gr.Row(equal_height=True):
12
+ dataset_type = gr.Dropdown(
13
+ choices=["train", "test"],
14
+ value="train",
15
+ label="Dataset",
16
+ info="Choose dataset to explore",
17
+ scale=2
18
+ )
19
+ import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
20
+
21
+ search_type = gr.Dropdown(
22
+ choices=["keys", "idx", "random"],
23
+ value="random",
24
+ label="Search Type",
25
+ info="How to find items",
26
+ scale=1
27
+ )
28
+ search_value = gr.Textbox(
29
+ label="Search Value",
30
+ placeholder="Enter keys or index (leave empty for random)",
31
+ info="Keys: exact match, Index: 0 to dataset size-1",
32
+ scale=2
33
+ )
34
+
35
+ instruction_display = gr.Textbox(
36
+ label="📝 Instruction",
37
+ interactive=False,
38
+ placeholder="No instruction available",
39
+ lines=1
40
+ )
41
+
42
+ repaint_viz_plot = gr.Plot()
43
+
44
+ with gr.Accordion("📋 Item Metadata (JSON)", open=False):
45
+ item_info_json = gr.Code(
46
+ label="Complete Item Information",
47
+ language="json",
48
+ interactive=False,
49
+ lines=15
50
+ )
51
+
52
+ with gr.Row(equal_height=True):
53
+ item_src_audio = gr.Audio(
54
+ label="Source Audio",
55
+ type="filepath",
56
+ interactive=False,
57
+ scale=8
58
+ )
59
+ get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
60
+
61
+ with gr.Row(equal_height=True):
62
+ item_target_audio = gr.Audio(
63
+ label="Target Audio",
64
+ type="filepath",
65
+ interactive=False,
66
+ scale=8
67
+ )
68
+ item_refer_audio = gr.Audio(
69
+ label="Reference Audio",
70
+ type="filepath",
71
+ interactive=False,
72
+ scale=2
73
+ )
74
+
75
+ with gr.Row():
76
+ use_src_checkbox = gr.Checkbox(
77
+ label="Use Source Audio from Dataset",
78
+ value=True,
79
+ info="Check to use the source audio from dataset"
80
+ )
81
+
82
+ data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
83
+ auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
84
+
85
+ return {
86
+ "dataset_type": dataset_type,
87
+ "import_dataset_btn": import_dataset_btn,
88
+ "search_type": search_type,
89
+ "search_value": search_value,
90
+ "instruction_display": instruction_display,
91
+ "repaint_viz_plot": repaint_viz_plot,
92
+ "item_info_json": item_info_json,
93
+ "item_src_audio": item_src_audio,
94
+ "get_item_btn": get_item_btn,
95
+ "item_target_audio": item_target_audio,
96
+ "item_refer_audio": item_refer_audio,
97
+ "use_src_checkbox": use_src_checkbox,
98
+ "data_status": data_status,
99
+ "auto_fill_btn": auto_fill_btn,
100
+ }
101
+
acestep/gradio_ui/interfaces/generation.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Generation Section Module
3
+ Contains generation section component definitions
4
+ """
5
+ import sys
6
+ import gradio as gr
7
+ from acestep.constants import (
8
+ VALID_LANGUAGES,
9
+ TRACK_NAMES,
10
+ TASK_TYPES_TURBO,
11
+ TASK_TYPES_BASE,
12
+ DEFAULT_DIT_INSTRUCTION,
13
+ )
14
+ from acestep.gradio_ui.i18n import t
15
+ from acestep.gpu_config import get_global_gpu_config, GPUConfig
16
+
17
+
18
+ def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
19
+ """Create generation section
20
+
21
+ Args:
22
+ dit_handler: DiT handler instance
23
+ llm_handler: LM handler instance
24
+ init_params: Dictionary containing initialization parameters and state.
25
+ If None, service will not be pre-initialized.
26
+ language: UI language code ('en', 'zh', 'ja')
27
+ """
28
+ # Check if service is pre-initialized
29
+ service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
30
+
31
+ # Check if running in service mode (restricted UI)
32
+ service_mode = init_params is not None and init_params.get('service_mode', False)
33
+
34
+ # Get current language from init_params if available
35
+ current_language = init_params.get('language', language) if init_params else language
36
+
37
+ # Get GPU configuration
38
+ gpu_config: GPUConfig = init_params.get('gpu_config') if init_params else None
39
+ if gpu_config is None:
40
+ gpu_config = get_global_gpu_config()
41
+
42
+ # Determine if LM is initialized (for setting appropriate limits)
43
+ lm_initialized = init_params.get('init_llm', False) if init_params else False
44
+
45
+ # Calculate UI limits based on GPU config and LM state
46
+ max_duration = gpu_config.max_duration_with_lm if lm_initialized else gpu_config.max_duration_without_lm
47
+ max_batch_size = gpu_config.max_batch_size_with_lm if lm_initialized else gpu_config.max_batch_size_without_lm
48
+ default_batch_size = min(2, max_batch_size) # Default to 2 or max if lower
49
+ init_lm_default = gpu_config.init_lm_default
50
+
51
+ # Determine default offload setting
52
+ # If XPU is detected, default offload to False (keep models on device)
53
+ # Otherwise default to True (offload to CPU to save VRAM)
54
+ default_offload = True
55
+ try:
56
+ import torch
57
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
58
+ default_offload = False
59
+ except ImportError:
60
+ pass
61
+
62
+ with gr.Group():
63
+ # Service Configuration - collapse if pre-initialized, hide if in service mode
64
+ accordion_open = not service_pre_initialized
65
+ accordion_visible = not service_pre_initialized # Hide when running in service mode
66
+ with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
67
+ # Language selector at the top
68
+ with gr.Row():
69
+ language_dropdown = gr.Dropdown(
70
+ choices=[
71
+ ("English", "en"),
72
+ ("中文", "zh"),
73
+ ("日本語", "ja"),
74
+ ],
75
+ value=current_language,
76
+ label=t("service.language_label"),
77
+ info=t("service.language_info"),
78
+ scale=1,
79
+ )
80
+
81
+ # Dropdown options section - all dropdowns grouped together
82
+ with gr.Row(equal_height=True):
83
+ with gr.Column(scale=4):
84
+ # Set checkpoint value from init_params if pre-initialized
85
+ checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
86
+ checkpoint_dropdown = gr.Dropdown(
87
+ label=t("service.checkpoint_label"),
88
+ choices=dit_handler.get_available_checkpoints(),
89
+ value=checkpoint_value,
90
+ info=t("service.checkpoint_info")
91
+ )
92
+ with gr.Column(scale=1, min_width=90):
93
+ refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
94
+
95
+ with gr.Row():
96
+ # Get available acestep-v15- model list
97
+ available_models = dit_handler.get_available_acestep_v15_models()
98
+ default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
99
+
100
+ # Set config_path value from init_params if pre-initialized
101
+ config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
102
+ config_path = gr.Dropdown(
103
+ label=t("service.model_path_label"),
104
+ choices=available_models,
105
+ value=config_path_value,
106
+ info=t("service.model_path_info")
107
+ )
108
+ # Set device value from init_params if pre-initialized
109
+ device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
110
+ device = gr.Dropdown(
111
+ choices=["auto", "cuda", "mps", "xpu", "cpu"],
112
+ value=device_value,
113
+ label=t("service.device_label"),
114
+ info=t("service.device_info")
115
+ )
116
+
117
+ with gr.Row():
118
+ # Get available 5Hz LM model list
119
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
120
+ default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
121
+
122
+ # Set lm_model_path value from init_params if pre-initialized
123
+ lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
124
+ lm_model_path = gr.Dropdown(
125
+ label=t("service.lm_model_path_label"),
126
+ choices=available_lm_models,
127
+ value=lm_model_path_value,
128
+ info=t("service.lm_model_path_info")
129
+ )
130
+ # Set backend value from init_params if pre-initialized
131
+ backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
132
+ backend_dropdown = gr.Dropdown(
133
+ choices=["vllm", "pt", "mlx"],
134
+ value=backend_value,
135
+ label=t("service.backend_label"),
136
+ info=t("service.backend_info")
137
+ )
138
+
139
+ # Checkbox options section - all checkboxes grouped together
140
+ with gr.Row():
141
+ # Set init_llm value from init_params if pre-initialized, otherwise use GPU config default
142
+ init_llm_value = init_params.get('init_llm', init_lm_default) if service_pre_initialized else init_lm_default
143
+ init_llm_checkbox = gr.Checkbox(
144
+ label=t("service.init_llm_label"),
145
+ value=init_llm_value,
146
+ info=t("service.init_llm_info"),
147
+ )
148
+ # Auto-detect flash attention availability
149
+ flash_attn_available = dit_handler.is_flash_attention_available(device_value)
150
+ # Set use_flash_attention value from init_params if pre-initialized
151
+ use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
152
+ use_flash_attention_checkbox = gr.Checkbox(
153
+ label=t("service.flash_attention_label"),
154
+ value=use_flash_attention_value,
155
+ interactive=flash_attn_available,
156
+ info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
157
+ )
158
+ # Set offload_to_cpu value from init_params if pre-initialized (default True)
159
+ offload_to_cpu_value = init_params.get('offload_to_cpu', default_offload) if service_pre_initialized else default_offload
160
+ offload_to_cpu_checkbox = gr.Checkbox(
161
+ label=t("service.offload_cpu_label"),
162
+ value=offload_to_cpu_value,
163
+ info=t("service.offload_cpu_info")
164
+ )
165
+ # Set offload_dit_to_cpu value from init_params if pre-initialized (default True)
166
+ offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', default_offload) if service_pre_initialized else default_offload
167
+ offload_dit_to_cpu_checkbox = gr.Checkbox(
168
+ label=t("service.offload_dit_cpu_label"),
169
+ value=offload_dit_to_cpu_value,
170
+ info=t("service.offload_dit_cpu_info")
171
+ )
172
+ # Set compile_model value from init_params if pre-initialized (default True)
173
+ compile_model_value = init_params.get('compile_model', True) if service_pre_initialized else True
174
+ compile_model_checkbox = gr.Checkbox(
175
+ label=t("service.compile_model_label"),
176
+ value=compile_model_value,
177
+ info=t("service.compile_model_info")
178
+ )
179
+ # Set quantization value from init_params if pre-initialized.
180
+ # Default to False on macOS to avoid torchao incompatibilities.
181
+ default_quantization = False if sys.platform == "darwin" else True
182
+ quantization_value = init_params.get('quantization', default_quantization) if service_pre_initialized else default_quantization
183
+ quantization_checkbox = gr.Checkbox(
184
+ label=t("service.quantization_label"),
185
+ value=quantization_value,
186
+ info=t("service.quantization_info")
187
+ )
188
+
189
+ init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
190
+ # Set init_status value from init_params if pre-initialized
191
+ init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
192
+ init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
193
+
194
+ # LoRA Configuration Section
195
+ gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
196
+ with gr.Row():
197
+ lora_path = gr.Textbox(
198
+ label="LoRA Path",
199
+ placeholder="./lora_output/final/adapter",
200
+ info="Path to trained LoRA adapter directory",
201
+ scale=3,
202
+ )
203
+ load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
204
+ unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
205
+ with gr.Row():
206
+ use_lora_checkbox = gr.Checkbox(
207
+ label="Use LoRA",
208
+ value=False,
209
+ info="Enable LoRA adapter for inference",
210
+ scale=1,
211
+ )
212
+ lora_scale_slider = gr.Slider(
213
+ minimum=0.0,
214
+ maximum=1.0,
215
+ value=1.0,
216
+ step=0.05,
217
+ label="LoRA Scale",
218
+ info="LoRA influence strength (0=disabled, 1=full)",
219
+ scale=2,
220
+ )
221
+ lora_status = gr.Textbox(
222
+ label="LoRA Status",
223
+ value="No LoRA loaded",
224
+ interactive=False,
225
+ scale=2,
226
+ )
227
+
228
+ # Inputs
229
+ with gr.Row():
230
+ with gr.Column(scale=2):
231
+ with gr.Accordion(t("generation.required_inputs"), open=True):
232
+ # Task type
233
+ # Determine initial task_type choices based on actual model in use
234
+ # When service is pre-initialized, use config_path from init_params
235
+ actual_model = init_params.get('config_path', default_model) if service_pre_initialized else default_model
236
+ actual_model_lower = (actual_model or "").lower()
237
+ if "turbo" in actual_model_lower:
238
+ initial_task_choices = TASK_TYPES_TURBO
239
+ else:
240
+ initial_task_choices = TASK_TYPES_BASE
241
+
242
+ with gr.Row(equal_height=True):
243
+ with gr.Column(scale=2):
244
+ task_type = gr.Dropdown(
245
+ choices=initial_task_choices,
246
+ value="text2music",
247
+ label=t("generation.task_type_label"),
248
+ info=t("generation.task_type_info"),
249
+ )
250
+ with gr.Column(scale=7):
251
+ instruction_display_gen = gr.Textbox(
252
+ label=t("generation.instruction_label"),
253
+ value=DEFAULT_DIT_INSTRUCTION,
254
+ interactive=False,
255
+ lines=1,
256
+ info=t("generation.instruction_info"),
257
+ )
258
+ with gr.Column(scale=1, min_width=100):
259
+ load_file = gr.UploadButton(
260
+ t("generation.load_btn"),
261
+ file_types=[".json"],
262
+ file_count="single",
263
+ variant="secondary",
264
+ size="sm",
265
+ )
266
+
267
+ track_name = gr.Dropdown(
268
+ choices=TRACK_NAMES,
269
+ value=None,
270
+ label=t("generation.track_name_label"),
271
+ info=t("generation.track_name_info"),
272
+ visible=False
273
+ )
274
+
275
+ complete_track_classes = gr.CheckboxGroup(
276
+ choices=TRACK_NAMES,
277
+ label=t("generation.track_classes_label"),
278
+ info=t("generation.track_classes_info"),
279
+ visible=False
280
+ )
281
+
282
+ # Audio uploads
283
+ audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False)
284
+ with audio_uploads_accordion:
285
+ with gr.Row(equal_height=True):
286
+ with gr.Column(scale=2):
287
+ reference_audio = gr.Audio(
288
+ label=t("generation.reference_audio"),
289
+ type="filepath",
290
+ )
291
+ with gr.Column(scale=7):
292
+ src_audio = gr.Audio(
293
+ label=t("generation.source_audio"),
294
+ type="filepath",
295
+ )
296
+ with gr.Column(scale=1, min_width=80):
297
+ convert_src_to_codes_btn = gr.Button(
298
+ t("generation.convert_codes_btn"),
299
+ variant="secondary",
300
+ size="sm"
301
+ )
302
+
303
+ # Audio Codes for text2music - single input for transcription or cover task
304
+ with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group:
305
+ with gr.Row(equal_height=True):
306
+ text2music_audio_code_string = gr.Textbox(
307
+ label=t("generation.lm_codes_label"),
308
+ placeholder=t("generation.lm_codes_placeholder"),
309
+ lines=6,
310
+ info=t("generation.lm_codes_info"),
311
+ scale=9,
312
+ )
313
+ transcribe_btn = gr.Button(
314
+ t("generation.transcribe_btn"),
315
+ variant="secondary",
316
+ size="sm",
317
+ scale=1,
318
+ )
319
+
320
+ # Repainting controls
321
+ with gr.Group(visible=False) as repainting_group:
322
+ gr.HTML(f"<h5>{t('generation.repainting_controls')}</h5>")
323
+ with gr.Row():
324
+ repainting_start = gr.Number(
325
+ label=t("generation.repainting_start"),
326
+ value=0.0,
327
+ step=0.1,
328
+ )
329
+ repainting_end = gr.Number(
330
+ label=t("generation.repainting_end"),
331
+ value=-1,
332
+ minimum=-1,
333
+ step=0.1,
334
+ )
335
+
336
+ # Simple/Custom Mode Toggle
337
+ # In service mode: only Custom mode, hide the toggle
338
+ with gr.Row(visible=not service_mode):
339
+ generation_mode = gr.Radio(
340
+ choices=[
341
+ (t("generation.mode_simple"), "simple"),
342
+ (t("generation.mode_custom"), "custom"),
343
+ ],
344
+ value="custom" if service_mode else "simple",
345
+ label=t("generation.mode_label"),
346
+ info=t("generation.mode_info"),
347
+ )
348
+
349
+ # Simple Mode Components - hidden in service mode
350
+ with gr.Group(visible=not service_mode) as simple_mode_group:
351
+ with gr.Row(equal_height=True):
352
+ simple_query_input = gr.Textbox(
353
+ label=t("generation.simple_query_label"),
354
+ placeholder=t("generation.simple_query_placeholder"),
355
+ lines=2,
356
+ info=t("generation.simple_query_info"),
357
+ scale=12,
358
+ )
359
+
360
+ with gr.Column(scale=1, min_width=100):
361
+ random_desc_btn = gr.Button(
362
+ "🎲",
363
+ variant="secondary",
364
+ size="sm",
365
+ scale=2
366
+ )
367
+
368
+ with gr.Row(equal_height=True):
369
+ with gr.Column(scale=1, variant="compact"):
370
+ simple_instrumental_checkbox = gr.Checkbox(
371
+ label=t("generation.instrumental_label"),
372
+ value=False,
373
+ )
374
+ with gr.Column(scale=18):
375
+ create_sample_btn = gr.Button(
376
+ t("generation.create_sample_btn"),
377
+ variant="primary",
378
+ size="lg",
379
+ )
380
+ with gr.Column(scale=1, variant="compact"):
381
+ simple_vocal_language = gr.Dropdown(
382
+ choices=VALID_LANGUAGES,
383
+ value="unknown",
384
+ allow_custom_value=True,
385
+ label=t("generation.simple_vocal_language_label"),
386
+ interactive=True,
387
+ )
388
+
389
+ # State to track if sample has been created in Simple mode
390
+ simple_sample_created = gr.State(value=False)
391
+
392
+ # Music Caption - wrapped in accordion that can be collapsed in Simple mode
393
+ # Default to expanded for better UX
394
+ with gr.Accordion(t("generation.caption_title"), open=True) as caption_accordion:
395
+ with gr.Row(equal_height=True):
396
+ captions = gr.Textbox(
397
+ label=t("generation.caption_label"),
398
+ placeholder=t("generation.caption_placeholder"),
399
+ lines=3,
400
+ info=t("generation.caption_info"),
401
+ scale=12,
402
+ )
403
+ with gr.Column(scale=1, min_width=100):
404
+ sample_btn = gr.Button(
405
+ "🎲",
406
+ variant="secondary",
407
+ size="sm",
408
+ scale=2,
409
+ )
410
+ # Lyrics - wrapped in accordion that can be collapsed in Simple mode
411
+ # Default to expanded for better UX
412
+ with gr.Accordion(t("generation.lyrics_title"), open=True) as lyrics_accordion:
413
+ lyrics = gr.Textbox(
414
+ label=t("generation.lyrics_label"),
415
+ placeholder=t("generation.lyrics_placeholder"),
416
+ lines=8,
417
+ info=t("generation.lyrics_info")
418
+ )
419
+
420
+ with gr.Row(variant="compact", equal_height=True):
421
+ instrumental_checkbox = gr.Checkbox(
422
+ label=t("generation.instrumental_label"),
423
+ value=False,
424
+ scale=1,
425
+ min_width=120,
426
+ container=True,
427
+ )
428
+
429
+ # 中间:语言选择 (Dropdown)
430
+ # 移除 gr.HTML hack,直接使用 label 参数,Gradio 会自动处理对齐
431
+ vocal_language = gr.Dropdown(
432
+ choices=VALID_LANGUAGES,
433
+ value="unknown",
434
+ label=t("generation.vocal_language_label"),
435
+ show_label=False,
436
+ container=True,
437
+ allow_custom_value=True,
438
+ scale=3,
439
+ )
440
+
441
+ # 右侧:格式化按钮 (Button)
442
+ # 放在同一行最右侧,操作更顺手
443
+ format_btn = gr.Button(
444
+ t("generation.format_btn"),
445
+ variant="secondary",
446
+ scale=1,
447
+ min_width=80,
448
+ )
449
+
450
+ # Optional Parameters
451
+ # In service mode: auto-expand
452
+ with gr.Accordion(t("generation.optional_params"), open=service_mode) as optional_params_accordion:
453
+ with gr.Row():
454
+ bpm = gr.Number(
455
+ label=t("generation.bpm_label"),
456
+ value=None,
457
+ step=1,
458
+ info=t("generation.bpm_info")
459
+ )
460
+ key_scale = gr.Textbox(
461
+ label=t("generation.keyscale_label"),
462
+ placeholder=t("generation.keyscale_placeholder"),
463
+ value="",
464
+ info=t("generation.keyscale_info")
465
+ )
466
+ time_signature = gr.Dropdown(
467
+ choices=["", "2", "3", "4", "6", "N/A"],
468
+ value="",
469
+ label=t("generation.timesig_label"),
470
+ allow_custom_value=True,
471
+ info=t("generation.timesig_info")
472
+ )
473
+ audio_duration = gr.Number(
474
+ label=t("generation.duration_label"),
475
+ value=-1,
476
+ minimum=-1,
477
+ maximum=float(max_duration),
478
+ step=0.1,
479
+ info=t("generation.duration_info") + f" (Max: {max_duration}s / {max_duration // 60} min)"
480
+ )
481
+ batch_size_input = gr.Number(
482
+ label=t("generation.batch_size_label"),
483
+ value=default_batch_size,
484
+ minimum=1,
485
+ maximum=max_batch_size,
486
+ step=1,
487
+ info=t("generation.batch_size_info") + f" (Max: {max_batch_size})",
488
+ interactive=not service_mode # Fixed in service mode
489
+ )
490
+
491
+ # Advanced Settings
492
+ # Default UI settings use turbo mode (max 20 steps, default 8, show shift with default 3)
493
+ # These will be updated after model initialization based on handler.is_turbo_model()
494
+ with gr.Accordion(t("generation.advanced_settings"), open=False):
495
+ with gr.Row():
496
+ inference_steps = gr.Slider(
497
+ minimum=1,
498
+ maximum=20,
499
+ value=8,
500
+ step=1,
501
+ label=t("generation.inference_steps_label"),
502
+ info=t("generation.inference_steps_info")
503
+ )
504
+ guidance_scale = gr.Slider(
505
+ minimum=1.0,
506
+ maximum=15.0,
507
+ value=7.0,
508
+ step=0.1,
509
+ label=t("generation.guidance_scale_label"),
510
+ info=t("generation.guidance_scale_info"),
511
+ visible=False
512
+ )
513
+ with gr.Column():
514
+ seed = gr.Textbox(
515
+ label=t("generation.seed_label"),
516
+ value="-1",
517
+ info=t("generation.seed_info")
518
+ )
519
+ random_seed_checkbox = gr.Checkbox(
520
+ label=t("generation.random_seed_label"),
521
+ value=True,
522
+ info=t("generation.random_seed_info")
523
+ )
524
+ audio_format = gr.Dropdown(
525
+ choices=["mp3", "flac"],
526
+ value="mp3",
527
+ label=t("generation.audio_format_label"),
528
+ info=t("generation.audio_format_info"),
529
+ interactive=not service_mode # Fixed in service mode
530
+ )
531
+
532
+ with gr.Row():
533
+ use_adg = gr.Checkbox(
534
+ label=t("generation.use_adg_label"),
535
+ value=False,
536
+ info=t("generation.use_adg_info"),
537
+ visible=False
538
+ )
539
+ shift = gr.Slider(
540
+ minimum=1.0,
541
+ maximum=5.0,
542
+ value=3.0,
543
+ step=0.1,
544
+ label=t("generation.shift_label"),
545
+ info=t("generation.shift_info"),
546
+ visible=True
547
+ )
548
+ infer_method = gr.Dropdown(
549
+ choices=["ode", "sde"],
550
+ value="ode",
551
+ label=t("generation.infer_method_label"),
552
+ info=t("generation.infer_method_info"),
553
+ )
554
+
555
+ with gr.Row():
556
+ custom_timesteps = gr.Textbox(
557
+ label=t("generation.custom_timesteps_label"),
558
+ placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
559
+ value="",
560
+ info=t("generation.custom_timesteps_info"),
561
+ )
562
+
563
+ with gr.Row():
564
+ cfg_interval_start = gr.Slider(
565
+ minimum=0.0,
566
+ maximum=1.0,
567
+ value=0.0,
568
+ step=0.01,
569
+ label=t("generation.cfg_interval_start"),
570
+ visible=False
571
+ )
572
+ cfg_interval_end = gr.Slider(
573
+ minimum=0.0,
574
+ maximum=1.0,
575
+ value=1.0,
576
+ step=0.01,
577
+ label=t("generation.cfg_interval_end"),
578
+ visible=False
579
+ )
580
+
581
+ # LM (Language Model) Parameters
582
+ gr.HTML(f"<h4>{t('generation.lm_params_title')}</h4>")
583
+ with gr.Row():
584
+ lm_temperature = gr.Slider(
585
+ label=t("generation.lm_temperature_label"),
586
+ minimum=0.0,
587
+ maximum=2.0,
588
+ value=0.85,
589
+ step=0.1,
590
+ scale=1,
591
+ info=t("generation.lm_temperature_info")
592
+ )
593
+ lm_cfg_scale = gr.Slider(
594
+ label=t("generation.lm_cfg_scale_label"),
595
+ minimum=1.0,
596
+ maximum=3.0,
597
+ value=2.0,
598
+ step=0.1,
599
+ scale=1,
600
+ info=t("generation.lm_cfg_scale_info")
601
+ )
602
+ lm_top_k = gr.Slider(
603
+ label=t("generation.lm_top_k_label"),
604
+ minimum=0,
605
+ maximum=100,
606
+ value=0,
607
+ step=1,
608
+ scale=1,
609
+ info=t("generation.lm_top_k_info")
610
+ )
611
+ lm_top_p = gr.Slider(
612
+ label=t("generation.lm_top_p_label"),
613
+ minimum=0.0,
614
+ maximum=1.0,
615
+ value=0.9,
616
+ step=0.01,
617
+ scale=1,
618
+ info=t("generation.lm_top_p_info")
619
+ )
620
+
621
+ with gr.Row():
622
+ lm_negative_prompt = gr.Textbox(
623
+ label=t("generation.lm_negative_prompt_label"),
624
+ value="NO USER INPUT",
625
+ placeholder=t("generation.lm_negative_prompt_placeholder"),
626
+ info=t("generation.lm_negative_prompt_info"),
627
+ lines=2,
628
+ scale=2,
629
+ )
630
+
631
+ with gr.Row():
632
+ use_cot_metas = gr.Checkbox(
633
+ label=t("generation.cot_metas_label"),
634
+ value=True,
635
+ info=t("generation.cot_metas_info"),
636
+ scale=1,
637
+ )
638
+ use_cot_language = gr.Checkbox(
639
+ label=t("generation.cot_language_label"),
640
+ value=True,
641
+ info=t("generation.cot_language_info"),
642
+ scale=1,
643
+ )
644
+ constrained_decoding_debug = gr.Checkbox(
645
+ label=t("generation.constrained_debug_label"),
646
+ value=False,
647
+ info=t("generation.constrained_debug_info"),
648
+ scale=1,
649
+ interactive=not service_mode # Fixed in service mode
650
+ )
651
+
652
+ with gr.Row():
653
+ auto_score = gr.Checkbox(
654
+ label=t("generation.auto_score_label"),
655
+ value=False,
656
+ info=t("generation.auto_score_info"),
657
+ scale=1,
658
+ interactive=not service_mode # Fixed in service mode
659
+ )
660
+ auto_lrc = gr.Checkbox(
661
+ label=t("generation.auto_lrc_label"),
662
+ value=False,
663
+ info=t("generation.auto_lrc_info"),
664
+ scale=1,
665
+ interactive=not service_mode # Fixed in service mode
666
+ )
667
+ lm_batch_chunk_size = gr.Number(
668
+ label=t("generation.lm_batch_chunk_label"),
669
+ value=8,
670
+ minimum=1,
671
+ maximum=32,
672
+ step=1,
673
+ info=t("generation.lm_batch_chunk_info"),
674
+ scale=1,
675
+ interactive=not service_mode # Fixed in service mode
676
+ )
677
+
678
+ with gr.Row():
679
+ audio_cover_strength = gr.Slider(
680
+ minimum=0.0,
681
+ maximum=1.0,
682
+ value=1.0,
683
+ step=0.01,
684
+ label=t("generation.codes_strength_label"),
685
+ info=t("generation.codes_strength_info"),
686
+ scale=1,
687
+ )
688
+ score_scale = gr.Slider(
689
+ minimum=0.01,
690
+ maximum=1.0,
691
+ value=0.5,
692
+ step=0.01,
693
+ label=t("generation.score_sensitivity_label"),
694
+ info=t("generation.score_sensitivity_info"),
695
+ scale=1,
696
+ visible=not service_mode # Hidden in service mode
697
+ )
698
+
699
+ # Set generate_btn to interactive if service is pre-initialized
700
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
701
+ with gr.Row(equal_height=True):
702
+ with gr.Column(scale=1, variant="compact"):
703
+ think_checkbox = gr.Checkbox(
704
+ label=t("generation.think_label"),
705
+ value=True,
706
+ scale=1,
707
+ )
708
+ allow_lm_batch = gr.Checkbox(
709
+ label=t("generation.parallel_thinking_label"),
710
+ value=True,
711
+ scale=1,
712
+ )
713
+ with gr.Column(scale=18):
714
+ generate_btn = gr.Button(t("generation.generate_btn"), variant="primary", size="lg", interactive=generate_btn_interactive)
715
+ with gr.Column(scale=1, variant="compact"):
716
+ autogen_checkbox = gr.Checkbox(
717
+ label=t("generation.autogen_label"),
718
+ value=False, # Default to False for both service and local modes
719
+ scale=1,
720
+ interactive=not service_mode # Not selectable in service mode
721
+ )
722
+ use_cot_caption = gr.Checkbox(
723
+ label=t("generation.caption_rewrite_label"),
724
+ value=True,
725
+ scale=1,
726
+ )
727
+
728
+ return {
729
+ "service_config_accordion": service_config_accordion,
730
+ "language_dropdown": language_dropdown,
731
+ "checkpoint_dropdown": checkpoint_dropdown,
732
+ "refresh_btn": refresh_btn,
733
+ "config_path": config_path,
734
+ "device": device,
735
+ "init_btn": init_btn,
736
+ "init_status": init_status,
737
+ "lm_model_path": lm_model_path,
738
+ "init_llm_checkbox": init_llm_checkbox,
739
+ "backend_dropdown": backend_dropdown,
740
+ "use_flash_attention_checkbox": use_flash_attention_checkbox,
741
+ "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
742
+ "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
743
+ "compile_model_checkbox": compile_model_checkbox,
744
+ "quantization_checkbox": quantization_checkbox,
745
+ # LoRA components
746
+ "lora_path": lora_path,
747
+ "load_lora_btn": load_lora_btn,
748
+ "unload_lora_btn": unload_lora_btn,
749
+ "use_lora_checkbox": use_lora_checkbox,
750
+ "lora_scale_slider": lora_scale_slider,
751
+ "lora_status": lora_status,
752
+ "task_type": task_type,
753
+ "instruction_display_gen": instruction_display_gen,
754
+ "track_name": track_name,
755
+ "complete_track_classes": complete_track_classes,
756
+ "audio_uploads_accordion": audio_uploads_accordion,
757
+ "reference_audio": reference_audio,
758
+ "src_audio": src_audio,
759
+ "convert_src_to_codes_btn": convert_src_to_codes_btn,
760
+ "text2music_audio_code_string": text2music_audio_code_string,
761
+ "transcribe_btn": transcribe_btn,
762
+ "text2music_audio_codes_group": text2music_audio_codes_group,
763
+ "lm_temperature": lm_temperature,
764
+ "lm_cfg_scale": lm_cfg_scale,
765
+ "lm_top_k": lm_top_k,
766
+ "lm_top_p": lm_top_p,
767
+ "lm_negative_prompt": lm_negative_prompt,
768
+ "use_cot_metas": use_cot_metas,
769
+ "use_cot_caption": use_cot_caption,
770
+ "use_cot_language": use_cot_language,
771
+ "repainting_group": repainting_group,
772
+ "repainting_start": repainting_start,
773
+ "repainting_end": repainting_end,
774
+ "audio_cover_strength": audio_cover_strength,
775
+ # Simple/Custom Mode Components
776
+ "generation_mode": generation_mode,
777
+ "simple_mode_group": simple_mode_group,
778
+ "simple_query_input": simple_query_input,
779
+ "random_desc_btn": random_desc_btn,
780
+ "simple_instrumental_checkbox": simple_instrumental_checkbox,
781
+ "simple_vocal_language": simple_vocal_language,
782
+ "create_sample_btn": create_sample_btn,
783
+ "simple_sample_created": simple_sample_created,
784
+ "caption_accordion": caption_accordion,
785
+ "lyrics_accordion": lyrics_accordion,
786
+ "optional_params_accordion": optional_params_accordion,
787
+ # Existing components
788
+ "captions": captions,
789
+ "sample_btn": sample_btn,
790
+ "load_file": load_file,
791
+ "lyrics": lyrics,
792
+ "vocal_language": vocal_language,
793
+ "bpm": bpm,
794
+ "key_scale": key_scale,
795
+ "time_signature": time_signature,
796
+ "audio_duration": audio_duration,
797
+ "batch_size_input": batch_size_input,
798
+ "inference_steps": inference_steps,
799
+ "guidance_scale": guidance_scale,
800
+ "seed": seed,
801
+ "random_seed_checkbox": random_seed_checkbox,
802
+ "use_adg": use_adg,
803
+ "cfg_interval_start": cfg_interval_start,
804
+ "cfg_interval_end": cfg_interval_end,
805
+ "shift": shift,
806
+ "infer_method": infer_method,
807
+ "custom_timesteps": custom_timesteps,
808
+ "audio_format": audio_format,
809
+ "think_checkbox": think_checkbox,
810
+ "autogen_checkbox": autogen_checkbox,
811
+ "generate_btn": generate_btn,
812
+ "instrumental_checkbox": instrumental_checkbox,
813
+ "format_btn": format_btn,
814
+ "constrained_decoding_debug": constrained_decoding_debug,
815
+ "score_scale": score_scale,
816
+ "allow_lm_batch": allow_lm_batch,
817
+ "auto_score": auto_score,
818
+ "auto_lrc": auto_lrc,
819
+ "lm_batch_chunk_size": lm_batch_chunk_size,
820
+ # GPU config values for validation
821
+ "gpu_config": gpu_config,
822
+ "max_duration": max_duration,
823
+ "max_batch_size": max_batch_size,
824
+ }
acestep/gradio_ui/interfaces/result.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Results Section Module
3
+ Contains results display section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import t
7
+
8
+
9
+ def create_results_section(dit_handler) -> dict:
10
+ """Create results display section"""
11
+ with gr.Accordion(t("results.title"), open=True):
12
+ # Hidden state to store LM-generated metadata
13
+ lm_metadata_state = gr.State(value=None)
14
+
15
+ # Hidden state to track if caption/metadata is from formatted source (LM/transcription)
16
+ is_format_caption_state = gr.State(value=False)
17
+
18
+ # Batch management states
19
+ current_batch_index = gr.State(value=0) # Currently displayed batch index
20
+ total_batches = gr.State(value=1) # Total number of batches generated
21
+ batch_queue = gr.State(value={}) # Dictionary storing all batch data
22
+ generation_params_state = gr.State(value={}) # Store generation parameters for next batches
23
+ is_generating_background = gr.State(value=False) # Background generation flag
24
+
25
+ # All audio components in one row with dynamic visibility
26
+ with gr.Row():
27
+ with gr.Column(visible=True) as audio_col_1:
28
+ generated_audio_1 = gr.Audio(
29
+ label=t("results.generated_music", n=1),
30
+ type="filepath",
31
+ interactive=False,
32
+ buttons=[]
33
+ )
34
+ with gr.Row(equal_height=True):
35
+ send_to_src_btn_1 = gr.Button(
36
+ t("results.send_to_src_btn"),
37
+ variant="secondary",
38
+ size="sm",
39
+ scale=1
40
+ )
41
+ save_btn_1 = gr.Button(
42
+ t("results.save_btn"),
43
+ variant="primary",
44
+ size="sm",
45
+ scale=1
46
+ )
47
+ score_btn_1 = gr.Button(
48
+ t("results.score_btn"),
49
+ variant="secondary",
50
+ size="sm",
51
+ scale=1
52
+ )
53
+ lrc_btn_1 = gr.Button(
54
+ t("results.lrc_btn"),
55
+ variant="secondary",
56
+ size="sm",
57
+ scale=1
58
+ )
59
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
60
+ codes_display_1 = gr.Textbox(
61
+ label=t("results.codes_label", n=1),
62
+ interactive=False,
63
+ buttons=["copy"],
64
+ lines=4,
65
+ max_lines=4,
66
+ visible=True
67
+ )
68
+ score_display_1 = gr.Textbox(
69
+ label=t("results.quality_score_label", n=1),
70
+ interactive=False,
71
+ buttons=["copy"],
72
+ lines=6,
73
+ max_lines=6,
74
+ visible=True
75
+ )
76
+ lrc_display_1 = gr.Textbox(
77
+ label=t("results.lrc_label", n=1),
78
+ interactive=True,
79
+ buttons=["copy"],
80
+ lines=8,
81
+ max_lines=8,
82
+ visible=True
83
+ )
84
+ with gr.Column(visible=True) as audio_col_2:
85
+ generated_audio_2 = gr.Audio(
86
+ label=t("results.generated_music", n=2),
87
+ type="filepath",
88
+ interactive=False,
89
+ buttons=[]
90
+ )
91
+ with gr.Row(equal_height=True):
92
+ send_to_src_btn_2 = gr.Button(
93
+ t("results.send_to_src_btn"),
94
+ variant="secondary",
95
+ size="sm",
96
+ scale=1
97
+ )
98
+ save_btn_2 = gr.Button(
99
+ t("results.save_btn"),
100
+ variant="primary",
101
+ size="sm",
102
+ scale=1
103
+ )
104
+ score_btn_2 = gr.Button(
105
+ t("results.score_btn"),
106
+ variant="secondary",
107
+ size="sm",
108
+ scale=1
109
+ )
110
+ lrc_btn_2 = gr.Button(
111
+ t("results.lrc_btn"),
112
+ variant="secondary",
113
+ size="sm",
114
+ scale=1
115
+ )
116
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
117
+ codes_display_2 = gr.Textbox(
118
+ label=t("results.codes_label", n=2),
119
+ interactive=False,
120
+ buttons=["copy"],
121
+ lines=4,
122
+ max_lines=4,
123
+ visible=True
124
+ )
125
+ score_display_2 = gr.Textbox(
126
+ label=t("results.quality_score_label", n=2),
127
+ interactive=False,
128
+ buttons=["copy"],
129
+ lines=6,
130
+ max_lines=6,
131
+ visible=True
132
+ )
133
+ lrc_display_2 = gr.Textbox(
134
+ label=t("results.lrc_label", n=2),
135
+ interactive=True,
136
+ buttons=["copy"],
137
+ lines=8,
138
+ max_lines=8,
139
+ visible=True
140
+ )
141
+ with gr.Column(visible=False) as audio_col_3:
142
+ generated_audio_3 = gr.Audio(
143
+ label=t("results.generated_music", n=3),
144
+ type="filepath",
145
+ interactive=False,
146
+ buttons=[]
147
+ )
148
+ with gr.Row(equal_height=True):
149
+ send_to_src_btn_3 = gr.Button(
150
+ t("results.send_to_src_btn"),
151
+ variant="secondary",
152
+ size="sm",
153
+ scale=1
154
+ )
155
+ save_btn_3 = gr.Button(
156
+ t("results.save_btn"),
157
+ variant="primary",
158
+ size="sm",
159
+ scale=1
160
+ )
161
+ score_btn_3 = gr.Button(
162
+ t("results.score_btn"),
163
+ variant="secondary",
164
+ size="sm",
165
+ scale=1
166
+ )
167
+ lrc_btn_3 = gr.Button(
168
+ t("results.lrc_btn"),
169
+ variant="secondary",
170
+ size="sm",
171
+ scale=1
172
+ )
173
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
174
+ codes_display_3 = gr.Textbox(
175
+ label=t("results.codes_label", n=3),
176
+ interactive=False,
177
+ buttons=["copy"],
178
+ lines=4,
179
+ max_lines=4,
180
+ visible=True
181
+ )
182
+ score_display_3 = gr.Textbox(
183
+ label=t("results.quality_score_label", n=3),
184
+ interactive=False,
185
+ buttons=["copy"],
186
+ lines=6,
187
+ max_lines=6,
188
+ visible=True
189
+ )
190
+ lrc_display_3 = gr.Textbox(
191
+ label=t("results.lrc_label", n=3),
192
+ interactive=True,
193
+ buttons=["copy"],
194
+ lines=8,
195
+ max_lines=8,
196
+ visible=True
197
+ )
198
+ with gr.Column(visible=False) as audio_col_4:
199
+ generated_audio_4 = gr.Audio(
200
+ label=t("results.generated_music", n=4),
201
+ type="filepath",
202
+ interactive=False,
203
+ buttons=[]
204
+ )
205
+ with gr.Row(equal_height=True):
206
+ send_to_src_btn_4 = gr.Button(
207
+ t("results.send_to_src_btn"),
208
+ variant="secondary",
209
+ size="sm",
210
+ scale=1
211
+ )
212
+ save_btn_4 = gr.Button(
213
+ t("results.save_btn"),
214
+ variant="primary",
215
+ size="sm",
216
+ scale=1
217
+ )
218
+ score_btn_4 = gr.Button(
219
+ t("results.score_btn"),
220
+ variant="secondary",
221
+ size="sm",
222
+ scale=1
223
+ )
224
+ lrc_btn_4 = gr.Button(
225
+ t("results.lrc_btn"),
226
+ variant="secondary",
227
+ size="sm",
228
+ scale=1
229
+ )
230
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
231
+ codes_display_4 = gr.Textbox(
232
+ label=t("results.codes_label", n=4),
233
+ interactive=False,
234
+ buttons=["copy"],
235
+ lines=4,
236
+ max_lines=4,
237
+ visible=True
238
+ )
239
+ score_display_4 = gr.Textbox(
240
+ label=t("results.quality_score_label", n=4),
241
+ interactive=False,
242
+ buttons=["copy"],
243
+ lines=6,
244
+ max_lines=6,
245
+ visible=True
246
+ )
247
+ lrc_display_4 = gr.Textbox(
248
+ label=t("results.lrc_label", n=4),
249
+ interactive=True,
250
+ buttons=["copy"],
251
+ lines=8,
252
+ max_lines=8,
253
+ visible=True
254
+ )
255
+
256
+ # Second row for batch size 5-8 (initially hidden)
257
+ with gr.Row(visible=False) as audio_row_5_8:
258
+ with gr.Column() as audio_col_5:
259
+ generated_audio_5 = gr.Audio(
260
+ label=t("results.generated_music", n=5),
261
+ type="filepath",
262
+ interactive=False,
263
+ buttons=[]
264
+ )
265
+ with gr.Row(equal_height=True):
266
+ send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
267
+ save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
268
+ score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
269
+ lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
270
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
271
+ codes_display_5 = gr.Textbox(
272
+ label=t("results.codes_label", n=5),
273
+ interactive=False,
274
+ buttons=["copy"],
275
+ lines=4,
276
+ max_lines=4,
277
+ visible=True
278
+ )
279
+ score_display_5 = gr.Textbox(
280
+ label=t("results.quality_score_label", n=5),
281
+ interactive=False,
282
+ buttons=["copy"],
283
+ lines=6,
284
+ max_lines=6,
285
+ visible=True
286
+ )
287
+ lrc_display_5 = gr.Textbox(
288
+ label=t("results.lrc_label", n=5),
289
+ interactive=True,
290
+ buttons=["copy"],
291
+ lines=8,
292
+ max_lines=8,
293
+ visible=True
294
+ )
295
+ with gr.Column() as audio_col_6:
296
+ generated_audio_6 = gr.Audio(
297
+ label=t("results.generated_music", n=6),
298
+ type="filepath",
299
+ interactive=False,
300
+ buttons=[]
301
+ )
302
+ with gr.Row(equal_height=True):
303
+ send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
304
+ save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
305
+ score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
306
+ lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
307
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
308
+ codes_display_6 = gr.Textbox(
309
+ label=t("results.codes_label", n=6),
310
+ interactive=False,
311
+ buttons=["copy"],
312
+ lines=4,
313
+ max_lines=4,
314
+ visible=True
315
+ )
316
+ score_display_6 = gr.Textbox(
317
+ label=t("results.quality_score_label", n=6),
318
+ interactive=False,
319
+ buttons=["copy"],
320
+ lines=6,
321
+ max_lines=6,
322
+ visible=True
323
+ )
324
+ lrc_display_6 = gr.Textbox(
325
+ label=t("results.lrc_label", n=6),
326
+ interactive=True,
327
+ buttons=["copy"],
328
+ lines=8,
329
+ max_lines=8,
330
+ visible=True
331
+ )
332
+ with gr.Column() as audio_col_7:
333
+ generated_audio_7 = gr.Audio(
334
+ label=t("results.generated_music", n=7),
335
+ type="filepath",
336
+ interactive=False,
337
+ buttons=[]
338
+ )
339
+ with gr.Row(equal_height=True):
340
+ send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
341
+ save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
342
+ score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
343
+ lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
344
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
345
+ codes_display_7 = gr.Textbox(
346
+ label=t("results.codes_label", n=7),
347
+ interactive=False,
348
+ buttons=["copy"],
349
+ lines=4,
350
+ max_lines=4,
351
+ visible=True
352
+ )
353
+ score_display_7 = gr.Textbox(
354
+ label=t("results.quality_score_label", n=7),
355
+ interactive=False,
356
+ buttons=["copy"],
357
+ lines=6,
358
+ max_lines=6,
359
+ visible=True
360
+ )
361
+ lrc_display_7 = gr.Textbox(
362
+ label=t("results.lrc_label", n=7),
363
+ interactive=True,
364
+ buttons=["copy"],
365
+ lines=8,
366
+ max_lines=8,
367
+ visible=True
368
+ )
369
+ with gr.Column() as audio_col_8:
370
+ generated_audio_8 = gr.Audio(
371
+ label=t("results.generated_music", n=8),
372
+ type="filepath",
373
+ interactive=False,
374
+ buttons=[]
375
+ )
376
+ with gr.Row(equal_height=True):
377
+ send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
378
+ save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
379
+ score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
380
+ lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1)
381
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
382
+ codes_display_8 = gr.Textbox(
383
+ label=t("results.codes_label", n=8),
384
+ interactive=False,
385
+ buttons=["copy"],
386
+ lines=4,
387
+ max_lines=4,
388
+ visible=True
389
+ )
390
+ score_display_8 = gr.Textbox(
391
+ label=t("results.quality_score_label", n=8),
392
+ interactive=False,
393
+ buttons=["copy"],
394
+ lines=6,
395
+ max_lines=6,
396
+ visible=True
397
+ )
398
+ lrc_display_8 = gr.Textbox(
399
+ label=t("results.lrc_label", n=8),
400
+ interactive=True,
401
+ buttons=["copy"],
402
+ lines=8,
403
+ max_lines=8,
404
+ visible=True
405
+ )
406
+
407
+ status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
408
+
409
+ # Batch navigation controls
410
+ with gr.Row(equal_height=True):
411
+ prev_batch_btn = gr.Button(
412
+ t("results.prev_btn"),
413
+ variant="secondary",
414
+ interactive=False,
415
+ scale=1,
416
+ size="sm"
417
+ )
418
+ batch_indicator = gr.Textbox(
419
+ label=t("results.current_batch"),
420
+ value=t("results.batch_indicator", current=1, total=1),
421
+ interactive=False,
422
+ scale=3
423
+ )
424
+ next_batch_status = gr.Textbox(
425
+ label=t("results.next_batch_status"),
426
+ value="",
427
+ interactive=False,
428
+ scale=3
429
+ )
430
+ next_batch_btn = gr.Button(
431
+ t("results.next_btn"),
432
+ variant="primary",
433
+ interactive=False,
434
+ scale=1,
435
+ size="sm"
436
+ )
437
+
438
+ # One-click restore parameters button
439
+ restore_params_btn = gr.Button(
440
+ t("results.restore_params_btn"),
441
+ variant="secondary",
442
+ interactive=False, # Initially disabled, enabled after generation
443
+ size="sm"
444
+ )
445
+
446
+ with gr.Accordion(t("results.batch_results_title"), open=False):
447
+ generated_audio_batch = gr.File(
448
+ label=t("results.all_files_label"),
449
+ file_count="multiple",
450
+ interactive=False
451
+ )
452
+ generation_info = gr.Markdown(label=t("results.generation_details"))
453
+
454
+ return {
455
+ "lm_metadata_state": lm_metadata_state,
456
+ "is_format_caption_state": is_format_caption_state,
457
+ "current_batch_index": current_batch_index,
458
+ "total_batches": total_batches,
459
+ "batch_queue": batch_queue,
460
+ "generation_params_state": generation_params_state,
461
+ "is_generating_background": is_generating_background,
462
+ "status_output": status_output,
463
+ "prev_batch_btn": prev_batch_btn,
464
+ "batch_indicator": batch_indicator,
465
+ "next_batch_btn": next_batch_btn,
466
+ "next_batch_status": next_batch_status,
467
+ "restore_params_btn": restore_params_btn,
468
+ "generated_audio_1": generated_audio_1,
469
+ "generated_audio_2": generated_audio_2,
470
+ "generated_audio_3": generated_audio_3,
471
+ "generated_audio_4": generated_audio_4,
472
+ "generated_audio_5": generated_audio_5,
473
+ "generated_audio_6": generated_audio_6,
474
+ "generated_audio_7": generated_audio_7,
475
+ "generated_audio_8": generated_audio_8,
476
+ "audio_row_5_8": audio_row_5_8,
477
+ "audio_col_1": audio_col_1,
478
+ "audio_col_2": audio_col_2,
479
+ "audio_col_3": audio_col_3,
480
+ "audio_col_4": audio_col_4,
481
+ "audio_col_5": audio_col_5,
482
+ "audio_col_6": audio_col_6,
483
+ "audio_col_7": audio_col_7,
484
+ "audio_col_8": audio_col_8,
485
+ "send_to_src_btn_1": send_to_src_btn_1,
486
+ "send_to_src_btn_2": send_to_src_btn_2,
487
+ "send_to_src_btn_3": send_to_src_btn_3,
488
+ "send_to_src_btn_4": send_to_src_btn_4,
489
+ "send_to_src_btn_5": send_to_src_btn_5,
490
+ "send_to_src_btn_6": send_to_src_btn_6,
491
+ "send_to_src_btn_7": send_to_src_btn_7,
492
+ "send_to_src_btn_8": send_to_src_btn_8,
493
+ "save_btn_1": save_btn_1,
494
+ "save_btn_2": save_btn_2,
495
+ "save_btn_3": save_btn_3,
496
+ "save_btn_4": save_btn_4,
497
+ "save_btn_5": save_btn_5,
498
+ "save_btn_6": save_btn_6,
499
+ "save_btn_7": save_btn_7,
500
+ "save_btn_8": save_btn_8,
501
+ "score_btn_1": score_btn_1,
502
+ "score_btn_2": score_btn_2,
503
+ "score_btn_3": score_btn_3,
504
+ "score_btn_4": score_btn_4,
505
+ "score_btn_5": score_btn_5,
506
+ "score_btn_6": score_btn_6,
507
+ "score_btn_7": score_btn_7,
508
+ "score_btn_8": score_btn_8,
509
+ "score_display_1": score_display_1,
510
+ "score_display_2": score_display_2,
511
+ "score_display_3": score_display_3,
512
+ "score_display_4": score_display_4,
513
+ "score_display_5": score_display_5,
514
+ "score_display_6": score_display_6,
515
+ "score_display_7": score_display_7,
516
+ "score_display_8": score_display_8,
517
+ "codes_display_1": codes_display_1,
518
+ "codes_display_2": codes_display_2,
519
+ "codes_display_3": codes_display_3,
520
+ "codes_display_4": codes_display_4,
521
+ "codes_display_5": codes_display_5,
522
+ "codes_display_6": codes_display_6,
523
+ "codes_display_7": codes_display_7,
524
+ "codes_display_8": codes_display_8,
525
+ "lrc_btn_1": lrc_btn_1,
526
+ "lrc_btn_2": lrc_btn_2,
527
+ "lrc_btn_3": lrc_btn_3,
528
+ "lrc_btn_4": lrc_btn_4,
529
+ "lrc_btn_5": lrc_btn_5,
530
+ "lrc_btn_6": lrc_btn_6,
531
+ "lrc_btn_7": lrc_btn_7,
532
+ "lrc_btn_8": lrc_btn_8,
533
+ "lrc_display_1": lrc_display_1,
534
+ "lrc_display_2": lrc_display_2,
535
+ "lrc_display_3": lrc_display_3,
536
+ "lrc_display_4": lrc_display_4,
537
+ "lrc_display_5": lrc_display_5,
538
+ "lrc_display_6": lrc_display_6,
539
+ "lrc_display_7": lrc_display_7,
540
+ "lrc_display_8": lrc_display_8,
541
+ "details_accordion_1": details_accordion_1,
542
+ "details_accordion_2": details_accordion_2,
543
+ "details_accordion_3": details_accordion_3,
544
+ "details_accordion_4": details_accordion_4,
545
+ "details_accordion_5": details_accordion_5,
546
+ "details_accordion_6": details_accordion_6,
547
+ "details_accordion_7": details_accordion_7,
548
+ "details_accordion_8": details_accordion_8,
549
+ "generated_audio_batch": generated_audio_batch,
550
+ "generation_info": generation_info,
551
+ }
552
+
acestep/gradio_ui/interfaces/training.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Training Tab Module
3
+
4
+ Contains the dataset builder and LoRA training interface components.
5
+ """
6
+
7
+ import os
8
+ import gradio as gr
9
+ from acestep.gradio_ui.i18n import t
10
+ from acestep.constants import DEBUG_TRAINING
11
+
12
+
13
+ def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
14
+ """Create the training tab section with dataset builder and training controls.
15
+
16
+ Args:
17
+ dit_handler: DiT handler instance
18
+ llm_handler: LLM handler instance
19
+ init_params: Dictionary containing initialization parameters and state.
20
+ If None, service will not be pre-initialized.
21
+
22
+ Returns:
23
+ Dictionary of Gradio components for event handling
24
+ """
25
+ # Check if running in service mode (hide training tab)
26
+ service_mode = init_params is not None and init_params.get('service_mode', False)
27
+
28
+ debug_training_enabled = str(DEBUG_TRAINING).strip().upper() != "OFF"
29
+ epoch_min = 1 if debug_training_enabled else 100
30
+ epoch_step = 1 if debug_training_enabled else 100
31
+ epoch_default = 1 if debug_training_enabled else 1000
32
+
33
+ with gr.Tab(t("training.tab_title"), visible=not service_mode):
34
+ gr.HTML("""
35
+ <div style="text-align: center; padding: 10px; margin-bottom: 15px;">
36
+ <h2>🎵 LoRA Training for ACE-Step</h2>
37
+ <p>Build datasets from your audio files and train custom LoRA adapters</p>
38
+ </div>
39
+ """)
40
+
41
+ with gr.Tabs():
42
+ # ==================== Dataset Builder Tab ====================
43
+ with gr.Tab(t("training.tab_dataset_builder")):
44
+ # ========== Load Existing OR Scan New ==========
45
+ gr.HTML(f"""
46
+ <div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
47
+ <h3 style="margin: 0 0 5px 0;">{t("training.quick_start_title")}</h3>
48
+ <p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
49
+ </div>
50
+ """)
51
+
52
+ with gr.Row():
53
+ with gr.Column(scale=1):
54
+ gr.HTML("<h4>📂 Load Existing Dataset</h4>")
55
+ with gr.Row():
56
+ load_json_path = gr.Textbox(
57
+ label=t("training.load_dataset_label"),
58
+ placeholder="./datasets/my_lora_dataset.json",
59
+ info=t("training.load_dataset_info"),
60
+ scale=3,
61
+ )
62
+ load_json_btn = gr.Button(t("training.load_btn"), variant="primary", scale=1)
63
+ load_json_status = gr.Textbox(
64
+ label=t("training.load_status"),
65
+ interactive=False,
66
+ )
67
+
68
+ with gr.Column(scale=1):
69
+ gr.HTML("<h4>🔍 Scan New Directory</h4>")
70
+ with gr.Row():
71
+ audio_directory = gr.Textbox(
72
+ label=t("training.scan_label"),
73
+ placeholder="/path/to/your/audio/folder",
74
+ info=t("training.scan_info"),
75
+ scale=3,
76
+ )
77
+ scan_btn = gr.Button(t("training.scan_btn"), variant="secondary", scale=1)
78
+ scan_status = gr.Textbox(
79
+ label=t("training.scan_status"),
80
+ interactive=False,
81
+ )
82
+
83
+ gr.HTML("<hr>")
84
+
85
+ with gr.Row():
86
+ with gr.Column(scale=2):
87
+
88
+ # Audio files table
89
+ audio_files_table = gr.Dataframe(
90
+ headers=["#", "Filename", "Duration", "Lyrics", "Labeled", "BPM", "Key", "Caption"],
91
+ datatype=["number", "str", "str", "str", "str", "str", "str", "str"],
92
+ label=t("training.found_audio_files"),
93
+ interactive=False,
94
+ wrap=True,
95
+ )
96
+
97
+ with gr.Column(scale=1):
98
+ gr.HTML(f"<h3>⚙️ {t('training.dataset_settings_header')}</h3>")
99
+
100
+ dataset_name = gr.Textbox(
101
+ label=t("training.dataset_name"),
102
+ value="my_lora_dataset",
103
+ placeholder=t("training.dataset_name_placeholder"),
104
+ )
105
+
106
+ all_instrumental = gr.Checkbox(
107
+ label=t("training.all_instrumental"),
108
+ value=True,
109
+ info=t("training.all_instrumental_info"),
110
+ )
111
+
112
+ format_lyrics = gr.Checkbox(
113
+ label="Format Lyrics (LM)",
114
+ value=False,
115
+ info="Use LM to format/structure user-provided lyrics from .txt files (coming soon)",
116
+ interactive=False, # Disabled for now - model update needed
117
+ )
118
+
119
+ transcribe_lyrics = gr.Checkbox(
120
+ label="Transcribe Lyrics (LM)",
121
+ value=False,
122
+ info="Use LM to transcribe lyrics from audio (coming soon)",
123
+ interactive=False, # Disabled for now - model update needed
124
+ )
125
+
126
+ custom_tag = gr.Textbox(
127
+ label=t("training.custom_tag"),
128
+ placeholder="e.g., 8bit_retro, my_style",
129
+ info=t("training.custom_tag_info"),
130
+ )
131
+
132
+ tag_position = gr.Radio(
133
+ choices=[
134
+ (t("training.tag_prepend"), "prepend"),
135
+ (t("training.tag_append"), "append"),
136
+ (t("training.tag_replace"), "replace"),
137
+ ],
138
+ value="replace",
139
+ label=t("training.tag_position"),
140
+ info=t("training.tag_position_info"),
141
+ )
142
+
143
+ genre_ratio = gr.Slider(
144
+ minimum=0,
145
+ maximum=100,
146
+ step=10,
147
+ value=0,
148
+ label=t("training.genre_ratio"),
149
+ info=t("training.genre_ratio_info"),
150
+ )
151
+
152
+ gr.HTML(f"<hr><h3>🤖 {t('training.step2_title')}</h3>")
153
+
154
+ with gr.Row():
155
+ with gr.Column(scale=3):
156
+ gr.Markdown("""
157
+ Click the button below to automatically generate metadata for all audio files using AI:
158
+ - **Caption**: Music style, genre, mood description
159
+ - **BPM**: Beats per minute
160
+ - **Key**: Musical key (e.g., C Major, Am)
161
+ - **Time Signature**: 4/4, 3/4, etc.
162
+ """)
163
+ skip_metas = gr.Checkbox(
164
+ label=t("training.skip_metas"),
165
+ value=False,
166
+ info=t("training.skip_metas_info"),
167
+ )
168
+ only_unlabeled = gr.Checkbox(
169
+ label=t("training.only_unlabeled"),
170
+ value=False,
171
+ info=t("training.only_unlabeled_info"),
172
+ )
173
+ with gr.Column(scale=1):
174
+ auto_label_btn = gr.Button(
175
+ t("training.auto_label_btn"),
176
+ variant="primary",
177
+ size="lg",
178
+ )
179
+
180
+ label_progress = gr.Textbox(
181
+ label=t("training.label_progress"),
182
+ interactive=False,
183
+ lines=2,
184
+ )
185
+
186
+ gr.HTML(f"<hr><h3>👀 {t('training.step3_title')}</h3>")
187
+
188
+ with gr.Row():
189
+ with gr.Column(scale=1):
190
+ sample_selector = gr.Slider(
191
+ minimum=0,
192
+ maximum=0,
193
+ step=1,
194
+ value=0,
195
+ label=t("training.select_sample"),
196
+ info=t("training.select_sample_info"),
197
+ )
198
+
199
+ preview_audio = gr.Audio(
200
+ label=t("training.audio_preview"),
201
+ type="filepath",
202
+ interactive=False,
203
+ )
204
+
205
+ preview_filename = gr.Textbox(
206
+ label=t("training.filename"),
207
+ interactive=False,
208
+ )
209
+
210
+ with gr.Column(scale=2):
211
+ with gr.Row():
212
+ edit_caption = gr.Textbox(
213
+ label=t("training.caption"),
214
+ lines=3,
215
+ placeholder="Music description...",
216
+ )
217
+
218
+ with gr.Row():
219
+ edit_genre = gr.Textbox(
220
+ label=t("training.genre"),
221
+ lines=1,
222
+ placeholder="pop, electronic, dance...",
223
+ )
224
+ prompt_override = gr.Dropdown(
225
+ choices=["Use Global Ratio", "Caption", "Genre"],
226
+ value="Use Global Ratio",
227
+ label=t("training.prompt_override_label"),
228
+ info=t("training.prompt_override_info"),
229
+ )
230
+
231
+ with gr.Row():
232
+ edit_lyrics = gr.Textbox(
233
+ label=t("training.lyrics_editable_label"),
234
+ lines=6,
235
+ placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
236
+ )
237
+ raw_lyrics_display = gr.Textbox(
238
+ label=t("training.raw_lyrics_label"),
239
+ lines=6,
240
+ placeholder=t("training.no_lyrics_placeholder"),
241
+ interactive=False, # Read-only, can copy but not edit
242
+ visible=False, # Hidden when no raw lyrics
243
+ )
244
+ has_raw_lyrics_state = gr.State(False) # Track visibility
245
+
246
+ with gr.Row():
247
+ edit_bpm = gr.Number(
248
+ label=t("training.bpm"),
249
+ precision=0,
250
+ )
251
+ edit_keyscale = gr.Textbox(
252
+ label=t("training.key_label"),
253
+ placeholder=t("training.key_placeholder"),
254
+ )
255
+ edit_timesig = gr.Dropdown(
256
+ choices=["", "2", "3", "4", "6", "N/A"],
257
+ label=t("training.time_sig"),
258
+ )
259
+ edit_duration = gr.Number(
260
+ label=t("training.duration_s"),
261
+ precision=1,
262
+ interactive=False,
263
+ )
264
+
265
+ with gr.Row():
266
+ edit_language = gr.Dropdown(
267
+ choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
268
+ value="instrumental",
269
+ label=t("training.language"),
270
+ )
271
+ edit_instrumental = gr.Checkbox(
272
+ label=t("training.instrumental"),
273
+ value=True,
274
+ )
275
+ save_edit_btn = gr.Button(t("training.save_changes_btn"), variant="secondary")
276
+
277
+ edit_status = gr.Textbox(
278
+ label=t("training.edit_status"),
279
+ interactive=False,
280
+ )
281
+
282
+ gr.HTML(f"<hr><h3>💾 {t('training.step4_title')}</h3>")
283
+
284
+ with gr.Row():
285
+ with gr.Column(scale=3):
286
+ save_path = gr.Textbox(
287
+ label=t("training.save_path"),
288
+ value="./datasets/my_lora_dataset.json",
289
+ placeholder="./datasets/dataset_name.json",
290
+ info=t("training.save_path_info"),
291
+ )
292
+ with gr.Column(scale=1):
293
+ save_dataset_btn = gr.Button(
294
+ t("training.save_dataset_btn"),
295
+ variant="primary",
296
+ size="lg",
297
+ )
298
+
299
+ save_status = gr.Textbox(
300
+ label=t("training.save_status"),
301
+ interactive=False,
302
+ lines=2,
303
+ )
304
+
305
+ gr.HTML(f"<hr><h3>⚡ {t('training.step5_title')}</h3>")
306
+
307
+ gr.Markdown("""
308
+ **Preprocessing converts your dataset to pre-computed tensors for fast training.**
309
+
310
+ You can either:
311
+ - Use the dataset from Steps 1-4 above, **OR**
312
+ - Load an existing dataset JSON file (if you've already saved one)
313
+ """)
314
+
315
+ with gr.Row():
316
+ with gr.Column(scale=3):
317
+ load_existing_dataset_path = gr.Textbox(
318
+ label=t("training.load_existing_label"),
319
+ placeholder="./datasets/my_lora_dataset.json",
320
+ info=t("training.load_existing_info"),
321
+ )
322
+ with gr.Column(scale=1):
323
+ load_existing_dataset_btn = gr.Button(
324
+ t("training.load_dataset_btn"),
325
+ variant="secondary",
326
+ size="lg",
327
+ )
328
+
329
+ load_existing_status = gr.Textbox(
330
+ label=t("training.load_status"),
331
+ interactive=False,
332
+ )
333
+
334
+ gr.Markdown("""
335
+ This step:
336
+ - Encodes audio to VAE latents
337
+ - Encodes captions and lyrics to text embeddings
338
+ - Runs the condition encoder
339
+ - Saves all tensors to `.pt` files
340
+
341
+ ⚠️ **This requires the model to be loaded and may take a few minutes.**
342
+ """)
343
+
344
+ with gr.Row():
345
+ with gr.Column(scale=3):
346
+ preprocess_output_dir = gr.Textbox(
347
+ label=t("training.tensor_output_dir"),
348
+ value="./datasets/preprocessed_tensors",
349
+ placeholder="./datasets/preprocessed_tensors",
350
+ info=t("training.tensor_output_info"),
351
+ )
352
+ with gr.Column(scale=1):
353
+ preprocess_btn = gr.Button(
354
+ t("training.preprocess_btn"),
355
+ variant="primary",
356
+ size="lg",
357
+ )
358
+
359
+ preprocess_progress = gr.Textbox(
360
+ label=t("training.preprocess_progress"),
361
+ interactive=False,
362
+ lines=3,
363
+ )
364
+
365
+ # ==================== Training Tab ====================
366
+ with gr.Tab(t("training.tab_train_lora")):
367
+ with gr.Row():
368
+ with gr.Column(scale=2):
369
+ gr.HTML(f"<h3>📊 {t('training.train_section_tensors')}</h3>")
370
+
371
+ gr.Markdown("""
372
+ Select the directory containing preprocessed tensor files (`.pt` files).
373
+ These are created in the "Dataset Builder" tab using the "Preprocess" button.
374
+ """)
375
+
376
+ training_tensor_dir = gr.Textbox(
377
+ label=t("training.preprocessed_tensors_dir"),
378
+ placeholder="./datasets/preprocessed_tensors",
379
+ value="./datasets/preprocessed_tensors",
380
+ info=t("training.preprocessed_tensors_info"),
381
+ )
382
+
383
+ load_dataset_btn = gr.Button(t("training.load_dataset_btn"), variant="secondary")
384
+
385
+ training_dataset_info = gr.Textbox(
386
+ label=t("training.dataset_info"),
387
+ interactive=False,
388
+ lines=3,
389
+ )
390
+
391
+ with gr.Column(scale=1):
392
+ gr.HTML(f"<h3>⚙️ {t('training.train_section_lora')}</h3>")
393
+
394
+ lora_rank = gr.Slider(
395
+ minimum=4,
396
+ maximum=256,
397
+ step=4,
398
+ value=64,
399
+ label=t("training.lora_rank"),
400
+ info=t("training.lora_rank_info"),
401
+ )
402
+
403
+ lora_alpha = gr.Slider(
404
+ minimum=4,
405
+ maximum=512,
406
+ step=4,
407
+ value=128,
408
+ label=t("training.lora_alpha"),
409
+ info=t("training.lora_alpha_info"),
410
+ )
411
+
412
+ lora_dropout = gr.Slider(
413
+ minimum=0.0,
414
+ maximum=0.5,
415
+ step=0.05,
416
+ value=0.1,
417
+ label=t("training.lora_dropout"),
418
+ )
419
+
420
+ gr.HTML(f"<hr><h3>🎛️ {t('training.train_section_params')}</h3>")
421
+
422
+ with gr.Row():
423
+ learning_rate = gr.Number(
424
+ label=t("training.learning_rate"),
425
+ value=3e-4,
426
+ info=t("training.learning_rate_info"),
427
+ )
428
+
429
+ train_epochs = gr.Slider(
430
+ minimum=epoch_min,
431
+ maximum=4000,
432
+ step=epoch_step,
433
+ value=epoch_default,
434
+ label=t("training.max_epochs"),
435
+ )
436
+
437
+ train_batch_size = gr.Slider(
438
+ minimum=1,
439
+ maximum=8,
440
+ step=1,
441
+ value=1,
442
+ label=t("training.batch_size"),
443
+ info=t("training.batch_size_info"),
444
+ )
445
+
446
+ gradient_accumulation = gr.Slider(
447
+ minimum=1,
448
+ maximum=16,
449
+ step=1,
450
+ value=1,
451
+ label=t("training.gradient_accumulation"),
452
+ info=t("training.gradient_accumulation_info"),
453
+ )
454
+
455
+ with gr.Row():
456
+ save_every_n_epochs = gr.Slider(
457
+ minimum=50,
458
+ maximum=1000,
459
+ step=50,
460
+ value=200,
461
+ label=t("training.save_every_n_epochs"),
462
+ )
463
+
464
+ training_shift = gr.Slider(
465
+ minimum=1.0,
466
+ maximum=5.0,
467
+ step=0.5,
468
+ value=3.0,
469
+ label=t("training.shift"),
470
+ info=t("training.shift_info"),
471
+ )
472
+
473
+ training_seed = gr.Number(
474
+ label=t("training.seed"),
475
+ value=42,
476
+ precision=0,
477
+ )
478
+
479
+ with gr.Row():
480
+ lora_output_dir = gr.Textbox(
481
+ label=t("training.output_dir"),
482
+ value="./lora_output",
483
+ placeholder="./lora_output",
484
+ info=t("training.output_dir_info"),
485
+ )
486
+
487
+ with gr.Row():
488
+ resume_checkpoint_dir = gr.Textbox(
489
+ label="Resume Checkpoint (optional)",
490
+ placeholder="./lora_output/checkpoints/epoch_200",
491
+ info="Directory of a saved LoRA checkpoint to resume from",
492
+ )
493
+
494
+ gr.HTML("<hr>")
495
+
496
+ with gr.Row():
497
+ with gr.Column(scale=1):
498
+ start_training_btn = gr.Button(
499
+ t("training.start_training_btn"),
500
+ variant="primary",
501
+ size="lg",
502
+ )
503
+ with gr.Column(scale=1):
504
+ stop_training_btn = gr.Button(
505
+ t("training.stop_training_btn"),
506
+ variant="stop",
507
+ size="lg",
508
+ )
509
+
510
+ training_progress = gr.Textbox(
511
+ label=t("training.training_progress"),
512
+ interactive=False,
513
+ lines=2,
514
+ )
515
+
516
+ with gr.Row():
517
+ training_log = gr.Textbox(
518
+ label=t("training.training_log"),
519
+ interactive=False,
520
+ lines=10,
521
+ max_lines=15,
522
+ scale=1,
523
+ )
524
+ training_loss_plot = gr.LinePlot(
525
+ x="step",
526
+ y="loss",
527
+ title=t("training.training_loss_title"),
528
+ x_title=t("training.step"),
529
+ y_title=t("training.loss"),
530
+ scale=1,
531
+ )
532
+
533
+ gr.HTML(f"<hr><h3>📦 {t('training.export_header')}</h3>")
534
+
535
+ with gr.Row():
536
+ export_path = gr.Textbox(
537
+ label=t("training.export_path"),
538
+ value="./lora_output/final_lora",
539
+ placeholder="./lora_output/my_lora",
540
+ )
541
+ export_lora_btn = gr.Button(t("training.export_lora_btn"), variant="secondary")
542
+
543
+ export_status = gr.Textbox(
544
+ label=t("training.export_status"),
545
+ interactive=False,
546
+ )
547
+
548
+ # Store dataset builder state
549
+ dataset_builder_state = gr.State(None)
550
+ training_state = gr.State({"is_training": False, "should_stop": False})
551
+
552
+ return {
553
+ # Dataset Builder - Load or Scan
554
+ "load_json_path": load_json_path,
555
+ "load_json_btn": load_json_btn,
556
+ "load_json_status": load_json_status,
557
+ "audio_directory": audio_directory,
558
+ "scan_btn": scan_btn,
559
+ "scan_status": scan_status,
560
+ "audio_files_table": audio_files_table,
561
+ "dataset_name": dataset_name,
562
+ "all_instrumental": all_instrumental,
563
+ "format_lyrics": format_lyrics,
564
+ "transcribe_lyrics": transcribe_lyrics,
565
+ "custom_tag": custom_tag,
566
+ "tag_position": tag_position,
567
+ "skip_metas": skip_metas,
568
+ "only_unlabeled": only_unlabeled,
569
+ "auto_label_btn": auto_label_btn,
570
+ "label_progress": label_progress,
571
+ "sample_selector": sample_selector,
572
+ "preview_audio": preview_audio,
573
+ "preview_filename": preview_filename,
574
+ "edit_caption": edit_caption,
575
+ "edit_genre": edit_genre,
576
+ "prompt_override": prompt_override,
577
+ "genre_ratio": genre_ratio,
578
+ "edit_lyrics": edit_lyrics,
579
+ "raw_lyrics_display": raw_lyrics_display,
580
+ "has_raw_lyrics_state": has_raw_lyrics_state,
581
+ "edit_bpm": edit_bpm,
582
+ "edit_keyscale": edit_keyscale,
583
+ "edit_timesig": edit_timesig,
584
+ "edit_duration": edit_duration,
585
+ "edit_language": edit_language,
586
+ "edit_instrumental": edit_instrumental,
587
+ "save_edit_btn": save_edit_btn,
588
+ "edit_status": edit_status,
589
+ "save_path": save_path,
590
+ "save_dataset_btn": save_dataset_btn,
591
+ "save_status": save_status,
592
+ # Preprocessing
593
+ "load_existing_dataset_path": load_existing_dataset_path,
594
+ "load_existing_dataset_btn": load_existing_dataset_btn,
595
+ "load_existing_status": load_existing_status,
596
+ "preprocess_output_dir": preprocess_output_dir,
597
+ "preprocess_btn": preprocess_btn,
598
+ "preprocess_progress": preprocess_progress,
599
+ "dataset_builder_state": dataset_builder_state,
600
+ # Training
601
+ "training_tensor_dir": training_tensor_dir,
602
+ "load_dataset_btn": load_dataset_btn,
603
+ "training_dataset_info": training_dataset_info,
604
+ "lora_rank": lora_rank,
605
+ "lora_alpha": lora_alpha,
606
+ "lora_dropout": lora_dropout,
607
+ "learning_rate": learning_rate,
608
+ "train_epochs": train_epochs,
609
+ "train_batch_size": train_batch_size,
610
+ "gradient_accumulation": gradient_accumulation,
611
+ "save_every_n_epochs": save_every_n_epochs,
612
+ "training_shift": training_shift,
613
+ "training_seed": training_seed,
614
+ "lora_output_dir": lora_output_dir,
615
+ "resume_checkpoint_dir": resume_checkpoint_dir,
616
+ "start_training_btn": start_training_btn,
617
+ "stop_training_btn": stop_training_btn,
618
+ "training_progress": training_progress,
619
+ "training_log": training_log,
620
+ "training_loss_plot": training_loss_plot,
621
+ "export_path": export_path,
622
+ "export_lora_btn": export_lora_btn,
623
+ "export_status": export_status,
624
+ "training_state": training_state,
625
+ }
acestep/handler.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/inference.py ADDED
@@ -0,0 +1,1310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Inference API Module
3
+
4
+ This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
+ backward-compatible Gradio UI support.
7
+ """
8
+
9
+ import math
10
+ import os
11
+ import tempfile
12
+ import shutil
13
+ import subprocess
14
+ import sys
15
+ from typing import Optional, Union, List, Dict, Any, Tuple
16
+ from dataclasses import dataclass, field, asdict
17
+ from loguru import logger
18
+
19
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params, is_audio_silent
20
+ from acestep.constants import TASK_INSTRUCTIONS
21
+ from acestep.gpu_config import get_gpu_config
22
+
23
+ # HuggingFace Space environment detection
24
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
25
+
26
+ def _get_spaces_gpu_decorator(duration=180):
27
+ """
28
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
29
+ Returns identity decorator if not in Space environment.
30
+ """
31
+ if IS_HUGGINGFACE_SPACE:
32
+ try:
33
+ import spaces
34
+ return spaces.GPU(duration=duration)
35
+ except ImportError:
36
+ logger.warning("spaces package not found, GPU decorator disabled")
37
+ return lambda func: func
38
+ return lambda func: func
39
+
40
+
41
+ @dataclass
42
+ class GenerationParams:
43
+ """Configuration for music generation parameters.
44
+
45
+ Attributes:
46
+ # Text Inputs
47
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
48
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
49
+ instrumental: If True, generate instrumental music regardless of lyrics.
50
+
51
+ # Music Metadata
52
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
53
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
54
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
55
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
56
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
57
+
58
+ # Generation Parameters
59
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
60
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
61
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
62
+
63
+ # Advanced DiT Parameters
64
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
65
+ cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
66
+ cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
67
+ shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
68
+
69
+ # Task-Specific Parameters
70
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
71
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
72
+ src_audio: Path to a source audio file for audio-to-audio tasks.
73
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
74
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
75
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
76
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
77
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
78
+
79
+ # 5Hz Language Model Parameters for CoT reasoning
80
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
81
+ lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
82
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
83
+ lm_top_k: LLM top-k sampling (0 = disabled).
84
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
85
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
86
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
87
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
88
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
89
+ """
90
+ # Required Inputs
91
+ task_type: str = "text2music"
92
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
93
+
94
+ # Audio Uploads
95
+ reference_audio: Optional[str] = None
96
+ src_audio: Optional[str] = None
97
+
98
+ # LM Codes Hints
99
+ audio_codes: str = ""
100
+
101
+ # Text Inputs
102
+ caption: str = ""
103
+ lyrics: str = ""
104
+ instrumental: bool = False
105
+
106
+ # Metadata
107
+ vocal_language: str = "unknown"
108
+ bpm: Optional[int] = None
109
+ keyscale: str = ""
110
+ timesignature: str = ""
111
+ duration: float = -1.0
112
+
113
+ # Advanced Settings
114
+ inference_steps: int = 8
115
+ seed: int = -1
116
+ guidance_scale: float = 7.0
117
+ use_adg: bool = False
118
+ cfg_interval_start: float = 0.0
119
+ cfg_interval_end: float = 1.0
120
+ shift: float = 1.0
121
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
122
+ # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
123
+ # If provided, overrides inference_steps and shift
124
+ timesteps: Optional[List[float]] = None
125
+
126
+ repainting_start: float = 0.0
127
+ repainting_end: float = -1
128
+ audio_cover_strength: float = 1.0
129
+
130
+ # 5Hz Language Model Parameters
131
+ thinking: bool = True
132
+ lm_temperature: float = 0.85
133
+ lm_cfg_scale: float = 2.0
134
+ lm_top_k: int = 0
135
+ lm_top_p: float = 0.9
136
+ lm_negative_prompt: str = "NO USER INPUT"
137
+ use_cot_metas: bool = True
138
+ use_cot_caption: bool = True
139
+ use_cot_lyrics: bool = False # TODO: not used yet
140
+ use_cot_language: bool = True
141
+ use_constrained_decoding: bool = True
142
+
143
+ cot_bpm: Optional[int] = None
144
+ cot_keyscale: str = ""
145
+ cot_timesignature: str = ""
146
+ cot_duration: Optional[float] = None
147
+ cot_vocal_language: str = "unknown"
148
+ cot_caption: str = ""
149
+ cot_lyrics: str = ""
150
+
151
+ def to_dict(self) -> Dict[str, Any]:
152
+ """Convert config to dictionary for JSON serialization."""
153
+ return asdict(self)
154
+
155
+
156
+ @dataclass
157
+ class GenerationConfig:
158
+ """Configuration for music generation.
159
+
160
+ Attributes:
161
+ batch_size: Number of audio samples to generate
162
+ allow_lm_batch: Whether to allow batch processing in LM
163
+ use_random_seed: Whether to use random seed
164
+ seeds: Seed(s) for batch generation. Can be:
165
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
166
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
167
+ - int: Single seed value (will be converted to list and padded)
168
+ lm_batch_chunk_size: Batch chunk size for LM processing
169
+ constrained_decoding_debug: Whether to enable constrained decoding debug
170
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
171
+ """
172
+ batch_size: int = 2
173
+ allow_lm_batch: bool = False
174
+ use_random_seed: bool = True
175
+ seeds: Optional[List[int]] = None
176
+ lm_batch_chunk_size: int = 8
177
+ constrained_decoding_debug: bool = False
178
+ audio_format: str = "flac" # Default to FLAC for fast saving
179
+
180
+ def to_dict(self) -> Dict[str, Any]:
181
+ """Convert config to dictionary for JSON serialization."""
182
+ return asdict(self)
183
+
184
+
185
+ @dataclass
186
+ class GenerationResult:
187
+ """Result of music generation.
188
+
189
+ Attributes:
190
+ # Audio Outputs
191
+ audios: List of audio dictionaries with paths, keys, params
192
+ status_message: Status message from generation
193
+ extra_outputs: Extra outputs from generation
194
+ success: Whether generation completed successfully
195
+ error: Error message if generation failed
196
+ """
197
+
198
+ # Audio Outputs
199
+ audios: List[Dict[str, Any]] = field(default_factory=list)
200
+ # Generation Information
201
+ status_message: str = ""
202
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
203
+ # Success Status
204
+ success: bool = True
205
+ error: Optional[str] = None
206
+
207
+ def to_dict(self) -> Dict[str, Any]:
208
+ """Convert result to dictionary for JSON serialization."""
209
+ return asdict(self)
210
+
211
+
212
+ @dataclass
213
+ class UnderstandResult:
214
+ """Result of music understanding from audio codes.
215
+
216
+ Attributes:
217
+ # Metadata Fields
218
+ caption: Generated caption describing the music
219
+ lyrics: Generated or extracted lyrics
220
+ bpm: Beats per minute (None if not detected)
221
+ duration: Duration in seconds (None if not detected)
222
+ keyscale: Musical key (e.g., "C Major")
223
+ language: Vocal language code (e.g., "en", "zh")
224
+ timesignature: Time signature (e.g., "4/4")
225
+
226
+ # Status
227
+ status_message: Status message from understanding
228
+ success: Whether understanding completed successfully
229
+ error: Error message if understanding failed
230
+ """
231
+ # Metadata Fields
232
+ caption: str = ""
233
+ lyrics: str = ""
234
+ bpm: Optional[int] = None
235
+ duration: Optional[float] = None
236
+ keyscale: str = ""
237
+ language: str = ""
238
+ timesignature: str = ""
239
+
240
+ # Status
241
+ status_message: str = ""
242
+ success: bool = True
243
+ error: Optional[str] = None
244
+
245
+ def to_dict(self) -> Dict[str, Any]:
246
+ """Convert result to dictionary for JSON serialization."""
247
+ return asdict(self)
248
+
249
+
250
+ def _update_metadata_from_lm(
251
+ metadata: Dict[str, Any],
252
+ bpm: Optional[int],
253
+ key_scale: str,
254
+ time_signature: str,
255
+ audio_duration: Optional[float],
256
+ vocal_language: str,
257
+ caption: str,
258
+ lyrics: str,
259
+ ) -> Tuple[Optional[int], str, str, Optional[float], str, str, str]:
260
+ """Update metadata fields from LM output if not provided by user."""
261
+
262
+ if bpm is None and metadata.get('bpm'):
263
+ bpm_value = metadata.get('bpm')
264
+ if bpm_value not in ["N/A", ""]:
265
+ try:
266
+ bpm = int(bpm_value)
267
+ except (ValueError, TypeError):
268
+ pass
269
+
270
+ if not key_scale and metadata.get('keyscale'):
271
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
272
+ if key_scale_value != "N/A":
273
+ key_scale = key_scale_value
274
+
275
+ if not time_signature and metadata.get('timesignature'):
276
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
277
+ if time_signature_value != "N/A":
278
+ time_signature = time_signature_value
279
+
280
+ if audio_duration is None or audio_duration <= 0:
281
+ audio_duration_value = metadata.get('duration', -1)
282
+ if audio_duration_value not in ["N/A", ""]:
283
+ try:
284
+ audio_duration = float(audio_duration_value)
285
+ except (ValueError, TypeError):
286
+ pass
287
+
288
+ if not vocal_language and metadata.get('vocal_language'):
289
+ vocal_language = metadata.get('vocal_language')
290
+ if not caption and metadata.get('caption'):
291
+ caption = metadata.get('caption')
292
+ if not lyrics and metadata.get('lyrics'):
293
+ lyrics = metadata.get('lyrics')
294
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
295
+
296
+
297
+ @_get_spaces_gpu_decorator(duration=180)
298
+ def generate_music(
299
+ dit_handler,
300
+ llm_handler,
301
+ params: GenerationParams,
302
+ config: GenerationConfig,
303
+ save_dir: Optional[str] = None,
304
+ progress=None,
305
+ ) -> GenerationResult:
306
+ """Generate music using ACE-Step model with optional LM reasoning.
307
+
308
+ Args:
309
+ dit_handler: Initialized DiT model handler (AceStepHandler instance)
310
+ llm_handler: Initialized LLM handler (LLMHandler instance)
311
+ params: Generation parameters (GenerationParams instance)
312
+ config: Generation configuration (GenerationConfig instance)
313
+
314
+ Returns:
315
+ GenerationResult with generated audio files and metadata
316
+ """
317
+ try:
318
+ # Phase 1: LM-based metadata and code generation (if enabled)
319
+ audio_code_string_to_use = params.audio_codes
320
+ lm_generated_metadata = None
321
+ lm_generated_audio_codes_list = []
322
+ lm_total_time_costs = {
323
+ "phase1_time": 0.0,
324
+ "phase2_time": 0.0,
325
+ "total_time": 0.0,
326
+ }
327
+
328
+ # Extract mutable copies of metadata (will be updated by LM if needed)
329
+ bpm = params.bpm
330
+ key_scale = params.keyscale
331
+ time_signature = params.timesignature
332
+ audio_duration = params.duration
333
+ dit_input_caption = params.caption
334
+ dit_input_vocal_language = params.vocal_language
335
+ dit_input_lyrics = params.lyrics
336
+ # Determine if we need to generate audio codes
337
+ # If user has provided audio_codes, we don't need to generate them
338
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
339
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
340
+
341
+ # Safety: cover task without any source audio or codes produces silence.
342
+ if params.task_type == "cover":
343
+ no_src_audio = not (params.reference_audio or params.src_audio)
344
+ if no_src_audio and not user_provided_audio_codes:
345
+ logger.warning("Cover task requested without source audio or audio codes. Falling back to text2music.")
346
+ params.task_type = "text2music"
347
+ if params.instruction == TASK_INSTRUCTIONS.get("cover"):
348
+ params.instruction = TASK_INSTRUCTIONS.get("text2music", params.instruction)
349
+
350
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
351
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
352
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
353
+ # Note: This logic can be refined based on specific requirements
354
+ need_audio_codes = not user_provided_audio_codes
355
+
356
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
357
+ # Determine actual batch size for chunk processing
358
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
359
+
360
+ # Prepare seeds for batch generation
361
+ # Use config.seed if provided, otherwise fallback to params.seed
362
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
363
+ seed_for_generation = ""
364
+ # Original code (commented out because it crashes on int seeds):
365
+ # if config.seeds is not None and len(config.seeds) > 0:
366
+ # if isinstance(config.seeds, list):
367
+ # # Convert List[int] to comma-separated string
368
+ # seed_for_generation = ",".join(str(s) for s in config.seeds)
369
+
370
+ if config.seeds is not None:
371
+ if isinstance(config.seeds, list) and len(config.seeds) > 0:
372
+ # Convert List[int] to comma-separated string
373
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
374
+ elif isinstance(config.seeds, int):
375
+ # Fix: Explicitly handle single integer seeds by converting to string.
376
+ # Previously, this would crash because 'len()' was called on an int.
377
+ seed_for_generation = str(config.seeds)
378
+
379
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
380
+ # This will handle all the logic: padding with random seeds if needed, etc.
381
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
382
+
383
+ # LM-based Chain-of-Thought reasoning
384
+ # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
385
+ # and don't need LM to generate audio codes
386
+ skip_lm_tasks = {"cover", "repaint"}
387
+
388
+ # Determine if we should use LLM
389
+ # LLM is needed for:
390
+ # 1. thinking=True: generate audio codes via LM
391
+ # 2. use_cot_caption=True: enhance/generate caption via CoT
392
+ # 3. use_cot_language=True: detect vocal language via CoT
393
+ # 4. use_cot_metas=True: fill missing metadata via CoT
394
+ need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
395
+ use_lm = (params.thinking or need_lm_for_cot) and llm_handler is not None and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
396
+ lm_status = []
397
+
398
+ if params.task_type in skip_lm_tasks:
399
+ logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
400
+
401
+ logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
402
+ f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
403
+ f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
404
+ f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
405
+
406
+ def _infer_audio_duration_seconds(audio_path: str) -> Optional[float]:
407
+ """Best-effort duration inference for common audio formats."""
408
+ if not audio_path:
409
+ return None
410
+ # Try torchaudio (supports more formats when ffmpeg backend is available)
411
+ try:
412
+ import torchaudio
413
+ info = torchaudio.info(audio_path)
414
+ if info and info.num_frames and info.sample_rate:
415
+ return float(info.num_frames) / float(info.sample_rate)
416
+ except Exception:
417
+ pass
418
+ # Try soundfile (fast for wav/flac)
419
+ try:
420
+ import soundfile as sf
421
+ info = sf.info(audio_path)
422
+ if info and info.frames and info.samplerate:
423
+ return float(info.frames) / float(info.samplerate)
424
+ except Exception:
425
+ pass
426
+ # macOS fallback: use afinfo for m4a/aac
427
+ if sys.platform == "darwin" and shutil.which("afinfo"):
428
+ try:
429
+ result = subprocess.run(
430
+ ["afinfo", audio_path],
431
+ check=False,
432
+ capture_output=True,
433
+ text=True,
434
+ )
435
+ if result.stdout:
436
+ for line in result.stdout.splitlines():
437
+ if "duration:" in line:
438
+ # Example: "duration: 183.165s"
439
+ parts = line.strip().split()
440
+ for p in parts:
441
+ if p.endswith("s"):
442
+ try:
443
+ return float(p.rstrip("s"))
444
+ except ValueError:
445
+ continue
446
+ except Exception:
447
+ pass
448
+ return None
449
+
450
+ # Clamp duration and batch size to GPU limits (applies to non-Gradio callers too)
451
+ try:
452
+ # If duration not provided, try to infer from source audio to enable safe clamping.
453
+ if (audio_duration is None or float(audio_duration) <= 0) and (params.src_audio or params.reference_audio):
454
+ audio_path = params.src_audio or params.reference_audio
455
+ try:
456
+ inferred = _infer_audio_duration_seconds(audio_path)
457
+ if inferred and inferred > 0:
458
+ audio_duration = inferred
459
+ params.duration = inferred
460
+ logger.info(f"[generate_music] Inferred duration from audio file: {inferred:.2f}s")
461
+ except Exception as e:
462
+ logger.warning(f"[generate_music] Failed to infer duration from audio file: {e}")
463
+
464
+ gpu_config = get_gpu_config()
465
+ max_duration = gpu_config.max_duration_with_lm if use_lm else gpu_config.max_duration_without_lm
466
+ if audio_duration is not None and float(audio_duration) > 0 and float(audio_duration) > max_duration:
467
+ logger.warning(f"[generate_music] Duration {audio_duration}s exceeds GPU limit {max_duration}s. Clamping.")
468
+ audio_duration = float(max_duration)
469
+ params.duration = float(max_duration)
470
+
471
+ max_batch = gpu_config.max_batch_size_with_lm if use_lm else gpu_config.max_batch_size_without_lm
472
+ if config.batch_size is not None and config.batch_size > max_batch:
473
+ logger.warning(f"[generate_music] Batch size {config.batch_size} exceeds GPU limit {max_batch}. Clamping.")
474
+ config.batch_size = max_batch
475
+
476
+ # Extra safety for MPS: large durations can OOM with batch > 1
477
+ if (
478
+ hasattr(dit_handler, "device")
479
+ and dit_handler.device == "mps"
480
+ and audio_duration is not None
481
+ and float(audio_duration) > 180
482
+ and config.batch_size is not None
483
+ and config.batch_size > 1
484
+ ):
485
+ logger.warning("[generate_music] MPS with long duration detected; reducing batch size to 1 to avoid OOM.")
486
+ config.batch_size = 1
487
+ except Exception as e:
488
+ logger.warning(f"[generate_music] Failed to clamp duration/batch to GPU limits: {e}")
489
+
490
+ if use_lm:
491
+ # Convert sampling parameters - handle None values safely
492
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
493
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
494
+
495
+ # Build user_metadata from user-provided values
496
+ user_metadata = {}
497
+ if bpm is not None:
498
+ try:
499
+ bpm_value = float(bpm)
500
+ if bpm_value > 0:
501
+ user_metadata['bpm'] = int(bpm_value)
502
+ except (ValueError, TypeError):
503
+ pass
504
+
505
+ if key_scale and key_scale.strip():
506
+ key_scale_clean = key_scale.strip()
507
+ if key_scale_clean.lower() not in ["n/a", ""]:
508
+ user_metadata['keyscale'] = key_scale_clean
509
+
510
+ if time_signature and time_signature.strip():
511
+ time_sig_clean = time_signature.strip()
512
+ if time_sig_clean.lower() not in ["n/a", ""]:
513
+ user_metadata['timesignature'] = time_sig_clean
514
+
515
+ if audio_duration is not None:
516
+ try:
517
+ duration_value = float(audio_duration)
518
+ if duration_value > 0:
519
+ user_metadata['duration'] = int(duration_value)
520
+ except (ValueError, TypeError):
521
+ pass
522
+
523
+ user_metadata_to_pass = user_metadata if user_metadata else None
524
+
525
+ # Determine infer_type based on whether we need audio codes
526
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
527
+ # - "dit": generates only metas (single phase)
528
+ infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
529
+
530
+ # Use chunk size from config, or default to batch_size if not set
531
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
532
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
533
+
534
+ all_metadata_list = []
535
+ all_audio_codes_list = []
536
+
537
+ for chunk_idx in range(num_chunks):
538
+ chunk_start = chunk_idx * max_inference_batch_size
539
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
540
+ chunk_size = chunk_end - chunk_start
541
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
542
+
543
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
544
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
545
+
546
+ # Use the determined infer_type
547
+ # - "llm_dit" will internally run two phases (metas + codes)
548
+ # - "dit" will only run phase 1 (metas only)
549
+ result = llm_handler.generate_with_stop_condition(
550
+ caption=params.caption or "",
551
+ lyrics=params.lyrics or "",
552
+ infer_type=infer_type,
553
+ temperature=params.lm_temperature,
554
+ cfg_scale=params.lm_cfg_scale,
555
+ negative_prompt=params.lm_negative_prompt,
556
+ top_k=top_k_value,
557
+ top_p=top_p_value,
558
+ target_duration=audio_duration, # Pass duration to limit audio codes generation
559
+ user_metadata=user_metadata_to_pass,
560
+ use_cot_caption=params.use_cot_caption,
561
+ use_cot_language=params.use_cot_language,
562
+ use_cot_metas=params.use_cot_metas,
563
+ use_constrained_decoding=params.use_constrained_decoding,
564
+ constrained_decoding_debug=config.constrained_decoding_debug,
565
+ batch_size=chunk_size,
566
+ seeds=chunk_seeds,
567
+ progress=progress,
568
+ )
569
+
570
+ # Check if LM generation failed
571
+ if not result.get("success", False):
572
+ error_msg = result.get("error", "Unknown LM error")
573
+ lm_status.append(f"❌ LM Error: {error_msg}")
574
+ # Return early with error
575
+ return GenerationResult(
576
+ audios=[],
577
+ status_message=f"❌ LM generation failed: {error_msg}",
578
+ extra_outputs={},
579
+ success=False,
580
+ error=error_msg,
581
+ )
582
+
583
+ # Extract metadata and audio_codes from result dict
584
+ if chunk_size > 1:
585
+ metadata_list = result.get("metadata", [])
586
+ audio_codes_list = result.get("audio_codes", [])
587
+ all_metadata_list.extend(metadata_list)
588
+ all_audio_codes_list.extend(audio_codes_list)
589
+ else:
590
+ metadata = result.get("metadata", {})
591
+ audio_codes = result.get("audio_codes", "")
592
+ all_metadata_list.append(metadata)
593
+ all_audio_codes_list.append(audio_codes)
594
+
595
+ # Collect time costs from LM extra_outputs
596
+ lm_extra = result.get("extra_outputs", {})
597
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
598
+ if lm_chunk_time_costs:
599
+ # Accumulate time costs from all chunks
600
+ for key in ["phase1_time", "phase2_time", "total_time"]:
601
+ if key in lm_chunk_time_costs:
602
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
603
+
604
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
605
+ lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
606
+
607
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
608
+ lm_generated_audio_codes_list = all_audio_codes_list
609
+
610
+ # Set audio_code_string_to_use based on infer_type
611
+ if infer_type == "llm_dit":
612
+ # If batch mode, use list; otherwise use single string
613
+ if actual_batch_size > 1:
614
+ audio_code_string_to_use = all_audio_codes_list
615
+ else:
616
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
617
+ else:
618
+ # For "dit" mode, keep user-provided codes or empty
619
+ audio_code_string_to_use = params.audio_codes
620
+
621
+ # Update metadata from LM if not provided by user
622
+ if lm_generated_metadata:
623
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
624
+ metadata=lm_generated_metadata,
625
+ bpm=bpm,
626
+ key_scale=key_scale,
627
+ time_signature=time_signature,
628
+ audio_duration=audio_duration,
629
+ vocal_language=dit_input_vocal_language,
630
+ caption=dit_input_caption,
631
+ lyrics=dit_input_lyrics)
632
+ if not params.bpm:
633
+ params.cot_bpm = bpm
634
+ if not params.keyscale:
635
+ params.cot_keyscale = key_scale
636
+ if not params.timesignature:
637
+ params.cot_timesignature = time_signature
638
+ if not params.duration:
639
+ params.cot_duration = audio_duration
640
+ if not params.vocal_language:
641
+ params.cot_vocal_language = vocal_language
642
+ if not params.caption:
643
+ params.cot_caption = caption
644
+ if not params.lyrics:
645
+ params.cot_lyrics = lyrics
646
+ dit_input_lyrics = lyrics
647
+
648
+ # set cot caption and language if needed
649
+ if params.use_cot_caption:
650
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
651
+ if params.use_cot_language:
652
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
653
+
654
+ # Phase 2: DiT music generation
655
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
656
+ result = dit_handler.generate_music(
657
+ captions=dit_input_caption,
658
+ lyrics=dit_input_lyrics,
659
+ bpm=bpm,
660
+ key_scale=key_scale,
661
+ time_signature=time_signature,
662
+ vocal_language=dit_input_vocal_language,
663
+ inference_steps=params.inference_steps,
664
+ guidance_scale=params.guidance_scale,
665
+ use_random_seed=config.use_random_seed,
666
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
667
+ reference_audio=params.reference_audio,
668
+ audio_duration=audio_duration,
669
+ batch_size=config.batch_size if config.batch_size is not None else 1,
670
+ src_audio=params.src_audio,
671
+ audio_code_string=audio_code_string_to_use,
672
+ repainting_start=params.repainting_start,
673
+ repainting_end=params.repainting_end,
674
+ instruction=params.instruction,
675
+ audio_cover_strength=params.audio_cover_strength,
676
+ task_type=params.task_type,
677
+ use_adg=params.use_adg,
678
+ cfg_interval_start=params.cfg_interval_start,
679
+ cfg_interval_end=params.cfg_interval_end,
680
+ shift=params.shift,
681
+ infer_method=params.infer_method,
682
+ timesteps=params.timesteps,
683
+ progress=progress,
684
+ )
685
+
686
+ # Check if generation failed
687
+ if not result.get("success", False):
688
+ return GenerationResult(
689
+ audios=[],
690
+ status_message=result.get("status_message", ""),
691
+ extra_outputs={},
692
+ success=False,
693
+ error=result.get("error"),
694
+ )
695
+
696
+ # Extract results from dit_handler.generate_music dict
697
+ dit_audios = result.get("audios", [])
698
+ status_message = result.get("status_message", "")
699
+ dit_extra_outputs = result.get("extra_outputs", {})
700
+
701
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
702
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
703
+ seed_list = actual_seed_list
704
+
705
+ # Get base params dictionary
706
+ base_params_dict = params.to_dict()
707
+
708
+ # Save audio files using AudioSaver (format from config)
709
+ audio_format = config.audio_format if config.audio_format else "flac"
710
+ audio_saver = AudioSaver(default_format=audio_format)
711
+
712
+ # Use handler's temp_dir for saving files
713
+ if save_dir is not None:
714
+ os.makedirs(save_dir, exist_ok=True)
715
+
716
+ # Build audios list for GenerationResult with params and save files
717
+ # Audio saving and UUID generation handled here, outside of handler
718
+ audios = []
719
+ silent_warnings = []
720
+ for idx, dit_audio in enumerate(dit_audios):
721
+ # Create a copy of params dict for this audio
722
+ audio_params = base_params_dict.copy()
723
+
724
+ # Update audio-specific values
725
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
726
+
727
+ # Add audio codes if batch mode
728
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
729
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
730
+
731
+ # Get audio tensor and metadata
732
+ audio_tensor = dit_audio.get("tensor")
733
+ sample_rate = dit_audio.get("sample_rate", 48000)
734
+
735
+ # Generate UUID for this audio (moved from handler)
736
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
737
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
738
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
739
+ if isinstance(audio_code_str, list):
740
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
741
+
742
+ audio_key = generate_uuid_from_params(audio_params)
743
+
744
+ silent_check = False
745
+ if audio_tensor is not None:
746
+ silent_check, rms_val, peak_val = is_audio_silent(audio_tensor, channels_first=True)
747
+ if silent_check:
748
+ logger.warning(
749
+ f"[generate_music] Silent output detected (idx={idx}, RMS={rms_val:.2e}, peak={peak_val:.2e}). "
750
+ "Likely cause: LLM backend returned empty conditioning, or incompatible torch/triton/flash-attn. "
751
+ "Suggest running with --backend pt."
752
+ )
753
+ silent_warnings.append(
754
+ f"Output {idx + 1}: silent or near-silent (RMS≈{rms_val:.2e}). "
755
+ "Likely causes: LLM backend failure, incompatible torch/triton/flash-attn, or CPU/fallback path. "
756
+ "Try running with --backend pt."
757
+ )
758
+
759
+ audio_path = None
760
+ if audio_tensor is not None and save_dir is not None and not silent_check:
761
+ try:
762
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
763
+ audio_path = audio_saver.save_audio(audio_tensor,
764
+ audio_file,
765
+ sample_rate=sample_rate,
766
+ format=audio_format,
767
+ channels_first=True)
768
+ except Exception as e:
769
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
770
+ audio_path = ""
771
+
772
+ audio_dict = {
773
+ "path": audio_path or "",
774
+ "tensor": audio_tensor,
775
+ "key": audio_key,
776
+ "sample_rate": sample_rate,
777
+ "params": audio_params,
778
+ "silent": silent_check,
779
+ }
780
+
781
+ audios.append(audio_dict)
782
+
783
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
784
+ extra_outputs = dit_extra_outputs.copy()
785
+ extra_outputs["lm_metadata"] = lm_generated_metadata
786
+
787
+ # Merge time_costs from both LM and DiT into a unified dictionary
788
+ unified_time_costs = {}
789
+
790
+ # Add LM time costs (if LM was used)
791
+ if use_lm and lm_total_time_costs:
792
+ for key, value in lm_total_time_costs.items():
793
+ unified_time_costs[f"lm_{key}"] = value
794
+
795
+ # Add DiT time costs (if available)
796
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
797
+ if dit_time_costs:
798
+ for key, value in dit_time_costs.items():
799
+ unified_time_costs[f"dit_{key}"] = value
800
+
801
+ # Calculate total pipeline time
802
+ if unified_time_costs:
803
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
804
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
805
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
806
+
807
+ # Update extra_outputs with unified time_costs
808
+ extra_outputs["time_costs"] = unified_time_costs
809
+
810
+ if lm_status:
811
+ status_message = "\n".join(lm_status) + "\n" + status_message
812
+ else:
813
+ status_message = status_message
814
+ if silent_warnings:
815
+ status_message = "⚠️ Silent output detected:\n" + "\n".join(silent_warnings) + "\n\nSuggested fix: try running with --backend pt\n\n" + (status_message or "")
816
+ # Create and return GenerationResult
817
+ return GenerationResult(
818
+ audios=audios,
819
+ status_message=status_message,
820
+ extra_outputs=extra_outputs,
821
+ success=True,
822
+ error=None,
823
+ )
824
+
825
+ except Exception as e:
826
+ logger.exception("Music generation failed")
827
+ return GenerationResult(
828
+ audios=[],
829
+ status_message=f"Error: {str(e)}",
830
+ extra_outputs={},
831
+ success=False,
832
+ error=str(e),
833
+ )
834
+
835
+
836
+ def understand_music(
837
+ llm_handler,
838
+ audio_codes: str,
839
+ temperature: float = 0.85,
840
+ top_k: Optional[int] = None,
841
+ top_p: Optional[float] = None,
842
+ repetition_penalty: float = 1.0,
843
+ use_constrained_decoding: bool = True,
844
+ constrained_decoding_debug: bool = False,
845
+ ) -> UnderstandResult:
846
+ """Understand music from audio codes using the 5Hz Language Model.
847
+
848
+ This function analyzes audio semantic codes and generates metadata about the music,
849
+ including caption, lyrics, BPM, duration, key scale, language, and time signature.
850
+
851
+ If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
852
+ instead of analyzing existing codes.
853
+
854
+ Note: cfg_scale and negative_prompt are not supported in understand mode.
855
+
856
+ Args:
857
+ llm_handler: Initialized LLM handler (LLMHandler instance)
858
+ audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
859
+ Use empty string or "NO USER INPUT" to generate a sample example.
860
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
861
+ top_k: Top-K sampling (None or 0 = disabled)
862
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
863
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
864
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
865
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
866
+
867
+ Returns:
868
+ UnderstandResult with parsed metadata fields and status
869
+
870
+ Example:
871
+ >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
872
+ >>> if result.success:
873
+ ... print(f"Caption: {result.caption}")
874
+ ... print(f"BPM: {result.bpm}")
875
+ ... print(f"Lyrics: {result.lyrics}")
876
+ """
877
+ # Check if LLM is initialized
878
+ if not llm_handler.llm_initialized:
879
+ return UnderstandResult(
880
+ status_message="5Hz LM not initialized. Please initialize it first.",
881
+ success=False,
882
+ error="LLM not initialized",
883
+ )
884
+
885
+ # If codes are empty, use "NO USER INPUT" to generate a sample example
886
+ if not audio_codes or not audio_codes.strip():
887
+ audio_codes = "NO USER INPUT"
888
+
889
+ try:
890
+ # Call LLM understanding
891
+ metadata, status = llm_handler.understand_audio_from_codes(
892
+ audio_codes=audio_codes,
893
+ temperature=temperature,
894
+ top_k=top_k,
895
+ top_p=top_p,
896
+ repetition_penalty=repetition_penalty,
897
+ use_constrained_decoding=use_constrained_decoding,
898
+ constrained_decoding_debug=constrained_decoding_debug,
899
+ )
900
+
901
+ # Check if LLM returned empty metadata (error case)
902
+ if not metadata:
903
+ return UnderstandResult(
904
+ status_message=status or "Failed to understand audio codes",
905
+ success=False,
906
+ error=status or "Empty metadata returned",
907
+ )
908
+
909
+ # Extract and convert fields
910
+ caption = metadata.get('caption', '')
911
+ lyrics = metadata.get('lyrics', '')
912
+ keyscale = metadata.get('keyscale', '')
913
+ language = metadata.get('language', metadata.get('vocal_language', ''))
914
+ timesignature = metadata.get('timesignature', '')
915
+
916
+ # Convert BPM to int
917
+ bpm = None
918
+ bpm_value = metadata.get('bpm')
919
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
920
+ try:
921
+ bpm = int(bpm_value)
922
+ except (ValueError, TypeError):
923
+ pass
924
+
925
+ # Convert duration to float
926
+ duration = None
927
+ duration_value = metadata.get('duration')
928
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
929
+ try:
930
+ duration = float(duration_value)
931
+ except (ValueError, TypeError):
932
+ pass
933
+
934
+ # Clean up N/A values
935
+ if keyscale == 'N/A':
936
+ keyscale = ''
937
+ if language == 'N/A':
938
+ language = ''
939
+ if timesignature == 'N/A':
940
+ timesignature = ''
941
+
942
+ return UnderstandResult(
943
+ caption=caption,
944
+ lyrics=lyrics,
945
+ bpm=bpm,
946
+ duration=duration,
947
+ keyscale=keyscale,
948
+ language=language,
949
+ timesignature=timesignature,
950
+ status_message=status,
951
+ success=True,
952
+ error=None,
953
+ )
954
+
955
+ except Exception as e:
956
+ logger.exception("Music understanding failed")
957
+ return UnderstandResult(
958
+ status_message=f"Error: {str(e)}",
959
+ success=False,
960
+ error=str(e),
961
+ )
962
+
963
+
964
+ @dataclass
965
+ class CreateSampleResult:
966
+ """Result of creating a music sample from a natural language query.
967
+
968
+ This is used by the "Simple Mode" / "Inspiration Mode" feature where users
969
+ provide a natural language description and the LLM generates a complete
970
+ sample with caption, lyrics, and metadata.
971
+
972
+ Attributes:
973
+ # Metadata Fields
974
+ caption: Generated detailed music description/caption
975
+ lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
976
+ bpm: Beats per minute (None if not generated)
977
+ duration: Duration in seconds (None if not generated)
978
+ keyscale: Musical key (e.g., "C Major")
979
+ language: Vocal language code (e.g., "en", "zh")
980
+ timesignature: Time signature (e.g., "4")
981
+ instrumental: Whether this is an instrumental piece
982
+
983
+ # Status
984
+ status_message: Status message from sample creation
985
+ success: Whether sample creation completed successfully
986
+ error: Error message if sample creation failed
987
+ """
988
+ # Metadata Fields
989
+ caption: str = ""
990
+ lyrics: str = ""
991
+ bpm: Optional[int] = None
992
+ duration: Optional[float] = None
993
+ keyscale: str = ""
994
+ language: str = ""
995
+ timesignature: str = ""
996
+ instrumental: bool = False
997
+
998
+ # Status
999
+ status_message: str = ""
1000
+ success: bool = True
1001
+ error: Optional[str] = None
1002
+
1003
+ def to_dict(self) -> Dict[str, Any]:
1004
+ """Convert result to dictionary for JSON serialization."""
1005
+ return asdict(self)
1006
+
1007
+
1008
+ def create_sample(
1009
+ llm_handler,
1010
+ query: str,
1011
+ instrumental: bool = False,
1012
+ vocal_language: Optional[str] = None,
1013
+ temperature: float = 0.85,
1014
+ top_k: Optional[int] = None,
1015
+ top_p: Optional[float] = None,
1016
+ repetition_penalty: float = 1.0,
1017
+ use_constrained_decoding: bool = True,
1018
+ constrained_decoding_debug: bool = False,
1019
+ ) -> CreateSampleResult:
1020
+ """Create a music sample from a natural language query using the 5Hz Language Model.
1021
+
1022
+ This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
1023
+ language description of music and generates a complete sample including:
1024
+ - Detailed caption/description
1025
+ - Lyrics (unless instrumental)
1026
+ - Metadata (BPM, duration, key, language, time signature)
1027
+
1028
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
1029
+
1030
+ Args:
1031
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1032
+ query: User's natural language music description (e.g., "a soft Bengali love song")
1033
+ instrumental: Whether to generate instrumental music (no vocals)
1034
+ vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
1035
+ If provided, the model will be constrained to generate lyrics in this language.
1036
+ If None or "unknown", no language constraint is applied.
1037
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1038
+ top_k: Top-K sampling (None or 0 = disabled)
1039
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1040
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1041
+ use_constrained_decoding: Whether to use FSM-based constrained decoding
1042
+ constrained_decoding_debug: Whether to enable debug logging
1043
+
1044
+ Returns:
1045
+ CreateSampleResult with generated sample fields and status
1046
+
1047
+ Example:
1048
+ >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
1049
+ >>> if result.success:
1050
+ ... print(f"Caption: {result.caption}")
1051
+ ... print(f"Lyrics: {result.lyrics}")
1052
+ ... print(f"BPM: {result.bpm}")
1053
+ """
1054
+ # Check if LLM is initialized
1055
+ if not llm_handler.llm_initialized:
1056
+ return CreateSampleResult(
1057
+ status_message="5Hz LM not initialized. Please initialize it first.",
1058
+ success=False,
1059
+ error="LLM not initialized",
1060
+ )
1061
+
1062
+ try:
1063
+ # Call LLM to create sample
1064
+ metadata, status = llm_handler.create_sample_from_query(
1065
+ query=query,
1066
+ instrumental=instrumental,
1067
+ vocal_language=vocal_language,
1068
+ temperature=temperature,
1069
+ top_k=top_k,
1070
+ top_p=top_p,
1071
+ repetition_penalty=repetition_penalty,
1072
+ use_constrained_decoding=use_constrained_decoding,
1073
+ constrained_decoding_debug=constrained_decoding_debug,
1074
+ )
1075
+
1076
+ # Check if LLM returned empty metadata (error case)
1077
+ if not metadata:
1078
+ return CreateSampleResult(
1079
+ status_message=status or "Failed to create sample",
1080
+ success=False,
1081
+ error=status or "Empty metadata returned",
1082
+ )
1083
+
1084
+ # Extract and convert fields
1085
+ caption = metadata.get('caption', '')
1086
+ lyrics = metadata.get('lyrics', '')
1087
+ keyscale = metadata.get('keyscale', '')
1088
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1089
+ timesignature = metadata.get('timesignature', '')
1090
+ is_instrumental = metadata.get('instrumental', instrumental)
1091
+
1092
+ # Convert BPM to int
1093
+ bpm = None
1094
+ bpm_value = metadata.get('bpm')
1095
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1096
+ try:
1097
+ bpm = int(bpm_value)
1098
+ except (ValueError, TypeError):
1099
+ pass
1100
+
1101
+ # Convert duration to float
1102
+ duration = None
1103
+ duration_value = metadata.get('duration')
1104
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1105
+ try:
1106
+ duration = float(duration_value)
1107
+ except (ValueError, TypeError):
1108
+ pass
1109
+
1110
+ # Clean up N/A values
1111
+ if keyscale == 'N/A':
1112
+ keyscale = ''
1113
+ if language == 'N/A':
1114
+ language = ''
1115
+ if timesignature == 'N/A':
1116
+ timesignature = ''
1117
+
1118
+ return CreateSampleResult(
1119
+ caption=caption,
1120
+ lyrics=lyrics,
1121
+ bpm=bpm,
1122
+ duration=duration,
1123
+ keyscale=keyscale,
1124
+ language=language,
1125
+ timesignature=timesignature,
1126
+ instrumental=is_instrumental,
1127
+ status_message=status,
1128
+ success=True,
1129
+ error=None,
1130
+ )
1131
+
1132
+ except Exception as e:
1133
+ logger.exception("Sample creation failed")
1134
+ return CreateSampleResult(
1135
+ status_message=f"Error: {str(e)}",
1136
+ success=False,
1137
+ error=str(e),
1138
+ )
1139
+
1140
+
1141
+ @dataclass
1142
+ class FormatSampleResult:
1143
+ """Result of formatting user-provided caption and lyrics.
1144
+
1145
+ This is used by the "Format" feature where users provide caption and lyrics,
1146
+ and the LLM formats them into structured music metadata and an enhanced description.
1147
+
1148
+ Attributes:
1149
+ # Metadata Fields
1150
+ caption: Enhanced/formatted music description/caption
1151
+ lyrics: Formatted lyrics (may be same as input or reformatted)
1152
+ bpm: Beats per minute (None if not detected)
1153
+ duration: Duration in seconds (None if not detected)
1154
+ keyscale: Musical key (e.g., "C Major")
1155
+ language: Vocal language code (e.g., "en", "zh")
1156
+ timesignature: Time signature (e.g., "4")
1157
+
1158
+ # Status
1159
+ status_message: Status message from formatting
1160
+ success: Whether formatting completed successfully
1161
+ error: Error message if formatting failed
1162
+ """
1163
+ # Metadata Fields
1164
+ caption: str = ""
1165
+ lyrics: str = ""
1166
+ bpm: Optional[int] = None
1167
+ duration: Optional[float] = None
1168
+ keyscale: str = ""
1169
+ language: str = ""
1170
+ timesignature: str = ""
1171
+
1172
+ # Status
1173
+ status_message: str = ""
1174
+ success: bool = True
1175
+ error: Optional[str] = None
1176
+
1177
+ def to_dict(self) -> Dict[str, Any]:
1178
+ """Convert result to dictionary for JSON serialization."""
1179
+ return asdict(self)
1180
+
1181
+
1182
+ def format_sample(
1183
+ llm_handler,
1184
+ caption: str,
1185
+ lyrics: str,
1186
+ user_metadata: Optional[Dict[str, Any]] = None,
1187
+ temperature: float = 0.85,
1188
+ top_k: Optional[int] = None,
1189
+ top_p: Optional[float] = None,
1190
+ repetition_penalty: float = 1.0,
1191
+ use_constrained_decoding: bool = True,
1192
+ constrained_decoding_debug: bool = False,
1193
+ ) -> FormatSampleResult:
1194
+ """Format user-provided caption and lyrics using the 5Hz Language Model.
1195
+
1196
+ This function takes user input (caption and lyrics) and generates structured
1197
+ music metadata including an enhanced caption, BPM, duration, key, language,
1198
+ and time signature.
1199
+
1200
+ If user_metadata is provided, those values will be used to constrain the
1201
+ decoding, ensuring the output matches user-specified values.
1202
+
1203
+ Note: cfg_scale and negative_prompt are not supported in format mode.
1204
+
1205
+ Args:
1206
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1207
+ caption: User's caption/description (e.g., "Latin pop, reggaeton")
1208
+ lyrics: User's lyrics with structure tags
1209
+ user_metadata: Optional dict with user-provided metadata to constrain decoding.
1210
+ Supported keys: bpm, duration, keyscale, timesignature, language
1211
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1212
+ top_k: Top-K sampling (None or 0 = disabled)
1213
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1214
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1215
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1216
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1217
+
1218
+ Returns:
1219
+ FormatSampleResult with formatted metadata fields and status
1220
+
1221
+ Example:
1222
+ >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
1223
+ >>> if result.success:
1224
+ ... print(f"Caption: {result.caption}")
1225
+ ... print(f"BPM: {result.bpm}")
1226
+ ... print(f"Lyrics: {result.lyrics}")
1227
+ """
1228
+ # Check if LLM is initialized
1229
+ if not llm_handler.llm_initialized:
1230
+ return FormatSampleResult(
1231
+ status_message="5Hz LM not initialized. Please initialize it first.",
1232
+ success=False,
1233
+ error="LLM not initialized",
1234
+ )
1235
+
1236
+ try:
1237
+ # Call LLM formatting
1238
+ metadata, status = llm_handler.format_sample_from_input(
1239
+ caption=caption,
1240
+ lyrics=lyrics,
1241
+ user_metadata=user_metadata,
1242
+ temperature=temperature,
1243
+ top_k=top_k,
1244
+ top_p=top_p,
1245
+ repetition_penalty=repetition_penalty,
1246
+ use_constrained_decoding=use_constrained_decoding,
1247
+ constrained_decoding_debug=constrained_decoding_debug,
1248
+ )
1249
+
1250
+ # Check if LLM returned empty metadata (error case)
1251
+ if not metadata:
1252
+ return FormatSampleResult(
1253
+ status_message=status or "Failed to format input",
1254
+ success=False,
1255
+ error=status or "Empty metadata returned",
1256
+ )
1257
+
1258
+ # Extract and convert fields
1259
+ result_caption = metadata.get('caption', '')
1260
+ result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
1261
+ keyscale = metadata.get('keyscale', '')
1262
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1263
+ timesignature = metadata.get('timesignature', '')
1264
+
1265
+ # Convert BPM to int
1266
+ bpm = None
1267
+ bpm_value = metadata.get('bpm')
1268
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1269
+ try:
1270
+ bpm = int(bpm_value)
1271
+ except (ValueError, TypeError):
1272
+ pass
1273
+
1274
+ # Convert duration to float
1275
+ duration = None
1276
+ duration_value = metadata.get('duration')
1277
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1278
+ try:
1279
+ duration = float(duration_value)
1280
+ except (ValueError, TypeError):
1281
+ pass
1282
+
1283
+ # Clean up N/A values
1284
+ if keyscale == 'N/A':
1285
+ keyscale = ''
1286
+ if language == 'N/A':
1287
+ language = ''
1288
+ if timesignature == 'N/A':
1289
+ timesignature = ''
1290
+
1291
+ return FormatSampleResult(
1292
+ caption=result_caption,
1293
+ lyrics=result_lyrics,
1294
+ bpm=bpm,
1295
+ duration=duration,
1296
+ keyscale=keyscale,
1297
+ language=language,
1298
+ timesignature=timesignature,
1299
+ status_message=status,
1300
+ success=True,
1301
+ error=None,
1302
+ )
1303
+
1304
+ except Exception as e:
1305
+ logger.exception("Format sample failed")
1306
+ return FormatSampleResult(
1307
+ status_message=f"Error: {str(e)}",
1308
+ success=False,
1309
+ error=str(e),
1310
+ )
acestep/llm_inference.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/local_cache.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local cache module to replace Redis
2
+
3
+ Uses diskcache as backend, provides Redis-compatible API.
4
+ Supports persistent storage and TTL expiration.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Any, Optional
10
+ from threading import Lock
11
+
12
+ try:
13
+ from diskcache import Cache
14
+ HAS_DISKCACHE = True
15
+ except ImportError:
16
+ HAS_DISKCACHE = False
17
+
18
+
19
+ class LocalCache:
20
+ """
21
+ Local cache implementation with Redis-compatible API.
22
+ Uses diskcache as backend, supports persistence and TTL.
23
+ """
24
+
25
+ _instance = None
26
+ _lock = Lock()
27
+
28
+ def __new__(cls, cache_dir: Optional[str] = None):
29
+ """Singleton pattern"""
30
+ if cls._instance is None:
31
+ with cls._lock:
32
+ if cls._instance is None:
33
+ cls._instance = super().__new__(cls)
34
+ cls._instance._initialized = False
35
+ return cls._instance
36
+
37
+ def __init__(self, cache_dir: Optional[str] = None):
38
+ if getattr(self, '_initialized', False):
39
+ return
40
+
41
+ if not HAS_DISKCACHE:
42
+ raise ImportError(
43
+ "diskcache not installed. Run: pip install diskcache"
44
+ )
45
+
46
+ if cache_dir is None:
47
+ cache_dir = os.path.join(
48
+ os.path.dirname(os.path.dirname(__file__)),
49
+ ".cache",
50
+ "local_redis"
51
+ )
52
+
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ self._cache = Cache(cache_dir)
55
+ self._initialized = True
56
+
57
+ def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
58
+ """
59
+ Set key-value pair
60
+
61
+ Args:
62
+ name: Key name
63
+ value: Value (auto-serialize dict/list)
64
+ ex: Expiration time (seconds)
65
+
66
+ Returns:
67
+ bool: Success status
68
+ """
69
+ if isinstance(value, (dict, list)):
70
+ value = json.dumps(value, ensure_ascii=False)
71
+ self._cache.set(name, value, expire=ex)
72
+ return True
73
+
74
+ def get(self, name: str) -> Optional[str]:
75
+ """Get value"""
76
+ return self._cache.get(name)
77
+
78
+ def delete(self, name: str) -> int:
79
+ """Delete key, returns number of deleted items"""
80
+ return 1 if self._cache.delete(name) else 0
81
+
82
+ def exists(self, name: str) -> bool:
83
+ """Check if key exists"""
84
+ return name in self._cache
85
+
86
+ def keys(self, pattern: str = "*") -> list:
87
+ """
88
+ Get list of matching keys
89
+ Note: Simplified implementation, only supports prefix and full matching
90
+ """
91
+ if pattern == "*":
92
+ return list(self._cache.iterkeys())
93
+
94
+ prefix = pattern.rstrip("*")
95
+ return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
96
+
97
+ def expire(self, name: str, seconds: int) -> bool:
98
+ """Set key expiration time"""
99
+ value = self._cache.get(name)
100
+ if value is not None:
101
+ self._cache.set(name, value, expire=seconds)
102
+ return True
103
+ return False
104
+
105
+ def ttl(self, name: str) -> int:
106
+ """
107
+ Get remaining time to live (seconds)
108
+ Note: diskcache does not directly support TTL queries
109
+ """
110
+ if name in self._cache:
111
+ return -1 # Exists but TTL unknown
112
+ return -2 # Key does not exist
113
+
114
+ def close(self):
115
+ """Close cache connection"""
116
+ if hasattr(self, '_cache'):
117
+ self._cache.close()
118
+
119
+
120
+ # Lazily initialized global instance
121
+ _local_cache: Optional[LocalCache] = None
122
+
123
+
124
+ def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
125
+ """Get local cache instance"""
126
+ global _local_cache
127
+ if _local_cache is None:
128
+ _local_cache = LocalCache(cache_dir)
129
+ return _local_cache
acestep/model_downloader.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Model Downloader
3
+
4
+ This module provides functionality to download models from HuggingFace Hub or ModelScope.
5
+ It supports automatic downloading when models are not found locally,
6
+ with intelligent fallback between download sources.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import argparse
12
+ from typing import Optional, List, Dict, Tuple
13
+ from pathlib import Path
14
+
15
+ from loguru import logger
16
+
17
+
18
+ # =============================================================================
19
+ # Network Detection & Smart Download
20
+ # =============================================================================
21
+
22
+ def _can_access_google(timeout: float = 3.0) -> bool:
23
+ """
24
+ Check if Google is accessible (to determine HuggingFace vs ModelScope).
25
+
26
+ Args:
27
+ timeout: Connection timeout in seconds
28
+
29
+ Returns:
30
+ True if Google is accessible, False otherwise
31
+ """
32
+ import socket
33
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
34
+ try:
35
+ sock.settimeout(timeout)
36
+ sock.connect(("www.google.com", 443))
37
+ return True
38
+ except (socket.timeout, socket.error, OSError):
39
+ return False
40
+ finally:
41
+ sock.close()
42
+
43
+
44
+ def _download_from_huggingface_internal(
45
+ repo_id: str,
46
+ local_dir: Path,
47
+ token: Optional[str] = None,
48
+ ) -> None:
49
+ """
50
+ Internal function to download from HuggingFace Hub.
51
+
52
+ Args:
53
+ repo_id: HuggingFace repository ID (e.g., "ACE-Step/Ace-Step1.5")
54
+ local_dir: Local directory to save the model
55
+ token: HuggingFace token for private repos (optional)
56
+
57
+ Raises:
58
+ Exception: If download fails
59
+ """
60
+ from huggingface_hub import snapshot_download
61
+
62
+ logger.info(f"[Model Download] Downloading from HuggingFace: {repo_id} -> {local_dir}")
63
+
64
+ snapshot_download(
65
+ repo_id=repo_id,
66
+ local_dir=str(local_dir),
67
+ local_dir_use_symlinks=False,
68
+ token=token,
69
+ )
70
+
71
+
72
+ def _download_from_modelscope_internal(
73
+ repo_id: str,
74
+ local_dir: Path,
75
+ ) -> None:
76
+ """
77
+ Internal function to download from ModelScope.
78
+
79
+ Args:
80
+ repo_id: ModelScope repository ID (e.g., "ACE-Step/Ace-Step1.5")
81
+ local_dir: Local directory to save the model
82
+
83
+ Raises:
84
+ Exception: If download fails
85
+ """
86
+ from modelscope import snapshot_download
87
+
88
+ logger.info(f"[Model Download] Downloading from ModelScope: {repo_id} -> {local_dir}")
89
+
90
+ snapshot_download(
91
+ model_id=repo_id,
92
+ local_dir=str(local_dir),
93
+ )
94
+
95
+
96
+ def _smart_download(
97
+ repo_id: str,
98
+ local_dir: Path,
99
+ token: Optional[str] = None,
100
+ prefer_source: Optional[str] = None,
101
+ ) -> Tuple[bool, str]:
102
+ """
103
+ Smart download with automatic fallback between HuggingFace and ModelScope.
104
+
105
+ Automatically detects network environment and chooses the best download source.
106
+ If the primary source fails, automatically falls back to the alternative.
107
+
108
+ Args:
109
+ repo_id: Repository ID (same format for both HF and ModelScope)
110
+ local_dir: Local directory to save the model
111
+ token: HuggingFace token for private repos (optional)
112
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
113
+
114
+ Returns:
115
+ Tuple of (success, message)
116
+ """
117
+ # Ensure directory exists
118
+ local_dir.mkdir(parents=True, exist_ok=True)
119
+
120
+ # Determine primary source
121
+ if prefer_source == "huggingface":
122
+ use_huggingface_first = True
123
+ logger.info("[Model Download] User preference: HuggingFace Hub")
124
+ elif prefer_source == "modelscope":
125
+ use_huggingface_first = False
126
+ logger.info("[Model Download] User preference: ModelScope")
127
+ else:
128
+ # Auto-detect network environment
129
+ can_access_google = _can_access_google()
130
+ use_huggingface_first = can_access_google
131
+ logger.info(f"[Model Download] Auto-detected: {'HuggingFace Hub' if can_access_google else 'ModelScope'}")
132
+
133
+ if use_huggingface_first:
134
+ logger.info("[Model Download] Using HuggingFace Hub...")
135
+ try:
136
+ _download_from_huggingface_internal(repo_id, local_dir, token)
137
+ return True, f"Successfully downloaded from HuggingFace: {repo_id}"
138
+ except Exception as e:
139
+ logger.warning(f"[Model Download] HuggingFace download failed: {e}")
140
+ logger.info("[Model Download] Falling back to ModelScope...")
141
+ try:
142
+ _download_from_modelscope_internal(repo_id, local_dir)
143
+ return True, f"Successfully downloaded from ModelScope: {repo_id}"
144
+ except Exception as e2:
145
+ error_msg = f"Both HuggingFace and ModelScope downloads failed. HF: {e}, MS: {e2}"
146
+ logger.error(error_msg)
147
+ return False, error_msg
148
+ else:
149
+ logger.info("[Model Download] Using ModelScope...")
150
+ try:
151
+ _download_from_modelscope_internal(repo_id, local_dir)
152
+ return True, f"Successfully downloaded from ModelScope: {repo_id}"
153
+ except Exception as e:
154
+ logger.warning(f"[Model Download] ModelScope download failed: {e}")
155
+ logger.info("[Model Download] Falling back to HuggingFace Hub...")
156
+ try:
157
+ _download_from_huggingface_internal(repo_id, local_dir, token)
158
+ return True, f"Successfully downloaded from HuggingFace: {repo_id}"
159
+ except Exception as e2:
160
+ error_msg = f"Both ModelScope and HuggingFace downloads failed. MS: {e}, HF: {e2}"
161
+ logger.error(error_msg)
162
+ return False, error_msg
163
+
164
+
165
+ # =============================================================================
166
+ # Model Registry
167
+ # =============================================================================
168
+ # Main model contains core components (vae, text_encoder, default DiT)
169
+ MAIN_MODEL_REPO = "ACE-Step/Ace-Step1.5"
170
+
171
+ # Sub-models that can be downloaded separately into the checkpoints directory
172
+ SUBMODEL_REGISTRY: Dict[str, str] = {
173
+ # LM models
174
+ "acestep-5Hz-lm-0.6B": "ACE-Step/acestep-5Hz-lm-0.6B",
175
+ "acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B",
176
+ # DiT models
177
+ "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3",
178
+ "acestep-v15-sft": "ACE-Step/acestep-v15-sft",
179
+ "acestep-v15-base": "ACE-Step/acestep-v15-base",
180
+ "acestep-v15-turbo-shift1": "ACE-Step/acestep-v15-turbo-shift1",
181
+ "acestep-v15-turbo-continuous": "ACE-Step/acestep-v15-turbo-continuous",
182
+ }
183
+
184
+ # Components that come from the main model repo (ACE-Step/Ace-Step1.5)
185
+ MAIN_MODEL_COMPONENTS = [
186
+ "acestep-v15-turbo", # Default DiT model
187
+ "vae", # VAE for audio encoding/decoding
188
+ "Qwen3-Embedding-0.6B", # Text encoder
189
+ "acestep-5Hz-lm-1.7B", # Default LM model (1.7B)
190
+ ]
191
+
192
+ # Default LM model (included in main model)
193
+ DEFAULT_LM_MODEL = "acestep-5Hz-lm-1.7B"
194
+
195
+
196
+ def get_project_root() -> Path:
197
+ """Get the project root directory."""
198
+ current_file = Path(__file__).resolve()
199
+ return current_file.parent.parent
200
+
201
+
202
+ def get_checkpoints_dir(custom_dir: Optional[str] = None) -> Path:
203
+ """Get the checkpoints directory path."""
204
+ if custom_dir:
205
+ return Path(custom_dir)
206
+ return get_project_root() / "checkpoints"
207
+
208
+
209
+ def check_main_model_exists(checkpoints_dir: Optional[Path] = None) -> bool:
210
+ """
211
+ Check if the main model components exist in the checkpoints directory.
212
+
213
+ Returns:
214
+ True if all main model components exist, False otherwise.
215
+ """
216
+ if checkpoints_dir is None:
217
+ checkpoints_dir = get_checkpoints_dir()
218
+
219
+ for component in MAIN_MODEL_COMPONENTS:
220
+ component_path = checkpoints_dir / component
221
+ if not component_path.exists():
222
+ return False
223
+ return True
224
+
225
+
226
+ def check_model_exists(model_name: str, checkpoints_dir: Optional[Path] = None) -> bool:
227
+ """
228
+ Check if a specific model exists in the checkpoints directory.
229
+
230
+ Args:
231
+ model_name: Name of the model to check
232
+ checkpoints_dir: Custom checkpoints directory (optional)
233
+
234
+ Returns:
235
+ True if the model exists, False otherwise.
236
+ """
237
+ if checkpoints_dir is None:
238
+ checkpoints_dir = get_checkpoints_dir()
239
+
240
+ model_path = checkpoints_dir / model_name
241
+ return model_path.exists()
242
+
243
+
244
+ def list_available_models() -> Dict[str, str]:
245
+ """
246
+ List all available models for download.
247
+
248
+ Returns:
249
+ Dictionary mapping local names to HuggingFace repo IDs.
250
+ """
251
+ models = {
252
+ "main": MAIN_MODEL_REPO,
253
+ **SUBMODEL_REGISTRY
254
+ }
255
+ return models
256
+
257
+
258
+ def download_main_model(
259
+ checkpoints_dir: Optional[Path] = None,
260
+ force: bool = False,
261
+ token: Optional[str] = None,
262
+ prefer_source: Optional[str] = None,
263
+ ) -> Tuple[bool, str]:
264
+ """
265
+ Download the main ACE-Step model from HuggingFace or ModelScope.
266
+
267
+ The main model includes:
268
+ - acestep-v15-turbo (default DiT model)
269
+ - vae (audio encoder/decoder)
270
+ - Qwen3-Embedding-0.6B (text encoder)
271
+ - acestep-5Hz-lm-1.7B (default LM model)
272
+
273
+ Args:
274
+ checkpoints_dir: Custom checkpoints directory (optional)
275
+ force: Force re-download even if model exists
276
+ token: HuggingFace token for private repos (optional)
277
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
278
+
279
+ Returns:
280
+ Tuple of (success, message)
281
+ """
282
+ if checkpoints_dir is None:
283
+ checkpoints_dir = get_checkpoints_dir()
284
+
285
+ # Ensure checkpoints directory exists
286
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
287
+
288
+ if not force and check_main_model_exists(checkpoints_dir):
289
+ return True, f"Main model already exists at {checkpoints_dir}"
290
+
291
+ print(f"Downloading main model from {MAIN_MODEL_REPO}...")
292
+ print(f"Destination: {checkpoints_dir}")
293
+ print("This may take a while depending on your internet connection...")
294
+
295
+ # Use smart download with automatic fallback
296
+ return _smart_download(MAIN_MODEL_REPO, checkpoints_dir, token, prefer_source)
297
+
298
+
299
+ def download_submodel(
300
+ model_name: str,
301
+ checkpoints_dir: Optional[Path] = None,
302
+ force: bool = False,
303
+ token: Optional[str] = None,
304
+ prefer_source: Optional[str] = None,
305
+ ) -> Tuple[bool, str]:
306
+ """
307
+ Download a specific sub-model from HuggingFace or ModelScope.
308
+
309
+ Args:
310
+ model_name: Name of the model to download (must be in SUBMODEL_REGISTRY)
311
+ checkpoints_dir: Custom checkpoints directory (optional)
312
+ force: Force re-download even if model exists
313
+ token: HuggingFace token for private repos (optional)
314
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
315
+
316
+ Returns:
317
+ Tuple of (success, message)
318
+ """
319
+ if model_name not in SUBMODEL_REGISTRY:
320
+ available = ", ".join(SUBMODEL_REGISTRY.keys())
321
+ return False, f"Unknown model '{model_name}'. Available models: {available}"
322
+
323
+ if checkpoints_dir is None:
324
+ checkpoints_dir = get_checkpoints_dir()
325
+
326
+ # Ensure checkpoints directory exists
327
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
328
+
329
+ model_path = checkpoints_dir / model_name
330
+
331
+ if not force and model_path.exists():
332
+ return True, f"Model '{model_name}' already exists at {model_path}"
333
+
334
+ repo_id = SUBMODEL_REGISTRY[model_name]
335
+
336
+ print(f"Downloading {model_name} from {repo_id}...")
337
+ print(f"Destination: {model_path}")
338
+
339
+ # Use smart download with automatic fallback
340
+ return _smart_download(repo_id, model_path, token, prefer_source)
341
+
342
+
343
+ def download_all_models(
344
+ checkpoints_dir: Optional[Path] = None,
345
+ force: bool = False,
346
+ token: Optional[str] = None,
347
+ ) -> Tuple[bool, List[str]]:
348
+ """
349
+ Download all available models.
350
+
351
+ Args:
352
+ checkpoints_dir: Custom checkpoints directory (optional)
353
+ force: Force re-download even if models exist
354
+ token: HuggingFace token for private repos (optional)
355
+
356
+ Returns:
357
+ Tuple of (all_success, list of messages)
358
+ """
359
+ if checkpoints_dir is None:
360
+ checkpoints_dir = get_checkpoints_dir()
361
+
362
+ messages = []
363
+ all_success = True
364
+
365
+ # Download main model first
366
+ success, msg = download_main_model(checkpoints_dir, force, token)
367
+ messages.append(msg)
368
+ if not success:
369
+ all_success = False
370
+
371
+ # Download all sub-models
372
+ for model_name in SUBMODEL_REGISTRY:
373
+ success, msg = download_submodel(model_name, checkpoints_dir, force, token)
374
+ messages.append(msg)
375
+ if not success:
376
+ all_success = False
377
+
378
+ return all_success, messages
379
+
380
+
381
+ def ensure_main_model(
382
+ checkpoints_dir: Optional[Path] = None,
383
+ token: Optional[str] = None,
384
+ prefer_source: Optional[str] = None,
385
+ ) -> Tuple[bool, str]:
386
+ """
387
+ Ensure the main model is available, downloading if necessary.
388
+
389
+ This function is designed to be called during initialization.
390
+ It will only download if the model doesn't exist.
391
+
392
+ Args:
393
+ checkpoints_dir: Custom checkpoints directory (optional)
394
+ token: HuggingFace token for private repos (optional)
395
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
396
+
397
+ Returns:
398
+ Tuple of (success, message)
399
+ """
400
+ if checkpoints_dir is None:
401
+ checkpoints_dir = get_checkpoints_dir()
402
+
403
+ if check_main_model_exists(checkpoints_dir):
404
+ return True, "Main model is available"
405
+
406
+ print("\n" + "=" * 60)
407
+ print("Main model not found. Starting automatic download...")
408
+ print("=" * 60 + "\n")
409
+
410
+ return download_main_model(checkpoints_dir, token=token, prefer_source=prefer_source)
411
+
412
+
413
+ def ensure_lm_model(
414
+ model_name: Optional[str] = None,
415
+ checkpoints_dir: Optional[Path] = None,
416
+ token: Optional[str] = None,
417
+ prefer_source: Optional[str] = None,
418
+ ) -> Tuple[bool, str]:
419
+ """
420
+ Ensure an LM model is available, downloading if necessary.
421
+
422
+ Args:
423
+ model_name: Name of the LM model (defaults to DEFAULT_LM_MODEL)
424
+ checkpoints_dir: Custom checkpoints directory (optional)
425
+ token: HuggingFace token for private repos (optional)
426
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
427
+
428
+ Returns:
429
+ Tuple of (success, message)
430
+ """
431
+ if model_name is None:
432
+ model_name = DEFAULT_LM_MODEL
433
+
434
+ if checkpoints_dir is None:
435
+ checkpoints_dir = get_checkpoints_dir()
436
+
437
+ if check_model_exists(model_name, checkpoints_dir):
438
+ return True, f"LM model '{model_name}' is available"
439
+
440
+ # Check if this is a known LM model
441
+ if model_name not in SUBMODEL_REGISTRY:
442
+ # Check if it might be a variant name
443
+ for known_model in SUBMODEL_REGISTRY:
444
+ if "lm" in known_model.lower() and model_name.lower() in known_model.lower():
445
+ model_name = known_model
446
+ break
447
+ else:
448
+ return False, f"Unknown LM model: {model_name}"
449
+
450
+ print("\n" + "=" * 60)
451
+ print(f"LM model '{model_name}' not found. Starting automatic download...")
452
+ print("=" * 60 + "\n")
453
+
454
+ return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
455
+
456
+
457
+ def ensure_dit_model(
458
+ model_name: str,
459
+ checkpoints_dir: Optional[Path] = None,
460
+ token: Optional[str] = None,
461
+ prefer_source: Optional[str] = None,
462
+ ) -> Tuple[bool, str]:
463
+ """
464
+ Ensure a DiT model is available, downloading if necessary.
465
+
466
+ Args:
467
+ model_name: Name of the DiT model
468
+ checkpoints_dir: Custom checkpoints directory (optional)
469
+ token: HuggingFace token for private repos (optional)
470
+ prefer_source: Preferred download source ("huggingface", "modelscope", or None for auto-detect)
471
+
472
+ Returns:
473
+ Tuple of (success, message)
474
+ """
475
+ if checkpoints_dir is None:
476
+ checkpoints_dir = get_checkpoints_dir()
477
+
478
+ if check_model_exists(model_name, checkpoints_dir):
479
+ return True, f"DiT model '{model_name}' is available"
480
+
481
+ # Check if this is the default turbo model (part of main)
482
+ if model_name == "acestep-v15-turbo":
483
+ return ensure_main_model(checkpoints_dir, token, prefer_source)
484
+
485
+ # Check if it's a known sub-model
486
+ if model_name in SUBMODEL_REGISTRY:
487
+ print("\n" + "=" * 60)
488
+ print(f"DiT model '{model_name}' not found. Starting automatic download...")
489
+ print("=" * 60 + "\n")
490
+ return download_submodel(model_name, checkpoints_dir, token=token, prefer_source=prefer_source)
491
+
492
+ return False, f"Unknown DiT model: {model_name}"
493
+
494
+
495
+ def print_model_list():
496
+ """Print formatted list of available models."""
497
+ print("\nAvailable Models for Download:")
498
+ print("=" * 60)
499
+ print("\nSupported Sources: HuggingFace Hub <-> ModelScope (auto-fallback)")
500
+
501
+ print("\n[Main Model]")
502
+ print(f" main -> {MAIN_MODEL_REPO}")
503
+ print(" Contains: vae, Qwen3-Embedding-0.6B, acestep-v15-turbo, acestep-5Hz-lm-1.7B")
504
+
505
+ print("\n[Optional LM Models]")
506
+ for name, repo in SUBMODEL_REGISTRY.items():
507
+ if "lm" in name.lower():
508
+ print(f" {name} -> {repo}")
509
+
510
+ print("\n[Optional DiT Models]")
511
+ for name, repo in SUBMODEL_REGISTRY.items():
512
+ if "lm" not in name.lower():
513
+ print(f" {name} -> {repo}")
514
+
515
+ print("\n" + "=" * 60)
516
+
517
+
518
+ def main():
519
+ """CLI entry point for model downloading."""
520
+ parser = argparse.ArgumentParser(
521
+ description="Download ACE-Step models with automatic fallback (HuggingFace <-> ModelScope)",
522
+ formatter_class=argparse.RawDescriptionHelpFormatter,
523
+ epilog="""
524
+ Examples:
525
+ acestep-download # Download main model (includes LM 1.7B)
526
+ acestep-download --all # Download all available models
527
+ acestep-download --model acestep-v15-sft # Download a specific model
528
+ acestep-download --list # List all available models
529
+
530
+ Network Detection:
531
+ Automatically detects network environment and chooses the best download source:
532
+ - Google accessible -> HuggingFace (fallback to ModelScope)
533
+ - Google blocked -> ModelScope (fallback to HuggingFace)
534
+
535
+ Alternative using huggingface-cli:
536
+ huggingface-cli download ACE-Step/Ace-Step1.5 --local-dir ./checkpoints
537
+ huggingface-cli download ACE-Step/acestep-5Hz-lm-0.6B --local-dir ./checkpoints/acestep-5Hz-lm-0.6B
538
+ """
539
+ )
540
+
541
+ parser.add_argument(
542
+ "--model", "-m",
543
+ type=str,
544
+ help="Specific model to download (use --list to see available models)"
545
+ )
546
+ parser.add_argument(
547
+ "--all", "-a",
548
+ action="store_true",
549
+ help="Download all available models"
550
+ )
551
+ parser.add_argument(
552
+ "--list", "-l",
553
+ action="store_true",
554
+ help="List all available models"
555
+ )
556
+ parser.add_argument(
557
+ "--dir", "-d",
558
+ type=str,
559
+ default=None,
560
+ help="Custom checkpoints directory (default: ./checkpoints)"
561
+ )
562
+ parser.add_argument(
563
+ "--force", "-f",
564
+ action="store_true",
565
+ help="Force re-download even if model exists"
566
+ )
567
+ parser.add_argument(
568
+ "--token", "-t",
569
+ type=str,
570
+ default=None,
571
+ help="HuggingFace token for private repos"
572
+ )
573
+ parser.add_argument(
574
+ "--skip-main",
575
+ action="store_true",
576
+ help="Skip downloading the main model (only download specified sub-model)"
577
+ )
578
+
579
+ args = parser.parse_args()
580
+
581
+ # Handle --list
582
+ if args.list:
583
+ print_model_list()
584
+ return 0
585
+
586
+ # Get checkpoints directory
587
+ checkpoints_dir = get_checkpoints_dir(args.dir) if args.dir else get_checkpoints_dir()
588
+ print(f"Checkpoints directory: {checkpoints_dir}")
589
+
590
+ # Handle --all
591
+ if args.all:
592
+ success, messages = download_all_models(checkpoints_dir, args.force, args.token)
593
+ for msg in messages:
594
+ print(msg)
595
+ return 0 if success else 1
596
+
597
+ # Handle --model
598
+ if args.model:
599
+ if args.model == "main":
600
+ success, msg = download_main_model(checkpoints_dir, args.force, args.token)
601
+ elif args.model in SUBMODEL_REGISTRY:
602
+ # Download main model first if needed (unless --skip-main)
603
+ if not args.skip_main and not check_main_model_exists(checkpoints_dir):
604
+ print("Main model not found. Downloading main model first...")
605
+ main_success, main_msg = download_main_model(checkpoints_dir, args.force, args.token)
606
+ print(main_msg)
607
+ if not main_success:
608
+ return 1
609
+
610
+ success, msg = download_submodel(args.model, checkpoints_dir, args.force, args.token)
611
+ else:
612
+ print(f"Unknown model: {args.model}")
613
+ print("Use --list to see available models")
614
+ return 1
615
+
616
+ print(msg)
617
+ return 0 if success else 1
618
+
619
+ # Default: download main model (includes default LM 1.7B)
620
+ print("Downloading main model (includes vae, text encoder, DiT, and LM 1.7B)...")
621
+
622
+ # Download main model
623
+ success, msg = download_main_model(checkpoints_dir, args.force, args.token)
624
+ print(msg)
625
+
626
+ if success:
627
+ print("\nDownload complete!")
628
+ print(f"Models are available at: {checkpoints_dir}")
629
+
630
+ return 0 if success else 1
631
+
632
+
633
+ if __name__ == "__main__":
634
+ sys.exit(main())
acestep/openrouter_adapter.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenRouter API adapter for ACE-Step music generation.
2
+
3
+ This module provides OpenRouter-compatible endpoints that wrap the ACE-Step
4
+ music generation API, mounted as a sub-router on the main api_server.
5
+
6
+ All generation requests go through the shared asyncio.Queue, ensuring unified
7
+ GPU scheduling with release_task.
8
+
9
+ Endpoints:
10
+ - POST /v1/chat/completions - Generate music via chat completion format
11
+ - GET /v1/models - List available models (OpenRouter format)
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import asyncio
17
+ import base64
18
+ import json
19
+ import os
20
+ import re
21
+ import tempfile
22
+ import time
23
+ from typing import Any, Dict, List, Optional, Tuple
24
+ from uuid import uuid4
25
+
26
+ from fastapi import APIRouter, HTTPException, Request
27
+ from fastapi.responses import JSONResponse, StreamingResponse
28
+
29
+ from acestep.openrouter_models import (
30
+ AudioConfig,
31
+ ChatCompletionRequest,
32
+ ModelInfo,
33
+ ModelPricing,
34
+ ModelsResponse,
35
+ )
36
+
37
+
38
+ # =============================================================================
39
+ # Constants
40
+ # =============================================================================
41
+
42
+ MODEL_PREFIX = "acestep"
43
+ DEFAULT_AUDIO_FORMAT = "mp3"
44
+
45
+ # Generation timeout for non-streaming requests (seconds)
46
+ GENERATION_TIMEOUT = int(os.environ.get("ACESTEP_GENERATION_TIMEOUT", "600"))
47
+
48
+
49
+ # =============================================================================
50
+ # Helper Functions
51
+ # =============================================================================
52
+
53
+ def _generate_completion_id() -> str:
54
+ """Generate a unique completion ID."""
55
+ return f"chatcmpl-{uuid4().hex[:24]}"
56
+
57
+
58
+ def _get_model_id(model_name: str) -> str:
59
+ """Convert internal model name to OpenRouter model ID."""
60
+ return f"{MODEL_PREFIX}/{model_name}"
61
+
62
+
63
+ def _parse_model_name(model_id: str) -> str:
64
+ """Extract internal model name from OpenRouter model ID."""
65
+ if "/" in model_id:
66
+ return model_id.split("/", 1)[1]
67
+ return model_id
68
+
69
+
70
+ def _audio_to_base64_url(audio_path: str, audio_format: str = "mp3") -> str:
71
+ """Convert audio file to base64 data URL."""
72
+ if not audio_path or not os.path.exists(audio_path):
73
+ return ""
74
+
75
+ mime_types = {
76
+ "mp3": "audio/mpeg",
77
+ "wav": "audio/wav",
78
+ "flac": "audio/flac",
79
+ "ogg": "audio/ogg",
80
+ "m4a": "audio/mp4",
81
+ "aac": "audio/aac",
82
+ }
83
+ mime_type = mime_types.get(audio_format.lower(), "audio/mpeg")
84
+
85
+ with open(audio_path, "rb") as f:
86
+ audio_data = f.read()
87
+
88
+ b64_data = base64.b64encode(audio_data).decode("utf-8")
89
+ return f"data:{mime_type};base64,{b64_data}"
90
+
91
+
92
+ def _format_lm_content(result: Dict[str, Any]) -> str:
93
+ """Format generation result as content string with metadata and lyrics."""
94
+ metas = result.get("metas", {})
95
+ lyrics = result.get("lyrics", "")
96
+
97
+ parts = []
98
+
99
+ # Add metadata section
100
+ meta_lines = []
101
+ caption = metas.get("prompt") or metas.get("caption") or result.get("prompt", "")
102
+ if caption:
103
+ meta_lines.append(f"**Caption:** {caption}")
104
+ if metas.get("bpm") and metas["bpm"] != "N/A":
105
+ meta_lines.append(f"**BPM:** {metas['bpm']}")
106
+ if metas.get("duration") and metas["duration"] != "N/A":
107
+ meta_lines.append(f"**Duration:** {metas['duration']}s")
108
+ if metas.get("keyscale") and metas["keyscale"] != "N/A":
109
+ meta_lines.append(f"**Key:** {metas['keyscale']}")
110
+ if metas.get("timesignature") and metas["timesignature"] != "N/A":
111
+ meta_lines.append(f"**Time Signature:** {metas['timesignature']}")
112
+
113
+ if meta_lines:
114
+ parts.append("## Metadata\n" + "\n".join(meta_lines))
115
+
116
+ # Add lyrics section
117
+ if lyrics and lyrics.strip() and lyrics.strip().lower() not in ("[inst]", "[instrumental]"):
118
+ parts.append(f"## Lyrics\n{lyrics}")
119
+
120
+ if parts:
121
+ return "\n\n".join(parts)
122
+ else:
123
+ return "Music generated successfully."
124
+
125
+
126
+ def _base64_to_temp_file(b64_data: str, audio_format: str = "mp3") -> str:
127
+ """Save base64 audio data to temporary file."""
128
+ if "," in b64_data:
129
+ b64_data = b64_data.split(",", 1)[1]
130
+
131
+ audio_bytes = base64.b64decode(b64_data)
132
+ suffix = f".{audio_format}" if not audio_format.startswith(".") else audio_format
133
+ fd, path = tempfile.mkstemp(suffix=suffix, prefix="openrouter_audio_")
134
+ os.close(fd)
135
+
136
+ with open(path, "wb") as f:
137
+ f.write(audio_bytes)
138
+
139
+ return path
140
+
141
+
142
+ def _extract_tagged_content(text: str) -> Tuple[Optional[str], Optional[str], str]:
143
+ """
144
+ Extract content from <prompt> and <lyrics> tags.
145
+
146
+ Returns:
147
+ (prompt, lyrics, remaining_text)
148
+ """
149
+ prompt = None
150
+ lyrics = None
151
+ remaining = text
152
+
153
+ prompt_match = re.search(r'<prompt>(.*?)</prompt>', text, re.DOTALL | re.IGNORECASE)
154
+ if prompt_match:
155
+ prompt = prompt_match.group(1).strip()
156
+ remaining = remaining.replace(prompt_match.group(0), '').strip()
157
+
158
+ lyrics_match = re.search(r'<lyrics>(.*?)</lyrics>', text, re.DOTALL | re.IGNORECASE)
159
+ if lyrics_match:
160
+ lyrics = lyrics_match.group(1).strip()
161
+ remaining = remaining.replace(lyrics_match.group(0), '').strip()
162
+
163
+ return prompt, lyrics, remaining
164
+
165
+
166
+ def _looks_like_lyrics(text: str) -> bool:
167
+ """Heuristic to detect if text looks like song lyrics."""
168
+ if not text:
169
+ return False
170
+
171
+ lyrics_markers = [
172
+ "[verse", "[chorus", "[bridge", "[intro", "[outro",
173
+ "[hook", "[pre-chorus", "[refrain", "[inst",
174
+ ]
175
+ text_lower = text.lower()
176
+ for marker in lyrics_markers:
177
+ if marker in text_lower:
178
+ return True
179
+
180
+ lines = [line.strip() for line in text.split("\n") if line.strip()]
181
+ if len(lines) >= 4:
182
+ avg_line_length = sum(len(line) for line in lines) / len(lines)
183
+ if avg_line_length < 60:
184
+ return True
185
+
186
+ return False
187
+
188
+
189
+ def _is_instrumental(lyrics: str) -> bool:
190
+ """Check if the music should be instrumental based on lyrics."""
191
+ if not lyrics:
192
+ return True
193
+ lyrics_clean = lyrics.strip().lower()
194
+ if not lyrics_clean:
195
+ return True
196
+ return lyrics_clean in ("[inst]", "[instrumental]")
197
+
198
+
199
+ def _parse_messages(messages: List[Any]) -> Tuple[str, str, List[str], Optional[str], Optional[str]]:
200
+ """
201
+ Parse chat messages to extract prompt, lyrics, sample_query and audio references.
202
+
203
+ Supports two modes:
204
+ 1. Tagged mode: Use <prompt>...</prompt> and <lyrics>...</lyrics> tags
205
+ 2. Heuristic mode: Auto-detect based on content structure
206
+
207
+ Multiple input_audio blocks are collected in order (like multiple images).
208
+ The caller routes them to src_audio / reference_audio based on task_type.
209
+
210
+ Returns:
211
+ (prompt, lyrics, audio_paths, system_instruction, sample_query)
212
+ """
213
+ prompt_parts = []
214
+ lyrics = ""
215
+ sample_query = None
216
+ audio_paths: List[str] = []
217
+ system_instruction = None
218
+ has_tags = False
219
+
220
+ for msg in messages:
221
+ role = msg.role
222
+ content = msg.content
223
+
224
+ if role == "system":
225
+ if isinstance(content, str):
226
+ system_instruction = content
227
+ continue
228
+
229
+ if role != "user":
230
+ continue
231
+
232
+ if isinstance(content, str):
233
+ text = content.strip()
234
+ tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
235
+ if tagged_prompt is not None or tagged_lyrics is not None:
236
+ has_tags = True
237
+ if tagged_prompt:
238
+ prompt_parts.append(tagged_prompt)
239
+ if tagged_lyrics:
240
+ lyrics = tagged_lyrics
241
+ if remaining:
242
+ prompt_parts.append(remaining)
243
+ else:
244
+ if _looks_like_lyrics(text):
245
+ lyrics = text
246
+ else:
247
+ prompt_parts.append(text)
248
+
249
+ elif isinstance(content, list):
250
+ for part in content:
251
+ if isinstance(part, dict):
252
+ part_type = part.get("type", "")
253
+
254
+ if part_type == "text":
255
+ text = part.get("text", "").strip()
256
+ tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
257
+ if tagged_prompt is not None or tagged_lyrics is not None:
258
+ has_tags = True
259
+ if tagged_prompt:
260
+ prompt_parts.append(tagged_prompt)
261
+ if tagged_lyrics:
262
+ lyrics = tagged_lyrics
263
+ if remaining:
264
+ prompt_parts.append(remaining)
265
+ elif _looks_like_lyrics(text):
266
+ lyrics = text
267
+ else:
268
+ prompt_parts.append(text)
269
+
270
+ elif part_type == "input_audio":
271
+ audio_data = part.get("input_audio", {})
272
+ if isinstance(audio_data, dict):
273
+ b64_data = audio_data.get("data", "")
274
+ audio_format = audio_data.get("format", "mp3")
275
+ if b64_data:
276
+ try:
277
+ path = _base64_to_temp_file(b64_data, audio_format)
278
+ audio_paths.append(path)
279
+ except Exception:
280
+ pass
281
+
282
+ elif hasattr(part, "type"):
283
+ if part.type == "text":
284
+ text = getattr(part, "text", "").strip()
285
+ tagged_prompt, tagged_lyrics, remaining = _extract_tagged_content(text)
286
+ if tagged_prompt is not None or tagged_lyrics is not None:
287
+ has_tags = True
288
+ if tagged_prompt:
289
+ prompt_parts.append(tagged_prompt)
290
+ if tagged_lyrics:
291
+ lyrics = tagged_lyrics
292
+ if remaining:
293
+ prompt_parts.append(remaining)
294
+ elif _looks_like_lyrics(text):
295
+ lyrics = text
296
+ else:
297
+ prompt_parts.append(text)
298
+
299
+ elif part.type == "input_audio":
300
+ audio_data = getattr(part, "input_audio", None)
301
+ if audio_data:
302
+ b64_data = getattr(audio_data, "data", "")
303
+ audio_format = getattr(audio_data, "format", "mp3")
304
+ if b64_data:
305
+ try:
306
+ path = _base64_to_temp_file(b64_data, audio_format)
307
+ audio_paths.append(path)
308
+ except Exception:
309
+ pass
310
+
311
+ prompt = " ".join(prompt_parts).strip()
312
+
313
+ # Use sample mode when: no tags, no lyrics detected, and we have text input
314
+ if not has_tags and not lyrics and prompt:
315
+ sample_query = prompt
316
+ prompt = ""
317
+
318
+ return prompt, lyrics, audio_paths, system_instruction, sample_query
319
+
320
+
321
+ def _to_generate_music_request(
322
+ req: ChatCompletionRequest,
323
+ prompt: str,
324
+ lyrics: str,
325
+ sample_query: Optional[str],
326
+ reference_audio_path: Optional[str],
327
+ src_audio_path: Optional[str],
328
+ ):
329
+ """
330
+ Convert OpenRouter ChatCompletionRequest to api_server's GenerateMusicRequest.
331
+
332
+ Audio routing depends on task_type:
333
+ text2music: audio[0] → reference_audio
334
+ cover/repaint/lego/…: audio[0] → src_audio, audio[1] → reference_audio
335
+
336
+ task_type auto-detection:
337
+ text2music + reference_audio → music_continuation
338
+
339
+ Uses late import to avoid circular dependency with api_server.
340
+ """
341
+ from acestep.api_server import GenerateMusicRequest
342
+
343
+ audio_config = req.audio_config or AudioConfig()
344
+
345
+ # Resolve parameters from audio_config only
346
+ resolved_instrumental = audio_config.instrumental if audio_config.instrumental is not None else False
347
+
348
+ # If instrumental, set lyrics to [inst]
349
+ resolved_lyrics = lyrics
350
+ if req.lyrics:
351
+ resolved_lyrics = req.lyrics
352
+ if resolved_instrumental and not resolved_lyrics:
353
+ resolved_lyrics = "[inst]"
354
+
355
+ # Resolve sample_mode: explicit field takes priority, then auto-detect from messages
356
+ resolved_sample_mode = req.sample_mode or bool(sample_query)
357
+ resolved_sample_query = sample_query or ""
358
+
359
+ # Resolve seed: pass through as-is (int or comma-separated string)
360
+ # handler.prepare_seeds() handles both formats
361
+ resolved_seed = req.seed if req.seed is not None else -1
362
+ use_random_seed = req.seed is None
363
+
364
+ # Resolve task_type
365
+ # Explicit task_type from request takes priority.
366
+ # For text2music: auto-detect based on reference_audio.
367
+ resolved_task_type = req.task_type
368
+ if resolved_task_type == "text2music" and reference_audio_path:
369
+ resolved_task_type = "music_continuation"
370
+
371
+ return GenerateMusicRequest(
372
+ # Text input
373
+ prompt=prompt,
374
+ lyrics=resolved_lyrics,
375
+ sample_query=resolved_sample_query,
376
+ sample_mode=resolved_sample_mode,
377
+
378
+ # Music metadata
379
+ bpm=audio_config.bpm,
380
+ key_scale=audio_config.key_scale or "",
381
+ time_signature=audio_config.time_signature or "",
382
+ audio_duration=audio_config.duration if audio_config.duration else None,
383
+ vocal_language=audio_config.vocal_language or "en",
384
+
385
+ # LM parameters
386
+ lm_temperature=req.temperature if req.temperature is not None else 0.85,
387
+ lm_top_p=req.top_p if req.top_p is not None else 0.9,
388
+ lm_top_k=req.top_k if req.top_k is not None else 0,
389
+ thinking=req.thinking if req.thinking is not None else False,
390
+
391
+ # Generation parameters
392
+ inference_steps=8,
393
+ guidance_scale=req.guidance_scale if req.guidance_scale is not None else 7.0,
394
+ seed=resolved_seed,
395
+ use_random_seed=use_random_seed,
396
+ batch_size=req.batch_size if req.batch_size is not None else 1,
397
+
398
+ # Task type
399
+ task_type=resolved_task_type,
400
+
401
+ # Audio paths
402
+ reference_audio_path=reference_audio_path or None,
403
+ src_audio_path=src_audio_path or None,
404
+
405
+ # Audio editing
406
+ repainting_start=req.repainting_start,
407
+ repainting_end=req.repainting_end,
408
+ audio_cover_strength=req.audio_cover_strength,
409
+
410
+ # Format / CoT control
411
+ use_format=req.use_format,
412
+ use_cot_caption=req.use_cot_caption,
413
+ use_cot_language=req.use_cot_language,
414
+
415
+ # Model selection
416
+ model=_parse_model_name(req.model),
417
+
418
+ # Audio format
419
+ audio_format=(audio_config.format or DEFAULT_AUDIO_FORMAT),
420
+ )
421
+
422
+
423
+ def _build_openrouter_response(
424
+ rec: Any,
425
+ model_id: str,
426
+ audio_format: str,
427
+ ) -> JSONResponse:
428
+ """Build OpenRouter non-streaming response from a completed JobRecord."""
429
+ if rec.status != "succeeded" or not rec.result:
430
+ error_msg = rec.error or "Generation failed"
431
+ raise HTTPException(status_code=500, detail=error_msg)
432
+
433
+ result = rec.result
434
+ completion_id = _generate_completion_id()
435
+ created_timestamp = int(time.time())
436
+
437
+ text_content = _format_lm_content(result)
438
+
439
+ # Encode audio
440
+ audio_obj = None
441
+ raw_audio_paths = result.get("raw_audio_paths", [])
442
+ if raw_audio_paths:
443
+ audio_path = raw_audio_paths[0]
444
+ if audio_path and os.path.exists(audio_path):
445
+ b64_url = _audio_to_base64_url(audio_path, audio_format)
446
+ if b64_url:
447
+ audio_obj = [{
448
+ "type": "audio_url",
449
+ "audio_url": {"url": b64_url},
450
+ }]
451
+
452
+ response_data = {
453
+ "id": completion_id,
454
+ "object": "chat.completion",
455
+ "created": created_timestamp,
456
+ "model": model_id,
457
+ "choices": [{
458
+ "index": 0,
459
+ "message": {
460
+ "role": "assistant",
461
+ "content": text_content,
462
+ "audio": audio_obj,
463
+ },
464
+ "finish_reason": "stop",
465
+ }],
466
+ "usage": {
467
+ "prompt_tokens": 0,
468
+ "completion_tokens": 0,
469
+ "total_tokens": 0,
470
+ },
471
+ }
472
+
473
+ return JSONResponse(content=response_data)
474
+
475
+
476
+ async def _openrouter_stream_generator(
477
+ rec: Any,
478
+ model_id: str,
479
+ audio_format: str,
480
+ ):
481
+ """
482
+ SSE stream generator that reads from rec.progress_queue.
483
+
484
+ Yields heartbeat chunks every 2 seconds while waiting for the
485
+ queue worker to push the generation result.
486
+ """
487
+ completion_id = _generate_completion_id()
488
+ created_timestamp = int(time.time())
489
+
490
+ def _make_chunk(
491
+ content: Optional[str] = None,
492
+ role: Optional[str] = None,
493
+ audio: Optional[Any] = None,
494
+ finish_reason: Optional[str] = None,
495
+ ) -> str:
496
+ delta = {}
497
+ if role:
498
+ delta["role"] = role
499
+ if content is not None:
500
+ delta["content"] = content
501
+ if audio is not None:
502
+ delta["audio"] = audio
503
+
504
+ chunk = {
505
+ "id": completion_id,
506
+ "object": "chat.completion.chunk",
507
+ "created": created_timestamp,
508
+ "model": model_id,
509
+ "choices": [{
510
+ "index": 0,
511
+ "delta": delta,
512
+ "finish_reason": finish_reason,
513
+ }],
514
+ }
515
+ return f"data: {json.dumps(chunk)}\n\n"
516
+
517
+ # Initial role chunk
518
+ yield _make_chunk(role="assistant", content="Generating music")
519
+ await asyncio.sleep(0)
520
+
521
+ # Wait for result with periodic heartbeats
522
+ while True:
523
+ try:
524
+ msg = await asyncio.wait_for(rec.progress_queue.get(), timeout=2.0)
525
+ except asyncio.TimeoutError:
526
+ yield _make_chunk(content=".")
527
+ await asyncio.sleep(0)
528
+ continue
529
+
530
+ msg_type = msg.get("type")
531
+
532
+ if msg_type == "done":
533
+ break
534
+
535
+ elif msg_type == "error":
536
+ yield _make_chunk(content=f"\n\nError: {msg.get('content', 'Unknown error')}")
537
+ yield _make_chunk(finish_reason="error")
538
+ yield "data: [DONE]\n\n"
539
+ return
540
+
541
+ elif msg_type == "result":
542
+ result = msg.get("result", {})
543
+
544
+ # Send LM content
545
+ lm_content = _format_lm_content(result)
546
+ yield _make_chunk(content=f"\n\n{lm_content}")
547
+ await asyncio.sleep(0)
548
+
549
+ # Send audio
550
+ raw_audio_paths = result.get("raw_audio_paths", [])
551
+ if raw_audio_paths:
552
+ audio_path = raw_audio_paths[0]
553
+ if audio_path and os.path.exists(audio_path):
554
+ b64_url = _audio_to_base64_url(audio_path, audio_format)
555
+ if b64_url:
556
+ audio_list = [{
557
+ "type": "audio_url",
558
+ "audio_url": {"url": b64_url},
559
+ }]
560
+ yield _make_chunk(audio=audio_list)
561
+ await asyncio.sleep(0)
562
+
563
+ # Finish
564
+ yield _make_chunk(finish_reason="stop")
565
+ yield "data: [DONE]\n\n"
566
+
567
+
568
+ # =============================================================================
569
+ # Router Factory
570
+ # =============================================================================
571
+
572
+ def create_openrouter_router(app_state_getter) -> APIRouter:
573
+ """
574
+ Create OpenRouter-compatible API router.
575
+
576
+ Args:
577
+ app_state_getter: Callable that returns the FastAPI app.state object
578
+
579
+ Returns:
580
+ APIRouter with OpenRouter-compatible endpoints
581
+ """
582
+ router = APIRouter(tags=["OpenRouter Compatible"])
583
+
584
+ def _get_model_name_from_path(config_path: str) -> str:
585
+ """Extract model name from config path."""
586
+ if not config_path:
587
+ return ""
588
+ normalized = config_path.rstrip("/\\")
589
+ return os.path.basename(normalized)
590
+
591
+ @router.get("/v1/models", response_model=ModelsResponse)
592
+ async def list_models():
593
+ """List available models in OpenRouter format."""
594
+ state = app_state_getter()
595
+ models = []
596
+ created_timestamp = int(time.time()) - 86400 * 30
597
+
598
+ # Primary model
599
+ if getattr(state, "_initialized", False):
600
+ model_name = _get_model_name_from_path(state._config_path)
601
+ if model_name:
602
+ models.append(ModelInfo(
603
+ id=_get_model_id(model_name),
604
+ name=f"ACE-Step {model_name}",
605
+ created=created_timestamp,
606
+ input_modalities=["text", "audio"],
607
+ output_modalities=["audio", "text"],
608
+ context_length=4096,
609
+ max_output_length=300,
610
+ pricing=ModelPricing(
611
+ prompt="0", completion="0", request="0",
612
+ ),
613
+ description="AI music generation model",
614
+ ))
615
+
616
+ # Secondary model
617
+ if getattr(state, "_initialized2", False) and getattr(state, "_config_path2", ""):
618
+ model_name = _get_model_name_from_path(state._config_path2)
619
+ if model_name:
620
+ models.append(ModelInfo(
621
+ id=_get_model_id(model_name),
622
+ name=f"ACE-Step {model_name}",
623
+ created=created_timestamp,
624
+ input_modalities=["text", "audio"],
625
+ output_modalities=["audio", "text"],
626
+ context_length=4096,
627
+ max_output_length=300,
628
+ pricing=ModelPricing(),
629
+ description="AI music generation model",
630
+ ))
631
+
632
+ # Third model
633
+ if getattr(state, "_initialized3", False) and getattr(state, "_config_path3", ""):
634
+ model_name = _get_model_name_from_path(state._config_path3)
635
+ if model_name:
636
+ models.append(ModelInfo(
637
+ id=_get_model_id(model_name),
638
+ name=f"ACE-Step {model_name}",
639
+ created=created_timestamp,
640
+ input_modalities=["text", "audio"],
641
+ output_modalities=["audio", "text"],
642
+ context_length=4096,
643
+ max_output_length=300,
644
+ pricing=ModelPricing(),
645
+ description="AI music generation model",
646
+ ))
647
+
648
+ return ModelsResponse(data=models)
649
+
650
+ @router.post("/v1/chat/completions")
651
+ async def chat_completions(request: Request):
652
+ """
653
+ OpenRouter-compatible chat completions endpoint for music generation.
654
+
655
+ Submits the request to the shared asyncio.Queue and waits for completion.
656
+ Supports both streaming (SSE) and non-streaming responses.
657
+ """
658
+ state = app_state_getter()
659
+
660
+ # Check initialization
661
+ if not getattr(state, "_initialized", False):
662
+ raise HTTPException(
663
+ status_code=503,
664
+ detail=f"Model not initialized. init_error={getattr(state, '_init_error', None)}"
665
+ )
666
+
667
+ # Parse request
668
+ try:
669
+ body = await request.json()
670
+ req = ChatCompletionRequest(**body)
671
+ except Exception as e:
672
+ raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
673
+
674
+ # Parse messages for text, audio, and system instruction
675
+ prompt, lyrics, audio_paths, system_instruction, sample_query = _parse_messages(req.messages)
676
+
677
+ # When lyrics or sample_mode is explicitly provided, the message text role
678
+ # is already known — skip auto-detection results.
679
+ # _parse_messages may have put raw text into prompt or sample_query;
680
+ # recover it as raw_text for re-assignment.
681
+ if req.lyrics or req.sample_mode:
682
+ raw_text = prompt or sample_query or ""
683
+ if req.lyrics:
684
+ # lyrics provided → message text is the prompt
685
+ prompt = raw_text
686
+ lyrics = req.lyrics
687
+ sample_query = None
688
+ else:
689
+ # sample_mode → message text is the sample_query
690
+ prompt = ""
691
+ lyrics = ""
692
+ sample_query = raw_text
693
+
694
+ if not prompt and not lyrics and not sample_query and not req.sample_mode and not audio_paths:
695
+ raise HTTPException(
696
+ status_code=400,
697
+ detail="No valid prompt, lyrics, sample query, or input audio found in request"
698
+ )
699
+
700
+ # Route audio paths based on task_type.
701
+ # Multiple input_audio blocks are supported (like multiple images).
702
+ #
703
+ # For cover / repaint / lego / extract / complete:
704
+ # audio[0] → src_audio (primary: the audio to edit / cover)
705
+ # audio[1] → reference_audio (optional: style conditioning)
706
+ #
707
+ # For text2music (default):
708
+ # audio[0] → reference_audio (style conditioning → music_continuation)
709
+ reference_audio_path = None
710
+ src_audio_path = None
711
+ _SRC_AUDIO_TASK_TYPES = {"cover", "repaint", "lego", "extract", "complete"}
712
+ if audio_paths:
713
+ if req.task_type in _SRC_AUDIO_TASK_TYPES:
714
+ src_audio_path = audio_paths[0]
715
+ if len(audio_paths) > 1:
716
+ reference_audio_path = audio_paths[1]
717
+ else:
718
+ reference_audio_path = audio_paths[0]
719
+
720
+ # Convert to GenerateMusicRequest
721
+ gen_request = _to_generate_music_request(
722
+ req, prompt, lyrics, sample_query, reference_audio_path, src_audio_path
723
+ )
724
+
725
+ # Check queue capacity
726
+ job_queue = state.job_queue
727
+ if job_queue.full():
728
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
729
+
730
+ # Get audio format
731
+ audio_config = req.audio_config or AudioConfig()
732
+ audio_format = audio_config.format or DEFAULT_AUDIO_FORMAT
733
+
734
+ # Create job record and submit to queue
735
+ job_store = state.job_store
736
+ rec = job_store.create()
737
+
738
+ # Track temp files from base64 audio uploads
739
+ if audio_paths:
740
+ async with state.job_temp_files_lock:
741
+ state.job_temp_files.setdefault(rec.job_id, []).extend(audio_paths)
742
+
743
+ if req.stream:
744
+ # Streaming: use progress_queue
745
+ rec.progress_queue = asyncio.Queue()
746
+
747
+ async with state.pending_lock:
748
+ state.pending_ids.append(rec.job_id)
749
+
750
+ await job_queue.put((rec.job_id, gen_request))
751
+
752
+ return StreamingResponse(
753
+ _openrouter_stream_generator(rec, req.model, audio_format),
754
+ media_type="text/event-stream",
755
+ )
756
+ else:
757
+ # Non-streaming: use done_event
758
+ rec.done_event = asyncio.Event()
759
+
760
+ async with state.pending_lock:
761
+ state.pending_ids.append(rec.job_id)
762
+
763
+ await job_queue.put((rec.job_id, gen_request))
764
+
765
+ # Wait for completion with timeout
766
+ try:
767
+ await asyncio.wait_for(rec.done_event.wait(), timeout=GENERATION_TIMEOUT)
768
+ except asyncio.TimeoutError:
769
+ raise HTTPException(status_code=504, detail="Generation timeout")
770
+
771
+ return _build_openrouter_response(rec, req.model, audio_format)
772
+
773
+ return router
acestep/openrouter_models.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenRouter API compatible Pydantic models for ACE-Step.
2
+
3
+ This module defines request/response models that conform to OpenRouter's
4
+ chat completions API specification for audio generation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any, Dict, List, Literal, Optional, Union
10
+ from pydantic import BaseModel, Field
11
+
12
+
13
+ # =============================================================================
14
+ # Request Models
15
+ # =============================================================================
16
+
17
+ class AudioInputContent(BaseModel):
18
+ """Audio input content in base64 format."""
19
+ data: str = Field(..., description="Base64-encoded audio data")
20
+ format: str = Field(default="mp3", description="Audio format (mp3, wav, flac, etc.)")
21
+
22
+
23
+ class TextContent(BaseModel):
24
+ """Text content block."""
25
+ type: Literal["text"] = "text"
26
+ text: str = Field(..., description="Text content")
27
+
28
+
29
+ class AudioContent(BaseModel):
30
+ """Audio input content block."""
31
+ type: Literal["input_audio"] = "input_audio"
32
+ input_audio: AudioInputContent
33
+
34
+
35
+ # Union type for message content
36
+ ContentPart = Union[TextContent, AudioContent, Dict[str, Any]]
37
+
38
+
39
+ class ChatMessage(BaseModel):
40
+ """A single message in the chat conversation."""
41
+ role: Literal["system", "user", "assistant"] = Field(..., description="Message role")
42
+ content: Union[str, List[ContentPart]] = Field(..., description="Message content")
43
+ name: Optional[str] = Field(default=None, description="Optional name for the message author")
44
+
45
+
46
+ class AudioConfig(BaseModel):
47
+ """Audio generation configuration."""
48
+ duration: Optional[float] = Field(default=None, description="Target audio duration in seconds")
49
+ format: str = Field(default="mp3", description="Output audio format")
50
+ # ACE-Step specific parameters
51
+ bpm: Optional[int] = Field(default=None, description="Beats per minute")
52
+ key_scale: Optional[str] = Field(default=None, description="Musical key and scale")
53
+ time_signature: Optional[str] = Field(default=None, description="Time signature (e.g., 4/4)")
54
+ vocal_language: Optional[str] = Field(default=None, description="Vocal language code")
55
+ instrumental: Optional[bool] = Field(default=None, description="Generate instrumental only")
56
+
57
+
58
+ class ChatCompletionRequest(BaseModel):
59
+ """OpenRouter-compatible chat completion request."""
60
+ model: str = Field(..., description="Model ID to use")
61
+ messages: List[ChatMessage] = Field(..., description="List of messages")
62
+
63
+ # Modalities
64
+ modalities: Optional[List[str]] = Field(
65
+ default=None,
66
+ description="Output modalities (e.g., ['audio', 'text'])"
67
+ )
68
+
69
+ # Audio configuration
70
+ audio_config: Optional[AudioConfig] = Field(
71
+ default=None,
72
+ description="Audio generation configuration"
73
+ )
74
+
75
+ # Standard OpenAI parameters
76
+ temperature: Optional[float] = Field(default=None, ge=0, le=2)
77
+ top_p: Optional[float] = Field(default=None, ge=0, le=1)
78
+ top_k: Optional[int] = Field(default=None, ge=0)
79
+ max_tokens: Optional[int] = Field(default=None, ge=1)
80
+ stream: bool = Field(default=False, description="Enable streaming response")
81
+ stop: Optional[Union[str, List[str]]] = Field(default=None)
82
+ seed: Optional[Union[int, str]] = Field(default=None, description="Seed(s) for reproducibility. Comma-separated for batch (e.g. '42,123,456')")
83
+
84
+ # ACE-Step specific parameters (extended)
85
+ thinking: Optional[bool] = Field(default=None, description="Use LM for audio code generation")
86
+ guidance_scale: Optional[float] = Field(default=None, description="Classifier-free guidance scale")
87
+ batch_size: Optional[int] = Field(default=None, description="Number of audio samples to generate")
88
+
89
+ # ACE-Step direct fields (bypass message parsing / audio_config)
90
+ lyrics: str = Field(default="", description="Direct lyrics input (bypass message parsing)")
91
+ sample_mode: bool = Field(default=False, description="Auto-generate caption/lyrics/metas via LM; user message becomes the query")
92
+ use_format: bool = Field(default=False, description="Use format_sample to enhance caption/lyrics")
93
+ use_cot_caption: bool = Field(default=True, description="Use CoT for caption rewriting")
94
+ use_cot_language: bool = Field(default=True, description="Use CoT for language detection")
95
+
96
+ # Task type
97
+ task_type: str = Field(default="text2music", description="Task type: text2music, cover, repaint, extract, lego, complete")
98
+
99
+ # Audio editing parameters
100
+ repainting_start: float = Field(default=0.0, description="Repainting region start (seconds)")
101
+ repainting_end: Optional[float] = Field(default=None, description="Repainting region end (seconds)")
102
+ audio_cover_strength: float = Field(default=1.0, description="Audio cover strength (0.0~1.0)")
103
+
104
+ class Config:
105
+ extra = "allow" # Allow additional fields for forward compatibility
106
+
107
+
108
+ # =============================================================================
109
+ # Response Models
110
+ # =============================================================================
111
+
112
+ class AudioOutputUrl(BaseModel):
113
+ """Audio output URL (base64 data URL)."""
114
+ url: str = Field(..., description="Base64 data URL of the audio")
115
+
116
+
117
+ class AudioOutput(BaseModel):
118
+ """Audio output content block."""
119
+ type: Literal["audio_url"] = "audio_url"
120
+ audio_url: AudioOutputUrl
121
+
122
+
123
+ class AssistantMessage(BaseModel):
124
+ """Assistant response message."""
125
+ role: Literal["assistant"] = "assistant"
126
+ content: Optional[str] = Field(default=None, description="Text content")
127
+ audio: Optional[List[AudioOutput]] = Field(default=None, description="Generated audio files")
128
+
129
+
130
+ class Choice(BaseModel):
131
+ """A single completion choice."""
132
+ index: int = Field(default=0)
133
+ message: AssistantMessage
134
+ finish_reason: Literal["stop", "length", "content_filter", "error"] = "stop"
135
+
136
+
137
+ class Usage(BaseModel):
138
+ """Token usage statistics."""
139
+ prompt_tokens: int = 0
140
+ completion_tokens: int = 0
141
+ total_tokens: int = 0
142
+
143
+
144
+ class ChatCompletionResponse(BaseModel):
145
+ """OpenRouter-compatible chat completion response."""
146
+ id: str = Field(..., description="Unique completion ID")
147
+ object: Literal["chat.completion"] = "chat.completion"
148
+ created: int = Field(..., description="Unix timestamp")
149
+ model: str = Field(..., description="Model ID used")
150
+ choices: List[Choice] = Field(..., description="Completion choices")
151
+ usage: Usage = Field(default_factory=Usage)
152
+
153
+ # Extended metadata
154
+ system_fingerprint: Optional[str] = Field(default=None)
155
+
156
+
157
+ # =============================================================================
158
+ # Streaming Response Models
159
+ # =============================================================================
160
+
161
+ class DeltaContent(BaseModel):
162
+ """Delta content for streaming."""
163
+ role: Optional[Literal["assistant"]] = None
164
+ content: Optional[str] = None
165
+ audio: Optional[List[AudioOutput]] = None
166
+
167
+
168
+ class StreamChoice(BaseModel):
169
+ """Streaming choice."""
170
+ index: int = 0
171
+ delta: DeltaContent
172
+ finish_reason: Optional[Literal["stop", "length", "content_filter", "error"]] = None
173
+
174
+
175
+ class ChatCompletionChunk(BaseModel):
176
+ """Streaming chunk response."""
177
+ id: str
178
+ object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
179
+ created: int
180
+ model: str
181
+ choices: List[StreamChoice]
182
+
183
+
184
+ # =============================================================================
185
+ # Models Endpoint Response
186
+ # =============================================================================
187
+
188
+ class ModelPricing(BaseModel):
189
+ """Model pricing information."""
190
+ prompt: str = Field(default="0", description="Price per prompt token in USD")
191
+ completion: str = Field(default="0", description="Price per completion token in USD")
192
+ request: str = Field(default="0", description="Price per request in USD")
193
+ image: str = Field(default="0", description="Price per image in USD")
194
+
195
+
196
+ class ModelInfo(BaseModel):
197
+ """OpenRouter-compatible model information."""
198
+ id: str = Field(..., description="Model identifier")
199
+ name: str = Field(..., description="Display name")
200
+ created: int = Field(..., description="Unix timestamp of creation")
201
+
202
+ # Modalities
203
+ input_modalities: List[str] = Field(
204
+ default_factory=lambda: ["text"],
205
+ description="Supported input modalities"
206
+ )
207
+ output_modalities: List[str] = Field(
208
+ default_factory=lambda: ["audio", "text"],
209
+ description="Supported output modalities"
210
+ )
211
+
212
+ # Limits
213
+ context_length: int = Field(default=4096, description="Maximum context length")
214
+ max_output_length: int = Field(default=300, description="Maximum output length in seconds")
215
+
216
+ # Pricing
217
+ pricing: ModelPricing = Field(default_factory=ModelPricing)
218
+
219
+ # Metadata
220
+ description: Optional[str] = Field(default=None)
221
+ architecture: Optional[Dict[str, Any]] = Field(default=None)
222
+
223
+
224
+ class ModelsResponse(BaseModel):
225
+ """Response for /v1/models endpoint."""
226
+ object: Literal["list"] = "list"
227
+ data: List[ModelInfo] = Field(default_factory=list)
228
+
229
+
230
+ # =============================================================================
231
+ # Error Response
232
+ # =============================================================================
233
+
234
+ class ErrorDetail(BaseModel):
235
+ """Error detail information."""
236
+ message: str
237
+ type: str = "invalid_request_error"
238
+ param: Optional[str] = None
239
+ code: Optional[str] = None
240
+
241
+
242
+ class ErrorResponse(BaseModel):
243
+ """OpenRouter-compatible error response."""
244
+ error: ErrorDetail
acestep/test_time_scaling.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-Time Scaling Module
3
+ Implements perplexity-based scoring for generated audio codes
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Optional, Dict, Any, List
8
+ from loguru import logger
9
+ import yaml
10
+ import math
11
+ import re
12
+
13
+
14
+ def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
15
+ """
16
+ Calculate Pointwise Mutual Information (PMI) score.
17
+
18
+ PMI = log P(condition|codes) - log P(condition)
19
+ = log [P(codes|condition) / P(codes)]
20
+
21
+ This removes the bias from P(condition) and measures how much the codes
22
+ improve our ability to predict the condition.
23
+
24
+ Args:
25
+ log_prob_conditional: Average log probability of condition given codes
26
+ log_prob_unconditional: Average log probability of condition without codes
27
+
28
+ Returns:
29
+ PMI score (higher is better, can be positive or negative)
30
+ - Positive: codes improve prediction → good match
31
+ - Zero: codes don't help → no correlation
32
+ - Negative: codes hurt prediction → poor match
33
+ """
34
+ return log_prob_conditional - log_prob_unconditional
35
+
36
+
37
+ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
38
+ """
39
+ Convert PMI score to normalized [0, 1] range using sigmoid function.
40
+
41
+ score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
42
+
43
+ Args:
44
+ pmi: PMI score (can be positive or negative)
45
+ scale: Scale parameter to control sensitivity (default 0.1)
46
+ - Smaller scale: more sensitive to PMI changes
47
+ - Larger scale: less sensitive to PMI changes
48
+
49
+ Returns:
50
+ Normalized score in [0, 1] range, where:
51
+ - PMI > 0 → score > 0.5 (good match)
52
+ - PMI = 0 → score = 0.5 (neutral)
53
+ - PMI < 0 → score < 0.5 (poor match)
54
+
55
+ Examples (scale=1.0):
56
+ PMI=2.0 → score≈0.88 (excellent)
57
+ PMI=1.0 → score≈0.73 (good)
58
+ PMI=0.0 → score=0.50 (neutral)
59
+ PMI=-1.0 → score≈0.27 (poor)
60
+ PMI=-2.0 → score≈0.12 (bad)
61
+ """
62
+ return 1.0 / (1.0 + math.exp(-pmi / scale))
63
+
64
+
65
+ def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
66
+ target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ """
68
+ Args:
69
+ llm_handler: The handler containing the model and tokenizer.
70
+ formatted_prompt: The input context.
71
+ target_text: The text we want to calculate probability/recall for.
72
+
73
+ Returns:
74
+ Tuple of (target_logits, target_ids)
75
+ - target_logits: Logits used to predict the target tokens.
76
+ - target_ids: The ground truth token IDs of the target.
77
+ """
78
+ model = llm_handler.get_hf_model_for_scoring()
79
+ tokenizer = llm_handler.llm_tokenizer
80
+ device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
81
+
82
+ # 1. Tokenize prompt ONLY to get its length (used for slicing later).
83
+ # We must ensure special tokens are added to count the offset correctly.
84
+ prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
85
+ prompt_len = prompt_tokens_temp['input_ids'].shape[1]
86
+
87
+ # 2. Tokenize the FULL text (Prompt + Target).
88
+ # This ensures subword merging at boundaries is handled correctly by the tokenizer.
89
+ full_text = formatted_prompt + target_text
90
+ full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
91
+
92
+ input_ids = full_tokens['input_ids']
93
+
94
+ # Safety check: if target was empty or truncated entirely
95
+ if input_ids.shape[1] <= prompt_len:
96
+ return torch.empty(0, device=device), torch.empty(0, device=device)
97
+
98
+ # 3. Forward Pass (Teacher Forcing)
99
+ with torch.no_grad():
100
+ with llm_handler._load_model_context():
101
+ outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
102
+ all_logits = outputs.logits # [1, seq_len, vocab_size]
103
+
104
+ # 4. Extract Logits and Labels
105
+ # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
106
+ # Target starts at index `prompt_len`.
107
+ # So we need logits from `prompt_len - 1` up to the second to last position.
108
+
109
+ target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
110
+ target_ids = input_ids[0, prompt_len:] # [target_len]
111
+
112
+ return target_logits, target_ids
113
+
114
+
115
+ # ==============================================================================
116
+ # Scoring Logic
117
+ # ==============================================================================
118
+
119
+
120
+ def _calculate_topk_recall(llm_handler,
121
+ formatted_prompt: str,
122
+ target_text: str,
123
+ topk: int = 10) -> Tuple[float, Dict[int, float]]:
124
+ """
125
+ Calculate top-k recall for target text given prompt.
126
+ Checks if the ground truth token is within the top-k probabilities at each step.
127
+ """
128
+ # Use the fixed helper to get aligned logits/labels
129
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
130
+
131
+ if target_ids.shape[0] == 0:
132
+ return 0.0, {}
133
+
134
+ target_len = target_ids.shape[0]
135
+
136
+ # Get top-k indices for all positions at once
137
+ # topk_indices: [target_len, topk]
138
+ _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
139
+
140
+ recall_per_k = {}
141
+ position_scores = []
142
+
143
+ # Convert to list for faster CPU iteration
144
+ target_ids_list = target_ids.tolist()
145
+ topk_indices_list = topk_indices.tolist()
146
+
147
+ for k in range(1, topk + 1):
148
+ hits = 0
149
+ for pos in range(target_len):
150
+ gt_token = target_ids_list[pos]
151
+ # Check the top-k slice
152
+ topk_at_pos = topk_indices_list[pos][:k]
153
+
154
+ if gt_token in topk_at_pos:
155
+ hits += 1
156
+ # Calculate position-weighted score only once (when k=topk)
157
+ if k == topk:
158
+ rank = topk_at_pos.index(gt_token) + 1
159
+ # Rank 1 = 1.0, Rank k = small positive
160
+ position_weight = 1.0 - (rank - 1) / topk
161
+ position_scores.append(position_weight)
162
+
163
+ recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
164
+
165
+ # Fill scores for positions where GT was NOT in top-k
166
+ while len(position_scores) < target_len:
167
+ position_scores.append(0.0)
168
+
169
+ average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
170
+
171
+ return average_recall, recall_per_k
172
+
173
+
174
+ def _calculate_metadata_recall(llm_handler,
175
+ formatted_prompt: str,
176
+ fields_dict: Dict[str, Any],
177
+ topk: int = 10) -> Dict[str, float]:
178
+ """
179
+ Args:
180
+ fields_dict: Dictionary of {field_name: field_value}
181
+ """
182
+ if not fields_dict:
183
+ return {}
184
+
185
+ field_scores = {}
186
+
187
+ for field_name in sorted(fields_dict.keys()):
188
+ # Construct target text for this specific field
189
+ # e.g. <think>\nbpm: 120\n</think>\n
190
+ field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
191
+ field_target_text = f"<think>\n{field_yaml}\n</think>\n"
192
+
193
+ # Calculate recall using the robust logic
194
+ avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
195
+
196
+ field_scores[field_name] = avg_score
197
+ logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
198
+
199
+ return field_scores
200
+
201
+
202
+ def _calculate_log_prob(
203
+ llm_handler,
204
+ formatted_prompt: str,
205
+ target_text: str,
206
+ temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
207
+ ) -> float:
208
+ """
209
+ Calculate average log probability of target text given prompt.
210
+ """
211
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
212
+
213
+ if target_ids.shape[0] == 0:
214
+ return float('-inf')
215
+
216
+ # FIX: Do not divide by temperature.
217
+ # Log-probability for PMI/Perplexity should be exact.
218
+
219
+ # Calculate log probabilities (log_softmax)
220
+ log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
221
+
222
+ # Gather log probabilities of the ground truth tokens
223
+ target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
224
+
225
+ # Return average log probability
226
+ mean_log_prob = target_log_probs.mean().item()
227
+
228
+ return mean_log_prob
229
+
230
+
231
+ def calculate_reward_score(
232
+ scores: Dict[str, float],
233
+ weights_config: Optional[Dict[str, float]] = None
234
+ ) -> Tuple[float, str]:
235
+ """
236
+ Reward Model Calculator: Computes a final reward based on user priorities.
237
+
238
+ Priority Logic:
239
+ 1. Caption (Highest): The overall vibe/style must match.
240
+ 2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
241
+ 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
242
+
243
+ Strategy: Dynamic Weighted Sum
244
+ - Metadata fields are aggregated into a single 'metadata' score first.
245
+ - Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
246
+
247
+ Args:
248
+ scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
249
+ weights_config: Optional custom weights. Defaults to:
250
+ Caption (50%), Lyrics (30%), Metadata (20%).
251
+
252
+ Returns:
253
+ final_reward: The calculated reward score (0.0 - 1.0).
254
+ explanation: A formatted string explaining how the score was derived.
255
+ """
256
+
257
+ # 1. Default Preference Configuration
258
+ # These weights determine the relative importance of each component.
259
+ if weights_config is None:
260
+ weights_config = {
261
+ 'caption': 0.50, # High priority: Style/Vibe
262
+ 'lyrics': 0.30, # Medium priority: Content
263
+ 'metadata': 0.20 # Low priority: Technical details
264
+ }
265
+
266
+ # 2. Extract and Group Scores
267
+ # Caption and Lyrics are standalone high-level features.
268
+ caption_score = scores.get('caption')
269
+ lyrics_score = scores.get('lyrics')
270
+
271
+ # Metadata fields (bpm, key, duration, etc.) are aggregated.
272
+ # We treat them as a single "Technical Score" to prevent them from
273
+ # diluting the weight of Caption/Lyrics simply by having many fields.
274
+ meta_scores_list = [
275
+ val for key, val in scores.items()
276
+ if key not in ['caption', 'lyrics']
277
+ ]
278
+
279
+ # Calculate average of all metadata fields (if any exist)
280
+ meta_aggregate_score = None
281
+ if meta_scores_list:
282
+ meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
283
+
284
+ # 3. specific Active Components & Dynamic Weighting
285
+ # We only include components that actually exist in this generation.
286
+ active_components = {}
287
+
288
+ if caption_score is not None:
289
+ active_components['caption'] = (caption_score, weights_config['caption'])
290
+
291
+ if lyrics_score is not None:
292
+ active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
293
+
294
+ if meta_aggregate_score is not None:
295
+ active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
296
+
297
+ # 4. Calculate Final Weighted Score
298
+ total_base_weight = sum(w for _, w in active_components.values())
299
+ total_score = 0.0
300
+
301
+ breakdown_lines = []
302
+
303
+ if total_base_weight == 0:
304
+ return 0.0, "❌ No valid scores available to calculate reward."
305
+
306
+ # Sort by weight (importance) for display
307
+ sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
308
+
309
+ for name, (score, base_weight) in sorted_components:
310
+ # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
311
+ normalized_weight = base_weight / total_base_weight
312
+ weighted_contribution = score * normalized_weight
313
+ total_score += weighted_contribution
314
+
315
+ breakdown_lines.append(
316
+ f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
317
+ f"-> Contrib: +{weighted_contribution:.4f}"
318
+ )
319
+
320
+ return total_score, "\n".join(breakdown_lines)
321
+
322
+ # ==============================================================================
323
+ # Main Public API
324
+ # ==============================================================================
325
+
326
+
327
+ def calculate_pmi_score_per_condition(
328
+ llm_handler,
329
+ audio_codes: str,
330
+ caption: str = "",
331
+ lyrics: str = "",
332
+ metadata: Optional[Dict[str, Any]] = None,
333
+ temperature: float = 1.0,
334
+ topk: int = 10,
335
+ score_scale: float = 0.1,
336
+ ) -> Tuple[Dict[str, float], float, str]:
337
+ """
338
+ Calculate quality score separately for each condition.
339
+ - Metadata: Uses Top-k Recall.
340
+ - Caption/Lyrics: Uses PMI (Normalized).
341
+ """
342
+ if not llm_handler.llm_initialized:
343
+ return {}, 0.0, "❌ LLM not initialized"
344
+
345
+ if not audio_codes or not audio_codes.strip():
346
+ return {}, 0.0, "❌ No audio codes provided"
347
+
348
+ if "caption" not in metadata:
349
+ metadata['caption'] = caption
350
+
351
+ formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
352
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
353
+ try:
354
+ # 1. Calculate Recall for Metadata Fields
355
+ if metadata and isinstance(metadata, dict):
356
+ scores = {}
357
+ # Define which fields use which metric
358
+ metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
359
+ metadata_pmi_keys = ['caption']
360
+ for key in metadata_recall_keys:
361
+ if key in metadata and metadata[key] is not None:
362
+ recall_metadata = {key: metadata[key]}
363
+ field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
364
+ scores.update(field_scores)
365
+
366
+ # 2. Calculate PMI for Caption
367
+ for key in metadata_pmi_keys:
368
+ if key in metadata and metadata[key] is not None:
369
+ cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
370
+ target_text = f"<think>\n{cot_yaml}\n</think>\n"
371
+
372
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
373
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
374
+
375
+ pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
376
+ scores[key] = pmi_normalized
377
+
378
+ # 3. Calculate PMI for Lyrics
379
+ if lyrics:
380
+ target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
381
+
382
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
383
+
384
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
385
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
386
+
387
+ scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
388
+
389
+ if not scores:
390
+ return {}, 0.0, "❌ No conditions to evaluate"
391
+
392
+ # 4. Global Score
393
+ global_score = sum(scores.values()) / len(scores)
394
+ global_score, breakdown_lines = calculate_reward_score(scores)
395
+
396
+ # Status Message
397
+ status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
398
+ for key, score in sorted(scores.items()):
399
+ metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
400
+ status_lines.append(f" {key}: {score:.4f} ({metric})")
401
+ status = "\n".join(status_lines)
402
+ logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
403
+ return scores, global_score, status
404
+
405
+ except Exception as e:
406
+ import traceback
407
+ error_msg = f"❌ Error: {str(e)}"
408
+ logger.error(error_msg)
409
+ logger.error(traceback.format_exc())
410
+ return {}, float('-inf'), error_msg
acestep/third_parts/nano-vllm/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Xingkai Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
acestep/third_parts/nano-vllm/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img width="300" src="assets/logo.png">
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
7
+ </p>
8
+
9
+ # Nano-vLLM
10
+
11
+ A lightweight vLLM implementation built from scratch.
12
+
13
+ ## Key Features
14
+
15
+ * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
16
+ * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
17
+ * ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
23
+ ```
24
+
25
+ ## Model Download
26
+
27
+ To download the model weights manually, use the following command:
28
+ ```bash
29
+ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
30
+ --local-dir ~/huggingface/Qwen3-0.6B/ \
31
+ --local-dir-use-symlinks False
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
37
+ ```python
38
+ from nanovllm import LLM, SamplingParams
39
+ llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
40
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
41
+ prompts = ["Hello, Nano-vLLM."]
42
+ outputs = llm.generate(prompts, sampling_params)
43
+ outputs[0]["text"]
44
+ ```
45
+
46
+ ## Benchmark
47
+
48
+ See `bench.py` for benchmark.
49
+
50
+ **Test Configuration:**
51
+ - Hardware: RTX 4070 Laptop (8GB)
52
+ - Model: Qwen3-0.6B
53
+ - Total Requests: 256 sequences
54
+ - Input Length: Randomly sampled between 100–1024 tokens
55
+ - Output Length: Randomly sampled between 100–1024 tokens
56
+
57
+ **Performance Results:**
58
+ | Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
59
+ |----------------|-------------|----------|-----------------------|
60
+ | vLLM | 133,966 | 98.37 | 1361.84 |
61
+ | Nano-vLLM | 133,966 | 93.41 | 1434.13 |
62
+
63
+
64
+ ## Star History
65
+
66
+ [![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
acestep/third_parts/nano-vllm/bench.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from random import randint, seed
4
+ from nanovllm import LLM, SamplingParams
5
+ # from vllm import LLM, SamplingParams
6
+
7
+
8
+ def main():
9
+ seed(0)
10
+ num_seqs = 256
11
+ max_input_len = 1024
12
+ max_ouput_len = 1024
13
+
14
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
15
+ llm = LLM(path, enforce_eager=False, max_model_len=4096)
16
+
17
+ prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
18
+ sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
19
+ # uncomment the following line for vllm
20
+ # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
21
+
22
+ llm.generate(["Benchmark: "], SamplingParams())
23
+ t = time.time()
24
+ llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
25
+ t = (time.time() - t)
26
+ total_tokens = sum(sp.max_tokens for sp in sampling_params)
27
+ throughput = total_tokens / t
28
+ print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()