polats commited on
Commit
51a9e4a
·
1 Parent(s): c6a3e9f

update ui for mobile

Browse files
Files changed (1) hide show
  1. app.py +159 -97
app.py CHANGED
@@ -462,7 +462,7 @@ def format_conversation(history, system_prompt, tokenizer, enable_thinking=False
462
  prompt += "Assistant: "
463
  return prompt
464
 
465
- def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout, enable_tts, enable_thinking):
466
  # Get model size from the MODELS dict (more reliable than string parsing)
467
  model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
468
 
@@ -474,18 +474,17 @@ def get_duration(user_msg, chat_history, system_prompt, enable_search, max_resul
474
  token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
475
  search_duration = 10 if enable_search else 0 # Reduced search time
476
  aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
477
- tts_duration = 15 if enable_tts else 0 # TTS generation time
478
 
479
- return base_duration + token_duration + search_duration + aot_compilation_buffer + tts_duration
480
 
481
  @spaces.GPU(duration=get_duration)
482
  def chat_response(user_msg, chat_history, system_prompt,
483
  enable_search, max_results, max_chars,
484
  model_name, max_tokens, temperature,
485
- top_k, top_p, repeat_penalty, search_timeout, enable_tts, enable_thinking):
486
  """
487
  Generates streaming chat responses, optionally with background web search.
488
- This version includes cancellation support.
489
  """
490
  # Clear the cancellation event at the start of a new generation
491
  cancel_event.clear()
@@ -592,7 +591,7 @@ def chat_response(user_msg, chat_history, system_prompt,
592
  assistant_message_started = False
593
 
594
  # First yield contains the user message
595
- yield history, debug, None
596
 
597
  # Stream tokens
598
  for chunk in streamer:
@@ -600,7 +599,7 @@ def chat_response(user_msg, chat_history, system_prompt,
600
  if cancel_event.is_set():
601
  if assistant_message_started and history and history[-1]['role'] == 'assistant':
602
  history[-1]['content'] += " [Generation Canceled]"
603
- yield history, debug, None
604
  break
605
 
606
  text = chunk
@@ -620,7 +619,7 @@ def chat_response(user_msg, chat_history, system_prompt,
620
  history.append({'role': 'assistant', 'content': answer_buf})
621
  else:
622
  history[-1]['content'] = thought_buf
623
- yield history, debug, None
624
  continue
625
 
626
  if in_thought:
@@ -633,7 +632,7 @@ def chat_response(user_msg, chat_history, system_prompt,
633
  history.append({'role': 'assistant', 'content': answer_buf})
634
  else:
635
  history[-1]['content'] = thought_buf
636
- yield history, debug, None
637
  continue
638
 
639
  # Stream answer
@@ -643,16 +642,11 @@ def chat_response(user_msg, chat_history, system_prompt,
643
 
644
  answer_buf += text
645
  history[-1]['content'] = answer_buf.strip()
646
- yield history, debug, None
647
 
648
  gen_thread.join()
649
 
650
- # Generate TTS audio if enabled
651
- tts_audio = None
652
- if enable_tts and answer_buf.strip():
653
- tts_audio = generate_tts_audio(answer_buf)
654
-
655
- yield history, debug + prompt_debug, tts_audio
656
  except GeneratorExit:
657
  # Handle cancellation gracefully
658
  print("Chat response cancelled.")
@@ -660,7 +654,7 @@ def chat_response(user_msg, chat_history, system_prompt,
660
  return
661
  except Exception as e:
662
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
663
- yield history, debug, None
664
  finally:
665
  gc.collect()
666
 
@@ -668,22 +662,63 @@ def chat_response(user_msg, chat_history, system_prompt,
668
  def update_default_prompt(enable_search):
669
  return f"You are a helpful assistant. Don't use emojis in your response. Keep replies short to a maximum of three sentences."
670
 
671
- def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout, enable_tts, enable_thinking):
672
  """Calculate and format the estimated GPU duration for current settings."""
673
  try:
674
  dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
675
  duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
676
  enable_search, max_results, max_chars, model_name,
677
- max_tokens, 0.7, 40, 0.9, 1.2, search_timeout, enable_tts, enable_thinking)
678
  model_size = MODELS[model_name].get("params_b", 4.0)
679
  return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
680
  f"📊 **Model Size:** {model_size:.1f}B parameters\n"
681
  f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
682
- f"🔊 **TTS:** {'Enabled' if enable_tts else 'Disabled'}\n"
683
  f"💭 **Thinking:** {'Enabled' if enable_thinking else 'Disabled'}")
684
  except Exception as e:
685
  return f"⚠️ Error calculating estimate: {e}"
686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
687
  # ------------------------------
688
  # Gradio UI
689
  # ------------------------------
@@ -700,6 +735,32 @@ CUSTOM_CSS = """
700
  .chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
701
  button.primary { font-weight: 600; }
702
  .gradio-accordion { margin-bottom: 12px; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
703
  """
704
 
705
  with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
@@ -711,7 +772,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
711
  height=500,
712
  label="💬 Conversation",
713
  buttons=["copy"],
714
- avatar_images=(None, "🤖"),
715
  layout="bubble"
716
  )
717
 
@@ -725,7 +786,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
725
  )
726
 
727
  # Input Area
728
- with gr.Row():
729
  txt = gr.Textbox(
730
  placeholder="💭 Type your message here... (Press Enter to send)",
731
  scale=9,
@@ -774,7 +835,7 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
774
 
775
  # Duration Estimate
776
  duration_display = gr.Markdown(
777
- value=update_duration_estimate("Qwen3-0.6B", False, 4, 50, 1024, 5.0, True, False),
778
  elem_classes="duration-estimate"
779
  )
780
 
@@ -802,81 +863,65 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
802
 
803
  # --- Event Listeners ---
804
 
805
- # Group all inputs for cleaner event handling
806
- chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st, tts_chk, thinking_chk]
807
- # Group all UI components that can be updated.
808
- ui_components = [chat, dbg, txt, submit_btn, cancel_btn, tts_audio_output]
809
 
810
- def submit_and_manage_ui(user_msg, chat_history, *args):
811
  """
812
- Orchestrator function that manages UI state and calls the backend chat function.
813
- It uses a try...finally block to ensure the UI is always reset.
814
  """
815
  if not user_msg.strip():
816
- # If the message is empty, do nothing.
817
- yield {
818
- chat: gr.update(),
819
- dbg: gr.update(),
820
- txt: gr.update(),
821
- submit_btn: gr.update(),
822
- cancel_btn: gr.update(),
823
- tts_audio_output: gr.update(),
824
- }
825
  return
826
 
827
- # Check if TTS is enabled (last argument)
828
- tts_enabled = args[-1] if args else False
829
-
830
- # 1. Update UI to "generating" state.
831
- # Crucially, we do NOT update the `chat` component here, as the backend
832
- # will provide the correctly formatted history in the first response chunk.
833
- # Keep audio visible but clear it - Gradio will show loading state
834
- yield {
835
- txt: gr.update(value="", interactive=False),
836
- submit_btn: gr.update(interactive=False),
837
- cancel_btn: gr.update(visible=True),
838
- tts_audio_output: gr.update(value=None), # Clear audio but keep visible
839
- }
840
 
841
- cancelled = False
842
  try:
843
- # 2. Call the backend and stream updates
844
- backend_args = [user_msg, chat_history] + list(args)
845
- for response_chunk in chat_response(*backend_args):
846
- history, debug, audio = response_chunk[0], response_chunk[1], response_chunk[2] if len(response_chunk) > 2 else None
847
-
848
- update_dict = {
849
- chat: history,
850
- dbg: debug,
851
- }
852
-
853
- # Update audio output when audio is generated (final yield with TTS)
854
- if audio is not None:
855
- update_dict[tts_audio_output] = gr.update(value=audio)
856
-
857
- yield update_dict
858
  except GeneratorExit:
859
- # Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
860
- cancelled = True
861
  print("Generation cancelled by user.")
862
  raise
863
  except Exception as e:
864
  print(f"An error occurred during generation: {e}")
865
- # If an error happens, add it to the chat history to inform the user.
866
  error_history = (chat_history or []) + [
867
  {'role': 'user', 'content': user_msg},
868
  {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
869
  ]
870
- yield {chat: error_history}
871
- finally:
872
- # Only reset UI if not cancelled (to avoid "generator ignored GeneratorExit")
873
- if not cancelled:
874
- print("Resetting UI state.")
875
- yield {
876
- txt: gr.update(interactive=True),
877
- submit_btn: gr.update(interactive=True),
878
- cancel_btn: gr.update(visible=False),
879
- }
 
 
 
 
 
 
880
 
881
  def set_cancel_flag():
882
  """Called by the cancel button, sets the global event."""
@@ -887,37 +932,54 @@ with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
887
  """Reset UI components after cancellation."""
888
  cancel_event.clear() # Clear the flag for next generation
889
  print("UI reset after cancellation.")
890
- return {
891
- txt: gr.update(interactive=True),
892
- submit_btn: gr.update(interactive=True),
893
- cancel_btn: gr.update(visible=False),
894
- tts_audio_output: gr.update(value=None), # Clear audio but keep visible
895
- }
896
 
897
- # Event for submitting text via Enter key or Submit button
 
898
  submit_event = txt.submit(
899
- fn=submit_and_manage_ui,
900
  inputs=chat_inputs,
901
- outputs=ui_components,
 
 
 
 
 
 
 
902
  )
903
- submit_btn.click(
904
- fn=submit_and_manage_ui,
 
 
905
  inputs=chat_inputs,
906
- outputs=ui_components,
 
 
 
 
 
 
 
907
  )
908
 
909
  # Event for the "Cancel" button.
910
- # It sets the cancel flag, cancels the submit event, then resets the UI.
911
  cancel_btn.click(
912
  fn=set_cancel_flag,
913
- cancels=[submit_event]
914
  ).then(
915
  fn=reset_ui_after_cancel,
916
- outputs=ui_components
917
  )
918
 
919
  # Listeners for updating the duration estimate
920
- duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st, tts_chk, thinking_chk]
921
  for component in duration_inputs:
922
  component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
923
 
 
462
  prompt += "Assistant: "
463
  return prompt
464
 
465
+ def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout, enable_thinking):
466
  # Get model size from the MODELS dict (more reliable than string parsing)
467
  model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
468
 
 
474
  token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
475
  search_duration = 10 if enable_search else 0 # Reduced search time
476
  aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
 
477
 
478
+ return base_duration + token_duration + search_duration + aot_compilation_buffer
479
 
480
  @spaces.GPU(duration=get_duration)
481
  def chat_response(user_msg, chat_history, system_prompt,
482
  enable_search, max_results, max_chars,
483
  model_name, max_tokens, temperature,
484
+ top_k, top_p, repeat_penalty, search_timeout, enable_thinking):
485
  """
486
  Generates streaming chat responses, optionally with background web search.
487
+ TTS is handled separately after this completes.
488
  """
489
  # Clear the cancellation event at the start of a new generation
490
  cancel_event.clear()
 
591
  assistant_message_started = False
592
 
593
  # First yield contains the user message
594
+ yield history, debug
595
 
596
  # Stream tokens
597
  for chunk in streamer:
 
599
  if cancel_event.is_set():
600
  if assistant_message_started and history and history[-1]['role'] == 'assistant':
601
  history[-1]['content'] += " [Generation Canceled]"
602
+ yield history, debug
603
  break
604
 
605
  text = chunk
 
619
  history.append({'role': 'assistant', 'content': answer_buf})
620
  else:
621
  history[-1]['content'] = thought_buf
622
+ yield history, debug
623
  continue
624
 
625
  if in_thought:
 
632
  history.append({'role': 'assistant', 'content': answer_buf})
633
  else:
634
  history[-1]['content'] = thought_buf
635
+ yield history, debug
636
  continue
637
 
638
  # Stream answer
 
642
 
643
  answer_buf += text
644
  history[-1]['content'] = answer_buf.strip()
645
+ yield history, debug
646
 
647
  gen_thread.join()
648
 
649
+ yield history, debug + prompt_debug
 
 
 
 
 
650
  except GeneratorExit:
651
  # Handle cancellation gracefully
652
  print("Chat response cancelled.")
 
654
  return
655
  except Exception as e:
656
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
657
+ yield history, debug
658
  finally:
659
  gc.collect()
660
 
 
662
  def update_default_prompt(enable_search):
663
  return f"You are a helpful assistant. Don't use emojis in your response. Keep replies short to a maximum of three sentences."
664
 
665
+ def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout, enable_thinking):
666
  """Calculate and format the estimated GPU duration for current settings."""
667
  try:
668
  dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
669
  duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
670
  enable_search, max_results, max_chars, model_name,
671
+ max_tokens, 0.7, 40, 0.9, 1.2, search_timeout, enable_thinking)
672
  model_size = MODELS[model_name].get("params_b", 4.0)
