AlessandroMasala commited on
Commit
54bdb38
·
verified ·
1 Parent(s): 01398b7

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +641 -74
agent.py CHANGED
@@ -1,75 +1,71 @@
1
- # --- IMPORTS --- #
2
  import os
3
  import re
4
- import ast
5
- import requests
6
- import time
7
- import tempfile
8
-
9
  from pathlib import Path
 
10
  from enum import Enum
11
- from typing import TypedDict, Dict, List, Any
 
 
12
 
13
  from dotenv import load_dotenv
14
  from langgraph.graph import StateGraph, END
15
- from langchain.tools import Tool as LangTool, StructuredTool
16
  from langchain_core.runnables import RunnableLambda
 
17
 
18
- # ——— Zephyr LLM setup ———
19
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
20
  from langchain_community.llms import HuggingFacePipeline
21
  import torch
22
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- # Function to get the desired model Zephyr
25
  def get_zephyr_llm():
26
- model_id = 'HuggingFaceH4/zephyr-7b-alpha'
27
  tokenizer = AutoTokenizer.from_pretrained(model_id)
28
  model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- torch_dtype = torch.float16,
31
- device_map = 'auto'
32
  )
33
  gen = pipeline(
34
- 'text-generation',
35
- model = model,
36
- tokenizer = tokenizer,
37
- max_new_tokens = 256,
38
- temperature = 0.7,
39
- top_p = 0.9
40
  )
41
  return HuggingFacePipeline(pipeline=gen)
42
 
43
- # Declare the llm
44
  llm = get_zephyr_llm()
45
 
46
-
47
- # HF APIs
48
  load_dotenv()
 
 
49
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
50
  QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
51
- SUBMIT_URL = f"{DEFAULT_API_URL}/submit"
52
- FILE_PATH = f"{DEFAULT_API_URL}/files/"
53
 
54
- # --- TOOLS --- #
 
 
 
 
55
 
 
 
56
 
57
-
58
- # Agent Steps Class
59
- class AgentStep(Enum):
60
- ANALYZE = "analyze"
61
- SELECT_TOOLS = "select_tools"
62
- EXECUTE_TOOLS = "execute_tools"
63
- SYNTHESIZE = "synthesize_answer"
64
- ERROR_RECOVERY = "error_recovery"
65
- COMPLETE = "complete"
66
-
67
-
68
-
69
- # Agent State Class
70
  class AgentState(TypedDict):
 
71
  question: str
72
  original_question: str
 
73
  selected_tools: List[str]
74
  tool_results: Dict[str, Any]
75
  final_answer: str
@@ -77,63 +73,634 @@ class AgentState(TypedDict):
77
  error_count: int
78
  max_errors: int
79
 
80
-
81
- # Initialize state
82
- def initialize_state(question: str) -> str:
 
 
 
 
 
 
 
 
83
  return {
84
- 'question': question,
85
- 'original_question': question,
86
- 'selected_tools': [],
87
- 'tool_result': {},
88
- 'final_answer': '',
89
- 'current_step': AgentStep.ANALYZE.value,
 
90
  "error_count": 0,
91
  "max_errors": 3
92
  }
93
 
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def analyze_question(state: AgentState) -> AgentState:
96
- state["current_step"] = AgentStep.SELECT_TOOLS.value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  return state
98
 
99
- # Agent Class
100
- class GaiaAgent:
101
- def __init__(self):
102
- self.graph = graph
103
-
104
- def __call__(self, task_id: str, question: str) -> str:
105
-
106
-
107
-
108
-
109
-
110
-
111
-
112
-
113
-
114
-
115
-
116
-
117
-
118
-
119
-
120
-
121
-
122
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
 
 
125
 
 
126
 
 
127
 
 
 
128
 
 
129
 
 
 
 
 
 
 
 
130
 
 
131
 
 
132
 
 
133
 
 
 
134
 
 
 
135
 
 
 
 
 
 
 
136
 
 
137
 
 
 
 
 
 
 
 
 
 
 
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
 
 
 
 
 
3
  from pathlib import Path
4
+ from typing import Optional, Union, Dict, List, Any
5
  from enum import Enum
6
+ import requests
7
+ import tempfile
8
+ import ast
9
 
10
  from dotenv import load_dotenv
11
  from langgraph.graph import StateGraph, END
12
+ from langchain.tools import Tool as LangTool
13
  from langchain_core.runnables import RunnableLambda
