Spaces:
Running on Zero
Running on Zero
fix llm gen for pt
Browse files- .gitignore +2 -1
- acestep/gradio_ui.py +143 -61
- acestep/handler.py +115 -2
- acestep/llm_inference.py +441 -175
.gitignore
CHANGED
|
@@ -214,4 +214,5 @@ checkpoints/
|
|
| 214 |
playground.ipynb
|
| 215 |
.history/
|
| 216 |
upload_checkpoints.sh
|
| 217 |
-
checkpoints.7z
|
|
|
|
|
|
| 214 |
playground.ipynb
|
| 215 |
.history/
|
| 216 |
upload_checkpoints.sh
|
| 217 |
+
checkpoints.7z
|
| 218 |
+
README_old.md
|
acestep/gradio_ui.py
CHANGED
|
@@ -36,6 +36,20 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
|
|
| 36 |
border-radius: 5px;
|
| 37 |
margin: 10px 0;
|
| 38 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
"""
|
| 40 |
) as demo:
|
| 41 |
|
|
@@ -320,43 +334,46 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 320 |
)
|
| 321 |
|
| 322 |
# Audio uploads
|
| 323 |
-
|
| 324 |
-
|
|
|
|
| 325 |
with gr.Column(scale=2):
|
| 326 |
reference_audio = gr.Audio(
|
| 327 |
label="Reference Audio (optional)",
|
| 328 |
type="filepath",
|
| 329 |
)
|
| 330 |
-
with gr.Column(scale=
|
| 331 |
src_audio = gr.Audio(
|
| 332 |
label="Source Audio (optional)",
|
| 333 |
type="filepath",
|
| 334 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
|
| 336 |
-
audio_code_string = gr.Textbox(
|
| 337 |
-
label="Audio Codes (optional)",
|
| 338 |
-
placeholder="<|audio_code_10695|><|audio_code_54246|>...",
|
| 339 |
-
lines=4,
|
| 340 |
-
visible=False,
|
| 341 |
-
info="Paste precomputed audio code tokens"
|
| 342 |
-
)
|
| 343 |
-
|
| 344 |
# Audio Codes for text2music
|
| 345 |
with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
-
|
| 354 |
-
with gr.Row(visible=True) as use_5hz_lm_row:
|
| 355 |
-
use_5hz_lm_btn = gr.Button(
|
| 356 |
-
"Generate LM Hints",
|
| 357 |
-
variant="secondary",
|
| 358 |
-
size="lg",
|
| 359 |
-
)
|
| 360 |
lm_temperature = gr.Slider(
|
| 361 |
label="Temperature",
|
| 362 |
minimum=0.0,
|
|
@@ -364,7 +381,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 364 |
value=0.85,
|
| 365 |
step=0.1,
|
| 366 |
scale=1,
|
| 367 |
-
info="
|
| 368 |
)
|
| 369 |
lm_cfg_scale = gr.Slider(
|
| 370 |
label="CFG Scale",
|
|
@@ -373,10 +390,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 373 |
value=2.0,
|
| 374 |
step=0.1,
|
| 375 |
scale=1,
|
| 376 |
-
info="
|
| 377 |
)
|
| 378 |
-
|
| 379 |
-
with gr.Row():
|
| 380 |
lm_top_k = gr.Slider(
|
| 381 |
label="Top-K",
|
| 382 |
minimum=0,
|
|
@@ -384,7 +399,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 384 |
value=0,
|
| 385 |
step=1,
|
| 386 |
scale=1,
|
| 387 |
-
info="Top-K
|
| 388 |
)
|
| 389 |
lm_top_p = gr.Slider(
|
| 390 |
label="Top-P",
|
|
@@ -393,7 +408,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 393 |
value=0.9,
|
| 394 |
step=0.01,
|
| 395 |
scale=1,
|
| 396 |
-
info="Top-P (
|
| 397 |
)
|
| 398 |
lm_repetition_penalty = gr.Slider(
|
| 399 |
label="Repetition Penalty",
|
|
@@ -402,20 +417,10 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 402 |
value=1.0,
|
| 403 |
step=0.01,
|
| 404 |
scale=1,
|
| 405 |
-
info="Repetition penalty: >1.0 reduces repetition, <1.0 increases it
|
| 406 |
visible=False,
|
| 407 |
)
|
| 408 |
|
| 409 |
-
# Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
|
| 410 |
-
lm_negative_prompt = gr.Textbox(
|
| 411 |
-
label="Negative Prompt",
|
| 412 |
-
value="NO USER INPUT",
|
| 413 |
-
placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 414 |
-
visible=True,
|
| 415 |
-
info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
|
| 416 |
-
lines=2
|
| 417 |
-
)
|
| 418 |
-
|
| 419 |
# Repainting controls
|
| 420 |
with gr.Group(visible=False) as repainting_group:
|
| 421 |
gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
|
|
@@ -445,12 +450,24 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 445 |
|
| 446 |
# Music Caption
|
| 447 |
with gr.Accordion("📝 Music Caption", open=True):
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
# Lyrics
|
| 456 |
with gr.Accordion("📝 Lyrics", open=True):
|
|
@@ -468,7 +485,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 468 |
choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
|
| 469 |
value="en",
|
| 470 |
label="Vocal Language (optional)",
|
| 471 |
-
allow_custom_value=True
|
|
|
|
| 472 |
)
|
| 473 |
bpm = gr.Number(
|
| 474 |
label="BPM (optional)",
|
|
@@ -477,15 +495,17 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 477 |
info="leave empty for N/A"
|
| 478 |
)
|
| 479 |
key_scale = gr.Textbox(
|
| 480 |
-
label="
|
| 481 |
placeholder="Leave empty for N/A",
|
| 482 |
value="",
|
|
|
|
| 483 |
)
|
| 484 |
time_signature = gr.Dropdown(
|
| 485 |
choices=["2", "3", "4", "N/A", ""],
|
| 486 |
value="4",
|
| 487 |
label="Time Signature (optional)",
|
| 488 |
-
allow_custom_value=True
|
|
|
|
| 489 |
)
|
| 490 |
audio_duration = gr.Number(
|
| 491 |
label="Audio Duration (seconds)",
|
|
@@ -497,7 +517,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 497 |
)
|
| 498 |
batch_size_input = gr.Number(
|
| 499 |
label="Batch Size",
|
| 500 |
-
value=
|
| 501 |
minimum=1,
|
| 502 |
maximum=8,
|
| 503 |
step=1,
|
|
@@ -582,6 +602,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 582 |
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive)
|
| 583 |
|
| 584 |
return {
|
|
|
|
| 585 |
"checkpoint_dropdown": checkpoint_dropdown,
|
| 586 |
"refresh_btn": refresh_btn,
|
| 587 |
"config_path": config_path,
|
|
@@ -598,9 +619,10 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
|
|
| 598 |
"instruction_display_gen": instruction_display_gen,
|
| 599 |
"track_name": track_name,
|
| 600 |
"complete_track_classes": complete_track_classes,
|
|
|
|
| 601 |
"reference_audio": reference_audio,
|
| 602 |
"src_audio": src_audio,
|
| 603 |
-
"
|
| 604 |
"text2music_audio_code_string": text2music_audio_code_string,
|
| 605 |
"text2music_audio_codes_group": text2music_audio_codes_group,
|
| 606 |
"use_5hz_lm_row": use_5hz_lm_row,
|
|
@@ -650,12 +672,22 @@ def create_results_section(dit_handler) -> dict:
|
|
| 650 |
type="filepath",
|
| 651 |
interactive=False
|
| 652 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
with gr.Column():
|
| 654 |
generated_audio_2 = gr.Audio(
|
| 655 |
label="🎵 Generated Music (Sample 2)",
|
| 656 |
type="filepath",
|
| 657 |
interactive=False
|
| 658 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
|
| 660 |
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 661 |
generated_audio_batch = gr.File(
|
|
@@ -680,6 +712,8 @@ def create_results_section(dit_handler) -> dict:
|
|
| 680 |
"status_output": status_output,
|
| 681 |
"generated_audio_1": generated_audio_1,
|
| 682 |
"generated_audio_2": generated_audio_2,
|
|
|
|
|
|
|
| 683 |
"generated_audio_batch": generated_audio_batch,
|
| 684 |
"generation_info": generation_info,
|
| 685 |
"align_score_1": align_score_1,
|
|
@@ -768,7 +802,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 768 |
|
| 769 |
# Service initialization
|
| 770 |
def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 771 |
-
"""Wrapper for service initialization, returns status and
|
| 772 |
# Initialize DiT handler
|
| 773 |
status, enable = dit_handler.initialize_service(
|
| 774 |
checkpoint, config_path, device,
|
|
@@ -799,7 +833,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 799 |
# Don't fail the entire initialization if LM fails, but log it
|
| 800 |
# Keep enable as is (DiT initialization result) even if LM fails
|
| 801 |
|
| 802 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 803 |
|
| 804 |
# Update negative prompt visibility based on "Initialize 5Hz LM" checkbox
|
| 805 |
def update_negative_prompt_visibility(init_llm_checked):
|
|
@@ -855,7 +893,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 855 |
generation_section["offload_to_cpu_checkbox"],
|
| 856 |
generation_section["offload_dit_to_cpu_checkbox"],
|
| 857 |
],
|
| 858 |
-
outputs=[generation_section["init_status"], generation_section["generate_btn"]]
|
| 859 |
)
|
| 860 |
|
| 861 |
# Generation with progress bar
|
|
@@ -992,6 +1030,18 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 992 |
]
|
| 993 |
)
|
| 994 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 995 |
# Update instruction and UI visibility based on task type
|
| 996 |
def update_instruction_ui(
|
| 997 |
task_type_value: str,
|
|
@@ -1020,8 +1070,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1020 |
else:
|
| 1021 |
audio_cover_strength_label = "Audio Cover Strength"
|
| 1022 |
audio_cover_strength_info = "Control how many denoising steps use cover mode"
|
| 1023 |
-
# Show audio_code_string for cover
|
| 1024 |
-
audio_code_visible = task_type_value == "cover"
|
| 1025 |
# Show repainting controls for repaint and lego
|
| 1026 |
repainting_visible = task_type_value in ["repaint", "lego"]
|
| 1027 |
# Show use_5hz_lm, lm_temperature for text2music
|
|
@@ -1037,7 +1085,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1037 |
gr.update(visible=complete_visible), # complete_track_classes
|
| 1038 |
gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
|
| 1039 |
gr.update(visible=repainting_visible), # repainting_group
|
| 1040 |
-
gr.update(visible=audio_code_visible), # audio_code_string
|
| 1041 |
gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
|
| 1042 |
gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
|
| 1043 |
)
|
|
@@ -1058,7 +1105,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1058 |
generation_section["complete_track_classes"],
|
| 1059 |
generation_section["audio_cover_strength"],
|
| 1060 |
generation_section["repainting_group"],
|
| 1061 |
-
generation_section["audio_code_string"],
|
| 1062 |
generation_section["use_5hz_lm_row"],
|
| 1063 |
generation_section["text2music_audio_codes_group"],
|
| 1064 |
]
|
|
@@ -1080,7 +1126,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1080 |
generation_section["complete_track_classes"],
|
| 1081 |
generation_section["audio_cover_strength"],
|
| 1082 |
generation_section["repainting_group"],
|
| 1083 |
-
generation_section["audio_code_string"],
|
| 1084 |
generation_section["use_5hz_lm_row"],
|
| 1085 |
generation_section["text2music_audio_codes_group"],
|
| 1086 |
]
|
|
@@ -1102,9 +1147,46 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
|
|
| 1102 |
generation_section["complete_track_classes"],
|
| 1103 |
generation_section["audio_cover_strength"],
|
| 1104 |
generation_section["repainting_group"],
|
| 1105 |
-
generation_section["audio_code_string"],
|
| 1106 |
generation_section["use_5hz_lm_row"],
|
| 1107 |
generation_section["text2music_audio_codes_group"],
|
| 1108 |
]
|
| 1109 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1110 |
|
|
|
|
| 36 |
border-radius: 5px;
|
| 37 |
margin: 10px 0;
|
| 38 |
}
|
| 39 |
+
.lm-hints-row {
|
| 40 |
+
align-items: stretch;
|
| 41 |
+
}
|
| 42 |
+
.lm-hints-col {
|
| 43 |
+
display: flex;
|
| 44 |
+
}
|
| 45 |
+
.lm-hints-col > div {
|
| 46 |
+
flex: 1;
|
| 47 |
+
display: flex;
|
| 48 |
+
}
|
| 49 |
+
.lm-hints-btn button {
|
| 50 |
+
height: 100%;
|
| 51 |
+
width: 100%;
|
| 52 |
+
}
|
| 53 |
"""
|
| 54 |
) as demo:
|
| 55 |
|
|
|
|
| 334 |
)
|
| 335 |
|
| 336 |
# Audio uploads
|
| 337 |
+
audio_uploads_accordion = gr.Accordion("🎵 Audio Uploads", open=False)
|
| 338 |
+
with audio_uploads_accordion:
|
| 339 |
+
with gr.Row(equal_height=True):
|
| 340 |
with gr.Column(scale=2):
|
| 341 |
reference_audio = gr.Audio(
|
| 342 |
label="Reference Audio (optional)",
|
| 343 |
type="filepath",
|
| 344 |
)
|
| 345 |
+
with gr.Column(scale=7):
|
| 346 |
src_audio = gr.Audio(
|
| 347 |
label="Source Audio (optional)",
|
| 348 |
type="filepath",
|
| 349 |
)
|
| 350 |
+
with gr.Column(scale=1, min_width=80):
|
| 351 |
+
convert_src_to_codes_btn = gr.Button(
|
| 352 |
+
"Convert to Codes",
|
| 353 |
+
variant="secondary",
|
| 354 |
+
size="sm"
|
| 355 |
+
)
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
# Audio Codes for text2music
|
| 358 |
with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
|
| 359 |
+
with gr.Row(equal_height=True, elem_classes=["lm-hints-row"]):
|
| 360 |
+
with gr.Column(scale=9):
|
| 361 |
+
text2music_audio_code_string = gr.Textbox(
|
| 362 |
+
label="Audio Codes",
|
| 363 |
+
placeholder="<|audio_code_10695|><|audio_code_54246|>...",
|
| 364 |
+
lines=6,
|
| 365 |
+
info="Paste precomputed audio code tokens for text2music generation"
|
| 366 |
+
)
|
| 367 |
+
with gr.Column(scale=3, elem_classes=["lm-hints-col"]):
|
| 368 |
+
with gr.Row(equal_height=True, visible=True) as use_5hz_lm_row:
|
| 369 |
+
use_5hz_lm_btn = gr.Button(
|
| 370 |
+
"Generate LM Hints",
|
| 371 |
+
variant="secondary",
|
| 372 |
+
# size="lg",
|
| 373 |
+
elem_classes=["lm-hints-btn"],
|
| 374 |
+
)
|
| 375 |
|
| 376 |
+
with gr.Row(equal_height=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
lm_temperature = gr.Slider(
|
| 378 |
label="Temperature",
|
| 379 |
minimum=0.0,
|
|
|
|
| 381 |
value=0.85,
|
| 382 |
step=0.1,
|
| 383 |
scale=1,
|
| 384 |
+
info="5Hz LM temperature (higher = random)"
|
| 385 |
)
|
| 386 |
lm_cfg_scale = gr.Slider(
|
| 387 |
label="CFG Scale",
|
|
|
|
| 390 |
value=2.0,
|
| 391 |
step=0.1,
|
| 392 |
scale=1,
|
| 393 |
+
info="5Hz LM CFG (1.0 = no CFG)"
|
| 394 |
)
|
|
|
|
|
|
|
| 395 |
lm_top_k = gr.Slider(
|
| 396 |
label="Top-K",
|
| 397 |
minimum=0,
|
|
|
|
| 399 |
value=0,
|
| 400 |
step=1,
|
| 401 |
scale=1,
|
| 402 |
+
info="Top-K (0 = disabled)"
|
| 403 |
)
|
| 404 |
lm_top_p = gr.Slider(
|
| 405 |
label="Top-P",
|
|
|
|
| 408 |
value=0.9,
|
| 409 |
step=0.01,
|
| 410 |
scale=1,
|
| 411 |
+
info="Top-P (1.0 = disabled)"
|
| 412 |
)
|
| 413 |
lm_repetition_penalty = gr.Slider(
|
| 414 |
label="Repetition Penalty",
|
|
|
|
| 417 |
value=1.0,
|
| 418 |
step=0.01,
|
| 419 |
scale=1,
|
| 420 |
+
info="Repetition penalty: >1.0 reduces repetition, <1.0 increases it. Use 1.0 or very small values for audio tokens.",
|
| 421 |
visible=False,
|
| 422 |
)
|
| 423 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
# Repainting controls
|
| 425 |
with gr.Group(visible=False) as repainting_group:
|
| 426 |
gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
|
|
|
|
| 450 |
|
| 451 |
# Music Caption
|
| 452 |
with gr.Accordion("📝 Music Caption", open=True):
|
| 453 |
+
with gr.Row(equal_height=True):
|
| 454 |
+
captions = gr.Textbox(
|
| 455 |
+
label="Music Caption (optional)",
|
| 456 |
+
placeholder="A peaceful acoustic guitar melody with soft vocals...",
|
| 457 |
+
lines=3,
|
| 458 |
+
info="Describe the style, genre, instruments, and mood",
|
| 459 |
+
scale=7,
|
| 460 |
+
)
|
| 461 |
+
# Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
|
| 462 |
+
lm_negative_prompt = gr.Textbox(
|
| 463 |
+
label="Negative Prompt",
|
| 464 |
+
value="NO USER INPUT",
|
| 465 |
+
placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
|
| 466 |
+
visible=True,
|
| 467 |
+
info="Negative prompt (use when CFG Scale > 1.0)",
|
| 468 |
+
lines=3,
|
| 469 |
+
scale=5,
|
| 470 |
+
)
|
| 471 |
|
| 472 |
# Lyrics
|
| 473 |
with gr.Accordion("📝 Lyrics", open=True):
|
|
|
|
| 485 |
choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
|
| 486 |
value="en",
|
| 487 |
label="Vocal Language (optional)",
|
| 488 |
+
allow_custom_value=True,
|
| 489 |
+
info="use `unknown` for inst"
|
| 490 |
)
|
| 491 |
bpm = gr.Number(
|
| 492 |
label="BPM (optional)",
|
|
|
|
| 495 |
info="leave empty for N/A"
|
| 496 |
)
|
| 497 |
key_scale = gr.Textbox(
|
| 498 |
+
label="KeyScale (optional)",
|
| 499 |
placeholder="Leave empty for N/A",
|
| 500 |
value="",
|
| 501 |
+
info="A-G, #/♭, major/minor"
|
| 502 |
)
|
| 503 |
time_signature = gr.Dropdown(
|
| 504 |
choices=["2", "3", "4", "N/A", ""],
|
| 505 |
value="4",
|
| 506 |
label="Time Signature (optional)",
|
| 507 |
+
allow_custom_value=True,
|
| 508 |
+
info="2/4, 3/4, 4/4..."
|
| 509 |
)
|
| 510 |
audio_duration = gr.Number(
|
| 511 |
label="Audio Duration (seconds)",
|
|
|
|
| 517 |
)
|
| 518 |
batch_size_input = gr.Number(
|
| 519 |
label="Batch Size",
|
| 520 |
+
value=2,
|
| 521 |
minimum=1,
|
| 522 |
maximum=8,
|
| 523 |
step=1,
|
|
|
|
| 602 |
generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive)
|
| 603 |
|
| 604 |
return {
|
| 605 |
+
"service_config_accordion": service_config_accordion,
|
| 606 |
"checkpoint_dropdown": checkpoint_dropdown,
|
| 607 |
"refresh_btn": refresh_btn,
|
| 608 |
"config_path": config_path,
|
|
|
|
| 619 |
"instruction_display_gen": instruction_display_gen,
|
| 620 |
"track_name": track_name,
|
| 621 |
"complete_track_classes": complete_track_classes,
|
| 622 |
+
"audio_uploads_accordion": audio_uploads_accordion,
|
| 623 |
"reference_audio": reference_audio,
|
| 624 |
"src_audio": src_audio,
|
| 625 |
+
"convert_src_to_codes_btn": convert_src_to_codes_btn,
|
| 626 |
"text2music_audio_code_string": text2music_audio_code_string,
|
| 627 |
"text2music_audio_codes_group": text2music_audio_codes_group,
|
| 628 |
"use_5hz_lm_row": use_5hz_lm_row,
|
|
|
|
| 672 |
type="filepath",
|
| 673 |
interactive=False
|
| 674 |
)
|
| 675 |
+
send_to_src_btn_1 = gr.Button(
|
| 676 |
+
"Send To Src Audio",
|
| 677 |
+
variant="secondary",
|
| 678 |
+
size="sm"
|
| 679 |
+
)
|
| 680 |
with gr.Column():
|
| 681 |
generated_audio_2 = gr.Audio(
|
| 682 |
label="🎵 Generated Music (Sample 2)",
|
| 683 |
type="filepath",
|
| 684 |
interactive=False
|
| 685 |
)
|
| 686 |
+
send_to_src_btn_2 = gr.Button(
|
| 687 |
+
"Send To Src Audio",
|
| 688 |
+
variant="secondary",
|
| 689 |
+
size="sm"
|
| 690 |
+
)
|
| 691 |
|
| 692 |
with gr.Accordion("📁 Batch Results & Generation Details", open=False):
|
| 693 |
generated_audio_batch = gr.File(
|
|
|
|
| 712 |
"status_output": status_output,
|
| 713 |
"generated_audio_1": generated_audio_1,
|
| 714 |
"generated_audio_2": generated_audio_2,
|
| 715 |
+
"send_to_src_btn_1": send_to_src_btn_1,
|
| 716 |
+
"send_to_src_btn_2": send_to_src_btn_2,
|
| 717 |
"generated_audio_batch": generated_audio_batch,
|
| 718 |
"generation_info": generation_info,
|
| 719 |
"align_score_1": align_score_1,
|
|
|
|
| 802 |
|
| 803 |
# Service initialization
|
| 804 |
def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
|
| 805 |
+
"""Wrapper for service initialization, returns status, button state, and accordion state"""
|
| 806 |
# Initialize DiT handler
|
| 807 |
status, enable = dit_handler.initialize_service(
|
| 808 |
checkpoint, config_path, device,
|
|
|
|
| 833 |
# Don't fail the entire initialization if LM fails, but log it
|
| 834 |
# Keep enable as is (DiT initialization result) even if LM fails
|
| 835 |
|
| 836 |
+
# Check if model is initialized - if so, collapse the accordion
|
| 837 |
+
is_model_initialized = dit_handler.model is not None
|
| 838 |
+
accordion_state = gr.update(open=not is_model_initialized)
|
| 839 |
+
|
| 840 |
+
return status, gr.update(interactive=enable), accordion_state
|
| 841 |
|
| 842 |
# Update negative prompt visibility based on "Initialize 5Hz LM" checkbox
|
| 843 |
def update_negative_prompt_visibility(init_llm_checked):
|
|
|
|
| 893 |
generation_section["offload_to_cpu_checkbox"],
|
| 894 |
generation_section["offload_dit_to_cpu_checkbox"],
|
| 895 |
],
|
| 896 |
+
outputs=[generation_section["init_status"], generation_section["generate_btn"], generation_section["service_config_accordion"]]
|
| 897 |
)
|
| 898 |
|
| 899 |
# Generation with progress bar
|
|
|
|
| 1030 |
]
|
| 1031 |
)
|
| 1032 |
|
| 1033 |
+
# Convert src audio to codes
|
| 1034 |
+
def convert_src_audio_to_codes_wrapper(src_audio):
|
| 1035 |
+
"""Wrapper for converting src audio to codes"""
|
| 1036 |
+
codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
|
| 1037 |
+
return codes_string
|
| 1038 |
+
|
| 1039 |
+
generation_section["convert_src_to_codes_btn"].click(
|
| 1040 |
+
fn=convert_src_audio_to_codes_wrapper,
|
| 1041 |
+
inputs=[generation_section["src_audio"]],
|
| 1042 |
+
outputs=[generation_section["text2music_audio_code_string"]]
|
| 1043 |
+
)
|
| 1044 |
+
|
| 1045 |
# Update instruction and UI visibility based on task type
|
| 1046 |
def update_instruction_ui(
|
| 1047 |
task_type_value: str,
|
|
|
|
| 1070 |
else:
|
| 1071 |
audio_cover_strength_label = "Audio Cover Strength"
|
| 1072 |
audio_cover_strength_info = "Control how many denoising steps use cover mode"
|
|
|
|
|
|
|
| 1073 |
# Show repainting controls for repaint and lego
|
| 1074 |
repainting_visible = task_type_value in ["repaint", "lego"]
|
| 1075 |
# Show use_5hz_lm, lm_temperature for text2music
|
|
|
|
| 1085 |
gr.update(visible=complete_visible), # complete_track_classes
|
| 1086 |
gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
|
| 1087 |
gr.update(visible=repainting_visible), # repainting_group
|
|
|
|
| 1088 |
gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
|
| 1089 |
gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
|
| 1090 |
)
|
|
|
|
| 1105 |
generation_section["complete_track_classes"],
|
| 1106 |
generation_section["audio_cover_strength"],
|
| 1107 |
generation_section["repainting_group"],
|
|
|
|
| 1108 |
generation_section["use_5hz_lm_row"],
|
| 1109 |
generation_section["text2music_audio_codes_group"],
|
| 1110 |
]
|
|
|
|
| 1126 |
generation_section["complete_track_classes"],
|
| 1127 |
generation_section["audio_cover_strength"],
|
| 1128 |
generation_section["repainting_group"],
|
|
|
|
| 1129 |
generation_section["use_5hz_lm_row"],
|
| 1130 |
generation_section["text2music_audio_codes_group"],
|
| 1131 |
]
|
|
|
|
| 1147 |
generation_section["complete_track_classes"],
|
| 1148 |
generation_section["audio_cover_strength"],
|
| 1149 |
generation_section["repainting_group"],
|
|
|
|
| 1150 |
generation_section["use_5hz_lm_row"],
|
| 1151 |
generation_section["text2music_audio_codes_group"],
|
| 1152 |
]
|
| 1153 |
)
|
| 1154 |
+
|
| 1155 |
+
# Send generated audio to src_audio
|
| 1156 |
+
def send_audio_to_src(audio_file):
|
| 1157 |
+
"""Send generated audio file to src_audio input"""
|
| 1158 |
+
if audio_file is None:
|
| 1159 |
+
return None
|
| 1160 |
+
return audio_file
|
| 1161 |
+
|
| 1162 |
+
results_section["send_to_src_btn_1"].click(
|
| 1163 |
+
fn=send_audio_to_src,
|
| 1164 |
+
inputs=[results_section["generated_audio_1"]],
|
| 1165 |
+
outputs=[generation_section["src_audio"]]
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
results_section["send_to_src_btn_2"].click(
|
| 1169 |
+
fn=send_audio_to_src,
|
| 1170 |
+
inputs=[results_section["generated_audio_2"]],
|
| 1171 |
+
outputs=[generation_section["src_audio"]]
|
| 1172 |
+
)
|
| 1173 |
+
|
| 1174 |
+
# Auto-expand Audio Uploads accordion when audio is uploaded
|
| 1175 |
+
def update_audio_uploads_accordion(reference_audio, src_audio):
|
| 1176 |
+
"""Update Audio Uploads accordion open state based on whether audio files are present"""
|
| 1177 |
+
has_audio = (reference_audio is not None) or (src_audio is not None)
|
| 1178 |
+
return gr.update(open=has_audio)
|
| 1179 |
+
|
| 1180 |
+
# Bind to both audio components' change events
|
| 1181 |
+
generation_section["reference_audio"].change(
|
| 1182 |
+
fn=update_audio_uploads_accordion,
|
| 1183 |
+
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 1184 |
+
outputs=[generation_section["audio_uploads_accordion"]]
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
generation_section["src_audio"].change(
|
| 1188 |
+
fn=update_audio_uploads_accordion,
|
| 1189 |
+
inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
|
| 1190 |
+
outputs=[generation_section["audio_uploads_accordion"]]
|
| 1191 |
+
)
|
| 1192 |
|
acestep/handler.py
CHANGED
|
@@ -504,6 +504,49 @@ class AceStepHandler:
|
|
| 504 |
|
| 505 |
return parsed_metas
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 508 |
"""Get text hidden states from text encoder."""
|
| 509 |
if self.text_tokenizer is None or self.text_encoder is None:
|
|
@@ -765,6 +808,71 @@ class AceStepHandler:
|
|
| 765 |
except Exception as e:
|
| 766 |
logger.error(f"Error processing target audio: {e}")
|
| 767 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 768 |
|
| 769 |
def prepare_batch_data(
|
| 770 |
self,
|
|
@@ -1919,10 +2027,15 @@ class AceStepHandler:
|
|
| 1919 |
refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
|
| 1920 |
|
| 1921 |
# 2. Process source audio
|
|
|
|
| 1922 |
processed_src_audio = None
|
| 1923 |
if src_audio is not None:
|
| 1924 |
-
|
| 1925 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1926 |
|
| 1927 |
# 3. Prepare batch data
|
| 1928 |
captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
|
|
|
|
| 504 |
|
| 505 |
return parsed_metas
|
| 506 |
|
| 507 |
+
def build_dit_inputs(
|
| 508 |
+
self,
|
| 509 |
+
task: str,
|
| 510 |
+
instruction: Optional[str],
|
| 511 |
+
caption: str,
|
| 512 |
+
lyrics: str,
|
| 513 |
+
metas: Optional[Union[str, Dict[str, Any]]] = None,
|
| 514 |
+
vocal_language: str = "en",
|
| 515 |
+
) -> Tuple[str, str]:
|
| 516 |
+
"""
|
| 517 |
+
Build text inputs for the caption and lyric branches used by DiT.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
|
| 521 |
+
instruction: Instruction text; default fallback matches service_generate behavior.
|
| 522 |
+
caption: Caption string.
|
| 523 |
+
lyrics: Lyrics string.
|
| 524 |
+
metas: Metadata (str or dict); follows _parse_metas formatting.
|
| 525 |
+
vocal_language: Language code for lyrics section.
|
| 526 |
+
|
| 527 |
+
Returns:
|
| 528 |
+
(caption_input_text, lyrics_input_text)
|
| 529 |
+
|
| 530 |
+
Example:
|
| 531 |
+
caption_input, lyrics_input = handler.build_dit_inputs(
|
| 532 |
+
task="text2music",
|
| 533 |
+
instruction=None,
|
| 534 |
+
caption="A calm piano melody",
|
| 535 |
+
lyrics="la la la",
|
| 536 |
+
metas={"bpm": 90, "duration": 45},
|
| 537 |
+
vocal_language="en",
|
| 538 |
+
)
|
| 539 |
+
"""
|
| 540 |
+
# Align instruction formatting with _prepare_batch
|
| 541 |
+
final_instruction = instruction or "Fill the audio semantic mask based on the given conditions:"
|
| 542 |
+
if not final_instruction.endswith(":"):
|
| 543 |
+
final_instruction = final_instruction + ":"
|
| 544 |
+
|
| 545 |
+
parsed_meta = self._parse_metas([metas])[0]
|
| 546 |
+
caption_input = SFT_GEN_PROMPT.format(final_instruction, caption, parsed_meta)
|
| 547 |
+
lyrics_input = f"# Languages\n{vocal_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
|
| 548 |
+
return caption_input, lyrics_input
|
| 549 |
+
|
| 550 |
def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 551 |
"""Get text hidden states from text encoder."""
|
| 552 |
if self.text_tokenizer is None or self.text_encoder is None:
|
|
|
|
| 808 |
except Exception as e:
|
| 809 |
logger.error(f"Error processing target audio: {e}")
|
| 810 |
return None
|
| 811 |
+
|
| 812 |
+
def convert_src_audio_to_codes(self, audio_file) -> str:
|
| 813 |
+
"""
|
| 814 |
+
Convert uploaded source audio to audio codes string.
|
| 815 |
+
|
| 816 |
+
Args:
|
| 817 |
+
audio_file: Path to audio file or None
|
| 818 |
+
|
| 819 |
+
Returns:
|
| 820 |
+
Formatted codes string like '<|audio_code_123|><|audio_code_456|>...' or error message
|
| 821 |
+
"""
|
| 822 |
+
if audio_file is None:
|
| 823 |
+
return "❌ Please upload source audio first"
|
| 824 |
+
|
| 825 |
+
if self.model is None or self.vae is None:
|
| 826 |
+
return "❌ Model not initialized. Please initialize the service first."
|
| 827 |
+
|
| 828 |
+
try:
|
| 829 |
+
# Process audio file
|
| 830 |
+
processed_audio = self.process_src_audio(audio_file)
|
| 831 |
+
if processed_audio is None:
|
| 832 |
+
return "❌ Failed to process audio file"
|
| 833 |
+
|
| 834 |
+
# Encode audio to latents using VAE
|
| 835 |
+
with torch.no_grad():
|
| 836 |
+
with self._load_model_context("vae"):
|
| 837 |
+
# Prepare audio for VAE: [channels, samples] -> [1, channels, samples]
|
| 838 |
+
vae_input = processed_audio.unsqueeze(0).to(self.device).to(self.vae.dtype)
|
| 839 |
+
|
| 840 |
+
# Check if audio is silence
|
| 841 |
+
if self.is_silence(vae_input):
|
| 842 |
+
return "❌ Audio file appears to be silent"
|
| 843 |
+
|
| 844 |
+
# Encode to latents
|
| 845 |
+
latents = self.vae.encode(vae_input).latent_dist.sample()
|
| 846 |
+
# Cast back to model dtype
|
| 847 |
+
latents = latents.to(self.dtype)
|
| 848 |
+
# Transpose: [1, d, T] -> [1, T, d] -> [T, d]
|
| 849 |
+
latents = latents.squeeze(0).transpose(0, 1) # [T, d]
|
| 850 |
+
|
| 851 |
+
# Create attention mask for latents
|
| 852 |
+
attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
|
| 853 |
+
|
| 854 |
+
# Tokenize latents to get code indices
|
| 855 |
+
with self._load_model_context("model"):
|
| 856 |
+
# Prepare latents for tokenize: [T, d] -> [1, T, d]
|
| 857 |
+
hidden_states = latents.unsqueeze(0) # [1, T, d]
|
| 858 |
+
|
| 859 |
+
# Call tokenize method
|
| 860 |
+
# tokenize returns: (quantized, indices, attention_mask)
|
| 861 |
+
_, indices, _ = self.model.tokenize(hidden_states, self.silence_latent, attention_mask.unsqueeze(0))
|
| 862 |
+
|
| 863 |
+
# Format indices as code string
|
| 864 |
+
# indices shape: [1, T_5Hz] or [1, T_5Hz, num_quantizers]
|
| 865 |
+
# Flatten and convert to list
|
| 866 |
+
indices_flat = indices.flatten().cpu().tolist()
|
| 867 |
+
codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat])
|
| 868 |
+
|
| 869 |
+
logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes")
|
| 870 |
+
return codes_string
|
| 871 |
+
|
| 872 |
+
except Exception as e:
|
| 873 |
+
error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
|
| 874 |
+
logger.error(error_msg)
|
| 875 |
+
return error_msg
|
| 876 |
|
| 877 |
def prepare_batch_data(
|
| 878 |
self,
|
|
|
|
| 2027 |
refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
|
| 2028 |
|
| 2029 |
# 2. Process source audio
|
| 2030 |
+
# If audio_code_string is provided, ignore src_audio and use codes instead
|
| 2031 |
processed_src_audio = None
|
| 2032 |
if src_audio is not None:
|
| 2033 |
+
# Check if audio codes are provided - if so, ignore src_audio
|
| 2034 |
+
if audio_code_string and str(audio_code_string).strip():
|
| 2035 |
+
logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead")
|
| 2036 |
+
else:
|
| 2037 |
+
logger.info("[generate_music] Processing source audio...")
|
| 2038 |
+
processed_src_audio = self.process_src_audio(src_audio)
|
| 2039 |
|
| 2040 |
# 3. Prepare batch data
|
| 2041 |
captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
|
acestep/llm_inference.py
CHANGED
|
@@ -11,20 +11,15 @@ from contextlib import contextmanager
|
|
| 11 |
import torch
|
| 12 |
from tqdm import tqdm
|
| 13 |
from loguru import logger
|
| 14 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
from transformers.generation.streamers import BaseStreamer
|
| 16 |
from transformers.generation.logits_process import (
|
| 17 |
LogitsProcessorList,
|
| 18 |
-
LogitsProcessor,
|
| 19 |
-
TopKLogitsWarper,
|
| 20 |
-
TopPLogitsWarper,
|
| 21 |
RepetitionPenaltyLogitsProcessor,
|
| 22 |
-
|
| 23 |
)
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
class LLMHandler:
|
| 29 |
"""5Hz LM Handler for audio code generation"""
|
| 30 |
|
|
@@ -234,16 +229,7 @@ class LLMHandler:
|
|
| 234 |
try:
|
| 235 |
from nanovllm import SamplingParams
|
| 236 |
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
formatted_prompt = self.llm_tokenizer.apply_chat_template(
|
| 240 |
-
[
|
| 241 |
-
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 242 |
-
{"role": "user", "content": prompt}
|
| 243 |
-
],
|
| 244 |
-
tokenize=False,
|
| 245 |
-
add_generation_prompt=True,
|
| 246 |
-
)
|
| 247 |
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 248 |
|
| 249 |
sampling_params = SamplingParams(
|
|
@@ -257,14 +243,7 @@ class LLMHandler:
|
|
| 257 |
# Use CFG if cfg_scale > 1.0
|
| 258 |
if cfg_scale > 1.0:
|
| 259 |
# Build unconditional prompt (user input replaced with "NO USER INPUT")
|
| 260 |
-
formatted_unconditional_prompt = self.
|
| 261 |
-
[
|
| 262 |
-
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 263 |
-
{"role": "user", "content": negative_prompt}
|
| 264 |
-
],
|
| 265 |
-
tokenize=False,
|
| 266 |
-
add_generation_prompt=True,
|
| 267 |
-
)
|
| 268 |
outputs = self.llm.generate(
|
| 269 |
[formatted_prompt],
|
| 270 |
sampling_params,
|
|
@@ -293,6 +272,53 @@ class LLMHandler:
|
|
| 293 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 294 |
return {}, "", error_msg
|
| 295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
def generate_with_5hz_lm_pt(
|
| 297 |
self,
|
| 298 |
caption: str,
|
|
@@ -306,23 +332,13 @@ class LLMHandler:
|
|
| 306 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 307 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
|
| 308 |
try:
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
formatted_prompt = self.llm_tokenizer.apply_chat_template(
|
| 312 |
-
[
|
| 313 |
-
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 314 |
-
{"role": "user", "content": prompt}
|
| 315 |
-
],
|
| 316 |
-
tokenize=False,
|
| 317 |
-
add_generation_prompt=True,
|
| 318 |
-
)
|
| 319 |
|
| 320 |
# Tokenize the prompt
|
| 321 |
inputs = self.llm_tokenizer(
|
| 322 |
formatted_prompt,
|
| 323 |
return_tensors="pt",
|
| 324 |
padding=False,
|
| 325 |
-
truncation=True,
|
| 326 |
)
|
| 327 |
|
| 328 |
# Generate with the model
|
|
@@ -352,82 +368,90 @@ class LLMHandler:
|
|
| 352 |
|
| 353 |
streamer = TqdmTokenStreamer(total=max_new_tokens)
|
| 354 |
|
| 355 |
-
# Build logits processor list
|
| 356 |
logits_processor = LogitsProcessorList()
|
| 357 |
|
| 358 |
-
# Add repetition penalty if needed
|
| 359 |
if repetition_penalty != 1.0:
|
| 360 |
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 361 |
|
| 362 |
-
# Add temperature warper if needed (temperature is handled separately in generate, but we can also use warper)
|
| 363 |
-
# Note: temperature is passed directly to generate(), but we can use TemperatureLogitsWarper for consistency
|
| 364 |
-
if temperature != 1.0:
|
| 365 |
-
logits_processor.append(TemperatureLogitsWarper(temperature=temperature))
|
| 366 |
-
|
| 367 |
-
# Add top-k warper if specified
|
| 368 |
-
if top_k is not None and top_k > 0:
|
| 369 |
-
logits_processor.append(TopKLogitsWarper(top_k=top_k))
|
| 370 |
-
|
| 371 |
-
# Add top-p warper if specified
|
| 372 |
-
if top_p is not None and top_p > 0.0 and top_p < 1.0:
|
| 373 |
-
logits_processor.append(TopPLogitsWarper(top_p=top_p))
|
| 374 |
-
|
| 375 |
# Handle CFG if cfg_scale > 1.0
|
| 376 |
if cfg_scale > 1.0:
|
| 377 |
# Build unconditional prompt
|
| 378 |
-
formatted_unconditional_prompt = self.
|
| 379 |
-
[
|
| 380 |
-
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 381 |
-
{"role": "user", "content": negative_prompt}
|
| 382 |
-
],
|
| 383 |
-
tokenize=False,
|
| 384 |
-
add_generation_prompt=True,
|
| 385 |
-
)
|
| 386 |
|
| 387 |
-
# Tokenize
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
return_tensors="pt",
|
| 391 |
-
padding=
|
| 392 |
truncation=True,
|
| 393 |
)
|
| 394 |
-
|
|
|
|
| 395 |
|
| 396 |
-
#
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
batch_input_ids = torch.cat([inputs['input_ids'], uncond_inputs['input_ids']], dim=0)
|
| 400 |
-
batch_attention_mask = None
|
| 401 |
-
if 'attention_mask' in inputs:
|
| 402 |
-
batch_attention_mask = torch.cat([inputs['attention_mask'], uncond_inputs.get('attention_mask', torch.ones_like(uncond_inputs['input_ids']))], dim=0)
|
| 403 |
|
| 404 |
-
#
|
| 405 |
-
outputs = self.
|
| 406 |
batch_input_ids=batch_input_ids,
|
| 407 |
batch_attention_mask=batch_attention_mask,
|
| 408 |
max_new_tokens=max_new_tokens,
|
| 409 |
temperature=temperature,
|
| 410 |
cfg_scale=cfg_scale,
|
| 411 |
-
|
|
|
|
|
|
|
| 412 |
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 413 |
streamer=streamer,
|
| 414 |
)
|
|
|
|
|
|
|
|
|
|
| 415 |
else:
|
| 416 |
-
# Generate without CFG
|
| 417 |
with torch.no_grad():
|
| 418 |
outputs = self.llm.generate(
|
| 419 |
**inputs,
|
| 420 |
max_new_tokens=max_new_tokens,
|
| 421 |
temperature=temperature if temperature > 0 else 1.0,
|
| 422 |
do_sample=True if temperature > 0 else False,
|
|
|
|
|
|
|
| 423 |
logits_processor=logits_processor if len(logits_processor) > 0 else None,
|
| 424 |
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 425 |
streamer=streamer,
|
| 426 |
)
|
| 427 |
|
| 428 |
# Decode the generated tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
# Only decode the newly generated tokens (skip the input prompt)
|
| 430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 432 |
|
| 433 |
metadata, audio_codes = self.parse_lm_output(output_text)
|
|
@@ -436,8 +460,120 @@ class LLMHandler:
|
|
| 436 |
|
| 437 |
except Exception as e:
|
| 438 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
|
|
|
| 439 |
return {}, "", error_msg
|
| 440 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 441 |
def generate_with_5hz_lm(
|
| 442 |
self,
|
| 443 |
caption: str,
|
|
@@ -474,103 +610,115 @@ class LLMHandler:
|
|
| 474 |
caption, lyrics, temperature, cfg_scale, negative_prompt,
|
| 475 |
top_k, top_p, repetition_penalty
|
| 476 |
)
|
| 477 |
-
|
| 478 |
-
def
|
| 479 |
"""
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
Returns:
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
"""
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
if len(parts) == 2:
|
| 537 |
-
key = parts[0].strip().lower()
|
| 538 |
-
value = parts[1].strip()
|
| 539 |
-
|
| 540 |
-
if key == 'bpm':
|
| 541 |
-
try:
|
| 542 |
-
metadata['bpm'] = int(value)
|
| 543 |
-
except:
|
| 544 |
-
metadata['bpm'] = value
|
| 545 |
-
elif key == 'duration':
|
| 546 |
-
try:
|
| 547 |
-
metadata['duration'] = int(value)
|
| 548 |
-
except:
|
| 549 |
-
metadata['duration'] = value
|
| 550 |
-
elif key == 'genres':
|
| 551 |
-
metadata['genres'] = value
|
| 552 |
-
elif key == 'keyscale':
|
| 553 |
-
metadata['keyscale'] = value
|
| 554 |
-
elif key == 'timesignature':
|
| 555 |
-
metadata['timesignature'] = value
|
| 556 |
-
|
| 557 |
-
return metadata, audio_codes
|
| 558 |
|
| 559 |
-
def
|
| 560 |
self,
|
| 561 |
batch_input_ids: torch.Tensor,
|
| 562 |
batch_attention_mask: Optional[torch.Tensor],
|
| 563 |
max_new_tokens: int,
|
| 564 |
temperature: float,
|
| 565 |
cfg_scale: float,
|
| 566 |
-
|
|
|
|
|
|
|
| 567 |
pad_token_id: int,
|
| 568 |
streamer: Optional[BaseStreamer],
|
| 569 |
) -> torch.Tensor:
|
| 570 |
"""
|
| 571 |
-
Custom generation loop
|
| 572 |
-
|
| 573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
"""
|
| 575 |
model = self.llm
|
| 576 |
device = self.device
|
|
@@ -594,6 +742,16 @@ class LLMHandler:
|
|
| 594 |
past_key_values = None
|
| 595 |
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
|
| 596 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
with torch.no_grad():
|
| 598 |
for step in range(max_new_tokens):
|
| 599 |
# Forward pass for the entire batch (conditional + unconditional)
|
|
@@ -613,22 +771,38 @@ class LLMHandler:
|
|
| 613 |
use_cache=use_cache,
|
| 614 |
)
|
| 615 |
|
| 616 |
-
# Get logits
|
| 617 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
| 618 |
|
| 619 |
# Split conditional and unconditional logits
|
| 620 |
cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
|
| 621 |
uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
|
| 622 |
|
| 623 |
-
# Apply CFG formula:
|
| 624 |
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
| 625 |
|
| 626 |
-
# Apply logits processors (
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
|
| 633 |
# Apply temperature and sample
|
| 634 |
if temperature > 0:
|
|
@@ -638,9 +812,19 @@ class LLMHandler:
|
|
| 638 |
else:
|
| 639 |
next_tokens = torch.argmax(cfg_logits, dim=-1)
|
| 640 |
|
| 641 |
-
#
|
| 642 |
-
|
| 643 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
| 645 |
model_kwargs['attention_mask'] = attention_mask
|
| 646 |
|
|
@@ -650,17 +834,99 @@ class LLMHandler:
|
|
| 650 |
|
| 651 |
# Update streamer
|
| 652 |
if streamer is not None:
|
| 653 |
-
streamer.put(
|
| 654 |
|
| 655 |
-
#
|
| 656 |
-
if
|
| 657 |
break
|
| 658 |
|
| 659 |
if streamer is not None:
|
| 660 |
streamer.end()
|
| 661 |
|
| 662 |
-
# Return
|
| 663 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
|
| 665 |
@contextmanager
|
| 666 |
def _load_model_context(self):
|
|
|
|
| 11 |
import torch
|
| 12 |
from tqdm import tqdm
|
| 13 |
from loguru import logger
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
from transformers.generation.streamers import BaseStreamer
|
| 16 |
from transformers.generation.logits_process import (
|
| 17 |
LogitsProcessorList,
|
|
|
|
|
|
|
|
|
|
| 18 |
RepetitionPenaltyLogitsProcessor,
|
| 19 |
+
LogitsProcessor,
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
| 23 |
class LLMHandler:
|
| 24 |
"""5Hz LM Handler for audio code generation"""
|
| 25 |
|
|
|
|
| 229 |
try:
|
| 230 |
from nanovllm import SamplingParams
|
| 231 |
|
| 232 |
+
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
|
| 234 |
|
| 235 |
sampling_params = SamplingParams(
|
|
|
|
| 243 |
# Use CFG if cfg_scale > 1.0
|
| 244 |
if cfg_scale > 1.0:
|
| 245 |
# Build unconditional prompt (user input replaced with "NO USER INPUT")
|
| 246 |
+
formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
outputs = self.llm.generate(
|
| 248 |
[formatted_prompt],
|
| 249 |
sampling_params,
|
|
|
|
| 272 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 273 |
return {}, "", error_msg
|
| 274 |
|
| 275 |
+
def _run_vllm_from_formatted(
|
| 276 |
+
self,
|
| 277 |
+
formatted_prompt: str,
|
| 278 |
+
temperature: float,
|
| 279 |
+
cfg_scale: float,
|
| 280 |
+
negative_prompt: str,
|
| 281 |
+
top_k: Optional[int],
|
| 282 |
+
top_p: Optional[float],
|
| 283 |
+
repetition_penalty: float,
|
| 284 |
+
) -> str:
|
| 285 |
+
"""Shared vllm path: accept prebuilt formatted prompt and return text."""
|
| 286 |
+
from nanovllm import SamplingParams
|
| 287 |
+
|
| 288 |
+
sampling_params = SamplingParams(
|
| 289 |
+
max_tokens=self.max_model_len - 64,
|
| 290 |
+
temperature=temperature,
|
| 291 |
+
cfg_scale=cfg_scale,
|
| 292 |
+
top_k=top_k,
|
| 293 |
+
top_p=top_p,
|
| 294 |
+
repetition_penalty=repetition_penalty,
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
if cfg_scale > 1.0:
|
| 298 |
+
formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
|
| 299 |
+
outputs = self.llm.generate(
|
| 300 |
+
[formatted_prompt],
|
| 301 |
+
sampling_params,
|
| 302 |
+
unconditional_prompts=[formatted_unconditional_prompt],
|
| 303 |
+
)
|
| 304 |
+
else:
|
| 305 |
+
outputs = self.llm.generate([formatted_prompt], sampling_params)
|
| 306 |
+
|
| 307 |
+
# Extract text (retain original selection order/logic)
|
| 308 |
+
if isinstance(outputs, list) and len(outputs) > 0:
|
| 309 |
+
if hasattr(outputs[0], "outputs") and len(outputs[0].outputs) > 0:
|
| 310 |
+
output_text = outputs[0].outputs[0].text
|
| 311 |
+
elif hasattr(outputs[0], "text"):
|
| 312 |
+
output_text = outputs[0].text
|
| 313 |
+
elif isinstance(outputs[0], dict) and "text" in outputs[0]:
|
| 314 |
+
output_text = outputs[0]["text"]
|
| 315 |
+
else:
|
| 316 |
+
output_text = str(outputs[0])
|
| 317 |
+
else:
|
| 318 |
+
output_text = str(outputs)
|
| 319 |
+
|
| 320 |
+
return output_text
|
| 321 |
+
|
| 322 |
def generate_with_5hz_lm_pt(
|
| 323 |
self,
|
| 324 |
caption: str,
|
|
|
|
| 332 |
) -> Tuple[Dict[str, Any], str, str]:
|
| 333 |
"""Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
|
| 334 |
try:
|
| 335 |
+
formatted_prompt = self.build_formatted_prompt(caption, lyrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
|
| 337 |
# Tokenize the prompt
|
| 338 |
inputs = self.llm_tokenizer(
|
| 339 |
formatted_prompt,
|
| 340 |
return_tensors="pt",
|
| 341 |
padding=False,
|
|
|
|
| 342 |
)
|
| 343 |
|
| 344 |
# Generate with the model
|
|
|
|
| 368 |
|
| 369 |
streamer = TqdmTokenStreamer(total=max_new_tokens)
|
| 370 |
|
| 371 |
+
# Build logits processor list (only for CFG and repetition penalty)
|
| 372 |
logits_processor = LogitsProcessorList()
|
| 373 |
|
| 374 |
+
# Add repetition penalty if needed (generate() doesn't support it natively in all versions)
|
| 375 |
if repetition_penalty != 1.0:
|
| 376 |
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
# Handle CFG if cfg_scale > 1.0
|
| 379 |
if cfg_scale > 1.0:
|
| 380 |
# Build unconditional prompt
|
| 381 |
+
formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
+
# Tokenize both prompts together to ensure same length (with left padding)
|
| 384 |
+
# Left padding is important for generation tasks
|
| 385 |
+
batch_texts = [formatted_prompt, formatted_unconditional_prompt]
|
| 386 |
+
original_padding_side = self.llm_tokenizer.padding_side
|
| 387 |
+
self.llm_tokenizer.padding_side = 'left'
|
| 388 |
+
batch_inputs = self.llm_tokenizer(
|
| 389 |
+
batch_texts,
|
| 390 |
return_tensors="pt",
|
| 391 |
+
padding=True,
|
| 392 |
truncation=True,
|
| 393 |
)
|
| 394 |
+
self.llm_tokenizer.padding_side = original_padding_side
|
| 395 |
+
batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}
|
| 396 |
|
| 397 |
+
# Extract conditional and unconditional inputs
|
| 398 |
+
batch_input_ids = batch_inputs['input_ids'] # [2, seq_len]
|
| 399 |
+
batch_attention_mask = batch_inputs.get('attention_mask', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
+
# Use custom CFG generation loop
|
| 402 |
+
outputs = self._generate_with_cfg_custom(
|
| 403 |
batch_input_ids=batch_input_ids,
|
| 404 |
batch_attention_mask=batch_attention_mask,
|
| 405 |
max_new_tokens=max_new_tokens,
|
| 406 |
temperature=temperature,
|
| 407 |
cfg_scale=cfg_scale,
|
| 408 |
+
top_k=top_k,
|
| 409 |
+
top_p=top_p,
|
| 410 |
+
repetition_penalty=repetition_penalty,
|
| 411 |
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 412 |
streamer=streamer,
|
| 413 |
)
|
| 414 |
+
|
| 415 |
+
# Extract only the conditional output (first in batch)
|
| 416 |
+
outputs = outputs[0:1] # Keep only conditional output
|
| 417 |
else:
|
| 418 |
+
# Generate without CFG using native generate() parameters
|
| 419 |
with torch.no_grad():
|
| 420 |
outputs = self.llm.generate(
|
| 421 |
**inputs,
|
| 422 |
max_new_tokens=max_new_tokens,
|
| 423 |
temperature=temperature if temperature > 0 else 1.0,
|
| 424 |
do_sample=True if temperature > 0 else False,
|
| 425 |
+
top_k=top_k if top_k is not None and top_k > 0 else None,
|
| 426 |
+
top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
|
| 427 |
logits_processor=logits_processor if len(logits_processor) > 0 else None,
|
| 428 |
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 429 |
streamer=streamer,
|
| 430 |
)
|
| 431 |
|
| 432 |
# Decode the generated tokens
|
| 433 |
+
# outputs is a tensor with shape [batch_size, seq_len], extract first sequence
|
| 434 |
+
if isinstance(outputs, torch.Tensor):
|
| 435 |
+
if outputs.dim() == 2:
|
| 436 |
+
generated_ids = outputs[0]
|
| 437 |
+
else:
|
| 438 |
+
generated_ids = outputs
|
| 439 |
+
else:
|
| 440 |
+
generated_ids = outputs[0]
|
| 441 |
+
|
| 442 |
# Only decode the newly generated tokens (skip the input prompt)
|
| 443 |
+
# Use the correct input length based on whether CFG was used
|
| 444 |
+
if cfg_scale > 1.0:
|
| 445 |
+
# In CFG case, use batch_inputs length (both sequences have same length due to padding)
|
| 446 |
+
input_length = batch_inputs['input_ids'].shape[1]
|
| 447 |
+
else:
|
| 448 |
+
input_length = inputs['input_ids'].shape[1]
|
| 449 |
+
generated_ids = generated_ids[input_length:]
|
| 450 |
+
|
| 451 |
+
# Move to CPU for decoding
|
| 452 |
+
if generated_ids.is_cuda:
|
| 453 |
+
generated_ids = generated_ids.cpu()
|
| 454 |
+
|
| 455 |
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 456 |
|
| 457 |
metadata, audio_codes = self.parse_lm_output(output_text)
|
|
|
|
| 460 |
|
| 461 |
except Exception as e:
|
| 462 |
error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
|
| 463 |
+
logger.error(error_msg)
|
| 464 |
return {}, "", error_msg
|
| 465 |
|
| 466 |
+
def _run_pt_from_formatted(
|
| 467 |
+
self,
|
| 468 |
+
formatted_prompt: str,
|
| 469 |
+
temperature: float,
|
| 470 |
+
cfg_scale: float,
|
| 471 |
+
negative_prompt: str,
|
| 472 |
+
top_k: Optional[int],
|
| 473 |
+
top_p: Optional[float],
|
| 474 |
+
repetition_penalty: float,
|
| 475 |
+
) -> str:
|
| 476 |
+
"""Shared PyTorch path: accept prebuilt formatted prompt and return text."""
|
| 477 |
+
inputs = self.llm_tokenizer(
|
| 478 |
+
formatted_prompt,
|
| 479 |
+
return_tensors="pt",
|
| 480 |
+
padding=False,
|
| 481 |
+
truncation=True,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
with self._load_model_context():
|
| 485 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 486 |
+
max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096)
|
| 487 |
+
if hasattr(self, "max_model_len"):
|
| 488 |
+
max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
|
| 489 |
+
|
| 490 |
+
# Build logits processor list (only for CFG and repetition penalty)
|
| 491 |
+
logits_processor = LogitsProcessorList()
|
| 492 |
+
|
| 493 |
+
# Add repetition penalty if needed
|
| 494 |
+
if repetition_penalty != 1.0:
|
| 495 |
+
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 496 |
+
|
| 497 |
+
if cfg_scale > 1.0:
|
| 498 |
+
formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
|
| 499 |
+
|
| 500 |
+
# Tokenize both prompts together to ensure same length (with left padding)
|
| 501 |
+
# Left padding is important for generation tasks
|
| 502 |
+
batch_texts = [formatted_prompt, formatted_unconditional_prompt]
|
| 503 |
+
original_padding_side = self.llm_tokenizer.padding_side
|
| 504 |
+
self.llm_tokenizer.padding_side = 'left'
|
| 505 |
+
batch_inputs_tokenized = self.llm_tokenizer(
|
| 506 |
+
batch_texts,
|
| 507 |
+
return_tensors="pt",
|
| 508 |
+
padding=True,
|
| 509 |
+
truncation=True,
|
| 510 |
+
)
|
| 511 |
+
self.llm_tokenizer.padding_side = original_padding_side
|
| 512 |
+
batch_inputs_tokenized = {k: v.to(self.device) for k, v in batch_inputs_tokenized.items()}
|
| 513 |
+
|
| 514 |
+
# Extract batch inputs
|
| 515 |
+
batch_input_ids = batch_inputs_tokenized['input_ids']
|
| 516 |
+
batch_attention_mask = batch_inputs_tokenized.get('attention_mask', None)
|
| 517 |
+
|
| 518 |
+
# Use custom CFG generation loop
|
| 519 |
+
outputs = self._generate_with_cfg_custom(
|
| 520 |
+
batch_input_ids=batch_input_ids,
|
| 521 |
+
batch_attention_mask=batch_attention_mask,
|
| 522 |
+
max_new_tokens=max_new_tokens,
|
| 523 |
+
temperature=temperature,
|
| 524 |
+
cfg_scale=cfg_scale,
|
| 525 |
+
top_k=top_k,
|
| 526 |
+
top_p=top_p,
|
| 527 |
+
repetition_penalty=repetition_penalty,
|
| 528 |
+
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 529 |
+
streamer=None,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# Extract only the conditional output (first in batch)
|
| 533 |
+
outputs = outputs[0:1] # Keep only conditional output
|
| 534 |
+
else:
|
| 535 |
+
# Generate without CFG using native generate() parameters
|
| 536 |
+
with torch.no_grad():
|
| 537 |
+
outputs = self.llm.generate(
|
| 538 |
+
**inputs,
|
| 539 |
+
max_new_tokens=max_new_tokens,
|
| 540 |
+
temperature=temperature if temperature > 0 else 1.0,
|
| 541 |
+
do_sample=True if temperature > 0 else False,
|
| 542 |
+
top_k=top_k if top_k is not None and top_k > 0 else None,
|
| 543 |
+
top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
|
| 544 |
+
logits_processor=logits_processor if len(logits_processor) > 0 else None,
|
| 545 |
+
pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
|
| 546 |
+
streamer=None,
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Decode the generated tokens
|
| 550 |
+
# outputs is a tensor with shape [batch_size, seq_len], extract first sequence
|
| 551 |
+
if isinstance(outputs, torch.Tensor):
|
| 552 |
+
if outputs.dim() == 2:
|
| 553 |
+
generated_ids = outputs[0]
|
| 554 |
+
else:
|
| 555 |
+
generated_ids = outputs
|
| 556 |
+
else:
|
| 557 |
+
generated_ids = outputs[0]
|
| 558 |
+
|
| 559 |
+
# Only decode the newly generated tokens (skip the input prompt)
|
| 560 |
+
# Use the original input length (before batch processing for CFG)
|
| 561 |
+
if cfg_scale > 1.0:
|
| 562 |
+
# In CFG case, we need to use the conditional input length from batch_inputs_tokenized
|
| 563 |
+
# Both sequences have the same length due to padding
|
| 564 |
+
input_length = batch_inputs_tokenized['input_ids'].shape[1]
|
| 565 |
+
else:
|
| 566 |
+
input_length = inputs["input_ids"].shape[1]
|
| 567 |
+
|
| 568 |
+
generated_ids = generated_ids[input_length:]
|
| 569 |
+
|
| 570 |
+
# Move to CPU for decoding
|
| 571 |
+
if generated_ids.is_cuda:
|
| 572 |
+
generated_ids = generated_ids.cpu()
|
| 573 |
+
|
| 574 |
+
output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
|
| 575 |
+
return output_text
|
| 576 |
+
|
| 577 |
def generate_with_5hz_lm(
|
| 578 |
self,
|
| 579 |
caption: str,
|
|
|
|
| 610 |
caption, lyrics, temperature, cfg_scale, negative_prompt,
|
| 611 |
top_k, top_p, repetition_penalty
|
| 612 |
)
|
| 613 |
+
|
| 614 |
+
def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False) -> str:
|
| 615 |
"""
|
| 616 |
+
Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
|
| 617 |
+
Raises a ValueError if the tokenizer is not initialized.
|
| 618 |
+
|
| 619 |
+
Example:
|
| 620 |
+
prompt = handler.build_formatted_prompt("calm piano", "hello world")
|
| 621 |
+
"""
|
| 622 |
+
if self.llm_tokenizer is None:
|
| 623 |
+
raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
|
| 624 |
+
if is_negative_prompt:
|
| 625 |
+
prompt = caption
|
| 626 |
+
else:
|
| 627 |
+
prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
|
| 628 |
+
return self.llm_tokenizer.apply_chat_template(
|
| 629 |
+
[
|
| 630 |
+
{"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
|
| 631 |
+
{"role": "user", "content": prompt},
|
| 632 |
+
],
|
| 633 |
+
tokenize=False,
|
| 634 |
+
add_generation_prompt=True,
|
| 635 |
+
)
|
| 636 |
+
|
| 637 |
+
def generate_from_formatted_prompt(
|
| 638 |
+
self,
|
| 639 |
+
formatted_prompt: str,
|
| 640 |
+
cfg: Optional[Dict[str, Any]] = None,
|
| 641 |
+
) -> Tuple[str, str]:
|
| 642 |
+
"""
|
| 643 |
+
Generate raw LM text output from a pre-built formatted prompt.
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
formatted_prompt: Prompt that is already formatted by `build_formatted_prompt`.
|
| 647 |
+
cfg: Optional dict supporting keys:
|
| 648 |
+
- temperature (float)
|
| 649 |
+
- cfg_scale (float)
|
| 650 |
+
- negative_prompt (str) used when cfg_scale > 1
|
| 651 |
+
- top_k (int), top_p (float), repetition_penalty (float)
|
| 652 |
+
|
| 653 |
Returns:
|
| 654 |
+
(output_text, status_message)
|
| 655 |
+
|
| 656 |
+
Example:
|
| 657 |
+
prompt = handler.build_formatted_prompt(caption, lyric)
|
| 658 |
+
text, status = handler.generate_from_formatted_prompt(prompt, {"temperature": 0.7})
|
| 659 |
"""
|
| 660 |
+
if not getattr(self, "llm_initialized", False):
|
| 661 |
+
return "", "❌ 5Hz LM not initialized. Please initialize it first."
|
| 662 |
+
if self.llm is None or self.llm_tokenizer is None:
|
| 663 |
+
return "", "❌ 5Hz LM is missing model or tokenizer."
|
| 664 |
+
|
| 665 |
+
cfg = cfg or {}
|
| 666 |
+
temperature = cfg.get("temperature", 0.6)
|
| 667 |
+
cfg_scale = cfg.get("cfg_scale", 1.0)
|
| 668 |
+
negative_prompt = cfg.get("negative_prompt", "NO USER INPUT")
|
| 669 |
+
top_k = cfg.get("top_k")
|
| 670 |
+
top_p = cfg.get("top_p")
|
| 671 |
+
repetition_penalty = cfg.get("repetition_penalty", 1.0)
|
| 672 |
+
|
| 673 |
+
try:
|
| 674 |
+
if self.llm_backend == "vllm":
|
| 675 |
+
output_text = self._run_vllm_from_formatted(
|
| 676 |
+
formatted_prompt=formatted_prompt,
|
| 677 |
+
temperature=temperature,
|
| 678 |
+
cfg_scale=cfg_scale,
|
| 679 |
+
negative_prompt=negative_prompt,
|
| 680 |
+
top_k=top_k,
|
| 681 |
+
top_p=top_p,
|
| 682 |
+
repetition_penalty=repetition_penalty,
|
| 683 |
+
)
|
| 684 |
+
return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
|
| 685 |
+
|
| 686 |
+
# PyTorch backend
|
| 687 |
+
output_text = self._run_pt_from_formatted(
|
| 688 |
+
formatted_prompt=formatted_prompt,
|
| 689 |
+
temperature=temperature,
|
| 690 |
+
cfg_scale=cfg_scale,
|
| 691 |
+
negative_prompt=negative_prompt,
|
| 692 |
+
top_k=top_k,
|
| 693 |
+
top_p=top_p,
|
| 694 |
+
repetition_penalty=repetition_penalty,
|
| 695 |
+
)
|
| 696 |
+
return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
|
| 697 |
+
|
| 698 |
+
except Exception as e:
|
| 699 |
+
return "", f"❌ Error generating from formatted prompt: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 700 |
|
| 701 |
+
def _generate_with_cfg_custom(
|
| 702 |
self,
|
| 703 |
batch_input_ids: torch.Tensor,
|
| 704 |
batch_attention_mask: Optional[torch.Tensor],
|
| 705 |
max_new_tokens: int,
|
| 706 |
temperature: float,
|
| 707 |
cfg_scale: float,
|
| 708 |
+
top_k: Optional[int],
|
| 709 |
+
top_p: Optional[float],
|
| 710 |
+
repetition_penalty: float,
|
| 711 |
pad_token_id: int,
|
| 712 |
streamer: Optional[BaseStreamer],
|
| 713 |
) -> torch.Tensor:
|
| 714 |
"""
|
| 715 |
+
Custom CFG generation loop that:
|
| 716 |
+
1. Processes both conditional and unconditional sequences in parallel
|
| 717 |
+
2. Applies CFG formula to logits
|
| 718 |
+
3. Samples tokens only for conditional sequences
|
| 719 |
+
4. Applies the same sampled tokens to both conditional and unconditional sequences
|
| 720 |
+
|
| 721 |
+
Batch format: [cond_input, uncond_input]
|
| 722 |
"""
|
| 723 |
model = self.llm
|
| 724 |
device = self.device
|
|
|
|
| 742 |
past_key_values = None
|
| 743 |
use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
|
| 744 |
|
| 745 |
+
# Get EOS token ID for stopping condition
|
| 746 |
+
eos_token_id = self.llm_tokenizer.eos_token_id
|
| 747 |
+
if eos_token_id is None:
|
| 748 |
+
eos_token_id = pad_token_id
|
| 749 |
+
|
| 750 |
+
# Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
|
| 751 |
+
logits_processor = LogitsProcessorList()
|
| 752 |
+
if repetition_penalty != 1.0:
|
| 753 |
+
logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
| 754 |
+
|
| 755 |
with torch.no_grad():
|
| 756 |
for step in range(max_new_tokens):
|
| 757 |
# Forward pass for the entire batch (conditional + unconditional)
|
|
|
|
| 771 |
use_cache=use_cache,
|
| 772 |
)
|
| 773 |
|
| 774 |
+
# Get logits for the last position
|
| 775 |
next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
|
| 776 |
|
| 777 |
# Split conditional and unconditional logits
|
| 778 |
cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
|
| 779 |
uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
|
| 780 |
|
| 781 |
+
# Apply CFG formula: cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
| 782 |
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
| 783 |
|
| 784 |
+
# Apply logits processors (repetition penalty, top-k, top-p)
|
| 785 |
+
# Get current input_ids for repetition penalty (only conditional part)
|
| 786 |
+
current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
|
| 787 |
+
for processor in logits_processor:
|
| 788 |
+
cfg_logits = processor(current_input_ids, cfg_logits)
|
| 789 |
+
|
| 790 |
+
# Apply top-k filtering
|
| 791 |
+
if top_k is not None and top_k > 0:
|
| 792 |
+
indices_to_remove = cfg_logits < torch.topk(cfg_logits, top_k)[0][..., -1, None]
|
| 793 |
+
cfg_logits[indices_to_remove] = float('-inf')
|
| 794 |
+
|
| 795 |
+
# Apply top-p (nucleus) filtering
|
| 796 |
+
if top_p is not None and 0.0 < top_p < 1.0:
|
| 797 |
+
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
| 798 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 799 |
+
# Remove tokens with cumulative probability above the threshold
|
| 800 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 801 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 802 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 803 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 804 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 805 |
+
cfg_logits[indices_to_remove] = float('-inf')
|
| 806 |
|
| 807 |
# Apply temperature and sample
|
| 808 |
if temperature > 0:
|
|
|
|
| 812 |
else:
|
| 813 |
next_tokens = torch.argmax(cfg_logits, dim=-1)
|
| 814 |
|
| 815 |
+
# Check for EOS token in conditional sequences BEFORE unsqueezing
|
| 816 |
+
# Stop if any conditional sequence generates EOS token
|
| 817 |
+
# next_tokens shape: [batch_size] (only conditional tokens)
|
| 818 |
+
should_stop = False
|
| 819 |
+
if torch.any(next_tokens == eos_token_id):
|
| 820 |
+
should_stop = True
|
| 821 |
+
elif pad_token_id is not None and pad_token_id != eos_token_id:
|
| 822 |
+
if torch.any(next_tokens == pad_token_id):
|
| 823 |
+
should_stop = True
|
| 824 |
+
|
| 825 |
+
# Apply the same sampled tokens to both conditional and unconditional sequences
|
| 826 |
+
next_tokens_unsqueezed = next_tokens.unsqueeze(1)
|
| 827 |
+
generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed.repeat(2, 1)], dim=1)
|
| 828 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
| 829 |
model_kwargs['attention_mask'] = attention_mask
|
| 830 |
|
|
|
|
| 834 |
|
| 835 |
# Update streamer
|
| 836 |
if streamer is not None:
|
| 837 |
+
streamer.put(next_tokens_unsqueezed) # Stream conditional tokens
|
| 838 |
|
| 839 |
+
# Stop generation if EOS token detected
|
| 840 |
+
if should_stop:
|
| 841 |
break
|
| 842 |
|
| 843 |
if streamer is not None:
|
| 844 |
streamer.end()
|
| 845 |
|
| 846 |
+
# Return the full batch (both conditional and unconditional)
|
| 847 |
+
# The caller will extract only the conditional output
|
| 848 |
+
return generated_ids
|
| 849 |
+
|
| 850 |
+
def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
|
| 851 |
+
"""
|
| 852 |
+
Parse LM output to extract metadata and audio codes.
|
| 853 |
+
|
| 854 |
+
Expected format:
|
| 855 |
+
<think>
|
| 856 |
+
bpm: 73
|
| 857 |
+
duration: 273
|
| 858 |
+
genres: Chinese folk
|
| 859 |
+
keyscale: G major
|
| 860 |
+
timesignature: 4
|
| 861 |
+
</think>
|
| 862 |
+
|
| 863 |
+
<|audio_code_56535|><|audio_code_62918|>...
|
| 864 |
+
|
| 865 |
+
Returns:
|
| 866 |
+
Tuple of (metadata_dict, audio_codes_string)
|
| 867 |
+
"""
|
| 868 |
+
debug_output_text = output_text.split("</think>")[0]
|
| 869 |
+
logger.debug(f"Debug output text: {debug_output_text}")
|
| 870 |
+
metadata = {}
|
| 871 |
+
audio_codes = ""
|
| 872 |
+
|
| 873 |
+
import re
|
| 874 |
+
|
| 875 |
+
# Extract audio codes - find all <|audio_code_XXX|> patterns
|
| 876 |
+
code_pattern = r'<\|audio_code_\d+\|>'
|
| 877 |
+
code_matches = re.findall(code_pattern, output_text)
|
| 878 |
+
if code_matches:
|
| 879 |
+
audio_codes = "".join(code_matches)
|
| 880 |
+
|
| 881 |
+
# Extract metadata from reasoning section
|
| 882 |
+
# Try different reasoning tag patterns
|
| 883 |
+
reasoning_patterns = [
|
| 884 |
+
r'<think>(.*?)</think>',
|
| 885 |
+
r'<think>(.*?)</think>',
|
| 886 |
+
r'<reasoning>(.*?)</reasoning>',
|
| 887 |
+
]
|
| 888 |
+
|
| 889 |
+
reasoning_text = None
|
| 890 |
+
for pattern in reasoning_patterns:
|
| 891 |
+
match = re.search(pattern, output_text, re.DOTALL)
|
| 892 |
+
if match:
|
| 893 |
+
reasoning_text = match.group(1).strip()
|
| 894 |
+
break
|
| 895 |
+
|
| 896 |
+
# If no reasoning tags found, try to parse metadata from the beginning of output
|
| 897 |
+
if not reasoning_text:
|
| 898 |
+
# Look for metadata lines before audio codes
|
| 899 |
+
lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
|
| 900 |
+
reasoning_text = lines_before_codes.strip()
|
| 901 |
+
|
| 902 |
+
# Parse metadata fields
|
| 903 |
+
if reasoning_text:
|
| 904 |
+
for line in reasoning_text.split('\n'):
|
| 905 |
+
line = line.strip()
|
| 906 |
+
if ':' in line and not line.startswith('<'):
|
| 907 |
+
parts = line.split(':', 1)
|
| 908 |
+
if len(parts) == 2:
|
| 909 |
+
key = parts[0].strip().lower()
|
| 910 |
+
value = parts[1].strip()
|
| 911 |
+
|
| 912 |
+
if key == 'bpm':
|
| 913 |
+
try:
|
| 914 |
+
metadata['bpm'] = int(value)
|
| 915 |
+
except:
|
| 916 |
+
metadata['bpm'] = value
|
| 917 |
+
elif key == 'duration':
|
| 918 |
+
try:
|
| 919 |
+
metadata['duration'] = int(value)
|
| 920 |
+
except:
|
| 921 |
+
metadata['duration'] = value
|
| 922 |
+
elif key == 'genres':
|
| 923 |
+
metadata['genres'] = value
|
| 924 |
+
elif key == 'keyscale':
|
| 925 |
+
metadata['keyscale'] = value
|
| 926 |
+
elif key == 'timesignature':
|
| 927 |
+
metadata['timesignature'] = value
|
| 928 |
+
|
| 929 |
+
return metadata, audio_codes
|
| 930 |
|
| 931 |
@contextmanager
|
| 932 |
def _load_model_context(self):
|