Spaces:
Running
on
Zero
Running
on
Zero
update ui for mobile
Browse files
app.py
CHANGED
|
@@ -462,7 +462,7 @@ def format_conversation(history, system_prompt, tokenizer, enable_thinking=False
|
|
| 462 |
prompt += "Assistant: "
|
| 463 |
return prompt
|
| 464 |
|
| 465 |
-
def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout,
|
| 466 |
# Get model size from the MODELS dict (more reliable than string parsing)
|
| 467 |
model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
|
| 468 |
|
|
@@ -474,18 +474,17 @@ def get_duration(user_msg, chat_history, system_prompt, enable_search, max_resul
|
|
| 474 |
token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
|
| 475 |
search_duration = 10 if enable_search else 0 # Reduced search time
|
| 476 |
aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
|
| 477 |
-
tts_duration = 15 if enable_tts else 0 # TTS generation time
|
| 478 |
|
| 479 |
-
return base_duration + token_duration + search_duration + aot_compilation_buffer
|
| 480 |
|
| 481 |
@spaces.GPU(duration=get_duration)
|
| 482 |
def chat_response(user_msg, chat_history, system_prompt,
|
| 483 |
enable_search, max_results, max_chars,
|
| 484 |
model_name, max_tokens, temperature,
|
| 485 |
-
top_k, top_p, repeat_penalty, search_timeout,
|
| 486 |
"""
|
| 487 |
Generates streaming chat responses, optionally with background web search.
|
| 488 |
-
|
| 489 |
"""
|
| 490 |
# Clear the cancellation event at the start of a new generation
|
| 491 |
cancel_event.clear()
|
|
@@ -592,7 +591,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 592 |
assistant_message_started = False
|
| 593 |
|
| 594 |
# First yield contains the user message
|
| 595 |
-
yield history, debug
|
| 596 |
|
| 597 |
# Stream tokens
|
| 598 |
for chunk in streamer:
|
|
@@ -600,7 +599,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 600 |
if cancel_event.is_set():
|
| 601 |
if assistant_message_started and history and history[-1]['role'] == 'assistant':
|
| 602 |
history[-1]['content'] += " [Generation Canceled]"
|
| 603 |
-
yield history, debug
|
| 604 |
break
|
| 605 |
|
| 606 |
text = chunk
|
|
@@ -620,7 +619,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 620 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 621 |
else:
|
| 622 |
history[-1]['content'] = thought_buf
|
| 623 |
-
yield history, debug
|
| 624 |
continue
|
| 625 |
|
| 626 |
if in_thought:
|
|
@@ -633,7 +632,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 633 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 634 |
else:
|
| 635 |
history[-1]['content'] = thought_buf
|
| 636 |
-
yield history, debug
|
| 637 |
continue
|
| 638 |
|
| 639 |
# Stream answer
|
|
@@ -643,16 +642,11 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 643 |
|
| 644 |
answer_buf += text
|
| 645 |
history[-1]['content'] = answer_buf.strip()
|
| 646 |
-
yield history, debug
|
| 647 |
|
| 648 |
gen_thread.join()
|
| 649 |
|
| 650 |
-
|
| 651 |
-
tts_audio = None
|
| 652 |
-
if enable_tts and answer_buf.strip():
|
| 653 |
-
tts_audio = generate_tts_audio(answer_buf)
|
| 654 |
-
|
| 655 |
-
yield history, debug + prompt_debug, tts_audio
|
| 656 |
except GeneratorExit:
|
| 657 |
# Handle cancellation gracefully
|
| 658 |
print("Chat response cancelled.")
|
|
@@ -660,7 +654,7 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 660 |
return
|
| 661 |
except Exception as e:
|
| 662 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 663 |
-
yield history, debug
|
| 664 |
finally:
|
| 665 |
gc.collect()
|
| 666 |
|
|
@@ -668,22 +662,63 @@ def chat_response(user_msg, chat_history, system_prompt,
|
|
| 668 |
def update_default_prompt(enable_search):
|
| 669 |
return f"You are a helpful assistant. Don't use emojis in your response. Keep replies short to a maximum of three sentences."
|
| 670 |
|
| 671 |
-
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout,
|
| 672 |
"""Calculate and format the estimated GPU duration for current settings."""
|
| 673 |
try:
|
| 674 |
dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
|
| 675 |
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 676 |
enable_search, max_results, max_chars, model_name,
|
| 677 |
-
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout,
|
| 678 |
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 679 |
return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
|
| 680 |
f"📊 **Model Size:** {model_size:.1f}B parameters\n"
|
| 681 |
f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
|
| 682 |
-
f"🔊 **TTS:** {'Enabled' if enable_tts else 'Disabled'}\n"
|
| 683 |
f"💭 **Thinking:** {'Enabled' if enable_thinking else 'Disabled'}")
|
| 684 |
except Exception as e:
|
| 685 |
return f"⚠️ Error calculating estimate: {e}"
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
# ------------------------------
|
| 688 |
# Gradio UI
|
| 689 |
# ------------------------------
|
|
@@ -700,6 +735,32 @@ CUSTOM_CSS = """
|
|
| 700 |
.chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
|
| 701 |
button.primary { font-weight: 600; }
|
| 702 |
.gradio-accordion { margin-bottom: 12px; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
"""
|
| 704 |
|
| 705 |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
@@ -711,7 +772,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 711 |
height=500,
|
| 712 |
label="💬 Conversation",
|
| 713 |
buttons=["copy"],
|
| 714 |
-
avatar_images=(None, "
|
| 715 |
layout="bubble"
|
| 716 |
)
|
| 717 |
|
|
@@ -725,7 +786,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 725 |
)
|
| 726 |
|
| 727 |
# Input Area
|
| 728 |
-
with gr.Row():
|
| 729 |
txt = gr.Textbox(
|
| 730 |
placeholder="💭 Type your message here... (Press Enter to send)",
|
| 731 |
scale=9,
|
|
@@ -774,7 +835,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 774 |
|
| 775 |
# Duration Estimate
|
| 776 |
duration_display = gr.Markdown(
|
| 777 |
-
value=update_duration_estimate("Qwen3-0.6B", False, 4, 50, 1024, 5.0,
|
| 778 |
elem_classes="duration-estimate"
|
| 779 |
)
|
| 780 |
|
|
@@ -802,81 +863,65 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 802 |
|
| 803 |
# --- Event Listeners ---
|
| 804 |
|
| 805 |
-
# Group
|
| 806 |
-
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st,
|
| 807 |
-
#
|
| 808 |
-
|
| 809 |
|
| 810 |
-
def
|
| 811 |
"""
|
| 812 |
-
|
| 813 |
-
|
| 814 |
"""
|
| 815 |
if not user_msg.strip():
|
| 816 |
-
# If the message is empty, do nothing
|
| 817 |
-
yield
|
| 818 |
-
chat: gr.update(),
|
| 819 |
-
dbg: gr.update(),
|
| 820 |
-
txt: gr.update(),
|
| 821 |
-
submit_btn: gr.update(),
|
| 822 |
-
cancel_btn: gr.update(),
|
| 823 |
-
tts_audio_output: gr.update(),
|
| 824 |
-
}
|
| 825 |
return
|
| 826 |
|
| 827 |
-
#
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
txt: gr.update(value="", interactive=False),
|
| 836 |
-
submit_btn: gr.update(interactive=False),
|
| 837 |
-
cancel_btn: gr.update(visible=True),
|
| 838 |
-
tts_audio_output: gr.update(value=None), # Clear audio but keep visible
|
| 839 |
-
}
|
| 840 |
|
| 841 |
-
cancelled = False
|
| 842 |
try:
|
| 843 |
-
# 2.
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
# Update audio output when audio is generated (final yield with TTS)
|
| 854 |
-
if audio is not None:
|
| 855 |
-
update_dict[tts_audio_output] = gr.update(value=audio)
|
| 856 |
-
|
| 857 |
-
yield update_dict
|
| 858 |
except GeneratorExit:
|
| 859 |
-
# Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
|
| 860 |
-
cancelled = True
|
| 861 |
print("Generation cancelled by user.")
|
| 862 |
raise
|
| 863 |
except Exception as e:
|
| 864 |
print(f"An error occurred during generation: {e}")
|
| 865 |
-
# If an error happens, add it to the chat history to inform the user.
|
| 866 |
error_history = (chat_history or []) + [
|
| 867 |
{'role': 'user', 'content': user_msg},
|
| 868 |
{'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
|
| 869 |
]
|
| 870 |
-
yield
|
| 871 |
-
|
| 872 |
-
|
| 873 |
-
|
| 874 |
-
|
| 875 |
-
|
| 876 |
-
|
| 877 |
-
|
| 878 |
-
|
| 879 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 880 |
|
| 881 |
def set_cancel_flag():
|
| 882 |
"""Called by the cancel button, sets the global event."""
|
|
@@ -887,37 +932,54 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
| 887 |
"""Reset UI components after cancellation."""
|
| 888 |
cancel_event.clear() # Clear the flag for next generation
|
| 889 |
print("UI reset after cancellation.")
|
| 890 |
-
return
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
|
| 897 |
-
# Event for submitting text via Enter key
|
|
|
|
| 898 |
submit_event = txt.submit(
|
| 899 |
-
fn=
|
| 900 |
inputs=chat_inputs,
|
| 901 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 902 |
)
|
| 903 |
-
|
| 904 |
-
|
|
|
|
|
|
|
| 905 |
inputs=chat_inputs,
|
| 906 |
-
outputs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 907 |
)
|
| 908 |
|
| 909 |
# Event for the "Cancel" button.
|
| 910 |
-
# It sets the cancel flag, cancels the submit
|
| 911 |
cancel_btn.click(
|
| 912 |
fn=set_cancel_flag,
|
| 913 |
-
cancels=[submit_event]
|
| 914 |
).then(
|
| 915 |
fn=reset_ui_after_cancel,
|
| 916 |
-
outputs=
|
| 917 |
)
|
| 918 |
|
| 919 |
# Listeners for updating the duration estimate
|
| 920 |
-
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st,
|
| 921 |
for component in duration_inputs:
|
| 922 |
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 923 |
|
|
|
|
| 462 |
prompt += "Assistant: "
|
| 463 |
return prompt
|
| 464 |
|
| 465 |
+
def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout, enable_thinking):
|
| 466 |
# Get model size from the MODELS dict (more reliable than string parsing)
|
| 467 |
model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
|
| 468 |
|
|
|
|
| 474 |
token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
|
| 475 |
search_duration = 10 if enable_search else 0 # Reduced search time
|
| 476 |
aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
|
|
|
|
| 477 |
|
| 478 |
+
return base_duration + token_duration + search_duration + aot_compilation_buffer
|
| 479 |
|
| 480 |
@spaces.GPU(duration=get_duration)
|
| 481 |
def chat_response(user_msg, chat_history, system_prompt,
|
| 482 |
enable_search, max_results, max_chars,
|
| 483 |
model_name, max_tokens, temperature,
|
| 484 |
+
top_k, top_p, repeat_penalty, search_timeout, enable_thinking):
|
| 485 |
"""
|
| 486 |
Generates streaming chat responses, optionally with background web search.
|
| 487 |
+
TTS is handled separately after this completes.
|
| 488 |
"""
|
| 489 |
# Clear the cancellation event at the start of a new generation
|
| 490 |
cancel_event.clear()
|
|
|
|
| 591 |
assistant_message_started = False
|
| 592 |
|
| 593 |
# First yield contains the user message
|
| 594 |
+
yield history, debug
|
| 595 |
|
| 596 |
# Stream tokens
|
| 597 |
for chunk in streamer:
|
|
|
|
| 599 |
if cancel_event.is_set():
|
| 600 |
if assistant_message_started and history and history[-1]['role'] == 'assistant':
|
| 601 |
history[-1]['content'] += " [Generation Canceled]"
|
| 602 |
+
yield history, debug
|
| 603 |
break
|
| 604 |
|
| 605 |
text = chunk
|
|
|
|
| 619 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 620 |
else:
|
| 621 |
history[-1]['content'] = thought_buf
|
| 622 |
+
yield history, debug
|
| 623 |
continue
|
| 624 |
|
| 625 |
if in_thought:
|
|
|
|
| 632 |
history.append({'role': 'assistant', 'content': answer_buf})
|
| 633 |
else:
|
| 634 |
history[-1]['content'] = thought_buf
|
| 635 |
+
yield history, debug
|
| 636 |
continue
|
| 637 |
|
| 638 |
# Stream answer
|
|
|
|
| 642 |
|
| 643 |
answer_buf += text
|
| 644 |
history[-1]['content'] = answer_buf.strip()
|
| 645 |
+
yield history, debug
|
| 646 |
|
| 647 |
gen_thread.join()
|
| 648 |
|
| 649 |
+
yield history, debug + prompt_debug
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
except GeneratorExit:
|
| 651 |
# Handle cancellation gracefully
|
| 652 |
print("Chat response cancelled.")
|
|
|
|
| 654 |
return
|
| 655 |
except Exception as e:
|
| 656 |
history.append({'role': 'assistant', 'content': f"Error: {e}"})
|
| 657 |
+
yield history, debug
|
| 658 |
finally:
|
| 659 |
gc.collect()
|
| 660 |
|
|
|
|
| 662 |
def update_default_prompt(enable_search):
|
| 663 |
return f"You are a helpful assistant. Don't use emojis in your response. Keep replies short to a maximum of three sentences."
|
| 664 |
|
| 665 |
+
def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout, enable_thinking):
|
| 666 |
"""Calculate and format the estimated GPU duration for current settings."""
|
| 667 |
try:
|
| 668 |
dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
|
| 669 |
duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
|
| 670 |
enable_search, max_results, max_chars, model_name,
|
| 671 |
+
max_tokens, 0.7, 40, 0.9, 1.2, search_timeout, enable_thinking)
|
| 672 |
model_size = MODELS[model_name].get("params_b", 4.0)
|
| 673 |
return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
|
| 674 |
f"📊 **Model Size:** {model_size:.1f}B parameters\n"
|
| 675 |
f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
|
|
|
|
| 676 |
f"💭 **Thinking:** {'Enabled' if enable_thinking else 'Disabled'}")
|
| 677 |
except Exception as e:
|
| 678 |
return f"⚠️ Error calculating estimate: {e}"
|
| 679 |
|
| 680 |
+
|
| 681 |
+
def generate_speech_from_chat(chat_history, enable_tts):
|
| 682 |
+
"""
|
| 683 |
+
Generate TTS audio from the last assistant message in chat history.
|
| 684 |
+
This runs as a separate step after text generation, allowing the audio
|
| 685 |
+
component to show its loading state.
|
| 686 |
+
"""
|
| 687 |
+
if not enable_tts:
|
| 688 |
+
return None
|
| 689 |
+
|
| 690 |
+
if not chat_history:
|
| 691 |
+
return None
|
| 692 |
+
|
| 693 |
+
# Find the last assistant message (skip thought bubbles)
|
| 694 |
+
last_message = None
|
| 695 |
+
for msg in reversed(chat_history):
|
| 696 |
+
if msg.get('role') == 'assistant':
|
| 697 |
+
# Skip thought bubbles (they have metadata with title starting with 💭)
|
| 698 |
+
metadata = msg.get('metadata') or {}
|
| 699 |
+
if metadata.get('title', '').startswith('💭'):
|
| 700 |
+
continue
|
| 701 |
+
content = msg.get('content', '')
|
| 702 |
+
# Handle both string and list content (Gradio multi-modal format)
|
| 703 |
+
if isinstance(content, list):
|
| 704 |
+
# Extract text from list items
|
| 705 |
+
text_parts = []
|
| 706 |
+
for item in content:
|
| 707 |
+
if isinstance(item, str):
|
| 708 |
+
text_parts.append(item)
|
| 709 |
+
elif isinstance(item, dict) and 'text' in item:
|
| 710 |
+
text_parts.append(item['text'])
|
| 711 |
+
last_message = ' '.join(text_parts)
|
| 712 |
+
else:
|
| 713 |
+
last_message = content
|
| 714 |
+
break
|
| 715 |
+
|
| 716 |
+
if not last_message or not last_message.strip():
|
| 717 |
+
return None
|
| 718 |
+
|
| 719 |
+
return generate_tts_audio(last_message)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
# ------------------------------
|
| 723 |
# Gradio UI
|
| 724 |
# ------------------------------
|
|
|
|
| 735 |
.chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
|
| 736 |
button.primary { font-weight: 600; }
|
| 737 |
.gradio-accordion { margin-bottom: 12px; }
|
| 738 |
+
|
| 739 |
+
/* Mobile: sticky input at bottom */
|
| 740 |
+
@media (max-width: 768px) {
|
| 741 |
+
#input-row {
|
| 742 |
+
position: fixed;
|
| 743 |
+
bottom: 0;
|
| 744 |
+
left: 0;
|
| 745 |
+
right: 0;
|
| 746 |
+
background: var(--background-fill-primary);
|
| 747 |
+
padding: 12px;
|
| 748 |
+
box-shadow: 0 -2px 10px rgba(0, 0, 0, 0.1);
|
| 749 |
+
z-index: 1000;
|
| 750 |
+
margin: 0 !important;
|
| 751 |
+
}
|
| 752 |
+
|
| 753 |
+
/* Add padding at bottom of main content to prevent overlap */
|
| 754 |
+
.main {
|
| 755 |
+
padding-bottom: 80px !important;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
/* Adjust chatbot height on mobile */
|
| 759 |
+
.chatbot {
|
| 760 |
+
height: calc(100vh - 200px) !important;
|
| 761 |
+
max-height: none !important;
|
| 762 |
+
}
|
| 763 |
+
}
|
| 764 |
"""
|
| 765 |
|
| 766 |
with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
|
|
|
|
| 772 |
height=500,
|
| 773 |
label="💬 Conversation",
|
| 774 |
buttons=["copy"],
|
| 775 |
+
avatar_images=(None, "pfp.png"),
|
| 776 |
layout="bubble"
|
| 777 |
)
|
| 778 |
|
|
|
|
| 786 |
)
|
| 787 |
|
| 788 |
# Input Area
|
| 789 |
+
with gr.Row(elem_id="input-row"):
|
| 790 |
txt = gr.Textbox(
|
| 791 |
placeholder="💭 Type your message here... (Press Enter to send)",
|
| 792 |
scale=9,
|
|
|
|
| 835 |
|
| 836 |
# Duration Estimate
|
| 837 |
duration_display = gr.Markdown(
|
| 838 |
+
value=update_duration_estimate("Qwen3-0.6B", False, 4, 50, 1024, 5.0, False),
|
| 839 |
elem_classes="duration-estimate"
|
| 840 |
)
|
| 841 |
|
|
|
|
| 863 |
|
| 864 |
# --- Event Listeners ---
|
| 865 |
|
| 866 |
+
# Group inputs for chat generation (no TTS - handled separately via .then())
|
| 867 |
+
chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st, thinking_chk]
|
| 868 |
+
# UI components for streaming phase
|
| 869 |
+
stream_outputs = [chat, dbg, txt, submit_btn, cancel_btn]
|
| 870 |
|
| 871 |
+
def stream_chat_and_update_ui(user_msg, chat_history, *args):
|
| 872 |
"""
|
| 873 |
+
Stream chat responses and manage UI state during generation.
|
| 874 |
+
TTS is handled separately via .then() chaining.
|
| 875 |
"""
|
| 876 |
if not user_msg.strip():
|
| 877 |
+
# If the message is empty, do nothing - return current state
|
| 878 |
+
yield chat_history, "", gr.update(), gr.update(), gr.update()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 879 |
return
|
| 880 |
|
| 881 |
+
# 1. Update UI to "generating" state
|
| 882 |
+
yield (
|
| 883 |
+
chat_history, # Keep current chat
|
| 884 |
+
"", # Clear debug
|
| 885 |
+
gr.update(value="", interactive=False), # Clear and disable input
|
| 886 |
+
gr.update(interactive=False), # Disable submit
|
| 887 |
+
gr.update(visible=True), # Show cancel
|
| 888 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 889 |
|
|
|
|
| 890 |
try:
|
| 891 |
+
# 2. Stream chat responses
|
| 892 |
+
for history, debug in chat_response(user_msg, chat_history, *args):
|
| 893 |
+
yield (
|
| 894 |
+
history,
|
| 895 |
+
debug,
|
| 896 |
+
gr.update(), # Keep input state
|
| 897 |
+
gr.update(), # Keep submit state
|
| 898 |
+
gr.update(), # Keep cancel state
|
| 899 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
except GeneratorExit:
|
|
|
|
|
|
|
| 901 |
print("Generation cancelled by user.")
|
| 902 |
raise
|
| 903 |
except Exception as e:
|
| 904 |
print(f"An error occurred during generation: {e}")
|
|
|
|
| 905 |
error_history = (chat_history or []) + [
|
| 906 |
{'role': 'user', 'content': user_msg},
|
| 907 |
{'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
|
| 908 |
]
|
| 909 |
+
yield (
|
| 910 |
+
error_history,
|
| 911 |
+
f"Error: {e}",
|
| 912 |
+
gr.update(),
|
| 913 |
+
gr.update(),
|
| 914 |
+
gr.update(),
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
def reset_ui_after_generation():
|
| 918 |
+
"""Reset UI to idle state after generation completes."""
|
| 919 |
+
print("Resetting UI state after generation.")
|
| 920 |
+
return (
|
| 921 |
+
gr.update(interactive=True), # Re-enable input
|
| 922 |
+
gr.update(interactive=True), # Re-enable submit
|
| 923 |
+
gr.update(visible=False), # Hide cancel
|
| 924 |
+
)
|
| 925 |
|
| 926 |
def set_cancel_flag():
|
| 927 |
"""Called by the cancel button, sets the global event."""
|
|
|
|
| 932 |
"""Reset UI components after cancellation."""
|
| 933 |
cancel_event.clear() # Clear the flag for next generation
|
| 934 |
print("UI reset after cancellation.")
|
| 935 |
+
return (
|
| 936 |
+
gr.update(interactive=True),
|
| 937 |
+
gr.update(interactive=True),
|
| 938 |
+
gr.update(visible=False),
|
| 939 |
+
None, # Clear audio
|
| 940 |
+
)
|
| 941 |
|
| 942 |
+
# Event for submitting text via Enter key
|
| 943 |
+
# Uses .then() chaining: stream text -> generate TTS -> reset UI
|
| 944 |
submit_event = txt.submit(
|
| 945 |
+
fn=stream_chat_and_update_ui,
|
| 946 |
inputs=chat_inputs,
|
| 947 |
+
outputs=stream_outputs,
|
| 948 |
+
).then(
|
| 949 |
+
fn=generate_speech_from_chat,
|
| 950 |
+
inputs=[chat, tts_chk],
|
| 951 |
+
outputs=[tts_audio_output],
|
| 952 |
+
).then(
|
| 953 |
+
fn=reset_ui_after_generation,
|
| 954 |
+
outputs=[txt, submit_btn, cancel_btn],
|
| 955 |
)
|
| 956 |
+
|
| 957 |
+
# Event for clicking Submit button
|
| 958 |
+
submit_btn_event = submit_btn.click(
|
| 959 |
+
fn=stream_chat_and_update_ui,
|
| 960 |
inputs=chat_inputs,
|
| 961 |
+
outputs=stream_outputs,
|
| 962 |
+
).then(
|
| 963 |
+
fn=generate_speech_from_chat,
|
| 964 |
+
inputs=[chat, tts_chk],
|
| 965 |
+
outputs=[tts_audio_output],
|
| 966 |
+
).then(
|
| 967 |
+
fn=reset_ui_after_generation,
|
| 968 |
+
outputs=[txt, submit_btn, cancel_btn],
|
| 969 |
)
|
| 970 |
|
| 971 |
# Event for the "Cancel" button.
|
| 972 |
+
# It sets the cancel flag, cancels the submit events, then resets the UI.
|
| 973 |
cancel_btn.click(
|
| 974 |
fn=set_cancel_flag,
|
| 975 |
+
cancels=[submit_event, submit_btn_event]
|
| 976 |
).then(
|
| 977 |
fn=reset_ui_after_cancel,
|
| 978 |
+
outputs=[txt, submit_btn, cancel_btn, tts_audio_output]
|
| 979 |
)
|
| 980 |
|
| 981 |
# Listeners for updating the duration estimate
|
| 982 |
+
duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st, thinking_chk]
|
| 983 |
for component in duration_inputs:
|
| 984 |
component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
|
| 985 |
|