14
+ from pathlib import Path
15
 
 
16
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
17
  from langchain_community.llms import HuggingFacePipeline
18
  import torch
19
 
20
+ from langchain.tools import StructuredTool
21
+
22
+ from tools import (
23
+ EnhancedSearchTool,
24
+ EnhancedWikipediaTool,
25
+ excel_to_markdown,
26
+ image_file_info,
27
+ audio_file_info,
28
+ code_file_read,
29
+ extract_youtube_info)
30
 
 
31
  def get_zephyr_llm():
32
+ model_id = "HuggingFaceH4/zephyr-7b-alpha"
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  model = AutoModelForCausalLM.from_pretrained(
35
+ model_id, torch_dtype=torch.float16, device_map="auto"
 
 
36
  )
37
  gen = pipeline(
38
+ "text-generation", model=model, tokenizer=tokenizer,
39
+ max_new_tokens=256, temperature=0.7, top_p=0.9
 
 
 
 
40
  )
41
  return HuggingFacePipeline(pipeline=gen)
42
 
43
+ # LLM Instance
44
  llm = get_zephyr_llm()
45
 
46
+ # Load environment variables
 
47
  load_dotenv()
48
+
49
+ # --- Constants ---
50
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
51
  QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
52
+ SUBMIT_URL = f"{DEFAULT_API_URL}/submit"
53
+ FILE_PATH = f"{DEFAULT_API_URL}/files/"
54
 
55
+ # Initialize LLM
56
+ llm = ChatGoogleGenerativeAI(
57
+ model=os.getenv("GEMINI_MODEL", "gemini-pro"),
58
+ google_api_key=os.getenv("GEMINI_API_KEY")
59
+ )
60
 
61
+ # ----------- Enhanced State Management -----------
62
+ from typing import TypedDict
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  class AgentState(TypedDict):
65
+ """Enhanced state tracking for the agent - using TypedDict for LangGraph compatibility"""
66
  question: str
67
  original_question: str
68
+ conversation_history: List[Dict[str, str]]
69
  selected_tools: List[str]
70
  tool_results: Dict[str, Any]
71
  final_answer: str
 
73
  error_count: int
74
  max_errors: int
75
 
76
+ class AgentStep(Enum):
77
+ ANALYZE_QUESTION = "analyze_question"
78
+ SELECT_TOOLS = "select_tools"
79
+ EXECUTE_TOOLS = "execute_tools"
80
+ SYNTHESIZE_ANSWER = "synthesize_answer"
81
+ ERROR_RECOVERY = "error_recovery"
82
+ COMPLETE = "complete"
83
+
84
+ # ----------- Helper Functions -----------
85
+ def initialize_state(question: str) -> AgentState:
86
+ """Initialize agent state with default values"""
87
  return {
88
+ "question": question,
89
+ "original_question": question,
90
+ "conversation_history": [],
91
+ "selected_tools": [],
92
+ "tool_results": {},
93
+ "final_answer": "",
94
+ "current_step": "start",
95
  "error_count": 0,
96
  "max_errors": 3
97
  }
98
 
