R-Kentaren commited on
Commit
fd64ee9
Β·
verified Β·
1 Parent(s): 9a0f889

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +576 -0
app.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import gc
4
+ import sys
5
+ import threading
6
+ from itertools import islice
7
+ from datetime import datetime
8
+ import re # for parsing <think> blocks
9
+ import gradio as gr
10
+ import torch
11
+ from transformers import pipeline, TextIteratorStreamer, StoppingCriteria
12
+ from transformers import AutoTokenizer
13
+ from ddgs import DDGS
14
+ import spaces # Import spaces early to enable ZeroGPU support
15
+ from torch.utils._pytree import tree_map
16
+ from config import *
17
+ # Global event to signal cancellation from the UI thread to the generation thread
18
+ cancel_event = threading.Event()
19
+
20
+ access_token=os.environ['HF_TOKEN']
21
+
22
+ # Optional: Disable GPU visibility if you wish to force CPU usage
23
+ # os.environ["CUDA_VISIBLE_DEVICES"] = ""
24
+
25
+
26
+
27
+ # Global cache for pipelines to avoid re-loading.
28
+ PIPELINES = {}
29
+
30
+ def load_pipeline(model_name):
31
+ """
32
+ Load and cache a transformers pipeline for text generation.
33
+ Tries bfloat16, falls back to float16 or float32 if unsupported.
34
+ """
35
+ global PIPELINES
36
+ if model_name in PIPELINES:
37
+ return PIPELINES[model_name]
38
+ repo = MODELS[model_name]["repo_id"]
39
+ tokenizer = AutoTokenizer.from_pretrained(repo,
40
+ token=access_token)
41
+ for dtype in (torch.bfloat16, torch.float16, torch.float32):
42
+ try:
43
+ pipe = pipeline(
44
+ task="text-generation",
45
+ model=repo,
46
+ tokenizer=tokenizer,
47
+ trust_remote_code=True,
48
+ dtype=dtype, # Use `dtype` instead of deprecated `torch_dtype`
49
+ device_map="auto",
50
+ use_cache=True, # Enable past-key-value caching
51
+ token=access_token)
52
+ PIPELINES[model_name] = pipe
53
+ return pipe
54
+ except Exception:
55
+ continue
56
+ # Final fallback
57
+ pipe = pipeline(
58
+ task="text-generation",
59
+ model=repo,
60
+ tokenizer=tokenizer,
61
+ trust_remote_code=True,
62
+ device_map="auto",
63
+ use_cache=True
64
+ )
65
+ PIPELINES[model_name] = pipe
66
+ return pipe
67
+
68
+
69
+ def retrieve_context(query, max_results=6, max_chars=50):
70
+ """
71
+ Retrieve search snippets from DuckDuckGo (runs in background).
72
+ Returns a list of result strings.
73
+ """
74
+ try:
75
+ with DDGS() as ddgs:
76
+ return [f"{i+1}. {r.get('title','No Title')} - {r.get('body','')[:max_chars]}"
77
+ for i, r in enumerate(islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results))]
78
+ except Exception:
79
+ return []
80
+
81
+ def format_conversation(history, system_prompt, tokenizer):
82
+ if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
83
+ messages = [{"role": "system", "content": system_prompt.strip()}] + history
84
+ return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
85
+ else:
86
+ # Fallback for base LMs without chat template
87
+ prompt = system_prompt.strip() + "\n"
88
+ for msg in history:
89
+ if msg['role'] == 'user':
90
+ prompt += "User: " + msg['content'].strip() + "\n"
91
+ elif msg['role'] == 'assistant':
92
+ prompt += "Assistant: " + msg['content'].strip() + "\n"
93
+ if not prompt.strip().endswith("Assistant:"):
94
+ prompt += "Assistant: "
95
+ return prompt
96
+
97
+ def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
98
+ # Get model size from the MODELS dict (more reliable than string parsing)
99
+ model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
100
+
101
+ # Only use AOT for models >= 2B parameters
102
+ use_aot = model_size >= 2
103
+
104
+ # Adjusted for H200 performance: faster inference, quicker compilation
105
+ base_duration = 20 if not use_aot else 40 # Reduced base times
106
+ token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
107
+ search_duration = 10 if enable_search else 0 # Reduced search time
108
+ aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
109
+
110
+ return base_duration + token_duration + search_duration + aot_compilation_buffer
111
+
112
+ @spaces.GPU(duration=get_duration)
113
+ def chat_response(user_msg, chat_history, system_prompt,
114
+ enable_search, max_results, max_chars,
115
+ model_name, max_tokens, temperature,
116
+ top_k, top_p, repeat_penalty, search_timeout):
117
+ """
118
+ Generates streaming chat responses, optionally with background web search.
119
+ This version includes cancellation support.
120
+ """
121
+ # Clear the cancellation event at the start of a new generation
122
+ cancel_event.clear()
123
+
124
+ history = list(chat_history or [])
125
+ history.append({'role': 'user', 'content': user_msg})
126
+
127
+ # Launch web search if enabled
128
+ debug = ''
129
+ search_results = []
130
+ if enable_search:
131
+ debug = 'Search task started.'
132
+ thread_search = threading.Thread(
133
+ target=lambda: search_results.extend(
134
+ retrieve_context(user_msg, int(max_results), int(max_chars))
135
+ )
136
+ )
137
+ thread_search.daemon = True
138
+ thread_search.start()
139
+ else:
140
+ debug = 'Web search disabled.'
141
+
142
+ try:
143
+ cur_date = datetime.now().strftime('%Y-%m-%d')
144
+ # merge any fetched search results into the system prompt
145
+ if search_results:
146
+
147
+ enriched = system_prompt.strip() + \
148
+ f'''\n# The following contents are the search results related to the user's message:
149
+ {search_results}
150
+ In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
151
+ When responding, please keep the following points in mind:
152
+ - Today is {cur_date}.
153
+ - Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
154
+ - For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
155
+ - For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
156
+ - If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
157
+ - For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
158
+ - Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
159
+ - Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
160
+ - Unless the user requests otherwise, your response should be in the same language as the user's question.
161
+ # The user's message is:
162
+ '''
163
+ else:
164
+ enriched = system_prompt
165
+
166
+ # wait up to 1s for snippets, then replace debug with them
167
+ if enable_search:
168
+ thread_search.join(timeout=float(search_timeout))
169
+ if search_results:
170
+ debug = "### Search results merged into prompt\n\n" + "\n".join(
171
+ f"- {r}" for r in search_results
172
+ )
173
+ else:
174
+ debug = "*No web search results found.*"
175
+
176
+ # merge fetched snippets into the system prompt
177
+ if search_results:
178
+ enriched = system_prompt.strip() + \
179
+ f'''\n# The following contents are the search results related to the user's message:
180
+ {search_results}
181
+ In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
182
+ When responding, please keep the following points in mind:
183
+ - Today is {cur_date}.
184
+ - Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
185
+ - For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
186
+ - For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
187
+ - If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
188
+ - For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
189
+ - Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
190
+ - Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
191
+ - Unless the user requests otherwise, your response should be in the same language as the user's question.
192
+ # The user's message is:
193
+ '''
194
+ else:
195
+ enriched = system_prompt
196
+
197
+ pipe = load_pipeline(model_name)
198
+
199
+ prompt = format_conversation(history, enriched, pipe.tokenizer)
200
+ prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```"
201
+ streamer = TextIteratorStreamer(pipe.tokenizer,
202
+ skip_prompt=True,
203
+ skip_special_tokens=True)
204
+ gen_thread = threading.Thread(
205
+ target=pipe,
206
+ args=(prompt,),
207
+ kwargs={
208
+ 'max_new_tokens': max_tokens,
209
+ 'temperature': temperature,
210
+ 'top_k': top_k,
211
+ 'top_p': top_p,
212
+ 'repetition_penalty': repeat_penalty,
213
+ 'streamer': streamer,
214
+ 'return_full_text': False,
215
+ }
216
+ )
217
+ gen_thread.start()
218
+
219
+ # Buffers for thought vs answer
220
+ thought_buf = ''
221
+ answer_buf = ''
222
+ in_thought = False
223
+ assistant_message_started = False
224
+
225
+ # First yield contains the user message
226
+ yield history, debug
227
+
228
+ # Stream tokens
229
+ for chunk in streamer:
230
+ # Check for cancellation signal
231
+ if cancel_event.is_set():
232
+ if assistant_message_started and history and history[-1]['role'] == 'assistant':
233
+ history[-1]['content'] += " [Generation Canceled]"
234
+ yield history, debug
235
+ break
236
+
237
+ text = chunk
238
+
239
+ # Detect start of thinking
240
+ if not in_thought and '<think>' in text:
241
+ in_thought = True
242
+ history.append({'role': 'assistant', 'content': '', 'metadata': {'title': 'πŸ’­ Thought'}})
243
+ assistant_message_started = True
244
+ after = text.split('<think>', 1)[1]
245
+ thought_buf += after
246
+ if '</think>' in thought_buf:
247
+ before, after2 = thought_buf.split('</think>', 1)
248
+ history[-1]['content'] = before.strip()
249
+ in_thought = False
250
+ answer_buf = after2
251
+ history.append({'role': 'assistant', 'content': answer_buf})
252
+ else:
253
+ history[-1]['content'] = thought_buf
254
+ yield history, debug
255
+ continue
256
+
257
+ if in_thought:
258
+ thought_buf += text
259
+ if '</think>' in thought_buf:
260
+ before, after2 = thought_buf.split('</think>', 1)
261
+ history[-1]['content'] = before.strip()
262
+ in_thought = False
263
+ answer_buf = after2
264
+ history.append({'role': 'assistant', 'content': answer_buf})
265
+ else:
266
+ history[-1]['content'] = thought_buf
267
+ yield history, debug
268
+ continue
269
+
270
+ # Stream answer
271
+ if not assistant_message_started:
272
+ history.append({'role': 'assistant', 'content': ''})
273
+ assistant_message_started = True
274
+
275
+ answer_buf += text
276
+ history[-1]['content'] = answer_buf.strip()
277
+ yield history, debug
278
+
279
+ gen_thread.join()
280
+ yield history, debug + prompt_debug
281
+ except GeneratorExit:
282
+ # Handle cancellation gracefully
283
+ print("Chat response cancelled.")
284
+ # Don't yield anything - let the cancellation propagate
285
+ return
286
+ except Exception as e:
287
+ history.append({'role': 'assistant', 'content': f"Error: {e}"})
288
+ yield history, debug
289
+ finally:
290
+ gc.collect()
291
+
292
+
293
+ def update_default_prompt(enable_search):
294
+ return f"You are a helpful assistant."
295
+
296
+ def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout):
297
+ """Calculate and format the estimated GPU duration for current settings."""
298
+ try:
299
+ dummy_msg, dummy_history, dummy_system_prompt = "", [], ""
300
+ duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
301
+ enable_search, max_results, max_chars, model_name,
302
+ max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
303
+ model_size = MODELS[model_name].get("params_b", 4.0)
304
+ return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
305
+ f"πŸ“Š **Model Size:** {model_size:.1f}B parameters\n"
306
+ f"πŸ” **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
307
+ except Exception as e:
308
+ return f"⚠️ Error calculating estimate: {e}"
309
+
310
+ # ------------------------------
311
+ # Gradio UI
312
+ # ------------------------------
313
+ with gr.Blocks(
314
+ title="LLM Inference",
315
+ theme=gr.themes.Soft(
316
+ primary_hue="blue",
317
+ secondary_hue="blue",
318
+ neutral_hue="slate",
319
+ radius_size="lg",
320
+ font=[gr.themes.GoogleFont("Syne"), "Arial", "sans-serif"]
321
+ ),
322
+ css="""
323
+ .duration-estimate { background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%); border-left: 4px solid #667eea; padding: 12px; border-radius: 8px; margin: 16px 0; }
324
+ .chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); }
325
+ button.primary { font-weight: 600; }
326
+ .gradio-accordion { margin-bottom: 12px; }
327
+ """
328
+ ) as demo:
329
+ # Header
330
+ gr.Markdown("""
331
+ # 🧠 CPU LLM Inference
332
+ """)
333
+
334
+ with gr.Row():
335
+ # Left Panel - Configuration
336
+ with gr.Column(scale=3):
337
+ # Core Settings (Always Visible)
338
+ with gr.Group():
339
+ gr.Markdown("### βš™οΈ Core Settings")
340
+ model_dd = gr.Dropdown(
341
+ label="πŸ€– Model",
342
+ choices=list(MODELS.keys()),
343
+ value="Qwen3-1.7B",
344
+ info="Select the language model to use"
345
+ )
346
+ search_chk = gr.Checkbox(
347
+ label="πŸ” Enable Web Search",
348
+ value=False,
349
+ info="Augment responses with real-time web data"
350
+ )
351
+ sys_prompt = gr.Textbox(label="πŸ“ System Prompt", lines=3, value=update_default_prompt(search_chk.value), placeholder="Define the assistant's behavior and personality...")
352
+
353
+ # Duration Estimate
354
+ duration_display = gr.Markdown(
355
+ value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0),
356
+ elem_classes="duration-estimate"
357
+ )
358
+
359
+ # Advanced Settings (Collapsible)
360
+ with gr.Accordion("πŸŽ›οΈ Advanced Generation Parameters", open=False):
361
+ max_tok = gr.Slider(
362
+ 64, 16384, value=1024, step=32,
363
+ label="Max Tokens",
364
+ info="Maximum length of generated response"
365
+ )
366
+ temp = gr.Slider(
367
+ 0.1, 2.0, value=0.7, step=0.1,
368
+ label="Temperature",
369
+ info="Higher = more creative, Lower = more focused"
370
+ )
371
+ with gr.Row():
372
+ k = gr.Slider(
373
+ 1, 100, value=40, step=1,
374
+ label="Top-K",
375
+ info="Number of top tokens to consider"
376
+ )
377
+ p = gr.Slider(
378
+ 0.1, 1.0, value=0.9, step=0.05,
379
+ label="Top-P",
380
+ info="Nucleus sampling threshold"
381
+ )
382
+ rp = gr.Slider(
383
+ 1.0, 2.0, value=1.2, step=0.1,
384
+ label="Repetition Penalty",
385
+ info="Penalize repeated tokens"
386
+ )
387
+
388
+ # Web Search Settings (Collapsible)
389
+ with gr.Accordion("🌐 Web Search Settings", open=False, visible=False) as search_settings:
390
+ mr = gr.Number(
391
+ value=4, precision=0,
392
+ label="Max Results",
393
+ info="Number of search results to retrieve"
394
+ )
395
+ mc = gr.Number(
396
+ value=50, precision=0,
397
+ label="Max Chars/Result",
398
+ info="Character limit per search result"
399
+ )
400
+ st = gr.Slider(
401
+ minimum=0.0, maximum=30.0, step=0.5, value=5.0,
402
+ label="Search Timeout (s)",
403
+ info="Maximum time to wait for search results"
404
+ )
405
+
406
+ # Actions
407
+ with gr.Row():
408
+ clr = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary", scale=1)
409
+
410
+ # Right Panel - Chat Interface
411
+ with gr.Column(scale=7):
412
+ chat = gr.Chatbot(
413
+ type="messages",
414
+ height=600,
415
+ label="πŸ’¬ Conversation",
416
+ show_copy_button=True,
417
+ avatar_images=(None, "πŸ€–"),
418
+ bubble_full_width=False
419
+ )
420
+
421
+ # Input Area
422
+ with gr.Row():
423
+ txt = gr.Textbox(
424
+ placeholder="πŸ’­ Type your message here... (Press Enter to send)",
425
+ scale=9,
426
+ container=False,
427
+ show_label=False,
428
+ lines=1,
429
+ max_lines=5
430
+ )
431
+ with gr.Column(scale=1, min_width=120):
432
+ submit_btn = gr.Button("πŸ“€ Send", variant="primary", size="lg")
433
+ cancel_btn = gr.Button("⏹️ Stop", variant="stop", visible=False, size="lg")
434
+
435
+ # Example Prompts
436
+ gr.Examples(
437
+ examples=[
438
+ ["Explain quantum computing in simple terms"],
439
+ ["Write a Python function to calculate fibonacci numbers"],
440
+ ["What are the latest developments in AI? (Enable web search)"],
441
+ ["Tell me a creative story about a time traveler"],
442
+ ["Help me debug this code: def add(a,b): return a+b+1"]
443
+ ],
444
+ inputs=txt,
445
+ label="πŸ’‘ Example Prompts"
446
+ )
447
+
448
+ # Debug/Status Info (Collapsible)
449
+ with gr.Accordion("πŸ” Debug Info", open=False):
450
+ dbg = gr.Markdown()
451
+
452
+ # Footer
453
+ gr.Markdown("""
454
+ ---
455
+ πŸ’‘ **Tips:**
456
+ - Use **Advanced Parameters** to fine-tune creativity and response length
457
+ - Enable **Web Search** for real-time, up-to-date information
458
+ - Try different **models** for various tasks (reasoning, coding, general chat)
459
+ - Click the **Copy** button on responses to save them to your clipboard
460
+ """, elem_classes="footer")
461
+
462
+ # --- Event Listeners ---
463
+
464
+ # Group all inputs for cleaner event handling
465
+ chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st]
466
+ # Group all UI components that can be updated.
467
+ ui_components = [chat, dbg, txt, submit_btn, cancel_btn]
468
+
469
+ def submit_and_manage_ui(user_msg, chat_history, *args):
470
+ """
471
+ Orchestrator function that manages UI state and calls the backend chat function.
472
+ It uses a try...finally block to ensure the UI is always reset.
473
+ """
474
+ if not user_msg.strip():
475
+ # If the message is empty, do nothing.
476
+ # We yield an empty dict to avoid any state changes.
477
+ yield {}
478
+ return
479
+
480
+ # 1. Update UI to "generating" state.
481
+ # Crucially, we do NOT update the `chat` component here, as the backend
482
+ # will provide the correctly formatted history in the first response chunk.
483
+ yield {
484
+ txt: gr.update(value="", interactive=False),
485
+ submit_btn: gr.update(interactive=False),
486
+ cancel_btn: gr.update(visible=True),
487
+ }
488
+
489
+ cancelled = False
490
+ try:
491
+ # 2. Call the backend and stream updates
492
+ backend_args = [user_msg, chat_history] + list(args)
493
+ for response_chunk in chat_response(*backend_args):
494
+ yield {
495
+ chat: response_chunk[0],
496
+ dbg: response_chunk[1],
497
+ }
498
+ except GeneratorExit:
499
+ # Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
500
+ cancelled = True
501
+ print("Generation cancelled by user.")
502
+ raise
503
+ except Exception as e:
504
+ print(f"An error occurred during generation: {e}")
505
+ # If an error happens, add it to the chat history to inform the user.
506
+ error_history = (chat_history or []) + [
507
+ {'role': 'user', 'content': user_msg},
508
+ {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
509
+ ]
510
+ yield {chat: error_history}
511
+ finally:
512
+ # Only reset UI if not cancelled (to avoid "generator ignored GeneratorExit")
513
+ if not cancelled:
514
+ print("Resetting UI state.")
515
+ yield {
516
+ txt: gr.update(interactive=True),
517
+ submit_btn: gr.update(interactive=True),
518
+ cancel_btn: gr.update(visible=False),
519
+ }
520
+
521
+ def set_cancel_flag():
522
+ """Called by the cancel button, sets the global event."""
523
+ cancel_event.set()
524
+ print("Cancellation signal sent.")
525
+
526
+ def reset_ui_after_cancel():
527
+ """Reset UI components after cancellation."""
528
+ cancel_event.clear() # Clear the flag for next generation
529
+ print("UI reset after cancellation.")
530
+ return {
531
+ txt: gr.update(interactive=True),
532
+ submit_btn: gr.update(interactive=True),
533
+ cancel_btn: gr.update(visible=False),
534
+ }
535
+
536
+ # Event for submitting text via Enter key or Submit button
537
+ submit_event = txt.submit(
538
+ fn=submit_and_manage_ui,
539
+ inputs=chat_inputs,
540
+ outputs=ui_components,
541
+ )
542
+ submit_btn.click(
543
+ fn=submit_and_manage_ui,
544
+ inputs=chat_inputs,
545
+ outputs=ui_components,
546
+ )
547
+
548
+ # Event for the "Cancel" button.
549
+ # It sets the cancel flag, cancels the submit event, then resets the UI.
550
+ cancel_btn.click(
551
+ fn=set_cancel_flag,
552
+ cancels=[submit_event]
553
+ ).then(
554
+ fn=reset_ui_after_cancel,
555
+ outputs=ui_components
556
+ )
557
+
558
+ # Listeners for updating the duration estimate
559
+ duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st]
560
+ for component in duration_inputs:
561
+ component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display)
562
+
563
+ # Toggle web search settings visibility
564
+ def toggle_search_settings(enabled):
565
+ return gr.update(visible=enabled)
566
+
567
+ search_chk.change(
568
+ fn=lambda enabled: (update_default_prompt(enabled), gr.update(visible=enabled)),
569
+ inputs=search_chk,
570
+ outputs=[sys_prompt, search_settings]
571
+ )
572
+
573
+ # Clear chat action
574
+ clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg])
575
+
576
+ demo.launch()