victor HF Staff commited on
Commit
dc5fc4b
·
0 Parent(s):

feat: ACE-Step Studio — custom frontend for ACE-Step v1.5 music generation

Browse files

- gr.Server with custom HTML frontend + API endpoints
- v1.5 AceStepHandler with acestep-v15-xl-turbo (8-step turbo)
- Peak normalization, ZeroGPU permission fixes
- /generate and /inspire API endpoints

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +39 -0
  2. .gitignore +1 -0
  3. README.md +31 -0
  4. acestep/__init__.py +1 -0
  5. acestep/acestep_v15_pipeline.py +303 -0
  6. acestep/api_server.py +1700 -0
  7. acestep/audio_utils.py +378 -0
  8. acestep/constants.py +109 -0
  9. acestep/constrained_logits_processor.py +0 -0
  10. acestep/dataset_handler.py +37 -0
  11. acestep/dit_alignment_score.py +870 -0
  12. acestep/genres_vocab.txt +0 -0
  13. acestep/gradio_ui/__init__.py +1 -0
  14. acestep/gradio_ui/events/__init__.py +1355 -0
  15. acestep/gradio_ui/events/generation_handlers.py +1071 -0
  16. acestep/gradio_ui/events/results_handlers.py +0 -0
  17. acestep/gradio_ui/events/training_handlers.py +644 -0
  18. acestep/gradio_ui/i18n.py +152 -0
  19. acestep/gradio_ui/i18n/en.json +245 -0
  20. acestep/gradio_ui/i18n/ja.json +245 -0
  21. acestep/gradio_ui/i18n/zh.json +245 -0
  22. acestep/gradio_ui/interfaces/__init__.py +105 -0
  23. acestep/gradio_ui/interfaces/dataset.py +101 -0
  24. acestep/gradio_ui/interfaces/generation.py +694 -0
  25. acestep/gradio_ui/interfaces/result.py +598 -0
  26. acestep/gradio_ui/interfaces/training.py +562 -0
  27. acestep/handler.py +0 -0
  28. acestep/inference.py +1181 -0
  29. acestep/llm_inference.py +0 -0
  30. acestep/local_cache.py +129 -0
  31. acestep/test_time_scaling.py +410 -0
  32. acestep/third_parts/nano-vllm/LICENSE +21 -0
  33. acestep/third_parts/nano-vllm/README.md +66 -0
  34. acestep/third_parts/nano-vllm/bench.py +32 -0
  35. acestep/third_parts/nano-vllm/example.py +33 -0
  36. acestep/third_parts/nano-vllm/nanovllm/__init__.py +2 -0
  37. acestep/third_parts/nano-vllm/nanovllm/config.py +26 -0
  38. acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py +119 -0
  39. acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py +178 -0
  40. acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py +543 -0
  41. acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py +230 -0
  42. acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py +96 -0
  43. acestep/third_parts/nano-vllm/nanovllm/layers/activation.py +14 -0
  44. acestep/third_parts/nano-vllm/nanovllm/layers/attention.py +75 -0
  45. acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py +66 -0
  46. acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py +50 -0
  47. acestep/third_parts/nano-vllm/nanovllm/layers/linear.py +153 -0
  48. acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py +61 -0
  49. acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py +114 -0
  50. acestep/third_parts/nano-vllm/nanovllm/llm.py +5 -0
.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .claude/
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ace-Step Studio
3
+ emoji: 🎵
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 6.12.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Minimalist dark UI for ACE-Step music generation
12
+ models:
13
+ - ACE-Step/Ace-Step1.5
14
+ - ACE-Step/acestep-v15-xl-turbo
15
+ preload_from_hub:
16
+ - ACE-Step/Ace-Step1.5
17
+ - ACE-Step/acestep-v15-xl-turbo
18
+ ---
19
+
20
+ # ACE-Step Studio
21
+
22
+ A minimalist, dark-themed interface for generating music with [ACE-Step](https://github.com/ace-step/ACE-Step).
23
+
24
+ **Model**: `ACE-Step/acestep-v15-xl-turbo` — generates 1 minute of audio in ~2 seconds (8-step turbo distillation).
25
+
26
+ ## Usage
27
+
28
+ 1. Enter style tags (e.g. `lo-fi, chill, piano, female vocals`)
29
+ 2. Write lyrics with `[verse]`, `[chorus]`, `[bridge]` section markers
30
+ 3. Hit **Generate** — a waveform appears when ready
31
+ 4. Use **✨ Inspire me** to auto-generate lyrics via LLM
acestep/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """ACE-Step package."""
acestep/acestep_v15_pipeline.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step V1.5 Pipeline
3
+ Handler wrapper connecting model and UI
4
+ """
5
+ import os
6
+ import sys
7
+
8
+ # Load environment variables from .env file in project root
9
+ # This allows configuration without hardcoding values
10
+ # Falls back to .env.example if .env is not found
11
+ try:
12
+ from dotenv import load_dotenv
13
+ # Get project root directory
14
+ _current_file = os.path.abspath(__file__)
15
+ _project_root = os.path.dirname(os.path.dirname(_current_file))
16
+ _env_path = os.path.join(_project_root, '.env')
17
+ _env_example_path = os.path.join(_project_root, '.env.example')
18
+
19
+ if os.path.exists(_env_path):
20
+ load_dotenv(_env_path)
21
+ print(f"Loaded configuration from {_env_path}")
22
+ elif os.path.exists(_env_example_path):
23
+ load_dotenv(_env_example_path)
24
+ print(f"Loaded configuration from {_env_example_path} (fallback)")
25
+ except ImportError:
26
+ # python-dotenv not installed, skip loading .env
27
+ pass
28
+
29
+ # Clear proxy settings that may affect Gradio
30
+ for proxy_var in ['http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'ALL_PROXY']:
31
+ os.environ.pop(proxy_var, None)
32
+
33
+ try:
34
+ # When executed as a module: `python -m acestep.acestep_v15_pipeline`
35
+ from .handler import AceStepHandler
36
+ from .llm_inference import LLMHandler
37
+ from .dataset_handler import DatasetHandler
38
+ from .gradio_ui import create_gradio_interface
39
+ except ImportError:
40
+ # When executed as a script: `python acestep/acestep_v15_pipeline.py`
41
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
42
+ if project_root not in sys.path:
43
+ sys.path.insert(0, project_root)
44
+ from acestep.handler import AceStepHandler
45
+ from acestep.llm_inference import LLMHandler
46
+ from acestep.dataset_handler import DatasetHandler
47
+ from acestep.gradio_ui import create_gradio_interface
48
+
49
+
50
+ def create_demo(init_params=None, language='en'):
51
+ """
52
+ Create Gradio demo interface
53
+
54
+ Args:
55
+ init_params: Dictionary containing initialization parameters and state.
56
+ If None, service will not be pre-initialized.
57
+ Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
58
+ 'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
59
+ 'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
60
+ 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
61
+ 'language' (UI language code)
62
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
63
+
64
+ Returns:
65
+ Gradio Blocks instance
66
+ """
67
+ # Get persistent storage path from init_params (for HuggingFace Space)
68
+ persistent_storage_path = None
69
+ if init_params:
70
+ persistent_storage_path = init_params.get('persistent_storage_path')
71
+
72
+ # Use pre-initialized handlers if available, otherwise create new ones
73
+ if init_params and init_params.get('pre_initialized') and 'dit_handler' in init_params:
74
+ dit_handler = init_params['dit_handler']
75
+ llm_handler = init_params['llm_handler']
76
+ else:
77
+ dit_handler = AceStepHandler(persistent_storage_path=persistent_storage_path)
78
+ llm_handler = LLMHandler(persistent_storage_path=persistent_storage_path)
79
+
80
+ dataset_handler = DatasetHandler() # Dataset handler
81
+
82
+ # Create Gradio interface with all handlers and initialization parameters
83
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
84
+
85
+ return demo
86
+
87
+
88
+ def get_gpu_memory_gb():
89
+ """
90
+ Get GPU memory in GB. Returns 0 if no GPU is available.
91
+ """
92
+ try:
93
+ import torch
94
+ if torch.cuda.is_available():
95
+ # Get total memory of the first GPU in GB
96
+ total_memory = torch.cuda.get_device_properties(0).total_memory
97
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
98
+ return memory_gb
99
+ else:
100
+ return 0
101
+ except Exception as e:
102
+ print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
103
+ return 0
104
+
105
+
106
+ def main():
107
+ """Main entry function"""
108
+ import argparse
109
+
110
+ # Detect GPU memory to auto-configure offload settings
111
+ gpu_memory_gb = get_gpu_memory_gb()
112
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
113
+
114
+ if auto_offload:
115
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
116
+ print("Auto-enabling CPU offload to reduce GPU memory usage")
117
+ elif gpu_memory_gb > 0:
118
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
119
+ print("CPU offload disabled by default")
120
+ else:
121
+ print("No GPU detected, running on CPU")
122
+
123
+ parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
124
+ parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
125
+ parser.add_argument("--share", action="store_true", help="Create a public link")
126
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
127
+ parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
128
+ parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
129
+
130
+ # Service mode argument
131
+ parser.add_argument("--service_mode", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False,
132
+ help="Enable service mode (default: False). When enabled, uses preset models and restricts UI options.")
133
+
134
+ # Service initialization arguments
135
+ parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
136
+ parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file path (optional, for display purposes)")
137
+ parser.add_argument("--config_path", type=str, default=None, help="Main model path (e.g., 'acestep-v15-turbo')")
138
+ parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "cpu"], help="Processing device (default: auto)")
139
+ parser.add_argument("--init_llm", type=lambda x: x.lower() in ['true', '1', 'yes'], default=True, help="Initialize 5Hz LM (default: True)")
140
+ parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
141
+ parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
142
+ parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
143
+ parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
144
+ parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
145
+
146
+ args = parser.parse_args()
147
+
148
+ # Service mode defaults (can be configured via .env file)
149
+ if args.service_mode:
150
+ print("Service mode enabled - applying preset configurations...")
151
+ # Force init_service in service mode
152
+ args.init_service = True
153
+ # Default DiT model for service mode (from env or fallback)
154
+ if args.config_path is None:
155
+ args.config_path = os.environ.get(
156
+ "SERVICE_MODE_DIT_MODEL",
157
+ "acestep-v15-turbo-fix-inst-shift-dynamic"
158
+ )
159
+ # Default LM model for service mode (from env or fallback)
160
+ if args.lm_model_path is None:
161
+ args.lm_model_path = os.environ.get(
162
+ "SERVICE_MODE_LM_MODEL",
163
+ "acestep-5Hz-lm-1.7B-v4-fix"
164
+ )
165
+ # Backend for service mode (from env or fallback to vllm)
166
+ args.backend = os.environ.get("SERVICE_MODE_BACKEND", "vllm")
167
+ print(f" DiT model: {args.config_path}")
168
+ print(f" LM model: {args.lm_model_path}")
169
+ print(f" Backend: {args.backend}")
170
+
171
+ try:
172
+ init_params = None
173
+
174
+ # If init_service is True, perform initialization before creating UI
175
+ if args.init_service:
176
+ print("Initializing service from command line...")
177
+
178
+ # Create handler instances for initialization
179
+ dit_handler = AceStepHandler()
180
+ llm_handler = LLMHandler()
181
+
182
+ # Auto-select config_path if not provided
183
+ if args.config_path is None:
184
+ available_models = dit_handler.get_available_acestep_v15_models()
185
+ if available_models:
186
+ args.config_path = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else available_models[0]
187
+ print(f"Auto-selected config_path: {args.config_path}")
188
+ else:
189
+ print("Error: No available models found. Please specify --config_path", file=sys.stderr)
190
+ sys.exit(1)
191
+
192
+ # Get project root (same logic as in handler)
193
+ current_file = os.path.abspath(__file__)
194
+ project_root = os.path.dirname(os.path.dirname(current_file))
195
+
196
+ # Determine flash attention setting
197
+ use_flash_attention = args.use_flash_attention
198
+ if use_flash_attention is None:
199
+ use_flash_attention = dit_handler.is_flash_attention_available()
200
+
201
+ # Initialize DiT handler
202
+ print(f"Initializing DiT model: {args.config_path} on {args.device}...")
203
+ init_status, enable_generate = dit_handler.initialize_service(
204
+ project_root=project_root,
205
+ config_path=args.config_path,
206
+ device=args.device,
207
+ use_flash_attention=use_flash_attention,
208
+ compile_model=False,
209
+ offload_to_cpu=args.offload_to_cpu,
210
+ offload_dit_to_cpu=args.offload_dit_to_cpu
211
+ )
212
+
213
+ if not enable_generate:
214
+ print(f"Error initializing DiT model: {init_status}", file=sys.stderr)
215
+ sys.exit(1)
216
+
217
+ print(f"DiT model initialized successfully")
218
+
219
+ # Initialize LM handler if requested
220
+ lm_status = ""
221
+ if args.init_llm:
222
+ if args.lm_model_path is None:
223
+ # Try to get default LM model
224
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
225
+ if available_lm_models:
226
+ args.lm_model_path = available_lm_models[0]
227
+ print(f"Using default LM model: {args.lm_model_path}")
228
+ else:
229
+ print("Warning: No LM models available, skipping LM initialization", file=sys.stderr)
230
+ args.init_llm = False
231
+
232
+ if args.init_llm and args.lm_model_path:
233
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
234
+ print(f"Initializing 5Hz LM: {args.lm_model_path} on {args.device}...")
235
+ lm_status, lm_success = llm_handler.initialize(
236
+ checkpoint_dir=checkpoint_dir,
237
+ lm_model_path=args.lm_model_path,
238
+ backend=args.backend,
239
+ device=args.device,
240
+ offload_to_cpu=args.offload_to_cpu,
241
+ dtype=dit_handler.dtype
242
+ )
243
+
244
+ if lm_success:
245
+ print(f"5Hz LM initialized successfully")
246
+ init_status += f"\n{lm_status}"
247
+ else:
248
+ print(f"Warning: 5Hz LM initialization failed: {lm_status}", file=sys.stderr)
249
+ init_status += f"\n{lm_status}"
250
+
251
+ # Prepare initialization parameters for UI
252
+ init_params = {
253
+ 'pre_initialized': True,
254
+ 'service_mode': args.service_mode,
255
+ 'checkpoint': args.checkpoint,
256
+ 'config_path': args.config_path,
257
+ 'device': args.device,
258
+ 'init_llm': args.init_llm,
259
+ 'lm_model_path': args.lm_model_path,
260
+ 'backend': args.backend,
261
+ 'use_flash_attention': use_flash_attention,
262
+ 'offload_to_cpu': args.offload_to_cpu,
263
+ 'offload_dit_to_cpu': args.offload_dit_to_cpu,
264
+ 'init_status': init_status,
265
+ 'enable_generate': enable_generate,
266
+ 'dit_handler': dit_handler,
267
+ 'llm_handler': llm_handler,
268
+ 'language': args.language
269
+ }
270
+
271
+ print("Service initialization completed successfully!")
272
+
273
+ # Create and launch demo
274
+ print(f"Creating Gradio interface with language: {args.language}...")
275
+ demo = create_demo(init_params=init_params, language=args.language)
276
+
277
+ # Enable queue for multi-user support
278
+ # This ensures proper request queuing and prevents concurrent generation conflicts
279
+ print("Enabling queue for multi-user support...")
280
+ demo.queue(
281
+ max_size=20, # Maximum queue size (adjust based on your needs)
282
+ status_update_rate="auto", # Update rate for queue status
283
+ )
284
+
285
+ print(f"Launching server on {args.server_name}:{args.port}...")
286
+ demo.launch(
287
+ server_name=args.server_name,
288
+ server_port=args.port,
289
+ share=args.share,
290
+ debug=args.debug,
291
+ show_error=True,
292
+ prevent_thread_lock=False, # Keep thread locked to maintain server running
293
+ inbrowser=False, # Don't auto-open browser
294
+ )
295
+ except Exception as e:
296
+ print(f"Error launching Gradio: {e}", file=sys.stderr)
297
+ import traceback
298
+ traceback.print_exc()
299
+ sys.exit(1)
300
+
301
+
302
+ if __name__ == "__main__":
303
+ main()
acestep/api_server.py ADDED
@@ -0,0 +1,1700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server for ACE-Step V1.5.
2
+
3
+ Endpoints:
4
+ - POST /release_task Create music generation task
5
+ - POST /query_result Batch query task results
6
+ - POST /v1/music/random Create random sample task
7
+ - GET /v1/models List available models
8
+ - GET /v1/audio Download audio file
9
+ - GET /health Health check
10
+
11
+ NOTE:
12
+ - In-memory queue and job store -> run uvicorn with workers=1.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import json
19
+ import os
20
+ import sys
21
+ import time
22
+ import traceback
23
+ import tempfile
24
+ import urllib.parse
25
+ from collections import deque
26
+ from concurrent.futures import ThreadPoolExecutor
27
+ from contextlib import asynccontextmanager
28
+ from dataclasses import dataclass
29
+ from pathlib import Path
30
+ from threading import Lock
31
+ from typing import Any, Dict, List, Literal, Optional
32
+ from uuid import uuid4
33
+
34
+ try:
35
+ from dotenv import load_dotenv
36
+ except ImportError: # Optional dependency
37
+ load_dotenv = None # type: ignore
38
+
39
+ from fastapi import FastAPI, HTTPException, Request
40
+ from pydantic import BaseModel, Field
41
+ from starlette.datastructures import UploadFile as StarletteUploadFile
42
+
43
+ from acestep.handler import AceStepHandler
44
+ from acestep.llm_inference import LLMHandler
45
+ from acestep.constants import (
46
+ DEFAULT_DIT_INSTRUCTION,
47
+ DEFAULT_LM_INSTRUCTION,
48
+ TASK_INSTRUCTIONS,
49
+ )
50
+ from acestep.inference import (
51
+ GenerationParams,
52
+ GenerationConfig,
53
+ generate_music,
54
+ create_sample,
55
+ format_sample,
56
+ )
57
+ from acestep.gradio_ui.events.results_handlers import _build_generation_info
58
+
59
+
60
+ # =============================================================================
61
+ # Constants
62
+ # =============================================================================
63
+
64
+ RESULT_KEY_PREFIX = "ace_step_v1.5_"
65
+ RESULT_EXPIRE_SECONDS = 7 * 24 * 60 * 60 # 7 days
66
+ TASK_TIMEOUT_SECONDS = 3600 # 1 hour
67
+ STATUS_MAP = {"queued": 0, "running": 0, "succeeded": 1, "failed": 2}
68
+
69
+ LM_DEFAULT_TEMPERATURE = 0.85
70
+ LM_DEFAULT_CFG_SCALE = 2.5
71
+ LM_DEFAULT_TOP_P = 0.9
72
+
73
+ # Parameter aliases for request parsing
74
+ PARAM_ALIASES = {
75
+ "prompt": ["prompt"],
76
+ "sample_mode": ["sample_mode", "sampleMode"],
77
+ "sample_query": ["sample_query", "sampleQuery", "description", "desc"],
78
+ "use_format": ["use_format", "useFormat", "format"],
79
+ "model": ["model", "dit_model", "ditModel"],
80
+ "key_scale": ["key_scale", "keyscale", "keyScale"],
81
+ "time_signature": ["time_signature", "timesignature", "timeSignature"],
82
+ "audio_duration": ["audio_duration", "duration", "audioDuration", "target_duration", "targetDuration"],
83
+ "vocal_language": ["vocal_language", "vocalLanguage"],
84
+ "inference_steps": ["inference_steps", "inferenceSteps"],
85
+ "guidance_scale": ["guidance_scale", "guidanceScale"],
86
+ "use_random_seed": ["use_random_seed", "useRandomSeed"],
87
+ "audio_code_string": ["audio_code_string", "audioCodeString"],
88
+ "audio_cover_strength": ["audio_cover_strength", "audioCoverStrength"],
89
+ "task_type": ["task_type", "taskType"],
90
+ "infer_method": ["infer_method", "inferMethod"],
91
+ "use_tiled_decode": ["use_tiled_decode", "useTiledDecode"],
92
+ "constrained_decoding": ["constrained_decoding", "constrainedDecoding", "constrained"],
93
+ "constrained_decoding_debug": ["constrained_decoding_debug", "constrainedDecodingDebug"],
94
+ "use_cot_caption": ["use_cot_caption", "cot_caption", "cot-caption"],
95
+ "use_cot_language": ["use_cot_language", "cot_language", "cot-language"],
96
+ "is_format_caption": ["is_format_caption", "isFormatCaption"],
97
+ }
98
+
99
+
100
+ def _parse_description_hints(description: str) -> tuple[Optional[str], bool]:
101
+ """
102
+ Parse a description string to extract language code and instrumental flag.
103
+
104
+ This function analyzes user descriptions like "Pop rock. English" or "piano solo"
105
+ to detect:
106
+ - Language: Maps language names to ISO codes (e.g., "English" -> "en")
107
+ - Instrumental: Detects patterns indicating instrumental/no-vocal music
108
+
109
+ Args:
110
+ description: User's natural language music description
111
+
112
+ Returns:
113
+ (language_code, is_instrumental) tuple:
114
+ - language_code: ISO language code (e.g., "en", "zh") or None if not detected
115
+ - is_instrumental: True if description indicates instrumental music
116
+ """
117
+ import re
118
+
119
+ if not description:
120
+ return None, False
121
+
122
+ description_lower = description.lower().strip()
123
+
124
+ # Language mapping: input patterns -> ISO code
125
+ language_mapping = {
126
+ 'english': 'en', 'en': 'en',
127
+ 'chinese': 'zh', '中文': 'zh', 'zh': 'zh', 'mandarin': 'zh',
128
+ 'japanese': 'ja', '日本語': 'ja', 'ja': 'ja',
129
+ 'korean': 'ko', '한국어': 'ko', 'ko': 'ko',
130
+ 'spanish': 'es', 'español': 'es', 'es': 'es',
131
+ 'french': 'fr', 'français': 'fr', 'fr': 'fr',
132
+ 'german': 'de', 'deutsch': 'de', 'de': 'de',
133
+ 'italian': 'it', 'italiano': 'it', 'it': 'it',
134
+ 'portuguese': 'pt', 'português': 'pt', 'pt': 'pt',
135
+ 'russian': 'ru', 'русский': 'ru', 'ru': 'ru',
136
+ 'bengali': 'bn', 'bn': 'bn',
137
+ 'hindi': 'hi', 'hi': 'hi',
138
+ 'arabic': 'ar', 'ar': 'ar',
139
+ 'thai': 'th', 'th': 'th',
140
+ 'vietnamese': 'vi', 'vi': 'vi',
141
+ 'indonesian': 'id', 'id': 'id',
142
+ 'turkish': 'tr', 'tr': 'tr',
143
+ 'dutch': 'nl', 'nl': 'nl',
144
+ 'polish': 'pl', 'pl': 'pl',
145
+ }
146
+
147
+ # Detect language
148
+ detected_language = None
149
+ for lang_name, lang_code in language_mapping.items():
150
+ if len(lang_name) <= 2:
151
+ pattern = r'(?:^|\s|[.,;:!?])' + re.escape(lang_name) + r'(?:$|\s|[.,;:!?])'
152
+ else:
153
+ pattern = r'\b' + re.escape(lang_name) + r'\b'
154
+
155
+ if re.search(pattern, description_lower):
156
+ detected_language = lang_code
157
+ break
158
+
159
+ # Detect instrumental
160
+ is_instrumental = False
161
+ if 'instrumental' in description_lower:
162
+ is_instrumental = True
163
+ elif 'pure music' in description_lower or 'pure instrument' in description_lower:
164
+ is_instrumental = True
165
+ elif description_lower.endswith(' solo') or description_lower == 'solo':
166
+ is_instrumental = True
167
+
168
+ return detected_language, is_instrumental
169
+
170
+
171
+ JobStatus = Literal["queued", "running", "succeeded", "failed"]
172
+
173
+
174
+ class GenerateMusicRequest(BaseModel):
175
+ prompt: str = Field(default="", description="Text prompt describing the music")
176
+ lyrics: str = Field(default="", description="Lyric text")
177
+
178
+ # New API semantics:
179
+ # - thinking=True: use 5Hz LM to generate audio codes (lm-dit behavior)
180
+ # - thinking=False: do not use LM to generate codes (dit behavior)
181
+ # Regardless of thinking, if some metas are missing, server may use LM to fill them.
182
+ thinking: bool = False
183
+ # Sample-mode requests auto-generate caption/lyrics/metas via LM (no user prompt).
184
+ sample_mode: bool = False
185
+ # Description for sample mode: auto-generate caption/lyrics from description query
186
+ sample_query: str = Field(default="", description="Query/description for sample mode (use create_sample)")
187
+ # Whether to use format_sample() to enhance input caption/lyrics
188
+ use_format: bool = Field(default=False, description="Use format_sample() to enhance input (default: False)")
189
+ # Model name for multi-model support (select which DiT model to use)
190
+ model: Optional[str] = Field(default=None, description="Model name to use (e.g., 'acestep-v15-turbo')")
191
+
192
+ bpm: Optional[int] = None
193
+ # Accept common client keys via manual parsing (see RequestParser).
194
+ key_scale: str = ""
195
+ time_signature: str = ""
196
+ vocal_language: str = "en"
197
+ inference_steps: int = 8
198
+ guidance_scale: float = 7.0
199
+ use_random_seed: bool = True
200
+ seed: int = -1
201
+
202
+ reference_audio_path: Optional[str] = None
203
+ src_audio_path: Optional[str] = None
204
+ audio_duration: Optional[float] = None
205
+ batch_size: Optional[int] = None
206
+
207
+ audio_code_string: str = ""
208
+
209
+ repainting_start: float = 0.0
210
+ repainting_end: Optional[float] = None
211
+
212
+ instruction: str = DEFAULT_DIT_INSTRUCTION
213
+ audio_cover_strength: float = 1.0
214
+ task_type: str = "text2music"
215
+
216
+ use_adg: bool = False
217
+ cfg_interval_start: float = 0.0
218
+ cfg_interval_end: float = 1.0
219
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
220
+ shift: float = Field(
221
+ default=3.0,
222
+ description="Timestep shift factor (range 1.0~5.0, default 3.0). Only effective for base models, not turbo models."
223
+ )
224
+ timesteps: Optional[str] = Field(
225
+ default=None,
226
+ description="Custom timesteps (comma-separated, e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference_steps and shift."
227
+ )
228
+
229
+ audio_format: str = "mp3"
230
+ use_tiled_decode: bool = True
231
+
232
+ # 5Hz LM (server-side): used for metadata completion and (when thinking=True) codes generation.
233
+ lm_model_path: Optional[str] = None # e.g. "acestep-5Hz-lm-0.6B"
234
+ lm_backend: Literal["vllm", "pt"] = "vllm"
235
+
236
+ constrained_decoding: bool = True
237
+ constrained_decoding_debug: bool = False
238
+ use_cot_caption: bool = True
239
+ use_cot_language: bool = True
240
+ is_format_caption: bool = False
241
+
242
+ lm_temperature: float = 0.85
243
+ lm_cfg_scale: float = 2.5
244
+ lm_top_k: Optional[int] = None
245
+ lm_top_p: Optional[float] = 0.9
246
+ lm_repetition_penalty: float = 1.0
247
+ lm_negative_prompt: str = "NO USER INPUT"
248
+
249
+ class Config:
250
+ allow_population_by_field_name = True
251
+ allow_population_by_alias = True
252
+
253
+
254
+ class CreateJobResponse(BaseModel):
255
+ task_id: str
256
+ status: JobStatus
257
+ queue_position: int = 0 # 1-based best-effort position when queued
258
+
259
+
260
+ class JobResult(BaseModel):
261
+ first_audio_path: Optional[str] = None
262
+ second_audio_path: Optional[str] = None
263
+ audio_paths: list[str] = Field(default_factory=list)
264
+
265
+ generation_info: str = ""
266
+ status_message: str = ""
267
+ seed_value: str = ""
268
+
269
+ metas: Dict[str, Any] = Field(default_factory=dict)
270
+ bpm: Optional[int] = None
271
+ duration: Optional[float] = None
272
+ genres: Optional[str] = None
273
+ keyscale: Optional[str] = None
274
+ timesignature: Optional[str] = None
275
+
276
+ # Model information
277
+ lm_model: Optional[str] = None
278
+ dit_model: Optional[str] = None
279
+
280
+
281
+ class JobResponse(BaseModel):
282
+ job_id: str
283
+ status: JobStatus
284
+ created_at: float
285
+ started_at: Optional[float] = None
286
+ finished_at: Optional[float] = None
287
+
288
+ # queue observability
289
+ queue_position: int = 0
290
+ eta_seconds: Optional[float] = None
291
+ avg_job_seconds: Optional[float] = None
292
+
293
+ result: Optional[JobResult] = None
294
+ error: Optional[str] = None
295
+
296
+
297
+ @dataclass
298
+ class _JobRecord:
299
+ job_id: str
300
+ status: JobStatus
301
+ created_at: float
302
+ started_at: Optional[float] = None
303
+ finished_at: Optional[float] = None
304
+ result: Optional[Dict[str, Any]] = None
305
+ error: Optional[str] = None
306
+ env: str = "development"
307
+
308
+
309
+ class _JobStore:
310
+ def __init__(self) -> None:
311
+ self._lock = Lock()
312
+ self._jobs: Dict[str, _JobRecord] = {}
313
+
314
+ def create(self) -> _JobRecord:
315
+ job_id = str(uuid4())
316
+ rec = _JobRecord(job_id=job_id, status="queued", created_at=time.time())
317
+ with self._lock:
318
+ self._jobs[job_id] = rec
319
+ return rec
320
+
321
+ def create_with_id(self, job_id: str, env: str = "development") -> _JobRecord:
322
+ """Create job record with specified ID"""
323
+ rec = _JobRecord(
324
+ job_id=job_id,
325
+ status="queued",
326
+ created_at=time.time(),
327
+ env=env
328
+ )
329
+ with self._lock:
330
+ self._jobs[job_id] = rec
331
+ return rec
332
+
333
+ def get(self, job_id: str) -> Optional[_JobRecord]:
334
+ with self._lock:
335
+ return self._jobs.get(job_id)
336
+
337
+ def mark_running(self, job_id: str) -> None:
338
+ with self._lock:
339
+ rec = self._jobs[job_id]
340
+ rec.status = "running"
341
+ rec.started_at = time.time()
342
+
343
+ def mark_succeeded(self, job_id: str, result: Dict[str, Any]) -> None:
344
+ with self._lock:
345
+ rec = self._jobs[job_id]
346
+ rec.status = "succeeded"
347
+ rec.finished_at = time.time()
348
+ rec.result = result
349
+ rec.error = None
350
+
351
+ def mark_failed(self, job_id: str, error: str) -> None:
352
+ with self._lock:
353
+ rec = self._jobs[job_id]
354
+ rec.status = "failed"
355
+ rec.finished_at = time.time()
356
+ rec.result = None
357
+ rec.error = error
358
+
359
+
360
+ def _env_bool(name: str, default: bool) -> bool:
361
+ v = os.getenv(name)
362
+ if v is None:
363
+ return default
364
+ return v.strip().lower() in {"1", "true", "yes", "y", "on"}
365
+
366
+
367
+ def _get_project_root() -> str:
368
+ current_file = os.path.abspath(__file__)
369
+ return os.path.dirname(os.path.dirname(current_file))
370
+
371
+
372
+ def _get_model_name(config_path: str) -> str:
373
+ """
374
+ Extract model name from config_path.
375
+
376
+ Args:
377
+ config_path: Path like "acestep-v15-turbo" or "/path/to/acestep-v15-turbo"
378
+
379
+ Returns:
380
+ Model name (last directory name from config_path)
381
+ """
382
+ if not config_path:
383
+ return ""
384
+ normalized = config_path.rstrip("/\\")
385
+ return os.path.basename(normalized)
386
+
387
+
388
+ def _load_project_env() -> None:
389
+ if load_dotenv is None:
390
+ return
391
+ try:
392
+ project_root = _get_project_root()
393
+ env_path = os.path.join(project_root, ".env")
394
+ if os.path.exists(env_path):
395
+ load_dotenv(env_path, override=False)
396
+ except Exception:
397
+ # Optional best-effort: continue even if .env loading fails.
398
+ pass
399
+
400
+
401
+ _load_project_env()
402
+
403
+
404
+ def _to_int(v: Any, default: Optional[int] = None) -> Optional[int]:
405
+ if v is None:
406
+ return default
407
+ if isinstance(v, int):
408
+ return v
409
+ s = str(v).strip()
410
+ if s == "":
411
+ return default
412
+ try:
413
+ return int(s)
414
+ except Exception:
415
+ return default
416
+
417
+
418
+ def _to_float(v: Any, default: Optional[float] = None) -> Optional[float]:
419
+ if v is None:
420
+ return default
421
+ if isinstance(v, float):
422
+ return v
423
+ s = str(v).strip()
424
+ if s == "":
425
+ return default
426
+ try:
427
+ return float(s)
428
+ except Exception:
429
+ return default
430
+
431
+
432
+ def _to_bool(v: Any, default: bool = False) -> bool:
433
+ if v is None:
434
+ return default
435
+ if isinstance(v, bool):
436
+ return v
437
+ s = str(v).strip().lower()
438
+ if s == "":
439
+ return default
440
+ return s in {"1", "true", "yes", "y", "on"}
441
+
442
+
443
+ def _map_status(status: str) -> int:
444
+ """Map job status string to integer code."""
445
+ return STATUS_MAP.get(status, 2)
446
+
447
+
448
+ def _parse_timesteps(s: Optional[str]) -> Optional[List[float]]:
449
+ """Parse comma-separated timesteps string to list of floats."""
450
+ if not s or not s.strip():
451
+ return None
452
+ try:
453
+ return [float(t.strip()) for t in s.split(",") if t.strip()]
454
+ except (ValueError, Exception):
455
+ return None
456
+
457
+
458
+ class RequestParser:
459
+ """Parse request parameters from multiple sources with alias support."""
460
+
461
+ def __init__(self, raw: dict):
462
+ self._raw = dict(raw) if raw else {}
463
+ self._param_obj = self._parse_json(self._raw.get("param_obj"))
464
+ self._metas = self._find_metas()
465
+
466
+ def _parse_json(self, v) -> dict:
467
+ if isinstance(v, dict):
468
+ return v
469
+ if isinstance(v, str) and v.strip():
470
+ try:
471
+ return json.loads(v)
472
+ except Exception:
473
+ pass
474
+ return {}
475
+
476
+ def _find_metas(self) -> dict:
477
+ for key in ("metas", "meta", "metadata", "user_metadata", "userMetadata"):
478
+ v = self._raw.get(key)
479
+ if v:
480
+ return self._parse_json(v)
481
+ return {}
482
+
483
+ def get(self, name: str, default=None):
484
+ """Get parameter by canonical name from all sources."""
485
+ aliases = PARAM_ALIASES.get(name, [name])
486
+ for source in (self._raw, self._param_obj, self._metas):
487
+ for alias in aliases:
488
+ v = source.get(alias)
489
+ if v is not None:
490
+ return v
491
+ return default
492
+
493
+ def str(self, name: str, default: str = "") -> str:
494
+ v = self.get(name)
495
+ return str(v) if v is not None else default
496
+
497
+ def int(self, name: str, default: Optional[int] = None) -> Optional[int]:
498
+ return _to_int(self.get(name), default)
499
+
500
+ def float(self, name: str, default: Optional[float] = None) -> Optional[float]:
501
+ return _to_float(self.get(name), default)
502
+
503
+ def bool(self, name: str, default: bool = False) -> bool:
504
+ return _to_bool(self.get(name), default)
505
+
506
+
507
+ async def _save_upload_to_temp(upload: StarletteUploadFile, *, prefix: str) -> str:
508
+ suffix = Path(upload.filename or "").suffix
509
+ fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=suffix)
510
+ os.close(fd)
511
+ try:
512
+ with open(path, "wb") as f:
513
+ while True:
514
+ chunk = await upload.read(1024 * 1024)
515
+ if not chunk:
516
+ break
517
+ f.write(chunk)
518
+ except Exception:
519
+ try:
520
+ os.remove(path)
521
+ except Exception:
522
+ pass
523
+ raise
524
+ finally:
525
+ try:
526
+ await upload.close()
527
+ except Exception:
528
+ pass
529
+ return path
530
+
531
+
532
+ def create_app() -> FastAPI:
533
+ store = _JobStore()
534
+
535
+ QUEUE_MAXSIZE = int(os.getenv("ACESTEP_QUEUE_MAXSIZE", "200"))
536
+ WORKER_COUNT = int(os.getenv("ACESTEP_QUEUE_WORKERS", "1")) # Single GPU recommended
537
+
538
+ INITIAL_AVG_JOB_SECONDS = float(os.getenv("ACESTEP_AVG_JOB_SECONDS", "5.0"))
539
+ AVG_WINDOW = int(os.getenv("ACESTEP_AVG_WINDOW", "50"))
540
+
541
+ def _path_to_audio_url(path: str) -> str:
542
+ """Convert local file path to downloadable relative URL"""
543
+ if not path:
544
+ return path
545
+ if path.startswith("http://") or path.startswith("https://"):
546
+ return path
547
+ encoded_path = urllib.parse.quote(path, safe="")
548
+ return f"/v1/audio?path={encoded_path}"
549
+
550
+ @asynccontextmanager
551
+ async def lifespan(app: FastAPI):
552
+ # Clear proxy env that may affect downstream libs
553
+ for proxy_var in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY"]:
554
+ os.environ.pop(proxy_var, None)
555
+
556
+ # Ensure compilation/temp caches do not fill up small default /tmp.
557
+ # Triton/Inductor (and the system compiler) can create large temporary files.
558
+ project_root = _get_project_root()
559
+ cache_root = os.path.join(project_root, ".cache", "acestep")
560
+ tmp_root = (os.getenv("ACESTEP_TMPDIR") or os.path.join(cache_root, "tmp")).strip()
561
+ triton_cache_root = (os.getenv("TRITON_CACHE_DIR") or os.path.join(cache_root, "triton")).strip()
562
+ inductor_cache_root = (os.getenv("TORCHINDUCTOR_CACHE_DIR") or os.path.join(cache_root, "torchinductor")).strip()
563
+
564
+ for p in [cache_root, tmp_root, triton_cache_root, inductor_cache_root]:
565
+ try:
566
+ os.makedirs(p, exist_ok=True)
567
+ except Exception:
568
+ # Best-effort: do not block startup if directory creation fails.
569
+ pass
570
+
571
+ # Respect explicit user overrides; if ACESTEP_TMPDIR is set, it should win.
572
+ if os.getenv("ACESTEP_TMPDIR"):
573
+ os.environ["TMPDIR"] = tmp_root
574
+ os.environ["TEMP"] = tmp_root
575
+ os.environ["TMP"] = tmp_root
576
+ else:
577
+ os.environ.setdefault("TMPDIR", tmp_root)
578
+ os.environ.setdefault("TEMP", tmp_root)
579
+ os.environ.setdefault("TMP", tmp_root)
580
+
581
+ os.environ.setdefault("TRITON_CACHE_DIR", triton_cache_root)
582
+ os.environ.setdefault("TORCHINDUCTOR_CACHE_DIR", inductor_cache_root)
583
+
584
+ handler = AceStepHandler()
585
+ llm_handler = LLMHandler()
586
+ init_lock = asyncio.Lock()
587
+ app.state._initialized = False
588
+ app.state._init_error = None
589
+ app.state._init_lock = init_lock
590
+
591
+ app.state.llm_handler = llm_handler
592
+ app.state._llm_initialized = False
593
+ app.state._llm_init_error = None
594
+ app.state._llm_init_lock = Lock()
595
+
596
+ # Multi-model support: secondary DiT handlers
597
+ handler2 = None
598
+ handler3 = None
599
+ config_path2 = os.getenv("ACESTEP_CONFIG_PATH2", "").strip()
600
+ config_path3 = os.getenv("ACESTEP_CONFIG_PATH3", "").strip()
601
+
602
+ if config_path2:
603
+ handler2 = AceStepHandler()
604
+ if config_path3:
605
+ handler3 = AceStepHandler()
606
+
607
+ app.state.handler2 = handler2
608
+ app.state.handler3 = handler3
609
+ app.state._initialized2 = False
610
+ app.state._initialized3 = False
611
+ app.state._config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
612
+ app.state._config_path2 = config_path2
613
+ app.state._config_path3 = config_path3
614
+
615
+ max_workers = int(os.getenv("ACESTEP_API_WORKERS", "1"))
616
+ executor = ThreadPoolExecutor(max_workers=max_workers)
617
+
618
+ # Queue & observability
619
+ app.state.job_queue = asyncio.Queue(maxsize=QUEUE_MAXSIZE) # (job_id, req)
620
+ app.state.pending_ids = deque() # queued job_ids
621
+ app.state.pending_lock = asyncio.Lock()
622
+
623
+ # temp files per job (from multipart uploads)
624
+ app.state.job_temp_files = {} # job_id -> list[path]
625
+ app.state.job_temp_files_lock = asyncio.Lock()
626
+
627
+ # stats
628
+ app.state.stats_lock = asyncio.Lock()
629
+ app.state.recent_durations = deque(maxlen=AVG_WINDOW)
630
+ app.state.avg_job_seconds = INITIAL_AVG_JOB_SECONDS
631
+
632
+ app.state.handler = handler
633
+ app.state.executor = executor
634
+ app.state.job_store = store
635
+ app.state._python_executable = sys.executable
636
+
637
+ # Temporary directory for saving generated audio files
638
+ app.state.temp_audio_dir = os.path.join(tmp_root, "api_audio")
639
+ os.makedirs(app.state.temp_audio_dir, exist_ok=True)
640
+
641
+ # Initialize local cache
642
+ try:
643
+ from acestep.local_cache import get_local_cache
644
+ local_cache_dir = os.path.join(cache_root, "local_redis")
645
+ app.state.local_cache = get_local_cache(local_cache_dir)
646
+ except ImportError:
647
+ app.state.local_cache = None
648
+
649
+ async def _ensure_initialized() -> None:
650
+ h: AceStepHandler = app.state.handler
651
+
652
+ if getattr(app.state, "_initialized", False):
653
+ return
654
+ if getattr(app.state, "_init_error", None):
655
+ raise RuntimeError(app.state._init_error)
656
+
657
+ async with app.state._init_lock:
658
+ if getattr(app.state, "_initialized", False):
659
+ return
660
+ if getattr(app.state, "_init_error", None):
661
+ raise RuntimeError(app.state._init_error)
662
+
663
+ project_root = _get_project_root()
664
+ config_path = os.getenv("ACESTEP_CONFIG_PATH", "acestep-v15-turbo")
665
+ device = os.getenv("ACESTEP_DEVICE", "auto")
666
+
667
+ use_flash_attention = _env_bool("ACESTEP_USE_FLASH_ATTENTION", True)
668
+ offload_to_cpu = _env_bool("ACESTEP_OFFLOAD_TO_CPU", False)
669
+ offload_dit_to_cpu = _env_bool("ACESTEP_OFFLOAD_DIT_TO_CPU", False)
670
+
671
+ # Initialize primary model
672
+ status_msg, ok = h.initialize_service(
673
+ project_root=project_root,
674
+ config_path=config_path,
675
+ device=device,
676
+ use_flash_attention=use_flash_attention,
677
+ compile_model=False,
678
+ offload_to_cpu=offload_to_cpu,
679
+ offload_dit_to_cpu=offload_dit_to_cpu,
680
+ )
681
+ if not ok:
682
+ app.state._init_error = status_msg
683
+ raise RuntimeError(status_msg)
684
+ app.state._initialized = True
685
+
686
+ # Initialize secondary model if configured
687
+ if app.state.handler2 and app.state._config_path2:
688
+ try:
689
+ status_msg2, ok2 = app.state.handler2.initialize_service(
690
+ project_root=project_root,
691
+ config_path=app.state._config_path2,
692
+ device=device,
693
+ use_flash_attention=use_flash_attention,
694
+ compile_model=False,
695
+ offload_to_cpu=offload_to_cpu,
696
+ offload_dit_to_cpu=offload_dit_to_cpu,
697
+ )
698
+ app.state._initialized2 = ok2
699
+ if ok2:
700
+ print(f"[API Server] Secondary model loaded: {_get_model_name(app.state._config_path2)}")
701
+ else:
702
+ print(f"[API Server] Warning: Secondary model failed to load: {status_msg2}")
703
+ except Exception as e:
704
+ print(f"[API Server] Warning: Failed to initialize secondary model: {e}")
705
+ app.state._initialized2 = False
706
+
707
+ # Initialize third model if configured
708
+ if app.state.handler3 and app.state._config_path3:
709
+ try:
710
+ status_msg3, ok3 = app.state.handler3.initialize_service(
711
+ project_root=project_root,
712
+ config_path=app.state._config_path3,
713
+ device=device,
714
+ use_flash_attention=use_flash_attention,
715
+ compile_model=False,
716
+ offload_to_cpu=offload_to_cpu,
717
+ offload_dit_to_cpu=offload_dit_to_cpu,
718
+ )
719
+ app.state._initialized3 = ok3
720
+ if ok3:
721
+ print(f"[API Server] Third model loaded: {_get_model_name(app.state._config_path3)}")
722
+ else:
723
+ print(f"[API Server] Warning: Third model failed to load: {status_msg3}")
724
+ except Exception as e:
725
+ print(f"[API Server] Warning: Failed to initialize third model: {e}")
726
+ app.state._initialized3 = False
727
+
728
+ async def _cleanup_job_temp_files(job_id: str) -> None:
729
+ async with app.state.job_temp_files_lock:
730
+ paths = app.state.job_temp_files.pop(job_id, [])
731
+ for p in paths:
732
+ try:
733
+ os.remove(p)
734
+ except Exception:
735
+ pass
736
+
737
+ def _update_local_cache(job_id: str, result: Optional[Dict], status: str) -> None:
738
+ """Update local cache with job result"""
739
+ local_cache = getattr(app.state, 'local_cache', None)
740
+ if not local_cache:
741
+ return
742
+
743
+ rec = store.get(job_id)
744
+ env = getattr(rec, 'env', 'development') if rec else 'development'
745
+ create_time = rec.created_at if rec else time.time()
746
+
747
+ status_int = _map_status(status)
748
+
749
+ if status == "succeeded" and result:
750
+ audio_paths = result.get("audio_paths", [])
751
+ # Final prompt/lyrics (may be modified by thinking/format)
752
+ final_prompt = result.get("prompt", "")
753
+ final_lyrics = result.get("lyrics", "")
754
+ # Original user input from metas
755
+ metas_raw = result.get("metas", {}) or {}
756
+ original_prompt = metas_raw.get("prompt", "")
757
+ original_lyrics = metas_raw.get("lyrics", "")
758
+ # metas contains original input + other metadata
759
+ metas = {
760
+ "bpm": metas_raw.get("bpm"),
761
+ "duration": metas_raw.get("duration"),
762
+ "genres": metas_raw.get("genres", ""),
763
+ "keyscale": metas_raw.get("keyscale", ""),
764
+ "timesignature": metas_raw.get("timesignature", ""),
765
+ "prompt": original_prompt,
766
+ "lyrics": original_lyrics,
767
+ }
768
+ # Extra fields for Discord bot
769
+ generation_info = result.get("generation_info", "")
770
+ seed_value = result.get("seed_value", "")
771
+ lm_model = result.get("lm_model", "")
772
+ dit_model = result.get("dit_model", "")
773
+
774
+ if audio_paths:
775
+ result_data = [
776
+ {
777
+ "file": p,
778
+ "wave": "",
779
+ "status": status_int,
780
+ "create_time": int(create_time),
781
+ "env": env,
782
+ "prompt": final_prompt,
783
+ "lyrics": final_lyrics,
784
+ "metas": metas,
785
+ "generation_info": generation_info,
786
+ "seed_value": seed_value,
787
+ "lm_model": lm_model,
788
+ "dit_model": dit_model,
789
+ }
790
+ for p in audio_paths
791
+ ]
792
+ else:
793
+ result_data = [{
794
+ "file": "",
795
+ "wave": "",
796
+ "status": status_int,
797
+ "create_time": int(create_time),
798
+ "env": env,
799
+ "prompt": final_prompt,
800
+ "lyrics": final_lyrics,
801
+ "metas": metas,
802
+ "generation_info": generation_info,
803
+ "seed_value": seed_value,
804
+ "lm_model": lm_model,
805
+ "dit_model": dit_model,
806
+ }]
807
+ else:
808
+ result_data = [{"file": "", "wave": "", "status": status_int, "create_time": int(create_time), "env": env}]
809
+
810
+ result_key = f"{RESULT_KEY_PREFIX}{job_id}"
811
+ local_cache.set(result_key, result_data, ex=RESULT_EXPIRE_SECONDS)
812
+
813
+ async def _run_one_job(job_id: str, req: GenerateMusicRequest) -> None:
814
+ job_store: _JobStore = app.state.job_store
815
+ llm: LLMHandler = app.state.llm_handler
816
+ executor: ThreadPoolExecutor = app.state.executor
817
+
818
+ await _ensure_initialized()
819
+ job_store.mark_running(job_id)
820
+
821
+ # Select DiT handler based on user's model choice
822
+ # Default: use primary handler
823
+ selected_handler: AceStepHandler = app.state.handler
824
+ selected_model_name = _get_model_name(app.state._config_path)
825
+
826
+ if req.model:
827
+ model_matched = False
828
+
829
+ # Check if it matches the second model
830
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
831
+ model2_name = _get_model_name(app.state._config_path2)
832
+ if req.model == model2_name:
833
+ selected_handler = app.state.handler2
834
+ selected_model_name = model2_name
835
+ model_matched = True
836
+ print(f"[API Server] Job {job_id}: Using second model: {model2_name}")
837
+
838
+ # Check if it matches the third model
839
+ if not model_matched and app.state.handler3 and getattr(app.state, "_initialized3", False):
840
+ model3_name = _get_model_name(app.state._config_path3)
841
+ if req.model == model3_name:
842
+ selected_handler = app.state.handler3
843
+ selected_model_name = model3_name
844
+ model_matched = True
845
+ print(f"[API Server] Job {job_id}: Using third model: {model3_name}")
846
+
847
+ if not model_matched:
848
+ available_models = [_get_model_name(app.state._config_path)]
849
+ if app.state.handler2 and getattr(app.state, "_initialized2", False):
850
+ available_models.append(_get_model_name(app.state._config_path2))
851
+ if app.state.handler3 and getattr(app.state, "_initialized3", False):
852
+ available_models.append(_get_model_name(app.state._config_path3))
853
+ print(f"[API Server] Job {job_id}: Model '{req.model}' not found in {available_models}, using primary: {selected_model_name}")
854
+
855
+ # Use selected handler for generation
856
+ h: AceStepHandler = selected_handler
857
+
858
+ def _blocking_generate() -> Dict[str, Any]:
859
+ """Generate music using unified inference logic from acestep.inference"""
860
+
861
+ def _ensure_llm_ready() -> None:
862
+ """Ensure LLM handler is initialized when needed"""
863
+ with app.state._llm_init_lock:
864
+ initialized = getattr(app.state, "_llm_initialized", False)
865
+ had_error = getattr(app.state, "_llm_init_error", None)
866
+ if initialized or had_error is not None:
867
+ return
868
+
869
+ project_root = _get_project_root()
870
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
871
+ lm_model_path = (req.lm_model_path or os.getenv("ACESTEP_LM_MODEL_PATH") or "acestep-5Hz-lm-0.6B").strip()
872
+ backend = (req.lm_backend or os.getenv("ACESTEP_LM_BACKEND") or "vllm").strip().lower()
873
+ if backend not in {"vllm", "pt"}:
874
+ backend = "vllm"
875
+
876
+ lm_device = os.getenv("ACESTEP_LM_DEVICE", os.getenv("ACESTEP_DEVICE", "auto"))
877
+ lm_offload = _env_bool("ACESTEP_LM_OFFLOAD_TO_CPU", False)
878
+
879
+ status, ok = llm.initialize(
880
+ checkpoint_dir=checkpoint_dir,
881
+ lm_model_path=lm_model_path,
882
+ backend=backend,
883
+ device=lm_device,
884
+ offload_to_cpu=lm_offload,
885
+ dtype=h.dtype,
886
+ )
887
+ if not ok:
888
+ app.state._llm_init_error = status
889
+ else:
890
+ app.state._llm_initialized = True
891
+
892
+ def _normalize_metas(meta: Dict[str, Any]) -> Dict[str, Any]:
893
+ """Ensure a stable `metas` dict (keys always present)."""
894
+ meta = meta or {}
895
+ out: Dict[str, Any] = dict(meta)
896
+
897
+ # Normalize key aliases
898
+ if "keyscale" not in out and "key_scale" in out:
899
+ out["keyscale"] = out.get("key_scale")
900
+ if "timesignature" not in out and "time_signature" in out:
901
+ out["timesignature"] = out.get("time_signature")
902
+
903
+ # Ensure required keys exist
904
+ for k in ["bpm", "duration", "genres", "keyscale", "timesignature"]:
905
+ if out.get(k) in (None, ""):
906
+ out[k] = "N/A"
907
+ return out
908
+
909
+ # Normalize LM sampling parameters
910
+ lm_top_k = req.lm_top_k if req.lm_top_k and req.lm_top_k > 0 else 0
911
+ lm_top_p = req.lm_top_p if req.lm_top_p and req.lm_top_p < 1.0 else 0.9
912
+
913
+ # Determine if LLM is needed
914
+ thinking = bool(req.thinking)
915
+ sample_mode = bool(req.sample_mode)
916
+ has_sample_query = bool(req.sample_query and req.sample_query.strip())
917
+ use_format = bool(req.use_format)
918
+ use_cot_caption = bool(req.use_cot_caption)
919
+ use_cot_language = bool(req.use_cot_language)
920
+
921
+ # LLM is needed for:
922
+ # - thinking mode (LM generates audio codes)
923
+ # - sample_mode (LM generates random caption/lyrics/metas)
924
+ # - sample_query/description (LM generates from description)
925
+ # - use_format (LM enhances caption/lyrics)
926
+ # - use_cot_caption or use_cot_language (LM enhances metadata)
927
+ need_llm = thinking or sample_mode or has_sample_query or use_format or use_cot_caption or use_cot_language
928
+
929
+ # Ensure LLM is ready if needed
930
+ if need_llm:
931
+ _ensure_llm_ready()
932
+ if getattr(app.state, "_llm_init_error", None):
933
+ raise RuntimeError(f"5Hz LM init failed: {app.state._llm_init_error}")
934
+
935
+ # Handle sample mode or description: generate caption/lyrics/metas via LM
936
+ caption = req.prompt
937
+ lyrics = req.lyrics
938
+ bpm = req.bpm
939
+ key_scale = req.key_scale
940
+ time_signature = req.time_signature
941
+ audio_duration = req.audio_duration
942
+
943
+ # Save original user input for metas
944
+ original_prompt = req.prompt or ""
945
+ original_lyrics = req.lyrics or ""
946
+
947
+ if sample_mode or has_sample_query:
948
+ if has_sample_query:
949
+ # Use create_sample() with description query
950
+ parsed_language, parsed_instrumental = _parse_description_hints(req.sample_query)
951
+
952
+ # Determine vocal_language with priority:
953
+ # 1. User-specified vocal_language (if not default "en")
954
+ # 2. Language parsed from description
955
+ # 3. None (no constraint)
956
+ if req.vocal_language and req.vocal_language not in ("en", "unknown", ""):
957
+ sample_language = req.vocal_language
958
+ else:
959
+ sample_language = parsed_language
960
+
961
+ sample_result = create_sample(
962
+ llm_handler=llm,
963
+ query=req.sample_query,
964
+ instrumental=parsed_instrumental,
965
+ vocal_language=sample_language,
966
+ temperature=req.lm_temperature,
967
+ top_k=lm_top_k if lm_top_k > 0 else None,
968
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
969
+ use_constrained_decoding=req.constrained_decoding,
970
+ )
971
+
972
+ if not sample_result.success:
973
+ raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}")
974
+
975
+ # Use generated sample data
976
+ caption = sample_result.caption
977
+ lyrics = sample_result.lyrics
978
+ bpm = sample_result.bpm
979
+ key_scale = sample_result.keyscale
980
+ time_signature = sample_result.timesignature
981
+ audio_duration = sample_result.duration
982
+ else:
983
+ # Original sample_mode behavior: random generation
984
+ sample_metadata, sample_status = llm.understand_audio_from_codes(
985
+ audio_codes="NO USER INPUT",
986
+ temperature=req.lm_temperature,
987
+ top_k=lm_top_k if lm_top_k > 0 else None,
988
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
989
+ repetition_penalty=req.lm_repetition_penalty,
990
+ use_constrained_decoding=req.constrained_decoding,
991
+ constrained_decoding_debug=req.constrained_decoding_debug,
992
+ )
993
+
994
+ if not sample_metadata or str(sample_status).startswith("❌"):
995
+ raise RuntimeError(f"Sample generation failed: {sample_status}")
996
+
997
+ # Use generated values with fallback defaults
998
+ caption = sample_metadata.get("caption", "")
999
+ lyrics = sample_metadata.get("lyrics", "")
1000
+ bpm = _to_int(sample_metadata.get("bpm"), None) or _to_int(os.getenv("ACESTEP_SAMPLE_DEFAULT_BPM", "120"), 120)
1001
+ key_scale = sample_metadata.get("keyscale", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_KEY", "C Major")
1002
+ time_signature = sample_metadata.get("timesignature", "") or os.getenv("ACESTEP_SAMPLE_DEFAULT_TIMESIGNATURE", "4/4")
1003
+ audio_duration = _to_float(sample_metadata.get("duration"), None) or _to_float(os.getenv("ACESTEP_SAMPLE_DEFAULT_DURATION_SECONDS", "120"), 120.0)
1004
+
1005
+ # Apply format_sample() if use_format is True and caption/lyrics are provided
1006
+ format_has_duration = False
1007
+
1008
+ if req.use_format and (caption or lyrics):
1009
+ _ensure_llm_ready()
1010
+ if getattr(app.state, "_llm_init_error", None):
1011
+ raise RuntimeError(f"5Hz LM init failed (needed for format): {app.state._llm_init_error}")
1012
+
1013
+ # Build user_metadata from request params (matching bot.py behavior)
1014
+ user_metadata_for_format = {}
1015
+ if bpm is not None:
1016
+ user_metadata_for_format['bpm'] = bpm
1017
+ if audio_duration is not None and audio_duration > 0:
1018
+ user_metadata_for_format['duration'] = int(audio_duration)
1019
+ if key_scale:
1020
+ user_metadata_for_format['keyscale'] = key_scale
1021
+ if time_signature:
1022
+ user_metadata_for_format['timesignature'] = time_signature
1023
+ if req.vocal_language and req.vocal_language != "unknown":
1024
+ user_metadata_for_format['language'] = req.vocal_language
1025
+
1026
+ format_result = format_sample(
1027
+ llm_handler=llm,
1028
+ caption=caption,
1029
+ lyrics=lyrics,
1030
+ user_metadata=user_metadata_for_format if user_metadata_for_format else None,
1031
+ temperature=req.lm_temperature,
1032
+ top_k=lm_top_k if lm_top_k > 0 else None,
1033
+ top_p=lm_top_p if lm_top_p < 1.0 else None,
1034
+ use_constrained_decoding=req.constrained_decoding,
1035
+ )
1036
+
1037
+ if format_result.success:
1038
+ # Extract all formatted data (matching bot.py behavior)
1039
+ caption = format_result.caption or caption
1040
+ lyrics = format_result.lyrics or lyrics
1041
+ if format_result.duration:
1042
+ audio_duration = format_result.duration
1043
+ format_has_duration = True
1044
+ if format_result.bpm:
1045
+ bpm = format_result.bpm
1046
+ if format_result.keyscale:
1047
+ key_scale = format_result.keyscale
1048
+ if format_result.timesignature:
1049
+ time_signature = format_result.timesignature
1050
+
1051
+ # Parse timesteps string to list of floats if provided
1052
+ parsed_timesteps = _parse_timesteps(req.timesteps)
1053
+
1054
+ # Determine actual inference steps (timesteps override inference_steps)
1055
+ actual_inference_steps = len(parsed_timesteps) if parsed_timesteps else req.inference_steps
1056
+
1057
+ # Auto-select instruction based on task_type if user didn't provide custom instruction
1058
+ # This matches gradio behavior which uses TASK_INSTRUCTIONS for each task type
1059
+ instruction_to_use = req.instruction
1060
+ if instruction_to_use == DEFAULT_DIT_INSTRUCTION and req.task_type in TASK_INSTRUCTIONS:
1061
+ instruction_to_use = TASK_INSTRUCTIONS[req.task_type]
1062
+
1063
+ # Build GenerationParams using unified interface
1064
+ # Note: thinking controls LM code generation, sample_mode only affects CoT metas
1065
+ params = GenerationParams(
1066
+ task_type=req.task_type,
1067
+ instruction=instruction_to_use,
1068
+ reference_audio=req.reference_audio_path,
1069
+ src_audio=req.src_audio_path,
1070
+ audio_codes=req.audio_code_string,
1071
+ caption=caption,
1072
+ lyrics=lyrics,
1073
+ instrumental=False,
1074
+ vocal_language=req.vocal_language,
1075
+ bpm=bpm,
1076
+ keyscale=key_scale,
1077
+ timesignature=time_signature,
1078
+ duration=audio_duration if audio_duration else -1.0,
1079
+ inference_steps=req.inference_steps,
1080
+ seed=req.seed,
1081
+ guidance_scale=req.guidance_scale,
1082
+ use_adg=req.use_adg,
1083
+ cfg_interval_start=req.cfg_interval_start,
1084
+ cfg_interval_end=req.cfg_interval_end,
1085
+ shift=req.shift,
1086
+ infer_method=req.infer_method,
1087
+ timesteps=parsed_timesteps,
1088
+ repainting_start=req.repainting_start,
1089
+ repainting_end=req.repainting_end if req.repainting_end else -1,
1090
+ audio_cover_strength=req.audio_cover_strength,
1091
+ # LM parameters
1092
+ thinking=thinking, # Use LM for code generation when thinking=True
1093
+ lm_temperature=req.lm_temperature,
1094
+ lm_cfg_scale=req.lm_cfg_scale,
1095
+ lm_top_k=lm_top_k,
1096
+ lm_top_p=lm_top_p,
1097
+ lm_negative_prompt=req.lm_negative_prompt,
1098
+ # use_cot_metas logic:
1099
+ # - sample_mode: metas already generated, skip Phase 1
1100
+ # - format with duration: metas already generated, skip Phase 1
1101
+ # - format without duration: need Phase 1 to generate duration
1102
+ # - no format: need Phase 1 to generate all metas
1103
+ use_cot_metas=not sample_mode and not format_has_duration,
1104
+ use_cot_caption=req.use_cot_caption,
1105
+ use_cot_language=req.use_cot_language,
1106
+ use_constrained_decoding=req.constrained_decoding,
1107
+ )
1108
+
1109
+ # Build GenerationConfig - default to 2 audios like gradio_ui
1110
+ batch_size = req.batch_size if req.batch_size is not None else 2
1111
+ config = GenerationConfig(
1112
+ batch_size=batch_size,
1113
+ use_random_seed=req.use_random_seed,
1114
+ seeds=None, # Let unified logic handle seed generation
1115
+ audio_format=req.audio_format,
1116
+ constrained_decoding_debug=req.constrained_decoding_debug,
1117
+ )
1118
+
1119
+ # Check LLM initialization status
1120
+ llm_is_initialized = getattr(app.state, "_llm_initialized", False)
1121
+ llm_to_pass = llm if llm_is_initialized else None
1122
+
1123
+ # Generate music using unified interface
1124
+ result = generate_music(
1125
+ dit_handler=h,
1126
+ llm_handler=llm_to_pass,
1127
+ params=params,
1128
+ config=config,
1129
+ save_dir=app.state.temp_audio_dir,
1130
+ progress=None,
1131
+ )
1132
+
1133
+ if not result.success:
1134
+ raise RuntimeError(f"Music generation failed: {result.error or result.status_message}")
1135
+
1136
+ # Extract results
1137
+ audio_paths = [audio["path"] for audio in result.audios if audio.get("path")]
1138
+ first_audio = audio_paths[0] if len(audio_paths) > 0 else None
1139
+ second_audio = audio_paths[1] if len(audio_paths) > 1 else None
1140
+
1141
+ # Get metadata from LM or CoT results
1142
+ lm_metadata = result.extra_outputs.get("lm_metadata", {})
1143
+ metas_out = _normalize_metas(lm_metadata)
1144
+
1145
+ # Update metas with actual values used
1146
+ if params.cot_bpm:
1147
+ metas_out["bpm"] = params.cot_bpm
1148
+ elif bpm:
1149
+ metas_out["bpm"] = bpm
1150
+
1151
+ if params.cot_duration:
1152
+ metas_out["duration"] = params.cot_duration
1153
+ elif audio_duration:
1154
+ metas_out["duration"] = audio_duration
1155
+
1156
+ if params.cot_keyscale:
1157
+ metas_out["keyscale"] = params.cot_keyscale
1158
+ elif key_scale:
1159
+ metas_out["keyscale"] = key_scale
1160
+
1161
+ if params.cot_timesignature:
1162
+ metas_out["timesignature"] = params.cot_timesignature
1163
+ elif time_signature:
1164
+ metas_out["timesignature"] = time_signature
1165
+
1166
+ # Store original user input in metas (not the final/modified values)
1167
+ metas_out["prompt"] = original_prompt
1168
+ metas_out["lyrics"] = original_lyrics
1169
+
1170
+ # Extract seed values for response (comma-separated for multiple audios)
1171
+ seed_values = []
1172
+ for audio in result.audios:
1173
+ audio_params = audio.get("params", {})
1174
+ seed = audio_params.get("seed")
1175
+ if seed is not None:
1176
+ seed_values.append(str(seed))
1177
+ seed_value = ",".join(seed_values) if seed_values else ""
1178
+
1179
+ # Build generation_info using the helper function (like gradio_ui)
1180
+ time_costs = result.extra_outputs.get("time_costs", {})
1181
+ generation_info = _build_generation_info(
1182
+ lm_metadata=lm_metadata,
1183
+ time_costs=time_costs,
1184
+ seed_value=seed_value,
1185
+ inference_steps=req.inference_steps,
1186
+ num_audios=len(result.audios),
1187
+ )
1188
+
1189
+ def _none_if_na_str(v: Any) -> Optional[str]:
1190
+ if v is None:
1191
+ return None
1192
+ s = str(v).strip()
1193
+ if s in {"", "N/A"}:
1194
+ return None
1195
+ return s
1196
+
1197
+ # Get model information
1198
+ lm_model_name = os.getenv("ACESTEP_LM_MODEL_PATH", "acestep-5Hz-lm-0.6B")
1199
+ # Use selected_model_name (set at the beginning of _run_one_job)
1200
+ dit_model_name = selected_model_name
1201
+
1202
+ return {
1203
+ "first_audio_path": _path_to_audio_url(first_audio) if first_audio else None,
1204
+ "second_audio_path": _path_to_audio_url(second_audio) if second_audio else None,
1205
+ "audio_paths": [_path_to_audio_url(p) for p in audio_paths],
1206
+ "generation_info": generation_info,
1207
+ "status_message": result.status_message,
1208
+ "seed_value": seed_value,
1209
+ # Final prompt/lyrics (may be modified by thinking/format)
1210
+ "prompt": caption or "",
1211
+ "lyrics": lyrics or "",
1212
+ # metas contains original user input + other metadata
1213
+ "metas": metas_out,
1214
+ "bpm": metas_out.get("bpm") if isinstance(metas_out.get("bpm"), int) else None,
1215
+ "duration": metas_out.get("duration") if isinstance(metas_out.get("duration"), (int, float)) else None,
1216
+ "genres": _none_if_na_str(metas_out.get("genres")),
1217
+ "keyscale": _none_if_na_str(metas_out.get("keyscale")),
1218
+ "timesignature": _none_if_na_str(metas_out.get("timesignature")),
1219
+ "lm_model": lm_model_name,
1220
+ "dit_model": dit_model_name,
1221
+ }
1222
+
1223
+ t0 = time.time()
1224
+ try:
1225
+ loop = asyncio.get_running_loop()
1226
+ result = await loop.run_in_executor(executor, _blocking_generate)
1227
+ job_store.mark_succeeded(job_id, result)
1228
+
1229
+ # Update local cache
1230
+ _update_local_cache(job_id, result, "succeeded")
1231
+ except Exception:
1232
+ job_store.mark_failed(job_id, traceback.format_exc())
1233
+
1234
+ # Update local cache
1235
+ _update_local_cache(job_id, None, "failed")
1236
+ finally:
1237
+ dt = max(0.0, time.time() - t0)
1238
+ async with app.state.stats_lock:
1239
+ app.state.recent_durations.append(dt)
1240
+ if app.state.recent_durations:
1241
+ app.state.avg_job_seconds = sum(app.state.recent_durations) / len(app.state.recent_durations)
1242
+
1243
+ async def _queue_worker(worker_idx: int) -> None:
1244
+ while True:
1245
+ job_id, req = await app.state.job_queue.get()
1246
+ try:
1247
+ async with app.state.pending_lock:
1248
+ try:
1249
+ app.state.pending_ids.remove(job_id)
1250
+ except ValueError:
1251
+ pass
1252
+
1253
+ await _run_one_job(job_id, req)
1254
+ finally:
1255
+ await _cleanup_job_temp_files(job_id)
1256
+ app.state.job_queue.task_done()
1257
+
1258
+ worker_count = max(1, WORKER_COUNT)
1259
+ workers = [asyncio.create_task(_queue_worker(i)) for i in range(worker_count)]
1260
+ app.state.worker_tasks = workers
1261
+
1262
+ try:
1263
+ yield
1264
+ finally:
1265
+ for t in workers:
1266
+ t.cancel()
1267
+ executor.shutdown(wait=False, cancel_futures=True)
1268
+
1269
+ app = FastAPI(title="ACE-Step API", version="1.0", lifespan=lifespan)
1270
+
1271
+ async def _queue_position(job_id: str) -> int:
1272
+ async with app.state.pending_lock:
1273
+ try:
1274
+ return list(app.state.pending_ids).index(job_id) + 1
1275
+ except ValueError:
1276
+ return 0
1277
+
1278
+ async def _eta_seconds_for_position(pos: int) -> Optional[float]:
1279
+ if pos <= 0:
1280
+ return None
1281
+ async with app.state.stats_lock:
1282
+ avg = float(getattr(app.state, "avg_job_seconds", INITIAL_AVG_JOB_SECONDS))
1283
+ return pos * avg
1284
+
1285
+ @app.post("/release_task", response_model=CreateJobResponse)
1286
+ async def create_music_generate_job(request: Request) -> CreateJobResponse:
1287
+ content_type = (request.headers.get("content-type") or "").lower()
1288
+ temp_files: list[str] = []
1289
+
1290
+ def _build_request(p: RequestParser, **kwargs) -> GenerateMusicRequest:
1291
+ """Build GenerateMusicRequest from parsed parameters."""
1292
+ return GenerateMusicRequest(
1293
+ prompt=p.str("prompt"),
1294
+ lyrics=p.str("lyrics"),
1295
+ thinking=p.bool("thinking"),
1296
+ sample_mode=p.bool("sample_mode"),
1297
+ sample_query=p.str("sample_query"),
1298
+ use_format=p.bool("use_format"),
1299
+ model=p.str("model") or None,
1300
+ bpm=p.int("bpm"),
1301
+ key_scale=p.str("key_scale"),
1302
+ time_signature=p.str("time_signature"),
1303
+ audio_duration=p.float("audio_duration"),
1304
+ vocal_language=p.str("vocal_language", "en"),
1305
+ inference_steps=p.int("inference_steps", 8),
1306
+ guidance_scale=p.float("guidance_scale", 7.0),
1307
+ use_random_seed=p.bool("use_random_seed", True),
1308
+ seed=p.int("seed", -1),
1309
+ batch_size=p.int("batch_size"),
1310
+ audio_code_string=p.str("audio_code_string"),
1311
+ repainting_start=p.float("repainting_start", 0.0),
1312
+ repainting_end=p.float("repainting_end"),
1313
+ instruction=p.str("instruction", DEFAULT_DIT_INSTRUCTION),
1314
+ audio_cover_strength=p.float("audio_cover_strength", 1.0),
1315
+ task_type=p.str("task_type", "text2music"),
1316
+ use_adg=p.bool("use_adg"),
1317
+ cfg_interval_start=p.float("cfg_interval_start", 0.0),
1318
+ cfg_interval_end=p.float("cfg_interval_end", 1.0),
1319
+ infer_method=p.str("infer_method", "ode"),
1320
+ shift=p.float("shift", 3.0),
1321
+ audio_format=p.str("audio_format", "mp3"),
1322
+ use_tiled_decode=p.bool("use_tiled_decode", True),
1323
+ lm_model_path=p.str("lm_model_path") or None,
1324
+ lm_backend=p.str("lm_backend", "vllm"),
1325
+ lm_temperature=p.float("lm_temperature", LM_DEFAULT_TEMPERATURE),
1326
+ lm_cfg_scale=p.float("lm_cfg_scale", LM_DEFAULT_CFG_SCALE),
1327
+ lm_top_k=p.int("lm_top_k"),
1328
+ lm_top_p=p.float("lm_top_p", LM_DEFAULT_TOP_P),
1329
+ lm_repetition_penalty=p.float("lm_repetition_penalty", 1.0),
1330
+ lm_negative_prompt=p.str("lm_negative_prompt", "NO USER INPUT"),
1331
+ constrained_decoding=p.bool("constrained_decoding", True),
1332
+ constrained_decoding_debug=p.bool("constrained_decoding_debug"),
1333
+ use_cot_caption=p.bool("use_cot_caption", True),
1334
+ use_cot_language=p.bool("use_cot_language", True),
1335
+ is_format_caption=p.bool("is_format_caption"),
1336
+ **kwargs,
1337
+ )
1338
+
1339
+ if content_type.startswith("application/json"):
1340
+ body = await request.json()
1341
+ if not isinstance(body, dict):
1342
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1343
+ req = _build_request(RequestParser(body))
1344
+
1345
+ elif content_type.endswith("+json"):
1346
+ body = await request.json()
1347
+ if not isinstance(body, dict):
1348
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1349
+ req = _build_request(RequestParser(body))
1350
+
1351
+ elif content_type.startswith("multipart/form-data"):
1352
+ form = await request.form()
1353
+
1354
+ ref_up = form.get("reference_audio")
1355
+ src_up = form.get("src_audio")
1356
+
1357
+ reference_audio_path = None
1358
+ src_audio_path = None
1359
+
1360
+ if isinstance(ref_up, StarletteUploadFile):
1361
+ reference_audio_path = await _save_upload_to_temp(ref_up, prefix="reference_audio")
1362
+ temp_files.append(reference_audio_path)
1363
+ else:
1364
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1365
+
1366
+ if isinstance(src_up, StarletteUploadFile):
1367
+ src_audio_path = await _save_upload_to_temp(src_up, prefix="src_audio")
1368
+ temp_files.append(src_audio_path)
1369
+ else:
1370
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1371
+
1372
+ req = _build_request(
1373
+ RequestParser(dict(form)),
1374
+ reference_audio_path=reference_audio_path,
1375
+ src_audio_path=src_audio_path,
1376
+ )
1377
+
1378
+ elif content_type.startswith("application/x-www-form-urlencoded"):
1379
+ form = await request.form()
1380
+ reference_audio_path = str(form.get("reference_audio_path") or "").strip() or None
1381
+ src_audio_path = str(form.get("src_audio_path") or "").strip() or None
1382
+ req = _build_request(
1383
+ RequestParser(dict(form)),
1384
+ reference_audio_path=reference_audio_path,
1385
+ src_audio_path=src_audio_path,
1386
+ )
1387
+
1388
+ else:
1389
+ raw = await request.body()
1390
+ raw_stripped = raw.lstrip()
1391
+ # Best-effort: accept missing/incorrect Content-Type if payload is valid JSON.
1392
+ if raw_stripped.startswith(b"{") or raw_stripped.startswith(b"["):
1393
+ try:
1394
+ body = json.loads(raw.decode("utf-8"))
1395
+ if isinstance(body, dict):
1396
+ req = _build_request(RequestParser(body))
1397
+ else:
1398
+ raise HTTPException(status_code=400, detail="JSON payload must be an object")
1399
+ except HTTPException:
1400
+ raise
1401
+ except Exception:
1402
+ raise HTTPException(
1403
+ status_code=400,
1404
+ detail="Invalid JSON body (hint: set 'Content-Type: application/json')",
1405
+ )
1406
+ # Best-effort: parse key=value bodies even if Content-Type is missing.
1407
+ elif raw_stripped and b"=" in raw:
1408
+ parsed = urllib.parse.parse_qs(raw.decode("utf-8"), keep_blank_values=True)
1409
+ flat = {k: (v[0] if isinstance(v, list) and v else v) for k, v in parsed.items()}
1410
+ reference_audio_path = str(flat.get("reference_audio_path") or "").strip() or None
1411
+ src_audio_path = str(flat.get("src_audio_path") or "").strip() or None
1412
+ req = _build_request(
1413
+ RequestParser(flat),
1414
+ reference_audio_path=reference_audio_path,
1415
+ src_audio_path=src_audio_path,
1416
+ )
1417
+ else:
1418
+ raise HTTPException(
1419
+ status_code=415,
1420
+ detail=(
1421
+ f"Unsupported Content-Type: {content_type or '(missing)'}; "
1422
+ "use application/json, application/x-www-form-urlencoded, or multipart/form-data"
1423
+ ),
1424
+ )
1425
+
1426
+ rec = store.create()
1427
+
1428
+ q: asyncio.Queue = app.state.job_queue
1429
+ if q.full():
1430
+ for p in temp_files:
1431
+ try:
1432
+ os.remove(p)
1433
+ except Exception:
1434
+ pass
1435
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
1436
+
1437
+ if temp_files:
1438
+ async with app.state.job_temp_files_lock:
1439
+ app.state.job_temp_files[rec.job_id] = temp_files
1440
+
1441
+ async with app.state.pending_lock:
1442
+ app.state.pending_ids.append(rec.job_id)
1443
+ position = len(app.state.pending_ids)
1444
+
1445
+ await q.put((rec.job_id, req))
1446
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1447
+
1448
+ @app.post("/v1/music/random", response_model=CreateJobResponse)
1449
+ async def create_random_sample_job(request: Request) -> CreateJobResponse:
1450
+ """Create a sample-mode job that auto-generates caption/lyrics via LM."""
1451
+
1452
+ thinking_value: Any = None
1453
+ content_type = (request.headers.get("content-type") or "").lower()
1454
+ body_dict: Dict[str, Any] = {}
1455
+
1456
+ if "json" in content_type:
1457
+ try:
1458
+ payload = await request.json()
1459
+ if isinstance(payload, dict):
1460
+ body_dict = payload
1461
+ except Exception:
1462
+ body_dict = {}
1463
+
1464
+ if not body_dict and request.query_params:
1465
+ body_dict = dict(request.query_params)
1466
+
1467
+ thinking_value = body_dict.get("thinking")
1468
+ if thinking_value is None:
1469
+ thinking_value = body_dict.get("Thinking")
1470
+
1471
+ thinking_flag = _to_bool(thinking_value, True)
1472
+
1473
+ req = GenerateMusicRequest(
1474
+ caption="",
1475
+ lyrics="",
1476
+ thinking=thinking_flag,
1477
+ sample_mode=True,
1478
+ )
1479
+
1480
+ rec = store.create()
1481
+ q: asyncio.Queue = app.state.job_queue
1482
+ if q.full():
1483
+ raise HTTPException(status_code=429, detail="Server busy: queue is full")
1484
+
1485
+ async with app.state.pending_lock:
1486
+ app.state.pending_ids.append(rec.job_id)
1487
+ position = len(app.state.pending_ids)
1488
+
1489
+ await q.put((rec.job_id, req))
1490
+ return CreateJobResponse(task_id=rec.job_id, status="queued", queue_position=position)
1491
+
1492
+ @app.post("/query_result")
1493
+ async def query_result(request: Request) -> List[Dict[str, Any]]:
1494
+ """Batch query job results"""
1495
+ content_type = (request.headers.get("content-type") or "").lower()
1496
+
1497
+ if "json" in content_type:
1498
+ body = await request.json()
1499
+ else:
1500
+ form = await request.form()
1501
+ body = {k: v for k, v in form.items()}
1502
+
1503
+ task_id_list_str = body.get("task_id_list", "[]")
1504
+
1505
+ # Parse task ID list
1506
+ if isinstance(task_id_list_str, list):
1507
+ task_id_list = task_id_list_str
1508
+ else:
1509
+ try:
1510
+ task_id_list = json.loads(task_id_list_str)
1511
+ except Exception:
1512
+ task_id_list = []
1513
+
1514
+ local_cache = getattr(app.state, 'local_cache', None)
1515
+ data_list = []
1516
+ current_time = time.time()
1517
+
1518
+ for task_id in task_id_list:
1519
+ result_key = f"{RESULT_KEY_PREFIX}{task_id}"
1520
+
1521
+ # Read from local cache first
1522
+ if local_cache:
1523
+ data = local_cache.get(result_key)
1524
+ if data:
1525
+ try:
1526
+ data_json = json.loads(data)
1527
+ except Exception:
1528
+ data_json = []
1529
+
1530
+ if len(data_json) <= 0:
1531
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1532
+ else:
1533
+ status = data_json[0].get("status")
1534
+ create_time = data_json[0].get("create_time", 0)
1535
+ if status == 0 and (current_time - create_time) > TASK_TIMEOUT_SECONDS:
1536
+ data_list.append({"task_id": task_id, "result": data, "status": 2})
1537
+ else:
1538
+ data_list.append({
1539
+ "task_id": task_id,
1540
+ "result": data,
1541
+ "status": int(status) if status is not None else 1,
1542
+ })
1543
+ continue
1544
+
1545
+ # Fallback to job_store query
1546
+ rec = store.get(task_id)
1547
+ if rec:
1548
+ env = getattr(rec, 'env', 'development')
1549
+ create_time = rec.created_at
1550
+ status_int = _map_status(rec.status)
1551
+
1552
+ if rec.result and rec.status == "succeeded":
1553
+ audio_paths = rec.result.get("audio_paths", [])
1554
+ metas = rec.result.get("metas", {}) or {}
1555
+ result_data = [
1556
+ {
1557
+ "file": p, "wave": "", "status": status_int,
1558
+ "create_time": int(create_time), "env": env,
1559
+ "prompt": metas.get("caption", ""),
1560
+ "lyrics": metas.get("lyrics", ""),
1561
+ "metas": {
1562
+ "bpm": metas.get("bpm"),
1563
+ "duration": metas.get("duration"),
1564
+ "genres": metas.get("genres", ""),
1565
+ "keyscale": metas.get("keyscale", ""),
1566
+ "timesignature": metas.get("timesignature", ""),
1567
+ }
1568
+ }
1569
+ for p in audio_paths
1570
+ ] if audio_paths else [{
1571
+ "file": "", "wave": "", "status": status_int,
1572
+ "create_time": int(create_time), "env": env,
1573
+ "prompt": metas.get("caption", ""),
1574
+ "lyrics": metas.get("lyrics", ""),
1575
+ "metas": {
1576
+ "bpm": metas.get("bpm"),
1577
+ "duration": metas.get("duration"),
1578
+ "genres": metas.get("genres", ""),
1579
+ "keyscale": metas.get("keyscale", ""),
1580
+ "timesignature": metas.get("timesignature", ""),
1581
+ }
1582
+ }]
1583
+ else:
1584
+ result_data = [{
1585
+ "file": "", "wave": "", "status": status_int,
1586
+ "create_time": int(create_time), "env": env,
1587
+ "prompt": "", "lyrics": "",
1588
+ "metas": {}
1589
+ }]
1590
+
1591
+ data_list.append({
1592
+ "task_id": task_id,
1593
+ "result": json.dumps(result_data, ensure_ascii=False),
1594
+ "status": status_int,
1595
+ })
1596
+ else:
1597
+ data_list.append({"task_id": task_id, "result": "[]", "status": 0})
1598
+
1599
+ return data_list
1600
+
1601
+ @app.get("/health")
1602
+ async def health_check():
1603
+ """Health check endpoint for service status."""
1604
+ return {
1605
+ "status": "ok",
1606
+ "service": "ACE-Step API",
1607
+ "version": "1.0",
1608
+ }
1609
+
1610
+ @app.get("/v1/models")
1611
+ async def list_models():
1612
+ """List available DiT models."""
1613
+ models = []
1614
+
1615
+ # Primary model (always available if initialized)
1616
+ if getattr(app.state, "_initialized", False):
1617
+ primary_model = _get_model_name(app.state._config_path)
1618
+ if primary_model:
1619
+ models.append({
1620
+ "name": primary_model,
1621
+ "is_default": True,
1622
+ })
1623
+
1624
+ # Secondary model
1625
+ if getattr(app.state, "_initialized2", False) and app.state._config_path2:
1626
+ secondary_model = _get_model_name(app.state._config_path2)
1627
+ if secondary_model:
1628
+ models.append({
1629
+ "name": secondary_model,
1630
+ "is_default": False,
1631
+ })
1632
+
1633
+ # Third model
1634
+ if getattr(app.state, "_initialized3", False) and app.state._config_path3:
1635
+ third_model = _get_model_name(app.state._config_path3)
1636
+ if third_model:
1637
+ models.append({
1638
+ "name": third_model,
1639
+ "is_default": False,
1640
+ })
1641
+
1642
+ return {
1643
+ "models": models,
1644
+ "default_model": models[0]["name"] if models else None,
1645
+ }
1646
+
1647
+ @app.get("/v1/audio")
1648
+ async def get_audio(path: str):
1649
+ """Serve audio file by path."""
1650
+ from fastapi.responses import FileResponse
1651
+
1652
+ if not os.path.exists(path):
1653
+ raise HTTPException(status_code=404, detail=f"Audio file not found: {path}")
1654
+
1655
+ ext = os.path.splitext(path)[1].lower()
1656
+ media_types = {
1657
+ ".mp3": "audio/mpeg",
1658
+ ".wav": "audio/wav",
1659
+ ".flac": "audio/flac",
1660
+ ".ogg": "audio/ogg",
1661
+ }
1662
+ media_type = media_types.get(ext, "audio/mpeg")
1663
+
1664
+ return FileResponse(path, media_type=media_type)
1665
+
1666
+ return app
1667
+
1668
+
1669
+ app = create_app()
1670
+
1671
+
1672
+ def main() -> None:
1673
+ import argparse
1674
+ import uvicorn
1675
+
1676
+ parser = argparse.ArgumentParser(description="ACE-Step API server")
1677
+ parser.add_argument(
1678
+ "--host",
1679
+ default=os.getenv("ACESTEP_API_HOST", "127.0.0.1"),
1680
+ help="Bind host (default from ACESTEP_API_HOST or 127.0.0.1)",
1681
+ )
1682
+ parser.add_argument(
1683
+ "--port",
1684
+ type=int,
1685
+ default=int(os.getenv("ACESTEP_API_PORT", "8001")),
1686
+ help="Bind port (default from ACESTEP_API_PORT or 8001)",
1687
+ )
1688
+ args = parser.parse_args()
1689
+
1690
+ # IMPORTANT: in-memory queue/store -> workers MUST be 1
1691
+ uvicorn.run(
1692
+ "acestep.api_server:app",
1693
+ host=str(args.host),
1694
+ port=int(args.port),
1695
+ reload=False,
1696
+ workers=1,
1697
+ )
1698
+
1699
+ if __name__ == "__main__":
1700
+ main()
acestep/audio_utils.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio saving and transcoding utility module
3
+
4
+ Independent audio file operations outside of handler, supporting:
5
+ - Save audio tensor/numpy to files (default FLAC format, fast)
6
+ - Format conversion (FLAC/WAV/MP3)
7
+ - Batch processing
8
+ """
9
+
10
+ import os
11
+
12
+ # Disable torchcodec backend to avoid CUDA dependency issues on HuggingFace Space
13
+ # This forces torchaudio to use ffmpeg/sox/soundfile backends instead
14
+ os.environ["TORCHAUDIO_USE_TORCHCODEC"] = "0"
15
+
16
+ import hashlib
17
+ import json
18
+ from pathlib import Path
19
+ from typing import Union, Optional, List, Tuple
20
+ import torch
21
+ import numpy as np
22
+ import torchaudio
23
+ from loguru import logger
24
+
25
+
26
+ class AudioSaver:
27
+ """Audio saving and transcoding utility class"""
28
+
29
+ def __init__(self, default_format: str = "flac"):
30
+ """
31
+ Initialize audio saver
32
+
33
+ Args:
34
+ default_format: Default save format ('flac', 'wav', 'mp3')
35
+ """
36
+ self.default_format = default_format.lower()
37
+ if self.default_format not in ["flac", "wav", "mp3"]:
38
+ logger.warning(f"Unsupported format {default_format}, using 'flac'")
39
+ self.default_format = "flac"
40
+
41
+ def save_audio(
42
+ self,
43
+ audio_data: Union[torch.Tensor, np.ndarray],
44
+ output_path: Union[str, Path],
45
+ sample_rate: int = 48000,
46
+ format: Optional[str] = None,
47
+ channels_first: bool = True,
48
+ ) -> str:
49
+ """
50
+ Save audio data to file
51
+
52
+ Args:
53
+ audio_data: Audio data, torch.Tensor [channels, samples] or numpy.ndarray
54
+ output_path: Output file path (extension can be omitted)
55
+ sample_rate: Sample rate
56
+ format: Audio format ('flac', 'wav', 'mp3'), defaults to default_format
57
+ channels_first: If True, tensor format is [channels, samples], else [samples, channels]
58
+
59
+ Returns:
60
+ Actual saved file path
61
+ """
62
+ format = (format or self.default_format).lower()
63
+ if format not in ["flac", "wav", "mp3"]:
64
+ logger.warning(f"Unsupported format {format}, using {self.default_format}")
65
+ format = self.default_format
66
+
67
+ # Ensure output path has correct extension
68
+ output_path = Path(output_path)
69
+ if output_path.suffix.lower() not in ['.flac', '.wav', '.mp3']:
70
+ output_path = output_path.with_suffix(f'.{format}')
71
+
72
+ # Convert to torch tensor
73
+ if isinstance(audio_data, np.ndarray):
74
+ if channels_first:
75
+ # numpy [samples, channels] -> tensor [channels, samples]
76
+ audio_tensor = torch.from_numpy(audio_data.T).float()
77
+ else:
78
+ # numpy [samples, channels] -> tensor [samples, channels] -> [channels, samples]
79
+ audio_tensor = torch.from_numpy(audio_data).float()
80
+ if audio_tensor.dim() == 2 and audio_tensor.shape[0] < audio_tensor.shape[1]:
81
+ audio_tensor = audio_tensor.T
82
+ else:
83
+ # torch tensor
84
+ audio_tensor = audio_data.cpu().float()
85
+ if not channels_first and audio_tensor.dim() == 2:
86
+ # [samples, channels] -> [channels, samples]
87
+ if audio_tensor.shape[0] > audio_tensor.shape[1]:
88
+ audio_tensor = audio_tensor.T
89
+
90
+ # Ensure memory is contiguous
91
+ audio_tensor = audio_tensor.contiguous()
92
+
93
+ # Select backend and save
94
+ try:
95
+ if format == "mp3":
96
+ # MP3 uses ffmpeg backend
97
+ torchaudio.save(
98
+ str(output_path),
99
+ audio_tensor,
100
+ sample_rate,
101
+ channels_first=True,
102
+ backend='ffmpeg',
103
+ )
104
+ elif format in ["flac", "wav"]:
105
+ # FLAC and WAV use soundfile backend (fastest)
106
+ torchaudio.save(
107
+ str(output_path),
108
+ audio_tensor,
109
+ sample_rate,
110
+ channels_first=True,
111
+ backend='soundfile',
112
+ )
113
+ else:
114
+ # Other formats use default backend
115
+ torchaudio.save(
116
+ str(output_path),
117
+ audio_tensor,
118
+ sample_rate,
119
+ channels_first=True,
120
+ )
121
+
122
+ logger.debug(f"[AudioSaver] Saved audio to {output_path} ({format}, {sample_rate}Hz)")
123
+ return str(output_path)
124
+
125
+ except Exception as e:
126
+ try:
127
+ import soundfile as sf
128
+ audio_np = audio_tensor.transpose(0, 1).numpy() # -> [samples, channels]
129
+ sf.write(str(output_path), audio_np, sample_rate, format=format.upper())
130
+ logger.debug(f"[AudioSaver] Fallback soundfile Saved audio to {output_path} ({format}, {sample_rate}Hz)")
131
+ return str(output_path)
132
+ except Exception as e:
133
+ logger.error(f"[AudioSaver] Failed to save audio: {e}")
134
+ raise
135
+
136
+ def _load_audio_file(self, audio_file: Union[str, Path]) -> Tuple[torch.Tensor, int]:
137
+ """
138
+ Load audio file with ffmpeg backend, fallback to soundfile if failed.
139
+
140
+ This handles CUDA dependency issues with torchcodec on HuggingFace Space.
141
+
142
+ Args:
143
+ audio_file: Path to the audio file
144
+
145
+ Returns:
146
+ Tuple of (audio_tensor, sample_rate)
147
+
148
+ Raises:
149
+ FileNotFoundError: If the audio file doesn't exist
150
+ Exception: If all methods fail to load the audio
151
+ """
152
+ audio_file = str(audio_file)
153
+
154
+ # Check if file exists first
155
+ if not Path(audio_file).exists():
156
+ raise FileNotFoundError(f"Audio file not found: {audio_file}")
157
+
158
+ # Try torchaudio with explicit ffmpeg backend first
159
+ try:
160
+ audio, sr = torchaudio.load(audio_file, backend="ffmpeg")
161
+ return audio, sr
162
+ except Exception as e:
163
+ logger.debug(f"[AudioSaver._load_audio_file] ffmpeg backend failed: {e}, trying soundfile fallback")
164
+
165
+ # Fallback: use soundfile directly (most compatible)
166
+ try:
167
+ import soundfile as sf
168
+ audio_np, sr = sf.read(audio_file)
169
+ # soundfile returns [samples, channels] or [samples], convert to [channels, samples]
170
+ audio = torch.from_numpy(audio_np).float()
171
+ if audio.dim() == 1:
172
+ # Mono: [samples] -> [1, samples]
173
+ audio = audio.unsqueeze(0)
174
+ else:
175
+ # Stereo: [samples, channels] -> [channels, samples]
176
+ audio = audio.T
177
+ return audio, sr
178
+ except Exception as e:
179
+ logger.error(f"[AudioSaver._load_audio_file] All methods failed to load audio: {audio_file}, error: {e}")
180
+ raise
181
+
182
+ def convert_audio(
183
+ self,
184
+ input_path: Union[str, Path],
185
+ output_path: Union[str, Path],
186
+ output_format: str,
187
+ remove_input: bool = False,
188
+ ) -> str:
189
+ """
190
+ Convert audio format
191
+
192
+ Args:
193
+ input_path: Input audio file path
194
+ output_path: Output audio file path
195
+ output_format: Target format ('flac', 'wav', 'mp3')
196
+ remove_input: Whether to delete input file
197
+
198
+ Returns:
199
+ Output file path
200
+ """
201
+ input_path = Path(input_path)
202
+ output_path = Path(output_path)
203
+
204
+ if not input_path.exists():
205
+ raise FileNotFoundError(f"Input file not found: {input_path}")
206
+
207
+ # Load audio with fallback backends
208
+ audio_tensor, sample_rate = self._load_audio_file(input_path)
209
+
210
+ # Save as new format
211
+ output_path = self.save_audio(
212
+ audio_tensor,
213
+ output_path,
214
+ sample_rate=sample_rate,
215
+ format=output_format,
216
+ channels_first=True
217
+ )
218
+
219
+ # Delete input file if needed
220
+ if remove_input:
221
+ input_path.unlink()
222
+ logger.debug(f"[AudioSaver] Removed input file: {input_path}")
223
+
224
+ return output_path
225
+
226
+ def save_batch(
227
+ self,
228
+ audio_batch: Union[List[torch.Tensor], torch.Tensor],
229
+ output_dir: Union[str, Path],
230
+ file_prefix: str = "audio",
231
+ sample_rate: int = 48000,
232
+ format: Optional[str] = None,
233
+ channels_first: bool = True,
234
+ ) -> List[str]:
235
+ """
236
+ Save audio batch
237
+
238
+ Args:
239
+ audio_batch: Audio batch, List[tensor] or tensor [batch, channels, samples]
240
+ output_dir: Output directory
241
+ file_prefix: File prefix
242
+ sample_rate: Sample rate
243
+ format: Audio format
244
+ channels_first: Tensor format flag
245
+
246
+ Returns:
247
+ List of saved file paths
248
+ """
249
+ output_dir = Path(output_dir)
250
+ output_dir.mkdir(parents=True, exist_ok=True)
251
+
252
+ # Process batch
253
+ if isinstance(audio_batch, torch.Tensor) and audio_batch.dim() == 3:
254
+ # [batch, channels, samples]
255
+ audio_list = [audio_batch[i] for i in range(audio_batch.shape[0])]
256
+ elif isinstance(audio_batch, list):
257
+ audio_list = audio_batch
258
+ else:
259
+ audio_list = [audio_batch]
260
+
261
+ saved_paths = []
262
+ for i, audio in enumerate(audio_list):
263
+ output_path = output_dir / f"{file_prefix}_{i:04d}"
264
+ saved_path = self.save_audio(
265
+ audio,
266
+ output_path,
267
+ sample_rate=sample_rate,
268
+ format=format,
269
+ channels_first=channels_first
270
+ )
271
+ saved_paths.append(saved_path)
272
+
273
+ return saved_paths
274
+
275
+
276
+ def get_audio_file_hash(audio_file) -> str:
277
+ """
278
+ Get hash identifier for an audio file.
279
+
280
+ Args:
281
+ audio_file: Path to audio file (str) or file-like object
282
+
283
+ Returns:
284
+ Hash string or empty string
285
+ """
286
+ if audio_file is None:
287
+ return ""
288
+
289
+ try:
290
+ if isinstance(audio_file, str):
291
+ if os.path.exists(audio_file):
292
+ with open(audio_file, 'rb') as f:
293
+ return hashlib.md5(f.read()).hexdigest()
294
+ return hashlib.md5(audio_file.encode('utf-8')).hexdigest()
295
+ elif hasattr(audio_file, 'name'):
296
+ return hashlib.md5(str(audio_file.name).encode('utf-8')).hexdigest()
297
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
298
+ except Exception:
299
+ return hashlib.md5(str(audio_file).encode('utf-8')).hexdigest()
300
+
301
+
302
+ def generate_uuid_from_params(params_dict) -> str:
303
+ """
304
+ Generate deterministic UUID from generation parameters.
305
+ Same parameters will always generate the same UUID.
306
+
307
+ Args:
308
+ params_dict: Dictionary of parameters
309
+
310
+ Returns:
311
+ UUID string
312
+ """
313
+
314
+ params_json = json.dumps(params_dict, sort_keys=True, ensure_ascii=False)
315
+ hash_obj = hashlib.sha256(params_json.encode('utf-8'))
316
+ hash_hex = hash_obj.hexdigest()
317
+ uuid_str = f"{hash_hex[0:8]}-{hash_hex[8:12]}-{hash_hex[12:16]}-{hash_hex[16:20]}-{hash_hex[20:32]}"
318
+ return uuid_str
319
+
320
+
321
+ def generate_uuid_from_audio_data(
322
+ audio_data: Union[torch.Tensor, np.ndarray],
323
+ seed: Optional[int] = None
324
+ ) -> str:
325
+ """
326
+ Generate UUID from audio data (for caching/deduplication)
327
+
328
+ Args:
329
+ audio_data: Audio data
330
+ seed: Optional seed value
331
+
332
+ Returns:
333
+ UUID string
334
+ """
335
+ if isinstance(audio_data, torch.Tensor):
336
+ # Convert to numpy and calculate hash
337
+ audio_np = audio_data.cpu().numpy()
338
+ else:
339
+ audio_np = audio_data
340
+
341
+ # Calculate data hash
342
+ data_hash = hashlib.md5(audio_np.tobytes()).hexdigest()
343
+
344
+ if seed is not None:
345
+ combined = f"{data_hash}_{seed}"
346
+ return hashlib.md5(combined.encode()).hexdigest()
347
+
348
+ return data_hash
349
+
350
+
351
+ # Global default instance
352
+ _default_saver = AudioSaver(default_format="flac")
353
+
354
+
355
+ def save_audio(
356
+ audio_data: Union[torch.Tensor, np.ndarray],
357
+ output_path: Union[str, Path],
358
+ sample_rate: int = 48000,
359
+ format: Optional[str] = None,
360
+ channels_first: bool = True,
361
+ ) -> str:
362
+ """
363
+ Convenience function: save audio (using default configuration)
364
+
365
+ Args:
366
+ audio_data: Audio data
367
+ output_path: Output path
368
+ sample_rate: Sample rate
369
+ format: Format (default flac)
370
+ channels_first: Tensor format flag
371
+
372
+ Returns:
373
+ Saved file path
374
+ """
375
+ return _default_saver.save_audio(
376
+ audio_data, output_path, sample_rate, format, channels_first
377
+ )
378
+
acestep/constants.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Constants for ACE-Step
3
+ Centralized constants used across the codebase
4
+ """
5
+
6
+ # ==============================================================================
7
+ # Language Constants
8
+ # ==============================================================================
9
+
10
+ VALID_LANGUAGES = [
11
+ 'ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en',
12
+ 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id',
13
+ 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no',
14
+ 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw',
15
+ 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh',
16
+ 'unknown'
17
+ ]
18
+
19
+
20
+ # ==============================================================================
21
+ # Keyscale Constants
22
+ # ==============================================================================
23
+
24
+ KEYSCALE_NOTES = ['A', 'B', 'C', 'D', 'E', 'F', 'G']
25
+ KEYSCALE_ACCIDENTALS = ['', '#', 'b', '♯', '♭'] # empty + ASCII sharp/flat + Unicode sharp/flat
26
+ KEYSCALE_MODES = ['major', 'minor']
27
+
28
+ # Generate all valid keyscales: 7 notes × 5 accidentals × 2 modes = 70 combinations
29
+ VALID_KEYSCALES = set()
30
+ for note in KEYSCALE_NOTES:
31
+ for acc in KEYSCALE_ACCIDENTALS:
32
+ for mode in KEYSCALE_MODES:
33
+ VALID_KEYSCALES.add(f"{note}{acc} {mode}")
34
+
35
+
36
+ # ==============================================================================
37
+ # Metadata Range Constants
38
+ # ==============================================================================
39
+
40
+ # BPM (Beats Per Minute) range
41
+ BPM_MIN = 30
42
+ BPM_MAX = 300
43
+
44
+ # Duration range (in seconds)
45
+ DURATION_MIN = 10
46
+ DURATION_MAX = 600
47
+
48
+ # Valid time signatures
49
+ VALID_TIME_SIGNATURES = [2, 3, 4, 6]
50
+
51
+
52
+ # ==============================================================================
53
+ # Task Type Constants
54
+ # ==============================================================================
55
+
56
+ TASK_TYPES = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
57
+
58
+ # Task types available for turbo models (subset)
59
+ TASK_TYPES_TURBO = ["text2music", "repaint", "cover"]
60
+
61
+ # Task types available for base models (full set)
62
+ TASK_TYPES_BASE = ["text2music", "repaint", "cover", "extract", "lego", "complete"]
63
+
64
+
65
+ # ==============================================================================
66
+ # Instruction Constants
67
+ # ==============================================================================
68
+
69
+ # Default instructions
70
+ DEFAULT_DIT_INSTRUCTION = "Fill the audio semantic mask based on the given conditions:"
71
+ DEFAULT_LM_INSTRUCTION = "Generate audio semantic tokens based on the given conditions:"
72
+ DEFAULT_LM_UNDERSTAND_INSTRUCTION = "Understand the given musical conditions and describe the audio semantics accordingly:"
73
+ DEFAULT_LM_INSPIRED_INSTRUCTION = "Expand the user's input into a more detailed and specific musical description:"
74
+ DEFAULT_LM_REWRITE_INSTRUCTION = "Format the user's input into a more detailed and specific musical description:"
75
+
76
+ # Instruction templates for each task type
77
+ # Note: Some instructions use placeholders like {TRACK_NAME} or {TRACK_CLASSES}
78
+ # These should be formatted using .format() or f-strings when used
79
+ TASK_INSTRUCTIONS = {
80
+ "text2music": "Fill the audio semantic mask based on the given conditions:",
81
+ "repaint": "Repaint the mask area based on the given conditions:",
82
+ "cover": "Generate audio semantic tokens based on the given conditions:",
83
+ "extract": "Extract the {TRACK_NAME} track from the audio:",
84
+ "extract_default": "Extract the track from the audio:",
85
+ "lego": "Generate the {TRACK_NAME} track based on the audio context:",
86
+ "lego_default": "Generate the track based on the audio context:",
87
+ "complete": "Complete the input track with {TRACK_CLASSES}:",
88
+ "complete_default": "Complete the input track:",
89
+ }
90
+
91
+
92
+ # ==============================================================================
93
+ # Track/Instrument Constants
94
+ # ==============================================================================
95
+
96
+ TRACK_NAMES = [
97
+ "woodwinds", "brass", "fx", "synth", "strings", "percussion",
98
+ "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
99
+ ]
100
+
101
+ SFT_GEN_PROMPT = """# Instruction
102
+ {}
103
+
104
+ # Caption
105
+ {}
106
+
107
+ # Metas
108
+ {}<|endoftext|>
109
+ """
acestep/constrained_logits_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/dataset_handler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Handler
3
+ Handles dataset import and exploration functionality
4
+ """
5
+ from typing import Optional, Tuple, Any, Dict
6
+
7
+
8
+ class DatasetHandler:
9
+ """Dataset Handler for Dataset Explorer functionality"""
10
+
11
+ def __init__(self):
12
+ """Initialize dataset handler"""
13
+ self.dataset = None
14
+ self.dataset_imported = False
15
+
16
+ def import_dataset(self, dataset_type: str) -> str:
17
+ """
18
+ Import dataset (temporarily disabled)
19
+
20
+ Args:
21
+ dataset_type: Type of dataset to import (e.g., "train", "test")
22
+
23
+ Returns:
24
+ Status message string
25
+ """
26
+ self.dataset_imported = False
27
+ return f"⚠️ Dataset import is currently disabled. Text2MusicDataset dependency not available."
28
+
29
+ def get_item_data(self, *args, **kwargs) -> Tuple:
30
+ """
31
+ Get dataset item (temporarily disabled)
32
+
33
+ Returns:
34
+ Tuple of placeholder values matching the expected return format
35
+ """
36
+ return "", "", "", "", "", None, None, None, "❌ Dataset not available", "", 0, "", None, None, None, {}, "text2music"
37
+
acestep/dit_alignment_score.py ADDED
@@ -0,0 +1,870 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DiT Alignment Score Module
3
+
4
+ This module provides lyrics-to-audio alignment using cross-attention matrices
5
+ from DiT model for generating LRC timestamps.
6
+
7
+ Refactored from lyrics_alignment_infos.py for integration with ACE-Step.
8
+ """
9
+ import numba
10
+ import torch
11
+ import numpy as np
12
+ import torch.nn.functional as F
13
+ from dataclasses import dataclass, asdict
14
+ from typing import List, Dict, Any, Optional, Tuple, Union
15
+
16
+
17
+ # ================= Data Classes =================
18
+ @dataclass
19
+ class TokenTimestamp:
20
+ """Stores per-token timing information."""
21
+ token_id: int
22
+ text: str
23
+ start: float
24
+ end: float
25
+ probability: float
26
+
27
+
28
+ @dataclass
29
+ class SentenceTimestamp:
30
+ """Stores per-sentence timing information with token list."""
31
+ text: str
32
+ start: float
33
+ end: float
34
+ tokens: List[TokenTimestamp]
35
+ confidence: float
36
+
37
+
38
+ # ================= DTW Algorithm (Numba Optimized) =================
39
+ @numba.jit(nopython=True)
40
+ def dtw_cpu(x: np.ndarray):
41
+ """
42
+ Dynamic Time Warping algorithm optimized with Numba.
43
+
44
+ Args:
45
+ x: Cost matrix of shape [N, M]
46
+
47
+ Returns:
48
+ Tuple of (text_indices, time_indices) arrays
49
+ """
50
+ N, M = x.shape
51
+ # Use float32 for memory efficiency
52
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
53
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
54
+ cost[0, 0] = 0
55
+
56
+ for j in range(1, M + 1):
57
+ for i in range(1, N + 1):
58
+ c0 = cost[i - 1, j - 1]
59
+ c1 = cost[i - 1, j]
60
+ c2 = cost[i, j - 1]
61
+
62
+ if c0 < c1 and c0 < c2:
63
+ c, t = c0, 0
64
+ elif c1 < c0 and c1 < c2:
65
+ c, t = c1, 1
66
+ else:
67
+ c, t = c2, 2
68
+
69
+ cost[i, j] = x[i - 1, j - 1] + c
70
+ trace[i, j] = t
71
+
72
+ return _backtrace(trace, N, M)
73
+
74
+
75
+ @numba.jit(nopython=True)
76
+ def _backtrace(trace: np.ndarray, N: int, M: int):
77
+ """
78
+ Optimized backtrace function for DTW.
79
+
80
+ Args:
81
+ trace: Trace matrix of shape (N+1, M+1)
82
+ N, M: Original matrix dimensions
83
+
84
+ Returns:
85
+ Path array of shape (2, path_len) - first row is text indices, second is time indices
86
+ """
87
+ # Boundary handling
88
+ trace[0, :] = 2
89
+ trace[:, 0] = 1
90
+
91
+ # Pre-allocate array, max path length is N+M
92
+ max_path_len = N + M
93
+ path = np.zeros((2, max_path_len), dtype=np.int32)
94
+
95
+ i, j = N, M
96
+ path_idx = max_path_len - 1
97
+
98
+ while i > 0 or j > 0:
99
+ path[0, path_idx] = i - 1 # text index
100
+ path[1, path_idx] = j - 1 # time index
101
+ path_idx -= 1
102
+
103
+ t = trace[i, j]
104
+ if t == 0:
105
+ i -= 1
106
+ j -= 1
107
+ elif t == 1:
108
+ i -= 1
109
+ elif t == 2:
110
+ j -= 1
111
+ else:
112
+ break
113
+
114
+ actual_len = max_path_len - path_idx - 1
115
+ return path[:, path_idx + 1:max_path_len]
116
+
117
+
118
+ # ================= Utility Functions =================
119
+ def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
120
+ """
121
+ Apply median filter to tensor.
122
+
123
+ Args:
124
+ x: Input tensor
125
+ filter_width: Width of median filter
126
+
127
+ Returns:
128
+ Filtered tensor
129
+ """
130
+ pad_width = filter_width // 2
131
+ if x.shape[-1] <= pad_width:
132
+ return x
133
+ if x.ndim == 2:
134
+ x = x[None, :]
135
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
136
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
137
+ if result.ndim > 2:
138
+ result = result.squeeze(0)
139
+ return result
140
+
141
+
142
+ # ================= Main Aligner Class =================
143
+ class MusicStampsAligner:
144
+ """
145
+ Aligner class for generating lyrics timestamps from cross-attention matrices.
146
+
147
+ Uses bidirectional consensus denoising and DTW for alignment.
148
+ """
149
+
150
+ def __init__(self, tokenizer):
151
+ """
152
+ Initialize the aligner.
153
+
154
+ Args:
155
+ tokenizer: Text tokenizer for decoding tokens
156
+ """
157
+ self.tokenizer = tokenizer
158
+
159
+ def _apply_bidirectional_consensus(
160
+ self,
161
+ weights_stack: torch.Tensor,
162
+ violence_level: float,
163
+ medfilt_width: int
164
+ ) -> tuple:
165
+ """
166
+ Core denoising logic using bidirectional consensus.
167
+
168
+ Args:
169
+ weights_stack: Attention weights [Heads, Tokens, Frames]
170
+ violence_level: Denoising strength coefficient
171
+ medfilt_width: Median filter width
172
+
173
+ Returns:
174
+ Tuple of (calc_matrix, energy_matrix) as numpy arrays
175
+ """
176
+ # A. Bidirectional Consensus
177
+ row_prob = F.softmax(weights_stack, dim=-1) # Token -> Frame
178
+ col_prob = F.softmax(weights_stack, dim=-2) # Frame -> Token
179
+ processed = row_prob * col_prob
180
+
181
+ # 1. Row suppression (kill horizontal crossing lines)
182
+ row_medians = torch.quantile(processed, 0.5, dim=-1, keepdim=True)
183
+ processed = processed - (violence_level * row_medians)
184
+ processed = torch.relu(processed)
185
+
186
+ # 2. Column suppression (kill vertical crossing lines)
187
+ col_medians = torch.quantile(processed, 0.5, dim=-2, keepdim=True)
188
+ processed = processed - (violence_level * col_medians)
189
+ processed = torch.relu(processed)
190
+
191
+ # C. Power sharpening
192
+ processed = processed ** 2
193
+
194
+ # Energy matrix for confidence
195
+ energy_matrix = processed.mean(dim=0).cpu().numpy()
196
+
197
+ # D. Z-Score normalization
198
+ std, mean = torch.std_mean(processed, unbiased=False)
199
+ weights_processed = (processed - mean) / (std + 1e-9)
200
+
201
+ # E. Median filtering
202
+ weights_processed = median_filter(weights_processed, filter_width=medfilt_width)
203
+ calc_matrix = weights_processed.mean(dim=0).numpy()
204
+
205
+ return calc_matrix, energy_matrix
206
+
207
+ def _preprocess_attention(
208
+ self,
209
+ attention_matrix: torch.Tensor,
210
+ custom_config: Dict[int, List[int]],
211
+ violence_level: float,
212
+ medfilt_width: int = 7
213
+ ) -> tuple:
214
+ """
215
+ Preprocess attention matrix for alignment.
216
+
217
+ Args:
218
+ attention_matrix: Attention tensor [Layers, Heads, Tokens, Frames]
219
+ custom_config: Dict mapping layer indices to head indices
220
+ violence_level: Denoising strength
221
+ medfilt_width: Median filter width
222
+
223
+ Returns:
224
+ Tuple of (calc_matrix, energy_matrix, visual_matrix)
225
+ """
226
+ if not isinstance(attention_matrix, torch.Tensor):
227
+ weights = torch.tensor(attention_matrix)
228
+ else:
229
+ weights = attention_matrix.clone()
230
+
231
+ weights = weights.cpu().float()
232
+
233
+ selected_tensors = []
234
+ for layer_idx, head_indices in custom_config.items():
235
+ for head_idx in head_indices:
236
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
237
+ head_matrix = weights[layer_idx, head_idx]
238
+ selected_tensors.append(head_matrix)
239
+
240
+ if not selected_tensors:
241
+ return None, None, None
242
+
243
+ # Stack selected heads: [Heads, Tokens, Frames]
244
+ weights_stack = torch.stack(selected_tensors, dim=0)
245
+ visual_matrix = weights_stack.mean(dim=0).numpy()
246
+
247
+ calc_matrix, energy_matrix = self._apply_bidirectional_consensus(
248
+ weights_stack, violence_level, medfilt_width
249
+ )
250
+
251
+ return calc_matrix, energy_matrix, visual_matrix
252
+
253
+ def stamps_align_info(
254
+ self,
255
+ attention_matrix: torch.Tensor,
256
+ lyrics_tokens: List[int],
257
+ total_duration_seconds: float,
258
+ custom_config: Dict[int, List[int]],
259
+ return_matrices: bool = False,
260
+ violence_level: float = 2.0,
261
+ medfilt_width: int = 1
262
+ ) -> Dict[str, Any]:
263
+ """
264
+ Get alignment information from attention matrix.
265
+
266
+ Args:
267
+ attention_matrix: Cross-attention tensor [Layers, Heads, Tokens, Frames]
268
+ lyrics_tokens: List of lyrics token IDs
269
+ total_duration_seconds: Total audio duration in seconds
270
+ custom_config: Dict mapping layer indices to head indices
271
+ return_matrices: Whether to return intermediate matrices
272
+ violence_level: Denoising strength
273
+ medfilt_width: Median filter width
274
+
275
+ Returns:
276
+ Dict containing calc_matrix, lyrics_tokens, total_duration_seconds,
277
+ and optionally energy_matrix and vis_matrix
278
+ """
279
+ calc_matrix, energy_matrix, visual_matrix = self._preprocess_attention(
280
+ attention_matrix, custom_config, violence_level, medfilt_width
281
+ )
282
+
283
+ if calc_matrix is None:
284
+ return {
285
+ "calc_matrix": None,
286
+ "lyrics_tokens": lyrics_tokens,
287
+ "total_duration_seconds": total_duration_seconds,
288
+ "error": "No valid attention heads found"
289
+ }
290
+
291
+ return_dict = {
292
+ "calc_matrix": calc_matrix,
293
+ "lyrics_tokens": lyrics_tokens,
294
+ "total_duration_seconds": total_duration_seconds
295
+ }
296
+
297
+ if return_matrices:
298
+ return_dict['energy_matrix'] = energy_matrix
299
+ return_dict['vis_matrix'] = visual_matrix
300
+
301
+ return return_dict
302
+
303
+ def _decode_tokens_incrementally(self, token_ids: List[int]) -> List[str]:
304
+ """
305
+ Decode tokens incrementally to properly handle multi-byte UTF-8 characters.
306
+
307
+ For Chinese and other multi-byte characters, the tokenizer may split them
308
+ into multiple byte-level tokens. Decoding each token individually produces
309
+ invalid UTF-8 sequences (showing as �). This method uses byte-level comparison
310
+ to correctly track which characters each token contributes.
311
+
312
+ Args:
313
+ token_ids: List of token IDs
314
+
315
+ Returns:
316
+ List of decoded text for each token position
317
+ """
318
+ decoded_tokens = []
319
+ prev_bytes = b""
320
+
321
+ for i in range(len(token_ids)):
322
+ # Decode tokens from start to current position
323
+ current_text = self.tokenizer.decode(token_ids[:i+1], skip_special_tokens=False)
324
+ current_bytes = current_text.encode('utf-8', errors='surrogatepass')
325
+
326
+ # The contribution of current token is the new bytes added
327
+ if len(current_bytes) >= len(prev_bytes):
328
+ new_bytes = current_bytes[len(prev_bytes):]
329
+ # Try to decode the new bytes; if incomplete, use empty string
330
+ try:
331
+ token_text = new_bytes.decode('utf-8')
332
+ except UnicodeDecodeError:
333
+ # Incomplete UTF-8 sequence, this token doesn't complete a character
334
+ token_text = ""
335
+ else:
336
+ # Edge case: current decode is shorter (shouldn't happen normally)
337
+ token_text = ""
338
+
339
+ decoded_tokens.append(token_text)
340
+ prev_bytes = current_bytes
341
+
342
+ return decoded_tokens
343
+
344
+ def token_timestamps(
345
+ self,
346
+ calc_matrix: np.ndarray,
347
+ lyrics_tokens: List[int],
348
+ total_duration_seconds: float
349
+ ) -> List[TokenTimestamp]:
350
+ """
351
+ Generate per-token timestamps using DTW.
352
+
353
+ Args:
354
+ calc_matrix: Processed attention matrix [Tokens, Frames]
355
+ lyrics_tokens: List of token IDs
356
+ total_duration_seconds: Total audio duration
357
+
358
+ Returns:
359
+ List of TokenTimestamp objects
360
+ """
361
+ n_frames = calc_matrix.shape[-1]
362
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float64))
363
+
364
+ seconds_per_frame = total_duration_seconds / n_frames
365
+ alignment_results = []
366
+
367
+ # Use incremental decoding to properly handle multi-byte UTF-8 characters
368
+ decoded_tokens = self._decode_tokens_incrementally(lyrics_tokens)
369
+
370
+ for i in range(len(lyrics_tokens)):
371
+ mask = (text_indices == i)
372
+
373
+ if not np.any(mask):
374
+ start = alignment_results[-1].end if alignment_results else 0.0
375
+ end = start
376
+ token_conf = 0.0
377
+ else:
378
+ times = time_indices[mask] * seconds_per_frame
379
+ start = times[0]
380
+ end = times[-1]
381
+ token_conf = 0.0
382
+
383
+ if end < start:
384
+ end = start
385
+
386
+ alignment_results.append(TokenTimestamp(
387
+ token_id=lyrics_tokens[i],
388
+ text=decoded_tokens[i],
389
+ start=float(start),
390
+ end=float(end),
391
+ probability=token_conf
392
+ ))
393
+
394
+ return alignment_results
395
+
396
+ def _decode_sentence_from_tokens(self, tokens: List[TokenTimestamp]) -> str:
397
+ """
398
+ Decode a sentence by decoding all token IDs together.
399
+ This avoids UTF-8 encoding issues from joining individual token texts.
400
+
401
+ Args:
402
+ tokens: List of TokenTimestamp objects
403
+
404
+ Returns:
405
+ Properly decoded sentence text
406
+ """
407
+ token_ids = [t.token_id for t in tokens]
408
+ return self.tokenizer.decode(token_ids, skip_special_tokens=False)
409
+
410
+ def sentence_timestamps(
411
+ self,
412
+ token_alignment: List[TokenTimestamp]
413
+ ) -> List[SentenceTimestamp]:
414
+ """
415
+ Group token timestamps into sentence timestamps.
416
+
417
+ Args:
418
+ token_alignment: List of TokenTimestamp objects
419
+
420
+ Returns:
421
+ List of SentenceTimestamp objects
422
+ """
423
+ results = []
424
+ current_tokens = []
425
+
426
+ for token in token_alignment:
427
+ current_tokens.append(token)
428
+
429
+ if '\n' in token.text:
430
+ # Decode all token IDs together to avoid UTF-8 issues
431
+ full_text = self._decode_sentence_from_tokens(current_tokens)
432
+
433
+ if full_text.strip():
434
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
435
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
436
+
437
+ results.append(SentenceTimestamp(
438
+ text=full_text.strip(),
439
+ start=round(current_tokens[0].start, 3),
440
+ end=round(current_tokens[-1].end, 3),
441
+ tokens=list(current_tokens),
442
+ confidence=sent_conf
443
+ ))
444
+
445
+ current_tokens = []
446
+
447
+ # Handle last sentence
448
+ if current_tokens:
449
+ # Decode all token IDs together to avoid UTF-8 issues
450
+ full_text = self._decode_sentence_from_tokens(current_tokens)
451
+ if full_text.strip():
452
+ valid_scores = [t.probability for t in current_tokens if t.probability > 0]
453
+ sent_conf = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
454
+
455
+ results.append(SentenceTimestamp(
456
+ text=full_text.strip(),
457
+ start=round(current_tokens[0].start, 3),
458
+ end=round(current_tokens[-1].end, 3),
459
+ tokens=list(current_tokens),
460
+ confidence=sent_conf
461
+ ))
462
+
463
+ # Normalize confidence scores
464
+ if results:
465
+ all_scores = [s.confidence for s in results]
466
+ min_score = min(all_scores)
467
+ max_score = max(all_scores)
468
+ score_range = max_score - min_score
469
+
470
+ if score_range > 1e-9:
471
+ for s in results:
472
+ normalized_score = (s.confidence - min_score) / score_range
473
+ s.confidence = round(normalized_score, 2)
474
+ else:
475
+ for s in results:
476
+ s.confidence = round(s.confidence, 2)
477
+
478
+ return results
479
+
480
+ def format_lrc(
481
+ self,
482
+ sentence_timestamps: List[SentenceTimestamp],
483
+ include_end_time: bool = False
484
+ ) -> str:
485
+ """
486
+ Format sentence timestamps as LRC lyrics format.
487
+
488
+ Args:
489
+ sentence_timestamps: List of SentenceTimestamp objects
490
+ include_end_time: Whether to include end time (enhanced LRC format)
491
+
492
+ Returns:
493
+ LRC formatted string
494
+ """
495
+ lines = []
496
+
497
+ for sentence in sentence_timestamps:
498
+ # Convert seconds to mm:ss.xx format
499
+ start_minutes = int(sentence.start // 60)
500
+ start_seconds = sentence.start % 60
501
+
502
+ if include_end_time:
503
+ end_minutes = int(sentence.end // 60)
504
+ end_seconds = sentence.end % 60
505
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}][{end_minutes:02d}:{end_seconds:05.2f}]"
506
+ else:
507
+ timestamp = f"[{start_minutes:02d}:{start_seconds:05.2f}]"
508
+
509
+ # Clean the text (remove structural tags like [verse], [chorus])
510
+ text = sentence.text
511
+
512
+ lines.append(f"{timestamp}{text}")
513
+
514
+ return "\n".join(lines)
515
+
516
+ def get_timestamps_and_lrc(
517
+ self,
518
+ calc_matrix: np.ndarray,
519
+ lyrics_tokens: List[int],
520
+ total_duration_seconds: float
521
+ ) -> Dict[str, Any]:
522
+ """
523
+ Convenience method to get both timestamps and LRC in one call.
524
+
525
+ Args:
526
+ calc_matrix: Processed attention matrix
527
+ lyrics_tokens: List of token IDs
528
+ total_duration_seconds: Total audio duration
529
+
530
+ Returns:
531
+ Dict containing token_timestamps, sentence_timestamps, and lrc_text
532
+ """
533
+ token_stamps = self.token_timestamps(
534
+ calc_matrix=calc_matrix,
535
+ lyrics_tokens=lyrics_tokens,
536
+ total_duration_seconds=total_duration_seconds
537
+ )
538
+
539
+ sentence_stamps = self.sentence_timestamps(token_stamps)
540
+ lrc_text = self.format_lrc(sentence_stamps)
541
+
542
+ return {
543
+ "token_timestamps": token_stamps,
544
+ "sentence_timestamps": sentence_stamps,
545
+ "lrc_text": lrc_text
546
+ }
547
+
548
+
549
+ class MusicLyricScorer:
550
+ """
551
+ Scorer class for evaluating lyrics-to-audio alignment quality.
552
+
553
+ Focuses on calculating alignment quality metrics (Coverage, Monotonicity, Confidence)
554
+ using tensor operations for potential differentiability or GPU acceleration.
555
+ """
556
+
557
+ def __init__(self, tokenizer: Any):
558
+ """
559
+ Initialize the aligner.
560
+
561
+ Args:
562
+ tokenizer: Tokenizer instance (must implement .decode()).
563
+ """
564
+ self.tokenizer = tokenizer
565
+
566
+ def _generate_token_type_mask(self, token_ids: List[int]) -> np.ndarray:
567
+ """
568
+ Generate a mask distinguishing lyrics (1) from structural tags (0).
569
+ Uses self.tokenizer to decode tokens.
570
+
571
+ Args:
572
+ token_ids: List of token IDs.
573
+
574
+ Returns:
575
+ Numpy array of shape [len(token_ids)] with 1 or 0.
576
+ """
577
+ decoded_tokens = [self.tokenizer.decode([tid]) for tid in token_ids]
578
+ mask = np.ones(len(token_ids), dtype=np.int32)
579
+ in_bracket = False
580
+
581
+ for i, token_str in enumerate(decoded_tokens):
582
+ if '[' in token_str:
583
+ in_bracket = True
584
+ if in_bracket:
585
+ mask[i] = 0
586
+ if ']' in token_str:
587
+ in_bracket = False
588
+ mask[i] = 0
589
+ return mask
590
+
591
+ def _preprocess_attention(
592
+ self,
593
+ attention_matrix: Union[torch.Tensor, np.ndarray],
594
+ custom_config: Dict[int, List[int]],
595
+ medfilt_width: int = 1
596
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[torch.Tensor]]:
597
+ """
598
+ Extracts and normalizes the attention matrix.
599
+
600
+ Logic V4: Uses Min-Max normalization to highlight energy differences.
601
+
602
+ Args:
603
+ attention_matrix: Raw attention tensor [Layers, Heads, Tokens, Frames].
604
+ custom_config: Config mapping layers to heads.
605
+ medfilt_width: Width for median filtering.
606
+
607
+ Returns:
608
+ Tuple of (calc_matrix, energy_matrix, avg_weights_tensor).
609
+ """
610
+ # 1. Prepare Tensor
611
+ if not isinstance(attention_matrix, torch.Tensor):
612
+ weights = torch.tensor(attention_matrix)
613
+ else:
614
+ weights = attention_matrix.clone()
615
+ weights = weights.cpu().float()
616
+
617
+ # 2. Select Heads based on config
618
+ selected_tensors = []
619
+ for layer_idx, head_indices in custom_config.items():
620
+ for head_idx in head_indices:
621
+ if layer_idx < weights.shape[0] and head_idx < weights.shape[1]:
622
+ selected_tensors.append(weights[layer_idx, head_idx])
623
+
624
+ if not selected_tensors:
625
+ return None, None, None
626
+
627
+ weights_stack = torch.stack(selected_tensors, dim=0)
628
+
629
+ # 3. Average Heads
630
+ avg_weights = weights_stack.mean(dim=0) # [Tokens, Frames]
631
+
632
+ # 4. Preprocessing Logic
633
+ # Min-Max normalization preserving energy distribution
634
+ # Median filter is applied to the energy matrix
635
+ energy_tensor = median_filter(avg_weights, filter_width=medfilt_width)
636
+ energy_matrix = energy_tensor.numpy()
637
+
638
+ e_min, e_max = energy_matrix.min(), energy_matrix.max()
639
+
640
+ if e_max - e_min > 1e-9:
641
+ energy_matrix = (energy_matrix - e_min) / (e_max - e_min)
642
+ else:
643
+ energy_matrix = np.zeros_like(energy_matrix)
644
+
645
+ # Contrast enhancement for DTW pathfinding
646
+ # calc_matrix is used for pathfinding, energy_matrix for scoring
647
+ calc_matrix = energy_matrix ** 2
648
+
649
+ return calc_matrix, energy_matrix, avg_weights
650
+
651
+ def _compute_alignment_metrics(
652
+ self,
653
+ energy_matrix: torch.Tensor,
654
+ path_coords: torch.Tensor,
655
+ type_mask: torch.Tensor,
656
+ time_weight: float = 0.01,
657
+ overlap_frames: float = 9.0,
658
+ instrumental_weight: float = 1.0
659
+ ) -> Tuple[float, float, float]:
660
+ """
661
+ Core metric calculation logic using high-precision Tensor operations.
662
+
663
+ Args:
664
+ energy_matrix: Normalized energy [Rows, Cols].
665
+ path_coords: DTW path coordinates [Steps, 2].
666
+ type_mask: Token type mask [Rows] (1=Lyrics, 0=Tags).
667
+ time_weight: Minimum energy threshold for monotonicity.
668
+ overlap_frames: Allowed overlap for monotonicity check.
669
+ instrumental_weight: Weight for non-lyric tokens in confidence calc.
670
+
671
+ Returns:
672
+ Tuple of (coverage, monotonicity, confidence).
673
+ """
674
+ # Ensure high precision for internal calculation
675
+ energy_matrix = energy_matrix.to(dtype=torch.float64)
676
+ path_coords = path_coords.long()
677
+ type_mask = type_mask.long()
678
+
679
+ device = energy_matrix.device
680
+ rows, cols = energy_matrix.shape
681
+
682
+ is_lyrics_row = (type_mask == 1)
683
+
684
+ # ================= A. Coverage Score =================
685
+ # Ratio of lyric lines that have significant energy peak
686
+ row_max_energies = energy_matrix.max(dim=1).values
687
+ total_sung_rows = is_lyrics_row.sum().double()
688
+
689
+ coverage_threshold = 0.1
690
+ valid_sung_mask = is_lyrics_row & (row_max_energies > coverage_threshold)
691
+ valid_sung_rows = valid_sung_mask.sum().double()
692
+
693
+ if total_sung_rows > 0:
694
+ coverage_score = valid_sung_rows / total_sung_rows
695
+ else:
696
+ coverage_score = torch.tensor(1.0, device=device, dtype=torch.float64)
697
+
698
+ # ================= B. Monotonicity Score =================
699
+ # Check if the "center of mass" of lyric lines moves forward in time
700
+ col_indices = torch.arange(cols, device=device, dtype=torch.float64)
701
+
702
+ # Zero out low energy noise
703
+ weights = torch.where(
704
+ energy_matrix > time_weight,
705
+ energy_matrix,
706
+ torch.zeros_like(energy_matrix)
707
+ )
708
+
709
+ sum_w = weights.sum(dim=1)
710
+ sum_t = (weights * col_indices).sum(dim=1)
711
+
712
+ # Calculate centroids
713
+ centroids = torch.full((rows,), -1.0, device=device, dtype=torch.float64)
714
+ valid_w_mask = sum_w > 1e-9
715
+ centroids[valid_w_mask] = sum_t[valid_w_mask] / sum_w[valid_w_mask]
716
+
717
+ # Extract sequence of valid lyrics centroids
718
+ valid_sequence_mask = is_lyrics_row & (centroids >= 0)
719
+ sung_centroids = centroids[valid_sequence_mask]
720
+
721
+ cnt = sung_centroids.shape[0]
722
+ if cnt > 1:
723
+ curr_c = sung_centroids[:-1]
724
+ next_c = sung_centroids[1:]
725
+
726
+ # Check non-decreasing order with overlap tolerance
727
+ non_decreasing = (next_c >= (curr_c - overlap_frames)).double().sum()
728
+ pairs = torch.tensor(cnt - 1, device=device, dtype=torch.float64)
729
+ monotonicity_score = non_decreasing / pairs
730
+ else:
731
+ monotonicity_score = torch.tensor(1.0, device=device, dtype=torch.float64)
732
+
733
+ # ================= C. Path Confidence =================
734
+ # Average energy along the optimal path
735
+ if path_coords.shape[0] > 0:
736
+ p_rows = path_coords[:, 0]
737
+ p_cols = path_coords[:, 1]
738
+
739
+ path_energies = energy_matrix[p_rows, p_cols]
740
+ step_weights = torch.ones_like(path_energies)
741
+
742
+ # Lower weight for instrumental/tag steps
743
+ is_inst_step = (type_mask[p_rows] == 0)
744
+ step_weights[is_inst_step] = instrumental_weight
745
+
746
+ total_energy = (path_energies * step_weights).sum()
747
+ total_steps = step_weights.sum()
748
+
749
+ if total_steps > 0:
750
+ path_confidence = total_energy / total_steps
751
+ else:
752
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
753
+ else:
754
+ path_confidence = torch.tensor(0.0, device=device, dtype=torch.float64)
755
+
756
+ return coverage_score.item(), monotonicity_score.item(), path_confidence.item()
757
+
758
+ def lyrics_alignment_info(
759
+ self,
760
+ attention_matrix: Union[torch.Tensor, np.ndarray],
761
+ token_ids: List[int],
762
+ custom_config: Dict[int, List[int]],
763
+ return_matrices: bool = False,
764
+ medfilt_width: int = 1
765
+ ) -> Dict[str, Any]:
766
+ """
767
+ Generates alignment path and processed matrices.
768
+
769
+ Args:
770
+ attention_matrix: Input attention tensor.
771
+ token_ids: Corresponding token IDs.
772
+ custom_config: Layer/Head configuration.
773
+ return_matrices: If True, returns matrices in the output.
774
+ medfilt_width: Median filter width.
775
+
776
+ Returns:
777
+ Dict or AlignmentInfo object containing path and masks.
778
+ """
779
+ calc_matrix, energy_matrix, vis_matrix = self._preprocess_attention(
780
+ attention_matrix, custom_config, medfilt_width
781
+ )
782
+
783
+ if calc_matrix is None:
784
+ return {
785
+ "calc_matrix": None,
786
+ "error": "No valid attention heads found"
787
+ }
788
+
789
+ # 1. Generate Semantic Mask (1=Lyrics, 0=Tags)
790
+ # Uses self.tokenizer internally
791
+ type_mask = self._generate_token_type_mask(token_ids)
792
+
793
+ # Safety check for shape mismatch
794
+ if len(type_mask) != energy_matrix.shape[0]:
795
+ # Fallback to all lyrics if shapes don't align
796
+ type_mask = np.ones(energy_matrix.shape[0], dtype=np.int32)
797
+
798
+ # 2. DTW Pathfinding
799
+ # Using negative calc_matrix because DTW minimizes cost
800
+ text_indices, time_indices = dtw_cpu(-calc_matrix.astype(np.float32))
801
+ path_coords = np.stack([text_indices, time_indices], axis=1)
802
+
803
+ return_dict = {
804
+ "path_coords": path_coords,
805
+ "type_mask": type_mask,
806
+ "energy_matrix": energy_matrix
807
+ }
808
+ if return_matrices:
809
+ return_dict['calc_matrix'] = calc_matrix
810
+ return_dict['vis_matrix'] = vis_matrix
811
+
812
+ return return_dict
813
+
814
+ def calculate_score(
815
+ self,
816
+ energy_matrix: Union[torch.Tensor, np.ndarray],
817
+ type_mask: Union[torch.Tensor, np.ndarray],
818
+ path_coords: Union[torch.Tensor, np.ndarray],
819
+ time_weight: float = 0.01,
820
+ overlap_frames: float = 9.0,
821
+ instrumental_weight: float = 1.0
822
+ ) -> Dict[str, Any]:
823
+ """
824
+ Calculates the final alignment score based on pre-computed components.
825
+
826
+ Args:
827
+ energy_matrix: Processed energy matrix.
828
+ type_mask: Token type mask.
829
+ path_coords: DTW path coordinates.
830
+ time_weight: Minimum energy threshold for monotonicity.
831
+ overlap_frames: Allowed backward movement frames.
832
+ instrumental_weight: Weight for non-lyric path steps.
833
+
834
+ Returns:
835
+ AlignmentScore object containing individual metrics and final score.
836
+ """
837
+ # Ensure Inputs are Tensors on the correct device
838
+ if not isinstance(energy_matrix, torch.Tensor):
839
+ energy_matrix = torch.tensor(energy_matrix, device='cuda', dtype=torch.float32)
840
+
841
+ device = energy_matrix.device
842
+
843
+ if not isinstance(type_mask, torch.Tensor):
844
+ type_mask = torch.tensor(type_mask, device=device, dtype=torch.long)
845
+ else:
846
+ type_mask = type_mask.to(device=device, dtype=torch.long)
847
+
848
+ if not isinstance(path_coords, torch.Tensor):
849
+ path_coords = torch.tensor(path_coords, device=device, dtype=torch.long)
850
+ else:
851
+ path_coords = path_coords.to(device=device, dtype=torch.long)
852
+
853
+ # Compute Metrics
854
+ coverage, monotonicity, confidence = self._compute_alignment_metrics(
855
+ energy_matrix=energy_matrix,
856
+ path_coords=path_coords,
857
+ type_mask=type_mask,
858
+ time_weight=time_weight,
859
+ overlap_frames=overlap_frames,
860
+ instrumental_weight=instrumental_weight
861
+ )
862
+
863
+ # Final Score Calculation
864
+ # (Cov^2 * Mono^2 * Conf)
865
+ final_score = (coverage ** 2) * (monotonicity ** 2) * confidence
866
+ final_score = float(np.clip(final_score, 0.0, 1.0))
867
+
868
+ return {
869
+ "lyrics_score": round(final_score, 4)
870
+ }
acestep/genres_vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
acestep/gradio_ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from acestep.gradio_ui.interfaces import create_gradio_interface
acestep/gradio_ui/events/__init__.py ADDED
@@ -0,0 +1,1355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Event Handlers Module
3
+ Main entry point for setting up all event handlers
4
+ """
5
+ import os
6
+ import gradio as gr
7
+ from typing import Optional
8
+ from loguru import logger
9
+
10
+ # Import handler modules
11
+ from . import generation_handlers as gen_h
12
+ from . import results_handlers as res_h
13
+ from . import training_handlers as train_h
14
+ from acestep.gradio_ui.i18n import t
15
+
16
+ # HuggingFace Space environment detection for ZeroGPU support
17
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
18
+
19
+
20
+ def _get_spaces_gpu_decorator(duration=120):
21
+ """
22
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
23
+ Returns identity decorator if not in Space environment.
24
+ """
25
+ if IS_HUGGINGFACE_SPACE:
26
+ try:
27
+ import spaces
28
+ return spaces.GPU(duration=duration)
29
+ except ImportError:
30
+ logger.warning("spaces package not found, GPU decorator disabled")
31
+ return lambda func: func
32
+ return lambda func: func
33
+
34
+
35
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=None):
36
+ """Setup event handlers connecting UI components and business logic
37
+
38
+ Args:
39
+ init_params: Dictionary containing initialization parameters including:
40
+ - dit_handler_2: Optional second DiT handler for multi-model setup
41
+ - available_dit_models: List of available DiT model names
42
+ - config_path: Primary model config path
43
+ - config_path_2: Secondary model config path (if available)
44
+ """
45
+ # Get secondary DiT handler from init_params (for multi-model support)
46
+ dit_handler_2 = init_params.get('dit_handler_2') if init_params else None
47
+ config_path_1 = init_params.get('config_path', '') if init_params else ''
48
+ config_path_2 = init_params.get('config_path_2', '') if init_params else ''
49
+
50
+ # ========== Dataset Handlers ==========
51
+ dataset_section["import_dataset_btn"].click(
52
+ fn=dataset_handler.import_dataset,
53
+ inputs=[dataset_section["dataset_type"]],
54
+ outputs=[dataset_section["data_status"]]
55
+ )
56
+
57
+ # ========== Service Initialization ==========
58
+ generation_section["refresh_btn"].click(
59
+ fn=lambda: gen_h.refresh_checkpoints(dit_handler),
60
+ outputs=[generation_section["checkpoint_dropdown"]]
61
+ )
62
+
63
+ generation_section["config_path"].change(
64
+ fn=gen_h.update_model_type_settings,
65
+ inputs=[generation_section["config_path"]],
66
+ outputs=[
67
+ generation_section["inference_steps"],
68
+ generation_section["guidance_scale"],
69
+ generation_section["use_adg"],
70
+ generation_section["shift"],
71
+ generation_section["cfg_interval_start"],
72
+ generation_section["cfg_interval_end"],
73
+ generation_section["task_type"],
74
+ ]
75
+ )
76
+
77
+ generation_section["init_btn"].click(
78
+ fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
79
+ inputs=[
80
+ generation_section["checkpoint_dropdown"],
81
+ generation_section["config_path"],
82
+ generation_section["device"],
83
+ generation_section["init_llm_checkbox"],
84
+ generation_section["lm_model_path"],
85
+ generation_section["backend_dropdown"],
86
+ generation_section["use_flash_attention_checkbox"],
87
+ generation_section["offload_to_cpu_checkbox"],
88
+ generation_section["offload_dit_to_cpu_checkbox"],
89
+ ],
90
+ outputs=[
91
+ generation_section["init_status"],
92
+ generation_section["generate_btn"],
93
+ generation_section["service_config_accordion"],
94
+ # Model type settings (updated based on actual loaded model)
95
+ generation_section["inference_steps"],
96
+ generation_section["guidance_scale"],
97
+ generation_section["use_adg"],
98
+ generation_section["shift"],
99
+ generation_section["cfg_interval_start"],
100
+ generation_section["cfg_interval_end"],
101
+ generation_section["task_type"],
102
+ ]
103
+ )
104
+
105
+ # ========== LoRA Handlers ==========
106
+ generation_section["load_lora_btn"].click(
107
+ fn=dit_handler.load_lora,
108
+ inputs=[generation_section["lora_path"]],
109
+ outputs=[generation_section["lora_status"]]
110
+ ).then(
111
+ # Update checkbox to enabled state after loading
112
+ fn=lambda: gr.update(value=True),
113
+ outputs=[generation_section["use_lora_checkbox"]]
114
+ )
115
+
116
+ generation_section["unload_lora_btn"].click(
117
+ fn=dit_handler.unload_lora,
118
+ outputs=[generation_section["lora_status"]]
119
+ ).then(
120
+ # Update checkbox to disabled state after unloading
121
+ fn=lambda: gr.update(value=False),
122
+ outputs=[generation_section["use_lora_checkbox"]]
123
+ )
124
+
125
+ generation_section["use_lora_checkbox"].change(
126
+ fn=dit_handler.set_use_lora,
127
+ inputs=[generation_section["use_lora_checkbox"]],
128
+ outputs=[generation_section["lora_status"]]
129
+ )
130
+
131
+ # ========== UI Visibility Updates ==========
132
+ generation_section["init_llm_checkbox"].change(
133
+ fn=gen_h.update_negative_prompt_visibility,
134
+ inputs=[generation_section["init_llm_checkbox"]],
135
+ outputs=[generation_section["lm_negative_prompt"]]
136
+ )
137
+
138
+ generation_section["init_llm_checkbox"].change(
139
+ fn=gen_h.update_audio_cover_strength_visibility,
140
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
141
+ outputs=[generation_section["audio_cover_strength"]]
142
+ )
143
+
144
+ generation_section["task_type"].change(
145
+ fn=gen_h.update_audio_cover_strength_visibility,
146
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
147
+ outputs=[generation_section["audio_cover_strength"]]
148
+ )
149
+
150
+ generation_section["batch_size_input"].change(
151
+ fn=gen_h.update_audio_components_visibility,
152
+ inputs=[generation_section["batch_size_input"]],
153
+ outputs=[
154
+ results_section["audio_col_1"],
155
+ results_section["audio_col_2"],
156
+ results_section["audio_col_3"],
157
+ results_section["audio_col_4"],
158
+ results_section["audio_row_5_8"],
159
+ results_section["audio_col_5"],
160
+ results_section["audio_col_6"],
161
+ results_section["audio_col_7"],
162
+ results_section["audio_col_8"],
163
+ ]
164
+ )
165
+
166
+ # ========== Audio Conversion ==========
167
+ generation_section["convert_src_to_codes_btn"].click(
168
+ fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
169
+ inputs=[generation_section["src_audio"]],
170
+ outputs=[generation_section["text2music_audio_code_string"]]
171
+ )
172
+
173
+ # ========== Instruction UI Updates ==========
174
+ for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"]]:
175
+ trigger.change(
176
+ fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
177
+ inputs=[
178
+ generation_section["task_type"],
179
+ generation_section["track_name"],
180
+ generation_section["complete_track_classes"],
181
+ generation_section["text2music_audio_code_string"],
182
+ generation_section["init_llm_checkbox"]
183
+ ],
184
+ outputs=[
185
+ generation_section["instruction_display_gen"],
186
+ generation_section["track_name"],
187
+ generation_section["complete_track_classes"],
188
+ generation_section["audio_cover_strength"],
189
+ generation_section["repainting_group"],
190
+ generation_section["text2music_audio_codes_group"],
191
+ ]
192
+ )
193
+
194
+ # ========== Sample/Transcribe Handlers ==========
195
+ # Load random example from ./examples/text2music directory
196
+ generation_section["sample_btn"].click(
197
+ fn=lambda task: gen_h.load_random_example(task) + (True,),
198
+ inputs=[
199
+ generation_section["task_type"],
200
+ ],
201
+ outputs=[
202
+ generation_section["captions"],
203
+ generation_section["lyrics"],
204
+ generation_section["think_checkbox"],
205
+ generation_section["bpm"],
206
+ generation_section["audio_duration"],
207
+ generation_section["key_scale"],
208
+ generation_section["vocal_language"],
209
+ generation_section["time_signature"],
210
+ results_section["is_format_caption_state"]
211
+ ]
212
+ )
213
+
214
+ generation_section["text2music_audio_code_string"].change(
215
+ fn=gen_h.update_transcribe_button_text,
216
+ inputs=[generation_section["text2music_audio_code_string"]],
217
+ outputs=[generation_section["transcribe_btn"]]
218
+ )
219
+
220
+ generation_section["transcribe_btn"].click(
221
+ fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
222
+ inputs=[
223
+ generation_section["text2music_audio_code_string"],
224
+ generation_section["constrained_decoding_debug"]
225
+ ],
226
+ outputs=[
227
+ results_section["status_output"],
228
+ generation_section["captions"],
229
+ generation_section["lyrics"],
230
+ generation_section["bpm"],
231
+ generation_section["audio_duration"],
232
+ generation_section["key_scale"],
233
+ generation_section["vocal_language"],
234
+ generation_section["time_signature"],
235
+ results_section["is_format_caption_state"]
236
+ ]
237
+ )
238
+
239
+ # ========== Reset Format Caption Flag ==========
240
+ for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
241
+ generation_section["key_scale"], generation_section["time_signature"],
242
+ generation_section["vocal_language"], generation_section["audio_duration"]]:
243
+ trigger.change(
244
+ fn=gen_h.reset_format_caption_flag,
245
+ inputs=[],
246
+ outputs=[results_section["is_format_caption_state"]]
247
+ )
248
+
249
+ # ========== Audio Uploads Accordion ==========
250
+ for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
251
+ trigger.change(
252
+ fn=gen_h.update_audio_uploads_accordion,
253
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
254
+ outputs=[generation_section["audio_uploads_accordion"]]
255
+ )
256
+
257
+ # ========== Instrumental Checkbox ==========
258
+ generation_section["instrumental_checkbox"].change(
259
+ fn=gen_h.handle_instrumental_checkbox,
260
+ inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
261
+ outputs=[generation_section["lyrics"]]
262
+ )
263
+
264
+ # ========== Format Button ==========
265
+ # Note: cfg_scale and negative_prompt are not supported in format mode
266
+ @_get_spaces_gpu_decorator(duration=120)
267
+ def handle_format_sample_wrapper(caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug):
268
+ return gen_h.handle_format_sample(
269
+ llm_handler, caption, lyrics, bpm, duration, key_scale, time_sig, temp, top_k, top_p, debug
270
+ )
271
+
272
+ generation_section["format_btn"].click(
273
+ fn=handle_format_sample_wrapper,
274
+ inputs=[
275
+ generation_section["captions"],
276
+ generation_section["lyrics"],
277
+ generation_section["bpm"],
278
+ generation_section["audio_duration"],
279
+ generation_section["key_scale"],
280
+ generation_section["time_signature"],
281
+ generation_section["lm_temperature"],
282
+ generation_section["lm_top_k"],
283
+ generation_section["lm_top_p"],
284
+ generation_section["constrained_decoding_debug"],
285
+ ],
286
+ outputs=[
287
+ generation_section["captions"],
288
+ generation_section["lyrics"],
289
+ generation_section["bpm"],
290
+ generation_section["audio_duration"],
291
+ generation_section["key_scale"],
292
+ generation_section["vocal_language"],
293
+ generation_section["time_signature"],
294
+ results_section["is_format_caption_state"],
295
+ results_section["status_output"],
296
+ ]
297
+ )
298
+
299
+ # ========== Generation Mode Toggle (Simple/Custom/Cover/Repaint) ==========
300
+ generation_section["generation_mode"].change(
301
+ fn=gen_h.handle_generation_mode_change,
302
+ inputs=[generation_section["generation_mode"]],
303
+ outputs=[
304
+ generation_section["simple_mode_group"],
305
+ generation_section["custom_mode_content"],
306
+ generation_section["cover_mode_group"],
307
+ generation_section["repainting_group"],
308
+ generation_section["task_type"],
309
+ generation_section["generate_btn"],
310
+ generation_section["simple_sample_created"],
311
+ generation_section["src_audio_group"],
312
+ generation_section["audio_cover_strength"],
313
+ generation_section["think_checkbox"], # Disable thinking for cover/repaint modes
314
+ ]
315
+ )
316
+
317
+ # ========== Process Source Audio Button ==========
318
+ # Combines Convert to Codes + Transcribe in one step
319
+ # Note: @spaces.GPU decorator must be on the function passed directly to fn=,
320
+ # not on a module-level function wrapped in a lambda. Lambdas capturing handler
321
+ # objects cause pickling errors on ZeroGPU because the model contains unpicklable
322
+ # local objects (e.g. AceStepDiTModel.__init__ lambdas).
323
+ @_get_spaces_gpu_decorator(duration=120)
324
+ def process_source_audio_wrapper(src, debug):
325
+ return gen_h.process_source_audio(dit_handler, llm_handler, src, debug)
326
+
327
+ generation_section["process_src_btn"].click(
328
+ fn=process_source_audio_wrapper,
329
+ inputs=[
330
+ generation_section["src_audio"],
331
+ generation_section["constrained_decoding_debug"]
332
+ ],
333
+ outputs=[
334
+ generation_section["text2music_audio_code_string"],
335
+ results_section["status_output"],
336
+ generation_section["captions"],
337
+ generation_section["lyrics"],
338
+ generation_section["bpm"],
339
+ generation_section["audio_duration"],
340
+ generation_section["key_scale"],
341
+ generation_section["vocal_language"],
342
+ generation_section["time_signature"],
343
+ results_section["is_format_caption_state"],
344
+ ]
345
+ )
346
+
347
+ # ========== Simple Mode Instrumental Checkbox ==========
348
+ # When instrumental is checked, disable vocal language and set to ["unknown"]
349
+ generation_section["simple_instrumental_checkbox"].change(
350
+ fn=gen_h.handle_simple_instrumental_change,
351
+ inputs=[generation_section["simple_instrumental_checkbox"]],
352
+ outputs=[generation_section["simple_vocal_language"]]
353
+ )
354
+
355
+ # ========== Random Description Button ==========
356
+ generation_section["random_desc_btn"].click(
357
+ fn=gen_h.load_random_simple_description,
358
+ inputs=[],
359
+ outputs=[
360
+ generation_section["simple_query_input"],
361
+ generation_section["simple_instrumental_checkbox"],
362
+ generation_section["simple_vocal_language"],
363
+ ]
364
+ )
365
+
366
+ # ========== Create Sample Button (Simple Mode) ==========
367
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
368
+ @_get_spaces_gpu_decorator(duration=120)
369
+ def handle_create_sample_wrapper(query, instrumental, vocal_lang, temp, top_k, top_p, debug):
370
+ return gen_h.handle_create_sample(
371
+ llm_handler, query, instrumental, vocal_lang, temp, top_k, top_p, debug
372
+ )
373
+
374
+ generation_section["create_sample_btn"].click(
375
+ fn=handle_create_sample_wrapper,
376
+ inputs=[
377
+ generation_section["simple_query_input"],
378
+ generation_section["simple_instrumental_checkbox"],
379
+ generation_section["simple_vocal_language"],
380
+ generation_section["lm_temperature"],
381
+ generation_section["lm_top_k"],
382
+ generation_section["lm_top_p"],
383
+ generation_section["constrained_decoding_debug"],
384
+ ],
385
+ outputs=[
386
+ generation_section["captions"],
387
+ generation_section["lyrics"],
388
+ generation_section["bpm"],
389
+ generation_section["audio_duration"],
390
+ generation_section["key_scale"],
391
+ generation_section["vocal_language"],
392
+ generation_section["simple_vocal_language"],
393
+ generation_section["time_signature"],
394
+ generation_section["instrumental_checkbox"],
395
+ generation_section["caption_accordion"],
396
+ generation_section["lyrics_accordion"],
397
+ generation_section["generate_btn"],
398
+ generation_section["simple_sample_created"],
399
+ generation_section["think_checkbox"],
400
+ results_section["is_format_caption_state"],
401
+ results_section["status_output"],
402
+ ]
403
+ )
404
+
405
+ # ========== Load/Save Metadata ==========
406
+ generation_section["load_file"].upload(
407
+ fn=gen_h.load_metadata,
408
+ inputs=[generation_section["load_file"]],
409
+ outputs=[
410
+ generation_section["task_type"],
411
+ generation_section["captions"],
412
+ generation_section["lyrics"],
413
+ generation_section["vocal_language"],
414
+ generation_section["bpm"],
415
+ generation_section["key_scale"],
416
+ generation_section["time_signature"],
417
+ generation_section["audio_duration"],
418
+ generation_section["batch_size_input"],
419
+ generation_section["inference_steps"],
420
+ generation_section["guidance_scale"],
421
+ generation_section["seed"],
422
+ generation_section["random_seed_checkbox"],
423
+ generation_section["use_adg"],
424
+ generation_section["cfg_interval_start"],
425
+ generation_section["cfg_interval_end"],
426
+ generation_section["shift"],
427
+ generation_section["infer_method"],
428
+ generation_section["custom_timesteps"],
429
+ generation_section["audio_format"],
430
+ generation_section["lm_temperature"],
431
+ generation_section["lm_cfg_scale"],
432
+ generation_section["lm_top_k"],
433
+ generation_section["lm_top_p"],
434
+ generation_section["lm_negative_prompt"],
435
+ generation_section["use_cot_metas"], # Added: use_cot_metas
436
+ generation_section["use_cot_caption"],
437
+ generation_section["use_cot_language"],
438
+ generation_section["audio_cover_strength"],
439
+ generation_section["think_checkbox"],
440
+ generation_section["text2music_audio_code_string"],
441
+ generation_section["repainting_start"],
442
+ generation_section["repainting_end"],
443
+ generation_section["track_name"],
444
+ generation_section["complete_track_classes"],
445
+ generation_section["instrumental_checkbox"], # Added: instrumental_checkbox
446
+ results_section["is_format_caption_state"]
447
+ ]
448
+ )
449
+
450
+ # Save buttons for all 8 audio outputs
451
+ download_existing_js = """(current_audio, batch_files) => {
452
+ // Debug: print what the input actually is
453
+ console.log("👉 [Debug] Current Audio Input:", current_audio);
454
+
455
+ // 1. Safety check
456
+ if (!current_audio) {
457
+ console.warn("⚠️ No audio selected or audio is empty.");
458
+ return;
459
+ }
460
+ if (!batch_files || !Array.isArray(batch_files)) {
461
+ console.warn("⚠️ Batch file list is empty/not ready.");
462
+ return;
463
+ }
464
+
465
+ // 2. Smartly extract path string
466
+ let pathString = "";
467
+
468
+ if (typeof current_audio === "string") {
469
+ // Case A: direct path string received
470
+ pathString = current_audio;
471
+ } else if (typeof current_audio === "object") {
472
+ // Case B: an object is received, try common properties
473
+ // Gradio file objects usually have path, url, or name
474
+ pathString = current_audio.path || current_audio.name || current_audio.url || "";
475
+ }
476
+
477
+ if (!pathString) {
478
+ console.error("❌ Error: Could not extract a valid path string from input.", current_audio);
479
+ return;
480
+ }
481
+
482
+ // 3. Extract Key (UUID)
483
+ // Path could be /tmp/.../uuid.mp3 or url like /file=.../uuid.mp3
484
+ let filename = pathString.split(/[\\\\/]/).pop(); // get the filename
485
+ let key = filename.split('.')[0]; // get UUID without extension
486
+
487
+ console.log(`🔑 Key extracted: ${key}`);
488
+
489
+ // 4. Find matching file(s) in the list
490
+ let targets = batch_files.filter(f => {
491
+ // Also extract names from batch_files objects
492
+ // f usually contains name (backend path) and orig_name (download name)
493
+ const fPath = f.name || f.path || "";
494
+ return fPath.includes(key);
495
+ });
496
+
497
+ if (targets.length === 0) {
498
+ console.warn("❌ No matching files found in batch list for key:", key);
499
+ alert("Batch list does not contain this file yet. Please wait for generation to finish.");
500
+ return;
501
+ }
502
+
503
+ // 5. Trigger download(s)
504
+ console.log(`🎯 Found ${targets.length} files to download.`);
505
+ targets.forEach((f, index) => {
506
+ setTimeout(() => {
507
+ const a = document.createElement('a');
508
+ // Prefer url (frontend-accessible link), otherwise try data
509
+ a.href = f.url || f.data;
510
+ a.download = f.orig_name || "download";
511
+ a.style.display = 'none';
512
+ document.body.appendChild(a);
513
+ a.click();
514
+ document.body.removeChild(a);
515
+ }, index * 1000); // 300ms interval to avoid browser blocking
516
+ });
517
+ }
518
+ """
519
+ for btn_idx in range(1, 9):
520
+ results_section[f"save_btn_{btn_idx}"].click(
521
+ fn=None,
522
+ inputs=[
523
+ results_section[f"generated_audio_{btn_idx}"],
524
+ results_section["generated_audio_batch"],
525
+ ],
526
+ js=download_existing_js # Run the above JS
527
+ )
528
+ # ========== Send to Cover Handlers ==========
529
+ def send_to_cover_handler(audio_file, lm_metadata):
530
+ """Send audio to cover mode and switch to cover"""
531
+ if audio_file is None:
532
+ return (gr.skip(),) * 11
533
+ return (
534
+ audio_file, # src_audio
535
+ gr.skip(), # bpm
536
+ gr.skip(), # captions
537
+ gr.skip(), # lyrics
538
+ gr.skip(), # audio_duration
539
+ gr.skip(), # key_scale
540
+ gr.skip(), # vocal_language
541
+ gr.skip(), # time_signature
542
+ gr.skip(), # is_format_caption_state
543
+ "cover", # generation_mode - switch to cover
544
+ "cover", # task_type - set to cover
545
+ )
546
+
547
+ for btn_idx in range(1, 9):
548
+ results_section[f"send_to_cover_btn_{btn_idx}"].click(
549
+ fn=send_to_cover_handler,
550
+ inputs=[
551
+ results_section[f"generated_audio_{btn_idx}"],
552
+ results_section["lm_metadata_state"]
553
+ ],
554
+ outputs=[
555
+ generation_section["src_audio"],
556
+ generation_section["bpm"],
557
+ generation_section["captions"],
558
+ generation_section["lyrics"],
559
+ generation_section["audio_duration"],
560
+ generation_section["key_scale"],
561
+ generation_section["vocal_language"],
562
+ generation_section["time_signature"],
563
+ results_section["is_format_caption_state"],
564
+ generation_section["generation_mode"],
565
+ generation_section["task_type"],
566
+ ]
567
+ )
568
+
569
+ # ========== Send to Repaint Handlers ==========
570
+ def send_to_repaint_handler(audio_file, lm_metadata):
571
+ """Send audio to repaint mode and switch to repaint"""
572
+ if audio_file is None:
573
+ return (gr.skip(),) * 11
574
+ return (
575
+ audio_file, # src_audio
576
+ gr.skip(), # bpm
577
+ gr.skip(), # captions
578
+ gr.skip(), # lyrics
579
+ gr.skip(), # audio_duration
580
+ gr.skip(), # key_scale
581
+ gr.skip(), # vocal_language
582
+ gr.skip(), # time_signature
583
+ gr.skip(), # is_format_caption_state
584
+ "repaint", # generation_mode - switch to repaint
585
+ "repaint", # task_type - set to repaint
586
+ )
587
+
588
+ for btn_idx in range(1, 9):
589
+ results_section[f"send_to_repaint_btn_{btn_idx}"].click(
590
+ fn=send_to_repaint_handler,
591
+ inputs=[
592
+ results_section[f"generated_audio_{btn_idx}"],
593
+ results_section["lm_metadata_state"]
594
+ ],
595
+ outputs=[
596
+ generation_section["src_audio"],
597
+ generation_section["bpm"],
598
+ generation_section["captions"],
599
+ generation_section["lyrics"],
600
+ generation_section["audio_duration"],
601
+ generation_section["key_scale"],
602
+ generation_section["vocal_language"],
603
+ generation_section["time_signature"],
604
+ results_section["is_format_caption_state"],
605
+ generation_section["generation_mode"],
606
+ generation_section["task_type"],
607
+ ]
608
+ )
609
+
610
+ # ========== Score Calculation Handlers ==========
611
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
612
+ # Note: @spaces.GPU decorator applied here (not on module-level function) to avoid
613
+ # pickling issues on ZeroGPU when handler objects are captured in closures.
614
+ def make_score_handler(idx):
615
+ @_get_spaces_gpu_decorator(duration=120)
616
+ def score_handler(scale, batch_idx, queue):
617
+ return res_h.calculate_score_handler_with_selection(
618
+ dit_handler, llm_handler, idx, scale, batch_idx, queue
619
+ )
620
+ return score_handler
621
+
622
+ for btn_idx in range(1, 9):
623
+ results_section[f"score_btn_{btn_idx}"].click(
624
+ fn=make_score_handler(btn_idx),
625
+ inputs=[
626
+ generation_section["score_scale"],
627
+ results_section["current_batch_index"],
628
+ results_section["batch_queue"],
629
+ ],
630
+ outputs=[
631
+ results_section[f"score_display_{btn_idx}"],
632
+ results_section[f"details_accordion_{btn_idx}"],
633
+ results_section["batch_queue"]
634
+ ]
635
+ )
636
+
637
+ # ========== LRC Timestamp Handlers ==========
638
+ # Use default argument to capture btn_idx value at definition time (Python closure fix)
639
+ def make_lrc_handler(idx):
640
+ @_get_spaces_gpu_decorator(duration=120)
641
+ def lrc_handler(batch_idx, queue, vocal_lang, infer_steps):
642
+ return res_h.generate_lrc_handler(
643
+ dit_handler, idx, batch_idx, queue, vocal_lang, infer_steps
644
+ )
645
+ return lrc_handler
646
+
647
+ for btn_idx in range(1, 9):
648
+ results_section[f"lrc_btn_{btn_idx}"].click(
649
+ fn=make_lrc_handler(btn_idx),
650
+ inputs=[
651
+ results_section["current_batch_index"],
652
+ results_section["batch_queue"],
653
+ generation_section["vocal_language"],
654
+ generation_section["inference_steps"],
655
+ ],
656
+ outputs=[
657
+ results_section[f"lrc_display_{btn_idx}"],
658
+ results_section[f"details_accordion_{btn_idx}"],
659
+ # NOTE: Removed generated_audio output!
660
+ # Audio subtitles are now updated via lrc_display.change() event.
661
+ results_section["batch_queue"]
662
+ ]
663
+ )
664
+
665
+ @_get_spaces_gpu_decorator(duration=120)
666
+ def generation_wrapper(selected_model, generation_mode, simple_query_input, simple_vocal_language, *args):
667
+ """Wrapper that selects the appropriate DiT handler based on model selection"""
668
+ # Convert args to list for modification
669
+ args_list = list(args)
670
+
671
+ # args order (after simple mode params):
672
+ # captions (0), lyrics (1), bpm (2), key_scale (3), time_signature (4), vocal_language (5),
673
+ # inference_steps (6), guidance_scale (7), random_seed_checkbox (8), seed (9),
674
+ # reference_audio (10), audio_duration (11), batch_size_input (12), src_audio (13),
675
+ # text2music_audio_code_string (14), repainting_start (15), repainting_end (16),
676
+ # instruction_display_gen (17), audio_cover_strength (18), task_type (19), ...
677
+ # ... lm_temperature (27), think_checkbox (28), ...
678
+ # ... instrumental_checkbox (at position after all regular params)
679
+
680
+ src_audio = args_list[13] if len(args_list) > 13 else None
681
+ task_type = args_list[19] if len(args_list) > 19 else "text2music"
682
+
683
+ # Validate: Cover and Repaint modes require source audio
684
+ if task_type in ["cover", "repaint"] and src_audio is None:
685
+ raise gr.Error(f"Source Audio is required for {task_type.capitalize()} mode. Please upload an audio file.")
686
+
687
+ # Handle Simple mode: first create sample, then generate
688
+ if generation_mode == "simple":
689
+ # Get instrumental from the main checkbox (args[-6] based on input order)
690
+ # The instrumental_checkbox is passed after all the regular generation params
691
+ instrumental = args_list[-6] if len(args_list) > 6 else False # instrumental_checkbox position
692
+ lm_temperature = args_list[27] if len(args_list) > 27 else 0.85
693
+ lm_top_k = args_list[30] if len(args_list) > 30 else 0
694
+ lm_top_p = args_list[31] if len(args_list) > 31 else 0.9
695
+ constrained_decoding_debug = args_list[38] if len(args_list) > 38 else False
696
+
697
+ # Call create_sample to generate caption/lyrics/metadata
698
+ from acestep.inference import create_sample
699
+
700
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
701
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
702
+
703
+ result = create_sample(
704
+ llm_handler=llm_handler,
705
+ query=simple_query_input,
706
+ instrumental=instrumental,
707
+ vocal_language=simple_vocal_language,
708
+ temperature=lm_temperature,
709
+ top_k=top_k_value,
710
+ top_p=top_p_value,
711
+ use_constrained_decoding=True,
712
+ constrained_decoding_debug=constrained_decoding_debug,
713
+ )
714
+
715
+ if not result.success:
716
+ raise gr.Error(f"Failed to create sample: {result.status_message}")
717
+
718
+ # Update args with generated data
719
+ args_list[0] = result.caption # captions
720
+ args_list[1] = result.lyrics # lyrics
721
+ args_list[2] = result.bpm # bpm
722
+ args_list[3] = result.keyscale # key_scale
723
+ args_list[4] = result.timesignature # time_signature
724
+ args_list[5] = result.language # vocal_language
725
+ if result.duration and result.duration > 0:
726
+ args_list[11] = result.duration # audio_duration
727
+ # Enable thinking for Simple mode
728
+ args_list[28] = True # think_checkbox
729
+ # Mark as formatted caption (LM-generated sample)
730
+ args_list[36] = True # is_format_caption_state
731
+
732
+ # Determine which handler to use based on model selection
733
+ active_handler = dit_handler # Default to primary handler
734
+ if dit_handler_2 is not None and selected_model == config_path_2:
735
+ active_handler = dit_handler_2
736
+ yield from res_h.generate_with_batch_management(active_handler, llm_handler, *args_list)
737
+
738
+ # ========== Generation Handler ==========
739
+ generation_section["generate_btn"].click(
740
+ fn=generation_wrapper,
741
+ inputs=[
742
+ generation_section["dit_model_selector"], # Model selection input
743
+ generation_section["generation_mode"], # For Simple mode detection
744
+ generation_section["simple_query_input"], # Simple mode query
745
+ generation_section["simple_vocal_language"], # Simple mode vocal language
746
+ generation_section["captions"],
747
+ generation_section["lyrics"],
748
+ generation_section["bpm"],
749
+ generation_section["key_scale"],
750
+ generation_section["time_signature"],
751
+ generation_section["vocal_language"],
752
+ generation_section["inference_steps"],
753
+ generation_section["guidance_scale"],
754
+ generation_section["random_seed_checkbox"],
755
+ generation_section["seed"],
756
+ generation_section["reference_audio"],
757
+ generation_section["audio_duration"],
758
+ generation_section["batch_size_input"],
759
+ generation_section["src_audio"],
760
+ generation_section["text2music_audio_code_string"],
761
+ generation_section["repainting_start"],
762
+ generation_section["repainting_end"],
763
+ generation_section["instruction_display_gen"],
764
+ generation_section["audio_cover_strength"],
765
+ generation_section["task_type"],
766
+ generation_section["use_adg"],
767
+ generation_section["cfg_interval_start"],
768
+ generation_section["cfg_interval_end"],
769
+ generation_section["shift"],
770
+ generation_section["infer_method"],
771
+ generation_section["custom_timesteps"],
772
+ generation_section["audio_format"],
773
+ generation_section["lm_temperature"],
774
+ generation_section["think_checkbox"],
775
+ generation_section["lm_cfg_scale"],
776
+ generation_section["lm_top_k"],
777
+ generation_section["lm_top_p"],
778
+ generation_section["lm_negative_prompt"],
779
+ generation_section["use_cot_metas"],
780
+ generation_section["use_cot_caption"],
781
+ generation_section["use_cot_language"],
782
+ results_section["is_format_caption_state"],
783
+ generation_section["constrained_decoding_debug"],
784
+ generation_section["allow_lm_batch"],
785
+ generation_section["auto_score"],
786
+ generation_section["auto_lrc"],
787
+ generation_section["score_scale"],
788
+ generation_section["lm_batch_chunk_size"],
789
+ generation_section["track_name"],
790
+ generation_section["complete_track_classes"],
791
+ generation_section["autogen_checkbox"],
792
+ results_section["current_batch_index"],
793
+ results_section["total_batches"],
794
+ results_section["batch_queue"],
795
+ results_section["generation_params_state"],
796
+ ],
797
+ outputs=[
798
+ results_section["generated_audio_1"],
799
+ results_section["generated_audio_2"],
800
+ results_section["generated_audio_3"],
801
+ results_section["generated_audio_4"],
802
+ results_section["generated_audio_5"],
803
+ results_section["generated_audio_6"],
804
+ results_section["generated_audio_7"],
805
+ results_section["generated_audio_8"],
806
+ results_section["generated_audio_batch"],
807
+ results_section["generation_info"],
808
+ results_section["status_output"],
809
+ generation_section["seed"],
810
+ results_section["score_display_1"],
811
+ results_section["score_display_2"],
812
+ results_section["score_display_3"],
813
+ results_section["score_display_4"],
814
+ results_section["score_display_5"],
815
+ results_section["score_display_6"],
816
+ results_section["score_display_7"],
817
+ results_section["score_display_8"],
818
+ results_section["codes_display_1"],
819
+ results_section["codes_display_2"],
820
+ results_section["codes_display_3"],
821
+ results_section["codes_display_4"],
822
+ results_section["codes_display_5"],
823
+ results_section["codes_display_6"],
824
+ results_section["codes_display_7"],
825
+ results_section["codes_display_8"],
826
+ results_section["details_accordion_1"],
827
+ results_section["details_accordion_2"],
828
+ results_section["details_accordion_3"],
829
+ results_section["details_accordion_4"],
830
+ results_section["details_accordion_5"],
831
+ results_section["details_accordion_6"],
832
+ results_section["details_accordion_7"],
833
+ results_section["details_accordion_8"],
834
+ results_section["lrc_display_1"],
835
+ results_section["lrc_display_2"],
836
+ results_section["lrc_display_3"],
837
+ results_section["lrc_display_4"],
838
+ results_section["lrc_display_5"],
839
+ results_section["lrc_display_6"],
840
+ results_section["lrc_display_7"],
841
+ results_section["lrc_display_8"],
842
+ results_section["lm_metadata_state"],
843
+ results_section["is_format_caption_state"],
844
+ results_section["current_batch_index"],
845
+ results_section["total_batches"],
846
+ results_section["batch_queue"],
847
+ results_section["generation_params_state"],
848
+ results_section["batch_indicator"],
849
+ results_section["prev_batch_btn"],
850
+ results_section["next_batch_btn"],
851
+ results_section["next_batch_status"],
852
+ results_section["restore_params_btn"],
853
+ ]
854
+ ).then(
855
+ fn=lambda selected_model, *args: res_h.generate_next_batch_background(
856
+ dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
857
+ llm_handler, *args
858
+ ),
859
+ inputs=[
860
+ generation_section["dit_model_selector"], # Model selection input
861
+ generation_section["autogen_checkbox"],
862
+ results_section["generation_params_state"],
863
+ results_section["current_batch_index"],
864
+ results_section["total_batches"],
865
+ results_section["batch_queue"],
866
+ results_section["is_format_caption_state"],
867
+ ],
868
+ outputs=[
869
+ results_section["batch_queue"],
870
+ results_section["total_batches"],
871
+ results_section["next_batch_status"],
872
+ results_section["next_batch_btn"],
873
+ ]
874
+ )
875
+
876
+ # ========== Batch Navigation Handlers ==========
877
+ results_section["prev_batch_btn"].click(
878
+ fn=res_h.navigate_to_previous_batch,
879
+ inputs=[
880
+ results_section["current_batch_index"],
881
+ results_section["batch_queue"],
882
+ ],
883
+ outputs=[
884
+ results_section["generated_audio_1"],
885
+ results_section["generated_audio_2"],
886
+ results_section["generated_audio_3"],
887
+ results_section["generated_audio_4"],
888
+ results_section["generated_audio_5"],
889
+ results_section["generated_audio_6"],
890
+ results_section["generated_audio_7"],
891
+ results_section["generated_audio_8"],
892
+ results_section["generated_audio_batch"],
893
+ results_section["generation_info"],
894
+ results_section["current_batch_index"],
895
+ results_section["batch_indicator"],
896
+ results_section["prev_batch_btn"],
897
+ results_section["next_batch_btn"],
898
+ results_section["status_output"],
899
+ results_section["score_display_1"],
900
+ results_section["score_display_2"],
901
+ results_section["score_display_3"],
902
+ results_section["score_display_4"],
903
+ results_section["score_display_5"],
904
+ results_section["score_display_6"],
905
+ results_section["score_display_7"],
906
+ results_section["score_display_8"],
907
+ results_section["codes_display_1"],
908
+ results_section["codes_display_2"],
909
+ results_section["codes_display_3"],
910
+ results_section["codes_display_4"],
911
+ results_section["codes_display_5"],
912
+ results_section["codes_display_6"],
913
+ results_section["codes_display_7"],
914
+ results_section["codes_display_8"],
915
+ results_section["lrc_display_1"],
916
+ results_section["lrc_display_2"],
917
+ results_section["lrc_display_3"],
918
+ results_section["lrc_display_4"],
919
+ results_section["lrc_display_5"],
920
+ results_section["lrc_display_6"],
921
+ results_section["lrc_display_7"],
922
+ results_section["lrc_display_8"],
923
+ results_section["details_accordion_1"],
924
+ results_section["details_accordion_2"],
925
+ results_section["details_accordion_3"],
926
+ results_section["details_accordion_4"],
927
+ results_section["details_accordion_5"],
928
+ results_section["details_accordion_6"],
929
+ results_section["details_accordion_7"],
930
+ results_section["details_accordion_8"],
931
+ results_section["restore_params_btn"],
932
+ ]
933
+ )
934
+
935
+ results_section["next_batch_btn"].click(
936
+ fn=res_h.capture_current_params,
937
+ inputs=[
938
+ generation_section["captions"],
939
+ generation_section["lyrics"],
940
+ generation_section["bpm"],
941
+ generation_section["key_scale"],
942
+ generation_section["time_signature"],
943
+ generation_section["vocal_language"],
944
+ generation_section["inference_steps"],
945
+ generation_section["guidance_scale"],
946
+ generation_section["random_seed_checkbox"],
947
+ generation_section["seed"],
948
+ generation_section["reference_audio"],
949
+ generation_section["audio_duration"],
950
+ generation_section["batch_size_input"],
951
+ generation_section["src_audio"],
952
+ generation_section["text2music_audio_code_string"],
953
+ generation_section["repainting_start"],
954
+ generation_section["repainting_end"],
955
+ generation_section["instruction_display_gen"],
956
+ generation_section["audio_cover_strength"],
957
+ generation_section["task_type"],
958
+ generation_section["use_adg"],
959
+ generation_section["cfg_interval_start"],
960
+ generation_section["cfg_interval_end"],
961
+ generation_section["shift"],
962
+ generation_section["infer_method"],
963
+ generation_section["custom_timesteps"],
964
+ generation_section["audio_format"],
965
+ generation_section["lm_temperature"],
966
+ generation_section["think_checkbox"],
967
+ generation_section["lm_cfg_scale"],
968
+ generation_section["lm_top_k"],
969
+ generation_section["lm_top_p"],
970
+ generation_section["lm_negative_prompt"],
971
+ generation_section["use_cot_metas"],
972
+ generation_section["use_cot_caption"],
973
+ generation_section["use_cot_language"],
974
+ generation_section["constrained_decoding_debug"],
975
+ generation_section["allow_lm_batch"],
976
+ generation_section["auto_score"],
977
+ generation_section["auto_lrc"],
978
+ generation_section["score_scale"],
979
+ generation_section["lm_batch_chunk_size"],
980
+ generation_section["track_name"],
981
+ generation_section["complete_track_classes"],
982
+ ],
983
+ outputs=[results_section["generation_params_state"]]
984
+ ).then(
985
+ fn=res_h.navigate_to_next_batch,
986
+ inputs=[
987
+ generation_section["autogen_checkbox"],
988
+ results_section["current_batch_index"],
989
+ results_section["total_batches"],
990
+ results_section["batch_queue"],
991
+ ],
992
+ outputs=[
993
+ results_section["generated_audio_1"],
994
+ results_section["generated_audio_2"],
995
+ results_section["generated_audio_3"],
996
+ results_section["generated_audio_4"],
997
+ results_section["generated_audio_5"],
998
+ results_section["generated_audio_6"],
999
+ results_section["generated_audio_7"],
1000
+ results_section["generated_audio_8"],
1001
+ results_section["generated_audio_batch"],
1002
+ results_section["generation_info"],
1003
+ results_section["current_batch_index"],
1004
+ results_section["batch_indicator"],
1005
+ results_section["prev_batch_btn"],
1006
+ results_section["next_batch_btn"],
1007
+ results_section["status_output"],
1008
+ results_section["next_batch_status"],
1009
+ results_section["score_display_1"],
1010
+ results_section["score_display_2"],
1011
+ results_section["score_display_3"],
1012
+ results_section["score_display_4"],
1013
+ results_section["score_display_5"],
1014
+ results_section["score_display_6"],
1015
+ results_section["score_display_7"],
1016
+ results_section["score_display_8"],
1017
+ results_section["codes_display_1"],
1018
+ results_section["codes_display_2"],
1019
+ results_section["codes_display_3"],
1020
+ results_section["codes_display_4"],
1021
+ results_section["codes_display_5"],
1022
+ results_section["codes_display_6"],
1023
+ results_section["codes_display_7"],
1024
+ results_section["codes_display_8"],
1025
+ results_section["lrc_display_1"],
1026
+ results_section["lrc_display_2"],
1027
+ results_section["lrc_display_3"],
1028
+ results_section["lrc_display_4"],
1029
+ results_section["lrc_display_5"],
1030
+ results_section["lrc_display_6"],
1031
+ results_section["lrc_display_7"],
1032
+ results_section["lrc_display_8"],
1033
+ results_section["details_accordion_1"],
1034
+ results_section["details_accordion_2"],
1035
+ results_section["details_accordion_3"],
1036
+ results_section["details_accordion_4"],
1037
+ results_section["details_accordion_5"],
1038
+ results_section["details_accordion_6"],
1039
+ results_section["details_accordion_7"],
1040
+ results_section["details_accordion_8"],
1041
+ results_section["restore_params_btn"],
1042
+ ]
1043
+ ).then(
1044
+ fn=lambda selected_model, *args: res_h.generate_next_batch_background(
1045
+ dit_handler_2 if (dit_handler_2 is not None and selected_model == config_path_2) else dit_handler,
1046
+ llm_handler, *args
1047
+ ),
1048
+ inputs=[
1049
+ generation_section["dit_model_selector"], # Model selection input
1050
+ generation_section["autogen_checkbox"],
1051
+ results_section["generation_params_state"],
1052
+ results_section["current_batch_index"],
1053
+ results_section["total_batches"],
1054
+ results_section["batch_queue"],
1055
+ results_section["is_format_caption_state"],
1056
+ ],
1057
+ outputs=[
1058
+ results_section["batch_queue"],
1059
+ results_section["total_batches"],
1060
+ results_section["next_batch_status"],
1061
+ results_section["next_batch_btn"],
1062
+ ]
1063
+ )
1064
+
1065
+ # ========== Restore Parameters Handler ==========
1066
+ results_section["restore_params_btn"].click(
1067
+ fn=res_h.restore_batch_parameters,
1068
+ inputs=[
1069
+ results_section["current_batch_index"],
1070
+ results_section["batch_queue"]
1071
+ ],
1072
+ outputs=[
1073
+ generation_section["text2music_audio_code_string"],
1074
+ generation_section["captions"],
1075
+ generation_section["lyrics"],
1076
+ generation_section["bpm"],
1077
+ generation_section["key_scale"],
1078
+ generation_section["time_signature"],
1079
+ generation_section["vocal_language"],
1080
+ generation_section["audio_duration"],
1081
+ generation_section["batch_size_input"],
1082
+ generation_section["inference_steps"],
1083
+ generation_section["lm_temperature"],
1084
+ generation_section["lm_cfg_scale"],
1085
+ generation_section["lm_top_k"],
1086
+ generation_section["lm_top_p"],
1087
+ generation_section["think_checkbox"],
1088
+ generation_section["use_cot_caption"],
1089
+ generation_section["use_cot_language"],
1090
+ generation_section["allow_lm_batch"],
1091
+ generation_section["track_name"],
1092
+ generation_section["complete_track_classes"],
1093
+ ]
1094
+ )
1095
+
1096
+ # ========== LRC Display Change Handlers ==========
1097
+ # NEW APPROACH: Use lrc_display.change() to update audio subtitles
1098
+ # This decouples audio value updates from subtitle updates, avoiding flickering.
1099
+ #
1100
+ # When lrc_display text changes (from generate, LRC button, or manual edit):
1101
+ # 1. lrc_display.change() is triggered
1102
+ # 2. update_audio_subtitles_from_lrc() parses LRC and updates audio subtitles
1103
+ # 3. Audio value is NEVER updated here - only subtitles
1104
+ for lrc_idx in range(1, 9):
1105
+ results_section[f"lrc_display_{lrc_idx}"].change(
1106
+ fn=res_h.update_audio_subtitles_from_lrc,
1107
+ inputs=[
1108
+ results_section[f"lrc_display_{lrc_idx}"],
1109
+ # audio_duration not needed - parse_lrc_to_subtitles calculates end time from timestamps
1110
+ ],
1111
+ outputs=[
1112
+ results_section[f"generated_audio_{lrc_idx}"], # Only updates subtitles, not value
1113
+ ]
1114
+ )
1115
+
1116
+
1117
+ def setup_training_event_handlers(demo, dit_handler, llm_handler, training_section):
1118
+ """Setup event handlers for the training tab (dataset builder and LoRA training)"""
1119
+
1120
+ # ========== Load Existing Dataset (Top Section) ==========
1121
+
1122
+ # Load existing dataset JSON at the top of Dataset Builder
1123
+ training_section["load_json_btn"].click(
1124
+ fn=train_h.load_existing_dataset_for_preprocess,
1125
+ inputs=[
1126
+ training_section["load_json_path"],
1127
+ training_section["dataset_builder_state"],
1128
+ ],
1129
+ outputs=[
1130
+ training_section["load_json_status"],
1131
+ training_section["audio_files_table"],
1132
+ training_section["sample_selector"],
1133
+ training_section["dataset_builder_state"],
1134
+ # Also update preview fields with first sample
1135
+ training_section["preview_audio"],
1136
+ training_section["preview_filename"],
1137
+ training_section["edit_caption"],
1138
+ training_section["edit_lyrics"],
1139
+ training_section["edit_bpm"],
1140
+ training_section["edit_keyscale"],
1141
+ training_section["edit_timesig"],
1142
+ training_section["edit_duration"],
1143
+ training_section["edit_language"],
1144
+ training_section["edit_instrumental"],
1145
+ ]
1146
+ )
1147
+
1148
+ # ========== Dataset Builder Handlers ==========
1149
+
1150
+ # Scan directory for audio files
1151
+ training_section["scan_btn"].click(
1152
+ fn=lambda dir, name, tag, pos, instr, state: train_h.scan_directory(
1153
+ dir, name, tag, pos, instr, state
1154
+ ),
1155
+ inputs=[
1156
+ training_section["audio_directory"],
1157
+ training_section["dataset_name"],
1158
+ training_section["custom_tag"],
1159
+ training_section["tag_position"],
1160
+ training_section["all_instrumental"],
1161
+ training_section["dataset_builder_state"],
1162
+ ],
1163
+ outputs=[
1164
+ training_section["audio_files_table"],
1165
+ training_section["scan_status"],
1166
+ training_section["sample_selector"],
1167
+ training_section["dataset_builder_state"],
1168
+ ]
1169
+ )
1170
+
1171
+ # Auto-label all samples
1172
+ training_section["auto_label_btn"].click(
1173
+ fn=lambda state, skip: train_h.auto_label_all(dit_handler, llm_handler, state, skip),
1174
+ inputs=[
1175
+ training_section["dataset_builder_state"],
1176
+ training_section["skip_metas"],
1177
+ ],
1178
+ outputs=[
1179
+ training_section["audio_files_table"],
1180
+ training_section["label_progress"],
1181
+ training_section["dataset_builder_state"],
1182
+ ]
1183
+ )
1184
+
1185
+ # Sample selector change - update preview
1186
+ training_section["sample_selector"].change(
1187
+ fn=train_h.get_sample_preview,
1188
+ inputs=[
1189
+ training_section["sample_selector"],
1190
+ training_section["dataset_builder_state"],
1191
+ ],
1192
+ outputs=[
1193
+ training_section["preview_audio"],
1194
+ training_section["preview_filename"],
1195
+ training_section["edit_caption"],
1196
+ training_section["edit_lyrics"],
1197
+ training_section["edit_bpm"],
1198
+ training_section["edit_keyscale"],
1199
+ training_section["edit_timesig"],
1200
+ training_section["edit_duration"],
1201
+ training_section["edit_language"],
1202
+ training_section["edit_instrumental"],
1203
+ ]
1204
+ )
1205
+
1206
+ # Save sample edit
1207
+ training_section["save_edit_btn"].click(
1208
+ fn=train_h.save_sample_edit,
1209
+ inputs=[
1210
+ training_section["sample_selector"],
1211
+ training_section["edit_caption"],
1212
+ training_section["edit_lyrics"],
1213
+ training_section["edit_bpm"],
1214
+ training_section["edit_keyscale"],
1215
+ training_section["edit_timesig"],
1216
+ training_section["edit_language"],
1217
+ training_section["edit_instrumental"],
1218
+ training_section["dataset_builder_state"],
1219
+ ],
1220
+ outputs=[
1221
+ training_section["audio_files_table"],
1222
+ training_section["edit_status"],
1223
+ training_section["dataset_builder_state"],
1224
+ ]
1225
+ )
1226
+
1227
+ # Update settings when changed
1228
+ for trigger in [training_section["custom_tag"], training_section["tag_position"], training_section["all_instrumental"]]:
1229
+ trigger.change(
1230
+ fn=train_h.update_settings,
1231
+ inputs=[
1232
+ training_section["custom_tag"],
1233
+ training_section["tag_position"],
1234
+ training_section["all_instrumental"],
1235
+ training_section["dataset_builder_state"],
1236
+ ],
1237
+ outputs=[training_section["dataset_builder_state"]]
1238
+ )
1239
+
1240
+ # Save dataset
1241
+ training_section["save_dataset_btn"].click(
1242
+ fn=train_h.save_dataset,
1243
+ inputs=[
1244
+ training_section["save_path"],
1245
+ training_section["dataset_name"],
1246
+ training_section["dataset_builder_state"],
1247
+ ],
1248
+ outputs=[training_section["save_status"]]
1249
+ )
1250
+
1251
+ # ========== Preprocess Handlers ==========
1252
+
1253
+ # Load existing dataset JSON for preprocessing
1254
+ # This also updates the preview section so users can view/edit samples
1255
+ training_section["load_existing_dataset_btn"].click(
1256
+ fn=train_h.load_existing_dataset_for_preprocess,
1257
+ inputs=[
1258
+ training_section["load_existing_dataset_path"],
1259
+ training_section["dataset_builder_state"],
1260
+ ],
1261
+ outputs=[
1262
+ training_section["load_existing_status"],
1263
+ training_section["audio_files_table"],
1264
+ training_section["sample_selector"],
1265
+ training_section["dataset_builder_state"],
1266
+ # Also update preview fields with first sample
1267
+ training_section["preview_audio"],
1268
+ training_section["preview_filename"],
1269
+ training_section["edit_caption"],
1270
+ training_section["edit_lyrics"],
1271
+ training_section["edit_bpm"],
1272
+ training_section["edit_keyscale"],
1273
+ training_section["edit_timesig"],
1274
+ training_section["edit_duration"],
1275
+ training_section["edit_language"],
1276
+ training_section["edit_instrumental"],
1277
+ ]
1278
+ )
1279
+
1280
+ # Preprocess dataset to tensor files
1281
+ training_section["preprocess_btn"].click(
1282
+ fn=lambda output_dir, state: train_h.preprocess_dataset(
1283
+ output_dir, dit_handler, state
1284
+ ),
1285
+ inputs=[
1286
+ training_section["preprocess_output_dir"],
1287
+ training_section["dataset_builder_state"],
1288
+ ],
1289
+ outputs=[training_section["preprocess_progress"]]
1290
+ )
1291
+
1292
+ # ========== Training Tab Handlers ==========
1293
+
1294
+ # Load preprocessed tensor dataset
1295
+ training_section["load_dataset_btn"].click(
1296
+ fn=train_h.load_training_dataset,
1297
+ inputs=[training_section["training_tensor_dir"]],
1298
+ outputs=[training_section["training_dataset_info"]]
1299
+ )
1300
+
1301
+ # Start training from preprocessed tensors
1302
+ def training_wrapper(tensor_dir, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts):
1303
+ try:
1304
+ for progress, log, plot, state in train_h.start_training(
1305
+ tensor_dir, dit_handler, r, a, d, lr, ep, bs, ga, se, sh, sd, od, ts
1306
+ ):
1307
+ yield progress, log, plot, state
1308
+ except Exception as e:
1309
+ logger.exception("Training wrapper error")
1310
+ yield f"❌ Error: {str(e)}", str(e), None, ts
1311
+
1312
+ training_section["start_training_btn"].click(
1313
+ fn=training_wrapper,
1314
+ inputs=[
1315
+ training_section["training_tensor_dir"],
1316
+ training_section["lora_rank"],
1317
+ training_section["lora_alpha"],
1318
+ training_section["lora_dropout"],
1319
+ training_section["learning_rate"],
1320
+ training_section["train_epochs"],
1321
+ training_section["train_batch_size"],
1322
+ training_section["gradient_accumulation"],
1323
+ training_section["save_every_n_epochs"],
1324
+ training_section["training_shift"],
1325
+ training_section["training_seed"],
1326
+ training_section["lora_output_dir"],
1327
+ training_section["training_state"],
1328
+ ],
1329
+ outputs=[
1330
+ training_section["training_progress"],
1331
+ training_section["training_log"],
1332
+ training_section["training_loss_plot"],
1333
+ training_section["training_state"],
1334
+ ]
1335
+ )
1336
+
1337
+ # Stop training
1338
+ training_section["stop_training_btn"].click(
1339
+ fn=train_h.stop_training,
1340
+ inputs=[training_section["training_state"]],
1341
+ outputs=[
1342
+ training_section["training_progress"],
1343
+ training_section["training_state"],
1344
+ ]
1345
+ )
1346
+
1347
+ # Export LoRA
1348
+ training_section["export_lora_btn"].click(
1349
+ fn=train_h.export_lora,
1350
+ inputs=[
1351
+ training_section["export_path"],
1352
+ training_section["lora_output_dir"],
1353
+ ],
1354
+ outputs=[training_section["export_status"]]
1355
+ )
acestep/gradio_ui/events/generation_handlers.py ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation Input Handlers Module
3
+ Contains event handlers and helper functions related to generation inputs
4
+ """
5
+ import os
6
+ import json
7
+ import random
8
+ import glob
9
+ import gradio as gr
10
+ from typing import Optional, List, Tuple
11
+ from loguru import logger
12
+ from acestep.constants import (
13
+ TASK_TYPES_TURBO,
14
+ TASK_TYPES_BASE,
15
+ )
16
+ from acestep.gradio_ui.i18n import t
17
+ from acestep.inference import understand_music, create_sample, format_sample
18
+
19
+
20
+ # HuggingFace Space environment detection for ZeroGPU support
21
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
22
+
23
+
24
+ def _get_spaces_gpu_decorator(duration=120):
25
+ """
26
+ Get the @spaces.GPU decorator if running in HuggingFace Space environment.
27
+ Returns identity decorator if not in Space environment.
28
+ """
29
+ if IS_HUGGINGFACE_SPACE:
30
+ try:
31
+ import spaces
32
+ return spaces.GPU(duration=duration)
33
+ except ImportError:
34
+ logger.warning("spaces package not found, GPU decorator disabled")
35
+ return lambda func: func
36
+ return lambda func: func
37
+
38
+
39
+ def parse_and_validate_timesteps(
40
+ timesteps_str: str,
41
+ inference_steps: int
42
+ ) -> Tuple[Optional[List[float]], bool, str]:
43
+ """
44
+ Parse timesteps string and validate.
45
+
46
+ Args:
47
+ timesteps_str: Comma-separated timesteps string (e.g., "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
48
+ inference_steps: Expected number of inference steps
49
+
50
+ Returns:
51
+ Tuple of (parsed_timesteps, has_warning, warning_message)
52
+ - parsed_timesteps: List of float timesteps, or None if invalid/empty
53
+ - has_warning: Whether a warning was shown
54
+ - warning_message: Description of the warning
55
+ """
56
+ if not timesteps_str or not timesteps_str.strip():
57
+ return None, False, ""
58
+
59
+ # Parse comma-separated values
60
+ values = [v.strip() for v in timesteps_str.split(",") if v.strip()]
61
+
62
+ if not values:
63
+ return None, False, ""
64
+
65
+ # Handle optional trailing 0
66
+ if values[-1] != "0":
67
+ values.append("0")
68
+
69
+ try:
70
+ timesteps = [float(v) for v in values]
71
+ except ValueError:
72
+ gr.Warning(t("messages.invalid_timesteps_format"))
73
+ return None, True, "Invalid format"
74
+
75
+ # Validate range [0, 1]
76
+ if any(ts < 0 or ts > 1 for ts in timesteps):
77
+ gr.Warning(t("messages.timesteps_out_of_range"))
78
+ return None, True, "Out of range"
79
+
80
+ # Check if count matches inference_steps
81
+ actual_steps = len(timesteps) - 1
82
+ if actual_steps != inference_steps:
83
+ gr.Warning(t("messages.timesteps_count_mismatch", actual=actual_steps, expected=inference_steps))
84
+ return timesteps, True, f"Using {actual_steps} steps from timesteps"
85
+
86
+ return timesteps, False, ""
87
+
88
+
89
+ def load_metadata(file_obj):
90
+ """Load generation parameters from a JSON file"""
91
+ if file_obj is None:
92
+ gr.Warning(t("messages.no_file_selected"))
93
+ return [None] * 36 + [False] # Return None for all fields, False for is_format_caption
94
+
95
+ try:
96
+ # Read the uploaded file
97
+ if hasattr(file_obj, 'name'):
98
+ filepath = file_obj.name
99
+ else:
100
+ filepath = file_obj
101
+
102
+ with open(filepath, 'r', encoding='utf-8') as f:
103
+ metadata = json.load(f)
104
+
105
+ # Extract all fields
106
+ task_type = metadata.get('task_type', 'text2music')
107
+ captions = metadata.get('caption', '')
108
+ lyrics = metadata.get('lyrics', '')
109
+ vocal_language = metadata.get('vocal_language', 'unknown')
110
+
111
+ # Convert bpm
112
+ bpm_value = metadata.get('bpm')
113
+ if bpm_value is not None and bpm_value != "N/A":
114
+ try:
115
+ bpm = int(bpm_value) if bpm_value else None
116
+ except:
117
+ bpm = None
118
+ else:
119
+ bpm = None
120
+
121
+ key_scale = metadata.get('keyscale', '')
122
+ time_signature = metadata.get('timesignature', '')
123
+
124
+ # Convert duration
125
+ duration_value = metadata.get('duration', -1)
126
+ if duration_value is not None and duration_value != "N/A":
127
+ try:
128
+ audio_duration = float(duration_value)
129
+ except:
130
+ audio_duration = -1
131
+ else:
132
+ audio_duration = -1
133
+
134
+ batch_size = metadata.get('batch_size', 2)
135
+ inference_steps = metadata.get('inference_steps', 8)
136
+ guidance_scale = metadata.get('guidance_scale', 7.0)
137
+ seed = metadata.get('seed', '-1')
138
+ random_seed = False # Always set to False when loading to enable reproducibility with saved seed
139
+ use_adg = metadata.get('use_adg', False)
140
+ cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
141
+ cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
142
+ audio_format = metadata.get('audio_format', 'mp3')
143
+ lm_temperature = metadata.get('lm_temperature', 0.85)
144
+ lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
145
+ lm_top_k = metadata.get('lm_top_k', 0)
146
+ lm_top_p = metadata.get('lm_top_p', 0.9)
147
+ lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
148
+ use_cot_metas = metadata.get('use_cot_metas', True) # Added: read use_cot_metas
149
+ use_cot_caption = metadata.get('use_cot_caption', True)
150
+ use_cot_language = metadata.get('use_cot_language', True)
151
+ audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
152
+ think = metadata.get('thinking', True) # Fixed: read 'thinking' not 'think'
153
+ audio_codes = metadata.get('audio_codes', '')
154
+ repainting_start = metadata.get('repainting_start', 0.0)
155
+ repainting_end = metadata.get('repainting_end', -1)
156
+ track_name = metadata.get('track_name')
157
+ complete_track_classes = metadata.get('complete_track_classes', [])
158
+ shift = metadata.get('shift', 3.0) # Default 3.0 for base models
159
+ infer_method = metadata.get('infer_method', 'ode') # Default 'ode' for diffusion inference
160
+ custom_timesteps = metadata.get('timesteps', '') # Custom timesteps (stored as 'timesteps' in JSON)
161
+ if custom_timesteps is None:
162
+ custom_timesteps = ''
163
+ instrumental = metadata.get('instrumental', False) # Added: read instrumental
164
+
165
+ gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
166
+
167
+ return (
168
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
169
+ audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
170
+ use_adg, cfg_interval_start, cfg_interval_end, shift, infer_method,
171
+ custom_timesteps, # Added: custom_timesteps (between infer_method and audio_format)
172
+ audio_format, lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
173
+ use_cot_metas, use_cot_caption, use_cot_language, audio_cover_strength,
174
+ think, audio_codes, repainting_start, repainting_end,
175
+ track_name, complete_track_classes, instrumental,
176
+ True # Set is_format_caption to True when loading from file
177
+ )
178
+
179
+ except json.JSONDecodeError as e:
180
+ gr.Warning(t("messages.invalid_json", error=str(e)))
181
+ return [None] * 36 + [False]
182
+ except Exception as e:
183
+ gr.Warning(t("messages.load_error", error=str(e)))
184
+ return [None] * 36 + [False]
185
+
186
+
187
+ def load_random_example(task_type: str):
188
+ """Load a random example from the task-specific examples directory
189
+
190
+ Args:
191
+ task_type: The task type (e.g., "text2music")
192
+
193
+ Returns:
194
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
195
+ """
196
+ try:
197
+ # Get the project root directory
198
+ current_file = os.path.abspath(__file__)
199
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
200
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
201
+
202
+ # Construct the examples directory path
203
+ examples_dir = os.path.join(project_root, "examples", task_type)
204
+
205
+ # Check if directory exists
206
+ if not os.path.exists(examples_dir):
207
+ gr.Warning(f"Examples directory not found: examples/{task_type}/")
208
+ return "", "", True, None, None, "", "", ""
209
+
210
+ # Find all JSON files in the directory
211
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
212
+
213
+ if not json_files:
214
+ gr.Warning(f"No JSON files found in examples/{task_type}/")
215
+ return "", "", True, None, None, "", "", ""
216
+
217
+ # Randomly select one file
218
+ selected_file = random.choice(json_files)
219
+
220
+ # Read and parse JSON
221
+ try:
222
+ with open(selected_file, 'r', encoding='utf-8') as f:
223
+ data = json.load(f)
224
+
225
+ # Extract caption (prefer 'caption', fallback to 'prompt')
226
+ caption_value = data.get('caption', data.get('prompt', ''))
227
+ if not isinstance(caption_value, str):
228
+ caption_value = str(caption_value) if caption_value else ''
229
+
230
+ # Extract lyrics
231
+ lyrics_value = data.get('lyrics', '')
232
+ if not isinstance(lyrics_value, str):
233
+ lyrics_value = str(lyrics_value) if lyrics_value else ''
234
+
235
+ # Extract think (default to True if not present)
236
+ think_value = data.get('think', True)
237
+ if not isinstance(think_value, bool):
238
+ think_value = True
239
+
240
+ # Extract optional metadata fields
241
+ bpm_value = None
242
+ if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
243
+ try:
244
+ bpm_value = int(data['bpm'])
245
+ except (ValueError, TypeError):
246
+ pass
247
+
248
+ duration_value = None
249
+ if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
250
+ try:
251
+ duration_value = float(data['duration'])
252
+ except (ValueError, TypeError):
253
+ pass
254
+
255
+ keyscale_value = data.get('keyscale', '')
256
+ if keyscale_value in [None, "N/A"]:
257
+ keyscale_value = ''
258
+
259
+ language_value = data.get('language', '')
260
+ if language_value in [None, "N/A"]:
261
+ language_value = ''
262
+
263
+ timesignature_value = data.get('timesignature', '')
264
+ if timesignature_value in [None, "N/A"]:
265
+ timesignature_value = ''
266
+
267
+ gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
268
+ return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
269
+
270
+ except json.JSONDecodeError as e:
271
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
272
+ return "", "", True, None, None, "", "", ""
273
+ except Exception as e:
274
+ gr.Warning(t("messages.example_error", error=str(e)))
275
+ return "", "", True, None, None, "", "", ""
276
+
277
+ except Exception as e:
278
+ gr.Warning(t("messages.example_error", error=str(e)))
279
+ return "", "", True, None, None, "", "", ""
280
+
281
+
282
+ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
283
+ """Smart sample function that uses LM if initialized, otherwise falls back to examples
284
+
285
+ This is a Gradio wrapper that uses the understand_music API from acestep.inference
286
+ to generate examples when LM is available.
287
+
288
+ Args:
289
+ llm_handler: LLM handler instance
290
+ task_type: The task type (e.g., "text2music")
291
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
292
+
293
+ Returns:
294
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
295
+ """
296
+ # Check if LM is initialized
297
+ if llm_handler.llm_initialized:
298
+ # Use LM to generate example via understand_music API
299
+ try:
300
+ result = understand_music(
301
+ llm_handler=llm_handler,
302
+ audio_codes="NO USER INPUT", # Empty input triggers example generation
303
+ temperature=0.85,
304
+ use_constrained_decoding=True,
305
+ constrained_decoding_debug=constrained_decoding_debug,
306
+ )
307
+
308
+ if result.success:
309
+ gr.Info(t("messages.lm_generated"))
310
+ return (
311
+ result.caption,
312
+ result.lyrics,
313
+ True, # Always enable think when using LM-generated examples
314
+ result.bpm,
315
+ result.duration,
316
+ result.keyscale,
317
+ result.language,
318
+ result.timesignature,
319
+ )
320
+ else:
321
+ gr.Warning(t("messages.lm_fallback"))
322
+ return load_random_example(task_type)
323
+
324
+ except Exception as e:
325
+ gr.Warning(t("messages.lm_fallback"))
326
+ return load_random_example(task_type)
327
+ else:
328
+ # LM not initialized, use examples directory
329
+ return load_random_example(task_type)
330
+
331
+
332
+ def load_random_simple_description():
333
+ """Load a random description from the simple_mode examples directory.
334
+
335
+ Returns:
336
+ Tuple of (description, instrumental, vocal_language) for updating UI components
337
+ """
338
+ try:
339
+ # Get the project root directory
340
+ current_file = os.path.abspath(__file__)
341
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
342
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
343
+
344
+ # Construct the examples directory path
345
+ examples_dir = os.path.join(project_root, "examples", "simple_mode")
346
+
347
+ # Check if directory exists
348
+ if not os.path.exists(examples_dir):
349
+ gr.Warning(t("messages.simple_examples_not_found"))
350
+ return gr.update(), gr.update(), gr.update()
351
+
352
+ # Find all JSON files in the directory
353
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
354
+
355
+ if not json_files:
356
+ gr.Warning(t("messages.simple_examples_empty"))
357
+ return gr.update(), gr.update(), gr.update()
358
+
359
+ # Randomly select one file
360
+ selected_file = random.choice(json_files)
361
+
362
+ # Read and parse JSON
363
+ try:
364
+ with open(selected_file, 'r', encoding='utf-8') as f:
365
+ data = json.load(f)
366
+
367
+ # Extract fields
368
+ description = data.get('description', '')
369
+ instrumental = data.get('instrumental', False)
370
+ vocal_language = data.get('vocal_language', 'unknown')
371
+
372
+ # Ensure vocal_language is a string
373
+ if isinstance(vocal_language, list):
374
+ vocal_language = vocal_language[0] if vocal_language else 'unknown'
375
+
376
+ gr.Info(t("messages.simple_example_loaded", filename=os.path.basename(selected_file)))
377
+ return description, instrumental, vocal_language
378
+
379
+ except json.JSONDecodeError as e:
380
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
381
+ return gr.update(), gr.update(), gr.update()
382
+ except Exception as e:
383
+ gr.Warning(t("messages.example_error", error=str(e)))
384
+ return gr.update(), gr.update(), gr.update()
385
+
386
+ except Exception as e:
387
+ gr.Warning(t("messages.example_error", error=str(e)))
388
+ return gr.update(), gr.update(), gr.update()
389
+
390
+
391
+ def refresh_checkpoints(dit_handler):
392
+ """Refresh available checkpoints"""
393
+ choices = dit_handler.get_available_checkpoints()
394
+ return gr.update(choices=choices)
395
+
396
+
397
+ def update_model_type_settings(config_path):
398
+ """Update UI settings based on model type (fallback when handler not initialized yet)
399
+
400
+ Note: This is used as a fallback when the user changes config_path dropdown
401
+ before initializing the model. The actual settings are determined by the
402
+ handler's is_turbo_model() method after initialization.
403
+ """
404
+ if config_path is None:
405
+ config_path = ""
406
+ config_path_lower = config_path.lower()
407
+
408
+ # Determine is_turbo based on config_path string
409
+ # This is a heuristic fallback - actual model type is determined after loading
410
+ if "turbo" in config_path_lower:
411
+ is_turbo = True
412
+ elif "base" in config_path_lower:
413
+ is_turbo = False
414
+ else:
415
+ # Default to turbo settings for unknown model types
416
+ is_turbo = True
417
+
418
+ return get_model_type_ui_settings(is_turbo)
419
+
420
+
421
+ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
422
+ """Wrapper for service initialization, returns status, button state, accordion state, and model type settings"""
423
+ # Initialize DiT handler
424
+ status, enable = dit_handler.initialize_service(
425
+ checkpoint, config_path, device,
426
+ use_flash_attention=use_flash_attention, compile_model=False,
427
+ offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
428
+ )
429
+
430
+ # Initialize LM handler if requested
431
+ if init_llm:
432
+ # Get checkpoint directory
433
+ current_file = os.path.abspath(__file__)
434
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
435
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
436
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
437
+
438
+ lm_status, lm_success = llm_handler.initialize(
439
+ checkpoint_dir=checkpoint_dir,
440
+ lm_model_path=lm_model_path,
441
+ backend=backend,
442
+ device=device,
443
+ offload_to_cpu=offload_to_cpu,
444
+ dtype=dit_handler.dtype
445
+ )
446
+
447
+ if lm_success:
448
+ status += f"\n{lm_status}"
449
+ else:
450
+ status += f"\n{lm_status}"
451
+ # Don't fail the entire initialization if LM fails, but log it
452
+ # Keep enable as is (DiT initialization result) even if LM fails
453
+
454
+ # Check if model is initialized - if so, collapse the accordion
455
+ is_model_initialized = dit_handler.model is not None
456
+ accordion_state = gr.Accordion(open=not is_model_initialized)
457
+
458
+ # Get model type settings based on actual loaded model
459
+ is_turbo = dit_handler.is_turbo_model()
460
+ model_type_settings = get_model_type_ui_settings(is_turbo)
461
+
462
+ return (
463
+ status,
464
+ gr.update(interactive=enable),
465
+ accordion_state,
466
+ *model_type_settings
467
+ )
468
+
469
+
470
+ def get_model_type_ui_settings(is_turbo: bool):
471
+ """Get UI settings based on whether the model is turbo or base"""
472
+ if is_turbo:
473
+ # Turbo model: max 20 steps, default 8, show shift with default 3.0, only show text2music/repaint/cover
474
+ return (
475
+ gr.update(value=8, maximum=20, minimum=1), # inference_steps
476
+ gr.update(visible=False), # guidance_scale
477
+ gr.update(visible=False), # use_adg
478
+ gr.update(value=3.0, visible=True), # shift (show with default 3.0)
479
+ gr.update(visible=False), # cfg_interval_start
480
+ gr.update(visible=False), # cfg_interval_end
481
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
482
+ )
483
+ else:
484
+ # Base model: max 200 steps, default 32, show CFG/ADG/shift, show all task types
485
+ return (
486
+ gr.update(value=32, maximum=200, minimum=1), # inference_steps
487
+ gr.update(visible=True), # guidance_scale
488
+ gr.update(visible=True), # use_adg
489
+ gr.update(value=3.0, visible=True), # shift (effective for base, default 3.0)
490
+ gr.update(visible=True), # cfg_interval_start
491
+ gr.update(visible=True), # cfg_interval_end
492
+ gr.update(choices=TASK_TYPES_BASE), # task_type
493
+ )
494
+
495
+
496
+ def update_negative_prompt_visibility(init_llm_checked):
497
+ """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
498
+ return gr.update(visible=init_llm_checked)
499
+
500
+
501
+ def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
502
+ """Update audio_cover_strength visibility and label"""
503
+ # Show if task is cover OR if LM is initialized (but NOT for repaint mode)
504
+ # Repaint mode never shows this control
505
+ is_repaint = task_type_value == "repaint"
506
+ is_cover = task_type_value == "cover"
507
+ is_visible = is_cover or (init_llm_checked and not is_repaint)
508
+
509
+ # Change label based on context
510
+ if init_llm_checked and not is_cover:
511
+ label = "LM codes strength"
512
+ info = "Control how many denoising steps use LM-generated codes"
513
+ else:
514
+ label = "Audio Cover Strength"
515
+ info = "Control how many denoising steps use cover mode"
516
+
517
+ return gr.update(visible=is_visible, label=label, info=info)
518
+
519
+
520
+ def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
521
+ """Wrapper for converting src audio to codes"""
522
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
523
+ return codes_string
524
+
525
+
526
+ def update_instruction_ui(
527
+ dit_handler,
528
+ task_type_value: str,
529
+ track_name_value: Optional[str],
530
+ complete_track_classes_value: list,
531
+ audio_codes_content: str = "",
532
+ init_llm_checked: bool = False
533
+ ) -> tuple:
534
+ """Update instruction and UI visibility based on task type."""
535
+ instruction = dit_handler.generate_instruction(
536
+ task_type=task_type_value,
537
+ track_name=track_name_value,
538
+ complete_track_classes=complete_track_classes_value
539
+ )
540
+
541
+ # Show track_name for lego and extract
542
+ track_name_visible = task_type_value in ["lego", "extract"]
543
+ # Show complete_track_classes for complete
544
+ complete_visible = task_type_value == "complete"
545
+ # Show audio_cover_strength for cover OR when LM is initialized (but NOT for repaint)
546
+ is_repaint = task_type_value == "repaint"
547
+ is_cover = task_type_value == "cover"
548
+ audio_cover_strength_visible = is_cover or (init_llm_checked and not is_repaint)
549
+ # Determine label and info based on context
550
+ if init_llm_checked and not is_cover:
551
+ audio_cover_strength_label = "LM codes strength"
552
+ audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
553
+ else:
554
+ audio_cover_strength_label = "Audio Cover Strength"
555
+ audio_cover_strength_info = "Control how many denoising steps use cover mode"
556
+ # Show repainting controls for repaint and lego
557
+ repainting_visible = task_type_value in ["repaint", "lego"]
558
+ # Show text2music_audio_codes if task is text2music OR if it has content
559
+ # This allows it to stay visible even if user switches task type but has codes
560
+ has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
561
+ text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
562
+
563
+ return (
564
+ instruction, # instruction_display_gen
565
+ gr.update(visible=track_name_visible), # track_name
566
+ gr.update(visible=complete_visible), # complete_track_classes
567
+ gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
568
+ gr.update(visible=repainting_visible), # repainting_group
569
+ gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
570
+ )
571
+
572
+
573
+ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
574
+ """
575
+ Transcribe audio codes to metadata using LLM understanding.
576
+ If audio_code_string is empty, generate a sample example instead.
577
+
578
+ This is a Gradio wrapper around the understand_music API in acestep.inference.
579
+
580
+ Args:
581
+ llm_handler: LLM handler instance
582
+ audio_code_string: String containing audio codes (or empty for example generation)
583
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
584
+
585
+ Returns:
586
+ Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
587
+ """
588
+ # Call the inference API
589
+ result = understand_music(
590
+ llm_handler=llm_handler,
591
+ audio_codes=audio_code_string,
592
+ use_constrained_decoding=True,
593
+ constrained_decoding_debug=constrained_decoding_debug,
594
+ )
595
+
596
+ # Handle error case with localized message
597
+ if not result.success:
598
+ # Use localized error message for LLM not initialized
599
+ if result.error == "LLM not initialized":
600
+ return t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False
601
+ return result.status_message, "", "", None, None, "", "", "", False
602
+
603
+ return (
604
+ result.status_message,
605
+ result.caption,
606
+ result.lyrics,
607
+ result.bpm,
608
+ result.duration,
609
+ result.keyscale,
610
+ result.language,
611
+ result.timesignature,
612
+ True # Set is_format_caption to True (from Transcribe/LM understanding)
613
+ )
614
+
615
+
616
+ def update_transcribe_button_text(audio_code_string):
617
+ """
618
+ Update the transcribe button text based on input content.
619
+ If empty: "Generate Example"
620
+ If has content: "Transcribe"
621
+ """
622
+ if not audio_code_string or not audio_code_string.strip():
623
+ return gr.update(value="Generate Example")
624
+ else:
625
+ return gr.update(value="Transcribe")
626
+
627
+
628
+ def reset_format_caption_flag():
629
+ """Reset is_format_caption to False when user manually edits caption/metadata"""
630
+ return False
631
+
632
+
633
+ def update_audio_uploads_accordion(reference_audio, src_audio):
634
+ """Update Audio Uploads visibility based on whether audio files are present"""
635
+ has_audio = (reference_audio is not None) or (src_audio is not None)
636
+ return gr.update(visible=has_audio)
637
+
638
+
639
+ def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
640
+ """
641
+ Handle instrumental checkbox changes.
642
+ When checked: if no lyrics, fill with [Instrumental]
643
+ When unchecked: if lyrics is [Instrumental], clear it
644
+ """
645
+ if instrumental_checked:
646
+ # If checked and no lyrics, fill with [Instrumental]
647
+ if not current_lyrics or not current_lyrics.strip():
648
+ return "[Instrumental]"
649
+ else:
650
+ # Has lyrics, don't change
651
+ return current_lyrics
652
+ else:
653
+ # If unchecked and lyrics is exactly [Instrumental], clear it
654
+ if current_lyrics and current_lyrics.strip() == "[Instrumental]":
655
+ return ""
656
+ else:
657
+ # Has other lyrics, don't change
658
+ return current_lyrics
659
+
660
+
661
+ def handle_simple_instrumental_change(is_instrumental: bool):
662
+ """
663
+ Handle simple mode instrumental checkbox changes.
664
+ When checked: set vocal_language to "unknown" and disable editing.
665
+ When unchecked: enable vocal_language editing.
666
+
667
+ Args:
668
+ is_instrumental: Whether instrumental checkbox is checked
669
+
670
+ Returns:
671
+ gr.update for simple_vocal_language dropdown
672
+ """
673
+ if is_instrumental:
674
+ return gr.update(value="unknown", interactive=False)
675
+ else:
676
+ return gr.update(interactive=True)
677
+
678
+
679
+ def update_audio_components_visibility(batch_size):
680
+ """Show/hide individual audio components based on batch size (1-8)
681
+
682
+ Row 1: Components 1-4 (batch_size 1-4)
683
+ Row 2: Components 5-8 (batch_size 5-8)
684
+ """
685
+ # Clamp batch size to 1-8 range for UI
686
+ batch_size = min(max(int(batch_size), 1), 8)
687
+
688
+ # Row 1 columns (1-4)
689
+ updates_row1 = (
690
+ gr.update(visible=True), # audio_col_1: always visible
691
+ gr.update(visible=batch_size >= 2), # audio_col_2
692
+ gr.update(visible=batch_size >= 3), # audio_col_3
693
+ gr.update(visible=batch_size >= 4), # audio_col_4
694
+ )
695
+
696
+ # Row 2 container and columns (5-8)
697
+ show_row_5_8 = batch_size >= 5
698
+ updates_row2 = (
699
+ gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
700
+ gr.update(visible=batch_size >= 5), # audio_col_5
701
+ gr.update(visible=batch_size >= 6), # audio_col_6
702
+ gr.update(visible=batch_size >= 7), # audio_col_7
703
+ gr.update(visible=batch_size >= 8), # audio_col_8
704
+ )
705
+
706
+ return updates_row1 + updates_row2
707
+
708
+
709
+ def handle_generation_mode_change(mode: str):
710
+ """
711
+ Handle generation mode change between Simple, Custom, Cover, and Repaint modes.
712
+
713
+ Modes:
714
+ - Simple: Show simple mode group, hide others
715
+ - Custom: Show custom content (prompt), hide others
716
+ - Cover: Show src_audio_group + custom content + LM codes strength
717
+ - Repaint: Show src_audio_group + custom content + repaint time controls (hide LM codes strength)
718
+
719
+ Args:
720
+ mode: "simple", "custom", "cover", or "repaint"
721
+
722
+ Returns:
723
+ Tuple of updates for:
724
+ - simple_mode_group (visibility)
725
+ - custom_mode_content (visibility)
726
+ - cover_mode_group (visibility) - legacy, always hidden
727
+ - repainting_group (visibility)
728
+ - task_type (value)
729
+ - generate_btn (interactive state)
730
+ - simple_sample_created (reset state)
731
+ - src_audio_group (visibility) - shown for cover and repaint
732
+ - audio_cover_strength (visibility) - shown only for cover mode
733
+ - think_checkbox (value and interactive) - disabled for cover/repaint modes
734
+ """
735
+ is_simple = mode == "simple"
736
+ is_custom = mode == "custom"
737
+ is_cover = mode == "cover"
738
+ is_repaint = mode == "repaint"
739
+
740
+ # Map mode to task_type
741
+ task_type_map = {
742
+ "simple": "text2music",
743
+ "custom": "text2music",
744
+ "cover": "cover",
745
+ "repaint": "repaint",
746
+ }
747
+ task_type_value = task_type_map.get(mode, "text2music")
748
+
749
+ # think_checkbox: disabled and set to False for cover/repaint modes
750
+ # (these modes don't use LM thinking, they use source audio codes)
751
+ if is_cover or is_repaint:
752
+ think_checkbox_update = gr.update(value=False, interactive=False)
753
+ else:
754
+ think_checkbox_update = gr.update(value=True, interactive=True)
755
+
756
+ return (
757
+ gr.update(visible=is_simple), # simple_mode_group
758
+ gr.update(visible=not is_simple), # custom_mode_content - visible for custom/cover/repaint
759
+ gr.update(visible=False), # cover_mode_group - legacy, always hidden
760
+ gr.update(visible=is_repaint), # repainting_group - time range controls
761
+ gr.update(value=task_type_value), # task_type
762
+ gr.update(interactive=True), # generate_btn - always enabled (Simple mode does create+generate in one step)
763
+ False, # simple_sample_created - reset to False on mode change
764
+ gr.update(visible=is_cover or is_repaint), # src_audio_group - shown for cover and repaint
765
+ gr.update(visible=is_cover), # audio_cover_strength - only shown for cover mode
766
+ think_checkbox_update, # think_checkbox - disabled for cover/repaint modes
767
+ )
768
+
769
+ def process_source_audio(dit_handler, llm_handler, src_audio, constrained_decoding_debug):
770
+ """
771
+ Process source audio: convert to codes and then transcribe.
772
+ This combines convert_src_audio_to_codes_wrapper + transcribe_audio_codes.
773
+
774
+ Args:
775
+ dit_handler: DiT handler instance
776
+ llm_handler: LLM handler instance
777
+ src_audio: Path to source audio file
778
+ constrained_decoding_debug: Whether to enable debug logging
779
+
780
+ Returns:
781
+ Tuple of (audio_codes, status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature, is_format_caption)
782
+ """
783
+ if src_audio is None:
784
+ return ("", "No audio file provided", "", "", None, None, "", "", "", False)
785
+
786
+ # Step 1: Convert audio to codes
787
+ try:
788
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
789
+ if not codes_string:
790
+ return ("", "Failed to convert audio to codes", "", "", None, None, "", "", "", False)
791
+ except Exception as e:
792
+ return ("", f"Error converting audio: {str(e)}", "", "", None, None, "", "", "", False)
793
+
794
+ # Step 2: Transcribe the codes
795
+ result = understand_music(
796
+ llm_handler=llm_handler,
797
+ audio_codes=codes_string,
798
+ use_constrained_decoding=True,
799
+ constrained_decoding_debug=constrained_decoding_debug,
800
+ )
801
+
802
+ # Handle error case
803
+ if not result.success:
804
+ if result.error == "LLM not initialized":
805
+ return (codes_string, t("messages.lm_not_initialized"), "", "", None, None, "", "", "", False)
806
+ return (codes_string, result.status_message, "", "", None, None, "", "", "", False)
807
+
808
+ return (
809
+ codes_string,
810
+ result.status_message,
811
+ result.caption,
812
+ result.lyrics,
813
+ result.bpm,
814
+ result.duration,
815
+ result.keyscale,
816
+ result.language,
817
+ result.timesignature,
818
+ True # Set is_format_caption to True
819
+ )
820
+
821
+ def handle_create_sample(
822
+ llm_handler,
823
+ query: str,
824
+ instrumental: bool,
825
+ vocal_language: str,
826
+ lm_temperature: float,
827
+ lm_top_k: int,
828
+ lm_top_p: float,
829
+ constrained_decoding_debug: bool = False,
830
+ ):
831
+ """
832
+ Handle the Create Sample button click in Simple mode.
833
+
834
+ Creates a sample from the user's query using the LLM, then populates
835
+ the caption, lyrics, and metadata fields.
836
+
837
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
838
+
839
+ Args:
840
+ llm_handler: LLM handler instance (unused, fetched from registry)
841
+ query: User's natural language music description
842
+ instrumental: Whether to generate instrumental music
843
+ vocal_language: Preferred vocal language for constrained decoding
844
+ lm_temperature: LLM temperature for generation
845
+ lm_top_k: LLM top-k sampling
846
+ lm_top_p: LLM top-p sampling
847
+ constrained_decoding_debug: Whether to enable debug logging
848
+
849
+ Returns:
850
+ Tuple of updates for:
851
+ - captions
852
+ - lyrics
853
+ - bpm
854
+ - audio_duration
855
+ - key_scale
856
+ - vocal_language
857
+ - time_signature
858
+ - instrumental_checkbox
859
+ - caption_accordion (open)
860
+ - lyrics_accordion (open)
861
+ - generate_btn (interactive)
862
+ - simple_sample_created (True)
863
+ - think_checkbox (True)
864
+ - is_format_caption_state (True)
865
+ - status_output
866
+ """
867
+ # Check if LLM is initialized
868
+ if not llm_handler.llm_initialized:
869
+ gr.Warning(t("messages.lm_not_initialized"))
870
+ return (
871
+ gr.update(), # captions - no change
872
+ gr.update(), # lyrics - no change
873
+ gr.update(), # bpm - no change
874
+ gr.update(), # audio_duration - no change
875
+ gr.update(), # key_scale - no change
876
+ gr.update(), # vocal_language - no change
877
+ gr.update(), # time_signature - no change
878
+ gr.update(), # instrumental_checkbox - no change
879
+ gr.update(), # caption_accordion - no change
880
+ gr.update(), # lyrics_accordion - no change
881
+ gr.update(interactive=False), # generate_btn - keep disabled
882
+ False, # simple_sample_created - still False
883
+ gr.update(), # think_checkbox - no change
884
+ gr.update(), # is_format_caption_state - no change
885
+ t("messages.lm_not_initialized"), # status_output
886
+ )
887
+
888
+ # Convert LM parameters
889
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
890
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
891
+
892
+ # Call create_sample API
893
+ # Note: cfg_scale and negative_prompt are not supported in create_sample mode
894
+ result = create_sample(
895
+ llm_handler=llm_handler,
896
+ query=query,
897
+ instrumental=instrumental,
898
+ vocal_language=vocal_language,
899
+ temperature=lm_temperature,
900
+ top_k=top_k_value,
901
+ top_p=top_p_value,
902
+ use_constrained_decoding=True,
903
+ constrained_decoding_debug=constrained_decoding_debug,
904
+ )
905
+
906
+ # Handle error
907
+ if not result.success:
908
+ gr.Warning(result.status_message or t("messages.sample_creation_failed"))
909
+ return (
910
+ gr.update(), # captions - no change
911
+ gr.update(), # lyrics - no change
912
+ gr.update(), # bpm - no change
913
+ gr.update(), # audio_duration - no change
914
+ gr.update(), # key_scale - no change
915
+ gr.update(), # vocal_language - no change
916
+ gr.update(), # simple vocal_language - no change
917
+ gr.update(), # time_signature - no change
918
+ gr.update(), # instrumental_checkbox - no change
919
+ gr.update(), # caption_accordion - no change
920
+ gr.update(), # lyrics_accordion - no change
921
+ gr.update(interactive=False), # generate_btn - keep disabled
922
+ False, # simple_sample_created - still False
923
+ gr.update(), # think_checkbox - no change
924
+ gr.update(), # is_format_caption_state - no change
925
+ result.status_message or t("messages.sample_creation_failed"), # status_output
926
+ )
927
+
928
+ # Success - populate fields
929
+ gr.Info(t("messages.sample_created"))
930
+
931
+ return (
932
+ result.caption, # captions
933
+ result.lyrics, # lyrics
934
+ result.bpm, # bpm
935
+ result.duration if result.duration and result.duration > 0 else -1, # audio_duration
936
+ result.keyscale, # key_scale
937
+ result.language, # vocal_language
938
+ result.language, # simple vocal_language
939
+ result.timesignature, # time_signature
940
+ result.instrumental, # instrumental_checkbox
941
+ gr.Accordion(open=True), # caption_accordion - expand
942
+ gr.Accordion(open=True), # lyrics_accordion - expand
943
+ gr.update(interactive=True), # generate_btn - enable
944
+ True, # simple_sample_created - True
945
+ True, # think_checkbox - enable thinking
946
+ True, # is_format_caption_state - True (LM-generated)
947
+ result.status_message, # status_output
948
+ )
949
+
950
+ def handle_format_sample(
951
+ llm_handler,
952
+ caption: str,
953
+ lyrics: str,
954
+ bpm,
955
+ audio_duration,
956
+ key_scale: str,
957
+ time_signature: str,
958
+ lm_temperature: float,
959
+ lm_top_k: int,
960
+ lm_top_p: float,
961
+ constrained_decoding_debug: bool = False,
962
+ ):
963
+ """
964
+ Handle the Format button click to format caption and lyrics.
965
+
966
+ Takes user-provided caption and lyrics, and uses the LLM to generate
967
+ structured music metadata and an enhanced description.
968
+
969
+ Note: cfg_scale and negative_prompt are not supported in format mode.
970
+
971
+ Args:
972
+ llm_handler: LLM handler instance (unused, fetched from registry)
973
+ caption: User's caption/description
974
+ lyrics: User's lyrics
975
+ bpm: User-provided BPM (optional, for constrained decoding)
976
+ audio_duration: User-provided duration (optional, for constrained decoding)
977
+ key_scale: User-provided key scale (optional, for constrained decoding)
978
+ time_signature: User-provided time signature (optional, for constrained decoding)
979
+ lm_temperature: LLM temperature for generation
980
+ lm_top_k: LLM top-k sampling
981
+ lm_top_p: LLM top-p sampling
982
+ constrained_decoding_debug: Whether to enable debug logging
983
+
984
+ Returns:
985
+ Tuple of updates for:
986
+ - captions
987
+ - lyrics
988
+ - bpm
989
+ - audio_duration
990
+ - key_scale
991
+ - vocal_language
992
+ - time_signature
993
+ - is_format_caption_state
994
+ - status_output
995
+ """
996
+ # Check if LLM is initialized
997
+ if not llm_handler.llm_initialized:
998
+ gr.Warning(t("messages.lm_not_initialized"))
999
+ return (
1000
+ gr.update(), # captions - no change
1001
+ gr.update(), # lyrics - no change
1002
+ gr.update(), # bpm - no change
1003
+ gr.update(), # audio_duration - no change
1004
+ gr.update(), # key_scale - no change
1005
+ gr.update(), # vocal_language - no change
1006
+ gr.update(), # time_signature - no change
1007
+ gr.update(), # is_format_caption_state - no change
1008
+ t("messages.lm_not_initialized"), # status_output
1009
+ )
1010
+
1011
+ # Build user_metadata from provided values for constrained decoding
1012
+ user_metadata = {}
1013
+ if bpm is not None and bpm > 0:
1014
+ user_metadata['bpm'] = int(bpm)
1015
+ if audio_duration is not None and audio_duration > 0:
1016
+ user_metadata['duration'] = int(audio_duration)
1017
+ if key_scale and key_scale.strip():
1018
+ user_metadata['keyscale'] = key_scale.strip()
1019
+ if time_signature and time_signature.strip():
1020
+ user_metadata['timesignature'] = time_signature.strip()
1021
+
1022
+ # Only pass user_metadata if we have at least one field
1023
+ user_metadata_to_pass = user_metadata if user_metadata else None
1024
+
1025
+ # Convert LM parameters
1026
+ top_k_value = None if not lm_top_k or lm_top_k == 0 else int(lm_top_k)
1027
+ top_p_value = None if not lm_top_p or lm_top_p >= 1.0 else lm_top_p
1028
+
1029
+ # Call format_sample API
1030
+ result = format_sample(
1031
+ llm_handler=llm_handler,
1032
+ caption=caption,
1033
+ lyrics=lyrics,
1034
+ user_metadata=user_metadata_to_pass,
1035
+ temperature=lm_temperature,
1036
+ top_k=top_k_value,
1037
+ top_p=top_p_value,
1038
+ use_constrained_decoding=True,
1039
+ constrained_decoding_debug=constrained_decoding_debug,
1040
+ )
1041
+
1042
+ # Handle error
1043
+ if not result.success:
1044
+ gr.Warning(result.status_message or t("messages.format_failed"))
1045
+ return (
1046
+ gr.update(), # captions - no change
1047
+ gr.update(), # lyrics - no change
1048
+ gr.update(), # bpm - no change
1049
+ gr.update(), # audio_duration - no change
1050
+ gr.update(), # key_scale - no change
1051
+ gr.update(), # vocal_language - no change
1052
+ gr.update(), # time_signature - no change
1053
+ gr.update(), # is_format_caption_state - no change
1054
+ result.status_message or t("messages.format_failed"), # status_output
1055
+ )
1056
+
1057
+ # Success - populate fields
1058
+ gr.Info(t("messages.format_success"))
1059
+
1060
+ return (
1061
+ result.caption, # captions
1062
+ result.lyrics, # lyrics
1063
+ result.bpm, # bpm
1064
+ result.duration if result.duration and result.duration > 0 else -1, # audio_duration
1065
+ result.keyscale, # key_scale
1066
+ result.language, # vocal_language
1067
+ result.timesignature, # time_signature
1068
+ True, # is_format_caption_state - True (LM-formatted)
1069
+ result.status_message, # status_output
1070
+ )
1071
+
acestep/gradio_ui/events/results_handlers.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/gradio_ui/events/training_handlers.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Event Handlers for Training Tab
3
+
4
+ Contains all event handler functions for the dataset builder and training UI.
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from typing import Any, Dict, List, Tuple, Optional
10
+ from loguru import logger
11
+ import gradio as gr
12
+
13
+ from acestep.training.dataset_builder import DatasetBuilder, AudioSample
14
+
15
+
16
+ def create_dataset_builder() -> DatasetBuilder:
17
+ """Create a new DatasetBuilder instance."""
18
+ return DatasetBuilder()
19
+
20
+
21
+ def scan_directory(
22
+ audio_dir: str,
23
+ dataset_name: str,
24
+ custom_tag: str,
25
+ tag_position: str,
26
+ all_instrumental: bool,
27
+ builder_state: Optional[DatasetBuilder],
28
+ ) -> Tuple[Any, str, Any, DatasetBuilder]:
29
+ """Scan a directory for audio files.
30
+
31
+ Returns:
32
+ Tuple of (table_data, status, slider_update, builder_state)
33
+ """
34
+ if not audio_dir or not audio_dir.strip():
35
+ return [], "❌ Please enter a directory path", gr.Slider(maximum=0, value=0), builder_state
36
+
37
+ # Create or use existing builder
38
+ builder = builder_state if builder_state else DatasetBuilder()
39
+
40
+ # Set metadata before scanning
41
+ builder.metadata.name = dataset_name
42
+ builder.metadata.custom_tag = custom_tag
43
+ builder.metadata.tag_position = tag_position
44
+ builder.metadata.all_instrumental = all_instrumental
45
+
46
+ # Scan directory
47
+ samples, status = builder.scan_directory(audio_dir.strip())
48
+
49
+ if not samples:
50
+ return [], status, gr.Slider(maximum=0, value=0), builder
51
+
52
+ # Set instrumental and tag for all samples
53
+ builder.set_all_instrumental(all_instrumental)
54
+ if custom_tag:
55
+ builder.set_custom_tag(custom_tag, tag_position)
56
+
57
+ # Get table data
58
+ table_data = builder.get_samples_dataframe_data()
59
+
60
+ # Calculate slider max and return as Slider update
61
+ slider_max = max(0, len(samples) - 1)
62
+
63
+ return table_data, status, gr.Slider(maximum=slider_max, value=0), builder
64
+
65
+
66
+ def auto_label_all(
67
+ dit_handler,
68
+ llm_handler,
69
+ builder_state: Optional[DatasetBuilder],
70
+ skip_metas: bool = False,
71
+ progress=None,
72
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
73
+ """Auto-label all samples in the dataset.
74
+
75
+ Args:
76
+ dit_handler: DiT handler for audio processing
77
+ llm_handler: LLM handler for caption generation
78
+ builder_state: Dataset builder state
79
+ skip_metas: If True, skip LLM labeling. BPM/Key/TimeSig = N/A, Language = unknown for instrumental
80
+ progress: Progress callback
81
+
82
+ Returns:
83
+ Tuple of (table_data, status, builder_state)
84
+ """
85
+ if builder_state is None:
86
+ return [], "❌ Please scan a directory first", builder_state
87
+
88
+ if not builder_state.samples:
89
+ return [], "❌ No samples to label. Please scan a directory first.", builder_state
90
+
91
+ # If skip_metas is True, just set default values without LLM
92
+ if skip_metas:
93
+ for sample in builder_state.samples:
94
+ sample.bpm = None # Will display as N/A
95
+ sample.keyscale = "N/A"
96
+ sample.timesignature = "N/A"
97
+ # For instrumental, language should be "unknown"
98
+ if sample.is_instrumental:
99
+ sample.language = "unknown"
100
+ else:
101
+ sample.language = "unknown"
102
+ # Use custom tag as caption if set, otherwise use filename
103
+ if builder_state.metadata.custom_tag:
104
+ sample.caption = builder_state.metadata.custom_tag
105
+ else:
106
+ sample.caption = sample.filename
107
+
108
+ table_data = builder_state.get_samples_dataframe_data()
109
+ return table_data, f"✅ Skipped AI labeling. {len(builder_state.samples)} samples set with default values.", builder_state
110
+
111
+ # Check if handlers are initialized
112
+ if dit_handler is None or dit_handler.model is None:
113
+ return builder_state.get_samples_dataframe_data(), "❌ Model not initialized. Please initialize the service first.", builder_state
114
+
115
+ if llm_handler is None or not llm_handler.llm_initialized:
116
+ return builder_state.get_samples_dataframe_data(), "❌ LLM not initialized. Please initialize the service with LLM enabled.", builder_state
117
+
118
+ def progress_callback(msg):
119
+ if progress:
120
+ try:
121
+ progress(msg)
122
+ except:
123
+ pass
124
+
125
+ # Label all samples
126
+ samples, status = builder_state.label_all_samples(
127
+ dit_handler=dit_handler,
128
+ llm_handler=llm_handler,
129
+ progress_callback=progress_callback,
130
+ )
131
+
132
+ # Get updated table data
133
+ table_data = builder_state.get_samples_dataframe_data()
134
+
135
+ return table_data, status, builder_state
136
+
137
+
138
+ def get_sample_preview(
139
+ sample_idx: int,
140
+ builder_state: Optional[DatasetBuilder],
141
+ ) -> Tuple[str, str, str, str, Optional[int], str, str, float, str, bool]:
142
+ """Get preview data for a specific sample.
143
+
144
+ Returns:
145
+ Tuple of (audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
146
+ """
147
+ if builder_state is None or not builder_state.samples:
148
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
149
+
150
+ idx = int(sample_idx)
151
+ if idx < 0 or idx >= len(builder_state.samples):
152
+ return None, "", "", "", None, "", "", 0.0, "instrumental", True
153
+
154
+ sample = builder_state.samples[idx]
155
+
156
+ return (
157
+ sample.audio_path,
158
+ sample.filename,
159
+ sample.caption,
160
+ sample.lyrics,
161
+ sample.bpm,
162
+ sample.keyscale,
163
+ sample.timesignature,
164
+ sample.duration,
165
+ sample.language,
166
+ sample.is_instrumental,
167
+ )
168
+
169
+
170
+ def save_sample_edit(
171
+ sample_idx: int,
172
+ caption: str,
173
+ lyrics: str,
174
+ bpm: Optional[int],
175
+ keyscale: str,
176
+ timesig: str,
177
+ language: str,
178
+ is_instrumental: bool,
179
+ builder_state: Optional[DatasetBuilder],
180
+ ) -> Tuple[List[List[Any]], str, DatasetBuilder]:
181
+ """Save edits to a sample.
182
+
183
+ Returns:
184
+ Tuple of (table_data, status, builder_state)
185
+ """
186
+ if builder_state is None:
187
+ return [], "❌ No dataset loaded", builder_state
188
+
189
+ idx = int(sample_idx)
190
+
191
+ # Update sample
192
+ sample, status = builder_state.update_sample(
193
+ idx,
194
+ caption=caption,
195
+ lyrics=lyrics if not is_instrumental else "[Instrumental]",
196
+ bpm=int(bpm) if bpm else None,
197
+ keyscale=keyscale,
198
+ timesignature=timesig,
199
+ language="instrumental" if is_instrumental else language,
200
+ is_instrumental=is_instrumental,
201
+ labeled=True,
202
+ )
203
+
204
+ # Get updated table data
205
+ table_data = builder_state.get_samples_dataframe_data()
206
+
207
+ return table_data, status, builder_state
208
+
209
+
210
+ def update_settings(
211
+ custom_tag: str,
212
+ tag_position: str,
213
+ all_instrumental: bool,
214
+ builder_state: Optional[DatasetBuilder],
215
+ ) -> DatasetBuilder:
216
+ """Update dataset settings.
217
+
218
+ Returns:
219
+ Updated builder_state
220
+ """
221
+ if builder_state is None:
222
+ return builder_state
223
+
224
+ if custom_tag:
225
+ builder_state.set_custom_tag(custom_tag, tag_position)
226
+
227
+ builder_state.set_all_instrumental(all_instrumental)
228
+
229
+ return builder_state
230
+
231
+
232
+ def save_dataset(
233
+ save_path: str,
234
+ dataset_name: str,
235
+ builder_state: Optional[DatasetBuilder],
236
+ ) -> str:
237
+ """Save the dataset to a JSON file.
238
+
239
+ Returns:
240
+ Status message
241
+ """
242
+ if builder_state is None:
243
+ return "❌ No dataset to save. Please scan a directory first."
244
+
245
+ if not builder_state.samples:
246
+ return "❌ No samples in dataset."
247
+
248
+ if not save_path or not save_path.strip():
249
+ return "❌ Please enter a save path."
250
+
251
+ # Check if any samples are labeled
252
+ labeled_count = builder_state.get_labeled_count()
253
+ if labeled_count == 0:
254
+ return "⚠️ Warning: No samples have been labeled. Consider auto-labeling first.\nSaving anyway..."
255
+
256
+ return builder_state.save_dataset(save_path.strip(), dataset_name)
257
+
258
+
259
+ def load_existing_dataset_for_preprocess(
260
+ dataset_path: str,
261
+ builder_state: Optional[DatasetBuilder],
262
+ ) -> Tuple[str, Any, Any, DatasetBuilder, str, str, str, str, Optional[int], str, str, float, str, bool]:
263
+ """Load an existing dataset JSON file for preprocessing.
264
+
265
+ This allows users to load a previously saved dataset and proceed to preprocessing
266
+ without having to re-scan and re-label.
267
+
268
+ Returns:
269
+ Tuple of (status, table_data, slider_update, builder_state,
270
+ audio_path, filename, caption, lyrics, bpm, keyscale, timesig, duration, language, instrumental)
271
+ """
272
+ empty_preview = (None, "", "", "", None, "", "", 0.0, "instrumental", True)
273
+
274
+ if not dataset_path or not dataset_path.strip():
275
+ return ("❌ Please enter a dataset path", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
276
+
277
+ dataset_path = dataset_path.strip()
278
+
279
+ if not os.path.exists(dataset_path):
280
+ return (f"❌ Dataset not found: {dataset_path}", [], gr.Slider(maximum=0, value=0), builder_state) + empty_preview
281
+
282
+ # Create new builder (don't reuse old state when loading a file)
283
+ builder = DatasetBuilder()
284
+
285
+ # Load the dataset
286
+ samples, status = builder.load_dataset(dataset_path)
287
+
288
+ if not samples:
289
+ return (status, [], gr.Slider(maximum=0, value=0), builder) + empty_preview
290
+
291
+ # Get table data
292
+ table_data = builder.get_samples_dataframe_data()
293
+
294
+ # Calculate slider max
295
+ slider_max = max(0, len(samples) - 1)
296
+
297
+ # Create info text
298
+ labeled_count = builder.get_labeled_count()
299
+ info = f"✅ Loaded dataset: {builder.metadata.name}\n"
300
+ info += f"📊 Samples: {len(samples)} ({labeled_count} labeled)\n"
301
+ info += f"🏷️ Custom Tag: {builder.metadata.custom_tag or '(none)'}\n"
302
+ info += "📝 Ready for preprocessing! You can also edit samples below."
303
+
304
+ # Get first sample preview
305
+ first_sample = builder.samples[0]
306
+ preview = (
307
+ first_sample.audio_path,
308
+ first_sample.filename,
309
+ first_sample.caption,
310
+ first_sample.lyrics,
311
+ first_sample.bpm,
312
+ first_sample.keyscale,
313
+ first_sample.timesignature,
314
+ first_sample.duration,
315
+ first_sample.language,
316
+ first_sample.is_instrumental,
317
+ )
318
+
319
+ return (info, table_data, gr.Slider(maximum=slider_max, value=0), builder) + preview
320
+
321
+
322
+ def preprocess_dataset(
323
+ output_dir: str,
324
+ dit_handler,
325
+ builder_state: Optional[DatasetBuilder],
326
+ progress=None,
327
+ ) -> str:
328
+ """Preprocess dataset to tensor files for fast training.
329
+
330
+ This converts audio files to VAE latents and text to embeddings.
331
+
332
+ Returns:
333
+ Status message
334
+ """
335
+ if builder_state is None:
336
+ return "❌ No dataset loaded. Please scan a directory first."
337
+
338
+ if not builder_state.samples:
339
+ return "❌ No samples in dataset."
340
+
341
+ labeled_count = builder_state.get_labeled_count()
342
+ if labeled_count == 0:
343
+ return "❌ No labeled samples. Please auto-label or manually label samples first."
344
+
345
+ if not output_dir or not output_dir.strip():
346
+ return "❌ Please enter an output directory."
347
+
348
+ if dit_handler is None or dit_handler.model is None:
349
+ return "❌ Model not initialized. Please initialize the service first."
350
+
351
+ def progress_callback(msg):
352
+ if progress:
353
+ try:
354
+ progress(msg)
355
+ except:
356
+ pass
357
+
358
+ # Run preprocessing
359
+ output_paths, status = builder_state.preprocess_to_tensors(
360
+ dit_handler=dit_handler,
361
+ output_dir=output_dir.strip(),
362
+ progress_callback=progress_callback,
363
+ )
364
+
365
+ return status
366
+
367
+
368
+ def load_training_dataset(
369
+ tensor_dir: str,
370
+ ) -> str:
371
+ """Load a preprocessed tensor dataset for training.
372
+
373
+ Returns:
374
+ Info text about the dataset
375
+ """
376
+ if not tensor_dir or not tensor_dir.strip():
377
+ return "❌ Please enter a tensor directory path"
378
+
379
+ tensor_dir = tensor_dir.strip()
380
+
381
+ if not os.path.exists(tensor_dir):
382
+ return f"❌ Directory not found: {tensor_dir}"
383
+
384
+ if not os.path.isdir(tensor_dir):
385
+ return f"❌ Not a directory: {tensor_dir}"
386
+
387
+ # Check for manifest
388
+ manifest_path = os.path.join(tensor_dir, "manifest.json")
389
+ if os.path.exists(manifest_path):
390
+ try:
391
+ with open(manifest_path, 'r') as f:
392
+ manifest = json.load(f)
393
+
394
+ num_samples = manifest.get("num_samples", 0)
395
+ metadata = manifest.get("metadata", {})
396
+ name = metadata.get("name", "Unknown")
397
+ custom_tag = metadata.get("custom_tag", "")
398
+
399
+ info = f"✅ Loaded preprocessed dataset: {name}\n"
400
+ info += f"📊 Samples: {num_samples} preprocessed tensors\n"
401
+ info += f"🏷️ Custom Tag: {custom_tag or '(none)'}"
402
+
403
+ return info
404
+ except Exception as e:
405
+ logger.warning(f"Failed to read manifest: {e}")
406
+
407
+ # Fallback: count .pt files
408
+ pt_files = [f for f in os.listdir(tensor_dir) if f.endswith('.pt')]
409
+
410
+ if not pt_files:
411
+ return f"❌ No .pt tensor files found in {tensor_dir}"
412
+
413
+ info = f"✅ Found {len(pt_files)} tensor files in {tensor_dir}\n"
414
+ info += "⚠️ No manifest.json found - using all .pt files"
415
+
416
+ return info
417
+
418
+
419
+ # Training handlers
420
+
421
+ import time
422
+ import re
423
+
424
+
425
+ def _format_duration(seconds):
426
+ """Format seconds to human readable string."""
427
+ seconds = int(seconds)
428
+ if seconds < 60:
429
+ return f"{seconds}s"
430
+ elif seconds < 3600:
431
+ return f"{seconds // 60}m {seconds % 60}s"
432
+ else:
433
+ return f"{seconds // 3600}h {(seconds % 3600) // 60}m"
434
+
435
+
436
+ def start_training(
437
+ tensor_dir: str,
438
+ dit_handler,
439
+ lora_rank: int,
440
+ lora_alpha: int,
441
+ lora_dropout: float,
442
+ learning_rate: float,
443
+ train_epochs: int,
444
+ train_batch_size: int,
445
+ gradient_accumulation: int,
446
+ save_every_n_epochs: int,
447
+ training_shift: float,
448
+ training_seed: int,
449
+ lora_output_dir: str,
450
+ training_state: Dict,
451
+ progress=None,
452
+ ):
453
+ """Start LoRA training from preprocessed tensors.
454
+
455
+ This is a generator function that yields progress updates.
456
+ """
457
+ if not tensor_dir or not tensor_dir.strip():
458
+ yield "❌ Please enter a tensor directory path", "", None, training_state
459
+ return
460
+
461
+ tensor_dir = tensor_dir.strip()
462
+
463
+ if not os.path.exists(tensor_dir):
464
+ yield f"❌ Tensor directory not found: {tensor_dir}", "", None, training_state
465
+ return
466
+
467
+ if dit_handler is None or dit_handler.model is None:
468
+ yield "❌ Model not initialized. Please initialize the service first.", "", None, training_state
469
+ return
470
+
471
+ # Check for required training dependencies
472
+ try:
473
+ from lightning.fabric import Fabric
474
+ from peft import get_peft_model, LoraConfig
475
+ except ImportError as e:
476
+ yield f"❌ Missing required packages: {e}\nPlease install: pip install peft lightning", "", None, training_state
477
+ return
478
+
479
+ training_state["is_training"] = True
480
+ training_state["should_stop"] = False
481
+
482
+ try:
483
+ from acestep.training.trainer import LoRATrainer
484
+ from acestep.training.configs import LoRAConfig as LoRAConfigClass, TrainingConfig
485
+
486
+ # Create configs
487
+ lora_config = LoRAConfigClass(
488
+ r=lora_rank,
489
+ alpha=lora_alpha,
490
+ dropout=lora_dropout,
491
+ )
492
+
493
+ training_config = TrainingConfig(
494
+ shift=training_shift,
495
+ learning_rate=learning_rate,
496
+ batch_size=train_batch_size,
497
+ gradient_accumulation_steps=gradient_accumulation,
498
+ max_epochs=train_epochs,
499
+ save_every_n_epochs=save_every_n_epochs,
500
+ seed=training_seed,
501
+ output_dir=lora_output_dir,
502
+ )
503
+
504
+ import pandas as pd
505
+
506
+ # Initialize training log and loss history
507
+ log_lines = []
508
+ loss_data = pd.DataFrame({"step": [0], "loss": [0.0]})
509
+
510
+ # Start timer
511
+ start_time = time.time()
512
+
513
+ yield f"🚀 Starting training from {tensor_dir}...", "", loss_data, training_state
514
+
515
+ # Create trainer
516
+ trainer = LoRATrainer(
517
+ dit_handler=dit_handler,
518
+ lora_config=lora_config,
519
+ training_config=training_config,
520
+ )
521
+
522
+ # Collect loss history
523
+ step_list = []
524
+ loss_list = []
525
+
526
+ # Train with progress updates using preprocessed tensors
527
+ for step, loss, status in trainer.train_from_preprocessed(tensor_dir, training_state):
528
+ # Calculate elapsed time and ETA
529
+ elapsed_seconds = time.time() - start_time
530
+ time_info = f"⏱️ Elapsed: {_format_duration(elapsed_seconds)}"
531
+
532
+ # Parse "Epoch x/y" from status to calculate ETA
533
+ match = re.search(r"Epoch\s+(\d+)/(\d+)", str(status))
534
+ if match:
535
+ current_ep = int(match.group(1))
536
+ total_ep = int(match.group(2))
537
+ if current_ep > 0:
538
+ eta_seconds = (elapsed_seconds / current_ep) * (total_ep - current_ep)
539
+ time_info += f" | ETA: ~{_format_duration(eta_seconds)}"
540
+
541
+ # Display status with time info
542
+ display_status = f"{status}\n{time_info}"
543
+
544
+ # Terminal log
545
+ log_msg = f"[{_format_duration(elapsed_seconds)}] Step {step}: {status}"
546
+ logger.info(log_msg)
547
+
548
+ # Add to UI log
549
+ log_lines.append(status)
550
+ if len(log_lines) > 15:
551
+ log_lines = log_lines[-15:]
552
+ log_text = "\n".join(log_lines)
553
+
554
+ # Track loss for plot (only valid values)
555
+ if step > 0 and loss is not None and loss == loss: # Check for NaN
556
+ step_list.append(step)
557
+ loss_list.append(float(loss))
558
+ loss_data = pd.DataFrame({"step": step_list, "loss": loss_list})
559
+
560
+ yield display_status, log_text, loss_data, training_state
561
+
562
+ if training_state.get("should_stop", False):
563
+ logger.info("⏹️ Training stopped by user")
564
+ log_lines.append("⏹️ Training stopped by user")
565
+ yield f"⏹️ Stopped ({time_info})", "\n".join(log_lines[-15:]), loss_data, training_state
566
+ break
567
+
568
+ total_time = time.time() - start_time
569
+ training_state["is_training"] = False
570
+ completion_msg = f"✅ Training completed! Total time: {_format_duration(total_time)}"
571
+
572
+ logger.info(completion_msg)
573
+ log_lines.append(completion_msg)
574
+
575
+ yield completion_msg, "\n".join(log_lines[-15:]), loss_data, training_state
576
+
577
+ except Exception as e:
578
+ logger.exception("Training error")
579
+ training_state["is_training"] = False
580
+ import pandas as pd
581
+ empty_df = pd.DataFrame({"step": [], "loss": []})
582
+ yield f"❌ Error: {str(e)}", str(e), empty_df, training_state
583
+
584
+
585
+ def stop_training(training_state: Dict) -> Tuple[str, Dict]:
586
+ """Stop the current training process.
587
+
588
+ Returns:
589
+ Tuple of (status, training_state)
590
+ """
591
+ if not training_state.get("is_training", False):
592
+ return "⚠️ No training in progress", training_state
593
+
594
+ training_state["should_stop"] = True
595
+ return "⏹️ Stopping training...", training_state
596
+
597
+
598
+ def export_lora(
599
+ export_path: str,
600
+ lora_output_dir: str,
601
+ ) -> str:
602
+ """Export the trained LoRA weights.
603
+
604
+ Returns:
605
+ Status message
606
+ """
607
+ if not export_path or not export_path.strip():
608
+ return "❌ Please enter an export path"
609
+
610
+ # Check if there's a trained model to export
611
+ final_dir = os.path.join(lora_output_dir, "final")
612
+ checkpoint_dir = os.path.join(lora_output_dir, "checkpoints")
613
+
614
+ # Prefer final, fallback to checkpoints
615
+ if os.path.exists(final_dir):
616
+ source_path = final_dir
617
+ elif os.path.exists(checkpoint_dir):
618
+ # Find the latest checkpoint
619
+ checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("epoch_")]
620
+ if not checkpoints:
621
+ return "❌ No checkpoints found"
622
+
623
+ checkpoints.sort(key=lambda x: int(x.split("_")[1]))
624
+ latest = checkpoints[-1]
625
+ source_path = os.path.join(checkpoint_dir, latest)
626
+ else:
627
+ return f"❌ No trained model found in {lora_output_dir}"
628
+
629
+ try:
630
+ import shutil
631
+
632
+ export_path = export_path.strip()
633
+ os.makedirs(os.path.dirname(export_path) if os.path.dirname(export_path) else ".", exist_ok=True)
634
+
635
+ if os.path.exists(export_path):
636
+ shutil.rmtree(export_path)
637
+
638
+ shutil.copytree(source_path, export_path)
639
+
640
+ return f"✅ LoRA exported to {export_path}"
641
+
642
+ except Exception as e:
643
+ logger.exception("Export error")
644
+ return f"❌ Export failed: {str(e)}"
acestep/gradio_ui/i18n.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Internationalization (i18n) module for Gradio UI
3
+ Supports multiple languages with easy translation management
4
+ """
5
+ import os
6
+ import json
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class I18n:
11
+ """Internationalization handler"""
12
+
13
+ def __init__(self, default_language: str = "en"):
14
+ """
15
+ Initialize i18n handler
16
+
17
+ Args:
18
+ default_language: Default language code (en, zh, ja, etc.)
19
+ """
20
+ self.current_language = default_language
21
+ self.translations: Dict[str, Dict[str, str]] = {}
22
+ self._load_all_translations()
23
+
24
+ def _load_all_translations(self):
25
+ """Load all translation files from i18n directory"""
26
+ current_file = os.path.abspath(__file__)
27
+ module_dir = os.path.dirname(current_file)
28
+ i18n_dir = os.path.join(module_dir, "i18n")
29
+
30
+ if not os.path.exists(i18n_dir):
31
+ # Create i18n directory if it doesn't exist
32
+ os.makedirs(i18n_dir)
33
+ return
34
+
35
+ # Load all JSON files in i18n directory
36
+ for filename in os.listdir(i18n_dir):
37
+ if filename.endswith(".json"):
38
+ lang_code = filename[:-5] # Remove .json extension
39
+ filepath = os.path.join(i18n_dir, filename)
40
+ try:
41
+ with open(filepath, 'r', encoding='utf-8') as f:
42
+ self.translations[lang_code] = json.load(f)
43
+ except Exception as e:
44
+ print(f"Error loading translation file {filename}: {e}")
45
+
46
+ def set_language(self, language: str):
47
+ """Set current language"""
48
+ if language in self.translations:
49
+ self.current_language = language
50
+ else:
51
+ print(f"Warning: Language '{language}' not found, using default")
52
+
53
+ def t(self, key: str, **kwargs) -> str:
54
+ """
55
+ Translate a key to current language
56
+
57
+ Args:
58
+ key: Translation key (dot-separated for nested keys)
59
+ **kwargs: Optional format parameters
60
+
61
+ Returns:
62
+ Translated string
63
+ """
64
+ # Get translation from current language
65
+ translation = self._get_nested_value(
66
+ self.translations.get(self.current_language, {}),
67
+ key
68
+ )
69
+
70
+ # Fallback to English if not found
71
+ if translation is None:
72
+ translation = self._get_nested_value(
73
+ self.translations.get('en', {}),
74
+ key
75
+ )
76
+
77
+ # Final fallback to key itself
78
+ if translation is None:
79
+ translation = key
80
+
81
+ # Apply formatting if kwargs provided
82
+ if kwargs:
83
+ try:
84
+ translation = translation.format(**kwargs)
85
+ except KeyError:
86
+ pass
87
+
88
+ return translation
89
+
90
+ def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
91
+ """
92
+ Get nested dictionary value using dot notation
93
+
94
+ Args:
95
+ data: Dictionary to search
96
+ key: Dot-separated key (e.g., "section.subsection.key")
97
+
98
+ Returns:
99
+ Value if found, None otherwise
100
+ """
101
+ keys = key.split('.')
102
+ current = data
103
+
104
+ for k in keys:
105
+ if isinstance(current, dict) and k in current:
106
+ current = current[k]
107
+ else:
108
+ return None
109
+
110
+ return current if isinstance(current, str) else None
111
+
112
+ def get_available_languages(self) -> list:
113
+ """Get list of available language codes"""
114
+ return list(self.translations.keys())
115
+
116
+
117
+ # Global i18n instance
118
+ _i18n_instance: Optional[I18n] = None
119
+
120
+
121
+ def get_i18n(language: Optional[str] = None) -> I18n:
122
+ """
123
+ Get global i18n instance
124
+
125
+ Args:
126
+ language: Optional language to set
127
+
128
+ Returns:
129
+ I18n instance
130
+ """
131
+ global _i18n_instance
132
+
133
+ if _i18n_instance is None:
134
+ _i18n_instance = I18n(default_language=language or "en")
135
+ elif language is not None:
136
+ _i18n_instance.set_language(language)
137
+
138
+ return _i18n_instance
139
+
140
+
141
+ def t(key: str, **kwargs) -> str:
142
+ """
143
+ Convenience function for translation
144
+
145
+ Args:
146
+ key: Translation key
147
+ **kwargs: Optional format parameters
148
+
149
+ Returns:
150
+ Translated string
151
+ """
152
+ return get_i18n().t(key, **kwargs)
acestep/gradio_ui/i18n/en.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 Playground💡",
4
+ "subtitle": "Pushing the Boundaries of Open-Source Music Generation"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 Dataset Explorer",
8
+ "dataset_label": "Dataset",
9
+ "dataset_info": "Choose dataset to explore",
10
+ "import_btn": "📥 Import Dataset",
11
+ "search_type_label": "Search Type",
12
+ "search_type_info": "How to find items",
13
+ "search_value_label": "Search Value",
14
+ "search_value_placeholder": "Enter keys or index (leave empty for random)",
15
+ "search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
16
+ "instruction_label": "📝 Instruction",
17
+ "instruction_placeholder": "No instruction available",
18
+ "metadata_title": "📋 Item Metadata (JSON)",
19
+ "metadata_label": "Complete Item Information",
20
+ "source_audio": "Source Audio",
21
+ "target_audio": "Target Audio",
22
+ "reference_audio": "Reference Audio",
23
+ "get_item_btn": "🔍 Get Item",
24
+ "use_src_checkbox": "Use Source Audio from Dataset",
25
+ "use_src_info": "Check to use the source audio from dataset",
26
+ "data_status_label": "📊 Data Status",
27
+ "data_status_default": "❌ No dataset imported",
28
+ "autofill_btn": "📋 Auto-fill Generation Form"
29
+ },
30
+ "service": {
31
+ "title": "🔧 Service Configuration",
32
+ "checkpoint_label": "Checkpoint File",
33
+ "checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
34
+ "refresh_btn": "🔄 Refresh",
35
+ "model_path_label": "Main Model Path",
36
+ "model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
37
+ "device_label": "Device",
38
+ "device_info": "Processing device (auto-detect recommended)",
39
+ "lm_model_path_label": "5Hz LM Model Path",
40
+ "lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
41
+ "backend_label": "5Hz LM Backend",
42
+ "backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
43
+ "init_llm_label": "Initialize 5Hz LM",
44
+ "init_llm_info": "Check to initialize 5Hz LM during service initialization",
45
+ "flash_attention_label": "Use Flash Attention",
46
+ "flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
47
+ "flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
48
+ "offload_cpu_label": "Offload to CPU",
49
+ "offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
50
+ "offload_dit_cpu_label": "Offload DiT to CPU",
51
+ "offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
52
+ "init_btn": "Initialize Service",
53
+ "status_label": "Status",
54
+ "language_label": "UI Language",
55
+ "language_info": "Select interface language"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 Required Inputs",
59
+ "task_type_label": "Task Type",
60
+ "task_type_info": "Select the task type for generation",
61
+ "instruction_label": "Instruction",
62
+ "instruction_info": "Instruction is automatically generated based on task type",
63
+ "load_btn": "Load",
64
+ "track_name_label": "Track Name",
65
+ "track_name_info": "Select track name for lego/extract tasks",
66
+ "track_classes_label": "Track Names",
67
+ "track_classes_info": "Select multiple track classes for complete task",
68
+ "audio_uploads": "🎵 Audio Uploads",
69
+ "reference_audio": "Reference Audio (optional)",
70
+ "source_audio": "Source Audio (optional)",
71
+ "convert_codes_btn": "Convert to Codes",
72
+ "lm_codes_hints": "🎼 LM Codes Hints",
73
+ "lm_codes_label": "LM Codes Hints",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "Paste LM codes hints for text2music generation",
76
+ "lm_codes_sample": "LM Codes Hints (Sample {n})",
77
+ "lm_codes_sample_info": "Codes for sample {n}",
78
+ "transcribe_btn": "Transcribe",
79
+ "repainting_controls": "🎨 Repainting Controls (seconds)",
80
+ "repainting_start": "Repainting Start",
81
+ "repainting_end": "Repainting End",
82
+ "mode_label": "Generation Mode",
83
+ "mode_info": "Simple: describe music in natural language. Custom: full control over caption and lyrics.",
84
+ "mode_simple": "Simple",
85
+ "mode_custom": "Custom",
86
+ "simple_query_label": "Song Description",
87
+ "simple_query_placeholder": "Describe the music you want to create, e.g., 'a soft Bengali love song for a quiet evening'. Leave empty for a random sample.",
88
+ "simple_query_info": "Enter a natural language description of the music you want to generate",
89
+ "simple_vocal_language_label": "Vocal Language (optional)",
90
+ "simple_vocal_language_info": "Select preferred language(s) for lyrics. Use 'unknown' for any language.",
91
+ "create_sample_btn": "Create Sample",
92
+ "caption_title": "📝 Music Caption",
93
+ "caption_label": "Music Caption (optional)",
94
+ "caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
95
+ "caption_info": "Describe the style, genre, instruments, and mood",
96
+ "lyrics_title": "📝 Lyrics",
97
+ "lyrics_label": "Lyrics (optional)",
98
+ "lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
99
+ "lyrics_info": "Song lyrics with structure",
100
+ "instrumental_label": "Instrumental",
101
+ "format_btn": "Format",
102
+ "optional_params": "⚙️ Optional Parameters",
103
+ "vocal_language_label": "Vocal Language (optional)",
104
+ "vocal_language_info": "use `unknown` for inst",
105
+ "bpm_label": "BPM (optional)",
106
+ "bpm_info": "leave empty for N/A",
107
+ "keyscale_label": "KeyScale (optional)",
108
+ "keyscale_placeholder": "Leave empty for N/A",
109
+ "keyscale_info": "A-G, #/♭, major/minor",
110
+ "timesig_label": "Time Signature (optional)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "Audio Duration (seconds)",
113
+ "duration_info": "Use -1 for random",
114
+ "batch_size_label": "Batch Size",
115
+ "batch_size_info": "Number of audio to generate (max 8)",
116
+ "advanced_settings": "🔧 Advanced Settings",
117
+ "inference_steps_label": "DiT Inference Steps",
118
+ "inference_steps_info": "Turbo: max 8, Base: max 200",
119
+ "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
120
+ "guidance_scale_info": "Higher values follow text more closely",
121
+ "seed_label": "Seed",
122
+ "seed_info": "Use comma-separated values for batches",
123
+ "random_seed_label": "Random Seed",
124
+ "random_seed_info": "Enable to auto-generate seeds",
125
+ "audio_format_label": "Audio Format",
126
+ "audio_format_info": "Audio format for saved files",
127
+ "use_adg_label": "Use ADG",
128
+ "use_adg_info": "Enable Angle Domain Guidance",
129
+ "shift_label": "Shift",
130
+ "shift_info": "Timestep shift factor for base models (range 1.0~5.0, default 3.0). Not effective for turbo models.",
131
+ "infer_method_label": "Inference Method",
132
+ "infer_method_info": "Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
133
+ "custom_timesteps_label": "Custom Timesteps",
134
+ "custom_timesteps_info": "Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
135
+ "cfg_interval_start": "CFG Interval Start",
136
+ "cfg_interval_end": "CFG Interval End",
137
+ "lm_params_title": "🤖 LM Generation Parameters",
138
+ "lm_temperature_label": "LM Temperature",
139
+ "lm_temperature_info": "5Hz LM temperature (higher = more random)",
140
+ "lm_cfg_scale_label": "LM CFG Scale",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = disabled)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = disabled)",
146
+ "lm_negative_prompt_label": "LM Negative Prompt",
147
+ "lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
149
+ "cot_metas_label": "CoT Metas",
150
+ "cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
151
+ "cot_language_label": "CoT Language",
152
+ "cot_language_info": "Generate language in CoT (chain-of-thought)",
153
+ "constrained_debug_label": "Constrained Decoding Debug",
154
+ "constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
155
+ "auto_score_label": "Auto Score",
156
+ "auto_score_info": "Automatically calculate quality scores for all generated audios",
157
+ "auto_lrc_label": "Auto LRC",
158
+ "auto_lrc_info": "Automatically generate LRC lyrics timestamps for all generated audios",
159
+ "lm_batch_chunk_label": "LM Batch Chunk Size",
160
+ "lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
161
+ "codes_strength_label": "LM Codes Strength",
162
+ "codes_strength_info": "Control how many denoising steps use LM-generated codes",
163
+ "cover_strength_label": "Audio Cover Strength",
164
+ "cover_strength_info": "Control how many denoising steps use cover mode",
165
+ "score_sensitivity_label": "Quality Score Sensitivity",
166
+ "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
167
+ "think_label": "Think",
168
+ "parallel_thinking_label": "ParallelThinking",
169
+ "generate_btn": "🎵 Generate Music",
170
+ "autogen_label": "AutoGen",
171
+ "caption_rewrite_label": "CaptionRewrite"
172
+ },
173
+ "results": {
174
+ "title": "🎵 Results",
175
+ "generated_music": "🎵 Generated Music (Sample {n})",
176
+ "send_to_src_btn": "🔗 Send To Src Audio",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
+ "save_btn": "💾 Save",
180
+ "score_btn": "📊 Score",
181
+ "lrc_btn": "🎵 LRC",
182
+ "quality_score_label": "Quality Score (Sample {n})",
183
+ "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
184
+ "codes_label": "LM Codes (Sample {n})",
185
+ "lrc_label": "Lyrics Timestamps (Sample {n})",
186
+ "lrc_placeholder": "Click 'LRC' to generate timestamps",
187
+ "details_accordion": "📊 Score & LRC & LM Codes",
188
+ "generation_status": "Generation Status",
189
+ "current_batch": "Current Batch",
190
+ "batch_indicator": "Batch {current} / {total}",
191
+ "next_batch_status": "Next Batch Status",
192
+ "prev_btn": "◀ Previous",
193
+ "next_btn": "Next ▶",
194
+ "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
195
+ "batch_results_title": "📁 Batch Results & Generation Details",
196
+ "all_files_label": "📁 All Generated Files (Download)",
197
+ "generation_details": "Generation Details"
198
+ },
199
+ "messages": {
200
+ "no_audio_to_save": "❌ No audio to save",
201
+ "save_success": "✅ Saved audio and metadata to {filename}",
202
+ "save_failed": "❌ Failed to save: {error}",
203
+ "no_file_selected": "⚠️ No file selected",
204
+ "params_loaded": "✅ Parameters loaded from {filename}",
205
+ "invalid_json": "❌ Invalid JSON file: {error}",
206
+ "load_error": "❌ Error loading file: {error}",
207
+ "example_loaded": "📁 Loaded example from {filename}",
208
+ "example_failed": "Failed to parse JSON file {filename}: {error}",
209
+ "example_error": "Error loading example: {error}",
210
+ "lm_generated": "🤖 Generated example using LM",
211
+ "lm_fallback": "Failed to generate example using LM, falling back to examples directory",
212
+ "lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
213
+ "autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
214
+ "batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
215
+ "batch_generating": "🔄 Starting background generation for Batch {n}...",
216
+ "batch_failed": "❌ Background generation failed: {error}",
217
+ "viewing_batch": "✅ Viewing Batch {n}",
218
+ "at_first_batch": "Already at first batch",
219
+ "at_last_batch": "No next batch available",
220
+ "batch_not_found": "Batch {n} not found in queue",
221
+ "no_batch_data": "No batch data found to restore.",
222
+ "params_restored": "✅ UI Parameters restored from Batch {n}",
223
+ "scoring_failed": "❌ Error: Batch data not found",
224
+ "no_codes": "❌ No audio codes available. Please generate music first.",
225
+ "score_failed": "❌ Scoring failed: {error}",
226
+ "score_error": "❌ Error calculating score: {error}",
227
+ "lrc_no_batch_data": "❌ No batch data found. Please generate music first.",
228
+ "lrc_no_extra_outputs": "❌ No extra outputs found. Condition tensors not available.",
229
+ "lrc_missing_tensors": "❌ Missing required tensors for LRC generation.",
230
+ "lrc_sample_not_exist": "❌ Sample does not exist in current batch.",
231
+ "lrc_empty_result": "⚠️ LRC generation produced empty result.",
232
+ "empty_query": "⚠️ Please enter a music description.",
233
+ "sample_creation_failed": "❌ Failed to create sample. Please try again.",
234
+ "sample_created": "✅ Sample created! Review the caption and lyrics, then click Generate Music.",
235
+ "simple_examples_not_found": "⚠️ Simple mode examples directory not found.",
236
+ "simple_examples_empty": "⚠️ No example files found in simple mode examples.",
237
+ "simple_example_loaded": "🎲 Loaded random example from {filename}",
238
+ "format_success": "✅ Caption and lyrics formatted successfully",
239
+ "format_failed": "❌ Format failed: {error}",
240
+ "skipping_metas_cot": "⚡ Skipping Phase 1 metas COT (sample already formatted)",
241
+ "invalid_timesteps_format": "⚠️ Invalid timesteps format. Using default schedule.",
242
+ "timesteps_out_of_range": "⚠️ Timesteps must be in range [0, 1]. Using default schedule.",
243
+ "timesteps_count_mismatch": "⚠️ Timesteps count ({actual}) differs from inference_steps ({expected}). Using timesteps count."
244
+ }
245
+ }
acestep/gradio_ui/i18n/ja.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
4
+ "subtitle": "オープンソース音楽生成の限界を押し広げる"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 データセットエクスプローラー",
8
+ "dataset_label": "データセット",
9
+ "dataset_info": "探索するデータセットを選択",
10
+ "import_btn": "📥 データセットをインポート",
11
+ "search_type_label": "検索タイプ",
12
+ "search_type_info": "アイテムの検索方法",
13
+ "search_value_label": "検索値",
14
+ "search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
15
+ "search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
16
+ "instruction_label": "📝 指示",
17
+ "instruction_placeholder": "利用可能な指示がありません",
18
+ "metadata_title": "📋 アイテムメタデータ (JSON)",
19
+ "metadata_label": "完全なアイテム情報",
20
+ "source_audio": "ソースオーディオ",
21
+ "target_audio": "ターゲットオーディオ",
22
+ "reference_audio": "リファレンスオーディオ",
23
+ "get_item_btn": "🔍 アイテムを取得",
24
+ "use_src_checkbox": "データセットのソースオーディオを使用",
25
+ "use_src_info": "データセットのソースオーディオを使用する場合はチェック",
26
+ "data_status_label": "📊 データステータス",
27
+ "data_status_default": "❌ データセットがインポートされていません",
28
+ "autofill_btn": "📋 生成フォームを自動入力"
29
+ },
30
+ "service": {
31
+ "title": "🔧 サービス設定",
32
+ "checkpoint_label": "チェックポイントファイル",
33
+ "checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
34
+ "refresh_btn": "🔄 更新",
35
+ "model_path_label": "メインモデルパス",
36
+ "model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
37
+ "device_label": "デバイス",
38
+ "device_info": "処理デバイス(自動検出を推奨)",
39
+ "lm_model_path_label": "5Hz LM モデルパス",
40
+ "lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
41
+ "backend_label": "5Hz LM バックエンド",
42
+ "backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
43
+ "init_llm_label": "5Hz LM を初期化",
44
+ "init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
45
+ "flash_attention_label": "Flash Attention を使用",
46
+ "flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
47
+ "flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
48
+ "offload_cpu_label": "CPUにオフロード",
49
+ "offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
50
+ "offload_dit_cpu_label": "DiTをCPUにオフロード",
51
+ "offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
52
+ "init_btn": "サービスを初期化",
53
+ "status_label": "ステータス",
54
+ "language_label": "UI言語",
55
+ "language_info": "インターフェース言語を選択"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必須入力",
59
+ "task_type_label": "タスクタイプ",
60
+ "task_type_info": "生成のタスクタイプを選択",
61
+ "instruction_label": "指示",
62
+ "instruction_info": "指示はタスクタイプに基づいて自動生成されます",
63
+ "load_btn": "読み込む",
64
+ "track_name_label": "トラック名",
65
+ "track_name_info": "lego/extractタスクのトラック名を選択",
66
+ "track_classes_label": "トラック名",
67
+ "track_classes_info": "completeタスクの複数のトラッククラスを選択",
68
+ "audio_uploads": "🎵 オーディオアップロード",
69
+ "reference_audio": "リファレンスオーディオ(オプション)",
70
+ "source_audio": "ソースオーディオ(オプション)",
71
+ "convert_codes_btn": "コードに変換",
72
+ "lm_codes_hints": "🎼 LM コードヒント",
73
+ "lm_codes_label": "LM コードヒント",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
76
+ "lm_codes_sample": "LM コードヒント(サンプル {n})",
77
+ "lm_codes_sample_info": "サンプル{n}のコード",
78
+ "transcribe_btn": "転写",
79
+ "repainting_controls": "🎨 再描画コントロール(秒)",
80
+ "repainting_start": "再描画開始",
81
+ "repainting_end": "再描画終了",
82
+ "mode_label": "生成モード",
83
+ "mode_info": "シンプル:自然言語で音楽を説明��カスタム:キャプションと歌詞を完全にコントロール。",
84
+ "mode_simple": "シンプル",
85
+ "mode_custom": "カスタム",
86
+ "simple_query_label": "曲の説明",
87
+ "simple_query_placeholder": "作成したい音楽を説明してください。例:'静かな夜のための優しいベンガルのラブソング'。空欄の場合はランダムなサンプルが生成されます。",
88
+ "simple_query_info": "生成したい音楽の自然言語の説明を入力",
89
+ "simple_vocal_language_label": "ボーカル言語(オプション)",
90
+ "simple_vocal_language_info": "歌詞の希望言語を選択。任意の言語の場合は'unknown'を使用。",
91
+ "create_sample_btn": "サンプル作成",
92
+ "caption_title": "📝 音楽キャプション",
93
+ "caption_label": "音楽キャプション(オプション)",
94
+ "caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
95
+ "caption_info": "スタイル、ジャンル、楽器、ムードを説明",
96
+ "lyrics_title": "📝 歌詞",
97
+ "lyrics_label": "歌詞(オプション)",
98
+ "lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
99
+ "lyrics_info": "構造を持つ曲の歌詞",
100
+ "instrumental_label": "インストゥルメンタル",
101
+ "format_btn": "フォーマット",
102
+ "optional_params": "⚙️ オプションパラメータ",
103
+ "vocal_language_label": "ボーカル言語(オプション)",
104
+ "vocal_language_info": "インストには`unknown`を使用",
105
+ "bpm_label": "BPM(オプション)",
106
+ "bpm_info": "空白の場合はN/A",
107
+ "keyscale_label": "キースケール(オプション)",
108
+ "keyscale_placeholder": "空白の場合はN/A",
109
+ "keyscale_info": "A-G, #/♭, メジャー/マイナー",
110
+ "timesig_label": "拍子記号(オプション)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "オーディオ長(秒)",
113
+ "duration_info": "ランダムの場合は-1を使用",
114
+ "batch_size_label": "バッチサイズ",
115
+ "batch_size_info": "生成するオーディオの数(最大8)",
116
+ "advanced_settings": "🔧 詳細設定",
117
+ "inference_steps_label": "DiT 推論ステップ",
118
+ "inference_steps_info": "Turbo: 最大8、Base: 最大200",
119
+ "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
120
+ "guidance_scale_info": "値が高いほどテキストに忠実に従う",
121
+ "seed_label": "シード",
122
+ "seed_info": "バッチにはカンマ区切りの値を使用",
123
+ "random_seed_label": "ランダムシード",
124
+ "random_seed_info": "有効にすると自動的にシードを生成",
125
+ "audio_format_label": "オーディオフォーマット",
126
+ "audio_format_info": "保存ファイルのオーディオフォーマット",
127
+ "use_adg_label": "ADG を使用",
128
+ "use_adg_info": "角度ドメインガイダンスを有効化",
129
+ "shift_label": "シフト",
130
+ "shift_info": "baseモデル用タイムステップシフト係数 (範囲 1.0~5.0、デフォルト 3.0)。turboモデルには無効。",
131
+ "infer_method_label": "推論方法",
132
+ "infer_method_info": "拡散推論方法。ODE (オイラー) は高速、SDE (確率的) は異なる結果を生成する可能性があります。",
133
+ "custom_timesteps_label": "カスタムタイムステップ",
134
+ "custom_timesteps_info": "オプション:1.0から0.0へのカンマ区切り値(例:'0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。推論ステップとシフトを上書きします。",
135
+ "cfg_interval_start": "CFG 間隔開始",
136
+ "cfg_interval_end": "CFG 間隔終了",
137
+ "lm_params_title": "🤖 LM 生成パラメータ",
138
+ "lm_temperature_label": "LM 温度",
139
+ "lm_temperature_info": "5Hz LM温度(高いほどランダム)",
140
+ "lm_cfg_scale_label": "LM CFG スケール",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = 無効)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = 無効)",
146
+ "lm_negative_prompt_label": "LM ネガティブプロンプト",
147
+ "lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
149
+ "cot_metas_label": "CoT メタデータ",
150
+ "cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
151
+ "cot_language_label": "CoT 言語",
152
+ "cot_language_info": "CoTで言語を生成(思考の連鎖)",
153
+ "constrained_debug_label": "制約付きデコーディングデバッグ",
154
+ "constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
155
+ "auto_score_label": "自動スコアリング",
156
+ "auto_score_info": "生成���れたすべてのオーディオの品質スコアを自動計算",
157
+ "auto_lrc_label": "自動 LRC",
158
+ "auto_lrc_info": "生成されたすべてのオーディオのLRC歌詞タイムスタンプを自動生成",
159
+ "lm_batch_chunk_label": "LM バッチチャンクサイズ",
160
+ "lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
161
+ "codes_strength_label": "LM コード強度",
162
+ "codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
163
+ "cover_strength_label": "オーディオカバー強度",
164
+ "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
165
+ "score_sensitivity_label": "品質スコア感度",
166
+ "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
167
+ "think_label": "思考",
168
+ "parallel_thinking_label": "並列思考",
169
+ "generate_btn": "🎵 音楽を生成",
170
+ "autogen_label": "自動生成",
171
+ "caption_rewrite_label": "キャプション書き換え"
172
+ },
173
+ "results": {
174
+ "title": "🎵 結果",
175
+ "generated_music": "🎵 生成された音楽(サンプル {n})",
176
+ "send_to_src_btn": "🔗 ソースオーディオに送信",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
+ "save_btn": "💾 保存",
180
+ "score_btn": "📊 スコア",
181
+ "lrc_btn": "🎵 LRC",
182
+ "quality_score_label": "品質スコア(サンプル {n})",
183
+ "quality_score_placeholder": "'スコア'をクリックしてパープレキシティベースの品質スコアを計算",
184
+ "codes_label": "LM コード(サンプル {n})",
185
+ "lrc_label": "歌詞タイムスタンプ(サンプル {n})",
186
+ "lrc_placeholder": "'LRC'をクリックしてタイムスタンプを生成",
187
+ "details_accordion": "📊 スコア & LRC & LM コード",
188
+ "generation_status": "生成ステータス",
189
+ "current_batch": "現在のバッチ",
190
+ "batch_indicator": "バッチ {current} / {total}",
191
+ "next_batch_status": "次のバッチステータス",
192
+ "prev_btn": "◀ 前へ",
193
+ "next_btn": "次へ ▶",
194
+ "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
195
+ "batch_results_title": "📁 バッチ結果と生成詳細",
196
+ "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
197
+ "generation_details": "生成詳細"
198
+ },
199
+ "messages": {
200
+ "no_audio_to_save": "❌ 保存するオーディオがありません",
201
+ "save_success": "✅ オーディオとメタデータを {filename} に保存しました",
202
+ "save_failed": "❌ 保存に失敗しました: {error}",
203
+ "no_file_selected": "⚠️ ファイルが選択されていません",
204
+ "params_loaded": "✅ {filename} からパラメータを読み込みました",
205
+ "invalid_json": "❌ 無効なJSONファイル: {error}",
206
+ "load_error": "❌ ファイルの読み込みエラー: {error}",
207
+ "example_loaded": "📁 {filename} からサンプルを読み込みました",
208
+ "example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
209
+ "example_error": "サンプル読み込みエラー: {error}",
210
+ "lm_generated": "🤖 LMを使用してサンプルを生成しました",
211
+ "lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
212
+ "lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
213
+ "autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
214
+ "batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
215
+ "batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
216
+ "batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
217
+ "viewing_batch": "✅ バッチ {n} を表示中",
218
+ "at_first_batch": "すでに最初のバッチです",
219
+ "at_last_batch": "次のバッチはありません",
220
+ "batch_not_found": "キューにバッチ {n} が見つかりません",
221
+ "no_batch_data": "復元するバッチデータがありません。",
222
+ "params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
223
+ "scoring_failed": "❌ エラー: バッチデータが見つかりません",
224
+ "no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
225
+ "score_failed": "❌ スコアリングに失敗しました: {error}",
226
+ "score_error": "❌ スコア計算エラー: {error}",
227
+ "lrc_no_batch_data": "❌ バッチデータが見つかりません。最初に音楽を生成してください。",
228
+ "lrc_no_extra_outputs": "❌ 追加出力が見つかりません。条件テンソルが利用できません。",
229
+ "lrc_missing_tensors": "❌ LRC生成に必要なテンソルがありません。",
230
+ "lrc_sample_not_exist": "❌ 現在のバッチにサンプルが存在しません。",
231
+ "lrc_empty_result": "⚠️ LRC生成の結果が空です。",
232
+ "empty_query": "⚠️ 音楽の説明を入力してください。",
233
+ "sample_creation_failed": "❌ サンプルの作成に失敗しました。もう一度お試しください。",
234
+ "sample_created": "✅ サンプルが作成されました!キャプションと歌詞を確認して、音楽を生成をクリックしてください。",
235
+ "simple_examples_not_found": "⚠️ シンプルモードサンプルディレクトリが見つかりません。",
236
+ "simple_examples_empty": "⚠️ シンプルモードサンプルにファイルがありません。",
237
+ "simple_example_loaded": "🎲 {filename} からランダムサンプルを読み込みました",
238
+ "format_success": "✅ キャプションと歌詞のフォーマットに成功しました",
239
+ "format_failed": "❌ フォーマットに失敗しました: {error}",
240
+ "skipping_metas_cot": "⚡ Phase 1 メタデータ COT をスキップ(サンプルは既にフォーマット済み)",
241
+ "invalid_timesteps_format": "⚠️ タイムステップ形式が無効です。デフォルトスケジュールを使用します。",
242
+ "timesteps_out_of_range": "⚠️ タイムステップは [0, 1] の範囲内である必要があります。デフォルトスケジュールを使用します。",
243
+ "timesteps_count_mismatch": "⚠️ タイムステップ数 ({actual}) が推論ステップ数 ({expected}) と異なります。タイムステップ数を使用します。"
244
+ }
245
+ }
acestep/gradio_ui/i18n/zh.json ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 演练场💡",
4
+ "subtitle": "推动开源音乐生成的边界"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 数据集浏览器",
8
+ "dataset_label": "数据集",
9
+ "dataset_info": "选择要浏览的数据集",
10
+ "import_btn": "📥 导入数据集",
11
+ "search_type_label": "搜索类型",
12
+ "search_type_info": "如何查找项目",
13
+ "search_value_label": "搜索值",
14
+ "search_value_placeholder": "输入键或索引(留空表示随机)",
15
+ "search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
16
+ "instruction_label": "📝 指令",
17
+ "instruction_placeholder": "无可用指令",
18
+ "metadata_title": "📋 项目元数据 (JSON)",
19
+ "metadata_label": "完整项目信息",
20
+ "source_audio": "源音频",
21
+ "target_audio": "目标音频",
22
+ "reference_audio": "参考音频",
23
+ "get_item_btn": "🔍 获取项目",
24
+ "use_src_checkbox": "使用数据集中的源音频",
25
+ "use_src_info": "勾选以使用数据集中的源音频",
26
+ "data_status_label": "📊 数据状态",
27
+ "data_status_default": "❌ 未导入数据集",
28
+ "autofill_btn": "📋 自动填充生成表单"
29
+ },
30
+ "service": {
31
+ "title": "🔧 服务配置",
32
+ "checkpoint_label": "检查点文件",
33
+ "checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
34
+ "refresh_btn": "🔄 刷新",
35
+ "model_path_label": "主模型路径",
36
+ "model_path_info": "选择模型配置目录(从检查点自动扫描)",
37
+ "device_label": "设备",
38
+ "device_info": "处理设备(建议自动检测)",
39
+ "lm_model_path_label": "5Hz LM 模型路径",
40
+ "lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
41
+ "backend_label": "5Hz LM 后端",
42
+ "backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
43
+ "init_llm_label": "初始化 5Hz LM",
44
+ "init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
45
+ "flash_attention_label": "使用Flash Attention",
46
+ "flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
47
+ "flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
48
+ "offload_cpu_label": "卸载到CPU",
49
+ "offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
50
+ "offload_dit_cpu_label": "将DiT卸载到CPU",
51
+ "offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
52
+ "init_btn": "初始化服务",
53
+ "status_label": "状态",
54
+ "language_label": "界面语言",
55
+ "language_info": "选择界面语言"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必需输入",
59
+ "task_type_label": "任务类型",
60
+ "task_type_info": "选择生成的任务类型",
61
+ "instruction_label": "指令",
62
+ "instruction_info": "指令根据任务类型自动生成",
63
+ "load_btn": "加载",
64
+ "track_name_label": "音轨名称",
65
+ "track_name_info": "为lego/extract任务选择音轨名称",
66
+ "track_classes_label": "音轨名称",
67
+ "track_classes_info": "为complete任务选择多个音轨类别",
68
+ "audio_uploads": "🎵 音频上传",
69
+ "reference_audio": "参考音频(可选)",
70
+ "source_audio": "源音频(可选)",
71
+ "convert_codes_btn": "转换为代码",
72
+ "lm_codes_hints": "🎼 LM 代码提示",
73
+ "lm_codes_label": "LM 代码提示",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "粘贴用于text2music生成的LM代码提示",
76
+ "lm_codes_sample": "LM 代码提示(样本 {n})",
77
+ "lm_codes_sample_info": "样本{n}的代码",
78
+ "transcribe_btn": "转录",
79
+ "repainting_controls": "🎨 重绘控制(秒)",
80
+ "repainting_start": "重绘开始",
81
+ "repainting_end": "重绘结束",
82
+ "mode_label": "生成模式",
83
+ "mode_info": "简单模式:用自然语言描述音乐。自定义模式:完全控制描述和歌词。",
84
+ "mode_simple": "简单",
85
+ "mode_custom": "自定义",
86
+ "simple_query_label": "歌曲描述",
87
+ "simple_query_placeholder": "描述你想创作的音乐,例如:'给我生成一首暗黑的戏剧古风,歌词要华丽'。留空则随机生成样本。",
88
+ "simple_query_info": "输入你想生成的音乐的自然语言描述",
89
+ "simple_vocal_language_label": "人声语言(可选)",
90
+ "simple_vocal_language_info": "选择歌词的首选语言。使用 'unknown' 表示任意语言。",
91
+ "create_sample_btn": "创建样本",
92
+ "caption_title": "📝 音乐描述",
93
+ "caption_label": "音乐描述(可选)",
94
+ "caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
95
+ "caption_info": "描述风格、流派、乐器和情绪",
96
+ "lyrics_title": "📝 歌词",
97
+ "lyrics_label": "歌词(可选)",
98
+ "lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
99
+ "lyrics_info": "带有结构的歌曲歌词",
100
+ "instrumental_label": "纯音乐",
101
+ "format_btn": "格式化",
102
+ "optional_params": "⚙️ 可选参数",
103
+ "vocal_language_label": "人声语言(可选)",
104
+ "vocal_language_info": "纯音乐使用 `unknown`",
105
+ "bpm_label": "BPM(可选)",
106
+ "bpm_info": "留空表示N/A",
107
+ "keyscale_label": "调性(可选)",
108
+ "keyscale_placeholder": "留空表示N/A",
109
+ "keyscale_info": "A-G, #/♭, 大调/小调",
110
+ "timesig_label": "拍号(可选)",
111
+ "timesig_info": "2/4, 3/4, 4/4...",
112
+ "duration_label": "音频时长(秒)",
113
+ "duration_info": "使用-1表示随机",
114
+ "batch_size_label": "批量大小",
115
+ "batch_size_info": "要生成的音频数量(最多8个)",
116
+ "advanced_settings": "🔧 高级设置",
117
+ "inference_steps_label": "DiT 推理步数",
118
+ "inference_steps_info": "Turbo: 最多8, Base: 最多200",
119
+ "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
120
+ "guidance_scale_info": "更高的值更紧密地遵循文本",
121
+ "seed_label": "种子",
122
+ "seed_info": "批量使用逗号分隔的值",
123
+ "random_seed_label": "随机种子",
124
+ "random_seed_info": "启用以自动生成种子",
125
+ "audio_format_label": "音频格式",
126
+ "audio_format_info": "保存文件的音频格式",
127
+ "use_adg_label": "使用 ADG",
128
+ "use_adg_info": "启用角域引导",
129
+ "shift_label": "Shift",
130
+ "shift_info": "时间步偏移因子,仅对 base 模型生效 (范围 1.0~5.0,默认 3.0)。对 turbo 模型无效。",
131
+ "infer_method_label": "推理方法",
132
+ "infer_method_info": "扩散推理方法。ODE (欧拉) 更快,SDE (随机) 可能产生不同结果。",
133
+ "custom_timesteps_label": "自定义时间步",
134
+ "custom_timesteps_info": "可选:从 1.0 到 0.0 的逗号分隔值(例如 '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0')。会覆盖推理步数和 shift 设置。",
135
+ "cfg_interval_start": "CFG 间隔开始",
136
+ "cfg_interval_end": "CFG 间隔结束",
137
+ "lm_params_title": "🤖 LM 生成参数",
138
+ "lm_temperature_label": "LM 温度",
139
+ "lm_temperature_info": "5Hz LM温度(越高越随机)",
140
+ "lm_cfg_scale_label": "LM CFG 比例",
141
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
142
+ "lm_top_k_label": "LM Top-K",
143
+ "lm_top_k_info": "Top-K (0 = 禁用)",
144
+ "lm_top_p_label": "LM Top-P",
145
+ "lm_top_p_info": "Top-P (1.0 = 禁用)",
146
+ "lm_negative_prompt_label": "LM 负面提示",
147
+ "lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
148
+ "lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
149
+ "cot_metas_label": "CoT 元数据",
150
+ "cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
151
+ "cot_language_label": "CoT 语言",
152
+ "cot_language_info": "在CoT中生成语言(思维链)",
153
+ "constrained_debug_label": "约束解码调试",
154
+ "constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
155
+ "auto_score_label": "自动评分",
156
+ "auto_score_info": "自动计算所有生成音频的质量分数",
157
+ "auto_lrc_label": "自动 LRC",
158
+ "auto_lrc_info": "自动为所有生成的音频生成LRC歌词时间戳",
159
+ "lm_batch_chunk_label": "LM 批量块大小",
160
+ "lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
161
+ "codes_strength_label": "LM 代码强度",
162
+ "codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
163
+ "cover_strength_label": "音频覆盖强度",
164
+ "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
165
+ "score_sensitivity_label": "质量评分敏感度",
166
+ "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
167
+ "think_label": "思考",
168
+ "parallel_thinking_label": "并行思考",
169
+ "generate_btn": "🎵 生成音乐",
170
+ "autogen_label": "自动生成",
171
+ "caption_rewrite_label": "描述重写"
172
+ },
173
+ "results": {
174
+ "title": "🎵 结果",
175
+ "generated_music": "🎵 生成的音乐(样本 {n})",
176
+ "send_to_src_btn": "🔗 发送到源音频",
177
+ "send_to_cover_btn": "🔗 Send To Cover",
178
+ "send_to_repaint_btn": "🔗 Send To Repaint",
179
+ "save_btn": "💾 保存",
180
+ "score_btn": "📊 评分",
181
+ "lrc_btn": "🎵 LRC",
182
+ "quality_score_label": "质量分数(样本 {n})",
183
+ "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
184
+ "codes_label": "LM 代码(样本 {n})",
185
+ "lrc_label": "歌词时间戳(样本 {n})",
186
+ "lrc_placeholder": "点击'LRC'生成时间戳",
187
+ "details_accordion": "📊 评分与LRC与LM代码",
188
+ "generation_status": "生成状态",
189
+ "current_batch": "当前批次",
190
+ "batch_indicator": "批次 {current} / {total}",
191
+ "next_batch_status": "下一批次状态",
192
+ "prev_btn": "◀ 上一个",
193
+ "next_btn": "下一个 ▶",
194
+ "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
195
+ "batch_results_title": "📁 批量结果和生成详情",
196
+ "all_files_label": "📁 所有生成的文件(��载)",
197
+ "generation_details": "生成详情"
198
+ },
199
+ "messages": {
200
+ "no_audio_to_save": "❌ 没有要保存的音频",
201
+ "save_success": "✅ 已将音频和元数据保存到 {filename}",
202
+ "save_failed": "❌ 保存失败: {error}",
203
+ "no_file_selected": "⚠️ 未选择文件",
204
+ "params_loaded": "✅ 已从 {filename} 加载参数",
205
+ "invalid_json": "❌ 无效的JSON文件: {error}",
206
+ "load_error": "❌ 加载文件时出错: {error}",
207
+ "example_loaded": "📁 已从 {filename} 加载示例",
208
+ "example_failed": "解析JSON文件 {filename} 失败: {error}",
209
+ "example_error": "加载示例时出错: {error}",
210
+ "lm_generated": "🤖 使用LM生成的示例",
211
+ "lm_fallback": "使用LM生成示例失败,回退到示例目录",
212
+ "lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
213
+ "autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
214
+ "batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
215
+ "batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
216
+ "batch_failed": "❌ 后台生成失败: {error}",
217
+ "viewing_batch": "✅ 查看批次 {n}",
218
+ "at_first_batch": "已在第一批次",
219
+ "at_last_batch": "没有下一批次可用",
220
+ "batch_not_found": "在队列中未找到批次 {n}",
221
+ "no_batch_data": "没有要恢复的批次数据。",
222
+ "params_restored": "✅ 已从批次 {n} 恢复UI参数",
223
+ "scoring_failed": "❌ 错误: 未找到批次数据",
224
+ "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
225
+ "score_failed": "❌ 评分失败: {error}",
226
+ "score_error": "❌ 计算分数时出错: {error}",
227
+ "lrc_no_batch_data": "❌ 未找到批次数据。请先生成音乐。",
228
+ "lrc_no_extra_outputs": "❌ 未找到额外输出。条件张量不可用。",
229
+ "lrc_missing_tensors": "❌ 缺少LRC生成所需的张量。",
230
+ "lrc_sample_not_exist": "❌ 当前批次中不存在该样本。",
231
+ "lrc_empty_result": "⚠️ LRC生成结果为空。",
232
+ "empty_query": "⚠️ 请输入音乐描述。",
233
+ "sample_creation_failed": "❌ 创建样本失败。请重试。",
234
+ "sample_created": "✅ 样本已创建!检查描述和歌词,然后点击生成音乐。",
235
+ "simple_examples_not_found": "⚠️ 未找到简单模式示例目录。",
236
+ "simple_examples_empty": "⚠️ 简单模式示例中没有示例文件。",
237
+ "simple_example_loaded": "🎲 已从 {filename} 加载随机示例",
238
+ "format_success": "✅ 描述和歌词格式化成功",
239
+ "format_failed": "❌ 格式化失败: {error}",
240
+ "skipping_metas_cot": "⚡ 跳过 Phase 1 元数据 COT(样本已格式化)",
241
+ "invalid_timesteps_format": "⚠️ 时间步格式无效,使用默认调度。",
242
+ "timesteps_out_of_range": "⚠️ 时间步必须在 [0, 1] 范围内,使用默认调度。",
243
+ "timesteps_count_mismatch": "⚠️ 时间步数量 ({actual}) 与推理步数 ({expected}) 不匹配,将使用时间步数量。"
244
+ }
245
+ }
acestep/gradio_ui/interfaces/__init__.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Components Module
3
+ Contains all Gradio interface component definitions and layouts
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import get_i18n, t
7
+ from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
+ from acestep.gradio_ui.interfaces.generation import create_generation_section
9
+ from acestep.gradio_ui.interfaces.result import create_results_section
10
+ from acestep.gradio_ui.interfaces.training import create_training_section
11
+ from acestep.gradio_ui.events import setup_event_handlers, setup_training_event_handlers
12
+
13
+
14
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
15
+ """
16
+ Create Gradio interface
17
+
18
+ Args:
19
+ dit_handler: DiT handler instance
20
+ llm_handler: LM handler instance
21
+ dataset_handler: Dataset handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
25
+
26
+ Returns:
27
+ Gradio Blocks instance
28
+ """
29
+ # Initialize i18n with selected language
30
+ i18n = get_i18n(language)
31
+
32
+ with gr.Blocks(
33
+ title=t("app.title"),
34
+ theme=gr.themes.Soft(),
35
+ css="""
36
+ .main-header {
37
+ text-align: center;
38
+ margin-bottom: 2rem;
39
+ }
40
+ .section-header {
41
+ background: linear-gradient(90deg, #4CAF50, #45a049);
42
+ color: white;
43
+ padding: 10px;
44
+ border-radius: 5px;
45
+ margin: 10px 0;
46
+ }
47
+ .lm-hints-row {
48
+ align-items: stretch;
49
+ }
50
+ .lm-hints-col {
51
+ display: flex;
52
+ }
53
+ .lm-hints-col > div {
54
+ flex: 1;
55
+ display: flex;
56
+ }
57
+ .lm-hints-btn button {
58
+ height: 100%;
59
+ width: 100%;
60
+ }
61
+ """
62
+ ) as demo:
63
+
64
+ gr.HTML(f"""
65
+ <div class="main-header">
66
+ <h1>{t("app.title")}</h1>
67
+ <p>{t("app.subtitle")}</p>
68
+ <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 10px 20px; border-radius: 8px; text-align: center; margin: 8px auto; max-width: 600px;">
69
+ <span style="color: white; font-size: 15px;">
70
+ 🚀 Want faster &amp; more stable experience? Try
71
+ <a href="https://acemusic.ai" target="_blank" style="color: #ffd700; font-weight: bold; text-decoration: underline;">acemusic.ai</a>
72
+ — 100% free!
73
+ </span>
74
+ </div>
75
+ <p style="margin-top: 0.5rem;">
76
+ <a href="https://ace-step.github.io/ace-step-v1.5.github.io/" target="_blank">Project</a> |
77
+ <a href="https://huggingface.co/collections/ACE-Step/ace-step-15" target="_blank">Hugging Face</a> |
78
+ <a href="https://modelscope.cn/models/ACE-Step/ACE-Step-v1-5" target="_blank">ModelScope</a> |
79
+ <a href="https://github.com/ACE-Step/ACE-Step-1.5" target="_blank">GitHub</a> |
80
+ <a href="https://discord.gg/PeWDxrkdj7" target="_blank">Discord</a> |
81
+ <a href="https://arxiv.org/abs/2602.00744" target="_blank">Technical Report</a>
82
+ </p>
83
+ </div>
84
+ """)
85
+
86
+ # Dataset Explorer Section
87
+ dataset_section = create_dataset_section(dataset_handler)
88
+
89
+ # Generation Section (pass init_params and language to support pre-initialization)
90
+ generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
91
+
92
+ # Results Section
93
+ results_section = create_results_section(dit_handler)
94
+
95
+ # Training Section (LoRA training and dataset builder)
96
+ # Pass init_params to support hiding in service mode
97
+ training_section = create_training_section(dit_handler, llm_handler, init_params=init_params)
98
+
99
+ # Connect event handlers (pass init_params for multi-model support)
100
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section, init_params=init_params)
101
+
102
+ # Connect training event handlers
103
+ setup_training_event_handlers(demo, dit_handler, llm_handler, training_section)
104
+
105
+ return demo
acestep/gradio_ui/interfaces/dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Dataset Section Module
3
+ Contains dataset explorer section component definitions
4
+ """
5
+ import gradio as gr
6
+
7
+
8
+ def create_dataset_section(dataset_handler) -> dict:
9
+ """Create dataset explorer section"""
10
+ with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
11
+ with gr.Row(equal_height=True):
12
+ dataset_type = gr.Dropdown(
13
+ choices=["train", "test"],
14
+ value="train",
15
+ label="Dataset",
16
+ info="Choose dataset to explore",
17
+ scale=2
18
+ )
19
+ import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
20
+
21
+ search_type = gr.Dropdown(
22
+ choices=["keys", "idx", "random"],
23
+ value="random",
24
+ label="Search Type",
25
+ info="How to find items",
26
+ scale=1
27
+ )
28
+ search_value = gr.Textbox(
29
+ label="Search Value",
30
+ placeholder="Enter keys or index (leave empty for random)",
31
+ info="Keys: exact match, Index: 0 to dataset size-1",
32
+ scale=2
33
+ )
34
+
35
+ instruction_display = gr.Textbox(
36
+ label="📝 Instruction",
37
+ interactive=False,
38
+ placeholder="No instruction available",
39
+ lines=1
40
+ )
41
+
42
+ repaint_viz_plot = gr.Plot()
43
+
44
+ with gr.Accordion("📋 Item Metadata (JSON)", open=False):
45
+ item_info_json = gr.Code(
46
+ label="Complete Item Information",
47
+ language="json",
48
+ interactive=False,
49
+ lines=15
50
+ )
51
+
52
+ with gr.Row(equal_height=True):
53
+ item_src_audio = gr.Audio(
54
+ label="Source Audio",
55
+ type="filepath",
56
+ interactive=False,
57
+ scale=8
58
+ )
59
+ get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
60
+
61
+ with gr.Row(equal_height=True):
62
+ item_target_audio = gr.Audio(
63
+ label="Target Audio",
64
+ type="filepath",
65
+ interactive=False,
66
+ scale=8
67
+ )
68
+ item_refer_audio = gr.Audio(
69
+ label="Reference Audio",
70
+ type="filepath",
71
+ interactive=False,
72
+ scale=2
73
+ )
74
+
75
+ with gr.Row():
76
+ use_src_checkbox = gr.Checkbox(
77
+ label="Use Source Audio from Dataset",
78
+ value=True,
79
+ info="Check to use the source audio from dataset"
80
+ )
81
+
82
+ data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
83
+ auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
84
+
85
+ return {
86
+ "dataset_type": dataset_type,
87
+ "import_dataset_btn": import_dataset_btn,
88
+ "search_type": search_type,
89
+ "search_value": search_value,
90
+ "instruction_display": instruction_display,
91
+ "repaint_viz_plot": repaint_viz_plot,
92
+ "item_info_json": item_info_json,
93
+ "item_src_audio": item_src_audio,
94
+ "get_item_btn": get_item_btn,
95
+ "item_target_audio": item_target_audio,
96
+ "item_refer_audio": item_refer_audio,
97
+ "use_src_checkbox": use_src_checkbox,
98
+ "data_status": data_status,
99
+ "auto_fill_btn": auto_fill_btn,
100
+ }
101
+
acestep/gradio_ui/interfaces/generation.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Generation Section Module
3
+ Contains generation section component definitions - Simplified UI
4
+ """
5
+ import gradio as gr
6
+ from acestep.constants import (
7
+ VALID_LANGUAGES,
8
+ TRACK_NAMES,
9
+ TASK_TYPES_TURBO,
10
+ TASK_TYPES_BASE,
11
+ DEFAULT_DIT_INSTRUCTION,
12
+ )
13
+ from acestep.gradio_ui.i18n import t
14
+
15
+
16
+ def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
17
+ """Create generation section with simplified UI
18
+
19
+ Args:
20
+ dit_handler: DiT handler instance
21
+ llm_handler: LM handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja')
25
+ """
26
+ # Check if service is pre-initialized
27
+ service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
28
+
29
+ # Check if running in service mode (restricted UI)
30
+ service_mode = init_params is not None and init_params.get('service_mode', False)
31
+
32
+ # Get current language from init_params if available
33
+ current_language = init_params.get('language', language) if init_params else language
34
+
35
+ # Get available models
36
+ available_dit_models = init_params.get('available_dit_models', []) if init_params else []
37
+ current_model_value = init_params.get('config_path', '') if init_params else ''
38
+ show_model_selector = len(available_dit_models) > 1
39
+
40
+ with gr.Group():
41
+ # ==================== Service Configuration (Hidden in service mode) ====================
42
+ accordion_open = not service_pre_initialized
43
+ accordion_visible = not service_pre_initialized
44
+ with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
45
+ # Language selector at the top
46
+ with gr.Row():
47
+ language_dropdown = gr.Dropdown(
48
+ choices=[
49
+ ("English", "en"),
50
+ ("中文", "zh"),
51
+ ("日本語", "ja"),
52
+ ],
53
+ value=current_language,
54
+ label=t("service.language_label"),
55
+ info=t("service.language_info"),
56
+ scale=1,
57
+ )
58
+
59
+ with gr.Row(equal_height=True):
60
+ with gr.Column(scale=4):
61
+ checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
62
+ checkpoint_dropdown = gr.Dropdown(
63
+ label=t("service.checkpoint_label"),
64
+ choices=dit_handler.get_available_checkpoints(),
65
+ value=checkpoint_value,
66
+ info=t("service.checkpoint_info")
67
+ )
68
+ with gr.Column(scale=1, min_width=90):
69
+ refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
70
+
71
+ with gr.Row():
72
+ available_models = dit_handler.get_available_acestep_v15_models()
73
+ default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
74
+ config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
75
+ config_path = gr.Dropdown(
76
+ label=t("service.model_path_label"),
77
+ choices=available_models,
78
+ value=config_path_value,
79
+ info=t("service.model_path_info")
80
+ )
81
+ device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
82
+ device = gr.Dropdown(
83
+ choices=["auto", "cuda", "cpu"],
84
+ value=device_value,
85
+ label=t("service.device_label"),
86
+ info=t("service.device_info")
87
+ )
88
+
89
+ with gr.Row():
90
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
91
+ default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
92
+ lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
93
+ lm_model_path = gr.Dropdown(
94
+ label=t("service.lm_model_path_label"),
95
+ choices=available_lm_models,
96
+ value=lm_model_path_value,
97
+ info=t("service.lm_model_path_info")
98
+ )
99
+ backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
100
+ backend_dropdown = gr.Dropdown(
101
+ choices=["vllm", "pt"],
102
+ value=backend_value,
103
+ label=t("service.backend_label"),
104
+ info=t("service.backend_info")
105
+ )
106
+
107
+ with gr.Row():
108
+ init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
109
+ init_llm_checkbox = gr.Checkbox(
110
+ label=t("service.init_llm_label"),
111
+ value=init_llm_value,
112
+ info=t("service.init_llm_info"),
113
+ )
114
+ flash_attn_available = dit_handler.is_flash_attention_available()
115
+ use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
116
+ use_flash_attention_checkbox = gr.Checkbox(
117
+ label=t("service.flash_attention_label"),
118
+ value=use_flash_attention_value,
119
+ interactive=flash_attn_available,
120
+ info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
121
+ )
122
+ offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
123
+ offload_to_cpu_checkbox = gr.Checkbox(
124
+ label=t("service.offload_cpu_label"),
125
+ value=offload_to_cpu_value,
126
+ info=t("service.offload_cpu_info")
127
+ )
128
+ offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
129
+ offload_dit_to_cpu_checkbox = gr.Checkbox(
130
+ label=t("service.offload_dit_cpu_label"),
131
+ value=offload_dit_to_cpu_value,
132
+ info=t("service.offload_dit_cpu_info")
133
+ )
134
+
135
+ init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
136
+ init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
137
+ init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
138
+
139
+ # LoRA Configuration Section
140
+ gr.HTML("<hr><h4>🔧 LoRA Adapter</h4>")
141
+ with gr.Row():
142
+ lora_path = gr.Textbox(
143
+ label="LoRA Path",
144
+ placeholder="./lora_output/final/adapter",
145
+ info="Path to trained LoRA adapter directory",
146
+ scale=3,
147
+ )
148
+ load_lora_btn = gr.Button("📥 Load LoRA", variant="secondary", scale=1)
149
+ unload_lora_btn = gr.Button("🗑️ Unload", variant="secondary", scale=1)
150
+ with gr.Row():
151
+ use_lora_checkbox = gr.Checkbox(
152
+ label="Use LoRA",
153
+ value=False,
154
+ info="Enable LoRA adapter for inference",
155
+ scale=1,
156
+ )
157
+ lora_status = gr.Textbox(
158
+ label="LoRA Status",
159
+ value="No LoRA loaded",
160
+ interactive=False,
161
+ scale=2,
162
+ )
163
+
164
+ # ==================== Model Selector (Top, only when multiple models) ====================
165
+ with gr.Row(visible=show_model_selector):
166
+ dit_model_selector = gr.Dropdown(
167
+ choices=available_dit_models,
168
+ value=current_model_value,
169
+ label="models",
170
+ scale=1,
171
+ )
172
+
173
+ # Hidden dropdown when only one model (for event handler compatibility)
174
+ if not show_model_selector:
175
+ dit_model_selector = gr.Dropdown(
176
+ choices=available_dit_models if available_dit_models else [current_model_value],
177
+ value=current_model_value,
178
+ visible=False,
179
+ )
180
+
181
+ # ==================== Generation Mode (4 modes) ====================
182
+ gr.HTML("<div style='background: #4a5568; color: white; padding: 8px 16px; border-radius: 4px; font-weight: bold;'>Generation Mode</div>")
183
+ with gr.Row():
184
+ generation_mode = gr.Radio(
185
+ choices=[
186
+ ("Simple", "simple"),
187
+ ("Custom", "custom"),
188
+ ("Cover", "cover"),
189
+ ("Repaint", "repaint"),
190
+ ],
191
+ value="custom",
192
+ label="",
193
+ show_label=False,
194
+ )
195
+
196
+ # ==================== Simple Mode Group ====================
197
+ with gr.Column(visible=False) as simple_mode_group:
198
+ # Row: Song Description + Vocal Language + Random button
199
+ with gr.Row(equal_height=True):
200
+ simple_query_input = gr.Textbox(
201
+ label=t("generation.simple_query_label"),
202
+ placeholder=t("generation.simple_query_placeholder"),
203
+ lines=2,
204
+ info=t("generation.simple_query_info"),
205
+ scale=10,
206
+ )
207
+ simple_vocal_language = gr.Dropdown(
208
+ choices=VALID_LANGUAGES,
209
+ value="unknown",
210
+ allow_custom_value=True,
211
+ label=t("generation.simple_vocal_language_label"),
212
+ interactive=True,
213
+ info="use unknown for instrumental",
214
+ scale=2,
215
+ )
216
+ with gr.Column(scale=1, min_width=60):
217
+ random_desc_btn = gr.Button(
218
+ "🎲",
219
+ variant="primary",
220
+ size="lg",
221
+ )
222
+
223
+ # Hidden components (kept for compatibility but not shown)
224
+ simple_instrumental_checkbox = gr.Checkbox(
225
+ label=t("generation.instrumental_label"),
226
+ value=False,
227
+ visible=False,
228
+ )
229
+ create_sample_btn = gr.Button(
230
+ t("generation.create_sample_btn"),
231
+ variant="primary",
232
+ size="lg",
233
+ visible=False,
234
+ )
235
+
236
+ # State to track if sample has been created in Simple mode
237
+ simple_sample_created = gr.State(value=False)
238
+
239
+ # ==================== Source Audio (for Cover/Repaint) ====================
240
+ # This is shown above the main content for Cover and Repaint modes
241
+ with gr.Column(visible=False) as src_audio_group:
242
+ with gr.Row(equal_height=True):
243
+ # Source Audio - scale=10 to match (refer_audio=2 + prompt/lyrics=8)
244
+ src_audio = gr.Audio(
245
+ label="Source Audio",
246
+ type="filepath",
247
+ scale=10,
248
+ )
249
+ # Process button - scale=1 to align with random button
250
+ with gr.Column(scale=1, min_width=80):
251
+ process_src_btn = gr.Button(
252
+ "Analyze",
253
+ variant="secondary",
254
+ size="lg",
255
+ )
256
+
257
+ # Hidden Audio Codes storage (needed internally but not displayed)
258
+ text2music_audio_code_string = gr.Textbox(
259
+ label="Audio Codes",
260
+ visible=False,
261
+ )
262
+
263
+ # ==================== Custom/Cover/Repaint Mode Content ====================
264
+ with gr.Column() as custom_mode_content:
265
+ with gr.Row(equal_height=True):
266
+ # Left: Reference Audio
267
+ with gr.Column(scale=2, min_width=200):
268
+ reference_audio = gr.Audio(
269
+ label="Reference Audio (optional)",
270
+ type="filepath",
271
+ show_label=True,
272
+ )
273
+
274
+ # Middle: Prompt + Lyrics + Format button
275
+ with gr.Column(scale=8):
276
+ # Row 1: Prompt and Lyrics
277
+ with gr.Row(equal_height=True):
278
+ captions = gr.Textbox(
279
+ label="Prompt",
280
+ placeholder="Describe the music style, mood, instruments...",
281
+ lines=12,
282
+ max_lines=12,
283
+ scale=1,
284
+ )
285
+ lyrics = gr.Textbox(
286
+ label="Lyrics",
287
+ placeholder="Enter lyrics here... Use [Verse], [Chorus] etc. for structure",
288
+ lines=12,
289
+ max_lines=12,
290
+ scale=1,
291
+ )
292
+
293
+ # Row 2: Format button (only below Prompt and Lyrics)
294
+ format_btn = gr.Button(
295
+ "Format",
296
+ variant="secondary",
297
+ )
298
+
299
+ # Right: Random button
300
+ with gr.Column(scale=1, min_width=60):
301
+ sample_btn = gr.Button(
302
+ "🎲",
303
+ variant="primary",
304
+ size="lg",
305
+ )
306
+
307
+ # Placeholder for removed audio_uploads_accordion (for compatibility)
308
+ audio_uploads_accordion = gr.Column(visible=False)
309
+
310
+ # Legacy cover_mode_group (hidden, for backward compatibility)
311
+ cover_mode_group = gr.Column(visible=False)
312
+ # Legacy convert button (hidden, for backward compatibility)
313
+ convert_src_to_codes_btn = gr.Button("Convert to Codes", visible=False)
314
+
315
+ # ==================== Repaint Mode: Source + Time Range ====================
316
+ with gr.Column(visible=False) as repainting_group:
317
+ with gr.Row():
318
+ repainting_start = gr.Number(
319
+ label="Start (seconds)",
320
+ value=0.0,
321
+ step=0.1,
322
+ scale=1,
323
+ )
324
+ repainting_end = gr.Number(
325
+ label="End (seconds, -1 for end)",
326
+ value=-1,
327
+ minimum=-1,
328
+ step=0.1,
329
+ scale=1,
330
+ )
331
+
332
+ # ==================== Optional Parameters ====================
333
+ with gr.Accordion("⚙️ Optional Parameters", open=False, visible=False) as optional_params_accordion:
334
+ pass
335
+
336
+ # ==================== Advanced Settings ====================
337
+ with gr.Accordion("🔧 Advanced Settings", open=False) as advanced_options_accordion:
338
+ with gr.Row():
339
+ bpm = gr.Number(
340
+ label="BPM (optional)",
341
+ value=0,
342
+ step=1,
343
+ info="leave empty for N/A",
344
+ scale=1,
345
+ )
346
+ key_scale = gr.Textbox(
347
+ label="Key Signature (optional)",
348
+ placeholder="Leave empty for N/A",
349
+ value="",
350
+ info="A-G, #/♭, major/minor",
351
+ scale=1,
352
+ )
353
+ time_signature = gr.Dropdown(
354
+ choices=["", "2", "3", "4"],
355
+ value="",
356
+ label="Time Signature (optional)",
357
+ allow_custom_value=True,
358
+ info="2/4, 3/4, 4/4...",
359
+ scale=1,
360
+ )
361
+ audio_duration = gr.Number(
362
+ label="Audio Duration (seconds)",
363
+ value=-1,
364
+ minimum=-1,
365
+ maximum=600.0,
366
+ step=1,
367
+ info="Use -1 for auto, or 10-600 seconds",
368
+ scale=1,
369
+ )
370
+ vocal_language = gr.Dropdown(
371
+ choices=VALID_LANGUAGES,
372
+ value="unknown",
373
+ label="Vocal Language",
374
+ allow_custom_value=True,
375
+ info="use `unknown` for instrumental",
376
+ scale=1,
377
+ )
378
+ batch_size_input = gr.Number(
379
+ label="batch size",
380
+ info="max 8",
381
+ value=2,
382
+ minimum=1,
383
+ maximum=8,
384
+ step=1,
385
+ scale=1,
386
+ interactive=False,
387
+ )
388
+
389
+ # Row 1: DiT Inference Steps, Seed, Audio Format
390
+ with gr.Row():
391
+ inference_steps = gr.Slider(
392
+ minimum=1,
393
+ maximum=20,
394
+ value=8,
395
+ step=1,
396
+ label="DiT Inference Steps",
397
+ info="Turbo: max 8, Base: max 200",
398
+ )
399
+ seed = gr.Textbox(
400
+ label="Seed",
401
+ value="-1",
402
+ info="Use comma-separated values for batches",
403
+ )
404
+ audio_format = gr.Dropdown(
405
+ choices=["mp3", "flac"],
406
+ value="mp3",
407
+ label="Audio Format",
408
+ info="Audio format for saved files",
409
+ )
410
+
411
+ # Row 2: Shift, Random Seed, Inference Method
412
+ with gr.Row():
413
+ shift = gr.Slider(
414
+ minimum=1.0,
415
+ maximum=5.0,
416
+ value=3.0,
417
+ step=0.1,
418
+ label="Shift",
419
+ info="Timestep shift factor for base models (range 1.0-5.0, default 3.0). Not effective for turbo models.",
420
+ )
421
+ random_seed_checkbox = gr.Checkbox(
422
+ label="Random Seed",
423
+ value=True,
424
+ info="Enable to auto-generate seeds",
425
+ )
426
+ infer_method = gr.Dropdown(
427
+ choices=["ode", "sde"],
428
+ value="ode",
429
+ label="Inference Method",
430
+ info="Diffusion inference method. ODE (Euler) is faster, SDE (stochastic) may produce different results.",
431
+ )
432
+
433
+ # Row 3: Custom Timesteps (full width)
434
+ custom_timesteps = gr.Textbox(
435
+ label="Custom Timesteps",
436
+ placeholder="0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0",
437
+ value="",
438
+ info="Optional: comma-separated values from 1.0 to 0.0 (e.g., '0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0'). Overrides inference steps and shift.",
439
+ )
440
+
441
+ # Section: LM Generation Parameters
442
+ gr.HTML("<h4>🎵 LM Generation Parameters</h4>")
443
+
444
+ # Row 4: LM Temperature, LM CFG Scale, LM Top-K, LM Top-P
445
+ with gr.Row():
446
+ lm_temperature = gr.Slider(
447
+ minimum=0.0,
448
+ maximum=2.0,
449
+ value=0.85,
450
+ step=0.05,
451
+ label="LM Temperature",
452
+ info="5Hz LM temperature (higher = more random)",
453
+ )
454
+ lm_cfg_scale = gr.Slider(
455
+ minimum=1.0,
456
+ maximum=3.0,
457
+ value=2.0,
458
+ step=0.1,
459
+ label="LM CFG Scale",
460
+ info="5Hz LM CFG (1.0 = no CFG)",
461
+ )
462
+ lm_top_k = gr.Slider(
463
+ minimum=0,
464
+ maximum=100,
465
+ value=0,
466
+ step=1,
467
+ label="LM Top-K",
468
+ info="Top-k (0 = disabled)",
469
+ )
470
+ lm_top_p = gr.Slider(
471
+ minimum=0.0,
472
+ maximum=1.0,
473
+ value=0.9,
474
+ step=0.01,
475
+ label="LM Top-P",
476
+ info="Top-p (1.0 = disabled)",
477
+ )
478
+
479
+ # Row 5: LM Negative Prompt (full width)
480
+ lm_negative_prompt = gr.Textbox(
481
+ label="LM Negative Prompt",
482
+ value="NO USER INPUT",
483
+ placeholder="Things to avoid in generation...",
484
+ lines=2,
485
+ info="Negative prompt (use when LM CFG Scale > 1.0)",
486
+ )
487
+ # audio_cover_strength remains hidden for now
488
+ audio_cover_strength = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, visible=False)
489
+
490
+ # Note: audio_duration, bpm, key_scale, time_signature are now visible in Optional Parameters
491
+ # ==================== Generate Button Row ====================
492
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
493
+ with gr.Row(equal_height=True):
494
+ # Left: Thinking and Instrumental checkboxes
495
+ with gr.Column(scale=1, min_width=120):
496
+ think_checkbox = gr.Checkbox(
497
+ label="Thinking",
498
+ value=True,
499
+ )
500
+ instrumental_checkbox = gr.Checkbox(
501
+ label="Instrumental",
502
+ value=False,
503
+ )
504
+
505
+ # Center: Generate button
506
+ with gr.Column(scale=4):
507
+ generate_btn = gr.Button(
508
+ "🎵 Generate Music",
509
+ variant="primary",
510
+ size="lg",
511
+ interactive=generate_btn_interactive,
512
+ )
513
+
514
+ # Right: auto_score, auto_lrc
515
+ with gr.Column(scale=1, min_width=120):
516
+ auto_score = gr.Checkbox(
517
+ label="Get Scores",
518
+ value=False,
519
+ )
520
+ auto_lrc = gr.Checkbox(
521
+ label="Get LRC",
522
+ value=False,
523
+ )
524
+
525
+ # ==================== Hidden Components (for internal use) ====================
526
+ # These are needed for event handlers but not shown in UI
527
+
528
+ # Task type (set automatically based on generation_mode)
529
+ actual_model = init_params.get('config_path', 'acestep-v15-turbo') if service_pre_initialized else 'acestep-v15-turbo'
530
+ actual_model_lower = (actual_model or "").lower()
531
+ if "turbo" in actual_model_lower:
532
+ initial_task_choices = TASK_TYPES_TURBO
533
+ else:
534
+ initial_task_choices = TASK_TYPES_BASE
535
+
536
+ task_type = gr.Dropdown(
537
+ choices=initial_task_choices,
538
+ value="text2music",
539
+ visible=False,
540
+ )
541
+
542
+ instruction_display_gen = gr.Textbox(
543
+ value=DEFAULT_DIT_INSTRUCTION,
544
+ visible=False,
545
+ )
546
+
547
+ track_name = gr.Dropdown(
548
+ choices=TRACK_NAMES,
549
+ value=None,
550
+ visible=False,
551
+ )
552
+
553
+ complete_track_classes = gr.CheckboxGroup(
554
+ choices=TRACK_NAMES,
555
+ visible=False,
556
+ )
557
+
558
+ # Note: lyrics, vocal_language, instrumental_checkbox, format_btn are now visible in custom_mode_content
559
+
560
+ # Hidden advanced settings (keep defaults)
561
+ # Note: Most parameters are now visible in Advanced Settings section above
562
+ guidance_scale = gr.Slider(value=7.0, visible=False)
563
+ use_adg = gr.Checkbox(value=False, visible=False)
564
+ cfg_interval_start = gr.Slider(value=0.0, visible=False)
565
+ cfg_interval_end = gr.Slider(value=1.0, visible=False)
566
+
567
+ # LM parameters (remaining hidden ones)
568
+ use_cot_metas = gr.Checkbox(value=True, visible=False)
569
+ use_cot_caption = gr.Checkbox(value=True, visible=False)
570
+ use_cot_language = gr.Checkbox(value=True, visible=False)
571
+ constrained_decoding_debug = gr.Checkbox(value=False, visible=False)
572
+ allow_lm_batch = gr.Checkbox(value=True, visible=False)
573
+ lm_batch_chunk_size = gr.Number(value=8, visible=False)
574
+ score_scale = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, visible=False)
575
+ autogen_checkbox = gr.Checkbox(value=False, visible=False)
576
+
577
+ # Transcribe button (hidden)
578
+ transcribe_btn = gr.Button(value="Transcribe", visible=False)
579
+ text2music_audio_codes_group = gr.Group(visible=False)
580
+
581
+ # Note: format_btn is now visible in custom_mode_content
582
+
583
+ # Load file button (hidden for now)
584
+ load_file = gr.UploadButton(
585
+ label="Load",
586
+ file_types=[".json"],
587
+ file_count="single",
588
+ visible=False,
589
+ )
590
+
591
+ # Caption/Lyrics accordions (not used in new UI but needed for compatibility)
592
+ caption_accordion = gr.Accordion("Caption", visible=False)
593
+ lyrics_accordion = gr.Accordion("Lyrics", visible=False)
594
+ # Note: optional_params_accordion is now visible above
595
+
596
+ return {
597
+ "service_config_accordion": service_config_accordion,
598
+ "language_dropdown": language_dropdown,
599
+ "checkpoint_dropdown": checkpoint_dropdown,
600
+ "refresh_btn": refresh_btn,
601
+ "config_path": config_path,
602
+ "device": device,
603
+ "init_btn": init_btn,
604
+ "init_status": init_status,
605
+ "lm_model_path": lm_model_path,
606
+ "init_llm_checkbox": init_llm_checkbox,
607
+ "backend_dropdown": backend_dropdown,
608
+ "use_flash_attention_checkbox": use_flash_attention_checkbox,
609
+ "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
610
+ "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
611
+ # LoRA components
612
+ "lora_path": lora_path,
613
+ "load_lora_btn": load_lora_btn,
614
+ "unload_lora_btn": unload_lora_btn,
615
+ "use_lora_checkbox": use_lora_checkbox,
616
+ "lora_status": lora_status,
617
+ # DiT model selector
618
+ "dit_model_selector": dit_model_selector,
619
+ "task_type": task_type,
620
+ "instruction_display_gen": instruction_display_gen,
621
+ "track_name": track_name,
622
+ "complete_track_classes": complete_track_classes,
623
+ "audio_uploads_accordion": audio_uploads_accordion,
624
+ "reference_audio": reference_audio,
625
+ "src_audio": src_audio,
626
+ "convert_src_to_codes_btn": convert_src_to_codes_btn,
627
+ "text2music_audio_code_string": text2music_audio_code_string,
628
+ "transcribe_btn": transcribe_btn,
629
+ "text2music_audio_codes_group": text2music_audio_codes_group,
630
+ "lm_temperature": lm_temperature,
631
+ "lm_cfg_scale": lm_cfg_scale,
632
+ "lm_top_k": lm_top_k,
633
+ "lm_top_p": lm_top_p,
634
+ "lm_negative_prompt": lm_negative_prompt,
635
+ "use_cot_metas": use_cot_metas,
636
+ "use_cot_caption": use_cot_caption,
637
+ "use_cot_language": use_cot_language,
638
+ "repainting_group": repainting_group,
639
+ "repainting_start": repainting_start,
640
+ "repainting_end": repainting_end,
641
+ "audio_cover_strength": audio_cover_strength,
642
+ # Generation mode components
643
+ "generation_mode": generation_mode,
644
+ "simple_mode_group": simple_mode_group,
645
+ "simple_query_input": simple_query_input,
646
+ "random_desc_btn": random_desc_btn,
647
+ "simple_instrumental_checkbox": simple_instrumental_checkbox,
648
+ "simple_vocal_language": simple_vocal_language,
649
+ "create_sample_btn": create_sample_btn,
650
+ "simple_sample_created": simple_sample_created,
651
+ "caption_accordion": caption_accordion,
652
+ "lyrics_accordion": lyrics_accordion,
653
+ "optional_params_accordion": optional_params_accordion,
654
+ # Custom mode components
655
+ "custom_mode_content": custom_mode_content,
656
+ "cover_mode_group": cover_mode_group,
657
+ # Source audio group for Cover/Repaint
658
+ "src_audio_group": src_audio_group,
659
+ "process_src_btn": process_src_btn,
660
+ "advanced_options_accordion": advanced_options_accordion,
661
+ # Existing components
662
+ "captions": captions,
663
+ "sample_btn": sample_btn,
664
+ "load_file": load_file,
665
+ "lyrics": lyrics,
666
+ "vocal_language": vocal_language,
667
+ "bpm": bpm,
668
+ "key_scale": key_scale,
669
+ "time_signature": time_signature,
670
+ "audio_duration": audio_duration,
671
+ "batch_size_input": batch_size_input,
672
+ "inference_steps": inference_steps,
673
+ "guidance_scale": guidance_scale,
674
+ "seed": seed,
675
+ "random_seed_checkbox": random_seed_checkbox,
676
+ "use_adg": use_adg,
677
+ "cfg_interval_start": cfg_interval_start,
678
+ "cfg_interval_end": cfg_interval_end,
679
+ "shift": shift,
680
+ "infer_method": infer_method,
681
+ "custom_timesteps": custom_timesteps,
682
+ "audio_format": audio_format,
683
+ "think_checkbox": think_checkbox,
684
+ "autogen_checkbox": autogen_checkbox,
685
+ "generate_btn": generate_btn,
686
+ "instrumental_checkbox": instrumental_checkbox,
687
+ "format_btn": format_btn,
688
+ "constrained_decoding_debug": constrained_decoding_debug,
689
+ "score_scale": score_scale,
690
+ "allow_lm_batch": allow_lm_batch,
691
+ "auto_score": auto_score,
692
+ "auto_lrc": auto_lrc,
693
+ "lm_batch_chunk_size": lm_batch_chunk_size,
694
+ }
acestep/gradio_ui/interfaces/result.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Results Section Module
3
+ Contains results display section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import t
7
+
8
+
9
+ def create_results_section(dit_handler) -> dict:
10
+ """Create results display section"""
11
+ with gr.Accordion(t("results.title"), open=True):
12
+ # Hidden state to store LM-generated metadata
13
+ lm_metadata_state = gr.State(value=None)
14
+
15
+ # Hidden state to track if caption/metadata is from formatted source (LM/transcription)
16
+ is_format_caption_state = gr.State(value=False)
17
+
18
+ # Batch management states
19
+ current_batch_index = gr.State(value=0) # Currently displayed batch index
20
+ total_batches = gr.State(value=1) # Total number of batches generated
21
+ batch_queue = gr.State(value={}) # Dictionary storing all batch data
22
+ generation_params_state = gr.State(value={}) # Store generation parameters for next batches
23
+ is_generating_background = gr.State(value=False) # Background generation flag
24
+
25
+ # All audio components in one row with dynamic visibility
26
+ with gr.Row():
27
+ with gr.Column(visible=True) as audio_col_1:
28
+ generated_audio_1 = gr.Audio(
29
+ label=t("results.generated_music", n=1),
30
+ type="filepath",
31
+ interactive=False,
32
+ buttons=[]
33
+ )
34
+ with gr.Row(equal_height=True):
35
+ send_to_cover_btn_1 = gr.Button(
36
+ t("results.send_to_cover_btn"),
37
+ variant="secondary",
38
+ size="sm",
39
+ scale=1
40
+ )
41
+ send_to_repaint_btn_1 = gr.Button(
42
+ t("results.send_to_repaint_btn"),
43
+ variant="secondary",
44
+ size="sm",
45
+ scale=1
46
+ )
47
+ save_btn_1 = gr.Button(
48
+ t("results.save_btn"),
49
+ variant="primary",
50
+ size="sm",
51
+ scale=1
52
+ )
53
+ score_btn_1 = gr.Button(
54
+ t("results.score_btn"),
55
+ variant="secondary",
56
+ size="sm",
57
+ scale=1,
58
+ visible=False
59
+ )
60
+ lrc_btn_1 = gr.Button(
61
+ t("results.lrc_btn"),
62
+ variant="secondary",
63
+ size="sm",
64
+ scale=1,
65
+ visible=False
66
+ )
67
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_1:
68
+ score_display_1 = gr.Textbox(
69
+ label=t("results.quality_score_label", n=1),
70
+ interactive=False,
71
+ buttons=["copy"],
72
+ lines=6,
73
+ max_lines=6,
74
+ visible=True
75
+ )
76
+ lrc_display_1 = gr.Textbox(
77
+ label=t("results.lrc_label", n=1),
78
+ interactive=True,
79
+ buttons=["copy"],
80
+ lines=8,
81
+ max_lines=8,
82
+ visible=True
83
+ )
84
+ codes_display_1 = gr.Textbox(
85
+ label=t("results.codes_label", n=1),
86
+ interactive=False,
87
+ buttons=["copy"],
88
+ lines=4,
89
+ max_lines=4,
90
+ visible=True
91
+ )
92
+ with gr.Column(visible=True) as audio_col_2:
93
+ generated_audio_2 = gr.Audio(
94
+ label=t("results.generated_music", n=2),
95
+ type="filepath",
96
+ interactive=False,
97
+ buttons=[]
98
+ )
99
+ with gr.Row(equal_height=True):
100
+ send_to_cover_btn_2 = gr.Button(
101
+ t("results.send_to_cover_btn"),
102
+ variant="secondary",
103
+ size="sm",
104
+ scale=1
105
+ )
106
+ send_to_repaint_btn_2 = gr.Button(
107
+ t("results.send_to_repaint_btn"),
108
+ variant="secondary",
109
+ size="sm",
110
+ scale=1
111
+ )
112
+ save_btn_2 = gr.Button(
113
+ t("results.save_btn"),
114
+ variant="primary",
115
+ size="sm",
116
+ scale=1
117
+ )
118
+ score_btn_2 = gr.Button(
119
+ t("results.score_btn"),
120
+ variant="secondary",
121
+ size="sm",
122
+ scale=1,
123
+ visible=False
124
+ )
125
+ lrc_btn_2 = gr.Button(
126
+ t("results.lrc_btn"),
127
+ variant="secondary",
128
+ size="sm",
129
+ scale=1,
130
+ visible=False
131
+ )
132
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_2:
133
+ score_display_2 = gr.Textbox(
134
+ label=t("results.quality_score_label", n=2),
135
+ interactive=False,
136
+ buttons=["copy"],
137
+ lines=6,
138
+ max_lines=6,
139
+ visible=True
140
+ )
141
+ lrc_display_2 = gr.Textbox(
142
+ label=t("results.lrc_label", n=2),
143
+ interactive=True,
144
+ buttons=["copy"],
145
+ lines=8,
146
+ max_lines=8,
147
+ visible=True
148
+ )
149
+ codes_display_2 = gr.Textbox(
150
+ label=t("results.codes_label", n=2),
151
+ interactive=False,
152
+ buttons=["copy"],
153
+ lines=4,
154
+ max_lines=4,
155
+ visible=True
156
+ )
157
+ with gr.Column(visible=False) as audio_col_3:
158
+ generated_audio_3 = gr.Audio(
159
+ label=t("results.generated_music", n=3),
160
+ type="filepath",
161
+ interactive=False,
162
+ buttons=[]
163
+ )
164
+ with gr.Row(equal_height=True):
165
+ send_to_cover_btn_3 = gr.Button(
166
+ t("results.send_to_cover_btn"),
167
+ variant="secondary",
168
+ size="sm",
169
+ scale=1
170
+ )
171
+ send_to_repaint_btn_3 = gr.Button(
172
+ t("results.send_to_repaint_btn"),
173
+ variant="secondary",
174
+ size="sm",
175
+ scale=1
176
+ )
177
+ save_btn_3 = gr.Button(
178
+ t("results.save_btn"),
179
+ variant="primary",
180
+ size="sm",
181
+ scale=1
182
+ )
183
+ score_btn_3 = gr.Button(
184
+ t("results.score_btn"),
185
+ variant="secondary",
186
+ size="sm",
187
+ scale=1,
188
+ visible=False
189
+ )
190
+ lrc_btn_3 = gr.Button(
191
+ t("results.lrc_btn"),
192
+ variant="secondary",
193
+ size="sm",
194
+ scale=1,
195
+ visible=False
196
+ )
197
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_3:
198
+ score_display_3 = gr.Textbox(
199
+ label=t("results.quality_score_label", n=3),
200
+ interactive=False,
201
+ buttons=["copy"],
202
+ lines=6,
203
+ max_lines=6,
204
+ visible=True
205
+ )
206
+ lrc_display_3 = gr.Textbox(
207
+ label=t("results.lrc_label", n=3),
208
+ interactive=True,
209
+ buttons=["copy"],
210
+ lines=8,
211
+ max_lines=8,
212
+ visible=True
213
+ )
214
+ codes_display_3 = gr.Textbox(
215
+ label=t("results.codes_label", n=3),
216
+ interactive=False,
217
+ buttons=["copy"],
218
+ lines=4,
219
+ max_lines=4,
220
+ visible=True
221
+ )
222
+ with gr.Column(visible=False) as audio_col_4:
223
+ generated_audio_4 = gr.Audio(
224
+ label=t("results.generated_music", n=4),
225
+ type="filepath",
226
+ interactive=False,
227
+ buttons=[]
228
+ )
229
+ with gr.Row(equal_height=True):
230
+ send_to_cover_btn_4 = gr.Button(
231
+ t("results.send_to_cover_btn"),
232
+ variant="secondary",
233
+ size="sm",
234
+ scale=1
235
+ )
236
+ send_to_repaint_btn_4 = gr.Button(
237
+ t("results.send_to_repaint_btn"),
238
+ variant="secondary",
239
+ size="sm",
240
+ scale=1
241
+ )
242
+ save_btn_4 = gr.Button(
243
+ t("results.save_btn"),
244
+ variant="primary",
245
+ size="sm",
246
+ scale=1
247
+ )
248
+ score_btn_4 = gr.Button(
249
+ t("results.score_btn"),
250
+ variant="secondary",
251
+ size="sm",
252
+ scale=1,
253
+ visible=False
254
+ )
255
+ lrc_btn_4 = gr.Button(
256
+ t("results.lrc_btn"),
257
+ variant="secondary",
258
+ size="sm",
259
+ scale=1,
260
+ visible=False
261
+ )
262
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_4:
263
+ score_display_4 = gr.Textbox(
264
+ label=t("results.quality_score_label", n=4),
265
+ interactive=False,
266
+ buttons=["copy"],
267
+ lines=6,
268
+ max_lines=6,
269
+ visible=True
270
+ )
271
+ lrc_display_4 = gr.Textbox(
272
+ label=t("results.lrc_label", n=4),
273
+ interactive=True,
274
+ buttons=["copy"],
275
+ lines=8,
276
+ max_lines=8,
277
+ visible=True
278
+ )
279
+ codes_display_4 = gr.Textbox(
280
+ label=t("results.codes_label", n=4),
281
+ interactive=False,
282
+ buttons=["copy"],
283
+ lines=4,
284
+ max_lines=4,
285
+ visible=True
286
+ )
287
+
288
+ # Second row for batch size 5-8 (initially hidden)
289
+ with gr.Row(visible=False) as audio_row_5_8:
290
+ with gr.Column() as audio_col_5:
291
+ generated_audio_5 = gr.Audio(
292
+ label=t("results.generated_music", n=5),
293
+ type="filepath",
294
+ interactive=False,
295
+ buttons=[]
296
+ )
297
+ with gr.Row(equal_height=True):
298
+ send_to_cover_btn_5 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
299
+ send_to_repaint_btn_5 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
300
+ save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
301
+ score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
302
+ lrc_btn_5 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
303
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_5:
304
+ score_display_5 = gr.Textbox(
305
+ label=t("results.quality_score_label", n=5),
306
+ interactive=False,
307
+ buttons=["copy"],
308
+ lines=6,
309
+ max_lines=6,
310
+ visible=True
311
+ )
312
+ lrc_display_5 = gr.Textbox(
313
+ label=t("results.lrc_label", n=5),
314
+ interactive=True,
315
+ buttons=["copy"],
316
+ lines=8,
317
+ max_lines=8,
318
+ visible=True
319
+ )
320
+ codes_display_5 = gr.Textbox(
321
+ label=t("results.codes_label", n=5),
322
+ interactive=False,
323
+ buttons=["copy"],
324
+ lines=4,
325
+ max_lines=4,
326
+ visible=True
327
+ )
328
+ with gr.Column() as audio_col_6:
329
+ generated_audio_6 = gr.Audio(
330
+ label=t("results.generated_music", n=6),
331
+ type="filepath",
332
+ interactive=False,
333
+ buttons=[]
334
+ )
335
+ with gr.Row(equal_height=True):
336
+ send_to_cover_btn_6 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
337
+ send_to_repaint_btn_6 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
338
+ save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
339
+ score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
340
+ lrc_btn_6 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
341
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_6:
342
+ score_display_6 = gr.Textbox(
343
+ label=t("results.quality_score_label", n=6),
344
+ interactive=False,
345
+ buttons=["copy"],
346
+ lines=6,
347
+ max_lines=6,
348
+ visible=True
349
+ )
350
+ lrc_display_6 = gr.Textbox(
351
+ label=t("results.lrc_label", n=6),
352
+ interactive=True,
353
+ buttons=["copy"],
354
+ lines=8,
355
+ max_lines=8,
356
+ visible=True
357
+ )
358
+ codes_display_6 = gr.Textbox(
359
+ label=t("results.codes_label", n=6),
360
+ interactive=False,
361
+ buttons=["copy"],
362
+ lines=4,
363
+ max_lines=4,
364
+ visible=True
365
+ )
366
+ with gr.Column() as audio_col_7:
367
+ generated_audio_7 = gr.Audio(
368
+ label=t("results.generated_music", n=7),
369
+ type="filepath",
370
+ interactive=False,
371
+ buttons=[]
372
+ )
373
+ with gr.Row(equal_height=True):
374
+ send_to_cover_btn_7 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
375
+ send_to_repaint_btn_7 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
376
+ save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
377
+ score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
378
+ lrc_btn_7 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
379
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_7:
380
+ score_display_7 = gr.Textbox(
381
+ label=t("results.quality_score_label", n=7),
382
+ interactive=False,
383
+ buttons=["copy"],
384
+ lines=6,
385
+ max_lines=6,
386
+ visible=True
387
+ )
388
+ lrc_display_7 = gr.Textbox(
389
+ label=t("results.lrc_label", n=7),
390
+ interactive=True,
391
+ buttons=["copy"],
392
+ lines=8,
393
+ max_lines=8,
394
+ visible=True
395
+ )
396
+ codes_display_7 = gr.Textbox(
397
+ label=t("results.codes_label", n=7),
398
+ interactive=False,
399
+ buttons=["copy"],
400
+ lines=4,
401
+ max_lines=4,
402
+ visible=True
403
+ )
404
+ with gr.Column() as audio_col_8:
405
+ generated_audio_8 = gr.Audio(
406
+ label=t("results.generated_music", n=8),
407
+ type="filepath",
408
+ interactive=False,
409
+ buttons=[]
410
+ )
411
+ with gr.Row(equal_height=True):
412
+ send_to_cover_btn_8 = gr.Button(t("results.send_to_cover_btn"), variant="secondary", size="sm", scale=1)
413
+ send_to_repaint_btn_8 = gr.Button(t("results.send_to_repaint_btn"), variant="secondary", size="sm", scale=1)
414
+ save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
415
+ score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1, visible=False)
416
+ lrc_btn_8 = gr.Button(t("results.lrc_btn"), variant="secondary", size="sm", scale=1, visible=False)
417
+ with gr.Accordion(t("results.details_accordion"), open=False, visible=True) as details_accordion_8:
418
+ score_display_8 = gr.Textbox(
419
+ label=t("results.quality_score_label", n=8),
420
+ interactive=False,
421
+ buttons=["copy"],
422
+ lines=6,
423
+ max_lines=6,
424
+ visible=True
425
+ )
426
+ lrc_display_8 = gr.Textbox(
427
+ label=t("results.lrc_label", n=8),
428
+ interactive=True,
429
+ buttons=["copy"],
430
+ lines=8,
431
+ max_lines=8,
432
+ visible=True
433
+ )
434
+ codes_display_8 = gr.Textbox(
435
+ label=t("results.codes_label", n=8),
436
+ interactive=False,
437
+ buttons=["copy"],
438
+ lines=4,
439
+ max_lines=4,
440
+ visible=True
441
+ )
442
+
443
+ status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
444
+
445
+ # Batch navigation controls (hidden for simplified UI)
446
+ with gr.Row(equal_height=True, visible=False):
447
+ prev_batch_btn = gr.Button(
448
+ t("results.prev_btn"),
449
+ variant="secondary",
450
+ interactive=False,
451
+ scale=1,
452
+ size="sm"
453
+ )
454
+ batch_indicator = gr.Textbox(
455
+ label=t("results.current_batch"),
456
+ value=t("results.batch_indicator", current=1, total=1),
457
+ interactive=False,
458
+ scale=3
459
+ )
460
+ next_batch_status = gr.Textbox(
461
+ label=t("results.next_batch_status"),
462
+ value="",
463
+ interactive=False,
464
+ scale=3
465
+ )
466
+ next_batch_btn = gr.Button(
467
+ t("results.next_btn"),
468
+ variant="primary",
469
+ interactive=False,
470
+ scale=1,
471
+ size="sm"
472
+ )
473
+
474
+ # One-click restore parameters button (hidden for simplified UI)
475
+ restore_params_btn = gr.Button(
476
+ t("results.restore_params_btn"),
477
+ variant="secondary",
478
+ interactive=False,
479
+ size="sm",
480
+ visible=False
481
+ )
482
+
483
+ with gr.Accordion(t("results.batch_results_title"), open=True):
484
+ generated_audio_batch = gr.File(
485
+ label=t("results.all_files_label"),
486
+ file_count="multiple",
487
+ interactive=False,
488
+ visible=False
489
+ )
490
+ generation_info = gr.Markdown(label=t("results.generation_details"))
491
+
492
+ return {
493
+ "lm_metadata_state": lm_metadata_state,
494
+ "is_format_caption_state": is_format_caption_state,
495
+ "current_batch_index": current_batch_index,
496
+ "total_batches": total_batches,
497
+ "batch_queue": batch_queue,
498
+ "generation_params_state": generation_params_state,
499
+ "is_generating_background": is_generating_background,
500
+ "status_output": status_output,
501
+ "prev_batch_btn": prev_batch_btn,
502
+ "batch_indicator": batch_indicator,
503
+ "next_batch_btn": next_batch_btn,
504
+ "next_batch_status": next_batch_status,
505
+ "restore_params_btn": restore_params_btn,
506
+ "generated_audio_1": generated_audio_1,
507
+ "generated_audio_2": generated_audio_2,
508
+ "generated_audio_3": generated_audio_3,
509
+ "generated_audio_4": generated_audio_4,
510
+ "generated_audio_5": generated_audio_5,
511
+ "generated_audio_6": generated_audio_6,
512
+ "generated_audio_7": generated_audio_7,
513
+ "generated_audio_8": generated_audio_8,
514
+ "audio_row_5_8": audio_row_5_8,
515
+ "audio_col_1": audio_col_1,
516
+ "audio_col_2": audio_col_2,
517
+ "audio_col_3": audio_col_3,
518
+ "audio_col_4": audio_col_4,
519
+ "audio_col_5": audio_col_5,
520
+ "audio_col_6": audio_col_6,
521
+ "audio_col_7": audio_col_7,
522
+ "audio_col_8": audio_col_8,
523
+ "send_to_cover_btn_1": send_to_cover_btn_1,
524
+ "send_to_cover_btn_2": send_to_cover_btn_2,
525
+ "send_to_cover_btn_3": send_to_cover_btn_3,
526
+ "send_to_cover_btn_4": send_to_cover_btn_4,
527
+ "send_to_cover_btn_5": send_to_cover_btn_5,
528
+ "send_to_cover_btn_6": send_to_cover_btn_6,
529
+ "send_to_cover_btn_7": send_to_cover_btn_7,
530
+ "send_to_cover_btn_8": send_to_cover_btn_8,
531
+ "send_to_repaint_btn_1": send_to_repaint_btn_1,
532
+ "send_to_repaint_btn_2": send_to_repaint_btn_2,
533
+ "send_to_repaint_btn_3": send_to_repaint_btn_3,
534
+ "send_to_repaint_btn_4": send_to_repaint_btn_4,
535
+ "send_to_repaint_btn_5": send_to_repaint_btn_5,
536
+ "send_to_repaint_btn_6": send_to_repaint_btn_6,
537
+ "send_to_repaint_btn_7": send_to_repaint_btn_7,
538
+ "send_to_repaint_btn_8": send_to_repaint_btn_8,
539
+ "save_btn_1": save_btn_1,
540
+ "save_btn_2": save_btn_2,
541
+ "save_btn_3": save_btn_3,
542
+ "save_btn_4": save_btn_4,
543
+ "save_btn_5": save_btn_5,
544
+ "save_btn_6": save_btn_6,
545
+ "save_btn_7": save_btn_7,
546
+ "save_btn_8": save_btn_8,
547
+ "score_btn_1": score_btn_1,
548
+ "score_btn_2": score_btn_2,
549
+ "score_btn_3": score_btn_3,
550
+ "score_btn_4": score_btn_4,
551
+ "score_btn_5": score_btn_5,
552
+ "score_btn_6": score_btn_6,
553
+ "score_btn_7": score_btn_7,
554
+ "score_btn_8": score_btn_8,
555
+ "score_display_1": score_display_1,
556
+ "score_display_2": score_display_2,
557
+ "score_display_3": score_display_3,
558
+ "score_display_4": score_display_4,
559
+ "score_display_5": score_display_5,
560
+ "score_display_6": score_display_6,
561
+ "score_display_7": score_display_7,
562
+ "score_display_8": score_display_8,
563
+ "codes_display_1": codes_display_1,
564
+ "codes_display_2": codes_display_2,
565
+ "codes_display_3": codes_display_3,
566
+ "codes_display_4": codes_display_4,
567
+ "codes_display_5": codes_display_5,
568
+ "codes_display_6": codes_display_6,
569
+ "codes_display_7": codes_display_7,
570
+ "codes_display_8": codes_display_8,
571
+ "lrc_btn_1": lrc_btn_1,
572
+ "lrc_btn_2": lrc_btn_2,
573
+ "lrc_btn_3": lrc_btn_3,
574
+ "lrc_btn_4": lrc_btn_4,
575
+ "lrc_btn_5": lrc_btn_5,
576
+ "lrc_btn_6": lrc_btn_6,
577
+ "lrc_btn_7": lrc_btn_7,
578
+ "lrc_btn_8": lrc_btn_8,
579
+ "lrc_display_1": lrc_display_1,
580
+ "lrc_display_2": lrc_display_2,
581
+ "lrc_display_3": lrc_display_3,
582
+ "lrc_display_4": lrc_display_4,
583
+ "lrc_display_5": lrc_display_5,
584
+ "lrc_display_6": lrc_display_6,
585
+ "lrc_display_7": lrc_display_7,
586
+ "lrc_display_8": lrc_display_8,
587
+ "details_accordion_1": details_accordion_1,
588
+ "details_accordion_2": details_accordion_2,
589
+ "details_accordion_3": details_accordion_3,
590
+ "details_accordion_4": details_accordion_4,
591
+ "details_accordion_5": details_accordion_5,
592
+ "details_accordion_6": details_accordion_6,
593
+ "details_accordion_7": details_accordion_7,
594
+ "details_accordion_8": details_accordion_8,
595
+ "generated_audio_batch": generated_audio_batch,
596
+ "generation_info": generation_info,
597
+ }
598
+
acestep/gradio_ui/interfaces/training.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Training Tab Module
3
+
4
+ Contains the dataset builder and LoRA training interface components.
5
+ """
6
+
7
+ import os
8
+ import gradio as gr
9
+ from acestep.gradio_ui.i18n import t
10
+
11
+
12
+ def create_training_section(dit_handler, llm_handler, init_params=None) -> dict:
13
+ """Create the training tab section with dataset builder and training controls.
14
+
15
+ Args:
16
+ dit_handler: DiT handler instance
17
+ llm_handler: LLM handler instance
18
+ init_params: Dictionary containing initialization parameters and state.
19
+ If None, service will not be pre-initialized.
20
+
21
+ Returns:
22
+ Dictionary of Gradio components for event handling
23
+ """
24
+ # Check if running in service mode (hide training tab)
25
+ service_mode = init_params is not None and init_params.get('service_mode', False)
26
+
27
+ with gr.Tab("🎓 LoRA Training", visible=not service_mode):
28
+ gr.HTML("""
29
+ <div style="text-align: center; padding: 10px; margin-bottom: 15px;">
30
+ <h2>🎵 LoRA Training for ACE-Step</h2>
31
+ <p>Build datasets from your audio files and train custom LoRA adapters</p>
32
+ </div>
33
+ """)
34
+
35
+ with gr.Tabs():
36
+ # ==================== Dataset Builder Tab ====================
37
+ with gr.Tab("📁 Dataset Builder"):
38
+ # ========== Load Existing OR Scan New ==========
39
+ gr.HTML("""
40
+ <div style="padding: 10px; margin-bottom: 10px; border: 1px solid #4a4a6a; border-radius: 8px; background: linear-gradient(135deg, #2a2a4a 0%, #1a1a3a 100%);">
41
+ <h3 style="margin: 0 0 5px 0;">🚀 Quick Start</h3>
42
+ <p style="margin: 0; color: #aaa;">Choose one: <b>Load existing dataset</b> OR <b>Scan new directory</b></p>
43
+ </div>
44
+ """)
45
+
46
+ with gr.Row():
47
+ with gr.Column(scale=1):
48
+ gr.HTML("<h4>📂 Load Existing Dataset</h4>")
49
+ with gr.Row():
50
+ load_json_path = gr.Textbox(
51
+ label="Dataset JSON Path",
52
+ placeholder="./datasets/my_lora_dataset.json",
53
+ info="Load a previously saved dataset",
54
+ scale=3,
55
+ )
56
+ load_json_btn = gr.Button("📂 Load", variant="primary", scale=1)
57
+ load_json_status = gr.Textbox(
58
+ label="Load Status",
59
+ interactive=False,
60
+ )
61
+
62
+ with gr.Column(scale=1):
63
+ gr.HTML("<h4>🔍 Scan New Directory</h4>")
64
+ with gr.Row():
65
+ audio_directory = gr.Textbox(
66
+ label="Audio Directory Path",
67
+ placeholder="/path/to/your/audio/folder",
68
+ info="Scan for audio files (wav, mp3, flac, ogg, opus)",
69
+ scale=3,
70
+ )
71
+ scan_btn = gr.Button("🔍 Scan", variant="secondary", scale=1)
72
+ scan_status = gr.Textbox(
73
+ label="Scan Status",
74
+ interactive=False,
75
+ )
76
+
77
+ gr.HTML("<hr>")
78
+
79
+ with gr.Row():
80
+ with gr.Column(scale=2):
81
+
82
+ # Audio files table
83
+ audio_files_table = gr.Dataframe(
84
+ headers=["#", "Filename", "Duration", "Labeled", "BPM", "Key", "Caption"],
85
+ datatype=["number", "str", "str", "str", "str", "str", "str"],
86
+ label="Found Audio Files",
87
+ interactive=False,
88
+ wrap=True,
89
+ )
90
+
91
+ with gr.Column(scale=1):
92
+ gr.HTML("<h3>⚙️ Dataset Settings</h3>")
93
+
94
+ dataset_name = gr.Textbox(
95
+ label="Dataset Name",
96
+ value="my_lora_dataset",
97
+ placeholder="Enter dataset name",
98
+ )
99
+
100
+ all_instrumental = gr.Checkbox(
101
+ label="All Instrumental",
102
+ value=True,
103
+ info="Check if all tracks are instrumental (no vocals)",
104
+ )
105
+
106
+ need_lyrics = gr.Checkbox(
107
+ label="Transcribe Lyrics",
108
+ value=False,
109
+ info="Attempt to transcribe lyrics (slower)",
110
+ interactive=False, # Disabled for now
111
+ )
112
+
113
+ custom_tag = gr.Textbox(
114
+ label="Custom Activation Tag",
115
+ placeholder="e.g., 8bit_retro, my_style",
116
+ info="Unique tag to activate this LoRA's style",
117
+ )
118
+
119
+ tag_position = gr.Radio(
120
+ choices=[
121
+ ("Prepend (tag, caption)", "prepend"),
122
+ ("Append (caption, tag)", "append"),
123
+ ("Replace caption", "replace"),
124
+ ],
125
+ value="replace",
126
+ label="Tag Position",
127
+ info="Where to place the custom tag in the caption",
128
+ )
129
+
130
+ gr.HTML("<hr><h3>🤖 Step 2: Auto-Label with AI</h3>")
131
+
132
+ with gr.Row():
133
+ with gr.Column(scale=3):
134
+ gr.Markdown("""
135
+ Click the button below to automatically generate metadata for all audio files using AI:
136
+ - **Caption**: Music style, genre, mood description
137
+ - **BPM**: Beats per minute
138
+ - **Key**: Musical key (e.g., C Major, Am)
139
+ - **Time Signature**: 4/4, 3/4, etc.
140
+ """)
141
+ skip_metas = gr.Checkbox(
142
+ label="Skip Metas (No LLM)",
143
+ value=False,
144
+ info="Skip AI labeling. BPM/Key/Time Signature will be N/A, Language will be 'unknown' for instrumental",
145
+ )
146
+ with gr.Column(scale=1):
147
+ auto_label_btn = gr.Button(
148
+ "🏷️ Auto-Label All",
149
+ variant="primary",
150
+ size="lg",
151
+ )
152
+
153
+ label_progress = gr.Textbox(
154
+ label="Labeling Progress",
155
+ interactive=False,
156
+ lines=2,
157
+ )
158
+
159
+ gr.HTML("<hr><h3>👀 Step 3: Preview & Edit</h3>")
160
+
161
+ with gr.Row():
162
+ with gr.Column(scale=1):
163
+ sample_selector = gr.Slider(
164
+ minimum=0,
165
+ maximum=0,
166
+ step=1,
167
+ value=0,
168
+ label="Select Sample #",
169
+ info="Choose a sample to preview and edit",
170
+ )
171
+
172
+ preview_audio = gr.Audio(
173
+ label="Audio Preview",
174
+ type="filepath",
175
+ interactive=False,
176
+ )
177
+
178
+ preview_filename = gr.Textbox(
179
+ label="Filename",
180
+ interactive=False,
181
+ )
182
+
183
+ with gr.Column(scale=2):
184
+ with gr.Row():
185
+ edit_caption = gr.Textbox(
186
+ label="Caption",
187
+ lines=3,
188
+ placeholder="Music description...",
189
+ )
190
+
191
+ with gr.Row():
192
+ edit_lyrics = gr.Textbox(
193
+ label="Lyrics",
194
+ lines=4,
195
+ placeholder="[Verse 1]\nLyrics here...\n\n[Chorus]\n...",
196
+ )
197
+
198
+ with gr.Row():
199
+ edit_bpm = gr.Number(
200
+ label="BPM",
201
+ precision=0,
202
+ )
203
+ edit_keyscale = gr.Textbox(
204
+ label="Key",
205
+ placeholder="C Major",
206
+ )
207
+ edit_timesig = gr.Dropdown(
208
+ choices=["", "2", "3", "4", "6"],
209
+ label="Time Signature",
210
+ )
211
+ edit_duration = gr.Number(
212
+ label="Duration (s)",
213
+ precision=1,
214
+ interactive=False,
215
+ )
216
+
217
+ with gr.Row():
218
+ edit_language = gr.Dropdown(
219
+ choices=["instrumental", "en", "zh", "ja", "ko", "es", "fr", "de", "pt", "ru", "unknown"],
220
+ value="instrumental",
221
+ label="Language",
222
+ )
223
+ edit_instrumental = gr.Checkbox(
224
+ label="Instrumental",
225
+ value=True,
226
+ )
227
+ save_edit_btn = gr.Button("💾 Save Changes", variant="secondary")
228
+
229
+ edit_status = gr.Textbox(
230
+ label="Edit Status",
231
+ interactive=False,
232
+ )
233
+
234
+ gr.HTML("<hr><h3>💾 Step 4: Save Dataset</h3>")
235
+
236
+ with gr.Row():
237
+ with gr.Column(scale=3):
238
+ save_path = gr.Textbox(
239
+ label="Save Path",
240
+ value="./datasets/my_lora_dataset.json",
241
+ placeholder="./datasets/dataset_name.json",
242
+ info="Path where the dataset JSON will be saved",
243
+ )
244
+ with gr.Column(scale=1):
245
+ save_dataset_btn = gr.Button(
246
+ "💾 Save Dataset",
247
+ variant="primary",
248
+ size="lg",
249
+ )
250
+
251
+ save_status = gr.Textbox(
252
+ label="Save Status",
253
+ interactive=False,
254
+ lines=2,
255
+ )
256
+
257
+ gr.HTML("<hr><h3>⚡ Step 5: Preprocess to Tensors</h3>")
258
+
259
+ gr.Markdown("""
260
+ **Preprocessing converts your dataset to pre-computed tensors for fast training.**
261
+
262
+ You can either:
263
+ - Use the dataset from Steps 1-4 above, **OR**
264
+ - Load an existing dataset JSON file (if you've already saved one)
265
+ """)
266
+
267
+ with gr.Row():
268
+ with gr.Column(scale=3):
269
+ load_existing_dataset_path = gr.Textbox(
270
+ label="Load Existing Dataset (Optional)",
271
+ placeholder="./datasets/my_lora_dataset.json",
272
+ info="Path to a previously saved dataset JSON file",
273
+ )
274
+ with gr.Column(scale=1):
275
+ load_existing_dataset_btn = gr.Button(
276
+ "📂 Load Dataset",
277
+ variant="secondary",
278
+ size="lg",
279
+ )
280
+
281
+ load_existing_status = gr.Textbox(
282
+ label="Load Status",
283
+ interactive=False,
284
+ )
285
+
286
+ gr.Markdown("""
287
+ This step:
288
+ - Encodes audio to VAE latents
289
+ - Encodes captions and lyrics to text embeddings
290
+ - Runs the condition encoder
291
+ - Saves all tensors to `.pt` files
292
+
293
+ ⚠️ **This requires the model to be loaded and may take a few minutes.**
294
+ """)
295
+
296
+ with gr.Row():
297
+ with gr.Column(scale=3):
298
+ preprocess_output_dir = gr.Textbox(
299
+ label="Tensor Output Directory",
300
+ value="./datasets/preprocessed_tensors",
301
+ placeholder="./datasets/preprocessed_tensors",
302
+ info="Directory to save preprocessed tensor files",
303
+ )
304
+ with gr.Column(scale=1):
305
+ preprocess_btn = gr.Button(
306
+ "⚡ Preprocess",
307
+ variant="primary",
308
+ size="lg",
309
+ )
310
+
311
+ preprocess_progress = gr.Textbox(
312
+ label="Preprocessing Progress",
313
+ interactive=False,
314
+ lines=3,
315
+ )
316
+
317
+ # ==================== Training Tab ====================
318
+ with gr.Tab("🚀 Train LoRA"):
319
+ with gr.Row():
320
+ with gr.Column(scale=2):
321
+ gr.HTML("<h3>📊 Preprocessed Dataset Selection</h3>")
322
+
323
+ gr.Markdown("""
324
+ Select the directory containing preprocessed tensor files (`.pt` files).
325
+ These are created in the "Dataset Builder" tab using the "Preprocess" button.
326
+ """)
327
+
328
+ training_tensor_dir = gr.Textbox(
329
+ label="Preprocessed Tensors Directory",
330
+ placeholder="./datasets/preprocessed_tensors",
331
+ value="./datasets/preprocessed_tensors",
332
+ info="Directory containing preprocessed .pt tensor files",
333
+ )
334
+
335
+ load_dataset_btn = gr.Button("📂 Load Dataset", variant="secondary")
336
+
337
+ training_dataset_info = gr.Textbox(
338
+ label="Dataset Info",
339
+ interactive=False,
340
+ lines=3,
341
+ )
342
+
343
+ with gr.Column(scale=1):
344
+ gr.HTML("<h3>⚙️ LoRA Settings</h3>")
345
+
346
+ lora_rank = gr.Slider(
347
+ minimum=4,
348
+ maximum=256,
349
+ step=4,
350
+ value=64,
351
+ label="LoRA Rank (r)",
352
+ info="Higher = more capacity, more memory",
353
+ )
354
+
355
+ lora_alpha = gr.Slider(
356
+ minimum=4,
357
+ maximum=512,
358
+ step=4,
359
+ value=128,
360
+ label="LoRA Alpha",
361
+ info="Scaling factor (typically 2x rank)",
362
+ )
363
+
364
+ lora_dropout = gr.Slider(
365
+ minimum=0.0,
366
+ maximum=0.5,
367
+ step=0.05,
368
+ value=0.1,
369
+ label="LoRA Dropout",
370
+ )
371
+
372
+ gr.HTML("<hr><h3>🎛️ Training Parameters</h3>")
373
+
374
+ with gr.Row():
375
+ learning_rate = gr.Number(
376
+ label="Learning Rate",
377
+ value=1e-4,
378
+ info="Start with 1e-4, adjust if needed",
379
+ )
380
+
381
+ train_epochs = gr.Slider(
382
+ minimum=100,
383
+ maximum=4000,
384
+ step=100,
385
+ value=500,
386
+ label="Max Epochs",
387
+ )
388
+
389
+ train_batch_size = gr.Slider(
390
+ minimum=1,
391
+ maximum=8,
392
+ step=1,
393
+ value=1,
394
+ label="Batch Size",
395
+ info="Increase if you have enough VRAM",
396
+ )
397
+
398
+ gradient_accumulation = gr.Slider(
399
+ minimum=1,
400
+ maximum=16,
401
+ step=1,
402
+ value=1,
403
+ label="Gradient Accumulation",
404
+ info="Effective batch = batch_size × accumulation",
405
+ )
406
+
407
+ with gr.Row():
408
+ save_every_n_epochs = gr.Slider(
409
+ minimum=50,
410
+ maximum=1000,
411
+ step=50,
412
+ value=200,
413
+ label="Save Every N Epochs",
414
+ )
415
+
416
+ training_shift = gr.Slider(
417
+ minimum=1.0,
418
+ maximum=5.0,
419
+ step=0.5,
420
+ value=3.0,
421
+ label="Shift",
422
+ info="Timestep shift for turbo model",
423
+ )
424
+
425
+ training_seed = gr.Number(
426
+ label="Seed",
427
+ value=42,
428
+ precision=0,
429
+ )
430
+
431
+ with gr.Row():
432
+ lora_output_dir = gr.Textbox(
433
+ label="Output Directory",
434
+ value="./lora_output",
435
+ placeholder="./lora_output",
436
+ info="Directory to save trained LoRA weights",
437
+ )
438
+
439
+ gr.HTML("<hr>")
440
+
441
+ with gr.Row():
442
+ with gr.Column(scale=1):
443
+ start_training_btn = gr.Button(
444
+ "🚀 Start Training",
445
+ variant="primary",
446
+ size="lg",
447
+ )
448
+ with gr.Column(scale=1):
449
+ stop_training_btn = gr.Button(
450
+ "⏹️ Stop Training",
451
+ variant="stop",
452
+ size="lg",
453
+ )
454
+
455
+ training_progress = gr.Textbox(
456
+ label="Training Progress",
457
+ interactive=False,
458
+ lines=2,
459
+ )
460
+
461
+ with gr.Row():
462
+ training_log = gr.Textbox(
463
+ label="Training Log",
464
+ interactive=False,
465
+ lines=10,
466
+ max_lines=15,
467
+ scale=1,
468
+ )
469
+ training_loss_plot = gr.LinePlot(
470
+ x="step",
471
+ y="loss",
472
+ title="Training Loss",
473
+ x_title="Step",
474
+ y_title="Loss",
475
+ scale=1,
476
+ )
477
+
478
+ gr.HTML("<hr><h3>📦 Export LoRA</h3>")
479
+
480
+ with gr.Row():
481
+ export_path = gr.Textbox(
482
+ label="Export Path",
483
+ value="./lora_output/final_lora",
484
+ placeholder="./lora_output/my_lora",
485
+ )
486
+ export_lora_btn = gr.Button("📦 Export LoRA", variant="secondary")
487
+
488
+ export_status = gr.Textbox(
489
+ label="Export Status",
490
+ interactive=False,
491
+ )
492
+
493
+ # Store dataset builder state
494
+ dataset_builder_state = gr.State(None)
495
+ training_state = gr.State({"is_training": False, "should_stop": False})
496
+
497
+ return {
498
+ # Dataset Builder - Load or Scan
499
+ "load_json_path": load_json_path,
500
+ "load_json_btn": load_json_btn,
501
+ "load_json_status": load_json_status,
502
+ "audio_directory": audio_directory,
503
+ "scan_btn": scan_btn,
504
+ "scan_status": scan_status,
505
+ "audio_files_table": audio_files_table,
506
+ "dataset_name": dataset_name,
507
+ "all_instrumental": all_instrumental,
508
+ "need_lyrics": need_lyrics,
509
+ "custom_tag": custom_tag,
510
+ "tag_position": tag_position,
511
+ "skip_metas": skip_metas,
512
+ "auto_label_btn": auto_label_btn,
513
+ "label_progress": label_progress,
514
+ "sample_selector": sample_selector,
515
+ "preview_audio": preview_audio,
516
+ "preview_filename": preview_filename,
517
+ "edit_caption": edit_caption,
518
+ "edit_lyrics": edit_lyrics,
519
+ "edit_bpm": edit_bpm,
520
+ "edit_keyscale": edit_keyscale,
521
+ "edit_timesig": edit_timesig,
522
+ "edit_duration": edit_duration,
523
+ "edit_language": edit_language,
524
+ "edit_instrumental": edit_instrumental,
525
+ "save_edit_btn": save_edit_btn,
526
+ "edit_status": edit_status,
527
+ "save_path": save_path,
528
+ "save_dataset_btn": save_dataset_btn,
529
+ "save_status": save_status,
530
+ # Preprocessing
531
+ "load_existing_dataset_path": load_existing_dataset_path,
532
+ "load_existing_dataset_btn": load_existing_dataset_btn,
533
+ "load_existing_status": load_existing_status,
534
+ "preprocess_output_dir": preprocess_output_dir,
535
+ "preprocess_btn": preprocess_btn,
536
+ "preprocess_progress": preprocess_progress,
537
+ "dataset_builder_state": dataset_builder_state,
538
+ # Training
539
+ "training_tensor_dir": training_tensor_dir,
540
+ "load_dataset_btn": load_dataset_btn,
541
+ "training_dataset_info": training_dataset_info,
542
+ "lora_rank": lora_rank,
543
+ "lora_alpha": lora_alpha,
544
+ "lora_dropout": lora_dropout,
545
+ "learning_rate": learning_rate,
546
+ "train_epochs": train_epochs,
547
+ "train_batch_size": train_batch_size,
548
+ "gradient_accumulation": gradient_accumulation,
549
+ "save_every_n_epochs": save_every_n_epochs,
550
+ "training_shift": training_shift,
551
+ "training_seed": training_seed,
552
+ "lora_output_dir": lora_output_dir,
553
+ "start_training_btn": start_training_btn,
554
+ "stop_training_btn": stop_training_btn,
555
+ "training_progress": training_progress,
556
+ "training_log": training_log,
557
+ "training_loss_plot": training_loss_plot,
558
+ "export_path": export_path,
559
+ "export_lora_btn": export_lora_btn,
560
+ "export_status": export_status,
561
+ "training_state": training_state,
562
+ }
acestep/handler.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/inference.py ADDED
@@ -0,0 +1,1181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ACE-Step Inference API Module
3
+
4
+ This module provides a standardized inference interface for music generation,
5
+ designed for third-party integration. It offers both a simplified API and
6
+ backward-compatible Gradio UI support.
7
+ """
8
+
9
+ import math
10
+ import os
11
+ import tempfile
12
+ from typing import Optional, Union, List, Dict, Any, Tuple
13
+ from dataclasses import dataclass, field, asdict
14
+ from loguru import logger
15
+
16
+ from acestep.audio_utils import AudioSaver, generate_uuid_from_params
17
+
18
+ # HuggingFace Space environment detection
19
+ IS_HUGGINGFACE_SPACE = os.environ.get("SPACE_ID") is not None
20
+
21
+
22
+ @dataclass
23
+ class GenerationParams:
24
+ """Configuration for music generation parameters.
25
+
26
+ Attributes:
27
+ # Text Inputs
28
+ caption: A short text prompt describing the desired music (main prompt). < 512 characters
29
+ lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters
30
+ instrumental: If True, generate instrumental music regardless of lyrics.
31
+
32
+ # Music Metadata
33
+ bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300
34
+ keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor
35
+ timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection.
36
+ vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES
37
+ duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600
38
+
39
+ # Generation Parameters
40
+ inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model).
41
+ guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model.
42
+ seed: Integer seed for reproducibility. -1 means use random seed each time.
43
+
44
+ # Advanced DiT Parameters
45
+ use_adg: Whether to use Adaptive Dual Guidance (only works for base model).
46
+ cfg_interval_start: Start ratio (0.0–1.0) to apply CFG.
47
+ cfg_interval_end: End ratio (0.0–1.0) to apply CFG.
48
+ shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps.
49
+
50
+ # Task-Specific Parameters
51
+ task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete".
52
+ reference_audio: Path to a reference audio file for style transfer or cover tasks.
53
+ src_audio: Path to a source audio file for audio-to-audio tasks.
54
+ audio_codes: Audio semantic codes as a string (advanced use, for code-control generation).
55
+ repainting_start: For repaint/lego tasks: start time in seconds for region to repaint.
56
+ repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end).
57
+ audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks.
58
+ instruction: Optional task instruction prompt. If empty, auto-generated by system.
59
+
60
+ # 5Hz Language Model Parameters for CoT reasoning
61
+ thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes.
62
+ lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results.
63
+ lm_cfg_scale: Classifier-free guidance scale for the LLM.
64
+ lm_top_k: LLM top-k sampling (0 = disabled).
65
+ lm_top_p: LLM top-p nucleus sampling (1.0 = disabled).
66
+ lm_negative_prompt: Negative prompt to use for LLM (for control).
67
+ use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning.
68
+ use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning.
69
+ use_cot_language: Whether to let LLM detect vocal language via CoT.
70
+ """
71
+ # Required Inputs
72
+ task_type: str = "text2music"
73
+ instruction: str = "Fill the audio semantic mask based on the given conditions:"
74
+
75
+ # Audio Uploads
76
+ reference_audio: Optional[str] = None
77
+ src_audio: Optional[str] = None
78
+
79
+ # LM Codes Hints
80
+ audio_codes: str = ""
81
+
82
+ # Text Inputs
83
+ caption: str = ""
84
+ lyrics: str = ""
85
+ instrumental: bool = False
86
+
87
+ # Metadata
88
+ vocal_language: str = "unknown"
89
+ bpm: Optional[int] = None
90
+ keyscale: str = ""
91
+ timesignature: str = ""
92
+ duration: float = -1.0
93
+
94
+ # Advanced Settings
95
+ inference_steps: int = 8
96
+ seed: int = -1
97
+ guidance_scale: float = 7.0
98
+ use_adg: bool = False
99
+ cfg_interval_start: float = 0.0
100
+ cfg_interval_end: float = 1.0
101
+ shift: float = 1.0
102
+ infer_method: str = "ode" # "ode" or "sde" - diffusion inference method
103
+ # Custom timesteps (parsed from string like "0.97,0.76,0.615,0.5,0.395,0.28,0.18,0.085,0")
104
+ # If provided, overrides inference_steps and shift
105
+ timesteps: Optional[List[float]] = None
106
+
107
+ repainting_start: float = 0.0
108
+ repainting_end: float = -1
109
+ audio_cover_strength: float = 1.0
110
+
111
+ # 5Hz Language Model Parameters
112
+ thinking: bool = True
113
+ lm_temperature: float = 0.85
114
+ lm_cfg_scale: float = 2.0
115
+ lm_top_k: int = 0
116
+ lm_top_p: float = 0.9
117
+ lm_negative_prompt: str = "NO USER INPUT"
118
+ use_cot_metas: bool = True
119
+ use_cot_caption: bool = True
120
+ use_cot_lyrics: bool = False # TODO: not used yet
121
+ use_cot_language: bool = True
122
+ use_constrained_decoding: bool = True
123
+
124
+ cot_bpm: Optional[int] = None
125
+ cot_keyscale: str = ""
126
+ cot_timesignature: str = ""
127
+ cot_duration: Optional[float] = None
128
+ cot_vocal_language: str = "unknown"
129
+ cot_caption: str = ""
130
+ cot_lyrics: str = ""
131
+
132
+ def to_dict(self) -> Dict[str, Any]:
133
+ """Convert config to dictionary for JSON serialization."""
134
+ return asdict(self)
135
+
136
+
137
+ @dataclass
138
+ class GenerationConfig:
139
+ """Configuration for music generation.
140
+
141
+ Attributes:
142
+ batch_size: Number of audio samples to generate
143
+ allow_lm_batch: Whether to allow batch processing in LM
144
+ use_random_seed: Whether to use random seed
145
+ seeds: Seed(s) for batch generation. Can be:
146
+ - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False)
147
+ - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size
148
+ - int: Single seed value (will be converted to list and padded)
149
+ lm_batch_chunk_size: Batch chunk size for LM processing
150
+ constrained_decoding_debug: Whether to enable constrained decoding debug
151
+ audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac"
152
+ """
153
+ batch_size: int = 2
154
+ allow_lm_batch: bool = False
155
+ use_random_seed: bool = True
156
+ seeds: Optional[List[int]] = None
157
+ lm_batch_chunk_size: int = 8
158
+ constrained_decoding_debug: bool = False
159
+ audio_format: str = "flac" # Default to FLAC for fast saving
160
+
161
+ def to_dict(self) -> Dict[str, Any]:
162
+ """Convert config to dictionary for JSON serialization."""
163
+ return asdict(self)
164
+
165
+
166
+ @dataclass
167
+ class GenerationResult:
168
+ """Result of music generation.
169
+
170
+ Attributes:
171
+ # Audio Outputs
172
+ audios: List of audio dictionaries with paths, keys, params
173
+ status_message: Status message from generation
174
+ extra_outputs: Extra outputs from generation
175
+ success: Whether generation completed successfully
176
+ error: Error message if generation failed
177
+ """
178
+
179
+ # Audio Outputs
180
+ audios: List[Dict[str, Any]] = field(default_factory=list)
181
+ # Generation Information
182
+ status_message: str = ""
183
+ extra_outputs: Dict[str, Any] = field(default_factory=dict)
184
+ # Success Status
185
+ success: bool = True
186
+ error: Optional[str] = None
187
+
188
+ def to_dict(self) -> Dict[str, Any]:
189
+ """Convert result to dictionary for JSON serialization."""
190
+ return asdict(self)
191
+
192
+
193
+ @dataclass
194
+ class UnderstandResult:
195
+ """Result of music understanding from audio codes.
196
+
197
+ Attributes:
198
+ # Metadata Fields
199
+ caption: Generated caption describing the music
200
+ lyrics: Generated or extracted lyrics
201
+ bpm: Beats per minute (None if not detected)
202
+ duration: Duration in seconds (None if not detected)
203
+ keyscale: Musical key (e.g., "C Major")
204
+ language: Vocal language code (e.g., "en", "zh")
205
+ timesignature: Time signature (e.g., "4/4")
206
+
207
+ # Status
208
+ status_message: Status message from understanding
209
+ success: Whether understanding completed successfully
210
+ error: Error message if understanding failed
211
+ """
212
+ # Metadata Fields
213
+ caption: str = ""
214
+ lyrics: str = ""
215
+ bpm: Optional[int] = None
216
+ duration: Optional[float] = None
217
+ keyscale: str = ""
218
+ language: str = ""
219
+ timesignature: str = ""
220
+
221
+ # Status
222
+ status_message: str = ""
223
+ success: bool = True
224
+ error: Optional[str] = None
225
+
226
+ def to_dict(self) -> Dict[str, Any]:
227
+ """Convert result to dictionary for JSON serialization."""
228
+ return asdict(self)
229
+
230
+
231
+ def _update_metadata_from_lm(
232
+ metadata: Dict[str, Any],
233
+ bpm: Optional[int],
234
+ key_scale: str,
235
+ time_signature: str,
236
+ audio_duration: Optional[float],
237
+ vocal_language: str,
238
+ caption: str,
239
+ lyrics: str,
240
+ ) -> Tuple[Optional[int], str, str, Optional[float]]:
241
+ """Update metadata fields from LM output if not provided by user."""
242
+
243
+ if bpm is None and metadata.get('bpm'):
244
+ bpm_value = metadata.get('bpm')
245
+ if bpm_value not in ["N/A", ""]:
246
+ try:
247
+ bpm = int(bpm_value)
248
+ except (ValueError, TypeError):
249
+ pass
250
+
251
+ if not key_scale and metadata.get('keyscale'):
252
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
253
+ if key_scale_value != "N/A":
254
+ key_scale = key_scale_value
255
+
256
+ if not time_signature and metadata.get('timesignature'):
257
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
258
+ if time_signature_value != "N/A":
259
+ time_signature = time_signature_value
260
+
261
+ if audio_duration is None or audio_duration <= 0:
262
+ audio_duration_value = metadata.get('duration', -1)
263
+ if audio_duration_value not in ["N/A", ""]:
264
+ try:
265
+ audio_duration = float(audio_duration_value)
266
+ except (ValueError, TypeError):
267
+ pass
268
+
269
+ if not vocal_language and metadata.get('vocal_language'):
270
+ vocal_language = metadata.get('vocal_language')
271
+ if not caption and metadata.get('caption'):
272
+ caption = metadata.get('caption')
273
+ if not lyrics and metadata.get('lyrics'):
274
+ lyrics = metadata.get('lyrics')
275
+ return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics
276
+
277
+
278
+ def generate_music(
279
+ dit_handler,
280
+ llm_handler,
281
+ params: GenerationParams,
282
+ config: GenerationConfig,
283
+ save_dir: Optional[str] = None,
284
+ progress=None,
285
+ ) -> GenerationResult:
286
+ """Generate music using ACE-Step model with optional LM reasoning.
287
+
288
+ Args:
289
+ dit_handler: Initialized DiT model handler (AceStepHandler instance)
290
+ llm_handler: Initialized LLM handler (LLMHandler instance)
291
+ params: Generation parameters (GenerationParams instance)
292
+ config: Generation configuration (GenerationConfig instance)
293
+
294
+ Returns:
295
+ GenerationResult with generated audio files and metadata
296
+ """
297
+ try:
298
+ # Phase 1: LM-based metadata and code generation (if enabled)
299
+ audio_code_string_to_use = params.audio_codes
300
+ lm_generated_metadata = None
301
+ lm_generated_audio_codes_list = []
302
+ lm_total_time_costs = {
303
+ "phase1_time": 0.0,
304
+ "phase2_time": 0.0,
305
+ "total_time": 0.0,
306
+ }
307
+
308
+ # Extract mutable copies of metadata (will be updated by LM if needed)
309
+ bpm = params.bpm
310
+ key_scale = params.keyscale
311
+ time_signature = params.timesignature
312
+ audio_duration = params.duration
313
+ dit_input_caption = params.caption
314
+ dit_input_vocal_language = params.vocal_language
315
+ dit_input_lyrics = params.lyrics
316
+ # Determine if we need to generate audio codes
317
+ # If user has provided audio_codes, we don't need to generate them
318
+ # Otherwise, check if we need audio codes (lm_dit mode) or just metas (dit mode)
319
+ user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip())
320
+
321
+ # Determine infer_type: use "llm_dit" if we need audio codes, "dit" if only metas needed
322
+ # For now, we use "llm_dit" if batch mode or if user hasn't provided codes
323
+ # Use "dit" if user has provided codes (only need metas) or if explicitly only need metas
324
+ # Note: This logic can be refined based on specific requirements
325
+ need_audio_codes = not user_provided_audio_codes
326
+
327
+ # Determine if we should use chunk-based LM generation (always use chunks for consistency)
328
+ # Determine actual batch size for chunk processing
329
+ actual_batch_size = config.batch_size if config.batch_size is not None else 1
330
+
331
+ # Prepare seeds for batch generation
332
+ # Use config.seed if provided, otherwise fallback to params.seed
333
+ # Convert config.seed (None, int, or List[int]) to format that prepare_seeds accepts
334
+ seed_for_generation = ""
335
+ if config.seeds is not None and len(config.seeds) > 0:
336
+ if isinstance(config.seeds, list):
337
+ # Convert List[int] to comma-separated string
338
+ seed_for_generation = ",".join(str(s) for s in config.seeds)
339
+
340
+ # Use dit_handler.prepare_seeds to handle seed list generation and padding
341
+ # This will handle all the logic: padding with random seeds if needed, etc.
342
+ actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed)
343
+
344
+ # LM-based Chain-of-Thought reasoning
345
+ # Skip LM for cover/repaint tasks - these tasks use reference/src audio directly
346
+ # and don't need LM to generate audio codes
347
+ skip_lm_tasks = {"cover", "repaint"}
348
+
349
+ # Determine if we should use LLM
350
+ # LLM is needed for:
351
+ # 1. thinking=True: generate audio codes via LM
352
+ # 2. use_cot_caption=True: enhance/generate caption via CoT
353
+ # 3. use_cot_language=True: detect vocal language via CoT
354
+ # 4. use_cot_metas=True: fill missing metadata via CoT
355
+ need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas
356
+ use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks
357
+ lm_status = []
358
+
359
+ if params.task_type in skip_lm_tasks:
360
+ logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly")
361
+
362
+ logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, "
363
+ f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, "
364
+ f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, "
365
+ f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}")
366
+
367
+ if use_lm:
368
+ # Convert sampling parameters - handle None values safely
369
+ top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k)
370
+ top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p
371
+
372
+ # Build user_metadata from user-provided values
373
+ user_metadata = {}
374
+ if bpm is not None:
375
+ try:
376
+ bpm_value = float(bpm)
377
+ if bpm_value > 0:
378
+ user_metadata['bpm'] = int(bpm_value)
379
+ except (ValueError, TypeError):
380
+ pass
381
+
382
+ if key_scale and key_scale.strip():
383
+ key_scale_clean = key_scale.strip()
384
+ if key_scale_clean.lower() not in ["n/a", ""]:
385
+ user_metadata['keyscale'] = key_scale_clean
386
+
387
+ if time_signature and time_signature.strip():
388
+ time_sig_clean = time_signature.strip()
389
+ if time_sig_clean.lower() not in ["n/a", ""]:
390
+ user_metadata['timesignature'] = time_sig_clean
391
+
392
+ if audio_duration is not None:
393
+ try:
394
+ duration_value = float(audio_duration)
395
+ if duration_value > 0:
396
+ user_metadata['duration'] = int(duration_value)
397
+ except (ValueError, TypeError):
398
+ pass
399
+
400
+ user_metadata_to_pass = user_metadata if user_metadata else None
401
+
402
+ # Determine infer_type based on whether we need audio codes
403
+ # - "llm_dit": generates both metas and audio codes (two-phase internally)
404
+ # - "dit": generates only metas (single phase)
405
+ infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit"
406
+
407
+ # Use chunk size from config, or default to batch_size if not set
408
+ max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size
409
+ num_chunks = math.ceil(actual_batch_size / max_inference_batch_size)
410
+
411
+ all_metadata_list = []
412
+ all_audio_codes_list = []
413
+
414
+ for chunk_idx in range(num_chunks):
415
+ chunk_start = chunk_idx * max_inference_batch_size
416
+ chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size)
417
+ chunk_size = chunk_end - chunk_start
418
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None
419
+
420
+ logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) "
421
+ f"(size: {chunk_size}, seeds: {chunk_seeds})")
422
+
423
+ # Use the determined infer_type
424
+ # - "llm_dit" will internally run two phases (metas + codes)
425
+ # - "dit" will only run phase 1 (metas only)
426
+ result = llm_handler.generate_with_stop_condition(
427
+ caption=params.caption or "",
428
+ lyrics=params.lyrics or "",
429
+ infer_type=infer_type,
430
+ temperature=params.lm_temperature,
431
+ cfg_scale=params.lm_cfg_scale,
432
+ negative_prompt=params.lm_negative_prompt,
433
+ top_k=top_k_value,
434
+ top_p=top_p_value,
435
+ target_duration=audio_duration, # Pass duration to limit audio codes generation
436
+ user_metadata=user_metadata_to_pass,
437
+ use_cot_caption=params.use_cot_caption,
438
+ use_cot_language=params.use_cot_language,
439
+ use_cot_metas=params.use_cot_metas,
440
+ use_constrained_decoding=params.use_constrained_decoding,
441
+ constrained_decoding_debug=config.constrained_decoding_debug,
442
+ batch_size=chunk_size,
443
+ seeds=chunk_seeds,
444
+ progress=progress,
445
+ )
446
+
447
+ # Check if LM generation failed
448
+ if not result.get("success", False):
449
+ error_msg = result.get("error", "Unknown LM error")
450
+ lm_status.append(f"❌ LM Error: {error_msg}")
451
+ # Return early with error
452
+ return GenerationResult(
453
+ audios=[],
454
+ status_message=f"❌ LM generation failed: {error_msg}",
455
+ extra_outputs={},
456
+ success=False,
457
+ error=error_msg,
458
+ )
459
+
460
+ # Extract metadata and audio_codes from result dict
461
+ if chunk_size > 1:
462
+ metadata_list = result.get("metadata", [])
463
+ audio_codes_list = result.get("audio_codes", [])
464
+ all_metadata_list.extend(metadata_list)
465
+ all_audio_codes_list.extend(audio_codes_list)
466
+ else:
467
+ metadata = result.get("metadata", {})
468
+ audio_codes = result.get("audio_codes", "")
469
+ all_metadata_list.append(metadata)
470
+ all_audio_codes_list.append(audio_codes)
471
+
472
+ # Collect time costs from LM extra_outputs
473
+ lm_extra = result.get("extra_outputs", {})
474
+ lm_chunk_time_costs = lm_extra.get("time_costs", {})
475
+ if lm_chunk_time_costs:
476
+ # Accumulate time costs from all chunks
477
+ for key in ["phase1_time", "phase2_time", "total_time"]:
478
+ if key in lm_chunk_time_costs:
479
+ lm_total_time_costs[key] += lm_chunk_time_costs[key]
480
+
481
+ time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()])
482
+ lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}")
483
+
484
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
485
+ lm_generated_audio_codes_list = all_audio_codes_list
486
+
487
+ # Set audio_code_string_to_use based on infer_type
488
+ if infer_type == "llm_dit":
489
+ # If batch mode, use list; otherwise use single string
490
+ if actual_batch_size > 1:
491
+ audio_code_string_to_use = all_audio_codes_list
492
+ else:
493
+ audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else ""
494
+ else:
495
+ # For "dit" mode, keep user-provided codes or empty
496
+ audio_code_string_to_use = params.audio_codes
497
+
498
+ # Update metadata from LM if not provided by user
499
+ if lm_generated_metadata:
500
+ bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm(
501
+ metadata=lm_generated_metadata,
502
+ bpm=bpm,
503
+ key_scale=key_scale,
504
+ time_signature=time_signature,
505
+ audio_duration=audio_duration,
506
+ vocal_language=dit_input_vocal_language,
507
+ caption=dit_input_caption,
508
+ lyrics=dit_input_lyrics)
509
+ if not params.bpm:
510
+ params.cot_bpm = bpm
511
+ if not params.keyscale:
512
+ params.cot_keyscale = key_scale
513
+ if not params.timesignature:
514
+ params.cot_timesignature = time_signature
515
+ if not params.duration:
516
+ params.cot_duration = audio_duration
517
+ if not params.vocal_language:
518
+ params.cot_vocal_language = vocal_language
519
+ if not params.caption:
520
+ params.cot_caption = caption
521
+ if not params.lyrics:
522
+ params.cot_lyrics = lyrics
523
+
524
+ # set cot caption and language if needed
525
+ if params.use_cot_caption:
526
+ dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption)
527
+ if params.use_cot_language:
528
+ dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language)
529
+
530
+ # Phase 2: DiT music generation
531
+ # Use seed_for_generation (from config.seed or params.seed) instead of params.seed for actual generation
532
+ result = dit_handler.generate_music(
533
+ captions=dit_input_caption,
534
+ lyrics=dit_input_lyrics,
535
+ bpm=bpm,
536
+ key_scale=key_scale,
537
+ time_signature=time_signature,
538
+ vocal_language=dit_input_vocal_language,
539
+ inference_steps=params.inference_steps,
540
+ guidance_scale=params.guidance_scale,
541
+ use_random_seed=config.use_random_seed,
542
+ seed=seed_for_generation, # Use config.seed (or params.seed fallback) instead of params.seed directly
543
+ reference_audio=params.reference_audio,
544
+ audio_duration=audio_duration,
545
+ batch_size=config.batch_size if config.batch_size is not None else 1,
546
+ src_audio=params.src_audio,
547
+ audio_code_string=audio_code_string_to_use,
548
+ repainting_start=params.repainting_start,
549
+ repainting_end=params.repainting_end,
550
+ instruction=params.instruction,
551
+ audio_cover_strength=params.audio_cover_strength,
552
+ task_type=params.task_type,
553
+ use_adg=params.use_adg,
554
+ cfg_interval_start=params.cfg_interval_start,
555
+ cfg_interval_end=params.cfg_interval_end,
556
+ shift=params.shift,
557
+ infer_method=params.infer_method,
558
+ timesteps=params.timesteps,
559
+ progress=progress,
560
+ )
561
+
562
+ # Check if generation failed
563
+ if not result.get("success", False):
564
+ return GenerationResult(
565
+ audios=[],
566
+ status_message=result.get("status_message", ""),
567
+ extra_outputs={},
568
+ success=False,
569
+ error=result.get("error"),
570
+ )
571
+
572
+ # Extract results from dit_handler.generate_music dict
573
+ dit_audios = result.get("audios", [])
574
+ status_message = result.get("status_message", "")
575
+ dit_extra_outputs = result.get("extra_outputs", {})
576
+
577
+ # Use the seed list already prepared above (from config.seed or params.seed fallback)
578
+ # actual_seed_list was computed earlier using dit_handler.prepare_seeds
579
+ seed_list = actual_seed_list
580
+
581
+ # Get base params dictionary
582
+ base_params_dict = params.to_dict()
583
+
584
+ # Save audio files using AudioSaver (format from config)
585
+ audio_format = config.audio_format if config.audio_format else "flac"
586
+ audio_saver = AudioSaver(default_format=audio_format)
587
+
588
+ # Use handler's temp_dir for saving files
589
+ if save_dir is not None:
590
+ os.makedirs(save_dir, exist_ok=True)
591
+
592
+ # Build audios list for GenerationResult with params and save files
593
+ # Audio saving and UUID generation handled here, outside of handler
594
+ audios = []
595
+ for idx, dit_audio in enumerate(dit_audios):
596
+ # Create a copy of params dict for this audio
597
+ audio_params = base_params_dict.copy()
598
+
599
+ # Update audio-specific values
600
+ audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None
601
+
602
+ # Add audio codes if batch mode
603
+ if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list):
604
+ audio_params["audio_codes"] = lm_generated_audio_codes_list[idx]
605
+
606
+ # Get audio tensor and metadata
607
+ audio_tensor = dit_audio.get("tensor")
608
+ sample_rate = dit_audio.get("sample_rate", 48000)
609
+
610
+ # Generate UUID for this audio (moved from handler)
611
+ batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1
612
+ audio_code_str = lm_generated_audio_codes_list[idx] if (
613
+ lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use
614
+ if isinstance(audio_code_str, list):
615
+ audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else ""
616
+
617
+ audio_key = generate_uuid_from_params(audio_params)
618
+
619
+ # Save audio file (handled outside handler)
620
+ audio_path = None
621
+ if audio_tensor is not None and save_dir is not None:
622
+ try:
623
+ audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}")
624
+ audio_path = audio_saver.save_audio(audio_tensor,
625
+ audio_file,
626
+ sample_rate=sample_rate,
627
+ format=audio_format,
628
+ channels_first=True)
629
+ except Exception as e:
630
+ logger.error(f"[generate_music] Failed to save audio file: {e}")
631
+ audio_path = "" # Fallback to empty path
632
+
633
+ audio_dict = {
634
+ "path": audio_path or "", # File path (saved here, not in handler)
635
+ "tensor": audio_tensor, # Audio tensor [channels, samples], CPU, float32
636
+ "key": audio_key,
637
+ "sample_rate": sample_rate,
638
+ "params": audio_params,
639
+ }
640
+
641
+ audios.append(audio_dict)
642
+
643
+ # Merge extra_outputs: include dit_extra_outputs (latents, masks) and add LM metadata
644
+ extra_outputs = dit_extra_outputs.copy()
645
+ extra_outputs["lm_metadata"] = lm_generated_metadata
646
+
647
+ # Merge time_costs from both LM and DiT into a unified dictionary
648
+ unified_time_costs = {}
649
+
650
+ # Add LM time costs (if LM was used)
651
+ if use_lm and lm_total_time_costs:
652
+ for key, value in lm_total_time_costs.items():
653
+ unified_time_costs[f"lm_{key}"] = value
654
+
655
+ # Add DiT time costs (if available)
656
+ dit_time_costs = dit_extra_outputs.get("time_costs", {})
657
+ if dit_time_costs:
658
+ for key, value in dit_time_costs.items():
659
+ unified_time_costs[f"dit_{key}"] = value
660
+
661
+ # Calculate total pipeline time
662
+ if unified_time_costs:
663
+ lm_total = unified_time_costs.get("lm_total_time", 0.0)
664
+ dit_total = unified_time_costs.get("dit_total_time_cost", 0.0)
665
+ unified_time_costs["pipeline_total_time"] = lm_total + dit_total
666
+
667
+ # Update extra_outputs with unified time_costs
668
+ extra_outputs["time_costs"] = unified_time_costs
669
+
670
+ if lm_status:
671
+ status_message = "\n".join(lm_status) + "\n" + status_message
672
+ else:
673
+ status_message = status_message
674
+ # Create and return GenerationResult
675
+ return GenerationResult(
676
+ audios=audios,
677
+ status_message=status_message,
678
+ extra_outputs=extra_outputs,
679
+ success=True,
680
+ error=None,
681
+ )
682
+
683
+ except Exception as e:
684
+ logger.exception("Music generation failed")
685
+ return GenerationResult(
686
+ audios=[],
687
+ status_message=f"Error: {str(e)}",
688
+ extra_outputs={},
689
+ success=False,
690
+ error=str(e),
691
+ )
692
+
693
+
694
+ def understand_music(
695
+ llm_handler,
696
+ audio_codes: str,
697
+ temperature: float = 0.85,
698
+ top_k: Optional[int] = None,
699
+ top_p: Optional[float] = None,
700
+ repetition_penalty: float = 1.0,
701
+ use_constrained_decoding: bool = True,
702
+ constrained_decoding_debug: bool = False,
703
+ ) -> UnderstandResult:
704
+ """Understand music from audio codes using the 5Hz Language Model.
705
+
706
+ This function analyzes audio semantic codes and generates metadata about the music,
707
+ including caption, lyrics, BPM, duration, key scale, language, and time signature.
708
+
709
+ If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example
710
+ instead of analyzing existing codes.
711
+
712
+ Note: cfg_scale and negative_prompt are not supported in understand mode.
713
+
714
+ Args:
715
+ llm_handler: Initialized LLM handler (LLMHandler instance)
716
+ audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...")
717
+ Use empty string or "NO USER INPUT" to generate a sample example.
718
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
719
+ top_k: Top-K sampling (None or 0 = disabled)
720
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
721
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
722
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
723
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
724
+
725
+ Returns:
726
+ UnderstandResult with parsed metadata fields and status
727
+
728
+ Example:
729
+ >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...")
730
+ >>> if result.success:
731
+ ... print(f"Caption: {result.caption}")
732
+ ... print(f"BPM: {result.bpm}")
733
+ ... print(f"Lyrics: {result.lyrics}")
734
+ """
735
+ # Check if LLM is initialized
736
+ if not llm_handler.llm_initialized:
737
+ return UnderstandResult(
738
+ status_message="5Hz LM not initialized. Please initialize it first.",
739
+ success=False,
740
+ error="LLM not initialized",
741
+ )
742
+
743
+ # If codes are empty, use "NO USER INPUT" to generate a sample example
744
+ if not audio_codes or not audio_codes.strip():
745
+ audio_codes = "NO USER INPUT"
746
+
747
+ try:
748
+ # Call LLM understanding
749
+ metadata, status = llm_handler.understand_audio_from_codes(
750
+ audio_codes=audio_codes,
751
+ temperature=temperature,
752
+ top_k=top_k,
753
+ top_p=top_p,
754
+ repetition_penalty=repetition_penalty,
755
+ use_constrained_decoding=use_constrained_decoding,
756
+ constrained_decoding_debug=constrained_decoding_debug,
757
+ )
758
+
759
+ # Check if LLM returned empty metadata (error case)
760
+ if not metadata:
761
+ return UnderstandResult(
762
+ status_message=status or "Failed to understand audio codes",
763
+ success=False,
764
+ error=status or "Empty metadata returned",
765
+ )
766
+
767
+ # Extract and convert fields
768
+ caption = metadata.get('caption', '')
769
+ lyrics = metadata.get('lyrics', '')
770
+ keyscale = metadata.get('keyscale', '')
771
+ language = metadata.get('language', metadata.get('vocal_language', ''))
772
+ timesignature = metadata.get('timesignature', '')
773
+
774
+ # Convert BPM to int
775
+ bpm = None
776
+ bpm_value = metadata.get('bpm')
777
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
778
+ try:
779
+ bpm = int(bpm_value)
780
+ except (ValueError, TypeError):
781
+ pass
782
+
783
+ # Convert duration to float
784
+ duration = None
785
+ duration_value = metadata.get('duration')
786
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
787
+ try:
788
+ duration = float(duration_value)
789
+ except (ValueError, TypeError):
790
+ pass
791
+
792
+ # Clean up N/A values
793
+ if keyscale == 'N/A':
794
+ keyscale = ''
795
+ if language == 'N/A':
796
+ language = ''
797
+ if timesignature == 'N/A':
798
+ timesignature = ''
799
+
800
+ return UnderstandResult(
801
+ caption=caption,
802
+ lyrics=lyrics,
803
+ bpm=bpm,
804
+ duration=duration,
805
+ keyscale=keyscale,
806
+ language=language,
807
+ timesignature=timesignature,
808
+ status_message=status,
809
+ success=True,
810
+ error=None,
811
+ )
812
+
813
+ except Exception as e:
814
+ logger.exception("Music understanding failed")
815
+ return UnderstandResult(
816
+ status_message=f"Error: {str(e)}",
817
+ success=False,
818
+ error=str(e),
819
+ )
820
+
821
+
822
+ @dataclass
823
+ class CreateSampleResult:
824
+ """Result of creating a music sample from a natural language query.
825
+
826
+ This is used by the "Simple Mode" / "Inspiration Mode" feature where users
827
+ provide a natural language description and the LLM generates a complete
828
+ sample with caption, lyrics, and metadata.
829
+
830
+ Attributes:
831
+ # Metadata Fields
832
+ caption: Generated detailed music description/caption
833
+ lyrics: Generated lyrics (or "[Instrumental]" for instrumental music)
834
+ bpm: Beats per minute (None if not generated)
835
+ duration: Duration in seconds (None if not generated)
836
+ keyscale: Musical key (e.g., "C Major")
837
+ language: Vocal language code (e.g., "en", "zh")
838
+ timesignature: Time signature (e.g., "4")
839
+ instrumental: Whether this is an instrumental piece
840
+
841
+ # Status
842
+ status_message: Status message from sample creation
843
+ success: Whether sample creation completed successfully
844
+ error: Error message if sample creation failed
845
+ """
846
+ # Metadata Fields
847
+ caption: str = ""
848
+ lyrics: str = ""
849
+ bpm: Optional[int] = None
850
+ duration: Optional[float] = None
851
+ keyscale: str = ""
852
+ language: str = ""
853
+ timesignature: str = ""
854
+ instrumental: bool = False
855
+
856
+ # Status
857
+ status_message: str = ""
858
+ success: bool = True
859
+ error: Optional[str] = None
860
+
861
+ def to_dict(self) -> Dict[str, Any]:
862
+ """Convert result to dictionary for JSON serialization."""
863
+ return asdict(self)
864
+
865
+
866
+ def create_sample(
867
+ llm_handler,
868
+ query: str,
869
+ instrumental: bool = False,
870
+ vocal_language: Optional[str] = None,
871
+ temperature: float = 0.85,
872
+ top_k: Optional[int] = None,
873
+ top_p: Optional[float] = None,
874
+ repetition_penalty: float = 1.0,
875
+ use_constrained_decoding: bool = True,
876
+ constrained_decoding_debug: bool = False,
877
+ ) -> CreateSampleResult:
878
+ """Create a music sample from a natural language query using the 5Hz Language Model.
879
+
880
+ This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural
881
+ language description of music and generates a complete sample including:
882
+ - Detailed caption/description
883
+ - Lyrics (unless instrumental)
884
+ - Metadata (BPM, duration, key, language, time signature)
885
+
886
+ Note: cfg_scale and negative_prompt are not supported in create_sample mode.
887
+
888
+ Args:
889
+ llm_handler: Initialized LLM handler (LLMHandler instance)
890
+ query: User's natural language music description (e.g., "a soft Bengali love song")
891
+ instrumental: Whether to generate instrumental music (no vocals)
892
+ vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh").
893
+ If provided, the model will be constrained to generate lyrics in this language.
894
+ If None or "unknown", no language constraint is applied.
895
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
896
+ top_k: Top-K sampling (None or 0 = disabled)
897
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
898
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
899
+ use_constrained_decoding: Whether to use FSM-based constrained decoding
900
+ constrained_decoding_debug: Whether to enable debug logging
901
+
902
+ Returns:
903
+ CreateSampleResult with generated sample fields and status
904
+
905
+ Example:
906
+ >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn")
907
+ >>> if result.success:
908
+ ... print(f"Caption: {result.caption}")
909
+ ... print(f"Lyrics: {result.lyrics}")
910
+ ... print(f"BPM: {result.bpm}")
911
+ """
912
+ import torch
913
+ # Debug logging for ZeroGPU diagnosis
914
+ logger.info(f"[create_sample Debug] Entry: IS_HUGGINGFACE_SPACE={IS_HUGGINGFACE_SPACE}")
915
+ logger.info(f"[create_sample Debug] torch.cuda.is_available()={torch.cuda.is_available()}")
916
+ if torch.cuda.is_available():
917
+ logger.info(f"[create_sample Debug] torch.cuda.current_device()={torch.cuda.current_device()}")
918
+ logger.info(f"[create_sample Debug] llm_handler.device={llm_handler.device}, llm_handler.offload_to_cpu={llm_handler.offload_to_cpu}")
919
+ if llm_handler.llm is not None:
920
+ try:
921
+ logger.info(f"[create_sample Debug] Model device: {next(llm_handler.llm.parameters()).device}")
922
+ except Exception as e:
923
+ logger.info(f"[create_sample Debug] Could not get model device: {e}")
924
+
925
+ # Check if LLM is initialized
926
+ if not llm_handler.llm_initialized:
927
+ return CreateSampleResult(
928
+ status_message="5Hz LM not initialized. Please initialize it first.",
929
+ success=False,
930
+ error="LLM not initialized",
931
+ )
932
+
933
+ try:
934
+ # Call LLM to create sample
935
+ metadata, status = llm_handler.create_sample_from_query(
936
+ query=query,
937
+ instrumental=instrumental,
938
+ vocal_language=vocal_language,
939
+ temperature=temperature,
940
+ top_k=top_k,
941
+ top_p=top_p,
942
+ repetition_penalty=repetition_penalty,
943
+ use_constrained_decoding=use_constrained_decoding,
944
+ constrained_decoding_debug=constrained_decoding_debug,
945
+ )
946
+
947
+ # Check if LLM returned empty metadata (error case)
948
+ if not metadata:
949
+ return CreateSampleResult(
950
+ status_message=status or "Failed to create sample",
951
+ success=False,
952
+ error=status or "Empty metadata returned",
953
+ )
954
+
955
+ # Extract and convert fields
956
+ caption = metadata.get('caption', '')
957
+ lyrics = metadata.get('lyrics', '')
958
+ keyscale = metadata.get('keyscale', '')
959
+ language = metadata.get('language', metadata.get('vocal_language', ''))
960
+ timesignature = metadata.get('timesignature', '')
961
+ is_instrumental = metadata.get('instrumental', instrumental)
962
+
963
+ # Convert BPM to int
964
+ bpm = None
965
+ bpm_value = metadata.get('bpm')
966
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
967
+ try:
968
+ bpm = int(bpm_value)
969
+ except (ValueError, TypeError):
970
+ pass
971
+
972
+ # Convert duration to float
973
+ duration = None
974
+ duration_value = metadata.get('duration')
975
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
976
+ try:
977
+ duration = float(duration_value)
978
+ except (ValueError, TypeError):
979
+ pass
980
+
981
+ # Clean up N/A values
982
+ if keyscale == 'N/A':
983
+ keyscale = ''
984
+ if language == 'N/A':
985
+ language = ''
986
+ if timesignature == 'N/A':
987
+ timesignature = ''
988
+
989
+ return CreateSampleResult(
990
+ caption=caption,
991
+ lyrics=lyrics,
992
+ bpm=bpm,
993
+ duration=duration,
994
+ keyscale=keyscale,
995
+ language=language,
996
+ timesignature=timesignature,
997
+ instrumental=is_instrumental,
998
+ status_message=status,
999
+ success=True,
1000
+ error=None,
1001
+ )
1002
+
1003
+ except Exception as e:
1004
+ logger.exception("Sample creation failed")
1005
+ return CreateSampleResult(
1006
+ status_message=f"Error: {str(e)}",
1007
+ success=False,
1008
+ error=str(e),
1009
+ )
1010
+
1011
+
1012
+ @dataclass
1013
+ class FormatSampleResult:
1014
+ """Result of formatting user-provided caption and lyrics.
1015
+
1016
+ This is used by the "Format" feature where users provide caption and lyrics,
1017
+ and the LLM formats them into structured music metadata and an enhanced description.
1018
+
1019
+ Attributes:
1020
+ # Metadata Fields
1021
+ caption: Enhanced/formatted music description/caption
1022
+ lyrics: Formatted lyrics (may be same as input or reformatted)
1023
+ bpm: Beats per minute (None if not detected)
1024
+ duration: Duration in seconds (None if not detected)
1025
+ keyscale: Musical key (e.g., "C Major")
1026
+ language: Vocal language code (e.g., "en", "zh")
1027
+ timesignature: Time signature (e.g., "4")
1028
+
1029
+ # Status
1030
+ status_message: Status message from formatting
1031
+ success: Whether formatting completed successfully
1032
+ error: Error message if formatting failed
1033
+ """
1034
+ # Metadata Fields
1035
+ caption: str = ""
1036
+ lyrics: str = ""
1037
+ bpm: Optional[int] = None
1038
+ duration: Optional[float] = None
1039
+ keyscale: str = ""
1040
+ language: str = ""
1041
+ timesignature: str = ""
1042
+
1043
+ # Status
1044
+ status_message: str = ""
1045
+ success: bool = True
1046
+ error: Optional[str] = None
1047
+
1048
+ def to_dict(self) -> Dict[str, Any]:
1049
+ """Convert result to dictionary for JSON serialization."""
1050
+ return asdict(self)
1051
+
1052
+
1053
+ def format_sample(
1054
+ llm_handler,
1055
+ caption: str,
1056
+ lyrics: str,
1057
+ user_metadata: Optional[Dict[str, Any]] = None,
1058
+ temperature: float = 0.85,
1059
+ top_k: Optional[int] = None,
1060
+ top_p: Optional[float] = None,
1061
+ repetition_penalty: float = 1.0,
1062
+ use_constrained_decoding: bool = True,
1063
+ constrained_decoding_debug: bool = False,
1064
+ ) -> FormatSampleResult:
1065
+ """Format user-provided caption and lyrics using the 5Hz Language Model.
1066
+
1067
+ This function takes user input (caption and lyrics) and generates structured
1068
+ music metadata including an enhanced caption, BPM, duration, key, language,
1069
+ and time signature.
1070
+
1071
+ If user_metadata is provided, those values will be used to constrain the
1072
+ decoding, ensuring the output matches user-specified values.
1073
+
1074
+ Note: cfg_scale and negative_prompt are not supported in format mode.
1075
+
1076
+ Args:
1077
+ llm_handler: Initialized LLM handler (LLMHandler instance)
1078
+ caption: User's caption/description (e.g., "Latin pop, reggaeton")
1079
+ lyrics: User's lyrics with structure tags
1080
+ user_metadata: Optional dict with user-provided metadata to constrain decoding.
1081
+ Supported keys: bpm, duration, keyscale, timesignature, language
1082
+ temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative.
1083
+ top_k: Top-K sampling (None or 0 = disabled)
1084
+ top_p: Top-P (nucleus) sampling (None or 1.0 = disabled)
1085
+ repetition_penalty: Repetition penalty (1.0 = no penalty)
1086
+ use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata
1087
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
1088
+
1089
+ Returns:
1090
+ FormatSampleResult with formatted metadata fields and status
1091
+
1092
+ Example:
1093
+ >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...")
1094
+ >>> if result.success:
1095
+ ... print(f"Caption: {result.caption}")
1096
+ ... print(f"BPM: {result.bpm}")
1097
+ ... print(f"Lyrics: {result.lyrics}")
1098
+ """
1099
+ # Check if LLM is initialized
1100
+ if not llm_handler.llm_initialized:
1101
+ return FormatSampleResult(
1102
+ status_message="5Hz LM not initialized. Please initialize it first.",
1103
+ success=False,
1104
+ error="LLM not initialized",
1105
+ )
1106
+
1107
+ try:
1108
+ # Call LLM formatting
1109
+ metadata, status = llm_handler.format_sample_from_input(
1110
+ caption=caption,
1111
+ lyrics=lyrics,
1112
+ user_metadata=user_metadata,
1113
+ temperature=temperature,
1114
+ top_k=top_k,
1115
+ top_p=top_p,
1116
+ repetition_penalty=repetition_penalty,
1117
+ use_constrained_decoding=use_constrained_decoding,
1118
+ constrained_decoding_debug=constrained_decoding_debug,
1119
+ )
1120
+
1121
+ # Check if LLM returned empty metadata (error case)
1122
+ if not metadata:
1123
+ return FormatSampleResult(
1124
+ status_message=status or "Failed to format input",
1125
+ success=False,
1126
+ error=status or "Empty metadata returned",
1127
+ )
1128
+
1129
+ # Extract and convert fields
1130
+ result_caption = metadata.get('caption', '')
1131
+ result_lyrics = metadata.get('lyrics', lyrics) # Fall back to input lyrics
1132
+ keyscale = metadata.get('keyscale', '')
1133
+ language = metadata.get('language', metadata.get('vocal_language', ''))
1134
+ timesignature = metadata.get('timesignature', '')
1135
+
1136
+ # Convert BPM to int
1137
+ bpm = None
1138
+ bpm_value = metadata.get('bpm')
1139
+ if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '':
1140
+ try:
1141
+ bpm = int(bpm_value)
1142
+ except (ValueError, TypeError):
1143
+ pass
1144
+
1145
+ # Convert duration to float
1146
+ duration = None
1147
+ duration_value = metadata.get('duration')
1148
+ if duration_value is not None and duration_value != 'N/A' and duration_value != '':
1149
+ try:
1150
+ duration = float(duration_value)
1151
+ except (ValueError, TypeError):
1152
+ pass
1153
+
1154
+ # Clean up N/A values
1155
+ if keyscale == 'N/A':
1156
+ keyscale = ''
1157
+ if language == 'N/A':
1158
+ language = ''
1159
+ if timesignature == 'N/A':
1160
+ timesignature = ''
1161
+
1162
+ return FormatSampleResult(
1163
+ caption=result_caption,
1164
+ lyrics=result_lyrics,
1165
+ bpm=bpm,
1166
+ duration=duration,
1167
+ keyscale=keyscale,
1168
+ language=language,
1169
+ timesignature=timesignature,
1170
+ status_message=status,
1171
+ success=True,
1172
+ error=None,
1173
+ )
1174
+
1175
+ except Exception as e:
1176
+ logger.exception("Format sample failed")
1177
+ return FormatSampleResult(
1178
+ status_message=f"Error: {str(e)}",
1179
+ success=False,
1180
+ error=str(e),
1181
+ )
acestep/llm_inference.py ADDED
The diff for this file is too large to render. See raw diff
 
acestep/local_cache.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Local cache module to replace Redis
2
+
3
+ Uses diskcache as backend, provides Redis-compatible API.
4
+ Supports persistent storage and TTL expiration.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ from typing import Any, Optional
10
+ from threading import Lock
11
+
12
+ try:
13
+ from diskcache import Cache
14
+ HAS_DISKCACHE = True
15
+ except ImportError:
16
+ HAS_DISKCACHE = False
17
+
18
+
19
+ class LocalCache:
20
+ """
21
+ Local cache implementation with Redis-compatible API.
22
+ Uses diskcache as backend, supports persistence and TTL.
23
+ """
24
+
25
+ _instance = None
26
+ _lock = Lock()
27
+
28
+ def __new__(cls, cache_dir: Optional[str] = None):
29
+ """Singleton pattern"""
30
+ if cls._instance is None:
31
+ with cls._lock:
32
+ if cls._instance is None:
33
+ cls._instance = super().__new__(cls)
34
+ cls._instance._initialized = False
35
+ return cls._instance
36
+
37
+ def __init__(self, cache_dir: Optional[str] = None):
38
+ if getattr(self, '_initialized', False):
39
+ return
40
+
41
+ if not HAS_DISKCACHE:
42
+ raise ImportError(
43
+ "diskcache not installed. Run: pip install diskcache"
44
+ )
45
+
46
+ if cache_dir is None:
47
+ cache_dir = os.path.join(
48
+ os.path.dirname(os.path.dirname(__file__)),
49
+ ".cache",
50
+ "local_redis"
51
+ )
52
+
53
+ os.makedirs(cache_dir, exist_ok=True)
54
+ self._cache = Cache(cache_dir)
55
+ self._initialized = True
56
+
57
+ def set(self, name: str, value: Any, ex: Optional[int] = None) -> bool:
58
+ """
59
+ Set key-value pair
60
+
61
+ Args:
62
+ name: Key name
63
+ value: Value (auto-serialize dict/list)
64
+ ex: Expiration time (seconds)
65
+
66
+ Returns:
67
+ bool: Success status
68
+ """
69
+ if isinstance(value, (dict, list)):
70
+ value = json.dumps(value, ensure_ascii=False)
71
+ self._cache.set(name, value, expire=ex)
72
+ return True
73
+
74
+ def get(self, name: str) -> Optional[str]:
75
+ """Get value"""
76
+ return self._cache.get(name)
77
+
78
+ def delete(self, name: str) -> int:
79
+ """Delete key, returns number of deleted items"""
80
+ return 1 if self._cache.delete(name) else 0
81
+
82
+ def exists(self, name: str) -> bool:
83
+ """Check if key exists"""
84
+ return name in self._cache
85
+
86
+ def keys(self, pattern: str = "*") -> list:
87
+ """
88
+ Get list of matching keys
89
+ Note: Simplified implementation, only supports prefix and full matching
90
+ """
91
+ if pattern == "*":
92
+ return list(self._cache.iterkeys())
93
+
94
+ prefix = pattern.rstrip("*")
95
+ return [k for k in self._cache.iterkeys() if k.startswith(prefix)]
96
+
97
+ def expire(self, name: str, seconds: int) -> bool:
98
+ """Set key expiration time"""
99
+ value = self._cache.get(name)
100
+ if value is not None:
101
+ self._cache.set(name, value, expire=seconds)
102
+ return True
103
+ return False
104
+
105
+ def ttl(self, name: str) -> int:
106
+ """
107
+ Get remaining time to live (seconds)
108
+ Note: diskcache does not directly support TTL queries
109
+ """
110
+ if name in self._cache:
111
+ return -1 # Exists but TTL unknown
112
+ return -2 # Key does not exist
113
+
114
+ def close(self):
115
+ """Close cache connection"""
116
+ if hasattr(self, '_cache'):
117
+ self._cache.close()
118
+
119
+
120
+ # Lazily initialized global instance
121
+ _local_cache: Optional[LocalCache] = None
122
+
123
+
124
+ def get_local_cache(cache_dir: Optional[str] = None) -> LocalCache:
125
+ """Get local cache instance"""
126
+ global _local_cache
127
+ if _local_cache is None:
128
+ _local_cache = LocalCache(cache_dir)
129
+ return _local_cache
acestep/test_time_scaling.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test-Time Scaling Module
3
+ Implements perplexity-based scoring for generated audio codes
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from typing import Tuple, Optional, Dict, Any, List
8
+ from loguru import logger
9
+ import yaml
10
+ import math
11
+ import re
12
+
13
+
14
+ def pmi_score(log_prob_conditional: float, log_prob_unconditional: float) -> float:
15
+ """
16
+ Calculate Pointwise Mutual Information (PMI) score.
17
+
18
+ PMI = log P(condition|codes) - log P(condition)
19
+ = log [P(codes|condition) / P(codes)]
20
+
21
+ This removes the bias from P(condition) and measures how much the codes
22
+ improve our ability to predict the condition.
23
+
24
+ Args:
25
+ log_prob_conditional: Average log probability of condition given codes
26
+ log_prob_unconditional: Average log probability of condition without codes
27
+
28
+ Returns:
29
+ PMI score (higher is better, can be positive or negative)
30
+ - Positive: codes improve prediction → good match
31
+ - Zero: codes don't help → no correlation
32
+ - Negative: codes hurt prediction → poor match
33
+ """
34
+ return log_prob_conditional - log_prob_unconditional
35
+
36
+
37
+ def pmi_to_normalized_score(pmi: float, scale: float = 0.1) -> float:
38
+ """
39
+ Convert PMI score to normalized [0, 1] range using sigmoid function.
40
+
41
+ score = sigmoid(PMI / scale) = 1 / (1 + exp(-PMI / scale))
42
+
43
+ Args:
44
+ pmi: PMI score (can be positive or negative)
45
+ scale: Scale parameter to control sensitivity (default 0.1)
46
+ - Smaller scale: more sensitive to PMI changes
47
+ - Larger scale: less sensitive to PMI changes
48
+
49
+ Returns:
50
+ Normalized score in [0, 1] range, where:
51
+ - PMI > 0 → score > 0.5 (good match)
52
+ - PMI = 0 → score = 0.5 (neutral)
53
+ - PMI < 0 → score < 0.5 (poor match)
54
+
55
+ Examples (scale=1.0):
56
+ PMI=2.0 → score≈0.88 (excellent)
57
+ PMI=1.0 → score≈0.73 (good)
58
+ PMI=0.0 → score=0.50 (neutral)
59
+ PMI=-1.0 → score≈0.27 (poor)
60
+ PMI=-2.0 → score≈0.12 (bad)
61
+ """
62
+ return 1.0 / (1.0 + math.exp(-pmi / scale))
63
+
64
+
65
+ def _get_logits_and_target_for_scoring(llm_handler, formatted_prompt: str,
66
+ target_text: str) -> Tuple[torch.Tensor, torch.Tensor]:
67
+ """
68
+ Args:
69
+ llm_handler: The handler containing the model and tokenizer.
70
+ formatted_prompt: The input context.
71
+ target_text: The text we want to calculate probability/recall for.
72
+
73
+ Returns:
74
+ Tuple of (target_logits, target_ids)
75
+ - target_logits: Logits used to predict the target tokens.
76
+ - target_ids: The ground truth token IDs of the target.
77
+ """
78
+ model = llm_handler.get_hf_model_for_scoring()
79
+ tokenizer = llm_handler.llm_tokenizer
80
+ device = llm_handler.device if llm_handler.llm_backend == "pt" else next(model.parameters()).device
81
+
82
+ # 1. Tokenize prompt ONLY to get its length (used for slicing later).
83
+ # We must ensure special tokens are added to count the offset correctly.
84
+ prompt_tokens_temp = tokenizer(formatted_prompt, return_tensors="pt", add_special_tokens=True)
85
+ prompt_len = prompt_tokens_temp['input_ids'].shape[1]
86
+
87
+ # 2. Tokenize the FULL text (Prompt + Target).
88
+ # This ensures subword merging at boundaries is handled correctly by the tokenizer.
89
+ full_text = formatted_prompt + target_text
90
+ full_tokens = tokenizer(full_text, return_tensors="pt", padding=False, truncation=True, add_special_tokens=True).to(device)
91
+
92
+ input_ids = full_tokens['input_ids']
93
+
94
+ # Safety check: if target was empty or truncated entirely
95
+ if input_ids.shape[1] <= prompt_len:
96
+ return torch.empty(0, device=device), torch.empty(0, device=device)
97
+
98
+ # 3. Forward Pass (Teacher Forcing)
99
+ with torch.no_grad():
100
+ with llm_handler._load_model_context():
101
+ outputs = model(input_ids=input_ids, attention_mask=full_tokens['attention_mask'])
102
+ all_logits = outputs.logits # [1, seq_len, vocab_size]
103
+
104
+ # 4. Extract Logits and Labels
105
+ # We need to predict `input_ids[i]`. The logit for this is at `all_logits[i-1]`.
106
+ # Target starts at index `prompt_len`.
107
+ # So we need logits from `prompt_len - 1` up to the second to last position.
108
+
109
+ target_logits = all_logits[0, prompt_len - 1:-1, :] # [target_len, vocab_size]
110
+ target_ids = input_ids[0, prompt_len:] # [target_len]
111
+
112
+ return target_logits, target_ids
113
+
114
+
115
+ # ==============================================================================
116
+ # Scoring Logic
117
+ # ==============================================================================
118
+
119
+
120
+ def _calculate_topk_recall(llm_handler,
121
+ formatted_prompt: str,
122
+ target_text: str,
123
+ topk: int = 10) -> Tuple[float, Dict[int, float]]:
124
+ """
125
+ Calculate top-k recall for target text given prompt.
126
+ Checks if the ground truth token is within the top-k probabilities at each step.
127
+ """
128
+ # Use the fixed helper to get aligned logits/labels
129
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
130
+
131
+ if target_ids.shape[0] == 0:
132
+ return 0.0, {}
133
+
134
+ target_len = target_ids.shape[0]
135
+
136
+ # Get top-k indices for all positions at once
137
+ # topk_indices: [target_len, topk]
138
+ _, topk_indices = torch.topk(pred_logits, k=min(topk, pred_logits.shape[-1]), dim=-1)
139
+
140
+ recall_per_k = {}
141
+ position_scores = []
142
+
143
+ # Convert to list for faster CPU iteration
144
+ target_ids_list = target_ids.tolist()
145
+ topk_indices_list = topk_indices.tolist()
146
+
147
+ for k in range(1, topk + 1):
148
+ hits = 0
149
+ for pos in range(target_len):
150
+ gt_token = target_ids_list[pos]
151
+ # Check the top-k slice
152
+ topk_at_pos = topk_indices_list[pos][:k]
153
+
154
+ if gt_token in topk_at_pos:
155
+ hits += 1
156
+ # Calculate position-weighted score only once (when k=topk)
157
+ if k == topk:
158
+ rank = topk_at_pos.index(gt_token) + 1
159
+ # Rank 1 = 1.0, Rank k = small positive
160
+ position_weight = 1.0 - (rank - 1) / topk
161
+ position_scores.append(position_weight)
162
+
163
+ recall_per_k[k] = hits / target_len if target_len > 0 else 0.0
164
+
165
+ # Fill scores for positions where GT was NOT in top-k
166
+ while len(position_scores) < target_len:
167
+ position_scores.append(0.0)
168
+
169
+ average_recall = sum(position_scores) / len(position_scores) if position_scores else 0.0
170
+
171
+ return average_recall, recall_per_k
172
+
173
+
174
+ def _calculate_metadata_recall(llm_handler,
175
+ formatted_prompt: str,
176
+ fields_dict: Dict[str, Any],
177
+ topk: int = 10) -> Dict[str, float]:
178
+ """
179
+ Args:
180
+ fields_dict: Dictionary of {field_name: field_value}
181
+ """
182
+ if not fields_dict:
183
+ return {}
184
+
185
+ field_scores = {}
186
+
187
+ for field_name in sorted(fields_dict.keys()):
188
+ # Construct target text for this specific field
189
+ # e.g. <think>\nbpm: 120\n</think>\n
190
+ field_yaml = yaml.dump({field_name: fields_dict[field_name]}, allow_unicode=True, sort_keys=True).strip()
191
+ field_target_text = f"<think>\n{field_yaml}\n</think>\n"
192
+
193
+ # Calculate recall using the robust logic
194
+ avg_score, _ = _calculate_topk_recall(llm_handler, formatted_prompt, field_target_text, topk=topk)
195
+
196
+ field_scores[field_name] = avg_score
197
+ logger.debug(f"Recall for {field_name}: {avg_score:.4f}")
198
+
199
+ return field_scores
200
+
201
+
202
+ def _calculate_log_prob(
203
+ llm_handler,
204
+ formatted_prompt: str,
205
+ target_text: str,
206
+ temperature: float = 1.0 # Kept for API compatibility, but ignored for scoring
207
+ ) -> float:
208
+ """
209
+ Calculate average log probability of target text given prompt.
210
+ """
211
+ pred_logits, target_ids = _get_logits_and_target_for_scoring(llm_handler, formatted_prompt, target_text)
212
+
213
+ if target_ids.shape[0] == 0:
214
+ return float('-inf')
215
+
216
+ # FIX: Do not divide by temperature.
217
+ # Log-probability for PMI/Perplexity should be exact.
218
+
219
+ # Calculate log probabilities (log_softmax)
220
+ log_probs = F.log_softmax(pred_logits, dim=-1) # [target_len, vocab_size]
221
+
222
+ # Gather log probabilities of the ground truth tokens
223
+ target_log_probs = log_probs[torch.arange(target_ids.shape[0]), target_ids]
224
+
225
+ # Return average log probability
226
+ mean_log_prob = target_log_probs.mean().item()
227
+
228
+ return mean_log_prob
229
+
230
+
231
+ def calculate_reward_score(
232
+ scores: Dict[str, float],
233
+ weights_config: Optional[Dict[str, float]] = None
234
+ ) -> Tuple[float, str]:
235
+ """
236
+ Reward Model Calculator: Computes a final reward based on user priorities.
237
+
238
+ Priority Logic:
239
+ 1. Caption (Highest): The overall vibe/style must match.
240
+ 2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
241
+ 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
242
+
243
+ Strategy: Dynamic Weighted Sum
244
+ - Metadata fields are aggregated into a single 'metadata' score first.
245
+ - Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
246
+
247
+ Args:
248
+ scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
249
+ weights_config: Optional custom weights. Defaults to:
250
+ Caption (50%), Lyrics (30%), Metadata (20%).
251
+
252
+ Returns:
253
+ final_reward: The calculated reward score (0.0 - 1.0).
254
+ explanation: A formatted string explaining how the score was derived.
255
+ """
256
+
257
+ # 1. Default Preference Configuration
258
+ # These weights determine the relative importance of each component.
259
+ if weights_config is None:
260
+ weights_config = {
261
+ 'caption': 0.50, # High priority: Style/Vibe
262
+ 'lyrics': 0.30, # Medium priority: Content
263
+ 'metadata': 0.20 # Low priority: Technical details
264
+ }
265
+
266
+ # 2. Extract and Group Scores
267
+ # Caption and Lyrics are standalone high-level features.
268
+ caption_score = scores.get('caption')
269
+ lyrics_score = scores.get('lyrics')
270
+
271
+ # Metadata fields (bpm, key, duration, etc.) are aggregated.
272
+ # We treat them as a single "Technical Score" to prevent them from
273
+ # diluting the weight of Caption/Lyrics simply by having many fields.
274
+ meta_scores_list = [
275
+ val for key, val in scores.items()
276
+ if key not in ['caption', 'lyrics']
277
+ ]
278
+
279
+ # Calculate average of all metadata fields (if any exist)
280
+ meta_aggregate_score = None
281
+ if meta_scores_list:
282
+ meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
283
+
284
+ # 3. specific Active Components & Dynamic Weighting
285
+ # We only include components that actually exist in this generation.
286
+ active_components = {}
287
+
288
+ if caption_score is not None:
289
+ active_components['caption'] = (caption_score, weights_config['caption'])
290
+
291
+ if lyrics_score is not None:
292
+ active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
293
+
294
+ if meta_aggregate_score is not None:
295
+ active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
296
+
297
+ # 4. Calculate Final Weighted Score
298
+ total_base_weight = sum(w for _, w in active_components.values())
299
+ total_score = 0.0
300
+
301
+ breakdown_lines = []
302
+
303
+ if total_base_weight == 0:
304
+ return 0.0, "❌ No valid scores available to calculate reward."
305
+
306
+ # Sort by weight (importance) for display
307
+ sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
308
+
309
+ for name, (score, base_weight) in sorted_components:
310
+ # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
311
+ normalized_weight = base_weight / total_base_weight
312
+ weighted_contribution = score * normalized_weight
313
+ total_score += weighted_contribution
314
+
315
+ breakdown_lines.append(
316
+ f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
317
+ f"-> Contrib: +{weighted_contribution:.4f}"
318
+ )
319
+
320
+ return total_score, "\n".join(breakdown_lines)
321
+
322
+ # ==============================================================================
323
+ # Main Public API
324
+ # ==============================================================================
325
+
326
+
327
+ def calculate_pmi_score_per_condition(
328
+ llm_handler,
329
+ audio_codes: str,
330
+ caption: str = "",
331
+ lyrics: str = "",
332
+ metadata: Optional[Dict[str, Any]] = None,
333
+ temperature: float = 1.0,
334
+ topk: int = 10,
335
+ score_scale: float = 0.1,
336
+ ) -> Tuple[Dict[str, float], float, str]:
337
+ """
338
+ Calculate quality score separately for each condition.
339
+ - Metadata: Uses Top-k Recall.
340
+ - Caption/Lyrics: Uses PMI (Normalized).
341
+ """
342
+ if not llm_handler.llm_initialized:
343
+ return {}, 0.0, "❌ LLM not initialized"
344
+
345
+ if not audio_codes or not audio_codes.strip():
346
+ return {}, 0.0, "❌ No audio codes provided"
347
+
348
+ if "caption" not in metadata:
349
+ metadata['caption'] = caption
350
+
351
+ formatted_prompt = llm_handler.build_formatted_prompt_for_understanding(audio_codes=audio_codes, is_negative_prompt=False)
352
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
353
+ try:
354
+ # 1. Calculate Recall for Metadata Fields
355
+ if metadata and isinstance(metadata, dict):
356
+ scores = {}
357
+ # Define which fields use which metric
358
+ metadata_recall_keys = ['bpm', 'duration', 'genres', 'keyscale', 'language', 'timesignature']
359
+ metadata_pmi_keys = ['caption']
360
+ for key in metadata_recall_keys:
361
+ if key in metadata and metadata[key] is not None:
362
+ recall_metadata = {key: metadata[key]}
363
+ field_scores = _calculate_metadata_recall(llm_handler, formatted_prompt, recall_metadata, topk=topk)
364
+ scores.update(field_scores)
365
+
366
+ # 2. Calculate PMI for Caption
367
+ for key in metadata_pmi_keys:
368
+ if key in metadata and metadata[key] is not None:
369
+ cot_yaml = yaml.dump({key: metadata[key]}, allow_unicode=True, sort_keys=True).strip()
370
+ target_text = f"<think>\n{cot_yaml}\n</think>\n"
371
+
372
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
373
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
374
+
375
+ pmi_normalized = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
376
+ scores[key] = pmi_normalized
377
+
378
+ # 3. Calculate PMI for Lyrics
379
+ if lyrics:
380
+ target_text = f"<think>\n</think>\n# Lyric\n{lyrics}\n"
381
+
382
+ log_prob_cond = _calculate_log_prob(llm_handler, formatted_prompt, target_text)
383
+
384
+ prompt_uncond = llm_handler.build_formatted_prompt_for_understanding(audio_codes="NO USER INPUT", is_negative_prompt=False)
385
+ log_prob_uncond = _calculate_log_prob(llm_handler, prompt_uncond, target_text)
386
+
387
+ scores['lyrics'] = pmi_to_normalized_score(log_prob_cond - log_prob_uncond, scale=score_scale)
388
+
389
+ if not scores:
390
+ return {}, 0.0, "❌ No conditions to evaluate"
391
+
392
+ # 4. Global Score
393
+ global_score = sum(scores.values()) / len(scores)
394
+ global_score, breakdown_lines = calculate_reward_score(scores)
395
+
396
+ # Status Message
397
+ status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
398
+ for key, score in sorted(scores.items()):
399
+ metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
400
+ status_lines.append(f" {key}: {score:.4f} ({metric})")
401
+ status = "\n".join(status_lines)
402
+ logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
403
+ return scores, global_score, status
404
+
405
+ except Exception as e:
406
+ import traceback
407
+ error_msg = f"❌ Error: {str(e)}"
408
+ logger.error(error_msg)
409
+ logger.error(traceback.format_exc())
410
+ return {}, float('-inf'), error_msg
acestep/third_parts/nano-vllm/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Xingkai Yu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
acestep/third_parts/nano-vllm/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img width="300" src="assets/logo.png">
3
+ </p>
4
+
5
+ <p align="center">
6
+ <a href="https://trendshift.io/repositories/15323" target="_blank"><img src="https://trendshift.io/api/badge/repositories/15323" alt="GeeeekExplorer%2Fnano-vllm | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
7
+ </p>
8
+
9
+ # Nano-vLLM
10
+
11
+ A lightweight vLLM implementation built from scratch.
12
+
13
+ ## Key Features
14
+
15
+ * 🚀 **Fast offline inference** - Comparable inference speeds to vLLM
16
+ * 📖 **Readable codebase** - Clean implementation in ~ 1,200 lines of Python code
17
+ * ⚡ **Optimization Suite** - Prefix caching, Tensor Parallelism, Torch compilation, CUDA graph, etc.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
23
+ ```
24
+
25
+ ## Model Download
26
+
27
+ To download the model weights manually, use the following command:
28
+ ```bash
29
+ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
30
+ --local-dir ~/huggingface/Qwen3-0.6B/ \
31
+ --local-dir-use-symlinks False
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ See `example.py` for usage. The API mirrors vLLM's interface with minor differences in the `LLM.generate` method:
37
+ ```python
38
+ from nanovllm import LLM, SamplingParams
39
+ llm = LLM("/YOUR/MODEL/PATH", enforce_eager=True, tensor_parallel_size=1)
40
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
41
+ prompts = ["Hello, Nano-vLLM."]
42
+ outputs = llm.generate(prompts, sampling_params)
43
+ outputs[0]["text"]
44
+ ```
45
+
46
+ ## Benchmark
47
+
48
+ See `bench.py` for benchmark.
49
+
50
+ **Test Configuration:**
51
+ - Hardware: RTX 4070 Laptop (8GB)
52
+ - Model: Qwen3-0.6B
53
+ - Total Requests: 256 sequences
54
+ - Input Length: Randomly sampled between 100–1024 tokens
55
+ - Output Length: Randomly sampled between 100–1024 tokens
56
+
57
+ **Performance Results:**
58
+ | Inference Engine | Output Tokens | Time (s) | Throughput (tokens/s) |
59
+ |----------------|-------------|----------|-----------------------|
60
+ | vLLM | 133,966 | 98.37 | 1361.84 |
61
+ | Nano-vLLM | 133,966 | 93.41 | 1434.13 |
62
+
63
+
64
+ ## Star History
65
+
66
+ [![Star History Chart](https://api.star-history.com/svg?repos=GeeeekExplorer/nano-vllm&type=Date)](https://www.star-history.com/#GeeeekExplorer/nano-vllm&Date)
acestep/third_parts/nano-vllm/bench.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from random import randint, seed
4
+ from nanovllm import LLM, SamplingParams
5
+ # from vllm import LLM, SamplingParams
6
+
7
+
8
+ def main():
9
+ seed(0)
10
+ num_seqs = 256
11
+ max_input_len = 1024
12
+ max_ouput_len = 1024
13
+
14
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
15
+ llm = LLM(path, enforce_eager=False, max_model_len=4096)
16
+
17
+ prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
18
+ sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
19
+ # uncomment the following line for vllm
20
+ # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
21
+
22
+ llm.generate(["Benchmark: "], SamplingParams())
23
+ t = time.time()
24
+ llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
25
+ t = (time.time() - t)
26
+ total_tokens = sum(sp.max_tokens for sp in sampling_params)
27
+ throughput = total_tokens / t
28
+ print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()
acestep/third_parts/nano-vllm/example.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from nanovllm import LLM, SamplingParams
3
+ from transformers import AutoTokenizer
4
+
5
+
6
+ def main():
7
+ path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
8
+ tokenizer = AutoTokenizer.from_pretrained(path)
9
+ llm = LLM(path, enforce_eager=True, tensor_parallel_size=1)
10
+
11
+ sampling_params = SamplingParams(temperature=0.6, max_tokens=256)
12
+ prompts = [
13
+ "introduce yourself",
14
+ "list all prime numbers within 100",
15
+ ]
16
+ prompts = [
17
+ tokenizer.apply_chat_template(
18
+ [{"role": "user", "content": prompt}],
19
+ tokenize=False,
20
+ add_generation_prompt=True,
21
+ )
22
+ for prompt in prompts
23
+ ]
24
+ outputs = llm.generate(prompts, sampling_params)
25
+
26
+ for prompt, output in zip(prompts, outputs):
27
+ print("\n")
28
+ print(f"Prompt: {prompt!r}")
29
+ print(f"Completion: {output['text']!r}")
30
+
31
+
32
+ if __name__ == "__main__":
33
+ main()
acestep/third_parts/nano-vllm/nanovllm/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nanovllm.llm import LLM
2
+ from nanovllm.sampling_params import SamplingParams
acestep/third_parts/nano-vllm/nanovllm/config.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from transformers import AutoConfig
4
+
5
+
6
+ @dataclass
7
+ class Config:
8
+ model: str
9
+ max_num_batched_tokens: int = 16384
10
+ max_num_seqs: int = 512
11
+ max_model_len: int = 4096
12
+ gpu_memory_utilization: float = 0.9
13
+ tensor_parallel_size: int = 1
14
+ enforce_eager: bool = False
15
+ hf_config: AutoConfig | None = None
16
+ eos: int = -1
17
+ kvcache_block_size: int = 256
18
+ num_kvcache_blocks: int = -1
19
+
20
+ def __post_init__(self):
21
+ assert os.path.isdir(self.model)
22
+ assert self.kvcache_block_size % 256 == 0
23
+ assert 1 <= self.tensor_parallel_size <= 8
24
+ self.hf_config = AutoConfig.from_pretrained(self.model)
25
+ self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
26
+ assert self.max_num_batched_tokens >= self.max_model_len
acestep/third_parts/nano-vllm/nanovllm/engine/block_manager.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+ import xxhash
3
+ import numpy as np
4
+
5
+ from nanovllm.engine.sequence import Sequence
6
+
7
+
8
+ class Block:
9
+
10
+ def __init__(self, block_id):
11
+ self.block_id = block_id
12
+ self.ref_count = 0
13
+ self.hash = -1
14
+ self.token_ids = []
15
+
16
+ def update(self, hash: int, token_ids: list[int]):
17
+ self.hash = hash
18
+ self.token_ids = token_ids
19
+
20
+ def reset(self):
21
+ self.ref_count = 1
22
+ self.hash = -1
23
+ self.token_ids = []
24
+
25
+
26
+ class BlockManager:
27
+
28
+ def __init__(self, num_blocks: int, block_size: int):
29
+ self.block_size = block_size
30
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
31
+ self.hash_to_block_id: dict[int, int] = dict()
32
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
33
+ self.used_block_ids: set[int] = set()
34
+
35
+ @classmethod
36
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
37
+ h = xxhash.xxh64()
38
+ if prefix != -1:
39
+ h.update(prefix.to_bytes(8, "little"))
40
+ h.update(np.array(token_ids).tobytes())
41
+ return h.intdigest()
42
+
43
+ def _allocate_block(self, block_id: int) -> Block:
44
+ block = self.blocks[block_id]
45
+ assert block.ref_count == 0
46
+ block.reset()
47
+ self.free_block_ids.remove(block_id)
48
+ self.used_block_ids.add(block_id)
49
+ return self.blocks[block_id]
50
+
51
+ def _deallocate_block(self, block_id: int) -> Block:
52
+ assert self.blocks[block_id].ref_count == 0
53
+ self.used_block_ids.remove(block_id)
54
+ self.free_block_ids.append(block_id)
55
+
56
+ def can_allocate(self, seq: Sequence) -> bool:
57
+ return len(self.free_block_ids) >= seq.num_blocks
58
+
59
+ def allocate(self, seq: Sequence):
60
+ assert not seq.block_table
61
+ h = -1
62
+ cache_miss = False
63
+ for i in range(seq.num_blocks):
64
+ token_ids = seq.block(i)
65
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
66
+ block_id = self.hash_to_block_id.get(h, -1)
67
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
68
+ cache_miss = True
69
+ if cache_miss:
70
+ block_id = self.free_block_ids[0]
71
+ block = self._allocate_block(block_id)
72
+ else:
73
+ seq.num_cached_tokens += self.block_size
74
+ if block_id in self.used_block_ids:
75
+ block = self.blocks[block_id]
76
+ block.ref_count += 1
77
+ else:
78
+ block = self._allocate_block(block_id)
79
+ if h != -1:
80
+ block.update(h, token_ids)
81
+ self.hash_to_block_id[h] = block_id
82
+ seq.block_table.append(block_id)
83
+
84
+ def deallocate(self, seq: Sequence):
85
+ for block_id in reversed(seq.block_table):
86
+ block = self.blocks[block_id]
87
+ block.ref_count -= 1
88
+ if block.ref_count == 0:
89
+ # Fix: Clean up hash_to_block_id mapping to prevent stale references
90
+ # This prevents CUDA illegal memory access when prefix cache tries to
91
+ # reuse a block_id that has already been freed
92
+ if block.hash != -1:
93
+ cached_id = self.hash_to_block_id.get(block.hash)
94
+ if cached_id == block_id:
95
+ del self.hash_to_block_id[block.hash]
96
+ self._deallocate_block(block_id)
97
+ seq.num_cached_tokens = 0
98
+ seq.block_table.clear()
99
+
100
+ def can_append(self, seq: Sequence) -> bool:
101
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
102
+
103
+ def may_append(self, seq: Sequence):
104
+ block_table = seq.block_table
105
+ last_block = self.blocks[block_table[-1]]
106
+ if len(seq) % self.block_size == 1:
107
+ assert last_block.hash != -1
108
+ block_id = self.free_block_ids[0]
109
+ self._allocate_block(block_id)
110
+ block_table.append(block_id)
111
+ elif len(seq) % self.block_size == 0:
112
+ assert last_block.hash == -1
113
+ token_ids = seq.block(seq.num_blocks-1)
114
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
115
+ h = self.compute_hash(token_ids, prefix)
116
+ last_block.update(h, token_ids)
117
+ self.hash_to_block_id[h] = last_block.block_id
118
+ else:
119
+ assert last_block.hash == -1
acestep/third_parts/nano-vllm/nanovllm/engine/llm_engine.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import threading
3
+ from dataclasses import fields
4
+ from time import perf_counter
5
+ from tqdm.auto import tqdm
6
+ from transformers import AutoTokenizer
7
+ import torch.multiprocessing as mp
8
+
9
+ from nanovllm.config import Config
10
+ from nanovllm.sampling_params import SamplingParams
11
+ from nanovllm.engine.sequence import Sequence
12
+ from nanovllm.engine.scheduler import Scheduler
13
+ from nanovllm.engine.model_runner import ModelRunner
14
+
15
+
16
+ class LLMEngine:
17
+
18
+ def __init__(self, model, **kwargs):
19
+ config_fields = {field.name for field in fields(Config)}
20
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
21
+ config = Config(model, **config_kwargs)
22
+ self.ps = []
23
+ self.events = []
24
+ # Thread-safety lock for generate().
25
+ # The scheduler, block manager, model runner, and CUDA graph buffers are all
26
+ # shared mutable state that is NOT thread-safe. In concurrent serving scenarios
27
+ # (API server with ThreadPoolExecutor, multiple queue workers, Gradio with
28
+ # concurrent requests), multiple threads can call generate() simultaneously.
29
+ # Without this lock, concurrent access corrupts scheduler state, block tables,
30
+ # and CUDA graph input buffers, leading to intermittent CUDA device-side
31
+ # assertion failures (illegal memory access in KV cache).
32
+ self._generate_lock = threading.Lock()
33
+ ctx = mp.get_context("spawn")
34
+ for i in range(1, config.tensor_parallel_size):
35
+ event = ctx.Event()
36
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
37
+ process.start()
38
+ self.ps.append(process)
39
+ self.events.append(event)
40
+ self.model_runner = ModelRunner(config, 0, self.events)
41
+ tokenizer = kwargs.get("tokenizer", None)
42
+ if tokenizer is not None:
43
+ self.tokenizer = tokenizer
44
+ else:
45
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
46
+ config.eos = self.tokenizer.eos_token_id
47
+ self.scheduler = Scheduler(config)
48
+ atexit.register(self.exit)
49
+
50
+ def exit(self):
51
+ self.model_runner.call("exit")
52
+ del self.model_runner
53
+ for p in self.ps:
54
+ p.join()
55
+
56
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, unconditional_prompt: str | list[int] | None = None):
57
+ if isinstance(prompt, str):
58
+ prompt = self.tokenizer.encode(prompt)
59
+ # For CFG: if cfg_scale > 1.0, create both conditional and unconditional sequences
60
+ if sampling_params.cfg_scale > 1.0:
61
+ if unconditional_prompt is None:
62
+ # Try to construct unconditional prompt by replacing user input with "NO USER INPUT"
63
+ # This is a fallback - ideally users should provide unconditional_prompt
64
+ if isinstance(prompt, list):
65
+ # For now, just use the same prompt (user should provide unconditional_prompt)
66
+ # TODO: Implement automatic "NO USER INPUT" replacement if possible
67
+ unconditional_prompt = prompt
68
+ else:
69
+ unconditional_prompt = prompt
70
+ if isinstance(unconditional_prompt, str):
71
+ unconditional_prompt = self.tokenizer.encode(unconditional_prompt)
72
+ # Create unconditional sequence first (so we can reference it from conditional)
73
+ uncond_seq = Sequence(unconditional_prompt, sampling_params, is_unconditional=True)
74
+ # Create conditional sequence with reference to unconditional
75
+ cond_seq = Sequence(prompt, sampling_params, is_unconditional=False, conditional_seq=uncond_seq)
76
+ uncond_seq.paired_seq = cond_seq # Link them bidirectionally
77
+ # Add both sequences to scheduler
78
+ self.scheduler.add(cond_seq)
79
+ self.scheduler.add(uncond_seq)
80
+ else:
81
+ seq = Sequence(prompt, sampling_params)
82
+ self.scheduler.add(seq)
83
+
84
+ def step(self):
85
+ seqs, is_prefill = self.scheduler.schedule()
86
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
87
+ self.scheduler.postprocess(seqs, token_ids)
88
+ # Only output conditional sequences (unconditional sequences are just for CFG computation)
89
+ output_seqs = [seq for seq in seqs if seq.is_finished and (seq.cfg_scale <= 1.0 or not seq.is_unconditional)]
90
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in output_seqs]
91
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len([s for s in seqs if not s.is_unconditional])
92
+ return outputs, num_tokens
93
+
94
+ def is_finished(self):
95
+ return self.scheduler.is_finished()
96
+
97
+ def reset(self):
98
+ """
99
+ Reset the scheduler state and release all allocated blocks.
100
+ This should be called when an exception occurs during generation to prevent
101
+ KV cache block leaks that can cause 'deque index out of range' errors.
102
+ """
103
+ # Deallocate all running sequences
104
+ while self.scheduler.running:
105
+ seq = self.scheduler.running.popleft()
106
+ if seq.block_table: # Only deallocate if blocks are allocated
107
+ self.scheduler.block_manager.deallocate(seq)
108
+
109
+ # Deallocate all waiting sequences (they might have blocks from preemption)
110
+ while self.scheduler.waiting:
111
+ seq = self.scheduler.waiting.popleft()
112
+ if seq.block_table:
113
+ self.scheduler.block_manager.deallocate(seq)
114
+
115
+ def generate(
116
+ self,
117
+ prompts: list[str] | list[list[int]],
118
+ sampling_params: SamplingParams | list[SamplingParams],
119
+ use_tqdm: bool = True,
120
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
121
+ ) -> list[str]:
122
+ # Serialize access to the engine to prevent concurrent corruption of
123
+ # scheduler state, block manager, CUDA graph buffers, and KV cache.
124
+ # This is the primary defense against the intermittent CUDA device-side
125
+ # assertion error that occurs in concurrent serving scenarios.
126
+ with self._generate_lock:
127
+ return self._generate_impl(prompts, sampling_params, use_tqdm, unconditional_prompts)
128
+
129
+ def _generate_impl(
130
+ self,
131
+ prompts: list[str] | list[list[int]],
132
+ sampling_params: SamplingParams | list[SamplingParams],
133
+ use_tqdm: bool = True,
134
+ unconditional_prompts: list[str] | list[list[int]] | None = None,
135
+ ) -> list[str]:
136
+ # Clean up any residual state from previous interrupted generations
137
+ # This prevents 'deque index out of range' errors from accumulated block leaks
138
+ if not self.is_finished():
139
+ self.reset()
140
+
141
+ if use_tqdm:
142
+ pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True)
143
+ if not isinstance(sampling_params, list):
144
+ sampling_params = [sampling_params] * len(prompts)
145
+ if unconditional_prompts is None:
146
+ unconditional_prompts = [None] * len(prompts)
147
+ for prompt, sp, uncond_prompt in zip(prompts, sampling_params, unconditional_prompts):
148
+ self.add_request(prompt, sp, uncond_prompt)
149
+ outputs = {}
150
+ prefill_throughput = decode_throughput = 0.
151
+ try:
152
+ while not self.is_finished():
153
+ t = perf_counter()
154
+ output, num_tokens = self.step()
155
+ if use_tqdm:
156
+ if num_tokens > 0:
157
+ prefill_throughput = num_tokens / (perf_counter() - t)
158
+ else:
159
+ decode_throughput = -num_tokens / (perf_counter() - t)
160
+ pbar.set_postfix({
161
+ "Prefill": f"{int(prefill_throughput)}tok/s",
162
+ "Decode": f"{int(decode_throughput)}tok/s",
163
+ })
164
+ for seq_id, token_ids in output:
165
+ outputs[seq_id] = token_ids
166
+ if use_tqdm:
167
+ pbar.update(1)
168
+ except Exception:
169
+ # Clean up on exception to prevent block leaks
170
+ self.reset()
171
+ raise
172
+ finally:
173
+ if use_tqdm:
174
+ pbar.close()
175
+
176
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
177
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
178
+ return outputs
acestep/third_parts/nano-vllm/nanovllm/engine/model_runner.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import torch
3
+ import torch.distributed as dist
4
+ from multiprocessing.synchronize import Event
5
+ from multiprocessing.shared_memory import SharedMemory
6
+ import sys
7
+
8
+ from nanovllm.config import Config
9
+ from nanovllm.engine.sequence import Sequence
10
+ from nanovllm.models.qwen3 import Qwen3ForCausalLM
11
+ from nanovllm.layers.sampler import Sampler
12
+ from nanovllm.utils.context import set_context, get_context, reset_context
13
+ from nanovllm.utils.loader import load_model
14
+
15
+ import socket
16
+
17
+
18
+ def find_available_port(start_port: int = 2333, max_attempts: int = 100) -> int:
19
+ """Find an available port starting from start_port.
20
+
21
+ Args:
22
+ start_port: The starting port number to check
23
+ max_attempts: Maximum number of ports to try
24
+
25
+ Returns:
26
+ An available port number
27
+
28
+ Raises:
29
+ RuntimeError: If no available port is found within max_attempts
30
+ """
31
+ for i in range(max_attempts):
32
+ port = start_port + i
33
+ try:
34
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
35
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
36
+ s.bind(('localhost', port))
37
+ return port
38
+ except OSError:
39
+ # Port is in use, try next one
40
+ continue
41
+ raise RuntimeError(f"Could not find an available port starting from {start_port} after {max_attempts} attempts")
42
+
43
+
44
+ class ModelRunner:
45
+
46
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
47
+ # Enable capturing scalar outputs to avoid graph breaks from Tensor.item() calls
48
+ torch._dynamo.config.capture_scalar_outputs = True
49
+
50
+ self.config = config
51
+ hf_config = config.hf_config
52
+ self.block_size = config.kvcache_block_size
53
+ self.enforce_eager = config.enforce_eager
54
+ self.world_size = config.tensor_parallel_size
55
+ self.rank = rank
56
+ self.event = event
57
+ dist_port = find_available_port()
58
+ print(f"[debug]dist_port: {dist_port}")
59
+ # Use gloo backend on Windows, nccl on Linux/other platforms
60
+ backend = "gloo" if sys.platform == "win32" else "nccl"
61
+ dist.init_process_group(backend, f"tcp://127.0.0.1:{dist_port}", world_size=self.world_size, rank=rank)
62
+ torch.cuda.set_device(rank)
63
+ default_dtype = torch.get_default_dtype()
64
+ # Use dtype instead of deprecated torch_dtype
65
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
66
+ torch.set_default_dtype(config_dtype)
67
+ torch.set_default_device("cuda")
68
+ self.model = Qwen3ForCausalLM(hf_config)
69
+ load_model(self.model, config.model)
70
+ self.sampler = Sampler()
71
+
72
+ # Pre-allocate buffers for sampling (optimization: avoid repeated tensor creation)
73
+ # Must be called before warmup_model() since it uses these buffers
74
+ self._allocate_sample_buffers()
75
+
76
+ self.warmup_model()
77
+ self.allocate_kv_cache()
78
+ if not self.enforce_eager:
79
+ self.capture_cudagraph()
80
+
81
+ torch.set_default_device("cpu")
82
+ torch.set_default_dtype(default_dtype)
83
+
84
+ if self.world_size > 1:
85
+ if rank == 0:
86
+ self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
87
+ dist.barrier()
88
+ else:
89
+ dist.barrier()
90
+ self.shm = SharedMemory(name="nanovllm")
91
+ self.loop()
92
+
93
+ def _allocate_sample_buffers(self):
94
+ """Pre-allocate reusable buffers for sampling to avoid repeated tensor creation."""
95
+ max_bs = self.config.max_num_seqs
96
+ max_tokens = self.config.max_num_batched_tokens
97
+ max_num_blocks = (self.config.max_model_len + self.block_size - 1) // self.block_size
98
+
99
+ # Pre-allocate pinned memory buffers on CPU for fast transfer
100
+ # Must explicitly specify device="cpu" since default device may be "cuda"
101
+ self._cpu_temperatures = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
102
+ self._cpu_cfg_scales = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
103
+ self._cpu_top_ks = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
104
+ self._cpu_top_ps = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
105
+ self._cpu_repetition_penalties = torch.zeros(max_bs, dtype=torch.float32, device="cpu", pin_memory=True)
106
+
107
+ # Pre-allocate decode buffers on CPU with pinned memory
108
+ self._cpu_input_ids = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
109
+ self._cpu_positions = torch.zeros(max_bs, dtype=torch.int64, device="cpu", pin_memory=True)
110
+ self._cpu_slot_mapping = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
111
+ self._cpu_context_lens = torch.zeros(max_bs, dtype=torch.int32, device="cpu", pin_memory=True)
112
+
113
+ # Pre-allocate prefill buffers on CPU with pinned memory (optimization to avoid repeated tensor creation)
114
+ self._cpu_prefill_input_ids = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
115
+ self._cpu_prefill_positions = torch.zeros(max_tokens, dtype=torch.int64, device="cpu", pin_memory=True)
116
+ self._cpu_prefill_cu_seqlens = torch.zeros(max_bs + 1, dtype=torch.int32, device="cpu", pin_memory=True)
117
+ self._cpu_prefill_slot_mapping = torch.zeros(max_tokens, dtype=torch.int32, device="cpu", pin_memory=True)
118
+
119
+ # Pre-allocate block tables buffer (shared by both decode and prefill)
120
+ self._cpu_block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32, device="cpu", pin_memory=True)
121
+
122
+ # Pre-allocate buffer for sequence token IDs (used in logits processor and sampler)
123
+ # Max length is max_model_len since sequences can be that long
124
+ self._seq_token_ids_buffer = torch.zeros(max_bs, self.config.max_model_len, dtype=torch.int64, device="cpu", pin_memory=True)
125
+
126
+ def exit(self):
127
+ if self.world_size > 1:
128
+ self.shm.close()
129
+ dist.barrier()
130
+ if self.rank == 0:
131
+ self.shm.unlink()
132
+ if not self.enforce_eager:
133
+ del self.graphs, self.graph_pool
134
+ torch.cuda.synchronize()
135
+ dist.destroy_process_group()
136
+
137
+ def loop(self):
138
+ while True:
139
+ method_name, args = self.read_shm()
140
+ self.call(method_name, *args)
141
+ if method_name == "exit":
142
+ break
143
+
144
+ def read_shm(self):
145
+ assert self.world_size > 1 and self.rank > 0
146
+ self.event.wait()
147
+ n = int.from_bytes(self.shm.buf[0:4], "little")
148
+ method_name, *args = pickle.loads(self.shm.buf[4:n+4])
149
+ self.event.clear()
150
+ return method_name, args
151
+
152
+ def write_shm(self, method_name, *args):
153
+ assert self.world_size > 1 and self.rank == 0
154
+ data = pickle.dumps([method_name, *args])
155
+ n = len(data)
156
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
157
+ self.shm.buf[4:n+4] = data
158
+ for event in self.event:
159
+ event.set()
160
+
161
+ def call(self, method_name, *args):
162
+ if self.world_size > 1 and self.rank == 0:
163
+ self.write_shm(method_name, *args)
164
+ method = getattr(self, method_name, None)
165
+ return method(*args)
166
+
167
+ def warmup_model(self):
168
+ torch.cuda.empty_cache()
169
+ torch.cuda.reset_peak_memory_stats()
170
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
171
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
172
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
173
+ self.run(seqs, True)
174
+ torch.cuda.empty_cache()
175
+
176
+ def allocate_kv_cache(self):
177
+ config = self.config
178
+ hf_config = config.hf_config
179
+ free, total = torch.cuda.mem_get_info()
180
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
181
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
182
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
183
+ # Use dtype instead of deprecated torch_dtype
184
+ config_dtype = getattr(hf_config, 'dtype', getattr(hf_config, 'torch_dtype', torch.float32))
185
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * config_dtype.itemsize
186
+
187
+ # Calculate available memory for KV cache
188
+ # After warmup_model, empty_cache has been called, so current represents model memory only
189
+ # Use free memory but respect the gpu_memory_utilization limit
190
+ target_total_usage = total * config.gpu_memory_utilization
191
+ available_for_kv_cache = min(free * 0.9, target_total_usage - current)
192
+
193
+ # Ensure we have positive memory available
194
+ if available_for_kv_cache <= 0:
195
+ available_for_kv_cache = free * 0.5 # Fallback to 50% of free memory
196
+
197
+ config.num_kvcache_blocks = max(1, int(available_for_kv_cache) // block_bytes)
198
+ if config.num_kvcache_blocks <= 0:
199
+ raise RuntimeError(
200
+ f"Insufficient GPU memory for KV cache. "
201
+ f"Free: {free / 1024**3:.2f} GB, Current: {current / 1024**3:.2f} GB, "
202
+ f"Available for KV: {available_for_kv_cache / 1024**3:.2f} GB, "
203
+ f"Block size: {block_bytes / 1024**2:.2f} MB"
204
+ )
205
+ self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
206
+ layer_id = 0
207
+ for module in self.model.modules():
208
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
209
+ module.k_cache = self.kv_cache[0, layer_id]
210
+ module.v_cache = self.kv_cache[1, layer_id]
211
+ layer_id += 1
212
+
213
+ def prepare_block_tables(self, seqs: list[Sequence]):
214
+ max_len = max(len(seq.block_table) for seq in seqs)
215
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
216
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
217
+ return block_tables
218
+
219
+ def prepare_prefill(self, seqs: list[Sequence]):
220
+ input_ids = []
221
+ positions = []
222
+ cu_seqlens_q = [0]
223
+ cu_seqlens_k = [0]
224
+ max_seqlen_q = 0
225
+ max_seqlen_k = 0
226
+ slot_mapping = []
227
+ block_tables = None
228
+ for seq in seqs:
229
+ seqlen = len(seq)
230
+ input_ids.extend(seq[seq.num_cached_tokens:])
231
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
232
+ seqlen_q = seqlen - seq.num_cached_tokens
233
+ seqlen_k = seqlen
234
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
235
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
236
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
237
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
238
+ if not seq.block_table: # warmup
239
+ continue
240
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
241
+ start = seq.block_table[i] * self.block_size
242
+ if i != seq.num_blocks - 1:
243
+ end = start + self.block_size
244
+ else:
245
+ end = start + seq.last_block_num_tokens
246
+ slot_mapping.extend(list(range(start, end)))
247
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
248
+ block_tables = self.prepare_block_tables(seqs)
249
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
250
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
251
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
252
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
253
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
254
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
255
+ return input_ids, positions
256
+
257
+ def prepare_decode(self, seqs: list[Sequence]):
258
+ """Optimized decode preparation using pre-allocated buffers."""
259
+ bs = len(seqs)
260
+
261
+ # Use pre-allocated CPU buffers
262
+ for i, seq in enumerate(seqs):
263
+ self._cpu_input_ids[i] = seq.last_token
264
+ self._cpu_positions[i] = len(seq) - 1
265
+ self._cpu_context_lens[i] = len(seq)
266
+ self._cpu_slot_mapping[i] = seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1
267
+
268
+ # Transfer to GPU using sliced views
269
+ input_ids = self._cpu_input_ids[:bs].cuda(non_blocking=True)
270
+ positions = self._cpu_positions[:bs].cuda(non_blocking=True)
271
+ slot_mapping = self._cpu_slot_mapping[:bs].cuda(non_blocking=True)
272
+ context_lens = self._cpu_context_lens[:bs].cuda(non_blocking=True)
273
+ block_tables = self.prepare_block_tables(seqs)
274
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
275
+ return input_ids, positions
276
+
277
+ def prepare_sample(self, seqs: list[Sequence], is_cfg_batch: bool = False):
278
+ """Optimized sample preparation using pre-allocated buffers."""
279
+ if is_cfg_batch:
280
+ num_seqs = len(seqs) // 2
281
+ target_seqs = seqs[:num_seqs]
282
+ else:
283
+ num_seqs = len(seqs)
284
+ target_seqs = seqs
285
+
286
+ # Fill pre-allocated CPU buffers
287
+ top_ks_is_zero = True
288
+ top_ps_is_one = True
289
+ repetition_penalties_is_one = True
290
+ for i, seq in enumerate(target_seqs):
291
+ self._cpu_temperatures[i] = seq.temperature
292
+ self._cpu_cfg_scales[i] = seq.cfg_scale
293
+ self._cpu_top_ks[i] = seq.top_k if seq.top_k is not None else 0
294
+ if seq.top_k is not None and seq.top_k > 0:
295
+ top_ks_is_zero = False
296
+ self._cpu_top_ps[i] = seq.top_p if seq.top_p is not None else 1.0
297
+ if seq.top_p is not None and seq.top_p == 1.0:
298
+ top_ps_is_one = False
299
+ self._cpu_repetition_penalties[i] = seq.repetition_penalty if seq.repetition_penalty is not None else 1.0
300
+ if seq.repetition_penalty is not None and seq.repetition_penalty == 1.0:
301
+ repetition_penalties_is_one = False
302
+
303
+ # Transfer to GPU using sliced views (single batched transfer)
304
+ temperatures = self._cpu_temperatures[:num_seqs].cuda(non_blocking=True)
305
+ cfg_scales = self._cpu_cfg_scales[:num_seqs].cuda(non_blocking=True)
306
+ top_ks = self._cpu_top_ks[:num_seqs].cuda(non_blocking=True) if not top_ks_is_zero else None
307
+ top_ps = self._cpu_top_ps[:num_seqs].cuda(non_blocking=True) if not top_ps_is_one else None
308
+ repetition_penalties = self._cpu_repetition_penalties[:num_seqs].cuda(non_blocking=True) if not repetition_penalties_is_one else None
309
+
310
+ return temperatures, cfg_scales, top_ks, top_ps, repetition_penalties
311
+
312
+ @torch.inference_mode()
313
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
314
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
315
+ return self.model.compute_logits(self.model(input_ids, positions))
316
+ else:
317
+ bs = input_ids.size(0)
318
+ context = get_context()
319
+
320
+ # Check if block_tables size exceeds pre-allocated buffer size
321
+ # This can happen when conditional and unconditional sequences have different lengths
322
+ # in CFG mode, causing block_tables to have more columns than expected
323
+ max_num_blocks = self.graph_vars["block_tables"].size(1)
324
+ if context.block_tables.size(1) > max_num_blocks:
325
+ # Fall back to eager mode when block_tables is too large for CUDA graph
326
+ return self.model.compute_logits(self.model(input_ids, positions))
327
+
328
+ # Fix: Also check if block_tables row count matches batch size
329
+ # Dimension mismatch can cause CUDA illegal memory access during graph replay
330
+ if context.block_tables.size(0) != bs:
331
+ # Fall back to eager mode when block_tables row count doesn't match batch size
332
+ return self.model.compute_logits(self.model(input_ids, positions))
333
+
334
+ # Fix: Verify slot_mapping and context_lens dimensions match batch size
335
+ if context.slot_mapping.size(0) != bs or context.context_lens.size(0) != bs:
336
+ # Fall back to eager mode when dimensions don't match
337
+ return self.model.compute_logits(self.model(input_ids, positions))
338
+
339
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
340
+ graph_vars = self.graph_vars
341
+ graph_vars["input_ids"][:bs] = input_ids
342
+ graph_vars["positions"][:bs] = positions
343
+ graph_vars["slot_mapping"].fill_(-1)
344
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
345
+ graph_vars["context_lens"].zero_()
346
+ graph_vars["context_lens"][:bs] = context.context_lens
347
+ # Clear block_tables first to ensure no stale data from previous runs
348
+ graph_vars["block_tables"][:bs].fill_(-1)
349
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
350
+ graph.replay()
351
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
352
+
353
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
354
+ """Run model forward and sampling. For CFG sequences, batch is structured as:
355
+ [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
356
+ where uncond_seqi is the paired unconditional sequence of cond_seqi."""
357
+ # Check if this is a CFG batch (contains paired conditional and unconditional sequences)
358
+ is_cfg_batch = seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None
359
+ if is_cfg_batch:
360
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
361
+ num_cond = len(seqs) // 2
362
+ cond_seqs = seqs[:num_cond]
363
+ # uncond_seqs = seqs[num_cond:]
364
+
365
+ # Prepare inputs for both conditional and unconditional (they're already in the batch)
366
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs))
367
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=True) if self.rank == 0 else None
368
+ if sample_params is not None:
369
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
370
+ else:
371
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
372
+
373
+ # Run model forward (processes entire batch: cond + uncond)
374
+ logits_all = self.run_model(input_ids, positions, is_prefill)
375
+ reset_context()
376
+
377
+ if self.rank == 0:
378
+ # Split logits: first half is conditional, second half is unconditional
379
+ logits_cond = logits_all[:num_cond]
380
+ logits_uncond = logits_all[num_cond:]
381
+
382
+ # Apply repetition penalty to conditional logits (before CFG)
383
+ if repetition_penalties is not None:
384
+ for i, seq in enumerate(cond_seqs):
385
+ penalty = repetition_penalties[i].item()
386
+ if penalty != 1.0:
387
+ # Only penalize completion tokens (not prompt tokens)
388
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits_cond.device)
389
+ if len(completion_tokens) > 0:
390
+ # Create token mask: mark tokens that appeared in completion
391
+ token_mask = torch.zeros(logits_cond.shape[1], dtype=torch.bool, device=logits_cond.device)
392
+ token_mask[completion_tokens] = True
393
+
394
+ # Apply standard repetition penalty formula (matching transformers implementation):
395
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
396
+ penalty_scores = torch.where(
397
+ logits_cond[i] < 0,
398
+ logits_cond[i] * penalty,
399
+ logits_cond[i] / penalty
400
+ )
401
+ # Only apply penalty to tokens that appeared in completion
402
+ logits_cond[i] = torch.where(token_mask, penalty_scores, logits_cond[i])
403
+
404
+ # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
405
+ cfg_scales_tensor = cfg_scales.unsqueeze(1) # [num_cond, 1]
406
+ logits_cfg = logits_uncond + cfg_scales_tensor * (logits_cond - logits_uncond)
407
+
408
+ # Apply logits processor for constrained decoding (if any sequence has one)
409
+ for i, seq in enumerate(cond_seqs):
410
+ if seq.logits_processor is not None:
411
+ # Create input_ids tensor for this sequence
412
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits_cfg.device)
413
+ # Apply processor to this sequence's logits
414
+ logits_cfg[i:i+1] = seq.logits_processor(seq_input_ids, logits_cfg[i:i+1])
415
+
416
+ # Prepare input_ids for sampler (for repetition penalty, though we already applied it)
417
+ # cond_input_ids = torch.tensor([seq.token_ids for seq in cond_seqs], device=logits_cfg.device)
418
+
419
+ # Sample from CFG logits
420
+ token_ids_cfg = self.sampler(
421
+ logits_cfg,
422
+ temperatures,
423
+ top_ks=top_ks if top_ks is not None else None,
424
+ top_ps=top_ps if top_ps is not None else None,
425
+ repetition_penalties=None, # Already applied above
426
+ # input_ids=cond_input_ids,
427
+ ).tolist()
428
+
429
+ # Update logits processor state after sampling
430
+ # NOTE: Only update for the first sequence since all sequences share the same processor
431
+ # Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
432
+ if cond_seqs and cond_seqs[0].logits_processor_update_state is not None:
433
+ cond_seqs[0].logits_processor_update_state(token_ids_cfg[0])
434
+
435
+ # Return token_ids (will be applied to both conditional and unconditional sequences)
436
+ return token_ids_cfg
437
+ else:
438
+ return None
439
+ else:
440
+ # Normal batch (non-CFG)
441
+ input_ids, positions = (self.prepare_prefill(seqs) if is_prefill
442
+ else self.prepare_decode(seqs))
443
+ sample_params = self.prepare_sample(seqs, is_cfg_batch=False) if self.rank == 0 else None
444
+ if sample_params is not None:
445
+ temperatures, cfg_scales, top_ks, top_ps, repetition_penalties = sample_params
446
+ else:
447
+ temperatures = cfg_scales = top_ks = top_ps = repetition_penalties = None
448
+ logits = self.run_model(input_ids, positions, is_prefill)
449
+ reset_context()
450
+
451
+ if self.rank == 0:
452
+ # Apply repetition penalty to logits
453
+ if repetition_penalties is not None:
454
+ for i, seq in enumerate(seqs):
455
+ penalty = repetition_penalties[i].item()
456
+ if penalty != 1.0:
457
+ # Only penalize completion tokens (not prompt tokens)
458
+ completion_tokens = torch.tensor(seq.completion_token_ids, device=logits.device)
459
+ if len(completion_tokens) > 0:
460
+ # Create token mask: mark tokens that appeared in completion
461
+ token_mask = torch.zeros(logits.shape[1], dtype=torch.bool, device=logits.device)
462
+ token_mask[completion_tokens] = True
463
+
464
+ # Apply standard repetition penalty formula (matching transformers implementation):
465
+ # For tokens in completion: if score < 0 then score * penalty, else score / penalty
466
+ penalty_scores = torch.where(
467
+ logits[i] < 0,
468
+ logits[i] * penalty,
469
+ logits[i] / penalty
470
+ )
471
+ # Only apply penalty to tokens that appeared in completion
472
+ logits[i] = torch.where(token_mask, penalty_scores, logits[i])
473
+
474
+ # Apply logits processor for constrained decoding (if any sequence has one)
475
+ # Clone logits to avoid in-place update issues in inference mode
476
+ logits = logits.clone()
477
+ for i, seq in enumerate(seqs):
478
+ if seq.logits_processor is not None:
479
+ # Create input_ids tensor for this sequence
480
+ seq_input_ids = torch.tensor([seq.token_ids], device=logits.device)
481
+ # Apply processor to this sequence's logits (clone to avoid inference mode issues)
482
+ processed = seq.logits_processor(seq_input_ids, logits[i:i+1].clone())
483
+ logits[i] = processed[0]
484
+
485
+ # Prepare input_ids for sampler
486
+ # seq_input_ids = torch.tensor([seq.token_ids for seq in seqs], device=logits.device)
487
+
488
+ token_ids = self.sampler(
489
+ logits,
490
+ temperatures,
491
+ top_ks=top_ks if top_ks is not None else None,
492
+ top_ps=top_ps if top_ps is not None else None,
493
+ repetition_penalties=None, # Already applied above
494
+ # input_ids=seq_input_ids,
495
+ ).tolist()
496
+
497
+ # Update logits processor state after sampling
498
+ # NOTE: Only update for the first sequence since all sequences may share the same processor
499
+ # (when using a single SamplingParams for batch generation)
500
+ # Updating multiple times would cause duplicate state updates (e.g., codes_count += N instead of += 1)
501
+ if seqs and seqs[0].logits_processor_update_state is not None:
502
+ seqs[0].logits_processor_update_state(token_ids[0])
503
+
504
+ return token_ids
505
+ else:
506
+ return None
507
+
508
+ @torch.inference_mode()
509
+ def capture_cudagraph(self):
510
+ config = self.config
511
+ hf_config = config.hf_config
512
+ max_bs = min(self.config.max_num_seqs, 512)
513
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
514
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
515
+ positions = torch.zeros(max_bs, dtype=torch.int64)
516
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
517
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
518
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
519
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
520
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
521
+ self.graphs = {}
522
+ self.graph_pool = None
523
+
524
+ for bs in reversed(self.graph_bs):
525
+ graph = torch.cuda.CUDAGraph()
526
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
527
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
528
+ with torch.cuda.graph(graph, self.graph_pool):
529
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
530
+ if self.graph_pool is None:
531
+ self.graph_pool = graph.pool()
532
+ self.graphs[bs] = graph
533
+ torch.cuda.synchronize()
534
+ reset_context()
535
+
536
+ self.graph_vars = dict(
537
+ input_ids=input_ids,
538
+ positions=positions,
539
+ slot_mapping=slot_mapping,
540
+ context_lens=context_lens,
541
+ block_tables=block_tables,
542
+ outputs=outputs,
543
+ )
acestep/third_parts/nano-vllm/nanovllm/engine/scheduler.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ from nanovllm.config import Config
4
+ from nanovllm.engine.sequence import Sequence, SequenceStatus
5
+ from nanovllm.engine.block_manager import BlockManager
6
+
7
+
8
+ class Scheduler:
9
+
10
+ def __init__(self, config: Config):
11
+ self.max_num_seqs = config.max_num_seqs
12
+ self.max_num_batched_tokens = config.max_num_batched_tokens
13
+ self.eos = config.eos
14
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
15
+ self.waiting: deque[Sequence] = deque()
16
+ self.running: deque[Sequence] = deque()
17
+
18
+ def is_finished(self):
19
+ return not self.waiting and not self.running
20
+
21
+ def add(self, seq: Sequence):
22
+ self.waiting.append(seq)
23
+
24
+ def schedule(self) -> tuple[list[Sequence], bool]:
25
+ # prefill
26
+ scheduled_seqs = []
27
+ num_seqs = 0
28
+ num_batched_tokens = 0
29
+ processed_seqs = set() # Track processed sequences to handle CFG pairs
30
+
31
+ while self.waiting and num_seqs < self.max_num_seqs:
32
+ seq = self.waiting[0]
33
+
34
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
35
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
36
+ # This is a conditional sequence, need to schedule its paired unconditional sequence too
37
+ paired_seq = seq.paired_seq
38
+ if paired_seq.status != SequenceStatus.WAITING:
39
+ # Paired sequence not in waiting, skip this conditional sequence for now
40
+ break
41
+
42
+ # Calculate tokens for both sequences
43
+ total_tokens = (len(seq) - seq.num_cached_tokens) + (len(paired_seq) - paired_seq.num_cached_tokens)
44
+
45
+ # FIX: Check if we have enough blocks for BOTH sequences combined
46
+ # The old check was wrong: it checked each sequence independently,
47
+ # but didn't account for the total blocks needed by both
48
+ total_blocks_needed = seq.num_blocks + paired_seq.num_blocks
49
+ can_allocate_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
50
+
51
+ if num_batched_tokens + total_tokens > self.max_num_batched_tokens or not can_allocate_both:
52
+ break
53
+
54
+ # Schedule both sequences: conditional first, then unconditional
55
+ for s in [seq, paired_seq]:
56
+ num_seqs += 1
57
+ self.block_manager.allocate(s)
58
+ num_batched_tokens += len(s) - s.num_cached_tokens
59
+ s.status = SequenceStatus.RUNNING
60
+ self.waiting.remove(s)
61
+ self.running.append(s)
62
+ scheduled_seqs.append(s)
63
+ processed_seqs.add(s.seq_id)
64
+ else:
65
+ # Normal sequence or unconditional sequence (already processed with its conditional)
66
+ if seq.seq_id in processed_seqs:
67
+ # Skip if already processed as part of a CFG pair
68
+ self.waiting.popleft()
69
+ continue
70
+
71
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
72
+ break
73
+ num_seqs += 1
74
+ self.block_manager.allocate(seq)
75
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
76
+ seq.status = SequenceStatus.RUNNING
77
+ self.waiting.popleft()
78
+ self.running.append(seq)
79
+ scheduled_seqs.append(seq)
80
+
81
+ if scheduled_seqs:
82
+ # For CFG batches, ensure conditional sequences come before their unconditional pairs
83
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
84
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
85
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
86
+
87
+ # Reorder: non-CFG, then CFG conditional, then CFG unconditional
88
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
89
+ return scheduled_seqs, True
90
+
91
+ # decode
92
+ processed_seqs = set()
93
+ temp_running = list(self.running) # Work with a copy
94
+
95
+ while temp_running and num_seqs < self.max_num_seqs:
96
+ seq = temp_running.pop(0)
97
+
98
+ # For CFG sequences, ensure conditional and unconditional are scheduled together
99
+ if seq.cfg_scale > 1.0 and seq.paired_seq is not None and not seq.is_unconditional:
100
+ paired_seq = seq.paired_seq
101
+ if paired_seq not in temp_running:
102
+ # Paired sequence not available, skip for now
103
+ continue
104
+
105
+ # Remove paired_seq from temp_running
106
+ temp_running.remove(paired_seq)
107
+
108
+ # FIX: Check if we have enough blocks for BOTH sequences to append
109
+ # Each sequence needs 1 block when at block boundary (len % block_size == 1)
110
+ block_size = self.block_manager.block_size
111
+ blocks_needed_seq = 1 if len(seq) % block_size == 1 else 0
112
+ blocks_needed_paired = 1 if len(paired_seq) % block_size == 1 else 0
113
+ total_blocks_needed = blocks_needed_seq + blocks_needed_paired
114
+ can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
115
+
116
+ if not can_append_both:
117
+ # Try preempting other sequences
118
+ preempted = False
119
+ while not can_append_both and temp_running:
120
+ other_seq = temp_running.pop(0)
121
+ if other_seq != seq and other_seq != paired_seq:
122
+ self.preempt(other_seq)
123
+ # Recalculate with the same correct logic
124
+ can_append_both = len(self.block_manager.free_block_ids) >= total_blocks_needed
125
+ preempted = True
126
+ else:
127
+ temp_running.append(other_seq)
128
+ break
129
+
130
+ if not can_append_both:
131
+ # Can't schedule this pair right now
132
+ temp_running.append(seq)
133
+ temp_running.append(paired_seq)
134
+ continue
135
+
136
+ # Schedule both sequences
137
+ for s in [seq, paired_seq]:
138
+ num_seqs += 1
139
+ self.block_manager.may_append(s)
140
+ scheduled_seqs.append(s)
141
+ processed_seqs.add(s.seq_id)
142
+ # Remove from actual running list if scheduled
143
+ if s in self.running:
144
+ self.running.remove(s)
145
+ else:
146
+ # Normal sequence or unconditional (already processed)
147
+ if seq.seq_id in processed_seqs:
148
+ continue
149
+
150
+ while not self.block_manager.can_append(seq):
151
+ if temp_running:
152
+ other_seq = temp_running.pop(0)
153
+ if other_seq != seq:
154
+ self.preempt(other_seq)
155
+ else:
156
+ temp_running.append(other_seq)
157
+ break
158
+ else:
159
+ self.preempt(seq)
160
+ if seq in self.running:
161
+ self.running.remove(seq)
162
+ break
163
+ else:
164
+ num_seqs += 1
165
+ self.block_manager.may_append(seq)
166
+ scheduled_seqs.append(seq)
167
+ if seq in self.running:
168
+ self.running.remove(seq)
169
+
170
+ assert scheduled_seqs
171
+
172
+ # For CFG batches in decode, ensure conditional sequences come before unconditional
173
+ cfg_cond_seqs = [s for s in scheduled_seqs if s.cfg_scale > 1.0 and not s.is_unconditional]
174
+ cfg_uncond_seqs = [s for s in scheduled_seqs if s.is_unconditional]
175
+ non_cfg_seqs = [s for s in scheduled_seqs if s.cfg_scale <= 1.0]
176
+ scheduled_seqs = non_cfg_seqs + cfg_cond_seqs + cfg_uncond_seqs
177
+
178
+ self.running.extendleft(reversed(scheduled_seqs))
179
+ return scheduled_seqs, False
180
+
181
+ def preempt(self, seq: Sequence):
182
+ seq.status = SequenceStatus.WAITING
183
+ self.block_manager.deallocate(seq)
184
+ self.waiting.appendleft(seq)
185
+
186
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
187
+ # Check if this is a CFG batch
188
+ is_cfg_batch = False
189
+ if len(seqs) > 0 and seqs[0].cfg_scale > 1.0 and seqs[0].paired_seq is not None:
190
+ num_cond = len(seqs) // 2
191
+ is_cfg_batch = (num_cond > 0 and
192
+ not seqs[0].is_unconditional and
193
+ seqs[num_cond].is_unconditional)
194
+
195
+ if is_cfg_batch:
196
+ # CFG batch: seqs = [cond_seq1, cond_seq2, ..., uncond_seq1, uncond_seq2, ...]
197
+ # token_ids correspond to conditional sequences only (sampled from CFG logits)
198
+ num_cond = len(seqs) // 2
199
+ cond_seqs = seqs[:num_cond]
200
+ uncond_seqs = seqs[num_cond:]
201
+
202
+ # Apply the same sampled token to both conditional and unconditional sequences
203
+ for i, (cond_seq, uncond_seq, token_id) in enumerate(zip(cond_seqs, uncond_seqs, token_ids)):
204
+ cond_seq.append_token(token_id)
205
+ uncond_seq.append_token(token_id) # Same token for unconditional
206
+
207
+ # Check if either sequence is finished
208
+ cond_finished = ((not cond_seq.ignore_eos and token_id == self.eos) or
209
+ cond_seq.num_completion_tokens == cond_seq.max_tokens)
210
+ uncond_finished = ((not uncond_seq.ignore_eos and token_id == self.eos) or
211
+ uncond_seq.num_completion_tokens == uncond_seq.max_tokens)
212
+
213
+ if cond_finished or uncond_finished:
214
+ # Mark both as finished
215
+ cond_seq.status = SequenceStatus.FINISHED
216
+ uncond_seq.status = SequenceStatus.FINISHED
217
+ self.block_manager.deallocate(cond_seq)
218
+ self.block_manager.deallocate(uncond_seq)
219
+ if cond_seq in self.running:
220
+ self.running.remove(cond_seq)
221
+ if uncond_seq in self.running:
222
+ self.running.remove(uncond_seq)
223
+ else:
224
+ # Normal batch
225
+ for seq, token_id in zip(seqs, token_ids):
226
+ seq.append_token(token_id)
227
+ if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
228
+ seq.status = SequenceStatus.FINISHED
229
+ self.block_manager.deallocate(seq)
230
+ self.running.remove(seq)
acestep/third_parts/nano-vllm/nanovllm/engine/sequence.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from enum import Enum, auto
3
+ from itertools import count
4
+ from typing import Optional, Callable, Any
5
+
6
+ from nanovllm.sampling_params import SamplingParams
7
+
8
+
9
+ class SequenceStatus(Enum):
10
+ WAITING = auto()
11
+ RUNNING = auto()
12
+ FINISHED = auto()
13
+
14
+
15
+ class Sequence:
16
+ block_size = 256
17
+ counter = count()
18
+
19
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), is_unconditional: bool = False, conditional_seq = None):
20
+ self.seq_id = next(Sequence.counter)
21
+ self.status = SequenceStatus.WAITING
22
+ self.token_ids = copy(token_ids)
23
+ self.last_token = token_ids[-1]
24
+ self.num_tokens = len(self.token_ids)
25
+ self.num_prompt_tokens = len(token_ids)
26
+ self.num_cached_tokens = 0
27
+ self.block_table = []
28
+ self.temperature = sampling_params.temperature
29
+ self.max_tokens = sampling_params.max_tokens
30
+ self.ignore_eos = sampling_params.ignore_eos
31
+ self.cfg_scale = sampling_params.cfg_scale
32
+ self.top_k = sampling_params.top_k
33
+ self.top_p = sampling_params.top_p
34
+ self.repetition_penalty = sampling_params.repetition_penalty
35
+ # For CFG: mark if this is an unconditional sequence
36
+ self.is_unconditional = is_unconditional
37
+ # For CFG: reference to the corresponding conditional sequence (if this is unconditional)
38
+ # For conditional sequences, this points to the unconditional sequence
39
+ self.paired_seq = conditional_seq # For conditional seq, points to uncond; for uncond seq, points to cond
40
+ # For constrained decoding: logits processor and state update callback
41
+ self.logits_processor: Optional[Any] = sampling_params.logits_processor
42
+ self.logits_processor_update_state: Optional[Callable[[int], None]] = sampling_params.logits_processor_update_state
43
+
44
+ def __len__(self):
45
+ return self.num_tokens
46
+
47
+ def __getitem__(self, key):
48
+ return self.token_ids[key]
49
+
50
+ @property
51
+ def is_finished(self):
52
+ return self.status == SequenceStatus.FINISHED
53
+
54
+ @property
55
+ def num_completion_tokens(self):
56
+ return self.num_tokens - self.num_prompt_tokens
57
+
58
+ @property
59
+ def prompt_token_ids(self):
60
+ return self.token_ids[:self.num_prompt_tokens]
61
+
62
+ @property
63
+ def completion_token_ids(self):
64
+ return self.token_ids[self.num_prompt_tokens:]
65
+
66
+ @property
67
+ def num_cached_blocks(self):
68
+ return self.num_cached_tokens // self.block_size
69
+
70
+ @property
71
+ def num_blocks(self):
72
+ return (self.num_tokens + self.block_size - 1) // self.block_size
73
+
74
+ @property
75
+ def last_block_num_tokens(self):
76
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
77
+
78
+ def block(self, i):
79
+ assert 0 <= i < self.num_blocks
80
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
81
+
82
+ def append_token(self, token_id: int):
83
+ self.token_ids.append(token_id)
84
+ self.last_token = token_id
85
+ self.num_tokens += 1
86
+
87
+ def __getstate__(self):
88
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
89
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
90
+
91
+ def __setstate__(self, state):
92
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
93
+ if self.num_completion_tokens == 0:
94
+ self.token_ids = state[-1]
95
+ else:
96
+ self.last_token = state[-1]
acestep/third_parts/nano-vllm/nanovllm/layers/activation.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SiluAndMul(nn.Module):
7
+
8
+ def __init__(self):
9
+ super().__init__()
10
+
11
+ @torch.compile
12
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
13
+ x, y = x.chunk(2, -1)
14
+ return F.silu(x) * y
acestep/third_parts/nano-vllm/nanovllm/layers/attention.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import triton
4
+ import triton.language as tl
5
+
6
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
7
+ from nanovllm.utils.context import get_context
8
+
9
+
10
+ @triton.jit
11
+ def store_kvcache_kernel(
12
+ key_ptr,
13
+ key_stride,
14
+ value_ptr,
15
+ value_stride,
16
+ k_cache_ptr,
17
+ v_cache_ptr,
18
+ slot_mapping_ptr,
19
+ D: tl.constexpr,
20
+ ):
21
+ idx = tl.program_id(0)
22
+ slot = tl.load(slot_mapping_ptr + idx)
23
+ if slot == -1: return
24
+ key_offsets = idx * key_stride + tl.arange(0, D)
25
+ value_offsets = idx * value_stride + tl.arange(0, D)
26
+ key = tl.load(key_ptr + key_offsets)
27
+ value = tl.load(value_ptr + value_offsets)
28
+ cache_offsets = slot * D + tl.arange(0, D)
29
+ tl.store(k_cache_ptr + cache_offsets, key)
30
+ tl.store(v_cache_ptr + cache_offsets, value)
31
+
32
+
33
+ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
34
+ N, num_heads, head_dim = key.shape
35
+ D = num_heads * head_dim
36
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
37
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
38
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
39
+ assert slot_mapping.numel() == N
40
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
41
+
42
+
43
+ class Attention(nn.Module):
44
+
45
+ def __init__(
46
+ self,
47
+ num_heads,
48
+ head_dim,
49
+ scale,
50
+ num_kv_heads,
51
+ ):
52
+ super().__init__()
53
+ self.num_heads = num_heads
54
+ self.head_dim = head_dim
55
+ self.scale = scale
56
+ self.num_kv_heads = num_kv_heads
57
+ self.k_cache = self.v_cache = torch.tensor([])
58
+
59
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
60
+ context = get_context()
61
+ k_cache, v_cache = self.k_cache, self.v_cache
62
+ if k_cache.numel() and v_cache.numel():
63
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
64
+ if context.is_prefill:
65
+ if context.block_tables is not None: # prefix cache
66
+ k, v = k_cache, v_cache
67
+ o = flash_attn_varlen_func(q, k, v,
68
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
69
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
70
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
71
+ else: # decode
72
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
73
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
74
+ softmax_scale=self.scale, causal=True)
75
+ return o
acestep/third_parts/nano-vllm/nanovllm/layers/embed_head.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+ from nanovllm.utils.context import get_context
7
+
8
+
9
+ class VocabParallelEmbedding(nn.Module):
10
+
11
+ def __init__(
12
+ self,
13
+ num_embeddings: int,
14
+ embedding_dim: int,
15
+ ):
16
+ super().__init__()
17
+ self.tp_rank = dist.get_rank()
18
+ self.tp_size = dist.get_world_size()
19
+ assert num_embeddings % self.tp_size == 0
20
+ self.num_embeddings = num_embeddings
21
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
22
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
23
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
24
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
25
+ self.weight.weight_loader = self.weight_loader
26
+
27
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
28
+ param_data = param.data
29
+ shard_size = param_data.size(0)
30
+ start_idx = self.tp_rank * shard_size
31
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
32
+ param_data.copy_(loaded_weight)
33
+
34
+ def forward(self, x: torch.Tensor):
35
+ if self.tp_size > 1:
36
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
37
+ x = mask * (x - self.vocab_start_idx)
38
+ y = F.embedding(x, self.weight)
39
+ if self.tp_size > 1:
40
+ y = mask.unsqueeze(1) * y
41
+ dist.all_reduce(y)
42
+ return y
43
+
44
+
45
+ class ParallelLMHead(VocabParallelEmbedding):
46
+
47
+ def __init__(
48
+ self,
49
+ num_embeddings: int,
50
+ embedding_dim: int,
51
+ bias: bool = False,
52
+ ):
53
+ assert not bias
54
+ super().__init__(num_embeddings, embedding_dim)
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ context = get_context()
58
+ if context.is_prefill:
59
+ last_indices = context.cu_seqlens_q[1:] - 1
60
+ x = x[last_indices].contiguous()
61
+ logits = F.linear(x, self.weight)
62
+ if self.tp_size > 1:
63
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
64
+ dist.gather(logits, all_logits, 0)
65
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
66
+ return logits
acestep/third_parts/nano-vllm/nanovllm/layers/layernorm.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+
7
+ def __init__(
8
+ self,
9
+ hidden_size: int,
10
+ eps: float = 1e-6,
11
+ ) -> None:
12
+ super().__init__()
13
+ self.eps = eps
14
+ self.weight = nn.Parameter(torch.ones(hidden_size))
15
+
16
+ @torch.compile
17
+ def rms_forward(
18
+ self,
19
+ x: torch.Tensor,
20
+ ) -> torch.Tensor:
21
+ orig_dtype = x.dtype
22
+ x = x.float()
23
+ var = x.pow(2).mean(dim=-1, keepdim=True)
24
+ x.mul_(torch.rsqrt(var + self.eps))
25
+ x = x.to(orig_dtype).mul_(self.weight)
26
+ return x
27
+
28
+ @torch.compile
29
+ def add_rms_forward(
30
+ self,
31
+ x: torch.Tensor,
32
+ residual: torch.Tensor,
33
+ ) -> tuple[torch.Tensor, torch.Tensor]:
34
+ orig_dtype = x.dtype
35
+ x = x.float().add_(residual.float())
36
+ residual = x.to(orig_dtype)
37
+ var = x.pow(2).mean(dim=-1, keepdim=True)
38
+ x.mul_(torch.rsqrt(var + self.eps))
39
+ x = x.to(orig_dtype).mul_(self.weight)
40
+ return x, residual
41
+
42
+ def forward(
43
+ self,
44
+ x: torch.Tensor,
45
+ residual: torch.Tensor | None = None,
46
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
47
+ if residual is None:
48
+ return self.rms_forward(x)
49
+ else:
50
+ return self.add_rms_forward(x, residual)
acestep/third_parts/nano-vllm/nanovllm/layers/linear.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+
6
+
7
+ def divide(numerator, denominator):
8
+ assert numerator % denominator == 0
9
+ return numerator // denominator
10
+
11
+
12
+ class LinearBase(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ input_size: int,
17
+ output_size: int,
18
+ bias: bool = False,
19
+ tp_dim: int | None = None,
20
+ ):
21
+ super().__init__()
22
+ self.tp_dim = tp_dim
23
+ self.tp_rank = dist.get_rank()
24
+ self.tp_size = dist.get_world_size()
25
+ self.weight = nn.Parameter(torch.empty(output_size, input_size))
26
+ self.weight.weight_loader = self.weight_loader
27
+ if bias:
28
+ self.bias = nn.Parameter(torch.empty(output_size))
29
+ self.bias.weight_loader = self.weight_loader
30
+ else:
31
+ self.register_parameter("bias", None)
32
+
33
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
34
+ raise NotImplementedError
35
+
36
+
37
+ class ReplicatedLinear(LinearBase):
38
+
39
+ def __init__(
40
+ self,
41
+ input_size: int,
42
+ output_size: int,
43
+ bias: bool = False,
44
+ ):
45
+ super().__init__(input_size, output_size, bias)
46
+
47
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
48
+ param.data.copy_(loaded_weight)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return F.linear(x, self.weight, self.bias)
52
+
53
+
54
+ class ColumnParallelLinear(LinearBase):
55
+
56
+ def __init__(
57
+ self,
58
+ input_size: int,
59
+ output_size: int,
60
+ bias: bool = False,
61
+ ):
62
+ tp_size = dist.get_world_size()
63
+ super().__init__(input_size, divide(output_size, tp_size), bias, 0)
64
+
65
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
66
+ param_data = param.data
67
+ shard_size = param_data.size(self.tp_dim)
68
+ start_idx = self.tp_rank * shard_size
69
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
70
+ param_data.copy_(loaded_weight)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return F.linear(x, self.weight, self.bias)
74
+
75
+
76
+ class MergedColumnParallelLinear(ColumnParallelLinear):
77
+
78
+ def __init__(
79
+ self,
80
+ input_size: int,
81
+ output_sizes: list[int],
82
+ bias: bool = False,
83
+ ):
84
+ self.output_sizes = output_sizes
85
+ super().__init__(input_size, sum(output_sizes), bias)
86
+
87
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
88
+ param_data = param.data
89
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
90
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
91
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
92
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
93
+ param_data.copy_(loaded_weight)
94
+
95
+
96
+ class QKVParallelLinear(ColumnParallelLinear):
97
+
98
+ def __init__(
99
+ self,
100
+ hidden_size: int,
101
+ head_size: int,
102
+ total_num_heads: int,
103
+ total_num_kv_heads: int | None = None,
104
+ bias: bool = False,
105
+ ):
106
+ tp_size = dist.get_world_size()
107
+ total_num_kv_heads = total_num_kv_heads or total_num_heads
108
+ self.head_size = head_size
109
+ self.num_heads = divide(total_num_heads, tp_size)
110
+ self.num_kv_heads = divide(total_num_kv_heads, tp_size)
111
+ output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
112
+ super().__init__(hidden_size, output_size, bias)
113
+
114
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
115
+ param_data = param.data
116
+ assert loaded_shard_id in ["q", "k", "v"]
117
+ if loaded_shard_id == "q":
118
+ shard_size = self.num_heads * self.head_size
119
+ shard_offset = 0
120
+ elif loaded_shard_id == "k":
121
+ shard_size = self.num_kv_heads * self.head_size
122
+ shard_offset = self.num_heads * self.head_size
123
+ else:
124
+ shard_size = self.num_kv_heads * self.head_size
125
+ shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
126
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
127
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
128
+ param_data.copy_(loaded_weight)
129
+
130
+
131
+ class RowParallelLinear(LinearBase):
132
+
133
+ def __init__(
134
+ self,
135
+ input_size: int,
136
+ output_size: int,
137
+ bias: bool = False,
138
+ ):
139
+ tp_size = dist.get_world_size()
140
+ super().__init__(divide(input_size, tp_size), output_size, bias, 1)
141
+
142
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
143
+ param_data = param.data
144
+ shard_size = param_data.size(self.tp_dim)
145
+ start_idx = self.tp_rank * shard_size
146
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
147
+ param_data.copy_(loaded_weight)
148
+
149
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
150
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
151
+ if self.tp_size > 1:
152
+ dist.all_reduce(y)
153
+ return y
acestep/third_parts/nano-vllm/nanovllm/layers/rotary_embedding.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ def apply_rotary_emb(
7
+ x: torch.Tensor,
8
+ cos: torch.Tensor,
9
+ sin: torch.Tensor,
10
+ ) -> torch.Tensor:
11
+ x1, x2 = torch.chunk(x.float(), 2, dim=-1)
12
+ y1 = x1 * cos - x2 * sin
13
+ y2 = x2 * cos + x1 * sin
14
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
15
+
16
+
17
+ class RotaryEmbedding(nn.Module):
18
+
19
+ def __init__(
20
+ self,
21
+ head_size: int,
22
+ rotary_dim: int,
23
+ max_position_embeddings: int,
24
+ base: float,
25
+ ) -> None:
26
+ super().__init__()
27
+ self.head_size = head_size
28
+ assert rotary_dim == head_size
29
+ inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
30
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
31
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
32
+ cos = freqs.cos()
33
+ sin = freqs.sin()
34
+ cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
35
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
36
+
37
+ @torch.compile
38
+ def forward(
39
+ self,
40
+ positions: torch.Tensor,
41
+ query: torch.Tensor,
42
+ key: torch.Tensor,
43
+ ) -> tuple[torch.Tensor, torch.Tensor]:
44
+ cos_sin = self.cos_sin_cache[positions]
45
+ cos, sin = cos_sin.chunk(2, dim=-1)
46
+ query = apply_rotary_emb(query, cos, sin)
47
+ key = apply_rotary_emb(key, cos, sin)
48
+ return query, key
49
+
50
+
51
+ @lru_cache(1)
52
+ def get_rope(
53
+ head_size: int,
54
+ rotary_dim: int,
55
+ max_position: int,
56
+ base: float,
57
+ rope_scaling: dict | None = None,
58
+ ):
59
+ assert rope_scaling is None
60
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
61
+ return rotary_emb
acestep/third_parts/nano-vllm/nanovllm/layers/sampler.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from typing import Optional
4
+
5
+
6
+ def apply_top_k_top_p(
7
+ logits: torch.Tensor,
8
+ k: Optional[torch.Tensor],
9
+ p: Optional[torch.Tensor],
10
+ ) -> torch.Tensor:
11
+ """Apply top-k and top-p masks to the logits (vLLM style).
12
+
13
+ The logits tensor is updated in-place.
14
+ """
15
+ if p is None:
16
+ if k is None:
17
+ return logits
18
+ # Avoid sorting vocab for top-k only case
19
+ return apply_top_k_only(logits, k)
20
+
21
+ # Need to sort for top-p
22
+ logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
23
+
24
+ if k is not None:
25
+ # Apply top-k first
26
+ vocab_size = logits_sort.size(1)
27
+ # Clamp k to valid range
28
+ k_clamped = k.clamp(1, vocab_size).long()
29
+ top_k_mask_idx = vocab_size - k_clamped # shape: [B]
30
+ # Get the threshold value for each batch
31
+ top_k_thresh = logits_sort.gather(1, top_k_mask_idx.unsqueeze(1))
32
+ top_k_mask = logits_sort < top_k_thresh
33
+ logits_sort.masked_fill_(top_k_mask, float('-inf'))
34
+
35
+ # Apply top-p
36
+ probs_sort = logits_sort.softmax(dim=-1)
37
+ probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) # reuse buffer
38
+ top_p_mask = probs_sum <= (1.0 - p.unsqueeze(1))
39
+ # Ensure at least one token is kept
40
+ top_p_mask[:, -1] = False
41
+ logits_sort.masked_fill_(top_p_mask, float('-inf'))
42
+
43
+ # Re-sort back to original positions
44
+ logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
45
+ return logits
46
+
47
+
48
+ def apply_top_k_only(
49
+ logits: torch.Tensor,
50
+ k: torch.Tensor,
51
+ ) -> torch.Tensor:
52
+ """Apply top-k mask without sorting the entire vocab (vLLM style).
53
+
54
+ This is much faster than sorting for top-k only cases.
55
+ The logits tensor is updated in-place.
56
+ """
57
+ vocab_size = logits.shape[1]
58
+ # Handle cases where k >= vocab_size (no filtering needed)
59
+ no_top_k_mask = (k <= 0) | (k >= vocab_size)
60
+ # Set invalid k to 1 so we can still gather
61
+ k_safe = k.masked_fill(no_top_k_mask, 1).long()
62
+ # NOTE: This int() causes CPU-GPU sync, but torch.topk requires Python int
63
+ max_top_k = int(k_safe.max().clamp(max=vocab_size))
64
+
65
+ # Get top-k values for all batches
66
+ # topk.values has shape [batch_size, max_top_k]
67
+ topk_values = logits.topk(max_top_k, dim=1).values
68
+
69
+ # Convert k to 0-based index: we want the k-th largest value (index k-1)
70
+ # Clamp to valid range for gather
71
+ k_index = (k_safe - 1).clamp(0, max_top_k - 1).unsqueeze(1) # shape: [B, 1]
72
+ # Gather the threshold value (the k-th largest)
73
+ top_k_thresh = topk_values.gather(1, k_index)
74
+
75
+ # For rows with no top-k filtering, set threshold to -inf so nothing gets masked
76
+ top_k_thresh.masked_fill_(no_top_k_mask.unsqueeze(1), float('-inf'))
77
+
78
+ # Mask all values below the threshold
79
+ logits.masked_fill_(logits < top_k_thresh, float('-inf'))
80
+ return logits
81
+
82
+
83
+ class Sampler(nn.Module):
84
+
85
+ def __init__(self):
86
+ super().__init__()
87
+
88
+ @torch.compile
89
+ def forward(
90
+ self,
91
+ logits: torch.Tensor,
92
+ temperatures: torch.Tensor,
93
+ top_ks: Optional[torch.Tensor] = None,
94
+ top_ps: Optional[torch.Tensor] = None,
95
+ repetition_penalties: Optional[torch.Tensor] = None,
96
+ input_ids: Optional[torch.Tensor] = None,
97
+ ):
98
+ """
99
+ Sample tokens from logits with optional top-k and top-p filtering.
100
+
101
+ Condition checking is done OUTSIDE the compiled function to avoid
102
+ graph breaks from .any() calls.
103
+ """
104
+ # Apply temperature
105
+ logits = logits.float().div_(temperatures.unsqueeze(dim=1))
106
+
107
+ logits = apply_top_k_top_p(
108
+ logits,
109
+ top_ks,
110
+ top_ps,
111
+ )
112
+ probs = torch.softmax(logits, dim=-1)
113
+ sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
114
+ return sample_tokens
acestep/third_parts/nano-vllm/nanovllm/llm.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from nanovllm.engine.llm_engine import LLMEngine
2
+
3
+
4
+ class LLM(LLMEngine):
5
+ pass