99
+ # Initialize vanilla tools
100
+ from langchain.tools import DuckDuckGoSearchResults, WikipediaQueryRun
101
+ from langchain.utilities import WikipediaAPIWrapper
102
+
103
+ duckduckgo_tool = DuckDuckGoSearchResults()
104
+ wiki_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
105
+
106
+
107
+ # Initialize enhanced tools
108
+ enhanced_search_tool = LangTool.from_function(
109
+ name="enhanced_web_search",
110
+ func=EnhancedSearchTool().run,
111
+ description="Enhanced web search with intelligent query processing, multiple search strategies, and result filtering. Provides comprehensive and relevant search results."
112
+ )
113
+
114
+ enhanced_wiki_tool = LangTool.from_function(
115
+ name="enhanced_wikipedia",
116
+ func=EnhancedWikipediaTool().run,
117
+ description="Enhanced Wikipedia search with entity extraction, multi-term search, and relevant content filtering. Provides detailed encyclopedic information."
118
+ )
119
+
120
+ excel_tool = StructuredTool.from_function(
121
+ name="excel_to_text",
122
+ func=excel_to_markdown,
123
+ description="Enhanced Excel analysis with metadata, statistics, and structured data preview. Inputs: 'excel_path' (str), 'sheet_name' (str, optional).",
124
+ )
125
+
126
+ image_tool = StructuredTool.from_function(
127
+ name="image_file_info",
128
+ func=image_file_info,
129
+ description="Enhanced image file analysis with detailed metadata and properties."
130
+ )
131
+
132
+ audio_tool = LangTool.from_function(
133
+ name="audio_file_info",
134
+ func=audio_file_info,
135
+ description="Enhanced audio processing with transcription, language detection, and timestamped segments."
136
+ )
137
+
138
+ code_tool = LangTool.from_function(
139
+ name="code_file_read",
140
+ func=code_file_read,
141
+ description="Enhanced code file analysis with language-specific insights and structure analysis."
142
+ )
143
+
144
+ youtube_tool = LangTool.from_function(
145
+ name="extract_youtube_info",
146
+ func=extract_youtube_info,
147
+ description="Extracts transcription from the youtube link"
148
+ )
149
+
150
+ # Enhanced tool registry
151
+ AVAILABLE_TOOLS = {
152
+ "excel": excel_tool,
153
+ "search": wiki_tool,
154
+ "wikipedia": duckduckgo_tool,
155
+ "image": image_tool,
156
+ "audio": audio_tool,
157
+ "code": code_tool,
158
+ "youtube": youtube_tool
159
+ }
160
+
161
+ # ----------- Intelligent Tool Selection -----------
162
  def analyze_question(state: AgentState) -> AgentState:
163
+ """Enhanced question analysis with better tool recommendation"""
164
+ analysis_prompt = f"""
165
+ Analyze this question and determine the best tools and approach:
166
+ Question: {state["question"]}
167
+
168
+ Available enhanced tools:
169
+ 1. excel - Enhanced Excel/CSV analysis with statistics and metadata
170
+ 2. search - Enhanced web search with intelligent query processing and result filtering
171
+ 3. wikipedia - Enhanced Wikipedia search with entity extraction and content filtering
172
+ 4. image - Enhanced image analysis with what the image contains
173
+ 5. audio - Enhanced audio processing with transcription
174
+ 6. code - Enhanced code analysis with language-specific insights
175
+ 7. youtube - Extracts transcription from the youtube link
176
+
177
+ Consider:
178
+ - Question type (factual, analytical, current events, technical)
179
+ - Required information sources (files, web, encyclopedic)
180
+ - Time sensitivity (current vs historical information)
181
+ - Complexity level
182
+
183
+ Respond with:
184
+ 1. Question type: <type>
185
+ 2. Primary tools needed: <tools>
186
+ 3. Search strategy: <strategy>
187
+ 4. Expected answer format: <format>
188
+
189
+ Format: TYPE: <type> | TOOLS: <tools> | STRATEGY: <strategy> | FORMAT: <format>
190
+ """
191
+
192
+ try:
193
+ response = llm.invoke(analysis_prompt).content
194
+ state["conversation_history"].append({"role": "analysis", "content": response})
195
+ state["current_step"] = AgentStep.SELECT_TOOLS.value
196
+ except Exception as e:
197
+ state["error_count"] += 1
198
+ state["conversation_history"].append({"role": "error", "content": f"Analysis failed: {e}"})
199
+ state["current_step"] = AgentStep.ERROR_RECOVERY.value
200
+
201
  return state
202
 
