R-Kentaren commited on
Commit
0ff1c7b
·
verified ·
1 Parent(s): 7fa6b40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -94
app.py CHANGED
@@ -5,20 +5,18 @@ import sys
5
  import threading
6
  from itertools import islice
7
  from datetime import datetime
8
- import re # for parsing <think> blocks
9
  import gradio as gr
10
  import torch
11
- from transformers import pipeline, TextIteratorStreamer, StoppingCriteria
12
  from transformers import AutoTokenizer
13
  from ddgs import DDGS
14
- from torch.utils._pytree import tree_map
15
- from config import *
16
  # Global event to signal cancellation from the UI thread to the generation thread
17
  cancel_event = threading.Event()
18
 
19
- access_token=os.environ['HF_TOKEN']
20
-
21
-
22
 
23
  # Global cache for pipelines to avoid re-loading.
24
  PIPELINES = {}
@@ -32,8 +30,7 @@ def load_pipeline(model_name):
32
  if model_name in PIPELINES:
33
  return PIPELINES[model_name]
34
  repo = MODELS[model_name]["repo_id"]
35
- tokenizer = AutoTokenizer.from_pretrained(repo,
36
- token=access_token)
37
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
38
  try:
39
  pipe = pipeline(
@@ -41,9 +38,9 @@ def load_pipeline(model_name):
41
  model=repo,
42
  tokenizer=tokenizer,
43
  trust_remote_code=True,
44
- dtype=dtype, # Use `dtype` instead of deprecated `torch_dtype`
45
  device_map="auto",
46
- use_cache=True, # Enable past-key-value caching
47
  token=access_token)
48
  PIPELINES[model_name] = pipe
49
  return pipe
@@ -61,7 +58,6 @@ def load_pipeline(model_name):
61
  PIPELINES[model_name] = pipe
62
  return pipe
63
 
64
-
65
  def retrieve_context(query, max_results=6, max_chars=50):
66
  """
67
  Retrieve search snippets from DuckDuckGo (runs in background).
@@ -91,20 +87,23 @@ def format_conversation(history, system_prompt, tokenizer):
91
  return prompt
92
 
93
  def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
94
- # Get model size from the MODELS dict (more reliable than string parsing)
95
- model_size = MODELS[model_name].get("params_b", 4.0) # Default to 4B if not found
96
 
97
  # Only use AOT for models >= 2B parameters
98
  use_aot = model_size >= 2
99
 
100
- # Adjusted for H200 performance: faster inference, quicker compilation
101
- base_duration = 20 if not use_aot else 40 # Reduced base times
102
- token_duration = max_tokens * 0.005 # ~200 tokens/second average on H200
103
- search_duration = 10 if enable_search else 0 # Reduced search time
104
- aot_compilation_buffer = 20 if use_aot else 0 # Faster compilation on H200
105
 
106
  return base_duration + token_duration + search_duration + aot_compilation_buffer
107
 
 
 
 
108
 
109
  def chat_response(user_msg, chat_history, system_prompt,
110
  enable_search, max_results, max_chars,
@@ -135,71 +134,53 @@ def chat_response(user_msg, chat_history, system_prompt,
135
  else:
136
  debug = 'Web search disabled.'
137
 
 
 
 
 
 
 
 
 
 
 
138
  try:
139
  cur_date = datetime.now().strftime('%Y-%m-%d')
140
- # merge any fetched search results into the system prompt
141
- if search_results:
142
- enriched = system_prompt.strip() + f"""
143
- # SEARCH CONTEXT (TRUSTED SOURCES ONLY)
144
- Below are web search results. Treat them as the ONLY source of truth for answering.
145
- {search_results}
146
-
147
- RULES (VERY IMPORTANT):
148
- - Do NOT use outside knowledge. Do NOT guess or fill missing information.
149
- - If the answer is not clearly supported by the search results, say: "Not enough information in the provided sources."
150
- - Every factual statement must be directly supported by at least one citation [citation:X].
151
- - Do NOT add explanations, examples, or background that are not explicitly present in the sources.
152
- - Do NOT paraphrase beyond what is necessary for clarity.
153
- - If sources conflict, mention the conflict and cite both.
154
- - If multiple sources are used, distribute citations per sentence, not only at the end.
155
 
