ChuxiJ commited on
Commit
1da0418
·
1 Parent(s): 134e9c0

refact ui and add i18n

Browse files
acestep/acestep_v15_pipeline.py CHANGED
@@ -26,7 +26,7 @@ except ImportError:
26
  from acestep.gradio_ui import create_gradio_interface
27
 
28
 
29
- def create_demo(init_params=None):
30
  """
31
  Create Gradio demo interface
32
 
@@ -36,7 +36,9 @@ def create_demo(init_params=None):
36
  Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
37
  'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
38
  'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
39
- 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized)
 
 
40
 
41
  Returns:
42
  Gradio Blocks instance
@@ -52,20 +54,52 @@ def create_demo(init_params=None):
52
  dataset_handler = DatasetHandler() # Dataset handler
53
 
54
  # Create Gradio interface with all handlers and initialization parameters
55
- demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params)
56
 
57
  return demo
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def main():
61
  """Main entry function"""
62
  import argparse
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
65
  parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
66
  parser.add_argument("--share", action="store_true", help="Create a public link")
67
  parser.add_argument("--debug", action="store_true", help="Enable debug mode")
68
  parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
 
69
 
70
  # Service initialization arguments
71
  parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
@@ -76,7 +110,7 @@ def main():
76
  parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
77
  parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
78
  parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
79
- parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload models to CPU (default: False)")
80
  parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
81
 
82
  args = parser.parse_args()
@@ -176,14 +210,24 @@ def main():
176
  'init_status': init_status,
177
  'enable_generate': enable_generate,
178
  'dit_handler': dit_handler,
179
- 'llm_handler': llm_handler
 
180
  }
181
 
182
  print("Service initialization completed successfully!")
183
 
184
  # Create and launch demo
185
- print("Creating Gradio interface...")
186
- demo = create_demo(init_params=init_params)
 
 
 
 
 
 
 
 
 
187
  print(f"Launching server on {args.server_name}:{args.port}...")
188
  demo.launch(
189
  server_name=args.server_name,
 
26
  from acestep.gradio_ui import create_gradio_interface
27
 
28
 
29
+ def create_demo(init_params=None, language='en'):
30
  """
31
  Create Gradio demo interface
32
 
 
36
  Keys: 'pre_initialized' (bool), 'checkpoint', 'config_path', 'device',
37
  'init_llm', 'lm_model_path', 'backend', 'use_flash_attention',
38
  'offload_to_cpu', 'offload_dit_to_cpu', 'init_status',
39
+ 'dit_handler', 'llm_handler' (initialized handlers if pre-initialized),
40
+ 'language' (UI language code)
41
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
42
 
43
  Returns:
44
  Gradio Blocks instance
 
54
  dataset_handler = DatasetHandler() # Dataset handler
55
 
56
  # Create Gradio interface with all handlers and initialization parameters
57
+ demo = create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=init_params, language=language)
58
 
59
  return demo
60
 
61
 
62
+ def get_gpu_memory_gb():
63
+ """
64
+ Get GPU memory in GB. Returns 0 if no GPU is available.
65
+ """
66
+ try:
67
+ import torch
68
+ if torch.cuda.is_available():
69
+ # Get total memory of the first GPU in GB
70
+ total_memory = torch.cuda.get_device_properties(0).total_memory
71
+ memory_gb = total_memory / (1024**3) # Convert bytes to GB
72
+ return memory_gb
73
+ else:
74
+ return 0
75
+ except Exception as e:
76
+ print(f"Warning: Failed to detect GPU memory: {e}", file=sys.stderr)
77
+ return 0
78
+
79
+
80
  def main():
81
  """Main entry function"""
82
  import argparse
83
 
84
+ # Detect GPU memory to auto-configure offload settings
85
+ gpu_memory_gb = get_gpu_memory_gb()
86
+ auto_offload = gpu_memory_gb > 0 and gpu_memory_gb < 16
87
+
88
+ if auto_offload:
89
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (< 16GB)")
90
+ print("Auto-enabling CPU offload to reduce GPU memory usage")
91
+ elif gpu_memory_gb > 0:
92
+ print(f"Detected GPU memory: {gpu_memory_gb:.2f} GB (>= 16GB)")
93
+ print("CPU offload disabled by default")
94
+ else:
95
+ print("No GPU detected, running on CPU")
96
+
97
  parser = argparse.ArgumentParser(description="Gradio Demo for ACE-Step V1.5")
98
  parser.add_argument("--port", type=int, default=7860, help="Port to run the gradio server on")
99
  parser.add_argument("--share", action="store_true", help="Create a public link")
100
  parser.add_argument("--debug", action="store_true", help="Enable debug mode")
101
  parser.add_argument("--server-name", type=str, default="127.0.0.1", help="Server name (default: 127.0.0.1, use 0.0.0.0 for all interfaces)")
102
+ parser.add_argument("--language", type=str, default="en", choices=["en", "zh", "ja"], help="UI language: en (English), zh (中文), ja (日本語)")
103
 
104
  # Service initialization arguments
