ABLingss commited on
Commit
89fe941
·
1 Parent(s): 57dda4d
Files changed (1) hide show
  1. app.py +206 -30
app.py CHANGED
@@ -187,11 +187,11 @@ except Exception as e:
187
  GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "24"))
188
  GRADIO_DEFAULT_CONCURRENCY = int(os.environ.get("GRADIO_DEFAULT_CONCURRENCY", "1"))
189
  GPU_CONCURRENCY_LIMIT = int(os.environ.get("GRADIO_GPU_CONCURRENCY", "1"))
190
- STREAM_MIN_CHUNK_SEC = float(os.environ.get("STREAM_MIN_CHUNK_SEC", "29.76"))
191
 
192
 
193
  class ModelManager:
194
- def __init__(self, model_path: str):
195
  import torch
196
  from heartlib import HeartMuLaGenPipeline, HeartTranscriptorPipeline
197
 
@@ -200,7 +200,10 @@ class ModelManager:
200
  self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
201
  self._gen_pipes: Dict[Tuple[str, str, str], "HeartMuLaGenPipeline"] = {}
202
  self._transcribe_pipe: Optional["HeartTranscriptorPipeline"] = None
203
- self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes")
 
 
 
204
  self.ds_inference_config = self._make_ds_inference_config()
205
  self._HeartMuLaGenPipeline = HeartMuLaGenPipeline
206
  self._HeartTranscriptorPipeline = HeartTranscriptorPipeline
@@ -270,16 +273,17 @@ class ModelManager:
270
  return self._transcribe_pipe
271
 
272
 
273
- model_manager: Optional[ModelManager] = None
274
 
275
 
276
- def get_model_manager() -> ModelManager:
277
- global model_manager
278
- if model_manager is None:
279
  os.makedirs(MODEL_PATH, exist_ok=True)
280
  download_models_if_needed(MODEL_PATH)
281
- model_manager = ModelManager(MODEL_PATH)
282
- return model_manager
 
283
 
284
 
285
  def update_tag_string(*args):
@@ -463,9 +467,11 @@ def download_transcriptor_if_needed(ckpt_dir):
463
  print("")
464
 
465
 
466
- def load_pipeline(model_path, version, codec_version, quant_mode):
467
  """Load HeartMuLa pipeline (lazy)"""
468
- manager = get_model_manager()
 
 
469
  print(f"Using model from {model_path} on {manager.device}...")
470
  return manager.get_gen_pipeline(version, codec_version, quant_mode)
471
 
@@ -473,7 +479,7 @@ def load_pipeline(model_path, version, codec_version, quant_mode):
473
  def load_transcriptor(model_path):
474
  """Load HeartTranscriptor pipeline"""
475
  download_transcriptor_if_needed(model_path)
476
- manager = get_model_manager()
477
  return manager.get_transcriptor()
478
 
479
 
@@ -492,6 +498,7 @@ def generate(
492
  keep_model_loaded,
493
  offload_mode,
494
  backend,
 
495
  ):
496
  """Generate music"""
497
  import torch
@@ -507,7 +514,7 @@ def generate(
507
  if backend == "exllama_v2":
508
  raise gr.Error("ExLlamaV2 backend is not implemented yet.")
509
 
510
- pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode)
511
  output_path = os.path.join(DATA_DIR, f"gen_{uuid.uuid4().hex}.wav")
512
 
513
  with torch.no_grad():
@@ -540,6 +547,72 @@ def generate(
540
  raise gr.Error(f"Generation error: {str(e)}")
541
 
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  @_gpu_guard
544
  def transcribe_audio(audio_path, task, max_new_tokens, num_beams, temperature):
545
  """Transcribe or translate lyrics from audio"""
@@ -590,10 +663,11 @@ def generate_music_streaming(
590
  offload_mode,
591
  backend,
592
  chunk_frames,
 
593
  ) -> Iterator[Tuple[int, np.ndarray]]:
594
  if backend == "exllama_v2":
595
  raise gr.Error("ExLlamaV2 backend is not implemented yet.")
596
- pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode)
597
  max_audio_length_ms = int(duration_sec * 1000)
