Spaces:
Sleeping
Sleeping
File size: 6,688 Bytes
3973360 f87934b 3973360 f87934b 3973360 f87934b 3973360 ef0145e 3973360 ef0145e f87934b 3973360 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | from dotenv import load_dotenv
load_dotenv()
from src.langgraph.langchain.llm import llm
from src.langgraph.tools.destination_tools import destination_suggestion
from src.langgraph.tools.search_tools import search_and_summarize_website
from .react_agent import create_react_agent
from src.langgraph.langchain.prompt import planner_prompt, parser_output_planner_prompt
from typing import TypedDict, Union, Any, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.exceptions import OutputParserException
from langgraph.graph import StateGraph, START, END
from src.utils.logger import logger
from src.utils.helper import format_include_destinations
import operator
from typing import Annotated
from .utils import format_log_to_str
from langchain.agents.output_parsers import ReActSingleInputOutputParser
class State(TypedDict):
llm_response: Annotated[list[Union[AgentAction, Any]], operator.add]
tools_ouput: Annotated[list[str], operator.add]
# error: Union[OutputParserException, Any]
error: Optional[Any]
duration: str
start_date: str
location: str
interests: str
nation: str
include_destination: list
limit_interation: int
current_interation: int
final_answer: str
parser = ReActSingleInputOutputParser()
tools = [destination_suggestion, search_and_summarize_website]
tools_mapping = {tool.name: tool for tool in tools}
async def agent_fn(state: State):
llm_response = state["llm_response"]
tools_output = state["tools_ouput"]
error = state["error"]
duration = state["duration"]
start_date = state["start_date"]
location = state["location"]
interests = state["interests"]
nation = state["nation"]
include_destination = format_include_destinations(state["include_destination"])
prompt = planner_prompt.partial(
duration=duration,
start_date=start_date,
location=location,
interests=interests,
nation=nation,
include_destination=include_destination,
)
if len(llm_response) != 0:
agent_scratchpad = format_log_to_str(
zip(llm_response, tools_output), llm_prefix=""
)
else:
agent_scratchpad = ""
if error:
if isinstance(error, OutputParserException):
error = error.observation
agent_scratchpad += (
"\nPrevious response have error: "
+ str(error)
+ "so agent will try to recover. Please return in right format defined in prompt"
)
agent = create_react_agent(llm, tools, prompt)
try:
response = await agent.ainvoke(agent_scratchpad)
logger.info(f"-> Agent response {response.content}")
response_paser: Union[AgentAction, AgentFinish] = parser.parse(response.content)
return {
"llm_response": [response_paser],
"error": None,
}
except OutputParserException as e:
response = e.observation
logger.error(f"Error in agent invoke {e}")
return {
"error": e,
}
def after_call_agent(state: State):
error = state["error"]
llm_response = state["llm_response"][-1]
if isinstance(error, OutputParserException):
logger.info("-> paser output 1")
return "parse_output"
else:
logger.info("-> Tool")
return "execute_tools"
async def excute_tools_fn(state: State):
llm_response: AgentAction = state["llm_response"][-1]
tool_call_name = llm_response.tool
tool_args = llm_response.tool_input
# Preprocess tool name to remove markdown formatting
if tool_call_name.startswith("**") and tool_call_name.endswith("**"):
tool_call_name = tool_call_name[2:-2] # Remove ** from start and end
elif tool_call_name.startswith("*") and tool_call_name.endswith("*"):
tool_call_name = tool_call_name[1:-1] # Remove * from start and end
elif tool_call_name.startswith("`") and tool_call_name.endswith("`"):
tool_call_name = tool_call_name[1:-1] # Remove ` from start and end
logger.info(f"-> Original tool name: {llm_response.tool}")
logger.info(f"-> Processed tool name: {tool_call_name}")
logger.info(f"-> Tool args: {tool_args}")
if tool_call_name == "destination_suggestion":
tool_response = await destination_suggestion.ainvoke(
{"query": tool_args, "config": ""}
)
elif tool_call_name == "search_and_summarize_website":
tool_response = await search_and_summarize_website.ainvoke({"query": tool_args})
logger.info("-> Agent")
return {
"tools_ouput": [tool_response],
"error": None,
"current_interation": state["current_interation"] + 1,
}
async def parser_output_fn(state: State):
llm_response = state["llm_response"]
tools_output = state["tools_ouput"]
error = state["error"]
duration = state["duration"]
start_date = state["start_date"]
location = state["location"]
interests = state["interests"]
nation = state["nation"]
if len(llm_response) != 0:
agent_scratchpad = format_log_to_str(
zip(llm_response, tools_output), llm_prefix=""
)
else:
agent_scratchpad = ""
if error:
if isinstance(error, OutputParserException):
error = error.observation
agent_scratchpad += (
"\nPrevious response have error: "
+ str(error)
+ "so agent will try to recover. Please return in right format defined in prompt"
)
prompt = parser_output_planner_prompt.partial(
duration=duration,
start_date=start_date,
location=location,
interests=interests,
nation=nation,
)
chain_output = prompt | llm
output = await chain_output.ainvoke({"agent_scratchpad": agent_scratchpad})
return {
"final_answer": output.content,
}
workflow = StateGraph(State)
workflow.add_node("agent", agent_fn)
workflow.add_node("execute_tools", excute_tools_fn)
workflow.add_node("parse_output", parser_output_fn)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges(
"agent",
after_call_agent,
{
"parse_output": "parse_output",
"execute_tools": "execute_tools",
},
)
def after_execute_tools(state: State):
if state["current_interation"] >= state["limit_interation"]:
return "parse_output"
return "agent"
workflow.add_conditional_edges(
"execute_tools",
after_execute_tools,
{
"parse_output": "parse_output",
"agent": "agent",
},
)
workflow.add_edge("parse_output", END)
planner_app = workflow.compile()
|