99
Browse files
app.py
CHANGED
|
@@ -187,11 +187,11 @@ except Exception as e:
|
|
| 187 |
GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "24"))
|
| 188 |
GRADIO_DEFAULT_CONCURRENCY = int(os.environ.get("GRADIO_DEFAULT_CONCURRENCY", "1"))
|
| 189 |
GPU_CONCURRENCY_LIMIT = int(os.environ.get("GRADIO_GPU_CONCURRENCY", "1"))
|
| 190 |
-
STREAM_MIN_CHUNK_SEC = float(os.environ.get("STREAM_MIN_CHUNK_SEC", "
|
| 191 |
|
| 192 |
|
| 193 |
class ModelManager:
|
| 194 |
-
def __init__(self, model_path: str):
|
| 195 |
import torch
|
| 196 |
from heartlib import HeartMuLaGenPipeline, HeartTranscriptorPipeline
|
| 197 |
|
|
@@ -200,7 +200,10 @@ class ModelManager:
|
|
| 200 |
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 201 |
self._gen_pipes: Dict[Tuple[str, str, str], "HeartMuLaGenPipeline"] = {}
|
| 202 |
self._transcribe_pipe: Optional["HeartTranscriptorPipeline"] = None
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
| 204 |
self.ds_inference_config = self._make_ds_inference_config()
|
| 205 |
self._HeartMuLaGenPipeline = HeartMuLaGenPipeline
|
| 206 |
self._HeartTranscriptorPipeline = HeartTranscriptorPipeline
|
|
@@ -270,16 +273,17 @@ class ModelManager:
|
|
| 270 |
return self._transcribe_pipe
|
| 271 |
|
| 272 |
|
| 273 |
-
|
| 274 |
|
| 275 |
|
| 276 |
-
def get_model_manager() -> ModelManager:
|
| 277 |
-
|
| 278 |
-
if
|
| 279 |
os.makedirs(MODEL_PATH, exist_ok=True)
|
| 280 |
download_models_if_needed(MODEL_PATH)
|
| 281 |
-
|
| 282 |
-
|
|
|
|
| 283 |
|
| 284 |
|
| 285 |
def update_tag_string(*args):
|
|
@@ -463,9 +467,11 @@ def download_transcriptor_if_needed(ckpt_dir):
|
|
| 463 |
print("")
|
| 464 |
|
| 465 |
|
| 466 |
-
def load_pipeline(model_path, version, codec_version, quant_mode):
|
| 467 |
"""Load HeartMuLa pipeline (lazy)"""
|
| 468 |
-
|
|
|
|
|
|
|
| 469 |
print(f"Using model from {model_path} on {manager.device}...")
|
| 470 |
return manager.get_gen_pipeline(version, codec_version, quant_mode)
|
| 471 |
|
|
@@ -473,7 +479,7 @@ def load_pipeline(model_path, version, codec_version, quant_mode):
|
|
| 473 |
def load_transcriptor(model_path):
|
| 474 |
"""Load HeartTranscriptor pipeline"""
|
| 475 |
download_transcriptor_if_needed(model_path)
|
| 476 |
-
manager = get_model_manager()
|
| 477 |
return manager.get_transcriptor()
|
| 478 |
|
| 479 |
|
|
@@ -492,6 +498,7 @@ def generate(
|
|
| 492 |
keep_model_loaded,
|
| 493 |
offload_mode,
|
| 494 |
backend,
|
|
|
|
| 495 |
):
|
| 496 |
"""Generate music"""
|
| 497 |
import torch
|
|
@@ -507,7 +514,7 @@ def generate(
|
|
| 507 |
if backend == "exllama_v2":
|
| 508 |
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
|
| 509 |
|
| 510 |
-
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode)
|
| 511 |
output_path = os.path.join(DATA_DIR, f"gen_{uuid.uuid4().hex}.wav")
|
| 512 |
|
| 513 |
with torch.no_grad():
|
|
@@ -540,6 +547,72 @@ def generate(
|
|
| 540 |
raise gr.Error(f"Generation error: {str(e)}")
|
| 541 |
|
| 542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
@_gpu_guard
|
| 544 |
def transcribe_audio(audio_path, task, max_new_tokens, num_beams, temperature):
|
| 545 |
"""Transcribe or translate lyrics from audio"""
|
|
@@ -590,10 +663,11 @@ def generate_music_streaming(
|
|
| 590 |
offload_mode,
|
| 591 |
backend,
|
| 592 |
chunk_frames,
|
|
|
|
| 593 |
) -> Iterator[Tuple[int, np.ndarray]]:
|
| 594 |
if backend == "exllama_v2":
|
| 595 |
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
|
| 596 |
-
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode)
|
| 597 |
max_audio_length_ms = int(duration_sec * 1000)
|
| 598 |
for chunk in pipe.stream(
|
| 599 |
{"lyrics": lyrics, "tags": tags},
|
|
@@ -629,15 +703,16 @@ def stream_generate(
|
|
| 629 |
backend,
|
| 630 |
output_format,
|
| 631 |
chunk_frames,
|
|
|
|
| 632 |
):
|
| 633 |
try:
|
| 634 |
-
min_samples = max(
|
| 635 |
buffer = []
|
| 636 |
buffered_samples = 0
|
| 637 |
last_yield_samples = 0
|
| 638 |
print(
|
| 639 |
"stream start:",
|
| 640 |
-
f"
|
| 641 |
f"duration_sec={duration_sec}",
|
| 642 |
f"chunk_frames={chunk_frames}",
|
| 643 |
)
|
|
@@ -656,23 +731,22 @@ def stream_generate(
|
|
| 656 |
offload_mode=offload_mode,
|
| 657 |
backend=backend,
|
| 658 |
chunk_frames=chunk_frames,
|
|
|
|
| 659 |
):
|
| 660 |
chunk_np = chunk_np.astype("float32", copy=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
buffer.append(chunk_np)
|
| 662 |
buffered_samples += chunk_np.shape[0]
|
| 663 |
-
print(
|
| 664 |
-
"stream buffer:",
|
| 665 |
-
f"chunk={chunk_np.shape[0]}",
|
| 666 |
-
f"buffered={buffered_samples}",
|
| 667 |
-
)
|
| 668 |
if buffered_samples - last_yield_samples < min_samples:
|
| 669 |
continue
|
| 670 |
full_audio = np.concatenate(buffer)
|
| 671 |
last_yield_samples = buffered_samples
|
| 672 |
print(f"stream yield: samples={full_audio.shape[0]}")
|
| 673 |
yield sr, full_audio
|
| 674 |
-
|
| 675 |
-
if buffer:
|
| 676 |
full_audio = np.concatenate(buffer)
|
| 677 |
print(f"stream final yield: samples={full_audio.shape[0]}")
|
| 678 |
yield 48000, full_audio
|
|
@@ -680,6 +754,41 @@ def stream_generate(
|
|
| 680 |
raise gr.Error(f"Streaming error: {str(e)}")
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()):
|
| 684 |
"""Generate lyrics using selected LLM API"""
|
| 685 |
|
|
@@ -946,9 +1055,25 @@ def create_ui():
|
|
| 946 |
5, 100, value=20, step=1, label="Streaming Chunk Frames"
|
| 947 |
)
|
| 948 |
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 952 |
cancel_state = gr.State()
|
| 953 |
|
| 954 |
with gr.Column():
|
|
@@ -1015,8 +1140,59 @@ Every day the fire burns
|
|
| 1015 |
outputs=[lyrics]
|
| 1016 |
)
|
| 1017 |
|
| 1018 |
-
|
| 1019 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1020 |
inputs=[
|
| 1021 |
lyrics,
|
| 1022 |
tags,
|
|
@@ -1037,8 +1213,8 @@ Every day the fire burns
|
|
| 1037 |
concurrency_limit=GPU_CONCURRENCY_LIMIT,
|
| 1038 |
)
|
| 1039 |
|
| 1040 |
-
stream_event =
|
| 1041 |
-
fn=
|
| 1042 |
inputs=[
|
| 1043 |
lyrics,
|
| 1044 |
tags,
|
|
|
|
| 187 |
GRADIO_QUEUE_MAX_SIZE = int(os.environ.get("GRADIO_QUEUE_MAX_SIZE", "24"))
|
| 188 |
GRADIO_DEFAULT_CONCURRENCY = int(os.environ.get("GRADIO_DEFAULT_CONCURRENCY", "1"))
|
| 189 |
GPU_CONCURRENCY_LIMIT = int(os.environ.get("GRADIO_GPU_CONCURRENCY", "1"))
|
| 190 |
+
STREAM_MIN_CHUNK_SEC = float(os.environ.get("STREAM_MIN_CHUNK_SEC", "0"))
|
| 191 |
|
| 192 |
|
| 193 |
class ModelManager:
|
| 194 |
+
def __init__(self, model_path: str, use_deepspeed_override: Optional[bool] = None):
|
| 195 |
import torch
|
| 196 |
from heartlib import HeartMuLaGenPipeline, HeartTranscriptorPipeline
|
| 197 |
|
|
|
|
| 200 |
self.dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
| 201 |
self._gen_pipes: Dict[Tuple[str, str, str], "HeartMuLaGenPipeline"] = {}
|
| 202 |
self._transcribe_pipe: Optional["HeartTranscriptorPipeline"] = None
|
| 203 |
+
if use_deepspeed_override is None:
|
| 204 |
+
self.use_deepspeed = os.getenv("USE_DEEPSPEED_INFERENCE", "0").lower() in ("1", "true", "yes")
|
| 205 |
+
else:
|
| 206 |
+
self.use_deepspeed = use_deepspeed_override
|
| 207 |
self.ds_inference_config = self._make_ds_inference_config()
|
| 208 |
self._HeartMuLaGenPipeline = HeartMuLaGenPipeline
|
| 209 |
self._HeartTranscriptorPipeline = HeartTranscriptorPipeline
|
|
|
|
| 273 |
return self._transcribe_pipe
|
| 274 |
|
| 275 |
|
| 276 |
+
model_managers: Dict[str, ModelManager] = {}
|
| 277 |
|
| 278 |
|
| 279 |
+
def get_model_manager(use_acceleration: bool) -> ModelManager:
|
| 280 |
+
key = "accelerated" if use_acceleration else "original"
|
| 281 |
+
if key not in model_managers:
|
| 282 |
os.makedirs(MODEL_PATH, exist_ok=True)
|
| 283 |
download_models_if_needed(MODEL_PATH)
|
| 284 |
+
use_deepspeed_override = None if use_acceleration else False
|
| 285 |
+
model_managers[key] = ModelManager(MODEL_PATH, use_deepspeed_override=use_deepspeed_override)
|
| 286 |
+
return model_managers[key]
|
| 287 |
|
| 288 |
|
| 289 |
def update_tag_string(*args):
|
|
|
|
| 467 |
print("")
|
| 468 |
|
| 469 |
|
| 470 |
+
def load_pipeline(model_path, version, codec_version, quant_mode, use_acceleration: bool):
|
| 471 |
"""Load HeartMuLa pipeline (lazy)"""
|
| 472 |
+
if not use_acceleration:
|
| 473 |
+
quant_mode = "none"
|
| 474 |
+
manager = get_model_manager(use_acceleration)
|
| 475 |
print(f"Using model from {model_path} on {manager.device}...")
|
| 476 |
return manager.get_gen_pipeline(version, codec_version, quant_mode)
|
| 477 |
|
|
|
|
| 479 |
def load_transcriptor(model_path):
|
| 480 |
"""Load HeartTranscriptor pipeline"""
|
| 481 |
download_transcriptor_if_needed(model_path)
|
| 482 |
+
manager = get_model_manager(use_acceleration=True)
|
| 483 |
return manager.get_transcriptor()
|
| 484 |
|
| 485 |
|
|
|
|
| 498 |
keep_model_loaded,
|
| 499 |
offload_mode,
|
| 500 |
backend,
|
| 501 |
+
use_acceleration,
|
| 502 |
):
|
| 503 |
"""Generate music"""
|
| 504 |
import torch
|
|
|
|
| 514 |
if backend == "exllama_v2":
|
| 515 |
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
|
| 516 |
|
| 517 |
+
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration)
|
| 518 |
output_path = os.path.join(DATA_DIR, f"gen_{uuid.uuid4().hex}.wav")
|
| 519 |
|
| 520 |
with torch.no_grad():
|
|
|
|
| 547 |
raise gr.Error(f"Generation error: {str(e)}")
|
| 548 |
|
| 549 |
|
| 550 |
+
def generate_original(
|
| 551 |
+
lyrics,
|
| 552 |
+
tags,
|
| 553 |
+
cfg_scale,
|
| 554 |
+
duration_sec,
|
| 555 |
+
temperature,
|
| 556 |
+
topk,
|
| 557 |
+
version,
|
| 558 |
+
codec_version,
|
| 559 |
+
quant_mode,
|
| 560 |
+
output_format,
|
| 561 |
+
keep_model_loaded,
|
| 562 |
+
offload_mode,
|
| 563 |
+
backend,
|
| 564 |
+
):
|
| 565 |
+
return generate(
|
| 566 |
+
lyrics,
|
| 567 |
+
tags,
|
| 568 |
+
cfg_scale,
|
| 569 |
+
duration_sec,
|
| 570 |
+
temperature,
|
| 571 |
+
topk,
|
| 572 |
+
version,
|
| 573 |
+
codec_version,
|
| 574 |
+
quant_mode,
|
| 575 |
+
output_format,
|
| 576 |
+
keep_model_loaded,
|
| 577 |
+
offload_mode,
|
| 578 |
+
backend,
|
| 579 |
+
False,
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
def generate_accelerated(
|
| 584 |
+
lyrics,
|
| 585 |
+
tags,
|
| 586 |
+
cfg_scale,
|
| 587 |
+
duration_sec,
|
| 588 |
+
temperature,
|
| 589 |
+
topk,
|
| 590 |
+
version,
|
| 591 |
+
codec_version,
|
| 592 |
+
quant_mode,
|
| 593 |
+
output_format,
|
| 594 |
+
keep_model_loaded,
|
| 595 |
+
offload_mode,
|
| 596 |
+
backend,
|
| 597 |
+
):
|
| 598 |
+
return generate(
|
| 599 |
+
lyrics,
|
| 600 |
+
tags,
|
| 601 |
+
cfg_scale,
|
| 602 |
+
duration_sec,
|
| 603 |
+
temperature,
|
| 604 |
+
topk,
|
| 605 |
+
version,
|
| 606 |
+
codec_version,
|
| 607 |
+
quant_mode,
|
| 608 |
+
output_format,
|
| 609 |
+
keep_model_loaded,
|
| 610 |
+
offload_mode,
|
| 611 |
+
backend,
|
| 612 |
+
True,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
@_gpu_guard
|
| 617 |
def transcribe_audio(audio_path, task, max_new_tokens, num_beams, temperature):
|
| 618 |
"""Transcribe or translate lyrics from audio"""
|
|
|
|
| 663 |
offload_mode,
|
| 664 |
backend,
|
| 665 |
chunk_frames,
|
| 666 |
+
use_acceleration,
|
| 667 |
) -> Iterator[Tuple[int, np.ndarray]]:
|
| 668 |
if backend == "exllama_v2":
|
| 669 |
raise gr.Error("ExLlamaV2 backend is not implemented yet.")
|
| 670 |
+
pipe = load_pipeline(MODEL_PATH, version, codec_version, quant_mode, use_acceleration)
|
| 671 |
max_audio_length_ms = int(duration_sec * 1000)
|
| 672 |
for chunk in pipe.stream(
|
| 673 |
{"lyrics": lyrics, "tags": tags},
|
|
|
|
| 703 |
backend,
|
| 704 |
output_format,
|
| 705 |
chunk_frames,
|
| 706 |
+
use_acceleration,
|
| 707 |
):
|
| 708 |
try:
|
| 709 |
+
min_samples = max(0, int(STREAM_MIN_CHUNK_SEC * 48000))
|
| 710 |
buffer = []
|
| 711 |
buffered_samples = 0
|
| 712 |
last_yield_samples = 0
|
| 713 |
print(
|
| 714 |
"stream start:",
|
| 715 |
+
f"min_chunk_sec={STREAM_MIN_CHUNK_SEC}",
|
| 716 |
f"duration_sec={duration_sec}",
|
| 717 |
f"chunk_frames={chunk_frames}",
|
| 718 |
)
|
|
|
|
| 731 |
offload_mode=offload_mode,
|
| 732 |
backend=backend,
|
| 733 |
chunk_frames=chunk_frames,
|
| 734 |
+
use_acceleration=use_acceleration,
|
| 735 |
):
|
| 736 |
chunk_np = chunk_np.astype("float32", copy=False)
|
| 737 |
+
if min_samples <= 0:
|
| 738 |
+
print(f"stream yield: samples={chunk_np.shape[0]}")
|
| 739 |
+
yield sr, chunk_np
|
| 740 |
+
continue
|
| 741 |
buffer.append(chunk_np)
|
| 742 |
buffered_samples += chunk_np.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
if buffered_samples - last_yield_samples < min_samples:
|
| 744 |
continue
|
| 745 |
full_audio = np.concatenate(buffer)
|
| 746 |
last_yield_samples = buffered_samples
|
| 747 |
print(f"stream yield: samples={full_audio.shape[0]}")
|
| 748 |
yield sr, full_audio
|
| 749 |
+
if min_samples > 0 and buffer:
|
|
|
|
| 750 |
full_audio = np.concatenate(buffer)
|
| 751 |
print(f"stream final yield: samples={full_audio.shape[0]}")
|
| 752 |
yield 48000, full_audio
|
|
|
|
| 754 |
raise gr.Error(f"Streaming error: {str(e)}")
|
| 755 |
|
| 756 |
|
| 757 |
+
def stream_generate_accelerated(
|
| 758 |
+
lyrics,
|
| 759 |
+
tags,
|
| 760 |
+
cfg_scale,
|
| 761 |
+
duration_sec,
|
| 762 |
+
temperature,
|
| 763 |
+
topk,
|
| 764 |
+
version,
|
| 765 |
+
codec_version,
|
| 766 |
+
quant_mode,
|
| 767 |
+
keep_model_loaded,
|
| 768 |
+
offload_mode,
|
| 769 |
+
backend,
|
| 770 |
+
output_format,
|
| 771 |
+
chunk_frames,
|
| 772 |
+
):
|
| 773 |
+
return stream_generate(
|
| 774 |
+
lyrics,
|
| 775 |
+
tags,
|
| 776 |
+
cfg_scale,
|
| 777 |
+
duration_sec,
|
| 778 |
+
temperature,
|
| 779 |
+
topk,
|
| 780 |
+
version,
|
| 781 |
+
codec_version,
|
| 782 |
+
quant_mode,
|
| 783 |
+
keep_model_loaded,
|
| 784 |
+
offload_mode,
|
| 785 |
+
backend,
|
| 786 |
+
output_format,
|
| 787 |
+
chunk_frames,
|
| 788 |
+
True,
|
| 789 |
+
)
|
| 790 |
+
|
| 791 |
+
|
| 792 |
def generate_lyrics(theme, tags, language, api_choice, api_key_input, custom_base_url, custom_model, progress=gr.Progress()):
|
| 793 |
"""Generate lyrics using selected LLM API"""
|
| 794 |
|
|
|
|
| 1055 |
5, 100, value=20, step=1, label="Streaming Chunk Frames"
|
| 1056 |
)
|
| 1057 |
|
| 1058 |
+
gr.Markdown("### 🚀 Generation")
|
| 1059 |
+
|
| 1060 |
+
generation_mode = gr.Radio(
|
| 1061 |
+
choices=["Original (No Acceleration)", "Accelerated"],
|
| 1062 |
+
value="Original (No Acceleration)",
|
| 1063 |
+
label="Generation Mode",
|
| 1064 |
+
)
|
| 1065 |
+
|
| 1066 |
+
speed_submode = gr.Radio(
|
| 1067 |
+
choices=["Standard", "Streaming"],
|
| 1068 |
+
value="Standard",
|
| 1069 |
+
label="Accelerated Options",
|
| 1070 |
+
visible=False,
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
btn_original = gr.Button("🎼 Generate Music (Original)", variant="primary", size="lg", visible=True)
|
| 1074 |
+
btn_accel = gr.Button("🎼 Generate Music (Accelerated)", variant="primary", size="lg", visible=False)
|
| 1075 |
+
btn_stream = gr.Button("🎼 Generate Music (Streaming)", variant="primary", size="lg", visible=False)
|
| 1076 |
+
cancel_stream_btn = gr.Button("Cancel Streaming", variant="secondary", size="lg", visible=False)
|
| 1077 |
cancel_state = gr.State()
|
| 1078 |
|
| 1079 |
with gr.Column():
|
|
|
|
| 1140 |
outputs=[lyrics]
|
| 1141 |
)
|
| 1142 |
|
| 1143 |
+
def update_visibility(gen_mode, spd_mode):
|
| 1144 |
+
if gen_mode == "Original (No Acceleration)":
|
| 1145 |
+
return (
|
| 1146 |
+
gr.update(visible=False), # speed_submode
|
| 1147 |
+
gr.update(visible=True), # btn_original
|
| 1148 |
+
gr.update(visible=False), # btn_accel
|
| 1149 |
+
gr.update(visible=False), # btn_stream
|
| 1150 |
+
gr.update(visible=False), # cancel_stream_btn
|
| 1151 |
+
)
|
| 1152 |
+
show_stream = spd_mode == "Streaming"
|
| 1153 |
+
return (
|
| 1154 |
+
gr.update(visible=True), # speed_submode
|
| 1155 |
+
gr.update(visible=False), # btn_original
|
| 1156 |
+
gr.update(visible=not show_stream), # btn_accel
|
| 1157 |
+
gr.update(visible=show_stream), # btn_stream
|
| 1158 |
+
gr.update(visible=show_stream), # cancel_stream_btn
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
generation_mode.change(
|
| 1162 |
+
fn=update_visibility,
|
| 1163 |
+
inputs=[generation_mode, speed_submode],
|
| 1164 |
+
outputs=[speed_submode, btn_original, btn_accel, btn_stream, cancel_stream_btn],
|
| 1165 |
+
)
|
| 1166 |
+
speed_submode.change(
|
| 1167 |
+
fn=update_visibility,
|
| 1168 |
+
inputs=[generation_mode, speed_submode],
|
| 1169 |
+
outputs=[speed_submode, btn_original, btn_accel, btn_stream, cancel_stream_btn],
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
btn_original.click(
|
| 1173 |
+
fn=generate_original,
|
| 1174 |
+
inputs=[
|
| 1175 |
+
lyrics,
|
| 1176 |
+
tags,
|
| 1177 |
+
cfg_scale,
|
| 1178 |
+
duration,
|
| 1179 |
+
temperature,
|
| 1180 |
+
topk,
|
| 1181 |
+
version,
|
| 1182 |
+
codec_version,
|
| 1183 |
+
quant_mode,
|
| 1184 |
+
output_format,
|
| 1185 |
+
keep_model_loaded,
|
| 1186 |
+
offload_mode,
|
| 1187 |
+
backend,
|
| 1188 |
+
],
|
| 1189 |
+
outputs=[output_audio_file],
|
| 1190 |
+
concurrency_id="gpu_queue",
|
| 1191 |
+
concurrency_limit=GPU_CONCURRENCY_LIMIT,
|
| 1192 |
+
)
|
| 1193 |
+
|
| 1194 |
+
btn_accel.click(
|
| 1195 |
+
fn=generate_accelerated,
|
| 1196 |
inputs=[
|
| 1197 |
lyrics,
|
| 1198 |
tags,
|
|
|
|
| 1213 |
concurrency_limit=GPU_CONCURRENCY_LIMIT,
|
| 1214 |
)
|
| 1215 |
|
| 1216 |
+
stream_event = btn_stream.click(
|
| 1217 |
+
fn=stream_generate_accelerated,
|
| 1218 |
inputs=[
|
| 1219 |
lyrics,
|
| 1220 |
tags,
|