jedick commited on
Commit
03db0de
·
1 Parent(s): 0efb496

Normalize message types for Gemma

Browse files
Files changed (5) hide show
  1. app.py +157 -145
  2. graph.py +70 -40
  3. main.py +12 -7
  4. prompts.py +32 -14
  5. util.py +20 -5
app.py CHANGED
@@ -3,9 +3,9 @@ from main import GetChatModel
3
  from graph import BuildGraph
4
  from retriever import db_dir
5
  from langgraph.checkpoint.memory import MemorySaver
6
- from dotenv import load_dotenv
7
 
8
- # from util import get_collection, get_sources, get_start_end_months
 
9
  from git import Repo
10
  import zipfile
11
  import spaces
@@ -18,11 +18,6 @@ import os
18
  COMPUTE = "cloud"
19
  search_type = "hybrid"
20
 
21
- # Load LANGCHAIN_API_KEY (for local deployment)
22
- load_dotenv(dotenv_path=".env", override=True)
23
- os.environ["LANGSMITH_TRACING"] = "true"
24
- os.environ["LANGSMITH_PROJECT"] = "R-help-chat"
25
-
26
  # Check for GPU
27
  if COMPUTE == "edge":
28
  if not torch.cuda.is_available():
@@ -33,7 +28,7 @@ graph_edge = None
33
  graph_cloud = None
34
 
35
 
36
- def run_workflow(chatbot, input, thread_id):
37
  """The main function to run the chat workflow"""
38
 
39
  # Get global graph for compute location
@@ -69,10 +64,10 @@ def run_workflow(chatbot, input, thread_id):
69
 
70
  print(f"Using thread_id: {thread_id}")
71
 
72
- # Display the user input in the chatbot interface
73
- chatbot.append(gr.ChatMessage(role="user", content=input))
74
- # Return the chatbot messages and empty lists for emails and citations texboxes
75
- yield chatbot, [], []
76
 
77
  # Asynchronously stream graph steps for a single input
78
  # https://langchain-ai.lang.chat/langgraph/reference/graphs/#langgraph.graph.state.CompiledStateGraph
@@ -101,7 +96,7 @@ def run_workflow(chatbot, input, thread_id):
101
  content = f"{content} ({start_year or ''} - {end_year or ''})"
102
  if "months" in args:
103
  content = f"{content} {args['months']}"
104
- chatbot.append(
105
  gr.ChatMessage(
106
  role="assistant",
107
  content=content,
@@ -109,10 +104,10 @@ def run_workflow(chatbot, input, thread_id):
109
  )
110
  )
111
  if chunk_messages.content:
112
- chatbot.append(
113
  gr.ChatMessage(role="assistant", content=chunk_messages.content)
114
  )
115
- yield chatbot, [], []
116
 
117
  if node == "retrieve_emails":
118
  chunk_messages = chunk["messages"]
@@ -136,7 +131,7 @@ def run_workflow(chatbot, input, thread_id):
136
  title = f"🛒 Retrieved {n_emails} emails"
137
  if email_list[0] == "### No emails were retrieved":
138
  title = "❌ Retrieved 0 emails"
