ChuxiJ commited on
Commit
4670365
·
1 Parent(s): fcbd6fb

fix llm gen for pt

Browse files
Files changed (4) hide show
  1. .gitignore +2 -1
  2. acestep/gradio_ui.py +143 -61
  3. acestep/handler.py +115 -2
  4. acestep/llm_inference.py +441 -175
.gitignore CHANGED
@@ -214,4 +214,5 @@ checkpoints/
214
  playground.ipynb
215
  .history/
216
  upload_checkpoints.sh
217
- checkpoints.7z
 
 
214
  playground.ipynb
215
  .history/
216
  upload_checkpoints.sh
217
+ checkpoints.7z
218
+ README_old.md
acestep/gradio_ui.py CHANGED
@@ -36,6 +36,20 @@ def create_gradio_interface(dit_handler, llm_handler, dataset_handler, init_para
36
  border-radius: 5px;
37
  margin: 10px 0;
38
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  """
40
  ) as demo:
41
 
@@ -320,43 +334,46 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
320
  )
321
 
322
  # Audio uploads
323
- with gr.Accordion("🎵 Audio Uploads", open=False):
324
- with gr.Row():
 
325
  with gr.Column(scale=2):
326
  reference_audio = gr.Audio(
327
  label="Reference Audio (optional)",
328
  type="filepath",
329
  )
330
- with gr.Column(scale=8):
331
  src_audio = gr.Audio(
332
  label="Source Audio (optional)",
333
  type="filepath",
334
  )
 
 
 
 
 
 
335
 
336
- audio_code_string = gr.Textbox(
337
- label="Audio Codes (optional)",
338
- placeholder="<|audio_code_10695|><|audio_code_54246|>...",
339
- lines=4,
340
- visible=False,
341
- info="Paste precomputed audio code tokens"
342
- )
343
-
344
  # Audio Codes for text2music
345
  with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
346
- text2music_audio_code_string = gr.Textbox(
347
- label="Audio Codes",
348
- placeholder="<|audio_code_10695|><|audio_code_54246|>...",
349
- lines=6,
350
- info="Paste precomputed audio code tokens for text2music generation"
351
- )
 
 
 
 
 
 
 
 
 
 
352
 
353
- # 5Hz LM
354
- with gr.Row(visible=True) as use_5hz_lm_row:
355
- use_5hz_lm_btn = gr.Button(
356
- "Generate LM Hints",
357
- variant="secondary",
358
- size="lg",
359
- )
360
  lm_temperature = gr.Slider(
361
  label="Temperature",
362
  minimum=0.0,
@@ -364,7 +381,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
364
  value=0.85,
365
  step=0.1,
366
  scale=1,
367
- info="Temperature for 5Hz LM sampling (higher = more random, lower = more deterministic)"
368
  )
369
  lm_cfg_scale = gr.Slider(
370
  label="CFG Scale",
@@ -373,10 +390,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
373
  value=2.0,
374
  step=0.1,
375
  scale=1,
376
- info="Classifier-Free Guidance scale for 5Hz LM (1.0 = no CFG, higher = stronger guidance)"
377
  )
378
-
379
- with gr.Row():
380
  lm_top_k = gr.Slider(
381
  label="Top-K",
382
  minimum=0,
@@ -384,7 +399,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
384
  value=0,
385
  step=1,
386
  scale=1,
387
- info="Top-K sampling: consider only top K tokens (0 = disabled)"
388
  )
389
  lm_top_p = gr.Slider(
390
  label="Top-P",
@@ -393,7 +408,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
393
  value=0.9,
394
  step=0.01,
395
  scale=1,
396
- info="Top-P (nucleus) sampling: cumulative probability threshold (1.0 = disabled)"
397
  )
398
  lm_repetition_penalty = gr.Slider(
399
  label="Repetition Penalty",
@@ -402,20 +417,10 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
402
  value=1.0,
403
  step=0.01,
404
  scale=1,
405
- info="Repetition penalty: >1.0 reduces repetition, <1.0 increases it (1.0 = no penalty). For audio generation, use 1.0 or very small values (1.01-1.05) as audio tokens naturally repeat.",
406
  visible=False,
407
  )
408
 
409
- # Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
410
- lm_negative_prompt = gr.Textbox(
411
- label="Negative Prompt",
412
- value="NO USER INPUT",
413
- placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
414
- visible=True,
415
- info="Negative prompt used for Classifier-Free Guidance when CFG Scale > 1.0",
416
- lines=2
417
- )
418
-
419
  # Repainting controls
420
  with gr.Group(visible=False) as repainting_group:
421
  gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
@@ -445,12 +450,24 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
445
 
446
  # Music Caption
447
  with gr.Accordion("📝 Music Caption", open=True):
448
- captions = gr.Textbox(
449
- label="Music Caption (optional)",
450
- placeholder="A peaceful acoustic guitar melody with soft vocals...",
451
- lines=3,
452
- info="Describe the style, genre, instruments, and mood"
453
- )
 
 
 
 
 
 
 
 
 
 
 
 
454
 
455
  # Lyrics
456
  with gr.Accordion("📝 Lyrics", open=True):
@@ -468,7 +485,8 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
468
  choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
469
  value="en",
470
  label="Vocal Language (optional)",
471
- allow_custom_value=True
 
472
  )
473
  bpm = gr.Number(
474
  label="BPM (optional)",
@@ -477,15 +495,17 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
477
  info="leave empty for N/A"
478
  )
479
  key_scale = gr.Textbox(
480
- label="Key/Scale (optional)",
481
  placeholder="Leave empty for N/A",
482
  value="",
 
483
  )
484
  time_signature = gr.Dropdown(
485
  choices=["2", "3", "4", "N/A", ""],
486
  value="4",
487
  label="Time Signature (optional)",
488
- allow_custom_value=True
 
489
  )
490
  audio_duration = gr.Number(
491
  label="Audio Duration (seconds)",
@@ -497,7 +517,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
497
  )
498
  batch_size_input = gr.Number(
499
  label="Batch Size",
500
- value=1,
501
  minimum=1,
502
  maximum=8,
503
  step=1,
@@ -582,6 +602,7 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
582
  generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive)
583
 
584
  return {
 
585
  "checkpoint_dropdown": checkpoint_dropdown,
586
  "refresh_btn": refresh_btn,
587
  "config_path": config_path,
@@ -598,9 +619,10 @@ def create_generation_section(dit_handler, llm_handler, init_params=None) -> dic
598
  "instruction_display_gen": instruction_display_gen,
599
  "track_name": track_name,
600
  "complete_track_classes": complete_track_classes,
 
601
  "reference_audio": reference_audio,
602
  "src_audio": src_audio,
603
- "audio_code_string": audio_code_string,
604
  "text2music_audio_code_string": text2music_audio_code_string,
605
  "text2music_audio_codes_group": text2music_audio_codes_group,
606
  "use_5hz_lm_row": use_5hz_lm_row,
@@ -650,12 +672,22 @@ def create_results_section(dit_handler) -> dict:
650
  type="filepath",
651
  interactive=False
652
  )
 
 
 
 
 
653
  with gr.Column():
654
  generated_audio_2 = gr.Audio(
655
  label="🎵 Generated Music (Sample 2)",
656
  type="filepath",
657
  interactive=False
658
  )
 
 
 
 
 
659
 
660
  with gr.Accordion("📁 Batch Results & Generation Details", open=False):
661
  generated_audio_batch = gr.File(
@@ -680,6 +712,8 @@ def create_results_section(dit_handler) -> dict:
680
  "status_output": status_output,
681
  "generated_audio_1": generated_audio_1,
682
  "generated_audio_2": generated_audio_2,
 
 
683
  "generated_audio_batch": generated_audio_batch,
684
  "generation_info": generation_info,
685
  "align_score_1": align_score_1,
@@ -768,7 +802,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
768
 
769
  # Service initialization
770
  def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
771
- """Wrapper for service initialization, returns status and button state"""
772
  # Initialize DiT handler
773
  status, enable = dit_handler.initialize_service(
774
  checkpoint, config_path, device,
@@ -799,7 +833,11 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
799
  # Don't fail the entire initialization if LM fails, but log it
800
  # Keep enable as is (DiT initialization result) even if LM fails
801
 
802
- return status, gr.update(interactive=enable)
 
 
 
 
803
 
804
  # Update negative prompt visibility based on "Initialize 5Hz LM" checkbox
805
  def update_negative_prompt_visibility(init_llm_checked):
@@ -855,7 +893,7 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
855
  generation_section["offload_to_cpu_checkbox"],
856
  generation_section["offload_dit_to_cpu_checkbox"],
857
  ],
858
- outputs=[generation_section["init_status"], generation_section["generate_btn"]]
859
  )
860
 
861
  # Generation with progress bar
@@ -992,6 +1030,18 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
992
  ]
993
  )
994
 
 
 
 
 
 
 
 
 
 
 
 
 
995
  # Update instruction and UI visibility based on task type
996
  def update_instruction_ui(
997
  task_type_value: str,
@@ -1020,8 +1070,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1020
  else:
1021
  audio_cover_strength_label = "Audio Cover Strength"
1022
  audio_cover_strength_info = "Control how many denoising steps use cover mode"
1023
- # Show audio_code_string for cover
1024
- audio_code_visible = task_type_value == "cover"
1025
  # Show repainting controls for repaint and lego
1026
  repainting_visible = task_type_value in ["repaint", "lego"]
1027
  # Show use_5hz_lm, lm_temperature for text2music
@@ -1037,7 +1085,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1037
  gr.update(visible=complete_visible), # complete_track_classes
1038
  gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
1039
  gr.update(visible=repainting_visible), # repainting_group
1040
- gr.update(visible=audio_code_visible), # audio_code_string
1041
  gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
1042
  gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
1043
  )
@@ -1058,7 +1105,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1058
  generation_section["complete_track_classes"],
1059
  generation_section["audio_cover_strength"],
1060
  generation_section["repainting_group"],
1061
- generation_section["audio_code_string"],
1062
  generation_section["use_5hz_lm_row"],
1063
  generation_section["text2music_audio_codes_group"],
1064
  ]
@@ -1080,7 +1126,6 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1080
  generation_section["complete_track_classes"],
1081
  generation_section["audio_cover_strength"],
1082
  generation_section["repainting_group"],
1083
- generation_section["audio_code_string"],
1084
  generation_section["use_5hz_lm_row"],
1085
  generation_section["text2music_audio_codes_group"],
1086
  ]
@@ -1102,9 +1147,46 @@ def setup_event_handlers(demo, dit_handler, llm_handler, dataset_handler, datase
1102
  generation_section["complete_track_classes"],
1103
  generation_section["audio_cover_strength"],
1104
  generation_section["repainting_group"],
1105
- generation_section["audio_code_string"],
1106
  generation_section["use_5hz_lm_row"],
1107
  generation_section["text2music_audio_codes_group"],
1108
  ]
1109
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1110
 
 
36
  border-radius: 5px;
37
  margin: 10px 0;
38
  }
39
+ .lm-hints-row {
40
+ align-items: stretch;
41
+ }
42
+ .lm-hints-col {
43
+ display: flex;
44
+ }
45
+ .lm-hints-col > div {
46
+ flex: 1;
47
+ display: flex;
48
+ }
49
+ .lm-hints-btn button {
50
+ height: 100%;
51
+ width: 100%;
52
+ }
53
  """
54
  ) as demo:
