Spaces:
Running
Running
Delete ask_candid/graph.py
Browse files- ask_candid/graph.py +0 -203
ask_candid/graph.py
DELETED
|
@@ -1,203 +0,0 @@
|
|
| 1 |
-
from typing import List, Optional, Callable, Any
|
| 2 |
-
from functools import partial
|
| 3 |
-
import logging
|
| 4 |
-
|
| 5 |
-
from langchain_core.messages import AIMessage, BaseMessage
|
| 6 |
-
from langchain_core.output_parsers import StrOutputParser
|
| 7 |
-
from langchain_core.prompts import ChatPromptTemplate
|
| 8 |
-
from langchain_core.language_models.llms import LLM
|
| 9 |
-
|
| 10 |
-
from langgraph.prebuilt import tools_condition, ToolNode
|
| 11 |
-
from langgraph.graph.state import StateGraph
|
| 12 |
-
from langgraph.constants import START, END
|
| 13 |
-
|
| 14 |
-
from ask_candid.tools.recommendation import (
|
| 15 |
-
detect_intent_with_llm,
|
| 16 |
-
determine_context,
|
| 17 |
-
make_recommendation
|
| 18 |
-
)
|
| 19 |
-
from ask_candid.tools.question_reformulation import reformulate_question_using_history
|
| 20 |
-
from ask_candid.tools.org_seach import has_org_name, insert_org_link
|
| 21 |
-
from ask_candid.tools.search import search_agent, retriever_tool
|
| 22 |
-
from ask_candid.agents.schema import AgentState
|
| 23 |
-
from ask_candid.base.config.data import DataIndices
|
| 24 |
-
|
| 25 |
-
from ask_candid.utils import html_format_docs_chat
|
| 26 |
-
|
| 27 |
-
logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
|
| 28 |
-
logger = logging.getLogger(__name__)
|
| 29 |
-
logger.setLevel(logging.INFO)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def generate_with_context(
|
| 33 |
-
state: AgentState,
|
| 34 |
-
llm: LLM,
|
| 35 |
-
user_callback: Optional[Callable[[str], Any]] = None
|
| 36 |
-
) -> AgentState:
|
| 37 |
-
"""Generate answer.
|
| 38 |
-
|
| 39 |
-
Parameters
|
| 40 |
-
----------
|
| 41 |
-
state : AgentState
|
| 42 |
-
The current state
|
| 43 |
-
llm : LLM
|
| 44 |
-
user_callback : Optional[Callable[[str], Any]], optional
|
| 45 |
-
Optional UI callback to inform the user of apps states, by default None
|
| 46 |
-
|
| 47 |
-
Returns
|
| 48 |
-
-------
|
| 49 |
-
AgentState
|
| 50 |
-
The updated state with the agent response appended to messages
|
| 51 |
-
"""
|
| 52 |
-
|
| 53 |
-
logger.info("---GENERATE ANSWER---")
|
| 54 |
-
if user_callback is not None:
|
| 55 |
-
try:
|
| 56 |
-
user_callback("Writing a response...")
|
| 57 |
-
except Exception as ex:
|
| 58 |
-
logger.warning("User callback was passed in but failed: %s", ex)
|
| 59 |
-
|
| 60 |
-
messages = state["messages"]
|
| 61 |
-
question = state["user_input"]
|
| 62 |
-
last_message = messages[-1]
|
| 63 |
-
|
| 64 |
-
sources_str = last_message.content
|
| 65 |
-
sources_list = last_message.artifact
|
| 66 |
-
sources_html = html_format_docs_chat(sources_list)
|
| 67 |
-
|
| 68 |
-
if sources_list:
|
| 69 |
-
logger.info("---ADD SOURCES---")
|
| 70 |
-
state["messages"].append(BaseMessage(content=sources_html, type="HTML"))
|
| 71 |
-
|
| 72 |
-
# Prompt
|
| 73 |
-
qa_system_prompt = """
|
| 74 |
-
You are an assistant for question-answering tasks in the social and philanthropic sector. \n
|
| 75 |
-
Use the following pieces of retrieved context to answer the question at the end. \n
|
| 76 |
-
If you don't know the answer, just say that you don't know. \n
|
| 77 |
-
Keep the response professional, friendly, and as concise as possible. \n
|
| 78 |
-
Question: {question}
|
| 79 |
-
Context: {context}
|
| 80 |
-
Answer:
|
| 81 |
-
"""
|
| 82 |
-
|
| 83 |
-
qa_prompt = ChatPromptTemplate([
|
| 84 |
-
("system", qa_system_prompt),
|
| 85 |
-
("human", question),
|
| 86 |
-
])
|
| 87 |
-
|
| 88 |
-
rag_chain = qa_prompt | llm | StrOutputParser()
|
| 89 |
-
response = rag_chain.invoke({"context": sources_str, "question": question})
|
| 90 |
-
return {"messages": [AIMessage(content=response)], "user_input": question}
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def add_recommendations_pipeline_(
|
| 94 |
-
G: StateGraph,
|
| 95 |
-
llm: LLM,
|
| 96 |
-
reformulation_node_name: str = "reformulate",
|
| 97 |
-
search_node_name: str = "search_agent"
|
| 98 |
-
) -> None:
|
| 99 |
-
"""Adds execution sub-graph for recommendation engine flow. Graph changes are in-place.
|
| 100 |
-
|
| 101 |
-
Parameters
|
| 102 |
-
----------
|
| 103 |
-
G : StateGraph
|
| 104 |
-
Execution graph
|
| 105 |
-
reformulation_node_name : str, optional
|
| 106 |
-
Name of the node which reforumates input queries, by default "reformulate"
|
| 107 |
-
search_node_name : str, optional
|
| 108 |
-
Name of the node which executes document search + retrieval, by default "search_agent"
|
| 109 |
-
"""
|
| 110 |
-
|
| 111 |
-
# Nodes for recommendation functionalities
|
| 112 |
-
G.add_node(node="detect_intent_with_llm", action=partial(detect_intent_with_llm, llm=llm))
|
| 113 |
-
G.add_node(node="determine_context", action=determine_context)
|
| 114 |
-
G.add_node(node="make_recommendation", action=make_recommendation)
|
| 115 |
-
|
| 116 |
-
# Check for recommendation query first
|
| 117 |
-
# Execute until reaching END if user asks for recommendation
|
| 118 |
-
G.add_edge(start_key=reformulation_node_name, end_key="detect_intent_with_llm")
|
| 119 |
-
G.add_conditional_edges(
|
| 120 |
-
source="detect_intent_with_llm",
|
| 121 |
-
path=lambda state: "determine_context" if state["intent"] in ["rfp", "funder"] else search_node_name,
|
| 122 |
-
path_map={
|
| 123 |
-
"determine_context": "determine_context",
|
| 124 |
-
search_node_name: search_node_name
|
| 125 |
-
},
|
| 126 |
-
)
|
| 127 |
-
G.add_edge(start_key="determine_context", end_key="make_recommendation")
|
| 128 |
-
G.add_edge(start_key="make_recommendation", end_key=END)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def build_compute_graph(
|
| 132 |
-
llm: LLM,
|
| 133 |
-
indices: List[DataIndices],
|
| 134 |
-
enable_recommendations: bool = False,
|
| 135 |
-
user_callback: Optional[Callable[[str], Any]] = None
|
| 136 |
-
) -> StateGraph:
|
| 137 |
-
"""Execution graph builder, the output is the execution flow for an interaction with the assistant.
|
| 138 |
-
|
| 139 |
-
Parameters
|
| 140 |
-
----------
|
| 141 |
-
llm : LLM
|
| 142 |
-
indices : List[DataIndices]
|
| 143 |
-
Semantic index names to search over
|
| 144 |
-
enable_recommendations : bool, optional
|
| 145 |
-
Set to `True` to allow the flow to generate recommendations based on context, by default False
|
| 146 |
-
user_callback : Optional[Callable[[str], Any]], optional
|
| 147 |
-
Optional UI callback to inform the user of apps states, by default None
|
| 148 |
-
|
| 149 |
-
Returns
|
| 150 |
-
-------
|
| 151 |
-
StateGraph
|
| 152 |
-
Execution graph
|
| 153 |
-
"""
|
| 154 |
-
|
| 155 |
-
candid_retriever_tool = retriever_tool(indices=indices, user_callback=user_callback)
|
| 156 |
-
retrieve = ToolNode([candid_retriever_tool])
|
| 157 |
-
tools = [candid_retriever_tool]
|
| 158 |
-
|
| 159 |
-
G = StateGraph(AgentState)
|
| 160 |
-
|
| 161 |
-
G.add_node(
|
| 162 |
-
node="reformulate",
|
| 163 |
-
action=partial(reformulate_question_using_history, llm=llm, focus_on_recommendations=enable_recommendations)
|
| 164 |
-
)
|
| 165 |
-
G.add_node(node="search_agent", action=partial(search_agent, llm=llm, tools=tools))
|
| 166 |
-
G.add_node(node="retrieve", action=retrieve)
|
| 167 |
-
G.add_node(
|
| 168 |
-
node="generate_with_context",
|
| 169 |
-
action=partial(generate_with_context, llm=llm, user_callback=user_callback)
|
| 170 |
-
)
|
| 171 |
-
G.add_node(node="has_org_name", action=partial(has_org_name, llm=llm, user_callback=user_callback))
|
| 172 |
-
G.add_node(node="insert_org_link", action=insert_org_link)
|
| 173 |
-
|
| 174 |
-
if enable_recommendations:
|
| 175 |
-
add_recommendations_pipeline_(
|
| 176 |
-
G, llm=llm,
|
| 177 |
-
reformulation_node_name="reformulate",
|
| 178 |
-
search_node_name="search_agent"
|
| 179 |
-
)
|
| 180 |
-
else:
|
| 181 |
-
G.add_edge(start_key="reformulate", end_key="search_agent")
|
| 182 |
-
|
| 183 |
-
G.add_edge(start_key=START, end_key="reformulate")
|
| 184 |
-
G.add_conditional_edges(
|
| 185 |
-
source="search_agent",
|
| 186 |
-
path=tools_condition,
|
| 187 |
-
path_map={
|
| 188 |
-
"tools": "retrieve",
|
| 189 |
-
END: "has_org_name",
|
| 190 |
-
},
|
| 191 |
-
)
|
| 192 |
-
G.add_edge(start_key="retrieve", end_key="generate_with_context")
|
| 193 |
-
G.add_edge(start_key="generate_with_context", end_key="has_org_name")
|
| 194 |
-
G.add_conditional_edges(
|
| 195 |
-
source="has_org_name",
|
| 196 |
-
path=lambda x: x["next"], # Now we're accessing the 'next' key from the dict
|
| 197 |
-
path_map={
|
| 198 |
-
"insert_org_link": "insert_org_link",
|
| 199 |
-
END: END
|
| 200 |
-
},
|
| 201 |
-
)
|
| 202 |
-
G.add_edge(start_key="insert_org_link", end_key=END)
|
| 203 |
-
return G
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|