Spaces:
Running
Running
UI callbacks and style changes
Browse files- app.py +2 -0
- ask_candid/base/config/data.py +5 -14
- ask_candid/chat.py +6 -1
- ask_candid/graph.py +53 -28
- ask_candid/retrieval/elastic.py +60 -117
- ask_candid/retrieval/sources/candid_blog.py +7 -0
- ask_candid/retrieval/sources/candid_help.py +7 -0
- ask_candid/retrieval/sources/candid_learning.py +7 -0
- ask_candid/retrieval/sources/candid_news.py +7 -0
- ask_candid/retrieval/sources/issuelab.py +7 -0
- ask_candid/retrieval/sources/schema.py +9 -0
- ask_candid/retrieval/sources/youtube.py +8 -0
- ask_candid/tools/elastic/list_indices_tool.py +2 -1
- ask_candid/tools/org_seach.py +25 -12
- ask_candid/tools/search.py +86 -2
- ask_candid/utils.py +1 -1
app.py
CHANGED
|
@@ -147,6 +147,8 @@ def build_rag_chat() -> Tuple[LoggedComponents, gr.Blocks]:
|
|
| 147 |
show_copy_button=True,
|
| 148 |
show_share_button=None,
|
| 149 |
show_copy_all_button=False,
|
|
|
|
|
|
|
| 150 |
)
|
| 151 |
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
| 152 |
thread_id = gr.Text(visible=False, value="", label="thread_id")
|
|
|
|
| 147 |
show_copy_button=True,
|
| 148 |
show_share_button=None,
|
| 149 |
show_copy_all_button=False,
|
| 150 |
+
autoscroll=True,
|
| 151 |
+
layout="panel",
|
| 152 |
)
|
| 153 |
msg = gr.MultimodalTextbox(label="Your message", interactive=True)
|
| 154 |
thread_id = gr.Text(visible=False, value="", label="thread_id")
|
ask_candid/base/config/data.py
CHANGED
|
@@ -1,21 +1,12 @@
|
|
| 1 |
-
|
| 2 |
-
"Mapping from plain name to Elasticsearch index name"
|
| 3 |
|
| 4 |
-
|
| 5 |
-
ISSUELAB_INDEX_ELSER = "search-semantic-issuelab-elser_ve2"
|
| 6 |
-
YOUTUBE_INDEX = "search-semantic-youtube_v1"
|
| 7 |
-
YOUTUBE_INDEX_ELSER = "search-semantic-youtube-elser_ve1"
|
| 8 |
-
CANDID_BLOG_INDEX = "search-semantic-candid-blog_v1"
|
| 9 |
-
CANDID_BLOG_INDEX_ELSER = "search-semantic-candid-blog"
|
| 10 |
-
CANDID_LEARNING_INDEX_ELSER = "search-semantic-candid-learning_ve1"
|
| 11 |
-
CANDID_HELP_INDEX_ELSER = "search-semantic-candid-help-elser_ve1"
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
ALL_INDICES = (
|
| 15 |
"issuelab",
|
| 16 |
"youtube",
|
| 17 |
"candid_blog",
|
| 18 |
"candid_learning",
|
| 19 |
"candid_help",
|
| 20 |
"news"
|
| 21 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Literal, get_args
|
|
|
|
| 2 |
|
| 3 |
+
DataIndices = Literal[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"issuelab",
|
| 5 |
"youtube",
|
| 6 |
"candid_blog",
|
| 7 |
"candid_learning",
|
| 8 |
"candid_help",
|
| 9 |
"news"
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
ALL_INDICES = get_args(DataIndices)
|
ask_candid/chat.py
CHANGED
|
@@ -29,7 +29,12 @@ def run_chat(
|
|
| 29 |
config = {"configurable": {"thread_id": thread_id}}
|
| 30 |
|
| 31 |
enable_recommendations = "Recommendation" in premium_features
|
| 32 |
-
workflow = build_compute_graph(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
memory = MemorySaver() # TODO: don't use for Prod
|
| 35 |
graph = workflow.compile(checkpointer=memory)
|
|
|
|
| 29 |
config = {"configurable": {"thread_id": thread_id}}
|
| 30 |
|
| 31 |
enable_recommendations = "Recommendation" in premium_features
|
| 32 |
+
workflow = build_compute_graph(
|
| 33 |
+
llm=llm,
|
| 34 |
+
indices=indices,
|
| 35 |
+
user_callback=gr.Info,
|
| 36 |
+
enable_recommendations=enable_recommendations
|
| 37 |
+
)
|
| 38 |
|
| 39 |
memory = MemorySaver() # TODO: don't use for Prod
|
| 40 |
graph = workflow.compile(checkpointer=memory)
|
ask_candid/graph.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
from functools import partial
|
| 3 |
import logging
|
| 4 |
|
|
@@ -11,7 +11,6 @@ from langgraph.prebuilt import tools_condition, ToolNode
|
|
| 11 |
from langgraph.graph.state import StateGraph
|
| 12 |
from langgraph.constants import START, END
|
| 13 |
|
| 14 |
-
from ask_candid.retrieval.elastic import retriever_tool
|
| 15 |
from ask_candid.tools.recommendation import (
|
| 16 |
detect_intent_with_llm,
|
| 17 |
determine_context,
|
|
@@ -19,8 +18,9 @@ from ask_candid.tools.recommendation import (
|
|
| 19 |
)
|
| 20 |
from ask_candid.tools.question_reformulation import reformulate_question_using_history
|
| 21 |
from ask_candid.tools.org_seach import has_org_name, insert_org_link
|
| 22 |
-
from ask_candid.tools.search import search_agent
|
| 23 |
from ask_candid.agents.schema import AgentState
|
|
|
|
| 24 |
|
| 25 |
from ask_candid.utils import html_format_docs_chat
|
| 26 |
|
|
@@ -29,7 +29,11 @@ logger = logging.getLogger(__name__)
|
|
| 29 |
logger.setLevel(logging.INFO)
|
| 30 |
|
| 31 |
|
| 32 |
-
def generate_with_context(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""Generate answer.
|
| 34 |
|
| 35 |
Parameters
|
|
@@ -37,6 +41,8 @@ def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
|
|
| 37 |
state : AgentState
|
| 38 |
The current state
|
| 39 |
llm : LLM
|
|
|
|
|
|
|
| 40 |
|
| 41 |
Returns
|
| 42 |
-------
|
|
@@ -45,14 +51,20 @@ def generate_with_context(state: AgentState, llm: LLM) -> AgentState:
|
|
| 45 |
"""
|
| 46 |
|
| 47 |
logger.info("---GENERATE ANSWER---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
messages = state["messages"]
|
| 49 |
question = state["user_input"]
|
| 50 |
last_message = messages[-1]
|
| 51 |
|
| 52 |
sources_str = last_message.content
|
| 53 |
-
sources_list = last_message.artifact
|
| 54 |
-
# converting to html string
|
| 55 |
sources_html = html_format_docs_chat(sources_list)
|
|
|
|
| 56 |
if sources_list:
|
| 57 |
logger.info("---ADD SOURCES---")
|
| 58 |
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
|
|
@@ -97,13 +109,13 @@ def add_recommendations_pipeline_(
|
|
| 97 |
"""
|
| 98 |
|
| 99 |
# Nodes for recommendation functionalities
|
| 100 |
-
G.add_node("detect_intent_with_llm", partial(detect_intent_with_llm, llm=llm))
|
| 101 |
-
G.add_node("determine_context", determine_context)
|
| 102 |
-
G.add_node("make_recommendation", make_recommendation)
|
| 103 |
|
| 104 |
# Check for recommendation query first
|
| 105 |
# Execute until reaching END if user asks for recommendation
|
| 106 |
-
G.add_edge(reformulation_node_name, "detect_intent_with_llm")
|
| 107 |
G.add_conditional_edges(
|
| 108 |
source="detect_intent_with_llm",
|
| 109 |
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
|
|
@@ -112,24 +124,27 @@ def add_recommendations_pipeline_(
|
|
| 112 |
search_node_name: search_node_name
|
| 113 |
},
|
| 114 |
)
|
| 115 |
-
G.add_edge("determine_context", "make_recommendation")
|
| 116 |
-
G.add_edge("make_recommendation", END)
|
| 117 |
|
| 118 |
|
| 119 |
def build_compute_graph(
|
| 120 |
llm: LLM,
|
| 121 |
-
indices: List[
|
| 122 |
-
enable_recommendations: bool = False
|
|
|
|
| 123 |
) -> StateGraph:
|
| 124 |
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
|
| 125 |
|
| 126 |
Parameters
|
| 127 |
----------
|
| 128 |
llm : LLM
|
| 129 |
-
indices : List[
|
| 130 |
Semantic index names to search over
|
| 131 |
enable_recommendations : bool, optional
|
| 132 |
Set to `True` to allow the flow to generate recommendations based on context, by default False
|
|
|
|
|
|
|
| 133 |
|
| 134 |
Returns
|
| 135 |
-------
|
|
@@ -137,25 +152,35 @@ def build_compute_graph(
|
|
| 137 |
Execution graph
|
| 138 |
"""
|
| 139 |
|
| 140 |
-
candid_retriever_tool = retriever_tool(indices=indices)
|
| 141 |
retrieve = ToolNode([candid_retriever_tool])
|
| 142 |
tools = [candid_retriever_tool]
|
| 143 |
|
| 144 |
G = StateGraph(AgentState)
|
| 145 |
|
| 146 |
-
G.add_node(
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
G.add_node("
|
| 151 |
-
G.add_node("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
if enable_recommendations:
|
| 154 |
-
add_recommendations_pipeline_(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
else:
|
| 156 |
-
G.add_edge("reformulate", "search_agent")
|
| 157 |
|
| 158 |
-
G.add_edge(START, "reformulate")
|
| 159 |
G.add_conditional_edges(
|
| 160 |
source="search_agent",
|
| 161 |
path=tools_condition,
|
|
@@ -164,8 +189,8 @@ def build_compute_graph(
|
|
| 164 |
END: "has_org_name",
|
| 165 |
},
|
| 166 |
)
|
| 167 |
-
G.add_edge("retrieve", "generate_with_context")
|
| 168 |
-
G.add_edge("generate_with_context", "has_org_name")
|
| 169 |
G.add_conditional_edges(
|
| 170 |
source="has_org_name",
|
| 171 |
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
|
|
@@ -174,5 +199,5 @@ def build_compute_graph(
|
|
| 174 |
END: END
|
| 175 |
},
|
| 176 |
)
|
| 177 |
-
G.add_edge("insert_org_link", END)
|
| 178 |
return G
|
|
|
|
| 1 |
+
from typing import List, Optional, Callable, Any
|
| 2 |
from functools import partial
|
| 3 |
import logging
|
| 4 |
|
|
|
|
| 11 |
from langgraph.graph.state import StateGraph
|
| 12 |
from langgraph.constants import START, END
|
| 13 |
|
|
|
|
| 14 |
from ask_candid.tools.recommendation import (
|
| 15 |
detect_intent_with_llm,
|
| 16 |
determine_context,
|
|
|
|
| 18 |
)
|
| 19 |
from ask_candid.tools.question_reformulation import reformulate_question_using_history
|
| 20 |
from ask_candid.tools.org_seach import has_org_name, insert_org_link
|
| 21 |
+
from ask_candid.tools.search import search_agent, retriever_tool
|
| 22 |
from ask_candid.agents.schema import AgentState
|
| 23 |
+
from ask_candid.base.config.data import DataIndices
|
| 24 |
|
| 25 |
from ask_candid.utils import html_format_docs_chat
|
| 26 |
|
|
|
|
| 29 |
logger.setLevel(logging.INFO)
|
| 30 |
|
| 31 |
|
| 32 |
+
def generate_with_context(
|
| 33 |
+
state: AgentState,
|
| 34 |
+
llm: LLM,
|
| 35 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
| 36 |
+
) -> AgentState:
|
| 37 |
"""Generate answer.
|
| 38 |
|
| 39 |
Parameters
|
|
|
|
| 41 |
state : AgentState
|
| 42 |
The current state
|
| 43 |
llm : LLM
|
| 44 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
| 45 |
+
Optional UI callback to inform the user of apps states, by default None
|
| 46 |
|
| 47 |
Returns
|
| 48 |
-------
|
|
|
|
| 51 |
"""
|
| 52 |
|
| 53 |
logger.info("---GENERATE ANSWER---")
|
| 54 |
+
if user_callback is not None:
|
| 55 |
+
try:
|
| 56 |
+
user_callback("Writing a response...")
|
| 57 |
+
except Exception as ex:
|
| 58 |
+
logger.warning("User callback was passed in but failed: %s", ex)
|
| 59 |
+
|
| 60 |
messages = state["messages"]
|
| 61 |
question = state["user_input"]
|
| 62 |
last_message = messages[-1]
|
| 63 |
|
| 64 |
sources_str = last_message.content
|
| 65 |
+
sources_list = last_message.artifact
|
|
|
|
| 66 |
sources_html = html_format_docs_chat(sources_list)
|
| 67 |
+
|
| 68 |
if sources_list:
|
| 69 |
logger.info("---ADD SOURCES---")
|
| 70 |
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
|
|
|
|
| 109 |
"""
|
| 110 |
|
| 111 |
# Nodes for recommendation functionalities
|
| 112 |
+
G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm))
|
| 113 |
+
G.add_node(node="determine_context", action=determine_context)
|
| 114 |
+
G.add_node(node="make_recommendation", action=make_recommendation)
|
| 115 |
|
| 116 |
# Check for recommendation query first
|
| 117 |
# Execute until reaching END if user asks for recommendation
|
| 118 |
+
G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm")
|
| 119 |
G.add_conditional_edges(
|
| 120 |
source="detect_intent_with_llm",
|
| 121 |
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
|
|
|
|
| 124 |
search_node_name: search_node_name
|
| 125 |
},
|
| 126 |
)
|
| 127 |
+
G.add_edge(start_key="determine_context", end_key="make_recommendation")
|
| 128 |
+
G.add_edge(start_key="make_recommendation", end_key=END)
|
| 129 |
|
| 130 |
|
| 131 |
def build_compute_graph(
|
| 132 |
llm: LLM,
|
| 133 |
+
indices: List[DataIndices],
|
| 134 |
+
enable_recommendations: bool = False,
|
| 135 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
| 136 |
) -> StateGraph:
|
| 137 |
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
|
| 138 |
|
| 139 |
Parameters
|
| 140 |
----------
|
| 141 |
llm : LLM
|
| 142 |
+
indices : List[DataIndices]
|
| 143 |
Semantic index names to search over
|
| 144 |
enable_recommendations : bool, optional
|
| 145 |
Set to `True` to allow the flow to generate recommendations based on context, by default False
|
| 146 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
| 147 |
+
Optional UI callback to inform the user of apps states, by default None
|
| 148 |
|
| 149 |
Returns
|
| 150 |
-------
|
|
|
|
| 152 |
Execution graph
|
| 153 |
"""
|
| 154 |
|
| 155 |
+
candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback)
|
| 156 |
retrieve = ToolNode([candid_retriever_tool])
|
| 157 |
tools = [candid_retriever_tool]
|
| 158 |
|
| 159 |
G = StateGraph(AgentState)
|
| 160 |
|
| 161 |
+
G.add_node(
|
| 162 |
+
node="reformulate",
|
| 163 |
+
action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations)
|
| 164 |
+
)
|
| 165 |
+
G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools))
|
| 166 |
+
G.add_node(node="retrieve", action=retrieve)
|
| 167 |
+
G.add_node(
|
| 168 |
+
node="generate_with_context",
|
| 169 |
+
action=partial(generate_with_context, llm=llm, user_callback=user_callback)
|
| 170 |
+
)
|
| 171 |
+
G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback))
|
| 172 |
+
G.add_node(node="insert_org_link", action=insert_org_link)
|
| 173 |
|
| 174 |
if enable_recommendations:
|
| 175 |
+
add_recommendations_pipeline_(
|
| 176 |
+
G, llm=llm,
|
| 177 |
+
reformulation_node_name="reformulate",
|
| 178 |
+
search_node_name="search_agent"
|
| 179 |
+
)
|
| 180 |
else:
|
| 181 |
+
G.add_edge(start_key="reformulate", end_key="search_agent")
|
| 182 |
|
| 183 |
+
G.add_edge(start_key=START, end_key="reformulate")
|
| 184 |
G.add_conditional_edges(
|
| 185 |
source="search_agent",
|
| 186 |
path=tools_condition,
|
|
|
|
| 189 |
END: "has_org_name",
|
| 190 |
},
|
| 191 |
)
|
| 192 |
+
G.add_edge(start_key="retrieve", end_key="generate_with_context")
|
| 193 |
+
G.add_edge(start_key="generate_with_context", end_key="has_org_name")
|
| 194 |
G.add_conditional_edges(
|
| 195 |
source="has_org_name",
|
| 196 |
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
|
|
|
|
| 199 |
END: END
|
| 200 |
},
|
| 201 |
)
|
| 202 |
+
G.add_edge(start_key="insert_org_link", end_key=END)
|
| 203 |
return G
|
ask_candid/retrieval/elastic.py
CHANGED
|
@@ -1,20 +1,24 @@
|
|
| 1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
| 2 |
from dataclasses import dataclass
|
| 3 |
-
from functools import partial
|
| 4 |
from itertools import groupby
|
| 5 |
|
| 6 |
from torch.nn import functional as F
|
| 7 |
|
| 8 |
from pydantic import BaseModel, Field
|
| 9 |
from langchain_core.documents import Document
|
| 10 |
-
from langchain_core.tools import Tool
|
| 11 |
|
| 12 |
from elasticsearch import Elasticsearch
|
| 13 |
|
| 14 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
from ask_candid.services.small_lm import CandidSLM
|
| 16 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
| 17 |
-
from ask_candid.base.config.data import
|
| 18 |
|
| 19 |
encoder = SpladeEncoder()
|
| 20 |
|
|
@@ -82,6 +86,18 @@ def build_sparse_vector_query(
|
|
| 82 |
|
| 83 |
|
| 84 |
def news_query_builder(query: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
tokens = encoder.token_expand(query)
|
| 86 |
|
| 87 |
query = {
|
|
@@ -103,81 +119,70 @@ def news_query_builder(query: str) -> Dict[str, Any]:
|
|
| 103 |
query["query"]["bool"]["should"].append({
|
| 104 |
"multi_match": {
|
| 105 |
"query": token,
|
| 106 |
-
"fields":
|
| 107 |
"boost": score
|
| 108 |
}
|
| 109 |
})
|
| 110 |
return query
|
| 111 |
|
| 112 |
|
| 113 |
-
def query_builder(query: str, indices: List[
|
| 114 |
"""Builds Elasticsearch multi-search query payload
|
| 115 |
|
| 116 |
Parameters
|
| 117 |
----------
|
| 118 |
query : str
|
| 119 |
Search context string
|
| 120 |
-
indices : List[
|
| 121 |
Semantic index names to search over
|
| 122 |
|
| 123 |
Returns
|
| 124 |
-------
|
| 125 |
-
List[Dict[str, Any]]
|
|
|
|
| 126 |
"""
|
| 127 |
|
| 128 |
-
queries = []
|
| 129 |
if indices is None:
|
| 130 |
indices = list(ALL_INDICES)
|
| 131 |
|
| 132 |
for index in indices:
|
| 133 |
if index == "issuelab":
|
| 134 |
-
q = build_sparse_vector_query(
|
| 135 |
-
query=query,
|
| 136 |
-
fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
| 137 |
-
)
|
| 138 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 139 |
q["size"] = 1
|
| 140 |
-
queries.extend([{"index":
|
| 141 |
elif index == "youtube":
|
| 142 |
-
q = build_sparse_vector_query(
|
| 143 |
-
|
| 144 |
-
fields=("captions_cleaned", "description_cleaned", "title")
|
| 145 |
-
)
|
| 146 |
-
# text_cleaned duplicates captions_cleaned
|
| 147 |
-
q["_source"] = {"excludes": ["embeddings", "captions", "description", "text_cleaned"]}
|
| 148 |
q["size"] = 2
|
| 149 |
-
queries.extend([{"index":
|
| 150 |
elif index == "candid_blog":
|
| 151 |
-
q = build_sparse_vector_query(
|
| 152 |
-
query=query,
|
| 153 |
-
fields=("content", "authors_text", "title_summary_tags")
|
| 154 |
-
)
|
| 155 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 156 |
q["size"] = 2
|
| 157 |
-
queries.extend([{"index":
|
| 158 |
elif index == "candid_learning":
|
| 159 |
-
q = build_sparse_vector_query(
|
| 160 |
-
query=query,
|
| 161 |
-
fields=("content", "title", "training_topics", "staff_recommendations")
|
| 162 |
-
)
|
| 163 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 164 |
q["size"] = 2
|
| 165 |
-
queries.extend([{"index":
|
| 166 |
elif index == "candid_help":
|
| 167 |
-
q = build_sparse_vector_query(
|
| 168 |
-
query=query,
|
| 169 |
-
fields=("content", "combined_article_description")
|
| 170 |
-
)
|
| 171 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 172 |
q["size"] = 2
|
| 173 |
-
queries.extend([{"index":
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
return queries
|
| 176 |
|
| 177 |
|
| 178 |
def multi_search(
|
| 179 |
queries: List[Dict[str, Any]],
|
| 180 |
-
|
| 181 |
) -> List[ElasticHitsResult]:
|
| 182 |
"""Runs multi-search query
|
| 183 |
|
|
@@ -191,6 +196,17 @@ def multi_search(
|
|
| 191 |
List[ElasticHitsResult]
|
| 192 |
"""
|
| 193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
results = []
|
| 195 |
|
| 196 |
if len(queries) > 0:
|
|
@@ -200,31 +216,16 @@ def multi_search(
|
|
| 200 |
verify_certs=False,
|
| 201 |
request_timeout=60 * 3
|
| 202 |
) as es:
|
| 203 |
-
for
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
id=hit["_id"],
|
| 208 |
-
score=hit["_score"],
|
| 209 |
-
source=hit["_source"],
|
| 210 |
-
inner_hits=hit.get("inner_hits", {})
|
| 211 |
-
)
|
| 212 |
-
results.append(hit)
|
| 213 |
-
|
| 214 |
-
if news_query is not None:
|
| 215 |
with Elasticsearch(
|
| 216 |
NEWS_ELASTIC.url,
|
| 217 |
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password),
|
| 218 |
timeout=60
|
| 219 |
) as es:
|
| 220 |
-
for hit in es.
|
| 221 |
-
hit = ElasticHitsResult(
|
| 222 |
-
index=hit["_index"],
|
| 223 |
-
id=hit["_id"],
|
| 224 |
-
score=hit["_score"],
|
| 225 |
-
source=hit["_source"],
|
| 226 |
-
inner_hits=hit.get("inner_hits", {})
|
| 227 |
-
)
|
| 228 |
results.append(hit)
|
| 229 |
return results
|
| 230 |
|
|
@@ -244,9 +245,8 @@ def get_query_results(search_text: str, indices: Optional[List[str]] = None) ->
|
|
| 244 |
List[ElasticHitsResult]
|
| 245 |
"""
|
| 246 |
|
| 247 |
-
queries = query_builder(query=search_text, indices=indices)
|
| 248 |
-
|
| 249 |
-
return multi_search(queries, news_query=news_q)
|
| 250 |
|
| 251 |
|
| 252 |
def retrieved_text(hits: Dict[str, Any]) -> str:
|
|
@@ -335,36 +335,6 @@ def reranker(
|
|
| 335 |
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
| 336 |
|
| 337 |
|
| 338 |
-
def get_results(user_input: str, indices: List[str]) -> Tuple[str, List[Document]]:
|
| 339 |
-
"""End-to-end search and re-rank function.
|
| 340 |
-
|
| 341 |
-
Parameters
|
| 342 |
-
----------
|
| 343 |
-
user_input : str
|
| 344 |
-
Search context string
|
| 345 |
-
indices : List[str]
|
| 346 |
-
Semantic index names to search over
|
| 347 |
-
|
| 348 |
-
Returns
|
| 349 |
-
-------
|
| 350 |
-
Tuple[str, List[Document]]
|
| 351 |
-
(concatenated text from search results, documents list)
|
| 352 |
-
"""
|
| 353 |
-
|
| 354 |
-
output = ["Search didn't return any Candid sources"]
|
| 355 |
-
page_content = []
|
| 356 |
-
content = "Search didn't return any Candid sources"
|
| 357 |
-
results = get_query_results(search_text=user_input, indices=indices)
|
| 358 |
-
if results:
|
| 359 |
-
output = get_reranked_results(results, search_text=user_input)
|
| 360 |
-
for doc in output:
|
| 361 |
-
page_content.append(doc.page_content)
|
| 362 |
-
content = "\n\n".join(page_content)
|
| 363 |
-
|
| 364 |
-
# for the tool we need to return a tuple for content_and_artifact type
|
| 365 |
-
return content, output
|
| 366 |
-
|
| 367 |
-
|
| 368 |
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
| 369 |
"""Pads the relevant chunk of text with context before and after
|
| 370 |
|
|
@@ -537,30 +507,3 @@ def get_reranked_results(results: List[ElasticHitsResult], search_text: Optional
|
|
| 537 |
if hit is not None:
|
| 538 |
output.append(hit)
|
| 539 |
return output
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
def retriever_tool(indices: List[str]) -> Tool:
|
| 543 |
-
"""Tool component for use in conditional edge building for RAG execution graph.
|
| 544 |
-
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
| 545 |
-
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
| 546 |
-
|
| 547 |
-
Parameters
|
| 548 |
-
----------
|
| 549 |
-
indices : List[str]
|
| 550 |
-
Semantic index names to search over
|
| 551 |
-
|
| 552 |
-
Returns
|
| 553 |
-
-------
|
| 554 |
-
Tool
|
| 555 |
-
"""
|
| 556 |
-
|
| 557 |
-
return Tool(
|
| 558 |
-
name="retrieve_social_sector_information",
|
| 559 |
-
func=partial(get_results, indices=indices),
|
| 560 |
-
description=(
|
| 561 |
-
"Return additional information about social and philanthropic sector, "
|
| 562 |
-
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
| 563 |
-
),
|
| 564 |
-
args_schema=RetrieverInput,
|
| 565 |
-
response_format="content_and_artifact"
|
| 566 |
-
)
|
|
|
|
| 1 |
from typing import List, Tuple, Dict, Iterable, Iterator, Optional, Union, Any
|
| 2 |
from dataclasses import dataclass
|
|
|
|
| 3 |
from itertools import groupby
|
| 4 |
|
| 5 |
from torch.nn import functional as F
|
| 6 |
|
| 7 |
from pydantic import BaseModel, Field
|
| 8 |
from langchain_core.documents import Document
|
|
|
|
| 9 |
|
| 10 |
from elasticsearch import Elasticsearch
|
| 11 |
|
| 12 |
from ask_candid.retrieval.sparse_lexical import SpladeEncoder
|
| 13 |
+
from ask_candid.retrieval.sources.issuelab import IssueLabConfig
|
| 14 |
+
from ask_candid.retrieval.sources.youtube import YoutubeConfig
|
| 15 |
+
from ask_candid.retrieval.sources.candid_blog import CandidBlogConfig
|
| 16 |
+
from ask_candid.retrieval.sources.candid_learning import CandidLearningConfig
|
| 17 |
+
from ask_candid.retrieval.sources.candid_help import CandidHelpConfig
|
| 18 |
+
from ask_candid.retrieval.sources.candid_news import CandidNewsConfig
|
| 19 |
from ask_candid.services.small_lm import CandidSLM
|
| 20 |
from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
|
| 21 |
+
from ask_candid.base.config.data import DataIndices, ALL_INDICES
|
| 22 |
|
| 23 |
encoder = SpladeEncoder()
|
| 24 |
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def news_query_builder(query: str) -> Dict[str, Any]:
|
| 89 |
+
"""Builds a valid Elasticsearch query against Candid news, simulating a token expansion.
|
| 90 |
+
|
| 91 |
+
Parameters
|
| 92 |
+
----------
|
| 93 |
+
query : str
|
| 94 |
+
Search context string
|
| 95 |
+
|
| 96 |
+
Returns
|
| 97 |
+
-------
|
| 98 |
+
Dict[str, Any]
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
tokens = encoder.token_expand(query)
|
| 102 |
|
| 103 |
query = {
|
|
|
|
| 119 |
query["query"]["bool"]["should"].append({
|
| 120 |
"multi_match": {
|
| 121 |
"query": token,
|
| 122 |
+
"fields": CandidNewsConfig.text_fields,
|
| 123 |
"boost": score
|
| 124 |
}
|
| 125 |
})
|
| 126 |
return query
|
| 127 |
|
| 128 |
|
| 129 |
+
def query_builder(query: str, indices: List[DataIndices]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
| 130 |
"""Builds Elasticsearch multi-search query payload
|
| 131 |
|
| 132 |
Parameters
|
| 133 |
----------
|
| 134 |
query : str
|
| 135 |
Search context string
|
| 136 |
+
indices : List[DataIndices]
|
| 137 |
Semantic index names to search over
|
| 138 |
|
| 139 |
Returns
|
| 140 |
-------
|
| 141 |
+
Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]
|
| 142 |
+
(semantic index queries, news queries)
|
| 143 |
"""
|
| 144 |
|
| 145 |
+
queries, news_queries = [], []
|
| 146 |
if indices is None:
|
| 147 |
indices = list(ALL_INDICES)
|
| 148 |
|
| 149 |
for index in indices:
|
| 150 |
if index == "issuelab":
|
| 151 |
+
q = build_sparse_vector_query(query=query, fields=IssueLabConfig.text_fields)
|
|
|
|
|
|
|
|
|
|
| 152 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 153 |
q["size"] = 1
|
| 154 |
+
queries.extend([{"index": IssueLabConfig.index_name}, q])
|
| 155 |
elif index == "youtube":
|
| 156 |
+
q = build_sparse_vector_query(query=query, fields=YoutubeConfig.text_fields)
|
| 157 |
+
q["_source"] = {"excludes": ["embeddings", *YoutubeConfig.excluded_fields]}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
q["size"] = 2
|
| 159 |
+
queries.extend([{"index": YoutubeConfig.index_name}, q])
|
| 160 |
elif index == "candid_blog":
|
| 161 |
+
q = build_sparse_vector_query(query=query, fields=CandidBlogConfig.text_fields)
|
|
|
|
|
|
|
|
|
|
| 162 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 163 |
q["size"] = 2
|
| 164 |
+
queries.extend([{"index": CandidBlogConfig.index_name}, q])
|
| 165 |
elif index == "candid_learning":
|
| 166 |
+
q = build_sparse_vector_query(query=query, fields=CandidLearningConfig.text_fields)
|
|
|
|
|
|
|
|
|
|
| 167 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 168 |
q["size"] = 2
|
| 169 |
+
queries.extend([{"index": CandidLearningConfig.index_name}, q])
|
| 170 |
elif index == "candid_help":
|
| 171 |
+
q = build_sparse_vector_query(query=query, fields=CandidHelpConfig.text_fields)
|
|
|
|
|
|
|
|
|
|
| 172 |
q["_source"] = {"excludes": ["embeddings"]}
|
| 173 |
q["size"] = 2
|
| 174 |
+
queries.extend([{"index": CandidHelpConfig.index_name}, q])
|
| 175 |
+
elif index == "news":
|
| 176 |
+
q = news_query_builder(query=query)
|
| 177 |
+
q["size"] = 5
|
| 178 |
+
news_queries.extend([{"index": CandidNewsConfig.index_name}, q])
|
| 179 |
|
| 180 |
+
return queries, news_queries
|
| 181 |
|
| 182 |
|
| 183 |
def multi_search(
|
| 184 |
queries: List[Dict[str, Any]],
|
| 185 |
+
news_queries: Optional[List[Dict[str, Any]]] = None
|
| 186 |
) -> List[ElasticHitsResult]:
|
| 187 |
"""Runs multi-search query
|
| 188 |
|
|
|
|
| 196 |
List[ElasticHitsResult]
|
| 197 |
"""
|
| 198 |
|
| 199 |
+
def _msearch_response_generator(responses: List[Dict[str, Any]]) -> Iterator[ElasticHitsResult]:
|
| 200 |
+
for query_group in responses:
|
| 201 |
+
for h in query_group.get("hits", {}).get("hits", []):
|
| 202 |
+
yield ElasticHitsResult(
|
| 203 |
+
index=h["_index"],
|
| 204 |
+
id=h["_id"],
|
| 205 |
+
score=h["_score"],
|
| 206 |
+
source=h["_source"],
|
| 207 |
+
inner_hits=h.get("inner_hits", {})
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
results = []
|
| 211 |
|
| 212 |
if len(queries) > 0:
|
|
|
|
| 216 |
verify_certs=False,
|
| 217 |
request_timeout=60 * 3
|
| 218 |
) as es:
|
| 219 |
+
for hit in _msearch_response_generator(es.msearch(body=queries).get("responses", [])):
|
| 220 |
+
results.append(hit)
|
| 221 |
+
|
| 222 |
+
if news_queries is not None and len(news_queries):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
with Elasticsearch(
|
| 224 |
NEWS_ELASTIC.url,
|
| 225 |
http_auth=(NEWS_ELASTIC.username, NEWS_ELASTIC.password),
|
| 226 |
timeout=60
|
| 227 |
) as es:
|
| 228 |
+
for hit in _msearch_response_generator(es.msearch(body=news_queries).get("responses", [])):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
results.append(hit)
|
| 230 |
return results
|
| 231 |
|
|
|
|
| 245 |
List[ElasticHitsResult]
|
| 246 |
"""
|
| 247 |
|
| 248 |
+
queries, news_q = query_builder(query=search_text, indices=indices)
|
| 249 |
+
return multi_search(queries, news_queries=news_q)
|
|
|
|
| 250 |
|
| 251 |
|
| 252 |
def retrieved_text(hits: Dict[str, Any]) -> str:
|
|
|
|
| 335 |
yield from sorted(results, key=lambda x: x.score, reverse=True)
|
| 336 |
|
| 337 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
|
| 339 |
"""Pads the relevant chunk of text with context before and after
|
| 340 |
|
|
|
|
| 507 |
if hit is not None:
|
| 508 |
output.append(hit)
|
| 509 |
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ask_candid/retrieval/sources/candid_blog.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
CandidBlogConfig = ElasticSourceConfig(
|
| 6 |
+
index_name="search-semantic-candid-blog",
|
| 7 |
+
text_fields=("content", "authors_text", "title_summary_tags")
|
| 8 |
+
)
|
| 9 |
|
| 10 |
|
| 11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/candid_help.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
CandidHelpConfig = ElasticSourceConfig(
|
| 6 |
+
index_name="search-semantic-candid-help-elser_ve1",
|
| 7 |
+
text_fields=("content", "combined_article_description")
|
| 8 |
+
)
|
| 9 |
|
| 10 |
|
| 11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/candid_learning.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
CandidLearningConfig = ElasticSourceConfig(
|
| 6 |
+
index_name="search-semantic-candid-learning_ve1",
|
| 7 |
+
text_fields=("content", "title", "training_topics", "staff_recommendations")
|
| 8 |
+
)
|
| 9 |
|
| 10 |
|
| 11 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/candid_news.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
CandidNewsConfig = ElasticSourceConfig(
|
| 5 |
+
index_name="news_1",
|
| 6 |
+
text_fields=("title", "content")
|
| 7 |
+
)
|
ask_candid/retrieval/sources/issuelab.py
CHANGED
|
@@ -1,4 +1,11 @@
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
IssueLabConfig = ElasticSourceConfig(
|
| 6 |
+
index_name="search-semantic-issuelab-elser_ve2",
|
| 7 |
+
text_fields=("description", "content", "combined_issuelab_findings", "combined_item_description")
|
| 8 |
+
)
|
| 9 |
|
| 10 |
|
| 11 |
def issuelab_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/retrieval/sources/schema.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Optional
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class ElasticSourceConfig:
|
| 7 |
+
index_name: str
|
| 8 |
+
text_fields: Tuple[str]
|
| 9 |
+
excluded_fields: Optional[Tuple[str]] = field(default_factory=tuple)
|
ask_candid/retrieval/sources/youtube.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
|
| 4 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
|
|
|
| 1 |
from typing import Dict, Any
|
| 2 |
+
from ask_candid.retrieval.sources.schema import ElasticSourceConfig
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
YoutubeConfig = ElasticSourceConfig(
|
| 6 |
+
index_name="search-semantic-youtube-elser_ve1",
|
| 7 |
+
text_fields=("captions_cleaned", "description_cleaned", "title"),
|
| 8 |
+
excluded_fields=("captions", "description", "text_cleaned")
|
| 9 |
+
)
|
| 10 |
|
| 11 |
|
| 12 |
def build_card_html(doc: Dict[str, Any], height_px: int = 200, show_chunks=False) -> str:
|
ask_candid/tools/elastic/list_indices_tool.py
CHANGED
|
@@ -31,7 +31,8 @@ class ListIndicesTool(BaseTool):
|
|
| 31 |
|
| 32 |
name: str = "elastic_list_indices" # Added type annotation
|
| 33 |
description: str = (
|
| 34 |
-
"Input is a delimiter like comma or new line. Output is a separated list of indices in the database.
|
|
|
|
| 35 |
)
|
| 36 |
args_schema: Optional[Type[BaseModel]] = (
|
| 37 |
ListIndicesInput # Define this before methods
|
|
|
|
| 31 |
|
| 32 |
name: str = "elastic_list_indices" # Added type annotation
|
| 33 |
description: str = (
|
| 34 |
+
"Input is a delimiter like comma or new line. Output is a separated list of indices in the database. "
|
| 35 |
+
"Always use this tool to get to know the indices in the ElasticSearch cluster."
|
| 36 |
)
|
| 37 |
args_schema: Optional[Type[BaseModel]] = (
|
| 38 |
ListIndicesInput # Define this before methods
|
ask_candid/tools/org_seach.py
CHANGED
|
@@ -1,11 +1,10 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
import logging
|
| 3 |
import re
|
| 4 |
|
| 5 |
from thefuzz import fuzz
|
| 6 |
|
| 7 |
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
| 8 |
-
# from langchain_openai.chat_models import ChatOpenAI
|
| 9 |
from langchain_core.runnables import RunnableSequence
|
| 10 |
from langchain_core.prompts import ChatPromptTemplate
|
| 11 |
from langchain_core.language_models.llms import LLM
|
|
@@ -15,7 +14,6 @@ from pydantic import BaseModel, Field
|
|
| 15 |
|
| 16 |
from ask_candid.agents.schema import AgentState
|
| 17 |
from ask_candid.services.org_search import OrgSearch
|
| 18 |
-
# from ask_candid.base.config.rest import OPENAI
|
| 19 |
|
| 20 |
search = OrgSearch()
|
| 21 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
|
@@ -59,7 +57,6 @@ def extract_org_links_from_chatbot(chatbot_output: str, llm: LLM):
|
|
| 59 |
|
| 60 |
try:
|
| 61 |
parser = JsonOutputToolsParser()
|
| 62 |
-
# llm = ChatOpenAI(model="gpt-4o", api_key=OPENAI["key"]).bind_tools([OrganizationNames])
|
| 63 |
model = llm.bind_tools([OrganizationNames])
|
| 64 |
prompt = ChatPromptTemplate.from_template(prompt)
|
| 65 |
chain = RunnableSequence(prompt, model, parser)
|
|
@@ -203,17 +200,33 @@ def embed_org_links_in_text(input_text: str, org_link_dict: dict):
|
|
| 203 |
return input_text
|
| 204 |
|
| 205 |
|
| 206 |
-
def has_org_name(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
"""
|
| 208 |
-
Processes the latest message to extract organization links and determine the next step.
|
| 209 |
|
| 210 |
-
Args:
|
| 211 |
-
state (AgentState): The current state of the agent, including a list of messages.
|
| 212 |
-
|
| 213 |
-
Returns:
|
| 214 |
-
dict: A dictionary with the next agent action and, if available, a dictionary of organization links.
|
| 215 |
-
"""
|
| 216 |
logger.info("---HAS ORG NAMES?---")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
messages = state["messages"]
|
| 218 |
last_message = messages[-1].content
|
| 219 |
output_list = extract_org_links_from_chatbot(last_message, llm=llm)
|
|
|
|
| 1 |
+
from typing import List, Optional, Callable, Any
|
| 2 |
import logging
|
| 3 |
import re
|
| 4 |
|
| 5 |
from thefuzz import fuzz
|
| 6 |
|
| 7 |
from langchain.output_parsers.openai_tools import JsonOutputToolsParser
|
|
|
|
| 8 |
from langchain_core.runnables import RunnableSequence
|
| 9 |
from langchain_core.prompts import ChatPromptTemplate
|
| 10 |
from langchain_core.language_models.llms import LLM
|
|
|
|
| 14 |
|
| 15 |
from ask_candid.agents.schema import AgentState
|
| 16 |
from ask_candid.services.org_search import OrgSearch
|
|
|
|
| 17 |
|
| 18 |
search = OrgSearch()
|
| 19 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
|
|
|
| 57 |
|
| 58 |
try:
|
| 59 |
parser = JsonOutputToolsParser()
|
|
|
|
| 60 |
model = llm.bind_tools([OrganizationNames])
|
| 61 |
prompt = ChatPromptTemplate.from_template(prompt)
|
| 62 |
chain = RunnableSequence(prompt, model, parser)
|
|
|
|
| 200 |
return input_text
|
| 201 |
|
| 202 |
|
| 203 |
+
def has_org_name(
|
| 204 |
+
state: AgentState,
|
| 205 |
+
llm: LLM,
|
| 206 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
| 207 |
+
) -> AgentState:
|
| 208 |
+
"""Processes the latest message to extract organization links and determine the next step.
|
| 209 |
+
|
| 210 |
+
Parameters
|
| 211 |
+
----------
|
| 212 |
+
state : AgentState
|
| 213 |
+
The current state of the agent, including a list of messages.
|
| 214 |
+
llm : LLM
|
| 215 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
| 216 |
+
Optional UI callback to inform the user of apps states, by default None
|
| 217 |
+
|
| 218 |
+
Returns
|
| 219 |
+
-------
|
| 220 |
+
AgentState
|
| 221 |
"""
|
|
|
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
logger.info("---HAS ORG NAMES?---")
|
| 224 |
+
if user_callback is not None:
|
| 225 |
+
try:
|
| 226 |
+
user_callback("Checking for relevant organizations")
|
| 227 |
+
except Exception as ex:
|
| 228 |
+
logger.warning("User callback was passed in but failed: %s", ex)
|
| 229 |
+
|
| 230 |
messages = state["messages"]
|
| 231 |
last_message = messages[-1].content
|
| 232 |
output_list = extract_org_links_from_chatbot(last_message, llm=llm)
|
ask_candid/tools/search.py
CHANGED
|
@@ -1,9 +1,14 @@
|
|
| 1 |
-
from typing import List
|
|
|
|
| 2 |
import logging
|
| 3 |
|
|
|
|
| 4 |
from langchain_core.language_models.llms import LLM
|
|
|
|
| 5 |
from langchain_core.tools import Tool
|
| 6 |
|
|
|
|
|
|
|
| 7 |
from ask_candid.agents.schema import AgentState
|
| 8 |
|
| 9 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
|
@@ -11,7 +16,86 @@ logger = logging.getLogger(__name__)
|
|
| 11 |
logger.setLevel(logging.INFO)
|
| 12 |
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
"""Invokes the agent model to generate a response based on the current state. Given
|
| 16 |
the question, it will decide to retrieve using the retriever tool, or simply end.
|
| 17 |
|
|
|
|
| 1 |
+
from typing import List, Tuple, Callable, Optional, Any
|
| 2 |
+
from functools import partial
|
| 3 |
import logging
|
| 4 |
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
from langchain_core.language_models.llms import LLM
|
| 7 |
+
from langchain_core.documents import Document
|
| 8 |
from langchain_core.tools import Tool
|
| 9 |
|
| 10 |
+
from ask_candid.retrieval.elastic import get_query_results, get_reranked_results
|
| 11 |
+
from ask_candid.base.config.data import DataIndices
|
| 12 |
from ask_candid.agents.schema import AgentState
|
| 13 |
|
| 14 |
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
|
|
|
| 16 |
logger.setLevel(logging.INFO)
|
| 17 |
|
| 18 |
|
| 19 |
+
class RetrieverInput(BaseModel):
|
| 20 |
+
"""Input to the Elasticsearch retriever."""
|
| 21 |
+
user_input: str = Field(description="query to look up in retriever")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_search_results(
|
| 25 |
+
user_input: str,
|
| 26 |
+
indices: List[DataIndices],
|
| 27 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
| 28 |
+
) -> Tuple[str, List[Document]]:
|
| 29 |
+
"""End-to-end search and re-rank function.
|
| 30 |
+
|
| 31 |
+
Parameters
|
| 32 |
+
----------
|
| 33 |
+
user_input : str
|
| 34 |
+
Search context string
|
| 35 |
+
indices : List[DataIndices]
|
| 36 |
+
Semantic index names to search over
|
| 37 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
| 38 |
+
Optional UI callback to inform the user of apps states, by default None
|
| 39 |
+
|
| 40 |
+
Returns
|
| 41 |
+
-------
|
| 42 |
+
Tuple[str, List[Document]]
|
| 43 |
+
(concatenated text from search results, documents list)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
if user_callback is not None:
|
| 47 |
+
try:
|
| 48 |
+
user_callback("Searching for relevant information")
|
| 49 |
+
except Exception as ex:
|
| 50 |
+
logger.warning("User callback was passed in but failed: %s", ex)
|
| 51 |
+
|
| 52 |
+
output = ["Search didn't return any Candid sources"]
|
| 53 |
+
page_content = []
|
| 54 |
+
content = "Search didn't return any Candid sources"
|
| 55 |
+
results = get_query_results(search_text=user_input, indices=indices)
|
| 56 |
+
if results:
|
| 57 |
+
output = get_reranked_results(results, search_text=user_input)
|
| 58 |
+
for doc in output:
|
| 59 |
+
page_content.append(doc.page_content)
|
| 60 |
+
content = "\n\n".join(page_content)
|
| 61 |
+
|
| 62 |
+
# for the tool we need to return a tuple for content_and_artifact type
|
| 63 |
+
return content, output
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def retriever_tool(
|
| 67 |
+
indices: List[DataIndices],
|
| 68 |
+
user_callback: Optional[Callable[[str], Any]] = None
|
| 69 |
+
) -> Tool:
|
| 70 |
+
"""Tool component for use in conditional edge building for RAG execution graph.
|
| 71 |
+
Cannot use `create_retriever_tool` because it only provides content losing all metadata on the way
|
| 72 |
+
https://python.langchain.com/docs/how_to/custom_tools/#returning-artifacts-of-tool-execution
|
| 73 |
+
|
| 74 |
+
Parameters
|
| 75 |
+
----------
|
| 76 |
+
indices : List[DataIndices]
|
| 77 |
+
Semantic index names to search over
|
| 78 |
+
user_callback : Optional[Callable[[str], Any]], optional
|
| 79 |
+
Optional UI callback to inform the user of apps states, by default None
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
-------
|
| 83 |
+
Tool
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
return Tool(
|
| 87 |
+
name="retrieve_social_sector_information",
|
| 88 |
+
func=partial(get_search_results, indices=indices, user_callback=user_callback),
|
| 89 |
+
description=(
|
| 90 |
+
"Return additional information about social and philanthropic sector, "
|
| 91 |
+
"including nonprofits (NGO), grants, foundations, funding, RFP, LOI, Candid."
|
| 92 |
+
),
|
| 93 |
+
args_schema=RetrieverInput,
|
| 94 |
+
response_format="content_and_artifact"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def search_agent(state: AgentState, llm: LLM, tools: List[Tool]) -> AgentState:
|
| 99 |
"""Invokes the agent model to generate a response based on the current state. Given
|
| 100 |
the question, it will decide to retrieve using the retriever tool, or simply end.
|
| 101 |
|
ask_candid/utils.py
CHANGED
|
@@ -77,7 +77,7 @@ def format_chat_ag_response(chatbot: List[Any]) -> List[Any]:
|
|
| 77 |
"""
|
| 78 |
sources = ""
|
| 79 |
if chatbot:
|
| 80 |
-
title = chatbot[-1]
|
| 81 |
if title == "Sources HTML":
|
| 82 |
sources = chatbot[-1]["content"]
|
| 83 |
chatbot.pop(-1)
|
|
|
|
| 77 |
"""
|
| 78 |
sources = ""
|
| 79 |
if chatbot:
|
| 80 |
+
title = (chatbot[-1].get("metadata") or {}).get("title", None)
|
| 81 |
if title == "Sources HTML":
|
| 82 |
sources = chatbot[-1]["content"]
|
| 83 |
chatbot.pop(-1)
|