Spaces:
Running
Running
jedick commited on
Commit ·
e1365aa
1
Parent(s): cf58ea1
Support collections and filtering by months
Browse files- app.py +63 -34
- data.py +1 -3
- graph.py +14 -10
- index.py +23 -11
- main.py +27 -13
- prompts.py +30 -12
- retriever.py +134 -46
- util.py +17 -10
app.py
CHANGED
|
@@ -13,13 +13,15 @@ from util import get_sources, get_start_end_months
|
|
| 13 |
from data import download_data, extract_data
|
| 14 |
from main import openai_model
|
| 15 |
from graph import BuildGraph
|
| 16 |
-
from retriever import db_dir
|
| 17 |
|
| 18 |
# Set environment variables
|
| 19 |
load_dotenv(dotenv_path=".env", override=True)
|
| 20 |
# Hide BM25S progress bars
|
| 21 |
os.environ["DISABLE_TQDM"] = "true"
|
| 22 |
|
|
|
|
|
|
|
|
|
|
| 23 |
# Download and extract data if data directory is not present
|
| 24 |
if not os.path.isdir(db_dir):
|
| 25 |
print("Downloading data ... ", end="")
|
|
@@ -32,7 +34,8 @@ if not os.path.isdir(db_dir):
|
|
| 32 |
# Global setting for search type
|
| 33 |
search_type = "hybrid"
|
| 34 |
|
| 35 |
-
# Global
|
|
|
|
| 36 |
# https://www.gradio.app/guides/state-in-blocks
|
| 37 |
graph_instances = {}
|
| 38 |
|
|
@@ -86,7 +89,7 @@ def append_content(chunk_messages, history, thinking_about):
|
|
| 86 |
return history
|
| 87 |
|
| 88 |
|
| 89 |
-
def run_workflow(input, history, thread_id, session_hash):
|
| 90 |
"""The main function to run the chat workflow"""
|
| 91 |
|
| 92 |
# Get graph instance
|
|
@@ -97,6 +100,8 @@ def run_workflow(input, history, thread_id, session_hash):
|
|
| 97 |
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 98 |
graph_builder = BuildGraph(
|
| 99 |
chat_model,
|
|
|
|
|
|
|
| 100 |
search_type,
|
| 101 |
)
|
| 102 |
# Compile the graph with an in-memory checkpointer
|
|
@@ -106,7 +111,7 @@ def run_workflow(input, history, thread_id, session_hash):
|
|
| 106 |
graph_instances[session_hash] = graph
|
| 107 |
# ISO 8601 timestamp with local timezone information without microsecond
|
| 108 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 109 |
-
print(f"{timestamp} - Set graph for session {session_hash}")
|
| 110 |
## Notify when model finishes loading
|
| 111 |
# gr.Success("Model loaded!", duration=4)
|
| 112 |
else:
|
|
@@ -148,7 +153,7 @@ def run_workflow(input, history, thread_id, session_hash):
|
|
| 148 |
if start_year or end_year:
|
| 149 |
content = f"{content} ({start_year or ''} - {end_year or ''})"
|
| 150 |
if "months" in args:
|
| 151 |
-
content = f"{content} {args['months']}"
|
| 152 |
history.append(
|
| 153 |
gr.ChatMessage(
|
| 154 |
role="assistant",
|
|
@@ -169,12 +174,12 @@ def run_workflow(input, history, thread_id, session_hash):
|
|
| 169 |
email_list = message.content.replace(
|
| 170 |
"### Retrieved Emails:\n\n", ""
|
| 171 |
).split("--- --- --- --- Next Email --- --- --- ---\n\n")
|
| 172 |
-
# Get the
|
| 173 |
-
month_list = [
|
|
|
|
|
|
|
| 174 |
# Format months (e.g. 2024-December) into text
|
| 175 |
-
month_text = (
|
| 176 |
-
", ".join(month_list).replace("R-help/", "").replace(".txt", "")
|
| 177 |
-
)
|
| 178 |
# Get the number of retrieved emails
|
| 179 |
n_emails = len(email_list)
|
| 180 |
title = f"🗎 Retrieved {n_emails} emails"
|
|
@@ -219,7 +224,7 @@ def run_workflow(input, history, thread_id, session_hash):
|
|
| 219 |
yield history, None, citations
|
| 220 |
|
| 221 |
|
| 222 |
-
def
|
| 223 |
"""Wrapper function to call run_workflow() with session_hash"""
|
| 224 |
input = args[0]
|
| 225 |
# Add session_hash to arguments
|
|
@@ -318,19 +323,20 @@ with gr.Blocks(
|
|
| 318 |
<!-- Get AI-powered answers about R programming backed by email retrieval. -->
|
| 319 |
## 🇷🤝💬 R-help-chat
|
| 320 |
|
| 321 |
-
**Search and chat with the [R-help
|
|
|
|
| 322 |
An LLM turns your question into a search query, including year ranges and months.
|
| 323 |
Retrieved emails are shown below the chatbot and are used by the LLM to generate an answer.
|
| 324 |
-
You can ask follow-up questions with the chat history as context.
|
| 325 |
Press the clear button (🗑) to clear the history and start a new chat.
|
| 326 |
*Privacy notice*: Data sharing with OpenAI is enabled.
|
| 327 |
"""
|
| 328 |
return intro
|
| 329 |
|
| 330 |
-
def get_info_text():
|
| 331 |
try:
|
| 332 |
# Get source files for each email and start and end months from database
|
| 333 |
-
sources = get_sources()
|
| 334 |
start, end = get_start_end_months(sources)
|
| 335 |
except:
|
| 336 |
# If database isn't ready, put in empty values
|
|
@@ -339,28 +345,26 @@ with gr.Blocks(
|
|
| 339 |
end = None
|
| 340 |
info_text = f"""
|
| 341 |
**Database:** {len(sources)} emails from {start} to {end}<br>
|
| 342 |
-
**Models:** {openai_model} and text-embedding-3-small<br>
|
| 343 |
**Features:** RAG, today's date, hybrid search (semantic + lexical), multiple retrievals, citations output, chat memory<br>
|
| 344 |
**Tech:** [OpenAI](https://openai.com/), [Chroma](https://www.trychroma.com/),
|
| 345 |
[BM25S](https://github.com/xhluca/bm25s), [LangGraph](https://www.langchain.com/langgraph), [Gradio](https://www.langchain.com/langgraph)<br>
|
|
|
|
| 346 |
🏠 **More info:** [R-help-chat GitHub repository](https://github.com/jedick/R-help-chat)
|
| 347 |
"""
|
| 348 |
return info_text
|
| 349 |
|
| 350 |
-
def get_example_questions(as_dataset=
|
| 351 |
"""Get example questions"""
|
| 352 |
questions = [
|
| 353 |
# "What is today's date?",
|
| 354 |
-
"Summarize emails from the most recent two months",
|
| 355 |
"Show me code examples using plotmath",
|
| 356 |
-
"
|
| 357 |
-
"Who reported installation problems in 2023-2024?",
|
| 358 |
]
|
| 359 |
|
| 360 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
| 361 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 362 |
|
| 363 |
-
def get_multi_tool_questions(as_dataset=
|
| 364 |
"""Get multi-tool example questions"""
|
| 365 |
questions = [
|
| 366 |
"Differences between lapply and for loops",
|
|
@@ -369,7 +373,7 @@ with gr.Blocks(
|
|
| 369 |
|
| 370 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 371 |
|
| 372 |
-
def get_multi_turn_questions(as_dataset=
|
| 373 |
"""Get multi-turn example questions"""
|
| 374 |
questions = [
|
| 375 |
"Lookup emails that reference bugs.r-project.org in 2025",
|
|
@@ -378,6 +382,15 @@ with gr.Blocks(
|
|
| 378 |
|
| 379 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
with gr.Row():
|
| 382 |
# Left column: Intro, Compute, Chat
|
| 383 |
with gr.Column(scale=2):
|
|
@@ -385,10 +398,10 @@ with gr.Blocks(
|
|
| 385 |
with gr.Column(scale=4):
|
| 386 |
intro = gr.Markdown(get_intro_text())
|
| 387 |
with gr.Column(scale=1):
|
| 388 |
-
gr.Radio(
|
| 389 |
-
["
|
| 390 |
-
|
| 391 |
-
|
| 392 |
)
|
| 393 |
with gr.Group() as chat_interface:
|
| 394 |
chatbot.render()
|
|
@@ -401,23 +414,28 @@ with gr.Blocks(
|
|
| 401 |
# Right column: Info, Examples
|
| 402 |
with gr.Column(scale=1):
|
| 403 |
with gr.Accordion("ℹ️ App Info", open=True):
|
| 404 |
-
|
| 405 |
with gr.Accordion("💡 Examples", open=True):
|
| 406 |
# Add some helpful examples
|
| 407 |
example_questions = gr.Examples(
|
| 408 |
-
examples=get_example_questions(
|
| 409 |
inputs=[input],
|
| 410 |
-
label="
|
| 411 |
)
|
| 412 |
multi_tool_questions = gr.Examples(
|
| 413 |
-
examples=get_multi_tool_questions(
|
| 414 |
inputs=[input],
|
| 415 |
label="Multiple retrievals",
|
| 416 |
)
|
| 417 |
multi_turn_questions = gr.Examples(
|
| 418 |
-
examples=get_multi_turn_questions(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
inputs=[input],
|
| 420 |
-
label="
|
| 421 |
)
|
| 422 |
|
| 423 |
# Bottom row: retrieved emails and citations
|
|
@@ -458,10 +476,21 @@ with gr.Blocks(
|
|
| 458 |
# https://github.com/gradio-app/gradio/issues/9722
|
| 459 |
chatbot.clear(generate_thread_id, outputs=[thread_id], api_visibility="private")
|
| 460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
input.submit(
|
| 462 |
# Submit input to the chatbot
|
| 463 |
-
|
| 464 |
-
[input, chatbot, thread_id],
|
| 465 |
[chatbot, retrieved_emails, citations_text],
|
| 466 |
api_visibility="private",
|
| 467 |
)
|
|
|
|
| 13 |
from data import download_data, extract_data
|
| 14 |
from main import openai_model
|
| 15 |
from graph import BuildGraph
|
|
|
|
| 16 |
|
| 17 |
# Set environment variables
|
| 18 |
load_dotenv(dotenv_path=".env", override=True)
|
| 19 |
# Hide BM25S progress bars
|
| 20 |
os.environ["DISABLE_TQDM"] = "true"
|
| 21 |
|
| 22 |
+
# Database directory
|
| 23 |
+
db_dir = "db"
|
| 24 |
+
|
| 25 |
# Download and extract data if data directory is not present
|
| 26 |
if not os.path.isdir(db_dir):
|
| 27 |
print("Downloading data ... ", end="")
|
|
|
|
| 34 |
# Global setting for search type
|
| 35 |
search_type = "hybrid"
|
| 36 |
|
| 37 |
+
# Global variable for LangChain graph
|
| 38 |
+
# Use dictionary to store user-specific instances
|
| 39 |
# https://www.gradio.app/guides/state-in-blocks
|
| 40 |
graph_instances = {}
|
| 41 |
|
|
|
|
| 89 |
return history
|
| 90 |
|
| 91 |
|
| 92 |
+
def run_workflow(input, collection, history, thread_id, session_hash):
|
| 93 |
"""The main function to run the chat workflow"""
|
| 94 |
|
| 95 |
# Get graph instance
|
|
|
|
| 100 |
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 101 |
graph_builder = BuildGraph(
|
| 102 |
chat_model,
|
| 103 |
+
db_dir,
|
| 104 |
+
collection,
|
| 105 |
search_type,
|
| 106 |
)
|
| 107 |
# Compile the graph with an in-memory checkpointer
|
|
|
|
| 111 |
graph_instances[session_hash] = graph
|
| 112 |
# ISO 8601 timestamp with local timezone information without microsecond
|
| 113 |
timestamp = datetime.now().replace(microsecond=0).isoformat()
|
| 114 |
+
print(f"{timestamp} - Set {collection} graph for session {session_hash}")
|
| 115 |
## Notify when model finishes loading
|
| 116 |
# gr.Success("Model loaded!", duration=4)
|
| 117 |
else:
|
|
|
|
| 153 |
if start_year or end_year:
|
| 154 |
content = f"{content} ({start_year or ''} - {end_year or ''})"
|
| 155 |
if "months" in args:
|
| 156 |
+
content = f"{content} {", ".join(args['months'])}"
|
| 157 |
history.append(
|
| 158 |
gr.ChatMessage(
|
| 159 |
role="assistant",
|
|
|
|
| 174 |
email_list = message.content.replace(
|
| 175 |
"### Retrieved Emails:\n\n", ""
|
| 176 |
).split("--- --- --- --- Next Email --- --- --- ---\n\n")
|
| 177 |
+
# Get the source file names (e.g. 2024-December.txt) for retrieved emails
|
| 178 |
+
month_list = [
|
| 179 |
+
os.path.basename(text.splitlines()[0]) for text in email_list
|
| 180 |
+
]
|
| 181 |
# Format months (e.g. 2024-December) into text
|
| 182 |
+
month_text = ", ".join(month_list).replace(".txt", "")
|
|
|
|
|
|
|
| 183 |
# Get the number of retrieved emails
|
| 184 |
n_emails = len(email_list)
|
| 185 |
title = f"🗎 Retrieved {n_emails} emails"
|
|
|
|
| 224 |
yield history, None, citations
|
| 225 |
|
| 226 |
|
| 227 |
+
def run_workflow_in_session(request: gr.Request, *args):
|
| 228 |
"""Wrapper function to call run_workflow() with session_hash"""
|
| 229 |
input = args[0]
|
| 230 |
# Add session_hash to arguments
|
|
|
|
| 323 |
<!-- Get AI-powered answers about R programming backed by email retrieval. -->
|
| 324 |
## 🇷🤝💬 R-help-chat
|
| 325 |
|
| 326 |
+
**Search and chat with the [R-help](https://stat.ethz.ch/pipermail/r-help/) and [R-devel](https://stat.ethz.ch/pipermail/r-devel/)
|
| 327 |
+
mailing list archives.**
|
| 328 |
An LLM turns your question into a search query, including year ranges and months.
|
| 329 |
Retrieved emails are shown below the chatbot and are used by the LLM to generate an answer.
|
| 330 |
+
You can ask follow-up questions with the chat history as context; changing the mailing list maintains history.
|
| 331 |
Press the clear button (🗑) to clear the history and start a new chat.
|
| 332 |
*Privacy notice*: Data sharing with OpenAI is enabled.
|
| 333 |
"""
|
| 334 |
return intro
|
| 335 |
|
| 336 |
+
def get_info_text(collection):
|
| 337 |
try:
|
| 338 |
# Get source files for each email and start and end months from database
|
| 339 |
+
sources = get_sources(db_dir, collection)
|
| 340 |
start, end = get_start_end_months(sources)
|
| 341 |
except:
|
| 342 |
# If database isn't ready, put in empty values
|
|
|
|
| 345 |
end = None
|
| 346 |
info_text = f"""
|
| 347 |
**Database:** {len(sources)} emails from {start} to {end}<br>
|
|
|
|
| 348 |
**Features:** RAG, today's date, hybrid search (semantic + lexical), multiple retrievals, citations output, chat memory<br>
|
| 349 |
**Tech:** [OpenAI](https://openai.com/), [Chroma](https://www.trychroma.com/),
|
| 350 |
[BM25S](https://github.com/xhluca/bm25s), [LangGraph](https://www.langchain.com/langgraph), [Gradio](https://www.langchain.com/langgraph)<br>
|
| 351 |
+
**Maintainer:** [Jeffrey Dick](mailto:j3ffdick@gmail.com) - feedback welcome!<br>
|
| 352 |
🏠 **More info:** [R-help-chat GitHub repository](https://github.com/jedick/R-help-chat)
|
| 353 |
"""
|
| 354 |
return info_text
|
| 355 |
|
| 356 |
+
def get_example_questions(as_dataset=False):
|
| 357 |
"""Get example questions"""
|
| 358 |
questions = [
|
| 359 |
# "What is today's date?",
|
|
|
|
| 360 |
"Show me code examples using plotmath",
|
| 361 |
+
"Summarize emails from the most recent two months",
|
|
|
|
| 362 |
]
|
| 363 |
|
| 364 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
| 365 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 366 |
|
| 367 |
+
def get_multi_tool_questions(as_dataset=False):
|
| 368 |
"""Get multi-tool example questions"""
|
| 369 |
questions = [
|
| 370 |
"Differences between lapply and for loops",
|
|
|
|
| 373 |
|
| 374 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 375 |
|
| 376 |
+
def get_multi_turn_questions(as_dataset=False):
|
| 377 |
"""Get multi-turn example questions"""
|
| 378 |
questions = [
|
| 379 |
"Lookup emails that reference bugs.r-project.org in 2025",
|
|
|
|
| 382 |
|
| 383 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 384 |
|
| 385 |
+
def get_month_questions(as_dataset=False):
|
| 386 |
+
"""Get month example questions"""
|
| 387 |
+
questions = [
|
| 388 |
+
"Was there any discussion of ggplot2 in Q4 2025?",
|
| 389 |
+
"How about Q3?",
|
| 390 |
+
]
|
| 391 |
+
|
| 392 |
+
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
| 393 |
+
|
| 394 |
with gr.Row():
|
| 395 |
# Left column: Intro, Compute, Chat
|
| 396 |
with gr.Column(scale=2):
|
|
|
|
| 398 |
with gr.Column(scale=4):
|
| 399 |
intro = gr.Markdown(get_intro_text())
|
| 400 |
with gr.Column(scale=1):
|
| 401 |
+
collection = gr.Radio(
|
| 402 |
+
["R-help", "R-devel"],
|
| 403 |
+
value="R-help",
|
| 404 |
+
label="Mailing List",
|
| 405 |
)
|
| 406 |
with gr.Group() as chat_interface:
|
| 407 |
chatbot.render()
|
|
|
|
| 414 |
# Right column: Info, Examples
|
| 415 |
with gr.Column(scale=1):
|
| 416 |
with gr.Accordion("ℹ️ App Info", open=True):
|
| 417 |
+
app_info = gr.Markdown(get_info_text(collection.value))
|
| 418 |
with gr.Accordion("💡 Examples", open=True):
|
| 419 |
# Add some helpful examples
|
| 420 |
example_questions = gr.Examples(
|
| 421 |
+
examples=get_example_questions(),
|
| 422 |
inputs=[input],
|
| 423 |
+
label="Basic examples",
|
| 424 |
)
|
| 425 |
multi_tool_questions = gr.Examples(
|
| 426 |
+
examples=get_multi_tool_questions(),
|
| 427 |
inputs=[input],
|
| 428 |
label="Multiple retrievals",
|
| 429 |
)
|
| 430 |
multi_turn_questions = gr.Examples(
|
| 431 |
+
examples=get_multi_turn_questions(),
|
| 432 |
+
inputs=[input],
|
| 433 |
+
label="Follow-up questions",
|
| 434 |
+
)
|
| 435 |
+
month_questions = gr.Examples(
|
| 436 |
+
examples=get_month_questions(),
|
| 437 |
inputs=[input],
|
| 438 |
+
label="Three-month periods",
|
| 439 |
)
|
| 440 |
|
| 441 |
# Bottom row: retrieved emails and citations
|
|
|
|
| 476 |
# https://github.com/gradio-app/gradio/issues/9722
|
| 477 |
chatbot.clear(generate_thread_id, outputs=[thread_id], api_visibility="private")
|
| 478 |
|
| 479 |
+
collection.change(
|
| 480 |
+
# We need to build a new graph if the collection changes
|
| 481 |
+
cleanup_graph
|
| 482 |
+
).then(
|
| 483 |
+
# Update the database stats in the app info box
|
| 484 |
+
get_info_text,
|
| 485 |
+
[collection],
|
| 486 |
+
[app_info],
|
| 487 |
+
api_name=False,
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
input.submit(
|
| 491 |
# Submit input to the chatbot
|
| 492 |
+
run_workflow_in_session,
|
| 493 |
+
[input, collection, chatbot, thread_id],
|
| 494 |
[chatbot, retrieved_emails, citations_text],
|
| 495 |
api_visibility="private",
|
| 496 |
)
|
data.py
CHANGED
|
@@ -45,9 +45,7 @@ def download_data():
|
|
| 45 |
|
| 46 |
if not os.path.exists("db.zip"):
|
| 47 |
# For S3 (need AWS_ACCESS_KEY_ID and AWS_ACCESS_KEY_SECRET)
|
| 48 |
-
|
| 49 |
-
# db_20250801a.zip: chromadb==0.6.3
|
| 50 |
-
download_file_from_bucket("r-help-chat", "db_20260102.zip", "db.zip")
|
| 51 |
## For Dropbox (shared file - key is in URL)
|
| 52 |
# shared_link = "https://www.dropbox.com/scl/fi/jx90g5lorpgkkyyzeurtc/db.zip?rlkey=wvqa3p9hdy4rmod1r8yf2am09&st=l9tsam56&dl=0"
|
| 53 |
# output_filename = "db.zip"
|
|
|
|
| 45 |
|
| 46 |
if not os.path.exists("db.zip"):
|
| 47 |
# For S3 (need AWS_ACCESS_KEY_ID and AWS_ACCESS_KEY_SECRET)
|
| 48 |
+
download_file_from_bucket("r-help-chat", "db_20260104.zip", "db.zip")
|
|
|
|
|
|
|
| 49 |
## For Dropbox (shared file - key is in URL)
|
| 50 |
# shared_link = "https://www.dropbox.com/scl/fi/jx90g5lorpgkkyyzeurtc/db.zip?rlkey=wvqa3p9hdy4rmod1r8yf2am09&st=l9tsam56&dl=0"
|
| 51 |
# output_filename = "db.zip"
|
graph.py
CHANGED
|
@@ -17,7 +17,9 @@ from prompts import query_prompt, answer_prompt
|
|
| 17 |
|
| 18 |
def BuildGraph(
|
| 19 |
chat_model,
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
top_k=6,
|
| 22 |
):
|
| 23 |
"""
|
|
@@ -25,7 +27,9 @@ def BuildGraph(
|
|
| 25 |
|
| 26 |
Args:
|
| 27 |
chat_model: LangChain chat model
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
top_k: number of documents to retrieve
|
| 30 |
|
| 31 |
Based on:
|
|
@@ -64,7 +68,7 @@ def BuildGraph(
|
|
| 64 |
search_query: str,
|
| 65 |
start_year: Optional[int] = None,
|
| 66 |
end_year: Optional[int] = None,
|
| 67 |
-
months: Optional[str] = None,
|
| 68 |
) -> str:
|
| 69 |
"""
|
| 70 |
Retrieve emails related to a search query from the R-help mailing list archives.
|
|
@@ -75,23 +79,23 @@ def BuildGraph(
|
|
| 75 |
search_query (str): Search query
|
| 76 |
start_year (int, optional): Starting year for emails
|
| 77 |
end_year (int, optional): Ending year for emails
|
| 78 |
-
months (str, optional):
|
| 79 |
"""
|
| 80 |
retriever = BuildRetriever(
|
|
|
|
|
|
|
| 81 |
search_type,
|
| 82 |
top_k,
|
| 83 |
start_year,
|
| 84 |
end_year,
|
|
|
|
| 85 |
)
|
| 86 |
-
# For now, just add the months to the search query
|
| 87 |
-
if months:
|
| 88 |
-
search_query = " ".join([search_query, months])
|
| 89 |
# If the search query is empty, use the years
|
| 90 |
if not search_query:
|
| 91 |
search_query = " ".join([search_query, start_year, end_year])
|
| 92 |
retrieved_docs = retriever.invoke(search_query)
|
| 93 |
serialized = "\n\n--- --- --- --- Next Email --- --- --- ---".join(
|
| 94 |
-
# Add file name (e.g. R-help/2024-December.txt) from source key
|
| 95 |
"\n\n" + doc.metadata["source"] + doc.page_content
|
| 96 |
for doc in retrieved_docs
|
| 97 |
)
|
|
@@ -122,14 +126,14 @@ def BuildGraph(
|
|
| 122 |
|
| 123 |
def query(state: MessagesState):
|
| 124 |
"""Queries the retriever with the chat model"""
|
| 125 |
-
messages = [SystemMessage(query_prompt())] + state["messages"]
|
| 126 |
response = query_model.invoke(messages)
|
| 127 |
|
| 128 |
return {"messages": response}
|
| 129 |
|
| 130 |
def answer(state: MessagesState):
|
| 131 |
"""Generates an answer with the chat model"""
|
| 132 |
-
messages = [SystemMessage(answer_prompt())] + state["messages"]
|
| 133 |
response = answer_model.invoke(messages)
|
| 134 |
|
| 135 |
return {"messages": response}
|
|
|
|
| 17 |
|
| 18 |
def BuildGraph(
|
| 19 |
chat_model,
|
| 20 |
+
db_dir,
|
| 21 |
+
collection,
|
| 22 |
+
search_type="hybrid",
|
| 23 |
top_k=6,
|
| 24 |
):
|
| 25 |
"""
|
|
|
|
| 27 |
|
| 28 |
Args:
|
| 29 |
chat_model: LangChain chat model
|
| 30 |
+
db_dir: Database directory
|
| 31 |
+
collection: Email collection
|
| 32 |
+
search_type: dense, sparse, or hybrid
|
| 33 |
top_k: number of documents to retrieve
|
| 34 |
|
| 35 |
Based on:
|
|
|
|
| 68 |
search_query: str,
|
| 69 |
start_year: Optional[int] = None,
|
| 70 |
end_year: Optional[int] = None,
|
| 71 |
+
months: Optional[list[str]] = None,
|
| 72 |
) -> str:
|
| 73 |
"""
|
| 74 |
Retrieve emails related to a search query from the R-help mailing list archives.
|
|
|
|
| 79 |
search_query (str): Search query
|
| 80 |
start_year (int, optional): Starting year for emails
|
| 81 |
end_year (int, optional): Ending year for emails
|
| 82 |
+
months (list(str), optional): List of one or more months (three-letter abbreviations)
|
| 83 |
"""
|
| 84 |
retriever = BuildRetriever(
|
| 85 |
+
db_dir,
|
| 86 |
+
collection,
|
| 87 |
search_type,
|
| 88 |
top_k,
|
| 89 |
start_year,
|
| 90 |
end_year,
|
| 91 |
+
months,
|
| 92 |
)
|
|
|
|
|
|
|
|
|
|
| 93 |
# If the search query is empty, use the years
|
| 94 |
if not search_query:
|
| 95 |
search_query = " ".join([search_query, start_year, end_year])
|
| 96 |
retrieved_docs = retriever.invoke(search_query)
|
| 97 |
serialized = "\n\n--- --- --- --- Next Email --- --- --- ---".join(
|
| 98 |
+
# Add source file name (e.g. R-help/2024-December.txt) from source key
|
| 99 |
"\n\n" + doc.metadata["source"] + doc.page_content
|
| 100 |
for doc in retrieved_docs
|
| 101 |
)
|
|
|
|
| 126 |
|
| 127 |
def query(state: MessagesState):
|
| 128 |
"""Queries the retriever with the chat model"""
|
| 129 |
+
messages = [SystemMessage(query_prompt(db_dir, collection))] + state["messages"]
|
| 130 |
response = query_model.invoke(messages)
|
| 131 |
|
| 132 |
return {"messages": response}
|
| 133 |
|
| 134 |
def answer(state: MessagesState):
|
| 135 |
"""Generates an answer with the chat model"""
|
| 136 |
+
messages = [SystemMessage(answer_prompt(collection))] + state["messages"]
|
| 137 |
response = answer_model.invoke(messages)
|
| 138 |
|
| 139 |
return {"messages": response}
|
index.py
CHANGED
|
@@ -3,18 +3,21 @@ from langchain_community.document_loaders import TextLoader
|
|
| 3 |
from datetime import datetime
|
| 4 |
import tempfile
|
| 5 |
import os
|
|
|
|
| 6 |
|
| 7 |
# Local modules
|
| 8 |
-
from retriever import BuildRetriever
|
| 9 |
from mods.bm25s_retriever import BM25SRetriever
|
| 10 |
|
| 11 |
|
| 12 |
-
def ProcessFile(file_path,
|
| 13 |
"""
|
| 14 |
Wrapper function to process file for dense or sparse search
|
| 15 |
|
| 16 |
Args:
|
| 17 |
file_path: File to process
|
|
|
|
|
|
|
| 18 |
search_type: Type of search to use. Options: "dense", "sparse"
|
| 19 |
"""
|
| 20 |
|
|
@@ -65,10 +68,10 @@ def ProcessFile(file_path, search_type: str = "dense"):
|
|
| 65 |
try:
|
| 66 |
if search_type == "sparse":
|
| 67 |
# Handle sparse search with BM25
|
| 68 |
-
ProcessFileSparse(truncated_temp_file, file_path)
|
| 69 |
elif search_type == "dense":
|
| 70 |
# Handle dense search with ChromaDB
|
| 71 |
-
ProcessFileDense(truncated_temp_file, file_path)
|
| 72 |
else:
|
| 73 |
raise ValueError(f"Unsupported search type: {search_type}")
|
| 74 |
finally:
|
|
@@ -80,17 +83,25 @@ def ProcessFile(file_path, search_type: str = "dense"):
|
|
| 80 |
pass
|
| 81 |
|
| 82 |
|
| 83 |
-
def ProcessFileDense(cleaned_temp_file, file_path):
|
| 84 |
"""
|
| 85 |
Process file for dense vector search using ChromaDB
|
| 86 |
"""
|
| 87 |
# Get a retriever instance
|
| 88 |
-
retriever = BuildRetriever("dense")
|
| 89 |
# Load cleaned text file
|
| 90 |
loader = TextLoader(cleaned_temp_file)
|
| 91 |
documents = loader.load()
|
| 92 |
# Use original file path for "source" key in metadata
|
| 93 |
documents[0].metadata["source"] = file_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
# Add file timestamp to metadata
|
| 95 |
mod_time = os.path.getmtime(file_path)
|
| 96 |
timestamp = datetime.fromtimestamp(mod_time).isoformat()
|
|
@@ -113,7 +124,7 @@ def ProcessFileDense(cleaned_temp_file, file_path):
|
|
| 113 |
retriever.add_documents(documents_batch)
|
| 114 |
|
| 115 |
|
| 116 |
-
def ProcessFileSparse(cleaned_temp_file, file_path):
|
| 117 |
"""
|
| 118 |
Process file for sparse search using BM25
|
| 119 |
"""
|
|
@@ -126,18 +137,19 @@ def ProcessFileSparse(cleaned_temp_file, file_path):
|
|
| 126 |
splitter = RecursiveCharacterTextSplitter(
|
| 127 |
separators=["\n\n\nFrom"], chunk_size=1, chunk_overlap=0
|
| 128 |
)
|
| 129 |
-
## Using 'EmailFrom' as the separator (requires preprocesing)
|
| 130 |
-
# splitter = RecursiveCharacterTextSplitter(separators=["EmailFrom"])
|
| 131 |
emails = splitter.split_documents(documents)
|
| 132 |
|
| 133 |
-
#
|
| 134 |
for email in emails:
|
|
|
|
| 135 |
email.metadata["source"] = file_path
|
|
|
|
|
|
|
| 136 |
|
| 137 |
# Create or update BM25 index
|
| 138 |
try:
|
| 139 |
# Update BM25 index if it exists
|
| 140 |
-
bm25_persist_directory =
|
| 141 |
retriever = BM25SRetriever.from_persisted_directory(bm25_persist_directory)
|
| 142 |
# Get new emails - ones which have not been indexed
|
| 143 |
new_emails = [email for email in emails if email not in retriever.docs]
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
import tempfile
|
| 5 |
import os
|
| 6 |
+
import re
|
| 7 |
|
| 8 |
# Local modules
|
| 9 |
+
from retriever import BuildRetriever
|
| 10 |
from mods.bm25s_retriever import BM25SRetriever
|
| 11 |
|
| 12 |
|
| 13 |
+
def ProcessFile(file_path, db_dir, collection, search_type):
|
| 14 |
"""
|
| 15 |
Wrapper function to process file for dense or sparse search
|
| 16 |
|
| 17 |
Args:
|
| 18 |
file_path: File to process
|
| 19 |
+
db_dir: Database directory
|
| 20 |
+
collection: Email collection
|
| 21 |
search_type: Type of search to use. Options: "dense", "sparse"
|
| 22 |
"""
|
| 23 |
|
|
|
|
| 68 |
try:
|
| 69 |
if search_type == "sparse":
|
| 70 |
# Handle sparse search with BM25
|
| 71 |
+
ProcessFileSparse(truncated_temp_file, file_path, db_dir, collection)
|
| 72 |
elif search_type == "dense":
|
| 73 |
# Handle dense search with ChromaDB
|
| 74 |
+
ProcessFileDense(truncated_temp_file, file_path, db_dir, collection)
|
| 75 |
else:
|
| 76 |
raise ValueError(f"Unsupported search type: {search_type}")
|
| 77 |
finally:
|
|
|
|
| 83 |
pass
|
| 84 |
|
| 85 |
|
| 86 |
+
def ProcessFileDense(cleaned_temp_file, file_path, db_dir, collection):
|
| 87 |
"""
|
| 88 |
Process file for dense vector search using ChromaDB
|
| 89 |
"""
|
| 90 |
# Get a retriever instance
|
| 91 |
+
retriever = BuildRetriever(db_dir, collection, "dense")
|
| 92 |
# Load cleaned text file
|
| 93 |
loader = TextLoader(cleaned_temp_file)
|
| 94 |
documents = loader.load()
|
| 95 |
# Use original file path for "source" key in metadata
|
| 96 |
documents[0].metadata["source"] = file_path
|
| 97 |
+
# Add year and month to metadata
|
| 98 |
+
filename = os.path.basename(file_path)
|
| 99 |
+
pattern = re.compile(r"(\d{4})-([A-Za-z]+)\.txt")
|
| 100 |
+
match = pattern.match(filename)
|
| 101 |
+
year = int(match.group(1))
|
| 102 |
+
month = match.group(2)
|
| 103 |
+
documents[0].metadata["year"] = year
|
| 104 |
+
documents[0].metadata["month"] = month
|
| 105 |
# Add file timestamp to metadata
|
| 106 |
mod_time = os.path.getmtime(file_path)
|
| 107 |
timestamp = datetime.fromtimestamp(mod_time).isoformat()
|
|
|
|
| 124 |
retriever.add_documents(documents_batch)
|
| 125 |
|
| 126 |
|
| 127 |
+
def ProcessFileSparse(cleaned_temp_file, file_path, db_dir, collection):
|
| 128 |
"""
|
| 129 |
Process file for sparse search using BM25
|
| 130 |
"""
|
|
|
|
| 137 |
splitter = RecursiveCharacterTextSplitter(
|
| 138 |
separators=["\n\n\nFrom"], chunk_size=1, chunk_overlap=0
|
| 139 |
)
|
|
|
|
|
|
|
| 140 |
emails = splitter.split_documents(documents)
|
| 141 |
|
| 142 |
+
# Add metadata keys
|
| 143 |
for email in emails:
|
| 144 |
+
# Original file path, e.g. "R-help/2025-December.txt"
|
| 145 |
email.metadata["source"] = file_path
|
| 146 |
+
# Collection name, e.g. "R-help"
|
| 147 |
+
email.metadata["collection"] = collection
|
| 148 |
|
| 149 |
# Create or update BM25 index
|
| 150 |
try:
|
| 151 |
# Update BM25 index if it exists
|
| 152 |
+
bm25_persist_directory = os.path.join(db_dir, collection, "bm25")
|
| 153 |
retriever = BM25SRetriever.from_persisted_directory(bm25_persist_directory)
|
| 154 |
# Get new emails - ones which have not been indexed
|
| 155 |
new_emails = [email for email in emails if email not in retriever.docs]
|
main.py
CHANGED
|
@@ -13,7 +13,7 @@ import ast
|
|
| 13 |
import os
|
| 14 |
|
| 15 |
# Local modules
|
| 16 |
-
from retriever import BuildRetriever
|
| 17 |
from prompts import answer_prompt
|
| 18 |
from index import ProcessFile
|
| 19 |
from graph import BuildGraph
|
|
@@ -38,29 +38,33 @@ httpx_logger = logging.getLogger("httpx")
|
|
| 38 |
httpx_logger.setLevel(logging.WARNING)
|
| 39 |
|
| 40 |
|
| 41 |
-
def
|
| 42 |
"""
|
| 43 |
Update vector store and sparse index for files in a directory, only adding new or updated files
|
| 44 |
|
| 45 |
Args:
|
| 46 |
-
|
|
|
|
| 47 |
|
| 48 |
Usage example:
|
| 49 |
-
|
| 50 |
"""
|
| 51 |
|
| 52 |
# TODO: use UUID to process only changed documents
|
| 53 |
# https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
# Get a dense retriever instance
|
| 56 |
-
retriever = BuildRetriever("dense")
|
| 57 |
|
| 58 |
# List all text files in target directory
|
| 59 |
-
file_paths = glob.glob(f"{
|
| 60 |
for file_path in file_paths:
|
| 61 |
|
| 62 |
# Process file for sparse search (BM25S)
|
| 63 |
-
ProcessFile(file_path, "sparse")
|
| 64 |
|
| 65 |
# Logic for dense search: skip file if already indexed
|
| 66 |
# Look for existing embeddings for this file
|
|
@@ -90,7 +94,7 @@ def ProcessDirectory(path):
|
|
| 90 |
update_file = True
|
| 91 |
|
| 92 |
if add_file:
|
| 93 |
-
ProcessFile(file_path, "dense")
|
| 94 |
|
| 95 |
if update_file:
|
| 96 |
print(f"Chroma: updated embeddings for {file_path}")
|
|
@@ -101,7 +105,7 @@ def ProcessDirectory(path):
|
|
| 101 |
]
|
| 102 |
files_to_keep = list(set(used_doc_ids))
|
| 103 |
# Get all files in the file store
|
| 104 |
-
file_store =
|
| 105 |
all_files = os.listdir(file_store)
|
| 106 |
# Iterate through the files and delete those not in the list
|
| 107 |
for file in all_files:
|
|
@@ -115,7 +119,9 @@ def ProcessDirectory(path):
|
|
| 115 |
|
| 116 |
|
| 117 |
def RunChain(
|
| 118 |
-
query,
|
|
|
|
|
|
|
| 119 |
search_type: str = "hybrid",
|
| 120 |
):
|
| 121 |
"""
|
|
@@ -123,14 +129,16 @@ def RunChain(
|
|
| 123 |
|
| 124 |
Args:
|
| 125 |
query: User's query
|
|
|
|
|
|
|
| 126 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 127 |
|
| 128 |
Example:
|
| 129 |
-
RunChain("What R functions are discussed?")
|
| 130 |
"""
|
| 131 |
|
| 132 |
# Get retriever instance
|
| 133 |
-
retriever = BuildRetriever(search_type)
|
| 134 |
|
| 135 |
if retriever is None:
|
| 136 |
return "No retriever available. Please process some documents first."
|
|
@@ -139,7 +147,7 @@ def RunChain(
|
|
| 139 |
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 140 |
|
| 141 |
# Get system prompt
|
| 142 |
-
system_prompt = answer_prompt()
|
| 143 |
|
| 144 |
# Create a prompt template
|
| 145 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
@@ -170,6 +178,8 @@ def RunChain(
|
|
| 170 |
|
| 171 |
def RunGraph(
|
| 172 |
query: str,
|
|
|
|
|
|
|
| 173 |
search_type: str = "hybrid",
|
| 174 |
top_k: int = 6,
|
| 175 |
thread_id=None,
|
|
@@ -178,6 +188,8 @@ def RunGraph(
|
|
| 178 |
|
| 179 |
Args:
|
| 180 |
query: User query to start the chat
|
|
|
|
|
|
|
| 181 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 182 |
top_k: Number of documents to retrieve
|
| 183 |
thread_id: Thread ID for memory (optional)
|
|
@@ -191,6 +203,8 @@ def RunGraph(
|
|
| 191 |
# Build the graph
|
| 192 |
graph_builder = BuildGraph(
|
| 193 |
chat_model,
|
|
|
|
|
|
|
| 194 |
search_type,
|
| 195 |
top_k,
|
| 196 |
)
|
|
|
|
| 13 |
import os
|
| 14 |
|
| 15 |
# Local modules
|
| 16 |
+
from retriever import BuildRetriever
|
| 17 |
from prompts import answer_prompt
|
| 18 |
from index import ProcessFile
|
| 19 |
from graph import BuildGraph
|
|
|
|
| 38 |
httpx_logger.setLevel(logging.WARNING)
|
| 39 |
|
| 40 |
|
| 41 |
+
def ProcessCollection(email_dir, db_dir):
|
| 42 |
"""
|
| 43 |
Update vector store and sparse index for files in a directory, only adding new or updated files
|
| 44 |
|
| 45 |
Args:
|
| 46 |
+
email_dir: Email directory to process
|
| 47 |
+
db_dir: Database directory
|
| 48 |
|
| 49 |
Usage example:
|
| 50 |
+
ProcessCollection("R-help", "db")
|
| 51 |
"""
|
| 52 |
|
| 53 |
# TODO: use UUID to process only changed documents
|
| 54 |
# https://stackoverflow.com/questions/76265631/chromadb-add-single-document-only-if-it-doesnt-exist
|
| 55 |
|
| 56 |
+
# Get last part of path
|
| 57 |
+
# https://stackoverflow.com/questions/3925096/how-to-get-only-the-last-part-of-a-path-in-python
|
| 58 |
+
collection = os.path.basename(os.path.normpath(email_dir))
|
| 59 |
# Get a dense retriever instance
|
| 60 |
+
retriever = BuildRetriever(db_dir, collection, "dense")
|
| 61 |
|
| 62 |
# List all text files in target directory
|
| 63 |
+
file_paths = glob.glob(f"{email_dir}/*.txt")
|
| 64 |
for file_path in file_paths:
|
| 65 |
|
| 66 |
# Process file for sparse search (BM25S)
|
| 67 |
+
ProcessFile(file_path, db_dir, collection, "sparse")
|
| 68 |
|
| 69 |
# Logic for dense search: skip file if already indexed
|
| 70 |
# Look for existing embeddings for this file
|
|
|
|
| 94 |
update_file = True
|
| 95 |
|
| 96 |
if add_file:
|
| 97 |
+
ProcessFile(file_path, db_dir, collection, "dense")
|
| 98 |
|
| 99 |
if update_file:
|
| 100 |
print(f"Chroma: updated embeddings for {file_path}")
|
|
|
|
| 105 |
]
|
| 106 |
files_to_keep = list(set(used_doc_ids))
|
| 107 |
# Get all files in the file store
|
| 108 |
+
file_store = os.path.join(db_dir, collection, "file_store")
|
| 109 |
all_files = os.listdir(file_store)
|
| 110 |
# Iterate through the files and delete those not in the list
|
| 111 |
for file in all_files:
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def RunChain(
|
| 122 |
+
query: str,
|
| 123 |
+
db_dir: str,
|
| 124 |
+
collection: str,
|
| 125 |
search_type: str = "hybrid",
|
| 126 |
):
|
| 127 |
"""
|
|
|
|
| 129 |
|
| 130 |
Args:
|
| 131 |
query: User's query
|
| 132 |
+
db_dir: Database directory
|
| 133 |
+
collection: Email collection
|
| 134 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 135 |
|
| 136 |
Example:
|
| 137 |
+
RunChain("What R functions are discussed?", "db", "R-help")
|
| 138 |
"""
|
| 139 |
|
| 140 |
# Get retriever instance
|
| 141 |
+
retriever = BuildRetriever(db_dir, collection, search_type)
|
| 142 |
|
| 143 |
if retriever is None:
|
| 144 |
return "No retriever available. Please process some documents first."
|
|
|
|
| 147 |
chat_model = ChatOpenAI(model=openai_model, temperature=0)
|
| 148 |
|
| 149 |
# Get system prompt
|
| 150 |
+
system_prompt = answer_prompt(collection)
|
| 151 |
|
| 152 |
# Create a prompt template
|
| 153 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
|
|
| 178 |
|
| 179 |
def RunGraph(
|
| 180 |
query: str,
|
| 181 |
+
db_dir: str,
|
| 182 |
+
collection: str,
|
| 183 |
search_type: str = "hybrid",
|
| 184 |
top_k: int = 6,
|
| 185 |
thread_id=None,
|
|
|
|
| 188 |
|
| 189 |
Args:
|
| 190 |
query: User query to start the chat
|
| 191 |
+
db_dir: Database directory
|
| 192 |
+
collection: Email collection
|
| 193 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
| 194 |
top_k: Number of documents to retrieve
|
| 195 |
thread_id: Thread ID for memory (optional)
|
|
|
|
| 203 |
# Build the graph
|
| 204 |
graph_builder = BuildGraph(
|
| 205 |
chat_model,
|
| 206 |
+
db_dir,
|
| 207 |
+
collection,
|
| 208 |
search_type,
|
| 209 |
top_k,
|
| 210 |
)
|
prompts.py
CHANGED
|
@@ -12,15 +12,28 @@ def check_prompt(prompt):
|
|
| 12 |
return prompt
|
| 13 |
|
| 14 |
|
| 15 |
-
def query_prompt():
|
| 16 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Get start and end months from database
|
| 19 |
-
start, end = get_start_end_months(get_sources())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
prompt = (
|
| 22 |
f"Today Date: {date.today()}. "
|
| 23 |
-
"You are a
|
| 24 |
"Write a search query to retrieve emails relevant to the user's question. "
|
| 25 |
"Do not answer the user's question and do not ask the user for more information. "
|
| 26 |
# gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval. "
|
|
@@ -29,31 +42,36 @@ def query_prompt():
|
|
| 29 |
"Always use retrieve_emails with a non-empty query string for search_query. "
|
| 30 |
"For general summaries, use retrieve_emails(search_query='R'). "
|
| 31 |
"For questions about years, use retrieve_emails(search_query=<query>, start_year=, end_year=). "
|
| 32 |
-
"For questions about months, use 3-letter abbreviations (Jan...Dec) for the '
|
| 33 |
"Use all previous messages as context to formulate your search query. " # Gemma
|
| 34 |
"You should always retrieve more emails based on context and the most recent question. " # Qwen
|
| 35 |
-
|
| 36 |
-
# "You must perform the search yourself. Do not tell the user how to retrieve emails. " # Qwen
|
| 37 |
-
# "Do not use your memory or knowledge to answer the user's question. Only retrieve emails based on the user's question. " # Qwen
|
| 38 |
-
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list. "
|
| 39 |
)
|
| 40 |
prompt = check_prompt(prompt)
|
| 41 |
|
| 42 |
return prompt
|
| 43 |
|
| 44 |
|
| 45 |
-
def answer_prompt():
|
| 46 |
"""Return system prompt for answer step"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
prompt = (
|
| 48 |
f"Today Date: {date.today()}. "
|
| 49 |
-
"You are a helpful chatbot
|
| 50 |
"Summarize the retrieved emails to answer the user's question or query. "
|
| 51 |
"If any of the retrieved emails are irrelevant (e.g. wrong dates), then do not use them. "
|
| 52 |
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails. "
|
| 53 |
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails. "
|
| 54 |
"Example: For a question about using lm(), take examples of lm() from the retrieved emails to answer the user's question. "
|
| 55 |
# "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages. "
|
| 56 |
-
"Summarize the content of the emails rather than copying the headers. " # Qwen
|
| 57 |
"You must include inline citations (email senders and dates) in each part of your response. "
|
| 58 |
"Only answer general questions about R if the answer is in the retrieved emails. "
|
| 59 |
"Only include URLs if they were used by human authors (not in email headers), and do not modify any URLs. " # Qwen, Gemma
|
|
|
|
| 12 |
return prompt
|
| 13 |
|
| 14 |
|
| 15 |
+
def query_prompt(db_dir, collection):
|
| 16 |
+
"""
|
| 17 |
+
Return system prompt for query step
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
db_dir: Database directory
|
| 21 |
+
collection: Email collection
|
| 22 |
+
"""
|
| 23 |
|
| 24 |
# Get start and end months from database
|
| 25 |
+
start, end = get_start_end_months(get_sources(db_dir, collection))
|
| 26 |
+
# Use appropriate list topic
|
| 27 |
+
if collection == "R-help":
|
| 28 |
+
topic = "R programming"
|
| 29 |
+
elif collection == "R-devel":
|
| 30 |
+
topic = "R development"
|
| 31 |
+
elif collection == "R-package-devel":
|
| 32 |
+
topic = "R package development"
|
| 33 |
|
| 34 |
prompt = (
|
| 35 |
f"Today Date: {date.today()}. "
|
| 36 |
+
f"You are a search assistant for retrieving information about {topic} from the {collection} mailing list archives. "
|
| 37 |
"Write a search query to retrieve emails relevant to the user's question. "
|
| 38 |
"Do not answer the user's question and do not ask the user for more information. "
|
| 39 |
# gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval. "
|
|
|
|
| 42 |
"Always use retrieve_emails with a non-empty query string for search_query. "
|
| 43 |
"For general summaries, use retrieve_emails(search_query='R'). "
|
| 44 |
"For questions about years, use retrieve_emails(search_query=<query>, start_year=, end_year=). "
|
| 45 |
+
"For questions about months, use 3-letter abbreviations (Jan...Dec) for the 'months' argument. "
|
| 46 |
"Use all previous messages as context to formulate your search query. " # Gemma
|
| 47 |
"You should always retrieve more emails based on context and the most recent question. " # Qwen
|
| 48 |
+
f"If you decide not to retrieve emails, tell the user how to improve their question to search the {collection} mailing list. "
|
|
|
|
|
|
|
|
|
|
| 49 |
)
|
| 50 |
prompt = check_prompt(prompt)
|
| 51 |
|
| 52 |
return prompt
|
| 53 |
|
| 54 |
|
| 55 |
+
def answer_prompt(collection):
|
| 56 |
"""Return system prompt for answer step"""
|
| 57 |
+
|
| 58 |
+
# Use appropriate list topic
|
| 59 |
+
if collection == "R-help":
|
| 60 |
+
topic = "R programming"
|
| 61 |
+
elif collection == "R-devel":
|
| 62 |
+
topic = "R development"
|
| 63 |
+
elif collection == "R-package-devel":
|
| 64 |
+
topic = "R package development"
|
| 65 |
+
|
| 66 |
prompt = (
|
| 67 |
f"Today Date: {date.today()}. "
|
| 68 |
+
f"You are a helpful chatbot that can answer questions about {topic} based on the {collection} mailing list archives. "
|
| 69 |
"Summarize the retrieved emails to answer the user's question or query. "
|
| 70 |
"If any of the retrieved emails are irrelevant (e.g. wrong dates), then do not use them. "
|
| 71 |
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails. "
|
| 72 |
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails. "
|
| 73 |
"Example: For a question about using lm(), take examples of lm() from the retrieved emails to answer the user's question. "
|
| 74 |
# "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages. "
|
|
|
|
| 75 |
"You must include inline citations (email senders and dates) in each part of your response. "
|
| 76 |
"Only answer general questions about R if the answer is in the retrieved emails. "
|
| 77 |
"Only include URLs if they were used by human authors (not in email headers), and do not modify any URLs. " # Qwen, Gemma
|
retriever.py
CHANGED
|
@@ -11,54 +11,69 @@ from typing import Any, Optional
|
|
| 11 |
import chromadb
|
| 12 |
import os
|
| 13 |
import re
|
|
|
|
| 14 |
|
| 15 |
# Local modules
|
| 16 |
from mods.bm25s_retriever import BM25SRetriever
|
| 17 |
from mods.file_system import LocalFileStore
|
| 18 |
-
|
| 19 |
-
# Database directory
|
| 20 |
-
db_dir = "db"
|
| 21 |
|
| 22 |
|
| 23 |
def BuildRetriever(
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
):
|
| 29 |
"""
|
| 30 |
Build retriever instance.
|
| 31 |
All retriever types are configured to return up to 6 documents for fair comparison in evals.
|
| 32 |
|
| 33 |
Args:
|
|
|
|
|
|
|
| 34 |
search_type: Type of search to use. Options: "dense", "sparse", "hybrid"
|
| 35 |
top_k: Number of documents to retrieve for "dense" and "sparse"
|
| 36 |
start_year: Start year (optional)
|
| 37 |
end_year: End year (optional)
|
|
|
|
| 38 |
"""
|
| 39 |
if search_type == "dense":
|
| 40 |
-
if not (start_year or end_year):
|
| 41 |
-
# No year filtering, so directly use base retriever
|
| 42 |
-
return BuildRetrieverDense(
|
|
|
|
|
|
|
| 43 |
else:
|
| 44 |
-
# Get
|
| 45 |
-
base_retriever = BuildRetrieverDense(
|
|
|
|
|
|
|
| 46 |
return TopKRetriever(
|
| 47 |
base_retriever=base_retriever,
|
| 48 |
top_k=top_k,
|
| 49 |
start_year=start_year,
|
| 50 |
end_year=end_year,
|
|
|
|
| 51 |
)
|
| 52 |
if search_type == "sparse":
|
| 53 |
-
if not (start_year or end_year):
|
| 54 |
-
return BuildRetrieverSparse(
|
|
|
|
|
|
|
| 55 |
else:
|
| 56 |
-
base_retriever = BuildRetrieverSparse(
|
|
|
|
|
|
|
| 57 |
return TopKRetriever(
|
| 58 |
base_retriever=base_retriever,
|
| 59 |
top_k=top_k,
|
| 60 |
start_year=start_year,
|
| 61 |
end_year=end_year,
|
|
|
|
| 62 |
)
|
| 63 |
elif search_type == "hybrid":
|
| 64 |
# Hybrid search (dense + sparse) - use ensemble method
|
|
@@ -66,16 +81,22 @@ def BuildRetriever(
|
|
| 66 |
# Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k
|
| 67 |
# https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python
|
| 68 |
dense_retriever = BuildRetriever(
|
|
|
|
|
|
|
| 69 |
"dense",
|
| 70 |
(top_k // 2),
|
| 71 |
start_year,
|
| 72 |
end_year,
|
|
|
|
| 73 |
)
|
| 74 |
sparse_retriever = BuildRetriever(
|
|
|
|
|
|
|
| 75 |
"sparse",
|
| 76 |
-(top_k // -2),
|
| 77 |
start_year,
|
| 78 |
end_year,
|
|
|
|
| 79 |
)
|
| 80 |
ensemble_retriever = EnsembleRetriever(
|
| 81 |
retrievers=[dense_retriever, sparse_retriever], weights=[1, 1]
|
|
@@ -85,31 +106,38 @@ def BuildRetriever(
|
|
| 85 |
raise ValueError(f"Unsupported search type: {search_type}")
|
| 86 |
|
| 87 |
|
| 88 |
-
def BuildRetrieverSparse(top_k=6):
|
| 89 |
"""
|
| 90 |
Build sparse retriever instance
|
| 91 |
|
| 92 |
Args:
|
|
|
|
|
|
|
| 93 |
top_k: Number of documents to retrieve
|
| 94 |
"""
|
| 95 |
# BM25 persistent directory
|
| 96 |
-
bm25_persist_directory =
|
| 97 |
if not os.path.exists(bm25_persist_directory):
|
| 98 |
os.makedirs(bm25_persist_directory)
|
| 99 |
|
| 100 |
# Use BM25 sparse search
|
|
|
|
|
|
|
|
|
|
| 101 |
retriever = BM25SRetriever.from_persisted_directory(
|
| 102 |
path=bm25_persist_directory,
|
| 103 |
-
k=
|
| 104 |
)
|
| 105 |
return retriever
|
| 106 |
|
| 107 |
|
| 108 |
-
def BuildRetrieverDense(top_k=6):
|
| 109 |
"""
|
| 110 |
Build dense retriever instance with ChromaDB vectorstore
|
| 111 |
|
| 112 |
Args:
|
|
|
|
|
|
|
| 113 |
top_k: Number of documents to retrieve
|
| 114 |
"""
|
| 115 |
|
|
@@ -117,15 +145,15 @@ def BuildRetrieverDense(top_k=6):
|
|
| 117 |
embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
|
| 118 |
# Create vector store
|
| 119 |
client_settings = chromadb.config.Settings(anonymized_telemetry=False)
|
| 120 |
-
persist_directory =
|
| 121 |
vectorstore = Chroma(
|
| 122 |
-
collection_name=
|
| 123 |
embedding_function=embedding_function,
|
| 124 |
client_settings=client_settings,
|
| 125 |
persist_directory=persist_directory,
|
| 126 |
)
|
| 127 |
# The storage layer for the parent documents
|
| 128 |
-
file_store =
|
| 129 |
byte_store = LocalFileStore(file_store)
|
| 130 |
# Text splitter for child documents
|
| 131 |
child_splitter = RecursiveCharacterTextSplitter(
|
|
@@ -152,18 +180,21 @@ def BuildRetrieverDense(top_k=6):
|
|
| 152 |
|
| 153 |
|
| 154 |
class TopKRetriever(BaseRetriever):
|
| 155 |
-
"""
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
|
|
|
| 158 |
|
|
|
|
| 159 |
base_retriever: RetrieverLike
|
| 160 |
-
|
| 161 |
-
|
| 162 |
top_k: int = 6
|
| 163 |
-
|
| 164 |
-
|
| 165 |
start_year: Optional[int] = None
|
| 166 |
end_year: Optional[int] = None
|
|
|
|
| 167 |
|
| 168 |
def _get_relevant_documents(
|
| 169 |
self,
|
|
@@ -172,7 +203,8 @@ class TopKRetriever(BaseRetriever):
|
|
| 172 |
run_manager: CallbackManagerForRetrieverRun,
|
| 173 |
**kwargs: Any,
|
| 174 |
) -> list[Document]:
|
| 175 |
-
"""
|
|
|
|
| 176 |
|
| 177 |
Returns:
|
| 178 |
Sequence of documents
|
|
@@ -183,28 +215,84 @@ class TopKRetriever(BaseRetriever):
|
|
| 183 |
)
|
| 184 |
if retrieved_docs:
|
| 185 |
|
| 186 |
-
# Get the
|
| 187 |
sources = [doc.metadata["source"] for doc in filtered_docs]
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
]
|
| 192 |
-
#
|
| 193 |
-
years = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
#
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
|
|
|
|
|
|
| 203 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
if self.start_year or self.end_year:
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
filtered_docs = [
|
| 207 |
-
doc for doc,
|
| 208 |
]
|
| 209 |
|
| 210 |
# Return the top k docs
|
|
|
|
| 11 |
import chromadb
|
| 12 |
import os
|
| 13 |
import re
|
| 14 |
+
from calendar import month_abbr, month_name
|
| 15 |
|
| 16 |
# Local modules
|
| 17 |
from mods.bm25s_retriever import BM25SRetriever
|
| 18 |
from mods.file_system import LocalFileStore
|
| 19 |
+
from util import get_sources
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def BuildRetriever(
|
| 23 |
+
db_dir: str,
|
| 24 |
+
collection: str,
|
| 25 |
+
search_type: str,
|
| 26 |
+
top_k: Optional[int] = 6,
|
| 27 |
+
start_year: Optional[int] = None,
|
| 28 |
+
end_year: Optional[int] = None,
|
| 29 |
+
months: Optional[list[str]] = None,
|
| 30 |
):
|
| 31 |
"""
|
| 32 |
Build retriever instance.
|
| 33 |
All retriever types are configured to return up to 6 documents for fair comparison in evals.
|
| 34 |
|
| 35 |
Args:
|
| 36 |
+
db_dir: Database directory
|
| 37 |
+
collection: Email collection
|
| 38 |
search_type: Type of search to use. Options: "dense", "sparse", "hybrid"
|
| 39 |
top_k: Number of documents to retrieve for "dense" and "sparse"
|
| 40 |
start_year: Start year (optional)
|
| 41 |
end_year: End year (optional)
|
| 42 |
+
months: List of months (3-letter abbreviations) (optional)
|
| 43 |
"""
|
| 44 |
if search_type == "dense":
|
| 45 |
+
if not (start_year or end_year or months):
|
| 46 |
+
# No year or month filtering, so directly use base retriever
|
| 47 |
+
return BuildRetrieverDense(
|
| 48 |
+
db_dir=db_dir, collection=collection, top_k=top_k
|
| 49 |
+
)
|
| 50 |
else:
|
| 51 |
+
# Get 10000 documents then keep top_k filtered by year and month
|
| 52 |
+
base_retriever = BuildRetrieverDense(
|
| 53 |
+
db_dir=db_dir, collection=collection, top_k=10000
|
| 54 |
+
)
|
| 55 |
return TopKRetriever(
|
| 56 |
base_retriever=base_retriever,
|
| 57 |
top_k=top_k,
|
| 58 |
start_year=start_year,
|
| 59 |
end_year=end_year,
|
| 60 |
+
months=months,
|
| 61 |
)
|
| 62 |
if search_type == "sparse":
|
| 63 |
+
if not (start_year or end_year or months):
|
| 64 |
+
return BuildRetrieverSparse(
|
| 65 |
+
db_dir=db_dir, collection=collection, top_k=top_k
|
| 66 |
+
)
|
| 67 |
else:
|
| 68 |
+
base_retriever = BuildRetrieverSparse(
|
| 69 |
+
db_dir=db_dir, collection=collection, top_k=10000
|
| 70 |
+
)
|
| 71 |
return TopKRetriever(
|
| 72 |
base_retriever=base_retriever,
|
| 73 |
top_k=top_k,
|
| 74 |
start_year=start_year,
|
| 75 |
end_year=end_year,
|
| 76 |
+
months=months,
|
| 77 |
)
|
| 78 |
elif search_type == "hybrid":
|
| 79 |
# Hybrid search (dense + sparse) - use ensemble method
|
|
|
|
| 81 |
# Use floor (top_k // 2) and ceiling -(top_k // -2) to divide odd values of top_k
|
| 82 |
# https://stackoverflow.com/questions/14822184/is-there-a-ceiling-equivalent-of-operator-in-python
|
| 83 |
dense_retriever = BuildRetriever(
|
| 84 |
+
db_dir,
|
| 85 |
+
collection,
|
| 86 |
"dense",
|
| 87 |
(top_k // 2),
|
| 88 |
start_year,
|
| 89 |
end_year,
|
| 90 |
+
months,
|
| 91 |
)
|
| 92 |
sparse_retriever = BuildRetriever(
|
| 93 |
+
db_dir,
|
| 94 |
+
collection,
|
| 95 |
"sparse",
|
| 96 |
-(top_k // -2),
|
| 97 |
start_year,
|
| 98 |
end_year,
|
| 99 |
+
months,
|
| 100 |
)
|
| 101 |
ensemble_retriever = EnsembleRetriever(
|
| 102 |
retrievers=[dense_retriever, sparse_retriever], weights=[1, 1]
|
|
|
|
| 106 |
raise ValueError(f"Unsupported search type: {search_type}")
|
| 107 |
|
| 108 |
|
| 109 |
+
def BuildRetrieverSparse(db_dir, collection, top_k=6):
|
| 110 |
"""
|
| 111 |
Build sparse retriever instance
|
| 112 |
|
| 113 |
Args:
|
| 114 |
+
db_dir: Database directory
|
| 115 |
+
collection: Email collection
|
| 116 |
top_k: Number of documents to retrieve
|
| 117 |
"""
|
| 118 |
# BM25 persistent directory
|
| 119 |
+
bm25_persist_directory = os.path.join(db_dir, collection, "bm25")
|
| 120 |
if not os.path.exists(bm25_persist_directory):
|
| 121 |
os.makedirs(bm25_persist_directory)
|
| 122 |
|
| 123 |
# Use BM25 sparse search
|
| 124 |
+
# top_k can't be larger than the corpus size (number of emails)
|
| 125 |
+
corpus_size = len(get_sources(db_dir, collection))
|
| 126 |
+
k = top_k if top_k < corpus_size else corpus_size
|
| 127 |
retriever = BM25SRetriever.from_persisted_directory(
|
| 128 |
path=bm25_persist_directory,
|
| 129 |
+
k=k,
|
| 130 |
)
|
| 131 |
return retriever
|
| 132 |
|
| 133 |
|
| 134 |
+
def BuildRetrieverDense(db_dir, collection, top_k=6):
|
| 135 |
"""
|
| 136 |
Build dense retriever instance with ChromaDB vectorstore
|
| 137 |
|
| 138 |
Args:
|
| 139 |
+
db_dir: Database directory
|
| 140 |
+
collection: Email collection
|
| 141 |
top_k: Number of documents to retrieve
|
| 142 |
"""
|
| 143 |
|
|
|
|
| 145 |
embedding_function = OpenAIEmbeddings(model="text-embedding-3-small")
|
| 146 |
# Create vector store
|
| 147 |
client_settings = chromadb.config.Settings(anonymized_telemetry=False)
|
| 148 |
+
persist_directory = os.path.join(db_dir, collection, "chroma")
|
| 149 |
vectorstore = Chroma(
|
| 150 |
+
collection_name=collection,
|
| 151 |
embedding_function=embedding_function,
|
| 152 |
client_settings=client_settings,
|
| 153 |
persist_directory=persist_directory,
|
| 154 |
)
|
| 155 |
# The storage layer for the parent documents
|
| 156 |
+
file_store = os.path.join(db_dir, collection, "file_store")
|
| 157 |
byte_store = LocalFileStore(file_store)
|
| 158 |
# Text splitter for child documents
|
| 159 |
child_splitter = RecursiveCharacterTextSplitter(
|
|
|
|
| 180 |
|
| 181 |
|
| 182 |
class TopKRetriever(BaseRetriever):
|
| 183 |
+
"""
|
| 184 |
+
Retriever that wraps a base retriever and returns the top k documents,
|
| 185 |
+
optionally matching given start and/or end years.
|
| 186 |
|
| 187 |
+
Code adapted from langchain/retrievers/contextual_compression.py
|
| 188 |
+
"""
|
| 189 |
|
| 190 |
+
# Base Retriever to use for getting relevant documents
|
| 191 |
base_retriever: RetrieverLike
|
| 192 |
+
# Number of documents to return
|
|
|
|
| 193 |
top_k: int = 6
|
| 194 |
+
# Optional year and month arguments
|
|
|
|
| 195 |
start_year: Optional[int] = None
|
| 196 |
end_year: Optional[int] = None
|
| 197 |
+
months: Optional[list[str]] = None
|
| 198 |
|
| 199 |
def _get_relevant_documents(
|
| 200 |
self,
|
|
|
|
| 203 |
run_manager: CallbackManagerForRetrieverRun,
|
| 204 |
**kwargs: Any,
|
| 205 |
) -> list[Document]:
|
| 206 |
+
"""
|
| 207 |
+
Return the top k documents within start and end years (and months) if given.
|
| 208 |
|
| 209 |
Returns:
|
| 210 |
Sequence of documents
|
|
|
|
| 215 |
)
|
| 216 |
if retrieved_docs:
|
| 217 |
|
| 218 |
+
# Get the email source files and basenames
|
| 219 |
sources = [doc.metadata["source"] for doc in filtered_docs]
|
| 220 |
+
filenames = [os.path.basename(source) for source in sources]
|
| 221 |
+
# Get the years and months
|
| 222 |
+
pattern = re.compile(r"(\d{4})-([A-Za-z]+)\.txt")
|
| 223 |
+
matches = [pattern.match(filename) for filename in filenames]
|
| 224 |
+
# Extract years and month names, handling None matches
|
| 225 |
+
years = []
|
| 226 |
+
month_names = []
|
| 227 |
+
for match in matches:
|
| 228 |
+
if match:
|
| 229 |
+
years.append(int(match.group(1)))
|
| 230 |
+
month_names.append(match.group(2))
|
| 231 |
+
else:
|
| 232 |
+
years.append(None)
|
| 233 |
+
month_names.append(None)
|
| 234 |
|
| 235 |
+
# Create mapping from 3-letter abbreviations to full month names
|
| 236 |
+
# month_abbr[0] is empty string, month_abbr[1] is "Jan", etc.
|
| 237 |
+
# month_name[0] is empty string, month_name[1] is "January", etc.
|
| 238 |
+
abbr_to_full = {month_abbr[i].lower(): month_name[i] for i in range(1, 13)}
|
| 239 |
+
|
| 240 |
+
# Convert months list (3-letter abbreviations) to full month names
|
| 241 |
+
target_months = None
|
| 242 |
+
if self.months:
|
| 243 |
+
target_months = [
|
| 244 |
+
abbr_to_full.get(month.lower()) for month in self.months
|
| 245 |
]
|
| 246 |
+
# Filter out None values in case of invalid abbreviations
|
| 247 |
+
target_months = [m for m in target_months if m is not None]
|
| 248 |
+
|
| 249 |
+
# Initialize filter flags
|
| 250 |
+
year_filter = None
|
| 251 |
+
month_filter = None
|
| 252 |
+
|
| 253 |
+
# Filtering by year
|
| 254 |
if self.start_year or self.end_year:
|
| 255 |
+
if self.start_year and self.end_year:
|
| 256 |
+
year_filter = [
|
| 257 |
+
year is not None
|
| 258 |
+
and year >= self.start_year
|
| 259 |
+
and year <= self.end_year
|
| 260 |
+
for year in years
|
| 261 |
+
]
|
| 262 |
+
elif self.start_year:
|
| 263 |
+
year_filter = [
|
| 264 |
+
year is not None and year >= self.start_year for year in years
|
| 265 |
+
]
|
| 266 |
+
elif self.end_year:
|
| 267 |
+
year_filter = [
|
| 268 |
+
year is not None and year <= self.end_year for year in years
|
| 269 |
+
]
|
| 270 |
+
|
| 271 |
+
# Filtering by month
|
| 272 |
+
if target_months:
|
| 273 |
+
month_filter = [
|
| 274 |
+
month_name is not None and month_name in target_months
|
| 275 |
+
for month_name in month_names
|
| 276 |
+
]
|
| 277 |
+
|
| 278 |
+
# Combine filters
|
| 279 |
+
if year_filter is not None and month_filter is not None:
|
| 280 |
+
# Both year and month filters
|
| 281 |
+
combined_filter = [
|
| 282 |
+
year and month for year, month in zip(year_filter, month_filter)
|
| 283 |
+
]
|
| 284 |
+
filtered_docs = [
|
| 285 |
+
doc for doc, keep in zip(retrieved_docs, combined_filter) if keep
|
| 286 |
+
]
|
| 287 |
+
elif year_filter is not None:
|
| 288 |
+
# Only year filter
|
| 289 |
+
filtered_docs = [
|
| 290 |
+
doc for doc, keep in zip(retrieved_docs, year_filter) if keep
|
| 291 |
+
]
|
| 292 |
+
elif month_filter is not None:
|
| 293 |
+
# Only month filter
|
| 294 |
filtered_docs = [
|
| 295 |
+
doc for doc, keep in zip(retrieved_docs, month_filter) if keep
|
| 296 |
]
|
| 297 |
|
| 298 |
# Return the top k docs
|
util.py
CHANGED
|
@@ -1,18 +1,23 @@
|
|
| 1 |
from calendar import month_name
|
| 2 |
-
from retriever import BuildRetriever, db_dir
|
| 3 |
import json
|
| 4 |
import os
|
| 5 |
import re
|
| 6 |
|
| 7 |
|
| 8 |
-
def get_sources():
|
| 9 |
"""
|
| 10 |
-
Return the source files indexed in the database
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
-
# Path to
|
| 13 |
-
file_path = os.path.join(db_dir, "bm25", "corpus.jsonl")
|
| 14 |
|
| 15 |
-
#
|
| 16 |
with open(file_path, "r", encoding="utf-8") as file:
|
| 17 |
# Parse each line as a JSON object
|
| 18 |
sources = [json.loads(line.strip())["metadata"]["source"] for line in file]
|
|
@@ -24,11 +29,13 @@ def get_start_end_months(sources):
|
|
| 24 |
"""
|
| 25 |
Given a set of filenames like 'R-help/2024-January.txt', return the earliest and latest month in 'Month YYYY' format.
|
| 26 |
"""
|
| 27 |
-
|
|
|
|
|
|
|
| 28 |
months = []
|
| 29 |
-
# Start with the unique
|
| 30 |
-
|
| 31 |
-
for src in
|
| 32 |
m = pattern.match(src)
|
| 33 |
if m:
|
| 34 |
year = int(m.group(1))
|
|
|
|
| 1 |
from calendar import month_name
|
|
|
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import re
|
| 5 |
|
| 6 |
|
| 7 |
+
def get_sources(db_dir, collection):
|
| 8 |
"""
|
| 9 |
+
Return the source files for all emails indexed in the database.
|
| 10 |
+
The source file names look like 'R-help/2024-April.txt' and are repeated
|
| 11 |
+
for as many tims as there are indexed emails from each source file.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
db_dir: Database directory
|
| 15 |
+
collection: Email collection
|
| 16 |
"""
|
| 17 |
+
# Path to the JSON Lines file
|
| 18 |
+
file_path = os.path.join(db_dir, collection, "bm25", "corpus.jsonl")
|
| 19 |
|
| 20 |
+
# Read the JSON Lines file
|
| 21 |
with open(file_path, "r", encoding="utf-8") as file:
|
| 22 |
# Parse each line as a JSON object
|
| 23 |
sources = [json.loads(line.strip())["metadata"]["source"] for line in file]
|
|
|
|
| 29 |
"""
|
| 30 |
Given a set of filenames like 'R-help/2024-January.txt', return the earliest and latest month in 'Month YYYY' format.
|
| 31 |
"""
|
| 32 |
+
# Get just the file names (e.g. 2024-January.txt)
|
| 33 |
+
filenames = [os.path.basename(source) for source in sources]
|
| 34 |
+
pattern = re.compile(r"(\d{4})-([A-Za-z]+)\.txt")
|
| 35 |
months = []
|
| 36 |
+
# Start with the unique filenames
|
| 37 |
+
unique_filenames = set(filenames)
|
| 38 |
+
for src in unique_filenames:
|
| 39 |
m = pattern.match(src)
|
| 40 |
if m:
|
| 41 |
year = int(m.group(1))
|