jedick commited on
Commit
429393a
·
1 Parent(s): 8627eb1

Remove local compute mode

Browse files
Files changed (12) hide show
  1. app.py +67 -179
  2. eval.py +5 -14
  3. graph.py +12 -73
  4. images/graph_LR.png +0 -0
  5. index.py +4 -5
  6. main.py +13 -96
  7. mods/tool_calling_llm.py +0 -313
  8. pipeline.py +0 -86
  9. prompts.py +7 -66
  10. requirements.txt +13 -38
  11. retriever.py +11 -58
  12. util.py +0 -15
app.py CHANGED
@@ -1,42 +1,25 @@
1
  from langgraph.checkpoint.memory import MemorySaver
2
- from huggingface_hub import snapshot_download
3
  from dotenv import load_dotenv
4
  from datetime import datetime
5
  import gradio as gr
6
- import spaces
7
- import torch
8
  import uuid
9
  import ast
10
  import os
11
  import re
12
 
13
  # Local modules
14
- from main import GetChatModel, openai_model, model_id
15
  from util import get_sources, get_start_end_months
16
- from retriever import db_dir, embedding_model_id
17
- from mods.tool_calling_llm import extract_think
18
  from data import download_data, extract_data
 
19
  from graph import BuildGraph
 
20
 
21
  # Set environment variables
22
  load_dotenv(dotenv_path=".env", override=True)
23
  # Hide BM25S progress bars
24
  os.environ["DISABLE_TQDM"] = "true"
25
 
26
- # Download model snapshots from Hugging Face Hub
27
- if torch.cuda.is_available():
28
- print(f"Downloading checkpoints for {model_id}...")
29
- ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
30
- print(f"Using checkpoints from {ckpt_dir}")
31
- print(f"Downloading checkpoints for {embedding_model_id}...")
32
- embedding_ckpt_dir = snapshot_download(
33
- embedding_model_id, local_dir_use_symlinks=False
34
- )
35
- print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
36
- else:
37
- ckpt_dir = None
38
- embedding_ckpt_dir = None
39
-
40
  # Download and extract data if data directory is not present
41
  if not os.path.isdir(db_dir):
42
  print("Downloading data ... ", end="")
@@ -51,17 +34,35 @@ search_type = "hybrid"
51
 
52
  # Global variables for LangChain graph: use dictionaries to store user-specific instances
53
  # https://www.gradio.app/guides/state-in-blocks
54
- graph_instances = {"local": {}, "remote": {}}
55
 
56
 
57
  def cleanup_graph(request: gr.Request):
58
  timestamp = datetime.now().replace(microsecond=0).isoformat()
