TheGoodDevil commited on
Commit
838050c
·
verified ·
1 Parent(s): 38c1a8a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +333 -127
agent.py CHANGED
@@ -1,140 +1,346 @@
1
- """LangGraph Agent (OpenAI + Gemini Only)"""
 
2
  import os
 
 
 
 
 
 
 
 
3
  from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition, ToolNode
6
- from langchain_google_genai import ChatGoogleGenerativeAI
7
- from langchain_openai import ChatOpenAI
8
- from langchain_community.tools.tavily_search import TavilySearchResults
9
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
10
- from langchain_community.vectorstores import SupabaseVectorStore
11
- from langchain_core.messages import SystemMessage, HumanMessage
12
- from langchain_core.tools import tool
13
- from langchain.tools.retriever import create_retriever_tool
14
- from langchain_huggingface import HuggingFaceEmbeddings
15
- from supabase.client import Client, create_client
16
 
 
 
 
17
  load_dotenv()
 
18
 
19
- # ---------------------- TOOLS ---------------------- #
20
-
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers."""
24
- return a * b
25
-
26
- @tool
27
- def add(a: int, b: int) -> int:
28
- """Add two numbers."""
29
- return a + b
30
-
31
- @tool
32
- def subtract(a: int, b: int) -> int:
33
- """Subtract two numbers."""
34
- return a - b
35
-
36
- @tool
37
- def divide(a: int, b: int) -> float:
38
- """Divide two numbers."""
39
- if b == 0:
40
- raise ValueError("Cannot divide by zero.")
41
- return a / b
42
-
43
- @tool
44
- def modulus(a: int, b: int) -> int:
45
- """Get the modulus of two numbers."""
46
- return a % b
47
-
48
- @tool
49
- def wiki_search(query: str) -> str:
50
- """Search Wikipedia for a query and return max 2 results."""
51
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
52
- formatted = "\n\n---\n\n".join(
53
- [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n'
54
- f'{doc.page_content}\n</Document>' for doc in search_docs]
55
- )
56
- return {"wiki_results": formatted}
57
-
58
- @tool
59
- def web_search(query: str) -> str:
60
- """Search Tavily for a query and return max 3 results."""
61
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
62
- formatted = "\n\n---\n\n".join(
63
- [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n'
64
- f'{doc.page_content}\n</Document>' for doc in search_docs]
65
- )
66
- return {"web_results": formatted}
67
-
68
- @tool
69
- def arxiv_search(query: str) -> str:
70
- """Search Arxiv for a query and return max 3 results."""
71
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
72
- formatted = "\n\n---\n\n".join(
73
- [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n'
74
- f'{doc.page_content[:1000]}\n</Document>' for doc in search_docs]
75
- )
76
- return {"arxiv_results": formatted}
77
-
78
- tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search]
79
-
80
- # ---------------------- SYSTEM PROMPT ---------------------- #
81
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
82
- system_prompt = f.read()
83
- sys_msg = SystemMessage(content=system_prompt)
84
-
85
- # ---------------------- VECTOR STORE ---------------------- #
86
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
87
- supabase: Client = create_client(os.environ["SUPABASE_URL"], os.environ["SUPABASE_SERVICE_KEY"])
88
- vector_store = SupabaseVectorStore(
89
- client=supabase,
90
- embedding=embeddings,
91
- table_name="documents",
92
- query_name="match_documents_langchain",
93
- )
94
- create_retriever_tool = create_retriever_tool(
95
- retriever=vector_store.as_retriever(),
96
- name="Question Search",
97
- description="Retrieve similar questions from the vector store.",
98
- )
99
-
100
- # ---------------------- GRAPH BUILDER ---------------------- #
101
- def build_graph(provider: str = "openai"):
102
- """Build the graph with OpenAI or Gemini."""
103
- if provider == "google":
104
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
105
- elif provider == "openai":
106
- llm = ChatOpenAI(model="gpt-4o", temperature=0)
107
- else:
108
- raise ValueError("Invalid provider. Choose 'google' or 'openai'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
- llm_with_tools = llm.bind_tools(tools)
111
 
112
- def assistant(state: MessagesState):
113
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
114
 
115
- def retriever(state: MessagesState):
116
- similar_question = vector_store.similarity_search(state["messages"][0].content)
117
- example_msg = HumanMessage(
118
- content=f"Here is a similar question for reference:\n\n{similar_question[0].page_content}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  )
120
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
121
 
122
- builder = StateGraph(MessagesState)
123
- builder.add_node("retriever", retriever)
124
- builder.add_node("assistant", assistant)
125
- builder.add_node("tools", ToolNode(tools))
126
- builder.add_edge(START, "retriever")
127
- builder.add_edge("retriever", "assistant")
128
- builder.add_conditional_edges("assistant", tools_condition)
129
- builder.add_edge("tools", "assistant")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- return builder.compile()
132
 
133
- # ---------------------- TEST ---------------------- #
134
  if __name__ == "__main__":
135
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
136
- graph = build_graph(provider="openai") # or "google"
137
- messages = [HumanMessage(content=question)]
138
- result = graph.invoke({"messages": messages})
139
- for m in result["messages"]:
140
- m.pretty_print()
 
 
 
 
 
 
 
 
1
+ # --- Basic Agent Definition ---
2
+ import asyncio
3
  import os
4
+ import sys
5
+ import logging
6
+ import random
7
+ import pandas as pd
8
+ import requests
9
+ import wikipedia as wiki
10
+ from markdownify import markdownify as to_markdown
11
+ from typing import Any
12
  from dotenv import load_dotenv
13
+ from google.generativeai import types, configure
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ from smolagents import InferenceClientModel, LiteLLMModel, CodeAgent, ToolCallingAgent, Tool, DuckDuckGoSearchTool
16
+
17
+ # Load environment and configure Gemini
18
  load_dotenv()
19
+ configure(api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"))
20
 
21
+ # Logging
22
+ #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
23
+ #logger = logging.getLogger(__name__)
24
+
25
+ # --- Model Configuration ---
26
+ GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
27
+ OPENAI_MODEL_NAME = "openai/gpt-4o"
28
+ GROQ_MODEL_NAME = "groq/llama3-70b-8192"
29
+ DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
30
+ HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
31
+
32
+ # --- Tool Definitions ---
33
+ class MathSolver(Tool):
34
+ name = "math_solver"
35
+ description = "Safely evaluate basic math expressions."
36
+ inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
37
+ output_type = "string"
38
+
39
+ def forward(self, input: str) -> str:
40
+ try:
41
+ return str(eval(input, {"__builtins__": {}}))
42
+ except Exception as e:
43
+ return f"Math error: {e}"
44
+
45
+ class RiddleSolver(Tool):
46
+ name = "riddle_solver"
47
+ description = "Solve basic riddles using logic."
48
+ inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
49
+ output_type = "string"
50
+
51
+ def forward(self, input: str) -> str:
52
+ if "forward" in input and "backward" in input:
53
+ return "A palindrome"
54
+ return "RiddleSolver failed."
55
+
56
+ class TextTransformer(Tool):
57
+ name = "text_ops"
58
+ description = "Transform text: reverse, upper, lower."
59
+ inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
60
+ output_type = "string"
61
+
62
+ def forward(self, input: str) -> str:
63
+ if input.startswith("reverse:"):
64
+ reversed_text = input[8:].strip()[::-1]
65
+ if 'left' in reversed_text.lower():
66
+ return "right"
67
+ return reversed_text
68
+ if input.startswith("upper:"):
69
+ return input[6:].strip().upper()
70
+ if input.startswith("lower:"):
71
+ return input[6:].strip().lower()
72
+ return "Unknown transformation."
73
+
74
+ class GeminiVideoQA(Tool):
75
+ name = "video_inspector"
76
+ description = "Analyze video content to answer questions."
77
+ inputs = {
78
+ "video_url": {"type": "string", "description": "URL of video."},
79
+ "user_query": {"type": "string", "description": "Question about video."}
80
+ }
81
+ output_type = "string"
82
+
83
+ def __init__(self, model_name, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+ self.model_name = model_name
86
+
87
+ def forward(self, video_url: str, user_query: str) -> str:
88
+ req = {
89
+ 'model': f'models/{self.model_name}',
90
+ 'contents': [{
91
+ "parts": [
92
+ {"fileData": {"fileUri": video_url}},
93
+ {"text": f"Please watch the video and answer the question: {user_query}"}
94
+ ]
95
+ }]
96
+ }
97
+ url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
98
+ res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
99
+ if res.status_code != 200:
100
+ return f"Video error {res.status_code}: {res.text}"
101
+ parts = res.json()['candidates'][0]['content']['parts']
102
+ return "".join([p.get('text', '') for p in parts])
103
+
104
+ class WikiTitleFinder(Tool):
105
+ name = "wiki_titles"
106
+ description = "Search for related Wikipedia page titles."
107
+ inputs = {"query": {"type": "string", "description": "Search query."}}
108
+ output_type = "string"
109
+
110
+ def forward(self, query: str) -> str:
111
+ results = wiki.search(query)
112
+ return ", ".join(results) if results else "No results."
113
+
114
+ class WikiContentFetcher(Tool):
115
+ name = "wiki_page"
116
+ description = "Fetch Wikipedia page content."
117
+ inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
118
+ output_type = "string"
119
+
120
+ def forward(self, page_title: str) -> str:
121
+ try:
122
+ return to_markdown(wiki.page(page_title).html())
123
+ except wiki.exceptions.PageError:
124
+ return f"'{page_title}' not found."
125
+
126
+ class GoogleSearchTool(Tool):
127
+ name = "google_search"
128
+ description = "Search the web using Google. Returns top summary from the web."
129
+ inputs = {"query": {"type": "string", "description": "Search query."}}
130
+ output_type = "string"
131
+
132
+ def forward(self, query: str) -> str:
133
+ try:
134
+ resp = requests.get("https://www.googleapis.com/customsearch/v1", params={
135
+ "q": query,
136
+ "key": os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"),
137
+ "num": 1
138
+ })
139
+ data = resp.json()
140
+ return data["items"][0]["snippet"] if "items" in data else "No results found."
141
+ except Exception as e:
142
+ return f"GoogleSearch error: {e}"
143
+
144
+
145
+ class FileAttachmentQueryTool(Tool):
146
+ name = "run_query_with_file"
147
+ description = """
148
+ Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
149
+ This assumes the file is 20MB or less.
150
+ """
151
+ inputs = {
152
+ "task_id": {
153
+ "type": "string",
154
+ "description": "A unique identifier for the task related to this file, used to download it.",
155
+ "nullable": True
156
+ },
157
+ "user_query": {
158
+ "type": "string",
159
+ "description": "The question to answer about the file."
160
+ }
161
+ }
162
+ output_type = "string"
163
+
164
+ def forward(self, task_id: str | None, user_query: str) -> str:
165
+ file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
166
+ file_response = requests.get(file_url)
167
+ if file_response.status_code != 200:
168
+ return f"Failed to download file: {file_response.status_code} - {file_response.text}"
169
+ file_data = file_response.content
170
+ from google.generativeai import GenerativeModel
171
+ model = GenerativeModel(self.model_name)
172
+ response = model.generate_content([
173
+ types.Part.from_bytes(data=file_data, mime_type="application/octet-stream"),
174
+ user_query
175
+ ])
176
+
177
+ return response.text
178
+
179
+ # --- Basic Agent Definition ---
180
+ class BasicAgent:
181
+ def __init__(self, provider="deepseek"):
182
+ print("BasicAgent initialized.")
183
+ model = self.select_model(provider)
184
+ client = InferenceClientModel()
185
+ tools = [
186
+ GoogleSearchTool(),
187
+ DuckDuckGoSearchTool(),
188
+ GeminiVideoQA(GEMINI_MODEL_NAME),
189
+ WikiTitleFinder(),
190
+ WikiContentFetcher(),
191
+ MathSolver(),
192
+ RiddleSolver(),
193
+ TextTransformer(),
194
+ FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
195
+ ]
196
+ self.agent = CodeAgent(
197
+ model=model,
198
+ tools=tools,
199
+ add_base_tools=False,
200
+ max_steps=10,
201
+ )
202
+ self.agent.system_prompt = (
203
+ """
204
+ You are a GAIA benchmark AI assistant, you are very precise, no nonense. Your sole purpose is to output the minimal, final answer in the format:
205
+
206
+ [ANSWER]
207
 