598
  for chunk in pipe.stream(
599
  {"lyrics": lyrics, "tags": tags},
@@ -629,15 +703,16 @@ def stream_generate(
629
  backend,
630
  output_format,
631
  chunk_frames,
 
632
  ):
633
  try:
634
- min_samples = max(1, int(STREAM_MIN_CHUNK_SEC * 48000))
635
  buffer = []
636
  buffered_samples = 0
637
  last_yield_samples = 0
638
  print(
639
  "stream start:",
640
- f"min_samples={min_samples}",
641
  f"duration_sec={duration_sec}",
642
  f"chunk_frames={chunk_frames}",
643
  )
@@ -656,23 +731,22 @@ def stream_generate(
656
  offload_mode=offload_mode,
657
  backend=backend,
658
  chunk_frames=chunk_frames,
 
659
  ):
660
  chunk_np = chunk_np.astype("float32", copy=False)
 
 
 
 
661
  buffer.append(chunk_np)
662
  buffered_samples += chunk_np.shape[0]
663
- print(
664
- "stream buffer:",
665
- f"chunk={chunk_np.shape[0]}",
666
- f"buffered={buffered_samples}",
667
- )
668
  if buffered_samples - last_yield_samples < min_samples:
669
  continue
670
  full_audio = np.concatenate(buffer)
671
  last_yield_samples = buffered_samples
672
  print(f"stream yield: samples={full_audio.shape[0]}")
673
  yield sr, full_audio
674
-
675
- if buffer:
676
  full_audio = np.concatenate(buffer)
677
  print(f"stream final yield: samples={full_audio.shape[0]}")
678
  yield 48000, full_audio
@@ -680,6 +754,41 @@ def stream_generate(
680
  raise gr.Error(f"Streaming error: {str(e)}")
681
 
682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()):
684
  """Generate lyrics using selected LLM API"""
685
 
@@ -946,9 +1055,25 @@ def create_ui():
946
  5, 100, value=20, step=1, label="Streaming Chunk Frames"
947
  )
948
 
949
- generate_btn = gr.Button("Generate Music", variant="primary", size="lg")
950
- stream_btn = gr.Button("Generate Music (Streaming)", variant="primary", size="lg")
951
- cancel_stream_btn = gr.Button("Cancel Streaming", variant="secondary", size="lg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
  cancel_state = gr.State()
953
 
954
  with gr.Column():
@@ -1015,8 +1140,59 @@ Every day the fire burns
1015
  outputs=[lyrics]
1016
  )
1017
 
1018
- generate_btn.click(
1019
- fn=generate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1020
  inputs=[
1021
  lyrics,
1022
  tags,
@@ -1037,8 +1213,8 @@ Every day the fire burns
1037
  concurrency_limit=GPU_CONCURRENCY_LIMIT,
1038
  )
1039
 
1040
- stream_event = stream_btn.click(
1041
- fn=stream_generate,
1042
  inputs=[
1043
  lyrics,
1044
  tags,
 
187
  GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "24"))
188
  GRADIO_DEFAULT_CONCURRENCY = int(os.environ.get("GRADIO_DEFAULT_CONCURRENCY", "1"))
189
  GPU_CONCURRENCY_LIMIT = int(os.environ.get("GRADIO_GPU_CONCURRENCY", "1"))
190
+ STREAM_MIN_CHUNK_SEC = float(os.environ.get("STREAM_MIN_CHUNK_SEC", "0"))
191
 
192
 
193
  class ModelManager:
194
+ def __init__(self, model_path: str, use_deepspeed_override: Optional[bool] = None):
195
  import torch
196
  from heartlib import HeartMuLaGenPipeline, HeartTranscriptorPipeline
197
 
 
200
  self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
201
  self._gen_pipes: Dict[Tuple[str, str, str], "HeartMuLaGenPipeline"] = {}
202
  self._transcribe_pipe: Optional["HeartTranscriptorPipeline"] = None
203
+ if use_deepspeed_override is None:
204
+ self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes")
205
+ else:
206
+ self.use_deepspeed = use_deepspeed_override
207
  self.ds_inference_config = self._make_ds_inference_config()
208
  self._HeartMuLaGenPipeline = HeartMuLaGenPipeline
209
  self._HeartTranscriptorPipeline = HeartTranscriptorPipeline
 
273
  return self._transcribe_pipe
274
 
275
 
276
+ model_managers: Dict[str, ModelManager] = {}
277
 
278
 
279
+ def get_model_manager(use_acceleration: bool) -> ModelManager:
280
+ key = "accelerated" if use_acceleration else "original"
281
+ if key not in model_managers:
282
  os.makedirs(MODEL_PATH, exist_ok=True)
283
  download_models_if_needed(MODEL_PATH)
284
+ use_deepspeed_override = None if use_acceleration else False
285
+ model_managers[key] = ModelManager(MODEL_PATH, use_deepspeed_override=use_deepspeed_override)
286
+ return model_managers[key]
287
 
288
 
289
  def update_tag_string(*args):
 
467
  print("")
468
 
469
 
470
+ def load_pipeline(model_path, version, codec_version, quant_mode, use_acceleration: bool):
471
  """Load HeartMuLa pipeline (lazy)"""
472
+ if not use_acceleration:
473
+ quant_mode = "none"
474
+ manager = get_model_manager(use_acceleration)
475
  print(f"Using model from {model_path} on {manager.device}...")
476
  return manager.get_gen_pipeline(version, codec_version, quant_mode)
477
 
 
479
  def load_transcriptor(model_path):
480
  """Load HeartTranscriptor pipeline"""
481
  download_transcriptor_if_needed(model_path)
482
+ manager = get_model_manager(use_acceleration=True)
483
  return manager.get_transcriptor()
484
 
485
 
 
498
  keep_model_loaded,
499
  offload_mode,
500
  backend,
501
+ use_acceleration,
502
  ):