59
- if request.session_hash in graph_instances["local"]:
60
- del graph_instances["local"][request.session_hash]
61
- print(f"{timestamp} - Delete local graph for session {request.session_hash}")
62
- if request.session_hash in graph_instances["remote"]:
63
- del graph_instances["remote"][request.session_hash]
64
- print(f"{timestamp} - Delete remote graph for session {request.session_hash}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  def append_content(chunk_messages, history, thinking_about):
@@ -85,48 +86,32 @@ def append_content(chunk_messages, history, thinking_about):
85
  return history
86
 
87
 
88
- def run_workflow(input, history, compute_mode, thread_id, session_hash):
89
  """The main function to run the chat workflow"""
90
 
91
- # Error if user tries to run local mode without GPU
92
- if compute_mode == "local":
93
- if not torch.cuda.is_available():
94
- raise gr.Error(
95
- "Local mode requires GPU.",
96
- print_exception=False,
97
- )
98
-
99
  # Get graph instance
100
- graph = graph_instances[compute_mode].get(session_hash)
101
 
102
  if graph is None:
103
- # Notify when we're loading the local model because it takes some time
104
- if compute_mode == "local":
105
- gr.Info(
106
- f"Please wait for the local model to load",
107
- title=f"Model loading...",
108
- )
109
  # Get the chat model and build the graph
110
- chat_model = GetChatModel(compute_mode, ckpt_dir)
111
  graph_builder = BuildGraph(
112
  chat_model,
113
- compute_mode,
114
  search_type,
115
- embedding_ckpt_dir=embedding_ckpt_dir,
116
  )
117
  # Compile the graph with an in-memory checkpointer
118
  memory = MemorySaver()
119
  graph = graph_builder.compile(checkpointer=memory)
120
- # Set global graph for compute mode
121
- graph_instances[compute_mode][session_hash] = graph
122
  # ISO 8601 timestamp with local timezone information without microsecond
123
  timestamp = datetime.now().replace(microsecond=0).isoformat()
124
- print(f"{timestamp} - Set {compute_mode} graph for session {session_hash}")
125
- # Notify when model finishes loading
126
- gr.Success(f"{compute_mode}", duration=4, title=f"Model loaded!")
127
  else:
128
  timestamp = datetime.now().replace(microsecond=0).isoformat()
129
- print(f"{timestamp} - Get {compute_mode} graph for session {session_hash}")
130
 
131
  # print(f"Using thread_id: {thread_id}")
132
 
@@ -235,28 +220,11 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
235
 
236
 
237
  def to_workflow(request: gr.Request, *args):
238
- """Wrapper function to call function with or without @spaces.GPU"""
239
  input = args[0]
240
- compute_mode = args[2]
241
  # Add session_hash to arguments
242
  new_args = args + (request.session_hash,)
243
- if compute_mode == "local":
244
- # Call the workflow function with the @spaces.GPU decorator
245
- for value in run_workflow_local(*new_args):
246
- yield value
247
- if compute_mode == "remote":
248
- for value in run_workflow_remote(*new_args):
249
- yield value
250
-
251
-
252
- @spaces.GPU(duration=100)
253
- def run_workflow_local(*args):
254
- for value in run_workflow(*args):
255
- yield value
256
-
257
-
258
- def run_workflow_remote(*args):
259
- for value in run_workflow(*args):
260
  yield value
261
 
262
 
@@ -290,19 +258,6 @@ with gr.Blocks(
290
  # Define components
291
  # -----------------
292
 
293
- compute_mode = gr.Radio(
294
- choices=[
295
- "local",
296
- "remote",
297
- ],
298
- # Default to remote because it provides a better first impression for most people
299
- # value=("local" if torch.cuda.is_available() else "remote"),
300
- value="remote",
301
- label="Compute Mode",
302
- info="NOTE: remote mode **does not** use ZeroGPU",
303
- render=False,
304
- )
305
-
306
  loading_data = gr.Textbox(
307
  "Please wait for the email database to be downloaded and extracted.",
308
  max_lines=0,
@@ -332,14 +287,7 @@ with gr.Blocks(
332
  chatbot = gr.Chatbot(
333
  type="messages",
334
  show_label=False,
335
- avatar_images=(
336
- None,
337
- (
338
- "images/cloud.png"
339
- if compute_mode.value == "remote"
340
- else "images/chip.png"
341
- ),
342
- ),
343
  show_copy_all_button=True,
344
  render=False,
345
  )
@@ -398,24 +346,17 @@ with gr.Blocks(
398
  and generates an answer from the retrieved emails (*emails are shown below the chatbot*).
399
  You can ask follow-up questions with the chat history as context.
400
  Press the clear button (🗑) to clear the history and start a new chat.
 
401
  """
402
  return intro
403
 
404
- def get_status_text(compute_mode):
405
- if compute_mode == "remote":
406
- status_text = f"""
407
- 🌐 Now in **remote** mode, using the OpenAI API<br>
408
- ⚠️ **_Privacy Notice_**: Data sharing with OpenAI is enabled<br>
409
- text-embedding-3-small and {openai_model}<br>
410
- 🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
411
- """
412
- if compute_mode == "local":
413
- status_text = f"""
414
- 📍 Now in **local** mode, using ZeroGPU hardware<br>
415
- ⌛ Response time is about one minute<br>
416
- ✨ [{embedding_model_id.split("/")[-1]}](https://huggingface.co/{embedding_model_id}) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
417
- 🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
418
- """
419
  return status_text
420
 
421
  def get_info_text():
@@ -430,13 +371,13 @@ with gr.Blocks(
430
  end = None
431
  info_text = f"""
432
  **Database:** {len(sources)} emails from {start} to {end}.
433
- **Features:** RAG, today's date, hybrid search (dense+sparse), multiple retrievals, citations output (remote), chat memory.
434
- **Tech:** LangChain + Hugging Face + Gradio; ChromaDB and BM25S-based retrievers.<br>
435
  """
436
  return info_text
437
 
438
- def get_example_questions(compute_mode, as_dataset=True):
439
- """Get example questions based on compute mode"""
440
  questions = [
441
  # "What is today's date?",
442
  "Summarize emails from the most recent two months",
@@ -445,15 +386,11 @@ with gr.Blocks(
445
  "Who reported installation problems in 2023-2024?",
446
  ]
447
 
448
- ## Remove "/think" from questions in remote mode
449
- # if compute_mode == "remote":
450
- # questions = [q.replace(" /think", "") for q in questions]
451
-
452
  # cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
453
  return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
454
 
455
- def get_multi_tool_questions(compute_mode, as_dataset=True):
456
- """Get multi-tool example questions based on compute mode"""
457
  questions = [
458
  "Differences between lapply and for loops",
459
  "Discuss pipe operator usage in 2022, 2023, and 2024",
@@ -461,8 +398,8 @@ with gr.Blocks(
461
 
462
  return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
463
 
464
- def get_multi_turn_questions(compute_mode, as_dataset=True):
465
- """Get multi-turn example questions based on compute mode"""
466
  questions = [
467
  "Lookup emails that reference bugs.r-project.org in 2025",
468
  "Did the authors you cited report bugs before 2025?",
@@ -474,10 +411,14 @@ with gr.Blocks(
474
  # Left column: Intro, Compute, Chat
475
  with gr.Column(scale=2):
476
  with gr.Row(elem_classes=["row-container"]):
477
- with gr.Column(scale=2):
478
  intro = gr.Markdown(get_intro_text())
479
  with gr.Column(scale=1):
480
- compute_mode.render()
 
 
 
 
481
  with gr.Group() as chat_interface:
482
  chatbot.render()
483
  input.render()
@@ -488,29 +429,23 @@ with gr.Blocks(
488
  missing_data.render()
489
  # Right column: Info, Examples
490
  with gr.Column(scale=1):
491
- status = gr.Markdown(get_status_text(compute_mode.value))
492
  with gr.Accordion("ℹ️ More Info", open=False):
493
  info = gr.Markdown(get_info_text())
494
  with gr.Accordion("💡 Examples", open=True):
495
  # Add some helpful examples
496
  example_questions = gr.Examples(
497
- examples=get_example_questions(
498
- compute_mode.value, as_dataset=False
499
- ),
500
  inputs=[input],
501
  label="Click an example to fill the message box",
502
  )
503
  multi_tool_questions = gr.Examples(
504
- examples=get_multi_tool_questions(
505
- compute_mode.value, as_dataset=False
506
- ),
507
  inputs=[input],
508
  label="Multiple retrievals",
509
  )
510
  multi_turn_questions = gr.Examples(
511
- examples=get_multi_turn_questions(
512
- compute_mode.value, as_dataset=False
513
- ),
514
  inputs=[input],
515
  label="Asking follow-up questions",
516
  )
@@ -530,18 +465,6 @@ with gr.Blocks(
530
  """Return updated value for a component"""
531
  return gr.update(value=value)
532
 
533
- def set_avatar(compute_mode):
534
- if compute_mode == "remote":
535
- image_file = "images/cloud.png"
536
- if compute_mode == "local":
537
- image_file = "images/chip.png"
538
- return gr.update(
539
- avatar_images=(
540
- None,
541
- image_file,
542
- ),
543
- )
544
-
545
  def change_visibility(visible):
546
  """Return updated visibility state for a component"""
547
  return gr.update(visible=visible)
@@ -565,45 +488,10 @@ with gr.Blocks(
565
  # https://github.com/gradio-app/gradio/issues/9722
566
  chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
567
 
568
- def clear_component(component):
569
- """Return cleared component"""
570
- return component.clear()
571
-
572
- compute_mode.change(
573
- # Start a new thread
574
- generate_thread_id,
575
- outputs=[thread_id],
576
- api_name=False,
577
- ).then(
578
- # Focus textbox by updating the textbox with the current value
579
- lambda x: gr.update(value=x),
580
- [input],
581
- [input],
582
- api_name=False,
583
- ).then(
584
- # Change the app status text
585
- get_status_text,
586
- [compute_mode],
587
- [status],
588
- api_name=False,
589
- ).then(
590
- # Clear the chatbot history
591
- clear_component,
592
- [chatbot],
593
- [chatbot],
594
- api_name=False,
595
- ).then(
596
- # Change the chatbot avatar
597
- set_avatar,
598
- [compute_mode],
599
- [chatbot],
600
- api_name=False,
601
- )
602
-
603
  input.submit(
604
  # Submit input to the chatbot
605
  to_workflow,
606
- [input, chatbot, compute_mode, thread_id],
607
  [chatbot, retrieved_emails, citations_text],
608
  api_name=False,
609
  )
 
1
  from langgraph.checkpoint.memory import MemorySaver
2
+ from langchain_openai import ChatOpenAI
3
  from dotenv import load_dotenv
4
  from datetime import datetime
5
  import gradio as gr
 
 
6
  import uuid
7
  import ast
8
  import os
9
  import re
10
 
11
  # Local modules
 
12
  from util import get_sources, get_start_end_months
 
 
13
  from data import download_data, extract_data
14
+ from main import openai_model
15
  from graph import BuildGraph
16
+ from retriever import db_dir
17
 
18
  # Set environment variables
19
  load_dotenv(dotenv_path=".env", override=True)
20
  # Hide BM25S progress bars
21
  os.environ["DISABLE_TQDM"] = "true"
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Download and extract data if data directory is not present
24
  if not os.path.isdir(db_dir):
25
  print("Downloading data ... ", end="")
 
34
 
35
  # Global variables for LangChain graph: use dictionaries to store user-specific instances
36
  # https://www.gradio.app/guides/state-in-blocks
37
+ graph_instances = {}
38
 
39
 
40
  def cleanup_graph(request: gr.Request):
41
  timestamp = datetime.now().replace(microsecond=0).isoformat()
42
+ if request.session_hash in graph_instances:
43
+ del graph_instances[request.session_hash]
44
+ print(f"{timestamp} - Delete graph for session {request.session_hash}")
45
+
46
+
47
+ def extract_think(content):
48
+ # Added by Cursor 20250726 jmd
49
+ # Extract content within <think>...</think>
50
+ think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
51
+ think_text = think_match.group(1).strip() if think_match else ""
52
+ # Extract text after </think>
53
+ if think_match:
54
+ post_think = content[think_match.end() :].lstrip()
55
+ else:
56
+ # Check if content starts with <think> but missing closing tag
57
+ if content.strip().startswith("<think>"):
58
+ # Extract everything after <think>
59
+ think_start = content.find("<think>") + len("<think>")
60
+ think_text = content[think_start:].strip()
61
+ post_think = ""
62
+ else:
63
+ # No <think> found, so return entire content as post_think
64
+ post_think = content
65
+ return think_text, post_think
66
 
67
 
68
  def append_content(chunk_messages, history, thinking_about):
 
86
  return history
87
 
88
 
89
+ def run_workflow(input, history, thread_id, session_hash):
90
  """The main function to run the chat workflow"""
91
 
 
 
 
 
 
 
 
 
92
  # Get graph instance
93
+ graph = graph_instances.get(session_hash)
94
 
95
  if graph is None:
 
 
 
 
 
 
96
  # Get the chat model and build the graph
97
+ chat_model = ChatOpenAI(model=openai_model, temperature=0)
98
  graph_builder = BuildGraph(
99
  chat_model,
 
100
  search_type,
 
101
  )
102
  # Compile the graph with an in-memory checkpointer
103
  memory = MemorySaver()
104
  graph = graph_builder.compile(checkpointer=memory)
105
+ # Set global graph
106
+ graph_instances[session_hash] = graph
107
  # ISO 8601 timestamp with local timezone information without microsecond
108
  timestamp = datetime.now().replace(microsecond=0).isoformat()
109
+ print(f"{timestamp} - Set graph for session {session_hash}")
110
+ ## Notify when model finishes loading
111
+ # gr.Success("Model loaded!", duration=4)
112
  else:
113
  timestamp = datetime.now().replace(microsecond=0).isoformat()
114
+ print(f"{timestamp} - Get graph for session {session_hash}")
115
 
116
  # print(f"Using thread_id: {thread_id}")
117
 
 
220
 
221
 
222
  def to_workflow(request: gr.Request, *args):
223
+ """Wrapper function to call run_workflow() with session_hash"""
224
  input = args[0]
 
225
  # Add session_hash to arguments
226
  new_args = args + (request.session_hash,)
227
+ for value in run_workflow(*new_args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  yield value
229
 
230
 
 
258
  # Define components
259
  # -----------------
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  loading_data = gr.Textbox(
262
  "Please wait for the email database to be downloaded and extracted.",
263
  max_lines=0,
 
287
  chatbot = gr.Chatbot(
288
  type="messages",
289
  show_label=False,
290
+ avatar_images=(None, "images/cloud.png"),
 
 
 
 
 
 
 
291
  show_copy_all_button=True,
292
  render=False,
293
  )
 
346
  and generates an answer from the retrieved emails (*emails are shown below the chatbot*).
347
  You can ask follow-up questions with the chat history as context.
348
  Press the clear button (🗑) to clear the history and start a new chat.
349
+ 🚧 Under construction: Select a mailing list to search, or use Auto to let the LLM choose.
350
  """
351
  return intro
352
 
353
+ def get_status_text():
354
+ status_text = f"""
355
+ 🌐 This app uses the OpenAI API<br>
356
+ ⚠️ **_Privacy Notice_**: Data sharing with OpenAI is enabled<br>
357
+ text-embedding-3-small and {openai_model}<br>
358
+ 🏠 More info: [R-help-chat GitHub repository](https://github.com/jedick/R-help-chat)
359
+ """
 
 
 
 
 
 
 
 
360
  return status_text
361
 
362
  def get_info_text():
 
371
  end = None
372
  info_text = f"""
373
  **Database:** {len(sources)} emails from {start} to {end}.
374
+ **Features:** RAG, today's date, hybrid search (dense+sparse), multiple retrievals, citations output, chat memory.
375
+ **Tech:** OpenAI API + LangGraph + Gradio; ChromaDB and BM25S-based retrievers.<br>
376
  """
377
  return info_text
378
 
379
+ def get_example_questions(as_dataset=True):
380
+ """Get example questions"""
381
  questions = [
382
  # "What is today's date?",
383
  "Summarize emails from the most recent two months",
 
386
  "Who reported installation problems in 2023-2024?",
387
  ]
388
 
 
 
 
 
389
  # cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
390
  return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
391
 
392
+ def get_multi_tool_questions(as_dataset=True):
393
+ """Get multi-tool example questions"""
394
  questions = [
395
  "Differences between lapply and for loops",
396
  "Discuss pipe operator usage in 2022, 2023, and 2024",
 
398
 
399
  return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
400
 
401
+ def get_multi_turn_questions(as_dataset=True):
402
+ """Get multi-turn example questions"""
403
  questions = [
404
  "Lookup emails that reference bugs.r-project.org in 2025",
405
  "Did the authors you cited report bugs before 2025?",
 
411
  # Left column: Intro, Compute, Chat
412
  with gr.Column(scale=2):
413
  with gr.Row(elem_classes=["row-container"]):
414
+ with gr.Column(scale=4):
415
  intro = gr.Markdown(get_intro_text())
416
  with gr.Column(scale=1):
417
+ gr.Radio(
418
+ ["Auto", "R-help", "R-devel", "R-pkg-devel"],
419
+ label="Mailing List",
420
+ interactive=False,
421
+ )
422
  with gr.Group() as chat_interface:
423
  chatbot.render()
424
  input.render()
 
429
  missing_data.render()
430
  # Right column: Info, Examples
431
  with gr.Column(scale=1):
432
+ status = gr.Markdown(get_status_text())
433
  with gr.Accordion("ℹ️ More Info", open=False):
434
  info = gr.Markdown(get_info_text())
435
  with gr.Accordion("💡 Examples", open=True):
436
  # Add some helpful examples
437
  example_questions = gr.Examples(
438
+ examples=get_example_questions(as_dataset=False),
 
 
439
  inputs=[input],
440
  label="Click an example to fill the message box",
441
  )
442
  multi_tool_questions = gr.Examples(
443
+ examples=get_multi_tool_questions(as_dataset=False),
 
 
444
  inputs=[input],
445
  label="Multiple retrievals",
446
  )
447
  multi_turn_questions = gr.Examples(
448
+ examples=get_multi_turn_questions(as_dataset=False),
 
 
449
  inputs=[input],
450
  label="Asking follow-up questions",
451
  )
 
465
  """Return updated value for a component"""
466
  return gr.update(value=value)
467
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  def change_visibility(visible):
469
  """Return updated visibility state for a component"""
470
  return gr.update(visible=visible)
 
488
  # https://github.com/gradio-app/gradio/issues/9722
489
  chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  input.submit(
492
  # Submit input to the chatbot
493
  to_workflow,
494
+ [input, chatbot, thread_id],
495
  [chatbot, retrieved_emails, citations_text],
496
  api_name=False,
497
  )
eval.py CHANGED
@@ -34,7 +34,7 @@ def load_questions_and_references(csv_path):
34
  return questions, references
35
 
36
 
37
- def build_eval_dataset(questions, references, compute_mode, workflow, search_type):
38
  """Build dataset for evaluation"""
39
  dataset = []
40
  for question, reference in zip(questions, references):
@@ -42,15 +42,15 @@ def build_eval_dataset(questions, references, compute_mode, workflow, search_typ
42
  if workflow == "chain":
43
  print("\n\n--- Question ---")
44
  print(question)
45
- response = RunChain(question, compute_mode, search_type)
46
  print("--- Response ---")
47
  print(response)
48
  # Retrieve context documents for a question
49
- retriever = BuildRetriever(compute_mode, search_type)
50
  docs = retriever.invoke(question)
51
  retrieved_contexts = [doc.page_content for doc in docs]
52
  if workflow == "graph":
53
- result = RunGraph(question, compute_mode, search_type)
54
  retrieved_contexts = []
55
  if "retrieved_emails" in result:
56
  # Remove the source file names (e.g. R-help/2022-September.txt) as it confuses the evaluator
@@ -142,12 +142,6 @@ def main():
142
  parser = argparse.ArgumentParser(
143
  description="Evaluate RAG retrieval and generation."
144
  )
145
- parser.add_argument(
146
- "--compute_mode",
147
- choices=["remote", "local"],
148
- required=True,
149
- help="Compute mode: remote or local.",
150
- )
151
  parser.add_argument(
152
  "--workflow",
153
  choices=["chain", "graph"],
@@ -161,14 +155,11 @@ def main():
161
  help="Search type: dense, sparse, or hybrid.",
162
  )
163
  args = parser.parse_args()
164
- compute_mode = args.compute_mode
165
  workflow = args.workflow
166
  search_type = args.search_type
167
 
168
  questions, references = load_questions_and_references("eval.csv")
169
- dataset = build_eval_dataset(
170
- questions, references, compute_mode, workflow, search_type
171
- )
172
  evaluation_dataset = EvaluationDataset.from_list(dataset)
173
 
174
  # Set up LLM for evaluation
 
34
  return questions, references
35
 
36
 
37
+ def build_eval_dataset(questions, references, workflow, search_type):
38
  """Build dataset for evaluation"""
39
  dataset = []
40
  for question, reference in zip(questions, references):
 
42
  if workflow == "chain":
43
  print("\n\n--- Question ---")
44
  print(question)
45
+ response = RunChain(question, search_type)
46
  print("--- Response ---")
47
  print(response)
48
  # Retrieve context documents for a question
49
+ retriever = BuildRetriever(search_type)
50
  docs = retriever.invoke(question)
51
  retrieved_contexts = [doc.page_content for doc in docs]
52
  if workflow == "graph":
53
+ result = RunGraph(question, search_type)
54
  retrieved_contexts = []
55
  if "retrieved_emails" in result:
56
  # Remove the source file names (e.g. R-help/2022-September.txt) as it confuses the evaluator
 
142
  parser = argparse.ArgumentParser(
143
  description="Evaluate RAG retrieval and generation."
144
  )
 
 
 
 
 
 
145
  parser.add_argument(
146
  "--workflow",
147
  choices=["chain", "graph"],
 
155
  help="Search type: dense, sparse, or hybrid.",
156
  )
157
  args = parser.parse_args()
 
158
  workflow = args.workflow
159
  search_type = args.search_type
160
 
161
  questions, references = load_questions_and_references("eval.csv")
162
+ dataset = build_eval_dataset(questions, references, workflow, search_type)
 
 
163
  evaluation_dataset = EvaluationDataset.from_list(dataset)
164
 
165
  # Set up LLM for evaluation
graph.py CHANGED
@@ -2,15 +2,13 @@ from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage, AI
2
  from langgraph.graph import START, END, MessagesState, StateGraph
3
  from langchain_core.tools import tool
4
  from langgraph.prebuilt import ToolNode, tools_condition
5
- from langchain_huggingface import ChatHuggingFace
6
  from typing import Optional
7
  import datetime
8
  import os
9
 
10
  # Local modules
11
  from retriever import BuildRetriever
12
- from prompts import query_prompt, answer_prompt, generic_tools_template
13
- from mods.tool_calling_llm import ToolCallingLLM
14
 
15
  # For tracing (disabled)
16
  # os.environ["LANGSMITH_TRACING"] = "true"
@@ -105,48 +103,18 @@ def normalize_messages(messages, summaries_for=None):
105
  return messages
106
 
107
 
108
- def ToolifyHF(chat_model, system_message):
109
- """
110
- Get a Hugging Face model ready for bind_tools().
111
- """
112
-
113
- # Combine system prompt and tools template
114
- tool_system_prompt_template = system_message + generic_tools_template
115
-
116
- class HuggingFaceWithTools(ToolCallingLLM, ChatHuggingFace):
117
- def __init__(self, **kwargs):
118
- super().__init__(**kwargs)
119
-
120
- chat_model = HuggingFaceWithTools(
121
- llm=chat_model.llm,
122
- tool_system_prompt_template=tool_system_prompt_template,
123
- )
124
-
125
- return chat_model
126
-
127
-
128
  def BuildGraph(
129
  chat_model,
130
- compute_mode,
131
  search_type,
132
  top_k=6,
133
- think_query=False,
134
- think_answer=False,
135
- local_citations=False,
136
- embedding_ckpt_dir=None,
137
  ):
138
  """
139
  Build conversational RAG graph for email retrieval and answering with citations.
140
 
141
  Args:
142
- chat_model: LangChain chat model from GetChatModel()
143
- compute_mode: remote or local (for retriever)
144
  search_type: dense, sparse, or hybrid (for retriever)
145
  top_k: number of documents to retrieve
146
- think_query: Whether to use thinking mode for the query (local model)
147
- think_answer: Whether to use thinking mode for the answer (local model)
148
- local_citations: Whether to use answer_with_citations() tool (local model)
149
- embedding_ckpt_dir: Directory for embedding model checkpoint
150
 
151
  Based on:
152
  https://python.langchain.com/docs/how_to/qa_sources
@@ -158,7 +126,7 @@ def BuildGraph(
158
  # Build graph with chat model
159
  from langchain_openai import ChatOpenAI
160
  chat_model = ChatOpenAI(model="gpt-4o-mini")
161
- graph = BuildGraph(chat_model, "remote", "hybrid")
162
 
163
  # Add simple in-memory checkpointer
164
  from langgraph.checkpoint.memory import MemorySaver
@@ -198,7 +166,10 @@ def BuildGraph(
198
  months (str, optional): One or more months separated by spaces
199
  """
200
  retriever = BuildRetriever(
201
- compute_mode, search_type, top_k, start_year, end_year, embedding_ckpt_dir
 
 
 
202
  )
203
  # For now, just add the months to the search query
204
  if months:
@@ -230,55 +201,23 @@ def BuildGraph(
230
  """
231
  return answer, citations
232
 
233
- # Add tools to the local or remote chat model
234
- is_local = hasattr(chat_model, "model_id")
235
- if is_local:
236
- # For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
237
- query_model = ToolifyHF(
238
- chat_model, query_prompt(chat_model, think=think_query)
239
- ).bind_tools([retrieve_emails])
240
- if local_citations:
241
- answer_model = ToolifyHF(
242
- chat_model,
243
- answer_prompt(chat_model, think=think_answer, with_tools=True),
244
- ).bind_tools([answer_with_citations])
245
- else:
246
- # Don't use answer_with_citations tool because responses with are sometimes unparseable
247
- answer_model = chat_model
248
- else:
249
- # For remote model (OpenAI API)
250
- query_model = chat_model.bind_tools([retrieve_emails])
251
- answer_model = chat_model.bind_tools([answer_with_citations])
252
 
253
  # Initialize the graph object
254
  graph = StateGraph(MessagesState)
255
 
256
  def query(state: MessagesState):
257
  """Queries the retriever with the chat model"""
258
- if is_local:
259
- # Don't include the system message here because it's defined in ToolCallingLLM
260
- messages = state["messages"]
261
- messages = normalize_messages(messages)
262
- else:
263
- messages = [SystemMessage(query_prompt(chat_model))] + state["messages"]
264
  response = query_model.invoke(messages)
265
 
266
  return {"messages": response}
267
 
268
  def answer(state: MessagesState):
269
  """Generates an answer with the chat model"""
270
- if is_local:
271
- messages = state["messages"]
272
- messages = normalize_messages(messages)
273
- if not local_citations:
274
- # Add the system message here if we're not using tools
275
- messages = [
276
- SystemMessage(answer_prompt(chat_model, think=think_answer))
277
- ] + messages
278
- else:
279
- messages = [
280
- SystemMessage(answer_prompt(chat_model, with_tools=True))
281
- ] + state["messages"]
282
  response = answer_model.invoke(messages)
283
 
284
  return {"messages": response}
 
2
  from langgraph.graph import START, END, MessagesState, StateGraph
3
  from langchain_core.tools import tool
4
  from langgraph.prebuilt import ToolNode, tools_condition
 
5
  from typing import Optional
6
  import datetime
7
  import os
8
 
9
  # Local modules
10
  from retriever import BuildRetriever
11
+ from prompts import query_prompt, answer_prompt
 
12
 
13
  # For tracing (disabled)
14
  # os.environ["LANGSMITH_TRACING"] = "true"
 
103
  return messages
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def BuildGraph(
107
  chat_model,
 
108
  search_type,
109
  top_k=6,
 
 
 
 
110
  ):
111
  """
112
  Build conversational RAG graph for email retrieval and answering with citations.
113
 
114
  Args:
115
+ chat_model: LangChain chat model
 
116
  search_type: dense, sparse, or hybrid (for retriever)
117
  top_k: number of documents to retrieve
 
 
 
 
118
 
119
  Based on:
120
  https://python.langchain.com/docs/how_to/qa_sources
 
126
  # Build graph with chat model
127
  from langchain_openai import ChatOpenAI
128
  chat_model = ChatOpenAI(model="gpt-4o-mini")
129
+ graph = BuildGraph(chat_model, "hybrid")
130
 
131
  # Add simple in-memory checkpointer
132
  from langgraph.checkpoint.memory import MemorySaver
 
166
  months (str, optional): One or more months separated by spaces
167
  """
168
  retriever = BuildRetriever(
169
+ search_type,
170
+ top_k,
171
+ start_year,
172
+ end_year,
173
  )
174
  # For now, just add the months to the search query
175
  if months:
 
201
  """
202
  return answer, citations
203
 
204
+ # Add tools to the chat model
205
+ query_model = chat_model.bind_tools([retrieve_emails])
206
+ answer_model = chat_model.bind_tools([answer_with_citations])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  # Initialize the graph object
209
  graph = StateGraph(MessagesState)
210
 
211
  def query(state: MessagesState):
212
  """Queries the retriever with the chat model"""
213
+ messages = [SystemMessage(query_prompt())] + state["messages"]
 
 
 
 
 
214
  response = query_model.invoke(messages)
215
 
216
  return {"messages": response}
217
 
218
  def answer(state: MessagesState):
219
  """Generates an answer with the chat model"""
220
+ messages = [SystemMessage(answer_prompt())] + state["messages"]
 
 
 
 
 
 
 
 
 
 
 
221
  response = answer_model.invoke(messages)
222
 
223
  return {"messages": response}
images/graph_LR.png CHANGED
index.py CHANGED
@@ -9,14 +9,13 @@ from retriever import BuildRetriever, db_dir
9
  from mods.bm25s_retriever import BM25SRetriever
10
 
11
 
12
- def ProcessFile(file_path, search_type: str = "dense", compute_mode: str = "remote"):
13
  """
14
  Wrapper function to process file for dense or sparse search
15
 
16
  Args:
17
  file_path: File to process
18
  search_type: Type of search to use. Options: "dense", "sparse"
19
- compute_mode: Compute mode for embeddings (remote or local)
20
  """
21
 
22
  # Preprocess: remove quoted lines and handle email boundaries
@@ -69,7 +68,7 @@ def ProcessFile(file_path, search_type: str = "dense", compute_mode: str = "remo
69
  ProcessFileSparse(truncated_temp_file, file_path)
70
  elif search_type == "dense":
71
  # Handle dense search with ChromaDB
72
- ProcessFileDense(truncated_temp_file, file_path, compute_mode)
73
  else:
74
  raise ValueError(f"Unsupported search type: {search_type}")
75
  finally:
@@ -81,12 +80,12 @@ def ProcessFile(file_path, search_type: str = "dense", compute_mode: str = "remo
81
  pass
82
 
83
 
84
- def ProcessFileDense(cleaned_temp_file, file_path, compute_mode):
85
  """
86
  Process file for dense vector search using ChromaDB
87
  """
88
  # Get a retriever instance
89
- retriever = BuildRetriever(compute_mode, "dense")
90
  # Load cleaned text file
91
  loader = TextLoader(cleaned_temp_file)
92
  documents = loader.load()
 
9
  from mods.bm25s_retriever import BM25SRetriever
10
 
11
 
12
+ def ProcessFile(file_path, search_type: str = "dense"):
13
  """
14
  Wrapper function to process file for dense or sparse search
15
 
16
  Args:
17
  file_path: File to process
18
  search_type: Type of search to use. Options: "dense", "sparse"
 
19
  """
20
 
21
  # Preprocess: remove quoted lines and handle email boundaries
 
68
  ProcessFileSparse(truncated_temp_file, file_path)
69
  elif search_type == "dense":
70
  # Handle dense search with ChromaDB
71
+ ProcessFileDense(truncated_temp_file, file_path)
72
  else:
73
  raise ValueError(f"Unsupported search type: {search_type}")
74
  finally:
 
80
  pass
81
 
82
 
83
+ def ProcessFileDense(cleaned_temp_file, file_path):
84
  """
85
  Process file for dense vector search using ChromaDB
86
  """
87
  # Get a retriever instance
88
+ retriever = BuildRetriever("dense")
89
  # Load cleaned text file
90
  loader = TextLoader(cleaned_temp_file)
91
  documents = loader.load()
main.py CHANGED
@@ -5,20 +5,15 @@ from langchain_core.prompts import ChatPromptTemplate
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langchain_core.messages import SystemMessage
7
  from langchain_core.messages import ToolMessage
 
8
  from dotenv import load_dotenv
9
  from datetime import datetime
10
  import logging
11
- import torch
12
  import glob
13
  import ast
14
  import os
15
 
16
- # Imports for local and remote chat models
17
- from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
18
- from langchain_openai import ChatOpenAI
19
-
20
  # Local modules
21
- from pipeline import MyTextGenerationPipeline
22
  from retriever import BuildRetriever, db_dir
23
  from prompts import answer_prompt
24
  from index import ProcessFile
@@ -32,16 +27,9 @@ from graph import BuildGraph
32
  # Setup environment variables
33
  load_dotenv(dotenv_path=".env", override=True)
34
 
35
- # Define the remote (OpenAI) model
36
  openai_model = "gpt-4o-mini"
37
 
38
- # Get the local model ID
39
- model_id = os.getenv("MODEL_ID")
40
- if model_id is None:
41
- # model_id = "HuggingFaceTB/SmolLM3-3B"
42
- model_id = "google/gemma-3-12b-it"
43
- # model_id = "Qwen/Qwen3-14B"
44
-
45
  # Suppress these messages:
46
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
47
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
@@ -50,30 +38,29 @@ httpx_logger = logging.getLogger("httpx")
50
  httpx_logger.setLevel(logging.WARNING)
51
 
52
 
53
- def ProcessDirectory(path, compute_mode):
54
  """
55
  Update vector store and sparse index for files in a directory, only adding new or updated files
56
 
57
  Args:
58
  path: Directory to process
59
- compute_mode: Compute mode for embeddings (remote or local)
60
 
61
  Usage example:
62
- ProcessDirectory("R-help", "remote")
63
  """
64
 
65
  # TODO: use UUID to process only changed documents
66
  # https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
67
 
68
  # Get a dense retriever instance
69
- retriever = BuildRetriever(compute_mode, "dense")
70
 
71
  # List all text files in target directory
72
  file_paths = glob.glob(f"{path}/*.txt")
73
  for file_path in file_paths:
74
 
75
  # Process file for sparse search (BM25S)
76
- ProcessFile(file_path, "sparse", compute_mode)
77
 
78
  # Logic for dense search: skip file if already indexed
79
  # Look for existing embeddings for this file
@@ -103,7 +90,7 @@ def ProcessDirectory(path, compute_mode):
103
  update_file = True
104
 
105
  if add_file:
106
- ProcessFile(file_path, "dense", compute_mode)
107
 
108
  if update_file:
109
  print(f"Chroma: updated embeddings for {file_path}")
@@ -114,7 +101,7 @@ def ProcessDirectory(path, compute_mode):
114
  ]
115
  files_to_keep = list(set(used_doc_ids))
116
  # Get all files in the file store
117
- file_store = f"{db_dir}/file_store_{compute_mode}"
118
  all_files = os.listdir(file_store)
119
  # Iterate through the files and delete those not in the list
120
  for file in all_files:
@@ -127,93 +114,32 @@ def ProcessDirectory(path, compute_mode):
127
  print(f"Chroma: no change for {file_path}")
128
 
129
 
130
- def GetChatModel(compute_mode, ckpt_dir=None):
131
- """
132
- Get a chat model.
133
-
134
- Args:
135
- compute_mode: Compute mode for chat model (remote or local)
136
- ckpt_dir: Checkpoint directory for model weights (optional)
137
- """
138
-
139
- if compute_mode == "remote":
140
-
141
- chat_model = ChatOpenAI(model=openai_model, temperature=0)
142
-
143
- if compute_mode == "local":
144
-
145
- # Don't try to use local models without a GPU
146
- if compute_mode == "local" and not torch.cuda.is_available():
147
- raise Exception("Local chat model selected without GPU")
148
-
149
- # Define the pipeline to pass to the HuggingFacePipeline class
150
- # https://huggingface.co/blog/langchain
151
- id_or_dir = ckpt_dir if ckpt_dir else model_id
152
- tokenizer = AutoTokenizer.from_pretrained(id_or_dir)
153
- model = AutoModelForCausalLM.from_pretrained(
154
- id_or_dir,
155
- # We need this to load the model in BF16 instead of fp32 (torch.float)
156
- torch_dtype=torch.bfloat16,
157
- # Enable FlashAttention (requires pip install flash-attn)
158
- # https://huggingface.co/docs/transformers/en/attention_interface
159
- # https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention
160
- # attn_implementation="flash_attention_2",
161
- )
162
- # For Flash Attention version of Qwen3
163
- tokenizer.padding_side = "left"
164
-
165
- # Use MyTextGenerationPipeline with custom preprocess() method
166
- pipe = MyTextGenerationPipeline(
167
- model=model,
168
- tokenizer=tokenizer,
169
- # ToolCallingLLM needs return_full_text=False in order to parse just the assistant response
170
- return_full_text=False,
171
- # It seems that max_new_tokens has to be specified here, not in .invoke()
172
- max_new_tokens=2000,
173
- # Use padding for proper alignment for FlashAttention
174
- # Part of fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
175
- # https://github.com/google-deepmind/gemma/issues/169
176
- padding="longest",
177
- )
178
- # We need the task so HuggingFacePipeline can deal with our class
179
- pipe.task = "text-generation"
180
-
181
- llm = HuggingFacePipeline(pipeline=pipe)
182
- chat_model = ChatHuggingFace(llm=llm)
183
-
184
- return chat_model
185
-
186
-
187
  def RunChain(
188
  query,
189
- compute_mode: str = "remote",
190
  search_type: str = "hybrid",
191
- think: bool = False,
192
  ):
193
  """
194
  Run chain to retrieve documents and send to chat
195
 
196
  Args:
197
  query: User's query
198
- compute_mode: Compute mode for embedding and chat models (remote or local)
199
  search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
200
- think: Control thinking mode for SmolLM3
201
 
202
  Example:
203
  RunChain("What R functions are discussed?")
204
  """
205
 
206
  # Get retriever instance
207
- retriever = BuildRetriever(compute_mode, search_type)
208
 
209
  if retriever is None:
210
  return "No retriever available. Please process some documents first."
211
 
212
  # Get chat model (LLM)
213
- chat_model = GetChatModel(compute_mode)
214
 
215
- # Get prompt with /no_think for SmolLM3/Qwen
216
- system_prompt = answer_prompt(chat_model)
217
 
218
  # Create a prompt template
219
  system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
@@ -244,22 +170,16 @@ def RunChain(
244
 
245
  def RunGraph(
246
  query: str,
247
- compute_mode: str = "remote",
248
  search_type: str = "hybrid",
249
  top_k: int = 6,
250
- think_query=False,
251
- think_answer=False,
252
  thread_id=None,
253
  ):
254
  """Run graph for conversational RAG app
255
 
256
  Args:
257
  query: User query to start the chat
258
- compute_mode: Compute mode for embedding and chat models (remote or local)
259
  search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
260
  top_k: Number of documents to retrieve
261
- think_query: Whether to use thinking mode for the query
262
- think_answer: Whether to use thinking mode for the answer
263
  thread_id: Thread ID for memory (optional)
264
 
265
  Example:
@@ -267,15 +187,12 @@ def RunGraph(
267
  """
268
 
269
  # Get chat model used in both query and generate steps
270
- chat_model = GetChatModel(compute_mode)
271
  # Build the graph
272
  graph_builder = BuildGraph(
273
  chat_model,
274
- compute_mode,
275
  search_type,
276
  top_k,
277
- think_query,
278
- think_answer,
279
  )
280
 
281
  # Compile the graph with an in-memory checkpointer
 
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langchain_core.messages import SystemMessage
7
  from langchain_core.messages import ToolMessage
8
+ from langchain_openai import ChatOpenAI
9
  from dotenv import load_dotenv
10
  from datetime import datetime
11
  import logging
 
12
  import glob
13
  import ast
14
  import os
15
 
 
 
 
 
16
  # Local modules
 
17
  from retriever import BuildRetriever, db_dir
18
  from prompts import answer_prompt
19
  from index import ProcessFile
 
27
  # Setup environment variables
28
  load_dotenv(dotenv_path=".env", override=True)
29
 
30
+ # Define the OpenAI model
31
  openai_model = "gpt-4o-mini"
32
 
 
 
 
 
 
 
 
33
  # Suppress these messages:
34
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
35
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
 
38
  httpx_logger.setLevel(logging.WARNING)
39
 
40
 
41
+ def ProcessDirectory(path):
42
  """
43
  Update vector store and sparse index for files in a directory, only adding new or updated files
44
 
45
  Args:
46
  path: Directory to process
 
47
 
48
  Usage example:
49
+ ProcessDirectory("R-help")
50
  """
51
 
52
  # TODO: use UUID to process only changed documents
53
  # https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
54
 
55
  # Get a dense retriever instance
56
+ retriever = BuildRetriever("dense")
57
 
58
  # List all text files in target directory
59
  file_paths = glob.glob(f"{path}/*.txt")
60
  for file_path in file_paths:
61
 
62
  # Process file for sparse search (BM25S)
63
+ ProcessFile(file_path, "sparse")
64
 
65
  # Logic for dense search: skip file if already indexed
66
  # Look for existing embeddings for this file
 
90
  update_file = True
91
 
92
  if add_file:
93
+ ProcessFile(file_path, "dense")
94
 
95
  if update_file:
96
  print(f"Chroma: updated embeddings for {file_path}")
 
101
  ]
102
  files_to_keep = list(set(used_doc_ids))
103
  # Get all files in the file store
104
+ file_store = f"{db_dir}/file_store"
105
  all_files = os.listdir(file_store)
106
  # Iterate through the files and delete those not in the list
107
  for file in all_files:
 
114
  print(f"Chroma: no change for {file_path}")
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def RunChain(
118
  query,
 
119
  search_type: str = "hybrid",
 
120
  ):
121
  """
122
  Run chain to retrieve documents and send to chat
123
 
124
  Args:
125
  query: User's query
 
126
  search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
 
127
 
128
  Example:
129
  RunChain("What R functions are discussed?")
130
  """
131
 
132
  # Get retriever instance
133
+ retriever = BuildRetriever(search_type)
134
 
135
  if retriever is None:
136
  return "No retriever available. Please process some documents first."
137
 
138
  # Get chat model (LLM)
139
+ chat_model = ChatOpenAI(model=openai_model, temperature=0)
140
 
141
+ # Get system prompt
142
+ system_prompt = answer_prompt()
143
 
144
  # Create a prompt template
145
  system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
 
170
 
171
  def RunGraph(
172
  query: str,
 
173
  search_type: str = "hybrid",
174
  top_k: int = 6,
 
 
175
  thread_id=None,
176
  ):
177
  """Run graph for conversational RAG app
178
 
179
  Args:
180
  query: User query to start the chat
 
181
  search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
182
  top_k: Number of documents to retrieve
 
 
183
  thread_id: Thread ID for memory (optional)
184
 
185
  Example:
 
187
  """
188
 
189
  # Get chat model used in both query and generate steps
190
+ chat_model = ChatOpenAI(model=openai_model, temperature=0)
191
  # Build the graph
192
  graph_builder = BuildGraph(
193
  chat_model,
 
194
  search_type,
195
  top_k,
 
 
196
  )
197
 
198
  # Compile the graph with an in-memory checkpointer
mods/tool_calling_llm.py DELETED
@@ -1,313 +0,0 @@
1
- import re
2
- import json
3
- import uuid
4
- import warnings
5
- from abc import ABC
6
- from typing import (
7
- Any,
8
- AsyncIterator,
9
- Callable,
10
- Dict,
11
- List,
12
- Optional,
13
- Sequence,
14
- Tuple,
15
- Type,
16
- Union,
17
- cast,
18
- )
19
-
20
- from langchain_core.callbacks import (
21
- AsyncCallbackManagerForLLMRun,
22
- CallbackManagerForLLMRun,
23
- )
24
- from langchain_core.language_models import BaseChatModel, LanguageModelInput
25
- from langchain_core.messages import (
26
- SystemMessage,
27
- AIMessage,
28
- BaseMessage,
29
- BaseMessageChunk,
30
- ToolCall,
31
- )
32
- from langchain_core.outputs import ChatGeneration, ChatResult
33
- from langchain_core.prompts import SystemMessagePromptTemplate
34
- from pydantic import BaseModel
35
- from langchain_core.runnables import Runnable, RunnableConfig
36
- from langchain_core.tools import BaseTool
37
- from langchain_core.utils.function_calling import convert_to_openai_tool
38
-
39
- DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
40
-
41
- {tools}
42
-
43
- You must always select one of the above tools and respond with only a JSON object matching the following schema:
44
-
45
- {{
46
- "tool": <name of selected tool 1>,
47
- "tool_input": <parameters for selected tool 1, matching the tool's JSON schema>
48
- }},
49
- {{
50
- "tool": <name of selected tool 2>,
51
- "tool_input": <parameters for selected tool 2, matching the tool's JSON schema>
52
- }}
53
- """ # noqa: E501
54
-
55
-
56
- def extract_think(content):
57
- # Added by Cursor 20250726 jmd
58
- # Extract content within <think>...</think>
59
- think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
60
- think_text = think_match.group(1).strip() if think_match else ""
61
- # Extract text after </think>
62
- if think_match:
63
- post_think = content[think_match.end() :].lstrip()
64
- else:
65
- # Check if content starts with <think> but missing closing tag
66
- if content.strip().startswith("<think>"):
67
- # Extract everything after <think>
68
- think_start = content.find("<think>") + len("<think>")
69
- think_text = content[think_start:].strip()
70
- post_think = ""
71
- else:
72
- # No <think> found, so return entire content as post_think
73
- post_think = content
74
- return think_text, post_think
75
-
76
-
77
- class ToolCallingLLM(BaseChatModel, ABC):
78
- """ToolCallingLLM mixin to enable tool calling features on non tool calling models.
79
-
80
- Note: This is an incomplete mixin and should not be used directly. It must be used to extent an existing Chat Model.
81
-
82
- Setup:
83
- Install dependencies for your Chat Model.
84
- Any API Keys or setup needed for your Chat Model is still applicable.
85
-
86
- Key init args — completion params:
87
- Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
88
-
89
- Key init args — client params:
90
- Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
91
-
92
- See full list of supported init args and their descriptions in the params section.
93
-
94
- Instantiate:
95
- ```
96
- # Example implementation using LiteLLM
97
- from langchain_community.chat_models import ChatLiteLLM
98
-
99
- class LiteLLMFunctions(ToolCallingLLM, ChatLiteLLM):
100
-
101
- def __init__(self, **kwargs: Any) -> None:
102
- super().__init__(**kwargs)
103
-
104
- @property
105
- def _llm_type(self) -> str:
106
- return "litellm_functions"
107
-
108
- llm = LiteLLMFunctions(model="ollama/phi3")
109
- ```
110
-
111
- Invoke:
112
- ```
113
- messages = [
114
- ("human", "What is the capital of France?")
115
- ]
116
- llm.invoke(messages)
117
- ```
118
- ```
119
- AIMessage(content='The capital of France is Paris.', id='run-497d0e1a-d63b-45e8-9c8b-5e76d99b9468-0')
120
- ```
121
-
122
- Tool calling:
123
- ```
124
- from pydantic import BaseModel, Field
125
-
126
- class GetWeather(BaseModel):
127
- '''Get the current weather in a given location'''
128
-
129
- location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
130
-
131
- class GetPopulation(BaseModel):
132
- '''Get the current population in a given location'''
133
-
134
- location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
135
-
136
- llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
137
- ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
138
- ai_msg.tool_calls
139
- ```
140
- ```
141
- [{'name': 'GetWeather', 'args': {'location': 'Austin, TX'}, 'id': 'call_25ed526917b94d8fa5db3fe30a8cf3c0'}]
142
- ```
143
-
144
- Response metadata
145
- Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
146
-
147
- """ # noqa: E501
148
-
149
- tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
150
-
151
- def __init__(self, **kwargs: Any) -> None:
152
- super().__init__(**kwargs)
153
-
154
- def _generate_system_message_and_functions(
155
- self,
156
- kwargs: Dict[str, Any],
157
- ) -> Tuple[BaseMessage, List]:
158
- functions = kwargs.get("tools", [])
159
-
160
- # Convert functions to OpenAI tool schema
161
- functions = [convert_to_openai_tool(fn) for fn in functions]
162
- # Create system message with tool descriptions
163
- system_message_prompt_template = SystemMessagePromptTemplate.from_template(
164
- self.tool_system_prompt_template
165
- )
166
- system_message = system_message_prompt_template.format(
167
- tools=json.dumps(functions, indent=2)
168
- )
169
- return system_message, functions
170
-
171
- def _process_response(
172
- self, response_message: BaseMessage, functions: List[Dict]
173
- ) -> AIMessage:
174
- if not isinstance(response_message.content, str):
175
- raise ValueError("ToolCallingLLM does not support non-string output.")
176
-
177
- # Extract <think>...</think> content and text after </think> for further processing 20250726 jmd
178
- think_text, post_think = extract_think(response_message.content)
179
-
180
- ## For debugging
181
- # print("post_think")
182
- # print(post_think)
183
-
184
- # Remove backticks around code blocks
185
- post_think = re.sub(r"^```json", "", post_think)
186
- post_think = re.sub(r"^```", "", post_think)
187
- post_think = re.sub(r"```$", "", post_think)
188
- # Remove intervening backticks from adjacent code blocks
189
- post_think = re.sub(r"```\n```json", ",", post_think)
190
- # Remove trailing comma (if there is one)
191
- post_think = post_think.rstrip(",")
192
- # Parse output for JSON (support multiple objects separated by commas)
193
- try:
194
- # Works for one JSON object, or multiple JSON objects enclosed in "[]"
195
- parsed_json_results = json.loads(f"{post_think}")
196
- if not isinstance(parsed_json_results, list):
197
- parsed_json_results = [parsed_json_results]
198
- except:
199
- try:
200
- # Works for multiple JSON objects not enclosed in "[]"
201
- parsed_json_results = json.loads(f"[{post_think}]")
202
- except json.JSONDecodeError:
203
- # Return entire response if JSON wasn't parsed or is missing
204
- return AIMessage(content=response_message.content)
205
-
206
- # print("parsed_json_results")
207
- # print(parsed_json_results)
208
-
209
- tool_calls = []
210
- for parsed_json_result in parsed_json_results:
211
- # Get tool name from output
212
- called_tool_name = (
213
- parsed_json_result["tool"]
214
- if "tool" in parsed_json_result
215
- else (
216
- parsed_json_result["name"] if "name" in parsed_json_result else None
217
- )
218
- )
219
-
220
- # Check if tool name is in functions list
221
- called_tool = next(
222
- (fn for fn in functions if fn["function"]["name"] == called_tool_name),
223
- None,
224
- )
225
- if called_tool is None:
226
- # Issue a warning and skip this tool call
227
- warnings.warn(f"Called tool ({called_tool_name}) not in functions list")
228
- continue
229
-
230
- # Get tool arguments from output
231
- called_tool_arguments = (
232
- parsed_json_result["tool_input"]
233
- if "tool_input" in parsed_json_result
234
- else (
235
- parsed_json_result["parameters"]
236
- if "parameters" in parsed_json_result
237
- else {}
238
- )
239
- )
240
-
241
- tool_calls.append(
242
- ToolCall(
243
- name=called_tool_name,
244
- args=called_tool_arguments,
245
- id=f"call_{str(uuid.uuid4()).replace('-', '')}",
246
- )
247
- )
248
-
249
- if not tool_calls:
250
- # If nothing valid, return original content
251
- return AIMessage(content=response_message.content)
252
-
253
- # Put together response message
254
- response_message = AIMessage(
255
- content=f"<think>\n{think_text}\n</think>",
256
- tool_calls=tool_calls,
257
- )
258
- return response_message
259
-
260
- def _generate(
261
- self,
262
- messages: List[BaseMessage],
263
- stop: Optional[List[str]] = None,
264
- run_manager: Optional[CallbackManagerForLLMRun] = None,
265
- **kwargs: Any,
266
- ) -> ChatResult:
267
- system_message, functions = self._generate_system_message_and_functions(kwargs)
268
- response_message = super()._generate( # type: ignore[safe-super]
269
- [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
270
- )
271
- response = self._process_response(
272
- response_message.generations[0].message, functions
273
- )
274
- return ChatResult(generations=[ChatGeneration(message=response)])
275
-
276
- async def _agenerate(
277
- self,
278
- messages: List[BaseMessage],
279
- stop: Optional[List[str]] = None,
280
- run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
281
- **kwargs: Any,
282
- ) -> ChatResult:
283
- system_message, functions = self._generate_system_message_and_functions(kwargs)
284
- response_message = await super()._agenerate(
285
- [system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
286
- )
287
- response = self._process_response(
288
- response_message.generations[0].message, functions
289
- )
290
- return ChatResult(generations=[ChatGeneration(message=response)])
291
-
292
- async def astream(
293
- self,
294
- input: LanguageModelInput,
295
- config: Optional[RunnableConfig] = None,
296
- *,
297
- stop: Optional[List[str]] = None,
298
- **kwargs: Any,
299
- ) -> AsyncIterator[BaseMessageChunk]:
300
- system_message, functions = self._generate_system_message_and_functions(kwargs)
301
- generation: Optional[BaseMessageChunk] = None
302
- async for chunk in super().astream(
303
- [system_message] + super()._convert_input(input).to_messages(),
304
- stop=stop,
305
- **kwargs,
306
- ):
307
- if generation is None:
308
- generation = chunk
309
- else:
310
- generation += chunk
311
- assert generation is not None
312
- response = self._process_response(generation, functions)
313
- yield cast(BaseMessageChunk, response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline.py DELETED
@@ -1,86 +0,0 @@
1
- from transformers.pipelines.text_generation import Chat
2
- from transformers import TextGenerationPipeline
3
- from typing import Dict
4
-
5
-
6
- class MyTextGenerationPipeline(TextGenerationPipeline):
7
- """
8
- This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs.
9
- Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
10
- https://github.com/google-deepmind/gemma/issues/169
11
- NOTE: we also need padding="longest", which is set during class instantiation
12
- """
13
-
14
- def preprocess(
15
- self,
16
- prompt_text,
17
- prefix="",
18
- handle_long_generation=None,
19
- add_special_tokens=None,
20
- truncation=None,
21
- padding=None,
22
- max_length=None,
23
- continue_final_message=None,
24
- **generate_kwargs,
25
- ):
26
- # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
27
- tokenizer_kwargs = {
28
- "add_special_tokens": add_special_tokens,
29
- "truncation": truncation,
30
- "padding": padding,
31
- "max_length": max_length,
32
- "pad_to_multiple_of": 8,
33
- }
34
- tokenizer_kwargs = {
35
- key: value for key, value in tokenizer_kwargs.items() if value is not None
36
- }
37
-
38
- if isinstance(prompt_text, Chat):
39
- tokenizer_kwargs.pop(
40
- "add_special_tokens", None
41
- ) # ignore add_special_tokens on chats
42
- # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
43
- # because very few models support multiple separate, consecutive assistant messages
44
- if continue_final_message is None:
45
- continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
46
- inputs = self.tokenizer.apply_chat_template(
47
- prompt_text.messages,
48
- add_generation_prompt=not continue_final_message,
49
- continue_final_message=continue_final_message,
50
- return_dict=True,
51
- return_tensors=self.framework,
52
- **tokenizer_kwargs,
53
- )
54
- else:
55
- inputs = self.tokenizer(
56
- prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs
57
- )
58
-
59
- inputs["prompt_text"] = prompt_text
60
-
61
- if handle_long_generation == "hole":
62
- cur_len = inputs["input_ids"].shape[-1]
63
- if "max_new_tokens" in generate_kwargs:
64
- new_tokens = generate_kwargs["max_new_tokens"]
65
- else:
66
- new_tokens = (
67
- generate_kwargs.get("max_length", self.generation_config.max_length)
68
- - cur_len
69
- )
70
- if new_tokens < 0:
71
- raise ValueError("We cannot infer how many new tokens are expected")
72
- if cur_len + new_tokens > self.tokenizer.model_max_length:
73
- keep_length = self.tokenizer.model_max_length - new_tokens
74
- if keep_length <= 0:
75
- raise ValueError(
76
- "We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
77
- " models max length"
78
- )
79
-
80
- inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
81
- if "attention_mask" in inputs:
82
- inputs["attention_mask"] = inputs["attention_mask"][
83
- :, -keep_length:
84
- ]
85
-
86
- return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompts.py CHANGED
@@ -3,22 +3,16 @@ from util import get_sources, get_start_end_months
3
  import re
4
 
5
 
6
- def check_prompt(prompt, chat_model, think):
7
- """Check for unassigned variables and add /no_think if needed"""
8
  # A sanity check that we don't have unassigned variables
9
- # (this causes KeyError in parsing by ToolCallingLLM)
10
  matches = re.findall(r"\{.*?\}", " ".join(prompt))
11
  if matches:
12
  raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
13
- # Check if we should add /no_think to turn off thinking mode
14
- if hasattr(chat_model, "model_id"):
15
- model_id = chat_model.model_id
16
- if ("SmolLM" in model_id or "Qwen" in model_id) and not think:
17
- prompt = "/no_think\n" + prompt
18
  return prompt
19
 
20
 
21
- def query_prompt(chat_model, think=False):
22
  """Return system prompt for query step"""
23
 
24
  # Get start and end months from database
@@ -43,12 +37,12 @@ def query_prompt(chat_model, think=False):
43
  # "Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question. " # Qwen
44
  # "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list. "
45
  )
46
- prompt = check_prompt(prompt, chat_model, think)
47
 
48
  return prompt
49
 
50
 
51
- def answer_prompt(chat_model, think=False, with_tools=False):
52
  """Return system prompt for answer step"""
53
  prompt = (
54
  f"Today Date: {date.today()}. "
@@ -64,61 +58,8 @@ def answer_prompt(chat_model, think=False, with_tools=False):
64
  "Only answer general questions about R if the answer is in the retrieved emails. "
65
  "Only include URLs if they were used by human authors (not in email headers), and do not modify any URLs. " # Qwen, Gemma
66
  "Respond with 500 words maximum and 50 lines of code maximum. "
 
67
  )
68
- if with_tools:
69
- prompt = (
70
- f"{prompt}"
71
- "Use answer_with_citations to provide the complete answer and all citations used. "
72
- )
73
- prompt = check_prompt(prompt, chat_model, think)
74
 
75
  return prompt
76
-
77
-
78
- # Prompt template for SmolLM3 with tools
79
- # The first two lines, <function-name>, and <args-json-object> are from the apply_chat_template for HuggingFaceTB/SmolLM3-3B
80
- # The other lines (You have, {tools}, You must), "tool", and "tool_input" are from tool_calling_llm.py
81
- smollm3_tools_template = """
82
-
83
- ### Tools
84
-
85
- You may call one or more functions to assist with the user query.
86
-
87
- You have access to the following tools:
88
-
89
- {tools}
90
-
91
- You must always select one of the above tools and respond with only a JSON object matching the following schema:
92
-
93
- {{
94
- "tool": <function-name>,
95
- "tool_input": <args-json-object>
96
- }},
97
- {{
98
- "tool": <function-name>,
99
- "tool_input": <args-json-object>
100
- }}
101
-
102
- """
103
-
104
- # Prompt template for Gemma/Qwen with tools
105
- # Based on https://ai.google.dev/gemma/docs/capabilities/function-calling
106
- generic_tools_template = """
107
-
108
- ### Functions
109
-
110
- You have access to functions. If you decide to invoke any of the function(s), you MUST put it in the format of
111
-
112
- {{
113
- "tool": <function-name>,
114
- "tool_input": <args-json-object>
115
- }},
116
- {{
117
- "tool": <function-name>,
118
- "tool_input": <args-json-object>
119
- }}
120
-
121
- You SHOULD NOT include any other text in the response if you call a function
122
-
123
- {tools}
124
- """
 
3
  import re
4
 
5
 
6
+ def check_prompt(prompt):
7
+ """Check for unassigned variables"""
8
  # A sanity check that we don't have unassigned variables
 
9
  matches = re.findall(r"\{.*?\}", " ".join(prompt))
10
  if matches:
11
  raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
 
 
 
 
 
12
  return prompt
13
 
14
 
15
+ def query_prompt():
16
  """Return system prompt for query step"""
17
 
18
  # Get start and end months from database
 
37
  # "Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question. " # Qwen
38
  # "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list. "
39
  )
40
+ prompt = check_prompt(prompt)
41
 
42
  return prompt
43
 
44
 
45
+ def answer_prompt():
46
  """Return system prompt for answer step"""
47
  prompt = (
48
  f"Today Date: {date.today()}. "
 
58
  "Only answer general questions about R if the answer is in the retrieved emails. "
59
  "Only include URLs if they were used by human authors (not in email headers), and do not modify any URLs. " # Qwen, Gemma
60
  "Respond with 500 words maximum and 50 lines of code maximum. "
61
+ "Use answer_with_citations to provide the complete answer and all citations used. "
62
  )
63
+ prompt = check_prompt(prompt)
 
 
 
 
 
64
 
65
  return prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,25 +1,17 @@
1
- # Pin torch and chroma versions
2
- torch==2.5.1
 
 
 
 
 
 
3
  chromadb==0.6.3
4
  # NOTE: chromadb==1.0.13 was giving intermittent error:
5
  # ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
6
-
7
- # FlashAttention
8
- #flash-attn==2.8.2
9
-
10
- # Stated requirements:
11
- # Gemma 3: transformers>=4.50
12
- # Qwen3: transformers>=4.51
13
- # SmolLM3: transformers>=4.53
14
- transformers==4.51.3
15
- tokenizers==0.21.2
16
- # Only needed with AutoModelForCausalLM.from_pretrained(device_map="auto")
17
- #accelerate==1.8.1
18
-
19
- # Required by langchain-huggingface
20
- sentence-transformers==5.0.0
21
- # For snapshot_download
22
- huggingface-hub==0.34.3
23
 
24
  # Langchain packages
25
  langchain==0.3.26
@@ -27,31 +19,14 @@ langchain-core==0.3.72
27
  langchain-chroma==0.2.3
28
  langchain-openai==0.3.27
29
  langchain-community==0.3.27
30
- langchain-huggingface==0.3.0
31
  langchain-text-splitters==0.3.8
32
  langgraph==0.4.7
33
  langgraph-sdk==0.1.72
34
  langgraph-prebuilt==0.5.2
35
  langgraph-checkpoint==2.1.0
36
 
37
- # Required by Nomic embeddings
38
- einops==0.8.1
39
-
40
- # Commented because we have local modifications
41
- #tool-calling-llm==0.1.2
42
- bm25s==0.2.12
43
  ragas==0.2.15
44
 
45
- # posthog<6.0.0 is temporary fix for ChromaDB telemetry error log messages
46
- # https://github.com/vanna-ai/vanna/issues/917
47
- posthog==5.4.0
48
-
49
- # Gradio for the web interface
50
  gradio==5.38.2
51
- spaces==0.37.1
52
-
53
- # For downloading data from S3
54
- boto3==1.39.14
55
-
56
- # Others
57
- python-dotenv==1.1.1
 
1
+ # To load API keys
2
+ python-dotenv==1.1.1
3
+
4
+ # To download data from S3
5
+ boto3==1.39.14
6
+
7
+ # Retrieval
8
+ bm25s==0.2.12
9
  chromadb==0.6.3
10
  # NOTE: chromadb==1.0.13 was giving intermittent error:
11
  # ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
12
+ # posthog<6.0.0 is temporary fix for ChromaDB telemetry error log messages
13
+ # https://github.com/vanna-ai/vanna/issues/917
14
+ posthog==5.4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Langchain packages
17
  langchain==0.3.26
 
19
  langchain-chroma==0.2.3
20
  langchain-openai==0.3.27
21
  langchain-community==0.3.27
 
22
  langchain-text-splitters==0.3.8
23
  langgraph==0.4.7
24
  langgraph-sdk==0.1.72
25
  langgraph-prebuilt==0.5.2
26
  langgraph-checkpoint==2.1.0
27
 
28
+ # Evaluations
 
 
 
 
 
29
  ragas==0.2.15
30
 
31
+ # Frontend
 
 
 
 
32
  gradio==5.38.2
 
 
 
 
 
 
 
retriever.py CHANGED
@@ -1,25 +1,17 @@
1
  # Main retriever modules
2
- from langchain_text_splitters import RecursiveCharacterTextSplitter
3
- from langchain_community.document_loaders import TextLoader
4
- from langchain_chroma import Chroma
5
  from langchain.retrievers import ParentDocumentRetriever, EnsembleRetriever
6
- from langchain_core.documents import Document
7
  from langchain_core.retrievers import BaseRetriever, RetrieverLike
8
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
 
 
 
 
 
9
  from typing import Any, Optional
10
  import chromadb
11
- import torch
12
  import os
13
  import re
14
 
15
- # To use OpenAI models (remote)
16
- from langchain_openai import OpenAIEmbeddings
17
-
18
- ## To use Hugging Face models (local)
19
- # from langchain_huggingface import HuggingFaceEmbeddings
20
- # For more control over BGE and Nomic embeddings
21
- from langchain_community.embeddings import HuggingFaceBgeEmbeddings
22
-
23
  # Local modules
24
  from mods.bm25s_retriever import BM25SRetriever
25
  from mods.file_system import LocalFileStore
@@ -27,41 +19,30 @@ from mods.file_system import LocalFileStore
27
  # Database directory
28
  db_dir = "db"
29
 
30
- # Embedding model
31
- embedding_model_id = "nomic-ai/nomic-embed-text-v1.5"
32
-
33
 
34
  def BuildRetriever(
35
- compute_mode,
36
  search_type: str = "hybrid",
37
  top_k=6,
38
  start_year=None,
39
  end_year=None,
40
- embedding_ckpt_dir=None,
41
  ):
42
  """
43
  Build retriever instance.
44
  All retriever types are configured to return up to 6 documents for fair comparison in evals.
45
 
46
  Args:
47
- compute_mode: Compute mode for embeddings (remote or local)
48
  search_type: Type of search to use. Options: "dense", "sparse", "hybrid"
49
  top_k: Number of documents to retrieve for "dense" and "sparse"
50
  start_year: Start year (optional)
51
  end_year: End year (optional)
52
- embedding_ckpt_dir: Directory for embedding model checkpoint
53
  """
54
  if search_type == "dense":
55
  if not (start_year or end_year):
56
  # No year filtering, so directly use base retriever
57
- return BuildRetrieverDense(
58
- compute_mode, top_k=top_k, embedding_ckpt_dir=embedding_ckpt_dir
59
- )
60
  else:
61
  # Get 1000 documents then keep top_k filtered by year
62
- base_retriever = BuildRetrieverDense(
63
- compute_mode, top_k=1000, embedding_ckpt_dir=embedding_ckpt_dir
64
- )
65
  return TopKRetriever(
66
  base_retriever=base_retriever,
67
  top_k=top_k,
@@ -85,20 +66,16 @@ def BuildRetriever(
85
  # Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k
86
  # https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python
87
  dense_retriever = BuildRetriever(
88
- compute_mode,
89
  "dense",
90
  (top_k // 2),
91
  start_year,
92
  end_year,
93
- embedding_ckpt_dir,
94
  )
95
  sparse_retriever = BuildRetriever(
96
- compute_mode,
97
  "sparse",
98
  -(top_k // -2),
99
  start_year,
100
  end_year,
101
- embedding_ckpt_dir,
102
  )
103
  ensemble_retriever = EnsembleRetriever(
104
  retrievers=[dense_retriever, sparse_retriever], weights=[1, 1]
@@ -128,43 +105,19 @@ def BuildRetrieverSparse(top_k=6):
128
  return retriever
129
 
130
 
131
- def BuildRetrieverDense(compute_mode: str, top_k=6, embedding_ckpt_dir=None):
132
  """
133
  Build dense retriever instance with ChromaDB vectorstore
134
 
135
  Args:
136
- compute_mode: Compute mode for embeddings (remote or local)
137
  top_k: Number of documents to retrieve
138
- embedding_ckpt_dir: Directory for embedding model checkpoint
139
  """
140
 
141
- # Don't try to use local models without a GPU
142
- if compute_mode == "local" and not torch.cuda.is_available():
143
- raise Exception("Local embeddings selected without GPU")
144
-
145
  # Define embedding model
146
- if compute_mode == "remote":
147
- embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
148
- if compute_mode == "local":
149
- # embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5", show_progress=True)
150
- # https://python.langchain.com/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceBgeEmbeddings.html
151
- model_kwargs = {
152
- "device": "cuda",
153
- "trust_remote_code": True,
154
- }
155
- encode_kwargs = {"normalize_embeddings": True}
156
- # Use embedding model ID or checkpoint directory if given
157
- id_or_dir = embedding_ckpt_dir if embedding_ckpt_dir else embedding_model_id
158
- embedding_function = HuggingFaceBgeEmbeddings(
159
- model_name=id_or_dir,
160
- model_kwargs=model_kwargs,
161
- encode_kwargs=encode_kwargs,
162
- query_instruction="search_query:",
163
- embed_instruction="search_document:",
164
- )
165
  # Create vector store
166
  client_settings = chromadb.config.Settings(anonymized_telemetry=False)
167
- persist_directory = f"{db_dir}/chroma_{compute_mode}"
168
  vectorstore = Chroma(
169
  collection_name="R-help",
170
  embedding_function=embedding_function,
@@ -172,7 +125,7 @@ def BuildRetrieverDense(compute_mode: str, top_k=6, embedding_ckpt_dir=None):
172
  persist_directory=persist_directory,
173
  )
174
  # The storage layer for the parent documents
175
- file_store = f"{db_dir}/file_store_{compute_mode}"
176
  byte_store = LocalFileStore(file_store)
177
  # Text splitter for child documents
178
  child_splitter = RecursiveCharacterTextSplitter(
 
1
  # Main retriever modules
 
 
 
2
  from langchain.retrievers import ParentDocumentRetriever, EnsembleRetriever
 
3
  from langchain_core.retrievers import BaseRetriever, RetrieverLike
4
  from langchain_core.callbacks import CallbackManagerForRetrieverRun
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain_community.document_loaders import TextLoader
7
+ from langchain_core.documents import Document
8
+ from langchain_openai import OpenAIEmbeddings
9
+ from langchain_chroma import Chroma
10
  from typing import Any, Optional
11
  import chromadb
 
12
  import os
13
  import re
14
 
 
 
 
 
 
 
 
 
15
  # Local modules
16
  from mods.bm25s_retriever import BM25SRetriever
17
  from mods.file_system import LocalFileStore
 
19
  # Database directory
20
  db_dir = "db"
21
 
 
 
 
22
 
23
  def BuildRetriever(
 
24
  search_type: str = "hybrid",
25
  top_k=6,
26
  start_year=None,
27
  end_year=None,
 
28
  ):
29
  """
30
  Build retriever instance.
31
  All retriever types are configured to return up to 6 documents for fair comparison in evals.
32
 
33
  Args:
 
34
  search_type: Type of search to use. Options: "dense", "sparse", "hybrid"
35
  top_k: Number of documents to retrieve for "dense" and "sparse"
36
  start_year: Start year (optional)
37
  end_year: End year (optional)
 
38
  """
39
  if search_type == "dense":
40
  if not (start_year or end_year):
41
  # No year filtering, so directly use base retriever
42
+ return BuildRetrieverDense(top_k=top_k)
 
 
43
  else:
44
  # Get 1000 documents then keep top_k filtered by year
45
+ base_retriever = BuildRetrieverDense(top_k=1000)
 
 
46
  return TopKRetriever(
47
  base_retriever=base_retriever,
48
  top_k=top_k,
 
66
  # Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k
67
  # https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python
68
  dense_retriever = BuildRetriever(
 
69
  "dense",
70
  (top_k // 2),
71
  start_year,
72
  end_year,
 
73
  )
74
  sparse_retriever = BuildRetriever(
 
75
  "sparse",
76
  -(top_k // -2),
77
  start_year,
78
  end_year,
 
79
  )
80
  ensemble_retriever = EnsembleRetriever(
81
  retrievers=[dense_retriever, sparse_retriever], weights=[1, 1]
 
105
  return retriever
106
 
107
 
108
+ def BuildRetrieverDense(top_k=6):
109
  """
110
  Build dense retriever instance with ChromaDB vectorstore
111
 
112
  Args:
 
113
  top_k: Number of documents to retrieve
 
114
  """
115
 
 
 
 
 
116
  # Define embedding model
117
+ embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  # Create vector store
119
  client_settings = chromadb.config.Settings(anonymized_telemetry=False)
120
+ persist_directory = f"{db_dir}/chroma"
121
  vectorstore = Chroma(
122
  collection_name="R-help",
123
  embedding_function=embedding_function,
 
125
  persist_directory=persist_directory,
126
  )
127
  # The storage layer for the parent documents
128
+ file_store = f"{db_dir}/file_store"
129
  byte_store = LocalFileStore(file_store)
130
  # Text splitter for child documents
131
  child_splitter = RecursiveCharacterTextSplitter(
util.py CHANGED
@@ -5,21 +5,6 @@ import os
5
  import re
6
 
7
 
8
- def get_collection(compute_mode):
9
- """
10
- Returns the vectorstore collection.
11
-
12
- Usage Examples:
13
- # Number of child documents
14
- collection = get_collection("remote")
15
- len(collection["ids"])
16
- # Number of parent documents (unique doc_ids)
17
- len(set([m["doc_id"] for m in collection["metadatas"]]))
18
- """
19
- retriever = BuildRetriever(compute_mode, "dense")
20
- return retriever.vectorstore.get()
21
-
22
-
23
  def get_sources():
24
  """
25
  Return the source files indexed in the database, e.g. 'R-help/2024-April.txt'.
 
5
  import re
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def get_sources():
9
  """
10
  Return the source files indexed in the database, e.g. 'R-help/2024-April.txt'.