55
 
 
334
  )
335
 
336
  # Audio uploads
337
+ audio_uploads_accordion = gr.Accordion("🎵 Audio Uploads", open=False)
338
+ with audio_uploads_accordion:
339
+ with gr.Row(equal_height=True):
340
  with gr.Column(scale=2):
341
  reference_audio = gr.Audio(
342
  label="Reference Audio (optional)",
343
  type="filepath",
344
  )
345
+ with gr.Column(scale=7):
346
  src_audio = gr.Audio(
347
  label="Source Audio (optional)",
348
  type="filepath",
349
  )
350
+ with gr.Column(scale=1, min_width=80):
351
+ convert_src_to_codes_btn = gr.Button(
352
+ "Convert to Codes",
353
+ variant="secondary",
354
+ size="sm"
355
+ )
356
 
 
 
 
 
 
 
 
 
357
  # Audio Codes for text2music
358
  with gr.Accordion("🎼 Audio Codes (for text2music)", open=True, visible=True) as text2music_audio_codes_group:
359
+ with gr.Row(equal_height=True, elem_classes=["lm-hints-row"]):
360
+ with gr.Column(scale=9):
361
+ text2music_audio_code_string = gr.Textbox(
362
+ label="Audio Codes",
363
+ placeholder="<|audio_code_10695|><|audio_code_54246|>...",
364
+ lines=6,
365
+ info="Paste precomputed audio code tokens for text2music generation"
366
+ )
367
+ with gr.Column(scale=3, elem_classes=["lm-hints-col"]):
368
+ with gr.Row(equal_height=True, visible=True) as use_5hz_lm_row:
369
+ use_5hz_lm_btn = gr.Button(
370
+ "Generate LM Hints",
371
+ variant="secondary",
372
+ # size="lg",
373
+ elem_classes=["lm-hints-btn"],
374
+ )
375
 
376
+ with gr.Row(equal_height=True):
 
 
 
 
 
 
377
  lm_temperature = gr.Slider(
378
  label="Temperature",
379
  minimum=0.0,
 
381
  value=0.85,
382
  step=0.1,
383
  scale=1,
384
+ info="5Hz LM temperature (higher = random)"
385
  )
386
  lm_cfg_scale = gr.Slider(
387
  label="CFG Scale",
 
390
  value=2.0,
391
  step=0.1,
392
  scale=1,
393
+ info="5Hz LM CFG (1.0 = no CFG)"
394
  )
 
 
395
  lm_top_k = gr.Slider(
396
  label="Top-K",
397
  minimum=0,
 
399
  value=0,
400
  step=1,
401
  scale=1,
402
+ info="Top-K (0 = disabled)"
403
  )
404
  lm_top_p = gr.Slider(
405
  label="Top-P",
 
408
  value=0.9,
409
  step=0.01,
410
  scale=1,
411
+ info="Top-P (1.0 = disabled)"
412
  )
413
  lm_repetition_penalty = gr.Slider(
414
  label="Repetition Penalty",
 
417
  value=1.0,
418
  step=0.01,
419
  scale=1,
420
+ info="Repetition penalty: >1.0 reduces repetition, <1.0 increases it. Use 1.0 or very small values for audio tokens.",
421
  visible=False,
422
  )
423
 
 
 
 
 
 
 
 
 
 
 
424
  # Repainting controls
425
  with gr.Group(visible=False) as repainting_group:
426
  gr.HTML("<h5>🎨 Repainting Controls (seconds) </h5>")
 
450
 
451
  # Music Caption
452
  with gr.Accordion("📝 Music Caption", open=True):
453
+ with gr.Row(equal_height=True):
454
+ captions = gr.Textbox(
455
+ label="Music Caption (optional)",
456
+ placeholder="A peaceful acoustic guitar melody with soft vocals...",
457
+ lines=3,
458
+ info="Describe the style, genre, instruments, and mood",
459
+ scale=7,
460
+ )
461
+ # Negative prompt for CFG (only visible when LM initialized and cfg_scale > 1)
462
+ lm_negative_prompt = gr.Textbox(
463
+ label="Negative Prompt",
464
+ value="NO USER INPUT",
465
+ placeholder="Enter negative prompt for CFG (default: NO USER INPUT)",
466
+ visible=True,
467
+ info="Negative prompt (use when CFG Scale > 1.0)",
468
+ lines=3,
469
+ scale=5,
470
+ )
471
 
472
  # Lyrics
473
  with gr.Accordion("📝 Lyrics", open=True):
 
485
  choices=["en", "zh", "ja", "ko", "es", "fr", "de"],
486
  value="en",
487
  label="Vocal Language (optional)",
488
+ allow_custom_value=True,
489
+ info="use `unknown` for inst"
490
  )
491
  bpm = gr.Number(
492
  label="BPM (optional)",
 
495
  info="leave empty for N/A"
496
  )
497
  key_scale = gr.Textbox(
498
+ label="KeyScale (optional)",
499
  placeholder="Leave empty for N/A",
500
  value="",
501
+ info="A-G, #/♭, major/minor"
502
  )
503
  time_signature = gr.Dropdown(
504
  choices=["2", "3", "4", "N/A", ""],
505
  value="4",
506
  label="Time Signature (optional)",
507
+ allow_custom_value=True,
508
+ info="2/4, 3/4, 4/4..."
509
  )
510
  audio_duration = gr.Number(
511
  label="Audio Duration (seconds)",
 
517
  )