156
- CITATION RULES:
157
- - Use inline citations like this: [citation:1]
158
- - If multiple sources support a sentence: [citation:1][citation:3]
159
- - Never place all citations only at the end.
160
-
161
- ANSWER POLICY:
162
- - Be concise and strictly grounded.
163
- - No speculation, no assumptions, no "likely", no "probably".
164
- - If the user requests a list, only include items explicitly found in sources.
165
- - If sources are insufficient, stop and ask for more data instead of guessing.
166
- DATE CONTEXT:
167
- - Today is {cur_date} (use only for time reference, not for assumptions).
168
- USER QUESTION:
169
- """
170
- else:
171
- enriched = system_prompt
172
-
173
- # wait up to 1s for snippets, then replace debug with them
174
- if enable_search:
175
- thread_search.join(timeout=float(search_timeout))
176
- if search_results:
177
- debug = "### Search results merged into prompt\n\n" + "\n".join(
178
- f"- {r}" for r in search_results
179
- )
180
- else:
181
- debug = "*No web search results found.*"
182
-
183
- # merge fetched snippets into the system prompt
184
  if search_results:
185
- enriched = system_prompt.strip() + \
186
- f'''\n# The following contents are the search results related to the user's message:
187
- {search_results}
188
- In the search results I provide to you, each result is formatted as [webpage X begin]...[webpage X end], where X represents the numerical index of each article. Please cite the context at the end of the relevant sentence when appropriate. Use the citation format [citation:X] in the corresponding part of your answer. If a sentence is derived from multiple contexts, list all relevant citation numbers, such as [citation:3][citation:5]. Be sure not to cluster all citations at the end; instead, include them in the corresponding parts of the answer.
189
- When responding, please keep the following points in mind:
190
- - Today is {cur_date}.
191
- - Not all content in the search results is closely related to the user's question. You need to evaluate and filter the search results based on the question.
192
- - For listing-type questions (e.g., listing all flight information), try to limit the answer to 10 key points and inform the user that they can refer to the search sources for complete information. Prioritize providing the most complete and relevant items in the list. Avoid mentioning content not provided in the search results unless necessary.
193
- - For creative tasks (e.g., writing an essay), ensure that references are cited within the body of the text, such as [citation:3][citation:5], rather than only at the end of the text. You need to interpret and summarize the user's requirements, choose an appropriate format, fully utilize the search results, extract key information, and generate an answer that is insightful, creative, and professional. Extend the length of your response as much as possible, addressing each point in detail and from multiple perspectives, ensuring the content is rich and thorough.
194
- - If the response is lengthy, structure it well and summarize it in paragraphs. If a point-by-point format is needed, try to limit it to 5 points and merge related content.
195
- - For objective Q&A, if the answer is very brief, you may add one or two related sentences to enrich the content.
196
- - Choose an appropriate and visually appealing format for your response based on the user's requirements and the content of the answer, ensuring strong readability.
197
- - Your answer should synthesize information from multiple relevant webpages and avoid repeatedly citing the same webpage.
198
- - Unless the user requests otherwise, your response should be in the same language as the user's question.
199
- # The user's message is:
200
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  else:
202
- enriched = system_prompt
203
 
204
  pipe = load_pipeline(model_name)
205
 
@@ -288,7 +269,6 @@ def chat_response(user_msg, chat_history, system_prompt,
288
  except GeneratorExit:
289
  # Handle cancellation gracefully
290
  print("Chat response cancelled.")
291
- # Don't yield anything - let the cancellation propagate
292
  return
293
  except Exception as e:
294
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
@@ -296,7 +276,6 @@ def chat_response(user_msg, chat_history, system_prompt,
296
  finally:
297
  gc.collect()
298
 
299
-
300
  def update_default_prompt(enable_search):
301
  return f"You are a helpful assistant."
302
 
@@ -307,7 +286,7 @@ def update_duration_estimate(model_name, enable_search, max_results, max_chars,
307
  duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
308
  enable_search, max_results, max_chars, model_name,
309
  max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
310
- model_size = MODELS[model_name].get("params_b", 4.0)
311
  return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
312
  f"📊 **Model Size:** {model_size:.1f}B parameters\n"
313
  f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
@@ -355,7 +334,7 @@ with gr.Blocks(
355
  value=False,
356
  info="Augment responses with real-time web data"
357
  )
358
- sys_prompt = gr.Textbox(label="📝 System Prompt", lines=3, value=update_default_prompt(search_chk.value), placeholder="Define the assistant's behavior and personality...")
359
 
360
  # Duration Estimate
361
  duration_display = gr.Markdown(
@@ -479,14 +458,10 @@ with gr.Blocks(
479
  It uses a try...finally block to ensure the UI is always reset.
480
  """
481
  if not user_msg.strip():
482
- # If the message is empty, do nothing.
483
- # We yield an empty dict to avoid any state changes.
484
  yield {}
485
  return
486
 
