Pilatopia / agent.py
dareenharthi's picture
Update agent.py
23335f7 verified
import json
from typing import Dict, List, Any, TypedDict, Annotated
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
import re
from datetime import datetime
from system_prompts import *
import json
from tools import *
from utils import *
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
current_plan: str
language_detected: str
response_draft: str
manager_feedback: str
final_response: str
iteration_count: int
needs_planning: bool
plan_approved: bool
class PilatopiaAgentSystem:
def __init__(self, openai_api_key: str):
self.llm_base = ChatOpenAI(
model="gpt-4o-mini",
api_key=openai_api_key)
self.llm_planner = ChatOpenAI(
model="gpt-4o",
api_key=openai_api_key
)
self.router_agent = create_react_agent(
model=self.llm_base,
tools=[],
prompt=router_prompt,
name="router_agent"
)
self.planner_agent = create_react_agent(
model=self.llm_planner,
tools=[],
prompt=planner_prompt.format(
tools=json.dumps(tools, indent=2),
pilatopia_complete_info=pilatopia_complete_info,
saudi_consumer_protection=saudi_consumer_protection
),
name="planner_agent"
)
self.manager_agent = create_react_agent(
model=self.llm_base,
tools=[],
prompt=manager_prompt.format(
pilatopia_complete_info=pilatopia_complete_info,
saudi_consumer_protection=saudi_consumer_protection,
tools=json.dumps(tools, indent=2),
verify_tool_check_prompt=verify_tool_check_prompt,
),
name="manager_agent"
)
self.execution_agent = create_react_agent(
model=self.llm_base,
tools=tools_list,
prompt=execution_prompt.format(
tools=json.dumps(tools, indent=2),
),
name="execution_agent"
)
self.planner_context = ""
self.graph = self._build_graph()
with open("pilatopia_graph_agents.png", "wb") as f:
f.write(self.graph.get_graph().draw_mermaid_png())
def _build_graph(self) -> StateGraph:
"""Build the Manager-First workflow using create_react_agent components"""
workflow = StateGraph(AgentState)
# Add nodes - these will be wrapper functions that call the react agents
workflow.add_node("router_agent", self.router_decision_node)
workflow.add_node("planner_agent", self.planning_agent_node)
workflow.add_node("manager_agent", self.manager_approval_node)
workflow.add_node("execution_agent", self.execution_agent_node)
workflow.set_entry_point("router_agent")
# Router decision: either go to planning or direct execution
workflow.add_conditional_edges(
"router_agent",
self.route_from_router_decision,
{
"needs_planning": "planner_agent",
"direct_execution": "execution_agent"
}
)
# Planning always goes to manager approval
workflow.add_edge("planner_agent", "manager_agent")
# Manager approval: either approve (go to execution) or revise (back to planning)
workflow.add_conditional_edges(
"manager_agent",
self.route_from_manager_approval,
{
"approved": "execution_agent",
"revise_plan": "planner_agent"
}
)
workflow.add_edge("execution_agent", END)
return workflow.compile()
def route_from_router_decision(self, state: AgentState) -> str:
return "needs_planning" if state.get("needs_planning", False) else "direct_execution"
def route_from_manager_approval(self, state: AgentState) -> str:
return "approved" if state.get("plan_approved", False) else "revise_plan"
def get_conversation_history(self, state: AgentState) -> str:
conversation_history = ""
if len(state["messages"]) > 1:
history_parts = []
for msg in state["messages"][:-1]: # Exclude the current message
if hasattr(msg, 'type'):
if msg.type == "human":
history_parts.append(f"Human: {msg.content}")
elif msg.type == "ai":
history_parts.append(f"AI: {msg.content}")
else:
history_parts.append(f"System: {msg.content}")
else:
msg_type = type(msg).__name__
if "Human" in msg_type:
history_parts.append(f"Human: {msg.content}")
elif "AI" in msg_type:
history_parts.append(f"AI: {msg.content}")
else:
history_parts.append(f"Message: {msg.content}")
conversation_history = "\n".join(history_parts)
return conversation_history
def router_decision_node(self, state: AgentState) -> AgentState:
latest_message = state["messages"][-1].content if state["messages"] else ""
previous_context = self.get_conversation_history(state)
messages = [
SystemMessage(content=router_input.format(
previous_context=previous_context,
new_user_input=latest_message
))
]
response = self.router_agent.invoke({"messages": messages})
messages = response["messages"]
for msg in reversed(messages):
if hasattr(msg, "__class__") and msg.__class__.__name__ == "AIMessage":
if hasattr(msg, "content"):
content = msg.content
break
print("router decision", content)
# Parse the router decision
if "<router_decision>needs_planning</router_decision>" in content:
state["needs_planning"] = True
else:
state["needs_planning"] = False
# Detect language (simple heuristic)
if any(ord(char) > 127 for char in latest_message):
state["language_detected"] = "Arabic"
else:
state["language_detected"] = "English"
state["request_category"] = "GENERAL" # Default category
return state
def planning_agent_node(self, state: AgentState) -> AgentState:
"""Create detailed plan for complex queries"""
latest_message = state["messages"][-1].content if state["messages"] else ""
conversation_history = self.get_conversation_history(state)
# try:
# print(json.dumps(tools, indent=2))
# print("latest_message:", state["messages"])
self.planner_context = planner_input.format(
user_message=latest_message,
conversation_history=conversation_history,
)
messages = [SystemMessage(content= self.planner_context)]
response = self.planner_agent.invoke( {"messages": messages} )
messages = response["messages"]
for msg in reversed(messages):
if hasattr(msg, "__class__") and msg.__class__.__name__ == "AIMessage":
if hasattr(msg, "content"):
content = msg.content
break
state["current_plan"] = content
print("---" * 20)
print("Planning agent response:", content)
print("---" * 20)
# except Exception as e:
# state["current_plan"] = f"Error in planning: {str(e)}"
return state
def manager_approval_node(self, state: AgentState) -> AgentState:
"""Manager reviews and approves the plan"""
latest_message = state["messages"][-1].content if state["messages"] else ""
current_plan = state.get("current_plan", "")
# try:
messages = [
SystemMessage(content=manager_input.format(
agent_system_prompt= self.planner_context,
initial_user_prompt=latest_message,
messages=format_messages_with_actions(state["messages"]),
current_plan=current_plan,
))
]
response = self.manager_agent.invoke( {"messages": messages})
messages = response["messages"]
for msg in reversed(messages):
if hasattr(msg, "__class__") and msg.__class__.__name__ == "AIMessage":
if hasattr(msg, "content"):
content = msg.content
break
print("---" * 20)
print("Manager decision content:", content)
print("---" * 20)
# Parse manager decision
if "<manager_verify>accept</manager_verify>" in content:
state["plan_approved"] = True
state["manager_feedback"] = "Plan approved"
else:
state["plan_approved"] = False
# Extract feedback
feedback_match = re.search(r'<feedback_comment>(.*?)</feedback_comment>', content, re.DOTALL)
state["manager_feedback"] = feedback_match.group(1) if feedback_match else "Please revise the plan"
state["iteration_count"] = state.get("iteration_count", 0) + 1
# Prevent infinite loops
if state.get("iteration_count", 0) >= 2:
state["plan_approved"] = True
state["manager_feedback"] = "Plan approved after max iterations"
# except Exception as e:
# # If error, approve to continue
# state["plan_approved"] = True
# state["manager_feedback"] = f"Approved due to error: {str(e)}"
return state
def execution_agent_node(self, state: AgentState) -> AgentState:
"""Generate the final response"""
original_message = state["messages"][-1].content if state["messages"] else ""
print("---"* 20)
conversation_history = self.get_conversation_history(state)
print("conversation_history:", conversation_history)
print("---" * 20)
# try:
# Prepare plan context
plan_context = state.get("current_plan", "")
print("##" * 20)
print("Plan context:", plan_context)
print("##" * 20)
messages = [
SystemMessage(content=execution_input.format(
language=state.get("language_detected", "English"),
user_message=original_message,
conversation_history=conversation_history,
plan=plan_context
))
]
response = self.execution_agent.invoke({"messages": messages})
messages = response["messages"]
for msg in reversed(messages):
if hasattr(msg, "__class__") and msg.__class__.__name__ == "AIMessage":
if hasattr(msg, "content"):
content = msg.content
break
state["final_response"] = content
# except Exception as e:
# if state.get("language_detected") == "Arabic":
# state["final_response"] = f"عذراً، حدث خطأ في إنشاء الرد: {str(e)}"
# else:
# state["final_response"] = f"Sorry, there was an error generating the response: {str(e)}"
new_agent_message = AIMessage(content=state["final_response"])
state["messages"] = state["messages"] + [new_agent_message]
return state