518
  batch_size_input = gr.Number(
519
  label="Batch Size",
520
+ value=2,
521
  minimum=1,
522
  maximum=8,
523
  step=1,
 
602
  generate_btn = gr.Button("🎵 Generate Music", variant="primary", size="lg", interactive=generate_btn_interactive)
603
 
604
  return {
605
+ "service_config_accordion": service_config_accordion,
606
  "checkpoint_dropdown": checkpoint_dropdown,
607
  "refresh_btn": refresh_btn,
608
  "config_path": config_path,
 
619
  "instruction_display_gen": instruction_display_gen,
620
  "track_name": track_name,
621
  "complete_track_classes": complete_track_classes,
622
+ "audio_uploads_accordion": audio_uploads_accordion,
623
  "reference_audio": reference_audio,
624
  "src_audio": src_audio,
625
+ "convert_src_to_codes_btn": convert_src_to_codes_btn,
626
  "text2music_audio_code_string": text2music_audio_code_string,
627
  "text2music_audio_codes_group": text2music_audio_codes_group,
628
  "use_5hz_lm_row": use_5hz_lm_row,
 
672
  type="filepath",
673
  interactive=False
674
  )
675
+ send_to_src_btn_1 = gr.Button(
676
+ "Send To Src Audio",
677
+ variant="secondary",
678
+ size="sm"
679
+ )
680
  with gr.Column():
681
  generated_audio_2 = gr.Audio(
682
  label="🎵 Generated Music (Sample 2)",
683
  type="filepath",
684
  interactive=False
685
  )
686
+ send_to_src_btn_2 = gr.Button(
687
+ "Send To Src Audio",
688
+ variant="secondary",
689
+ size="sm"
690
+ )
691
 
692
  with gr.Accordion("📁 Batch Results & Generation Details", open=False):
693
  generated_audio_batch = gr.File(
 
712
  "status_output": status_output,
713
  "generated_audio_1": generated_audio_1,
714
  "generated_audio_2": generated_audio_2,
715
+ "send_to_src_btn_1": send_to_src_btn_1,
716
+ "send_to_src_btn_2": send_to_src_btn_2,
717
  "generated_audio_batch": generated_audio_batch,
718
  "generation_info": generation_info,
719
  "align_score_1": align_score_1,
 
802
 
803
  # Service initialization
804
  def init_service_wrapper(checkpoint, config_path, device, init_llm, lm_model_path, backend, use_flash_attention, offload_to_cpu, offload_dit_to_cpu):
805
+ """Wrapper for service initialization, returns status, button state, and accordion state"""
806
  # Initialize DiT handler
807
  status, enable = dit_handler.initialize_service(
808
  checkpoint, config_path, device,
 
833
  # Don't fail the entire initialization if LM fails, but log it
834
  # Keep enable as is (DiT initialization result) even if LM fails
835
 
836
+ # Check if model is initialized - if so, collapse the accordion
837
+ is_model_initialized = dit_handler.model is not None
838
+ accordion_state = gr.update(open=not is_model_initialized)
839
+
840
+ return status, gr.update(interactive=enable), accordion_state
841
 
842
  # Update negative prompt visibility based on "Initialize 5Hz LM" checkbox
843
  def update_negative_prompt_visibility(init_llm_checked):
 
893
  generation_section["offload_to_cpu_checkbox"],
894
  generation_section["offload_dit_to_cpu_checkbox"],
895
  ],
896
+ outputs=[generation_section["init_status"], generation_section["generate_btn"], generation_section["service_config_accordion"]]
897
  )
898
 
899
  # Generation with progress bar
 
1030
  ]
1031
  )
1032
 
1033
+ # Convert src audio to codes
1034
+ def convert_src_audio_to_codes_wrapper(src_audio):
1035
+ """Wrapper for converting src audio to codes"""
1036
+ codes_string = dit_handler.convert_src_audio_to_codes(src_audio)
1037
+ return codes_string
1038
+
1039
+ generation_section["convert_src_to_codes_btn"].click(
1040
+ fn=convert_src_audio_to_codes_wrapper,
1041
+ inputs=[generation_section["src_audio"]],
1042
+ outputs=[generation_section["text2music_audio_code_string"]]
1043
+ )
1044
+
1045
  # Update instruction and UI visibility based on task type
1046
  def update_instruction_ui(
1047
  task_type_value: str,
 
1070
  else:
1071
  audio_cover_strength_label = "Audio Cover Strength"
1072
  audio_cover_strength_info = "Control how many denoising steps use cover mode"
 
 
1073
  # Show repainting controls for repaint and lego
1074
  repainting_visible = task_type_value in ["repaint", "lego"]
1075
  # Show use_5hz_lm, lm_temperature for text2music
 
1085
  gr.update(visible=complete_visible), # complete_track_classes
1086
  gr.update(visible=audio_cover_strength_visible, label=audio_cover_strength_label, info=audio_cover_strength_info), # audio_cover_strength
1087
  gr.update(visible=repainting_visible), # repainting_group
 
1088
  gr.update(visible=use_5hz_lm_visible), # use_5hz_lm_row
1089
  gr.update(visible=text2music_audio_codes_visible), # text2music_audio_codes_group
1090
  )
 
1105
  generation_section["complete_track_classes"],
1106
  generation_section["audio_cover_strength"],
1107
  generation_section["repainting_group"],
 
1108
  generation_section["use_5hz_lm_row"],
1109
  generation_section["text2music_audio_codes_group"],
1110
  ]
 
1126
  generation_section["complete_track_classes"],
1127
  generation_section["audio_cover_strength"],
1128
  generation_section["repainting_group"],
 
1129
  generation_section["use_5hz_lm_row"],
1130
  generation_section["text2music_audio_codes_group"],
1131
  ]
 
1147
  generation_section["complete_track_classes"],
1148
  generation_section["audio_cover_strength"],
1149
  generation_section["repainting_group"],
 
1150
  generation_section["use_5hz_lm_row"],
1151
  generation_section["text2music_audio_codes_group"],
1152
  ]
1153
  )
1154
+
1155
+ # Send generated audio to src_audio
1156
+ def send_audio_to_src(audio_file):
1157
+ """Send generated audio file to src_audio input"""
1158
+ if audio_file is None:
1159
+ return None
1160
+ return audio_file
1161
+
1162
+ results_section["send_to_src_btn_1"].click(
1163
+ fn=send_audio_to_src,
1164
+ inputs=[results_section["generated_audio_1"]],
1165
+ outputs=[generation_section["src_audio"]]
1166
+ )
1167
+
1168
+ results_section["send_to_src_btn_2"].click(
1169
+ fn=send_audio_to_src,
1170
+ inputs=[results_section["generated_audio_2"]],
1171
+ outputs=[generation_section["src_audio"]]
1172
+ )
1173
+
1174
+ # Auto-expand Audio Uploads accordion when audio is uploaded
1175
+ def update_audio_uploads_accordion(reference_audio, src_audio):
1176
+ """Update Audio Uploads accordion open state based on whether audio files are present"""
1177
+ has_audio = (reference_audio is not None) or (src_audio is not None)
1178
+ return gr.update(open=has_audio)
1179
+
1180
+ # Bind to both audio components' change events
1181
+ generation_section["reference_audio"].change(
1182
+ fn=update_audio_uploads_accordion,
1183
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
1184
+ outputs=[generation_section["audio_uploads_accordion"]]
1185
+ )
1186
+
1187
+ generation_section["src_audio"].change(
1188
+ fn=update_audio_uploads_accordion,
1189
+ inputs=[generation_section["reference_audio"], generation_section["src_audio"]],
1190
+ outputs=[generation_section["audio_uploads_accordion"]]
1191
+ )
1192
 
acestep/handler.py CHANGED
@@ -504,6 +504,49 @@ class AceStepHandler:
504
 
505
  return parsed_metas
506
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
508
  """Get text hidden states from text encoder."""
509
  if self.text_tokenizer is None or self.text_encoder is None:
@@ -765,6 +808,71 @@ class AceStepHandler:
765
  except Exception as e:
766
  logger.error(f"Error processing target audio: {e}")