503
  """Generate music"""
504
  import torch
 
514
  if backend == "exllama_v2":
515
  raise gr.Error("ExLlamaV2 backend is not implemented yet.")
516
 
517
+ pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration)
518
  output_path = os.path.join(DATA_DIR, f"gen_{uuid.uuid4().hex}.wav")
519
 
520
  with torch.no_grad():
 
547
  raise gr.Error(f"Generation error: {str(e)}")
548
 
549
 
550
+ def generate_original(
551
+ lyrics,
552
+ tags,
553
+ cfg_scale,
554
+ duration_sec,
555
+ temperature,
556
+ topk,
557
+ version,
558
+ codec_version,
559
+ quant_mode,
560
+ output_format,
561
+ keep_model_loaded,
562
+ offload_mode,
563
+ backend,
564
+ ):
565
+ return generate(
566
+ lyrics,
567
+ tags,
568
+ cfg_scale,
569
+ duration_sec,
570
+ temperature,
571
+ topk,
572
+ version,
573
+ codec_version,
574
+ quant_mode,
575
+ output_format,
576
+ keep_model_loaded,
577
+ offload_mode,
578
+ backend,
579
+ False,
580
+ )
581
+
582
+
583
+ def generate_accelerated(
584
+ lyrics,
585
+ tags,
586
+ cfg_scale,
587
+ duration_sec,
588
+ temperature,
589
+ topk,
590
+ version,
591
+ codec_version,
592
+ quant_mode,
593
+ output_format,
594
+ keep_model_loaded,
595
+ offload_mode,
596
+ backend,
597
+ ):
598
+ return generate(
599
+ lyrics,
600
+ tags,
601
+ cfg_scale,
602
+ duration_sec,
603
+ temperature,
604
+ topk,
605
+ version,
606
+ codec_version,
607
+ quant_mode,
608
+ output_format,
609
+ keep_model_loaded,
610
+ offload_mode,
611
+ backend,
612
+ True,
613
+ )
614
+
615
+
616
  @_gpu_guard
617
  def transcribe_audio(audio_path, task, max_new_tokens, num_beams, temperature):
618
  """Transcribe or translate lyrics from audio"""
 
663
  offload_mode,
664
  backend,
665
  chunk_frames,
666
+ use_acceleration,
667
  ) -> Iterator[Tuple[int, np.ndarray]]:
668
  if backend == "exllama_v2":
669
  raise gr.Error("ExLlamaV2 backend is not implemented yet.")
670
+ pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration)
671
  max_audio_length_ms = int(duration_sec * 1000)
672
  for chunk in pipe.stream(
673
  {"lyrics": lyrics, "tags": tags},
 
703
  backend,
704
  output_format,
705
  chunk_frames,
706
+ use_acceleration,
707
  ):
708
  try:
709
+ min_samples = max(0, int(STREAM_MIN_CHUNK_SEC * 48000))
710
  buffer = []
711
  buffered_samples = 0
712
  last_yield_samples = 0
713
  print(
714
  "stream start:",
715
+ f"min_chunk_sec={STREAM_MIN_CHUNK_SEC}",
716
  f"duration_sec={duration_sec}",
717
  f"chunk_frames={chunk_frames}",
718
  )
 
731
  offload_mode=offload_mode,
732
  backend=backend,
733
  chunk_frames=chunk_frames,
734
+ use_acceleration=use_acceleration,
735
  ):
736
  chunk_np = chunk_np.astype("float32", copy=False)
737
+ if min_samples <= 0:
738
+ print(f"stream yield: samples={chunk_np.shape[0]}")
739
+ yield sr, chunk_np
740
+ continue
741
  buffer.append(chunk_np)
742
  buffered_samples += chunk_np.shape[0]
 
 
 
 
 
743
  if buffered_samples - last_yield_samples < min_samples:
744
  continue
745
  full_audio = np.concatenate(buffer)
746
  last_yield_samples = buffered_samples
747
  print(f"stream yield: samples={full_audio.shape[0]}")
748
  yield sr, full_audio
749
+ if min_samples > 0 and buffer:
 
750
  full_audio = np.concatenate(buffer)
751
  print(f"stream final yield: samples={full_audio.shape[0]}")
752
  yield 48000, full_audio
 
754
  raise gr.Error(f"Streaming error: {str(e)}")
