|
|
from typing import List, Union, Dict, Tuple |
|
|
from typing_extensions import NotRequired, TypedDict |
|
|
|
|
|
from langchain_core.agents import ( |
|
|
AgentAction, |
|
|
AgentFinish |
|
|
) |
|
|
from langchain_core.messages import BaseMessage, AIMessage |
|
|
from langgraph.graph import END, StateGraph |
|
|
|
|
|
from runnables import Answerer, Agent |
|
|
from prompts import _ANSWERER_SYSTEM_TEMPLATE, _AGENT_SYSTEM_TEMPLATE |
|
|
|
|
|
OUTPUT_KEY = "response" |
|
|
|
|
|
def _get_graph( |
|
|
agent_model_name: str = "gpt-4-turbo", |
|
|
agent_system_template: str = _AGENT_SYSTEM_TEMPLATE, |
|
|
agent_temperature: float = 0.0, |
|
|
answerer_model_name: str = "gpt-4-turbo", |
|
|
answerer_system_template: str = _ANSWERER_SYSTEM_TEMPLATE, |
|
|
answerer_temperature: float = 0.0, |
|
|
collection_index:int = 0, |
|
|
use_doctrines:bool = True, |
|
|
search_type:str = "similarity", |
|
|
similarity_threshold:float = 0.0, |
|
|
k:int = 15, |
|
|
): |
|
|
|
|
|
agent = Agent( |
|
|
model_name = agent_model_name, |
|
|
system_template = agent_system_template, |
|
|
temperature = agent_temperature, |
|
|
) |
|
|
|
|
|
agent_runnable = agent.runnable |
|
|
|
|
|
answerer = Answerer( |
|
|
model_name = answerer_model_name, |
|
|
system_template = answerer_system_template, |
|
|
temperature = answerer_temperature, |
|
|
collection_index = collection_index, |
|
|
use_doctrines = use_doctrines, |
|
|
search_type = search_type, |
|
|
similarity_threshold = similarity_threshold, |
|
|
k = k, |
|
|
) |
|
|
answerer_runnable = answerer.runnable |
|
|
|
|
|
|
|
|
class GraphState(TypedDict): |
|
|
query: str |
|
|
|
|
|
agent_outcome: NotRequired[ |
|
|
Union[AgentAction, AgentFinish] |
|
|
] |
|
|
|
|
|
chat_history: List[BaseMessage] |
|
|
|
|
|
response: NotRequired[ |
|
|
Dict[ |
|
|
str, |
|
|
Union[ |
|
|
str, |
|
|
List[int], |
|
|
List[Dict[str, Union[int, str]]] |
|
|
] |
|
|
] |
|
|
] |
|
|
|
|
|
|
|
|
async def execute_agent( |
|
|
state: GraphState, |
|
|
config: Dict, |
|
|
) -> Union[AgentAction, AgentFinish, None]: |
|
|
""" |
|
|
Invokes the agent model to generate a response based on the current state. |
|
|
|
|
|
This function calls the agent model to generate a response to the current conversation state. |
|
|
|
|
|
Args: |
|
|
state (messages): The current state of the agent. |
|
|
|
|
|
Returns: |
|
|
dict: The new agent outcome. |
|
|
""" |
|
|
|
|
|
inputs = state.copy() |
|
|
|
|
|
agent_outcome = await agent_runnable \ |
|
|
.with_config({"run_name": "agent_node"}) \ |
|
|
.ainvoke(inputs, config=config) |
|
|
|
|
|
return {"agent_outcome": agent_outcome} |
|
|
|
|
|
def execute_tool( |
|
|
state: GraphState, |
|
|
config: Dict, |
|
|
) -> List[Tuple[AgentAction, str]]: |
|
|
|
|
|
""" |
|
|
Executes the Retrieve tool. |
|
|
|
|
|
Args: |
|
|
state (messages): The current state of the agent. |
|
|
|
|
|
Returns: |
|
|
dict: The final response. |
|
|
""" |
|
|
|
|
|
inputs = state["agent_outcome"][0].tool_input |
|
|
|
|
|
tool_output = answerer_runnable.invoke( |
|
|
{"query": inputs["standalone_question"]}, |
|
|
config=config |
|
|
) |
|
|
|
|
|
return { |
|
|
OUTPUT_KEY: tool_output |
|
|
} |
|
|
|
|
|
def finish( |
|
|
state: GraphState |
|
|
) -> None: |
|
|
|
|
|
if state[OUTPUT_KEY] is not None: |
|
|
response = state[OUTPUT_KEY] |
|
|
else: |
|
|
response = { |
|
|
"answer": AIMessage(state['agent_outcome'].return_values['output']), |
|
|
"docs": [], |
|
|
"standalone_question": None |
|
|
} |
|
|
|
|
|
return {OUTPUT_KEY: response} |
|
|
|
|
|
|
|
|
def parse( |
|
|
state: GraphState |
|
|
) -> str: |
|
|
""" |
|
|
Router based on the previous agent outcome. |
|
|
|
|
|
This function checks the agent outcome to determine if the agent decided to finish the conversation. |
|
|
In that case it ends the process, otherwise it calls a tool. |
|
|
|
|
|
Args: |
|
|
state (messages): The current state of the agent. |
|
|
Returns: |
|
|
str: A decision to either "end", "use_tool". |
|
|
""" |
|
|
|
|
|
agent_outcome = state["agent_outcome"] |
|
|
|
|
|
if isinstance(agent_outcome, AgentFinish): |
|
|
return "end" |
|
|
elif isinstance(agent_outcome, List): |
|
|
agent_outcome = agent_outcome[0] |
|
|
if agent_outcome.tool is not None: |
|
|
return "use_tool" |
|
|
|
|
|
|
|
|
graph = StateGraph(GraphState) |
|
|
|
|
|
|
|
|
graph.add_node("agent", execute_agent) |
|
|
graph.add_node("tools", execute_tool) |
|
|
graph.add_node("finish", finish) |
|
|
|
|
|
|
|
|
graph.set_entry_point("agent") |
|
|
|
|
|
|
|
|
graph.add_conditional_edges( |
|
|
"agent", |
|
|
parse, |
|
|
{ |
|
|
"use_tool": "tools", |
|
|
"end": "finish", |
|
|
}, |
|
|
) |
|
|
|
|
|
|
|
|
graph.add_edge("tools", "finish") |
|
|
graph.add_edge("finish", END) |
|
|
|
|
|
|
|
|
compiled_graph = graph.compile() |
|
|
|
|
|
return compiled_graph |