Spaces:
Running
on
Zero
Running
on
Zero
add pocket-tts voice
Browse files- .gitattributes +1 -0
- app.py +163 -50
- requirements.txt +10 -1
- voice.wav +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -13,11 +13,16 @@ from transformers import AutoTokenizer
|
|
| 13 |
from ddgs import DDGS
|
| 14 |
import spaces # Import spaces early to enable ZeroGPU support
|
| 15 |
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Global event to signal cancellation from the UI thread to the generation thread
|
| 18 |
cancel_event = threading.Event()
|
| 19 |
|
| 20 |
-
access_token=os.environ
|
| 21 |
|
| 22 |
# Optional: Disable GPU visibility if you wish to force CPU usage
|
| 23 |
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
@@ -317,6 +322,74 @@ MODELS = {
|
|
| 317 |
# Global cache for pipelines to avoid re-loading.
|
| 318 |
PIPELINES = {}
|
| 319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
def load_pipeline(model_name):
|
| 321 |
"""
|
| 322 |
Load and cache a transformers pipeline for text generation.
|
|
@@ -384,26 +457,27 @@ def format_conversation(history, system_prompt, tokenizer):
|
|
| 384 |
prompt += "Assistant: "
|
| 385 |
return prompt
|
| 386 |
|
| 387 |
-
def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
|
| 388 |
# Get model size from the MODELS dict (more reliable than string parsing)
|
| 389 |
model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
|
| 390 |
-
|
| 391 |
# Only use AOT for models >= 2B parameters
|
| 392 |
use_aot = model_size >= 2
|
| 393 |
-
|
| 394 |
# Adjusted for H200 performance: faster inference, quicker compilation
|
| 395 |
base_duration = 20 if not use_aot else 40 # Reduced base times
|
| 396 |
token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
|
| 397 |
search_duration = 10 if enable_search else 0 # Reduced search time
|
| 398 |
aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
|
| 399 |
-
|
| 400 |
-
|
|
|
|
| 401 |
|
| 402 |
@spaces.GPU(duration=get_duration)
|
| 403 |
def chat_response(user_msg, chat_history, system_prompt,
|
| 404 |
enable_search, max_results, max_chars,
|
| 405 |
model_name, max_tokens, temperature,
|
| 406 |
-
top_k, top_p, repeat_penalty, search_timeout):
|
| 407 |
"""
|
| 408 |
Generates streaming chat responses, optionally with background web search.
|
| 409 |
This version includes cancellation support.
|
|
@@ -513,7 +587,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 513 |
assistant_message_started = False
|
| 514 |
|
| 515 |
# First yield contains the user message
|
| 516 |
-
yield history, debug
|
| 517 |
|
| 518 |
# Stream tokens
|
| 519 |
for chunk in streamer:
|
|
@@ -521,7 +595,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 521 |
if cancel_event.is_set():
|
| 522 |
if assistant_message_started and history and history[-1]['role'] == 'assistant':
|
| 523 |
history[-1]['content'] += " [Generation Canceled]"
|
| 524 |
-
yield history, debug
|
| 525 |
break
|
| 526 |
|
| 527 |
text = chunk
|
|
@@ -541,7 +615,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 541 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 542 |
else:
|
| 543 |
history[-1]['content'] = thought_buf
|
| 544 |
-
yield history, debug
|
| 545 |
continue
|
| 546 |
|
| 547 |
if in_thought:
|
|
@@ -554,7 +628,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 554 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 555 |
else:
|
| 556 |
history[-1]['content'] = thought_buf
|
| 557 |
-
yield history, debug
|
| 558 |
continue
|
| 559 |
|
| 560 |
# Stream answer
|
|
@@ -564,10 +638,16 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 564 |
|
| 565 |
answer_buf += text
|
| 566 |
history[-1]['content'] = answer_buf.strip()
|
| 567 |
-
yield history, debug
|
| 568 |
|
| 569 |
gen_thread.join()
|
| 570 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
except GeneratorExit:
|
| 572 |
# Handle cancellation gracefully
|
| 573 |
print("Chat response cancelled.")
|
|
@@ -575,7 +655,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 575 |
return
|
| 576 |
except Exception as e:
|
| 577 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 578 |
-
yield history, debug
|
| 579 |
finally:
|
| 580 |
gc.collect()
|
| 581 |
|
|
@@ -583,39 +663,40 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 583 |
def update_default_prompt(enable_search):
|
| 584 |
return f"You are a helpful assistant."
|
| 585 |
|
| 586 |
-
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
|
| 587 |
"""Calculate and format the estimated GPU duration for current settings."""
|
| 588 |
try:
|
| 589 |
dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
|
| 590 |
-
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 591 |
-
enable_search, max_results, max_chars, model_name,
|
| 592 |
-
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
|
| 593 |
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 594 |
return (f"โฑ๏ธ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
|
| 595 |
f"๐ **Model Size:** {model_size:.1f}B parameters\n"
|
| 596 |
-
f"๐ **Web Search:** {'Enabled' if enable_search else 'Disabled'}"
|
|
|
|
| 597 |
except Exception as e:
|
| 598 |
return f"โ ๏ธ Error calculating estimate: {e}"
|
| 599 |
|
| 600 |
# ------------------------------
|
| 601 |
# Gradio UI
|
| 602 |
# ------------------------------
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
) as demo:
|
| 619 |
# Header
|
| 620 |
gr.Markdown("""
|
| 621 |
# ๐ง ZeroGPU LLM Inference
|
|
@@ -639,6 +720,11 @@ with gr.Blocks(
|
|
| 639 |
value=False,
|
| 640 |
info="Augment responses with real-time web data"
|
| 641 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
sys_prompt = gr.Textbox(
|
| 643 |
label="๐ System Prompt",
|
| 644 |
lines=3,
|
|
@@ -648,7 +734,7 @@ with gr.Blocks(
|
|
| 648 |
|
| 649 |
# Duration Estimate
|
| 650 |
duration_display = gr.Markdown(
|
| 651 |
-
value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0),
|
| 652 |
elem_classes="duration-estimate"
|
| 653 |
)
|
| 654 |
|
|
@@ -706,14 +792,22 @@ with gr.Blocks(
|
|
| 706 |
# Right Panel - Chat Interface
|
| 707 |
with gr.Column(scale=7):
|
| 708 |
chat = gr.Chatbot(
|
| 709 |
-
type="messages",
|
| 710 |
height=600,
|
| 711 |
label="๐ฌ Conversation",
|
| 712 |
-
|
| 713 |
avatar_images=(None, "๐ค"),
|
| 714 |
-
|
| 715 |
)
|
| 716 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
# Input Area
|
| 718 |
with gr.Row():
|
| 719 |
txt = gr.Textbox(
|
|
@@ -758,9 +852,9 @@ with gr.Blocks(
|
|
| 758 |
# --- Event Listeners ---
|
| 759 |
|
| 760 |
# Group all inputs for cleaner event handling
|
| 761 |
-
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
|
| 762 |
# Group all UI components that can be updated.
|
| 763 |
-
ui_components = [chat, dbg, txt, submit_btn, cancel_btn]
|
| 764 |
|
| 765 |
def submit_and_manage_ui(user_msg, chat_history, *args):
|
| 766 |
"""
|
|
@@ -769,10 +863,19 @@ with gr.Blocks(
|
|
| 769 |
"""
|
| 770 |
if not user_msg.strip():
|
| 771 |
# If the message is empty, do nothing.
|
| 772 |
-
|
| 773 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
return
|
| 775 |
|
|
|
|
|
|
|
|
|
|
| 776 |
# 1. Update UI to "generating" state.
|
| 777 |
# Crucially, we do NOT update the `chat` component here, as the backend
|
| 778 |
# will provide the correctly formatted history in the first response chunk.
|
|
@@ -780,6 +883,7 @@ with gr.Blocks(
|
|
| 780 |
txt: gr.update(value="", interactive=False),
|
| 781 |
submit_btn: gr.update(interactive=False),
|
| 782 |
cancel_btn: gr.update(visible=True),
|
|
|
|
| 783 |
}
|
| 784 |
|
| 785 |
cancelled = False
|
|
@@ -787,10 +891,18 @@ with gr.Blocks(
|
|
| 787 |
# 2. Call the backend and stream updates
|
| 788 |
backend_args = [user_msg, chat_history] + list(args)
|
| 789 |
for response_chunk in chat_response(*backend_args):
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
|
|
|
|
|
|
| 793 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
except GeneratorExit:
|
| 795 |
# Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
|
| 796 |
cancelled = True
|
|
@@ -827,6 +939,7 @@ with gr.Blocks(
|
|
| 827 |
txt: gr.update(interactive=True),
|
| 828 |
submit_btn: gr.update(interactive=True),
|
| 829 |
cancel_btn: gr.update(visible=False),
|
|
|
|
| 830 |
}
|
| 831 |
|
| 832 |
# Event for submitting text via Enter key or Submit button
|
|
@@ -852,7 +965,7 @@ with gr.Blocks(
|
|
| 852 |
)
|
| 853 |
|
| 854 |
# Listeners for updating the duration estimate
|
| 855 |
-
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
|
| 856 |
for component in duration_inputs:
|
| 857 |
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 858 |
|
|
@@ -867,6 +980,6 @@ with gr.Blocks(
|
|
| 867 |
)
|
| 868 |
|
| 869 |
# Clear chat action
|
| 870 |
-
clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
|
| 871 |
|
| 872 |
-
demo.launch()
|
|
|
|
| 13 |
from ddgs import DDGS
|
| 14 |
import spaces # Import spaces early to enable ZeroGPU support
|
| 15 |
from torch.utils._pytree import tree_map
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
# Add pocket-tts to path for TTS functionality
|
| 19 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'pocket-tts'))
|
| 20 |
+
from pocket_tts import TTSModel
|
| 21 |
|
| 22 |
# Global event to signal cancellation from the UI thread to the generation thread
|
| 23 |
cancel_event = threading.Event()
|
| 24 |
|
| 25 |
+
access_token = os.environ.get('HF_TOKEN')
|
| 26 |
|
| 27 |
# Optional: Disable GPU visibility if you wish to force CPU usage
|
| 28 |
# os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
|
|
| 322 |
# Global cache for pipelines to avoid re-loading.
|
| 323 |
PIPELINES = {}
|
| 324 |
|
| 325 |
+
# ------------------------------
|
| 326 |
+
# TTS Configuration
|
| 327 |
+
# ------------------------------
|
| 328 |
+
TTS_VOICE_FILE = "./voice.wav" # Path to custom voice file for voice cloning
|
| 329 |
+
|
| 330 |
+
# Global TTS model cache
|
| 331 |
+
TTS_MODEL = None
|
| 332 |
+
TTS_VOICE_STATE = None
|
| 333 |
+
|
| 334 |
+
def load_tts_model():
|
| 335 |
+
"""Load and cache the TTS model."""
|
| 336 |
+
global TTS_MODEL
|
| 337 |
+
if TTS_MODEL is None:
|
| 338 |
+
TTS_MODEL = TTSModel.load_model()
|
| 339 |
+
return TTS_MODEL
|
| 340 |
+
|
| 341 |
+
def get_voice_state():
|
| 342 |
+
"""Get cached voice state from the custom voice file."""
|
| 343 |
+
global TTS_VOICE_STATE
|
| 344 |
+
if TTS_VOICE_STATE is None:
|
| 345 |
+
tts_model = load_tts_model()
|
| 346 |
+
TTS_VOICE_STATE = tts_model.get_state_for_audio_prompt(TTS_VOICE_FILE)
|
| 347 |
+
return TTS_VOICE_STATE
|
| 348 |
+
|
| 349 |
+
def clean_text_for_tts(text: str) -> str:
|
| 350 |
+
"""Clean text for better TTS output by removing code blocks, markdown, and thinking tags."""
|
| 351 |
+
# Remove thinking blocks (Qwen3 models)
|
| 352 |
+
text = re.sub(r'<think>[\s\S]*?</think>', '', text)
|
| 353 |
+
# Remove markdown code blocks
|
| 354 |
+
text = re.sub(r'```[\s\S]*?```', '', text)
|
| 355 |
+
# Remove inline code
|
| 356 |
+
text = re.sub(r'`[^`]+`', '', text)
|
| 357 |
+
# Remove citation markers
|
| 358 |
+
text = re.sub(r'\[citation:\d+\]', '', text)
|
| 359 |
+
# Remove markdown headers
|
| 360 |
+
text = re.sub(r'#{1,6}\s+', '', text)
|
| 361 |
+
# Remove markdown bold/italic
|
| 362 |
+
text = re.sub(r'\*{1,2}([^*]+)\*{1,2}', r'\1', text)
|
| 363 |
+
# Remove markdown links, keep text
|
| 364 |
+
text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
|
| 365 |
+
# Remove multiple spaces/newlines
|
| 366 |
+
text = re.sub(r'\s+', ' ', text)
|
| 367 |
+
return text.strip()
|
| 368 |
+
|
| 369 |
+
def generate_tts_audio(text: str) -> tuple[int, np.ndarray] | None:
|
| 370 |
+
"""
|
| 371 |
+
Generate TTS audio from text using the custom voice.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
text: The text to convert to speech
|
| 375 |
+
|
| 376 |
+
Returns:
|
| 377 |
+
Tuple of (sample_rate, audio_array) or None if TTS fails
|
| 378 |
+
"""
|
| 379 |
+
try:
|
| 380 |
+
# Clean the text for better TTS
|
| 381 |
+
clean_text = clean_text_for_tts(text)
|
| 382 |
+
if not clean_text:
|
| 383 |
+
return None
|
| 384 |
+
|
| 385 |
+
tts_model = load_tts_model()
|
| 386 |
+
voice_state = get_voice_state()
|
| 387 |
+
audio = tts_model.generate_audio(voice_state, clean_text)
|
| 388 |
+
return (tts_model.sample_rate, audio.numpy())
|
| 389 |
+
except Exception as e:
|
| 390 |
+
print(f"TTS generation error: {e}")
|
| 391 |
+
return None
|
| 392 |
+
|
| 393 |
def load_pipeline(model_name):
|
| 394 |
"""
|
| 395 |
Load and cache a transformers pipeline for text generation.
|
|
|
|
| 457 |
prompt += "Assistant: "
|
| 458 |
return prompt
|
| 459 |
|
| 460 |
+
def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout, enable_tts):
|
| 461 |
# Get model size from the MODELS dict (more reliable than string parsing)
|
| 462 |
model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
|
| 463 |
+
|
| 464 |
# Only use AOT for models >= 2B parameters
|
| 465 |
use_aot = model_size >= 2
|
| 466 |
+
|
| 467 |
# Adjusted for H200 performance: faster inference, quicker compilation
|
| 468 |
base_duration = 20 if not use_aot else 40 # Reduced base times
|
| 469 |
token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
|
| 470 |
search_duration = 10 if enable_search else 0 # Reduced search time
|
| 471 |
aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
|
| 472 |
+
tts_duration = 15 if enable_tts else 0 # TTS generation time
|
| 473 |
+
|
| 474 |
+
return base_duration + token_duration + search_duration + aot_compilation_buffer + tts_duration
|
| 475 |
|
| 476 |
@spaces.GPU(duration=get_duration)
|
| 477 |
def chat_response(user_msg, chat_history, system_prompt,
|
| 478 |
enable_search, max_results, max_chars,
|
| 479 |
model_name, max_tokens, temperature,
|
| 480 |
+
top_k, top_p, repeat_penalty, search_timeout, enable_tts):
|
| 481 |
"""
|
| 482 |
Generates streaming chat responses, optionally with background web search.
|
| 483 |
This version includes cancellation support.
|
|
|
|
| 587 |
assistant_message_started = False
|
| 588 |
|
| 589 |
# First yield contains the user message
|
| 590 |
+
yield history, debug, None
|
| 591 |
|
| 592 |
# Stream tokens
|
| 593 |
for chunk in streamer:
|
|
|
|
| 595 |
if cancel_event.is_set():
|
| 596 |
if assistant_message_started and history and history[-1]['role'] == 'assistant':
|
| 597 |
history[-1]['content'] += " [Generation Canceled]"
|
| 598 |
+
yield history, debug, None
|
| 599 |
break
|
| 600 |
|
| 601 |
text = chunk
|
|
|
|
| 615 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 616 |
else:
|
| 617 |
history[-1]['content'] = thought_buf
|
| 618 |
+
yield history, debug, None
|
| 619 |
continue
|
| 620 |
|
| 621 |
if in_thought:
|
|
|
|
| 628 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 629 |
else:
|
| 630 |
history[-1]['content'] = thought_buf
|
| 631 |
+
yield history, debug, None
|
| 632 |
continue
|
| 633 |
|
| 634 |
# Stream answer
|
|
|
|
| 638 |
|
| 639 |
answer_buf += text
|
| 640 |
history[-1]['content'] = answer_buf.strip()
|
| 641 |
+
yield history, debug, None
|
| 642 |
|
| 643 |
gen_thread.join()
|
| 644 |
+
|
| 645 |
+
# Generate TTS audio if enabled
|
| 646 |
+
tts_audio = None
|
| 647 |
+
if enable_tts and answer_buf.strip():
|
| 648 |
+
tts_audio = generate_tts_audio(answer_buf)
|
| 649 |
+
|
| 650 |
+
yield history, debug + prompt_debug, tts_audio
|
| 651 |
except GeneratorExit:
|
| 652 |
# Handle cancellation gracefully
|
| 653 |
print("Chat response cancelled.")
|
|
|
|
| 655 |
return
|
| 656 |
except Exception as e:
|
| 657 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 658 |
+
yield history, debug, None
|
| 659 |
finally:
|
| 660 |
gc.collect()
|
| 661 |
|
|
|
|
| 663 |
def update_default_prompt(enable_search):
|
| 664 |
return f"You are a helpful assistant."
|
| 665 |
|
| 666 |
+
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout, enable_tts):
|
| 667 |
"""Calculate and format the estimated GPU duration for current settings."""
|
| 668 |
try:
|
| 669 |
dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
|
| 670 |
+
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 671 |
+
enable_search, max_results, max_chars, model_name,
|
| 672 |
+
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout, enable_tts)
|
| 673 |
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 674 |
return (f"โฑ๏ธ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
|
| 675 |
f"๐ **Model Size:** {model_size:.1f}B parameters\n"
|
| 676 |
+
f"๐ **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
|
| 677 |
+
f"๐ **TTS:** {'Enabled' if enable_tts else 'Disabled'}")
|
| 678 |
except Exception as e:
|
| 679 |
return f"โ ๏ธ Error calculating estimate: {e}"
|
| 680 |
|
| 681 |
# ------------------------------
|
| 682 |
# Gradio UI
|
| 683 |
# ------------------------------
|
| 684 |
+
CUSTOM_THEME = gr.themes.Soft(
|
| 685 |
+
primary_hue="indigo",
|
| 686 |
+
secondary_hue="purple",
|
| 687 |
+
neutral_hue="slate",
|
| 688 |
+
radius_size="lg",
|
| 689 |
+
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
CUSTOM_CSS = """
|
| 693 |
+
.duration-estimate { background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%); border-left: 4px solid #667eea; padding: 12px; border-radius: 8px; margin: 16px 0; }
|
| 694 |
+
.chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
|
| 695 |
+
button.primary { font-weight: 600; }
|
| 696 |
+
.gradio-accordion { margin-bottom: 12px; }
|
| 697 |
+
"""
|
| 698 |
+
|
| 699 |
+
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
| 700 |
# Header
|
| 701 |
gr.Markdown("""
|
| 702 |
# ๐ง ZeroGPU LLM Inference
|
|
|
|
| 720 |
value=False,
|
| 721 |
info="Augment responses with real-time web data"
|
| 722 |
)
|
| 723 |
+
tts_chk = gr.Checkbox(
|
| 724 |
+
label="๐ Enable Text-to-Speech",
|
| 725 |
+
value=False,
|
| 726 |
+
info="Convert responses to speech using voice cloning"
|
| 727 |
+
)
|
| 728 |
sys_prompt = gr.Textbox(
|
| 729 |
label="๐ System Prompt",
|
| 730 |
lines=3,
|
|
|
|
| 734 |
|
| 735 |
# Duration Estimate
|
| 736 |
duration_display = gr.Markdown(
|
| 737 |
+
value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0, False),
|
| 738 |
elem_classes="duration-estimate"
|
| 739 |
)
|
| 740 |
|
|
|
|
| 792 |
# Right Panel - Chat Interface
|
| 793 |
with gr.Column(scale=7):
|
| 794 |
chat = gr.Chatbot(
|
|
|
|
| 795 |
height=600,
|
| 796 |
label="๐ฌ Conversation",
|
| 797 |
+
buttons=["copy"],
|
| 798 |
avatar_images=(None, "๐ค"),
|
| 799 |
+
layout="bubble"
|
| 800 |
)
|
| 801 |
+
|
| 802 |
+
# TTS Audio Output
|
| 803 |
+
tts_audio_output = gr.Audio(
|
| 804 |
+
label="๐ Generated Speech",
|
| 805 |
+
type="numpy",
|
| 806 |
+
autoplay=True,
|
| 807 |
+
visible=False,
|
| 808 |
+
elem_id="tts-audio"
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
# Input Area
|
| 812 |
with gr.Row():
|
| 813 |
txt = gr.Textbox(
|
|
|
|
| 852 |
# --- Event Listeners ---
|
| 853 |
|
| 854 |
# Group all inputs for cleaner event handling
|
| 855 |
+
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st, tts_chk]
|
| 856 |
# Group all UI components that can be updated.
|
| 857 |
+
ui_components = [chat, dbg, txt, submit_btn, cancel_btn, tts_audio_output]
|
| 858 |
|
| 859 |
def submit_and_manage_ui(user_msg, chat_history, *args):
|
| 860 |
"""
|
|
|
|
| 863 |
"""
|
| 864 |
if not user_msg.strip():
|
| 865 |
# If the message is empty, do nothing.
|
| 866 |
+
yield {
|
| 867 |
+
chat: gr.update(),
|
| 868 |
+
dbg: gr.update(),
|
| 869 |
+
txt: gr.update(),
|
| 870 |
+
submit_btn: gr.update(),
|
| 871 |
+
cancel_btn: gr.update(),
|
| 872 |
+
tts_audio_output: gr.update(),
|
| 873 |
+
}
|
| 874 |
return
|
| 875 |
|
| 876 |
+
# Check if TTS is enabled (last argument)
|
| 877 |
+
tts_enabled = args[-1] if args else False
|
| 878 |
+
|
| 879 |
# 1. Update UI to "generating" state.
|
| 880 |
# Crucially, we do NOT update the `chat` component here, as the backend
|
| 881 |
# will provide the correctly formatted history in the first response chunk.
|
|
|
|
| 883 |
txt: gr.update(value="", interactive=False),
|
| 884 |
submit_btn: gr.update(interactive=False),
|
| 885 |
cancel_btn: gr.update(visible=True),
|
| 886 |
+
tts_audio_output: gr.update(visible=False, value=None), # Hide audio during generation
|
| 887 |
}
|
| 888 |
|
| 889 |
cancelled = False
|
|
|
|
| 891 |
# 2. Call the backend and stream updates
|
| 892 |
backend_args = [user_msg, chat_history] + list(args)
|
| 893 |
for response_chunk in chat_response(*backend_args):
|
| 894 |
+
history, debug, audio = response_chunk[0], response_chunk[1], response_chunk[2] if len(response_chunk) > 2 else None
|
| 895 |
+
|
| 896 |
+
update_dict = {
|
| 897 |
+
chat: history,
|
| 898 |
+
dbg: debug,
|
| 899 |
}
|
| 900 |
+
|
| 901 |
+
# Show audio output when audio is generated (final yield with TTS)
|
| 902 |
+
if audio is not None:
|
| 903 |
+
update_dict[tts_audio_output] = gr.update(visible=True, value=audio)
|
| 904 |
+
|
| 905 |
+
yield update_dict
|
| 906 |
except GeneratorExit:
|
| 907 |
# Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
|
| 908 |
cancelled = True
|
|
|
|
| 939 |
txt: gr.update(interactive=True),
|
| 940 |
submit_btn: gr.update(interactive=True),
|
| 941 |
cancel_btn: gr.update(visible=False),
|
| 942 |
+
tts_audio_output: gr.update(visible=False, value=None),
|
| 943 |
}
|
| 944 |
|
| 945 |
# Event for submitting text via Enter key or Submit button
|
|
|
|
| 965 |
)
|
| 966 |
|
| 967 |
# Listeners for updating the duration estimate
|
| 968 |
+
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st, tts_chk]
|
| 969 |
for component in duration_inputs:
|
| 970 |
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 971 |
|
|
|
|
| 980 |
)
|
| 981 |
|
| 982 |
# Clear chat action
|
| 983 |
+
clr.click(fn=lambda: ([], "", "", gr.update(visible=False, value=None)), outputs=[chat, txt, dbg, tts_audio_output])
|
| 984 |
|
| 985 |
+
demo.launch(theme=CUSTOM_THEME, css=CUSTOM_CSS)
|
requirements.txt
CHANGED
|
@@ -9,4 +9,13 @@ sentencepiece
|
|
| 9 |
accelerate
|
| 10 |
autoawq
|
| 11 |
timm
|
| 12 |
-
compressed-tensors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
accelerate
|
| 10 |
autoawq
|
| 11 |
timm
|
| 12 |
+
compressed-tensors
|
| 13 |
+
|
| 14 |
+
# pocket-tts dependencies
|
| 15 |
+
numpy>=2
|
| 16 |
+
pydantic>=2
|
| 17 |
+
beartype>=0.22.5
|
| 18 |
+
safetensors>=0.4.0
|
| 19 |
+
scipy>=1.5.0
|
| 20 |
+
einops>=0.4.0
|
| 21 |
+
huggingface_hub>=0.10
|
voice.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0ba0e2f61e1e03c63791bd946c935b4dbc3b1a0e2b38f960b52ba746f2ca7e30
|
| 3 |
+
size 337028
|