105
  parser.add_argument("--init_service", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Initialize service on startup (default: False)")
 
110
  parser.add_argument("--lm_model_path", type=str, default=None, help="5Hz LM model path (e.g., 'acestep-5Hz-lm-0.6B')")
111
  parser.add_argument("--backend", type=str, default="vllm", choices=["vllm", "pt"], help="5Hz LM backend (default: vllm)")
112
  parser.add_argument("--use_flash_attention", type=lambda x: x.lower() in ['true', '1', 'yes'], default=None, help="Use flash attention (default: auto-detect)")
113
+ parser.add_argument("--offload_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=auto_offload, help=f"Offload models to CPU (default: {'True' if auto_offload else 'False'}, auto-detected based on GPU VRAM)")
114
  parser.add_argument("--offload_dit_to_cpu", type=lambda x: x.lower() in ['true', '1', 'yes'], default=False, help="Offload DiT to CPU (default: False)")
115
 
116
  args = parser.parse_args()
 
210
  'init_status': init_status,
211
  'enable_generate': enable_generate,
212
  'dit_handler': dit_handler,
213
+ 'llm_handler': llm_handler,
214
+ 'language': args.language
215
  }
216
 
217
  print("Service initialization completed successfully!")
218
 
219
  # Create and launch demo
220
+ print(f"Creating Gradio interface with language: {args.language}...")
221
+ demo = create_demo(init_params=init_params, language=args.language)
222
+
223
+ # Enable queue for multi-user support
224
+ # This ensures proper request queuing and prevents concurrent generation conflicts
225
+ print("Enabling queue for multi-user support...")
226
+ demo.queue(
227
+ max_size=20, # Maximum queue size (adjust based on your needs)
228
+ status_update_rate="auto", # Update rate for queue status
229
+ )
230
+
231
  print(f"Launching server on {args.server_name}:{args.port}...")
232
  demo.launch(
233
  server_name=args.server_name,
acestep/constants.py CHANGED
@@ -96,3 +96,12 @@ TRACK_NAMES = [
96
  "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
97
  ]
98
 
 
 
 
 
 
 
 
 
 
 
96
  "keyboard", "guitar", "bass", "drums", "backing_vocals", "vocals"
97
  ]
98
 
99
+ SFT_GEN_PROMPT = """# Instruction
100
+ {}
101
+
102
+ # Caption
103
+ {}
104
+
105
+ # Metas
106
+ {}<|endoftext|>
107
+ """
acestep/gradio_ui/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from acestep.gradio_ui.interfaces import create_gradio_interface
acestep/{gradio_ui.py → gradio_ui/event.py} RENAMED
The diff for this file is too large to render. See raw diff
 
acestep/gradio_ui/events/__init__.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Event Handlers Module
3
+ Main entry point for setting up all event handlers
4
+ """
5
+ import gradio as gr
6
+ from typing import Optional
7
+
8
+ # Import handler modules
9
+ from . import generation_handlers as gen_h
10
+ from . import results_handlers as res_h
11
+ from acestep.gradio_ui.i18n import t
12
+
13
+
14
+ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section):
15
+ """Setup event handlers connecting UI components and business logic"""
16
+
17
+ # ========== Dataset Handlers ==========
18
+ dataset_section["import_dataset_btn"].click(
19
+ fn=dataset_handler.import_dataset,
20
+ inputs=[dataset_section["dataset_type"]],
21
+ outputs=[dataset_section["data_status"]]
22
+ )
23
+
24
+ # ========== Service Initialization ==========
25
+ generation_section["refresh_btn"].click(
26
+ fn=lambda: gen_h.refresh_checkpoints(dit_handler),
27
+ outputs=[generation_section["checkpoint_dropdown"]]
28
+ )
29
+
30
+ generation_section["config_path"].change(
31
+ fn=gen_h.update_model_type_settings,
32
+ inputs=[generation_section["config_path"]],
33
+ outputs=[
34
+ generation_section["inference_steps"],
35
+ generation_section["guidance_scale"],
36
+ generation_section["use_adg"],
37
+ generation_section["cfg_interval_start"],
38
+ generation_section["cfg_interval_end"],
39
+ generation_section["task_type"],
40
+ ]
41
+ )
42
+
43
+ generation_section["init_btn"].click(
44
+ fn=lambda *args: gen_h.init_service_wrapper(dit_handler, llm_handler, *args),
45
+ inputs=[
46
+ generation_section["checkpoint_dropdown"],
47
+ generation_section["config_path"],
48
+ generation_section["device"],
49
+ generation_section["init_llm_checkbox"],
50
+ generation_section["lm_model_path"],
51
+ generation_section["backend_dropdown"],
52
+ generation_section["use_flash_attention_checkbox"],
53
+ generation_section["offload_to_cpu_checkbox"],
54
+ generation_section["offload_dit_to_cpu_checkbox"],
55
+ ],
56
+ outputs=[generation_section["init_status"], generation_section["generate_btn"], generation_section["service_config_accordion"]]
57
+ )
58
+
59
+ # ========== UI Visibility Updates ==========
60
+ generation_section["init_llm_checkbox"].change(
61
+ fn=gen_h.update_negative_prompt_visibility,
62
+ inputs=[generation_section["init_llm_checkbox"]],
63
+ outputs=[generation_section["lm_negative_prompt"]]
64
+ )
65
+
66
+ generation_section["init_llm_checkbox"].change(
67
+ fn=gen_h.update_audio_cover_strength_visibility,
68
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
69
+ outputs=[generation_section["audio_cover_strength"]]
70
+ )
71
+
72
+ generation_section["task_type"].change(
73
+ fn=gen_h.update_audio_cover_strength_visibility,
74
+ inputs=[generation_section["task_type"], generation_section["init_llm_checkbox"]],
75
+ outputs=[generation_section["audio_cover_strength"]]
76
+ )
77
+
78
+ generation_section["batch_size_input"].change(
79
+ fn=gen_h.update_audio_components_visibility,
80
+ inputs=[generation_section["batch_size_input"]],
81
+ outputs=[
82
+ results_section["audio_col_1"],
83
+ results_section["audio_col_2"],
84
+ results_section["audio_col_3"],
85
+ results_section["audio_col_4"],
86
+ results_section["audio_row_5_8"],
87
+ results_section["audio_col_5"],
88
+ results_section["audio_col_6"],
89
+ results_section["audio_col_7"],
90
+ results_section["audio_col_8"],
91
+ ]
92
+ )
93
+
94
+ # Update codes hints visibility
95
+ for trigger in [generation_section["src_audio"], generation_section["allow_lm_batch"], generation_section["batch_size_input"]]:
96
+ trigger.change(
97
+ fn=gen_h.update_codes_hints_visibility,
98
+ inputs=[
99
+ generation_section["src_audio"],
100
+ generation_section["allow_lm_batch"],
101
+ generation_section["batch_size_input"]
102
+ ],
103
+ outputs=[
104
+ generation_section["codes_single_row"],
105
+ generation_section["codes_batch_row"],
106
+ generation_section["codes_batch_row_2"],
107
+ generation_section["codes_col_1"],
108
+ generation_section["codes_col_2"],
109
+ generation_section["codes_col_3"],
110
+ generation_section["codes_col_4"],
111
+ generation_section["codes_col_5"],
112
+ generation_section["codes_col_6"],
113
+ generation_section["codes_col_7"],
114
+ generation_section["codes_col_8"],
115
+ generation_section["transcribe_btn"],
116
+ ]
117
+ )
118
+
119
+ # ========== Audio Conversion ==========
120
+ generation_section["convert_src_to_codes_btn"].click(
121
+ fn=lambda src: gen_h.convert_src_audio_to_codes_wrapper(dit_handler, src),
122
+ inputs=[generation_section["src_audio"]],
123
+ outputs=[generation_section["text2music_audio_code_string"]]
124
+ )
125
+
126
+ # ========== Instruction UI Updates ==========
127
+ for trigger in [generation_section["task_type"], generation_section["track_name"], generation_section["complete_track_classes"]]:
128
+ trigger.change(
129
+ fn=lambda *args: gen_h.update_instruction_ui(dit_handler, *args),
130
+ inputs=[
131
+ generation_section["task_type"],
132
+ generation_section["track_name"],
133
+ generation_section["complete_track_classes"],
134
+ generation_section["text2music_audio_code_string"],
135
+ generation_section["init_llm_checkbox"]
136
+ ],
137
+ outputs=[
138
+ generation_section["instruction_display_gen"],
139
+ generation_section["track_name"],
140
+ generation_section["complete_track_classes"],
141
+ generation_section["audio_cover_strength"],
142
+ generation_section["repainting_group"],
143
+ generation_section["text2music_audio_codes_group"],
144
+ ]
145
+ )
146
+
147
+ # ========== Sample/Transcribe Handlers ==========
148
+ generation_section["sample_btn"].click(
149
+ fn=lambda task, debug: gen_h.sample_example_smart(llm_handler, task, debug) + (True,),
150
+ inputs=[
151
+ generation_section["task_type"],
152
+ generation_section["constrained_decoding_debug"]
153
+ ],
154
+ outputs=[
155
+ generation_section["captions"],
156
+ generation_section["lyrics"],
157
+ generation_section["think_checkbox"],
158
+ generation_section["bpm"],
159
+ generation_section["audio_duration"],
160
+ generation_section["key_scale"],
161
+ generation_section["vocal_language"],
162
+ generation_section["time_signature"],
163
+ results_section["is_format_caption_state"]
164
+ ]
165
+ )
166
+
167
+ generation_section["text2music_audio_code_string"].change(
168
+ fn=gen_h.update_transcribe_button_text,
169
+ inputs=[generation_section["text2music_audio_code_string"]],
170
+ outputs=[generation_section["transcribe_btn"]]
171
+ )
172
+
173
+ generation_section["transcribe_btn"].click(
174
+ fn=lambda codes, debug: gen_h.transcribe_audio_codes(llm_handler, codes, debug),
175
+ inputs=[
176
+ generation_section["text2music_audio_code_string"],
177
+ generation_section["constrained_decoding_debug"]
178
+ ],
179
+ outputs=[
180
+ results_section["status_output"],
181
+ generation_section["captions"],
182
+ generation_section["lyrics"],
183
+ generation_section["bpm"],
184
+ generation_section["audio_duration"],
185
+ generation_section["key_scale"],
186
+ generation_section["vocal_language"],
187
+ generation_section["time_signature"],
188
+ results_section["is_format_caption_state"]
189
+ ]
190
+ )
191
+
192
+ # ========== Reset Format Caption Flag ==========
193
+ for trigger in [generation_section["captions"], generation_section["lyrics"], generation_section["bpm"],
194
+ generation_section["key_scale"], generation_section["time_signature"],
195
+ generation_section["vocal_language"], generation_section["audio_duration"]]:
196
+ trigger.change(
197
+ fn=gen_h.reset_format_caption_flag,
198
+ inputs=[],
199
+ outputs=[results_section["is_format_caption_state"]]
200
+ )
201
+
202
+ # ========== Audio Uploads Accordion ==========
203
+ for trigger in [generation_section["reference_audio"], generation_section["src_audio"]]:
204
+ trigger.change(
205
+ fn=gen_h.update_audio_uploads_accordion,
206
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
207
+ outputs=[generation_section["audio_uploads_accordion"]]
208
+ )
209
+
210
+ # ========== Instrumental Checkbox ==========
211
+ generation_section["instrumental_checkbox"].change(
212
+ fn=gen_h.handle_instrumental_checkbox,
213
+ inputs=[generation_section["instrumental_checkbox"], generation_section["lyrics"]],
214
+ outputs=[generation_section["lyrics"]]
215
+ )
216
+
217
+ # ========== Load/Save Metadata ==========
218
+ generation_section["load_file"].upload(
219
+ fn=gen_h.load_metadata,
220
+ inputs=[generation_section["load_file"]],
221
+ outputs=[
222
+ generation_section["task_type"],
223
+ generation_section["captions"],
224
+ generation_section["lyrics"],
225
+ generation_section["vocal_language"],
226
+ generation_section["bpm"],
227
+ generation_section["key_scale"],
228
+ generation_section["time_signature"],
229
+ generation_section["audio_duration"],
230
+ generation_section["batch_size_input"],
231
+ generation_section["inference_steps"],
232
+ generation_section["guidance_scale"],
233
+ generation_section["seed"],
234
+ generation_section["random_seed_checkbox"],
235
+ generation_section["use_adg"],
236
+ generation_section["cfg_interval_start"],
237
+ generation_section["cfg_interval_end"],
238
+ generation_section["audio_format"],
239
+ generation_section["lm_temperature"],
240
+ generation_section["lm_cfg_scale"],
241
+ generation_section["lm_top_k"],
242
+ generation_section["lm_top_p"],
243
+ generation_section["lm_negative_prompt"],
244
+ generation_section["use_cot_caption"],
245
+ generation_section["use_cot_language"],
246
+ generation_section["audio_cover_strength"],
247
+ generation_section["think_checkbox"],
248
+ generation_section["text2music_audio_code_string"],
249
+ generation_section["repainting_start"],
250
+ generation_section["repainting_end"],
251
+ generation_section["track_name"],
252
+ generation_section["complete_track_classes"],
253
+ results_section["is_format_caption_state"]
254
+ ]
255
+ )
256
+
257
+ # Save buttons for audio 1 and 2
258
+ for btn_idx, btn_key in [(1, "save_btn_1"), (2, "save_btn_2")]:
259
+ results_section[btn_key].click(
260
+ fn=res_h.save_audio_and_metadata,
261
+ inputs=[
262
+ results_section[f"generated_audio_{btn_idx}"],
263
+ generation_section["task_type"],
264
+ generation_section["captions"],
265
+ generation_section["lyrics"],
266
+ generation_section["vocal_language"],
267
+ generation_section["bpm"],
268
+ generation_section["key_scale"],
269
+ generation_section["time_signature"],
270
+ generation_section["audio_duration"],
271
+ generation_section["batch_size_input"],
272
+ generation_section["inference_steps"],
273
+ generation_section["guidance_scale"],
274
+ generation_section["seed"],
275
+ generation_section["random_seed_checkbox"],
276
+ generation_section["use_adg"],
277
+ generation_section["cfg_interval_start"],
278
+ generation_section["cfg_interval_end"],
279
+ generation_section["audio_format"],
280
+ generation_section["lm_temperature"],
281
+ generation_section["lm_cfg_scale"],
282
+ generation_section["lm_top_k"],
283
+ generation_section["lm_top_p"],
284
+ generation_section["lm_negative_prompt"],
285
+ generation_section["use_cot_caption"],
286
+ generation_section["use_cot_language"],
287
+ generation_section["audio_cover_strength"],
288
+ generation_section["think_checkbox"],
289
+ generation_section["text2music_audio_code_string"],
290
+ generation_section["repainting_start"],
291
+ generation_section["repainting_end"],
292
+ generation_section["track_name"],
293
+ generation_section["complete_track_classes"],
294
+ results_section["lm_metadata_state"],
295
+ ],
296
+ outputs=[gr.File(label="Download Package", visible=False)]
297
+ )
298
+
299
+ # ========== Send to SRC Handlers ==========
300
+ for btn_idx in range(1, 9):
301
+ results_section[f"send_to_src_btn_{btn_idx}"].click(
302
+ fn=res_h.send_audio_to_src_with_metadata,
303
+ inputs=[
304
+ results_section[f"generated_audio_{btn_idx}"],
305
+ results_section["lm_metadata_state"]
306
+ ],
307
+ outputs=[
308
+ generation_section["src_audio"],
309
+ generation_section["bpm"],
310
+ generation_section["captions"],
311
+ generation_section["lyrics"],
312
+ generation_section["audio_duration"],
313
+ generation_section["key_scale"],
314
+ generation_section["vocal_language"],
315
+ generation_section["time_signature"],
316
+ results_section["is_format_caption_state"]
317
+ ]
318
+ )
319
+
320
+ # ========== Score Calculation Handlers ==========
321
+ for btn_idx in range(1, 9):
322
+ results_section[f"score_btn_{btn_idx}"].click(
323
+ fn=lambda sample_idx, scale, batch_idx, queue: res_h.calculate_score_handler_with_selection(
324
+ llm_handler, sample_idx, scale, batch_idx, queue
325
+ ),
326
+ inputs=[
327
+ gr.State(value=btn_idx),
328
+ generation_section["score_scale"],
329
+ results_section["current_batch_index"],
330
+ results_section["batch_queue"],
331
+ ],
332
+ outputs=[results_section[f"score_display_{btn_idx}"], results_section["batch_queue"]]
333
+ )
334
+
335
+ # ========== Generation Handler ==========
336
+ generation_section["generate_btn"].click(
337
+ fn=lambda *args: res_h.generate_with_batch_management(dit_handler, llm_handler, *args),
338
+ inputs=[
339
+ generation_section["captions"],
340
+ generation_section["lyrics"],
341
+ generation_section["bpm"],
342
+ generation_section["key_scale"],
343
+ generation_section["time_signature"],
344
+ generation_section["vocal_language"],
345
+ generation_section["inference_steps"],
346
+ generation_section["guidance_scale"],
347
+ generation_section["random_seed_checkbox"],
348
+ generation_section["seed"],
349
+ generation_section["reference_audio"],
350
+ generation_section["audio_duration"],
351
+ generation_section["batch_size_input"],
352
+ generation_section["src_audio"],
353
+ generation_section["text2music_audio_code_string"],
354
+ generation_section["repainting_start"],
355
+ generation_section["repainting_end"],
356
+ generation_section["instruction_display_gen"],
357
+ generation_section["audio_cover_strength"],
358
+ generation_section["task_type"],
359
+ generation_section["use_adg"],
360
+ generation_section["cfg_interval_start"],
361
+ generation_section["cfg_interval_end"],
362
+ generation_section["audio_format"],
363
+ generation_section["lm_temperature"],
364
+ generation_section["think_checkbox"],
365
+ generation_section["lm_cfg_scale"],
366
+ generation_section["lm_top_k"],
367
+ generation_section["lm_top_p"],
368
+ generation_section["lm_negative_prompt"],
369
+ generation_section["use_cot_metas"],
370
+ generation_section["use_cot_caption"],
371
+ generation_section["use_cot_language"],
372
+ results_section["is_format_caption_state"],
373
+ generation_section["constrained_decoding_debug"],
374
+ generation_section["allow_lm_batch"],
375
+ generation_section["auto_score"],
376
+ generation_section["score_scale"],
377
+ generation_section["lm_batch_chunk_size"],
378
+ generation_section["track_name"],
379
+ generation_section["complete_track_classes"],
380
+ generation_section["autogen_checkbox"],
381
+ results_section["current_batch_index"],
382
+ results_section["total_batches"],
383
+ results_section["batch_queue"],
384
+ results_section["generation_params_state"],
385
+ ],
386
+ outputs=[
387
+ results_section["generated_audio_1"],
388
+ results_section["generated_audio_2"],
389
+ results_section["generated_audio_3"],
390
+ results_section["generated_audio_4"],
391
+ results_section["generated_audio_5"],
392
+ results_section["generated_audio_6"],
393
+ results_section["generated_audio_7"],
394
+ results_section["generated_audio_8"],
395
+ results_section["generated_audio_batch"],
396
+ results_section["generation_info"],
397
+ results_section["status_output"],
398
+ generation_section["seed"],
399
+ results_section["align_score_1"],
400
+ results_section["align_text_1"],
401
+ results_section["align_plot_1"],
402
+ results_section["align_score_2"],
403
+ results_section["align_text_2"],
404
+ results_section["align_plot_2"],
405
+ results_section["score_display_1"],
406
+ results_section["score_display_2"],
407
+ results_section["score_display_3"],
408
+ results_section["score_display_4"],
409
+ results_section["score_display_5"],
410
+ results_section["score_display_6"],
411
+ results_section["score_display_7"],
412
+ results_section["score_display_8"],
413
+ generation_section["text2music_audio_code_string"],
414
+ generation_section["text2music_audio_code_string_1"],
415
+ generation_section["text2music_audio_code_string_2"],
416
+ generation_section["text2music_audio_code_string_3"],
417
+ generation_section["text2music_audio_code_string_4"],
418
+ generation_section["text2music_audio_code_string_5"],
419
+ generation_section["text2music_audio_code_string_6"],
420
+ generation_section["text2music_audio_code_string_7"],
421
+ generation_section["text2music_audio_code_string_8"],
422
+ results_section["lm_metadata_state"],
423
+ results_section["is_format_caption_state"],
424
+ results_section["current_batch_index"],
425
+ results_section["total_batches"],
426
+ results_section["batch_queue"],
427
+ results_section["generation_params_state"],
428
+ results_section["batch_indicator"],
429
+ results_section["prev_batch_btn"],
430
+ results_section["next_batch_btn"],
431
+ results_section["next_batch_status"],
432
+ results_section["restore_params_btn"],
433
+ ]
434
+ ).then(
435
+ fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
436
+ inputs=[
437
+ generation_section["autogen_checkbox"],
438
+ results_section["generation_params_state"],
439
+ results_section["current_batch_index"],
440
+ results_section["total_batches"],
441
+ results_section["batch_queue"],
442
+ results_section["is_format_caption_state"],
443
+ ],
444
+ outputs=[
445
+ results_section["batch_queue"],
446
+ results_section["total_batches"],
447
+ results_section["next_batch_status"],
448
+ results_section["next_batch_btn"],
449
+ ]
450
+ )
451
+
452
+ # ========== Batch Navigation Handlers ==========
453
+ results_section["prev_batch_btn"].click(
454
+ fn=res_h.navigate_to_previous_batch,
455
+ inputs=[
456
+ results_section["current_batch_index"],
457
+ results_section["batch_queue"],
458
+ ],
459
+ outputs=[
460
+ results_section["generated_audio_1"],
461
+ results_section["generated_audio_2"],
462
+ results_section["generated_audio_3"],
463
+ results_section["generated_audio_4"],
464
+ results_section["generated_audio_5"],
465
+ results_section["generated_audio_6"],
466
+ results_section["generated_audio_7"],
467
+ results_section["generated_audio_8"],
468
+ results_section["generated_audio_batch"],
469
+ results_section["generation_info"],
470
+ results_section["current_batch_index"],
471
+ results_section["batch_indicator"],
472
+ results_section["prev_batch_btn"],
473
+ results_section["next_batch_btn"],
474
+ results_section["status_output"],
475
+ results_section["score_display_1"],
476
+ results_section["score_display_2"],
477
+ results_section["score_display_3"],
478
+ results_section["score_display_4"],
479
+ results_section["score_display_5"],
480
+ results_section["score_display_6"],
481
+ results_section["score_display_7"],
482
+ results_section["score_display_8"],
483
+ results_section["restore_params_btn"],
484
+ ]
485
+ )
486
+
487
+ results_section["next_batch_btn"].click(
488
+ fn=res_h.capture_current_params,
489
+ inputs=[
490
+ generation_section["captions"],
491
+ generation_section["lyrics"],
492
+ generation_section["bpm"],
493
+ generation_section["key_scale"],
494
+ generation_section["time_signature"],
495
+ generation_section["vocal_language"],
496
+ generation_section["inference_steps"],
497
+ generation_section["guidance_scale"],
498
+ generation_section["random_seed_checkbox"],
499
+ generation_section["seed"],
500
+ generation_section["reference_audio"],
501
+ generation_section["audio_duration"],
502
+ generation_section["batch_size_input"],
503
+ generation_section["src_audio"],
504
+ generation_section["text2music_audio_code_string"],
505
+ generation_section["repainting_start"],
506
+ generation_section["repainting_end"],
507
+ generation_section["instruction_display_gen"],
508
+ generation_section["audio_cover_strength"],
509
+ generation_section["task_type"],
510
+ generation_section["use_adg"],
511
+ generation_section["cfg_interval_start"],
512
+ generation_section["cfg_interval_end"],
513
+ generation_section["audio_format"],
514
+ generation_section["lm_temperature"],
515
+ generation_section["think_checkbox"],
516
+ generation_section["lm_cfg_scale"],
517
+ generation_section["lm_top_k"],
518
+ generation_section["lm_top_p"],
519
+ generation_section["lm_negative_prompt"],
520
+ generation_section["use_cot_metas"],
521
+ generation_section["use_cot_caption"],
522
+ generation_section["use_cot_language"],
523
+ generation_section["constrained_decoding_debug"],
524
+ generation_section["allow_lm_batch"],
525
+ generation_section["auto_score"],
526
+ generation_section["score_scale"],
527
+ generation_section["lm_batch_chunk_size"],
528
+ generation_section["track_name"],
529
+ generation_section["complete_track_classes"],
530
+ ],
531
+ outputs=[results_section["generation_params_state"]]
532
+ ).then(
533
+ fn=res_h.navigate_to_next_batch,
534
+ inputs=[
535
+ generation_section["autogen_checkbox"],
536
+ results_section["current_batch_index"],
537
+ results_section["total_batches"],
538
+ results_section["batch_queue"],
539
+ ],
540
+ outputs=[
541
+ results_section["generated_audio_1"],
542
+ results_section["generated_audio_2"],
543
+ results_section["generated_audio_3"],
544
+ results_section["generated_audio_4"],
545
+ results_section["generated_audio_5"],
546
+ results_section["generated_audio_6"],
547
+ results_section["generated_audio_7"],
548
+ results_section["generated_audio_8"],
549
+ results_section["generated_audio_batch"],
550
+ results_section["generation_info"],
551
+ results_section["current_batch_index"],
552
+ results_section["batch_indicator"],
553
+ results_section["prev_batch_btn"],
554
+ results_section["next_batch_btn"],
555
+ results_section["status_output"],
556
+ results_section["next_batch_status"],
557
+ results_section["score_display_1"],
558
+ results_section["score_display_2"],
559
+ results_section["score_display_3"],
560
+ results_section["score_display_4"],
561
+ results_section["score_display_5"],
562
+ results_section["score_display_6"],
563
+ results_section["score_display_7"],
564
+ results_section["score_display_8"],
565
+ results_section["restore_params_btn"],
566
+ ]
567
+ ).then(
568
+ fn=lambda *args: res_h.generate_next_batch_background(dit_handler, llm_handler, *args),
569
+ inputs=[
570
+ generation_section["autogen_checkbox"],
571
+ results_section["generation_params_state"],
572
+ results_section["current_batch_index"],
573
+ results_section["total_batches"],
574
+ results_section["batch_queue"],
575
+ results_section["is_format_caption_state"],
576
+ ],
577
+ outputs=[
578
+ results_section["batch_queue"],
579
+ results_section["total_batches"],
580
+ results_section["next_batch_status"],
581
+ results_section["next_batch_btn"],
582
+ ]
583
+ )
584
+
585
+ # ========== Restore Parameters Handler ==========
586
+ results_section["restore_params_btn"].click(
587
+ fn=res_h.restore_batch_parameters,
588
+ inputs=[
589
+ results_section["current_batch_index"],
590
+ results_section["batch_queue"]
591
+ ],
592
+ outputs=[
593
+ generation_section["text2music_audio_code_string"],
594
+ generation_section["text2music_audio_code_string_1"],
595
+ generation_section["text2music_audio_code_string_2"],
596
+ generation_section["text2music_audio_code_string_3"],
597
+ generation_section["text2music_audio_code_string_4"],
598
+ generation_section["text2music_audio_code_string_5"],
599
+ generation_section["text2music_audio_code_string_6"],
600
+ generation_section["text2music_audio_code_string_7"],
601
+ generation_section["text2music_audio_code_string_8"],
602
+ generation_section["captions"],
603
+ generation_section["lyrics"],
604
+ generation_section["bpm"],
605
+ generation_section["key_scale"],
606
+ generation_section["time_signature"],
607
+ generation_section["vocal_language"],
608
+ generation_section["audio_duration"],
609
+ generation_section["batch_size_input"],
610
+ generation_section["inference_steps"],
611
+ generation_section["lm_temperature"],
612
+ generation_section["lm_cfg_scale"],
613
+ generation_section["lm_top_k"],
614
+ generation_section["lm_top_p"],
615
+ generation_section["think_checkbox"],
616
+ generation_section["use_cot_caption"],
617
+ generation_section["use_cot_language"],
618
+ generation_section["allow_lm_batch"],
619
+ generation_section["track_name"],
620
+ generation_section["complete_track_classes"],
621
+ ]
622
+ )
acestep/gradio_ui/events/generation_handlers.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generation Input Handlers Module
3
+ Contains event handlers and helper functions related to generation inputs
4
+ """
5
+ import os
6
+ import json
7
+ import random
8
+ import glob
9
+ import gradio as gr
10
+ from typing import Optional
11
+ from acestep.constants import (
12
+ TASK_TYPES_TURBO,
13
+ TASK_TYPES_BASE,
14
+ )
15
+ from acestep.gradio_ui.i18n import t
16
+
17
+
18
+ def load_metadata(file_obj):
19
+ """Load generation parameters from a JSON file"""
20
+ if file_obj is None:
21
+ gr.Warning(t("messages.no_file_selected"))
22
+ return [None] * 31 + [False] # Return None for all fields, False for is_format_caption
23
+
24
+ try:
25
+ # Read the uploaded file
26
+ if hasattr(file_obj, 'name'):
27
+ filepath = file_obj.name
28
+ else:
29
+ filepath = file_obj
30
+
31
+ with open(filepath, 'r', encoding='utf-8') as f:
32
+ metadata = json.load(f)
33
+
34
+ # Extract all fields
35
+ task_type = metadata.get('task_type', 'text2music')
36
+ captions = metadata.get('caption', '')
37
+ lyrics = metadata.get('lyrics', '')
38
+ vocal_language = metadata.get('vocal_language', 'unknown')
39
+
40
+ # Convert bpm
41
+ bpm_value = metadata.get('bpm')
42
+ if bpm_value is not None and bpm_value != "N/A":
43
+ try:
44
+ bpm = int(bpm_value) if bpm_value else None
45
+ except:
46
+ bpm = None
47
+ else:
48
+ bpm = None
49
+
50
+ key_scale = metadata.get('keyscale', '')
51
+ time_signature = metadata.get('timesignature', '')
52
+
53
+ # Convert duration
54
+ duration_value = metadata.get('duration', -1)
55
+ if duration_value is not None and duration_value != "N/A":
56
+ try:
57
+ audio_duration = float(duration_value)
58
+ except:
59
+ audio_duration = -1
60
+ else:
61
+ audio_duration = -1
62
+
63
+ batch_size = metadata.get('batch_size', 2)
64
+ inference_steps = metadata.get('inference_steps', 8)
65
+ guidance_scale = metadata.get('guidance_scale', 7.0)
66
+ seed = metadata.get('seed', '-1')
67
+ random_seed = metadata.get('random_seed', True)
68
+ use_adg = metadata.get('use_adg', False)
69
+ cfg_interval_start = metadata.get('cfg_interval_start', 0.0)
70
+ cfg_interval_end = metadata.get('cfg_interval_end', 1.0)
71
+ audio_format = metadata.get('audio_format', 'mp3')
72
+ lm_temperature = metadata.get('lm_temperature', 0.85)
73
+ lm_cfg_scale = metadata.get('lm_cfg_scale', 2.0)
74
+ lm_top_k = metadata.get('lm_top_k', 0)
75
+ lm_top_p = metadata.get('lm_top_p', 0.9)
76
+ lm_negative_prompt = metadata.get('lm_negative_prompt', 'NO USER INPUT')
77
+ use_cot_caption = metadata.get('use_cot_caption', True)
78
+ use_cot_language = metadata.get('use_cot_language', True)
79
+ audio_cover_strength = metadata.get('audio_cover_strength', 1.0)
80
+ think = metadata.get('think', True)
81
+ audio_codes = metadata.get('audio_codes', '')
82
+ repainting_start = metadata.get('repainting_start', 0.0)
83
+ repainting_end = metadata.get('repainting_end', -1)
84
+ track_name = metadata.get('track_name')
85
+ complete_track_classes = metadata.get('complete_track_classes', [])
86
+
87
+ gr.Info(t("messages.params_loaded", filename=os.path.basename(filepath)))
88
+
89
+ return (
90
+ task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature,
91
+ audio_duration, batch_size, inference_steps, guidance_scale, seed, random_seed,
92
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format,
93
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
94
+ use_cot_caption, use_cot_language, audio_cover_strength,
95
+ think, audio_codes, repainting_start, repainting_end,
96
+ track_name, complete_track_classes,
97
+ True # Set is_format_caption to True when loading from file
98
+ )
99
+
100
+ except json.JSONDecodeError as e:
101
+ gr.Warning(t("messages.invalid_json", error=str(e)))
102
+ return [None] * 31 + [False]
103
+ except Exception as e:
104
+ gr.Warning(t("messages.load_error", error=str(e)))
105
+ return [None] * 31 + [False]
106
+
107
+
108
+ def load_random_example(task_type: str):
109
+ """Load a random example from the task-specific examples directory
110
+
111
+ Args:
112
+ task_type: The task type (e.g., "text2music")
113
+
114
+ Returns:
115
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
116
+ """
117
+ try:
118
+ # Get the project root directory
119
+ current_file = os.path.abspath(__file__)
120
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
121
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
122
+
123
+ # Construct the examples directory path
124
+ examples_dir = os.path.join(project_root, "examples", task_type)
125
+
126
+ # Check if directory exists
127
+ if not os.path.exists(examples_dir):
128
+ gr.Warning(f"Examples directory not found: examples/{task_type}/")
129
+ return "", "", True, None, None, "", "", ""
130
+
131
+ # Find all JSON files in the directory
132
+ json_files = glob.glob(os.path.join(examples_dir, "*.json"))
133
+
134
+ if not json_files:
135
+ gr.Warning(f"No JSON files found in examples/{task_type}/")
136
+ return "", "", True, None, None, "", "", ""
137
+
138
+ # Randomly select one file
139
+ selected_file = random.choice(json_files)
140
+
141
+ # Read and parse JSON
142
+ try:
143
+ with open(selected_file, 'r', encoding='utf-8') as f:
144
+ data = json.load(f)
145
+
146
+ # Extract caption (prefer 'caption', fallback to 'prompt')
147
+ caption_value = data.get('caption', data.get('prompt', ''))
148
+ if not isinstance(caption_value, str):
149
+ caption_value = str(caption_value) if caption_value else ''
150
+
151
+ # Extract lyrics
152
+ lyrics_value = data.get('lyrics', '')
153
+ if not isinstance(lyrics_value, str):
154
+ lyrics_value = str(lyrics_value) if lyrics_value else ''
155
+
156
+ # Extract think (default to True if not present)
157
+ think_value = data.get('think', True)
158
+ if not isinstance(think_value, bool):
159
+ think_value = True
160
+
161
+ # Extract optional metadata fields
162
+ bpm_value = None
163
+ if 'bpm' in data and data['bpm'] not in [None, "N/A", ""]:
164
+ try:
165
+ bpm_value = int(data['bpm'])
166
+ except (ValueError, TypeError):
167
+ pass
168
+
169
+ duration_value = None
170
+ if 'duration' in data and data['duration'] not in [None, "N/A", ""]:
171
+ try:
172
+ duration_value = float(data['duration'])
173
+ except (ValueError, TypeError):
174
+ pass
175
+
176
+ keyscale_value = data.get('keyscale', '')
177
+ if keyscale_value in [None, "N/A"]:
178
+ keyscale_value = ''
179
+
180
+ language_value = data.get('language', '')
181
+ if language_value in [None, "N/A"]:
182
+ language_value = ''
183
+
184
+ timesignature_value = data.get('timesignature', '')
185
+ if timesignature_value in [None, "N/A"]:
186
+ timesignature_value = ''
187
+
188
+ gr.Info(t("messages.example_loaded", filename=os.path.basename(selected_file)))
189
+ return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
190
+
191
+ except json.JSONDecodeError as e:
192
+ gr.Warning(t("messages.example_failed", filename=os.path.basename(selected_file), error=str(e)))
193
+ return "", "", True, None, None, "", "", ""
194
+ except Exception as e:
195
+ gr.Warning(t("messages.example_error", error=str(e)))
196
+ return "", "", True, None, None, "", "", ""
197
+
198
+ except Exception as e:
199
+ gr.Warning(t("messages.example_error", error=str(e)))
200
+ return "", "", True, None, None, "", "", ""
201
+
202
+
203
+ def sample_example_smart(llm_handler, task_type: str, constrained_decoding_debug: bool = False):
204
+ """Smart sample function that uses LM if initialized, otherwise falls back to examples
205
+
206
+ Args:
207
+ llm_handler: LLM handler instance
208
+ task_type: The task type (e.g., "text2music")
209
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
210
+
211
+ Returns:
212
+ Tuple of (caption, lyrics, think, bpm, duration, keyscale, language, timesignature) for updating UI components
213
+ """
214
+ # Check if LM is initialized
215
+ if llm_handler.llm_initialized:
216
+ # Use LM to generate example
217
+ try:
218
+ # Generate example using LM with empty input (NO USER INPUT)
219
+ metadata, status = llm_handler.understand_audio_from_codes(
220
+ audio_codes="NO USER INPUT",
221
+ use_constrained_decoding=True,
222
+ temperature=0.85,
223
+ constrained_decoding_debug=constrained_decoding_debug,
224
+ )
225
+
226
+ if metadata:
227
+ caption_value = metadata.get('caption', '')
228
+ lyrics_value = metadata.get('lyrics', '')
229
+ think_value = True # Always enable think when using LM-generated examples
230
+
231
+ # Extract optional metadata fields
232
+ bpm_value = None
233
+ if 'bpm' in metadata and metadata['bpm'] not in [None, "N/A", ""]:
234
+ try:
235
+ bpm_value = int(metadata['bpm'])
236
+ except (ValueError, TypeError):
237
+ pass
238
+
239
+ duration_value = None
240
+ if 'duration' in metadata and metadata['duration'] not in [None, "N/A", ""]:
241
+ try:
242
+ duration_value = float(metadata['duration'])
243
+ except (ValueError, TypeError):
244
+ pass
245
+
246
+ keyscale_value = metadata.get('keyscale', '')
247
+ if keyscale_value in [None, "N/A"]:
248
+ keyscale_value = ''
249
+
250
+ language_value = metadata.get('language', '')
251
+ if language_value in [None, "N/A"]:
252
+ language_value = ''
253
+
254
+ timesignature_value = metadata.get('timesignature', '')
255
+ if timesignature_value in [None, "N/A"]:
256
+ timesignature_value = ''
257
+
258
+ gr.Info(t("messages.lm_generated"))
259
+ return caption_value, lyrics_value, think_value, bpm_value, duration_value, keyscale_value, language_value, timesignature_value
260
+ else:
261
+ gr.Warning(t("messages.lm_fallback"))
262
+ return load_random_example(task_type)
263
+
264
+ except Exception as e:
265
+ gr.Warning(t("messages.lm_fallback"))
266
+ return load_random_example(task_type)
267
+ else:
268
+ # LM not initialized, use examples directory
269
+ return load_random_example(task_type)
270
+
271
+
272
+ def refresh_checkpoints(dit_handler):
273
+ """Refresh available checkpoints"""
274
+ choices = dit_handler.get_available_checkpoints()
275
+ return gr.update(choices=choices)
276
+
277
+
278
+ def update_model_type_settings(config_path):
279
+ """Update UI settings based on model type"""
280
+ if config_path is None:
281
+ config_path = ""
282
+ config_path_lower = config_path.lower()
283
+
284
+ if "turbo" in config_path_lower:
285
+ # Turbo model: max 8 steps, hide CFG/ADG, only show text2music/repaint/cover
286
+ return (
287
+ gr.update(value=8, maximum=8, minimum=1), # inference_steps
288
+ gr.update(visible=False), # guidance_scale
289
+ gr.update(visible=False), # use_adg
290
+ gr.update(visible=False), # cfg_interval_start
291
+ gr.update(visible=False), # cfg_interval_end
292
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
293
+ )
294
+ elif "base" in config_path_lower:
295
+ # Base model: max 100 steps, show CFG/ADG, show all task types
296
+ return (
297
+ gr.update(value=32, maximum=100, minimum=1), # inference_steps
298
+ gr.update(visible=True), # guidance_scale
299
+ gr.update(visible=True), # use_adg
300
+ gr.update(visible=True), # cfg_interval_start
301
+ gr.update(visible=True), # cfg_interval_end
302
+ gr.update(choices=TASK_TYPES_BASE), # task_type
303
+ )
304
+ else:
305
+ # Default to turbo settings
306
+ return (
307
+ gr.update(value=8, maximum=8, minimum=1),
308
+ gr.update(visible=False),
309
+ gr.update(visible=False),
310
+ gr.update(visible=False),
311
+ gr.update(visible=False),
312
+ gr.update(choices=TASK_TYPES_TURBO), # task_type
313
+ )
314
+
315
+
316
+ def init_service_wrapper(dit_handler, llm_handler, checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
317
+ """Wrapper for service initialization, returns status, button state, and accordion state"""
318
+ # Initialize DiT handler
319
+ status, enable = dit_handler.initialize_service(
320
+ checkpoint, config_path, device,
321
+ use_flash_attention=use_flash_attention, compile_model=False,
322
+ offload_to_cpu=offload_to_cpu, offload_dit_to_cpu=offload_dit_to_cpu
323
+ )
324
+
325
+ # Initialize LM handler if requested
326
+ if init_llm:
327
+ # Get checkpoint directory
328
+ current_file = os.path.abspath(__file__)
329
+ # This file is in acestep/gradio_ui/events/, need 4 levels up to reach project root
330
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(current_file))))
331
+ checkpoint_dir = os.path.join(project_root, "checkpoints")
332
+
333
+ lm_status, lm_success = llm_handler.initialize(
334
+ checkpoint_dir=checkpoint_dir,
335
+ lm_model_path=lm_model_path,
336
+ backend=backend,
337
+ device=device,
338
+ offload_to_cpu=offload_to_cpu,
339
+ dtype=dit_handler.dtype
340
+ )
341
+
342
+ if lm_success:
343
+ status += f"\n{lm_status}"
344
+ else:
345
+ status += f"\n{lm_status}"
346
+ # Don't fail the entire initialization if LM fails, but log it
347
+ # Keep enable as is (DiT initialization result) even if LM fails
348
+
349
+ # Check if model is initialized - if so, collapse the accordion
350
+ is_model_initialized = dit_handler.model is not None
351
+ accordion_state = gr.update(open=not is_model_initialized)
352
+
353
+ return status, gr.update(interactive=enable), accordion_state
354
+
355
+
356
+ def update_negative_prompt_visibility(init_llm_checked):
357
+ """Update negative prompt visibility: show if Initialize 5Hz LM checkbox is checked"""
358
+ return gr.update(visible=init_llm_checked)
359
+
360
+
361
+ def update_audio_cover_strength_visibility(task_type_value, init_llm_checked):
362
+ """Update audio_cover_strength visibility and label"""
363
+ # Show if task is cover OR if LM is initialized
364
+ is_visible = (task_type_value == "cover") or init_llm_checked
365
+ # Change label based on context
366
+ if init_llm_checked and task_type_value != "cover":
367
+ label = "LM codes strength"
368
+ info = "Control how many denoising steps use LM-generated codes"
369
+ else:
370
+ label = "Audio Cover Strength"
371
+ info = "Control how many denoising steps use cover mode"
372
+
373
+ return gr.update(visible=is_visible, label=label, info=info)
374
+
375
+
376
+ def convert_src_audio_to_codes_wrapper(dit_handler, src_audio):
377
+ """Wrapper for converting src audio to codes"""
378
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
379
+ return codes_string
380
+
381
+
382
+ def update_instruction_ui(
383
+ dit_handler,
384
+ task_type_value: str,
385
+ track_name_value: Optional[str],
386
+ complete_track_classes_value: list,
387
+ audio_codes_content: str = "",
388
+ init_llm_checked: bool = False
389
+ ) -> tuple:
390
+ """Update instruction and UI visibility based on task type."""
391
+ instruction = dit_handler.generate_instruction(
392
+ task_type=task_type_value,
393
+ track_name=track_name_value,
394
+ complete_track_classes=complete_track_classes_value
395
+ )
396
+
397
+ # Show track_name for lego and extract
398
+ track_name_visible = task_type_value in ["lego", "extract"]
399
+ # Show complete_track_classes for complete
400
+ complete_visible = task_type_value == "complete"
401
+ # Show audio_cover_strength for cover OR when LM is initialized
402
+ audio_cover_strength_visible = (task_type_value == "cover") or init_llm_checked
403
+ # Determine label and info based on context
404
+ if init_llm_checked and task_type_value != "cover":
405
+ audio_cover_strength_label = "LM codes strength"
406
+ audio_cover_strength_info = "Control how many denoising steps use LM-generated codes"
407
+ else:
408
+ audio_cover_strength_label = "Audio Cover Strength"
409
+ audio_cover_strength_info = "Control how many denoising steps use cover mode"
410
+ # Show repainting controls for repaint and lego
411
+ repainting_visible = task_type_value in ["repaint", "lego"]
412
+ # Show text2music_audio_codes if task is text2music OR if it has content
413
+ # This allows it to stay visible even if user switches task type but has codes
414
+ has_audio_codes = audio_codes_content and str(audio_codes_content).strip()
415
+ text2music_audio_codes_visible = task_type_value == "text2music" or has_audio_codes
416
+
417
+ return (
418
+ instruction, # instruction_display_gen
419
+ gr.update(visible=track_name_visible), # track_name
420
+ gr.update(visible=complete_visible), # complete_track_classes
421
+ gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
422
+ gr.update(visible=repainting_visible), # repainting_group
423
+ gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
424
+ )
425
+
426
+
427
+ def transcribe_audio_codes(llm_handler, audio_code_string, constrained_decoding_debug):
428
+ """
429
+ Transcribe audio codes to metadata using LLM understanding.
430
+ If audio_code_string is empty, generate a sample example instead.
431
+
432
+ Args:
433
+ llm_handler: LLM handler instance
434
+ audio_code_string: String containing audio codes (or empty for example generation)
435
+ constrained_decoding_debug: Whether to enable debug logging for constrained decoding
436
+
437
+ Returns:
438
+ Tuple of (status_message, caption, lyrics, bpm, duration, keyscale, language, timesignature)
439
+ """
440
+ if not llm_handler.llm_initialized:
441
+ return t("messages.lm_not_initialized"), "", "", None, None, "", "", ""
442
+
443
+ # If codes are empty, this becomes a "generate example" task
444
+ # Use "NO USER INPUT" as the input to generate a sample
445
+ if not audio_code_string or not audio_code_string.strip():
446
+ audio_code_string = "NO USER INPUT"
447
+
448
+ # Call LLM understanding
449
+ metadata, status = llm_handler.understand_audio_from_codes(
450
+ audio_codes=audio_code_string,
451
+ use_constrained_decoding=True,
452
+ constrained_decoding_debug=constrained_decoding_debug,
453
+ )
454
+
455
+ # Extract fields for UI update
456
+ caption = metadata.get('caption', '')
457
+ lyrics = metadata.get('lyrics', '')
458
+ bpm = metadata.get('bpm')
459
+ duration = metadata.get('duration')
460
+ keyscale = metadata.get('keyscale', '')
461
+ language = metadata.get('language', '')
462
+ timesignature = metadata.get('timesignature', '')
463
+
464
+ # Convert to appropriate types
465
+ try:
466
+ bpm = int(bpm) if bpm and bpm != 'N/A' else None
467
+ except:
468
+ bpm = None
469
+
470
+ try:
471
+ duration = float(duration) if duration and duration != 'N/A' else None
472
+ except:
473
+ duration = None
474
+
475
+ return (
476
+ status,
477
+ caption,
478
+ lyrics,
479
+ bpm,
480
+ duration,
481
+ keyscale,
482
+ language,
483
+ timesignature,
484
+ True # Set is_format_caption to True (from Transcribe/LM understanding)
485
+ )
486
+
487
+
488
+ def update_transcribe_button_text(audio_code_string):
489
+ """
490
+ Update the transcribe button text based on input content.
491
+ If empty: "Generate Example"
492
+ If has content: "Transcribe"
493
+ """
494
+ if not audio_code_string or not audio_code_string.strip():
495
+ return gr.update(value="Generate Example")
496
+ else:
497
+ return gr.update(value="Transcribe")
498
+
499
+
500
+ def reset_format_caption_flag():
501
+ """Reset is_format_caption to False when user manually edits caption/metadata"""
502
+ return False
503
+
504
+
505
+ def update_audio_uploads_accordion(reference_audio, src_audio):
506
+ """Update Audio Uploads accordion open state based on whether audio files are present"""
507
+ has_audio = (reference_audio is not None) or (src_audio is not None)
508
+ return gr.update(open=has_audio)
509
+
510
+
511
+ def handle_instrumental_checkbox(instrumental_checked, current_lyrics):
512
+ """
513
+ Handle instrumental checkbox changes.
514
+ When checked: if no lyrics, fill with [Instrumental]
515
+ When unchecked: if lyrics is [Instrumental], clear it
516
+ """
517
+ if instrumental_checked:
518
+ # If checked and no lyrics, fill with [Instrumental]
519
+ if not current_lyrics or not current_lyrics.strip():
520
+ return "[Instrumental]"
521
+ else:
522
+ # Has lyrics, don't change
523
+ return current_lyrics
524
+ else:
525
+ # If unchecked and lyrics is exactly [Instrumental], clear it
526
+ if current_lyrics and current_lyrics.strip() == "[Instrumental]":
527
+ return ""
528
+ else:
529
+ # Has other lyrics, don't change
530
+ return current_lyrics
531
+
532
+
533
+ def update_audio_components_visibility(batch_size):
534
+ """Show/hide individual audio components based on batch size (1-8)
535
+
536
+ Row 1: Components 1-4 (batch_size 1-4)
537
+ Row 2: Components 5-8 (batch_size 5-8)
538
+ """
539
+ # Clamp batch size to 1-8 range for UI
540
+ batch_size = min(max(int(batch_size), 1), 8)
541
+
542
+ # Row 1 columns (1-4)
543
+ updates_row1 = (
544
+ gr.update(visible=True), # audio_col_1: always visible
545
+ gr.update(visible=batch_size >= 2), # audio_col_2
546
+ gr.update(visible=batch_size >= 3), # audio_col_3
547
+ gr.update(visible=batch_size >= 4), # audio_col_4
548
+ )
549
+
550
+ # Row 2 container and columns (5-8)
551
+ show_row_5_8 = batch_size >= 5
552
+ updates_row2 = (
553
+ gr.update(visible=show_row_5_8), # audio_row_5_8 (container)
554
+ gr.update(visible=batch_size >= 5), # audio_col_5
555
+ gr.update(visible=batch_size >= 6), # audio_col_6
556
+ gr.update(visible=batch_size >= 7), # audio_col_7
557
+ gr.update(visible=batch_size >= 8), # audio_col_8
558
+ )
559
+
560
+ return updates_row1 + updates_row2
561
+
562
+
563
+ def update_codes_hints_visibility(src_audio, allow_lm_batch, batch_size):
564
+ """Switch between single/batch codes input based on src_audio presence
565
+
566
+ When src_audio is present:
567
+ - Show single mode with transcribe button
568
+ - Clear codes (will be filled by transcription)
569
+
570
+ When src_audio is absent:
571
+ - Hide transcribe button
572
+ - Show batch mode if allow_lm_batch=True and batch_size>=2
573
+ - Show single mode otherwise
574
+
575
+ Row 1: Codes 1-4
576
+ Row 2: Codes 5-8 (batch_size >= 5)
577
+ """
578
+ batch_size = min(max(int(batch_size), 1), 8)
579
+ has_src_audio = src_audio is not None
580
+
581
+ if has_src_audio:
582
+ # Has src_audio: show single mode with transcribe button
583
+ return (
584
+ gr.update(visible=True), # codes_single_row
585
+ gr.update(visible=False), # codes_batch_row
586
+ gr.update(visible=False), # codes_batch_row_2
587
+ *[gr.update(visible=False)] * 8, # Hide all batch columns
588
+ gr.update(visible=True), # transcribe_btn: show when src_audio present
589
+ )
590
+ else:
591
+ # No src_audio: decide between single/batch mode based on settings
592
+ if allow_lm_batch and batch_size >= 2:
593
+ # Batch mode: hide single, show batch codes with dynamic columns
594
+ show_row_2 = batch_size >= 5
595
+ return (
596
+ gr.update(visible=False), # codes_single_row
597
+ gr.update(visible=True), # codes_batch_row (row 1)
598
+ gr.update(visible=show_row_2), # codes_batch_row_2 (row 2)
599
+ # Row 1 columns (1-4)
600
+ gr.update(visible=True), # codes_col_1: always visible in batch mode
601
+ gr.update(visible=batch_size >= 2), # codes_col_2
602
+ gr.update(visible=batch_size >= 3), # codes_col_3
603
+ gr.update(visible=batch_size >= 4), # codes_col_4
604
+ # Row 2 columns (5-8)
605
+ gr.update(visible=batch_size >= 5), # codes_col_5
606
+ gr.update(visible=batch_size >= 6), # codes_col_6
607
+ gr.update(visible=batch_size >= 7), # codes_col_7
608
+ gr.update(visible=batch_size >= 8), # codes_col_8
609
+ gr.update(visible=False), # transcribe_btn: hide when no src_audio
610
+ )
611
+ else:
612
+ # Single mode: show single, hide batch
613
+ return (
614
+ gr.update(visible=True), # codes_single_row
615
+ gr.update(visible=False), # codes_batch_row
616
+ gr.update(visible=False), # codes_batch_row_2
617
+ *[gr.update(visible=False)] * 8, # Hide all batch columns
618
+ gr.update(visible=False), # transcribe_btn: hide when no src_audio
619
+ )
acestep/gradio_ui/events/results_handlers.py ADDED
@@ -0,0 +1,1381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Results Handlers Module
3
+ Contains event handlers and helper functions related to result display, scoring, and batch management
4
+ """
5
+ import os
6
+ import json
7
+ import datetime
8
+ import tempfile
9
+ import shutil
10
+ import zipfile
11
+ import time as time_module
12
+ import gradio as gr
13
+ from loguru import logger
14
+ from acestep.gradio_ui.i18n import t
15
+
16
+
17
+ def store_batch_in_queue(
18
+ batch_queue,
19
+ batch_index,
20
+ audio_paths,
21
+ generation_info,
22
+ seeds,
23
+ codes=None,
24
+ scores=None,
25
+ allow_lm_batch=False,
26
+ batch_size=2,
27
+ generation_params=None,
28
+ lm_generated_metadata=None,
29
+ status="completed"
30
+ ):
31
+ """Store batch results in queue with ALL generation parameters
32
+
33
+ Args:
34
+ codes: Audio codes used for generation (list for batch mode, string for single mode)
35
+ scores: List of score displays for each audio (optional)
36
+ allow_lm_batch: Whether batch LM mode was used for this batch
37
+ batch_size: Batch size used for this batch
38
+ generation_params: Complete dictionary of ALL generation parameters used
39
+ lm_generated_metadata: LM-generated metadata for scoring (optional)
40
+ """
41
+ batch_queue[batch_index] = {
42
+ "status": status,
43
+ "audio_paths": audio_paths,
44
+ "generation_info": generation_info,
45
+ "seeds": seeds,
46
+ "codes": codes, # Store codes used for this batch
47
+ "scores": scores if scores else [""] * 8, # Store scores, default to empty
48
+ "allow_lm_batch": allow_lm_batch, # Store batch mode setting
49
+ "batch_size": batch_size, # Store batch size
50
+ "generation_params": generation_params if generation_params else {}, # Store ALL parameters
51
+ "lm_generated_metadata": lm_generated_metadata, # Store LM metadata for scoring
52
+ "timestamp": datetime.datetime.now().isoformat()
53
+ }
54
+ return batch_queue
55
+
56
+
57
+ def update_batch_indicator(current_batch, total_batches):
58
+ """Update batch indicator text"""
59
+ return t("results.batch_indicator", current=current_batch + 1, total=total_batches)
60
+
61
+
62
+ def update_navigation_buttons(current_batch, total_batches):
63
+ """Determine navigation button states"""
64
+ can_go_previous = current_batch > 0
65
+ can_go_next = current_batch < total_batches - 1
66
+ return can_go_previous, can_go_next
67
+
68
+
69
+ def save_audio_and_metadata(
70
+ audio_path, task_type, captions, lyrics, vocal_language, bpm, key_scale, time_signature, audio_duration,
71
+ batch_size_input, inference_steps, guidance_scale, seed, random_seed_checkbox,
72
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format,
73
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
74
+ use_cot_caption, use_cot_language, audio_cover_strength,
75
+ think_checkbox, text2music_audio_code_string, repainting_start, repainting_end,
76
+ track_name, complete_track_classes, lm_metadata
77
+ ):
78
+ """Save audio file and its metadata as a zip package"""
79
+ if audio_path is None:
80
+ gr.Warning(t("messages.no_audio_to_save"))
81
+ return None
82
+
83
+ try:
84
+ # Create metadata dictionary
85
+ metadata = {
86
+ "saved_at": datetime.datetime.now().isoformat(),
87
+ "task_type": task_type,
88
+ "caption": captions or "",
89
+ "lyrics": lyrics or "",
90
+ "vocal_language": vocal_language,
91
+ "bpm": bpm if bpm is not None else None,
92
+ "keyscale": key_scale or "",
93
+ "timesignature": time_signature or "",
94
+ "duration": audio_duration if audio_duration is not None else -1,
95
+ "batch_size": batch_size_input,
96
+ "inference_steps": inference_steps,
97
+ "guidance_scale": guidance_scale,
98
+ "seed": seed,
99
+ "random_seed": False, # Disable random seed for reproducibility
100
+ "use_adg": use_adg,
101
+ "cfg_interval_start": cfg_interval_start,
102
+ "cfg_interval_end": cfg_interval_end,
103
+ "audio_format": audio_format,
104
+ "lm_temperature": lm_temperature,
105
+ "lm_cfg_scale": lm_cfg_scale,
106
+ "lm_top_k": lm_top_k,
107
+ "lm_top_p": lm_top_p,
108
+ "lm_negative_prompt": lm_negative_prompt,
109
+ "use_cot_caption": use_cot_caption,
110
+ "use_cot_language": use_cot_language,
111
+ "audio_cover_strength": audio_cover_strength,
112
+ "think": think_checkbox,
113
+ "audio_codes": text2music_audio_code_string or "",
114
+ "repainting_start": repainting_start,
115
+ "repainting_end": repainting_end,
116
+ "track_name": track_name,
117
+ "complete_track_classes": complete_track_classes or [],
118
+ }
119
+
120
+ # Add LM-generated metadata if available
121
+ if lm_metadata:
122
+ metadata["lm_generated_metadata"] = lm_metadata
123
+
124
+ # Generate timestamp and base name
125
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
126
+
127
+ # Extract audio filename extension
128
+ audio_ext = os.path.splitext(audio_path)[1]
129
+
130
+ # Create temporary directory for packaging
131
+ temp_dir = tempfile.mkdtemp()
132
+
133
+ # Save JSON metadata
134
+ json_path = os.path.join(temp_dir, f"metadata_{timestamp}.json")
135
+ with open(json_path, 'w', encoding='utf-8') as f:
136
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
137
+
138
+ # Copy audio file
139
+ audio_copy_path = os.path.join(temp_dir, f"audio_{timestamp}{audio_ext}")
140
+ shutil.copy2(audio_path, audio_copy_path)
141
+
142
+ # Create zip file
143
+ zip_path = os.path.join(tempfile.gettempdir(), f"music_package_{timestamp}.zip")
144
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
145
+ zipf.write(audio_copy_path, os.path.basename(audio_copy_path))
146
+ zipf.write(json_path, os.path.basename(json_path))
147
+
148
+ # Clean up temp directory
149
+ shutil.rmtree(temp_dir)
150
+
151
+ gr.Info(t("messages.save_success", filename=os.path.basename(zip_path)))
152
+ return zip_path
153
+
154
+ except Exception as e:
155
+ gr.Warning(t("messages.save_failed", error=str(e)))
156
+ import traceback
157
+ traceback.print_exc()
158
+ return None
159
+
160
+
161
+ def send_audio_to_src_with_metadata(audio_file, lm_metadata):
162
+ """Send generated audio file to src_audio input and populate metadata fields
163
+
164
+ Args:
165
+ audio_file: Audio file path
166
+ lm_metadata: Dictionary containing LM-generated metadata
167
+
168
+ Returns:
169
+ Tuple of (audio_file, bpm, caption, lyrics, duration, key_scale, language, time_signature, is_format_caption)
170
+ """
171
+ if audio_file is None:
172
+ return None, None, None, None, None, None, None, None, True # Keep is_format_caption as True
173
+
174
+ # Extract metadata fields if available
175
+ bpm_value = None
176
+ caption_value = None
177
+ lyrics_value = None
178
+ duration_value = None
179
+ key_scale_value = None
180
+ language_value = None
181
+ time_signature_value = None
182
+
183
+ if lm_metadata:
184
+ # BPM
185
+ if lm_metadata.get('bpm'):
186
+ bpm_str = lm_metadata.get('bpm')
187
+ if bpm_str and bpm_str != "N/A":
188
+ try:
189
+ bpm_value = int(bpm_str)
190
+ except (ValueError, TypeError):
191
+ pass
192
+
193
+ # Caption (Rewritten Caption)
194
+ if lm_metadata.get('caption'):
195
+ caption_value = lm_metadata.get('caption')
196
+
197
+ # Lyrics
198
+ if lm_metadata.get('lyrics'):
199
+ lyrics_value = lm_metadata.get('lyrics')
200
+
201
+ # Duration
202
+ if lm_metadata.get('duration'):
203
+ duration_str = lm_metadata.get('duration')
204
+ if duration_str and duration_str != "N/A":
205
+ try:
206
+ duration_value = float(duration_str)
207
+ except (ValueError, TypeError):
208
+ pass
209
+
210
+ # KeyScale
211
+ if lm_metadata.get('keyscale'):
212
+ key_scale_str = lm_metadata.get('keyscale')
213
+ if key_scale_str and key_scale_str != "N/A":
214
+ key_scale_value = key_scale_str
215
+
216
+ # Language
217
+ if lm_metadata.get('language'):
218
+ language_str = lm_metadata.get('language')
219
+ if language_str and language_str != "N/A":
220
+ language_value = language_str
221
+
222
+ # Time Signature
223
+ if lm_metadata.get('timesignature'):
224
+ time_sig_str = lm_metadata.get('timesignature')
225
+ if time_sig_str and time_sig_str != "N/A":
226
+ time_signature_value = time_sig_str
227
+
228
+ return (
229
+ audio_file,
230
+ bpm_value,
231
+ caption_value,
232
+ lyrics_value,
233
+ duration_value,
234
+ key_scale_value,
235
+ language_value,
236
+ time_signature_value,
237
+ True # Set is_format_caption to True (from LM-generated metadata)
238
+ )
239
+
240
+
241
+ def generate_with_progress(
242
+ dit_handler, llm_handler,
243
+ captions, lyrics, bpm, key_scale, time_signature, vocal_language,
244
+ inference_steps, guidance_scale, random_seed_checkbox, seed,
245
+ reference_audio, audio_duration, batch_size_input, src_audio,
246
+ text2music_audio_code_string, repainting_start, repainting_end,
247
+ instruction_display_gen, audio_cover_strength, task_type,
248
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
249
+ think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
250
+ use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
251
+ constrained_decoding_debug,
252
+ allow_lm_batch,
253
+ auto_score,
254
+ score_scale,
255
+ lm_batch_chunk_size,
256
+ progress=gr.Progress(track_tqdm=True)
257
+ ):
258
+ """Generate audio with progress tracking"""
259
+ # If think is enabled (llm_dit mode) and use_cot_metas is True, generate audio codes using LM first
260
+ audio_code_string_to_use = text2music_audio_code_string
261
+ lm_generated_metadata = None # Store LM-generated metadata for display
262
+ lm_generated_audio_codes = None # Store LM-generated audio codes for display
263
+ lm_generated_audio_codes_list = [] # Store list of audio codes for batch processing
264
+
265
+ # Determine if we should use batch LM generation
266
+ should_use_lm_batch = (
267
+ think_checkbox and
268
+ llm_handler.llm_initialized and
269
+ use_cot_metas and
270
+ allow_lm_batch and
271
+ batch_size_input >= 2
272
+ )
273
+
274
+ if think_checkbox and llm_handler.llm_initialized and use_cot_metas:
275
+ # Convert top_k: 0 means None (disabled)
276
+ top_k_value = None if lm_top_k == 0 else int(lm_top_k)
277
+ # Convert top_p: 1.0 means None (disabled)
278
+ top_p_value = None if lm_top_p >= 1.0 else lm_top_p
279
+
280
+ # Build user_metadata from user-provided values (only include non-empty values)
281
+ user_metadata = {}
282
+ # Handle bpm: gr.Number can be None, int, float, or string
283
+ if bpm is not None:
284
+ try:
285
+ bpm_value = float(bpm)
286
+ if bpm_value > 0:
287
+ user_metadata['bpm'] = str(int(bpm_value))
288
+ except (ValueError, TypeError):
289
+ # If bpm is not a valid number, skip it
290
+ pass
291
+ if key_scale and key_scale.strip():
292
+ key_scale_clean = key_scale.strip()
293
+ if key_scale_clean.lower() not in ["n/a", ""]:
294
+ user_metadata['keyscale'] = key_scale_clean
295
+ if time_signature and time_signature.strip():
296
+ time_sig_clean = time_signature.strip()
297
+ if time_sig_clean.lower() not in ["n/a", ""]:
298
+ user_metadata['timesignature'] = time_sig_clean
299
+ if audio_duration is not None:
300
+ try:
301
+ duration_value = float(audio_duration)
302
+ if duration_value > 0:
303
+ user_metadata['duration'] = str(int(duration_value))
304
+ except (ValueError, TypeError):
305
+ # If audio_duration is not a valid number, skip it
306
+ pass
307
+
308
+ # Only pass user_metadata if user provided any values, otherwise let LM generate
309
+ user_metadata_to_pass = user_metadata if user_metadata else None
310
+
311
+ if should_use_lm_batch:
312
+ # BATCH LM GENERATION
313
+ import math
314
+ from acestep.handler import AceStepHandler
315
+
316
+ logger.info(f"Using LM batch generation for {batch_size_input} items...")
317
+
318
+ # Prepare seeds for batch items
319
+ temp_handler = AceStepHandler()
320
+ actual_seed_list, _ = temp_handler.prepare_seeds(batch_size_input, seed, random_seed_checkbox)
321
+
322
+ # Split batch into chunks (GPU memory constraint)
323
+ max_inference_batch_size = int(lm_batch_chunk_size)
324
+ num_chunks = math.ceil(batch_size_input / max_inference_batch_size)
325
+
326
+ all_metadata_list = []
327
+ all_audio_codes_list = []
328
+
329
+ for chunk_idx in range(num_chunks):
330
+ chunk_start = chunk_idx * max_inference_batch_size
331
+ chunk_end = min(chunk_start + max_inference_batch_size, batch_size_input)
332
+ chunk_size = chunk_end - chunk_start
333
+ chunk_seeds = actual_seed_list[chunk_start:chunk_end]
334
+
335
+ logger.info(f"Generating LM batch chunk {chunk_idx+1}/{num_chunks} (size: {chunk_size}, seeds: {chunk_seeds})...")
336
+
337
+ # Generate batch
338
+ metadata_list, audio_codes_list, status = llm_handler.generate_with_stop_condition_batch(
339
+ caption=captions or "",
340
+ lyrics=lyrics or "",
341
+ batch_size=chunk_size,
342
+ infer_type="llm_dit",
343
+ temperature=lm_temperature,
344
+ cfg_scale=lm_cfg_scale,
345
+ negative_prompt=lm_negative_prompt,
346
+ top_k=top_k_value,
347
+ top_p=top_p_value,
348
+ user_metadata=user_metadata_to_pass,
349
+ use_cot_caption=use_cot_caption,
350
+ use_cot_language=use_cot_language,
351
+ is_format_caption=is_format_caption,
352
+ constrained_decoding_debug=constrained_decoding_debug,
353
+ seeds=chunk_seeds,
354
+ )
355
+
356
+ all_metadata_list.extend(metadata_list)
357
+ all_audio_codes_list.extend(audio_codes_list)
358
+
359
+ # Use first metadata as representative (all are same)
360
+ lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None
361
+
362
+ # Store audio codes list for later use
363
+ lm_generated_audio_codes_list = all_audio_codes_list
364
+
365
+ # Prepare audio codes for DiT (list of codes, one per batch item)
366
+ audio_code_string_to_use = all_audio_codes_list
367
+
368
+ # Update metadata fields from LM if not provided by user
369
+ if lm_generated_metadata:
370
+ if bpm is None and lm_generated_metadata.get('bpm'):
371
+ bpm_value = lm_generated_metadata.get('bpm')
372
+ if bpm_value != "N/A" and bpm_value != "":
373
+ try:
374
+ bpm = int(bpm_value)
375
+ except:
376
+ pass
377
+ if not key_scale and lm_generated_metadata.get('keyscale'):
378
+ key_scale_value = lm_generated_metadata.get('keyscale', lm_generated_metadata.get('key_scale', ""))
379
+ if key_scale_value != "N/A":
380
+ key_scale = key_scale_value
381
+ if not time_signature and lm_generated_metadata.get('timesignature'):
382
+ time_signature_value = lm_generated_metadata.get('timesignature', lm_generated_metadata.get('time_signature', ""))
383
+ if time_signature_value != "N/A":
384
+ time_signature = time_signature_value
385
+ if audio_duration is None or audio_duration <= 0:
386
+ audio_duration_value = lm_generated_metadata.get('duration', -1)
387
+ if audio_duration_value != "N/A" and audio_duration_value != "":
388
+ try:
389
+ audio_duration = float(audio_duration_value)
390
+ except:
391
+ pass
392
+ else:
393
+ # SEQUENTIAL LM GENERATION (current behavior, when allow_lm_batch is False)
394
+ # Phase 1: Generate CoT metadata
395
+ phase1_start = time_module.time()
396
+ metadata, _, status = llm_handler.generate_with_stop_condition(
397
+ caption=captions or "",
398
+ lyrics=lyrics or "",
399
+ infer_type="dit", # Only generate metadata in Phase 1
400
+ temperature=lm_temperature,
401
+ cfg_scale=lm_cfg_scale,
402
+ negative_prompt=lm_negative_prompt,
403
+ top_k=top_k_value,
404
+ top_p=top_p_value,
405
+ user_metadata=user_metadata_to_pass,
406
+ use_cot_caption=use_cot_caption,
407
+ use_cot_language=use_cot_language,
408
+ is_format_caption=is_format_caption,
409
+ constrained_decoding_debug=constrained_decoding_debug,
410
+ )
411
+ lm_phase1_time = time_module.time() - phase1_start
412
+ logger.info(f"LM Phase 1 (CoT) completed in {lm_phase1_time:.2f}s")
413
+
414
+ # Phase 2: Generate audio codes
415
+ phase2_start = time_module.time()
416
+ metadata, audio_codes, status = llm_handler.generate_with_stop_condition(
417
+ caption=captions or "",
418
+ lyrics=lyrics or "",
419
+ infer_type="llm_dit", # Generate both metadata and codes
420
+ temperature=lm_temperature,
421
+ cfg_scale=lm_cfg_scale,
422
+ negative_prompt=lm_negative_prompt,
423
+ top_k=top_k_value,
424
+ top_p=top_p_value,
425
+ user_metadata=user_metadata_to_pass,
426
+ use_cot_caption=use_cot_caption,
427
+ use_cot_language=use_cot_language,
428
+ is_format_caption=is_format_caption,
429
+ constrained_decoding_debug=constrained_decoding_debug,
430
+ )
431
+ lm_phase2_time = time_module.time() - phase2_start
432
+ logger.info(f"LM Phase 2 (Codes) completed in {lm_phase2_time:.2f}s")
433
+
434
+ # Store LM-generated metadata and audio codes for display
435
+ lm_generated_metadata = metadata
436
+ if audio_codes:
437
+ audio_code_string_to_use = audio_codes
438
+ lm_generated_audio_codes = audio_codes
439
+ # Update metadata fields only if they are empty/None (user didn't provide them)
440
+ if bpm is None and metadata.get('bpm'):
441
+ bpm_value = metadata.get('bpm')
442
+ if bpm_value != "N/A" and bpm_value != "":
443
+ try:
444
+ bpm = int(bpm_value)
445
+ except:
446
+ pass
447
+ if not key_scale and metadata.get('keyscale'):
448
+ key_scale_value = metadata.get('keyscale', metadata.get('key_scale', ""))
449
+ if key_scale_value != "N/A":
450
+ key_scale = key_scale_value
451
+ if not time_signature and metadata.get('timesignature'):
452
+ time_signature_value = metadata.get('timesignature', metadata.get('time_signature', ""))
453
+ if time_signature_value != "N/A":
454
+ time_signature = time_signature_value
455
+ if audio_duration is None or audio_duration <= 0:
456
+ audio_duration_value = metadata.get('duration', -1)
457
+ if audio_duration_value != "N/A" and audio_duration_value != "":
458
+ try:
459
+ audio_duration = float(audio_duration_value)
460
+ except:
461
+ pass
462
+
463
+ # Call generate_music and get results
464
+ result = dit_handler.generate_music(
465
+ captions=captions, lyrics=lyrics, bpm=bpm, key_scale=key_scale,
466
+ time_signature=time_signature, vocal_language=vocal_language,
467
+ inference_steps=inference_steps, guidance_scale=guidance_scale,
468
+ use_random_seed=random_seed_checkbox, seed=seed,
469
+ reference_audio=reference_audio, audio_duration=audio_duration,
470
+ batch_size=batch_size_input, src_audio=src_audio,
471
+ audio_code_string=audio_code_string_to_use,
472
+ repainting_start=repainting_start, repainting_end=repainting_end,
473
+ instruction=instruction_display_gen, audio_cover_strength=audio_cover_strength,
474
+ task_type=task_type, use_adg=use_adg,
475
+ cfg_interval_start=cfg_interval_start, cfg_interval_end=cfg_interval_end,
476
+ audio_format=audio_format, lm_temperature=lm_temperature,
477
+ progress=progress
478
+ )
479
+
480
+ # Extract results
481
+ first_audio, second_audio, all_audio_paths, generation_info, status_message, seed_value_for_ui, \
482
+ align_score_1, align_text_1, align_plot_1, align_score_2, align_text_2, align_plot_2 = result
483
+
484
+ # Extract LM timing from status if available and prepend to generation_info
485
+ if status:
486
+ import re
487
+ # Try to extract timing info from status using regex
488
+ # Expected format: "Phase1: X.XXs" and "Phase2: X.XXs"
489
+ phase1_match = re.search(r'Phase1:\s*([\d.]+)s', status)
490
+ phase2_match = re.search(r'Phase2:\s*([\d.]+)s', status)
491
+
492
+ if phase1_match or phase2_match:
493
+ lm_timing_section = "\n\n**🤖 LM Timing:**\n"
494
+ lm_total = 0.0
495
+ if phase1_match:
496
+ phase1_time = float(phase1_match.group(1))
497
+ lm_timing_section += f" - Phase 1 (CoT Metadata): {phase1_time:.2f}s\n"
498
+ lm_total += phase1_time
499
+ if phase2_match:
500
+ phase2_time = float(phase2_match.group(1))
501
+ lm_timing_section += f" - Phase 2 (Audio Codes): {phase2_time:.2f}s\n"
502
+ lm_total += phase2_time
503
+ if lm_total > 0:
504
+ lm_timing_section += f" - Total LM Time: {lm_total:.2f}s\n"
505
+ generation_info = lm_timing_section + "\n" + generation_info
506
+
507
+ # Append LM-generated metadata to generation_info if available
508
+ if lm_generated_metadata:
509
+ metadata_lines = []
510
+ if lm_generated_metadata.get('bpm'):
511
+ metadata_lines.append(f"- **BPM:** {lm_generated_metadata['bpm']}")
512
+ if lm_generated_metadata.get('caption'):
513
+ metadata_lines.append(f"- **User Query Rewritten Caption:** {lm_generated_metadata['caption']}")
514
+ if lm_generated_metadata.get('duration'):
515
+ metadata_lines.append(f"- **Duration:** {lm_generated_metadata['duration']} seconds")
516
+ if lm_generated_metadata.get('keyscale'):
517
+ metadata_lines.append(f"- **KeyScale:** {lm_generated_metadata['keyscale']}")
518
+ if lm_generated_metadata.get('language'):
519
+ metadata_lines.append(f"- **Language:** {lm_generated_metadata['language']}")
520
+ if lm_generated_metadata.get('timesignature'):
521
+ metadata_lines.append(f"- **Time Signature:** {lm_generated_metadata['timesignature']}")
522
+
523
+ if metadata_lines:
524
+ metadata_section = "\n\n**🤖 LM-Generated Metadata:**\n" + "\n\n".join(metadata_lines)
525
+ generation_info = metadata_section + "\n\n" + generation_info
526
+
527
+ # Update audio codes in UI if LM generated them
528
+ codes_outputs = [""] * 8 # Codes for 8 components
529
+ if should_use_lm_batch and lm_generated_audio_codes_list:
530
+ # Batch mode: update individual codes inputs
531
+ for idx in range(min(len(lm_generated_audio_codes_list), 8)):
532
+ codes_outputs[idx] = lm_generated_audio_codes_list[idx]
533
+ # For single codes input, show first one
534
+ updated_audio_codes = lm_generated_audio_codes_list[0] if lm_generated_audio_codes_list else text2music_audio_code_string
535
+ else:
536
+ # Single mode: update main codes input
537
+ updated_audio_codes = lm_generated_audio_codes if lm_generated_audio_codes else text2music_audio_code_string
538
+
539
+ # AUTO-SCORING
540
+ score_displays = [""] * 8 # Scores for 8 components
541
+ if auto_score and all_audio_paths:
542
+ logger.info(f"Auto-scoring enabled, calculating quality scores for {batch_size_input} generated audios...")
543
+
544
+ # Determine which audio codes to use for scoring
545
+ if should_use_lm_batch and lm_generated_audio_codes_list:
546
+ codes_list = lm_generated_audio_codes_list
547
+ elif audio_code_string_to_use and isinstance(audio_code_string_to_use, list):
548
+ codes_list = audio_code_string_to_use
549
+ else:
550
+ # Single code string, replicate for all audios
551
+ codes_list = [audio_code_string_to_use] * len(all_audio_paths)
552
+
553
+ # Calculate scores only for actually generated audios (up to batch_size_input)
554
+ # Don't score beyond the actual batch size to avoid duplicates
555
+ actual_audios_to_score = min(len(all_audio_paths), int(batch_size_input))
556
+ for idx in range(actual_audios_to_score):
557
+ if idx < len(codes_list) and codes_list[idx]:
558
+ try:
559
+ score_display = calculate_score_handler(
560
+ llm_handler,
561
+ codes_list[idx],
562
+ captions,
563
+ lyrics,
564
+ lm_generated_metadata,
565
+ bpm, key_scale, time_signature, audio_duration, vocal_language,
566
+ score_scale
567
+ )
568
+ score_displays[idx] = score_display
569
+ logger.info(f"Auto-scored audio {idx+1}")
570
+ except Exception as e:
571
+ logger.error(f"Auto-scoring failed for audio {idx+1}: {e}")
572
+ score_displays[idx] = f"❌ Auto-scoring failed: {str(e)}"
573
+
574
+ # Prepare audio outputs (up to 8)
575
+ audio_outputs = [None] * 8
576
+ for idx in range(min(len(all_audio_paths), 8)):
577
+ audio_outputs[idx] = all_audio_paths[idx]
578
+
579
+ return (
580
+ audio_outputs[0], # generated_audio_1
581
+ audio_outputs[1], # generated_audio_2
582
+ audio_outputs[2], # generated_audio_3
583
+ audio_outputs[3], # generated_audio_4
584
+ audio_outputs[4], # generated_audio_5
585
+ audio_outputs[5], # generated_audio_6
586
+ audio_outputs[6], # generated_audio_7
587
+ audio_outputs[7], # generated_audio_8
588
+ all_audio_paths, # generated_audio_batch
589
+ generation_info,
590
+ status_message,
591
+ seed_value_for_ui,
592
+ align_score_1,
593
+ align_text_1,
594
+ align_plot_1,
595
+ align_score_2,
596
+ align_text_2,
597
+ align_plot_2,
598
+ score_displays[0], # score_display_1
599
+ score_displays[1], # score_display_2
600
+ score_displays[2], # score_display_3
601
+ score_displays[3], # score_display_4
602
+ score_displays[4], # score_display_5
603
+ score_displays[5], # score_display_6
604
+ score_displays[6], # score_display_7
605
+ score_displays[7], # score_display_8
606
+ updated_audio_codes, # Update main audio codes in UI
607
+ codes_outputs[0], # text2music_audio_code_string_1
608
+ codes_outputs[1], # text2music_audio_code_string_2
609
+ codes_outputs[2], # text2music_audio_code_string_3
610
+ codes_outputs[3], # text2music_audio_code_string_4
611
+ codes_outputs[4], # text2music_audio_code_string_5
612
+ codes_outputs[5], # text2music_audio_code_string_6
613
+ codes_outputs[6], # text2music_audio_code_string_7
614
+ codes_outputs[7], # text2music_audio_code_string_8
615
+ lm_generated_metadata, # Store metadata for "Send to src audio" buttons
616
+ is_format_caption, # Keep is_format_caption unchanged
617
+ )
618
+
619
+
620
+ def calculate_score_handler(llm_handler, audio_codes_str, caption, lyrics, lm_metadata, bpm, key_scale, time_signature, audio_duration, vocal_language, score_scale):
621
+ """
622
+ Calculate PMI-based quality score for generated audio.
623
+
624
+ PMI (Pointwise Mutual Information) removes condition bias:
625
+ score = log P(condition|codes) - log P(condition)
626
+
627
+ Args:
628
+ llm_handler: LLM handler instance
629
+ audio_codes_str: Generated audio codes string
630
+ caption: Caption text used for generation
631
+ lyrics: Lyrics text used for generation
632
+ lm_metadata: LM-generated metadata dictionary (from CoT generation)
633
+ bpm: BPM value
634
+ key_scale: Key scale value
635
+ time_signature: Time signature value
636
+ audio_duration: Audio duration value
637
+ vocal_language: Vocal language value
638
+ score_scale: Sensitivity scale parameter
639
+
640
+ Returns:
641
+ Score display string
642
+ """
643
+ from acestep.test_time_scaling import calculate_pmi_score_per_condition
644
+
645
+ if not llm_handler.llm_initialized:
646
+ return t("messages.lm_not_initialized")
647
+
648
+ if not audio_codes_str or not audio_codes_str.strip():
649
+ return t("messages.no_codes")
650
+
651
+ try:
652
+ # Build metadata dictionary from both LM metadata and user inputs
653
+ metadata = {}
654
+
655
+ # Priority 1: Use LM-generated metadata if available
656
+ if lm_metadata and isinstance(lm_metadata, dict):
657
+ metadata.update(lm_metadata)
658
+
659
+ # Priority 2: Add user-provided metadata (if not already in LM metadata)
660
+ if bpm is not None and 'bpm' not in metadata:
661
+ try:
662
+ metadata['bpm'] = int(bpm)
663
+ except:
664
+ pass
665
+
666
+ if caption and 'caption' not in metadata:
667
+ metadata['caption'] = caption
668
+
669
+ if audio_duration is not None and audio_duration > 0 and 'duration' not in metadata:
670
+ try:
671
+ metadata['duration'] = int(audio_duration)
672
+ except:
673
+ pass
674
+
675
+ if key_scale and key_scale.strip() and 'keyscale' not in metadata:
676
+ metadata['keyscale'] = key_scale.strip()
677
+
678
+ if vocal_language and vocal_language.strip() and 'language' not in metadata:
679
+ metadata['language'] = vocal_language.strip()
680
+
681
+ if time_signature and time_signature.strip() and 'timesignature' not in metadata:
682
+ metadata['timesignature'] = time_signature.strip()
683
+
684
+ # Calculate per-condition scores with appropriate metrics
685
+ # - Metadata fields (bpm, duration, etc.): Top-k recall
686
+ # - Caption and lyrics: PMI (normalized)
687
+ scores_per_condition, global_score, status = calculate_pmi_score_per_condition(
688
+ llm_handler=llm_handler,
689
+ audio_codes=audio_codes_str,
690
+ caption=caption or "",
691
+ lyrics=lyrics or "",
692
+ metadata=metadata if metadata else None,
693
+ temperature=1.0,
694
+ topk=10,
695
+ score_scale=score_scale
696
+ )
697
+
698
+ # Format display string with per-condition breakdown
699
+ if global_score == 0.0 and not scores_per_condition:
700
+ return t("messages.score_failed", error=status)
701
+ else:
702
+ # Build per-condition scores display
703
+ condition_lines = []
704
+ for condition_name, score_value in sorted(scores_per_condition.items()):
705
+ condition_lines.append(
706
+ f" • {condition_name}: {score_value:.4f}"
707
+ )
708
+
709
+ conditions_display = "\n".join(condition_lines) if condition_lines else " (no conditions)"
710
+
711
+ return (
712
+ f"✅ Global Quality Score: {global_score:.4f} (0-1, higher=better)\n\n"
713
+ f"📊 Per-Condition Scores (0-1):\n{conditions_display}\n\n"
714
+ f"Note: Metadata uses Top-k Recall, Caption/Lyrics use PMI\n"
715
+ )
716
+
717
+ except Exception as e:
718
+ import traceback
719
+ error_msg = t("messages.score_error", error=str(e)) + f"\n{traceback.format_exc()}"
720
+ return error_msg
721
+
722
+
723
+ def calculate_score_handler_with_selection(llm_handler, sample_idx, score_scale, current_batch_index, batch_queue):
724
+ """
725
+ Calculate PMI-based quality score - REFACTORED to read from batch_queue only.
726
+ This ensures scoring uses the actual generation parameters, not current UI values.
727
+
728
+ Args:
729
+ llm_handler: LLM handler instance
730
+ sample_idx: Which sample to score (1-8)
731
+ score_scale: Sensitivity scale parameter (tool setting, can be from UI)
732
+ current_batch_index: Current batch index
733
+ batch_queue: Batch queue containing historical generation data
734
+ """
735
+ if current_batch_index not in batch_queue:
736
+ return t("messages.scoring_failed"), batch_queue
737
+
738
+ batch_data = batch_queue[current_batch_index]
739
+ params = batch_data.get("generation_params", {})
740
+
741
+ # Read ALL parameters from historical batch data
742
+ caption = params.get("captions", "")
743
+ lyrics = params.get("lyrics", "")
744
+ bpm = params.get("bpm")
745
+ key_scale = params.get("key_scale", "")
746
+ time_signature = params.get("time_signature", "")
747
+ audio_duration = params.get("audio_duration", -1)
748
+ vocal_language = params.get("vocal_language", "")
749
+
750
+ # Get LM metadata from batch_data (if it was saved during generation)
751
+ lm_metadata = batch_data.get("lm_generated_metadata", None)
752
+
753
+ # Get codes from batch_data
754
+ stored_codes = batch_data.get("codes", "")
755
+ stored_allow_lm_batch = batch_data.get("allow_lm_batch", False)
756
+
757
+ # Select correct codes for this sample
758
+ audio_codes_str = ""
759
+ if stored_allow_lm_batch and isinstance(stored_codes, list):
760
+ # Batch mode: use specific sample's codes
761
+ if 0 <= sample_idx - 1 < len(stored_codes):
762
+ audio_codes_str = stored_codes[sample_idx - 1]
763
+ else:
764
+ # Single mode: all samples use same codes
765
+ audio_codes_str = stored_codes if isinstance(stored_codes, str) else ""
766
+
767
+ # Calculate score using historical parameters
768
+ score_display = calculate_score_handler(
769
+ llm_handler,
770
+ audio_codes_str, caption, lyrics, lm_metadata,
771
+ bpm, key_scale, time_signature, audio_duration, vocal_language,
772
+ score_scale
773
+ )
774
+
775
+ # Update batch_queue with the calculated score
776
+ if current_batch_index in batch_queue:
777
+ if "scores" not in batch_queue[current_batch_index]:
778
+ batch_queue[current_batch_index]["scores"] = [""] * 8
779
+ batch_queue[current_batch_index]["scores"][sample_idx - 1] = score_display
780
+
781
+ return score_display, batch_queue
782
+
783
+
784
+ def capture_current_params(
785
+ captions, lyrics, bpm, key_scale, time_signature, vocal_language,
786
+ inference_steps, guidance_scale, random_seed_checkbox, seed,
787
+ reference_audio, audio_duration, batch_size_input, src_audio,
788
+ text2music_audio_code_string, repainting_start, repainting_end,
789
+ instruction_display_gen, audio_cover_strength, task_type,
790
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
791
+ think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
792
+ use_cot_metas, use_cot_caption, use_cot_language,
793
+ constrained_decoding_debug, allow_lm_batch, auto_score, score_scale, lm_batch_chunk_size,
794
+ track_name, complete_track_classes
795
+ ):
796
+ """Capture current UI parameters for next batch generation
797
+
798
+ IMPORTANT: For AutoGen batches, we clear audio codes to ensure:
799
+ - Thinking mode: LM generates NEW codes for each batch
800
+ - Non-thinking mode: DiT generates with different random seeds
801
+ """
802
+ return {
803
+ "captions": captions,
804
+ "lyrics": lyrics,
805
+ "bpm": bpm,
806
+ "key_scale": key_scale,
807
+ "time_signature": time_signature,
808
+ "vocal_language": vocal_language,
809
+ "inference_steps": inference_steps,
810
+ "guidance_scale": guidance_scale,
811
+ "random_seed_checkbox": True, # Always use random for AutoGen batches
812
+ "seed": seed,
813
+ "reference_audio": reference_audio,
814
+ "audio_duration": audio_duration,
815
+ "batch_size_input": batch_size_input,
816
+ "src_audio": src_audio,
817
+ "text2music_audio_code_string": "", # CLEAR codes for next batch! Let LM regenerate or DiT use new seeds
818
+ "repainting_start": repainting_start,
819
+ "repainting_end": repainting_end,
820
+ "instruction_display_gen": instruction_display_gen,
821
+ "audio_cover_strength": audio_cover_strength,
822
+ "task_type": task_type,
823
+ "use_adg": use_adg,
824
+ "cfg_interval_start": cfg_interval_start,
825
+ "cfg_interval_end": cfg_interval_end,
826
+ "audio_format": audio_format,
827
+ "lm_temperature": lm_temperature,
828
+ "think_checkbox": think_checkbox,
829
+ "lm_cfg_scale": lm_cfg_scale,
830
+ "lm_top_k": lm_top_k,
831
+ "lm_top_p": lm_top_p,
832
+ "lm_negative_prompt": lm_negative_prompt,
833
+ "use_cot_metas": use_cot_metas,
834
+ "use_cot_caption": use_cot_caption,
835
+ "use_cot_language": use_cot_language,
836
+ "constrained_decoding_debug": constrained_decoding_debug,
837
+ "allow_lm_batch": allow_lm_batch,
838
+ "auto_score": auto_score,
839
+ "score_scale": score_scale,
840
+ "lm_batch_chunk_size": lm_batch_chunk_size,
841
+ "track_name": track_name,
842
+ "complete_track_classes": complete_track_classes,
843
+ }
844
+
845
+
846
+ def generate_with_batch_management(
847
+ dit_handler, llm_handler,
848
+ captions, lyrics, bpm, key_scale, time_signature, vocal_language,
849
+ inference_steps, guidance_scale, random_seed_checkbox, seed,
850
+ reference_audio, audio_duration, batch_size_input, src_audio,
851
+ text2music_audio_code_string, repainting_start, repainting_end,
852
+ instruction_display_gen, audio_cover_strength, task_type,
853
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
854
+ think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
855
+ use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
856
+ constrained_decoding_debug,
857
+ allow_lm_batch,
858
+ auto_score,
859
+ score_scale,
860
+ lm_batch_chunk_size,
861
+ track_name,
862
+ complete_track_classes,
863
+ autogen_checkbox,
864
+ current_batch_index,
865
+ total_batches,
866
+ batch_queue,
867
+ generation_params_state,
868
+ progress=gr.Progress(track_tqdm=True)
869
+ ):
870
+ """
871
+ Wrapper for generate_with_progress that adds batch queue management
872
+ """
873
+ # Call the original generation function
874
+ result = generate_with_progress(
875
+ dit_handler, llm_handler,
876
+ captions, lyrics, bpm, key_scale, time_signature, vocal_language,
877
+ inference_steps, guidance_scale, random_seed_checkbox, seed,
878
+ reference_audio, audio_duration, batch_size_input, src_audio,
879
+ text2music_audio_code_string, repainting_start, repainting_end,
880
+ instruction_display_gen, audio_cover_strength, task_type,
881
+ use_adg, cfg_interval_start, cfg_interval_end, audio_format, lm_temperature,
882
+ think_checkbox, lm_cfg_scale, lm_top_k, lm_top_p, lm_negative_prompt,
883
+ use_cot_metas, use_cot_caption, use_cot_language, is_format_caption,
884
+ constrained_decoding_debug,
885
+ allow_lm_batch,
886
+ auto_score,
887
+ score_scale,
888
+ lm_batch_chunk_size,
889
+ progress
890
+ )
891
+
892
+ # Extract results from generation
893
+ all_audio_paths = result[8] # generated_audio_batch
894
+ generation_info = result[9]
895
+ seed_value_for_ui = result[11]
896
+ lm_generated_metadata = result[34] # Index 34 is lm_metadata_state
897
+
898
+ # Extract codes
899
+ generated_codes_single = result[26]
900
+ generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
901
+
902
+ # Determine which codes to store based on mode
903
+ if allow_lm_batch and batch_size_input >= 2:
904
+ codes_to_store = generated_codes_batch[:int(batch_size_input)]
905
+ else:
906
+ codes_to_store = generated_codes_single
907
+
908
+ # Save parameters for history
909
+ saved_params = {
910
+ "captions": captions,
911
+ "lyrics": lyrics,
912
+ "bpm": bpm,
913
+ "key_scale": key_scale,
914
+ "time_signature": time_signature,
915
+ "vocal_language": vocal_language,
916
+ "inference_steps": inference_steps,
917
+ "guidance_scale": guidance_scale,
918
+ "random_seed_checkbox": random_seed_checkbox,
919
+ "seed": seed,
920
+ "reference_audio": reference_audio,
921
+ "audio_duration": audio_duration,
922
+ "batch_size_input": batch_size_input,
923
+ "src_audio": src_audio,
924
+ "text2music_audio_code_string": text2music_audio_code_string,
925
+ "repainting_start": repainting_start,
926
+ "repainting_end": repainting_end,
927
+ "instruction_display_gen": instruction_display_gen,
928
+ "audio_cover_strength": audio_cover_strength,
929
+ "task_type": task_type,
930
+ "use_adg": use_adg,
931
+ "cfg_interval_start": cfg_interval_start,
932
+ "cfg_interval_end": cfg_interval_end,
933
+ "audio_format": audio_format,
934
+ "lm_temperature": lm_temperature,
935
+ "think_checkbox": think_checkbox,
936
+ "lm_cfg_scale": lm_cfg_scale,
937
+ "lm_top_k": lm_top_k,
938
+ "lm_top_p": lm_top_p,
939
+ "lm_negative_prompt": lm_negative_prompt,
940
+ "use_cot_metas": use_cot_metas,
941
+ "use_cot_caption": use_cot_caption,
942
+ "use_cot_language": use_cot_language,
943
+ "constrained_decoding_debug": constrained_decoding_debug,
944
+ "allow_lm_batch": allow_lm_batch,
945
+ "auto_score": auto_score,
946
+ "score_scale": score_scale,
947
+ "lm_batch_chunk_size": lm_batch_chunk_size,
948
+ "track_name": track_name,
949
+ "complete_track_classes": complete_track_classes,
950
+ }
951
+
952
+ # Next batch parameters (with cleared codes & random seed)
953
+ next_params = saved_params.copy()
954
+ next_params["text2music_audio_code_string"] = ""
955
+ next_params["random_seed_checkbox"] = True
956
+
957
+ # Store current batch in queue
958
+ batch_queue = store_batch_in_queue(
959
+ batch_queue,
960
+ current_batch_index,
961
+ all_audio_paths,
962
+ generation_info,
963
+ seed_value_for_ui,
964
+ codes=codes_to_store,
965
+ allow_lm_batch=allow_lm_batch,
966
+ batch_size=int(batch_size_input),
967
+ generation_params=saved_params,
968
+ lm_generated_metadata=lm_generated_metadata,
969
+ status="completed"
970
+ )
971
+
972
+ # Update batch counters
973
+ total_batches = max(total_batches, current_batch_index + 1)
974
+
975
+ # Update batch indicator
976
+ batch_indicator_text = update_batch_indicator(current_batch_index, total_batches)
977
+
978
+ # Update navigation button states
979
+ can_go_previous, can_go_next = update_navigation_buttons(current_batch_index, total_batches)
980
+
981
+ # Prepare next batch status message
982
+ next_batch_status_text = ""
983
+ if autogen_checkbox:
984
+ next_batch_status_text = t("messages.autogen_enabled")
985
+
986
+ # Return original results plus batch management state updates
987
+ return result + (
988
+ current_batch_index,
989
+ total_batches,
990
+ batch_queue,
991
+ next_params,
992
+ batch_indicator_text,
993
+ gr.update(interactive=can_go_previous),
994
+ gr.update(interactive=can_go_next),
995
+ next_batch_status_text,
996
+ gr.update(interactive=True),
997
+ )
998
+
999
+
1000
+ def generate_next_batch_background(
1001
+ dit_handler,
1002
+ llm_handler,
1003
+ autogen_enabled,
1004
+ generation_params,
1005
+ current_batch_index,
1006
+ total_batches,
1007
+ batch_queue,
1008
+ is_format_caption,
1009
+ progress=gr.Progress(track_tqdm=True)
1010
+ ):
1011
+ """
1012
+ Generate next batch in background if AutoGen is enabled
1013
+ """
1014
+ # Early return if AutoGen not enabled
1015
+ if not autogen_enabled:
1016
+ return (
1017
+ batch_queue,
1018
+ total_batches,
1019
+ "",
1020
+ gr.update(interactive=False),
1021
+ )
1022
+
1023
+ # Calculate next batch index
1024
+ next_batch_idx = current_batch_index + 1
1025
+
1026
+ # Check if next batch already exists
1027
+ if next_batch_idx in batch_queue and batch_queue[next_batch_idx].get("status") == "completed":
1028
+ return (
1029
+ batch_queue,
1030
+ total_batches,
1031
+ t("messages.batch_ready", n=next_batch_idx + 1),
1032
+ gr.update(interactive=True),
1033
+ )
1034
+
1035
+ # Update total batches count
1036
+ total_batches = next_batch_idx + 1
1037
+
1038
+ gr.Info(t("messages.batch_generating", n=next_batch_idx + 1))
1039
+
1040
+ # Generate next batch using stored parameters
1041
+ params = generation_params.copy()
1042
+
1043
+ # DEBUG LOGGING: Log all parameters used for background generation
1044
+ logger.info(f"========== BACKGROUND GENERATION BATCH {next_batch_idx + 1} ==========")
1045
+ logger.info(f"Parameters used for background generation:")
1046
+ logger.info(f" - captions: {params.get('captions', 'N/A')}")
1047
+ logger.info(f" - lyrics: {params.get('lyrics', 'N/A')[:50]}..." if params.get('lyrics') else " - lyrics: N/A")
1048
+ logger.info(f" - bpm: {params.get('bpm')}")
1049
+ logger.info(f" - batch_size_input: {params.get('batch_size_input')}")
1050
+ logger.info(f" - allow_lm_batch: {params.get('allow_lm_batch')}")
1051
+ logger.info(f" - think_checkbox: {params.get('think_checkbox')}")
1052
+ logger.info(f" - lm_temperature: {params.get('lm_temperature')}")
1053
+ logger.info(f" - track_name: {params.get('track_name')}")
1054
+ logger.info(f" - complete_track_classes: {params.get('complete_track_classes')}")
1055
+ logger.info(f" - text2music_audio_code_string: {'<CLEARED>' if params.get('text2music_audio_code_string') == '' else 'HAS_VALUE'}")
1056
+ logger.info(f"=========================================================")
1057
+
1058
+ # Add error handling for background generation
1059
+ try:
1060
+ # Ensure all parameters have default values to prevent None errors
1061
+ params.setdefault("captions", "")
1062
+ params.setdefault("lyrics", "")
1063
+ params.setdefault("bpm", None)
1064
+ params.setdefault("key_scale", "")
1065
+ params.setdefault("time_signature", "")
1066
+ params.setdefault("vocal_language", "unknown")
1067
+ params.setdefault("inference_steps", 8)
1068
+ params.setdefault("guidance_scale", 7.0)
1069
+ params.setdefault("random_seed_checkbox", True)
1070
+ params.setdefault("seed", "-1")
1071
+ params.setdefault("reference_audio", None)
1072
+ params.setdefault("audio_duration", -1)
1073
+ params.setdefault("batch_size_input", 2)
1074
+ params.setdefault("src_audio", None)
1075
+ params.setdefault("text2music_audio_code_string", "")
1076
+ params.setdefault("repainting_start", 0.0)
1077
+ params.setdefault("repainting_end", -1)
1078
+ params.setdefault("instruction_display_gen", "")
1079
+ params.setdefault("audio_cover_strength", 1.0)
1080
+ params.setdefault("task_type", "text2music")
1081
+ params.setdefault("use_adg", False)
1082
+ params.setdefault("cfg_interval_start", 0.0)
1083
+ params.setdefault("cfg_interval_end", 1.0)
1084
+ params.setdefault("audio_format", "mp3")
1085
+ params.setdefault("lm_temperature", 0.85)
1086
+ params.setdefault("think_checkbox", True)
1087
+ params.setdefault("lm_cfg_scale", 2.0)
1088
+ params.setdefault("lm_top_k", 0)
1089
+ params.setdefault("lm_top_p", 0.9)
1090
+ params.setdefault("lm_negative_prompt", "NO USER INPUT")
1091
+ params.setdefault("use_cot_metas", True)
1092
+ params.setdefault("use_cot_caption", True)
1093
+ params.setdefault("use_cot_language", True)
1094
+ params.setdefault("constrained_decoding_debug", False)
1095
+ params.setdefault("allow_lm_batch", True)
1096
+ params.setdefault("auto_score", False)
1097
+ params.setdefault("score_scale", 0.5)
1098
+ params.setdefault("lm_batch_chunk_size", 8)
1099
+ params.setdefault("track_name", None)
1100
+ params.setdefault("complete_track_classes", [])
1101
+
1102
+ # Call generate_with_progress with the saved parameters
1103
+ result = generate_with_progress(
1104
+ dit_handler,
1105
+ llm_handler,
1106
+ captions=params.get("captions"),
1107
+ lyrics=params.get("lyrics"),
1108
+ bpm=params.get("bpm"),
1109
+ key_scale=params.get("key_scale"),
1110
+ time_signature=params.get("time_signature"),
1111
+ vocal_language=params.get("vocal_language"),
1112
+ inference_steps=params.get("inference_steps"),
1113
+ guidance_scale=params.get("guidance_scale"),
1114
+ random_seed_checkbox=params.get("random_seed_checkbox"),
1115
+ seed=params.get("seed"),
1116
+ reference_audio=params.get("reference_audio"),
1117
+ audio_duration=params.get("audio_duration"),
1118
+ batch_size_input=params.get("batch_size_input"),
1119
+ src_audio=params.get("src_audio"),
1120
+ text2music_audio_code_string=params.get("text2music_audio_code_string"),
1121
+ repainting_start=params.get("repainting_start"),
1122
+ repainting_end=params.get("repainting_end"),
1123
+ instruction_display_gen=params.get("instruction_display_gen"),
1124
+ audio_cover_strength=params.get("audio_cover_strength"),
1125
+ task_type=params.get("task_type"),
1126
+ use_adg=params.get("use_adg"),
1127
+ cfg_interval_start=params.get("cfg_interval_start"),
1128
+ cfg_interval_end=params.get("cfg_interval_end"),
1129
+ audio_format=params.get("audio_format"),
1130
+ lm_temperature=params.get("lm_temperature"),
1131
+ think_checkbox=params.get("think_checkbox"),
1132
+ lm_cfg_scale=params.get("lm_cfg_scale"),
1133
+ lm_top_k=params.get("lm_top_k"),
1134
+ lm_top_p=params.get("lm_top_p"),
1135
+ lm_negative_prompt=params.get("lm_negative_prompt"),
1136
+ use_cot_metas=params.get("use_cot_metas"),
1137
+ use_cot_caption=params.get("use_cot_caption"),
1138
+ use_cot_language=params.get("use_cot_language"),
1139
+ is_format_caption=is_format_caption,
1140
+ constrained_decoding_debug=params.get("constrained_decoding_debug"),
1141
+ allow_lm_batch=params.get("allow_lm_batch"),
1142
+ auto_score=params.get("auto_score"),
1143
+ score_scale=params.get("score_scale"),
1144
+ lm_batch_chunk_size=params.get("lm_batch_chunk_size"),
1145
+ progress=progress
1146
+ )
1147
+
1148
+ # Extract results
1149
+ all_audio_paths = result[8] # generated_audio_batch
1150
+ generation_info = result[9]
1151
+ seed_value_for_ui = result[11]
1152
+ lm_generated_metadata = result[34] # Index 34 is lm_metadata_state
1153
+
1154
+ # Extract codes
1155
+ generated_codes_single = result[26]
1156
+ generated_codes_batch = [result[27], result[28], result[29], result[30], result[31], result[32], result[33], result[34]]
1157
+
1158
+ # Determine which codes to store
1159
+ batch_size = params.get("batch_size_input", 2)
1160
+ allow_lm_batch = params.get("allow_lm_batch", False)
1161
+ if allow_lm_batch and batch_size >= 2:
1162
+ codes_to_store = generated_codes_batch[:int(batch_size)]
1163
+ else:
1164
+ codes_to_store = generated_codes_single
1165
+
1166
+ # DEBUG LOGGING: Log codes extraction and storage
1167
+ logger.info(f"Codes extraction for Batch {next_batch_idx + 1}:")
1168
+ logger.info(f" - allow_lm_batch: {allow_lm_batch}")
1169
+ logger.info(f" - batch_size: {batch_size}")
1170
+ logger.info(f" - generated_codes_single exists: {bool(generated_codes_single)}")
1171
+ if isinstance(codes_to_store, list):
1172
+ logger.info(f" - codes_to_store: LIST with {len(codes_to_store)} items")
1173
+ for idx, code in enumerate(codes_to_store):
1174
+ logger.info(f" * Sample {idx + 1}: {len(code) if code else 0} chars")
1175
+ else:
1176
+ logger.info(f" - codes_to_store: STRING with {len(codes_to_store) if codes_to_store else 0} chars")
1177
+
1178
+ # Store next batch in queue with codes, batch settings, and ALL generation params
1179
+ batch_queue = store_batch_in_queue(
1180
+ batch_queue,
1181
+ next_batch_idx,
1182
+ all_audio_paths,
1183
+ generation_info,
1184
+ seed_value_for_ui,
1185
+ codes=codes_to_store,
1186
+ allow_lm_batch=allow_lm_batch,
1187
+ batch_size=int(batch_size),
1188
+ generation_params=params,
1189
+ lm_generated_metadata=lm_generated_metadata,
1190
+ status="completed"
1191
+ )
1192
+
1193
+ logger.info(f"Batch {next_batch_idx + 1} stored in queue successfully")
1194
+
1195
+ # Success message
1196
+ next_batch_status = t("messages.batch_ready", n=next_batch_idx + 1)
1197
+
1198
+ # Enable next button now that batch is ready
1199
+ return (
1200
+ batch_queue,
1201
+ total_batches,
1202
+ next_batch_status,
1203
+ gr.update(interactive=True),
1204
+ )
1205
+ except Exception as e:
1206
+ # Handle generation errors
1207
+ import traceback
1208
+ error_msg = t("messages.batch_failed", error=str(e))
1209
+ gr.Warning(error_msg)
1210
+
1211
+ # Mark batch as failed in queue
1212
+ batch_queue[next_batch_idx] = {
1213
+ "status": "error",
1214
+ "error": str(e),
1215
+ "traceback": traceback.format_exc()
1216
+ }
1217
+
1218
+ return (
1219
+ batch_queue,
1220
+ total_batches,
1221
+ error_msg,
1222
+ gr.update(interactive=False),
1223
+ )
1224
+
1225
+
1226
+ def navigate_to_previous_batch(current_batch_index, batch_queue):
1227
+ """Navigate to previous batch (Result View Only - Never touches Input UI)"""
1228
+ if current_batch_index <= 0:
1229
+ gr.Warning(t("messages.at_first_batch"))
1230
+ return [gr.update()] * 24
1231
+
1232
+ # Move to previous batch
1233
+ new_batch_index = current_batch_index - 1
1234
+
1235
+ # Load batch data from queue
1236
+ if new_batch_index not in batch_queue:
1237
+ gr.Warning(t("messages.batch_not_found", n=new_batch_index + 1))
1238
+ return [gr.update()] * 24
1239
+
1240
+ batch_data = batch_queue[new_batch_index]
1241
+ audio_paths = batch_data.get("audio_paths", [])
1242
+ generation_info_text = batch_data.get("generation_info", "")
1243
+
1244
+ # Prepare audio outputs (up to 8)
1245
+ audio_outputs = [None] * 8
1246
+ for idx in range(min(len(audio_paths), 8)):
1247
+ audio_outputs[idx] = audio_paths[idx]
1248
+
1249
+ # Update batch indicator
1250
+ total_batches = len(batch_queue)
1251
+ batch_indicator_text = update_batch_indicator(new_batch_index, total_batches)
1252
+
1253
+ # Update button states
1254
+ can_go_previous, can_go_next = update_navigation_buttons(new_batch_index, total_batches)
1255
+
1256
+ # Restore score displays from batch queue
1257
+ stored_scores = batch_data.get("scores", [""] * 8)
1258
+ score_displays = stored_scores if stored_scores else [""] * 8
1259
+
1260
+ return (
1261
+ audio_outputs[0], audio_outputs[1], audio_outputs[2], audio_outputs[3],
1262
+ audio_outputs[4], audio_outputs[5], audio_outputs[6], audio_outputs[7],
1263
+ audio_paths, generation_info_text, new_batch_index, batch_indicator_text,
1264
+ gr.update(interactive=can_go_previous), gr.update(interactive=can_go_next),
1265
+ t("messages.viewing_batch", n=new_batch_index + 1),
1266
+ score_displays[0], score_displays[1], score_displays[2], score_displays[3],
1267
+ score_displays[4], score_displays[5], score_displays[6], score_displays[7],
1268
+ gr.update(interactive=True),
1269
+ )
1270
+
1271
+
1272
+ def navigate_to_next_batch(autogen_enabled, current_batch_index, total_batches, batch_queue):
1273
+ """Navigate to next batch (Result View Only - Never touches Input UI)"""
1274
+ if current_batch_index >= total_batches - 1:
1275
+ gr.Warning(t("messages.at_last_batch"))
1276
+ return [gr.update()] * 25
1277
+
1278
+ # Move to next batch
1279
+ new_batch_index = current_batch_index + 1
1280
+
1281
+ # Load batch data from queue
1282
+ if new_batch_index not in batch_queue:
1283
+ gr.Warning(t("messages.batch_not_found", n=new_batch_index + 1))
1284
+ return [gr.update()] * 25
1285
+
1286
+ batch_data = batch_queue[new_batch_index]
1287
+ audio_paths = batch_data.get("audio_paths", [])
1288
+ generation_info_text = batch_data.get("generation_info", "")
1289
+
1290
+ # Prepare audio outputs (up to 8)
1291
+ audio_outputs = [None] * 8
1292
+ for idx in range(min(len(audio_paths), 8)):
1293
+ audio_outputs[idx] = audio_paths[idx]
1294
+
1295
+ # Update batch indicator
1296
+ batch_indicator_text = update_batch_indicator(new_batch_index, total_batches)
1297
+
1298
+ # Update button states
1299
+ can_go_previous, can_go_next = update_navigation_buttons(new_batch_index, total_batches)
1300
+
1301
+ # Prepare next batch status message
1302
+ next_batch_status_text = ""
1303
+ is_latest_view = (new_batch_index == total_batches - 1)
1304
+ if autogen_enabled and is_latest_view:
1305
+ next_batch_status_text = "🔄 AutoGen will generate next batch in background..."
1306
+
1307
+ # Restore score displays from batch queue
1308
+ stored_scores = batch_data.get("scores", [""] * 8)
1309
+ score_displays = stored_scores if stored_scores else [""] * 8
1310
+
1311
+ return (
1312
+ audio_outputs[0], audio_outputs[1], audio_outputs[2], audio_outputs[3],
1313
+ audio_outputs[4], audio_outputs[5], audio_outputs[6], audio_outputs[7],
1314
+ audio_paths, generation_info_text, new_batch_index, batch_indicator_text,
1315
+ gr.update(interactive=can_go_previous), gr.update(interactive=can_go_next),
1316
+ t("messages.viewing_batch", n=new_batch_index + 1), next_batch_status_text,
1317
+ score_displays[0], score_displays[1], score_displays[2], score_displays[3],
1318
+ score_displays[4], score_displays[5], score_displays[6], score_displays[7],
1319
+ gr.update(interactive=True),
1320
+ )
1321
+
1322
+
1323
+ def restore_batch_parameters(current_batch_index, batch_queue):
1324
+ """
1325
+ Restore parameters from currently viewed batch to Input UI.
1326
+ This is the bridge allowing users to "reuse" historical settings.
1327
+ """
1328
+ if current_batch_index not in batch_queue:
1329
+ gr.Warning(t("messages.no_batch_data"))
1330
+ return [gr.update()] * 29
1331
+
1332
+ batch_data = batch_queue[current_batch_index]
1333
+ params = batch_data.get("generation_params", {})
1334
+
1335
+ # Extract all parameters with defaults
1336
+ captions = params.get("captions", "")
1337
+ lyrics = params.get("lyrics", "")
1338
+ bpm = params.get("bpm", None)
1339
+ key_scale = params.get("key_scale", "")
1340
+ time_signature = params.get("time_signature", "")
1341
+ vocal_language = params.get("vocal_language", "unknown")
1342
+ audio_duration = params.get("audio_duration", -1)
1343
+ batch_size_input = params.get("batch_size_input", 2)
1344
+ inference_steps = params.get("inference_steps", 8)
1345
+ lm_temperature = params.get("lm_temperature", 0.85)
1346
+ lm_cfg_scale = params.get("lm_cfg_scale", 2.0)
1347
+ lm_top_k = params.get("lm_top_k", 0)
1348
+ lm_top_p = params.get("lm_top_p", 0.9)
1349
+ think_checkbox = params.get("think_checkbox", True)
1350
+ use_cot_caption = params.get("use_cot_caption", True)
1351
+ use_cot_language = params.get("use_cot_language", True)
1352
+ allow_lm_batch = params.get("allow_lm_batch", True)
1353
+ track_name = params.get("track_name", None)
1354
+ complete_track_classes = params.get("complete_track_classes", [])
1355
+
1356
+ # Extract and process codes
1357
+ stored_codes = batch_data.get("codes", "")
1358
+ stored_allow_lm_batch = params.get("allow_lm_batch", False)
1359
+
1360
+ codes_outputs = [""] * 9 # [Main, 1-8]
1361
+ if stored_codes:
1362
+ if stored_allow_lm_batch and isinstance(stored_codes, list):
1363
+ # Batch mode: populate codes 1-8, main shows first
1364
+ codes_outputs[0] = stored_codes[0] if stored_codes else ""
1365
+ for idx in range(min(len(stored_codes), 8)):
1366
+ codes_outputs[idx + 1] = stored_codes[idx]
1367
+ else:
1368
+ # Single mode: populate main, clear 1-8
1369
+ codes_outputs[0] = stored_codes if isinstance(stored_codes, str) else (stored_codes[0] if stored_codes else "")
1370
+
1371
+ gr.Info(t("messages.params_restored", n=current_batch_index + 1))
1372
+
1373
+ return (
1374
+ codes_outputs[0], codes_outputs[1], codes_outputs[2], codes_outputs[3],
1375
+ codes_outputs[4], codes_outputs[5], codes_outputs[6], codes_outputs[7],
1376
+ codes_outputs[8], captions, lyrics, bpm, key_scale, time_signature,
1377
+ vocal_language, audio_duration, batch_size_input, inference_steps,
1378
+ lm_temperature, lm_cfg_scale, lm_top_k, lm_top_p, think_checkbox,
1379
+ use_cot_caption, use_cot_language, allow_lm_batch,
1380
+ track_name, complete_track_classes
1381
+ )
acestep/gradio_ui/i18n.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Internationalization (i18n) module for Gradio UI
3
+ Supports multiple languages with easy translation management
4
+ """
5
+ import os
6
+ import json
7
+ from typing import Dict, Optional
8
+
9
+
10
+ class I18n:
11
+ """Internationalization handler"""
12
+
13
+ def __init__(self, default_language: str = "en"):
14
+ """
15
+ Initialize i18n handler
16
+
17
+ Args:
18
+ default_language: Default language code (en, zh, ja, etc.)
19
+ """
20
+ self.current_language = default_language
21
+ self.translations: Dict[str, Dict[str, str]] = {}
22
+ self._load_all_translations()
23
+
24
+ def _load_all_translations(self):
25
+ """Load all translation files from i18n directory"""
26
+ current_file = os.path.abspath(__file__)
27
+ module_dir = os.path.dirname(current_file)
28
+ i18n_dir = os.path.join(module_dir, "i18n")
29
+
30
+ if not os.path.exists(i18n_dir):
31
+ # Create i18n directory if it doesn't exist
32
+ os.makedirs(i18n_dir)
33
+ return
34
+
35
+ # Load all JSON files in i18n directory
36
+ for filename in os.listdir(i18n_dir):
37
+ if filename.endswith(".json"):
38
+ lang_code = filename[:-5] # Remove .json extension
39
+ filepath = os.path.join(i18n_dir, filename)
40
+ try:
41
+ with open(filepath, 'r', encoding='utf-8') as f:
42
+ self.translations[lang_code] = json.load(f)
43
+ except Exception as e:
44
+ print(f"Error loading translation file {filename}: {e}")
45
+
46
+ def set_language(self, language: str):
47
+ """Set current language"""
48
+ if language in self.translations:
49
+ self.current_language = language
50
+ else:
51
+ print(f"Warning: Language '{language}' not found, using default")
52
+
53
+ def t(self, key: str, **kwargs) -> str:
54
+ """
55
+ Translate a key to current language
56
+
57
+ Args:
58
+ key: Translation key (dot-separated for nested keys)
59
+ **kwargs: Optional format parameters
60
+
61
+ Returns:
62
+ Translated string
63
+ """
64
+ # Get translation from current language
65
+ translation = self._get_nested_value(
66
+ self.translations.get(self.current_language, {}),
67
+ key
68
+ )
69
+
70
+ # Fallback to English if not found
71
+ if translation is None:
72
+ translation = self._get_nested_value(
73
+ self.translations.get('en', {}),
74
+ key
75
+ )
76
+
77
+ # Final fallback to key itself
78
+ if translation is None:
79
+ translation = key
80
+
81
+ # Apply formatting if kwargs provided
82
+ if kwargs:
83
+ try:
84
+ translation = translation.format(**kwargs)
85
+ except KeyError:
86
+ pass
87
+
88
+ return translation
89
+
90
+ def _get_nested_value(self, data: dict, key: str) -> Optional[str]:
91
+ """
92
+ Get nested dictionary value using dot notation
93
+
94
+ Args:
95
+ data: Dictionary to search
96
+ key: Dot-separated key (e.g., "section.subsection.key")
97
+
98
+ Returns:
99
+ Value if found, None otherwise
100
+ """
101
+ keys = key.split('.')
102
+ current = data
103
+
104
+ for k in keys:
105
+ if isinstance(current, dict) and k in current:
106
+ current = current[k]
107
+ else:
108
+ return None
109
+
110
+ return current if isinstance(current, str) else None
111
+
112
+ def get_available_languages(self) -> list:
113
+ """Get list of available language codes"""
114
+ return list(self.translations.keys())
115
+
116
+
117
+ # Global i18n instance
118
+ _i18n_instance: Optional[I18n] = None
119
+
120
+
121
+ def get_i18n(language: Optional[str] = None) -> I18n:
122
+ """
123
+ Get global i18n instance
124
+
125
+ Args:
126
+ language: Optional language to set
127
+
128
+ Returns:
129
+ I18n instance
130
+ """
131
+ global _i18n_instance
132
+
133
+ if _i18n_instance is None:
134
+ _i18n_instance = I18n(default_language=language or "en")
135
+ elif language is not None:
136
+ _i18n_instance.set_language(language)
137
+
138
+ return _i18n_instance
139
+
140
+
141
+ def t(key: str, **kwargs) -> str:
142
+ """
143
+ Convenience function for translation
144
+
145
+ Args:
146
+ key: Translation key
147
+ **kwargs: Optional format parameters
148
+
149
+ Returns:
150
+ Translated string
151
+ """
152
+ return get_i18n().t(key, **kwargs)
acestep/gradio_ui/i18n/en.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 Playground💡",
4
+ "subtitle": "Pushing the Boundaries of Open-Source Music Generation"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 Dataset Explorer",
8
+ "dataset_label": "Dataset",
9
+ "dataset_info": "Choose dataset to explore",
10
+ "import_btn": "📥 Import Dataset",
11
+ "search_type_label": "Search Type",
12
+ "search_type_info": "How to find items",
13
+ "search_value_label": "Search Value",
14
+ "search_value_placeholder": "Enter keys or index (leave empty for random)",
15
+ "search_value_info": "Keys: exact match, Index: 0 to dataset size-1",
16
+ "instruction_label": "📝 Instruction",
17
+ "instruction_placeholder": "No instruction available",
18
+ "metadata_title": "📋 Item Metadata (JSON)",
19
+ "metadata_label": "Complete Item Information",
20
+ "source_audio": "Source Audio",
21
+ "target_audio": "Target Audio",
22
+ "reference_audio": "Reference Audio",
23
+ "get_item_btn": "🔍 Get Item",
24
+ "use_src_checkbox": "Use Source Audio from Dataset",
25
+ "use_src_info": "Check to use the source audio from dataset",
26
+ "data_status_label": "📊 Data Status",
27
+ "data_status_default": "❌ No dataset imported",
28
+ "autofill_btn": "📋 Auto-fill Generation Form"
29
+ },
30
+ "service": {
31
+ "title": "🔧 Service Configuration",
32
+ "checkpoint_label": "Checkpoint File",
33
+ "checkpoint_info": "Select a trained model checkpoint file (full path or filename)",
34
+ "refresh_btn": "🔄 Refresh",
35
+ "model_path_label": "Main Model Path",
36
+ "model_path_info": "Select the model configuration directory (auto-scanned from checkpoints)",
37
+ "device_label": "Device",
38
+ "device_info": "Processing device (auto-detect recommended)",
39
+ "lm_model_path_label": "5Hz LM Model Path",
40
+ "lm_model_path_info": "Select the 5Hz LM model checkpoint (auto-scanned from checkpoints)",
41
+ "backend_label": "5Hz LM Backend",
42
+ "backend_info": "Select backend for 5Hz LM: vllm (faster) or pt (PyTorch, more compatible)",
43
+ "init_llm_label": "Initialize 5Hz LM",
44
+ "init_llm_info": "Check to initialize 5Hz LM during service initialization",
45
+ "flash_attention_label": "Use Flash Attention",
46
+ "flash_attention_info_enabled": "Enable flash attention for faster inference (requires flash_attn package)",
47
+ "flash_attention_info_disabled": "Flash attention not available (flash_attn package not installed)",
48
+ "offload_cpu_label": "Offload to CPU",
49
+ "offload_cpu_info": "Offload models to CPU when not in use to save GPU memory",
50
+ "offload_dit_cpu_label": "Offload DiT to CPU",
51
+ "offload_dit_cpu_info": "Offload DiT to CPU (needs Offload to CPU)",
52
+ "init_btn": "Initialize Service",
53
+ "status_label": "Status",
54
+ "language_label": "UI Language",
55
+ "language_info": "Select interface language"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 Required Inputs",
59
+ "task_type_label": "Task Type",
60
+ "task_type_info": "Select the task type for generation",
61
+ "instruction_label": "Instruction",
62
+ "instruction_info": "Instruction is automatically generated based on task type",
63
+ "load_btn": "Load",
64
+ "track_name_label": "Track Name",
65
+ "track_name_info": "Select track name for lego/extract tasks",
66
+ "track_classes_label": "Track Names",
67
+ "track_classes_info": "Select multiple track classes for complete task",
68
+ "audio_uploads": "🎵 Audio Uploads",
69
+ "reference_audio": "Reference Audio (optional)",
70
+ "source_audio": "Source Audio (optional)",
71
+ "convert_codes_btn": "Convert to Codes",
72
+ "lm_codes_hints": "🎼 LM Codes Hints",
73
+ "lm_codes_label": "LM Codes Hints",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "Paste LM codes hints for text2music generation",
76
+ "lm_codes_sample": "LM Codes Hints (Sample {n})",
77
+ "lm_codes_sample_info": "Codes for sample {n}",
78
+ "transcribe_btn": "Transcribe",
79
+ "repainting_controls": "🎨 Repainting Controls (seconds)",
80
+ "repainting_start": "Repainting Start",
81
+ "repainting_end": "Repainting End",
82
+ "caption_title": "📝 Music Caption",
83
+ "caption_label": "Music Caption (optional)",
84
+ "caption_placeholder": "A peaceful acoustic guitar melody with soft vocals...",
85
+ "caption_info": "Describe the style, genre, instruments, and mood",
86
+ "sample_btn": "Sample",
87
+ "lyrics_title": "📝 Lyrics",
88
+ "lyrics_label": "Lyrics (optional)",
89
+ "lyrics_placeholder": "[Verse 1]\\nUnder the starry night\\nI feel so alive...",
90
+ "lyrics_info": "Song lyrics with structure",
91
+ "instrumental_label": "Instrumental",
92
+ "optional_params": "⚙️ Optional Parameters",
93
+ "vocal_language_label": "Vocal Language (optional)",
94
+ "vocal_language_info": "use `unknown` for inst",
95
+ "bpm_label": "BPM (optional)",
96
+ "bpm_info": "leave empty for N/A",
97
+ "keyscale_label": "KeyScale (optional)",
98
+ "keyscale_placeholder": "Leave empty for N/A",
99
+ "keyscale_info": "A-G, #/♭, major/minor",
100
+ "timesig_label": "Time Signature (optional)",
101
+ "timesig_info": "2/4, 3/4, 4/4...",
102
+ "duration_label": "Audio Duration (seconds)",
103
+ "duration_info": "Use -1 for random",
104
+ "batch_size_label": "Batch Size",
105
+ "batch_size_info": "Number of audio to generate (max 8)",
106
+ "advanced_settings": "🔧 Advanced Settings",
107
+ "inference_steps_label": "DiT Inference Steps",
108
+ "inference_steps_info": "Turbo: max 8, Base: max 100",
109
+ "guidance_scale_label": "DiT Guidance Scale (Only support for base model)",
110
+ "guidance_scale_info": "Higher values follow text more closely",
111
+ "seed_label": "Seed",
112
+ "seed_info": "Use comma-separated values for batches",
113
+ "random_seed_label": "Random Seed",
114
+ "random_seed_info": "Enable to auto-generate seeds",
115
+ "audio_format_label": "Audio Format",
116
+ "audio_format_info": "Audio format for saved files",
117
+ "use_adg_label": "Use ADG",
118
+ "use_adg_info": "Enable Angle Domain Guidance",
119
+ "cfg_interval_start": "CFG Interval Start",
120
+ "cfg_interval_end": "CFG Interval End",
121
+ "lm_params_title": "🤖 LM Generation Parameters",
122
+ "lm_temperature_label": "LM Temperature",
123
+ "lm_temperature_info": "5Hz LM temperature (higher = more random)",
124
+ "lm_cfg_scale_label": "LM CFG Scale",
125
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = no CFG)",
126
+ "lm_top_k_label": "LM Top-K",
127
+ "lm_top_k_info": "Top-K (0 = disabled)",
128
+ "lm_top_p_label": "LM Top-P",
129
+ "lm_top_p_info": "Top-P (1.0 = disabled)",
130
+ "lm_negative_prompt_label": "LM Negative Prompt",
131
+ "lm_negative_prompt_placeholder": "Enter negative prompt for CFG (default: NO USER INPUT)",
132
+ "lm_negative_prompt_info": "Negative prompt (use when LM CFG Scale > 1.0)",
133
+ "cot_metas_label": "CoT Metas",
134
+ "cot_metas_info": "Use LM to generate CoT metadata (uncheck to skip LM CoT generation)",
135
+ "cot_language_label": "CoT Language",
136
+ "cot_language_info": "Generate language in CoT (chain-of-thought)",
137
+ "constrained_debug_label": "Constrained Decoding Debug",
138
+ "constrained_debug_info": "Enable debug logging for constrained decoding (check to see detailed logs)",
139
+ "auto_score_label": "Auto Score",
140
+ "auto_score_info": "Automatically calculate quality scores for all generated audios",
141
+ "lm_batch_chunk_label": "LM Batch Chunk Size",
142
+ "lm_batch_chunk_info": "Max items per LM batch chunk (default: 8, limited by GPU memory)",
143
+ "codes_strength_label": "LM Codes Strength",
144
+ "codes_strength_info": "Control how many denoising steps use LM-generated codes",
145
+ "cover_strength_label": "Audio Cover Strength",
146
+ "cover_strength_info": "Control how many denoising steps use cover mode",
147
+ "score_sensitivity_label": "Quality Score Sensitivity",
148
+ "score_sensitivity_info": "Lower = more sensitive (default: 1.0). Adjusts how PMI maps to [0,1]",
149
+ "attention_focus_label": "Output Attention Focus Score (disabled)",
150
+ "attention_focus_info": "Output attention focus score analysis",
151
+ "think_label": "Think",
152
+ "parallel_thinking_label": "ParallelThinking",
153
+ "generate_btn": "🎵 Generate Music",
154
+ "autogen_label": "AutoGen",
155
+ "caption_rewrite_label": "CaptionRewrite"
156
+ },
157
+ "results": {
158
+ "title": "🎵 Results",
159
+ "generated_music": "🎵 Generated Music (Sample {n})",
160
+ "send_to_src_btn": "🔗 Send To Src Audio",
161
+ "save_btn": "💾 Save",
162
+ "score_btn": "📊 Score",
163
+ "quality_score_label": "Quality Score (Sample {n})",
164
+ "quality_score_placeholder": "Click 'Score' to calculate perplexity-based quality score",
165
+ "generation_status": "Generation Status",
166
+ "current_batch": "Current Batch",
167
+ "batch_indicator": "Batch {current} / {total}",
168
+ "next_batch_status": "Next Batch Status",
169
+ "prev_btn": "◀ Previous",
170
+ "next_btn": "Next ▶",
171
+ "restore_params_btn": "↙️ Apply These Settings to UI (Restore Batch Parameters)",
172
+ "batch_results_title": "📁 Batch Results & Generation Details",
173
+ "all_files_label": "📁 All Generated Files (Download)",
174
+ "generation_details": "Generation Details",
175
+ "attention_analysis": "⚖️ Attention Focus Score Analysis",
176
+ "attention_score": "Attention Focus Score (Sample {n})",
177
+ "lyric_timestamps": "Lyric Timestamps (Sample {n})",
178
+ "attention_heatmap": "Attention Focus Score Heatmap (Sample {n})"
179
+ },
180
+ "messages": {
181
+ "no_audio_to_save": "❌ No audio to save",
182
+ "save_success": "✅ Saved audio and metadata to {filename}",
183
+ "save_failed": "❌ Failed to save: {error}",
184
+ "no_file_selected": "⚠️ No file selected",
185
+ "params_loaded": "✅ Parameters loaded from {filename}",
186
+ "invalid_json": "❌ Invalid JSON file: {error}",
187
+ "load_error": "❌ Error loading file: {error}",
188
+ "example_loaded": "📁 Loaded example from {filename}",
189
+ "example_failed": "Failed to parse JSON file {filename}: {error}",
190
+ "example_error": "Error loading example: {error}",
191
+ "lm_generated": "🤖 Generated example using LM",
192
+ "lm_fallback": "Failed to generate example using LM, falling back to examples directory",
193
+ "lm_not_initialized": "❌ 5Hz LM not initialized. Please initialize it first.",
194
+ "autogen_enabled": "🔄 AutoGen enabled - next batch will generate after this",
195
+ "batch_ready": "✅ Batch {n} ready! Click 'Next' to view.",
196
+ "batch_generating": "🔄 Starting background generation for Batch {n}...",
197
+ "batch_failed": "❌ Background generation failed: {error}",
198
+ "viewing_batch": "✅ Viewing Batch {n}",
199
+ "at_first_batch": "Already at first batch",
200
+ "at_last_batch": "No next batch available",
201
+ "batch_not_found": "Batch {n} not found in queue",
202
+ "no_batch_data": "No batch data found to restore.",
203
+ "params_restored": "✅ UI Parameters restored from Batch {n}",
204
+ "scoring_failed": "❌ Error: Batch data not found",
205
+ "no_codes": "❌ No audio codes available. Please generate music first.",
206
+ "score_failed": "❌ Scoring failed: {error}",
207
+ "score_error": "❌ Error calculating score: {error}"
208
+ }
209
+ }
acestep/gradio_ui/i18n/ja.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 プレイグラウンド💡",
4
+ "subtitle": "オープンソース音楽生成の限界を押し広げる"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 データセットエクスプローラー",
8
+ "dataset_label": "データセット",
9
+ "dataset_info": "探索するデータセットを選択",
10
+ "import_btn": "📥 データセットをインポート",
11
+ "search_type_label": "検索タイプ",
12
+ "search_type_info": "アイテムの検索方法",
13
+ "search_value_label": "検索値",
14
+ "search_value_placeholder": "キーまたはインデックスを入力(空白の場合はランダム)",
15
+ "search_value_info": "キー: 完全一致、インデックス: 0からデータセットサイズ-1",
16
+ "instruction_label": "📝 指示",
17
+ "instruction_placeholder": "利用可能な指示がありません",
18
+ "metadata_title": "📋 アイテムメタデータ (JSON)",
19
+ "metadata_label": "完全なアイテム情報",
20
+ "source_audio": "ソースオーディオ",
21
+ "target_audio": "ターゲットオーディオ",
22
+ "reference_audio": "リファレンスオーディオ",
23
+ "get_item_btn": "🔍 アイテムを取得",
24
+ "use_src_checkbox": "データセットのソースオーディオを使用",
25
+ "use_src_info": "データセットのソースオーディオを使用する場合はチェック",
26
+ "data_status_label": "📊 データステータス",
27
+ "data_status_default": "❌ データセットがインポートされていません",
28
+ "autofill_btn": "📋 生成フォームを自動入力"
29
+ },
30
+ "service": {
31
+ "title": "🔧 サービス設定",
32
+ "checkpoint_label": "チェックポイントファイル",
33
+ "checkpoint_info": "訓練済みモデルのチェックポイントファイルを選択(フルパスまたはファイル名)",
34
+ "refresh_btn": "🔄 更新",
35
+ "model_path_label": "メインモデルパス",
36
+ "model_path_info": "モデル設定ディレクトリを選択(チェックポイントから自動スキャン)",
37
+ "device_label": "デバイス",
38
+ "device_info": "処理デバイス(自動検出を推奨)",
39
+ "lm_model_path_label": "5Hz LM モデルパス",
40
+ "lm_model_path_info": "5Hz LMモデルチェックポイントを選択(チェックポイントから自動スキャン)",
41
+ "backend_label": "5Hz LM バックエンド",
42
+ "backend_info": "5Hz LMのバックエンドを選択: vllm(高速)またはpt(PyTorch、より互換性あり)",
43
+ "init_llm_label": "5Hz LM を初期化",
44
+ "init_llm_info": "サービス初期化中に5Hz LMを初期化する場合はチェック",
45
+ "flash_attention_label": "Flash Attention を使用",
46
+ "flash_attention_info_enabled": "推論を高速化するためにflash attentionを有効にする(flash_attnパッケージが必要)",
47
+ "flash_attention_info_disabled": "Flash attentionは利用できません(flash_attnパッケージがインストールされていません)",
48
+ "offload_cpu_label": "CPUにオフロード",
49
+ "offload_cpu_info": "使用していない時にモデルをCPUにオフロードしてGPUメモリを節約",
50
+ "offload_dit_cpu_label": "DiTをCPUにオフロード",
51
+ "offload_dit_cpu_info": "DiTをCPUにオフロード(CPUへのオフロードが必要)",
52
+ "init_btn": "サービスを初期化",
53
+ "status_label": "ステータス",
54
+ "language_label": "UI言語",
55
+ "language_info": "インターフェース言語を選択"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必須入力",
59
+ "task_type_label": "タスクタイプ",
60
+ "task_type_info": "生成のタスクタイプを選択",
61
+ "instruction_label": "指示",
62
+ "instruction_info": "指示はタスクタイプに基づいて自動生成されます",
63
+ "load_btn": "読み込む",
64
+ "track_name_label": "トラック名",
65
+ "track_name_info": "lego/extractタスクのトラック名を選択",
66
+ "track_classes_label": "トラック名",
67
+ "track_classes_info": "completeタスクの複数のトラッククラスを選択",
68
+ "audio_uploads": "🎵 オーディオアップロード",
69
+ "reference_audio": "リファレンスオーディオ(オプション)",
70
+ "source_audio": "ソースオーディオ(オプション)",
71
+ "convert_codes_btn": "コードに変換",
72
+ "lm_codes_hints": "🎼 LM コードヒント",
73
+ "lm_codes_label": "LM コードヒント",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "text2music生成用のLMコードヒントを貼り付け",
76
+ "lm_codes_sample": "LM コードヒント(サンプル {n})",
77
+ "lm_codes_sample_info": "サンプル{n}のコード",
78
+ "transcribe_btn": "転写",
79
+ "repainting_controls": "🎨 再描画コントロール(秒)",
80
+ "repainting_start": "再描画開始",
81
+ "repainting_end": "再描画終了",
82
+ "caption_title": "📝 音楽キャプション",
83
+ "caption_label": "音楽キャプション(���プション)",
84
+ "caption_placeholder": "柔らかいボーカルを伴う穏やかなアコースティックギターのメロディー...",
85
+ "caption_info": "スタイル、ジャンル、楽器、ムードを説明",
86
+ "sample_btn": "サンプル",
87
+ "lyrics_title": "📝 歌詞",
88
+ "lyrics_label": "歌詞(オプション)",
89
+ "lyrics_placeholder": "[バース1]\\n星空の下で\\nとても生きていると感じる...",
90
+ "lyrics_info": "構造を持つ曲の歌詞",
91
+ "instrumental_label": "インストゥルメンタル",
92
+ "optional_params": "⚙️ オプションパラメータ",
93
+ "vocal_language_label": "ボーカル言語(オプション)",
94
+ "vocal_language_info": "インストには`unknown`を使用",
95
+ "bpm_label": "BPM(オプション)",
96
+ "bpm_info": "空白の場合はN/A",
97
+ "keyscale_label": "キースケール(オプション)",
98
+ "keyscale_placeholder": "空白の場合はN/A",
99
+ "keyscale_info": "A-G, #/♭, メジャー/マイナー",
100
+ "timesig_label": "拍子記号(オプション)",
101
+ "timesig_info": "2/4, 3/4, 4/4...",
102
+ "duration_label": "オーディオ長(秒)",
103
+ "duration_info": "ランダムの場合は-1を使用",
104
+ "batch_size_label": "バッチサイズ",
105
+ "batch_size_info": "生成するオーディオの数(最大8)",
106
+ "advanced_settings": "🔧 詳細設定",
107
+ "inference_steps_label": "DiT 推論ステップ",
108
+ "inference_steps_info": "Turbo: 最大8、Base: 最大100",
109
+ "guidance_scale_label": "DiT ガイダンススケール(baseモデルのみサポート)",
110
+ "guidance_scale_info": "値が高いほどテキストに忠実に従う",
111
+ "seed_label": "シード",
112
+ "seed_info": "バッチにはカンマ区切りの値を使用",
113
+ "random_seed_label": "ランダムシード",
114
+ "random_seed_info": "有効にすると自動的にシードを生成",
115
+ "audio_format_label": "オーディオフォーマット",
116
+ "audio_format_info": "保存ファイルのオーディオフォーマット",
117
+ "use_adg_label": "ADG を使用",
118
+ "use_adg_info": "角度ドメインガイダンスを有効化",
119
+ "cfg_interval_start": "CFG 間隔開始",
120
+ "cfg_interval_end": "CFG 間隔終了",
121
+ "lm_params_title": "🤖 LM 生成パラメータ",
122
+ "lm_temperature_label": "LM 温度",
123
+ "lm_temperature_info": "5Hz LM温度(高いほどランダム)",
124
+ "lm_cfg_scale_label": "LM CFG スケール",
125
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = CFGなし)",
126
+ "lm_top_k_label": "LM Top-K",
127
+ "lm_top_k_info": "Top-K (0 = 無効)",
128
+ "lm_top_p_label": "LM Top-P",
129
+ "lm_top_p_info": "Top-P (1.0 = 無効)",
130
+ "lm_negative_prompt_label": "LM ネガティブプロンプト",
131
+ "lm_negative_prompt_placeholder": "CFGのネガティブプロンプトを入力(デフォルト: NO USER INPUT)",
132
+ "lm_negative_prompt_info": "ネガティブプロンプト(LM CFGスケール > 1.0の場合に使用)",
133
+ "cot_metas_label": "CoT メタデータ",
134
+ "cot_metas_info": "LMを使用してCoTメタデータを生成(チェックを外すとLM CoT生成をスキップ)",
135
+ "cot_language_label": "CoT 言語",
136
+ "cot_language_info": "CoTで言語を生成(思考の連鎖)",
137
+ "constrained_debug_label": "制約付きデコーディングデバッグ",
138
+ "constrained_debug_info": "制約付きデコーディングのデバッグログを有効化(チェックすると詳細ログを表示)",
139
+ "auto_score_label": "自動スコアリング",
140
+ "auto_score_info": "生成されたすべてのオーディオの品質スコアを自動計算",
141
+ "lm_batch_chunk_label": "LM バッチチャンクサイズ",
142
+ "lm_batch_chunk_info": "LMバッチチャンクあたりの最大アイテム数(デフォルト: 8、GPUメモリによる制限)",
143
+ "codes_strength_label": "LM コード強度",
144
+ "codes_strength_info": "LM生成コードを使用するデノイジングステップ数を制御",
145
+ "cover_strength_label": "オーディオカバー強度",
146
+ "cover_strength_info": "カバーモードを使用するデノイジングステップ数を制御",
147
+ "score_sensitivity_label": "品質スコア感度",
148
+ "score_sensitivity_info": "低い = より敏感(デフォルト: 1.0)。PMIが[0,1]にマッピングする方法を調整",
149
+ "attention_focus_label": "注意焦点スコアを出力(無効)",
150
+ "attention_focus_info": "注意焦点スコア分析を出力",
151
+ "think_label": "思考",
152
+ "parallel_thinking_label": "並列思考",
153
+ "generate_btn": "🎵 音楽を生成",
154
+ "autogen_label": "自動生成",
155
+ "caption_rewrite_label": "キャプション書き換え"
156
+ },
157
+ "results": {
158
+ "title": "🎵 結果",
159
+ "generated_music": "🎵 生成された音楽(サンプル {n})",
160
+ "send_to_src_btn": "🔗 ソースオーディオに送信",
161
+ "save_btn": "💾 保存",
162
+ "score_btn": "📊 スコア",
163
+ "quality_score_label": "品質スコア(サンプル {n})",
164
+ "quality_score_placeholder": "'��コア'をクリックしてパープレキシティベースの品質スコアを計算",
165
+ "generation_status": "生成ステータス",
166
+ "current_batch": "現在のバッチ",
167
+ "batch_indicator": "バッチ {current} / {total}",
168
+ "next_batch_status": "次のバッチステータス",
169
+ "prev_btn": "◀ 前へ",
170
+ "next_btn": "次へ ▶",
171
+ "restore_params_btn": "↙️ これらの設定をUIに適用(バッチパラメータを復元)",
172
+ "batch_results_title": "📁 バッチ結果と生成詳細",
173
+ "all_files_label": "📁 すべての生成ファイル(ダウンロード)",
174
+ "generation_details": "生成詳細",
175
+ "attention_analysis": "⚖️ 注意焦点スコア分析",
176
+ "attention_score": "注意焦点スコア(サンプル {n})",
177
+ "lyric_timestamps": "歌詞タイムスタンプ(サンプル {n})",
178
+ "attention_heatmap": "注意焦点スコアヒートマップ(サンプル {n})"
179
+ },
180
+ "messages": {
181
+ "no_audio_to_save": "❌ 保存するオーディオがありません",
182
+ "save_success": "✅ オーディオとメタデータを {filename} に保存しました",
183
+ "save_failed": "❌ 保存に失敗しました: {error}",
184
+ "no_file_selected": "⚠️ ファイルが選択されていません",
185
+ "params_loaded": "✅ {filename} からパラメータを読み込みました",
186
+ "invalid_json": "❌ 無効なJSONファイル: {error}",
187
+ "load_error": "❌ ファイルの読み込みエラー: {error}",
188
+ "example_loaded": "📁 {filename} からサンプルを読み込みました",
189
+ "example_failed": "JSONファイル {filename} の解析に失敗しました: {error}",
190
+ "example_error": "サンプル読み込みエラー: {error}",
191
+ "lm_generated": "🤖 LMを使用してサンプルを生成しました",
192
+ "lm_fallback": "LMを使用したサンプル生成に失敗、サンプルディレクトリにフォールバック",
193
+ "lm_not_initialized": "❌ 5Hz LMが初期化されていません。最初に初期化してください。",
194
+ "autogen_enabled": "🔄 自動生成が有効 - このあと次のバッチを生成します",
195
+ "batch_ready": "✅ バッチ {n} の準備完了!'次へ'をクリックして表示。",
196
+ "batch_generating": "🔄 バッチ {n} のバックグラウンド生成を開始...",
197
+ "batch_failed": "❌ バックグラウンド生成に失敗しました: {error}",
198
+ "viewing_batch": "✅ バッチ {n} を表示中",
199
+ "at_first_batch": "すでに最初のバッチです",
200
+ "at_last_batch": "次のバッチはありません",
201
+ "batch_not_found": "キューにバッチ {n} が見つかりません",
202
+ "no_batch_data": "復元するバッチデータがありません。",
203
+ "params_restored": "✅ バッチ {n} からUIパラメータを復元しました",
204
+ "scoring_failed": "❌ エラー: バッチデータが見つかりません",
205
+ "no_codes": "❌ 利用可能なオーディオコードがありません。最初に音楽を生成してください。",
206
+ "score_failed": "❌ スコアリングに失敗しました: {error}",
207
+ "score_error": "❌ スコア計算エラー: {error}"
208
+ }
209
+ }
acestep/gradio_ui/i18n/zh.json ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "app": {
3
+ "title": "🎛️ ACE-Step V1.5 演练场💡",
4
+ "subtitle": "推动开源音乐生成的边界"
5
+ },
6
+ "dataset": {
7
+ "title": "📊 数据集浏览器",
8
+ "dataset_label": "数据集",
9
+ "dataset_info": "选择要浏览的数据集",
10
+ "import_btn": "📥 导入数据集",
11
+ "search_type_label": "搜索类型",
12
+ "search_type_info": "如何查找项目",
13
+ "search_value_label": "搜索值",
14
+ "search_value_placeholder": "输入键或索引(留空表示随机)",
15
+ "search_value_info": "键: 精确匹配, 索引: 0到数据集大小-1",
16
+ "instruction_label": "📝 指令",
17
+ "instruction_placeholder": "无可用指令",
18
+ "metadata_title": "📋 项目元数据 (JSON)",
19
+ "metadata_label": "完整项目信息",
20
+ "source_audio": "源音频",
21
+ "target_audio": "目标音频",
22
+ "reference_audio": "参考音频",
23
+ "get_item_btn": "🔍 获取项目",
24
+ "use_src_checkbox": "使用数据集中的源音频",
25
+ "use_src_info": "勾选以使用数据集中的源音频",
26
+ "data_status_label": "📊 数据状态",
27
+ "data_status_default": "❌ 未导入数据集",
28
+ "autofill_btn": "📋 自动填充生成表单"
29
+ },
30
+ "service": {
31
+ "title": "🔧 服务配置",
32
+ "checkpoint_label": "检查点文件",
33
+ "checkpoint_info": "选择训练好的模型检查点文件(完整路径或文件名)",
34
+ "refresh_btn": "🔄 刷新",
35
+ "model_path_label": "主模型路径",
36
+ "model_path_info": "选择模型配置目录(从检查点自动扫描)",
37
+ "device_label": "设备",
38
+ "device_info": "处理设备(建议自动检测)",
39
+ "lm_model_path_label": "5Hz LM 模型路径",
40
+ "lm_model_path_info": "选择5Hz LM模型检查点(从检查点自动扫描)",
41
+ "backend_label": "5Hz LM 后端",
42
+ "backend_info": "选择5Hz LM的后端: vllm(更快)或pt(PyTorch, 更兼容)",
43
+ "init_llm_label": "初始化 5Hz LM",
44
+ "init_llm_info": "勾选以在服务初始化期间初始化5Hz LM",
45
+ "flash_attention_label": "使用Flash Attention",
46
+ "flash_attention_info_enabled": "启用flash attention以加快推理速度(需要flash_attn包)",
47
+ "flash_attention_info_disabled": "Flash attention不可用(未安装flash_attn包)",
48
+ "offload_cpu_label": "卸载到CPU",
49
+ "offload_cpu_info": "不使用时将模型卸载到CPU以节省GPU内存",
50
+ "offload_dit_cpu_label": "将DiT卸载到CPU",
51
+ "offload_dit_cpu_info": "将DiT卸载到CPU(需要启用卸载到CPU)",
52
+ "init_btn": "初始化服务",
53
+ "status_label": "状态",
54
+ "language_label": "界面语言",
55
+ "language_info": "选择界面语言"
56
+ },
57
+ "generation": {
58
+ "required_inputs": "📝 必需输入",
59
+ "task_type_label": "任务类型",
60
+ "task_type_info": "选择生成的任务类型",
61
+ "instruction_label": "指令",
62
+ "instruction_info": "指令根据任务类型自动生成",
63
+ "load_btn": "加载",
64
+ "track_name_label": "音轨名称",
65
+ "track_name_info": "为lego/extract任务选择音轨名称",
66
+ "track_classes_label": "音轨名称",
67
+ "track_classes_info": "为complete任务选择多个音轨类别",
68
+ "audio_uploads": "🎵 音频上传",
69
+ "reference_audio": "参考音频(可选)",
70
+ "source_audio": "源音频(可选)",
71
+ "convert_codes_btn": "转换为代码",
72
+ "lm_codes_hints": "🎼 LM 代码提示",
73
+ "lm_codes_label": "LM 代码提示",
74
+ "lm_codes_placeholder": "<|audio_code_10695|><|audio_code_54246|>...",
75
+ "lm_codes_info": "粘贴用于text2music生成的LM代码提示",
76
+ "lm_codes_sample": "LM 代码提示(样本 {n})",
77
+ "lm_codes_sample_info": "样本{n}的代码",
78
+ "transcribe_btn": "转录",
79
+ "repainting_controls": "🎨 重绘控制(秒)",
80
+ "repainting_start": "重绘开始",
81
+ "repainting_end": "重绘结束",
82
+ "caption_title": "📝 音乐描述",
83
+ "caption_label": "音乐描述(可选)",
84
+ "caption_placeholder": "一段平和的原声吉他旋律,配有柔和的人声...",
85
+ "caption_info": "描述风格、流派、乐器和情绪",
86
+ "sample_btn": "示例",
87
+ "lyrics_title": "📝 歌词",
88
+ "lyrics_label": "歌词(可选)",
89
+ "lyrics_placeholder": "[第一段]\\n在星空下\\n我感到如此活跃...",
90
+ "lyrics_info": "带有结构的歌曲歌词",
91
+ "instrumental_label": "纯音乐",
92
+ "optional_params": "⚙️ 可选参数",
93
+ "vocal_language_label": "人声语言(可选)",
94
+ "vocal_language_info": "纯音乐使用 `unknown`",
95
+ "bpm_label": "BPM(可选)",
96
+ "bpm_info": "留空表示N/A",
97
+ "keyscale_label": "调性(可选)",
98
+ "keyscale_placeholder": "留空表示N/A",
99
+ "keyscale_info": "A-G, #/♭, 大调/小调",
100
+ "timesig_label": "拍号(可选)",
101
+ "timesig_info": "2/4, 3/4, 4/4...",
102
+ "duration_label": "音频时长(秒)",
103
+ "duration_info": "使用-1表示随机",
104
+ "batch_size_label": "批量大小",
105
+ "batch_size_info": "要生成的音频数量(最多8个)",
106
+ "advanced_settings": "🔧 高级设置",
107
+ "inference_steps_label": "DiT 推理步数",
108
+ "inference_steps_info": "Turbo: 最多8, Base: 最多100",
109
+ "guidance_scale_label": "DiT 引导比例(仅支持base模型)",
110
+ "guidance_scale_info": "更高的值更紧密地遵循文本",
111
+ "seed_label": "种子",
112
+ "seed_info": "批量使用逗号分隔的值",
113
+ "random_seed_label": "随机种子",
114
+ "random_seed_info": "启用以自动生成种子",
115
+ "audio_format_label": "音频格式",
116
+ "audio_format_info": "保存文件的音频格式",
117
+ "use_adg_label": "使用 ADG",
118
+ "use_adg_info": "启用角域引导",
119
+ "cfg_interval_start": "CFG 间隔开始",
120
+ "cfg_interval_end": "CFG 间隔结束",
121
+ "lm_params_title": "🤖 LM 生成参数",
122
+ "lm_temperature_label": "LM 温度",
123
+ "lm_temperature_info": "5Hz LM温度(越高越随机)",
124
+ "lm_cfg_scale_label": "LM CFG 比例",
125
+ "lm_cfg_scale_info": "5Hz LM CFG (1.0 = 无CFG)",
126
+ "lm_top_k_label": "LM Top-K",
127
+ "lm_top_k_info": "Top-K (0 = 禁用)",
128
+ "lm_top_p_label": "LM Top-P",
129
+ "lm_top_p_info": "Top-P (1.0 = 禁用)",
130
+ "lm_negative_prompt_label": "LM 负面提示",
131
+ "lm_negative_prompt_placeholder": "输入CFG的负面提示(默认: NO USER INPUT)",
132
+ "lm_negative_prompt_info": "负面提示(当LM CFG比例 > 1.0时使用)",
133
+ "cot_metas_label": "CoT 元数据",
134
+ "cot_metas_info": "使用LM生成CoT元数据(取消勾选以跳过LM CoT生成)",
135
+ "cot_language_label": "CoT 语言",
136
+ "cot_language_info": "在CoT中生成语言(思维链)",
137
+ "constrained_debug_label": "约束解码调试",
138
+ "constrained_debug_info": "启用约束解码的调试日志(勾选以查看详细日志)",
139
+ "auto_score_label": "自动评分",
140
+ "auto_score_info": "自动计算所有生成音频的质量分数",
141
+ "lm_batch_chunk_label": "LM 批量块大小",
142
+ "lm_batch_chunk_info": "每个LM批量块的最大项目数(默认: 8, 受GPU内存限制)",
143
+ "codes_strength_label": "LM 代码强度",
144
+ "codes_strength_info": "控制使用LM生成代码的去噪步骤数量",
145
+ "cover_strength_label": "音频覆盖强度",
146
+ "cover_strength_info": "控制使用覆盖模式的去噪步骤数量",
147
+ "score_sensitivity_label": "质量评分敏感度",
148
+ "score_sensitivity_info": "更低 = 更敏感(默认: 1.0). 调整PMI如何映射到[0,1]",
149
+ "attention_focus_label": "输出注意力焦点分数(已禁用)",
150
+ "attention_focus_info": "输出注意力焦点分数分析",
151
+ "think_label": "思考",
152
+ "parallel_thinking_label": "并行思考",
153
+ "generate_btn": "🎵 生成音乐",
154
+ "autogen_label": "自动生成",
155
+ "caption_rewrite_label": "描述重写"
156
+ },
157
+ "results": {
158
+ "title": "🎵 结果",
159
+ "generated_music": "🎵 生成的音乐(样本 {n})",
160
+ "send_to_src_btn": "🔗 发送到源音频",
161
+ "save_btn": "💾 保存",
162
+ "score_btn": "📊 评分",
163
+ "quality_score_label": "质量分数(样本 {n})",
164
+ "quality_score_placeholder": "点击'评分'以计算基于困惑度的质量分数",
165
+ "generation_status": "生成状态",
166
+ "current_batch": "当前批次",
167
+ "batch_indicator": "批次 {current} / {total}",
168
+ "next_batch_status": "下一批次状态",
169
+ "prev_btn": "◀ 上一个",
170
+ "next_btn": "下一个 ▶",
171
+ "restore_params_btn": "↙️ 将这些设置应用到UI(恢复批次参数)",
172
+ "batch_results_title": "📁 批量结果和生成详情",
173
+ "all_files_label": "📁 所有生成的文件(下载)",
174
+ "generation_details": "生成详情",
175
+ "attention_analysis": "⚖️ 注意力焦点分数分析",
176
+ "attention_score": "注意力焦点分数(样本 {n})",
177
+ "lyric_timestamps": "歌词时间戳(样本 {n})",
178
+ "attention_heatmap": "注意力焦点分数热图(样本 {n})"
179
+ },
180
+ "messages": {
181
+ "no_audio_to_save": "❌ 没有要保存的音频",
182
+ "save_success": "✅ 已将音频和元数据保存到 {filename}",
183
+ "save_failed": "❌ 保存失败: {error}",
184
+ "no_file_selected": "⚠️ 未选择文件",
185
+ "params_loaded": "✅ 已从 {filename} 加载参数",
186
+ "invalid_json": "❌ 无效的JSON文件: {error}",
187
+ "load_error": "❌ 加载文件时出错: {error}",
188
+ "example_loaded": "📁 已从 {filename} 加载示例",
189
+ "example_failed": "解析JSON文件 {filename} 失败: {error}",
190
+ "example_error": "加载示例时出错: {error}",
191
+ "lm_generated": "🤖 使用LM生成的示例",
192
+ "lm_fallback": "使用LM生成示例失败,回退到示例目录",
193
+ "lm_not_initialized": "❌ 5Hz LM未初始化。请先初始化它。",
194
+ "autogen_enabled": "🔄 已启用自动生成 - 下一批次将在此之后生成",
195
+ "batch_ready": "✅ 批次 {n} 就绪!点击'下一个'查看。",
196
+ "batch_generating": "🔄 开始为批次 {n} 进行后台生成...",
197
+ "batch_failed": "❌ 后台生成失败: {error}",
198
+ "viewing_batch": "✅ 查看批次 {n}",
199
+ "at_first_batch": "已在第一批次",
200
+ "at_last_batch": "没有下一批次可用",
201
+ "batch_not_found": "在队列中未找到批次 {n}",
202
+ "no_batch_data": "没有要恢复的批次数据。",
203
+ "params_restored": "✅ 已从批次 {n} 恢复UI参数",
204
+ "scoring_failed": "❌ 错误: 未找到批次数据",
205
+ "no_codes": "❌ 没有可用的音频代码。请先生成音乐。",
206
+ "score_failed": "❌ 评分失败: {error}",
207
+ "score_error": "❌ 计算分数时出错: {error}"
208
+ }
209
+ }
acestep/gradio_ui/interfaces/__init__.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Components Module
3
+ Contains all Gradio interface component definitions and layouts
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import get_i18n, t
7
+ from acestep.gradio_ui.interfaces.dataset import create_dataset_section
8
+ from acestep.gradio_ui.interfaces.generation import create_generation_section
9
+ from acestep.gradio_ui.interfaces.result import create_results_section
10
+ from acestep.gradio_ui.events import setup_event_handlers
11
+
12
+
13
+ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_params=None, language='en') -> gr.Blocks:
14
+ """
15
+ Create Gradio interface
16
+
17
+ Args:
18
+ dit_handler: DiT handler instance
19
+ llm_handler: LM handler instance
20
+ dataset_handler: Dataset handler instance
21
+ init_params: Dictionary containing initialization parameters and state.
22
+ If None, service will not be pre-initialized.
23
+ language: UI language code ('en', 'zh', 'ja', default: 'en')
24
+
25
+ Returns:
26
+ Gradio Blocks instance
27
+ """
28
+ # Initialize i18n with selected language
29
+ i18n = get_i18n(language)
30
+
31
+ with gr.Blocks(
32
+ title=t("app.title"),
33
+ theme=gr.themes.Soft(),
34
+ css="""
35
+ .main-header {
36
+ text-align: center;
37
+ margin-bottom: 2rem;
38
+ }
39
+ .section-header {
40
+ background: linear-gradient(90deg, #4CAF50, #45a049);
41
+ color: white;
42
+ padding: 10px;
43
+ border-radius: 5px;
44
+ margin: 10px 0;
45
+ }
46
+ .lm-hints-row {
47
+ align-items: stretch;
48
+ }
49
+ .lm-hints-col {
50
+ display: flex;
51
+ }
52
+ .lm-hints-col > div {
53
+ flex: 1;
54
+ display: flex;
55
+ }
56
+ .lm-hints-btn button {
57
+ height: 100%;
58
+ width: 100%;
59
+ }
60
+ """
61
+ ) as demo:
62
+
63
+ gr.HTML(f"""
64
+ <div class="main-header">
65
+ <h1>{t("app.title")}</h1>
66
+ <p>{t("app.subtitle")}</p>
67
+ </div>
68
+ """)
69
+
70
+ # Dataset Explorer Section
71
+ dataset_section = create_dataset_section(dataset_handler)
72
+
73
+ # Generation Section (pass init_params and language to support pre-initialization)
74
+ generation_section = create_generation_section(dit_handler, llm_handler, init_params=init_params, language=language)
75
+
76
+ # Results Section
77
+ results_section = create_results_section(dit_handler)
78
+
79
+ # Connect event handlers
80
+ setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, dataset_section, generation_section, results_section)
81
+
82
+ return demo
acestep/gradio_ui/interfaces/dataset.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Dataset Section Module
3
+ Contains dataset explorer section component definitions
4
+ """
5
+ import gradio as gr
6
+
7
+
8
+ def create_dataset_section(dataset_handler) -> dict:
9
+ """Create dataset explorer section"""
10
+ with gr.Accordion("📊 Dataset Explorer", open=False, visible=False):
11
+ with gr.Row(equal_height=True):
12
+ dataset_type = gr.Dropdown(
13
+ choices=["train", "test"],
14
+ value="train",
15
+ label="Dataset",
16
+ info="Choose dataset to explore",
17
+ scale=2
18
+ )
19
+ import_dataset_btn = gr.Button("📥 Import Dataset", variant="primary", scale=1)
20
+
21
+ search_type = gr.Dropdown(
22
+ choices=["keys", "idx", "random"],
23
+ value="random",
24
+ label="Search Type",
25
+ info="How to find items",
26
+ scale=1
27
+ )
28
+ search_value = gr.Textbox(
29
+ label="Search Value",
30
+ placeholder="Enter keys or index (leave empty for random)",
31
+ info="Keys: exact match, Index: 0 to dataset size-1",
32
+ scale=2
33
+ )
34
+
35
+ instruction_display = gr.Textbox(
36
+ label="📝 Instruction",
37
+ interactive=False,
38
+ placeholder="No instruction available",
39
+ lines=1
40
+ )
41
+
42
+ repaint_viz_plot = gr.Plot()
43
+
44
+ with gr.Accordion("📋 Item Metadata (JSON)", open=False):
45
+ item_info_json = gr.Code(
46
+ label="Complete Item Information",
47
+ language="json",
48
+ interactive=False,
49
+ lines=15
50
+ )
51
+
52
+ with gr.Row(equal_height=True):
53
+ item_src_audio = gr.Audio(
54
+ label="Source Audio",
55
+ type="filepath",
56
+ interactive=False,
57
+ scale=8
58
+ )
59
+ get_item_btn = gr.Button("🔍 Get Item", variant="secondary", interactive=False, scale=2)
60
+
61
+ with gr.Row(equal_height=True):
62
+ item_target_audio = gr.Audio(
63
+ label="Target Audio",
64
+ type="filepath",
65
+ interactive=False,
66
+ scale=8
67
+ )
68
+ item_refer_audio = gr.Audio(
69
+ label="Reference Audio",
70
+ type="filepath",
71
+ interactive=False,
72
+ scale=2
73
+ )
74
+
75
+ with gr.Row():
76
+ use_src_checkbox = gr.Checkbox(
77
+ label="Use Source Audio from Dataset",
78
+ value=True,
79
+ info="Check to use the source audio from dataset"
80
+ )
81
+
82
+ data_status = gr.Textbox(label="📊 Data Status", interactive=False, value="❌ No dataset imported")
83
+ auto_fill_btn = gr.Button("📋 Auto-fill Generation Form", variant="primary")
84
+
85
+ return {
86
+ "dataset_type": dataset_type,
87
+ "import_dataset_btn": import_dataset_btn,
88
+ "search_type": search_type,
89
+ "search_value": search_value,
90
+ "instruction_display": instruction_display,
91
+ "repaint_viz_plot": repaint_viz_plot,
92
+ "item_info_json": item_info_json,
93
+ "item_src_audio": item_src_audio,
94
+ "get_item_btn": get_item_btn,
95
+ "item_target_audio": item_target_audio,
96
+ "item_refer_audio": item_refer_audio,
97
+ "use_src_checkbox": use_src_checkbox,
98
+ "data_status": data_status,
99
+ "auto_fill_btn": auto_fill_btn,
100
+ }
101
+
acestep/gradio_ui/interfaces/generation.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Generation Section Module
3
+ Contains generation section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.constants import (
7
+ VALID_LANGUAGES,
8
+ TRACK_NAMES,
9
+ TASK_TYPES_TURBO,
10
+ TASK_TYPES_BASE,
11
+ DEFAULT_DIT_INSTRUCTION,
12
+ )
13
+ from acestep.gradio_ui.i18n import t
14
+
15
+
16
+ def create_generation_section(dit_handler, llm_handler, init_params=None, language='en') -> dict:
17
+ """Create generation section
18
+
19
+ Args:
20
+ dit_handler: DiT handler instance
21
+ llm_handler: LM handler instance
22
+ init_params: Dictionary containing initialization parameters and state.
23
+ If None, service will not be pre-initialized.
24
+ language: UI language code ('en', 'zh', 'ja')
25
+ """
26
+ # Check if service is pre-initialized
27
+ service_pre_initialized = init_params is not None and init_params.get('pre_initialized', False)
28
+
29
+ # Get current language from init_params if available
30
+ current_language = init_params.get('language', language) if init_params else language
31
+
32
+ with gr.Group():
33
+ # Service Configuration - collapse if pre-initialized, hide if in service mode
34
+ accordion_open = not service_pre_initialized
35
+ accordion_visible = not service_pre_initialized # Hide when running in service mode
36
+ with gr.Accordion(t("service.title"), open=accordion_open, visible=accordion_visible) as service_config_accordion:
37
+ # Language selector at the top
38
+ with gr.Row():
39
+ language_dropdown = gr.Dropdown(
40
+ choices=[
41
+ ("English", "en"),
42
+ ("中文", "zh"),
43
+ ("日本語", "ja"),
44
+ ],
45
+ value=current_language,
46
+ label=t("service.language_label"),
47
+ info=t("service.language_info"),
48
+ scale=1,
49
+ )
50
+
51
+ # Dropdown options section - all dropdowns grouped together
52
+ with gr.Row(equal_height=True):
53
+ with gr.Column(scale=4):
54
+ # Set checkpoint value from init_params if pre-initialized
55
+ checkpoint_value = init_params.get('checkpoint') if service_pre_initialized else None
56
+ checkpoint_dropdown = gr.Dropdown(
57
+ label=t("service.checkpoint_label"),
58
+ choices=dit_handler.get_available_checkpoints(),
59
+ value=checkpoint_value,
60
+ info=t("service.checkpoint_info")
61
+ )
62
+ with gr.Column(scale=1, min_width=90):
63
+ refresh_btn = gr.Button(t("service.refresh_btn"), size="sm")
64
+
65
+ with gr.Row():
66
+ # Get available acestep-v15- model list
67
+ available_models = dit_handler.get_available_acestep_v15_models()
68
+ default_model = "acestep-v15-turbo" if "acestep-v15-turbo" in available_models else (available_models[0] if available_models else None)
69
+
70
+ # Set config_path value from init_params if pre-initialized
71
+ config_path_value = init_params.get('config_path', default_model) if service_pre_initialized else default_model
72
+ config_path = gr.Dropdown(
73
+ label=t("service.model_path_label"),
74
+ choices=available_models,
75
+ value=config_path_value,
76
+ info=t("service.model_path_info")
77
+ )
78
+ # Set device value from init_params if pre-initialized
79
+ device_value = init_params.get('device', 'auto') if service_pre_initialized else 'auto'
80
+ device = gr.Dropdown(
81
+ choices=["auto", "cuda", "cpu"],
82
+ value=device_value,
83
+ label=t("service.device_label"),
84
+ info=t("service.device_info")
85
+ )
86
+
87
+ with gr.Row():
88
+ # Get available 5Hz LM model list
89
+ available_lm_models = llm_handler.get_available_5hz_lm_models()
90
+ default_lm_model = "acestep-5Hz-lm-0.6B" if "acestep-5Hz-lm-0.6B" in available_lm_models else (available_lm_models[0] if available_lm_models else None)
91
+
92
+ # Set lm_model_path value from init_params if pre-initialized
93
+ lm_model_path_value = init_params.get('lm_model_path', default_lm_model) if service_pre_initialized else default_lm_model
94
+ lm_model_path = gr.Dropdown(
95
+ label=t("service.lm_model_path_label"),
96
+ choices=available_lm_models,
97
+ value=lm_model_path_value,
98
+ info=t("service.lm_model_path_info")
99
+ )
100
+ # Set backend value from init_params if pre-initialized
101
+ backend_value = init_params.get('backend', 'vllm') if service_pre_initialized else 'vllm'
102
+ backend_dropdown = gr.Dropdown(
103
+ choices=["vllm", "pt"],
104
+ value=backend_value,
105
+ label=t("service.backend_label"),
106
+ info=t("service.backend_info")
107
+ )
108
+
109
+ # Checkbox options section - all checkboxes grouped together
110
+ with gr.Row():
111
+ # Set init_llm value from init_params if pre-initialized
112
+ init_llm_value = init_params.get('init_llm', True) if service_pre_initialized else True
113
+ init_llm_checkbox = gr.Checkbox(
114
+ label=t("service.init_llm_label"),
115
+ value=init_llm_value,
116
+ info=t("service.init_llm_info"),
117
+ )
118
+ # Auto-detect flash attention availability
119
+ flash_attn_available = dit_handler.is_flash_attention_available()
120
+ # Set use_flash_attention value from init_params if pre-initialized
121
+ use_flash_attention_value = init_params.get('use_flash_attention', flash_attn_available) if service_pre_initialized else flash_attn_available
122
+ use_flash_attention_checkbox = gr.Checkbox(
123
+ label=t("service.flash_attention_label"),
124
+ value=use_flash_attention_value,
125
+ interactive=flash_attn_available,
126
+ info=t("service.flash_attention_info_enabled") if flash_attn_available else t("service.flash_attention_info_disabled")
127
+ )
128
+ # Set offload_to_cpu value from init_params if pre-initialized
129
+ offload_to_cpu_value = init_params.get('offload_to_cpu', False) if service_pre_initialized else False
130
+ offload_to_cpu_checkbox = gr.Checkbox(
131
+ label=t("service.offload_cpu_label"),
132
+ value=offload_to_cpu_value,
133
+ info=t("service.offload_cpu_info")
134
+ )
135
+ # Set offload_dit_to_cpu value from init_params if pre-initialized
136
+ offload_dit_to_cpu_value = init_params.get('offload_dit_to_cpu', False) if service_pre_initialized else False
137
+ offload_dit_to_cpu_checkbox = gr.Checkbox(
138
+ label=t("service.offload_dit_cpu_label"),
139
+ value=offload_dit_to_cpu_value,
140
+ info=t("service.offload_dit_cpu_info")
141
+ )
142
+
143
+ init_btn = gr.Button(t("service.init_btn"), variant="primary", size="lg")
144
+ # Set init_status value from init_params if pre-initialized
145
+ init_status_value = init_params.get('init_status', '') if service_pre_initialized else ''
146
+ init_status = gr.Textbox(label=t("service.status_label"), interactive=False, lines=3, value=init_status_value)
147
+
148
+ # Inputs
149
+ with gr.Row():
150
+ with gr.Column(scale=2):
151
+ with gr.Accordion(t("generation.required_inputs"), open=True):
152
+ # Task type
153
+ # Determine initial task_type choices based on default model
154
+ default_model_lower = (default_model or "").lower()
155
+ if "turbo" in default_model_lower:
156
+ initial_task_choices = TASK_TYPES_TURBO
157
+ else:
158
+ initial_task_choices = TASK_TYPES_BASE
159
+
160
+ with gr.Row(equal_height=True):
161
+ with gr.Column(scale=2):
162
+ task_type = gr.Dropdown(
163
+ choices=initial_task_choices,
164
+ value="text2music",
165
+ label=t("generation.task_type_label"),
166
+ info=t("generation.task_type_info"),
167
+ )
168
+ with gr.Column(scale=7):
169
+ instruction_display_gen = gr.Textbox(
170
+ label=t("generation.instruction_label"),
171
+ value=DEFAULT_DIT_INSTRUCTION,
172
+ interactive=False,
173
+ lines=1,
174
+ info=t("generation.instruction_info"),
175
+ )
176
+ with gr.Column(scale=1, min_width=100):
177
+ load_file = gr.UploadButton(
178
+ t("generation.load_btn"),
179
+ file_types=[".json"],
180
+ file_count="single",
181
+ variant="secondary",
182
+ size="sm",
183
+ )
184
+
185
+ track_name = gr.Dropdown(
186
+ choices=TRACK_NAMES,
187
+ value=None,
188
+ label=t("generation.track_name_label"),
189
+ info=t("generation.track_name_info"),
190
+ visible=False
191
+ )
192
+
193
+ complete_track_classes = gr.CheckboxGroup(
194
+ choices=TRACK_NAMES,
195
+ label=t("generation.track_classes_label"),
196
+ info=t("generation.track_classes_info"),
197
+ visible=False
198
+ )
199
+
200
+ # Audio uploads
201
+ audio_uploads_accordion = gr.Accordion(t("generation.audio_uploads"), open=False)
202
+ with audio_uploads_accordion:
203
+ with gr.Row(equal_height=True):
204
+ with gr.Column(scale=2):
205
+ reference_audio = gr.Audio(
206
+ label=t("generation.reference_audio"),
207
+ type="filepath",
208
+ )
209
+ with gr.Column(scale=7):
210
+ src_audio = gr.Audio(
211
+ label=t("generation.source_audio"),
212
+ type="filepath",
213
+ )
214
+ with gr.Column(scale=1, min_width=80):
215
+ convert_src_to_codes_btn = gr.Button(
216
+ t("generation.convert_codes_btn"),
217
+ variant="secondary",
218
+ size="sm"
219
+ )
220
+
221
+ # Audio Codes for text2music (dynamic display based on batch size and allow_lm_batch)
222
+ with gr.Accordion(t("generation.lm_codes_hints"), open=False, visible=True) as text2music_audio_codes_group:
223
+ # Single codes input (default mode)
224
+ with gr.Row(equal_height=True, visible=True) as codes_single_row:
225
+ text2music_audio_code_string = gr.Textbox(
226
+ label=t("generation.lm_codes_label"),
227
+ placeholder=t("generation.lm_codes_placeholder"),
228
+ lines=6,
229
+ info=t("generation.lm_codes_info"),
230
+ scale=9,
231
+ )
232
+ transcribe_btn = gr.Button(
233
+ t("generation.transcribe_btn"),
234
+ variant="secondary",
235
+ size="sm",
236
+ scale=1,
237
+ )
238
+
239
+ # Multiple codes inputs (batch mode when allow_lm_batch is enabled)
240
+ with gr.Row(visible=False) as codes_batch_row:
241
+ with gr.Column(visible=True) as codes_col_1:
242
+ text2music_audio_code_string_1 = gr.Textbox(
243
+ label=t("generation.lm_codes_sample", n=1),
244
+ placeholder="<|audio_code_...|>",
245
+ lines=4,
246
+ info=t("generation.lm_codes_sample_info", n=1),
247
+ )
248
+ with gr.Column(visible=True) as codes_col_2:
249
+ text2music_audio_code_string_2 = gr.Textbox(
250
+ label=t("generation.lm_codes_sample", n=2),
251
+ placeholder="<|audio_code_...|>",
252
+ lines=4,
253
+ info=t("generation.lm_codes_sample_info", n=2),
254
+ )
255
+ with gr.Column(visible=False) as codes_col_3:
256
+ text2music_audio_code_string_3 = gr.Textbox(
257
+ label=t("generation.lm_codes_sample", n=3),
258
+ placeholder="<|audio_code_...|>",
259
+ lines=4,
260
+ info=t("generation.lm_codes_sample_info", n=3),
261
+ )
262
+ with gr.Column(visible=False) as codes_col_4:
263
+ text2music_audio_code_string_4 = gr.Textbox(
264
+ label=t("generation.lm_codes_sample", n=4),
265
+ placeholder="<|audio_code_...|>",
266
+ lines=4,
267
+ info=t("generation.lm_codes_sample_info", n=4),
268
+ )
269
+
270
+ # Additional row for codes 5-8
271
+ with gr.Row(visible=False) as codes_batch_row_2:
272
+ with gr.Column() as codes_col_5:
273
+ text2music_audio_code_string_5 = gr.Textbox(
274
+ label=t("generation.lm_codes_sample", n=5),
275
+ placeholder="<|audio_code_...|>",
276
+ lines=4,
277
+ info=t("generation.lm_codes_sample_info", n=5),
278
+ )
279
+ with gr.Column() as codes_col_6:
280
+ text2music_audio_code_string_6 = gr.Textbox(
281
+ label=t("generation.lm_codes_sample", n=6),
282
+ placeholder="<|audio_code_...|>",
283
+ lines=4,
284
+ info=t("generation.lm_codes_sample_info", n=6),
285
+ )
286
+ with gr.Column() as codes_col_7:
287
+ text2music_audio_code_string_7 = gr.Textbox(
288
+ label=t("generation.lm_codes_sample", n=7),
289
+ placeholder="<|audio_code_...|>",
290
+ lines=4,
291
+ info=t("generation.lm_codes_sample_info", n=7),
292
+ )
293
+ with gr.Column() as codes_col_8:
294
+ text2music_audio_code_string_8 = gr.Textbox(
295
+ label=t("generation.lm_codes_sample", n=8),
296
+ placeholder="<|audio_code_...|>",
297
+ lines=4,
298
+ info=t("generation.lm_codes_sample_info", n=8),
299
+ )
300
+
301
+ # Repainting controls
302
+ with gr.Group(visible=False) as repainting_group:
303
+ gr.HTML(f"<h5>{t('generation.repainting_controls')}</h5>")
304
+ with gr.Row():
305
+ repainting_start = gr.Number(
306
+ label=t("generation.repainting_start"),
307
+ value=0.0,
308
+ step=0.1,
309
+ )
310
+ repainting_end = gr.Number(
311
+ label=t("generation.repainting_end"),
312
+ value=-1,
313
+ minimum=-1,
314
+ step=0.1,
315
+ )
316
+
317
+ # Music Caption
318
+ with gr.Accordion(t("generation.caption_title"), open=True):
319
+ with gr.Row(equal_height=True):
320
+ captions = gr.Textbox(
321
+ label=t("generation.caption_label"),
322
+ placeholder=t("generation.caption_placeholder"),
323
+ lines=3,
324
+ info=t("generation.caption_info"),
325
+ scale=9,
326
+ )
327
+ sample_btn = gr.Button(
328
+ t("generation.sample_btn"),
329
+ variant="secondary",
330
+ size="sm",
331
+ scale=1,
332
+ )
333
+
334
+ # Lyrics
335
+ with gr.Accordion(t("generation.lyrics_title"), open=True):
336
+ lyrics = gr.Textbox(
337
+ label=t("generation.lyrics_label"),
338
+ placeholder=t("generation.lyrics_placeholder"),
339
+ lines=8,
340
+ info=t("generation.lyrics_info")
341
+ )
342
+ instrumental_checkbox = gr.Checkbox(
343
+ label=t("generation.instrumental_label"),
344
+ value=False,
345
+ scale=1,
346
+ )
347
+
348
+ # Optional Parameters
349
+ with gr.Accordion(t("generation.optional_params"), open=True):
350
+ with gr.Row():
351
+ vocal_language = gr.Dropdown(
352
+ choices=VALID_LANGUAGES,
353
+ value="unknown",
354
+ label=t("generation.vocal_language_label"),
355
+ allow_custom_value=True,
356
+ info=t("generation.vocal_language_info")
357
+ )
358
+ bpm = gr.Number(
359
+ label=t("generation.bpm_label"),
360
+ value=None,
361
+ step=1,
362
+ info=t("generation.bpm_info")
363
+ )
364
+ key_scale = gr.Textbox(
365
+ label=t("generation.keyscale_label"),
366
+ placeholder=t("generation.keyscale_placeholder"),
367
+ value="",
368
+ info=t("generation.keyscale_info")
369
+ )
370
+ time_signature = gr.Dropdown(
371
+ choices=["2", "3", "4", "N/A", ""],
372
+ value="",
373
+ label=t("generation.timesig_label"),
374
+ allow_custom_value=True,
375
+ info=t("generation.timesig_info")
376
+ )
377
+ audio_duration = gr.Number(
378
+ label=t("generation.duration_label"),
379
+ value=-1,
380
+ minimum=-1,
381
+ maximum=600.0,
382
+ step=0.1,
383
+ info=t("generation.duration_info")
384
+ )
385
+ batch_size_input = gr.Number(
386
+ label=t("generation.batch_size_label"),
387
+ value=2,
388
+ minimum=1,
389
+ maximum=8,
390
+ step=1,
391
+ info=t("generation.batch_size_info")
392
+ )
393
+
394
+ # Advanced Settings
395
+ with gr.Accordion(t("generation.advanced_settings"), open=False):
396
+ with gr.Row():
397
+ inference_steps = gr.Slider(
398
+ minimum=1,
399
+ maximum=8,
400
+ value=8,
401
+ step=1,
402
+ label=t("generation.inference_steps_label"),
403
+ info=t("generation.inference_steps_info")
404
+ )
405
+ guidance_scale = gr.Slider(
406
+ minimum=1.0,
407
+ maximum=15.0,
408
+ value=7.0,
409
+ step=0.1,
410
+ label=t("generation.guidance_scale_label"),
411
+ info=t("generation.guidance_scale_info"),
412
+ visible=False
413
+ )
414
+ with gr.Column():
415
+ seed = gr.Textbox(
416
+ label=t("generation.seed_label"),
417
+ value="-1",
418
+ info=t("generation.seed_info")
419
+ )
420
+ random_seed_checkbox = gr.Checkbox(
421
+ label=t("generation.random_seed_label"),
422
+ value=True,
423
+ info=t("generation.random_seed_info")
424
+ )
425
+ audio_format = gr.Dropdown(
426
+ choices=["mp3", "flac"],
427
+ value="mp3",
428
+ label=t("generation.audio_format_label"),
429
+ info=t("generation.audio_format_info")
430
+ )
431
+
432
+ with gr.Row():
433
+ use_adg = gr.Checkbox(
434
+ label=t("generation.use_adg_label"),
435
+ value=False,
436
+ info=t("generation.use_adg_info"),
437
+ visible=False
438
+ )
439
+
440
+ with gr.Row():
441
+ cfg_interval_start = gr.Slider(
442
+ minimum=0.0,
443
+ maximum=1.0,
444
+ value=0.0,
445
+ step=0.01,
446
+ label=t("generation.cfg_interval_start"),
447
+ visible=False
448
+ )
449
+ cfg_interval_end = gr.Slider(
450
+ minimum=0.0,
451
+ maximum=1.0,
452
+ value=1.0,
453
+ step=0.01,
454
+ label=t("generation.cfg_interval_end"),
455
+ visible=False
456
+ )
457
+
458
+ # LM (Language Model) Parameters
459
+ gr.HTML(f"<h4>{t('generation.lm_params_title')}</h4>")
460
+ with gr.Row():
461
+ lm_temperature = gr.Slider(
462
+ label=t("generation.lm_temperature_label"),
463
+ minimum=0.0,
464
+ maximum=2.0,
465
+ value=0.85,
466
+ step=0.1,
467
+ scale=1,
468
+ info=t("generation.lm_temperature_info")
469
+ )
470
+ lm_cfg_scale = gr.Slider(
471
+ label=t("generation.lm_cfg_scale_label"),
472
+ minimum=1.0,
473
+ maximum=3.0,
474
+ value=2.0,
475
+ step=0.1,
476
+ scale=1,
477
+ info=t("generation.lm_cfg_scale_info")
478
+ )
479
+ lm_top_k = gr.Slider(
480
+ label=t("generation.lm_top_k_label"),
481
+ minimum=0,
482
+ maximum=100,
483
+ value=0,
484
+ step=1,
485
+ scale=1,
486
+ info=t("generation.lm_top_k_info")
487
+ )
488
+ lm_top_p = gr.Slider(
489
+ label=t("generation.lm_top_p_label"),
490
+ minimum=0.0,
491
+ maximum=1.0,
492
+ value=0.9,
493
+ step=0.01,
494
+ scale=1,
495
+ info=t("generation.lm_top_p_info")
496
+ )
497
+
498
+ with gr.Row():
499
+ lm_negative_prompt = gr.Textbox(
500
+ label=t("generation.lm_negative_prompt_label"),
501
+ value="NO USER INPUT",
502
+ placeholder=t("generation.lm_negative_prompt_placeholder"),
503
+ info=t("generation.lm_negative_prompt_info"),
504
+ lines=2,
505
+ scale=2,
506
+ )
507
+
508
+ with gr.Row():
509
+ use_cot_metas = gr.Checkbox(
510
+ label=t("generation.cot_metas_label"),
511
+ value=True,
512
+ info=t("generation.cot_metas_info"),
513
+ scale=1,
514
+ )
515
+ use_cot_language = gr.Checkbox(
516
+ label=t("generation.cot_language_label"),
517
+ value=True,
518
+ info=t("generation.cot_language_info"),
519
+ scale=1,
520
+ )
521
+ constrained_decoding_debug = gr.Checkbox(
522
+ label=t("generation.constrained_debug_label"),
523
+ value=False,
524
+ info=t("generation.constrained_debug_info"),
525
+ scale=1,
526
+ )
527
+
528
+ with gr.Row():
529
+ auto_score = gr.Checkbox(
530
+ label=t("generation.auto_score_label"),
531
+ value=False,
532
+ info=t("generation.auto_score_info"),
533
+ scale=1,
534
+ )
535
+ lm_batch_chunk_size = gr.Number(
536
+ label=t("generation.lm_batch_chunk_label"),
537
+ value=8,
538
+ minimum=1,
539
+ maximum=32,
540
+ step=1,
541
+ info=t("generation.lm_batch_chunk_info"),
542
+ scale=1,
543
+ )
544
+
545
+ with gr.Row():
546
+ audio_cover_strength = gr.Slider(
547
+ minimum=0.0,
548
+ maximum=1.0,
549
+ value=1.0,
550
+ step=0.01,
551
+ label=t("generation.codes_strength_label"),
552
+ info=t("generation.codes_strength_info"),
553
+ scale=1,
554
+ )
555
+ score_scale = gr.Slider(
556
+ minimum=0.01,
557
+ maximum=1.0,
558
+ value=0.5,
559
+ step=0.01,
560
+ label=t("generation.score_sensitivity_label"),
561
+ info=t("generation.score_sensitivity_info"),
562
+ scale=1,
563
+ )
564
+ output_alignment_preference = gr.Checkbox(
565
+ label=t("generation.attention_focus_label"),
566
+ value=False,
567
+ info=t("generation.attention_focus_info"),
568
+ interactive=False,
569
+ scale=1,
570
+ )
571
+
572
+ # Set generate_btn to interactive if service is pre-initialized
573
+ generate_btn_interactive = init_params.get('enable_generate', False) if service_pre_initialized else False
574
+ with gr.Row(equal_height=True):
575
+ think_checkbox = gr.Checkbox(
576
+ label=t("generation.think_label"),
577
+ value=True,
578
+ scale=1,
579
+ )
580
+ allow_lm_batch = gr.Checkbox(
581
+ label=t("generation.parallel_thinking_label"),
582
+ value=True,
583
+ scale=1,
584
+ )
585
+ generate_btn = gr.Button(t("generation.generate_btn"), variant="primary", size="lg", interactive=generate_btn_interactive, scale=9)
586
+ autogen_checkbox = gr.Checkbox(
587
+ label=t("generation.autogen_label"),
588
+ value=True,
589
+ scale=1,
590
+ )
591
+ use_cot_caption = gr.Checkbox(
592
+ label=t("generation.caption_rewrite_label"),
593
+ value=True,
594
+ scale=1,
595
+ )
596
+
597
+ return {
598
+ "service_config_accordion": service_config_accordion,
599
+ "language_dropdown": language_dropdown,
600
+ "checkpoint_dropdown": checkpoint_dropdown,
601
+ "refresh_btn": refresh_btn,
602
+ "config_path": config_path,
603
+ "device": device,
604
+ "init_btn": init_btn,
605
+ "init_status": init_status,
606
+ "lm_model_path": lm_model_path,
607
+ "init_llm_checkbox": init_llm_checkbox,
608
+ "backend_dropdown": backend_dropdown,
609
+ "use_flash_attention_checkbox": use_flash_attention_checkbox,
610
+ "offload_to_cpu_checkbox": offload_to_cpu_checkbox,
611
+ "offload_dit_to_cpu_checkbox": offload_dit_to_cpu_checkbox,
612
+ "task_type": task_type,
613
+ "instruction_display_gen": instruction_display_gen,
614
+ "track_name": track_name,
615
+ "complete_track_classes": complete_track_classes,
616
+ "audio_uploads_accordion": audio_uploads_accordion,
617
+ "reference_audio": reference_audio,
618
+ "src_audio": src_audio,
619
+ "convert_src_to_codes_btn": convert_src_to_codes_btn,
620
+ "text2music_audio_code_string": text2music_audio_code_string,
621
+ "transcribe_btn": transcribe_btn,
622
+ "text2music_audio_codes_group": text2music_audio_codes_group,
623
+ "lm_temperature": lm_temperature,
624
+ "lm_cfg_scale": lm_cfg_scale,
625
+ "lm_top_k": lm_top_k,
626
+ "lm_top_p": lm_top_p,
627
+ "lm_negative_prompt": lm_negative_prompt,
628
+ "use_cot_metas": use_cot_metas,
629
+ "use_cot_caption": use_cot_caption,
630
+ "use_cot_language": use_cot_language,
631
+ "repainting_group": repainting_group,
632
+ "repainting_start": repainting_start,
633
+ "repainting_end": repainting_end,
634
+ "audio_cover_strength": audio_cover_strength,
635
+ "captions": captions,
636
+ "sample_btn": sample_btn,
637
+ "load_file": load_file,
638
+ "lyrics": lyrics,
639
+ "vocal_language": vocal_language,
640
+ "bpm": bpm,
641
+ "key_scale": key_scale,
642
+ "time_signature": time_signature,
643
+ "audio_duration": audio_duration,
644
+ "batch_size_input": batch_size_input,
645
+ "inference_steps": inference_steps,
646
+ "guidance_scale": guidance_scale,
647
+ "seed": seed,
648
+ "random_seed_checkbox": random_seed_checkbox,
649
+ "use_adg": use_adg,
650
+ "cfg_interval_start": cfg_interval_start,
651
+ "cfg_interval_end": cfg_interval_end,
652
+ "audio_format": audio_format,
653
+ "output_alignment_preference": output_alignment_preference,
654
+ "think_checkbox": think_checkbox,
655
+ "autogen_checkbox": autogen_checkbox,
656
+ "generate_btn": generate_btn,
657
+ "instrumental_checkbox": instrumental_checkbox,
658
+ "constrained_decoding_debug": constrained_decoding_debug,
659
+ "score_scale": score_scale,
660
+ "allow_lm_batch": allow_lm_batch,
661
+ "auto_score": auto_score,
662
+ "lm_batch_chunk_size": lm_batch_chunk_size,
663
+ "codes_single_row": codes_single_row,
664
+ "codes_batch_row": codes_batch_row,
665
+ "codes_batch_row_2": codes_batch_row_2,
666
+ "text2music_audio_code_string_1": text2music_audio_code_string_1,
667
+ "text2music_audio_code_string_2": text2music_audio_code_string_2,
668
+ "text2music_audio_code_string_3": text2music_audio_code_string_3,
669
+ "text2music_audio_code_string_4": text2music_audio_code_string_4,
670
+ "text2music_audio_code_string_5": text2music_audio_code_string_5,
671
+ "text2music_audio_code_string_6": text2music_audio_code_string_6,
672
+ "text2music_audio_code_string_7": text2music_audio_code_string_7,
673
+ "text2music_audio_code_string_8": text2music_audio_code_string_8,
674
+ "codes_col_1": codes_col_1,
675
+ "codes_col_2": codes_col_2,
676
+ "codes_col_3": codes_col_3,
677
+ "codes_col_4": codes_col_4,
678
+ "codes_col_5": codes_col_5,
679
+ "codes_col_6": codes_col_6,
680
+ "codes_col_7": codes_col_7,
681
+ "codes_col_8": codes_col_8,
682
+ }
683
+
acestep/gradio_ui/interfaces/result.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI Results Section Module
3
+ Contains results display section component definitions
4
+ """
5
+ import gradio as gr
6
+ from acestep.gradio_ui.i18n import t
7
+
8
+
9
+ def create_results_section(dit_handler) -> dict:
10
+ """Create results display section"""
11
+ with gr.Accordion(t("results.title"), open=True):
12
+ # Hidden state to store LM-generated metadata
13
+ lm_metadata_state = gr.State(value=None)
14
+
15
+ # Hidden state to track if caption/metadata is from formatted source (LM/transcription)
16
+ is_format_caption_state = gr.State(value=False)
17
+
18
+ # Batch management states
19
+ current_batch_index = gr.State(value=0) # Currently displayed batch index
20
+ total_batches = gr.State(value=1) # Total number of batches generated
21
+ batch_queue = gr.State(value={}) # Dictionary storing all batch data
22
+ generation_params_state = gr.State(value={}) # Store generation parameters for next batches
23
+ is_generating_background = gr.State(value=False) # Background generation flag
24
+
25
+ # All audio components in one row with dynamic visibility
26
+ with gr.Row():
27
+ with gr.Column(visible=True) as audio_col_1:
28
+ generated_audio_1 = gr.Audio(
29
+ label=t("results.generated_music", n=1),
30
+ type="filepath",
31
+ interactive=False
32
+ )
33
+ with gr.Row(equal_height=True):
34
+ send_to_src_btn_1 = gr.Button(
35
+ t("results.send_to_src_btn"),
36
+ variant="secondary",
37
+ size="sm",
38
+ scale=1
39
+ )
40
+ save_btn_1 = gr.Button(
41
+ t("results.save_btn"),
42
+ variant="primary",
43
+ size="sm",
44
+ scale=1
45
+ )
46
+ score_btn_1 = gr.Button(
47
+ t("results.score_btn"),
48
+ variant="secondary",
49
+ size="sm",
50
+ scale=1
51
+ )
52
+ score_display_1 = gr.Textbox(
53
+ label=t("results.quality_score_label", n=1),
54
+ interactive=False,
55
+ placeholder=t("results.quality_score_placeholder")
56
+ )
57
+ with gr.Column(visible=True) as audio_col_2:
58
+ generated_audio_2 = gr.Audio(
59
+ label=t("results.generated_music", n=2),
60
+ type="filepath",
61
+ interactive=False
62
+ )
63
+ with gr.Row(equal_height=True):
64
+ send_to_src_btn_2 = gr.Button(
65
+ t("results.send_to_src_btn"),
66
+ variant="secondary",
67
+ size="sm",
68
+ scale=1
69
+ )
70
+ save_btn_2 = gr.Button(
71
+ t("results.save_btn"),
72
+ variant="primary",
73
+ size="sm",
74
+ scale=1
75
+ )
76
+ score_btn_2 = gr.Button(
77
+ t("results.score_btn"),
78
+ variant="secondary",
79
+ size="sm",
80
+ scale=1
81
+ )
82
+ score_display_2 = gr.Textbox(
83
+ label=t("results.quality_score_label", n=2),
84
+ interactive=False,
85
+ placeholder=t("results.quality_score_placeholder")
86
+ )
87
+ with gr.Column(visible=False) as audio_col_3:
88
+ generated_audio_3 = gr.Audio(
89
+ label=t("results.generated_music", n=3),
90
+ type="filepath",
91
+ interactive=False
92
+ )
93
+ with gr.Row(equal_height=True):
94
+ send_to_src_btn_3 = gr.Button(
95
+ t("results.send_to_src_btn"),
96
+ variant="secondary",
97
+ size="sm",
98
+ scale=1
99
+ )
100
+ save_btn_3 = gr.Button(
101
+ t("results.save_btn"),
102
+ variant="primary",
103
+ size="sm",
104
+ scale=1
105
+ )
106
+ score_btn_3 = gr.Button(
107
+ t("results.score_btn"),
108
+ variant="secondary",
109
+ size="sm",
110
+ scale=1
111
+ )
112
+ score_display_3 = gr.Textbox(
113
+ label=t("results.quality_score_label", n=3),
114
+ interactive=False,
115
+ placeholder=t("results.quality_score_placeholder")
116
+ )
117
+ with gr.Column(visible=False) as audio_col_4:
118
+ generated_audio_4 = gr.Audio(
119
+ label=t("results.generated_music", n=4),
120
+ type="filepath",
121
+ interactive=False
122
+ )
123
+ with gr.Row(equal_height=True):
124
+ send_to_src_btn_4 = gr.Button(
125
+ t("results.send_to_src_btn"),
126
+ variant="secondary",
127
+ size="sm",
128
+ scale=1
129
+ )
130
+ save_btn_4 = gr.Button(
131
+ t("results.save_btn"),
132
+ variant="primary",
133
+ size="sm",
134
+ scale=1
135
+ )
136
+ score_btn_4 = gr.Button(
137
+ t("results.score_btn"),
138
+ variant="secondary",
139
+ size="sm",
140
+ scale=1
141
+ )
142
+ score_display_4 = gr.Textbox(
143
+ label=t("results.quality_score_label", n=4),
144
+ interactive=False,
145
+ placeholder=t("results.quality_score_placeholder")
146
+ )
147
+
148
+ # Second row for batch size 5-8 (initially hidden)
149
+ with gr.Row(visible=False) as audio_row_5_8:
150
+ with gr.Column() as audio_col_5:
151
+ generated_audio_5 = gr.Audio(
152
+ label=t("results.generated_music", n=5),
153
+ type="filepath",
154
+ interactive=False
155
+ )
156
+ with gr.Row(equal_height=True):
157
+ send_to_src_btn_5 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
158
+ save_btn_5 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
159
+ score_btn_5 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
160
+ score_display_5 = gr.Textbox(
161
+ label=t("results.quality_score_label", n=5),
162
+ interactive=False,
163
+ placeholder=t("results.quality_score_placeholder")
164
+ )
165
+ with gr.Column() as audio_col_6:
166
+ generated_audio_6 = gr.Audio(
167
+ label=t("results.generated_music", n=6),
168
+ type="filepath",
169
+ interactive=False
170
+ )
171
+ with gr.Row(equal_height=True):
172
+ send_to_src_btn_6 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
173
+ save_btn_6 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
174
+ score_btn_6 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
175
+ score_display_6 = gr.Textbox(
176
+ label=t("results.quality_score_label", n=6),
177
+ interactive=False,
178
+ placeholder=t("results.quality_score_placeholder")
179
+ )
180
+ with gr.Column() as audio_col_7:
181
+ generated_audio_7 = gr.Audio(
182
+ label=t("results.generated_music", n=7),
183
+ type="filepath",
184
+ interactive=False
185
+ )
186
+ with gr.Row(equal_height=True):
187
+ send_to_src_btn_7 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
188
+ save_btn_7 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
189
+ score_btn_7 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
190
+ score_display_7 = gr.Textbox(
191
+ label=t("results.quality_score_label", n=7),
192
+ interactive=False,
193
+ placeholder=t("results.quality_score_placeholder")
194
+ )
195
+ with gr.Column() as audio_col_8:
196
+ generated_audio_8 = gr.Audio(
197
+ label=t("results.generated_music", n=8),
198
+ type="filepath",
199
+ interactive=False
200
+ )
201
+ with gr.Row(equal_height=True):
202
+ send_to_src_btn_8 = gr.Button(t("results.send_to_src_btn"), variant="secondary", size="sm", scale=1)
203
+ save_btn_8 = gr.Button(t("results.save_btn"), variant="primary", size="sm", scale=1)
204
+ score_btn_8 = gr.Button(t("results.score_btn"), variant="secondary", size="sm", scale=1)
205
+ score_display_8 = gr.Textbox(
206
+ label=t("results.quality_score_label", n=8),
207
+ interactive=False,
208
+ placeholder=t("results.quality_score_placeholder")
209
+ )
210
+
211
+ status_output = gr.Textbox(label=t("results.generation_status"), interactive=False)
212
+
213
+ # Batch navigation controls
214
+ with gr.Row(equal_height=True):
215
+ prev_batch_btn = gr.Button(
216
+ t("results.prev_btn"),
217
+ variant="secondary",
218
+ interactive=False,
219
+ scale=1,
220
+ size="sm"
221
+ )
222
+ batch_indicator = gr.Textbox(
223
+ label=t("results.current_batch"),
224
+ value=t("results.batch_indicator", current=1, total=1),
225
+ interactive=False,
226
+ scale=3
227
+ )
228
+ next_batch_status = gr.Textbox(
229
+ label=t("results.next_batch_status"),
230
+ value="",
231
+ interactive=False,
232
+ scale=3
233
+ )
234
+ next_batch_btn = gr.Button(
235
+ t("results.next_btn"),
236
+ variant="primary",
237
+ interactive=False,
238
+ scale=1,
239
+ size="sm"
240
+ )
241
+
242
+ # One-click restore parameters button
243
+ restore_params_btn = gr.Button(
244
+ t("results.restore_params_btn"),
245
+ variant="secondary",
246
+ interactive=False, # Initially disabled, enabled after generation
247
+ size="sm"
248
+ )
249
+
250
+ with gr.Accordion(t("results.batch_results_title"), open=False):
251
+ generated_audio_batch = gr.File(
252
+ label=t("results.all_files_label"),
253
+ file_count="multiple",
254
+ interactive=False
255
+ )
256
+ generation_info = gr.Markdown(label=t("results.generation_details"))
257
+
258
+ with gr.Accordion(t("results.attention_analysis"), open=False):
259
+ with gr.Row():
260
+ with gr.Column():
261
+ align_score_1 = gr.Textbox(label=t("results.attention_score", n=1), interactive=False)
262
+ align_text_1 = gr.Textbox(label=t("results.lyric_timestamps", n=1), interactive=False, lines=10)
263
+ align_plot_1 = gr.Plot(label=t("results.attention_heatmap", n=1))
264
+ with gr.Column():
265
+ align_score_2 = gr.Textbox(label=t("results.attention_score", n=2), interactive=False)
266
+ align_text_2 = gr.Textbox(label=t("results.lyric_timestamps", n=2), interactive=False, lines=10)
267
+ align_plot_2 = gr.Plot(label=t("results.attention_heatmap", n=2))
268
+
269
+ return {
270
+ "lm_metadata_state": lm_metadata_state,
271
+ "is_format_caption_state": is_format_caption_state,
272
+ "current_batch_index": current_batch_index,
273
+ "total_batches": total_batches,
274
+ "batch_queue": batch_queue,
275
+ "generation_params_state": generation_params_state,
276
+ "is_generating_background": is_generating_background,
277
+ "status_output": status_output,
278
+ "prev_batch_btn": prev_batch_btn,
279
+ "batch_indicator": batch_indicator,
280
+ "next_batch_btn": next_batch_btn,
281
+ "next_batch_status": next_batch_status,
282
+ "restore_params_btn": restore_params_btn,
283
+ "generated_audio_1": generated_audio_1,
284
+ "generated_audio_2": generated_audio_2,
285
+ "generated_audio_3": generated_audio_3,
286
+ "generated_audio_4": generated_audio_4,
287
+ "generated_audio_5": generated_audio_5,
288
+ "generated_audio_6": generated_audio_6,
289
+ "generated_audio_7": generated_audio_7,
290
+ "generated_audio_8": generated_audio_8,
291
+ "audio_row_5_8": audio_row_5_8,
292
+ "audio_col_1": audio_col_1,
293
+ "audio_col_2": audio_col_2,
294
+ "audio_col_3": audio_col_3,
295
+ "audio_col_4": audio_col_4,
296
+ "audio_col_5": audio_col_5,
297
+ "audio_col_6": audio_col_6,
298
+ "audio_col_7": audio_col_7,
299
+ "audio_col_8": audio_col_8,
300
+ "send_to_src_btn_1": send_to_src_btn_1,
301
+ "send_to_src_btn_2": send_to_src_btn_2,
302
+ "send_to_src_btn_3": send_to_src_btn_3,
303
+ "send_to_src_btn_4": send_to_src_btn_4,
304
+ "send_to_src_btn_5": send_to_src_btn_5,
305
+ "send_to_src_btn_6": send_to_src_btn_6,
306
+ "send_to_src_btn_7": send_to_src_btn_7,
307
+ "send_to_src_btn_8": send_to_src_btn_8,
308
+ "save_btn_1": save_btn_1,
309
+ "save_btn_2": save_btn_2,
310
+ "save_btn_3": save_btn_3,
311
+ "save_btn_4": save_btn_4,
312
+ "save_btn_5": save_btn_5,
313
+ "save_btn_6": save_btn_6,
314
+ "save_btn_7": save_btn_7,
315
+ "save_btn_8": save_btn_8,
316
+ "score_btn_1": score_btn_1,
317
+ "score_btn_2": score_btn_2,
318
+ "score_btn_3": score_btn_3,
319
+ "score_btn_4": score_btn_4,
320
+ "score_btn_5": score_btn_5,
321
+ "score_btn_6": score_btn_6,
322
+ "score_btn_7": score_btn_7,
323
+ "score_btn_8": score_btn_8,
324
+ "score_display_1": score_display_1,
325
+ "score_display_2": score_display_2,
326
+ "score_display_3": score_display_3,
327
+ "score_display_4": score_display_4,
328
+ "score_display_5": score_display_5,
329
+ "score_display_6": score_display_6,
330
+ "score_display_7": score_display_7,
331
+ "score_display_8": score_display_8,
332
+ "generated_audio_batch": generated_audio_batch,
333
+ "generation_info": generation_info,
334
+ "align_score_1": align_score_1,
335
+ "align_text_1": align_text_1,
336
+ "align_plot_1": align_plot_1,
337
+ "align_score_2": align_score_2,
338
+ "align_text_2": align_text_2,
339
+ "align_plot_2": align_plot_2,
340
+ }
341
+
acestep/handler.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
  import traceback