673
  return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
674
  f"📊 **Model Size:** {model_size:.1f}B parameters\n"
675
  f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}\n"
 
676
  f"💭 **Thinking:** {'Enabled' if enable_thinking else 'Disabled'}")
677
  except Exception as e:
678
  return f"⚠️ Error calculating estimate: {e}"
679
 
680
+
681
+ def generate_speech_from_chat(chat_history, enable_tts):
682
+ """
683
+ Generate TTS audio from the last assistant message in chat history.
684
+ This runs as a separate step after text generation, allowing the audio
685
+ component to show its loading state.
686
+ """
687
+ if not enable_tts:
688
+ return None
689
+
690
+ if not chat_history:
691
+ return None
692
+
693
+ # Find the last assistant message (skip thought bubbles)
694
+ last_message = None
695
+ for msg in reversed(chat_history):
696
+ if msg.get('role') == 'assistant':
697
+ # Skip thought bubbles (they have metadata with title starting with 💭)
698
+ metadata = msg.get('metadata') or {}
699
+ if metadata.get('title', '').startswith('💭'):
700
+ continue
701
+ content = msg.get('content', '')
702
+ # Handle both string and list content (Gradio multi-modal format)
703
+ if isinstance(content, list):
704
+ # Extract text from list items
705
+ text_parts = []
706
+ for item in content:
707
+ if isinstance(item, str):
708
+ text_parts.append(item)
709
+ elif isinstance(item, dict) and 'text' in item:
710
+ text_parts.append(item['text'])
711
+ last_message = ' '.join(text_parts)
712
+ else:
713
+ last_message = content
714
+ break
715
+
716
+ if not last_message or not last_message.strip():
717
+ return None
718
+
719
+ return generate_tts_audio(last_message)
720
+
721
+
722
  # ------------------------------
