Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| from typing import Dict, List, Any, Optional | |
| from langchain_anthropic import ChatAnthropic | |
| from langchain_core.messages import HumanMessage, AIMessage, SystemMessage | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.graph.message import add_messages | |
| from pydantic import BaseModel, Field | |
| from langchain.tools import Tool | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_community.utilities import WikipediaAPIWrapper | |
| import asyncio | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load system prompt | |
| def load_system_prompt(): | |
| try: | |
| with open("system_prompt.txt", "r") as f: | |
| return f.read() | |
| except FileNotFoundError: | |
| return """You are a helpful AI assistant designed to answer questions accurately and concisely. | |
| When answering questions: | |
| 1. Be direct and precise | |
| 2. For numerical answers, provide ONLY the number | |
| 3. For yes/no questions, answer ONLY 'yes' or 'no' | |
| 4. For names or single words, provide ONLY that word | |
| 5. Always end your response with: FINAL ANSWER: [your answer]""" | |
| SYSTEM_PROMPT = load_system_prompt() | |
| # Define state | |
| class GraphState(BaseModel): | |
| """State for the agent graph""" | |
| messages: List[Any] = Field(default_factory=list) | |
| final_answer_text: Optional[str] = None | |
| iterations: int = Field(default=0) | |
| max_iterations: int = Field(default=5) | |
| # Tools setup | |
| def setup_tools(): | |
| """Initialize and return all tools""" | |
| tools = [] | |
| # Web search tool | |
| try: | |
| search = DuckDuckGoSearchRun() | |
| web_search = Tool( | |
| name="web_search", | |
| func=search.run, | |
| description="Search the web for current information" | |
| ) | |
| tools.append(web_search) | |
| except Exception as e: | |
| logger.warning(f"Could not initialize web search: {e}") | |
| # Wikipedia tool | |
| try: | |
| wikipedia = WikipediaAPIWrapper() | |
| wiki_tool = Tool( | |
| name="wikipedia", | |
| func=wikipedia.run, | |
| description="Search Wikipedia for information" | |
| ) | |
| tools.append(wiki_tool) | |
| except Exception as e: | |
| logger.warning(f"Could not initialize Wikipedia: {e}") | |
| # Calculator tool | |
| def calculate(expression: str) -> str: | |
| """Safely evaluate mathematical expressions""" | |
| try: | |
| # Remove any dangerous characters | |
| safe_chars = "0123456789+-*/()., " | |
| expression = ''.join(c for c in expression if c in safe_chars) | |
| result = eval(expression) | |
| return str(result) | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| calc_tool = Tool( | |
| name="calculator", | |
| func=calculate, | |
| description="Perform mathematical calculations" | |
| ) | |
| tools.append(calc_tool) | |
| return tools | |
| # Create the agent | |
| class GAIAAgent: | |
| def __init__(self): | |
| self.llm = ChatAnthropic( | |
| model="claude-3-5-sonnet-20241022", | |
| temperature=0, | |
| max_tokens=1024, | |
| api_key=os.getenv("ANTHROPIC_API_KEY") | |
| ) | |
| self.tools = setup_tools() | |
| def create_graph(self): | |
| """Create the state graph""" | |
| workflow = StateGraph(GraphState) | |
| # Add nodes | |
| workflow.add_node("assistant", self.assistant_node) | |
| workflow.add_node("tools", self.tools_node) | |
| workflow.add_node("extract_answer", self.extract_answer_node) | |
| # Set entry point | |
| workflow.set_entry_point("assistant") | |
| # Add edges | |
| workflow.add_conditional_edges( | |
| "assistant", | |
| self.should_continue, | |
| { | |
| "tools": "tools", | |
| "extract_answer": "extract_answer", | |
| "end": END | |
| } | |
| ) | |
| workflow.add_edge("tools", "assistant") | |
| workflow.add_edge("extract_answer", END) | |
| return workflow.compile() | |
| def assistant_node(self, state: GraphState) -> Dict: | |
| """Main assistant logic""" | |
| messages = state.messages | |
| # Add system message if first iteration | |
| if state.iterations == 0: | |
| messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages | |
| # Get response from LLM | |
| if self.tools: | |
| response = self.llm.bind_tools(self.tools).invoke(messages) | |
| else: | |
| response = self.llm.invoke(messages) | |
| # Check if final answer is in response | |
| if isinstance(response.content, str) and "FINAL ANSWER:" in response.content: | |
| answer = response.content.split("FINAL ANSWER:")[-1].strip() | |
| return { | |
| "messages": [response], | |
| "final_answer_text": answer, | |
| "iterations": state.iterations + 1 | |
| } | |
| return { | |
| "messages": [response], | |
| "iterations": state.iterations + 1 | |
| } | |
| def tools_node(self, state: GraphState) -> Dict: | |
| """Execute tools""" | |
| messages = state.messages | |
| last_message = messages[-1] | |
| tool_results = [] | |
| if hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| for tool_call in last_message.tool_calls: | |
| tool_name = tool_call["name"] | |
| tool_args = tool_call.get("args", {}) | |
| # Find and execute tool | |
| for tool in self.tools: | |
| if tool.name == tool_name: | |
| try: | |
| if isinstance(tool_args, dict) and len(tool_args) == 1: | |
| # Get the first argument value | |
| arg_value = list(tool_args.values())[0] | |
| result = tool.func(arg_value) | |
| else: | |
| result = tool.func(str(tool_args)) | |
| tool_results.append({ | |
| "tool": tool_name, | |
| "result": result | |
| }) | |
| except Exception as e: | |
| tool_results.append({ | |
| "tool": tool_name, | |
| "result": f"Error: {str(e)}" | |
| }) | |
| break | |
| # Format results | |
| if tool_results: | |
| result_text = "\n".join([ | |
| f"Tool: {r['tool']}\nResult: {r['result']}" | |
| for r in tool_results | |
| ]) | |
| return {"messages": [AIMessage(content=result_text)]} | |
| return {"messages": []} | |
| def extract_answer_node(self, state: GraphState) -> Dict: | |
| """Extract final answer from conversation""" | |
| # Look through all messages for an answer | |
| for message in reversed(state.messages): | |
| if hasattr(message, "content") and message.content: | |
| content = str(message.content) | |
| if "FINAL ANSWER:" in content: | |
| answer = content.split("FINAL ANSWER:")[-1].strip() | |
| return {"final_answer_text": answer} | |
| # If no explicit final answer, ask for one | |
| prompt = "Based on our conversation, please provide your final answer. Format: FINAL ANSWER: [your answer]" | |
| response = self.llm.invoke([HumanMessage(content=prompt)]) | |
| if "FINAL ANSWER:" in response.content: | |
| answer = response.content.split("FINAL ANSWER:")[-1].strip() | |
| return {"final_answer_text": answer} | |
| return {"final_answer_text": "Unable to determine answer"} | |
| def should_continue(self, state: GraphState) -> str: | |
| """Decide next action""" | |
| if state.final_answer_text: | |
| return "end" | |
| if state.iterations >= state.max_iterations: | |
| return "extract_answer" | |
| last_message = state.messages[-1] if state.messages else None | |
| if last_message and hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| return "tools" | |
| if last_message and "FINAL ANSWER:" in str(last_message.content): | |
| return "extract_answer" | |
| return "end" | |
| # Main agent function | |
| async def basic_agent(question: str) -> str: | |
| """Process a question and return an answer""" | |
| try: | |
| # Create agent | |
| agent = GAIAAgent() | |
| graph = agent.create_graph() | |
| # Run the graph | |
| initial_state = GraphState( | |
| messages=[HumanMessage(content=question)] | |
| ) | |
| result = await graph.ainvoke(initial_state) | |
| # Extract answer | |
| if result.get("final_answer_text"): | |
| return result["final_answer_text"] | |
| # Fallback: look for answer in messages | |
| for message in reversed(result.get("messages", [])): | |
| if hasattr(message, "content") and message.content: | |
| return str(message.content) | |
| return "Unable to determine answer" | |
| except Exception as e: | |
| logger.error(f"Error in basic_agent: {str(e)}") | |
| return f"Error: {str(e)}" | |
| # For testing | |
| if __name__ == "__main__": | |
| import asyncio | |
| test_question = "What is the capital of France?" | |
| answer = asyncio.run(basic_agent(test_question)) | |
| print(f"Question: {test_question}") | |
| print(f"Answer: {answer}") |