755
 
756
 
757
+ def stream_generate_accelerated(
758
+ lyrics,
759
+ tags,
760
+ cfg_scale,
761
+ duration_sec,
762
+ temperature,
763
+ topk,
764
+ version,
765
+ codec_version,
766
+ quant_mode,
767
+ keep_model_loaded,
768
+ offload_mode,
769
+ backend,
770
+ output_format,
771
+ chunk_frames,
772
+ ):
773
+ return stream_generate(
774
+ lyrics,
775
+ tags,
776
+ cfg_scale,
777
+ duration_sec,
778
+ temperature,
779
+ topk,
780
+ version,
781
+ codec_version,
782
+ quant_mode,
783
+ keep_model_loaded,
784
+ offload_mode,
785
+ backend,
786
+ output_format,
787
+ chunk_frames,
788
+ True,
789
+ )
790
+
791
+
792
  def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()):
793
  """Generate lyrics using selected LLM API"""
794
 
 
1055
  5, 100, value=20, step=1, label="Streaming Chunk Frames"
1056
  )
1057
 
1058
+ gr.Markdown("### 🚀 Generation")
1059
+
1060
+ generation_mode = gr.Radio(
1061
+ choices=["Original (No Acceleration)", "Accelerated"],
1062
+ value="Original (No Acceleration)",
1063
+ label="Generation Mode",
1064
+ )
1065
+
1066
+ speed_submode = gr.Radio(
1067
+ choices=["Standard", "Streaming"],
1068
+ value="Standard",
1069
+ label="Accelerated Options",
1070
+ visible=False,
1071
+ )
1072
+
1073
+ btn_original = gr.Button("🎼 Generate Music (Original)", variant="primary", size="lg", visible=True)
1074
+ btn_accel = gr.Button("🎼 Generate Music (Accelerated)", variant="primary", size="lg", visible=False)
1075
+ btn_stream = gr.Button("🎼 Generate Music (Streaming)", variant="primary", size="lg", visible=False)
1076
+ cancel_stream_btn = gr.Button("Cancel Streaming", variant="secondary", size="lg", visible=False)
1077
  cancel_state = gr.State()
1078
 
1079
  with gr.Column():
 
1140
  outputs=[lyrics]
1141
  )
1142
 
1143
+ def update_visibility(gen_mode, spd_mode):
1144
+ if gen_mode == "Original (No Acceleration)":
1145
+ return (
1146
+ gr.update(visible=False), # speed_submode
1147
+ gr.update(visible=True), # btn_original
1148
+ gr.update(visible=False), # btn_accel
1149
+ gr.update(visible=False), # btn_stream
1150
+ gr.update(visible=False), # cancel_stream_btn
1151
+ )
1152
+ show_stream = spd_mode == "Streaming"
1153
+ return (
1154
+ gr.update(visible=True), # speed_submode
1155
+ gr.update(visible=False), # btn_original
1156
+ gr.update(visible=not show_stream), # btn_accel
1157
+ gr.update(visible=show_stream), # btn_stream
1158
+ gr.update(visible=show_stream), # cancel_stream_btn
1159
+ )
1160
+
1161
+ generation_mode.change(
1162
+ fn=update_visibility,
1163
+ inputs=[generation_mode, speed_submode],
1164
+ outputs=[speed_submode, btn_original, btn_accel, btn_stream, cancel_stream_btn],
1165
+ )
1166
+ speed_submode.change(
1167
+ fn=update_visibility,
1168
+ inputs=[generation_mode, speed_submode],
1169
+ outputs=[speed_submode, btn_original, btn_accel, btn_stream, cancel_stream_btn],
1170
+ )
1171
+
1172
+ btn_original.click(
1173
+ fn=generate_original,
1174
+ inputs=[
1175
+ lyrics,
1176
+ tags,
1177
+ cfg_scale,
1178
+ duration,
1179
+ temperature,
1180
+ topk,
1181
+ version,
1182
+ codec_version,
1183
+ quant_mode,
1184
+ output_format,
1185
+ keep_model_loaded,
1186
+ offload_mode,
1187
+ backend,
1188
+ ],
1189
+ outputs=[output_audio_file],
1190
+ concurrency_id="gpu_queue",
1191
+ concurrency_limit=GPU_CONCURRENCY_LIMIT,
1192
+ )
1193
+
1194
+ btn_accel.click(
1195
+ fn=generate_accelerated,
1196
  inputs=[
1197
  lyrics,
1198
  tags,
 
1213
  concurrency_limit=GPU_CONCURRENCY_LIMIT,
1214
  )
1215
 
1216
+ stream_event = btn_stream.click(
1217
+ fn=stream_generate_accelerated,
1218
  inputs=[
1219
  lyrics,
1220
  tags,