203
+ def select_tools(state: AgentState) -> AgentState:
204
+ """Enhanced tool selection with smarter logic"""
205
+ question = state["question"].lower()
206
+ selected_tools = []
207
+
208
+ # File-based tool selection
209
+ if any(keyword in question for keyword in ["excel", "csv", "spreadsheet", ".xlsx", ".xls"]):
210
+ selected_tools.append("excel")
211
+ if any(keyword in question for keyword in [".png", ".jpg", ".jpeg", ".bmp", ".gif", "image"]):
212
+ selected_tools.append("image")
213
+ if any(keyword in question for keyword in [".mp3", ".wav", ".ogg", "audio", "transcribe"]):
214
+ selected_tools.append("audio")
215
+ if any(keyword in question for keyword in [".py", ".ipynb", "code", "script", "function"]):
216
+ selected_tools.append("code")
217
+ if any(keyword in question for keyword in ["youtube"]):
218
+ selected_tools.append("youtube")
219
+
220
+ print(f"File-based tools selected: {selected_tools}")
221
+
222
+ tools_prompt = f"""
223
+ You are a smart assistant that selects relevant tools based on the user's natural language question.
224
+
225
+ Available tools:
226
+ - "search" → Use for real-time, recent, or broad web information.
227
+ - "wikipedia" → Use for factual or encyclopedic knowledge.
228
+ - "excel" → Use for spreadsheet-related questions (.xlsx, .csv).
229
+ - "image" → Use for image files (.png, .jpg, etc.) or image-based tasks.
230
+ - "audio" → Use for sound files (.mp3, .wav, etc.) or transcription.
231
+ - "code" → Use for programming-related questions or when files like .py are mentioned.
232
+ - "youtube" → Use for questions involving YouTube videos.
233
+
234
+ Return the result as a **Python list of strings**, no explanation. Use only the relevant tools.
235
+ If not relevant tool is found, return an empty list such as [].
236
+
237
+ ### Examples:
238
+
239
+ Q: "Show me recent news about elections in 2025"
240
+ A: ["search"]
241
+
242
+ Q: "Summarize this Wikipedia article about Einstein"
243
+ A: ["wikipedia"]
244
+
245
+ Q: "Analyze this .csv file"
246
+ A: ["excel"]
247
+
248
+ Q: "Transcribe this .wav audio file"
249
+ A: ["audio"]
250
+
251
+ Q: "Generate Python code from this prompt"
252
+ A: ["code"]
253
+
254
+ Q: "Who was the president of USA in 1945?"
255
+ A: ["wikipedia"]
256
+
257
+ Q: "Give me current weather updates"
258
+ A: ["search"]
259
+
260
+ Q: "Look up the history of space exploration"
261
+ A: ["search", "wikipedia"]
262
+
263
+ Q: "What is 2 + 2?"
264
+ A: []
265
+
266
+ ### Now answer:
267
+
268
+ Q: {state["question"]}
269
+ A:
270
+ """
271
+
272
+ llm_tools = ast.literal_eval(llm.invoke(tools_prompt).content.strip())
273
+ if not isinstance(llm_tools, list):
274
+ llm_tools = []
275
+ print(f"LLM suggested tools: {llm_tools}")
276
+ selected_tools.extend(llm_tools)
277
+ selected_tools = list(set(selected_tools)) # Remove duplicates
278
+
279
+ print(f"Final selected tools after LLM suggestion: {selected_tools}")
280
+
281
+
282
+ # # Information-based tool selection
283
+ # current_indicators = ["recent", "current", "news", "today", "2025", "now"]
284
+ # encyclopedia_indicators = ["wiki", "wikipedia"]
285
+
286
+ # if any(indicator in question for indicator in current_indicators):
287
+ # selected_tools.append("search")
288
+ # elif any(indicator in question for indicator in encyclopedia_indicators):
289
+ # selected_tools.append("wikipedia")
290
+ # elif any(keyword in question for keyword in ["search", "find", "look up", "information about"]):
291
+ # # Use both for comprehensive coverage
292
+ # selected_tools.extend(["search", "wikipedia"])
293
+
294
+ # # Default fallback
295
+ # if not selected_tools:
296
+ # if any(word in question for word in ["who", "what", "when", "where"]):
297
+ # selected_tools.append("wikipedia")
298
+ # selected_tools.append("search")
299
+
300
+ # # Remove duplicates while preserving order
301
+ # selected_tools = list(dict.fromkeys(selected_tools))
302
+
303
+ state["selected_tools"] = selected_tools
304
+ state["current_step"] = AgentStep.EXECUTE_TOOLS.value
305
+ return state
306
 
