Spaces:
Running
Running
Smarter document context retrieval
Browse files* Retrieved documents re-ranking w/ SPLADE-v3
* Enable news as a default source
- app.py +2 -1
- ask_candid/agents/elastic.py +246 -54
- ask_candid/retrieval/elastic.py +39 -160
- ask_candid/retrieval/sources/candid_blog.py +22 -1
- ask_candid/retrieval/sources/candid_help.py +20 -1
- ask_candid/retrieval/sources/candid_learning.py +22 -1
- ask_candid/retrieval/sources/candid_news.py +14 -1
- ask_candid/retrieval/sources/issuelab.py +27 -2
- ask_candid/retrieval/sources/schema.py +12 -1
- ask_candid/retrieval/sources/utils.py +47 -0
- ask_candid/retrieval/sources/youtube.py +20 -1
- ask_candid/retrieval/sparse_lexical.py +14 -4
- ask_candid/tools/elastic/index_search_tool.py +9 -2
- ask_candid/tools/question_reformulation.py +43 -39
app.py
CHANGED
|
@@ -113,7 +113,8 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
| 113 |
with gr.Accordion(label="Advanced settings", open=False):
|
| 114 |
es_indices = gr.CheckboxGroup(
|
| 115 |
choices=list(ALL_INDICES),
|
| 116 |
-
value=[idx for idx in ALL_INDICES if "news" not in idx],
|
|
|
|
| 117 |
label="Sources to include",
|
| 118 |
interactive=True,
|
| 119 |
)
|
|
|
|
| 113 |
with gr.Accordion(label="Advanced settings", open=False):
|
| 114 |
es_indices = gr.CheckboxGroup(
|
| 115 |
choices=list(ALL_INDICES),
|
| 116 |
+
# value=[idx for idx in ALL_INDICES if "news" not in idx],
|
| 117 |
+
value=list(ALL_INDICES),
|
| 118 |
label="Sources to include",
|
| 119 |
interactive=True,
|
| 120 |
)
|
ask_candid/agents/elastic.py
CHANGED
|
@@ -2,6 +2,9 @@ from typing import TypedDict, List
|
|
| 2 |
from functools import partial
|
| 3 |
import json
|
| 4 |
import ast
|
|
|
|
|
|
|
|
|
|
| 5 |
from pydantic import BaseModel, Field
|
| 6 |
|
| 7 |
from langchain_core.runnables import RunnableSequence
|
|
@@ -24,10 +27,118 @@ from ask_candid.tools.elastic.index_search_tool import create_search_tool
|
|
| 24 |
tools = [
|
| 25 |
IndexShowDataTool(),
|
| 26 |
IndexDetailsTool(),
|
| 27 |
-
create_search_tool(),
|
| 28 |
]
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class GraphState(TypedDict):
|
| 32 |
query: str = Field(
|
| 33 |
..., description="The user's query to be processed by the system."
|
|
@@ -47,6 +158,7 @@ class GraphState(TypedDict):
|
|
| 47 |
...,
|
| 48 |
description="The Elasticsearch query result generated or used by the agent.",
|
| 49 |
)
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
class AnalysisResult(BaseModel):
|
|
@@ -334,8 +446,6 @@ def build_compute_graph(llm: LLM) -> StateGraph:
|
|
| 334 |
|
| 335 |
|
| 336 |
class ElasticGraph(StateGraph):
|
| 337 |
-
"""Elastic Seach Agent State Graph"""
|
| 338 |
-
|
| 339 |
llm: LLM
|
| 340 |
tools: List[Tool]
|
| 341 |
|
|
@@ -345,6 +455,41 @@ class ElasticGraph(StateGraph):
|
|
| 345 |
self.tools = tools
|
| 346 |
self.construct_graph()
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
def agent_factory(self) -> AgentExecutor:
|
| 349 |
"""
|
| 350 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
|
@@ -387,7 +532,7 @@ class ElasticGraph(StateGraph):
|
|
| 387 |
return_intermediate_steps=True,
|
| 388 |
)
|
| 389 |
|
| 390 |
-
def agent_factory_claude(self) -> AgentExecutor:
|
| 391 |
"""
|
| 392 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
| 393 |
|
|
@@ -400,15 +545,6 @@ class ElasticGraph(StateGraph):
|
|
| 400 |
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
| 401 |
providing detailed intermediate steps for transparency.
|
| 402 |
"""
|
| 403 |
-
prefix = """
|
| 404 |
-
You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
| 405 |
-
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
| 406 |
-
Guidelines for generating right elastic seach query:
|
| 407 |
-
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
| 408 |
-
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
| 409 |
-
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
| 410 |
-
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
| 411 |
-
"""
|
| 412 |
prompt = ChatPromptTemplate.from_messages(
|
| 413 |
[
|
| 414 |
("system", f"You are a helpful elasticsearch assistant. {prefix}"),
|
|
@@ -418,9 +554,19 @@ class ElasticGraph(StateGraph):
|
|
| 418 |
]
|
| 419 |
)
|
| 420 |
|
| 421 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
agent_executor = AgentExecutor.from_agent_and_tools(
|
| 423 |
-
agent=agent,
|
|
|
|
|
|
|
|
|
|
| 424 |
)
|
| 425 |
# Create the agent
|
| 426 |
return agent_executor
|
|
@@ -467,6 +613,8 @@ class ElasticGraph(StateGraph):
|
|
| 467 |
|
| 468 |
def grant_index_agent(self, state: GraphState) -> GraphState:
|
| 469 |
print("> Grant Index Agent")
|
|
|
|
|
|
|
| 470 |
input_data = {
|
| 471 |
"input": f"""
|
| 472 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
@@ -479,52 +627,51 @@ class ElasticGraph(StateGraph):
|
|
| 479 |
Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
|
| 480 |
by searching for relevant keywords.
|
| 481 |
6. Present the response in a clear and natural language format, addressing the user's question directly.
|
| 482 |
-
|
| 483 |
-
|
| 484 |
Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
pcs_v3: PCS is taxonomy, describing the work of grantmakers, recipient organizations and the philanthropic transactions between those entities.
|
| 490 |
-
The facets of the PCS illuminate the work and answer the following questions about philanthropy:
|
| 491 |
-
Who? = Population Served
|
| 492 |
-
What? = Subject and Organization Type
|
| 493 |
-
How? = Support Strategy and Transaction Type
|
| 494 |
-
the Facets:
|
| 495 |
-
Subjects: Describes WHAT is being supported. Example: Elementary education or Clean water supply.
|
| 496 |
-
Populations: Describes WHO is being supported. Example: Girls or People with disabilities.
|
| 497 |
-
Organization Type: Describes WHAT type of organization is providing or receiving support.
|
| 498 |
-
Transaction Type: Describes HOW support is being provided.
|
| 499 |
-
Support Strategies: Describes HOW activities are being implemented.
|
| 500 |
-
|
| 501 |
-
pcs_v3 itself is in a json format:
|
| 502 |
-
key - subject
|
| 503 |
-
value: it is a list of dictionary so might need to loop around to find the particular aspect
|
| 504 |
-
hierarchy: (it is a list having subject name)
|
| 505 |
-
[
|
| 506 |
-
{{
|
| 507 |
-
'name':
|
| 508 |
-
}},
|
| 509 |
-
{{
|
| 510 |
-
'name':
|
| 511 |
-
}}
|
| 512 |
-
]
|
| 513 |
-
Before Writing elastic search query think through which field to use
|
| 514 |
-
|
| 515 |
-
Note: first you should focus on query `text` then look into pcs_v3. Make sure you pick the right size for the query
|
| 516 |
|
|
|
|
| 517 |
User's query:
|
| 518 |
```{state["query"]}```
|
| 519 |
"""
|
| 520 |
}
|
| 521 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
res = agent_exec.invoke(input_data)
|
| 523 |
state["agent_out"] = res["output"]
|
| 524 |
-
|
| 525 |
es_queries, es_results = {}, {}
|
| 526 |
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
| 527 |
if action[0].tool == "elastic_index_search_tool":
|
|
|
|
| 528 |
es_queries[f"query_{i}"] = json.loads(
|
| 529 |
action[0].tool_input.get("query") or "{}"
|
| 530 |
)
|
|
@@ -550,6 +697,18 @@ class ElasticGraph(StateGraph):
|
|
| 550 |
"""
|
| 551 |
|
| 552 |
print("> Org Index Agent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
input_data = {
|
| 554 |
"input": f"""
|
| 555 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
@@ -557,14 +716,45 @@ class ElasticGraph(StateGraph):
|
|
| 557 |
1. Understand the user query to determine the required information.
|
| 558 |
2. Query the indices in the Elasticsearch database.
|
| 559 |
3. Retrieve the mappings and field names relevant to the query.
|
| 560 |
-
4. Use the `
|
| 561 |
5. Present the response in a clear and natural language format, addressing the user's question directly.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 562 |
|
| 563 |
-
User's
|
| 564 |
```{state["query"]}```
|
| 565 |
"""
|
| 566 |
}
|
| 567 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
res = agent_exec.invoke(input_data)
|
| 569 |
state["agent_out"] = res["output"]
|
| 570 |
|
|
@@ -622,13 +812,15 @@ class ElasticGraph(StateGraph):
|
|
| 622 |
"""
|
| 623 |
|
| 624 |
# Add nodes
|
|
|
|
| 625 |
self.add_node("analyse", self.analyse_query)
|
| 626 |
self.add_node("grant-index", self.grant_index_agent)
|
| 627 |
self.add_node("org-index", self.org_index_agent)
|
| 628 |
self.add_node("final_answer", self.final_answer)
|
| 629 |
|
| 630 |
# Set entry point
|
| 631 |
-
self.set_entry_point("
|
|
|
|
| 632 |
|
| 633 |
# Add conditional edges
|
| 634 |
self.add_conditional_edges(
|
|
|
|
| 2 |
from functools import partial
|
| 3 |
import json
|
| 4 |
import ast
|
| 5 |
+
from ask_candid.base.api_base import BaseAPI
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
|
| 10 |
from langchain_core.runnables import RunnableSequence
|
|
|
|
| 27 |
tools = [
|
| 28 |
IndexShowDataTool(),
|
| 29 |
IndexDetailsTool(),
|
| 30 |
+
create_search_tool(pcs_codes={}),
|
| 31 |
]
|
| 32 |
|
| 33 |
|
| 34 |
+
class AutocodingAPI(BaseAPI):
|
| 35 |
+
def __init__(self):
|
| 36 |
+
super().__init__(
|
| 37 |
+
url=os.getenv("AUTOCODING_API_URL"),
|
| 38 |
+
headers={
|
| 39 |
+
"x-api-key": os.getenv("AUTOCODING_API_KEY"),
|
| 40 |
+
"Content-Type": "application/json",
|
| 41 |
+
},
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def __call__(self, text: str, taxonomy: str = "pcs-v3"):
|
| 45 |
+
params = {"text": text, "taxonomy": taxonomy}
|
| 46 |
+
return self.get(**params)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def find_subject_levels(filtered_df, subject_level_i, target_value):
|
| 50 |
+
"""
|
| 51 |
+
Filters the DataFrame from the last valid NaN in 'Subject Level i' and retrieves corresponding values for lower levels.
|
| 52 |
+
|
| 53 |
+
Parameters:
|
| 54 |
+
filtered_df (pd.DataFrame): The input DataFrame.
|
| 55 |
+
subject_level_i (int): The subject level to filter from (1 to 4).
|
| 56 |
+
target_value (str): The value to search for in 'Subject Level i'.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
dict: A dictionary containing values for 'Subject Level i' to 'Subject Level 1'.
|
| 60 |
+
pd.DataFrame: The filtered DataFrame from the determined start index to the target_value row.
|
| 61 |
+
"""
|
| 62 |
+
if subject_level_i < 1 or subject_level_i > 4:
|
| 63 |
+
raise ValueError("subject_level_i should be between 1 and 4")
|
| 64 |
+
|
| 65 |
+
# Define the target column dynamically
|
| 66 |
+
target_column = f"Subject Level {subject_level_i}"
|
| 67 |
+
|
| 68 |
+
# Find indices where the target column has the target value
|
| 69 |
+
target_indices = filtered_df[
|
| 70 |
+
filtered_df[target_column].astype(str).str.strip() == target_value
|
| 71 |
+
].index
|
| 72 |
+
|
| 73 |
+
if target_indices.empty:
|
| 74 |
+
return {}, pd.DataFrame() # Return empty if target_value is not found
|
| 75 |
+
|
| 76 |
+
# Get the first occurrence of the target value
|
| 77 |
+
first_target_index = target_indices[0]
|
| 78 |
+
|
| 79 |
+
# Initialize dictionary to store subject level values
|
| 80 |
+
subject_level_values = {target_column: target_value}
|
| 81 |
+
|
| 82 |
+
# Initialize subject level start index
|
| 83 |
+
subject_level_start = first_target_index
|
| 84 |
+
|
| 85 |
+
# Find the last non-NaN row for each subject level
|
| 86 |
+
for level in range(subject_level_i - 1, 0, -1): # Loop from subject_level_i-1 to 1
|
| 87 |
+
column_name = f"Subject Level {level}"
|
| 88 |
+
|
| 89 |
+
# Start checking above the previous found index
|
| 90 |
+
current_index = subject_level_start - 1
|
| 91 |
+
|
| 92 |
+
while current_index >= 0 and pd.isna(
|
| 93 |
+
filtered_df.loc[current_index, column_name]
|
| 94 |
+
):
|
| 95 |
+
current_index -= 1 # Move up while NaN is found
|
| 96 |
+
|
| 97 |
+
# Move one row down to get the last valid row in 'Subject Level level'
|
| 98 |
+
subject_level_start = current_index + 1
|
| 99 |
+
|
| 100 |
+
# Ensure we store the correct value at each subject level
|
| 101 |
+
if subject_level_start in filtered_df.index:
|
| 102 |
+
subject_level_values[column_name] = filtered_df.loc[
|
| 103 |
+
subject_level_start - 1, column_name
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
# Ensure valid slicing range
|
| 107 |
+
min_start_index = subject_level_start
|
| 108 |
+
|
| 109 |
+
if min_start_index < first_target_index:
|
| 110 |
+
filtered_df = filtered_df.loc[min_start_index:first_target_index]
|
| 111 |
+
else:
|
| 112 |
+
filtered_df = pd.DataFrame()
|
| 113 |
+
|
| 114 |
+
return subject_level_values, filtered_df
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def extract_heirarchy(full_code, target_value):
|
| 118 |
+
# df = pd.read_excel(
|
| 119 |
+
# r"C:\Users\mukul.rawat\OneDrive - Candid\Documents\Projects\Gen AI\azure_devops\ask-candid-assistant\PCS_Taxonomy_Definitions_2024.xlsx"
|
| 120 |
+
# )
|
| 121 |
+
df = pd.read_excel(r"C:\Users\siqi.deng\Downloads\PCS_Taxonomy_Definitions_2024.xlsx")
|
| 122 |
+
filtered_df = df[df["PCS Code"].str.startswith(full_code[:2], na=False)]
|
| 123 |
+
for i in range(1, 5):
|
| 124 |
+
column_name = f"Subject Level {i}"
|
| 125 |
+
if (df[column_name].str.strip() == target_value).any():
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
subject_level_values, filtered_df = find_subject_levels(
|
| 129 |
+
filtered_df, i, target_value
|
| 130 |
+
)
|
| 131 |
+
sorted_values = [
|
| 132 |
+
value
|
| 133 |
+
for key, value in sorted(
|
| 134 |
+
subject_level_values.items(), key=lambda x: int(x[0].split()[-1])
|
| 135 |
+
)
|
| 136 |
+
]
|
| 137 |
+
# Joining values in the required format
|
| 138 |
+
result = " : ".join(sorted_values)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
class GraphState(TypedDict):
|
| 143 |
query: str = Field(
|
| 144 |
..., description="The user's query to be processed by the system."
|
|
|
|
| 158 |
...,
|
| 159 |
description="The Elasticsearch query result generated or used by the agent.",
|
| 160 |
)
|
| 161 |
+
pcs_codes: dict = Field(..., description="pcs codes")
|
| 162 |
|
| 163 |
|
| 164 |
class AnalysisResult(BaseModel):
|
|
|
|
| 446 |
|
| 447 |
|
| 448 |
class ElasticGraph(StateGraph):
|
|
|
|
|
|
|
| 449 |
llm: LLM
|
| 450 |
tools: List[Tool]
|
| 451 |
|
|
|
|
| 455 |
self.tools = tools
|
| 456 |
self.construct_graph()
|
| 457 |
|
| 458 |
+
def Extract_PCS_Codes(self, state):
|
| 459 |
+
"""Todo: Add Subject heirarchies, Population, Geo"""
|
| 460 |
+
print("query", state["query"])
|
| 461 |
+
autocoding_api = AutocodingAPI()
|
| 462 |
+
autocoding_response = autocoding_api(text=state["query"]).get("data", {})
|
| 463 |
+
# population_served = autocoding_response.get("population", {})
|
| 464 |
+
subjects = autocoding_response.get("subject", {})
|
| 465 |
+
descriptions = []
|
| 466 |
+
heirarchy_string = []
|
| 467 |
+
if subjects and isinstance(subjects, list) and "description" in subjects[0]:
|
| 468 |
+
for subject in subjects:
|
| 469 |
+
# if subject['description'] in subjects_list:
|
| 470 |
+
descriptions.append(subject["description"])
|
| 471 |
+
heirarchy_string.append(
|
| 472 |
+
extract_heirarchy(subject["full_code"], subject["description"])
|
| 473 |
+
)
|
| 474 |
+
print("descriptions", descriptions)
|
| 475 |
+
|
| 476 |
+
populations = autocoding_response.get("population", {})
|
| 477 |
+
population_dict = []
|
| 478 |
+
if (
|
| 479 |
+
populations
|
| 480 |
+
and isinstance(populations, list)
|
| 481 |
+
and "description" in populations[0]
|
| 482 |
+
):
|
| 483 |
+
for population in populations:
|
| 484 |
+
population_dict.append(population["description"])
|
| 485 |
+
state["pcs_codes"] = {
|
| 486 |
+
"subject": descriptions,
|
| 487 |
+
"heirarchy_string": heirarchy_string,
|
| 488 |
+
"population": population_dict,
|
| 489 |
+
}
|
| 490 |
+
print("pcs_codes_new", state["pcs_codes"])
|
| 491 |
+
return state
|
| 492 |
+
|
| 493 |
def agent_factory(self) -> AgentExecutor:
|
| 494 |
"""
|
| 495 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
|
|
|
| 532 |
return_intermediate_steps=True,
|
| 533 |
)
|
| 534 |
|
| 535 |
+
def agent_factory_claude(self, pcs_codes, prefix) -> AgentExecutor:
|
| 536 |
"""
|
| 537 |
Creates and configures an AgentExecutor instance for interacting with Elasticsearch.
|
| 538 |
|
|
|
|
| 545 |
AgentExecutor: Configured agent ready to execute tasks with specified tools,
|
| 546 |
providing detailed intermediate steps for transparency.
|
| 547 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 548 |
prompt = ChatPromptTemplate.from_messages(
|
| 549 |
[
|
| 550 |
("system", f"You are a helpful elasticsearch assistant. {prefix}"),
|
|
|
|
| 554 |
]
|
| 555 |
)
|
| 556 |
|
| 557 |
+
tools = [
|
| 558 |
+
# ListIndicesTool(),
|
| 559 |
+
IndexShowDataTool(),
|
| 560 |
+
IndexDetailsTool(),
|
| 561 |
+
create_search_tool(pcs_codes=pcs_codes),
|
| 562 |
+
]
|
| 563 |
+
agent = create_tool_calling_agent(self.llm, tools, prompt)
|
| 564 |
+
|
| 565 |
agent_executor = AgentExecutor.from_agent_and_tools(
|
| 566 |
+
agent=agent,
|
| 567 |
+
tools=tools,
|
| 568 |
+
verbose=True,
|
| 569 |
+
return_intermediate_steps=True,
|
| 570 |
)
|
| 571 |
# Create the agent
|
| 572 |
return agent_executor
|
|
|
|
| 613 |
|
| 614 |
def grant_index_agent(self, state: GraphState) -> GraphState:
|
| 615 |
print("> Grant Index Agent")
|
| 616 |
+
# autocoding test
|
| 617 |
+
|
| 618 |
input_data = {
|
| 619 |
"input": f"""
|
| 620 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
|
|
| 627 |
Users may not always provide the exact name, so the Elasticsearch query should accommodate partial or incomplete names
|
| 628 |
by searching for relevant keywords.
|
| 629 |
6. Present the response in a clear and natural language format, addressing the user's question directly.
|
| 630 |
+
|
|
|
|
| 631 |
Description of some of the fields in the index but rest of the fields which are not here should be easy to understand:
|
| 632 |
+
*fiscal_year: Year when grantmaker allocates budget for funding and grants. format YYYY
|
| 633 |
+
*recipient_state: is abbreviated for eg. NY, FL, CA
|
| 634 |
+
*recipient_city - Full Name of the City e.g, New York City, Boston
|
| 635 |
+
*recipient_country - Country Abbreviation of the recipient organization e.g USA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 636 |
|
| 637 |
+
Note: Do not include `title`, `program_area`, `text` field in the elastic search query
|
| 638 |
User's query:
|
| 639 |
```{state["query"]}```
|
| 640 |
"""
|
| 641 |
}
|
| 642 |
+
pcs_codes = state["pcs_codes"]
|
| 643 |
+
pcs_match_term = ""
|
| 644 |
+
for pcs_code in pcs_codes["subject"]:
|
| 645 |
+
if pcs_code != "Philanthropy":
|
| 646 |
+
pcs_match_term += f"*'pcs_v3.subject.value.name': {pcs_code}* \n"
|
| 647 |
+
|
| 648 |
+
for pcs_code in pcs_codes["population"]:
|
| 649 |
+
if pcs_code != "Other population":
|
| 650 |
+
pcs_match_term += f"*'pcs_v3.population.value.name': {pcs_code}* \n"
|
| 651 |
+
print("pcs_match_term", pcs_match_term)
|
| 652 |
+
prefix = f"""
|
| 653 |
+
You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
| 654 |
+
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
| 655 |
+
Guidelines for generating right elastic seach query:
|
| 656 |
+
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
| 657 |
+
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
| 658 |
+
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
| 659 |
+
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
| 660 |
+
|
| 661 |
+
Instruction for pcs_v3 Field-
|
| 662 |
+
If {pcs_codes['subject']} not empty:
|
| 663 |
+
Only include all of the following match terms. No other pcs_v3 fields should be added, duplicated, or altered except for those listed below.
|
| 664 |
+
- {pcs_match_term}
|
| 665 |
+
"""
|
| 666 |
+
agent_exec = self.agent_factory_claude(
|
| 667 |
+
pcs_codes=state["pcs_codes"], prefix=prefix
|
| 668 |
+
)
|
| 669 |
res = agent_exec.invoke(input_data)
|
| 670 |
state["agent_out"] = res["output"]
|
|
|
|
| 671 |
es_queries, es_results = {}, {}
|
| 672 |
for i, action in enumerate(res.get("intermediate_steps", []), start=1):
|
| 673 |
if action[0].tool == "elastic_index_search_tool":
|
| 674 |
+
print("query", action[0].tool_input.get("query"))
|
| 675 |
es_queries[f"query_{i}"] = json.loads(
|
| 676 |
action[0].tool_input.get("query") or "{}"
|
| 677 |
)
|
|
|
|
| 697 |
"""
|
| 698 |
|
| 699 |
print("> Org Index Agent")
|
| 700 |
+
mapping_description = """
|
| 701 |
+
"admin1_code": "state abbreviation"
|
| 702 |
+
"admin1_description": "Full name/label of the state"
|
| 703 |
+
"city": Full Name of the city with 1st letter being capital for e.g. New York City
|
| 704 |
+
"assets": "The assets value of the most recent fiscals available for the organization."
|
| 705 |
+
"country_code": "Country abbreviation"
|
| 706 |
+
"country_name": "Country name"
|
| 707 |
+
"fiscal_year": "The year of the most recent fiscals available for the organization. (YYYY format)"
|
| 708 |
+
"mission_statement": "The mission statement of the organization."
|
| 709 |
+
"roles": "grantmaker: Indicates the organization gives grants., recipient: Indicates the organization receives grants., company: Indicates the organization is a company/corporation."
|
| 710 |
+
|
| 711 |
+
"""
|
| 712 |
input_data = {
|
| 713 |
"input": f"""
|
| 714 |
You are an Elasticsearch database agent designed to accurately understand and respond to user queries. Follow these steps:
|
|
|
|
| 716 |
1. Understand the user query to determine the required information.
|
| 717 |
2. Query the indices in the Elasticsearch database.
|
| 718 |
3. Retrieve the mappings and field names relevant to the query.
|
| 719 |
+
4. Use the `organization_qa_ds1` index to extract the necessary data.
|
| 720 |
5. Present the response in a clear and natural language format, addressing the user's question directly.
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
Given Below is mapping description of some of the fields
|
| 724 |
+
```{mapping_description}```
|
| 725 |
+
|
| 726 |
|
| 727 |
+
User's query:
|
| 728 |
```{state["query"]}```
|
| 729 |
"""
|
| 730 |
}
|
| 731 |
+
|
| 732 |
+
pcs_codes = state["pcs_codes"]
|
| 733 |
+
pcs_match_term = ""
|
| 734 |
+
for pcs_code in pcs_codes["subject"]:
|
| 735 |
+
pcs_match_term += f'"taxonomy_descriptions": "{pcs_code}" \n"'
|
| 736 |
+
|
| 737 |
+
print("pcs_match_term", pcs_match_term)
|
| 738 |
+
prefix = f"""You are an intelligent agent tasked with generating accurate Elasticsearch DSL queries.
|
| 739 |
+
Analyze the intent behind the query and determine the appropriate Elasticsearch operations required.
|
| 740 |
+
Guidelines for generating right elastic seach query:
|
| 741 |
+
1. Automatically determine whether to return document hits or aggregation results based on the query structure.
|
| 742 |
+
2. Use keyword fields instead of text fields for aggregations and sorting to avoid fielddata errors
|
| 743 |
+
3. Avoid using field.keyword if a keyword field is already present to prevent redundant queries.
|
| 744 |
+
4. Ensure efficient query execution by selecting appropriate query types for filtering, searching, and aggregating.
|
| 745 |
+
|
| 746 |
+
Instructions to use `taxonomy_descriptions` field:
|
| 747 |
+
If {pcs_codes['subject']} not empty, only add the following match term:
|
| 748 |
+
Only add the following `match` term, No other `taxonomy_descriptions` fields should be added, duplicated, or modified except belowIf {pcs_codes['subject']} not empty,
|
| 749 |
+
- {pcs_match_term}
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
Avoid using `ntee_major_description` field in the es query
|
| 753 |
+
|
| 754 |
+
"""
|
| 755 |
+
agent_exec = self.agent_factory_claude(
|
| 756 |
+
pcs_codes=state["pcs_codes"], prefix=prefix
|
| 757 |
+
)
|
| 758 |
res = agent_exec.invoke(input_data)
|
| 759 |
state["agent_out"] = res["output"]
|
| 760 |
|
|
|
|
| 812 |
"""
|
| 813 |
|
| 814 |
# Add nodes
|
| 815 |
+
self.add_node("Context_Extraction", self.Extract_PCS_Codes)
|
| 816 |
self.add_node("analyse", self.analyse_query)
|
| 817 |
self.add_node("grant-index", self.grant_index_agent)
|
| 818 |
self.add_node("org-index", self.org_index_agent)
|
| 819 |
self.add_node("final_answer", self.final_answer)
|
| 820 |
|
| 821 |
# Set entry point
|
| 822 |
+
self.set_entry_point("Context_Extraction")
|
| 823 |
+
self.add_edge("Context_Extraction", "analyse")
|
| 824 |
|
| 825 |
# Add conditional edges
|
| 826 |
self.add_conditional_edges(
|
ask_candid/retrieval/elastic.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
| 2 |
-
from dataclasses import dataclass
|
| 3 |
from itertools import groupby
|
| 4 |
|
| 5 |
from torch.nn import functional as F
|
|
@@ -10,12 +9,14 @@ from langchain_core.documents import Document
|
|
| 10 |
from elasticsearch import Elasticsearch
|
| 11 |
|
| 12 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
| 13 |
-
from ask_candid.retrieval.sources.
|
| 14 |
-
from ask_candid.retrieval.sources.
|
| 15 |
-
from ask_candid.retrieval.sources.
|
| 16 |
-
from ask_candid.retrieval.sources.
|
| 17 |
-
from ask_candid.retrieval.sources.
|
| 18 |
-
from ask_candid.retrieval.sources.
|
|
|
|
|
|
|
| 19 |
from ask_candid.services.small_lm import CandidSLM
|
| 20 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
| 21 |
from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
|
@@ -23,17 +24,6 @@ from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
|
| 23 |
encoder = SpladeEncoder()
|
| 24 |
|
| 25 |
|
| 26 |
-
@dataclass
|
| 27 |
-
class ElasticHitsResult:
|
| 28 |
-
"""Dataclass for Elasticsearch hits results
|
| 29 |
-
"""
|
| 30 |
-
index: str
|
| 31 |
-
id: Any
|
| 32 |
-
score: float
|
| 33 |
-
source: Dict[str, Any]
|
| 34 |
-
inner_hits: Dict[str, Any]
|
| 35 |
-
|
| 36 |
-
|
| 37 |
class RetrieverInput(BaseModel):
|
| 38 |
"""Input to the Elasticsearch retriever."""
|
| 39 |
user_input: str = Field(description="query to look up in retriever")
|
|
@@ -101,7 +91,7 @@ def news_query_builder(query: str) -> Dict[str, Any]:
|
|
| 101 |
tokens = encoder.token_expand(query)
|
| 102 |
|
| 103 |
query = {
|
| 104 |
-
"_source": ["id", "link", "title", "content"],
|
| 105 |
"query": {
|
| 106 |
"bool": {
|
| 107 |
"filter": [
|
|
@@ -150,27 +140,27 @@ def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str
|
|
| 150 |
if index == "issuelab":
|
| 151 |
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
|
| 152 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 153 |
-
q["size"] =
|
| 154 |
queries.extend([{"index": IssueLabConfig.index_name}, q])
|
| 155 |
elif index == "youtube":
|
| 156 |
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
|
| 157 |
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
|
| 158 |
-
q["size"] =
|
| 159 |
queries.extend([{"index": YoutubeConfig.index_name}, q])
|
| 160 |
elif index == "candid_blog":
|
| 161 |
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
|
| 162 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 163 |
-
q["size"] =
|
| 164 |
queries.extend([{"index": CandidBlogConfig.index_name}, q])
|
| 165 |
elif index == "candid_learning":
|
| 166 |
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
|
| 167 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 168 |
-
q["size"] =
|
| 169 |
queries.extend([{"index": CandidLearningConfig.index_name}, q])
|
| 170 |
elif index == "candid_help":
|
| 171 |
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
|
| 172 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 173 |
-
q["size"] =
|
| 174 |
queries.extend([{"index": CandidHelpConfig.index_name}, q])
|
| 175 |
elif index == "news":
|
| 176 |
q = news_query_builder(query=query)
|
|
@@ -199,12 +189,18 @@ def multi_search(
|
|
| 199 |
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
| 200 |
for query_group in responses:
|
| 201 |
for h in query_group.get("hits", {}).get("hits", []):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
yield ElasticHitsResult(
|
| 203 |
index=h["_index"],
|
| 204 |
id=h["_id"],
|
| 205 |
score=h["_score"],
|
| 206 |
source=h["_source"],
|
| 207 |
-
inner_hits=
|
| 208 |
)
|
| 209 |
|
| 210 |
results = []
|
|
@@ -264,6 +260,10 @@ def retrieved_text(hits: Dict[str, Any]) -> str:
|
|
| 264 |
|
| 265 |
text = []
|
| 266 |
for _, v in hits.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
for h in (v.get("hits", {}).get("hits") or []):
|
| 268 |
for _, field in h.get("fields", {}).items():
|
| 269 |
for chunk in field:
|
|
@@ -298,7 +298,8 @@ def cosine_rescore(query: str, contexts: List[str]) -> List[float]:
|
|
| 298 |
|
| 299 |
def reranker(
|
| 300 |
query_results: Iterable[ElasticHitsResult],
|
| 301 |
-
search_text: Optional[str] = None
|
|
|
|
| 302 |
) -> Iterator[ElasticHitsResult]:
|
| 303 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
| 304 |
This will shuffle results
|
|
@@ -327,58 +328,13 @@ def reranker(
|
|
| 327 |
text = retrieved_text(d.inner_hits)
|
| 328 |
texts.append(text)
|
| 329 |
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
| 336 |
|
| 337 |
-
|
| 338 |
-
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
| 339 |
-
"""Pads the relevant chunk of text with context before and after
|
| 340 |
-
|
| 341 |
-
Parameters
|
| 342 |
-
----------
|
| 343 |
-
field_name : str
|
| 344 |
-
a field with the long text that was chunked into pieces
|
| 345 |
-
hit : ElasticHitsResult
|
| 346 |
-
context_length : int, optional
|
| 347 |
-
length of text to add before and after the chunk, by default 1024
|
| 348 |
-
|
| 349 |
-
Returns
|
| 350 |
-
-------
|
| 351 |
-
str
|
| 352 |
-
longer chunks stuffed together
|
| 353 |
-
"""
|
| 354 |
-
|
| 355 |
-
chunks = []
|
| 356 |
-
# NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
|
| 357 |
-
long_text = hit.source.get(f"{field_name}", "")
|
| 358 |
-
long_text = long_text.lower()
|
| 359 |
-
inner_hits_field = f"embeddings.{field_name}.chunks"
|
| 360 |
-
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
| 361 |
-
if found_chunks:
|
| 362 |
-
hits = found_chunks.get("hits", {}).get("hits", [])
|
| 363 |
-
for h in hits:
|
| 364 |
-
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
| 365 |
-
|
| 366 |
-
# cutting the middle because we may have tokenizing artifacts there
|
| 367 |
-
chunk = chunk[3: -3]
|
| 368 |
-
|
| 369 |
-
if add_context:
|
| 370 |
-
# Find the start and end indices of the chunk in the large text
|
| 371 |
-
start_index = long_text.find(chunk[:20])
|
| 372 |
-
|
| 373 |
-
# Chunk is found
|
| 374 |
-
if start_index != -1:
|
| 375 |
-
end_index = start_index + len(chunk)
|
| 376 |
-
pre_start_index = max(0, start_index - context_length)
|
| 377 |
-
post_end_index = min(len(long_text), end_index + context_length)
|
| 378 |
-
chunks.append(long_text[pre_start_index:post_end_index])
|
| 379 |
-
else:
|
| 380 |
-
chunks.append(chunk)
|
| 381 |
-
return '\n\n'.join(chunks)
|
| 382 |
|
| 383 |
|
| 384 |
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
|
@@ -394,94 +350,17 @@ def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
|
| 394 |
"""
|
| 395 |
|
| 396 |
if "issuelab-elser" in hit.index:
|
| 397 |
-
|
| 398 |
-
description = hit.source.get("description", "")
|
| 399 |
-
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
| 400 |
-
# we only need to process long texts
|
| 401 |
-
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
| 402 |
-
doc = Document(
|
| 403 |
-
page_content='\n\n'.join([
|
| 404 |
-
combined_item_description,
|
| 405 |
-
combined_issuelab_findings,
|
| 406 |
-
description,
|
| 407 |
-
chunks_with_context_txt
|
| 408 |
-
]),
|
| 409 |
-
metadata={
|
| 410 |
-
"title": hit.source["title"],
|
| 411 |
-
"source": "IssueLab",
|
| 412 |
-
"source_id": hit.source["resource_id"],
|
| 413 |
-
"url": hit.source.get("permalink", "")
|
| 414 |
-
}
|
| 415 |
-
)
|
| 416 |
elif "youtube" in hit.index:
|
| 417 |
-
|
| 418 |
-
# we only need to process long texts
|
| 419 |
-
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
| 420 |
-
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
| 421 |
-
doc = Document(
|
| 422 |
-
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
| 423 |
-
metadata={
|
| 424 |
-
"title": title,
|
| 425 |
-
"source": "Candid YouTube",
|
| 426 |
-
"source_id": hit.source['video_id'],
|
| 427 |
-
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
| 428 |
-
}
|
| 429 |
-
)
|
| 430 |
elif "candid-blog" in hit.index:
|
| 431 |
-
|
| 432 |
-
title = hit.source.get("title", "")
|
| 433 |
-
# we only need to process long text
|
| 434 |
-
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
| 435 |
-
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
| 436 |
-
tags = hit.source.get("title_summary_tags", "")
|
| 437 |
-
doc = Document(
|
| 438 |
-
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
| 439 |
-
metadata={
|
| 440 |
-
"title": title,
|
| 441 |
-
"source": "Candid Blog",
|
| 442 |
-
"source_id": hit.source["id"],
|
| 443 |
-
"url": hit.source["link"]
|
| 444 |
-
}
|
| 445 |
-
)
|
| 446 |
elif "candid-learning" in hit.index:
|
| 447 |
-
|
| 448 |
-
content_with_context_txt = get_context("content", hit, context_length=12)
|
| 449 |
-
training_topics = hit.source.get("training_topics", "")
|
| 450 |
-
staff_recommendations = hit.source.get("staff_recommendations", "")
|
| 451 |
-
|
| 452 |
-
doc = Document(
|
| 453 |
-
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
| 454 |
-
metadata={
|
| 455 |
-
"title": hit.source["title"],
|
| 456 |
-
"source": "Candid Learning",
|
| 457 |
-
"source_id": hit.source["post_id"],
|
| 458 |
-
"url": hit.source.get("url", "")
|
| 459 |
-
}
|
| 460 |
-
)
|
| 461 |
elif "candid-help" in hit.index:
|
| 462 |
-
|
| 463 |
-
content_with_context_txt = get_context("content", hit, context_length=12)
|
| 464 |
-
combined_article_description = hit.source.get("combined_article_description", "")
|
| 465 |
-
|
| 466 |
-
doc = Document(
|
| 467 |
-
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
| 468 |
-
metadata={
|
| 469 |
-
"title": title,
|
| 470 |
-
"source": "Candid Help",
|
| 471 |
-
"source_id": hit.source["id"],
|
| 472 |
-
"url": hit.source.get("link", "")
|
| 473 |
-
}
|
| 474 |
-
)
|
| 475 |
elif "news" in hit.index:
|
| 476 |
-
doc =
|
| 477 |
-
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
| 478 |
-
metadata={
|
| 479 |
-
"title": hit.source.get("title", ""),
|
| 480 |
-
"source": "Candid News",
|
| 481 |
-
"source_id": hit.source["id"],
|
| 482 |
-
"url": hit.source.get("link", "")
|
| 483 |
-
}
|
| 484 |
-
)
|
| 485 |
else:
|
| 486 |
doc = None
|
| 487 |
return doc
|
|
|
|
| 1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
|
|
|
| 2 |
from itertools import groupby
|
| 3 |
|
| 4 |
from torch.nn import functional as F
|
|
|
|
| 9 |
from elasticsearch import Elasticsearch
|
| 10 |
|
| 11 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
| 12 |
+
from ask_candid.retrieval.sources.schema import ElasticHitsResult
|
| 13 |
+
from ask_candid.retrieval.sources.issuelab import IssueLabConfig, process_issuelab_hit
|
| 14 |
+
from ask_candid.retrieval.sources.youtube import YoutubeConfig, process_youtube_hit
|
| 15 |
+
from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig, process_blog_hit
|
| 16 |
+
from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig, process_learning_hit
|
| 17 |
+
from ask_candid.retrieval.sources.candid_help import CandidHelpConfig, process_help_hit
|
| 18 |
+
from ask_candid.retrieval.sources.candid_news import CandidNewsConfig, process_news_hit
|
| 19 |
+
|
| 20 |
from ask_candid.services.small_lm import CandidSLM
|
| 21 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
| 22 |
from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
|
|
|
| 24 |
encoder = SpladeEncoder()
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
class RetrieverInput(BaseModel):
|
| 28 |
"""Input to the Elasticsearch retriever."""
|
| 29 |
user_input: str = Field(description="query to look up in retriever")
|
|
|
|
| 91 |
tokens = encoder.token_expand(query)
|
| 92 |
|
| 93 |
query = {
|
| 94 |
+
"_source": ["id", "link", "title", "content", "site_name"],
|
| 95 |
"query": {
|
| 96 |
"bool": {
|
| 97 |
"filter": [
|
|
|
|
| 140 |
if index == "issuelab":
|
| 141 |
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
|
| 142 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 143 |
+
q["size"] = 2
|
| 144 |
queries.extend([{"index": IssueLabConfig.index_name}, q])
|
| 145 |
elif index == "youtube":
|
| 146 |
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
|
| 147 |
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
|
| 148 |
+
q["size"] = 5
|
| 149 |
queries.extend([{"index": YoutubeConfig.index_name}, q])
|
| 150 |
elif index == "candid_blog":
|
| 151 |
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
|
| 152 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 153 |
+
q["size"] = 5
|
| 154 |
queries.extend([{"index": CandidBlogConfig.index_name}, q])
|
| 155 |
elif index == "candid_learning":
|
| 156 |
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
|
| 157 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 158 |
+
q["size"] = 5
|
| 159 |
queries.extend([{"index": CandidLearningConfig.index_name}, q])
|
| 160 |
elif index == "candid_help":
|
| 161 |
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
|
| 162 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 163 |
+
q["size"] = 5
|
| 164 |
queries.extend([{"index": CandidHelpConfig.index_name}, q])
|
| 165 |
elif index == "news":
|
| 166 |
q = news_query_builder(query=query)
|
|
|
|
| 189 |
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
| 190 |
for query_group in responses:
|
| 191 |
for h in query_group.get("hits", {}).get("hits", []):
|
| 192 |
+
inner_hits = h.get("inner_hits", {})
|
| 193 |
+
|
| 194 |
+
if not inner_hits:
|
| 195 |
+
if "news" in h.get("_index"):
|
| 196 |
+
inner_hits = {"text": h.get("_source", {}).get("content")}
|
| 197 |
+
|
| 198 |
yield ElasticHitsResult(
|
| 199 |
index=h["_index"],
|
| 200 |
id=h["_id"],
|
| 201 |
score=h["_score"],
|
| 202 |
source=h["_source"],
|
| 203 |
+
inner_hits=inner_hits
|
| 204 |
)
|
| 205 |
|
| 206 |
results = []
|
|
|
|
| 260 |
|
| 261 |
text = []
|
| 262 |
for _, v in hits.items():
|
| 263 |
+
if _ == "text":
|
| 264 |
+
text.append(v)
|
| 265 |
+
continue
|
| 266 |
+
|
| 267 |
for h in (v.get("hits", {}).get("hits") or []):
|
| 268 |
for _, field in h.get("fields", {}).items():
|
| 269 |
for chunk in field:
|
|
|
|
| 298 |
|
| 299 |
def reranker(
|
| 300 |
query_results: Iterable[ElasticHitsResult],
|
| 301 |
+
search_text: Optional[str] = None,
|
| 302 |
+
max_num_results: int = 10
|
| 303 |
) -> Iterator[ElasticHitsResult]:
|
| 304 |
"""Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
|
| 305 |
This will shuffle results
|
|
|
|
| 328 |
text = retrieved_text(d.inner_hits)
|
| 329 |
texts.append(text)
|
| 330 |
|
| 331 |
+
if search_text and len(texts) == len(results):
|
| 332 |
+
# scores = cosine_rescore(search_text, texts)
|
| 333 |
+
scores = encoder.query_reranking(query=search_text, documents=texts)
|
| 334 |
+
for r, s in zip(results, scores):
|
| 335 |
+
r.score = s
|
|
|
|
| 336 |
|
| 337 |
+
yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
|
| 339 |
|
| 340 |
def process_hit(hit: ElasticHitsResult) -> Union[Document, None]:
|
|
|
|
| 350 |
"""
|
| 351 |
|
| 352 |
if "issuelab-elser" in hit.index:
|
| 353 |
+
doc = process_issuelab_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
elif "youtube" in hit.index:
|
| 355 |
+
doc = process_youtube_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
elif "candid-blog" in hit.index:
|
| 357 |
+
doc = process_blog_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
elif "candid-learning" in hit.index:
|
| 359 |
+
doc = process_learning_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
elif "candid-help" in hit.index:
|
| 361 |
+
doc = process_help_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
elif "news" in hit.index:
|
| 363 |
+
doc = process_news_hit(hit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
else:
|
| 365 |
doc = None
|
| 366 |
return doc
|
ask_candid/retrieval/sources/candid_blog.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
CandidBlogConfig = ElasticSourceConfig(
|
| 6 |
index_name="search-semantic-candid-blog",
|
|
@@ -8,6 +11,24 @@ CandidBlogConfig = ElasticSourceConfig(
|
|
| 8 |
)
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 12 |
url = f"{doc['link']}"
|
| 13 |
fields = ["title", "excerpt"]
|
|
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
| 2 |
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
| 6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
| 7 |
|
| 8 |
CandidBlogConfig = ElasticSourceConfig(
|
| 9 |
index_name="search-semantic-candid-blog",
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
|
| 14 |
+
def process_blog_hit(hit: ElasticHitsResult) -> Document:
|
| 15 |
+
excerpt = hit.source.get("excerpt", "")
|
| 16 |
+
title = hit.source.get("title", "")
|
| 17 |
+
# we only need to process long text
|
| 18 |
+
content_with_context_txt = get_context("content", hit, context_length=12, add_context=False)
|
| 19 |
+
authors = get_context("authors_text", hit, context_length=12, add_context=False)
|
| 20 |
+
tags = hit.source.get("title_summary_tags", "")
|
| 21 |
+
return Document(
|
| 22 |
+
page_content='\n\n'.join([title, excerpt, content_with_context_txt, authors, tags]),
|
| 23 |
+
metadata={
|
| 24 |
+
"title": title,
|
| 25 |
+
"source": "Candid Blog",
|
| 26 |
+
"source_id": hit.source["id"],
|
| 27 |
+
"url": hit.source["link"]
|
| 28 |
+
}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 33 |
url = f"{doc['link']}"
|
| 34 |
fields = ["title", "excerpt"]
|
ask_candid/retrieval/sources/candid_help.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
CandidHelpConfig = ElasticSourceConfig(
|
| 6 |
index_name="search-semantic-candid-help-elser_ve1",
|
|
@@ -8,6 +11,22 @@ CandidHelpConfig = ElasticSourceConfig(
|
|
| 8 |
)
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 12 |
url = f"{doc['link']}"
|
| 13 |
fields = ["title", "summary"]
|
|
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
| 2 |
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
| 6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
| 7 |
|
| 8 |
CandidHelpConfig = ElasticSourceConfig(
|
| 9 |
index_name="search-semantic-candid-help-elser_ve1",
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
|
| 14 |
+
def process_help_hit(hit: ElasticHitsResult) -> Document:
|
| 15 |
+
title = hit.source.get("title", "")
|
| 16 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
| 17 |
+
combined_article_description = hit.source.get("combined_article_description", "")
|
| 18 |
+
|
| 19 |
+
return Document(
|
| 20 |
+
page_content='\n\n'.join([combined_article_description, content_with_context_txt]),
|
| 21 |
+
metadata={
|
| 22 |
+
"title": title,
|
| 23 |
+
"source": "Candid Help",
|
| 24 |
+
"source_id": hit.source["id"],
|
| 25 |
+
"url": hit.source.get("link", "")
|
| 26 |
+
}
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 31 |
url = f"{doc['link']}"
|
| 32 |
fields = ["title", "summary"]
|
ask_candid/retrieval/sources/candid_learning.py
CHANGED
|
@@ -1,5 +1,9 @@
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
|
| 5 |
CandidLearningConfig = ElasticSourceConfig(
|
|
@@ -8,6 +12,23 @@ CandidLearningConfig = ElasticSourceConfig(
|
|
| 8 |
)
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 12 |
url = f"{doc['url']}"
|
| 13 |
fields = ["title", "excerpt"]
|
|
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
+
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
| 6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
| 7 |
|
| 8 |
|
| 9 |
CandidLearningConfig = ElasticSourceConfig(
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
|
| 15 |
+
def process_learning_hit(hit: ElasticHitsResult) -> Document:
|
| 16 |
+
title = hit.source.get("title", "")
|
| 17 |
+
content_with_context_txt = get_context("content", hit, context_length=12)
|
| 18 |
+
training_topics = hit.source.get("training_topics", "")
|
| 19 |
+
staff_recommendations = hit.source.get("staff_recommendations", "")
|
| 20 |
+
|
| 21 |
+
return Document(
|
| 22 |
+
page_content='\n\n'.join([title, staff_recommendations, training_topics, content_with_context_txt]),
|
| 23 |
+
metadata={
|
| 24 |
+
"title": hit.source["title"],
|
| 25 |
+
"source": "Candid Learning",
|
| 26 |
+
"source_id": hit.source["post_id"],
|
| 27 |
+
"url": hit.source.get("url", "")
|
| 28 |
+
}
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 33 |
url = f"{doc['url']}"
|
| 34 |
fields = ["title", "excerpt"]
|
ask_candid/retrieval/sources/candid_news.py
CHANGED
|
@@ -1,7 +1,20 @@
|
|
| 1 |
-
from
|
| 2 |
|
|
|
|
| 3 |
|
| 4 |
CandidNewsConfig = ElasticSourceConfig(
|
| 5 |
index_name="news_1",
|
| 6 |
text_fields=("title", "content")
|
| 7 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.documents import Document
|
| 2 |
|
| 3 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
| 4 |
|
| 5 |
CandidNewsConfig = ElasticSourceConfig(
|
| 6 |
index_name="news_1",
|
| 7 |
text_fields=("title", "content")
|
| 8 |
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def process_news_hit(hit: ElasticHitsResult) -> Document:
|
| 12 |
+
return Document(
|
| 13 |
+
page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
|
| 14 |
+
metadata={
|
| 15 |
+
"title": hit.source.get("title", ""),
|
| 16 |
+
"source": hit.source.get("site_name") or "Candid News",
|
| 17 |
+
"source_id": hit.source["id"],
|
| 18 |
+
"url": hit.source.get("link", "")
|
| 19 |
+
}
|
| 20 |
+
)
|
ask_candid/retrieval/sources/issuelab.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
IssueLabConfig = ElasticSourceConfig(
|
| 6 |
index_name="search-semantic-issuelab-elser_ve2",
|
|
@@ -8,11 +11,33 @@ IssueLabConfig = ElasticSourceConfig(
|
|
| 8 |
)
|
| 9 |
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 12 |
chunks_html = ""
|
| 13 |
if show_chunks:
|
| 14 |
cleaned_text = []
|
| 15 |
-
for
|
| 16 |
hits = v["hits"]["hits"]
|
| 17 |
for h in hits:
|
| 18 |
for k1, v1 in h["fields"].items():
|
|
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
| 2 |
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
| 6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
| 7 |
|
| 8 |
IssueLabConfig = ElasticSourceConfig(
|
| 9 |
index_name="search-semantic-issuelab-elser_ve2",
|
|
|
|
| 11 |
)
|
| 12 |
|
| 13 |
|
| 14 |
+
def process_issuelab_hit(hit: ElasticHitsResult) -> Document:
|
| 15 |
+
combined_item_description = hit.source.get("combined_item_description", "") # title inside
|
| 16 |
+
description = hit.source.get("description", "")
|
| 17 |
+
combined_issuelab_findings = hit.source.get("combined_issuelab_findings", "")
|
| 18 |
+
# we only need to process long texts
|
| 19 |
+
chunks_with_context_txt = get_context("content", hit, context_length=12)
|
| 20 |
+
return Document(
|
| 21 |
+
page_content='\n\n'.join([
|
| 22 |
+
combined_item_description,
|
| 23 |
+
combined_issuelab_findings,
|
| 24 |
+
description,
|
| 25 |
+
chunks_with_context_txt
|
| 26 |
+
]),
|
| 27 |
+
metadata={
|
| 28 |
+
"title": hit.source["title"],
|
| 29 |
+
"source": "IssueLab",
|
| 30 |
+
"source_id": hit.source["resource_id"],
|
| 31 |
+
"url": hit.source.get("permalink", "")
|
| 32 |
+
}
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 37 |
chunks_html = ""
|
| 38 |
if show_chunks:
|
| 39 |
cleaned_text = []
|
| 40 |
+
for _, v in doc["inner_hits"].items():
|
| 41 |
hits = v["hits"]["hits"]
|
| 42 |
for h in hits:
|
| 43 |
for k1, v1 in h["fields"].items():
|
ask_candid/retrieval/sources/schema.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Tuple, Optional
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
|
| 4 |
|
|
@@ -7,3 +7,14 @@ class ElasticSourceConfig:
|
|
| 7 |
index_name: str
|
| 8 |
text_fields: Tuple[str]
|
| 9 |
excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Dict, Optional, Any
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
|
| 4 |
|
|
|
|
| 7 |
index_name: str
|
| 8 |
text_fields: Tuple[str]
|
| 9 |
excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ElasticHitsResult:
|
| 14 |
+
"""Dataclass for Elasticsearch hits results
|
| 15 |
+
"""
|
| 16 |
+
index: str
|
| 17 |
+
id: Any
|
| 18 |
+
score: float
|
| 19 |
+
source: Dict[str, Any]
|
| 20 |
+
inner_hits: Dict[str, Any]
|
ask_candid/retrieval/sources/utils.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ask_candid.retrieval.sources.schema import ElasticHitsResult
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
| 5 |
+
"""Pads the relevant chunk of text with context before and after
|
| 6 |
+
|
| 7 |
+
Parameters
|
| 8 |
+
----------
|
| 9 |
+
field_name : str
|
| 10 |
+
a field with the long text that was chunked into pieces
|
| 11 |
+
hit : ElasticHitsResult
|
| 12 |
+
context_length : int, optional
|
| 13 |
+
length of text to add before and after the chunk, by default 1024
|
| 14 |
+
|
| 15 |
+
Returns
|
| 16 |
+
-------
|
| 17 |
+
str
|
| 18 |
+
longer chunks stuffed together
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
chunks = []
|
| 22 |
+
# NOTE chunks have tokens, long text is a normal text, but may contain html that also gets weird after tokenization
|
| 23 |
+
long_text = hit.source.get(f"{field_name}", "")
|
| 24 |
+
long_text = long_text.lower()
|
| 25 |
+
inner_hits_field = f"embeddings.{field_name}.chunks"
|
| 26 |
+
found_chunks = hit.inner_hits.get(inner_hits_field, {})
|
| 27 |
+
if found_chunks:
|
| 28 |
+
hits = found_chunks.get("hits", {}).get("hits", [])
|
| 29 |
+
for h in hits:
|
| 30 |
+
chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
|
| 31 |
+
|
| 32 |
+
# cutting the middle because we may have tokenizing artifacts there
|
| 33 |
+
chunk = chunk[3: -3]
|
| 34 |
+
|
| 35 |
+
if add_context:
|
| 36 |
+
# Find the start and end indices of the chunk in the large text
|
| 37 |
+
start_index = long_text.find(chunk[:20])
|
| 38 |
+
|
| 39 |
+
# Chunk is found
|
| 40 |
+
if start_index != -1:
|
| 41 |
+
end_index = start_index + len(chunk)
|
| 42 |
+
pre_start_index = max(0, start_index - context_length)
|
| 43 |
+
post_end_index = min(len(long_text), end_index + context_length)
|
| 44 |
+
chunks.append(long_text[pre_start_index:post_end_index])
|
| 45 |
+
else:
|
| 46 |
+
chunks.append(chunk)
|
| 47 |
+
return '\n\n'.join(chunks)
|
ask_candid/retrieval/sources/youtube.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
-
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
YoutubeConfig = ElasticSourceConfig(
|
| 6 |
index_name="search-semantic-youtube-elser_ve1",
|
|
@@ -9,6 +12,22 @@ YoutubeConfig = ElasticSourceConfig(
|
|
| 9 |
)
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 13 |
url = f"https://www.youtube.com/watch?v={doc['video_id']}"
|
| 14 |
fields = ["title", "description_cleaned"]
|
|
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
| 2 |
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
|
| 5 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig, ElasticHitsResult
|
| 6 |
+
from ask_candid.retrieval.sources.utils import get_context
|
| 7 |
|
| 8 |
YoutubeConfig = ElasticSourceConfig(
|
| 9 |
index_name="search-semantic-youtube-elser_ve1",
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
|
| 15 |
+
def process_youtube_hit(hit: ElasticHitsResult) -> Document:
|
| 16 |
+
title = hit.source.get("title", "")
|
| 17 |
+
# we only need to process long texts
|
| 18 |
+
description_cleaned_with_context_txt = get_context("description_cleaned", hit, context_length=12)
|
| 19 |
+
captions_cleaned_with_context_txt = get_context("captions_cleaned", hit, context_length=12)
|
| 20 |
+
return Document(
|
| 21 |
+
page_content='\n\n'.join([title, description_cleaned_with_context_txt, captions_cleaned_with_context_txt]),
|
| 22 |
+
metadata={
|
| 23 |
+
"title": title,
|
| 24 |
+
"source": "Candid YouTube",
|
| 25 |
+
"source_id": hit.source['video_id'],
|
| 26 |
+
"url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
|
| 27 |
+
}
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
| 32 |
url = f"https://www.youtube.com/watch?v={doc['video_id']}"
|
| 33 |
fields = ["title", "description_cleaned"]
|
ask_candid/retrieval/sparse_lexical.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
-
from typing import Dict
|
| 2 |
|
| 3 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
|
|
| 4 |
import torch
|
| 5 |
|
| 6 |
|
|
@@ -14,14 +15,23 @@ class SpladeEncoder:
|
|
| 14 |
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
| 15 |
|
| 16 |
@torch.no_grad()
|
| 17 |
-
def
|
| 18 |
-
tokens = self.tokenizer(
|
| 19 |
output = self.model(**tokens)
|
| 20 |
-
|
| 21 |
vec = torch.max(
|
| 22 |
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
| 23 |
dim=1
|
| 24 |
)[0].squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
| 26 |
weights = vec[cols].cpu().tolist()
|
| 27 |
|
|
|
|
| 1 |
+
from typing import List, Dict
|
| 2 |
|
| 3 |
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
import torch
|
| 6 |
|
| 7 |
|
|
|
|
| 15 |
self.idx2token = {idx: token for token, idx in self.tokenizer.get_vocab().items()}
|
| 16 |
|
| 17 |
@torch.no_grad()
|
| 18 |
+
def forward(self, texts: List[str]):
|
| 19 |
+
tokens = self.tokenizer(texts, return_tensors='pt', truncation=True, padding=True)
|
| 20 |
output = self.model(**tokens)
|
|
|
|
| 21 |
vec = torch.max(
|
| 22 |
torch.log(1 + torch.relu(output.logits)) * tokens.attention_mask.unsqueeze(-1),
|
| 23 |
dim=1
|
| 24 |
)[0].squeeze()
|
| 25 |
+
return vec
|
| 26 |
+
|
| 27 |
+
def query_reranking(self, query: str, documents: List[str]):
|
| 28 |
+
vec = self.forward([query, *documents])
|
| 29 |
+
xQ = F.normalize(vec[:1], dim=-1, p=2.)
|
| 30 |
+
xD = F.normalize(vec[1:], dim=-1, p=2.)
|
| 31 |
+
return (xQ * xD).sum(dim=-1).cpu().tolist()
|
| 32 |
+
|
| 33 |
+
def token_expand(self, query: str) -> Dict[str, float]:
|
| 34 |
+
vec = self.forward([query])
|
| 35 |
cols = vec.nonzero().squeeze().cpu().tolist()
|
| 36 |
weights = vec[cols].cpu().tolist()
|
| 37 |
|
ask_candid/tools/elastic/index_search_tool.py
CHANGED
|
@@ -40,6 +40,7 @@ class SearchToolInput(BaseModel):
|
|
| 40 |
|
| 41 |
|
| 42 |
def elastic_search(
|
|
|
|
| 43 |
index_name: str,
|
| 44 |
query: str,
|
| 45 |
from_: int = 0,
|
|
@@ -107,9 +108,15 @@ def elastic_search(
|
|
| 107 |
return msg
|
| 108 |
|
| 109 |
|
| 110 |
-
def create_search_tool():
|
| 111 |
return StructuredTool.from_function(
|
| 112 |
-
elastic_search
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
name="elastic_index_search_tool",
|
| 114 |
description=(
|
| 115 |
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
def elastic_search(
|
| 43 |
+
pcs_codes: dict,
|
| 44 |
index_name: str,
|
| 45 |
query: str,
|
| 46 |
from_: int = 0,
|
|
|
|
| 108 |
return msg
|
| 109 |
|
| 110 |
|
| 111 |
+
def create_search_tool(pcs_codes):
|
| 112 |
return StructuredTool.from_function(
|
| 113 |
+
func=lambda index_name, query, from_, size: elastic_search(
|
| 114 |
+
pcs_codes=pcs_codes,
|
| 115 |
+
index_name=index_name,
|
| 116 |
+
query=query,
|
| 117 |
+
from_=from_,
|
| 118 |
+
size=size,
|
| 119 |
+
),
|
| 120 |
name="elastic_index_search_tool",
|
| 121 |
description=(
|
| 122 |
"""This tool allows executing queries on an Elasticsearch index efficiently. Provide:
|
ask_candid/tools/question_reformulation.py
CHANGED
|
@@ -1,55 +1,55 @@
|
|
| 1 |
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
from langchain_core.output_parsers import StrOutputParser
|
|
|
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
-
def reformulate_question_using_history(state, llm, focus_on_recommendations=False):
|
| 6 |
-
"""
|
| 7 |
-
Transform the query to produce a better query with details from previous messages and emphasize
|
| 8 |
-
aspects important for recommendations if needed.
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
"""
|
|
|
|
| 19 |
print("---REFORMULATE THE USER INPUT---")
|
| 20 |
messages = state["messages"]
|
| 21 |
question = messages[-1].content
|
| 22 |
|
| 23 |
-
if len(messages) > 1:
|
| 24 |
if focus_on_recommendations:
|
| 25 |
-
prompt_text = """Given a chat history and the latest user input
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
{chat_history}
|
| 32 |
-
\n ------- \n
|
| 33 |
-
User input:
|
| 34 |
-
\n ------- \n
|
| 35 |
-
{question}
|
| 36 |
-
\n ------- \n
|
| 37 |
Reformulate the question without adding implications or assumptions about the user's needs or intentions.
|
| 38 |
Focus solely on clarifying any contextual details present in the original input."""
|
| 39 |
else:
|
| 40 |
-
prompt_text = """Given a chat history and the latest user input
|
| 41 |
-
which
|
| 42 |
-
|
| 43 |
-
Chat history:
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
User input:
|
| 48 |
-
\n ------- \n
|
| 49 |
-
{question}
|
| 50 |
-
\n ------- \n
|
| 51 |
-
Do NOT answer the question, \
|
| 52 |
-
just reformulate it if needed and otherwise return it as is.
|
| 53 |
"""
|
| 54 |
|
| 55 |
contextualize_q_prompt = ChatPromptTemplate([
|
|
@@ -58,7 +58,11 @@ def reformulate_question_using_history(state, llm, focus_on_recommendations=Fals
|
|
| 58 |
])
|
| 59 |
|
| 60 |
rag_chain = contextualize_q_prompt | llm | StrOutputParser()
|
| 61 |
-
new_question = rag_chain.invoke({"chat_history": messages, "question": question})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
|
| 63 |
return {"messages": [new_question], "user_input" : question}
|
| 64 |
return {"messages": [question], "user_input" : question}
|
|
|
|
| 1 |
from langchain_core.prompts import ChatPromptTemplate
|
| 2 |
from langchain_core.output_parsers import StrOutputParser
|
| 3 |
+
from langchain_core.language_models.llms import LLM
|
| 4 |
|
| 5 |
+
from ask_candid.agents.schema import AgentState
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
def reformulate_question_using_history(
|
| 9 |
+
state: AgentState,
|
| 10 |
+
llm: LLM,
|
| 11 |
+
focus_on_recommendations: bool = False
|
| 12 |
+
) -> AgentState:
|
| 13 |
+
"""Transform the query to produce a better query with details from previous messages and emphasize aspects important
|
| 14 |
+
for recommendations if needed.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
state : AgentState
|
| 19 |
+
The current state
|
| 20 |
+
llm : LLM
|
| 21 |
+
focus_on_recommendations : bool, optional
|
| 22 |
+
Flag to determine if the reformulation should emphasize recommendation-relevant aspects such as geographies,
|
| 23 |
+
cause areas, etc., by default False
|
| 24 |
+
|
| 25 |
+
Returns
|
| 26 |
+
-------
|
| 27 |
+
AgentState
|
| 28 |
+
The updated state
|
| 29 |
"""
|
| 30 |
+
|
| 31 |
print("---REFORMULATE THE USER INPUT---")
|
| 32 |
messages = state["messages"]
|
| 33 |
question = messages[-1].content
|
| 34 |
|
| 35 |
+
if len(messages[:-1]) > 1: # need to skip the system message
|
| 36 |
if focus_on_recommendations:
|
| 37 |
+
prompt_text = """Given a chat history and the latest user input which might reference context in the chat
|
| 38 |
+
history, especially geographic locations, cause areas and/or population groups, formulate a standalone input
|
| 39 |
+
which can be understood without the chat history.
|
| 40 |
+
Chat history: ```{chat_history}```
|
| 41 |
+
User input: ```{question}```
|
| 42 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
Reformulate the question without adding implications or assumptions about the user's needs or intentions.
|
| 44 |
Focus solely on clarifying any contextual details present in the original input."""
|
| 45 |
else:
|
| 46 |
+
prompt_text = """Given a chat history and the latest user input which might reference context in the chat
|
| 47 |
+
history, formulate a standalone input which can be understood without the chat history. Include hints as to
|
| 48 |
+
what the user is getting at given the context in the chat history.
|
| 49 |
+
Chat history: ```{chat_history}```
|
| 50 |
+
User input: ```{question}```
|
| 51 |
+
|
| 52 |
+
Do NOT answer the question, just reformulate it if needed and otherwise return it as is.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
"""
|
| 54 |
|
| 55 |
contextualize_q_prompt = ChatPromptTemplate([
|
|
|
|
| 58 |
])
|
| 59 |
|
| 60 |
rag_chain = contextualize_q_prompt | llm | StrOutputParser()
|
| 61 |
+
# new_question = rag_chain.invoke({"chat_history": messages, "question": question})
|
| 62 |
+
new_question = rag_chain.invoke({
|
| 63 |
+
"chat_history": '\n'.join(f"{m.type.upper()}: {m.content}" for m in messages[1:]),
|
| 64 |
+
"question": question
|
| 65 |
+
})
|
| 66 |
print(f"user asked: '{question}', agent reformulated the question basing on the chat history: {new_question}")
|
| 67 |
return {"messages": [new_question], "user_input" : question}
|
| 68 |
return {"messages": [question], "user_input" : question}
|