723
  # Gradio UI
724
  # ------------------------------
 
735
  .chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
736
  button.primary { font-weight: 600; }
737
  .gradio-accordion { margin-bottom: 12px; }
738
+
739
+ /* Mobile: sticky input at bottom */
740
+ @media (max-width: 768px) {
741
+ #input-row {
742
+ position: fixed;
743
+ bottom: 0;
744
+ left: 0;
745
+ right: 0;
746
+ background: var(--background-fill-primary);
747
+ padding: 12px;
748
+ box-shadow: 0 -2px 10px rgba(0, 0, 0, 0.1);
749
+ z-index: 1000;
750
+ margin: 0 !important;
751
+ }
752
+
753
+ /* Add padding at bottom of main content to prevent overlap */
754
+ .main {
755
+ padding-bottom: 80px !important;
756
+ }
757
+
758
+ /* Adjust chatbot height on mobile */
759
+ .chatbot {
760
+ height: calc(100vh - 200px) !important;
761
+ max-height: none !important;
762
+ }
763
+ }
764
  """
765
 
766
  with gr.Blocks(title="LLM Inference with ZeroGPU") as demo:
 
772
  height=500,
773
  label="💬 Conversation",
774
  buttons=["copy"],
775
+ avatar_images=(None, "pfp.png"),
776
  layout="bubble"
777
  )
778
 
 
786
  )
787
 
788
  # Input Area
789
+ with gr.Row(elem_id="input-row"):
790
  txt = gr.Textbox(
791
  placeholder="💭 Type your message here... (Press Enter to send)",
792
  scale=9,
 
835
 
836
  # Duration Estimate
837
  duration_display = gr.Markdown(
838
+ value=update_duration_estimate("Qwen3-0.6B", False, 4, 50, 1024, 5.0, False),
839
  elem_classes="duration-estimate"
840
  )
841
 
 
863
 
864
  # --- Event Listeners ---
865
 
866
+ # Group inputs for chat generation (no TTS - handled separately via .then())
867
+ chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st, thinking_chk]
868
+ # UI components for streaming phase
869
+ stream_outputs = [chat, dbg, txt, submit_btn, cancel_btn]
870
 
871
+ def stream_chat_and_update_ui(user_msg, chat_history, *args):
872
  """
