Aya1610 commited on
Commit
afdb032
·
verified ·
1 Parent(s): 3f0d696

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +204 -0
agent.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GAIA Agent Solution with LangGraph and OpenAI
2
+ import os
3
+ from typing import TypedDict, Annotated, Sequence, Union
4
+ from langgraph.graph import StateGraph, END
5
+ from langchain_community.tools import DuckDuckGoSearchResults
6
+ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, ToolMessage
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.tools import tool
9
+ from langchain_core.utils.function_calling import convert_to_openai_tool
10
+ import json
11
+ from openai import OpenAI # For vision capabilities
12
+
13
+ # Set your OpenAI API key
14
+ openai_api_key = os.getenv("OPENAI_API_KEY") # Replace with your actual key
15
+
16
+ # ---------------------
17
+ # Tool Definitions
18
+ # ---------------------
19
+
20
+ # Web Search Tool
21
+ search_tool = DuckDuckGoSearchResults(max_results=3)
22
+
23
+ # Image Description Tool (using GPT-4 Vision)
24
+ @tool
25
+ def describe_image(image_url: str) -> str:
26
+ """Generate detailed description of an image from its URL"""
27
+ vision_client = OpenAI()
28
+ response = vision_client.chat.completions.create(
29
+ model="gpt-4-vision-preview",
30
+ messages=[
31
+ {
32
+ "role": "user",
33
+ "content": [
34
+ {"type": "text", "text": "Describe this image in detail. Include text, objects, colors, and context."},
35
+ {"type": "image_url", "image_url": {"url": image_url}}
36
+ ]
37
+ }
38
+ ],
39
+ max_tokens=500
40
+ )
41
+ return response.choices[0].message.content
42
+
43
+ # Math Tool (example - extend with more capabilities)
44
+ @tool
45
+ def calculate(expression: str) -> Union[float, str]:
46
+ """Evaluate mathematical expressions. Input must be a valid math expression."""
47
+ try:
48
+ return eval(expression) # For real usage, use a safe evaluator like numexpr
49
+ except:
50
+ return "Error: Invalid expression"
51
+
52
+ # ---------------------
53
+ # Agent Setup
54
+ # ---------------------
55
+
56
+ # Available tools
57
+ tools = [search_tool, describe_image, calculate]
58
+ tools_as_openai = [convert_to_openai_tool(t) for t in tools]
59
+
60
+ # Agent State Definition
61
+ class AgentState(TypedDict):
62
+ messages: Annotated[Sequence[BaseMessage], operator.add]
63
+
64
+ # Initialize LLM (GPT-4 Turbo for best results)
65
+ model = ChatOpenAI(model="gpt-4-turbo", temperature=0)
66
+
67
+ # ---------------------
68
+ # Graph Nodes
69
+ # ---------------------
70
+
71
+ def run_agent(state: AgentState):
72
+ """Node: Run the agent's reasoning"""
73
+ messages = state["messages"]
74
+ response = model.invoke(messages, tools=tools_as_openai)
75
+ return {"messages": [response]}
76
+
77
+ def run_tools(state: AgentState):
78
+ """Node: Execute tools based on agent's request"""
79
+ messages = state["messages"]
80
+ last_message = messages[-1]
81
+
82
+ tool_messages = []
83
+ for tool_call in last_message.additional_kwargs.get("tool_calls", []):
84
+ function_name = tool_call["function"]["name"]
85
+ function_args = json.loads(tool_call["function"]["arguments"])
86
+
87
+ # Find matching tool
88
+ tool = next((t for t in tools if t.name == function_name), None)
89
+
90
+ if tool:
91
+ try:
92
+ # Special handling for image URLs in questions
93
+ if function_name == "describe_image" and "http" not in function_args["image_url"]:
94
+ function_args["image_url"] = find_image_url(messages, function_args["image_url"])
95
+
96
+ # Execute tool
97
+ output = tool.invoke(function_args)
98
+ content = f"Tool Result: {str(output)}"
99
+ except Exception as e:
100
+ content = f"Error: {str(e)}"
101
+ else:
102
+ content = f"Tool {function_name} not available"
103
+
104
+ tool_messages.append(
105
+ ToolMessage(
106
+ content=content,
107
+ tool_call_id=tool_call["id"]
108
+ )
109
+ )
110
+
111
+ return {"messages": tool_messages}
112
+
113
+ # ---------------------
114
+ # Helper Functions
115
+ # ---------------------
116
+
117
+ def find_image_url(messages: Sequence[BaseMessage], reference: str) -> str:
118
+ """Extract actual image URL from message context"""
119
+ for msg in messages:
120
+ if reference in msg.content:
121
+ # Simple extraction - improve with regex for production
122
+ if "http" in msg.content and ("jpg" in msg.content or "png" in msg.content):
123
+ start = msg.content.find("http")
124
+ return msg.content[start:].split()[0]
125
+ return reference # Fallback to original reference
126
+
127
+ # ---------------------
128
+ # Graph Construction
129
+ # ---------------------
130
+
131
+ # Decision logic for graph flow
132
+ def should_continue(state: AgentState):
133
+ last_message = state["messages"][-1]
134
+ if last_message.tool_calls:
135
+ return "run_tools"
136
+ return "end"
137
+
138
+ # Build the graph
139
+ graph = StateGraph(AgentState)
140
+ graph.add_node("run_agent", run_agent)
141
+ graph.add_node("run_tools", run_tools)
142
+ graph.set_entry_point("run_agent")
143
+
144
+ graph.add_conditional_edges(
145
+ "run_agent",
146
+ should_continue,
147
+ {
148
+ "run_tools": "run_tools",
149
+ "end": END
150
+ }
151
+ )
152
+
153
+ graph.add_edge("run_tools", "run_agent")
154
+ agent = graph.compile()
155
+
156
+ # ---------------------
157
+ # Execution Function
158
+ # ---------------------
159
+
160
+ def solve_gaia_task(question: str) -> str:
161
+ """Solve GAIA tasks with our agent"""
162
+ system_prompt = (
163
+ "You are a GAIA problem-solving expert. Follow these rules:\n"
164
+ "1. Use tools for current information\n"
165
+ "2. Break complex problems into steps\n"
166
+ "3. Verify answers before finalizing\n"
167
+ "4. Format final answers EXACTLY as requested:\n"
168
+ " - Lists: comma-separated values\n"
169
+ " - Numbers: digits only\n"
170
+ " - Dates: YYYY-MM-DD format\n"
171
+ "5. Never include reasoning in final answers"
172
+ )
173
+
174
+ # Initialize agent state
175
+ state = {
176
+ "messages": [
177
+ SystemMessage(content=system_prompt),
178
+ HumanMessage(content=question)
179
+ ]
180
+ }
181
+
182
+ # Run the agent
183
+ final_state = agent.invoke(state)
184
+
185
+ # Extract and return final answer
186
+ for msg in reversed(final_state["messages"]):
187
+ if msg.type == "ai" and not msg.tool_calls:
188
+ return msg.content
189
+ return "No final answer found"
190
+
191
+ # ---------------------
192
+ # Example Execution
193
+ # ---------------------
194
+ if __name__ == "__main__":
195
+ # Example GAIA task
196
+ task = (
197
+ "What is the current population of the country where the 2023 "
198
+ "World Artificial Intelligence Conference was held? "
199
+ "Include only the numeric value in your answer."
200
+ )
201
+
202
+ result = solve_gaia_task(task)
203
+ print("\n--- FINAL ANSWER ---")
204
+ print(result)