139
- chatbot.append(
140
  gr.ChatMessage(
141
  role="assistant",
142
  content=month_text,
@@ -152,17 +147,17 @@ def run_workflow(chatbot, input, thread_id):
152
  )
153
  # Combine all the Tool Call results
154
  retrieved_emails = "\n\n".join(retrieved_emails)
155
- yield chatbot, retrieved_emails, []
156
 
157
  if node == "generate":
158
  chunk_messages = chunk["messages"]
159
  # Chat response without citations
160
  if chunk_messages.content:
161
- chatbot.append(
162
  gr.ChatMessage(role="assistant", content=chunk_messages.content)
163
  )
164
  # None is used for no change to the retrieved emails textbox
165
- yield chatbot, None, []
166
 
167
  if node == "answer_with_citations":
168
  chunk_messages = chunk["messages"][0]
@@ -174,8 +169,8 @@ def run_workflow(chatbot, input, thread_id):
174
  answer = chunk_messages.content
175
  citations = None
176
 
177
- chatbot.append(gr.ChatMessage(role="assistant", content=answer))
178
- yield chatbot, None, citations
179
 
180
 
181
  def to_workflow(*args):
@@ -230,12 +225,6 @@ with gr.Blocks(
230
  render=False,
231
  )
232
 
233
- input = gr.Textbox(
234
- lines=1,
235
- label="Your Question",
236
- info="Press Enter to submit",
237
- render=False,
238
- )
239
  downloading = gr.Textbox(
240
  lines=1,
241
  label="Downloading Data, Please Wait",
@@ -248,6 +237,13 @@ with gr.Blocks(
248
  visible=False,
249
  render=False,
250
  )
 
 
 
 
 
 
 
251
  show_examples = gr.Checkbox(
252
  value=False,
253
  label="💡 Example Questions",
@@ -268,142 +264,142 @@ with gr.Blocks(
268
  render=False,
269
  )
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  # ------------------
272
  # Make the interface
273
  # ------------------
274
 
275
  def get_intro_text():
276
- ## Get start and end months from database
277
- # start, end = get_start_end_months(get_sources(compute_location.value))
278
  intro = f"""<!-- # 🤖 R-help-chat -->
 
279
  ## 🇷🤝💬 R-help-chat
280
 
281
- **Chat with the [R-help mailing list archives]((https://stat.ethz.ch/pipermail/r-help/)).** Get AI-powered answers about R programming backed by email retrieval.
282
- An LLM turns your question into a search query, including year ranges.
283
  You can ask follow-up questions with the chat history as context.
284
- ➡️ To clear the chat history and start a new chat, press the 🗑️ trash button.<br>
285
  **_Answers may be incorrect._**<br>
286
  """
287
  return intro
288
 
289
- def get_info_text(compute_location):
290
- info_prefix = """
291
- **Features:** conversational RAG, today's date, email database (*start* to *end*), hybrid search (dense+sparse),
292
- query analysis, multiple tool calls (cloud model), answer with citations.
293
- **Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
294
- """
295
  if compute_location.startswith("cloud"):
296
- info_text = f"""{info_prefix}
297
  📍 This is the **cloud** version, using the OpenAI API<br>
298
- gpt-4o-mini<br>
299
- ⚠️ **_Privacy Notice_**: Data sharing with OpenAI is enabled, and all interactions are logged<br>
300
  🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
301
  """
302
  if compute_location.startswith("edge"):
303
- info_text = f"""{info_prefix}
304
- 📍 This is the **edge** version, using [ZeroGPU](https://huggingface.co/docs/hub/spaces-zerogpu) hardware<br>
305
- ✨ Nomic embeddings and Gemma-3 LLM<br>
306
- ⚠️ **_Privacy Notice_**: All interactions are logged<br>
307
  🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
308
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  return info_text
310
 
311
- with gr.Row(elem_classes=["row-container"]):
 
312
  with gr.Column(scale=2):
313
  with gr.Row(elem_classes=["row-container"]):
314
  with gr.Column(scale=2):
315
  intro = gr.Markdown(get_intro_text())
316
  with gr.Column(scale=1):
317
  compute_location.render()
318
- input.render()
 
 
 
 
 
 
 
319
  downloading.render()
320
  extracting.render()
321
- with gr.Column(scale=1):
322
- # Add information about the system
323
- with gr.Accordion("ℹ️ About This App", open=True):
324
- ## Get number of emails (unique doc ids) in vector database
325
- # collection = get_collection(compute_location.value)
326
- # n_emails = len(set([m["doc_id"] for m in collection["metadatas"]]))
327
- # gr.Markdown(
328
- # f"""
329
- # - **Database**: *n_emails* emails from the [R-help mailing list archives](https://stat.ethz.ch/pipermail/r-help/)
330
- # - **System**: retrieval and citation tools; system prompt has today's date
331
- # - **Retrieval**: hybrid of dense (vector embeddings) and sparse ([BM25S](https://github.com/xhluca/bm25s))
332
- # """
333
- # )
334
- info = gr.Markdown(get_info_text(compute_location.value))
335
- show_examples.render()
336
-
337
- with gr.Row():
338
-
339
- with gr.Column(scale=2):
340
- chatbot.render()
341
-
342
- with gr.Column(scale=1, visible=False) as examples:
343
- # Add some helpful examples
344
- example_questions = [
345
- # "What is today's date?",
346
- "Summarize emails from the last two months",
347
- "What plotmath examples have been discussed?",
348
- "When was has.HLC mentioned?",
349
- "Who discussed profiling in 2023?",
350
- "Any messages about installation problems in 2023-2024?",
351
- ]
352
- gr.Examples(
353
- examples=[[q] for q in example_questions],
354
- inputs=[input],
355
- label="Click an example to fill the question box",
356
- elem_id="example-questions",
357
- )
358
- multi_tool_questions = [
359
- "Speed differences between lapply and for loops",
360
- "Compare usage of pipe operator between 2022 and 2024",
361
- ]
362
- gr.Examples(
363
- examples=[[q] for q in multi_tool_questions],
364
- inputs=[input],
365
- label="Example prompts for multiple retrievals",
366
- elem_id="example-questions",
367
- )
368
- multi_turn_questions = [
369
- "Lookup emails that reference bugs.r-project.org in 2025",
370
- "Did those authors report bugs before 2025?",
371
- ]
372
- gr.Examples(
373
- examples=[[q] for q in multi_turn_questions],
374
- inputs=[input],
375
- label="Multi-turn example for asking follow-up questions",
376
- elem_id="example-questions",
377
- )
378
-
379
- with gr.Row():
380
- with gr.Column(scale=2):
381
  emails_textbox = gr.Textbox(
382
  label="Retrieved Emails",
383
  lines=10,
384
  visible=False,
385
  info="Tip: Look for 'Tool Call' and 'Next Email' separators. Quoted lines (starting with '>') are removed before indexing.",
386
  )
387
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  citations_textbox = gr.Textbox(label="Citations", lines=2, visible=False)
389
 
390
- # ------------
391
- # Set up state
392
- # ------------
393
-
394
- def generate_thread_id():
395
- """Generate a new thread ID"""
396
- thread_id = uuid.uuid4()
397
- print(f"Generated thread_id: {thread_id}")
398
- return thread_id
399
-
400
- # Define thread_id variable
401
- thread_id = gr.State(generate_thread_id())
402
-
403
- # Define states for the output textboxes
404
- retrieved_emails = gr.State([])
405
- citations_text = gr.State([])
406
-
407
  # -------------
408
  # App functions
409
  # -------------
@@ -458,16 +454,26 @@ with gr.Blocks(
458
  # https://github.com/gradio-app/gradio/issues/9722
459
  chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
460
 
 
 
 
 
461
  compute_location.change(
462
  # Update global COMPUTE variable
463
  set_compute,
464
  [compute_location],
465
  api_name=False,
466
  ).then(
467
- # Change the info text
468
- get_info_text,
469
  [compute_location],
470
- [info],
 
 
 
 
 
 
471
  api_name=False,
472
  ).then(
473
  # Change the chatbot avatar
@@ -475,21 +481,10 @@ with gr.Blocks(
475
  [compute_location],
476
  [chatbot],
477
  api_name=False,
478
- )
479
-
480
- show_examples.change(
481
- # Show examples
482
- change_visibility,
483
- [show_examples],
484
- [examples],
485
- api_name=False,
486
- )
487
-
488
- input.submit(
489
- # Submit input to the chatbot
490
- to_workflow,
491
- [chatbot, input, thread_id],
492
- [chatbot, retrieved_emails, citations_text],
493
  api_name=False,
494
  )
495
 
@@ -558,13 +553,20 @@ with gr.Blocks(
558
  need_data = gr.State()
559
  have_data = gr.State()
560
 
 
 
 
 
 
561
  # fmt: off
562
  demo.load(
563
  is_data_missing, None, [need_data], api_name=False
564
  ).then(
565
  is_data_present, None, [have_data], api_name=False
566
  ).then(
567
- change_visibility, [have_data], [input], api_name=False
 
 
568
  ).then(
569
  change_visibility, [need_data], [downloading], api_name=False
570
  ).then(
@@ -578,7 +580,17 @@ with gr.Blocks(
578
  ).then(
579
  change_visibility, [false], [extracting], api_name=False
580
  ).then(
581
- change_visibility, [true], [input], api_name=False
 
 
 
 
 
 
 
 
 
 
582
  )
583
  # fmt: on
584
 
 
3
  from graph import BuildGraph
4
  from retriever import db_dir
5
  from langgraph.checkpoint.memory import MemorySaver
 
6
 
7
+ from main import openai_model, model_id
8
+ from util import get_sources, get_start_end_months
9
  from git import Repo
10
  import zipfile
11
  import spaces
 
18
  COMPUTE = "cloud"
19
  search_type = "hybrid"
20
 
 
 
 
 
 
21
  # Check for GPU
22
  if COMPUTE == "edge":
23
  if not torch.cuda.is_available():
 
28
  graph_cloud = None
29
 
30
 
31
+ def run_workflow(input, history, thread_id):
32
  """The main function to run the chat workflow"""
33
 
34
  # Get global graph for compute location
 
64
 
65
  print(f"Using thread_id: {thread_id}")
66
 
67
+ # # Display the user input in the history
68
+ # history.append(gr.ChatMessage(role="user", content=input))
69
+ # # Return the history and empty lists for emails and citations texboxes
70
+ # yield history, [], []
71
 
72
  # Asynchronously stream graph steps for a single input
73
  # https://langchain-ai.lang.chat/langgraph/reference/graphs/#langgraph.graph.state.CompiledStateGraph
 
96
  content = f"{content} ({start_year or ''} - {end_year or ''})"
97
  if "months" in args:
98
  content = f"{content} {args['months']}"
99
+ history.append(
100
  gr.ChatMessage(
101
  role="assistant",
102
  content=content,
 
104
  )
105
  )
106
  if chunk_messages.content:
107
+ history.append(
108
  gr.ChatMessage(role="assistant", content=chunk_messages.content)
109
  )
110
+ yield history, [], []
111
 
112
  if node == "retrieve_emails":
113
  chunk_messages = chunk["messages"]
 
131
  title = f"🛒 Retrieved {n_emails} emails"
132
  if email_list[0] == "### No emails were retrieved":
133
  title = "❌ Retrieved 0 emails"
134
+ history.append(
135
  gr.ChatMessage(
136
  role="assistant",
137
  content=month_text,
 
147
  )
148
  # Combine all the Tool Call results
149
  retrieved_emails = "\n\n".join(retrieved_emails)
150
+ yield history, retrieved_emails, []
151
 
152
  if node == "generate":
153
  chunk_messages = chunk["messages"]
154
  # Chat response without citations
155
  if chunk_messages.content:
156
+ history.append(
157
  gr.ChatMessage(role="assistant", content=chunk_messages.content)
158
  )
159
  # None is used for no change to the retrieved emails textbox
160
+ yield history, None, []
161
 
162
  if node == "answer_with_citations":
163
  chunk_messages = chunk["messages"][0]
 
169
  answer = chunk_messages.content
170
  citations = None
171
 
172
+ history.append(gr.ChatMessage(role="assistant", content=answer))
173
+ yield history, None, citations
174
 
175
 
176
  def to_workflow(*args):
 
225
  render=False,
226
  )
227
 
 
 
 
 
 
 
228
  downloading = gr.Textbox(
229
  lines=1,
230
  label="Downloading Data, Please Wait",
 
237
  visible=False,
238
  render=False,
239
  )
240
+ data_error = gr.Textbox(
241
+ value="App is unavailable. Please contact the maintainer.",
242
+ lines=1,
243
+ label="Error downloading or extracting data",
244
+ visible=False,
245
+ render=False,
246
+ )
247
  show_examples = gr.Checkbox(
248
  value=False,
249
  label="💡 Example Questions",
 
264
  render=False,
265
  )
266
 
267
+ # ------------
268
+ # Set up state
269
+ # ------------
270
+
271
+ def generate_thread_id():
272
+ """Generate a new thread ID"""
273
+ thread_id = uuid.uuid4()
274
+ print(f"Generated thread_id: {thread_id}")
275
+ return thread_id
276
+
277
+ # Define thread_id variable
278
+ thread_id = gr.State(generate_thread_id())
279
+
280
+ # Define states for the output textboxes
281
+ retrieved_emails = gr.State([])
282
+ citations_text = gr.State([])
283
+
284
  # ------------------
285
  # Make the interface
286
  # ------------------
287
 
288
  def get_intro_text():
 
 
289
  intro = f"""<!-- # 🤖 R-help-chat -->
290
+ <!-- Get AI-powered answers about R programming backed by email retrieval. -->
291
  ## 🇷🤝💬 R-help-chat
292
 
293
+ **Chat with the [R-help mailing list archives]((https://stat.ethz.ch/pipermail/r-help/)).**
294
+ An LLM turns your question into a search query, including year ranges, and generates an answer from the retrieved emails.
295
  You can ask follow-up questions with the chat history as context.
296
+ ➡️ To clear the history and start a new chat, press the 🗑️ trash button.<br>
297
  **_Answers may be incorrect._**<br>
298
  """
299
  return intro
300
 
301
+ def get_status_text(compute_location):
 
 
 
 
 
302
  if compute_location.startswith("cloud"):
303
+ status_text = f"""
304
  📍 This is the **cloud** version, using the OpenAI API<br>
305
+ text-embedding-3-small and {openai_model}<br>
306
+ ⚠️ **_Privacy Notice_**: Data sharing with OpenAI is enabled<br>
307
  🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
308
  """
309
  if compute_location.startswith("edge"):
310
+ status_text = f"""
311
+ 📍 This is the **edge** version, using ZeroGPU hardware<br>
312
+ Embeddings: [Nomic](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5); LLM: [{model_id}](https://huggingface.co/{model_id})<br>
 
313
  🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
314
  """
315
+ return status_text
316
+
317
+ def get_info_text():
318
+ try:
319
+ # Get source files for each email and start and end months from database
320
+ sources = get_sources()
321
+ start, end = get_start_end_months(sources)
322
+ except:
323
+ # If database isn't ready, put in empty values
324
+ sources = []
325
+ start = None
326
+ end = None
327
+ info_text = f"""
328
+ **Database:** {len(sources)} emails from {start} to {end}.
329
+ **Features:** RAG, today's date, hybrid search (dense+sparse), query analysis,
330
+ multiple tool calls (cloud model), answer with citations, chat memory.
331
+ **Tech:** LangChain + Hugging Face + Gradio; ChromaDB and [BM25S](https://github.com/xhluca/bm25s)-based retrievers.<br>
332
+ """
333
  return info_text
334
 
335
+ with gr.Row():
336
+ # Left column: Intro, Compute, Chat, Emails
337
  with gr.Column(scale=2):
338
  with gr.Row(elem_classes=["row-container"]):
339
  with gr.Column(scale=2):
340
  intro = gr.Markdown(get_intro_text())
341
  with gr.Column(scale=1):
342
  compute_location.render()
343
+ chat_interface = gr.ChatInterface(
344
+ to_workflow,
345
+ chatbot=chatbot,
346
+ type="messages",
347
+ additional_inputs=[thread_id],
348
+ additional_outputs=[retrieved_emails, citations_text],
349
+ api_name=False,
350
+ )
351
  downloading.render()
352
  extracting.render()
353
+ data_error.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  emails_textbox = gr.Textbox(
355
  label="Retrieved Emails",
356
  lines=10,
357
  visible=False,
358
  info="Tip: Look for 'Tool Call' and 'Next Email' separators. Quoted lines (starting with '>') are removed before indexing.",
359
  )
360
+ # Right column: Info, Examples, Citations
361
+ with gr.Column(scale=1):
362
+ status = gr.Markdown(get_status_text(compute_location.value))
363
+ with gr.Accordion("ℹ️ More Info", open=False):
364
+ info = gr.Markdown(get_info_text())
365
+ with gr.Accordion("💡 Examples", open=True):
366
+ # Add some helpful examples
367
+ example_questions = [
368
+ # "What is today's date?",
369
+ "Summarize emails from the last two months",
370
+ "What plotmath examples have been discussed?",
371
+ "When was has.HLC mentioned?",
372
+ "Who discussed profiling in 2023?",
373
+ "Any messages about installation problems in 2023-2024?",
374
+ ]
375
+ gr.Examples(
376
+ examples=[[q] for q in example_questions],
377
+ inputs=[chat_interface.textbox],
378
+ label="Click an example to fill the message box",
379
+ elem_id="example-questions",
380
+ )
381
+ multi_tool_questions = [
382
+ "Differences between lapply and for loops",
383
+ "Compare usage of pipe operator between 2022 and 2024",
384
+ ]
385
+ gr.Examples(
386
+ examples=[[q] for q in multi_tool_questions],
387
+ inputs=[chat_interface.textbox],
388
+ label="Prompts for multiple retrievals",
389
+ elem_id="example-questions",
390
+ )
391
+ multi_turn_questions = [
392
+ "Lookup emails that reference bugs.r-project.org in 2025",
393
+ "Did those authors report bugs before 2025?",
394
+ ]
395
+ gr.Examples(
396
+ examples=[[q] for q in multi_turn_questions],
397
+ inputs=[chat_interface.textbox],
398
+ label="Asking follow-up questions",
399
+ elem_id="example-questions",
400
+ )
401
  citations_textbox = gr.Textbox(label="Citations", lines=2, visible=False)
402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  # -------------
404
  # App functions
405
  # -------------
 
454
  # https://github.com/gradio-app/gradio/issues/9722
455
  chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
456
 
457
+ def clear_component(component):
458
+ """Return cleared component"""
459
+ return component.clear()
460
+
461
  compute_location.change(
462
  # Update global COMPUTE variable
463
  set_compute,
464
  [compute_location],
465
  api_name=False,
466
  ).then(
467
+ # Change the app status text
468
+ get_status_text,
469
  [compute_location],
470
+ [status],
471
+ api_name=False,
472
+ ).then(
473
+ # Clear the chatbot history
474
+ clear_component,
475
+ [chatbot],
476
+ [chatbot],
477
  api_name=False,
478
  ).then(
479
  # Change the chatbot avatar
 
481
  [compute_location],
482
  [chatbot],
483
  api_name=False,
484
+ ).then(
485
+ # Start a new thread
486
+ generate_thread_id,
487
+ outputs=[thread_id],
 
 
 
 
 
 
 
 
 
 
 
488
  api_name=False,
489
  )
490
 
 
553
  need_data = gr.State()
554
  have_data = gr.State()
555
 
556
+ # When app is launched, check if data is present, download it if necessary,
557
+ # hide chat interface during downloading, show downloading and extracting
558
+ # steps as textboxes, show error textbox if needed, restore chat interface,
559
+ # and show database info
560
+
561
  # fmt: off
562
  demo.load(
563
  is_data_missing, None, [need_data], api_name=False
564
  ).then(
565
  is_data_present, None, [have_data], api_name=False
566
  ).then(
567
+ change_visibility, [have_data], [chatbot], api_name=False
568
+ ).then(
569
+ change_visibility, [have_data], [chat_interface.textbox], api_name=False
570
  ).then(
571
  change_visibility, [need_data], [downloading], api_name=False
572
  ).then(
 
580
  ).then(
581
  change_visibility, [false], [extracting], api_name=False
582
  ).then(
583
+ is_data_missing, None, [need_data], api_name=False
584
+ ).then(
585
+ is_data_present, None, [have_data], api_name=False
586
+ ).then(
587
+ change_visibility, [have_data], [chatbot], api_name=False
588
+ ).then(
589
+ change_visibility, [have_data], [chat_interface.textbox], api_name=False
590
+ ).then(
591
+ change_visibility, [need_data], [data_error], api_name=False
592
+ ).then(
593
+ get_info_text, None, [info], api_name=False
594
  )
595
  # fmt: on
596
 
graph.py CHANGED
@@ -4,37 +4,85 @@ from langchain_core.tools import tool
4
  from langgraph.prebuilt import ToolNode, tools_condition
5
  from langchain_huggingface import ChatHuggingFace
6
  from typing import Optional
 
7
  import datetime
8
  import os
9
 
10
  # Local modules
11
  from retriever import BuildRetriever
12
- from prompts import retrieve_prompt, answer_prompt, smollm3_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
15
  # Local modules
16
  from retriever import BuildRetriever
17
 
18
  ## For LANGCHAIN_API_KEY
19
- # from dotenv import load_dotenv
20
- #
21
  # load_dotenv(dotenv_path=".env", override=True)
22
  # os.environ["LANGSMITH_TRACING"] = "true"
23
  # os.environ["LANGSMITH_PROJECT"] = "R-help-chat"
24
 
25
 
26
- def ToolifySmolLM3(chat_model, system_message, system_message_suffix="", think=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  """
28
- Get a SmolLM3 model ready for bind_tools().
29
  """
30
 
31
- # Add /no_think flag to turn off thinking mode
32
- if not think:
33
- system_message = "/no_think\n" + system_message
34
 
35
- # NOTE: The first two nonblank lines are taken from the chat template for HuggingFaceTB/SmolLM3-3B
36
- # The rest are taken from the default system template for ToolCallingLLM
37
- tool_system_prompt_template = system_message + smollm3_tools_template
38
 
39
  class HuggingFaceWithTools(ToolCallingLLM, ChatHuggingFace):
40
 
@@ -45,6 +93,7 @@ def ToolifySmolLM3(chat_model, system_message, system_message_suffix="", think=F
45
  chat_model = HuggingFaceWithTools(
46
  llm=chat_model.llm,
47
  tool_system_prompt_template=tool_system_prompt_template,
 
48
  system_message_suffix=system_message_suffix,
49
  )
50
 
@@ -154,12 +203,12 @@ def BuildGraph(
154
  is_edge = hasattr(chat_model, "model_id")
155
  if is_edge:
156
  # For edge model (ChatHuggingFace)
157
- query_model = ToolifySmolLM3(
158
  chat_model, retrieve_prompt(compute_location), "", think_retrieve
159
  ).bind_tools([retrieve_emails])
160
- generate_model = ToolifySmolLM3(chat_model, answer_prompt(), "", think_generate)
161
- # For testing with Gemma, don't bind tool for now
162
- # ).bind_tools([answer_with_citations])
163
  else:
164
  # For cloud model (OpenAI API)
165
  query_model = chat_model.bind_tools([retrieve_emails])
@@ -173,12 +222,9 @@ def BuildGraph(
173
  if is_edge:
174
  # Don't include the system message here because it's defined in ToolCallingLLM
175
  messages = state["messages"]
176
- # Convert ToolMessage (from previous turns) to AIMessage
177
- # (avoids SmolLM3 ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
178
- messages = [
179
- AIMessage(msg.content) if type(msg) is ToolMessage else msg
180
- for msg in messages
181
- ]
182
  else:
183
  messages = [SystemMessage(retrieve_prompt(compute_location))] + state[
184
  "messages"
@@ -191,25 +237,9 @@ def BuildGraph(
191
  """Generates an answer with the chat model"""
192
  if is_edge:
193
  messages = state["messages"]
194
- # Copy the most recent HumanMessage to the end
195
- # (avoids SmolLM3 ValueError: Last message must be a HumanMessage!)
196
- for msg in reversed(messages):
197
- if type(msg) is HumanMessage:
198
- messages.append(msg)
199
- # Convert tool output (ToolMessage) to AIMessage
200
- # (avoids SmolLM3 ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
201
- messages = [
202
- AIMessage(msg.content) if type(msg) is ToolMessage else msg
203
- for msg in messages
204
- ]
205
- # Delete tool call (AIMessage)
206
- # (avoids Gemma TemplateError: Conversation roles must alternate user/assistant/user/assistant/...)
207
- messages = [
208
- msg
209
- for msg in messages
210
- if not hasattr(msg, "tool_calls")
211
- or (hasattr(msg, "tool_calls") and not msg.tool_calls)
212
- ]
213
  else:
214
  messages = [SystemMessage(answer_prompt())] + state["messages"]
215
  response = generate_model.invoke(messages)
 
4
  from langgraph.prebuilt import ToolNode, tools_condition
5
  from langchain_huggingface import ChatHuggingFace
6
  from typing import Optional
7
+ from dotenv import load_dotenv
8
  import datetime
9
  import os
10
 
11
  # Local modules
12
  from retriever import BuildRetriever
13
+ from prompts import retrieve_prompt, answer_prompt, gemma_tools_template
14
  from mods.tool_calling_llm import ToolCallingLLM
15
 
16
  # Local modules
17
  from retriever import BuildRetriever
18
 
19
  ## For LANGCHAIN_API_KEY
 
 
20
  # load_dotenv(dotenv_path=".env", override=True)
21
  # os.environ["LANGSMITH_TRACING"] = "true"
22
  # os.environ["LANGSMITH_PROJECT"] = "R-help-chat"
23
 
24
 
25
+ def print_messages_summary(messages, header):
26
+ """Print message types and summaries for debugging"""
27
+ if header:
28
+ print(header)
29
+ for message in messages:
30
+ summary_text = ""
31
+ if type(message) == SystemMessage:
32
+ type_txt = "SystemMessage"
33
+ summary_txt = f"length = {len(message.content)}"
34
+ if type(message) == HumanMessage:
35
+ type_txt = "HumanMessage"
36
+ summary_txt = message.content
37
+ if type(message) == AIMessage:
38
+ type_txt = "AIMessage"
39
+ summary_txt = f"length = {len(message.content)}"
40
+ if type(message) == ToolMessage:
41
+ type_txt = "ToolMessage"
42
+ summary_txt = f"length = {len(message.content)}"
43
+ if hasattr(message, "tool_calls"):
44
+ if len(message.tool_calls) != 1:
45
+ summary_txt = f"{summary_txt} with {len(message.tool_calls)} tool calls"
46
+ else:
47
+ summary_txt = f"{summary_txt} with 1 tool call"
48
+ print(f"{type_txt}: {summary_txt}")
49
+
50
+
51
+ def normalize_messages(messages):
52
+ """Normalize messages to sequence of types expected by chat templates"""
53
+ # Copy the most recent HumanMessage to the end
54
+ # (avoids SmolLM3 ValueError: Last message must be a HumanMessage!)
55
+ if not type(messages[-1]) is HumanMessage:
56
+ for msg in reversed(messages):
57
+ if type(msg) is HumanMessage:
58
+ messages.append(msg)
59
+ # Convert tool output (ToolMessage) to AIMessage
60
+ # (avoids SmolLM3 ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
61
+ messages = [
62
+ AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
63
+ ]
64
+ # Delete tool call (AIMessage)
65
+ # (avoids Gemma TemplateError: Conversation roles must alternate user/assistant/user/assistant/...)
66
+ messages = [
67
+ msg
68
+ for msg in messages
69
+ if not hasattr(msg, "tool_calls")
70
+ or (hasattr(msg, "tool_calls") and not msg.tool_calls)
71
+ ]
72
+ return messages
73
+
74
+
75
+ def ToolifyHF(chat_model, system_message, system_message_suffix="", think=False):
76
  """
77
+ Get a Hugging Face model ready for bind_tools().
78
  """
79
 
80
+ ## Add /no_think flag to turn off thinking mode (SmolLM3)
81
+ # if not think:
82
+ # system_message = "/no_think\n" + system_message
83
 
84
+ # Combine system prompt and tools template
85
+ tool_system_prompt_template = system_message + gemma_tools_template
 
86
 
87
  class HuggingFaceWithTools(ToolCallingLLM, ChatHuggingFace):
88
 
 
93
  chat_model = HuggingFaceWithTools(
94
  llm=chat_model.llm,
95
  tool_system_prompt_template=tool_system_prompt_template,
96
+ # Suffix is for any additional context (not templated)
97
  system_message_suffix=system_message_suffix,
98
  )
99
 
 
203
  is_edge = hasattr(chat_model, "model_id")
204
  if is_edge:
205
  # For edge model (ChatHuggingFace)
206
+ query_model = ToolifyHF(
207
  chat_model, retrieve_prompt(compute_location), "", think_retrieve
208
  ).bind_tools([retrieve_emails])
209
+ generate_model = ToolifyHF(
210
+ chat_model, answer_prompt(), "", think_generate
211
+ ).bind_tools([answer_with_citations])
212
  else:
213
  # For cloud model (OpenAI API)
214
  query_model = chat_model.bind_tools([retrieve_emails])
 
222
  if is_edge:
223
  # Don't include the system message here because it's defined in ToolCallingLLM
224
  messages = state["messages"]
225
+ print_messages_summary(messages, "--- query: before normalization ---")
226
+ messages = normalize_messages(messages)
227
+ print_messages_summary(messages, "--- query: after normalization ---")
 
 
 
228
  else:
229
  messages = [SystemMessage(retrieve_prompt(compute_location))] + state[
230
  "messages"
 
237
  """Generates an answer with the chat model"""
238
  if is_edge:
239
  messages = state["messages"]
240
+ print_messages_summary(messages, "--- generate: before normalization ---")
241
+ messages = normalize_messages(messages)
242
+ print_messages_summary(messages, "--- generate: after normalization ---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  else:
244
  messages = [SystemMessage(answer_prompt())] + state["messages"]
245
  response = generate_model.invoke(messages)
main.py CHANGED
@@ -24,9 +24,20 @@ from retriever import BuildRetriever, db_dir
24
  from graph import BuildGraph
25
  from prompts import answer_prompt
26
 
 
27
  # R-help-chat
 
28
  # First version by Jeffrey Dick on 2025-06-29
29
 
 
 
 
 
 
 
 
 
 
30
  # Suppress these messages:
31
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
32
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
@@ -122,7 +133,7 @@ def GetChatModel(compute_location):
122
 
123
  if compute_location == "cloud":
124
 
125
- chat_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
126
 
127
  if compute_location == "edge":
128
 
@@ -130,12 +141,6 @@ def GetChatModel(compute_location):
130
  if compute_location == "edge" and not torch.cuda.is_available():
131
  raise Exception("Edge chat model selected without GPU")
132
 
133
- # Get the model ID (we can define the variable in HF Spaces settings)
134
- model_id = os.getenv("MODEL_ID")
135
- if model_id is None:
136
- # model_id = "HuggingFaceTB/SmolLM3-3B"
137
- model_id = "google/gemma-3-1b-it"
138
-
139
  # Define the pipeline to pass to the HuggingFacePipeline class
140
  # https://huggingface.co/blog/langchain
141
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
24
  from graph import BuildGraph
25
  from prompts import answer_prompt
26
 
27
+ # -----------
28
  # R-help-chat
29
+ # -----------
30
  # First version by Jeffrey Dick on 2025-06-29
31
 
32
+ # Define the cloud (OpenAI) model
33
+ openai_model = "gpt-4o-mini"
34
+
35
+ # Get the edge model ID (we can define the variable in HF Spaces settings)
36
+ model_id = os.getenv("MODEL_ID")
37
+ if model_id is None:
38
+ # model_id = "HuggingFaceTB/SmolLM3-3B"
39
+ model_id = "google/gemma-3-1b-it"
40
+
41
  # Suppress these messages:
42
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
43
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
 
133
 
134
  if compute_location == "cloud":
135
 
136
+ chat_model = ChatOpenAI(model=openai_model, temperature=0)
137
 
138
  if compute_location == "edge":
139
 
 
141
  if compute_location == "edge" and not torch.cuda.is_available():
142
  raise Exception("Edge chat model selected without GPU")
143
 
 
 
 
 
 
 
144
  # Define the pipeline to pass to the HuggingFacePipeline class
145
  # https://huggingface.co/blog/langchain
146
  tokenizer = AutoTokenizer.from_pretrained(model_id)
prompts.py CHANGED
@@ -11,7 +11,7 @@ def retrieve_prompt(compute_location):
11
  """
12
 
13
  # Get start and end months from database
14
- start, end = get_start_end_months(get_sources(compute_location))
15
 
16
  retrieve_prompt = (
17
  f"The current date is {date.today()}. "
@@ -58,23 +58,41 @@ def answer_prompt():
58
 
59
 
60
  # Prompt template for SmolLM3 with tools
61
- # The first two lines are from the apply_chat_template for HuggingFaceTB/SmolLM3-3B
62
- # The remainding lines (starting with You have access...) from tool_calling_llm.py
63
-
64
  smollm3_tools_template = """
65
 
66
- ### Tools
67
 
68
- You may call one or more functions to assist with the user query.
69
 
70
- You have access to the following tools:
71
 
72
- {tools}
73
 
74
- You must always select one of the above tools and respond with only a JSON object matching the following schema:
75
 
76
- {{
77
- "tool": <name of the selected tool>,
78
- "tool_input": <parameters for the selected tool, matching the tool's JSON schema>
79
- }}
80
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  """
12
 
13
  # Get start and end months from database
14
+ start, end = get_start_end_months(get_sources())
15
 
16
  retrieve_prompt = (
17
  f"The current date is {date.today()}. "
 
58
 
59
 
60
  # Prompt template for SmolLM3 with tools
61
+ # The first two lines, <function-name>, and <args-json-object> are from the apply_chat_template for HuggingFaceTB/SmolLM3-3B
62
+ # The other lines (You have, {tools}, You must), "tool", and "tool_input" are from tool_calling_llm.py
 
63
  smollm3_tools_template = """
64
 
65
+ ### Tools
66
 
67
+ You may call one or more functions to assist with the user query.
68
 
69
+ You have access to the following tools:
70
 
71
+ {tools}
72
 
73
+ You must always select one of the above tools and respond with only a JSON object matching the following schema:
74
 
75
+ {{
76
+ "tool": <function-name>,
77
+ "tool_input": <args-json-object>
78
+ }}
79
+
80
+ """
81
+
82
+ # Prompt template for Gemma-3 with tools
83
+ # Based on https://ai.google.dev/gemma/docs/capabilities/function-calling
84
+ gemma_tools_template = """
85
+
86
+ ### Functions
87
+
88
+ You have access to functions. If you decide to invoke any of the function(s), you MUST put it in the format of
89
+
90
+ {{
91
+ "tool": <function-name>,
92
+ "tool_input": <args-json-object>
93
+ }}
94
+
95
+ You SHOULD NOT include any other text in the response if you call a function
96
+
97
+ {tools}
98
+ """
util.py CHANGED
@@ -1,22 +1,37 @@
1
- import re
2
  from calendar import month_name
3
- from retriever import BuildRetriever
 
 
 
4
 
5
 
6
  def get_collection(compute_location):
7
  """
8
  Returns the vectorstore collection.
 
 
 
 
 
 
 
9
  """
10
  retriever = BuildRetriever(compute_location, "dense")
11
  return retriever.vectorstore.get()
12
 
13
 
14
- def get_sources(compute_location):
15
  """
16
  Return the source files indexed in the database, e.g. 'R-help/2024-April.txt'.
17
  """
18
- collection = get_collection(compute_location)
19
- sources = [m["source"] for m in collection["metadatas"]]
 
 
 
 
 
 
20
  return sources
21
 
22
 
 
 
1
  from calendar import month_name
2
+ from retriever import BuildRetriever, db_dir
3
+ import json
4
+ import os
5
+ import re
6
 
7
 
8
  def get_collection(compute_location):
9
  """
10
  Returns the vectorstore collection.
11
+
12
+ Usage Examples:
13
+ # Number of child documents
14
+ collection = get_collection("cloud")
15
+ len(collection["ids"])
16
+ # Number of parent documents (unique doc_ids)
17
+ len(set([m["doc_id"] for m in collection["metadatas"]]))
18
  """
19
  retriever = BuildRetriever(compute_location, "dense")
20
  return retriever.vectorstore.get()
21
 
22
 
23
+ def get_sources():
24
  """
25
  Return the source files indexed in the database, e.g. 'R-help/2024-April.txt'.
26
  """
27
+ # Path to your JSON Lines file
28
+ file_path = os.path.join(db_dir, "bm25", "corpus.jsonl")
29
+
30
+ # Reading the JSON Lines file
31
+ with open(file_path, "r", encoding="utf-8") as file:
32
+ # Parse each line as a JSON object
33
+ sources = [json.loads(line.strip())["metadata"]["source"] for line in file]
34
+
35
  return sources
36
 
37