hemantvirmani commited on
Commit
407e466
·
1 Parent(s): 43a2c8c

Refactor agent architecture with wrapper pattern and modular structure

Browse files

This commit restructures the agent codebase to support multiple agent
implementations through a flexible wrapper pattern, improving code
organization and maintainability.

MAJOR CHANGES:

1. Agent Architecture Refactoring:
- Created MyGAIAAgents wrapper class for managing multiple agent types
- Renamed MyLangGraphAgent → LangGraphAgent for consistency
- Moved LangGraphAgent to dedicated langgraphagent.py module
- Renamed internal helper methods to private (prefixed with _)
- Added ACTIVE_AGENT config variable to switch between agent types

2. New Files:
- langgraphagent.py: Standalone LangGraphAgent implementation with
private methods (_create_llm_client, _init_questions, _assistant,
_should_continue, _build_graph)

3. Code Organization Improvements:
- agents.py: Now contains only the MyGAIAAgents wrapper (35 lines vs 337)
- Better separation of concerns between wrapper and implementation
- Cleaner import structure across the codebase

4. API Refactoring (app.py):
- Moved ResultFormatter.format_for_api() call into submit_and_score()
- submit_and_score() now accepts raw results instead of formatted payload
- run_and_submit_all() simplified - passes raw results directly
- Renamed results_for_display → logs_for_display in run_test_code()
- Better responsibility distribution between functions

5. Configuration (config.py):
- Added ACTIVE_AGENT = "LangGraph" configuration variable
- Supports future agent types: ReActLangGraph, LLamaIndex, SMOL

6. Updated Imports:
- agent_runner.py: MyLangGraphAgent → MyGAIAAgents
- app.py: Updated docstrings and removed unused import

7. Documentation (README.md):
- Updated code examples to use MyGAIAAgents wrapper
- Replaced "Change LLM Provider" section with "Change Agent Type"
- Documents ACTIVE_AGENT configuration usage

BENEFITS:
- Easier to add new agent implementations (just add to wrapper)
- Better code modularity and single responsibility principle
- Cleaner API boundaries between functions
- Improved testability with separated concerns
- Future-ready for multiple agent architectures

FILES CHANGED:
- Modified: README.md, agent_runner.py, agents.py, app.py, config.py
- Added: langgraphagent.py

Files changed (6) hide show
  1. README.md +9 -11
  2. agent_runner.py +2 -2
  3. agents.py +20 -290
  4. app.py +23 -21
  5. config.py +2 -0
  6. langgraphagent.py +305 -0