767
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
  def prepare_batch_data(
770
  self,
@@ -1919,10 +2027,15 @@ class AceStepHandler:
1919
  refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
1920
 
1921
  # 2. Process source audio
 
1922
  processed_src_audio = None
1923
  if src_audio is not None:
1924
- logger.info("[generate_music] Processing source audio...")
1925
- processed_src_audio = self.process_src_audio(src_audio)
 
 
 
 
1926
 
1927
  # 3. Prepare batch data
1928
  captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
 
504
 
505
  return parsed_metas
506
 
507
+ def build_dit_inputs(
508
+ self,
509
+ task: str,
510
+ instruction: Optional[str],
511
+ caption: str,
512
+ lyrics: str,
513
+ metas: Optional[Union[str, Dict[str, Any]]] = None,
514
+ vocal_language: str = "en",
515
+ ) -> Tuple[str, str]:
516
+ """
517
+ Build text inputs for the caption and lyric branches used by DiT.
518
+
519
+ Args:
520
+ task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching.
521
+ instruction: Instruction text; default fallback matches service_generate behavior.
522
+ caption: Caption string.
523
+ lyrics: Lyrics string.
524
+ metas: Metadata (str or dict); follows _parse_metas formatting.
525
+ vocal_language: Language code for lyrics section.
526
+
527
+ Returns:
528
+ (caption_input_text, lyrics_input_text)
529
+
530
+ Example:
531
+ caption_input, lyrics_input = handler.build_dit_inputs(
532
+ task="text2music",
533
+ instruction=None,
534
+ caption="A calm piano melody",
535
+ lyrics="la la la",
536
+ metas={"bpm": 90, "duration": 45},
537
+ vocal_language="en",
538
+ )
539
+ """
540
+ # Align instruction formatting with _prepare_batch
541
+ final_instruction = instruction or "Fill the audio semantic mask based on the given conditions:"
542
+ if not final_instruction.endswith(":"):
543
+ final_instruction = final_instruction + ":"
544
+
545
+ parsed_meta = self._parse_metas([metas])[0]
546
+ caption_input = SFT_GEN_PROMPT.format(final_instruction, caption, parsed_meta)
547
+ lyrics_input = f"# Languages\n{vocal_language}\n\n# Lyric\n{lyrics}<|endoftext|>"
548
+ return caption_input, lyrics_input
549
+
550
  def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
551
  """Get text hidden states from text encoder."""
552
  if self.text_tokenizer is None or self.text_encoder is None:
 
808
  except Exception as e:
809
  logger.error(f"Error processing target audio: {e}")
810
  return None
811
+
812
+ def convert_src_audio_to_codes(self, audio_file) -> str:
813
+ """
814
+ Convert uploaded source audio to audio codes string.
815
+
816
+ Args:
817
+ audio_file: Path to audio file or None
818
+
819
+ Returns:
820
+ Formatted codes string like '<|audio_code_123|><|audio_code_456|>...' or error message
821
+ """
822
+ if audio_file is None:
823
+ return "❌ Please upload source audio first"
824
+
825
+ if self.model is None or self.vae is None:
826
+ return "❌ Model not initialized. Please initialize the service first."
827
+
828
+ try:
829
+ # Process audio file
830
+ processed_audio = self.process_src_audio(audio_file)
831
+ if processed_audio is None:
832
+ return "❌ Failed to process audio file"
833
+
834
+ # Encode audio to latents using VAE
835
+ with torch.no_grad():
836
+ with self._load_model_context("vae"):
837
+ # Prepare audio for VAE: [channels, samples] -> [1, channels, samples]
838
+ vae_input = processed_audio.unsqueeze(0).to(self.device).to(self.vae.dtype)
839
+
840
+ # Check if audio is silence
841
+ if self.is_silence(vae_input):
842
+ return "❌ Audio file appears to be silent"
843
+
844
+ # Encode to latents
845
+ latents = self.vae.encode(vae_input).latent_dist.sample()
846
+ # Cast back to model dtype
847
+ latents = latents.to(self.dtype)
848
+ # Transpose: [1, d, T] -> [1, T, d] -> [T, d]
849
+ latents = latents.squeeze(0).transpose(0, 1) # [T, d]
850
+
851
+ # Create attention mask for latents
852
+ attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device)
853
+
854
+ # Tokenize latents to get code indices
855
+ with self._load_model_context("model"):
856
+ # Prepare latents for tokenize: [T, d] -> [1, T, d]
857
+ hidden_states = latents.unsqueeze(0) # [1, T, d]
858
+
859
+ # Call tokenize method
860
+ # tokenize returns: (quantized, indices, attention_mask)
861
+ _, indices, _ = self.model.tokenize(hidden_states, self.silence_latent, attention_mask.unsqueeze(0))
862
+
863
+ # Format indices as code string
864
+ # indices shape: [1, T_5Hz] or [1, T_5Hz, num_quantizers]
865
+ # Flatten and convert to list
866
+ indices_flat = indices.flatten().cpu().tolist()
867
+ codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat])
868
+
869
+ logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes")
870
+ return codes_string
871
+
872
+ except Exception as e:
873
+ error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}"
874
+ logger.error(error_msg)
875
+ return error_msg
876
 
877
  def prepare_batch_data(
878
  self,
 
2027
  refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)]
2028
 
2029
  # 2. Process source audio
2030
+ # If audio_code_string is provided, ignore src_audio and use codes instead
2031
  processed_src_audio = None
2032
  if src_audio is not None:
2033
+ # Check if audio codes are provided - if so, ignore src_audio
2034
+ if audio_code_string and str(audio_code_string).strip():
2035
+ logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead")
2036
+ else:
2037
+ logger.info("[generate_music] Processing source audio...")
2038
+ processed_src_audio = self.process_src_audio(src_audio)
2039
 
2040
  # 3. Prepare batch data