10
  import re
11
  import random
 
12
  from contextlib import contextmanager
13
  from typing import Optional, Dict, Any, Tuple, List, Union
14
 
@@ -25,7 +26,7 @@ from transformers.generation.streamers import BaseStreamer
25
  from diffusers.models import AutoencoderOobleck
26
  from acestep.constants import (
27
  TASK_INSTRUCTIONS,
28
- TRACK_NAMES,
29
  DEFAULT_DIT_INSTRUCTION,
30
  )
31
 
@@ -33,16 +34,6 @@ from acestep.constants import (
33
  warnings.filterwarnings("ignore")
34
 
35
 
36
- SFT_GEN_PROMPT = """# Instruction
37
- {}
38
-
39
- # Caption
40
- {}
41
-
42
- # Metas
43
- {}<|endoftext|>
44
- """
45
-
46
  class AceStepHandler:
47
  """ACE-Step Business Logic Handler"""
48
 
@@ -2237,12 +2228,16 @@ class AceStepHandler:
2237
  audio_format_lower = "wav"
2238
 
2239
  saved_files = []
 
2240
  for i in range(actual_batch_size):
2241
- audio_file = os.path.join(self.temp_dir, f"generated_{i}_{actual_seed_list[i]}.{audio_format_lower}")
 
 
2242
  # Convert to numpy: [channels, samples] -> [samples, channels]
2243
  audio_np = pred_wavs[i].cpu().float().numpy().T
2244
  sf.write(audio_file, audio_np, self.sample_rate)
2245
  saved_files.append(audio_file)
 
2246
 
2247
  # Prepare return values
2248
  first_audio = saved_files[0] if len(saved_files) > 0 else None
 
9
  import traceback
10
  import re
11
  import random
12
+ import uuid
13
  from contextlib import contextmanager
14
  from typing import Optional, Dict, Any, Tuple, List, Union
15
 
 
26
  from diffusers.models import AutoencoderOobleck
27
  from acestep.constants import (
28
  TASK_INSTRUCTIONS,
29
+ SFT_GEN_PROMPT,
30
  DEFAULT_DIT_INSTRUCTION,
31
  )
32
 
 
34
  warnings.filterwarnings("ignore")
35
 
36
 
 
 
 
 
 
 
 
 
 
 
37
  class AceStepHandler:
38
  """ACE-Step Business Logic Handler"""
39
 
 
2228
  audio_format_lower = "wav"
2229
 
2230
  saved_files = []
2231
+ saved_uuids = [] # Store UUIDs for each file
2232
  for i in range(actual_batch_size):
2233
+ # Generate unique UUID for each audio file
2234
+ file_uuid = str(uuid.uuid4())
2235
+ audio_file = os.path.join(self.temp_dir, f"{file_uuid}.{audio_format_lower}")
2236
  # Convert to numpy: [channels, samples] -> [samples, channels]
2237
  audio_np = pred_wavs[i].cpu().float().numpy().T
2238
  sf.write(audio_file, audio_np, self.sample_rate)
2239
  saved_files.append(audio_file)
2240
+ saved_uuids.append(file_uuid)
2241
 
2242
  # Prepare return values
2243
  first_audio = saved_files[0] if len(saved_files) > 0 else None
acestep/test_time_scaling.py CHANGED
@@ -228,6 +228,97 @@ def _calculate_log_prob(
228
  return mean_log_prob
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # ==============================================================================
232
  # Main Public API
233
  # ==============================================================================
@@ -300,16 +391,16 @@ def calculate_pmi_score_per_condition(
300
 
301
  # 4. Global Score
302
  global_score = sum(scores.values()) / len(scores)
 
303
 
304
  # Status Message
305
- status_lines = ["✅ Per-condition scores (0-1):"]
306
  for key, score in sorted(scores.items()):
307
  metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
308
  status_lines.append(f" {key}: {score:.4f} ({metric})")
309
- status_lines.append(f"Global score: {global_score:.4f}")
310
-
311
- logger.info(f"Calculated scores: {global_score:.4f}")
312
- return scores, global_score, "\n".join(status_lines)
313
 
314
  except Exception as e:
315
  import traceback
 
228
  return mean_log_prob
229
 
230
 
231
+ def calculate_reward_score(
232
+ scores: Dict[str, float],
233
+ weights_config: Optional[Dict[str, float]] = None
234
+ ) -> Tuple[float, str]:
235
+ """
236
+ Reward Model Calculator: Computes a final reward based on user priorities.
237
+
238
+ Priority Logic:
239
+ 1. Caption (Highest): The overall vibe/style must match.
240
+ 2. Lyrics (Medium): Content accuracy is important but secondary to vibe.
241
+ 3. Metadata (Lowest): Technical constraints (BPM, Key) allow for slight deviations.
242
+
243
+ Strategy: Dynamic Weighted Sum
244
+ - Metadata fields are aggregated into a single 'metadata' score first.
245
+ - Weights are dynamically renormalized if any component (e.g., lyrics) is missing.
246
+
247
+ Args:
248
+ scores: Dictionary of raw scores (0.0 - 1.0) from the evaluation module.
249
+ weights_config: Optional custom weights. Defaults to:
250
+ Caption (50%), Lyrics (30%), Metadata (20%).
251
+
252
+ Returns:
253
+ final_reward: The calculated reward score (0.0 - 1.0).
254
+ explanation: A formatted string explaining how the score was derived.
255
+ """
256
+
257
+ # 1. Default Preference Configuration
258
+ # These weights determine the relative importance of each component.
259
+ if weights_config is None:
260
+ weights_config = {
261
+ 'caption': 0.50, # High priority: Style/Vibe
262
+ 'lyrics': 0.30, # Medium priority: Content
263
+ 'metadata': 0.20 # Low priority: Technical details
264
+ }
265
+
266
+ # 2. Extract and Group Scores
267
+ # Caption and Lyrics are standalone high-level features.
268
+ caption_score = scores.get('caption')
269
+ lyrics_score = scores.get('lyrics')
270
+
271
+ # Metadata fields (bpm, key, duration, etc.) are aggregated.
272
+ # We treat them as a single "Technical Score" to prevent them from
273
+ # diluting the weight of Caption/Lyrics simply by having many fields.
274
+ meta_scores_list = [
275
+ val for key, val in scores.items()
276
+ if key not in ['caption', 'lyrics']
277
+ ]
278
+
279
+ # Calculate average of all metadata fields (if any exist)
280
+ meta_aggregate_score = None
281
+ if meta_scores_list:
282
+ meta_aggregate_score = sum(meta_scores_list) / len(meta_scores_list)
283
+
284
+ # 3. specific Active Components & Dynamic Weighting
285
+ # We only include components that actually exist in this generation.
286
+ active_components = {}
287
+
288
+ if caption_score is not None:
289
+ active_components['caption'] = (caption_score, weights_config['caption'])
290
+
291
+ if lyrics_score is not None:
292
+ active_components['lyrics'] = (lyrics_score, weights_config['lyrics'])
293
+
294
+ if meta_aggregate_score is not None:
295
+ active_components['metadata'] = (meta_aggregate_score, weights_config['metadata'])
296
+
297
+ # 4. Calculate Final Weighted Score
298
+ total_base_weight = sum(w for _, w in active_components.values())
299
+ total_score = 0.0
300
+
301
+ breakdown_lines = []
302
+
303
+ if total_base_weight == 0:
304
+ return 0.0, "❌ No valid scores available to calculate reward."
305
+
306
+ # Sort by weight (importance) for display
307
+ sorted_components = sorted(active_components.items(), key=lambda x: x[1][1], reverse=True)
308
+
309
+ for name, (score, base_weight) in sorted_components:
310
+ # Renormalize weight: If lyrics are missing, caption/metadata weights scale up proportionately.
311
+ normalized_weight = base_weight / total_base_weight
312
+ weighted_contribution = score * normalized_weight
313
+ total_score += weighted_contribution
314
+
315
+ breakdown_lines.append(
316
+ f" • {name.title():<8} | Score: {score:.4f} | Weight: {normalized_weight:.2f} "
317
+ f"-> Contrib: +{weighted_contribution:.4f}"
318
+ )
319
+
320
+ return total_score, "\n".join(breakdown_lines)
321
+
322
  # ==============================================================================
323
  # Main Public API
324
  # ==============================================================================
 
391
 
392
  # 4. Global Score
393
  global_score = sum(scores.values()) / len(scores)
394
+ global_score, breakdown_lines = calculate_reward_score(scores)
395
 
396
  # Status Message
397
+ status_lines = [breakdown_lines, "\n✅ Per-condition scores (0-1):"]
398
  for key, score in sorted(scores.items()):
399
  metric = "Top-k Recall" if key in metadata_recall_keys else "PMI (Norm)"
400
  status_lines.append(f" {key}: {score:.4f} ({metric})")
401
+ status = "\n".join(status_lines)
402
+ logger.info(f"Calculated scores: {global_score:.4f}\n{status}")
403
+ return scores, global_score, status
 
404
 
405
  except Exception as e:
406
  import traceback