307
+ def execute_tools(state: AgentState) -> AgentState:
308
+ """Enhanced tool execution with better error handling"""
309
+ results = {}
310
+
311
+ # Enhanced file detection
312
+ file_path = None
313
+ downloaded_file_marker = "A file was downloaded for this task and saved locally at:"
314
+ if downloaded_file_marker in state["question"]:
315
+ lines = state["question"].splitlines()
316
+ for i, line in enumerate(lines):
317
+ if downloaded_file_marker in line:
318
+ if i + 1 < len(lines):
319
+ file_path_candidate = lines[i + 1].strip()
320
+ if Path(file_path_candidate).exists():
321
+ file_path = file_path_candidate
322
+ print(f"Detected file path: {file_path}")
323
+ break
324
+
325
+ for tool_name in state["selected_tools"]:
326
+ try:
327
+ print(f"Executing tool: {tool_name}")
328
+
329
+ # File-based tools
330
+ if tool_name in ["excel", "image", "audio", "code"] and file_path:
331
+ if tool_name == "excel":
332
+ result = AVAILABLE_TOOLS["excel"].run({"excel_path": file_path, "sheet_name": None})
333
+ elif tool_name == "image":
334
+ result = AVAILABLE_TOOLS["image"].run({"image_path": file_path, "question": state["question"]})
335
+ elif tool_name == "youtube":
336
+ print(f"Running YouTube tool with file path: {file_path}")
337
+ result = AVAILABLE_TOOLS["youtube"].run(state["question"])
338
+ else:
339
+ result = AVAILABLE_TOOLS[tool_name].run(file_path)
340
+ # Information-based tools
341
+ else:
342
+ # Extract clean query for search tools
343
+ clean_query = state["question"]
344
+ if downloaded_file_marker in clean_query:
345
+ clean_query = clean_query.split(downloaded_file_marker)[0].strip()
346
+
347
+ result = AVAILABLE_TOOLS[tool_name].run(clean_query)
348
+
349
+ results[tool_name] = result
350
+
351
+ print(f"Tool {tool_name} completed successfully.")
352
+ print(f"Output for {tool_name}: {result}")
353
+
354
+ except Exception as e:
355
+ error_msg = f"Error using {tool_name}: {str(e)}"
356
+ results[tool_name] = error_msg
357
+ state["error_count"] += 1
358
+ print(error_msg)
359
+
360
+ state["tool_results"] = results
361
+ state["current_step"] = AgentStep.SYNTHESIZE_ANSWER.value
362
+ return state
363
 
364
+ def synthesize_answer(state: AgentState) -> AgentState:
365
+ """Enhanced answer synthesis with better formatting"""
366
 
367
+ tool_results_str = "\n".join([f"=== {tool.upper()} RESULTS ===\n{result}\n" for tool, result in state["tool_results"].items()])
368
 
369
+ cot_prompt = f"""You are a precise assistant tasked with analyzing the user's question{" using the available tool outputs" if state["tool_results"] else ""}.
370
 
371
+ Question:
372
+ {state["question"]}
373
 
374
+ {f"Available tool outputs: {tool_results_str}" if state["tool_results"] else ""}
375
 
376
+ Instructions:
377
+ - Think step-by-step to determine the best strategy to answer the question.
378
+ - Use only the given information; do not hallucinate or infer from external knowledge.
379
+ - If decoding, logical deduction, counting, or interpretation is required, show each step clearly.
380
+ - If any part of the tool output is unclear or incomplete, mention it and its impact.
381
+ - Do not guess. If the information is insufficient, say so clearly.
382
+ - Finish with a clearly marked line: `---END OF ANALYSIS---`
383
 
384
+ Your step-by-step analysis:"""
385
 
386
+ cot_response = llm.invoke(cot_prompt).content
387
 
388
+ final_answer_prompt = f"""You are a precise assistant tasked with deriving the **final answer** from the step-by-step analysis below.
389
 
390
+ Question:
391
+ {state["question"]}
392
 
393
+ Step-by-step analysis:
394
+ {cot_response}
395
 
396
+ Instructions:
397
+ - Read the analysis thoroughly before responding.
398
+ - Output ONLY the final answer. Do NOT include any reasoning or explanation.
399
+ - Remove any punctuation at the corners of the answer unless it is explicitly mentioned in the question.
400
+ - The answer must be concise and factual.
401
+ - If the analysis concluded that a definitive answer cannot be determined, respond with: `NA` (exactly).
402
 
403
+ Final answer:"""
404
 
405
+
406
+ try:
407
+ response = llm.invoke(final_answer_prompt).content
408
+ state["final_answer"] = response
409
+ state["current_step"] = AgentStep.COMPLETE.value
410
+ except Exception as e:
411
+ state["error_count"] += 1
412
+ state["final_answer"] = f"Error synthesizing answer: {e}"
413
+ state["current_step"] = AgentStep.ERROR_RECOVERY.value
414
+
415
+ return state
416
 