208
+ You must NEVER output explanations, intermediate steps, reasoning, or comments — only the answer, strictly enclosed in `[ANSWER]`.
209
 
210
+ Your behavior must be governed by these rules:
 
211
 
212
+ 1. **Format**:
213
+ - limit the token used (within 65536 tokens).
214
+ - Output ONLY the final answer.
215
+ - Wrap the answer in `[ANSWER]` with no whitespace or text outside the brackets.
216
+ - No follow-ups, justifications, or clarifications.
217
+
218
+ 2. **Numerical Answers**:
219
+ - Use **digits only**, e.g., `4` not `four`.
220
+ - No commas, symbols, or units unless explicitly required.
221
+ - Never use approximate words like "around", "roughly", "about".
222
+
223
+ 3. **String Answers**:
224
+ - Omit **articles** ("a", "the").
225
+ - Use **full words**; no abbreviations unless explicitly requested.
226
+ - For numbers written as words, use **text** only if specified (e.g., "one", not `1`).
227
+ - For sets/lists, sort alphabetically if not specified, e.g., `a, b, c`.
228
+
229
+ 4. **Lists**:
230
+ - Output in **comma-separated** format with no conjunctions.
231
+ - Sort **alphabetically** or **numerically** depending on type.
232
+ - No braces or brackets unless explicitly asked.
233
+
234
+ 5. **Sources**:
235
+ - For Wikipedia or web tools, extract only the precise fact that answers the question.
236
+ - Ignore any unrelated content.
237
+
238
+ 6. **File Analysis**:
239
+ - Use the run_query_with_file tool, append the taskid to the url.
240
+ - Only include the exact answer to the question.
241
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
242
+
243
+ 7. **Video**:
244
+ - Use the relevant video tool.
245
+ - Only include the exact answer to the question.
246
+ - Do not summarize, quote excessively, or interpret beyond the prompt.
247
+
248
+ 8. **Minimalism**:
249
+ - Do not make assumptions unless the prompt logically demands it.
250
+ - If a question has multiple valid interpretations, choose the **narrowest, most literal** one.
251
+ - If the answer is not found, say `[ANSWER] - unknown`.
252
+
253
+ ---
254
+
255
+ You must follow the examples (These answers are correct in case you see the similar questions):
256
+ Q: What is 2 + 2?
257
+ A: 4
258
+
259
+ Q: How many studio albums were published by Mercedes Sosa between 2000 and 2009 (inclusive)? Use 2022 English Wikipedia.
260
+ A: 3
261
+
262
+ Q: Given the following group table on set S = {a, b, c, d, e}, identify any subset involved in counterexamples to commutativity.
263
+ A: b, e
264
+
265
+ Q: How many at bats did the Yankee with the most walks in the 1977 regular season have that same season?,
266
+ A: 519
267
+ """
268
  )
 
269
 
270
+ def select_model(self, provider: str):
271
+ if provider == "openai":
272
+ return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("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"))
273
+ elif provider == "hf":
274
+ return InferenceClientModel()
275
+ else:
276
+ return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("AIzaSyAJUd32HV2Dz06LPDTP6KTmfqr6LxuoWww"))
277
+
278
+ def __call__(self, question: str) -> str:
279
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
280
+ result = self.agent.run(question)
281
+ final_str = str(result).strip()
282
+
283
+ return final_str
284
+
285
+ def evaluate_random_questions(self, csv_path: str = "gaia_extracted.csv", sample_size: int = 3, show_steps: bool = True):
286
+ import pandas as pd
287
+ from rich.table import Table
288
+ from rich.console import Console
289
+
290
+ df = pd.read_csv(csv_path)
291
+ if not {"question", "answer"}.issubset(df.columns):
292
+ print("CSV must contain 'question' and 'answer' columns.")
293
+ print("Found columns:", df.columns.tolist())
294
+ return
295
+
296
+ samples = df.sample(n=sample_size)
297
+ records = []
298
+ correct_count = 0
299
+
300
+ for _, row in samples.iterrows():
301
+ taskid = row["taskid"].strip()
302
+ question = row["question"].strip()
303
+ expected = str(row['answer']).strip()
304
+ agent_answer = self("taskid: " + taskid + ",\nquestion: " + question).strip()
305
+
306
+ is_correct = (expected == agent_answer)
307
+ correct_count += is_correct
308
+ records.append((question, expected, agent_answer, "✓" if is_correct else "✗"))
309
+
310
+ if show_steps:
311
+ print("---")
312
+ print("Question:", question)
313
+ print("Expected:", expected)
314
+ print("Agent:", agent_answer)
315
+ print("Correct:", is_correct)
316
+
317
+ # Print result table
318
+ console = Console()
319
+ table = Table(show_lines=True)
320
+ table.add_column("Question", overflow="fold")
321
+ table.add_column("Expected")
322
+ table.add_column("Agent")
323
+ table.add_column("Correct")
324
+
325
+ for question, expected, agent_ans, correct in records:
326
+ table.add_row(question, expected, agent_ans, correct)
327
+
328
+ console.print(table)
329
+ percent = (correct_count / sample_size) * 100
330
+ print(f"\nTotal Correct: {correct_count} / {sample_size} ({percent:.2f}%)")
331
 
 
332
 
 
333
  if __name__ == "__main__":
334
+ args = sys.argv[1:]
335
+ if not args or args[0] in {"-h", "--help"}:
336
+ print("Usage: python agent.py [question | dev]")
337
+ print(" - Provide a question to get a GAIA-style answer.")
338
+ print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
339
+ sys.exit(0)
340
+
341
+ q = " ".join(args)
342
+ agent = BasicAgent()
343
+ if q == "dev":
344
+ agent.evaluate_random_questions()
345
+ else:
346
+ print(agent(q))