873
+ Stream chat responses and manage UI state during generation.
874
+ TTS is handled separately via .then() chaining.
875
  """
876
  if not user_msg.strip():
877
+ # If the message is empty, do nothing - return current state
878
+ yield chat_history, "", gr.update(), gr.update(), gr.update()
 
 
 
 
 
 
 
879
  return
880
 
881
+ # 1. Update UI to "generating" state
882
+ yield (
883
+ chat_history, # Keep current chat
884
+ "", # Clear debug
885
+ gr.update(value="", interactive=False), # Clear and disable input
886
+ gr.update(interactive=False), # Disable submit
887
+ gr.update(visible=True), # Show cancel
888
+ )
 
 
 
 
 
889
 
 
890
  try:
891
+ # 2. Stream chat responses
892
+ for history, debug in chat_response(user_msg, chat_history, *args):
893
+ yield (
894
+ history,
895
+ debug,
896
+ gr.update(), # Keep input state
897
+ gr.update(), # Keep submit state
898
+ gr.update(), # Keep cancel state
899
+ )
 
 
 
 
 
 
900
  except GeneratorExit:
 
 
901
  print("Generation cancelled by user.")
902
  raise
903
  except Exception as e:
904
  print(f"An error occurred during generation: {e}")
 
905
  error_history = (chat_history or []) + [
906
  {'role': 'user', 'content': user_msg},
907
  {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
908
  ]
909
+ yield (
910
+ error_history,
911
+ f"Error: {e}",
912
+ gr.update(),
913
+ gr.update(),
914
+ gr.update(),
915
+ )
916
+
917
+ def reset_ui_after_generation():
918
+ """Reset UI to idle state after generation completes."""
919
+ print("Resetting UI state after generation.")
920
+ return (
921
+ gr.update(interactive=True), # Re-enable input
922
+ gr.update(interactive=True), # Re-enable submit
923
+ gr.update(visible=False), # Hide cancel
924
+ )
925
 
926
  def set_cancel_flag():
927
  """Called by the cancel button, sets the global event."""
 
932
  """Reset UI components after cancellation."""
933
  cancel_event.clear() # Clear the flag for next generation
934
  print("UI reset after cancellation.")
935
+ return (
936
+ gr.update(interactive=True),
937
+ gr.update(interactive=True),
938
+ gr.update(visible=False),
939
+ None, # Clear audio
940
+ )
941
 
942
+ # Event for submitting text via Enter key
943
+ # Uses .then() chaining: stream text -> generate TTS -> reset UI
944
  submit_event = txt.submit(
945
+ fn=stream_chat_and_update_ui,
946
  inputs=chat_inputs,
947
+ outputs=stream_outputs,
948
+ ).then(
949
+ fn=generate_speech_from_chat,
950
+ inputs=[chat, tts_chk],
951
+ outputs=[tts_audio_output],
952
+ ).then(
953
+ fn=reset_ui_after_generation,
954
+ outputs=[txt, submit_btn, cancel_btn],
955
  )
956
+
957
+ # Event for clicking Submit button
958
+ submit_btn_event = submit_btn.click(
959
+ fn=stream_chat_and_update_ui,
960
  inputs=chat_inputs,
961
+ outputs=stream_outputs,
962
+ ).then(
963
+ fn=generate_speech_from_chat,
964
+ inputs=[chat, tts_chk],
965
+ outputs=[tts_audio_output],
966
+ ).then(
967
+ fn=reset_ui_after_generation,
968
+ outputs=[txt, submit_btn, cancel_btn],
969
  )
970
 
971
  # Event for the "Cancel" button.
972
+ # It sets the cancel flag, cancels the submit events, then resets the UI.
973
  cancel_btn.click(
974
  fn=set_cancel_flag,
975
+ cancels=[submit_event, submit_btn_event]
976
  ).then(
977
  fn=reset_ui_after_cancel,
978
+ outputs=[txt, submit_btn, cancel_btn, tts_audio_output]
979
  )
980
 
981
  # Listeners for updating the duration estimate
982
+ duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st, thinking_chk]
983
  for component in duration_inputs:
984
  component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
985