487
- # 1. Update UI to "generating" state.
488
- # Crucially, we do NOT update the `chat` component here, as the backend
489
- # will provide the correctly formatted history in the first response chunk.
490
  yield {
491
  txt: gr.update(value="", interactive=False),
492
  submit_btn: gr.update(interactive=False),
@@ -495,7 +470,6 @@ with gr.Blocks(
495
 
496
  cancelled = False
497
  try:
498
- # 2. Call the backend and stream updates
499
  backend_args = [user_msg, chat_history] + list(args)
500
  for response_chunk in chat_response(*backend_args):
501
  yield {
@@ -503,20 +477,17 @@ with gr.Blocks(
503
  dbg: response_chunk[1],
504
  }
505
  except GeneratorExit:
506
- # Mark as cancelled and re-raise to prevent "generator ignored GeneratorExit"
507
  cancelled = True
508
  print("Generation cancelled by user.")
509
  raise
510
  except Exception as e:
511
  print(f"An error occurred during generation: {e}")
512
- # If an error happens, add it to the chat history to inform the user.
513
  error_history = (chat_history or []) + [
514
  {'role': 'user', 'content': user_msg},
515
  {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
516
  ]
517
  yield {chat: error_history}
518
  finally:
519
- # Only reset UI if not cancelled (to avoid "generator ignored GeneratorExit")
520
  if not cancelled:
521
  print("Resetting UI state.")
522
  yield {
@@ -532,7 +503,7 @@ with gr.Blocks(
532
 
533
  def reset_ui_after_cancel():
534
  """Reset UI components after cancellation."""
535
- cancel_event.clear() # Clear the flag for next generation
536
  print("UI reset after cancellation.")
537
  return {
538
  txt: gr.update(interactive=True),
@@ -553,7 +524,6 @@ with gr.Blocks(
553
  )
554
 
555
  # Event for the "Cancel" button.
556
- # It sets the cancel flag, cancels the submit event, then resets the UI.
557
  cancel_btn.click(
558
  fn=set_cancel_flag,
559
  cancels=[submit_event]
 
5
  import threading
6
  from itertools import islice
7
  from datetime import datetime
8
+ import re
9
  import gradio as gr
10
  import torch
11
+ from transformers import pipeline, TextIteratorStreamer
12
  from transformers import AutoTokenizer
13
  from ddgs import DDGS
14
+ from config import MODELS # Import from config file
15
+
16
  # Global event to signal cancellation from the UI thread to the generation thread
17
  cancel_event = threading.Event()
18
 
19
+ access_token = os.environ.get('HF_TOKEN', '')
 
 
20
 
21
  # Global cache for pipelines to avoid re-loading.
22
  PIPELINES = {}
 
30
  if model_name in PIPELINES:
31
  return PIPELINES[model_name]
32
  repo = MODELS[model_name]["repo_id"]
33
+ tokenizer = AutoTokenizer.from_pretrained(repo, token=access_token)
 
34
  for dtype in (torch.bfloat16, torch.float16, torch.float32):
35
  try:
36
  pipe = pipeline(
 
38
  model=repo,
39
  tokenizer=tokenizer,
40
  trust_remote_code=True,
41
+ dtype=dtype,
42
  device_map="auto",
43
+ use_cache=True,
44
  token=access_token)
45
  PIPELINES[model_name] = pipe
46
  return pipe
 
58
  PIPELINES[model_name] = pipe
59
  return pipe
60
 
 
61
  def retrieve_context(query, max_results=6, max_chars=50):
62
  """
63
  Retrieve search snippets from DuckDuckGo (runs in background).
 
87
  return prompt
88
 
89
  def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout):
90
+ # Get model size from the MODELS dict
91
+ model_size = MODELS[model_name].get("params_b", 4.0)
92
 
93
  # Only use AOT for models >= 2B parameters
94
  use_aot = model_size >= 2
95
 
96
+ # Adjusted for H200 performance
97
+ base_duration = 20 if not use_aot else 40
98
+ token_duration = max_tokens * 0.005
99
+ search_duration = 10 if enable_search else 0
100
+ aot_compilation_buffer = 20 if use_aot else 0
101
 
102
  return base_duration + token_duration + search_duration + aot_compilation_buffer
103
 
104
+ def get_model_size(model_name):
105
+ """Get model size from the MODELS dict."""
106
+ return MODELS.get(model_name, {}).get("params_b", 4.0)
107
 
108
  def chat_response(user_msg, chat_history, system_prompt,
109
  enable_search, max_results, max_chars,
 
134
  else:
135
  debug = 'Web search disabled.'
136
 
137
+ # Wait for search results if enabled
138
+ if enable_search:
139
+ thread_search.join(timeout=float(search_timeout))
140
+ if search_results:
141
+ debug = "### Search results merged into prompt\n\n" + "\n".join(
142
+ f"- {r}" for r in search_results
143
+ )
144
+ else:
145
+ debug = "*No web search results found.*"
146
+
147
  try:
148
  cur_date = datetime.now().strftime('%Y-%m-%d')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ # Prepare enriched system prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  if search_results:
152
+ enriched = system_prompt.strip() + f"""
153
+ # SEARCH CONTEXT (TRUSTED SOURCES ONLY)
154
+ Below are web search results. Treat them as the ONLY source of truth for answering.
155
+ {search_results}
156
+
157
+ RULES (VERY IMPORTANT):
158
+ - Do NOT use outside knowledge. Do NOT guess or fill missing information.
159
+ - If the answer is not clearly supported by the search results, say: "Not enough information in the provided sources."
160
+ - Every factual statement must be directly supported by at least one citation [citation:X].
161
+ - Do NOT add explanations, examples, or background that are not explicitly present in the sources.
162
+ - Do NOT paraphrase beyond what is necessary for clarity.
163
+ - If sources conflict, mention the conflict and cite both.
164
+ - If multiple sources are used, distribute citations per sentence, not only at the end.
165
+
166
+ CITATION RULES:
167
+ - Use inline citations like this: [citation:1]
168
+ - If multiple sources support a sentence: [citation:1][citation:3]
169
+ - Never place all citations only at the end.
170
+
171
+ ANSWER POLICY:
172
+ - Be concise and strictly grounded.
173
+ - No speculation, no assumptions, no "likely", no "probably".
174
+ - If the user requests a list, only include items explicitly found in sources.
175
+ - If sources are insufficient, stop and ask for more data instead of guessing.
176
+
177
+ DATE CONTEXT:
178
+ - Today is {cur_date} (use only for time reference, not for assumptions).
179
+
180
+ USER QUESTION:
181
+ """
182
  else:
183
+ enriched = system_prompt.strip()
184
 
185
  pipe = load_pipeline(model_name)
186
 
 
269
  except GeneratorExit:
270
  # Handle cancellation gracefully
271
  print("Chat response cancelled.")
 
272
  return
273
  except Exception as e:
274
  history.append({'role': 'assistant', 'content': f"Error: {e}"})
 
276
  finally:
277
  gc.collect()
278
 
 
279
  def update_default_prompt(enable_search):
280
  return f"You are a helpful assistant."
281
 
 
286
  duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt,
287
  enable_search, max_results, max_chars, model_name,
288
  max_tokens, 0.7, 40, 0.9, 1.2, search_timeout)
289
+ model_size = get_model_size(model_name)
290
  return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n"
291
  f"📊 **Model Size:** {model_size:.1f}B parameters\n"
292
  f"🔍 **Web Search:** {'Enabled' if enable_search else 'Disabled'}")
 
334
  value=False,
335
  info="Augment responses with real-time web data"
336
  )
337
+ sys_prompt = gr.Textbox(label="📝 System Prompt", lines=3, value=update_default_prompt(False), placeholder="Define the assistant's behavior and personality...")
338
 
339
  # Duration Estimate
340
  duration_display = gr.Markdown(
 
458
  It uses a try...finally block to ensure the UI is always reset.
459
  """
460
  if not user_msg.strip():
 
 
461
  yield {}
462
  return
463
 
464
+ # Update UI to "generating" state
 
 
465
  yield {
466
  txt: gr.update(value="", interactive=False),
467
  submit_btn: gr.update(interactive=False),
 
470
 
471
  cancelled = False
472
  try:
 
473
  backend_args = [user_msg, chat_history] + list(args)
474
  for response_chunk in chat_response(*backend_args):
475
  yield {
 
477
  dbg: response_chunk[1],
478
  }
479
  except GeneratorExit:
 
480
  cancelled = True
481
  print("Generation cancelled by user.")
482
  raise
483
  except Exception as e:
484
  print(f"An error occurred during generation: {e}")
 
485
  error_history = (chat_history or []) + [
486
  {'role': 'user', 'content': user_msg},
487
  {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"}
488
  ]
489
  yield {chat: error_history}
490
  finally:
 
491
  if not cancelled:
492
  print("Resetting UI state.")
493
  yield {
 
503
 
504
  def reset_ui_after_cancel():
505
  """Reset UI components after cancellation."""
506
+ cancel_event.clear()
507
  print("UI reset after cancellation.")
508
  return {
509
  txt: gr.update(interactive=True),
 
524
  )
525
 
526
  # Event for the "Cancel" button.
 
527
  cancel_btn.click(
528
  fn=set_cancel_flag,
529
  cancels=[submit_event]