jebaponselvasingh commited on
Commit
93b72dc
Β·
1 Parent(s): 6190a63

Add application file

Browse files
Files changed (5) hide show
  1. .env +1 -0
  2. agent_enhanced.py +564 -0
  3. app.py +432 -0
  4. flagged/log.csv +2 -0
  5. requirements.txt +20 -0
.env ADDED
@@ -0,0 +1 @@
 
 
1
+ OPENAI_API_KEY="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"
agent_enhanced.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced GAIA Agent with LangGraph
3
+ Separate module for cleaner architecture and easier customization
4
+ """
5
+
6
+ import os
7
+ import re
8
+ import json
9
+ import requests
10
+ import tempfile
11
+ from typing import TypedDict, Annotated, Sequence, Literal, Any
12
+ import operator
13
+
14
+ from langgraph.graph import StateGraph, END
15
+ from langgraph.prebuilt import ToolNode
16
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
17
+ from langchain_core.tools import tool
18
+ from langchain_openai import ChatOpenAI
19
+ from langchain_community.tools import DuckDuckGoSearchResults
20
+ from langchain_experimental.utilities import PythonREPL
21
+ import pandas as pd
22
+
23
+
24
+ # ============ STATE DEFINITION ============
25
+ class AgentState(TypedDict):
26
+ """State maintained throughout the agent's execution."""
27
+ messages: Annotated[Sequence[BaseMessage], operator.add]
28
+ task_id: str
29
+ file_path: str | None
30
+ file_content: str | None
31
+ iteration_count: int
32
+ final_answer: str | None
33
+
34
+
35
+ # ============ TOOL DEFINITIONS ============
36
+ @tool
37
+ def web_search(query: str) -> str:
38
+ """
39
+ Search the web using DuckDuckGo for current information.
40
+ Use this for questions about recent events, facts, statistics, or any information
41
+ that might have changed or that you're uncertain about.
42
+
43
+ Args:
44
+ query: The search query string
45
+
46
+ Returns:
47
+ Search results with relevant snippets
48
+ """
49
+ import logging
50
+
51
+ # Suppress non-critical errors from DuckDuckGo's internal engines
52
+ # (Some engines like grokipedia may fail due to DNS issues, but others work fine)
53
+ ddgs_logger = logging.getLogger("ddgs.ddgs")
54
+ primp_logger = logging.getLogger("primp")
55
+
56
+ # Store original levels
57
+ ddgs_original = ddgs_logger.level if ddgs_logger.level else logging.NOTSET
58
+ primp_original = primp_logger.level if primp_logger.level else logging.NOTSET
59
+
60
+ # Suppress INFO level logs (which include non-critical engine errors)
61
+ ddgs_logger.setLevel(logging.WARNING)
62
+ primp_logger.setLevel(logging.WARNING)
63
+
64
+ try:
65
+ search = DuckDuckGoSearchResults(max_results=5, output_format="list")
66
+ results = search.run(query)
67
+
68
+ # Restore original logging levels
69
+ if ddgs_original != logging.NOTSET:
70
+ ddgs_logger.setLevel(ddgs_original)
71
+ if primp_original != logging.NOTSET:
72
+ primp_logger.setLevel(primp_original)
73
+
74
+ if isinstance(results, list):
75
+ formatted = []
76
+ for r in results:
77
+ if isinstance(r, dict):
78
+ formatted.append(f"Title: {r.get('title', 'N/A')}\nSnippet: {r.get('snippet', 'N/A')}\nLink: {r.get('link', 'N/A')}")
79
+ else:
80
+ formatted.append(str(r))
81
+ return "\n\n---\n\n".join(formatted)
82
+ return str(results)
83
+ except Exception as e:
84
+ # Restore original logging levels even on exception
85
+ if ddgs_original != logging.NOTSET:
86
+ ddgs_logger.setLevel(ddgs_original)
87
+ if primp_original != logging.NOTSET:
88
+ primp_logger.setLevel(primp_original)
89
+ return f"Search failed: {str(e)}. Try a different query or approach."
90
+
91
+
92
+ @tool
93
+ def python_executor(code: str) -> str:
94
+ """
95
+ Execute Python code for calculations, data analysis, or any computational task.
96
+ You have access to standard libraries: math, statistics, datetime, json, re, collections.
97
+
98
+ Args:
99
+ code: Python code to execute. Print statements will show in output.
100
+
101
+ Returns:
102
+ The output/result of the code execution
103
+ """
104
+ try:
105
+ repl = PythonREPL()
106
+ # Add common imports to the code
107
+ augmented_code = """
108
+ import math
109
+ import statistics
110
+ import datetime
111
+ import json
112
+ import re
113
+ from collections import Counter, defaultdict
114
+ """ + code
115
+ result = repl.run(augmented_code)
116
+ return result.strip() if result else "Code executed successfully with no output. Add print() to see results."
117
+ except Exception as e:
118
+ return f"Execution error: {str(e)}. Please fix the code and try again."
119
+
120
+
121
+ @tool
122
+ def read_file(file_path: str) -> str:
123
+ """
124
+ Read and extract content from various file types.
125
+ Supports: PDF, TXT, MD, CSV, JSON, XLSX, XLS, PY, and other text files.
126
+
127
+ Args:
128
+ file_path: Path to the file to read
129
+
130
+ Returns:
131
+ The content of the file as a string
132
+ """
133
+ try:
134
+ if not os.path.exists(file_path):
135
+ return f"Error: File not found at {file_path}"
136
+
137
+ file_lower = file_path.lower()
138
+
139
+ if file_lower.endswith('.pdf'):
140
+ from langchain_community.document_loaders import PyPDFLoader
141
+ loader = PyPDFLoader(file_path)
142
+ pages = loader.load()
143
+ content = "\n\n--- Page Break ---\n\n".join([p.page_content for p in pages])
144
+ return f"PDF Content ({len(pages)} pages):\n{content}"
145
+
146
+ elif file_lower.endswith(('.xlsx', '.xls')):
147
+ df = pd.read_excel(file_path, sheet_name=None) # Read all sheets
148
+ result = []
149
+ for sheet_name, sheet_df in df.items():
150
+ result.append(f"=== Sheet: {sheet_name} ===\n{sheet_df.to_string()}")
151
+ return "\n\n".join(result)
152
+
153
+ elif file_lower.endswith('.csv'):
154
+ df = pd.read_csv(file_path)
155
+ return f"CSV Data ({len(df)} rows):\n{df.to_string()}"
156
+
157
+ elif file_lower.endswith('.json'):
158
+ with open(file_path, 'r', encoding='utf-8') as f:
159
+ data = json.load(f)
160
+ return f"JSON Content:\n{json.dumps(data, indent=2)}"
161
+
162
+ else: # Default: treat as text
163
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
164
+ content = f.read()
165
+ return f"File Content:\n{content}"
166
+
167
+ except Exception as e:
168
+ return f"Error reading file: {str(e)}"
169
+
170
+
171
+ @tool
172
+ def calculator(expression: str) -> str:
173
+ """
174
+ Evaluate a mathematical expression safely.
175
+ Supports: basic arithmetic, trigonometry, logarithms, exponents, etc.
176
+
177
+ Args:
178
+ expression: Mathematical expression (e.g., "sqrt(16) + log(100, 10)")
179
+
180
+ Returns:
181
+ The numerical result as a string
182
+ """
183
+ try:
184
+ import math
185
+
186
+ # Define allowed functions and constants
187
+ safe_dict = {
188
+ 'abs': abs, 'round': round, 'min': min, 'max': max,
189
+ 'sum': sum, 'pow': pow, 'len': len,
190
+ 'sqrt': math.sqrt, 'log': math.log, 'log10': math.log10,
191
+ 'log2': math.log2, 'exp': math.exp,
192
+ 'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
193
+ 'asin': math.asin, 'acos': math.acos, 'atan': math.atan,
194
+ 'sinh': math.sinh, 'cosh': math.cosh, 'tanh': math.tanh,
195
+ 'ceil': math.ceil, 'floor': math.floor,
196
+ 'pi': math.pi, 'e': math.e,
197
+ 'factorial': math.factorial, 'gcd': math.gcd,
198
+ 'degrees': math.degrees, 'radians': math.radians,
199
+ }
200
+
201
+ result = eval(expression, {"__builtins__": {}}, safe_dict)
202
+
203
+ # Format nicely
204
+ if isinstance(result, float):
205
+ if result.is_integer():
206
+ return str(int(result))
207
+ return f"{result:.10g}" # Remove trailing zeros
208
+ return str(result)
209
+
210
+ except Exception as e:
211
+ return f"Calculation error: {str(e)}. Check your expression syntax."
212
+
213
+
214
+ @tool
215
+ def wikipedia_search(query: str) -> str:
216
+ """
217
+ Search Wikipedia for factual information about a specific topic.
218
+ Best for: historical facts, biographies, scientific concepts, definitions.
219
+
220
+ Args:
221
+ query: The topic to search for on Wikipedia
222
+
223
+ Returns:
224
+ Summary and key information from relevant Wikipedia articles
225
+ """
226
+ try:
227
+ import urllib.parse
228
+
229
+ # Search for articles
230
+ search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={urllib.parse.quote(query)}&format=json&srlimit=3"
231
+ response = requests.get(search_url, timeout=10)
232
+ data = response.json()
233
+
234
+ if 'query' not in data or 'search' not in data['query'] or not data['query']['search']:
235
+ return f"No Wikipedia articles found for '{query}'"
236
+
237
+ # Get full content of top result
238
+ top_title = data['query']['search'][0]['title']
239
+ content_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exintro=true&explaintext=true&titles={urllib.parse.quote(top_title)}&format=json"
240
+
241
+ content_response = requests.get(content_url, timeout=10)
242
+ content_data = content_response.json()
243
+
244
+ pages = content_data.get('query', {}).get('pages', {})
245
+ for page_id, page_data in pages.items():
246
+ if page_id != '-1':
247
+ title = page_data.get('title', '')
248
+ extract = page_data.get('extract', 'No content available')
249
+ return f"Wikipedia: {title}\n\n{extract[:2000]}"
250
+
251
+ return "Could not retrieve article content."
252
+
253
+ except Exception as e:
254
+ return f"Wikipedia search failed: {str(e)}"
255
+
256
+
257
+ @tool
258
+ def analyze_image(image_path: str, question: str) -> str:
259
+ """
260
+ Analyze an image file and answer questions about it.
261
+ Note: This is a placeholder - implement with vision model if needed.
262
+
263
+ Args:
264
+ image_path: Path to the image file
265
+ question: What to analyze or find in the image
266
+
267
+ Returns:
268
+ Description or analysis of the image
269
+ """
270
+ # This is a placeholder - you can integrate with GPT-4V or other vision models
271
+ return f"Image analysis not implemented. File: {image_path}, Question: {question}"
272
+
273
+
274
+ # Collect all tools
275
+ TOOLS = [web_search, python_executor, read_file, calculator, wikipedia_search]
276
+
277
+
278
+ # ============ SYSTEM PROMPT ============
279
+ SYSTEM_PROMPT = """You are an expert AI assistant designed to solve GAIA benchmark questions with maximum accuracy.
280
+
281
+ ## Your Mission
282
+ Provide PRECISE, EXACT answers. The benchmark uses EXACT STRING MATCHING, so your final answer must match the ground truth character-for-character.
283
+
284
+ ## Critical Answer Formatting Rules (MUST FOLLOW)
285
+
286
+ **DO NOT include "FINAL ANSWER:" or any prefix - just the answer itself.**
287
+
288
+ 1. **Numbers**: Give just the number.
289
+ - βœ… CORRECT: "42"
290
+ - ❌ WRONG: "The answer is 42", "42 units", "Answer: 42"
291
+
292
+ 2. **Names**: Exact spelling as found in sources. Check Wikipedia/official sources for correct spelling, capitalization, and punctuation.
293
+ - βœ… CORRECT: "John Smith"
294
+ - ❌ WRONG: "john smith", "John smith"
295
+
296
+ 3. **Lists**: Comma-separated, NO spaces after commas.
297
+ - βœ… CORRECT: "apple,banana,cherry"
298
+ - ❌ WRONG: "apple, banana, cherry", "apple,banana, cherry"
299
+
300
+ 4. **Dates**: Use the format specified in the question, or YYYY-MM-DD if not specified.
301
+ - βœ… CORRECT: "2024-01-15" or "January 15, 2024" (if question asks for that format)
302
+ - ❌ WRONG: "1/15/2024" (unless question asks for it)
303
+
304
+ 5. **Yes/No**: Just "Yes" or "No" (capitalized, no period).
305
+ - βœ… CORRECT: "Yes"
306
+ - ❌ WRONG: "yes", "Yes.", "The answer is Yes"
307
+
308
+ 6. **Counts**: Just the number.
309
+ - βœ… CORRECT: "5"
310
+ - ❌ WRONG: "5 items", "five", "There are 5"
311
+
312
+ 7. **No explanations**: Your final response must contain ONLY the answer, nothing else.
313
+ - βœ… CORRECT: "Paris"
314
+ - ❌ WRONG: "The answer is Paris because..."
315
+
316
+ ## Problem-Solving Strategy
317
+ 1. **Understand**: Read the question carefully. What exactly is being asked? Note any specific format requirements.
318
+ 2. **Check for File**: If a file is mentioned or available, ALWAYS read it FIRST - the answer is likely there.
319
+ 3. **Plan**: What information do I need? Which tools should I use?
320
+ 4. **Execute**: Use tools systematically. Verify information from multiple sources when possible.
321
+ 5. **Verify**: Double-check your answer format. Does it match the question's requirements? Is spelling correct?
322
+ 6. **Respond**: Give ONLY the final answer, no prefixes, no explanations.
323
+
324
+ ## Available Tools
325
+ - `read_file`: Read PDFs, spreadsheets, text files - USE THIS FIRST if a file is available
326
+ - `web_search`: Current information, recent events, facts
327
+ - `wikipedia_search`: Historical facts, biographies, definitions
328
+ - `python_executor`: Calculations, data processing, analysis
329
+ - `calculator`: Quick mathematical calculations
330
+
331
+ ## Tool Usage Priority
332
+ 1. **If file available**: Read file FIRST before doing anything else
333
+ 2. **For calculations**: Use python_executor for complex math, calculator for simple expressions
334
+ 3. **For facts**: Use wikipedia_search for established facts, web_search for current/recent information
335
+ 4. **Cross-reference**: When possible, verify important facts from multiple sources
336
+
337
+ ## Critical Reminders
338
+ - NEVER include "FINAL ANSWER:" or any prefix in your response
339
+ - NEVER add explanations or context to your final answer
340
+ - ALWAYS verify spelling, capitalization, and formatting
341
+ - ALWAYS read files first if they are available
342
+ - If uncertain about format, look for clues in the question itself
343
+ - Never guess - use tools to find accurate information
344
+
345
+ Remember: Your final message must contain ONLY the answer, nothing else. The scoring system uses exact string matching."""
346
+
347
+
348
+ # ============ LANGGRAPH AGENT ============
349
+ class GAIAAgent:
350
+ """LangGraph-based agent for GAIA benchmark."""
351
+
352
+ def __init__(
353
+ self,
354
+ model_name: str = "gpt-4o",
355
+ api_key: str = None,
356
+ temperature: float = 0,
357
+ max_iterations: int = 15
358
+ ):
359
+ """
360
+ Initialize the GAIA agent.
361
+
362
+ Args:
363
+ model_name: OpenAI model to use
364
+ api_key: OpenAI API key (or set OPENAI_API_KEY env var)
365
+ temperature: Model temperature (0 for deterministic)
366
+ max_iterations: Maximum tool-use iterations
367
+ """
368
+ self.model_name = model_name
369
+ self.max_iterations = max_iterations
370
+
371
+ self.llm = ChatOpenAI(
372
+ model=model_name,
373
+ temperature=temperature,
374
+ api_key=api_key or os.environ.get("OPENAI_API_KEY")
375
+ )
376
+ self.llm_with_tools = self.llm.bind_tools(TOOLS)
377
+ self.graph = self._build_graph()
378
+
379
+ def _build_graph(self) -> StateGraph:
380
+ """Construct the LangGraph workflow."""
381
+ workflow = StateGraph(AgentState)
382
+
383
+ # Define nodes
384
+ workflow.add_node("agent", self._agent_node)
385
+ workflow.add_node("tools", ToolNode(TOOLS))
386
+ workflow.add_node("extract_answer", self._extract_answer_node)
387
+
388
+ # Set entry point
389
+ workflow.set_entry_point("agent")
390
+
391
+ # Define edges
392
+ workflow.add_conditional_edges(
393
+ "agent",
394
+ self._route_agent_output,
395
+ {
396
+ "tools": "tools",
397
+ "end": "extract_answer"
398
+ }
399
+ )
400
+ workflow.add_edge("tools", "agent")
401
+ workflow.add_edge("extract_answer", END)
402
+
403
+ return workflow.compile()
404
+
405
+ def _agent_node(self, state: AgentState) -> dict:
406
+ """Process messages and decide on next action."""
407
+ messages = state["messages"]
408
+ iteration = state.get("iteration_count", 0)
409
+
410
+ # Add iteration warnings earlier to give agent more time to finish
411
+ if iteration >= self.max_iterations - 3:
412
+ warning_msg = "WARNING: Approaching iteration limit. Please provide your final answer now. Remember: just the answer, no prefix."
413
+ messages = list(messages) + [SystemMessage(content=warning_msg)]
414
+ elif iteration >= self.max_iterations - 5:
415
+ reminder_msg = "Reminder: When you're ready to answer, provide ONLY the final answer with no prefix like 'FINAL ANSWER:' or 'The answer is:'"
416
+ messages = list(messages) + [SystemMessage(content=reminder_msg)]
417
+
418
+ try:
419
+ response = self.llm_with_tools.invoke(messages)
420
+ except Exception as e:
421
+ # Graceful error handling
422
+ error_msg = AIMessage(content=f"Error during reasoning: {str(e)}. Please try a different approach or provide your best answer.")
423
+ return {
424
+ "messages": [error_msg],
425
+ "iteration_count": iteration + 1
426
+ }
427
+
428
+ return {
429
+ "messages": [response],
430
+ "iteration_count": iteration + 1
431
+ }
432
+
433
+ def _route_agent_output(self, state: AgentState) -> Literal["tools", "end"]:
434
+ """Determine whether to use tools or finish."""
435
+ last_message = state["messages"][-1]
436
+ iteration = state.get("iteration_count", 0)
437
+
438
+ # Force end if max iterations reached
439
+ if iteration >= self.max_iterations:
440
+ return "end"
441
+
442
+ # Check if agent wants to use tools
443
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
444
+ return "tools"
445
+
446
+ return "end"
447
+
448
+ def _extract_answer_node(self, state: AgentState) -> dict:
449
+ """Extract and clean the final answer."""
450
+ last_message = state["messages"][-1]
451
+ content = last_message.content if hasattr(last_message, "content") else str(last_message)
452
+
453
+ answer = self._clean_answer(content)
454
+
455
+ return {"final_answer": answer}
456
+
457
+ def _clean_answer(self, raw_answer: str) -> str:
458
+ """Clean and format the final answer for exact matching."""
459
+ answer = raw_answer.strip()
460
+
461
+ # Remove common prefixes (case-insensitive, with variations)
462
+ prefixes = [
463
+ "the answer is:", "the answer is", "answer is:",
464
+ "answer:", "answer", "answer:",
465
+ "final answer:", "final answer", "FINAL ANSWER:", "FINAL ANSWER",
466
+ "the final answer is:", "the final answer is",
467
+ "result:", "result", "result is:",
468
+ "solution:", "solution", "solution is:",
469
+ "the solution is:", "the solution is",
470
+ "it is", "it's", "that is", "that's",
471
+ ]
472
+
473
+ answer_lower = answer.lower()
474
+ for prefix in prefixes:
475
+ if answer_lower.startswith(prefix):
476
+ answer = answer[len(prefix):].strip()
477
+ # Remove any leading colon or dash
478
+ answer = answer.lstrip(':').lstrip('-').strip()
479
+ answer_lower = answer.lower()
480
+
481
+ # Remove quotes if they wrap the entire answer
482
+ if (answer.startswith('"') and answer.endswith('"')) or \
483
+ (answer.startswith("'") and answer.endswith("'")):
484
+ answer = answer[1:-1].strip()
485
+
486
+ # Remove trailing periods, commas, or semicolons for single-word/number answers
487
+ if answer and ' ' not in answer:
488
+ answer = answer.rstrip('.,;:')
489
+
490
+ # Remove leading/trailing whitespace and normalize internal whitespace
491
+ answer = ' '.join(answer.split())
492
+
493
+ # Remove markdown formatting if present
494
+ if answer.startswith('**') and answer.endswith('**'):
495
+ answer = answer[2:-2]
496
+ if answer.startswith('*') and answer.endswith('*'):
497
+ answer = answer[1:-1]
498
+
499
+ return answer.strip()
500
+
501
+ def run(self, question: str, task_id: str = "", file_path: str = None) -> str:
502
+ """
503
+ Run the agent on a question.
504
+
505
+ Args:
506
+ question: The GAIA question to answer
507
+ task_id: Optional task identifier
508
+ file_path: Optional path to associated file
509
+
510
+ Returns:
511
+ The agent's final answer
512
+ """
513
+ # Prepare the user message with file priority
514
+ user_content = question
515
+ if file_path and os.path.exists(file_path):
516
+ # Strongly emphasize reading the file first
517
+ user_content = f"[IMPORTANT: A file is available at {file_path}]\n\nYou MUST read this file FIRST using the read_file tool before attempting to answer. The answer is very likely contained in this file.\n\nQuestion: {question}"
518
+
519
+ # Initialize state
520
+ initial_state: AgentState = {
521
+ "messages": [
522
+ SystemMessage(content=SYSTEM_PROMPT),
523
+ HumanMessage(content=user_content)
524
+ ],
525
+ "task_id": task_id,
526
+ "file_path": file_path,
527
+ "file_content": None,
528
+ "iteration_count": 0,
529
+ "final_answer": None
530
+ }
531
+
532
+ # Execute the graph
533
+ try:
534
+ final_state = self.graph.invoke(
535
+ initial_state,
536
+ {"recursion_limit": self.max_iterations * 2 + 5}
537
+ )
538
+ answer = final_state.get("final_answer", "Unable to determine answer")
539
+
540
+ # Final validation - ensure answer is not empty or error message
541
+ if not answer or answer.startswith("Agent error:") or answer.startswith("Unable to determine"):
542
+ # Try to extract from last message if available
543
+ if final_state.get("messages"):
544
+ last_msg = final_state["messages"][-1]
545
+ if hasattr(last_msg, "content") and last_msg.content:
546
+ answer = self._clean_answer(last_msg.content)
547
+
548
+ return answer if answer else "Unable to determine answer"
549
+ except Exception as e:
550
+ # Log error for debugging but return a clean error message
551
+ import logging
552
+ logging.error(f"Agent execution error: {str(e)}")
553
+ return f"Agent error: {str(e)}"
554
+
555
+
556
+ # ============ UTILITY FUNCTIONS ============
557
+ def create_agent(api_key: str = None, model: str = "gpt-4o") -> GAIAAgent:
558
+ """Factory function to create a configured agent."""
559
+ return GAIAAgent(
560
+ model_name=model,
561
+ api_key=api_key,
562
+ temperature=0,
563
+ max_iterations=15
564
+ )
app.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import requests
4
+ import pandas as pd
5
+ import tempfile
6
+ import json
7
+ import logging
8
+ from typing import Optional
9
+
10
+ # Import the optimized agent from the separate module
11
+ from agent_enhanced import GAIAAgent
12
+
13
+ # ============ CONFIGURATION ============
14
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
15
+
16
+ # Set up logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ # ============ API INTERACTION ============
22
+ def fetch_questions(api_url: str = DEFAULT_API_URL, max_retries: int = 3) -> list:
23
+ """Fetch all questions from the GAIA API with retry logic."""
24
+ for attempt in range(max_retries):
25
+ try:
26
+ response = requests.get(f"{api_url}/questions", timeout=30)
27
+ response.raise_for_status()
28
+ return response.json()
29
+ except requests.exceptions.RequestException as e:
30
+ logger.warning(f"Attempt {attempt + 1} failed: {e}")
31
+ if attempt == max_retries - 1:
32
+ raise
33
+ return []
34
+
35
+ def fetch_random_question(api_url: str = DEFAULT_API_URL, max_retries: int = 3) -> dict:
36
+ """Fetch a random question from the GAIA API with retry logic."""
37
+ for attempt in range(max_retries):
38
+ try:
39
+ response = requests.get(f"{api_url}/random-question", timeout=30)
40
+ response.raise_for_status()
41
+ return response.json()
42
+ except requests.exceptions.RequestException as e:
43
+ logger.warning(f"Attempt {attempt + 1} failed: {e}")
44
+ if attempt == max_retries - 1:
45
+ raise
46
+ return {}
47
+
48
+ def fetch_file(task_id: str, api_url: str = DEFAULT_API_URL, max_retries: int = 3) -> Optional[str]:
49
+ """Fetch a file associated with a task with retry logic."""
50
+ for attempt in range(max_retries):
51
+ try:
52
+ response = requests.get(f"{api_url}/files/{task_id}", timeout=30)
53
+ if response.status_code == 200:
54
+ # Save to temp file
55
+ content_disposition = response.headers.get('content-disposition', '')
56
+ filename = f"task_{task_id}_file"
57
+ if 'filename=' in content_disposition:
58
+ filename = content_disposition.split('filename=')[1].strip('"')
59
+
60
+ temp_dir = tempfile.mkdtemp()
61
+ file_path = os.path.join(temp_dir, filename)
62
+
63
+ with open(file_path, 'wb') as f:
64
+ f.write(response.content)
65
+
66
+ logger.info(f"Downloaded file: {file_path}")
67
+ return file_path
68
+ elif response.status_code == 404:
69
+ logger.info(f"No file found for task {task_id}")
70
+ return None
71
+ except requests.exceptions.RequestException as e:
72
+ logger.warning(f"File fetch attempt {attempt + 1} failed: {e}")
73
+ if attempt == max_retries - 1:
74
+ logger.error(f"Failed to fetch file for task {task_id}: {e}")
75
+ return None
76
+
77
+ def submit_answers(username: str, agent_code: str, answers: list, api_url: str = DEFAULT_API_URL, max_retries: int = 3) -> dict:
78
+ """Submit answers to the GAIA API with retry logic."""
79
+ payload = {
80
+ "username": username,
81
+ "agent_code": agent_code,
82
+ "answers": answers
83
+ }
84
+
85
+ for attempt in range(max_retries):
86
+ try:
87
+ response = requests.post(f"{api_url}/submit", json=payload, timeout=60)
88
+ response.raise_for_status()
89
+ return response.json()
90
+ except requests.exceptions.RequestException as e:
91
+ logger.warning(f"Submission attempt {attempt + 1} failed: {e}")
92
+ if attempt == max_retries - 1:
93
+ raise
94
+ return {}
95
+
96
+
97
+ # ============ ANSWER VALIDATION ============
98
+ def validate_answer_format(answer: str) -> tuple[bool, str]:
99
+ """Validate answer format and return (is_valid, warning_message)."""
100
+ if not answer or answer.strip() == "":
101
+ return False, "Warning: Answer is empty"
102
+
103
+ # Check for common prefixes that should be removed
104
+ prefixes = ["FINAL ANSWER:", "The answer is:", "Answer:", "final answer:"]
105
+ answer_lower = answer.lower()
106
+ for prefix in prefixes:
107
+ if answer_lower.startswith(prefix.lower()):
108
+ return False, f"Warning: Answer contains prefix '{prefix}' which will be removed. Consider removing it."
109
+
110
+ # Check for explanations (multiple sentences)
111
+ if answer.count('.') > 1 or answer.count('because') > 0 or answer.count('since') > 0:
112
+ return False, "Warning: Answer may contain explanations. Only the answer should be submitted."
113
+
114
+ return True, ""
115
+
116
+ # ============ GRADIO INTERFACE ============
117
+ def run_agent_on_questions(openai_api_key: str, progress=gr.Progress()):
118
+ """Run the agent on all GAIA questions."""
119
+ if not openai_api_key:
120
+ return "Please provide your OpenAI API key.", None
121
+
122
+ try:
123
+ # Initialize agent
124
+ progress(0, desc="Initializing agent...")
125
+ agent = GAIAAgent(api_key=openai_api_key)
126
+
127
+ # Fetch questions
128
+ progress(0.05, desc="Fetching questions from API...")
129
+ questions = fetch_questions()
130
+
131
+ if not questions:
132
+ return "Error: Failed to fetch questions from API. Please try again.", None
133
+
134
+ total_questions = len(questions)
135
+ results = []
136
+ answers_for_submission = []
137
+
138
+ for i, q in enumerate(questions):
139
+ progress((i + 1) / total_questions, desc=f"Processing question {i+1}/{total_questions}...")
140
+
141
+ task_id = q.get("task_id", "")
142
+ question_text = q.get("question", "")
143
+
144
+ # Check if there's an associated file
145
+ file_path = None
146
+ if q.get("file_name"):
147
+ progress((i + 0.5) / total_questions, desc=f"Downloading file for question {i+1}...")
148
+ file_path = fetch_file(task_id)
149
+
150
+ # Run agent
151
+ try:
152
+ progress((i + 0.7) / total_questions, desc=f"Agent reasoning for question {i+1}...")
153
+ answer = agent.run(question_text, task_id, file_path)
154
+
155
+ # Validate answer format
156
+ is_valid, warning = validate_answer_format(answer)
157
+ if not is_valid:
158
+ logger.warning(f"Question {i+1} ({task_id}): {warning}")
159
+
160
+ except Exception as e:
161
+ logger.error(f"Error processing question {i+1} ({task_id}): {e}")
162
+ answer = f"Error: {str(e)}"
163
+
164
+ results.append({
165
+ "Task ID": task_id,
166
+ "Question": question_text[:100] + "..." if len(question_text) > 100 else question_text,
167
+ "Answer": answer,
168
+ "Status": "βœ“" if answer and not answer.startswith("Error:") else "βœ—"
169
+ })
170
+
171
+ answers_for_submission.append({
172
+ "task_id": task_id,
173
+ "submitted_answer": answer
174
+ })
175
+
176
+ # Cleanup temp file
177
+ if file_path and os.path.exists(file_path):
178
+ try:
179
+ os.remove(file_path)
180
+ # Also try to remove temp directory if empty
181
+ temp_dir = os.path.dirname(file_path)
182
+ if os.path.exists(temp_dir):
183
+ try:
184
+ os.rmdir(temp_dir)
185
+ except:
186
+ pass
187
+ except Exception as e:
188
+ logger.warning(f"Failed to cleanup file {file_path}: {e}")
189
+
190
+ df = pd.DataFrame(results)
191
+ progress(1.0, desc="Complete!")
192
+ return df, answers_for_submission
193
+
194
+ except Exception as e:
195
+ logger.error(f"Error in run_agent_on_questions: {e}")
196
+ return f"Error: {str(e)}", None
197
+
198
+
199
+ def submit_to_leaderboard(username: str, space_url: str, answers_json: str):
200
+ """Submit answers to the leaderboard."""
201
+ if not username or not space_url or not answers_json:
202
+ return "Please fill in all fields and run the agent first."
203
+
204
+ try:
205
+ answers = json.loads(answers_json) if isinstance(answers_json, str) else answers_json
206
+
207
+ if not isinstance(answers, list) or len(answers) == 0:
208
+ return "Error: Answers must be a non-empty list. Please run the agent first."
209
+
210
+ # Validate answer format before submission
211
+ warnings = []
212
+ for ans in answers:
213
+ if "task_id" not in ans or "submitted_answer" not in ans:
214
+ return "Error: Invalid answer format. Each answer must have 'task_id' and 'submitted_answer'."
215
+ is_valid, warning = validate_answer_format(ans.get("submitted_answer", ""))
216
+ if not is_valid:
217
+ warnings.append(f"Task {ans.get('task_id')}: {warning}")
218
+
219
+ # Ensure space URL ends with /tree/main
220
+ if not space_url.endswith("/tree/main"):
221
+ space_url = space_url.rstrip("/") + "/tree/main"
222
+
223
+ # Submit to API
224
+ result = submit_answers(username, space_url, answers)
225
+
226
+ score = result.get("score", 0)
227
+ correct = result.get("correct_count", 0)
228
+ total = result.get("total_attempted", 0)
229
+
230
+ warning_text = ""
231
+ if warnings:
232
+ warning_text = f"\n\n⚠️ **Warnings:**\n" + "\n".join(f"- {w}" for w in warnings[:5])
233
+ if len(warnings) > 5:
234
+ warning_text += f"\n- ... and {len(warnings) - 5} more warnings"
235
+
236
+ return f"""
237
+ ## Submission Successful! πŸŽ‰
238
+
239
+ **Score:** {score:.1%}
240
+ **Correct:** {correct}/{total}
241
+
242
+ {'πŸ† Congratulations! You passed the 30% threshold!' if score >= 0.3 else 'πŸ“ˆ Keep improving! You need 30% to earn your certificate.'}
243
+ {warning_text}
244
+
245
+ Check the [leaderboard](https://huggingface.co/spaces/agents-course/Students_leaderboard) to see your ranking!
246
+ """
247
+ except json.JSONDecodeError as e:
248
+ return f"Error: Invalid JSON format. Please run the agent first.\nDetails: {str(e)}"
249
+ except Exception as e:
250
+ logger.error(f"Submission error: {e}")
251
+ return f"Submission error: {str(e)}"
252
+
253
+
254
+ def test_single_question(openai_api_key: str):
255
+ """Test the agent on a single random question."""
256
+ if not openai_api_key:
257
+ return "Please provide your OpenAI API key.", "", "", ""
258
+
259
+ try:
260
+ agent = GAIAAgent(api_key=openai_api_key)
261
+ question_data = fetch_random_question()
262
+
263
+ if not question_data:
264
+ return "Error: Failed to fetch question from API.", "", "", ""
265
+
266
+ task_id = question_data.get("task_id", "")
267
+ question_text = question_data.get("question", "")
268
+
269
+ file_path = None
270
+ if question_data.get("file_name"):
271
+ file_path = fetch_file(task_id)
272
+
273
+ answer = agent.run(question_text, task_id, file_path)
274
+
275
+ # Validate answer format
276
+ is_valid, warning = validate_answer_format(answer)
277
+ validation_status = "βœ“ Valid format" if is_valid else f"⚠️ {warning}"
278
+
279
+ # Cleanup temp file
280
+ if file_path and os.path.exists(file_path):
281
+ try:
282
+ os.remove(file_path)
283
+ temp_dir = os.path.dirname(file_path)
284
+ if os.path.exists(temp_dir):
285
+ try:
286
+ os.rmdir(temp_dir)
287
+ except:
288
+ pass
289
+ except Exception as e:
290
+ logger.warning(f"Failed to cleanup file: {e}")
291
+
292
+ return question_text, answer, task_id, validation_status
293
+
294
+ except Exception as e:
295
+ logger.error(f"Error in test_single_question: {e}")
296
+ return f"Error: {str(e)}", "", "", ""
297
+
298
+
299
+ # ============ BUILD GRADIO APP ============
300
+ with gr.Blocks(title="GAIA Agent - LangGraph", theme=gr.themes.Soft()) as demo:
301
+ gr.Markdown("""
302
+ # πŸ€– GAIA Benchmark Agent (LangGraph)
303
+
304
+ This agent uses **LangGraph** to solve GAIA benchmark questions. It has access to:
305
+ - πŸ” Web Search (DuckDuckGo)
306
+ - πŸ“š Wikipedia Search
307
+ - 🐍 Python Code Execution
308
+ - πŸ“„ File Reading (PDF, Text, Excel)
309
+ - πŸ”’ Calculator
310
+
311
+ ## Instructions
312
+ 1. Enter your OpenAI API key
313
+ 2. Test with a single question or run on all questions
314
+ 3. Submit your answers to the leaderboard
315
+ """)
316
+
317
+ with gr.Row():
318
+ openai_key = gr.Textbox(
319
+ label="OpenAI API Key",
320
+ type="password",
321
+ placeholder="sk-...",
322
+ info="Required for GPT-4o"
323
+ )
324
+
325
+ with gr.Tabs():
326
+ with gr.TabItem("πŸ§ͺ Test Single Question"):
327
+ test_btn = gr.Button("Fetch & Solve Random Question", variant="primary")
328
+ test_question = gr.Textbox(label="Question", lines=5, interactive=False)
329
+ test_answer = gr.Textbox(label="Agent's Answer", lines=3, interactive=False)
330
+ test_task_id = gr.Textbox(label="Task ID", interactive=False)
331
+ test_validation = gr.Textbox(label="Answer Validation", interactive=False)
332
+
333
+ test_btn.click(
334
+ test_single_question,
335
+ inputs=[openai_key],
336
+ outputs=[test_question, test_answer, test_task_id, test_validation]
337
+ )
338
+
339
+ with gr.TabItem("πŸš€ Run Full Benchmark"):
340
+ run_btn = gr.Button("Run Agent on All Questions", variant="primary")
341
+ results_table = gr.Dataframe(label="Results")
342
+ answers_state = gr.State()
343
+
344
+ run_btn.click(
345
+ run_agent_on_questions,
346
+ inputs=[openai_key],
347
+ outputs=[results_table, answers_state]
348
+ )
349
+
350
+ with gr.TabItem("πŸ“€ Submit to Leaderboard"):
351
+ gr.Markdown("""
352
+ ### Submit Your Results
353
+
354
+ After running the full benchmark, fill in your details and submit to the leaderboard.
355
+
356
+ **Requirements:**
357
+ - Your HuggingFace username
358
+ - Your Space URL (must end with `/tree/main`)
359
+ - Answers will be auto-filled after running the benchmark
360
+ """)
361
+
362
+ with gr.Row():
363
+ username_input = gr.Textbox(
364
+ label="HuggingFace Username",
365
+ placeholder="your-username",
366
+ info="Your HuggingFace account username"
367
+ )
368
+ space_url_input = gr.Textbox(
369
+ label="Your Space URL",
370
+ placeholder="https://huggingface.co/spaces/your-username/your-space",
371
+ info="Full URL to your Space (will auto-append /tree/main if needed)"
372
+ )
373
+
374
+ answers_input = gr.Textbox(
375
+ label="Answers JSON (auto-filled after running benchmark)",
376
+ lines=10,
377
+ placeholder="Run the full benchmark first...",
378
+ info="This will be automatically populated after running the benchmark"
379
+ )
380
+
381
+ submit_btn = gr.Button("Submit to Leaderboard", variant="primary")
382
+ submit_result = gr.Markdown()
383
+
384
+ # Auto-fill answers when benchmark completes
385
+ def format_answers(answers):
386
+ if answers:
387
+ return json.dumps(answers, indent=2)
388
+ return ""
389
+
390
+ answers_state.change(format_answers, inputs=[answers_state], outputs=[answers_input])
391
+
392
+ submit_btn.click(
393
+ submit_to_leaderboard,
394
+ inputs=[username_input, space_url_input, answers_input],
395
+ outputs=[submit_result]
396
+ )
397
+
398
+ gr.Markdown("""
399
+ ---
400
+ ### πŸ“‹ Tips for Better Scores
401
+
402
+ **Answer Formatting:**
403
+ - Answers are matched **exactly** (character-for-character), so precision is critical
404
+ - Do NOT include prefixes like "FINAL ANSWER:" or "The answer is:"
405
+ - For lists: use comma-separated format with NO spaces (e.g., "item1,item2,item3")
406
+ - For numbers: just the number, no units unless specified
407
+ - Check the validation status in the test tab
408
+
409
+ **Agent Capabilities:**
410
+ - Uses GPT-4o for optimal reasoning
411
+ - Automatically reads files (PDFs, Excel, text) when available
412
+ - Web search for current information
413
+ - Wikipedia for factual lookups
414
+ - Python execution for calculations
415
+
416
+ **Best Practices:**
417
+ 1. Test with a single question first to verify the agent works
418
+ 2. Run the full benchmark (takes ~10-15 minutes)
419
+ 3. Review answers before submission
420
+ 4. Ensure your Space is public for verification
421
+
422
+ ### πŸ”— Links
423
+ - [GAIA Benchmark](https://huggingface.co/spaces/gaia-benchmark/leaderboard)
424
+ - [Student Leaderboard](https://huggingface.co/spaces/agents-course/Students_leaderboard)
425
+ - [Course Unit 4](https://huggingface.co/learn/agents-course/en/unit4/hands-on)
426
+ - [API Documentation](https://agents-course-unit4-scoring.hf.space/docs)
427
+ """)
428
+
429
+ if __name__ == "__main__":
430
+ # For HuggingFace Spaces, use share=False
431
+ # For local development, you can use share=True to get a public link
432
+ demo.launch(server_name="0.0.0.0", server_port=7860)
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ name,output,flag,username,timestamp
2
+ asdf,Hello asdf!!,,,2026-01-16 10:40:06.831644
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0,<5.0.0
3
+ requests>=2.31.0,<3.0.0
4
+ pandas>=2.0.0,<3.0.0
5
+
6
+ # LangChain & LangGraph
7
+ langgraph>=0.2.0,<1.0.0
8
+ langchain>=0.2.0,<1.0.0
9
+ langchain-core>=0.2.0,<1.0.0
10
+ langchain-openai>=0.1.0,<1.0.0
11
+ langchain-community>=0.2.0,<1.0.0
12
+ langchain-experimental>=0.0.60,<1.0.0
13
+
14
+ # Tools dependencies
15
+ duckduckgo-search>=6.0.0,<7.0.0
16
+ pypdf>=4.0.0,<5.0.0
17
+ openpyxl>=3.1.0,<4.0.0
18
+
19
+ # Utilities
20
+ python-dotenv>=1.0.0,<2.0.0