Spaces:
Running
Running
jedick commited on
Commit ·
429393a
1
Parent(s): 8627eb1
Remove local compute mode
Browse files- app.py +67 -179
- eval.py +5 -14
- graph.py +12 -73
- images/graph_LR.png +0 -0
- index.py +4 -5
- main.py +13 -96
- mods/tool_calling_llm.py +0 -313
- pipeline.py +0 -86
- prompts.py +7 -66
- requirements.txt +13 -38
- retriever.py +11 -58
- util.py +0 -15
app.py
CHANGED
|
@@ -1,42 +1,25 @@
|
|
| 1 |
from langgraph.checkpoint.memory import MemorySaver
|
| 2 |
-
from
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
from datetime import datetime
|
| 5 |
import gradio as gr
|
| 6 |
-
import spaces
|
| 7 |
-
import torch
|
| 8 |
import uuid
|
| 9 |
import ast
|
| 10 |
import os
|
| 11 |
import re
|
| 12 |
|
| 13 |
# Local modules
|
| 14 |
-
from main import GetChatModel, openai_model, model_id
|
| 15 |
from util import get_sources, get_start_end_months
|
| 16 |
-
from retriever import db_dir, embedding_model_id
|
| 17 |
-
from mods.tool_calling_llm import extract_think
|
| 18 |
from data import download_data, extract_data
|
|
|
|
| 19 |
from graph import BuildGraph
|
|
|
|
| 20 |
|
| 21 |
# Set environment variables
|
| 22 |
load_dotenv(dotenv_path=".env", override=True)
|
| 23 |
# Hide BM25S progress bars
|
| 24 |
os.environ["DISABLE_TQDM"] = "true"
|
| 25 |
|
| 26 |
-
# Download model snapshots from Hugging Face Hub
|
| 27 |
-
if torch.cuda.is_available():
|
| 28 |
-
print(f"Downloading checkpoints for {model_id}...")
|
| 29 |
-
ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
|
| 30 |
-
print(f"Using checkpoints from {ckpt_dir}")
|
| 31 |
-
print(f"Downloading checkpoints for {embedding_model_id}...")
|
| 32 |
-
embedding_ckpt_dir = snapshot_download(
|
| 33 |
-
embedding_model_id, local_dir_use_symlinks=False
|
| 34 |
-
)
|
| 35 |
-
print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
|
| 36 |
-
else:
|
| 37 |
-
ckpt_dir = None
|
| 38 |
-
embedding_ckpt_dir = None
|
| 39 |
-
|
| 40 |
# Download and extract data if data directory is not present
|
| 41 |
if not os.path.isdir(db_dir):
|
| 42 |
print("Downloading data ... ", end="")
|
|
@@ -51,17 +34,35 @@ search_type = "hybrid"
|
|
| 51 |
|
| 52 |
# Global variables for LangChain graph: use dictionaries to store user-specific instances
|
| 53 |
# https://www.gradio.app/guides/state-in-blocks
|
| 54 |
-
graph_instances = {
|
| 55 |
|
| 56 |
|
| 57 |
def cleanup_graph(request: gr.Request):
|
| 58 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 59 |
-
if request.session_hash in graph_instances
|
| 60 |
-
del graph_instances[
|
| 61 |
-
print(f"{timestamp} - Delete
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def append_content(chunk_messages, history, thinking_about):
|
|
@@ -85,48 +86,32 @@ def append_content(chunk_messages, history, thinking_about):
|
|
| 85 |
return history
|
| 86 |
|
| 87 |
|
| 88 |
-
def run_workflow(input, history,
|
| 89 |
"""The main function to run the chat workflow"""
|
| 90 |
|
| 91 |
-
# Error if user tries to run local mode without GPU
|
| 92 |
-
if compute_mode == "local":
|
| 93 |
-
if not torch.cuda.is_available():
|
| 94 |
-
raise gr.Error(
|
| 95 |
-
"Local mode requires GPU.",
|
| 96 |
-
print_exception=False,
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
# Get graph instance
|
| 100 |
-
graph = graph_instances
|
| 101 |
|
| 102 |
if graph is None:
|
| 103 |
-
# Notify when we're loading the local model because it takes some time
|
| 104 |
-
if compute_mode == "local":
|
| 105 |
-
gr.Info(
|
| 106 |
-
f"Please wait for the local model to load",
|
| 107 |
-
title=f"Model loading...",
|
| 108 |
-
)
|
| 109 |
# Get the chat model and build the graph
|
| 110 |
-
chat_model =
|
| 111 |
graph_builder = BuildGraph(
|
| 112 |
chat_model,
|
| 113 |
-
compute_mode,
|
| 114 |
search_type,
|
| 115 |
-
embedding_ckpt_dir=embedding_ckpt_dir,
|
| 116 |
)
|
| 117 |
# Compile the graph with an in-memory checkpointer
|
| 118 |
memory = MemorySaver()
|
| 119 |
graph = graph_builder.compile(checkpointer=memory)
|
| 120 |
-
# Set global graph
|
| 121 |
-
graph_instances[
|
| 122 |
# ISO 8601 timestamp with local timezone information without microsecond
|
| 123 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 124 |
-
print(f"{timestamp} - Set
|
| 125 |
-
# Notify when model finishes loading
|
| 126 |
-
gr.Success(
|
| 127 |
else:
|
| 128 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 129 |
-
print(f"{timestamp} - Get
|
| 130 |
|
| 131 |
# print(f"Using thread_id: {thread_id}")
|
| 132 |
|
|
@@ -235,28 +220,11 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
| 235 |
|
| 236 |
|
| 237 |
def to_workflow(request: gr.Request, *args):
|
| 238 |
-
"""Wrapper function to call
|
| 239 |
input = args[0]
|
| 240 |
-
compute_mode = args[2]
|
| 241 |
# Add session_hash to arguments
|
| 242 |
new_args = args + (request.session_hash,)
|
| 243 |
-
|
| 244 |
-
# Call the workflow function with the @spaces.GPU decorator
|
| 245 |
-
for value in run_workflow_local(*new_args):
|
| 246 |
-
yield value
|
| 247 |
-
if compute_mode == "remote":
|
| 248 |
-
for value in run_workflow_remote(*new_args):
|
| 249 |
-
yield value
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
@spaces.GPU(duration=100)
|
| 253 |
-
def run_workflow_local(*args):
|
| 254 |
-
for value in run_workflow(*args):
|
| 255 |
-
yield value
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
def run_workflow_remote(*args):
|
| 259 |
-
for value in run_workflow(*args):
|
| 260 |
yield value
|
| 261 |
|
| 262 |
|
|
@@ -290,19 +258,6 @@ with gr.Blocks(
|
|
| 290 |
# Define components
|
| 291 |
# -----------------
|
| 292 |
|
| 293 |
-
compute_mode = gr.Radio(
|
| 294 |
-
choices=[
|
| 295 |
-
"local",
|
| 296 |
-
"remote",
|
| 297 |
-
],
|
| 298 |
-
# Default to remote because it provides a better first impression for most people
|
| 299 |
-
# value=("local" if torch.cuda.is_available() else "remote"),
|
| 300 |
-
value="remote",
|
| 301 |
-
label="Compute Mode",
|
| 302 |
-
info="NOTE: remote mode **does not** use ZeroGPU",
|
| 303 |
-
render=False,
|
| 304 |
-
)
|
| 305 |
-
|
| 306 |
loading_data = gr.Textbox(
|
| 307 |
"Please wait for the email database to be downloaded and extracted.",
|
| 308 |
max_lines=0,
|
|
@@ -332,14 +287,7 @@ with gr.Blocks(
|
|
| 332 |
chatbot = gr.Chatbot(
|
| 333 |
type="messages",
|
| 334 |
show_label=False,
|
| 335 |
-
avatar_images=(
|
| 336 |
-
None,
|
| 337 |
-
(
|
| 338 |
-
"images/cloud.png"
|
| 339 |
-
if compute_mode.value == "remote"
|
| 340 |
-
else "images/chip.png"
|
| 341 |
-
),
|
| 342 |
-
),
|
| 343 |
show_copy_all_button=True,
|
| 344 |
render=False,
|
| 345 |
)
|
|
@@ -398,24 +346,17 @@ with gr.Blocks(
|
|
| 398 |
and generates an answer from the retrieved emails (*emails are shown below the chatbot*).
|
| 399 |
You can ask follow-up questions with the chat history as context.
|
| 400 |
Press the clear button (🗑) to clear the history and start a new chat.
|
|
|
|
| 401 |
"""
|
| 402 |
return intro
|
| 403 |
|
| 404 |
-
def get_status_text(
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
"""
|
| 412 |
-
if compute_mode == "local":
|
| 413 |
-
status_text = f"""
|
| 414 |
-
📍 Now in **local** mode, using ZeroGPU hardware<br>
|
| 415 |
-
⌛ Response time is about one minute<br>
|
| 416 |
-
✨ [{embedding_model_id.split("/")[-1]}](https://huggingface.co/{embedding_model_id}) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
|
| 417 |
-
🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
|
| 418 |
-
"""
|
| 419 |
return status_text
|
| 420 |
|
| 421 |
def get_info_text():
|
|
@@ -430,13 +371,13 @@ with gr.Blocks(
|
|
| 430 |
end = None
|
| 431 |
info_text = f"""
|
| 432 |
**Database:** {len(sources)} emails from {start} to {end}.
|
| 433 |
-
**Features:** RAG, today's date, hybrid search (dense+sparse), multiple retrievals, citations output
|
| 434 |
-
**Tech:**
|
| 435 |
"""
|
| 436 |
return info_text
|
| 437 |
|
| 438 |
-
def get_example_questions(
|
| 439 |
-
"""Get example questions
|
| 440 |
questions = [
|
| 441 |
# "What is today's date?",
|
| 442 |
"Summarize emails from the most recent two months",
|
|
@@ -445,15 +386,11 @@ with gr.Blocks(
|
|
| 445 |
"Who reported installation problems in 2023-2024?",
|
| 446 |
]
|
| 447 |
|
| 448 |
-
## Remove "/think" from questions in remote mode
|
| 449 |
-
# if compute_mode == "remote":
|
| 450 |
-
# questions = [q.replace(" /think", "") for q in questions]
|
| 451 |
-
|
| 452 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
| 453 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 454 |
|
| 455 |
-
def get_multi_tool_questions(
|
| 456 |
-
"""Get multi-tool example questions
|
| 457 |
questions = [
|
| 458 |
"Differences between lapply and for loops",
|
| 459 |
"Discuss pipe operator usage in 2022, 2023, and 2024",
|
|
@@ -461,8 +398,8 @@ with gr.Blocks(
|
|
| 461 |
|
| 462 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 463 |
|
| 464 |
-
def get_multi_turn_questions(
|
| 465 |
-
"""Get multi-turn example questions
|
| 466 |
questions = [
|
| 467 |
"Lookup emails that reference bugs.r-project.org in 2025",
|
| 468 |
"Did the authors you cited report bugs before 2025?",
|
|
@@ -474,10 +411,14 @@ with gr.Blocks(
|
|
| 474 |
# Left column: Intro, Compute, Chat
|
| 475 |
with gr.Column(scale=2):
|
| 476 |
with gr.Row(elem_classes=["row-container"]):
|
| 477 |
-
with gr.Column(scale=
|
| 478 |
intro = gr.Markdown(get_intro_text())
|
| 479 |
with gr.Column(scale=1):
|
| 480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
with gr.Group() as chat_interface:
|
| 482 |
chatbot.render()
|
| 483 |
input.render()
|
|
@@ -488,29 +429,23 @@ with gr.Blocks(
|
|
| 488 |
missing_data.render()
|
| 489 |
# Right column: Info, Examples
|
| 490 |
with gr.Column(scale=1):
|
| 491 |
-
status = gr.Markdown(get_status_text(
|
| 492 |
with gr.Accordion("ℹ️ More Info", open=False):
|
| 493 |
info = gr.Markdown(get_info_text())
|
| 494 |
with gr.Accordion("💡 Examples", open=True):
|
| 495 |
# Add some helpful examples
|
| 496 |
example_questions = gr.Examples(
|
| 497 |
-
examples=get_example_questions(
|
| 498 |
-
compute_mode.value, as_dataset=False
|
| 499 |
-
),
|
| 500 |
inputs=[input],
|
| 501 |
label="Click an example to fill the message box",
|
| 502 |
)
|
| 503 |
multi_tool_questions = gr.Examples(
|
| 504 |
-
examples=get_multi_tool_questions(
|
| 505 |
-
compute_mode.value, as_dataset=False
|
| 506 |
-
),
|
| 507 |
inputs=[input],
|
| 508 |
label="Multiple retrievals",
|
| 509 |
)
|
| 510 |
multi_turn_questions = gr.Examples(
|
| 511 |
-
examples=get_multi_turn_questions(
|
| 512 |
-
compute_mode.value, as_dataset=False
|
| 513 |
-
),
|
| 514 |
inputs=[input],
|
| 515 |
label="Asking follow-up questions",
|
| 516 |
)
|
|
@@ -530,18 +465,6 @@ with gr.Blocks(
|
|
| 530 |
"""Return updated value for a component"""
|
| 531 |
return gr.update(value=value)
|
| 532 |
|
| 533 |
-
def set_avatar(compute_mode):
|
| 534 |
-
if compute_mode == "remote":
|
| 535 |
-
image_file = "images/cloud.png"
|
| 536 |
-
if compute_mode == "local":
|
| 537 |
-
image_file = "images/chip.png"
|
| 538 |
-
return gr.update(
|
| 539 |
-
avatar_images=(
|
| 540 |
-
None,
|
| 541 |
-
image_file,
|
| 542 |
-
),
|
| 543 |
-
)
|
| 544 |
-
|
| 545 |
def change_visibility(visible):
|
| 546 |
"""Return updated visibility state for a component"""
|
| 547 |
return gr.update(visible=visible)
|
|
@@ -565,45 +488,10 @@ with gr.Blocks(
|
|
| 565 |
# https://github.com/gradio-app/gradio/issues/9722
|
| 566 |
chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
|
| 567 |
|
| 568 |
-
def clear_component(component):
|
| 569 |
-
"""Return cleared component"""
|
| 570 |
-
return component.clear()
|
| 571 |
-
|
| 572 |
-
compute_mode.change(
|
| 573 |
-
# Start a new thread
|
| 574 |
-
generate_thread_id,
|
| 575 |
-
outputs=[thread_id],
|
| 576 |
-
api_name=False,
|
| 577 |
-
).then(
|
| 578 |
-
# Focus textbox by updating the textbox with the current value
|
| 579 |
-
lambda x: gr.update(value=x),
|
| 580 |
-
[input],
|
| 581 |
-
[input],
|
| 582 |
-
api_name=False,
|
| 583 |
-
).then(
|
| 584 |
-
# Change the app status text
|
| 585 |
-
get_status_text,
|
| 586 |
-
[compute_mode],
|
| 587 |
-
[status],
|
| 588 |
-
api_name=False,
|
| 589 |
-
).then(
|
| 590 |
-
# Clear the chatbot history
|
| 591 |
-
clear_component,
|
| 592 |
-
[chatbot],
|
| 593 |
-
[chatbot],
|
| 594 |
-
api_name=False,
|
| 595 |
-
).then(
|
| 596 |
-
# Change the chatbot avatar
|
| 597 |
-
set_avatar,
|
| 598 |
-
[compute_mode],
|
| 599 |
-
[chatbot],
|
| 600 |
-
api_name=False,
|
| 601 |
-
)
|
| 602 |
-
|
| 603 |
input.submit(
|
| 604 |
# Submit input to the chatbot
|
| 605 |
to_workflow,
|
| 606 |
-
[input, chatbot,
|
| 607 |
[chatbot, retrieved_emails, citations_text],
|
| 608 |
api_name=False,
|
| 609 |
)
|
|
|
|
| 1 |
from langgraph.checkpoint.memory import MemorySaver
|
| 2 |
+
from langchain_openai import ChatOpenAI
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
from datetime import datetime
|
| 5 |
import gradio as gr
|
|
|
|
|
|
|
| 6 |
import uuid
|
| 7 |
import ast
|
| 8 |
import os
|
| 9 |
import re
|
| 10 |
|
| 11 |
# Local modules
|
|
|
|
| 12 |
from util import get_sources, get_start_end_months
|
|
|
|
|
|
|
| 13 |
from data import download_data, extract_data
|
| 14 |
+
from main import openai_model
|
| 15 |
from graph import BuildGraph
|
| 16 |
+
from retriever import db_dir
|
| 17 |
|
| 18 |
# Set environment variables
|
| 19 |
load_dotenv(dotenv_path=".env", override=True)
|
| 20 |
# Hide BM25S progress bars
|
| 21 |
os.environ["DISABLE_TQDM"] = "true"
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# Download and extract data if data directory is not present
|
| 24 |
if not os.path.isdir(db_dir):
|
| 25 |
print("Downloading data ... ", end="")
|
|
|
|
| 34 |
|
| 35 |
# Global variables for LangChain graph: use dictionaries to store user-specific instances
|
| 36 |
# https://www.gradio.app/guides/state-in-blocks
|
| 37 |
+
graph_instances = {}
|
| 38 |
|
| 39 |
|
| 40 |
def cleanup_graph(request: gr.Request):
|
| 41 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 42 |
+
if request.session_hash in graph_instances:
|
| 43 |
+
del graph_instances[request.session_hash]
|
| 44 |
+
print(f"{timestamp} - Delete graph for session {request.session_hash}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def extract_think(content):
|
| 48 |
+
# Added by Cursor 20250726 jmd
|
| 49 |
+
# Extract content within <think>...</think>
|
| 50 |
+
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
| 51 |
+
think_text = think_match.group(1).strip() if think_match else ""
|
| 52 |
+
# Extract text after </think>
|
| 53 |
+
if think_match:
|
| 54 |
+
post_think = content[think_match.end() :].lstrip()
|
| 55 |
+
else:
|
| 56 |
+
# Check if content starts with <think> but missing closing tag
|
| 57 |
+
if content.strip().startswith("<think>"):
|
| 58 |
+
# Extract everything after <think>
|
| 59 |
+
think_start = content.find("<think>") + len("<think>")
|
| 60 |
+
think_text = content[think_start:].strip()
|
| 61 |
+
post_think = ""
|
| 62 |
+
else:
|
| 63 |
+
# No <think> found, so return entire content as post_think
|
| 64 |
+
post_think = content
|
| 65 |
+
return think_text, post_think
|
| 66 |
|
| 67 |
|
| 68 |
def append_content(chunk_messages, history, thinking_about):
|
|
|
|
| 86 |
return history
|
| 87 |
|
| 88 |
|
| 89 |
+
def run_workflow(input, history, thread_id, session_hash):
|
| 90 |
"""The main function to run the chat workflow"""
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# Get graph instance
|
| 93 |
+
graph = graph_instances.get(session_hash)
|
| 94 |
|
| 95 |
if graph is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Get the chat model and build the graph
|
| 97 |
+
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 98 |
graph_builder = BuildGraph(
|
| 99 |
chat_model,
|
|
|
|
| 100 |
search_type,
|
|
|
|
| 101 |
)
|
| 102 |
# Compile the graph with an in-memory checkpointer
|
| 103 |
memory = MemorySaver()
|
| 104 |
graph = graph_builder.compile(checkpointer=memory)
|
| 105 |
+
# Set global graph
|
| 106 |
+
graph_instances[session_hash] = graph
|
| 107 |
# ISO 8601 timestamp with local timezone information without microsecond
|
| 108 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 109 |
+
print(f"{timestamp} - Set graph for session {session_hash}")
|
| 110 |
+
## Notify when model finishes loading
|
| 111 |
+
# gr.Success("Model loaded!", duration=4)
|
| 112 |
else:
|
| 113 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 114 |
+
print(f"{timestamp} - Get graph for session {session_hash}")
|
| 115 |
|
| 116 |
# print(f"Using thread_id: {thread_id}")
|
| 117 |
|
|
|
|
| 220 |
|
| 221 |
|
| 222 |
def to_workflow(request: gr.Request, *args):
|
| 223 |
+
"""Wrapper function to call run_workflow() with session_hash"""
|
| 224 |
input = args[0]
|
|
|
|
| 225 |
# Add session_hash to arguments
|
| 226 |
new_args = args + (request.session_hash,)
|
| 227 |
+
for value in run_workflow(*new_args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
yield value
|
| 229 |
|
| 230 |
|
|
|
|
| 258 |
# Define components
|
| 259 |
# -----------------
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
loading_data = gr.Textbox(
|
| 262 |
"Please wait for the email database to be downloaded and extracted.",
|
| 263 |
max_lines=0,
|
|
|
|
| 287 |
chatbot = gr.Chatbot(
|
| 288 |
type="messages",
|
| 289 |
show_label=False,
|
| 290 |
+
avatar_images=(None, "images/cloud.png"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
show_copy_all_button=True,
|
| 292 |
render=False,
|
| 293 |
)
|
|
|
|
| 346 |
and generates an answer from the retrieved emails (*emails are shown below the chatbot*).
|
| 347 |
You can ask follow-up questions with the chat history as context.
|
| 348 |
Press the clear button (🗑) to clear the history and start a new chat.
|
| 349 |
+
🚧 Under construction: Select a mailing list to search, or use Auto to let the LLM choose.
|
| 350 |
"""
|
| 351 |
return intro
|
| 352 |
|
| 353 |
+
def get_status_text():
|
| 354 |
+
status_text = f"""
|
| 355 |
+
🌐 This app uses the OpenAI API<br>
|
| 356 |
+
⚠️ **_Privacy Notice_**: Data sharing with OpenAI is enabled<br>
|
| 357 |
+
✨ text-embedding-3-small and {openai_model}<br>
|
| 358 |
+
🏠 More info: [R-help-chat GitHub repository](https://github.com/jedick/R-help-chat)
|
| 359 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
return status_text
|
| 361 |
|
| 362 |
def get_info_text():
|
|
|
|
| 371 |
end = None
|
| 372 |
info_text = f"""
|
| 373 |
**Database:** {len(sources)} emails from {start} to {end}.
|
| 374 |
+
**Features:** RAG, today's date, hybrid search (dense+sparse), multiple retrievals, citations output, chat memory.
|
| 375 |
+
**Tech:** OpenAI API + LangGraph + Gradio; ChromaDB and BM25S-based retrievers.<br>
|
| 376 |
"""
|
| 377 |
return info_text
|
| 378 |
|
| 379 |
+
def get_example_questions(as_dataset=True):
|
| 380 |
+
"""Get example questions"""
|
| 381 |
questions = [
|
| 382 |
# "What is today's date?",
|
| 383 |
"Summarize emails from the most recent two months",
|
|
|
|
| 386 |
"Who reported installation problems in 2023-2024?",
|
| 387 |
]
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
| 390 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 391 |
|
| 392 |
+
def get_multi_tool_questions(as_dataset=True):
|
| 393 |
+
"""Get multi-tool example questions"""
|
| 394 |
questions = [
|
| 395 |
"Differences between lapply and for loops",
|
| 396 |
"Discuss pipe operator usage in 2022, 2023, and 2024",
|
|
|
|
| 398 |
|
| 399 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 400 |
|
| 401 |
+
def get_multi_turn_questions(as_dataset=True):
|
| 402 |
+
"""Get multi-turn example questions"""
|
| 403 |
questions = [
|
| 404 |
"Lookup emails that reference bugs.r-project.org in 2025",
|
| 405 |
"Did the authors you cited report bugs before 2025?",
|
|
|
|
| 411 |
# Left column: Intro, Compute, Chat
|
| 412 |
with gr.Column(scale=2):
|
| 413 |
with gr.Row(elem_classes=["row-container"]):
|
| 414 |
+
with gr.Column(scale=4):
|
| 415 |
intro = gr.Markdown(get_intro_text())
|
| 416 |
with gr.Column(scale=1):
|
| 417 |
+
gr.Radio(
|
| 418 |
+
["Auto", "R-help", "R-devel", "R-pkg-devel"],
|
| 419 |
+
label="Mailing List",
|
| 420 |
+
interactive=False,
|
| 421 |
+
)
|
| 422 |
with gr.Group() as chat_interface:
|
| 423 |
chatbot.render()
|
| 424 |
input.render()
|
|
|
|
| 429 |
missing_data.render()
|
| 430 |
# Right column: Info, Examples
|
| 431 |
with gr.Column(scale=1):
|
| 432 |
+
status = gr.Markdown(get_status_text())
|
| 433 |
with gr.Accordion("ℹ️ More Info", open=False):
|
| 434 |
info = gr.Markdown(get_info_text())
|
| 435 |
with gr.Accordion("💡 Examples", open=True):
|
| 436 |
# Add some helpful examples
|
| 437 |
example_questions = gr.Examples(
|
| 438 |
+
examples=get_example_questions(as_dataset=False),
|
|
|
|
|
|
|
| 439 |
inputs=[input],
|
| 440 |
label="Click an example to fill the message box",
|
| 441 |
)
|
| 442 |
multi_tool_questions = gr.Examples(
|
| 443 |
+
examples=get_multi_tool_questions(as_dataset=False),
|
|
|
|
|
|
|
| 444 |
inputs=[input],
|
| 445 |
label="Multiple retrievals",
|
| 446 |
)
|
| 447 |
multi_turn_questions = gr.Examples(
|
| 448 |
+
examples=get_multi_turn_questions(as_dataset=False),
|
|
|
|
|
|
|
| 449 |
inputs=[input],
|
| 450 |
label="Asking follow-up questions",
|
| 451 |
)
|
|
|
|
| 465 |
"""Return updated value for a component"""
|
| 466 |
return gr.update(value=value)
|
| 467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
def change_visibility(visible):
|
| 469 |
"""Return updated visibility state for a component"""
|
| 470 |
return gr.update(visible=visible)
|
|
|
|
| 488 |
# https://github.com/gradio-app/gradio/issues/9722
|
| 489 |
chatbot.clear(generate_thread_id, outputs=[thread_id], api_name=False)
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
input.submit(
|
| 492 |
# Submit input to the chatbot
|
| 493 |
to_workflow,
|
| 494 |
+
[input, chatbot, thread_id],
|
| 495 |
[chatbot, retrieved_emails, citations_text],
|
| 496 |
api_name=False,
|
| 497 |
)
|
eval.py
CHANGED
|
@@ -34,7 +34,7 @@ def load_questions_and_references(csv_path):
|
|
| 34 |
return questions, references
|
| 35 |
|
| 36 |
|
| 37 |
-
def build_eval_dataset(questions, references,
|
| 38 |
"""Build dataset for evaluation"""
|
| 39 |
dataset = []
|
| 40 |
for question, reference in zip(questions, references):
|
|
@@ -42,15 +42,15 @@ def build_eval_dataset(questions, references, compute_mode, workflow, search_typ
|
|
| 42 |
if workflow == "chain":
|
| 43 |
print("\n\n--- Question ---")
|
| 44 |
print(question)
|
| 45 |
-
response = RunChain(question,
|
| 46 |
print("--- Response ---")
|
| 47 |
print(response)
|
| 48 |
# Retrieve context documents for a question
|
| 49 |
-
retriever = BuildRetriever(
|
| 50 |
docs = retriever.invoke(question)
|
| 51 |
retrieved_contexts = [doc.page_content for doc in docs]
|
| 52 |
if workflow == "graph":
|
| 53 |
-
result = RunGraph(question,
|
| 54 |
retrieved_contexts = []
|
| 55 |
if "retrieved_emails" in result:
|
| 56 |
# Remove the source file names (e.g. R-help/2022-September.txt) as it confuses the evaluator
|
|
@@ -142,12 +142,6 @@ def main():
|
|
| 142 |
parser = argparse.ArgumentParser(
|
| 143 |
description="Evaluate RAG retrieval and generation."
|
| 144 |
)
|
| 145 |
-
parser.add_argument(
|
| 146 |
-
"--compute_mode",
|
| 147 |
-
choices=["remote", "local"],
|
| 148 |
-
required=True,
|
| 149 |
-
help="Compute mode: remote or local.",
|
| 150 |
-
)
|
| 151 |
parser.add_argument(
|
| 152 |
"--workflow",
|
| 153 |
choices=["chain", "graph"],
|
|
@@ -161,14 +155,11 @@ def main():
|
|
| 161 |
help="Search type: dense, sparse, or hybrid.",
|
| 162 |
)
|
| 163 |
args = parser.parse_args()
|
| 164 |
-
compute_mode = args.compute_mode
|
| 165 |
workflow = args.workflow
|
| 166 |
search_type = args.search_type
|
| 167 |
|
| 168 |
questions, references = load_questions_and_references("eval.csv")
|
| 169 |
-
dataset = build_eval_dataset(
|
| 170 |
-
questions, references, compute_mode, workflow, search_type
|
| 171 |
-
)
|
| 172 |
evaluation_dataset = EvaluationDataset.from_list(dataset)
|
| 173 |
|
| 174 |
# Set up LLM for evaluation
|
|
|
|
| 34 |
return questions, references
|
| 35 |
|
| 36 |
|
| 37 |
+
def build_eval_dataset(questions, references, workflow, search_type):
|
| 38 |
"""Build dataset for evaluation"""
|
| 39 |
dataset = []
|
| 40 |
for question, reference in zip(questions, references):
|
|
|
|
| 42 |
if workflow == "chain":
|
| 43 |
print("\n\n--- Question ---")
|
| 44 |
print(question)
|
| 45 |
+
response = RunChain(question, search_type)
|
| 46 |
print("--- Response ---")
|
| 47 |
print(response)
|
| 48 |
# Retrieve context documents for a question
|
| 49 |
+
retriever = BuildRetriever(search_type)
|
| 50 |
docs = retriever.invoke(question)
|
| 51 |
retrieved_contexts = [doc.page_content for doc in docs]
|
| 52 |
if workflow == "graph":
|
| 53 |
+
result = RunGraph(question, search_type)
|
| 54 |
retrieved_contexts = []
|
| 55 |
if "retrieved_emails" in result:
|
| 56 |
# Remove the source file names (e.g. R-help/2022-September.txt) as it confuses the evaluator
|
|
|
|
| 142 |
parser = argparse.ArgumentParser(
|
| 143 |
description="Evaluate RAG retrieval and generation."
|
| 144 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
parser.add_argument(
|
| 146 |
"--workflow",
|
| 147 |
choices=["chain", "graph"],
|
|
|
|
| 155 |
help="Search type: dense, sparse, or hybrid.",
|
| 156 |
)
|
| 157 |
args = parser.parse_args()
|
|
|
|
| 158 |
workflow = args.workflow
|
| 159 |
search_type = args.search_type
|
| 160 |
|
| 161 |
questions, references = load_questions_and_references("eval.csv")
|
| 162 |
+
dataset = build_eval_dataset(questions, references, workflow, search_type)
|
|
|
|
|
|
|
| 163 |
evaluation_dataset = EvaluationDataset.from_list(dataset)
|
| 164 |
|
| 165 |
# Set up LLM for evaluation
|
graph.py
CHANGED
|
@@ -2,15 +2,13 @@ from langchain_core.messages import SystemMessage, ToolMessage, HumanMessage, AI
|
|
| 2 |
from langgraph.graph import START, END, MessagesState, StateGraph
|
| 3 |
from langchain_core.tools import tool
|
| 4 |
from langgraph.prebuilt import ToolNode, tools_condition
|
| 5 |
-
from langchain_huggingface import ChatHuggingFace
|
| 6 |
from typing import Optional
|
| 7 |
import datetime
|
| 8 |
import os
|
| 9 |
|
| 10 |
# Local modules
|
| 11 |
from retriever import BuildRetriever
|
| 12 |
-
from prompts import query_prompt, answer_prompt
|
| 13 |
-
from mods.tool_calling_llm import ToolCallingLLM
|
| 14 |
|
| 15 |
# For tracing (disabled)
|
| 16 |
# os.environ["LANGSMITH_TRACING"] = "true"
|
|
@@ -105,48 +103,18 @@ def normalize_messages(messages, summaries_for=None):
|
|
| 105 |
return messages
|
| 106 |
|
| 107 |
|
| 108 |
-
def ToolifyHF(chat_model, system_message):
|
| 109 |
-
"""
|
| 110 |
-
Get a Hugging Face model ready for bind_tools().
|
| 111 |
-
"""
|
| 112 |
-
|
| 113 |
-
# Combine system prompt and tools template
|
| 114 |
-
tool_system_prompt_template = system_message + generic_tools_template
|
| 115 |
-
|
| 116 |
-
class HuggingFaceWithTools(ToolCallingLLM, ChatHuggingFace):
|
| 117 |
-
def __init__(self, **kwargs):
|
| 118 |
-
super().__init__(**kwargs)
|
| 119 |
-
|
| 120 |
-
chat_model = HuggingFaceWithTools(
|
| 121 |
-
llm=chat_model.llm,
|
| 122 |
-
tool_system_prompt_template=tool_system_prompt_template,
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
return chat_model
|
| 126 |
-
|
| 127 |
-
|
| 128 |
def BuildGraph(
|
| 129 |
chat_model,
|
| 130 |
-
compute_mode,
|
| 131 |
search_type,
|
| 132 |
top_k=6,
|
| 133 |
-
think_query=False,
|
| 134 |
-
think_answer=False,
|
| 135 |
-
local_citations=False,
|
| 136 |
-
embedding_ckpt_dir=None,
|
| 137 |
):
|
| 138 |
"""
|
| 139 |
Build conversational RAG graph for email retrieval and answering with citations.
|
| 140 |
|
| 141 |
Args:
|
| 142 |
-
chat_model: LangChain chat model
|
| 143 |
-
compute_mode: remote or local (for retriever)
|
| 144 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 145 |
top_k: number of documents to retrieve
|
| 146 |
-
think_query: Whether to use thinking mode for the query (local model)
|
| 147 |
-
think_answer: Whether to use thinking mode for the answer (local model)
|
| 148 |
-
local_citations: Whether to use answer_with_citations() tool (local model)
|
| 149 |
-
embedding_ckpt_dir: Directory for embedding model checkpoint
|
| 150 |
|
| 151 |
Based on:
|
| 152 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
@@ -158,7 +126,7 @@ def BuildGraph(
|
|
| 158 |
# Build graph with chat model
|
| 159 |
from langchain_openai import ChatOpenAI
|
| 160 |
chat_model = ChatOpenAI(model="gpt-4o-mini")
|
| 161 |
-
graph = BuildGraph(chat_model, "
|
| 162 |
|
| 163 |
# Add simple in-memory checkpointer
|
| 164 |
from langgraph.checkpoint.memory import MemorySaver
|
|
@@ -198,7 +166,10 @@ def BuildGraph(
|
|
| 198 |
months (str, optional): One or more months separated by spaces
|
| 199 |
"""
|
| 200 |
retriever = BuildRetriever(
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
| 202 |
)
|
| 203 |
# For now, just add the months to the search query
|
| 204 |
if months:
|
|
@@ -230,55 +201,23 @@ def BuildGraph(
|
|
| 230 |
"""
|
| 231 |
return answer, citations
|
| 232 |
|
| 233 |
-
# Add tools to the
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
| 237 |
-
query_model = ToolifyHF(
|
| 238 |
-
chat_model, query_prompt(chat_model, think=think_query)
|
| 239 |
-
).bind_tools([retrieve_emails])
|
| 240 |
-
if local_citations:
|
| 241 |
-
answer_model = ToolifyHF(
|
| 242 |
-
chat_model,
|
| 243 |
-
answer_prompt(chat_model, think=think_answer, with_tools=True),
|
| 244 |
-
).bind_tools([answer_with_citations])
|
| 245 |
-
else:
|
| 246 |
-
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
| 247 |
-
answer_model = chat_model
|
| 248 |
-
else:
|
| 249 |
-
# For remote model (OpenAI API)
|
| 250 |
-
query_model = chat_model.bind_tools([retrieve_emails])
|
| 251 |
-
answer_model = chat_model.bind_tools([answer_with_citations])
|
| 252 |
|
| 253 |
# Initialize the graph object
|
| 254 |
graph = StateGraph(MessagesState)
|
| 255 |
|
| 256 |
def query(state: MessagesState):
|
| 257 |
"""Queries the retriever with the chat model"""
|
| 258 |
-
|
| 259 |
-
# Don't include the system message here because it's defined in ToolCallingLLM
|
| 260 |
-
messages = state["messages"]
|
| 261 |
-
messages = normalize_messages(messages)
|
| 262 |
-
else:
|
| 263 |
-
messages = [SystemMessage(query_prompt(chat_model))] + state["messages"]
|
| 264 |
response = query_model.invoke(messages)
|
| 265 |
|
| 266 |
return {"messages": response}
|
| 267 |
|
| 268 |
def answer(state: MessagesState):
|
| 269 |
"""Generates an answer with the chat model"""
|
| 270 |
-
|
| 271 |
-
messages = state["messages"]
|
| 272 |
-
messages = normalize_messages(messages)
|
| 273 |
-
if not local_citations:
|
| 274 |
-
# Add the system message here if we're not using tools
|
| 275 |
-
messages = [
|
| 276 |
-
SystemMessage(answer_prompt(chat_model, think=think_answer))
|
| 277 |
-
] + messages
|
| 278 |
-
else:
|
| 279 |
-
messages = [
|
| 280 |
-
SystemMessage(answer_prompt(chat_model, with_tools=True))
|
| 281 |
-
] + state["messages"]
|
| 282 |
response = answer_model.invoke(messages)
|
| 283 |
|
| 284 |
return {"messages": response}
|
|
|
|
| 2 |
from langgraph.graph import START, END, MessagesState, StateGraph
|
| 3 |
from langchain_core.tools import tool
|
| 4 |
from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
|
| 5 |
from typing import Optional
|
| 6 |
import datetime
|
| 7 |
import os
|
| 8 |
|
| 9 |
# Local modules
|
| 10 |
from retriever import BuildRetriever
|
| 11 |
+
from prompts import query_prompt, answer_prompt
|
|
|
|
| 12 |
|
| 13 |
# For tracing (disabled)
|
| 14 |
# os.environ["LANGSMITH_TRACING"] = "true"
|
|
|
|
| 103 |
return messages
|
| 104 |
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
def BuildGraph(
|
| 107 |
chat_model,
|
|
|
|
| 108 |
search_type,
|
| 109 |
top_k=6,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
):
|
| 111 |
"""
|
| 112 |
Build conversational RAG graph for email retrieval and answering with citations.
|
| 113 |
|
| 114 |
Args:
|
| 115 |
+
chat_model: LangChain chat model
|
|
|
|
| 116 |
search_type: dense, sparse, or hybrid (for retriever)
|
| 117 |
top_k: number of documents to retrieve
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
Based on:
|
| 120 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
|
|
| 126 |
# Build graph with chat model
|
| 127 |
from langchain_openai import ChatOpenAI
|
| 128 |
chat_model = ChatOpenAI(model="gpt-4o-mini")
|
| 129 |
+
graph = BuildGraph(chat_model, "hybrid")
|
| 130 |
|
| 131 |
# Add simple in-memory checkpointer
|
| 132 |
from langgraph.checkpoint.memory import MemorySaver
|
|
|
|
| 166 |
months (str, optional): One or more months separated by spaces
|
| 167 |
"""
|
| 168 |
retriever = BuildRetriever(
|
| 169 |
+
search_type,
|
| 170 |
+
top_k,
|
| 171 |
+
start_year,
|
| 172 |
+
end_year,
|
| 173 |
)
|
| 174 |
# For now, just add the months to the search query
|
| 175 |
if months:
|
|
|
|
| 201 |
"""
|
| 202 |
return answer, citations
|
| 203 |
|
| 204 |
+
# Add tools to the chat model
|
| 205 |
+
query_model = chat_model.bind_tools([retrieve_emails])
|
| 206 |
+
answer_model = chat_model.bind_tools([answer_with_citations])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
# Initialize the graph object
|
| 209 |
graph = StateGraph(MessagesState)
|
| 210 |
|
| 211 |
def query(state: MessagesState):
|
| 212 |
"""Queries the retriever with the chat model"""
|
| 213 |
+
messages = [SystemMessage(query_prompt())] + state["messages"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
response = query_model.invoke(messages)
|
| 215 |
|
| 216 |
return {"messages": response}
|
| 217 |
|
| 218 |
def answer(state: MessagesState):
|
| 219 |
"""Generates an answer with the chat model"""
|
| 220 |
+
messages = [SystemMessage(answer_prompt())] + state["messages"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
response = answer_model.invoke(messages)
|
| 222 |
|
| 223 |
return {"messages": response}
|
images/graph_LR.png
CHANGED
|
|
index.py
CHANGED
|
@@ -9,14 +9,13 @@ from retriever import BuildRetriever, db_dir
|
|
| 9 |
from mods.bm25s_retriever import BM25SRetriever
|
| 10 |
|
| 11 |
|
| 12 |
-
def ProcessFile(file_path, search_type: str = "dense"
|
| 13 |
"""
|
| 14 |
Wrapper function to process file for dense or sparse search
|
| 15 |
|
| 16 |
Args:
|
| 17 |
file_path: File to process
|
| 18 |
search_type: Type of search to use. Options: "dense", "sparse"
|
| 19 |
-
compute_mode: Compute mode for embeddings (remote or local)
|
| 20 |
"""
|
| 21 |
|
| 22 |
# Preprocess: remove quoted lines and handle email boundaries
|
|
@@ -69,7 +68,7 @@ def ProcessFile(file_path, search_type: str = "dense", compute_mode: str = "remo
|
|
| 69 |
ProcessFileSparse(truncated_temp_file, file_path)
|
| 70 |
elif search_type == "dense":
|
| 71 |
# Handle dense search with ChromaDB
|
| 72 |
-
ProcessFileDense(truncated_temp_file, file_path
|
| 73 |
else:
|
| 74 |
raise ValueError(f"Unsupported search type: {search_type}")
|
| 75 |
finally:
|
|
@@ -81,12 +80,12 @@ def ProcessFile(file_path, search_type: str = "dense", compute_mode: str = "remo
|
|
| 81 |
pass
|
| 82 |
|
| 83 |
|
| 84 |
-
def ProcessFileDense(cleaned_temp_file, file_path
|
| 85 |
"""
|
| 86 |
Process file for dense vector search using ChromaDB
|
| 87 |
"""
|
| 88 |
# Get a retriever instance
|
| 89 |
-
retriever = BuildRetriever(
|
| 90 |
# Load cleaned text file
|
| 91 |
loader = TextLoader(cleaned_temp_file)
|
| 92 |
documents = loader.load()
|
|
|
|
| 9 |
from mods.bm25s_retriever import BM25SRetriever
|
| 10 |
|
| 11 |
|
| 12 |
+
def ProcessFile(file_path, search_type: str = "dense"):
|
| 13 |
"""
|
| 14 |
Wrapper function to process file for dense or sparse search
|
| 15 |
|
| 16 |
Args:
|
| 17 |
file_path: File to process
|
| 18 |
search_type: Type of search to use. Options: "dense", "sparse"
|
|
|
|
| 19 |
"""
|
| 20 |
|
| 21 |
# Preprocess: remove quoted lines and handle email boundaries
|
|
|
|
| 68 |
ProcessFileSparse(truncated_temp_file, file_path)
|
| 69 |
elif search_type == "dense":
|
| 70 |
# Handle dense search with ChromaDB
|
| 71 |
+
ProcessFileDense(truncated_temp_file, file_path)
|
| 72 |
else:
|
| 73 |
raise ValueError(f"Unsupported search type: {search_type}")
|
| 74 |
finally:
|
|
|
|
| 80 |
pass
|
| 81 |
|
| 82 |
|
| 83 |
+
def ProcessFileDense(cleaned_temp_file, file_path):
|
| 84 |
"""
|
| 85 |
Process file for dense vector search using ChromaDB
|
| 86 |
"""
|
| 87 |
# Get a retriever instance
|
| 88 |
+
retriever = BuildRetriever("dense")
|
| 89 |
# Load cleaned text file
|
| 90 |
loader = TextLoader(cleaned_temp_file)
|
| 91 |
documents = loader.load()
|
main.py
CHANGED
|
@@ -5,20 +5,15 @@ from langchain_core.prompts import ChatPromptTemplate
|
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
from langchain_core.messages import SystemMessage
|
| 7 |
from langchain_core.messages import ToolMessage
|
|
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
from datetime import datetime
|
| 10 |
import logging
|
| 11 |
-
import torch
|
| 12 |
import glob
|
| 13 |
import ast
|
| 14 |
import os
|
| 15 |
|
| 16 |
-
# Imports for local and remote chat models
|
| 17 |
-
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
| 18 |
-
from langchain_openai import ChatOpenAI
|
| 19 |
-
|
| 20 |
# Local modules
|
| 21 |
-
from pipeline import MyTextGenerationPipeline
|
| 22 |
from retriever import BuildRetriever, db_dir
|
| 23 |
from prompts import answer_prompt
|
| 24 |
from index import ProcessFile
|
|
@@ -32,16 +27,9 @@ from graph import BuildGraph
|
|
| 32 |
# Setup environment variables
|
| 33 |
load_dotenv(dotenv_path=".env", override=True)
|
| 34 |
|
| 35 |
-
# Define the
|
| 36 |
openai_model = "gpt-4o-mini"
|
| 37 |
|
| 38 |
-
# Get the local model ID
|
| 39 |
-
model_id = os.getenv("MODEL_ID")
|
| 40 |
-
if model_id is None:
|
| 41 |
-
# model_id = "HuggingFaceTB/SmolLM3-3B"
|
| 42 |
-
model_id = "google/gemma-3-12b-it"
|
| 43 |
-
# model_id = "Qwen/Qwen3-14B"
|
| 44 |
-
|
| 45 |
# Suppress these messages:
|
| 46 |
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
|
| 47 |
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
|
|
@@ -50,30 +38,29 @@ httpx_logger = logging.getLogger("httpx")
|
|
| 50 |
httpx_logger.setLevel(logging.WARNING)
|
| 51 |
|
| 52 |
|
| 53 |
-
def ProcessDirectory(path
|
| 54 |
"""
|
| 55 |
Update vector store and sparse index for files in a directory, only adding new or updated files
|
| 56 |
|
| 57 |
Args:
|
| 58 |
path: Directory to process
|
| 59 |
-
compute_mode: Compute mode for embeddings (remote or local)
|
| 60 |
|
| 61 |
Usage example:
|
| 62 |
-
ProcessDirectory("R-help"
|
| 63 |
"""
|
| 64 |
|
| 65 |
# TODO: use UUID to process only changed documents
|
| 66 |
# https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
|
| 67 |
|
| 68 |
# Get a dense retriever instance
|
| 69 |
-
retriever = BuildRetriever(
|
| 70 |
|
| 71 |
# List all text files in target directory
|
| 72 |
file_paths = glob.glob(f"{path}/*.txt")
|
| 73 |
for file_path in file_paths:
|
| 74 |
|
| 75 |
# Process file for sparse search (BM25S)
|
| 76 |
-
ProcessFile(file_path, "sparse"
|
| 77 |
|
| 78 |
# Logic for dense search: skip file if already indexed
|
| 79 |
# Look for existing embeddings for this file
|
|
@@ -103,7 +90,7 @@ def ProcessDirectory(path, compute_mode):
|
|
| 103 |
update_file = True
|
| 104 |
|
| 105 |
if add_file:
|
| 106 |
-
ProcessFile(file_path, "dense"
|
| 107 |
|
| 108 |
if update_file:
|
| 109 |
print(f"Chroma: updated embeddings for {file_path}")
|
|
@@ -114,7 +101,7 @@ def ProcessDirectory(path, compute_mode):
|
|
| 114 |
]
|
| 115 |
files_to_keep = list(set(used_doc_ids))
|
| 116 |
# Get all files in the file store
|
| 117 |
-
file_store = f"{db_dir}/
|
| 118 |
all_files = os.listdir(file_store)
|
| 119 |
# Iterate through the files and delete those not in the list
|
| 120 |
for file in all_files:
|
|
@@ -127,93 +114,32 @@ def ProcessDirectory(path, compute_mode):
|
|
| 127 |
print(f"Chroma: no change for {file_path}")
|
| 128 |
|
| 129 |
|
| 130 |
-
def GetChatModel(compute_mode, ckpt_dir=None):
|
| 131 |
-
"""
|
| 132 |
-
Get a chat model.
|
| 133 |
-
|
| 134 |
-
Args:
|
| 135 |
-
compute_mode: Compute mode for chat model (remote or local)
|
| 136 |
-
ckpt_dir: Checkpoint directory for model weights (optional)
|
| 137 |
-
"""
|
| 138 |
-
|
| 139 |
-
if compute_mode == "remote":
|
| 140 |
-
|
| 141 |
-
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 142 |
-
|
| 143 |
-
if compute_mode == "local":
|
| 144 |
-
|
| 145 |
-
# Don't try to use local models without a GPU
|
| 146 |
-
if compute_mode == "local" and not torch.cuda.is_available():
|
| 147 |
-
raise Exception("Local chat model selected without GPU")
|
| 148 |
-
|
| 149 |
-
# Define the pipeline to pass to the HuggingFacePipeline class
|
| 150 |
-
# https://huggingface.co/blog/langchain
|
| 151 |
-
id_or_dir = ckpt_dir if ckpt_dir else model_id
|
| 152 |
-
tokenizer = AutoTokenizer.from_pretrained(id_or_dir)
|
| 153 |
-
model = AutoModelForCausalLM.from_pretrained(
|
| 154 |
-
id_or_dir,
|
| 155 |
-
# We need this to load the model in BF16 instead of fp32 (torch.float)
|
| 156 |
-
torch_dtype=torch.bfloat16,
|
| 157 |
-
# Enable FlashAttention (requires pip install flash-attn)
|
| 158 |
-
# https://huggingface.co/docs/transformers/en/attention_interface
|
| 159 |
-
# https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention
|
| 160 |
-
# attn_implementation="flash_attention_2",
|
| 161 |
-
)
|
| 162 |
-
# For Flash Attention version of Qwen3
|
| 163 |
-
tokenizer.padding_side = "left"
|
| 164 |
-
|
| 165 |
-
# Use MyTextGenerationPipeline with custom preprocess() method
|
| 166 |
-
pipe = MyTextGenerationPipeline(
|
| 167 |
-
model=model,
|
| 168 |
-
tokenizer=tokenizer,
|
| 169 |
-
# ToolCallingLLM needs return_full_text=False in order to parse just the assistant response
|
| 170 |
-
return_full_text=False,
|
| 171 |
-
# It seems that max_new_tokens has to be specified here, not in .invoke()
|
| 172 |
-
max_new_tokens=2000,
|
| 173 |
-
# Use padding for proper alignment for FlashAttention
|
| 174 |
-
# Part of fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
|
| 175 |
-
# https://github.com/google-deepmind/gemma/issues/169
|
| 176 |
-
padding="longest",
|
| 177 |
-
)
|
| 178 |
-
# We need the task so HuggingFacePipeline can deal with our class
|
| 179 |
-
pipe.task = "text-generation"
|
| 180 |
-
|
| 181 |
-
llm = HuggingFacePipeline(pipeline=pipe)
|
| 182 |
-
chat_model = ChatHuggingFace(llm=llm)
|
| 183 |
-
|
| 184 |
-
return chat_model
|
| 185 |
-
|
| 186 |
-
|
| 187 |
def RunChain(
|
| 188 |
query,
|
| 189 |
-
compute_mode: str = "remote",
|
| 190 |
search_type: str = "hybrid",
|
| 191 |
-
think: bool = False,
|
| 192 |
):
|
| 193 |
"""
|
| 194 |
Run chain to retrieve documents and send to chat
|
| 195 |
|
| 196 |
Args:
|
| 197 |
query: User's query
|
| 198 |
-
compute_mode: Compute mode for embedding and chat models (remote or local)
|
| 199 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 200 |
-
think: Control thinking mode for SmolLM3
|
| 201 |
|
| 202 |
Example:
|
| 203 |
RunChain("What R functions are discussed?")
|
| 204 |
"""
|
| 205 |
|
| 206 |
# Get retriever instance
|
| 207 |
-
retriever = BuildRetriever(
|
| 208 |
|
| 209 |
if retriever is None:
|
| 210 |
return "No retriever available. Please process some documents first."
|
| 211 |
|
| 212 |
# Get chat model (LLM)
|
| 213 |
-
chat_model =
|
| 214 |
|
| 215 |
-
# Get
|
| 216 |
-
system_prompt = answer_prompt(
|
| 217 |
|
| 218 |
# Create a prompt template
|
| 219 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
@@ -244,22 +170,16 @@ def RunChain(
|
|
| 244 |
|
| 245 |
def RunGraph(
|
| 246 |
query: str,
|
| 247 |
-
compute_mode: str = "remote",
|
| 248 |
search_type: str = "hybrid",
|
| 249 |
top_k: int = 6,
|
| 250 |
-
think_query=False,
|
| 251 |
-
think_answer=False,
|
| 252 |
thread_id=None,
|
| 253 |
):
|
| 254 |
"""Run graph for conversational RAG app
|
| 255 |
|
| 256 |
Args:
|
| 257 |
query: User query to start the chat
|
| 258 |
-
compute_mode: Compute mode for embedding and chat models (remote or local)
|
| 259 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 260 |
top_k: Number of documents to retrieve
|
| 261 |
-
think_query: Whether to use thinking mode for the query
|
| 262 |
-
think_answer: Whether to use thinking mode for the answer
|
| 263 |
thread_id: Thread ID for memory (optional)
|
| 264 |
|
| 265 |
Example:
|
|
@@ -267,15 +187,12 @@ def RunGraph(
|
|
| 267 |
"""
|
| 268 |
|
| 269 |
# Get chat model used in both query and generate steps
|
| 270 |
-
chat_model =
|
| 271 |
# Build the graph
|
| 272 |
graph_builder = BuildGraph(
|
| 273 |
chat_model,
|
| 274 |
-
compute_mode,
|
| 275 |
search_type,
|
| 276 |
top_k,
|
| 277 |
-
think_query,
|
| 278 |
-
think_answer,
|
| 279 |
)
|
| 280 |
|
| 281 |
# Compile the graph with an in-memory checkpointer
|
|
|
|
| 5 |
from langgraph.checkpoint.memory import MemorySaver
|
| 6 |
from langchain_core.messages import SystemMessage
|
| 7 |
from langchain_core.messages import ToolMessage
|
| 8 |
+
from langchain_openai import ChatOpenAI
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
from datetime import datetime
|
| 11 |
import logging
|
|
|
|
| 12 |
import glob
|
| 13 |
import ast
|
| 14 |
import os
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
# Local modules
|
|
|
|
| 17 |
from retriever import BuildRetriever, db_dir
|
| 18 |
from prompts import answer_prompt
|
| 19 |
from index import ProcessFile
|
|
|
|
| 27 |
# Setup environment variables
|
| 28 |
load_dotenv(dotenv_path=".env", override=True)
|
| 29 |
|
| 30 |
+
# Define the OpenAI model
|
| 31 |
openai_model = "gpt-4o-mini"
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Suppress these messages:
|
| 34 |
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
|
| 35 |
# INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
|
|
|
|
| 38 |
httpx_logger.setLevel(logging.WARNING)
|
| 39 |
|
| 40 |
|
| 41 |
+
def ProcessDirectory(path):
|
| 42 |
"""
|
| 43 |
Update vector store and sparse index for files in a directory, only adding new or updated files
|
| 44 |
|
| 45 |
Args:
|
| 46 |
path: Directory to process
|
|
|
|
| 47 |
|
| 48 |
Usage example:
|
| 49 |
+
ProcessDirectory("R-help")
|
| 50 |
"""
|
| 51 |
|
| 52 |
# TODO: use UUID to process only changed documents
|
| 53 |
# https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
|
| 54 |
|
| 55 |
# Get a dense retriever instance
|
| 56 |
+
retriever = BuildRetriever("dense")
|
| 57 |
|
| 58 |
# List all text files in target directory
|
| 59 |
file_paths = glob.glob(f"{path}/*.txt")
|
| 60 |
for file_path in file_paths:
|
| 61 |
|
| 62 |
# Process file for sparse search (BM25S)
|
| 63 |
+
ProcessFile(file_path, "sparse")
|
| 64 |
|
| 65 |
# Logic for dense search: skip file if already indexed
|
| 66 |
# Look for existing embeddings for this file
|
|
|
|
| 90 |
update_file = True
|
| 91 |
|
| 92 |
if add_file:
|
| 93 |
+
ProcessFile(file_path, "dense")
|
| 94 |
|
| 95 |
if update_file:
|
| 96 |
print(f"Chroma: updated embeddings for {file_path}")
|
|
|
|
| 101 |
]
|
| 102 |
files_to_keep = list(set(used_doc_ids))
|
| 103 |
# Get all files in the file store
|
| 104 |
+
file_store = f"{db_dir}/file_store"
|
| 105 |
all_files = os.listdir(file_store)
|
| 106 |
# Iterate through the files and delete those not in the list
|
| 107 |
for file in all_files:
|
|
|
|
| 114 |
print(f"Chroma: no change for {file_path}")
|
| 115 |
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
def RunChain(
|
| 118 |
query,
|
|
|
|
| 119 |
search_type: str = "hybrid",
|
|
|
|
| 120 |
):
|
| 121 |
"""
|
| 122 |
Run chain to retrieve documents and send to chat
|
| 123 |
|
| 124 |
Args:
|
| 125 |
query: User's query
|
|
|
|
| 126 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
|
|
|
| 127 |
|
| 128 |
Example:
|
| 129 |
RunChain("What R functions are discussed?")
|
| 130 |
"""
|
| 131 |
|
| 132 |
# Get retriever instance
|
| 133 |
+
retriever = BuildRetriever(search_type)
|
| 134 |
|
| 135 |
if retriever is None:
|
| 136 |
return "No retriever available. Please process some documents first."
|
| 137 |
|
| 138 |
# Get chat model (LLM)
|
| 139 |
+
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 140 |
|
| 141 |
+
# Get system prompt
|
| 142 |
+
system_prompt = answer_prompt()
|
| 143 |
|
| 144 |
# Create a prompt template
|
| 145 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
|
|
| 170 |
|
| 171 |
def RunGraph(
|
| 172 |
query: str,
|
|
|
|
| 173 |
search_type: str = "hybrid",
|
| 174 |
top_k: int = 6,
|
|
|
|
|
|
|
| 175 |
thread_id=None,
|
| 176 |
):
|
| 177 |
"""Run graph for conversational RAG app
|
| 178 |
|
| 179 |
Args:
|
| 180 |
query: User query to start the chat
|
|
|
|
| 181 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 182 |
top_k: Number of documents to retrieve
|
|
|
|
|
|
|
| 183 |
thread_id: Thread ID for memory (optional)
|
| 184 |
|
| 185 |
Example:
|
|
|
|
| 187 |
"""
|
| 188 |
|
| 189 |
# Get chat model used in both query and generate steps
|
| 190 |
+
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 191 |
# Build the graph
|
| 192 |
graph_builder = BuildGraph(
|
| 193 |
chat_model,
|
|
|
|
| 194 |
search_type,
|
| 195 |
top_k,
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
# Compile the graph with an in-memory checkpointer
|
mods/tool_calling_llm.py
DELETED
|
@@ -1,313 +0,0 @@
|
|
| 1 |
-
import re
|
| 2 |
-
import json
|
| 3 |
-
import uuid
|
| 4 |
-
import warnings
|
| 5 |
-
from abc import ABC
|
| 6 |
-
from typing import (
|
| 7 |
-
Any,
|
| 8 |
-
AsyncIterator,
|
| 9 |
-
Callable,
|
| 10 |
-
Dict,
|
| 11 |
-
List,
|
| 12 |
-
Optional,
|
| 13 |
-
Sequence,
|
| 14 |
-
Tuple,
|
| 15 |
-
Type,
|
| 16 |
-
Union,
|
| 17 |
-
cast,
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
from langchain_core.callbacks import (
|
| 21 |
-
AsyncCallbackManagerForLLMRun,
|
| 22 |
-
CallbackManagerForLLMRun,
|
| 23 |
-
)
|
| 24 |
-
from langchain_core.language_models import BaseChatModel, LanguageModelInput
|
| 25 |
-
from langchain_core.messages import (
|
| 26 |
-
SystemMessage,
|
| 27 |
-
AIMessage,
|
| 28 |
-
BaseMessage,
|
| 29 |
-
BaseMessageChunk,
|
| 30 |
-
ToolCall,
|
| 31 |
-
)
|
| 32 |
-
from langchain_core.outputs import ChatGeneration, ChatResult
|
| 33 |
-
from langchain_core.prompts import SystemMessagePromptTemplate
|
| 34 |
-
from pydantic import BaseModel
|
| 35 |
-
from langchain_core.runnables import Runnable, RunnableConfig
|
| 36 |
-
from langchain_core.tools import BaseTool
|
| 37 |
-
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 38 |
-
|
| 39 |
-
DEFAULT_SYSTEM_TEMPLATE = """You have access to the following tools:
|
| 40 |
-
|
| 41 |
-
{tools}
|
| 42 |
-
|
| 43 |
-
You must always select one of the above tools and respond with only a JSON object matching the following schema:
|
| 44 |
-
|
| 45 |
-
{{
|
| 46 |
-
"tool": <name of selected tool 1>,
|
| 47 |
-
"tool_input": <parameters for selected tool 1, matching the tool's JSON schema>
|
| 48 |
-
}},
|
| 49 |
-
{{
|
| 50 |
-
"tool": <name of selected tool 2>,
|
| 51 |
-
"tool_input": <parameters for selected tool 2, matching the tool's JSON schema>
|
| 52 |
-
}}
|
| 53 |
-
""" # noqa: E501
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def extract_think(content):
|
| 57 |
-
# Added by Cursor 20250726 jmd
|
| 58 |
-
# Extract content within <think>...</think>
|
| 59 |
-
think_match = re.search(r"<think>(.*?)</think>", content, re.DOTALL)
|
| 60 |
-
think_text = think_match.group(1).strip() if think_match else ""
|
| 61 |
-
# Extract text after </think>
|
| 62 |
-
if think_match:
|
| 63 |
-
post_think = content[think_match.end() :].lstrip()
|
| 64 |
-
else:
|
| 65 |
-
# Check if content starts with <think> but missing closing tag
|
| 66 |
-
if content.strip().startswith("<think>"):
|
| 67 |
-
# Extract everything after <think>
|
| 68 |
-
think_start = content.find("<think>") + len("<think>")
|
| 69 |
-
think_text = content[think_start:].strip()
|
| 70 |
-
post_think = ""
|
| 71 |
-
else:
|
| 72 |
-
# No <think> found, so return entire content as post_think
|
| 73 |
-
post_think = content
|
| 74 |
-
return think_text, post_think
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
class ToolCallingLLM(BaseChatModel, ABC):
|
| 78 |
-
"""ToolCallingLLM mixin to enable tool calling features on non tool calling models.
|
| 79 |
-
|
| 80 |
-
Note: This is an incomplete mixin and should not be used directly. It must be used to extent an existing Chat Model.
|
| 81 |
-
|
| 82 |
-
Setup:
|
| 83 |
-
Install dependencies for your Chat Model.
|
| 84 |
-
Any API Keys or setup needed for your Chat Model is still applicable.
|
| 85 |
-
|
| 86 |
-
Key init args — completion params:
|
| 87 |
-
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
|
| 88 |
-
|
| 89 |
-
Key init args — client params:
|
| 90 |
-
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
|
| 91 |
-
|
| 92 |
-
See full list of supported init args and their descriptions in the params section.
|
| 93 |
-
|
| 94 |
-
Instantiate:
|
| 95 |
-
```
|
| 96 |
-
# Example implementation using LiteLLM
|
| 97 |
-
from langchain_community.chat_models import ChatLiteLLM
|
| 98 |
-
|
| 99 |
-
class LiteLLMFunctions(ToolCallingLLM, ChatLiteLLM):
|
| 100 |
-
|
| 101 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 102 |
-
super().__init__(**kwargs)
|
| 103 |
-
|
| 104 |
-
@property
|
| 105 |
-
def _llm_type(self) -> str:
|
| 106 |
-
return "litellm_functions"
|
| 107 |
-
|
| 108 |
-
llm = LiteLLMFunctions(model="ollama/phi3")
|
| 109 |
-
```
|
| 110 |
-
|
| 111 |
-
Invoke:
|
| 112 |
-
```
|
| 113 |
-
messages = [
|
| 114 |
-
("human", "What is the capital of France?")
|
| 115 |
-
]
|
| 116 |
-
llm.invoke(messages)
|
| 117 |
-
```
|
| 118 |
-
```
|
| 119 |
-
AIMessage(content='The capital of France is Paris.', id='run-497d0e1a-d63b-45e8-9c8b-5e76d99b9468-0')
|
| 120 |
-
```
|
| 121 |
-
|
| 122 |
-
Tool calling:
|
| 123 |
-
```
|
| 124 |
-
from pydantic import BaseModel, Field
|
| 125 |
-
|
| 126 |
-
class GetWeather(BaseModel):
|
| 127 |
-
'''Get the current weather in a given location'''
|
| 128 |
-
|
| 129 |
-
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
| 130 |
-
|
| 131 |
-
class GetPopulation(BaseModel):
|
| 132 |
-
'''Get the current population in a given location'''
|
| 133 |
-
|
| 134 |
-
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
| 135 |
-
|
| 136 |
-
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
| 137 |
-
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
|
| 138 |
-
ai_msg.tool_calls
|
| 139 |
-
```
|
| 140 |
-
```
|
| 141 |
-
[{'name': 'GetWeather', 'args': {'location': 'Austin, TX'}, 'id': 'call_25ed526917b94d8fa5db3fe30a8cf3c0'}]
|
| 142 |
-
```
|
| 143 |
-
|
| 144 |
-
Response metadata
|
| 145 |
-
Refer to the documentation of the Chat Model you wish to extend with Tool Calling.
|
| 146 |
-
|
| 147 |
-
""" # noqa: E501
|
| 148 |
-
|
| 149 |
-
tool_system_prompt_template: str = DEFAULT_SYSTEM_TEMPLATE
|
| 150 |
-
|
| 151 |
-
def __init__(self, **kwargs: Any) -> None:
|
| 152 |
-
super().__init__(**kwargs)
|
| 153 |
-
|
| 154 |
-
def _generate_system_message_and_functions(
|
| 155 |
-
self,
|
| 156 |
-
kwargs: Dict[str, Any],
|
| 157 |
-
) -> Tuple[BaseMessage, List]:
|
| 158 |
-
functions = kwargs.get("tools", [])
|
| 159 |
-
|
| 160 |
-
# Convert functions to OpenAI tool schema
|
| 161 |
-
functions = [convert_to_openai_tool(fn) for fn in functions]
|
| 162 |
-
# Create system message with tool descriptions
|
| 163 |
-
system_message_prompt_template = SystemMessagePromptTemplate.from_template(
|
| 164 |
-
self.tool_system_prompt_template
|
| 165 |
-
)
|
| 166 |
-
system_message = system_message_prompt_template.format(
|
| 167 |
-
tools=json.dumps(functions, indent=2)
|
| 168 |
-
)
|
| 169 |
-
return system_message, functions
|
| 170 |
-
|
| 171 |
-
def _process_response(
|
| 172 |
-
self, response_message: BaseMessage, functions: List[Dict]
|
| 173 |
-
) -> AIMessage:
|
| 174 |
-
if not isinstance(response_message.content, str):
|
| 175 |
-
raise ValueError("ToolCallingLLM does not support non-string output.")
|
| 176 |
-
|
| 177 |
-
# Extract <think>...</think> content and text after </think> for further processing 20250726 jmd
|
| 178 |
-
think_text, post_think = extract_think(response_message.content)
|
| 179 |
-
|
| 180 |
-
## For debugging
|
| 181 |
-
# print("post_think")
|
| 182 |
-
# print(post_think)
|
| 183 |
-
|
| 184 |
-
# Remove backticks around code blocks
|
| 185 |
-
post_think = re.sub(r"^```json", "", post_think)
|
| 186 |
-
post_think = re.sub(r"^```", "", post_think)
|
| 187 |
-
post_think = re.sub(r"```$", "", post_think)
|
| 188 |
-
# Remove intervening backticks from adjacent code blocks
|
| 189 |
-
post_think = re.sub(r"```\n```json", ",", post_think)
|
| 190 |
-
# Remove trailing comma (if there is one)
|
| 191 |
-
post_think = post_think.rstrip(",")
|
| 192 |
-
# Parse output for JSON (support multiple objects separated by commas)
|
| 193 |
-
try:
|
| 194 |
-
# Works for one JSON object, or multiple JSON objects enclosed in "[]"
|
| 195 |
-
parsed_json_results = json.loads(f"{post_think}")
|
| 196 |
-
if not isinstance(parsed_json_results, list):
|
| 197 |
-
parsed_json_results = [parsed_json_results]
|
| 198 |
-
except:
|
| 199 |
-
try:
|
| 200 |
-
# Works for multiple JSON objects not enclosed in "[]"
|
| 201 |
-
parsed_json_results = json.loads(f"[{post_think}]")
|
| 202 |
-
except json.JSONDecodeError:
|
| 203 |
-
# Return entire response if JSON wasn't parsed or is missing
|
| 204 |
-
return AIMessage(content=response_message.content)
|
| 205 |
-
|
| 206 |
-
# print("parsed_json_results")
|
| 207 |
-
# print(parsed_json_results)
|
| 208 |
-
|
| 209 |
-
tool_calls = []
|
| 210 |
-
for parsed_json_result in parsed_json_results:
|
| 211 |
-
# Get tool name from output
|
| 212 |
-
called_tool_name = (
|
| 213 |
-
parsed_json_result["tool"]
|
| 214 |
-
if "tool" in parsed_json_result
|
| 215 |
-
else (
|
| 216 |
-
parsed_json_result["name"] if "name" in parsed_json_result else None
|
| 217 |
-
)
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
# Check if tool name is in functions list
|
| 221 |
-
called_tool = next(
|
| 222 |
-
(fn for fn in functions if fn["function"]["name"] == called_tool_name),
|
| 223 |
-
None,
|
| 224 |
-
)
|
| 225 |
-
if called_tool is None:
|
| 226 |
-
# Issue a warning and skip this tool call
|
| 227 |
-
warnings.warn(f"Called tool ({called_tool_name}) not in functions list")
|
| 228 |
-
continue
|
| 229 |
-
|
| 230 |
-
# Get tool arguments from output
|
| 231 |
-
called_tool_arguments = (
|
| 232 |
-
parsed_json_result["tool_input"]
|
| 233 |
-
if "tool_input" in parsed_json_result
|
| 234 |
-
else (
|
| 235 |
-
parsed_json_result["parameters"]
|
| 236 |
-
if "parameters" in parsed_json_result
|
| 237 |
-
else {}
|
| 238 |
-
)
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
tool_calls.append(
|
| 242 |
-
ToolCall(
|
| 243 |
-
name=called_tool_name,
|
| 244 |
-
args=called_tool_arguments,
|
| 245 |
-
id=f"call_{str(uuid.uuid4()).replace('-', '')}",
|
| 246 |
-
)
|
| 247 |
-
)
|
| 248 |
-
|
| 249 |
-
if not tool_calls:
|
| 250 |
-
# If nothing valid, return original content
|
| 251 |
-
return AIMessage(content=response_message.content)
|
| 252 |
-
|
| 253 |
-
# Put together response message
|
| 254 |
-
response_message = AIMessage(
|
| 255 |
-
content=f"<think>\n{think_text}\n</think>",
|
| 256 |
-
tool_calls=tool_calls,
|
| 257 |
-
)
|
| 258 |
-
return response_message
|
| 259 |
-
|
| 260 |
-
def _generate(
|
| 261 |
-
self,
|
| 262 |
-
messages: List[BaseMessage],
|
| 263 |
-
stop: Optional[List[str]] = None,
|
| 264 |
-
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
| 265 |
-
**kwargs: Any,
|
| 266 |
-
) -> ChatResult:
|
| 267 |
-
system_message, functions = self._generate_system_message_and_functions(kwargs)
|
| 268 |
-
response_message = super()._generate( # type: ignore[safe-super]
|
| 269 |
-
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
|
| 270 |
-
)
|
| 271 |
-
response = self._process_response(
|
| 272 |
-
response_message.generations[0].message, functions
|
| 273 |
-
)
|
| 274 |
-
return ChatResult(generations=[ChatGeneration(message=response)])
|
| 275 |
-
|
| 276 |
-
async def _agenerate(
|
| 277 |
-
self,
|
| 278 |
-
messages: List[BaseMessage],
|
| 279 |
-
stop: Optional[List[str]] = None,
|
| 280 |
-
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
| 281 |
-
**kwargs: Any,
|
| 282 |
-
) -> ChatResult:
|
| 283 |
-
system_message, functions = self._generate_system_message_and_functions(kwargs)
|
| 284 |
-
response_message = await super()._agenerate(
|
| 285 |
-
[system_message] + messages, stop=stop, run_manager=run_manager, **kwargs
|
| 286 |
-
)
|
| 287 |
-
response = self._process_response(
|
| 288 |
-
response_message.generations[0].message, functions
|
| 289 |
-
)
|
| 290 |
-
return ChatResult(generations=[ChatGeneration(message=response)])
|
| 291 |
-
|
| 292 |
-
async def astream(
|
| 293 |
-
self,
|
| 294 |
-
input: LanguageModelInput,
|
| 295 |
-
config: Optional[RunnableConfig] = None,
|
| 296 |
-
*,
|
| 297 |
-
stop: Optional[List[str]] = None,
|
| 298 |
-
**kwargs: Any,
|
| 299 |
-
) -> AsyncIterator[BaseMessageChunk]:
|
| 300 |
-
system_message, functions = self._generate_system_message_and_functions(kwargs)
|
| 301 |
-
generation: Optional[BaseMessageChunk] = None
|
| 302 |
-
async for chunk in super().astream(
|
| 303 |
-
[system_message] + super()._convert_input(input).to_messages(),
|
| 304 |
-
stop=stop,
|
| 305 |
-
**kwargs,
|
| 306 |
-
):
|
| 307 |
-
if generation is None:
|
| 308 |
-
generation = chunk
|
| 309 |
-
else:
|
| 310 |
-
generation += chunk
|
| 311 |
-
assert generation is not None
|
| 312 |
-
response = self._process_response(generation, functions)
|
| 313 |
-
yield cast(BaseMessageChunk, response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pipeline.py
DELETED
|
@@ -1,86 +0,0 @@
|
|
| 1 |
-
from transformers.pipelines.text_generation import Chat
|
| 2 |
-
from transformers import TextGenerationPipeline
|
| 3 |
-
from typing import Dict
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
class MyTextGenerationPipeline(TextGenerationPipeline):
|
| 7 |
-
"""
|
| 8 |
-
This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs.
|
| 9 |
-
Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned"
|
| 10 |
-
https://github.com/google-deepmind/gemma/issues/169
|
| 11 |
-
NOTE: we also need padding="longest", which is set during class instantiation
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
def preprocess(
|
| 15 |
-
self,
|
| 16 |
-
prompt_text,
|
| 17 |
-
prefix="",
|
| 18 |
-
handle_long_generation=None,
|
| 19 |
-
add_special_tokens=None,
|
| 20 |
-
truncation=None,
|
| 21 |
-
padding=None,
|
| 22 |
-
max_length=None,
|
| 23 |
-
continue_final_message=None,
|
| 24 |
-
**generate_kwargs,
|
| 25 |
-
):
|
| 26 |
-
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
|
| 27 |
-
tokenizer_kwargs = {
|
| 28 |
-
"add_special_tokens": add_special_tokens,
|
| 29 |
-
"truncation": truncation,
|
| 30 |
-
"padding": padding,
|
| 31 |
-
"max_length": max_length,
|
| 32 |
-
"pad_to_multiple_of": 8,
|
| 33 |
-
}
|
| 34 |
-
tokenizer_kwargs = {
|
| 35 |
-
key: value for key, value in tokenizer_kwargs.items() if value is not None
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
if isinstance(prompt_text, Chat):
|
| 39 |
-
tokenizer_kwargs.pop(
|
| 40 |
-
"add_special_tokens", None
|
| 41 |
-
) # ignore add_special_tokens on chats
|
| 42 |
-
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
|
| 43 |
-
# because very few models support multiple separate, consecutive assistant messages
|
| 44 |
-
if continue_final_message is None:
|
| 45 |
-
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
|
| 46 |
-
inputs = self.tokenizer.apply_chat_template(
|
| 47 |
-
prompt_text.messages,
|
| 48 |
-
add_generation_prompt=not continue_final_message,
|
| 49 |
-
continue_final_message=continue_final_message,
|
| 50 |
-
return_dict=True,
|
| 51 |
-
return_tensors=self.framework,
|
| 52 |
-
**tokenizer_kwargs,
|
| 53 |
-
)
|
| 54 |
-
else:
|
| 55 |
-
inputs = self.tokenizer(
|
| 56 |
-
prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs
|
| 57 |
-
)
|
| 58 |
-
|
| 59 |
-
inputs["prompt_text"] = prompt_text
|
| 60 |
-
|
| 61 |
-
if handle_long_generation == "hole":
|
| 62 |
-
cur_len = inputs["input_ids"].shape[-1]
|
| 63 |
-
if "max_new_tokens" in generate_kwargs:
|
| 64 |
-
new_tokens = generate_kwargs["max_new_tokens"]
|
| 65 |
-
else:
|
| 66 |
-
new_tokens = (
|
| 67 |
-
generate_kwargs.get("max_length", self.generation_config.max_length)
|
| 68 |
-
- cur_len
|
| 69 |
-
)
|
| 70 |
-
if new_tokens < 0:
|
| 71 |
-
raise ValueError("We cannot infer how many new tokens are expected")
|
| 72 |
-
if cur_len + new_tokens > self.tokenizer.model_max_length:
|
| 73 |
-
keep_length = self.tokenizer.model_max_length - new_tokens
|
| 74 |
-
if keep_length <= 0:
|
| 75 |
-
raise ValueError(
|
| 76 |
-
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
|
| 77 |
-
" models max length"
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
|
| 81 |
-
if "attention_mask" in inputs:
|
| 82 |
-
inputs["attention_mask"] = inputs["attention_mask"][
|
| 83 |
-
:, -keep_length:
|
| 84 |
-
]
|
| 85 |
-
|
| 86 |
-
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompts.py
CHANGED
|
@@ -3,22 +3,16 @@ from util import get_sources, get_start_end_months
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
|
| 6 |
-
def check_prompt(prompt
|
| 7 |
-
"""Check for unassigned variables
|
| 8 |
# A sanity check that we don't have unassigned variables
|
| 9 |
-
# (this causes KeyError in parsing by ToolCallingLLM)
|
| 10 |
matches = re.findall(r"\{.*?\}", " ".join(prompt))
|
| 11 |
if matches:
|
| 12 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
| 13 |
-
# Check if we should add /no_think to turn off thinking mode
|
| 14 |
-
if hasattr(chat_model, "model_id"):
|
| 15 |
-
model_id = chat_model.model_id
|
| 16 |
-
if ("SmolLM" in model_id or "Qwen" in model_id) and not think:
|
| 17 |
-
prompt = "/no_think\n" + prompt
|
| 18 |
return prompt
|
| 19 |
|
| 20 |
|
| 21 |
-
def query_prompt(
|
| 22 |
"""Return system prompt for query step"""
|
| 23 |
|
| 24 |
# Get start and end months from database
|
|
@@ -43,12 +37,12 @@ def query_prompt(chat_model, think=False):
|
|
| 43 |
# "Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question. " # Qwen
|
| 44 |
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list. "
|
| 45 |
)
|
| 46 |
-
prompt = check_prompt(prompt
|
| 47 |
|
| 48 |
return prompt
|
| 49 |
|
| 50 |
|
| 51 |
-
def answer_prompt(
|
| 52 |
"""Return system prompt for answer step"""
|
| 53 |
prompt = (
|
| 54 |
f"Today Date: {date.today()}. "
|
|
@@ -64,61 +58,8 @@ def answer_prompt(chat_model, think=False, with_tools=False):
|
|
| 64 |
"Only answer general questions about R if the answer is in the retrieved emails. "
|
| 65 |
"Only include URLs if they were used by human authors (not in email headers), and do not modify any URLs. " # Qwen, Gemma
|
| 66 |
"Respond with 500 words maximum and 50 lines of code maximum. "
|
|
|
|
| 67 |
)
|
| 68 |
-
|
| 69 |
-
prompt = (
|
| 70 |
-
f"{prompt}"
|
| 71 |
-
"Use answer_with_citations to provide the complete answer and all citations used. "
|
| 72 |
-
)
|
| 73 |
-
prompt = check_prompt(prompt, chat_model, think)
|
| 74 |
|
| 75 |
return prompt
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
# Prompt template for SmolLM3 with tools
|
| 79 |
-
# The first two lines, <function-name>, and <args-json-object> are from the apply_chat_template for HuggingFaceTB/SmolLM3-3B
|
| 80 |
-
# The other lines (You have, {tools}, You must), "tool", and "tool_input" are from tool_calling_llm.py
|
| 81 |
-
smollm3_tools_template = """
|
| 82 |
-
|
| 83 |
-
### Tools
|
| 84 |
-
|
| 85 |
-
You may call one or more functions to assist with the user query.
|
| 86 |
-
|
| 87 |
-
You have access to the following tools:
|
| 88 |
-
|
| 89 |
-
{tools}
|
| 90 |
-
|
| 91 |
-
You must always select one of the above tools and respond with only a JSON object matching the following schema:
|
| 92 |
-
|
| 93 |
-
{{
|
| 94 |
-
"tool": <function-name>,
|
| 95 |
-
"tool_input": <args-json-object>
|
| 96 |
-
}},
|
| 97 |
-
{{
|
| 98 |
-
"tool": <function-name>,
|
| 99 |
-
"tool_input": <args-json-object>
|
| 100 |
-
}}
|
| 101 |
-
|
| 102 |
-
"""
|
| 103 |
-
|
| 104 |
-
# Prompt template for Gemma/Qwen with tools
|
| 105 |
-
# Based on https://ai.google.dev/gemma/docs/capabilities/function-calling
|
| 106 |
-
generic_tools_template = """
|
| 107 |
-
|
| 108 |
-
### Functions
|
| 109 |
-
|
| 110 |
-
You have access to functions. If you decide to invoke any of the function(s), you MUST put it in the format of
|
| 111 |
-
|
| 112 |
-
{{
|
| 113 |
-
"tool": <function-name>,
|
| 114 |
-
"tool_input": <args-json-object>
|
| 115 |
-
}},
|
| 116 |
-
{{
|
| 117 |
-
"tool": <function-name>,
|
| 118 |
-
"tool_input": <args-json-object>
|
| 119 |
-
}}
|
| 120 |
-
|
| 121 |
-
You SHOULD NOT include any other text in the response if you call a function
|
| 122 |
-
|
| 123 |
-
{tools}
|
| 124 |
-
"""
|
|
|
|
| 3 |
import re
|
| 4 |
|
| 5 |
|
| 6 |
+
def check_prompt(prompt):
|
| 7 |
+
"""Check for unassigned variables"""
|
| 8 |
# A sanity check that we don't have unassigned variables
|
|
|
|
| 9 |
matches = re.findall(r"\{.*?\}", " ".join(prompt))
|
| 10 |
if matches:
|
| 11 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
return prompt
|
| 13 |
|
| 14 |
|
| 15 |
+
def query_prompt():
|
| 16 |
"""Return system prompt for query step"""
|
| 17 |
|
| 18 |
# Get start and end months from database
|
|
|
|
| 37 |
# "Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question. " # Qwen
|
| 38 |
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list. "
|
| 39 |
)
|
| 40 |
+
prompt = check_prompt(prompt)
|
| 41 |
|
| 42 |
return prompt
|
| 43 |
|
| 44 |
|
| 45 |
+
def answer_prompt():
|
| 46 |
"""Return system prompt for answer step"""
|
| 47 |
prompt = (
|
| 48 |
f"Today Date: {date.today()}. "
|
|
|
|
| 58 |
"Only answer general questions about R if the answer is in the retrieved emails. "
|
| 59 |
"Only include URLs if they were used by human authors (not in email headers), and do not modify any URLs. " # Qwen, Gemma
|
| 60 |
"Respond with 500 words maximum and 50 lines of code maximum. "
|
| 61 |
+
"Use answer_with_citations to provide the complete answer and all citations used. "
|
| 62 |
)
|
| 63 |
+
prompt = check_prompt(prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,25 +1,17 @@
|
|
| 1 |
-
#
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
chromadb==0.6.3
|
| 4 |
# NOTE: chromadb==1.0.13 was giving intermittent error:
|
| 5 |
# ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
|
| 6 |
-
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Stated requirements:
|
| 11 |
-
# Gemma 3: transformers>=4.50
|
| 12 |
-
# Qwen3: transformers>=4.51
|
| 13 |
-
# SmolLM3: transformers>=4.53
|
| 14 |
-
transformers==4.51.3
|
| 15 |
-
tokenizers==0.21.2
|
| 16 |
-
# Only needed with AutoModelForCausalLM.from_pretrained(device_map="auto")
|
| 17 |
-
#accelerate==1.8.1
|
| 18 |
-
|
| 19 |
-
# Required by langchain-huggingface
|
| 20 |
-
sentence-transformers==5.0.0
|
| 21 |
-
# For snapshot_download
|
| 22 |
-
huggingface-hub==0.34.3
|
| 23 |
|
| 24 |
# Langchain packages
|
| 25 |
langchain==0.3.26
|
|
@@ -27,31 +19,14 @@ langchain-core==0.3.72
|
|
| 27 |
langchain-chroma==0.2.3
|
| 28 |
langchain-openai==0.3.27
|
| 29 |
langchain-community==0.3.27
|
| 30 |
-
langchain-huggingface==0.3.0
|
| 31 |
langchain-text-splitters==0.3.8
|
| 32 |
langgraph==0.4.7
|
| 33 |
langgraph-sdk==0.1.72
|
| 34 |
langgraph-prebuilt==0.5.2
|
| 35 |
langgraph-checkpoint==2.1.0
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
einops==0.8.1
|
| 39 |
-
|
| 40 |
-
# Commented because we have local modifications
|
| 41 |
-
#tool-calling-llm==0.1.2
|
| 42 |
-
bm25s==0.2.12
|
| 43 |
ragas==0.2.15
|
| 44 |
|
| 45 |
-
#
|
| 46 |
-
# https://github.com/vanna-ai/vanna/issues/917
|
| 47 |
-
posthog==5.4.0
|
| 48 |
-
|
| 49 |
-
# Gradio for the web interface
|
| 50 |
gradio==5.38.2
|
| 51 |
-
spaces==0.37.1
|
| 52 |
-
|
| 53 |
-
# For downloading data from S3
|
| 54 |
-
boto3==1.39.14
|
| 55 |
-
|
| 56 |
-
# Others
|
| 57 |
-
python-dotenv==1.1.1
|
|
|
|
| 1 |
+
# To load API keys
|
| 2 |
+
python-dotenv==1.1.1
|
| 3 |
+
|
| 4 |
+
# To download data from S3
|
| 5 |
+
boto3==1.39.14
|
| 6 |
+
|
| 7 |
+
# Retrieval
|
| 8 |
+
bm25s==0.2.12
|
| 9 |
chromadb==0.6.3
|
| 10 |
# NOTE: chromadb==1.0.13 was giving intermittent error:
|
| 11 |
# ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
|
| 12 |
+
# posthog<6.0.0 is temporary fix for ChromaDB telemetry error log messages
|
| 13 |
+
# https://github.com/vanna-ai/vanna/issues/917
|
| 14 |
+
posthog==5.4.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
# Langchain packages
|
| 17 |
langchain==0.3.26
|
|
|
|
| 19 |
langchain-chroma==0.2.3
|
| 20 |
langchain-openai==0.3.27
|
| 21 |
langchain-community==0.3.27
|
|
|
|
| 22 |
langchain-text-splitters==0.3.8
|
| 23 |
langgraph==0.4.7
|
| 24 |
langgraph-sdk==0.1.72
|
| 25 |
langgraph-prebuilt==0.5.2
|
| 26 |
langgraph-checkpoint==2.1.0
|
| 27 |
|
| 28 |
+
# Evaluations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
ragas==0.2.15
|
| 30 |
|
| 31 |
+
# Frontend
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
gradio==5.38.2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever.py
CHANGED
|
@@ -1,25 +1,17 @@
|
|
| 1 |
# Main retriever modules
|
| 2 |
-
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 3 |
-
from langchain_community.document_loaders import TextLoader
|
| 4 |
-
from langchain_chroma import Chroma
|
| 5 |
from langchain.retrievers import ParentDocumentRetriever, EnsembleRetriever
|
| 6 |
-
from langchain_core.documents import Document
|
| 7 |
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
| 8 |
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from typing import Any, Optional
|
| 10 |
import chromadb
|
| 11 |
-
import torch
|
| 12 |
import os
|
| 13 |
import re
|
| 14 |
|
| 15 |
-
# To use OpenAI models (remote)
|
| 16 |
-
from langchain_openai import OpenAIEmbeddings
|
| 17 |
-
|
| 18 |
-
## To use Hugging Face models (local)
|
| 19 |
-
# from langchain_huggingface import HuggingFaceEmbeddings
|
| 20 |
-
# For more control over BGE and Nomic embeddings
|
| 21 |
-
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
| 22 |
-
|
| 23 |
# Local modules
|
| 24 |
from mods.bm25s_retriever import BM25SRetriever
|
| 25 |
from mods.file_system import LocalFileStore
|
|
@@ -27,41 +19,30 @@ from mods.file_system import LocalFileStore
|
|
| 27 |
# Database directory
|
| 28 |
db_dir = "db"
|
| 29 |
|
| 30 |
-
# Embedding model
|
| 31 |
-
embedding_model_id = "nomic-ai/nomic-embed-text-v1.5"
|
| 32 |
-
|
| 33 |
|
| 34 |
def BuildRetriever(
|
| 35 |
-
compute_mode,
|
| 36 |
search_type: str = "hybrid",
|
| 37 |
top_k=6,
|
| 38 |
start_year=None,
|
| 39 |
end_year=None,
|
| 40 |
-
embedding_ckpt_dir=None,
|
| 41 |
):
|
| 42 |
"""
|
| 43 |
Build retriever instance.
|
| 44 |
All retriever types are configured to return up to 6 documents for fair comparison in evals.
|
| 45 |
|
| 46 |
Args:
|
| 47 |
-
compute_mode: Compute mode for embeddings (remote or local)
|
| 48 |
search_type: Type of search to use. Options: "dense", "sparse", "hybrid"
|
| 49 |
top_k: Number of documents to retrieve for "dense" and "sparse"
|
| 50 |
start_year: Start year (optional)
|
| 51 |
end_year: End year (optional)
|
| 52 |
-
embedding_ckpt_dir: Directory for embedding model checkpoint
|
| 53 |
"""
|
| 54 |
if search_type == "dense":
|
| 55 |
if not (start_year or end_year):
|
| 56 |
# No year filtering, so directly use base retriever
|
| 57 |
-
return BuildRetrieverDense(
|
| 58 |
-
compute_mode, top_k=top_k, embedding_ckpt_dir=embedding_ckpt_dir
|
| 59 |
-
)
|
| 60 |
else:
|
| 61 |
# Get 1000 documents then keep top_k filtered by year
|
| 62 |
-
base_retriever = BuildRetrieverDense(
|
| 63 |
-
compute_mode, top_k=1000, embedding_ckpt_dir=embedding_ckpt_dir
|
| 64 |
-
)
|
| 65 |
return TopKRetriever(
|
| 66 |
base_retriever=base_retriever,
|
| 67 |
top_k=top_k,
|
|
@@ -85,20 +66,16 @@ def BuildRetriever(
|
|
| 85 |
# Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k
|
| 86 |
# https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python
|
| 87 |
dense_retriever = BuildRetriever(
|
| 88 |
-
compute_mode,
|
| 89 |
"dense",
|
| 90 |
(top_k // 2),
|
| 91 |
start_year,
|
| 92 |
end_year,
|
| 93 |
-
embedding_ckpt_dir,
|
| 94 |
)
|
| 95 |
sparse_retriever = BuildRetriever(
|
| 96 |
-
compute_mode,
|
| 97 |
"sparse",
|
| 98 |
-(top_k // -2),
|
| 99 |
start_year,
|
| 100 |
end_year,
|
| 101 |
-
embedding_ckpt_dir,
|
| 102 |
)
|
| 103 |
ensemble_retriever = EnsembleRetriever(
|
| 104 |
retrievers=[dense_retriever, sparse_retriever], weights=[1, 1]
|
|
@@ -128,43 +105,19 @@ def BuildRetrieverSparse(top_k=6):
|
|
| 128 |
return retriever
|
| 129 |
|
| 130 |
|
| 131 |
-
def BuildRetrieverDense(
|
| 132 |
"""
|
| 133 |
Build dense retriever instance with ChromaDB vectorstore
|
| 134 |
|
| 135 |
Args:
|
| 136 |
-
compute_mode: Compute mode for embeddings (remote or local)
|
| 137 |
top_k: Number of documents to retrieve
|
| 138 |
-
embedding_ckpt_dir: Directory for embedding model checkpoint
|
| 139 |
"""
|
| 140 |
|
| 141 |
-
# Don't try to use local models without a GPU
|
| 142 |
-
if compute_mode == "local" and not torch.cuda.is_available():
|
| 143 |
-
raise Exception("Local embeddings selected without GPU")
|
| 144 |
-
|
| 145 |
# Define embedding model
|
| 146 |
-
|
| 147 |
-
embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
|
| 148 |
-
if compute_mode == "local":
|
| 149 |
-
# embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en-v1.5", show_progress=True)
|
| 150 |
-
# https://python.langchain.com/api_reference/community/embeddings/langchain_community.embeddings.huggingface.HuggingFaceBgeEmbeddings.html
|
| 151 |
-
model_kwargs = {
|
| 152 |
-
"device": "cuda",
|
| 153 |
-
"trust_remote_code": True,
|
| 154 |
-
}
|
| 155 |
-
encode_kwargs = {"normalize_embeddings": True}
|
| 156 |
-
# Use embedding model ID or checkpoint directory if given
|
| 157 |
-
id_or_dir = embedding_ckpt_dir if embedding_ckpt_dir else embedding_model_id
|
| 158 |
-
embedding_function = HuggingFaceBgeEmbeddings(
|
| 159 |
-
model_name=id_or_dir,
|
| 160 |
-
model_kwargs=model_kwargs,
|
| 161 |
-
encode_kwargs=encode_kwargs,
|
| 162 |
-
query_instruction="search_query:",
|
| 163 |
-
embed_instruction="search_document:",
|
| 164 |
-
)
|
| 165 |
# Create vector store
|
| 166 |
client_settings = chromadb.config.Settings(anonymized_telemetry=False)
|
| 167 |
-
persist_directory = f"{db_dir}/
|
| 168 |
vectorstore = Chroma(
|
| 169 |
collection_name="R-help",
|
| 170 |
embedding_function=embedding_function,
|
|
@@ -172,7 +125,7 @@ def BuildRetrieverDense(compute_mode: str, top_k=6, embedding_ckpt_dir=None):
|
|
| 172 |
persist_directory=persist_directory,
|
| 173 |
)
|
| 174 |
# The storage layer for the parent documents
|
| 175 |
-
file_store = f"{db_dir}/
|
| 176 |
byte_store = LocalFileStore(file_store)
|
| 177 |
# Text splitter for child documents
|
| 178 |
child_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
| 1 |
# Main retriever modules
|
|
|
|
|
|
|
|
|
|
| 2 |
from langchain.retrievers import ParentDocumentRetriever, EnsembleRetriever
|
|
|
|
| 3 |
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
| 4 |
from langchain_core.callbacks import CallbackManagerForRetrieverRun
|
| 5 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 6 |
+
from langchain_community.document_loaders import TextLoader
|
| 7 |
+
from langchain_core.documents import Document
|
| 8 |
+
from langchain_openai import OpenAIEmbeddings
|
| 9 |
+
from langchain_chroma import Chroma
|
| 10 |
from typing import Any, Optional
|
| 11 |
import chromadb
|
|
|
|
| 12 |
import os
|
| 13 |
import re
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
# Local modules
|
| 16 |
from mods.bm25s_retriever import BM25SRetriever
|
| 17 |
from mods.file_system import LocalFileStore
|
|
|
|
| 19 |
# Database directory
|
| 20 |
db_dir = "db"
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def BuildRetriever(
|
|
|
|
| 24 |
search_type: str = "hybrid",
|
| 25 |
top_k=6,
|
| 26 |
start_year=None,
|
| 27 |
end_year=None,
|
|
|
|
| 28 |
):
|
| 29 |
"""
|
| 30 |
Build retriever instance.
|
| 31 |
All retriever types are configured to return up to 6 documents for fair comparison in evals.
|
| 32 |
|
| 33 |
Args:
|
|
|
|
| 34 |
search_type: Type of search to use. Options: "dense", "sparse", "hybrid"
|
| 35 |
top_k: Number of documents to retrieve for "dense" and "sparse"
|
| 36 |
start_year: Start year (optional)
|
| 37 |
end_year: End year (optional)
|
|
|
|
| 38 |
"""
|
| 39 |
if search_type == "dense":
|
| 40 |
if not (start_year or end_year):
|
| 41 |
# No year filtering, so directly use base retriever
|
| 42 |
+
return BuildRetrieverDense(top_k=top_k)
|
|
|
|
|
|
|
| 43 |
else:
|
| 44 |
# Get 1000 documents then keep top_k filtered by year
|
| 45 |
+
base_retriever = BuildRetrieverDense(top_k=1000)
|
|
|
|
|
|
|
| 46 |
return TopKRetriever(
|
| 47 |
base_retriever=base_retriever,
|
| 48 |
top_k=top_k,
|
|
|
|
| 66 |
# Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k
|
| 67 |
# https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python
|
| 68 |
dense_retriever = BuildRetriever(
|
|
|
|
| 69 |
"dense",
|
| 70 |
(top_k // 2),
|
| 71 |
start_year,
|
| 72 |
end_year,
|
|
|
|
| 73 |
)
|
| 74 |
sparse_retriever = BuildRetriever(
|
|
|
|
| 75 |
"sparse",
|
| 76 |
-(top_k // -2),
|
| 77 |
start_year,
|
| 78 |
end_year,
|
|
|
|
| 79 |
)
|
| 80 |
ensemble_retriever = EnsembleRetriever(
|
| 81 |
retrievers=[dense_retriever, sparse_retriever], weights=[1, 1]
|
|
|
|
| 105 |
return retriever
|
| 106 |
|
| 107 |
|
| 108 |
+
def BuildRetrieverDense(top_k=6):
|
| 109 |
"""
|
| 110 |
Build dense retriever instance with ChromaDB vectorstore
|
| 111 |
|
| 112 |
Args:
|
|
|
|
| 113 |
top_k: Number of documents to retrieve
|
|
|
|
| 114 |
"""
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
# Define embedding model
|
| 117 |
+
embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
# Create vector store
|
| 119 |
client_settings = chromadb.config.Settings(anonymized_telemetry=False)
|
| 120 |
+
persist_directory = f"{db_dir}/chroma"
|
| 121 |
vectorstore = Chroma(
|
| 122 |
collection_name="R-help",
|
| 123 |
embedding_function=embedding_function,
|
|
|
|
| 125 |
persist_directory=persist_directory,
|
| 126 |
)
|
| 127 |
# The storage layer for the parent documents
|
| 128 |
+
file_store = f"{db_dir}/file_store"
|
| 129 |
byte_store = LocalFileStore(file_store)
|
| 130 |
# Text splitter for child documents
|
| 131 |
child_splitter = RecursiveCharacterTextSplitter(
|
util.py
CHANGED
|
@@ -5,21 +5,6 @@ import os
|
|
| 5 |
import re
|
| 6 |
|
| 7 |
|
| 8 |
-
def get_collection(compute_mode):
|
| 9 |
-
"""
|
| 10 |
-
Returns the vectorstore collection.
|
| 11 |
-
|
| 12 |
-
Usage Examples:
|
| 13 |
-
# Number of child documents
|
| 14 |
-
collection = get_collection("remote")
|
| 15 |
-
len(collection["ids"])
|
| 16 |
-
# Number of parent documents (unique doc_ids)
|
| 17 |
-
len(set([m["doc_id"] for m in collection["metadatas"]]))
|
| 18 |
-
"""
|
| 19 |
-
retriever = BuildRetriever(compute_mode, "dense")
|
| 20 |
-
return retriever.vectorstore.get()
|
| 21 |
-
|
| 22 |
-
|
| 23 |
def get_sources():
|
| 24 |
"""
|
| 25 |
Return the source files indexed in the database, e.g. 'R-help/2024-April.txt'.
|
|
|
|
| 5 |
import re
|
| 6 |
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
def get_sources():
|
| 9 |
"""
|
| 10 |
Return the source files indexed in the database, e.g. 'R-help/2024-April.txt'.
|