README.md CHANGED
@@ -113,10 +113,10 @@ Edit the question indices in [app.py:196](app.py#L196) to customize which questi
113
  ### Using the Agent Programmatically
114
 
115
  ```python
116
- from agents import MyLangGraphAgent
117
 
118
- # Initialize agent
119
- agent = MyLangGraphAgent()
120
 
121
  # Ask a question
122
  answer = agent("What is the capital of France?")
@@ -174,19 +174,17 @@ The agent follows strict output formatting rules defined in [system_prompt.py](s
174
 
175
  ## Configuration
176
 
177
- ### Change LLM Provider
178
 
179
- Edit [agents.py:52](agents.py#L52) in the `create_llm_client` method:
180
 
181
  ```python
182
- # Use Google Gemini (default)
183
- agent = MyLangGraphAgent()
184
-
185
- # Use Hugging Face models
186
- def create_llm_client(self, model_provider: str = "huggingface"):
187
- # ...
188
  ```
189
 
 
 
190
  ### Adjust Step Limits
191
 
192
  Modify the maximum iteration count in [agents.py:169](agents.py#L169):
 
113
  ### Using the Agent Programmatically
114
 
115
  ```python
116
+ from agents import MyGAIAAgents
117
 
118
+ # Initialize agent (automatically uses ACTIVE_AGENT from config)
119
+ agent = MyGAIAAgents()
120
 
121
  # Ask a question
122
  answer = agent("What is the capital of France?")
 
174
 
175
  ## Configuration
176
 
177
+ ### Change Agent Type
178
 
179
+ Edit the `ACTIVE_AGENT` variable in [config.py:32](config.py#L32):
180
 
181
  ```python
182
+ # Valid values: "LangGraph", "ReActLangGraph", "LLamaIndex", "SMOL"
183
+ ACTIVE_AGENT = "LangGraph" # Currently only LangGraph is implemented
 
 
 
 
184
  ```
185
 
186
+ The `MyGAIAAgents` wrapper class will automatically instantiate the correct agent based on this configuration.
187
+
188
  ### Adjust Step Limits
189
 
190
  Modify the maximum iteration count in [agents.py:169](agents.py#L169):
agent_runner.py CHANGED
@@ -2,7 +2,7 @@
2
 
3
  from typing import Optional, Tuple, List, Dict
4
  from colorama import Fore, Style
5
- from agents import MyLangGraphAgent
6
  import config
7
 
8
 
@@ -16,7 +16,7 @@ class AgentRunner:
16
  def initialize_agent(self) -> bool:
17
  """Initialize the agent. Returns True if successful."""
18
  try:
19
- self.agent = MyLangGraphAgent()
20
  return True
21
  except Exception as e:
22
  print(f"{Fore.RED}Error instantiating agent: {e}{Style.RESET_ALL}")
 
2
 
3
  from typing import Optional, Tuple, List, Dict
4
  from colorama import Fore, Style
5
+ from agents import MyGAIAAgents
6
  import config
7
 
8
 
 
16
  def initialize_agent(self) -> bool:
17
  """Initialize the agent. Returns True if successful."""
18
  try:
19
+ self.agent = MyGAIAAgents()
20
  return True
21
  except Exception as e:
22
  print(f"{Fore.RED}Error instantiating agent: {e}{Style.RESET_ALL}")
agents.py CHANGED
@@ -1,305 +1,35 @@
1
- import os
2
- import logging
3
- import warnings
4
- import re
5
- import time
6
 
7
- # Suppress TensorFlow/Keras warnings
8
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
- logging.getLogger('tensorflow').setLevel(logging.ERROR)
10
- warnings.filterwarnings('ignore', module='tensorflow')
11
- warnings.filterwarnings('ignore', module='tf_keras')
12
-
13
- from typing import TypedDict, Optional, List, Annotated
14
- from langchain_core.messages import HumanMessage, SystemMessage
15
- from langgraph.graph import MessagesState, StateGraph, START, END
16
- from langgraph.graph.message import add_messages
17
- from langgraph.prebuilt import tools_condition
18
- from langgraph.prebuilt import ToolNode
19
- from langchain_google_genai import ChatGoogleGenerativeAI
20
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
21
-
22
- from custom_tools import get_custom_tools_list
23
- from system_prompt import SYSTEM_PROMPT
24
  import config
25
-
26
- # Suppress BeautifulSoup GuessedAtParserWarning
27
- try:
28
- from bs4 import GuessedAtParserWarning
29
- warnings.filterwarnings('ignore', category=GuessedAtParserWarning)
30
- except ImportError:
31
- pass
32
 
33
 
34
- class AgentState(TypedDict):
35
- question: str
36
- messages: Annotated[list , add_messages] # for LangGraph
37
- answer: str
38
- step_count: int # Track number of iterations to prevent infinite loops
39
- file_name: str # Optional file name for questions that reference files
40
 
41
-
42
- class MyLangGraphAgent:
 
43
 
44
  def __init__(self):
45
- # Validate API keys
46
- if not os.getenv("GOOGLE_API_KEY"):
47
- print("WARNING: GOOGLE_API_KEY not found - analyze_youtube_video will fail")
48
-
49
- self.tools = get_custom_tools_list()
50
- self.llm_client_with_tools = self.create_llm_client()
51
- self.graph = self.build_graph()
52
-
53
- def create_llm_client(self, model_provider: str = "google"):
54
- """Create and return the LLM client with tools bound based on the model provider."""
55
-
56
- if model_provider == "google":
57
- apikey = os.getenv("GOOGLE_API_KEY")
58
-
59
- return ChatGoogleGenerativeAI(
60
- model="gemini-2.5-flash", # Changed from gemini-2.5-flash-lite - better tool calling
61
- temperature=0,
62
- api_key=apikey,
63
- timeout=60 # Add timeout to prevent hanging
64
- ).bind_tools(self.tools)
65
-
66
- elif model_provider == "huggingface":
67
- LLM_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
68
- apikey = os.getenv("HUGGINGFACEHUB_API_TOKEN")
69
-
70
- llmObject = HuggingFaceEndpoint(
71
- repo_id=LLM_MODEL,
72
- task="text-generation",
73
- max_new_tokens=512,
74
- temperature=0.7,
75
- do_sample=False,
76
- repetition_penalty=1.03,
77
- huggingfacehub_api_token=apikey
78
- )
79
- return ChatHuggingFace(llm=llmObject).bind_tools(self.tools)
80
-
81
- # Nodes
82
- def init_questions(self, state: AgentState):
83
- """Initialize the messages in the state with system prompt and user question."""
84
-
85
- # Build the question message, including file name if available
86
- question_content = state["question"]
87
- if state.get("file_name"):
88
- question_content += f'\n\nNote: This question references a file: {state["file_name"]}'
89
-
90
- return {
91
- "messages": [
92
- SystemMessage(content=SYSTEM_PROMPT),
93
- HumanMessage(content=question_content)
94
- ],
95
- "step_count": 0 # Initialize step counter
96
- }
97
-
98
- def assistant(self, state: AgentState):
99
- """Assistant node which calls the LLM with tools"""
100
-
101
- # Track and log current step
102
- current_step = state.get("step_count", 0) + 1
103
- print(f"[STEP {current_step}] Calling assistant with {len(state['messages'])} messages")
104
-
105
- # Invoke LLM with tools enabled, with retry logic for 504 errors
106
- max_retries = config.MAX_RETRIES
107
- delay = config.INITIAL_RETRY_DELAY
108
-
109
- for attempt in range(max_retries + 1):
110
- try:
111
- response = self.llm_client_with_tools.invoke(state["messages"])
112
- # Success - break out of retry loop
113
- break
114
- except Exception as e:
115
- error_msg = str(e)
116
-
117
- # Check if this is a 504 DEADLINE_EXCEEDED error
118
- if "504" in error_msg and "DEADLINE_EXCEEDED" in error_msg:
119
- if attempt < max_retries:
120
- print(f"[RETRY] Attempt {attempt + 1}/{max_retries} failed with 504 DEADLINE_EXCEEDED")
121
- print(f"[RETRY] Retrying in {delay:.1f} seconds...")
122
- time.sleep(delay)
123
- delay *= config.RETRY_BACKOFF_FACTOR
124
- continue
125
- else:
126
- print(f"[RETRY] All {max_retries} retries exhausted for 504 error")
127
- print(f"[ERROR] LLM invocation failed after retries: {e}")
128
- return {
129
- "messages": [],
130
- "answer": f"Error: LLM failed after {max_retries} retries - {str(e)[:100]}",
131
- "step_count": current_step
132
- }
133
- else:
134
- # Not a 504 error - fail immediately without retry
135
- print(f"[ERROR] LLM invocation failed: {e}")
136
- return {
137
- "messages": [],
138
- "answer": f"Error: LLM failed - {str(e)[:100]}",
139
- "step_count": current_step
140
- }
141
-
142
- # If no tool calls, set the final answer
143
- if not response.tool_calls:
144
- content = response.content
145
- print(f"[FINAL ANSWER] Agent produced answer (no tool calls)")
146
-
147
- # Handle case where content is a list (e.g. mixed content from Gemini)
148
- if isinstance(content, list):
149
- # Extract text from list of content parts
150
- text_parts = []
151
- for item in content:
152
- if isinstance(item, dict) and 'text' in item:
153
- text_parts.append(item['text'])
154
- elif hasattr(item, 'text'):
155
- text_parts.append(item.text)
156
- else:
157
- text_parts.append(str(item))
158
- content = " ".join(text_parts)
159
- elif isinstance(content, dict) and 'text' in content:
160
- # Handle single dict with 'text' field
161
- content = content['text']
162
- elif hasattr(content, 'text'):
163
- # Handle object with text attribute
164
- content = content.text
165
- else:
166
- # Fallback to string conversion
167
- content = str(content)
168
-
169
- # Clean up any remaining noise
170
- content = content.strip()
171
- print(f"[EXTRACTED TEXT] {content[:100]}{'...' if len(content) > 100 else ''}")
172
-
173
- return {
174
- "messages": [response],
175
- "answer": content,
176
- "step_count": current_step
177
- }
178
-
179
- # Has tool calls, log them
180
- print(f"[TOOL CALLS] Agent requesting {len(response.tool_calls)} tool(s):")
181
- for tc in response.tool_calls:
182
- print(f" - {tc['name']}")
183
 
184
- return {
185
- "messages": [response],
186
- "step_count": current_step
187
- }
188
-
189
-
190
- def should_continue(self, state: AgentState):
191
- """Check if we should continue or stop based on step count and other conditions."""
192
-
193
- step_count = state.get("step_count", 0)
194
-
195
- # Stop if we've exceeded maximum steps
196
- if step_count >= 40: # Increased from 25 to handle complex multi-step reasoning
197
- print(f"[WARNING] Max steps (40) reached, forcing termination")
198
- # Force a final answer if we don't have one
199
- if not state.get("answer"):
200
- state["answer"] = "Error: Maximum iteration limit reached"
201
- return END
202
-
203
- # Otherwise use the default tools_condition
204
- return tools_condition(state)
205
-
206
-
207
- def build_graph(self):
208
- """Build and return the Compiled Graph for the agent."""
209
-
210
- graph = StateGraph(AgentState)
211
-
212
- # Build graph
213
- graph.add_node("init", self.init_questions)
214
- graph.add_node("assistant", self.assistant)
215
- graph.add_node("tools", ToolNode(self.tools))
216
- graph.add_edge(START, "init")
217
- graph.add_edge("init", "assistant")
218
- graph.add_conditional_edges(
219
- "assistant",
220
- # Use custom should_continue instead of tools_condition
221
- self.should_continue,
222
- )
223
- graph.add_edge("tools", "assistant")
224
- # Compile graph
225
- return graph.compile()
226
 
227
  def __call__(self, question: str, file_name: str = None) -> str:
228
- """Invoke the agent graph with the given question and return the final answer.
229
 
230
  Args:
231
  question: The question to answer
232
  file_name: Optional file name if the question references a file
233
- """
234
-
235
- print(f"\n{'='*60}")
236
- print(f"[AGENT START] Question: {question}")
237
- if file_name:
238
- print(f"[FILE] {file_name}")
239
- print(f"{'='*60}")
240
 
241
- start_time = time.time()
242
-
243
- try:
244
- response = self.graph.invoke(
245
- {"question": question, "messages": [], "answer": None, "step_count": 0, "file_name": file_name or ""},
246
- config={"recursion_limit": 80} # Must be >= 2x step limit (40 * 2 = 80)
247
- )
248
-
249
- elapsed_time = time.time() - start_time
250
- print(f"[AGENT COMPLETE] Time: {elapsed_time:.2f}s")
251
- print(f"{'='*60}\n")
252
-
253
- answer = response.get("answer")
254
- if answer is None:
255
- print("[WARNING] Agent completed but returned None as answer")
256
- return "Error: No answer generated"
257
-
258
- # Final safety check: ensure answer is plain text string
259
- if isinstance(answer, dict):
260
- # If it's a dict, try to extract text field
261
- if 'text' in answer:
262
- answer = answer['text']
263
- else:
264
- answer = str(answer)
265
- print(f"[WARNING] Answer was dict, extracted: {answer[:100]}")
266
- elif isinstance(answer, list):
267
- # If it's a list, extract text from each item
268
- text_parts = []
269
- for item in answer:
270
- if isinstance(item, dict) and 'text' in item:
271
- text_parts.append(item['text'])
272
- else:
273
- text_parts.append(str(item))
274
- answer = " ".join(text_parts)
275
- print(f"[WARNING] Answer was list, extracted: {answer[:100]}")
276
- elif not isinstance(answer, str):
277
- # Convert to string if it's any other type
278
- answer = str(answer)
279
- print(f"[WARNING] Answer was {type(answer)}, converted to string")
280
-
281
- answer = answer.strip()
282
-
283
- # Additional validation for numerical answers
284
- # Remove common formatting issues that break exact matching
285
- if answer:
286
- # Remove comma separators from numbers (e.g., "1,000" -> "1000")
287
- if ',' in answer and answer.replace(',', '').replace('.', '').isdigit():
288
- answer = answer.replace(',', '')
289
- print(f"[VALIDATION] Removed comma separators from answer")
290
-
291
- # Ensure no trailing/leading whitespace or punctuation
292
- answer = answer.strip().rstrip('.')
293
-
294
- # Log if answer looks suspicious (for debugging)
295
- if any(char in answer for char in ['{', '}', '[', ']', '`', '*', '#']):
296
- print(f"[WARNING] Answer contains suspicious formatting characters: {answer[:100]}")
297
-
298
- print(f"[FINAL ANSWER] {answer}")
299
- return answer
300
-
301
- except Exception as e:
302
- elapsed_time = time.time() - start_time
303
- print(f"[AGENT ERROR] Failed after {elapsed_time:.2f}s: {e}")
304
- print(f"{'='*60}\n")
305
- return f"Error: {str(e)[:100]}"
 
1
+ """Agent wrapper module for GAIA Benchmark."""
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import config
4
+ from langgraphagent import LangGraphAgent
 
 
 
 
 
 
5
 
6
 
7
+ class MyGAIAAgents:
8
+ """Wrapper class to manage multiple agent implementations.
 
 
 
 
9
 
10
+ This class provides a unified interface for different agent types.
11
+ The active agent is determined by the ACTIVE_AGENT configuration.
12
+ """
13
 
14
  def __init__(self):
15
+ """Initialize the wrapper with the active agent based on config."""
16
+ active_agent = config.ACTIVE_AGENT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ if active_agent == "LangGraph":
19
+ self.agent = LangGraphAgent()
20
+ else:
21
+ # Default to LangGraph if unknown agent type
22
+ print(f"[WARNING] Unknown agent type '{active_agent}', defaulting to LangGraph")
23
+ self.agent = LangGraphAgent()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def __call__(self, question: str, file_name: str = None) -> str:
26
+ """Invoke the active agent with the given question.
27
 
28
  Args:
29
  question: The question to answer
30
  file_name: Optional file name if the question references a file
 
 
 
 
 
 
 
31
 
32
+ Returns:
33
+ The agent's answer as a string
34
+ """
35
+ return self.agent(question, file_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -13,8 +13,7 @@ init(autoreset=True)
13
  # Import configuration
14
  import config
15
 
16
- # Import agent-related code from agents module
17
- from agents import MyLangGraphAgent
18
  # Import Gradio UI creation function
19
  from gradioapp import create_ui
20
  # Import scoring function for answer verification
@@ -40,13 +39,13 @@ def _submit_to_server(submit_url: str, submission_data: dict) -> dict:
40
  response.raise_for_status()
41
  return response.json()
42
 
43
- def submit_and_score(username: str, answers_payload: list) -> str:
44
  """
45
  Submit answers to the GAIA scoring server and return status message.
46
 
47
  Args:
48
  username: Hugging Face username for submission
49
- answers_payload: List of dicts with {"task_id": str, "submitted_answer": str}
50
 
51
  Returns:
52
  str: Status message (success or error details)
@@ -59,6 +58,14 @@ def submit_and_score(username: str, answers_payload: list) -> str:
59
  print(error_msg)
60
  return error_msg
61
 
 
 
 
 
 
 
 
 
62
  space_id = config.SPACE_ID
63
  submit_url = f"{config.DEFAULT_API_URL}/submit"
64
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
@@ -118,7 +125,7 @@ def submit_and_score(username: str, answers_payload: list) -> str:
118
 
119
  def run_and_submit_all(username: str) -> tuple:
120
  """
121
- Fetches all questions, runs the MyLangGraphAgent on them, submits all answers,
122
  and displays the results.
123
 
124
  Returns:
@@ -142,16 +149,11 @@ def run_and_submit_all(username: str) -> tuple:
142
  if results is None:
143
  return "Error initializing agent.", None
144
 
145
- # Format data structures: one for API submission, one for UI display
146
- answers_for_api = ResultFormatter.format_for_api(results)
147
- results_for_display = ResultFormatter.format_for_display(results)
148
-
149
- if not answers_for_api:
150
- print("Agent did not produce any answers to submit.")
151
- return "Agent did not produce any answers to submit.", pd.DataFrame(results_for_display)
152
 
153
- # Submit answers and get score
154
- status_message = submit_and_score(username, answers_for_api)
155
  results_df = pd.DataFrame(results_for_display)
156
  return status_message, results_df
157
 
@@ -239,8 +241,8 @@ def run_test_code(filter=None) -> pd.DataFrame:
239
  pd.DataFrame: Results and verification output
240
  """
241
  start_time = time.time()
242
- results_for_display = []
243
- results_for_display.append("=== Processing Example Questions One by One ===")
244
 
245
  # Fetch questions (OFFLINE for testing)
246
  try:
@@ -263,10 +265,10 @@ def run_test_code(filter=None) -> pd.DataFrame:
263
  # Apply filter or use all questions
264
  if filter is not None:
265
  questions_to_process = [questions_data[i] for i in filter]
266
- results_for_display.append(f"Testing {len(questions_to_process)} selected questions (indices: {filter})")
267
  else:
268
  questions_to_process = questions_data
269
- results_for_display.append(f"Testing all {len(questions_to_process)} questions")
270
 
271
  # Run agent on selected questions
272
  results = AgentRunner().run_on_questions(questions_to_process)
@@ -274,15 +276,15 @@ def run_test_code(filter=None) -> pd.DataFrame:
274
  if results is None:
275
  return pd.DataFrame(["Error initializing agent."])
276
 
277
- results_for_display.append("\n=== Completed Example Questions ===")
278
 
279
  # Calculate runtime
280
  elapsed_time = time.time() - start_time
281
  minutes = int(elapsed_time // 60)
282
  seconds = int(elapsed_time % 60)
283
 
284
- verify_answers(results, results_for_display, runtime=(minutes, seconds))
285
- return pd.DataFrame(results_for_display)
286
 
287
 
288
  def main() -> None:
 
13
  # Import configuration
14
  import config
15
 
16
+ # Agent-related code is imported via agent_runner module
 
17
  # Import Gradio UI creation function
18
  from gradioapp import create_ui
19
  # Import scoring function for answer verification
 
39
  response.raise_for_status()
40
  return response.json()
41
 
42
+ def submit_and_score(username: str, results: list) -> str:
43
  """
44
  Submit answers to the GAIA scoring server and return status message.
45
 
46
  Args:
47
  username: Hugging Face username for submission
48
+ results: List of tuples (task_id, question_text, answer)
49
 
50
  Returns:
51
  str: Status message (success or error details)
 
58
  print(error_msg)
59
  return error_msg
60
 
61
+ # Format results for API submission
62
+ answers_payload = ResultFormatter.format_for_api(results)
63
+
64
+ if not answers_payload:
65
+ error_msg = "No answers to submit."
66
+ print(error_msg)
67
+ return error_msg
68
+
69
  space_id = config.SPACE_ID
70
  submit_url = f"{config.DEFAULT_API_URL}/submit"
71
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
 
125
 
126
  def run_and_submit_all(username: str) -> tuple:
127
  """
128
+ Fetches all questions, runs the GAIA agent on them, submits all answers,
129
  and displays the results.
130
 
131
  Returns:
 
149
  if results is None:
150
  return "Error initializing agent.", None
151
 
152
+ # Submit answers and get score (formatting happens inside submit_and_score)
153
+ status_message = submit_and_score(username, results)
 
 
 
 
 
154
 
155
+ # Format results for UI display
156
+ results_for_display = ResultFormatter.format_for_display(results)
157
  results_df = pd.DataFrame(results_for_display)
158
  return status_message, results_df
159
 
 
241
  pd.DataFrame: Results and verification output
242
  """
243
  start_time = time.time()
244
+ logs_for_display = []
245
+ logs_for_display.append("=== Processing Example Questions One by One ===")
246
 
247
  # Fetch questions (OFFLINE for testing)
248
  try:
 
265
  # Apply filter or use all questions
266
  if filter is not None:
267
  questions_to_process = [questions_data[i] for i in filter]
268
+ logs_for_display.append(f"Testing {len(questions_to_process)} selected questions (indices: {filter})")
269
  else:
270
  questions_to_process = questions_data
271
+ logs_for_display.append(f"Testing all {len(questions_to_process)} questions")
272
 
273
  # Run agent on selected questions
274
  results = AgentRunner().run_on_questions(questions_to_process)
 
276
  if results is None:
277
  return pd.DataFrame(["Error initializing agent."])
278
 
279
+ logs_for_display.append("\n=== Completed Example Questions ===")
280
 
281
  # Calculate runtime
282
  elapsed_time = time.time() - start_time
283
  minutes = int(elapsed_time // 60)
284
  seconds = int(elapsed_time % 60)
285
 
286
+ verify_answers(results, logs_for_display, runtime=(minutes, seconds))
287
+ return pd.DataFrame(logs_for_display)
288
 
289
 
290
  def main() -> None:
config.py CHANGED
@@ -29,6 +29,8 @@ SPACE_HOST = os.getenv("SPACE_HOST")
29
  SPACE_ID = os.getenv("SPACE_ID")
30
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
31
 
 
 
32
  # Model Configuration
33
  GEMINI_MODEL = "gemini-2.5-flash"
34
  GEMINI_TEMPERATURE = 0
 
29
  SPACE_ID = os.getenv("SPACE_ID")
30
  GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
31
 
32
+ ACTIVE_AGENT = "LangGraph" # Valid vales are ReActLangGraph, LLamaIndex, LangGraph, SMOL
33
+
34
  # Model Configuration
35
  GEMINI_MODEL = "gemini-2.5-flash"
36
  GEMINI_TEMPERATURE = 0
langgraphagent.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+ import re
5
+ import time
6
+
7
+ # Suppress TensorFlow/Keras warnings
8
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
+ logging.getLogger('tensorflow').setLevel(logging.ERROR)
10
+ warnings.filterwarnings('ignore', module='tensorflow')
11
+ warnings.filterwarnings('ignore', module='tf_keras')
12
+
13
+ from typing import TypedDict, Optional, List, Annotated
14
+ from langchain_core.messages import HumanMessage, SystemMessage
15
+ from langgraph.graph import MessagesState, StateGraph, START, END
16
+ from langgraph.graph.message import add_messages
17
+ from langgraph.prebuilt import tools_condition
18
+ from langgraph.prebuilt import ToolNode
19
+ from langchain_google_genai import ChatGoogleGenerativeAI
20
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
21
+
22
+ from custom_tools import get_custom_tools_list
23
+ from system_prompt import SYSTEM_PROMPT
24
+ import config
25
+
26
+ # Suppress BeautifulSoup GuessedAtParserWarning
27
+ try:
28
+ from bs4 import GuessedAtParserWarning
29
+ warnings.filterwarnings('ignore', category=GuessedAtParserWarning)
30
+ except ImportError:
31
+ pass
32
+
33
+
34
+ class AgentState(TypedDict):
35
+ question: str
36
+ messages: Annotated[list , add_messages] # for LangGraph
37
+ answer: str
38
+ step_count: int # Track number of iterations to prevent infinite loops
39
+ file_name: str # Optional file name for questions that reference files
40
+
41
+
42
+ class LangGraphAgent:
43
+
44
+ def __init__(self):
45
+ # Validate API keys
46
+ if not os.getenv("GOOGLE_API_KEY"):
47
+ print("WARNING: GOOGLE_API_KEY not found - analyze_youtube_video will fail")
48
+
49
+ self.tools = get_custom_tools_list()
50
+ self.llm_client_with_tools = self._create_llm_client()
51
+ self.graph = self._build_graph()
52
+
53
+ def _create_llm_client(self, model_provider: str = "google"):
54
+ """Create and return the LLM client with tools bound based on the model provider."""
55
+
56
+ if model_provider == "google":
57
+ apikey = os.getenv("GOOGLE_API_KEY")
58
+
59
+ return ChatGoogleGenerativeAI(
60
+ model="gemini-2.5-flash", # Changed from gemini-2.5-flash-lite - better tool calling
61
+ temperature=0,
62
+ api_key=apikey,
63
+ timeout=60 # Add timeout to prevent hanging
64
+ ).bind_tools(self.tools)
65
+
66
+ elif model_provider == "huggingface":
67
+ LLM_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
68
+ apikey = os.getenv("HUGGINGFACEHUB_API_TOKEN")
69
+
70
+ llmObject = HuggingFaceEndpoint(
71
+ repo_id=LLM_MODEL,
72
+ task="text-generation",
73
+ max_new_tokens=512,
74
+ temperature=0.7,
75
+ do_sample=False,
76
+ repetition_penalty=1.03,
77
+ huggingfacehub_api_token=apikey
78
+ )
79
+ return ChatHuggingFace(llm=llmObject).bind_tools(self.tools)
80
+
81
+ # Nodes
82
+ def _init_questions(self, state: AgentState):
83
+ """Initialize the messages in the state with system prompt and user question."""
84
+
85
+ # Build the question message, including file name if available
86
+ question_content = state["question"]
87
+ if state.get("file_name"):
88
+ question_content += f'\n\nNote: This question references a file: {state["file_name"]}'
89
+
90
+ return {
91
+ "messages": [
92
+ SystemMessage(content=SYSTEM_PROMPT),
93
+ HumanMessage(content=question_content)
94
+ ],
95
+ "step_count": 0 # Initialize step counter
96
+ }
97
+
98
+ def _assistant(self, state: AgentState):
99
+ """Assistant node which calls the LLM with tools"""
100
+
101
+ # Track and log current step
102
+ current_step = state.get("step_count", 0) + 1
103
+ print(f"[STEP {current_step}] Calling assistant with {len(state['messages'])} messages")
104
+
105
+ # Invoke LLM with tools enabled, with retry logic for 504 errors
106
+ max_retries = config.MAX_RETRIES
107
+ delay = config.INITIAL_RETRY_DELAY
108
+
109
+ for attempt in range(max_retries + 1):
110
+ try:
111
+ response = self.llm_client_with_tools.invoke(state["messages"])
112
+ # Success - break out of retry loop
113
+ break
114
+ except Exception as e:
115
+ error_msg = str(e)
116
+
117
+ # Check if this is a 504 DEADLINE_EXCEEDED error
118
+ if "504" in error_msg and "DEADLINE_EXCEEDED" in error_msg:
119
+ if attempt < max_retries:
120
+ print(f"[RETRY] Attempt {attempt + 1}/{max_retries} failed with 504 DEADLINE_EXCEEDED")
121
+ print(f"[RETRY] Retrying in {delay:.1f} seconds...")
122
+ time.sleep(delay)
123
+ delay *= config.RETRY_BACKOFF_FACTOR
124
+ continue
125
+ else:
126
+ print(f"[RETRY] All {max_retries} retries exhausted for 504 error")
127
+ print(f"[ERROR] LLM invocation failed after retries: {e}")
128
+ return {
129
+ "messages": [],
130
+ "answer": f"Error: LLM failed after {max_retries} retries - {str(e)[:100]}",
131
+ "step_count": current_step
132
+ }
133
+ else:
134
+ # Not a 504 error - fail immediately without retry
135
+ print(f"[ERROR] LLM invocation failed: {e}")
136
+ return {
137
+ "messages": [],
138
+ "answer": f"Error: LLM failed - {str(e)[:100]}",
139
+ "step_count": current_step
140
+ }
141
+
142
+ # If no tool calls, set the final answer
143
+ if not response.tool_calls:
144
+ content = response.content
145
+ print(f"[FINAL ANSWER] Agent produced answer (no tool calls)")
146
+
147
+ # Handle case where content is a list (e.g. mixed content from Gemini)
148
+ if isinstance(content, list):
149
+ # Extract text from list of content parts
150
+ text_parts = []
151
+ for item in content:
152
+ if isinstance(item, dict) and 'text' in item:
153
+ text_parts.append(item['text'])
154
+ elif hasattr(item, 'text'):
155
+ text_parts.append(item.text)
156
+ else:
157
+ text_parts.append(str(item))
158
+ content = " ".join(text_parts)
159
+ elif isinstance(content, dict) and 'text' in content:
160
+ # Handle single dict with 'text' field
161
+ content = content['text']
162
+ elif hasattr(content, 'text'):
163
+ # Handle object with text attribute
164
+ content = content.text
165
+ else:
166
+ # Fallback to string conversion
167
+ content = str(content)
168
+
169
+ # Clean up any remaining noise
170
+ content = content.strip()
171
+ print(f"[EXTRACTED TEXT] {content[:100]}{'...' if len(content) > 100 else ''}")
172
+
173
+ return {
174
+ "messages": [response],
175
+ "answer": content,
176
+ "step_count": current_step
177
+ }
178
+
179
+ # Has tool calls, log them
180
+ print(f"[TOOL CALLS] Agent requesting {len(response.tool_calls)} tool(s):")
181
+ for tc in response.tool_calls:
182
+ print(f" - {tc['name']}")
183
+
184
+ return {
185
+ "messages": [response],
186
+ "step_count": current_step
187
+ }
188
+
189
+
190
+ def _should_continue(self, state: AgentState):
191
+ """Check if we should continue or stop based on step count and other conditions."""
192
+
193
+ step_count = state.get("step_count", 0)
194
+
195
+ # Stop if we've exceeded maximum steps
196
+ if step_count >= 40: # Increased from 25 to handle complex multi-step reasoning
197
+ print(f"[WARNING] Max steps (40) reached, forcing termination")
198
+ # Force a final answer if we don't have one
199
+ if not state.get("answer"):
200
+ state["answer"] = "Error: Maximum iteration limit reached"
201
+ return END
202
+
203
+ # Otherwise use the default tools_condition
204
+ return tools_condition(state)
205
+
206
+
207
+ def _build_graph(self):
208
+ """Build and return the Compiled Graph for the agent."""
209
+
210
+ graph = StateGraph(AgentState)
211
+
212
+ # Build graph
213
+ graph.add_node("init", self._init_questions)
214
+ graph.add_node("assistant", self._assistant)
215
+ graph.add_node("tools", ToolNode(self.tools))
216
+ graph.add_edge(START, "init")
217
+ graph.add_edge("init", "assistant")
218
+ graph.add_conditional_edges(
219
+ "assistant",
220
+ # Use custom should_continue instead of tools_condition
221
+ self._should_continue,
222
+ )
223
+ graph.add_edge("tools", "assistant")
224
+ # Compile graph
225
+ return graph.compile()
226
+
227
+ def __call__(self, question: str, file_name: str = None) -> str:
228
+ """Invoke the agent graph with the given question and return the final answer.
229
+
230
+ Args:
231
+ question: The question to answer
232
+ file_name: Optional file name if the question references a file
233
+ """
234
+
235
+ print(f"\n{'='*60}")
236
+ print(f"[AGENT START] Question: {question}")
237
+ if file_name:
238
+ print(f"[FILE] {file_name}")
239
+ print(f"{'='*60}")
240
+
241
+ start_time = time.time()
242
+
243
+ try:
244
+ response = self.graph.invoke(
245
+ {"question": question, "messages": [], "answer": None, "step_count": 0, "file_name": file_name or ""},
246
+ config={"recursion_limit": 80} # Must be >= 2x step limit (40 * 2 = 80)
247
+ )
248
+
249
+ elapsed_time = time.time() - start_time
250
+ print(f"[AGENT COMPLETE] Time: {elapsed_time:.2f}s")
251
+ print(f"{'='*60}\n")
252
+
253
+ answer = response.get("answer")
254
+ if answer is None:
255
+ print("[WARNING] Agent completed but returned None as answer")
256
+ return "Error: No answer generated"
257
+
258
+ # Final safety check: ensure answer is plain text string
259
+ if isinstance(answer, dict):
260
+ # If it's a dict, try to extract text field
261
+ if 'text' in answer:
262
+ answer = answer['text']
263
+ else:
264
+ answer = str(answer)
265
+ print(f"[WARNING] Answer was dict, extracted: {answer[:100]}")
266
+ elif isinstance(answer, list):
267
+ # If it's a list, extract text from each item
268
+ text_parts = []
269
+ for item in answer:
270
+ if isinstance(item, dict) and 'text' in item:
271
+ text_parts.append(item['text'])
272
+ else:
273
+ text_parts.append(str(item))
274
+ answer = " ".join(text_parts)
275
+ print(f"[WARNING] Answer was list, extracted: {answer[:100]}")
276
+ elif not isinstance(answer, str):
277
+ # Convert to string if it's any other type
278
+ answer = str(answer)
279
+ print(f"[WARNING] Answer was {type(answer)}, converted to string")
280
+
281
+ answer = answer.strip()
282
+
283
+ # Additional validation for numerical answers
284
+ # Remove common formatting issues that break exact matching
285
+ if answer:
286
+ # Remove comma separators from numbers (e.g., "1,000" -> "1000")
287
+ if ',' in answer and answer.replace(',', '').replace('.', '').isdigit():
288
+ answer = answer.replace(',', '')
289
+ print(f"[VALIDATION] Removed comma separators from answer")
290
+
291
+ # Ensure no trailing/leading whitespace or punctuation
292
+ answer = answer.strip().rstrip('.')
293
+
294
+ # Log if answer looks suspicious (for debugging)
295
+ if any(char in answer for char in ['{', '}', '[', ']', '`', '*', '#']):
296
+ print(f"[WARNING] Answer contains suspicious formatting characters: {answer[:100]}")
297
+
298
+ print(f"[FINAL ANSWER] {answer}")
299
+ return answer
300
+
301
+ except Exception as e:
302
+ elapsed_time = time.time() - start_time
303
+ print(f"[AGENT ERROR] Failed after {elapsed_time:.2f}s: {e}")
304
+ print(f"{'='*60}\n")
305
+ return f"Error: {str(e)[:100]}"