417
+ def error_recovery(state: AgentState) -> AgentState:
418
+ """Enhanced error recovery with multiple fallback strategies"""
419
+ if state["error_count"] >= state["max_errors"]:
420
+ state["final_answer"] = "I encountered multiple errors and cannot complete this task reliably."
421
+ state["current_step"] = AgentStep.COMPLETE.value
422
+ else:
423
+ # Enhanced fallback: try with simplified approach
424
+ try:
425
+ fallback_prompt = f"""
426
+ Answer this question directly using your knowledge:
427
+ {state["original_question"]}
428
+
429
+ Provide a helpful response even if you cannot access external tools.
430
+ Be clear about any limitations in your answer.
431
+ """
432
+ response = llm.invoke(fallback_prompt).content
433
+ state["final_answer"] = f"Using available knowledge (some tools unavailable): {response}"
434
+ state["current_step"] = AgentStep.COMPLETE.value
435
+ except Exception as e:
436
+ state["final_answer"] = f"All approaches failed. Error: {e}"
437
+ state["current_step"] = AgentStep.COMPLETE.value
438
+
439
+ return state
440
 
441
+ # ----------- Enhanced LangGraph Workflow -----------
442
+ def route_next_step(state: AgentState) -> str:
443
+ """Route to next step based on current state"""
444
+ step_routing = {
445
+ "start": AgentStep.ANALYZE_QUESTION.value,
446
+ AgentStep.ANALYZE_QUESTION.value: AgentStep.SELECT_TOOLS.value,
447
+ AgentStep.SELECT_TOOLS.value: AgentStep.EXECUTE_TOOLS.value,
448
+ AgentStep.EXECUTE_TOOLS.value: AgentStep.SYNTHESIZE_ANSWER.value,
449
+ AgentStep.SYNTHESIZE_ANSWER.value: AgentStep.COMPLETE.value,
450
+ AgentStep.ERROR_RECOVERY.value: AgentStep.COMPLETE.value,
451
+ AgentStep.COMPLETE.value: END,
452
+ }
453
+
454
+ return step_routing.get(state["current_step"], END)
455
+
456
+ # Create enhanced workflow
457
+ workflow = StateGraph(AgentState)
458
+
459
+ # Add nodes
460
+ workflow.add_node("analyze_question", RunnableLambda(analyze_question))
461
+ workflow.add_node("select_tools", RunnableLambda(select_tools))
462
+ workflow.add_node("execute_tools", RunnableLambda(execute_tools))
463
+ workflow.add_node("synthesize_answer", RunnableLambda(synthesize_answer))
464
+ workflow.add_node("error_recovery", RunnableLambda(error_recovery))
465
+
466
+ # Set entry point
467
+ workflow.set_entry_point("analyze_question")
468
+
469
+ # Add conditional edges
470
+ workflow.add_conditional_edges(
471
+ "analyze_question",
472
+ lambda state: "select_tools" if state["current_step"] == AgentStep.SELECT_TOOLS.value else "error_recovery"
473
+ )
474
+ workflow.add_edge("select_tools", "execute_tools")
475
+ workflow.add_conditional_edges(
476
+ "execute_tools",
477
+ lambda state: "synthesize_answer" if state["current_step"] == AgentStep.SYNTHESIZE_ANSWER.value else "error_recovery"
478
+ )
479
+ workflow.add_conditional_edges(
480
+ "synthesize_answer",
481
+ lambda state: END if state["current_step"] == AgentStep.COMPLETE.value else "error_recovery"
482
+ )
483
+ workflow.add_edge("error_recovery", END)
484
+
485
+ # Compile the enhanced graph
486
+ graph = workflow.compile()
487
+
488
+ # ----------- Agent Class -----------
489
+ class GaiaAgent:
490
+ """GAIA Agent with tools and intelligent processing"""
491
+
492
+ def __init__(self):
493
+ self.graph = graph
494
+ self.tool_usage_stats = {}
495
+ print("Enhanced GAIA Agent initialized with:")
496
+ print("✓ Intelligent multi-query web search")
497
+ print("✓ Entity-aware Wikipedia search")
498
+ print("✓ Enhanced file processing tools")
499
+ print("✓ Advanced error recovery")
500
+ print("✓ Comprehensive result synthesis")
501
+
502
+ def get_tool_stats(self) -> Dict[str, int]:
503
+ """Get usage statistics for tools"""
504
+ return self.tool_usage_stats.copy()
505
+
506
+ def __call__(self, task_id: str, question: str) -> str:
507
+ print(f"\n{'='*60}")
508
+ print(f"[{task_id}] ENHANCED PROCESSING: {question}")
509
+
510
+ # Initialize state
511
+ processed_question = process_file(task_id, question)
512
+ initial_state = initialize_state(processed_question)
513
+
514
+ try:
515
+ # Execute the enhanced workflow
516
+ result = self.graph.invoke(initial_state)
517
+
518
+ # Extract results
519
+ answer = result.get("final_answer", "No answer generated")
520
+ selected_tools = result.get("selected_tools", [])
521
+ conversation_history = result.get("conversation_history", [])
522
+ tool_results = result.get("tool_results", {})
523
+ error_count = result.get("error_count", 0)
524
+
525
+ # Update tool usage statistics
526
+ for tool in selected_tools:
527
+ self.tool_usage_stats[tool] = self.tool_usage_stats.get(tool, 0) + 1
528
+
529
+ # Enhanced logging
530
+ print(f"[{task_id}] Selected tools: {selected_tools}")
531
+ print(f"[{task_id}] Tools executed: {list(tool_results.keys())}")
532
+ print(f"[{task_id}] Processing steps: {len(conversation_history)}")
533
+ print(f"[{task_id}] Errors encountered: {error_count}")
534
+
535
+ # Log tool result sizes for debugging
536
+ for tool, result in tool_results.items():
537
+ result_size = len(str(result)) if result else 0
538
+ print(f"[{task_id}] {tool} result size: {result_size} chars")
539
+
540
+ print(f"[{task_id}] FINAL ANSWER: {answer}")
541
+ print(f"{'='*60}")
542
+
543
+ return answer
544
+
545
+ except Exception as e:
546
+ error_msg = f"Critical error in enhanced agent execution: {str(e)}"
547
+ print(f"[{task_id}] {error_msg}")
548
+
549
+ # Try fallback direct LLM response
550
+ try:
551
+ fallback_response = llm.invoke(f"Please answer this question: {question}").content
552
+ return f"Fallback response: {fallback_response}"
553
+ except:
554
+ return error_msg
555
+
556
+ # ----------- Enhanced File Processing -----------
557
+ def detect_file_type(file_path: str) -> Optional[str]:
558
+ """Enhanced file type detection with more formats"""
559
+ ext = Path(file_path).suffix.lower()
560
+
561
+ file_type_mapping = {
562
+ # Spreadsheets
563
+ '.xlsx': 'excel', '.xls': 'excel', '.csv': 'excel',
564
+ # Images
565
+ '.png': 'image', '.jpg': 'image', '.jpeg': 'image',
566
+ '.bmp': 'image', '.gif': 'image', '.tiff': 'image', '.webp': 'image',
567
+ # Audio
568
+ '.mp3': 'audio', '.wav': 'audio', '.ogg': 'audio',
569
+ '.flac': 'audio', '.m4a': 'audio', '.aac': 'audio',
570
+ # Code
571
+ '.py': 'code', '.ipynb': 'code', '.js': 'code', '.html': 'code',
572
+ '.css': 'code', '.java': 'code', '.cpp': 'code', '.c': 'code',
573
+ '.sql': 'code', '.r': 'code', '.json': 'code', '.xml': 'code',
574
+ # Documents
575
+ '.txt': 'text', '.md': 'text', '.pdf': 'document',
576
+ '.doc': 'document', '.docx': 'document'
577
+ }
578
+
579
+ return file_type_mapping.get(ext)
580
+
581
+ def process_file(task_id: str, question_text: str) -> str:
582
+ """Enhanced file processing with better error handling and metadata"""
583
+ file_url = f"{FILE_PATH}{task_id}"
584
+
585
+ try:
586
+ print(f"[{task_id}] Attempting to download file from: {file_url}")
587
+ response = requests.get(file_url, timeout=30)
588
+ response.raise_for_status()
589
+ print(f"[{task_id}] File download successful. Status: {response.status_code}")
590
+
591
+ except requests.exceptions.RequestException as exc:
592
+ print(f"[{task_id}] File download failed: {str(exc)}")
593
+ return question_text # Return original question if no file
594
+
595
+ # Enhanced filename extraction
596
+ content_disposition = response.headers.get("content-disposition", "")
597
+ filename = task_id # Default fallback
598
+
599
+ # Try to extract filename from Content-Disposition header
600
+ filename_match = re.search(r'filename[*]?=(?:"([^"]+)"|([^;]+))', content_disposition)
601
+ if filename_match:
602
+ filename = filename_match.group(1) or filename_match.group(2)
603
+ filename = filename.strip()
604
+
605
+ # Create enhanced temp directory structure
606
+ temp_storage_dir = Path(tempfile.gettempdir()) / "gaia_enhanced_files" / task_id
607
+ temp_storage_dir.mkdir(parents=True, exist_ok=True)
608
+
609
+ file_path = temp_storage_dir / filename
610
+ file_path.write_bytes(response.content)
611
+
612
+ # Get file metadata
613
+ file_size = len(response.content)
614
+ file_type = detect_file_type(filename)
615
+
616
+ print(f"[{task_id}] File saved: {filename} ({file_size:,} bytes, type: {file_type})")
617
+
618
+ # Enhanced question augmentation
619
+ enhanced_question = f"{question_text}\n\n"
620
+ enhanced_question += f"{'='*50}\n"
621
+ enhanced_question += f"FILE INFORMATION:\n"
622
+ enhanced_question += f"A file was downloaded for this task and saved locally at:\n"
623
+ enhanced_question += f"{str(file_path)}\n"
624
+ enhanced_question += f"File details:\n"
625
+ enhanced_question += f"- Name: {filename}\n"
626
+ enhanced_question += f"- Size: {file_size:,} bytes ({file_size/1024:.1f} KB)\n"
627
+ enhanced_question += f"- Type: {file_type or 'unknown'}\n"
628
+ enhanced_question += f"{'='*50}\n\n"
629
+
630
+ return enhanced_question
631
+
632
+ # ----------- Usage Examples and Testing -----------
633
+ def run_enhanced_tests():
634
+ """Run comprehensive tests of the enhanced agent"""
635
+ agent = GaiaAgent()
636
+
637
+ test_cases = [
638
+ {
639
+ "id": "test_search_1",
640
+ "question": "What are the latest developments in artificial intelligence in 2024?",
641
+ "expected_tools": ["search"]
642
+ },
643
+ {
644
+ "id": "test_wiki_1",
645
+ "question": "Tell me about Albert Einstein's contributions to physics",
646
+ "expected_tools": ["wikipedia"]
647
+ },
648
+ {
649
+ "id": "test_combined_1",
650
+ "question": "What is machine learning and what are recent breakthroughs?",
651
+ "expected_tools": ["wikipedia", "search"]
652
+ },
653
+ {
654
+ "id": "test_excel_1",
655
+ "question": "Analyze the data in the Excel file sales_data.xlsx",
656
+ "expected_tools": ["excel"]
657
+ }
658
+ ]
659
+
660
+ print("\n" + "="*80)
661
+ print("RUNNING ENHANCED AGENT TESTS")
662
+ print("="*80)
663
+
664
+ for test_case in test_cases:
665
+ print(f"\nTest Case: {test_case['id']}")
666
+ print(f"Question: {test_case['question']}")
667
+ print(f"Expected tools: {test_case['expected_tools']}")
668
+
669
+ try:
670
+ result = agent(test_case['id'], test_case['question'])
671
+ print(f"Result length: {len(result)} characters")
672
+ print(f"Result preview: {result[:200]}...")
673
+ except Exception as e:
674
+ print(f"Test failed: {e}")
675
+
676
+ print("-" * 60)
677
+
678
+ # Print tool usage statistics
679
+ print(f"\nTool Usage Statistics:")
680
+ for tool, count in agent.get_tool_stats().items():
681
+ print(f" {tool}: {count} times")
682
+
683
+ # Usage example
684
+ if __name__ == "__main__":
685
+ # Create enhanced agent
686
+ agent = GaiaAgent()
687
+
688
+ # Example usage
689
+ sample_questions = [
690
+ "What is the current population of Tokyo and how has it changed recently?",
691
+ "Explain quantum computing and its recent developments",
692
+ "Tell me about the history of machine learning and current AI trends",
693
+ ]
694
+
695
+ print("\n" + "="*80)
696
+ print("ENHANCED GAIA AGENT DEMONSTRATION")
697
+ print("="*80)
698
+
699
+ for i, question in enumerate(sample_questions):
700
+ print(f"\nExample {i+1}: {question}")
701
+ result = agent(f"demo_{i}", question)
702
+ print(f"Answer: {result[:300]}...")
703
+ print("-" * 60)
704
+
705
+ # Uncomment to run comprehensive tests
706
+ # run_enhanced_tests()