Spaces:
Runtime error
Runtime error
Create agent.py
Browse files
agent.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GAIA Agent Solution with LangGraph and OpenAI
|
| 2 |
+
import os
|
| 3 |
+
from typing import TypedDict, Annotated, Sequence, Union
|
| 4 |
+
from langgraph.graph import StateGraph, END
|
| 5 |
+
from langchain_community.tools import DuckDuckGoSearchResults
|
| 6 |
+
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langchain_core.tools import tool
|
| 9 |
+
from langchain_core.utils.function_calling import convert_to_openai_tool
|
| 10 |
+
import json
|
| 11 |
+
from openai import OpenAI # For vision capabilities
|
| 12 |
+
|
| 13 |
+
# Set your OpenAI API key
|
| 14 |
+
openai_api_key = os.getenv("OPENAI_API_KEY") # Replace with your actual key
|
| 15 |
+
|
| 16 |
+
# ---------------------
|
| 17 |
+
# Tool Definitions
|
| 18 |
+
# ---------------------
|
| 19 |
+
|
| 20 |
+
# Web Search Tool
|
| 21 |
+
search_tool = DuckDuckGoSearchResults(max_results=3)
|
| 22 |
+
|
| 23 |
+
# Image Description Tool (using GPT-4 Vision)
|
| 24 |
+
@tool
|
| 25 |
+
def describe_image(image_url: str) -> str:
|
| 26 |
+
"""Generate detailed description of an image from its URL"""
|
| 27 |
+
vision_client = OpenAI()
|
| 28 |
+
response = vision_client.chat.completions.create(
|
| 29 |
+
model="gpt-4-vision-preview",
|
| 30 |
+
messages=[
|
| 31 |
+
{
|
| 32 |
+
"role": "user",
|
| 33 |
+
"content": [
|
| 34 |
+
{"type": "text", "text": "Describe this image in detail. Include text, objects, colors, and context."},
|
| 35 |
+
{"type": "image_url", "image_url": {"url": image_url}}
|
| 36 |
+
]
|
| 37 |
+
}
|
| 38 |
+
],
|
| 39 |
+
max_tokens=500
|
| 40 |
+
)
|
| 41 |
+
return response.choices[0].message.content
|
| 42 |
+
|
| 43 |
+
# Math Tool (example - extend with more capabilities)
|
| 44 |
+
@tool
|
| 45 |
+
def calculate(expression: str) -> Union[float, str]:
|
| 46 |
+
"""Evaluate mathematical expressions. Input must be a valid math expression."""
|
| 47 |
+
try:
|
| 48 |
+
return eval(expression) # For real usage, use a safe evaluator like numexpr
|
| 49 |
+
except:
|
| 50 |
+
return "Error: Invalid expression"
|
| 51 |
+
|
| 52 |
+
# ---------------------
|
| 53 |
+
# Agent Setup
|
| 54 |
+
# ---------------------
|
| 55 |
+
|
| 56 |
+
# Available tools
|
| 57 |
+
tools = [search_tool, describe_image, calculate]
|
| 58 |
+
tools_as_openai = [convert_to_openai_tool(t) for t in tools]
|
| 59 |
+
|
| 60 |
+
# Agent State Definition
|
| 61 |
+
class AgentState(TypedDict):
|
| 62 |
+
messages: Annotated[Sequence[BaseMessage], operator.add]
|
| 63 |
+
|
| 64 |
+
# Initialize LLM (GPT-4 Turbo for best results)
|
| 65 |
+
model = ChatOpenAI(model="gpt-4-turbo", temperature=0)
|
| 66 |
+
|
| 67 |
+
# ---------------------
|
| 68 |
+
# Graph Nodes
|
| 69 |
+
# ---------------------
|
| 70 |
+
|
| 71 |
+
def run_agent(state: AgentState):
|
| 72 |
+
"""Node: Run the agent's reasoning"""
|
| 73 |
+
messages = state["messages"]
|
| 74 |
+
response = model.invoke(messages, tools=tools_as_openai)
|
| 75 |
+
return {"messages": [response]}
|
| 76 |
+
|
| 77 |
+
def run_tools(state: AgentState):
|
| 78 |
+
"""Node: Execute tools based on agent's request"""
|
| 79 |
+
messages = state["messages"]
|
| 80 |
+
last_message = messages[-1]
|
| 81 |
+
|
| 82 |
+
tool_messages = []
|
| 83 |
+
for tool_call in last_message.additional_kwargs.get("tool_calls", []):
|
| 84 |
+
function_name = tool_call["function"]["name"]
|
| 85 |
+
function_args = json.loads(tool_call["function"]["arguments"])
|
| 86 |
+
|
| 87 |
+
# Find matching tool
|
| 88 |
+
tool = next((t for t in tools if t.name == function_name), None)
|
| 89 |
+
|
| 90 |
+
if tool:
|
| 91 |
+
try:
|
| 92 |
+
# Special handling for image URLs in questions
|
| 93 |
+
if function_name == "describe_image" and "http" not in function_args["image_url"]:
|
| 94 |
+
function_args["image_url"] = find_image_url(messages, function_args["image_url"])
|
| 95 |
+
|
| 96 |
+
# Execute tool
|
| 97 |
+
output = tool.invoke(function_args)
|
| 98 |
+
content = f"Tool Result: {str(output)}"
|
| 99 |
+
except Exception as e:
|
| 100 |
+
content = f"Error: {str(e)}"
|
| 101 |
+
else:
|
| 102 |
+
content = f"Tool {function_name} not available"
|
| 103 |
+
|
| 104 |
+
tool_messages.append(
|
| 105 |
+
ToolMessage(
|
| 106 |
+
content=content,
|
| 107 |
+
tool_call_id=tool_call["id"]
|
| 108 |
+
)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return {"messages": tool_messages}
|
| 112 |
+
|
| 113 |
+
# ---------------------
|
| 114 |
+
# Helper Functions
|
| 115 |
+
# ---------------------
|
| 116 |
+
|
| 117 |
+
def find_image_url(messages: Sequence[BaseMessage], reference: str) -> str:
|
| 118 |
+
"""Extract actual image URL from message context"""
|
| 119 |
+
for msg in messages:
|
| 120 |
+
if reference in msg.content:
|
| 121 |
+
# Simple extraction - improve with regex for production
|
| 122 |
+
if "http" in msg.content and ("jpg" in msg.content or "png" in msg.content):
|
| 123 |
+
start = msg.content.find("http")
|
| 124 |
+
return msg.content[start:].split()[0]
|
| 125 |
+
return reference # Fallback to original reference
|
| 126 |
+
|
| 127 |
+
# ---------------------
|
| 128 |
+
# Graph Construction
|
| 129 |
+
# ---------------------
|
| 130 |
+
|
| 131 |
+
# Decision logic for graph flow
|
| 132 |
+
def should_continue(state: AgentState):
|
| 133 |
+
last_message = state["messages"][-1]
|
| 134 |
+
if last_message.tool_calls:
|
| 135 |
+
return "run_tools"
|
| 136 |
+
return "end"
|
| 137 |
+
|
| 138 |
+
# Build the graph
|
| 139 |
+
graph = StateGraph(AgentState)
|
| 140 |
+
graph.add_node("run_agent", run_agent)
|
| 141 |
+
graph.add_node("run_tools", run_tools)
|
| 142 |
+
graph.set_entry_point("run_agent")
|
| 143 |
+
|
| 144 |
+
graph.add_conditional_edges(
|
| 145 |
+
"run_agent",
|
| 146 |
+
should_continue,
|
| 147 |
+
{
|
| 148 |
+
"run_tools": "run_tools",
|
| 149 |
+
"end": END
|
| 150 |
+
}
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
graph.add_edge("run_tools", "run_agent")
|
| 154 |
+
agent = graph.compile()
|
| 155 |
+
|
| 156 |
+
# ---------------------
|
| 157 |
+
# Execution Function
|
| 158 |
+
# ---------------------
|
| 159 |
+
|
| 160 |
+
def solve_gaia_task(question: str) -> str:
|
| 161 |
+
"""Solve GAIA tasks with our agent"""
|
| 162 |
+
system_prompt = (
|
| 163 |
+
"You are a GAIA problem-solving expert. Follow these rules:\n"
|
| 164 |
+
"1. Use tools for current information\n"
|
| 165 |
+
"2. Break complex problems into steps\n"
|
| 166 |
+
"3. Verify answers before finalizing\n"
|
| 167 |
+
"4. Format final answers EXACTLY as requested:\n"
|
| 168 |
+
" - Lists: comma-separated values\n"
|
| 169 |
+
" - Numbers: digits only\n"
|
| 170 |
+
" - Dates: YYYY-MM-DD format\n"
|
| 171 |
+
"5. Never include reasoning in final answers"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Initialize agent state
|
| 175 |
+
state = {
|
| 176 |
+
"messages": [
|
| 177 |
+
SystemMessage(content=system_prompt),
|
| 178 |
+
HumanMessage(content=question)
|
| 179 |
+
]
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
# Run the agent
|
| 183 |
+
final_state = agent.invoke(state)
|
| 184 |
+
|
| 185 |
+
# Extract and return final answer
|
| 186 |
+
for msg in reversed(final_state["messages"]):
|
| 187 |
+
if msg.type == "ai" and not msg.tool_calls:
|
| 188 |
+
return msg.content
|
| 189 |
+
return "No final answer found"
|
| 190 |
+
|
| 191 |
+
# ---------------------
|
| 192 |
+
# Example Execution
|
| 193 |
+
# ---------------------
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
# Example GAIA task
|
| 196 |
+
task = (
|
| 197 |
+
"What is the current population of the country where the 2023 "
|
| 198 |
+
"World Artificial Intelligence Conference was held? "
|
| 199 |
+
"Include only the numeric value in your answer."
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
result = solve_gaia_task(task)
|
| 203 |
+
print("\n--- FINAL ANSWER ---")
|
| 204 |
+
print(result)
|