LuisZermeno's picture
Update agent.py
fa47a9d verified
import os
import logging
from typing import Dict, List, Any, Optional
from langchain_anthropic import ChatAnthropic
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from langchain.tools import Tool
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.utilities import WikipediaAPIWrapper
import asyncio
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Load system prompt
def load_system_prompt():
try:
with open("system_prompt.txt", "r") as f:
return f.read()
except FileNotFoundError:
return """You are a helpful AI assistant designed to answer questions accurately and concisely.
When answering questions:
1. Be direct and precise
2. For numerical answers, provide ONLY the number
3. For yes/no questions, answer ONLY 'yes' or 'no'
4. For names or single words, provide ONLY that word
5. Always end your response with: FINAL ANSWER: [your answer]"""
SYSTEM_PROMPT = load_system_prompt()
# Define state
class GraphState(BaseModel):
"""State for the agent graph"""
messages: List[Any] = Field(default_factory=list)
final_answer_text: Optional[str] = None
iterations: int = Field(default=0)
max_iterations: int = Field(default=5)
# Tools setup
def setup_tools():
"""Initialize and return all tools"""
tools = []
# Web search tool
try:
search = DuckDuckGoSearchRun()
web_search = Tool(
name="web_search",
func=search.run,
description="Search the web for current information"
)
tools.append(web_search)
except Exception as e:
logger.warning(f"Could not initialize web search: {e}")
# Wikipedia tool
try:
wikipedia = WikipediaAPIWrapper()
wiki_tool = Tool(
name="wikipedia",
func=wikipedia.run,
description="Search Wikipedia for information"
)
tools.append(wiki_tool)
except Exception as e:
logger.warning(f"Could not initialize Wikipedia: {e}")
# Calculator tool
def calculate(expression: str) -> str:
"""Safely evaluate mathematical expressions"""
try:
# Remove any dangerous characters
safe_chars = "0123456789+-*/()., "
expression = ''.join(c for c in expression if c in safe_chars)
result = eval(expression)
return str(result)
except Exception as e:
return f"Error: {str(e)}"
calc_tool = Tool(
name="calculator",
func=calculate,
description="Perform mathematical calculations"
)
tools.append(calc_tool)
return tools
# Create the agent
class GAIAAgent:
def __init__(self):
self.llm = ChatAnthropic(
model="claude-3-5-sonnet-20241022",
temperature=0,
max_tokens=1024,
api_key=os.getenv("ANTHROPIC_API_KEY")
)
self.tools = setup_tools()
def create_graph(self):
"""Create the state graph"""
workflow = StateGraph(GraphState)
# Add nodes
workflow.add_node("assistant", self.assistant_node)
workflow.add_node("tools", self.tools_node)
workflow.add_node("extract_answer", self.extract_answer_node)
# Set entry point
workflow.set_entry_point("assistant")
# Add edges
workflow.add_conditional_edges(
"assistant",
self.should_continue,
{
"tools": "tools",
"extract_answer": "extract_answer",
"end": END
}
)
workflow.add_edge("tools", "assistant")
workflow.add_edge("extract_answer", END)
return workflow.compile()
def assistant_node(self, state: GraphState) -> Dict:
"""Main assistant logic"""
messages = state.messages
# Add system message if first iteration
if state.iterations == 0:
messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
# Get response from LLM
if self.tools:
response = self.llm.bind_tools(self.tools).invoke(messages)
else:
response = self.llm.invoke(messages)
# Check if final answer is in response
if isinstance(response.content, str) and "FINAL ANSWER:" in response.content:
answer = response.content.split("FINAL ANSWER:")[-1].strip()
return {
"messages": [response],
"final_answer_text": answer,
"iterations": state.iterations + 1
}
return {
"messages": [response],
"iterations": state.iterations + 1
}
def tools_node(self, state: GraphState) -> Dict:
"""Execute tools"""
messages = state.messages
last_message = messages[-1]
tool_results = []
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call.get("args", {})
# Find and execute tool
for tool in self.tools:
if tool.name == tool_name:
try:
if isinstance(tool_args, dict) and len(tool_args) == 1:
# Get the first argument value
arg_value = list(tool_args.values())[0]
result = tool.func(arg_value)
else:
result = tool.func(str(tool_args))
tool_results.append({
"tool": tool_name,
"result": result
})
except Exception as e:
tool_results.append({
"tool": tool_name,
"result": f"Error: {str(e)}"
})
break
# Format results
if tool_results:
result_text = "\n".join([
f"Tool: {r['tool']}\nResult: {r['result']}"
for r in tool_results
])
return {"messages": [AIMessage(content=result_text)]}
return {"messages": []}
def extract_answer_node(self, state: GraphState) -> Dict:
"""Extract final answer from conversation"""
# Look through all messages for an answer
for message in reversed(state.messages):
if hasattr(message, "content") and message.content:
content = str(message.content)
if "FINAL ANSWER:" in content:
answer = content.split("FINAL ANSWER:")[-1].strip()
return {"final_answer_text": answer}
# If no explicit final answer, ask for one
prompt = "Based on our conversation, please provide your final answer. Format: FINAL ANSWER: [your answer]"
response = self.llm.invoke([HumanMessage(content=prompt)])
if "FINAL ANSWER:" in response.content:
answer = response.content.split("FINAL ANSWER:")[-1].strip()
return {"final_answer_text": answer}
return {"final_answer_text": "Unable to determine answer"}
def should_continue(self, state: GraphState) -> str:
"""Decide next action"""
if state.final_answer_text:
return "end"
if state.iterations >= state.max_iterations:
return "extract_answer"
last_message = state.messages[-1] if state.messages else None
if last_message and hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
if last_message and "FINAL ANSWER:" in str(last_message.content):
return "extract_answer"
return "end"
# Main agent function
async def basic_agent(question: str) -> str:
"""Process a question and return an answer"""
try:
# Create agent
agent = GAIAAgent()
graph = agent.create_graph()
# Run the graph
initial_state = GraphState(
messages=[HumanMessage(content=question)]
)
result = await graph.ainvoke(initial_state)
# Extract answer
if result.get("final_answer_text"):
return result["final_answer_text"]
# Fallback: look for answer in messages
for message in reversed(result.get("messages", [])):
if hasattr(message, "content") and message.content:
return str(message.content)
return "Unable to determine answer"
except Exception as e:
logger.error(f"Error in basic_agent: {str(e)}")
return f"Error: {str(e)}"
# For testing
if __name__ == "__main__":
import asyncio
test_question = "What is the capital of France?"
answer = asyncio.run(basic_agent(test_question))
print(f"Question: {test_question}")
print(f"Answer: {answer}")