2041
  captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data(
acestep/llm_inference.py CHANGED
@@ -11,20 +11,15 @@ from contextlib import contextmanager
11
  import torch
12
  from tqdm import tqdm
13
  from loguru import logger
14
- from transformers import AutoTokenizer, AutoModelForCausalLM, ClassifierFreeGuidanceLogitsProcessor
15
  from transformers.generation.streamers import BaseStreamer
16
  from transformers.generation.logits_process import (
17
  LogitsProcessorList,
18
- LogitsProcessor,
19
- TopKLogitsWarper,
20
- TopPLogitsWarper,
21
  RepetitionPenaltyLogitsProcessor,
22
- TemperatureLogitsWarper,
23
  )
24
 
25
 
26
-
27
-
28
  class LLMHandler:
29
  """5Hz LM Handler for audio code generation"""
30
 
@@ -234,16 +229,7 @@ class LLMHandler:
234
  try:
235
  from nanovllm import SamplingParams
236
 
237
- prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
238
-
239
- formatted_prompt = self.llm_tokenizer.apply_chat_template(
240
- [
241
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
242
- {"role": "user", "content": prompt}
243
- ],
244
- tokenize=False,
245
- add_generation_prompt=True,
246
- )
247
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
248
 
249
  sampling_params = SamplingParams(
@@ -257,14 +243,7 @@ class LLMHandler:
257
  # Use CFG if cfg_scale > 1.0
258
  if cfg_scale > 1.0:
259
  # Build unconditional prompt (user input replaced with "NO USER INPUT")
260
- formatted_unconditional_prompt = self.llm_tokenizer.apply_chat_template(
261
- [
262
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
263
- {"role": "user", "content": negative_prompt}
264
- ],
265
- tokenize=False,
266
- add_generation_prompt=True,
267
- )
268
  outputs = self.llm.generate(
269
  [formatted_prompt],
270
  sampling_params,
@@ -293,6 +272,53 @@ class LLMHandler:
293
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
294
  return {}, "", error_msg
295
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
  def generate_with_5hz_lm_pt(
297
  self,
298
  caption: str,
@@ -306,23 +332,13 @@ class LLMHandler:
306
  ) -> Tuple[Dict[str, Any], str, str]:
307
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
308
  try:
309
- prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
310
-
311
- formatted_prompt = self.llm_tokenizer.apply_chat_template(
312
- [
313
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
314
- {"role": "user", "content": prompt}
315
- ],
316
- tokenize=False,
317
- add_generation_prompt=True,
318
- )
319
 
320
  # Tokenize the prompt
321
  inputs = self.llm_tokenizer(
322
  formatted_prompt,
323
  return_tensors="pt",
324
  padding=False,
325
- truncation=True,
326
  )
327
 
328
  # Generate with the model
@@ -352,82 +368,90 @@ class LLMHandler:
352
 
353
  streamer = TqdmTokenStreamer(total=max_new_tokens)
354
 
355
- # Build logits processor list
356
  logits_processor = LogitsProcessorList()
357
 
358
- # Add repetition penalty if needed
359
  if repetition_penalty != 1.0:
360
  logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
361
 
362
- # Add temperature warper if needed (temperature is handled separately in generate, but we can also use warper)
363
- # Note: temperature is passed directly to generate(), but we can use TemperatureLogitsWarper for consistency
364
- if temperature != 1.0:
365
- logits_processor.append(TemperatureLogitsWarper(temperature=temperature))
366
-
367
- # Add top-k warper if specified
368
- if top_k is not None and top_k > 0:
369
- logits_processor.append(TopKLogitsWarper(top_k=top_k))
370
-
371
- # Add top-p warper if specified
372
- if top_p is not None and top_p > 0.0 and top_p < 1.0:
373
- logits_processor.append(TopPLogitsWarper(top_p=top_p))
374
-
375
  # Handle CFG if cfg_scale > 1.0
376
  if cfg_scale > 1.0:
377
  # Build unconditional prompt
378
- formatted_unconditional_prompt = self.llm_tokenizer.apply_chat_template(
379
- [
380
- {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
381
- {"role": "user", "content": negative_prompt}
382
- ],
383
- tokenize=False,
384
- add_generation_prompt=True,
385
- )
386
 
387
- # Tokenize unconditional prompt
388
- uncond_inputs = self.llm_tokenizer(
389
- formatted_unconditional_prompt,
 
 
 
 
390
  return_tensors="pt",
391
- padding=False,
392
  truncation=True,
393
  )
394
- uncond_inputs = {k: v.to(self.device) for k, v in uncond_inputs.items()}
 
395
 
396
- # Use custom CFG generation with batch processing
397
- # Combine conditional and unconditional inputs into a batch
398
- # Format: [cond_input, uncond_input]
399
- batch_input_ids = torch.cat([inputs['input_ids'], uncond_inputs['input_ids']], dim=0)
400
- batch_attention_mask = None
401
- if 'attention_mask' in inputs:
402
- batch_attention_mask = torch.cat([inputs['attention_mask'], uncond_inputs.get('attention_mask', torch.ones_like(uncond_inputs['input_ids']))], dim=0)
403
 
404
- # Custom CFG generation loop
405
- outputs = self._generate_with_cfg(
406
  batch_input_ids=batch_input_ids,
407
  batch_attention_mask=batch_attention_mask,
408
  max_new_tokens=max_new_tokens,
409
  temperature=temperature,
410
  cfg_scale=cfg_scale,
411
- logits_processor=logits_processor,
 
 
412
  pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
413
  streamer=streamer,
414
  )
 
 
 
415
  else:
416
- # Generate without CFG
417
  with torch.no_grad():
418
  outputs = self.llm.generate(
419
  **inputs,
420
  max_new_tokens=max_new_tokens,
421
  temperature=temperature if temperature > 0 else 1.0,
422
  do_sample=True if temperature > 0 else False,
 
 
423
  logits_processor=logits_processor if len(logits_processor) > 0 else None,
424
  pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
425
  streamer=streamer,
426
  )
427
 
428
  # Decode the generated tokens
 
 
 
 
 
 
 
 
 
429
  # Only decode the newly generated tokens (skip the input prompt)
430
- generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
 
 
 
 
 
 
 
 
 
 
 
431
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
432
 
433
  metadata, audio_codes = self.parse_lm_output(output_text)
@@ -436,8 +460,120 @@ class LLMHandler:
436
 
437
  except Exception as e:
438
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
 
439
  return {}, "", error_msg
440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
  def generate_with_5hz_lm(
442
  self,
443
  caption: str,
@@ -474,103 +610,115 @@ class LLMHandler:
474
  caption, lyrics, temperature, cfg_scale, negative_prompt,
475
  top_k, top_p, repetition_penalty
476
  )
477
-
478
- def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
479
  """
480
- Parse LM output to extract metadata and audio codes.
481
-
482
- Expected format:
483
- <think>
484
- bpm: 73
485
- duration: 273
486
- genres: Chinese folk
487
- keyscale: G major
488
- timesignature: 4
489
- </think>
490
-
491
- <|audio_code_56535|><|audio_code_62918|>...
492
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  Returns:
494
- Tuple of (metadata_dict, audio_codes_string)
 
 
 
 
495
  """
496
- debug_output_text = output_text.split("</think>")[0]
497
- logger.debug(f"Debug output text: {debug_output_text}")
498
- metadata = {}
499
- audio_codes = ""
500
-
501
- import re
502
-
503
- # Extract audio codes - find all <|audio_code_XXX|> patterns
504
- code_pattern = r'<\|audio_code_\d+\|>'
505
- code_matches = re.findall(code_pattern, output_text)
506
- if code_matches:
507
- audio_codes = "".join(code_matches)
508
-
509
- # Extract metadata from reasoning section
510
- # Try different reasoning tag patterns
511
- reasoning_patterns = [
512
- r'<think>(.*?)</think>',
513
- r'<think>(.*?)</think>',
514
- r'<reasoning>(.*?)</reasoning>',
515
- ]
516
-
517
- reasoning_text = None
518
- for pattern in reasoning_patterns:
519
- match = re.search(pattern, output_text, re.DOTALL)
520
- if match:
521
- reasoning_text = match.group(1).strip()
522
- break
523
-
524
- # If no reasoning tags found, try to parse metadata from the beginning of output
525
- if not reasoning_text:
526
- # Look for metadata lines before audio codes
527
- lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
528
- reasoning_text = lines_before_codes.strip()
529
-
530
- # Parse metadata fields
531
- if reasoning_text:
532
- for line in reasoning_text.split('\n'):
533
- line = line.strip()
534
- if ':' in line and not line.startswith('<'):
535
- parts = line.split(':', 1)
536
- if len(parts) == 2:
537
- key = parts[0].strip().lower()
538
- value = parts[1].strip()
539
-
540
- if key == 'bpm':
541
- try:
542
- metadata['bpm'] = int(value)
543
- except:
544
- metadata['bpm'] = value
545
- elif key == 'duration':
546
- try:
547
- metadata['duration'] = int(value)
548
- except:
549
- metadata['duration'] = value
550
- elif key == 'genres':
551
- metadata['genres'] = value
552
- elif key == 'keyscale':
553
- metadata['keyscale'] = value
554
- elif key == 'timesignature':
555
- metadata['timesignature'] = value
556
-
557
- return metadata, audio_codes
558
 
559
- def _generate_with_cfg(
560
  self,
561
  batch_input_ids: torch.Tensor,
562
  batch_attention_mask: Optional[torch.Tensor],
563
  max_new_tokens: int,
564
  temperature: float,
565
  cfg_scale: float,
566
- logits_processor: Optional[LogitsProcessorList],
 
 
567
  pad_token_id: int,
568
  streamer: Optional[BaseStreamer],
569
  ) -> torch.Tensor:
570
  """
571
- Custom generation loop with CFG support using batch processing.
572
- Batch format: [conditional_input, unconditional_input]
573
- This properly utilizes KV cache by processing both sequences in parallel.
 
 
 
 
574
  """
575
  model = self.llm
576
  device = self.device
@@ -594,6 +742,16 @@ class LLMHandler:
594
  past_key_values = None
595
  use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
596
 
 
 
 
 
 
 
 
 
 
 
597
  with torch.no_grad():
598
  for step in range(max_new_tokens):
599
  # Forward pass for the entire batch (conditional + unconditional)
@@ -613,22 +771,38 @@ class LLMHandler:
613
  use_cache=use_cache,
614
  )
615
 
616
- # Get logits
617
  next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
618
 
619
  # Split conditional and unconditional logits
620
  cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
621
  uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
622
 
623
- # Apply CFG formula: logits_cfg = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
624
  cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
625
 
626
- # Apply logits processors (temperature, top-k, top-p, repetition penalty)
627
- if logits_processor is not None:
628
- # Get current input_ids for repetition penalty (only conditional part)
629
- current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
630
- for processor in logits_processor:
631
- cfg_logits = processor(current_input_ids, cfg_logits)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
  # Apply temperature and sample
634
  if temperature > 0:
@@ -638,9 +812,19 @@ class LLMHandler:
638
  else:
639
  next_tokens = torch.argmax(cfg_logits, dim=-1)
640
 
641
- # Update generated sequences (apply same token to both conditional and unconditional)
642
- next_tokens = next_tokens.unsqueeze(1)
643
- generated_ids = torch.cat([generated_ids, next_tokens.repeat(2, 1)], dim=1)
 
 
 
 
 
 
 
 
 
 
644
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
645
  model_kwargs['attention_mask'] = attention_mask
646
 
@@ -650,17 +834,99 @@ class LLMHandler:
650
 
651
  # Update streamer
652
  if streamer is not None:
653
- streamer.put(next_tokens[0]) # Only stream conditional tokens
654
 
655
- # Check for EOS (simplified - you may want to check model's eos_token_id)
656
- if (next_tokens[0] == pad_token_id).all():
657
  break
658
 
659
  if streamer is not None:
660
  streamer.end()
661
 
662
- # Return only conditional output
663
- return generated_ids[cond_start_idx:cond_start_idx+batch_size]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
 
665
  @contextmanager
666
  def _load_model_context(self):
 
11
  import torch
12
  from tqdm import tqdm
13
  from loguru import logger
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
  from transformers.generation.streamers import BaseStreamer
16
  from transformers.generation.logits_process import (
17
  LogitsProcessorList,
 
 
 
18
  RepetitionPenaltyLogitsProcessor,
19
+ LogitsProcessor,
20
  )
21
 
22
 
 
 
23
  class LLMHandler:
24
  """5Hz LM Handler for audio code generation"""
25
 
 
229
  try:
230
  from nanovllm import SamplingParams
231
 
232
+ formatted_prompt = self.build_formatted_prompt(caption, lyrics)
 
 
 
 
 
 
 
 
 
233
  logger.debug(f"[debug] formatted_prompt: {formatted_prompt}")
234
 
235
  sampling_params = SamplingParams(
 
243
  # Use CFG if cfg_scale > 1.0
244
  if cfg_scale > 1.0:
245
  # Build unconditional prompt (user input replaced with "NO USER INPUT")
246
+ formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
 
 
 
 
 
 
 
247
  outputs = self.llm.generate(
248
  [formatted_prompt],
249
  sampling_params,
 
272
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
273
  return {}, "", error_msg
274
 
275
+ def _run_vllm_from_formatted(
276
+ self,
277
+ formatted_prompt: str,
278
+ temperature: float,
279
+ cfg_scale: float,
280
+ negative_prompt: str,
281
+ top_k: Optional[int],
282
+ top_p: Optional[float],
283
+ repetition_penalty: float,
284
+ ) -> str:
285
+ """Shared vllm path: accept prebuilt formatted prompt and return text."""
286
+ from nanovllm import SamplingParams
287
+
288
+ sampling_params = SamplingParams(
289
+ max_tokens=self.max_model_len - 64,
290
+ temperature=temperature,
291
+ cfg_scale=cfg_scale,
292
+ top_k=top_k,
293
+ top_p=top_p,
294
+ repetition_penalty=repetition_penalty,
295
+ )
296
+
297
+ if cfg_scale > 1.0:
298
+ formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
299
+ outputs = self.llm.generate(
300
+ [formatted_prompt],
301
+ sampling_params,
302
+ unconditional_prompts=[formatted_unconditional_prompt],
303
+ )
304
+ else:
305
+ outputs = self.llm.generate([formatted_prompt], sampling_params)
306
+
307
+ # Extract text (retain original selection order/logic)
308
+ if isinstance(outputs, list) and len(outputs) > 0:
309
+ if hasattr(outputs[0], "outputs") and len(outputs[0].outputs) > 0:
310
+ output_text = outputs[0].outputs[0].text
311
+ elif hasattr(outputs[0], "text"):
312
+ output_text = outputs[0].text
313
+ elif isinstance(outputs[0], dict) and "text" in outputs[0]:
314
+ output_text = outputs[0]["text"]
315
+ else:
316
+ output_text = str(outputs[0])
317
+ else:
318
+ output_text = str(outputs)
319
+
320
+ return output_text
321
+
322
  def generate_with_5hz_lm_pt(
323
  self,
324
  caption: str,
 
332
  ) -> Tuple[Dict[str, Any], str, str]:
333
  """Generate metadata and audio codes using 5Hz LM with PyTorch backend"""
334
  try:
335
+ formatted_prompt = self.build_formatted_prompt(caption, lyrics)
 
 
 
 
 
 
 
 
 
336
 
337
  # Tokenize the prompt
338
  inputs = self.llm_tokenizer(
339
  formatted_prompt,
340
  return_tensors="pt",
341
  padding=False,
 
342
  )
343
 
344
  # Generate with the model
 
368
 
369
  streamer = TqdmTokenStreamer(total=max_new_tokens)
370
 
371
+ # Build logits processor list (only for CFG and repetition penalty)
372
  logits_processor = LogitsProcessorList()
373
 
374
+ # Add repetition penalty if needed (generate() doesn't support it natively in all versions)
375
  if repetition_penalty != 1.0:
376
  logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  # Handle CFG if cfg_scale > 1.0
379
  if cfg_scale > 1.0:
380
  # Build unconditional prompt
381
+ formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
 
 
 
 
 
 
 
382
 
383
+ # Tokenize both prompts together to ensure same length (with left padding)
384
+ # Left padding is important for generation tasks
385
+ batch_texts = [formatted_prompt, formatted_unconditional_prompt]
386
+ original_padding_side = self.llm_tokenizer.padding_side
387
+ self.llm_tokenizer.padding_side = 'left'
388
+ batch_inputs = self.llm_tokenizer(
389
+ batch_texts,
390
  return_tensors="pt",
391
+ padding=True,
392
  truncation=True,
393
  )
394
+ self.llm_tokenizer.padding_side = original_padding_side
395
+ batch_inputs = {k: v.to(self.device) for k, v in batch_inputs.items()}
396
 
397
+ # Extract conditional and unconditional inputs
398
+ batch_input_ids = batch_inputs['input_ids'] # [2, seq_len]
399
+ batch_attention_mask = batch_inputs.get('attention_mask', None)
 
 
 
 
400
 
401
+ # Use custom CFG generation loop
402
+ outputs = self._generate_with_cfg_custom(
403
  batch_input_ids=batch_input_ids,
404
  batch_attention_mask=batch_attention_mask,
405
  max_new_tokens=max_new_tokens,
406
  temperature=temperature,
407
  cfg_scale=cfg_scale,
408
+ top_k=top_k,
409
+ top_p=top_p,
410
+ repetition_penalty=repetition_penalty,
411
  pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
412
  streamer=streamer,
413
  )
414
+
415
+ # Extract only the conditional output (first in batch)
416
+ outputs = outputs[0:1] # Keep only conditional output
417
  else:
418
+ # Generate without CFG using native generate() parameters
419
  with torch.no_grad():
420
  outputs = self.llm.generate(
421
  **inputs,
422
  max_new_tokens=max_new_tokens,
423
  temperature=temperature if temperature > 0 else 1.0,
424
  do_sample=True if temperature > 0 else False,
425
+ top_k=top_k if top_k is not None and top_k > 0 else None,
426
+ top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
427
  logits_processor=logits_processor if len(logits_processor) > 0 else None,
428
  pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
429
  streamer=streamer,
430
  )
431
 
432
  # Decode the generated tokens
433
+ # outputs is a tensor with shape [batch_size, seq_len], extract first sequence
434
+ if isinstance(outputs, torch.Tensor):
435
+ if outputs.dim() == 2:
436
+ generated_ids = outputs[0]
437
+ else:
438
+ generated_ids = outputs
439
+ else:
440
+ generated_ids = outputs[0]
441
+
442
  # Only decode the newly generated tokens (skip the input prompt)
443
+ # Use the correct input length based on whether CFG was used
444
+ if cfg_scale > 1.0:
445
+ # In CFG case, use batch_inputs length (both sequences have same length due to padding)
446
+ input_length = batch_inputs['input_ids'].shape[1]
447
+ else:
448
+ input_length = inputs['input_ids'].shape[1]
449
+ generated_ids = generated_ids[input_length:]
450
+
451
+ # Move to CPU for decoding
452
+ if generated_ids.is_cuda:
453
+ generated_ids = generated_ids.cpu()
454
+
455
  output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
456
 
457
  metadata, audio_codes = self.parse_lm_output(output_text)
 
460
 
461
  except Exception as e:
462
  error_msg = f"❌ Error generating with 5Hz LM: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
463
+ logger.error(error_msg)
464
  return {}, "", error_msg
465
 
466
+ def _run_pt_from_formatted(
467
+ self,
468
+ formatted_prompt: str,
469
+ temperature: float,
470
+ cfg_scale: float,
471
+ negative_prompt: str,
472
+ top_k: Optional[int],
473
+ top_p: Optional[float],
474
+ repetition_penalty: float,
475
+ ) -> str:
476
+ """Shared PyTorch path: accept prebuilt formatted prompt and return text."""
477
+ inputs = self.llm_tokenizer(
478
+ formatted_prompt,
479
+ return_tensors="pt",
480
+ padding=False,
481
+ truncation=True,
482
+ )
483
+
484
+ with self._load_model_context():
485
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
486
+ max_new_tokens = getattr(self.llm.config, "max_new_tokens", 4096)
487
+ if hasattr(self, "max_model_len"):
488
+ max_new_tokens = min(max_new_tokens, self.max_model_len - 64)
489
+
490
+ # Build logits processor list (only for CFG and repetition penalty)
491
+ logits_processor = LogitsProcessorList()
492
+
493
+ # Add repetition penalty if needed
494
+ if repetition_penalty != 1.0:
495
+ logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
496
+
497
+ if cfg_scale > 1.0:
498
+ formatted_unconditional_prompt = self.build_formatted_prompt(negative_prompt, is_negative_prompt=True)
499
+
500
+ # Tokenize both prompts together to ensure same length (with left padding)
501
+ # Left padding is important for generation tasks
502
+ batch_texts = [formatted_prompt, formatted_unconditional_prompt]
503
+ original_padding_side = self.llm_tokenizer.padding_side
504
+ self.llm_tokenizer.padding_side = 'left'
505
+ batch_inputs_tokenized = self.llm_tokenizer(
506
+ batch_texts,
507
+ return_tensors="pt",
508
+ padding=True,
509
+ truncation=True,
510
+ )
511
+ self.llm_tokenizer.padding_side = original_padding_side
512
+ batch_inputs_tokenized = {k: v.to(self.device) for k, v in batch_inputs_tokenized.items()}
513
+
514
+ # Extract batch inputs
515
+ batch_input_ids = batch_inputs_tokenized['input_ids']
516
+ batch_attention_mask = batch_inputs_tokenized.get('attention_mask', None)
517
+
518
+ # Use custom CFG generation loop
519
+ outputs = self._generate_with_cfg_custom(
520
+ batch_input_ids=batch_input_ids,
521
+ batch_attention_mask=batch_attention_mask,
522
+ max_new_tokens=max_new_tokens,
523
+ temperature=temperature,
524
+ cfg_scale=cfg_scale,
525
+ top_k=top_k,
526
+ top_p=top_p,
527
+ repetition_penalty=repetition_penalty,
528
+ pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
529
+ streamer=None,
530
+ )
531
+
532
+ # Extract only the conditional output (first in batch)
533
+ outputs = outputs[0:1] # Keep only conditional output
534
+ else:
535
+ # Generate without CFG using native generate() parameters
536
+ with torch.no_grad():
537
+ outputs = self.llm.generate(
538
+ **inputs,
539
+ max_new_tokens=max_new_tokens,
540
+ temperature=temperature if temperature > 0 else 1.0,
541
+ do_sample=True if temperature > 0 else False,
542
+ top_k=top_k if top_k is not None and top_k > 0 else None,
543
+ top_p=top_p if top_p is not None and 0.0 < top_p < 1.0 else None,
544
+ logits_processor=logits_processor if len(logits_processor) > 0 else None,
545
+ pad_token_id=self.llm_tokenizer.pad_token_id or self.llm_tokenizer.eos_token_id,
546
+ streamer=None,
547
+ )
548
+
549
+ # Decode the generated tokens
550
+ # outputs is a tensor with shape [batch_size, seq_len], extract first sequence
551
+ if isinstance(outputs, torch.Tensor):
552
+ if outputs.dim() == 2:
553
+ generated_ids = outputs[0]
554
+ else:
555
+ generated_ids = outputs
556
+ else:
557
+ generated_ids = outputs[0]
558
+
559
+ # Only decode the newly generated tokens (skip the input prompt)
560
+ # Use the original input length (before batch processing for CFG)
561
+ if cfg_scale > 1.0:
562
+ # In CFG case, we need to use the conditional input length from batch_inputs_tokenized
563
+ # Both sequences have the same length due to padding
564
+ input_length = batch_inputs_tokenized['input_ids'].shape[1]
565
+ else:
566
+ input_length = inputs["input_ids"].shape[1]
567
+
568
+ generated_ids = generated_ids[input_length:]
569
+
570
+ # Move to CPU for decoding
571
+ if generated_ids.is_cuda:
572
+ generated_ids = generated_ids.cpu()
573
+
574
+ output_text = self.llm_tokenizer.decode(generated_ids, skip_special_tokens=False)
575
+ return output_text
576
+
577
  def generate_with_5hz_lm(
578
  self,
579
  caption: str,
 
610
  caption, lyrics, temperature, cfg_scale, negative_prompt,
611
  top_k, top_p, repetition_penalty
612
  )
613
+
614
+ def build_formatted_prompt(self, caption: str, lyrics: str = "", is_negative_prompt: bool = False) -> str:
615
  """
616
+ Build the chat-formatted prompt for 5Hz LM from caption/lyrics.
617
+ Raises a ValueError if the tokenizer is not initialized.
618
+
619
+ Example:
620
+ prompt = handler.build_formatted_prompt("calm piano", "hello world")
621
+ """
622
+ if self.llm_tokenizer is None:
623
+ raise ValueError("LLM tokenizer is not initialized. Call initialize() first.")
624
+ if is_negative_prompt:
625
+ prompt = caption
626
+ else:
627
+ prompt = f"# Caption\n{caption}\n\n# Lyric\n{lyrics}\n"
628
+ return self.llm_tokenizer.apply_chat_template(
629
+ [
630
+ {"role": "system", "content": "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n"},
631
+ {"role": "user", "content": prompt},
632
+ ],
633
+ tokenize=False,
634
+ add_generation_prompt=True,
635
+ )
636
+
637
+ def generate_from_formatted_prompt(
638
+ self,
639
+ formatted_prompt: str,
640
+ cfg: Optional[Dict[str, Any]] = None,
641
+ ) -> Tuple[str, str]:
642
+ """
643
+ Generate raw LM text output from a pre-built formatted prompt.
644
+
645
+ Args:
646
+ formatted_prompt: Prompt that is already formatted by `build_formatted_prompt`.
647
+ cfg: Optional dict supporting keys:
648
+ - temperature (float)
649
+ - cfg_scale (float)
650
+ - negative_prompt (str) used when cfg_scale > 1
651
+ - top_k (int), top_p (float), repetition_penalty (float)
652
+
653
  Returns:
654
+ (output_text, status_message)
655
+
656
+ Example:
657
+ prompt = handler.build_formatted_prompt(caption, lyric)
658
+ text, status = handler.generate_from_formatted_prompt(prompt, {"temperature": 0.7})
659
  """
660
+ if not getattr(self, "llm_initialized", False):
661
+ return "", "❌ 5Hz LM not initialized. Please initialize it first."
662
+ if self.llm is None or self.llm_tokenizer is None:
663
+ return "", "❌ 5Hz LM is missing model or tokenizer."
664
+
665
+ cfg = cfg or {}
666
+ temperature = cfg.get("temperature", 0.6)
667
+ cfg_scale = cfg.get("cfg_scale", 1.0)
668
+ negative_prompt = cfg.get("negative_prompt", "NO USER INPUT")
669
+ top_k = cfg.get("top_k")
670
+ top_p = cfg.get("top_p")
671
+ repetition_penalty = cfg.get("repetition_penalty", 1.0)
672
+
673
+ try:
674
+ if self.llm_backend == "vllm":
675
+ output_text = self._run_vllm_from_formatted(
676
+ formatted_prompt=formatted_prompt,
677
+ temperature=temperature,
678
+ cfg_scale=cfg_scale,
679
+ negative_prompt=negative_prompt,
680
+ top_k=top_k,
681
+ top_p=top_p,
682
+ repetition_penalty=repetition_penalty,
683
+ )
684
+ return output_text, f"✅ Generated successfully (vllm) | length={len(output_text)}"
685
+
686
+ # PyTorch backend
687
+ output_text = self._run_pt_from_formatted(
688
+ formatted_prompt=formatted_prompt,
689
+ temperature=temperature,
690
+ cfg_scale=cfg_scale,
691
+ negative_prompt=negative_prompt,
692
+ top_k=top_k,
693
+ top_p=top_p,
694
+ repetition_penalty=repetition_penalty,
695
+ )
696
+ return output_text, f"✅ Generated successfully (pt) | length={len(output_text)}"
697
+
698
+ except Exception as e:
699
+ return "", f"❌ Error generating from formatted prompt: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
+ def _generate_with_cfg_custom(
702
  self,
703
  batch_input_ids: torch.Tensor,
704
  batch_attention_mask: Optional[torch.Tensor],
705
  max_new_tokens: int,
706
  temperature: float,
707
  cfg_scale: float,
708
+ top_k: Optional[int],
709
+ top_p: Optional[float],
710
+ repetition_penalty: float,
711
  pad_token_id: int,
712
  streamer: Optional[BaseStreamer],
713
  ) -> torch.Tensor:
714
  """
715
+ Custom CFG generation loop that:
716
+ 1. Processes both conditional and unconditional sequences in parallel
717
+ 2. Applies CFG formula to logits
718
+ 3. Samples tokens only for conditional sequences
719
+ 4. Applies the same sampled tokens to both conditional and unconditional sequences
720
+
721
+ Batch format: [cond_input, uncond_input]
722
  """
723
  model = self.llm
724
  device = self.device
 
742
  past_key_values = None
743
  use_cache = hasattr(model, 'generation_config') and getattr(model.generation_config, 'use_cache', True)
744
 
745
+ # Get EOS token ID for stopping condition
746
+ eos_token_id = self.llm_tokenizer.eos_token_id
747
+ if eos_token_id is None:
748
+ eos_token_id = pad_token_id
749
+
750
+ # Build logits processor for non-CFG operations (repetition penalty, top_k, top_p)
751
+ logits_processor = LogitsProcessorList()
752
+ if repetition_penalty != 1.0:
753
+ logits_processor.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
754
+
755
  with torch.no_grad():
756
  for step in range(max_new_tokens):
757
  # Forward pass for the entire batch (conditional + unconditional)
 
771
  use_cache=use_cache,
772
  )
773
 
774
+ # Get logits for the last position
775
  next_token_logits = outputs.logits[:, -1, :] # [batch_size*2, vocab_size]
776
 
777
  # Split conditional and unconditional logits
778
  cond_logits = next_token_logits[cond_start_idx:cond_start_idx+batch_size]
779
  uncond_logits = next_token_logits[uncond_start_idx:uncond_start_idx+batch_size]
780
 
781
+ # Apply CFG formula: cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
782
  cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
783
 
784
+ # Apply logits processors (repetition penalty, top-k, top-p)
785
+ # Get current input_ids for repetition penalty (only conditional part)
786
+ current_input_ids = generated_ids[cond_start_idx:cond_start_idx+batch_size]
787
+ for processor in logits_processor:
788
+ cfg_logits = processor(current_input_ids, cfg_logits)
789
+
790
+ # Apply top-k filtering
791
+ if top_k is not None and top_k > 0:
792
+ indices_to_remove = cfg_logits < torch.topk(cfg_logits, top_k)[0][..., -1, None]
793
+ cfg_logits[indices_to_remove] = float('-inf')
794
+
795
+ # Apply top-p (nucleus) filtering
796
+ if top_p is not None and 0.0 < top_p < 1.0:
797
+ sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
798
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
799
+ # Remove tokens with cumulative probability above the threshold
800
+ sorted_indices_to_remove = cumulative_probs > top_p
801
+ # Shift the indices to the right to keep also the first token above the threshold
802
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
803
+ sorted_indices_to_remove[..., 0] = 0
804
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
805
+ cfg_logits[indices_to_remove] = float('-inf')
806
 
807
  # Apply temperature and sample
808
  if temperature > 0:
 
812
  else:
813
  next_tokens = torch.argmax(cfg_logits, dim=-1)
814
 
815
+ # Check for EOS token in conditional sequences BEFORE unsqueezing
816
+ # Stop if any conditional sequence generates EOS token
817
+ # next_tokens shape: [batch_size] (only conditional tokens)
818
+ should_stop = False
819
+ if torch.any(next_tokens == eos_token_id):
820
+ should_stop = True
821
+ elif pad_token_id is not None and pad_token_id != eos_token_id:
822
+ if torch.any(next_tokens == pad_token_id):
823
+ should_stop = True
824
+
825
+ # Apply the same sampled tokens to both conditional and unconditional sequences
826
+ next_tokens_unsqueezed = next_tokens.unsqueeze(1)
827
+ generated_ids = torch.cat([generated_ids, next_tokens_unsqueezed.repeat(2, 1)], dim=1)
828
  attention_mask = torch.cat([attention_mask, torch.ones((batch_size*2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
829
  model_kwargs['attention_mask'] = attention_mask
830
 
 
834
 
835
  # Update streamer
836
  if streamer is not None:
837
+ streamer.put(next_tokens_unsqueezed) # Stream conditional tokens
838
 
839
+ # Stop generation if EOS token detected
840
+ if should_stop:
841
  break
842
 
843
  if streamer is not None:
844
  streamer.end()
845
 
846
+ # Return the full batch (both conditional and unconditional)
847
+ # The caller will extract only the conditional output
848
+ return generated_ids
849
+
850
+ def parse_lm_output(self, output_text: str) -> Tuple[Dict[str, Any], str]:
851
+ """
852
+ Parse LM output to extract metadata and audio codes.
853
+
854
+ Expected format:
855
+ <think>
856
+ bpm: 73
857
+ duration: 273
858
+ genres: Chinese folk
859
+ keyscale: G major
860
+ timesignature: 4
861
+ </think>
862
+
863
+ <|audio_code_56535|><|audio_code_62918|>...
864
+
865
+ Returns:
866
+ Tuple of (metadata_dict, audio_codes_string)
867
+ """
868
+ debug_output_text = output_text.split("</think>")[0]
869
+ logger.debug(f"Debug output text: {debug_output_text}")
870
+ metadata = {}
871
+ audio_codes = ""
872
+
873
+ import re
874
+
875
+ # Extract audio codes - find all <|audio_code_XXX|> patterns
876
+ code_pattern = r'<\|audio_code_\d+\|>'
877
+ code_matches = re.findall(code_pattern, output_text)
878
+ if code_matches:
879
+ audio_codes = "".join(code_matches)
880
+
881
+ # Extract metadata from reasoning section
882
+ # Try different reasoning tag patterns
883
+ reasoning_patterns = [
884
+ r'<think>(.*?)</think>',
885
+ r'<think>(.*?)</think>',
886
+ r'<reasoning>(.*?)</reasoning>',
887
+ ]
888
+
889
+ reasoning_text = None
890
+ for pattern in reasoning_patterns:
891
+ match = re.search(pattern, output_text, re.DOTALL)
892
+ if match:
893
+ reasoning_text = match.group(1).strip()
894
+ break
895
+
896
+ # If no reasoning tags found, try to parse metadata from the beginning of output
897
+ if not reasoning_text:
898
+ # Look for metadata lines before audio codes
899
+ lines_before_codes = output_text.split('<|audio_code_')[0] if '<|audio_code_' in output_text else output_text
900
+ reasoning_text = lines_before_codes.strip()
901
+
902
+ # Parse metadata fields
903
+ if reasoning_text:
904
+ for line in reasoning_text.split('\n'):
905
+ line = line.strip()
906
+ if ':' in line and not line.startswith('<'):
907
+ parts = line.split(':', 1)
908
+ if len(parts) == 2:
909
+ key = parts[0].strip().lower()
910
+ value = parts[1].strip()
911
+
912
+ if key == 'bpm':
913
+ try:
914
+ metadata['bpm'] = int(value)
915
+ except:
916
+ metadata['bpm'] = value
917
+ elif key == 'duration':
918
+ try:
919
+ metadata['duration'] = int(value)
920
+ except:
921
+ metadata['duration'] = value
922
+ elif key == 'genres':
923
+ metadata['genres'] = value
924
+ elif key == 'keyscale':
925
+ metadata['keyscale'] = value
926
+ elif key == 'timesignature':
927
+ metadata['timesignature'] = value
928
+
929
+ return metadata, audio_codes
930
 
931
  @contextmanager
932
  def _load_model_context(self):