Sanyam0605 commited on
Commit
76f42c5
·
verified ·
1 Parent(s): ebf370a

Create core_agent.py

Browse files
Files changed (1) hide show
  1. core_agent.py +556 -0
core_agent.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import (
2
+ CodeAgent,
3
+ DuckDuckGoSearchTool,
4
+ HfApiModel,
5
+ OpenAIServerModel,
6
+ PythonInterpreterTool,
7
+ tool,
8
+ InferenceClientModel,
9
+ TransformersModel
10
+ )
11
+ from typing import List, Dict, Any, Optional
12
+ import os
13
+ import tempfile
14
+ import re
15
+ import json
16
+ import requests
17
+ from urllib.parse import urlparse
18
+ import google.generativeai as genai
19
+
20
+ class GeminiResponseWrapper:
21
+ """A wrapper class to make Gemini API responses compatible with smolagents"""
22
+ def __init__(self, text):
23
+ self.text = text
24
+ # This content property is what smolagents will access - must be a string for strip() to work
25
+ self.content = text
26
+
27
+ class GeminiModel:
28
+ """Model implementation for Google's Gemini models"""
29
+
30
+ def __init__(self, api_key, model_id="gemini-2.0-flash-lite", temperature=0.2):
31
+ genai.configure(api_key=api_key)
32
+ self.model = genai.GenerativeModel(model_name=model_id)
33
+ self.temperature = temperature
34
+
35
+ def __call__(self, messages, stop_sequences=None, **kwargs):
36
+ try:
37
+ # For Gemini, we'll just use the latest message's content
38
+ # as that contains the full context we need
39
+ if not messages or len(messages) == 0:
40
+ return GeminiResponseWrapper("No messages provided")
41
+
42
+ # Get the content from the latest message
43
+ last_message = messages[-1]
44
+ if "content" not in last_message:
45
+ return GeminiResponseWrapper("Invalid message format")
46
+
47
+ content = last_message["content"]
48
+ if isinstance(content, list):
49
+ prompt = ""
50
+ for item in content:
51
+ if isinstance(item, dict) and "text" in item:
52
+ prompt += item["text"] + "\n"
53
+ elif isinstance(item, str):
54
+ prompt += item + "\n"
55
+ else:
56
+ prompt = str(content)
57
+
58
+ # Generate the response with Gemini
59
+ response = self.model.generate_content(prompt)
60
+
61
+ # Return a wrapper object that has both .text and .content properties
62
+ if hasattr(response, 'text'):
63
+ return GeminiResponseWrapper(response.text)
64
+ else:
65
+ return GeminiResponseWrapper(str(response))
66
+
67
+ except Exception as e:
68
+ # Return error in the wrapper format
69
+ return GeminiResponseWrapper(f"Error: {str(e)}")
70
+
71
+ @tool
72
+ def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
73
+ """
74
+ Save content to a temporary file and return the path.
75
+ Useful for processing files from the GAIA API.
76
+
77
+ Args:
78
+ content: The content to save to the file
79
+ filename: Optional filename, will generate a random name if not provided
80
+
81
+ Returns:
82
+ Path to the saved file
83
+ """
84
+ temp_dir = tempfile.gettempdir()
85
+ if filename is None:
86
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
87
+ filepath = temp_file.name
88
+ else:
89
+ filepath = os.path.join(temp_dir, filename)
90
+
91
+ # Write content to the file
92
+ with open(filepath, 'w') as f:
93
+ f.write(content)
94
+
95
+ return f"File saved to {filepath}. You can read this file to process its contents."
96
+
97
+ @tool
98
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
99
+ """
100
+ Download a file from a URL and save it to a temporary location.
101
+
102
+ Args:
103
+ url: The URL to download from
104
+ filename: Optional filename, will generate one based on URL if not provided
105
+
106
+ Returns:
107
+ Path to the downloaded file
108
+ """
109
+ try:
110
+ # Parse URL to get filename if not provided
111
+ if not filename:
112
+ path = urlparse(url).path
113
+ filename = os.path.basename(path)
114
+ if not filename:
115
+ # Generate a random name if we couldn't extract one
116
+ import uuid
117
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
118
+
119
+ # Create temporary file
120
+ temp_dir = tempfile.gettempdir()
121
+ filepath = os.path.join(temp_dir, filename)
122
+
123
+ # Download the file
124
+ response = requests.get(url, stream=True)
125
+ response.raise_for_status()
126
+
127
+ # Save the file
128
+ with open(filepath, 'wb') as f:
129
+ for chunk in response.iter_content(chunk_size=8192):
130
+ f.write(chunk)
131
+
132
+ return f"File downloaded to {filepath}. You can now process this file."
133
+ except Exception as e:
134
+ return f"Error downloading file: {str(e)}"
135
+
136
+ @tool
137
+ def extract_text_from_image(image_path: str) -> str:
138
+ """
139
+ Extract text from an image using pytesseract (if available).
140
+
141
+ Args:
142
+ image_path: Path to the image file
143
+
144
+ Returns:
145
+ Extracted text or error message
146
+ """
147
+ try:
148
+ # Try to import pytesseract
149
+ import pytesseract
150
+ from PIL import Image
151
+
152
+ # Open the image
153
+ image = Image.open(image_path)
154
+
155
+ # Extract text
156
+ text = pytesseract.image_to_string(image)
157
+
158
+ return f"Extracted text from image:\n\n{text}"
159
+ except ImportError:
160
+ return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
161
+ except Exception as e:
162
+ return f"Error extracting text from image: {str(e)}"
163
+
164
+ @tool
165
+ def analyze_csv_file(file_path: str, query: str) -> str:
166
+ """
167
+ Analyze a CSV file using pandas and answer a question about it.
168
+
169
+ Args:
170
+ file_path: Path to the CSV file
171
+ query: Question about the data
172
+
173
+ Returns:
174
+ Analysis result or error message
175
+ """
176
+ try:
177
+ import pandas as pd
178
+
179
+ # Read the CSV file
180
+ df = pd.read_csv(file_path)
181
+
182
+ # Run various analyses based on the query
183
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
184
+ result += f"Columns: {', '.join(df.columns)}\n\n"
185
+
186
+ # Add summary statistics
187
+ result += "Summary statistics:\n"
188
+ result += str(df.describe())
189
+
190
+ return result
191
+ except ImportError:
192
+ return "Error: pandas is not installed. Please install it with 'pip install pandas'."
193
+ except Exception as e:
194
+ return f"Error analyzing CSV file: {str(e)}"
195
+
196
+ @tool
197
+ def analyze_excel_file(file_path: str, query: str) -> str:
198
+ """
199
+ Analyze an Excel file using pandas and answer a question about it.
200
+
201
+ Args:
202
+ file_path: Path to the Excel file
203
+ query: Question about the data
204
+
205
+ Returns:
206
+ Analysis result or error message
207
+ """
208
+ try:
209
+ import pandas as pd
210
+
211
+ # Read the Excel file
212
+ df = pd.read_excel(file_path)
213
+
214
+ # Run various analyses based on the query
215
+ result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
216
+ result += f"Columns: {', '.join(df.columns)}\n\n"
217
+
218
+ # Add summary statistics
219
+ result += "Summary statistics:\n"
220
+ result += str(df.describe())
221
+
222
+ return result
223
+ except ImportError:
224
+ return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
225
+ except Exception as e:
226
+ return f"Error analyzing Excel file: {str(e)}"
227
+
228
+ class GAIAAgent:
229
+ def __init__(
230
+ self,
231
+ model_type: str = "HfApiModel",
232
+ model_id: Optional[str] = None,
233
+ api_key: Optional[str] = None,
234
+ api_base: Optional[str] = None,
235
+ temperature: float = 0.2,
236
+ executor_type: str = "local", # Changed from use_e2b to executor_type
237
+ additional_imports: List[str] = None,
238
+ additional_tools: List[Any] = None,
239
+ system_prompt: Optional[str] = None, # We'll still accept this parameter but not use it directly
240
+ verbose: bool = False,
241
+ provider: Optional[str] = None, # Add provider for InferenceClientModel
242
+ timeout: Optional[int] = None # Add timeout for InferenceClientModel
243
+ ):
244
+ """
245
+ Initialize a GAIAAgent with specified configuration
246
+
247
+ Args:
248
+ model_type: Type of model to use (HfApiModel, LiteLLMModel, OpenAIServerModel, InferenceClientModel)
249
+ model_id: ID of the model to use
250
+ api_key: API key for the model provider
251
+ api_base: Base URL for API calls
252
+ temperature: Temperature for text generation
253
+ executor_type: Type of executor for code execution ('local' or 'e2b')
254
+ additional_imports: Additional Python modules to allow importing
255
+ additional_tools: Additional tools to provide to the agent
256
+ system_prompt: Custom system prompt to use (not directly used, kept for backward compatibility)
257
+ verbose: Enable verbose logging
258
+ provider: Provider for InferenceClientModel (e.g., "hf-inference")
259
+ timeout: Timeout in seconds for API calls
260
+ """
261
+ # Set verbosity
262
+ self.verbose = verbose
263
+ self.system_prompt = system_prompt # Store for potential future use
264
+
265
+ # Initialize model based on configuration
266
+ if model_type == "HfApiModel":
267
+ if api_key is None:
268
+ api_key = os.getenv("HF_API_TOKEN")
269
+ if not api_key:
270
+ raise ValueError("No Hugging Face token provided. Please set HF_API_TOKEN environment variable or pass api_key parameter.")
271
+
272
+ if self.verbose:
273
+ print(f"Using Hugging Face token: {api_key[:5]}...")
274
+
275
+ self.model = HfApiModel(
276
+ model_id=model_id or "meta-llama/Llama-3-70B-Instruct",
277
+ token=api_key,
278
+ temperature=temperature
279
+ )
280
+ elif model_type == "InferenceClientModel":
281
+ if api_key is None:
282
+ api_key = os.getenv("HF_API_TOKEN")
283
+ if not api_key:
284
+ raise ValueError("No Hugging Face token provided. Please set HUGGINGFACEHUB_API_TOKEN environment variable or pass api_key parameter.")
285
+
286
+ if self.verbose:
287
+ print(f"Using Hugging Face token: {api_key[:5]}...")
288
+
289
+ self.model = InferenceClientModel(
290
+ model_id=model_id or "meta-llama/Llama-3-70B-Instruct",
291
+ provider=provider or "hf-inference",
292
+ token=api_key,
293
+ timeout=timeout or 120,
294
+ temperature=temperature
295
+ )
296
+ elif model_type == "LiteLLMModel":
297
+ from smolagents import LiteLLMModel
298
+ self.model = LiteLLMModel(
299
+ model_id=model_id or "gpt-4o",
300
+ api_key=api_key or os.getenv("OPENAI_API_KEY"),
301
+ temperature=temperature
302
+ )
303
+ elif model_type == "TransformersModel":
304
+ if not model_id:
305
+ raise ValueError("model_id is required for TransformersModel")
306
+ self.model = TransformersModel(
307
+ model_id=model_id,
308
+ device_map="auto", # Let it automatically decide device placement
309
+ trust_remote_code=True, # Needed for some models
310
+ temperature=temperature
311
+ )
312
+ elif model_type == "GeminiModel":
313
+ api_key = api_key or os.getenv("GOOGLE_API_KEY")
314
+ if not api_key:
315
+ raise ValueError("No Google API key provided. Please set GOOGLE_API_KEY environment variable or pass api_key parameter.")
316
+ self.model = GeminiModel(
317
+ api_key=api_key,
318
+ model_id=model_id or "gemini-pro",
319
+ temperature=temperature
320
+ )
321
+ elif model_type == "OpenAIServerModel":
322
+ # Check for xAI API key and base URL first
323
+ xai_api_key = os.getenv("XAI_API_KEY")
324
+ xai_api_base = os.getenv("XAI_API_BASE")
325
+
326
+ # If xAI credentials are available, use them
327
+ if xai_api_key and api_key is None:
328
+ api_key = xai_api_key
329
+ if self.verbose:
330
+ print(f"Using xAI API key: {api_key[:5]}...")
331
+
332
+ # If no API key specified, fall back to OPENAI_API_KEY
333
+ if api_key is None:
334
+ api_key = os.getenv("OPENAI_API_KEY")
335
+ if not api_key:
336
+ raise ValueError("No OpenAI API key provided. Please set OPENAI_API_KEY or XAI_API_KEY environment variable or pass api_key parameter.")
337
+
338
+ # If xAI API base is available and no api_base is provided, use it
339
+ if xai_api_base and api_base is None:
340
+ api_base = xai_api_base
341
+ if self.verbose:
342
+ print(f"Using xAI API base URL: {api_base}")
343
+
344
+ # If no API base specified but environment variable available, use it
345
+ if api_base is None:
346
+ api_base = os.getenv("AGENT_API_BASE")
347
+ if api_base and self.verbose:
348
+ print(f"Using API base from AGENT_API_BASE: {api_base}")
349
+
350
+ self.model = OpenAIServerModel(
351
+ model_id=model_id or "gpt-4o",
352
+ api_key=api_key,
353
+ api_base=api_base,
354
+ temperature=temperature
355
+ )
356
+ else:
357
+ raise ValueError(f"Unknown model type: {model_type}")
358
+
359
+ if self.verbose:
360
+ print(f"Initialized model: {model_type} - {model_id}")
361
+
362
+ # Initialize default tools
363
+ self.tools = [
364
+ DuckDuckGoSearchTool(),
365
+ PythonInterpreterTool(),
366
+ save_and_read_file,
367
+ download_file_from_url,
368
+ analyze_csv_file,
369
+ analyze_excel_file
370
+ ]
371
+
372
+ # Add extract_text_from_image if PIL and pytesseract are available
373
+ try:
374
+ import pytesseract
375
+ from PIL import Image
376
+ self.tools.append(extract_text_from_image)
377
+ if self.verbose:
378
+ print("Added image processing tool")
379
+ except ImportError:
380
+ if self.verbose:
381
+ print("Image processing libraries not available")
382
+
383
+ # Add any additional tools
384
+ if additional_tools:
385
+ self.tools.extend(additional_tools)
386
+
387
+ if self.verbose:
388
+ print(f"Initialized with {len(self.tools)} tools")
389
+
390
+ # Setup imports allowed
391
+ self.imports = ["pandas", "numpy", "datetime", "json", "re", "math", "os", "requests", "csv", "urllib"]
392
+ if additional_imports:
393
+ self.imports.extend(additional_imports)
394
+
395
+ # Initialize the CodeAgent
396
+ executor_kwargs = {}
397
+ if executor_type == "e2b":
398
+ try:
399
+ # Try to import e2b dependencies to check if they're available
400
+ from e2b_code_interpreter import Sandbox
401
+ if self.verbose:
402
+ print("Using e2b executor")
403
+ except ImportError:
404
+ if self.verbose:
405
+ print("e2b dependencies not found, falling back to local executor")
406
+ executor_type = "local" # Fallback to local if e2b is not available
407
+
408
+ self.agent = CodeAgent(
409
+ tools=self.tools,
410
+ model=self.model,
411
+ additional_authorized_imports=self.imports,
412
+ executor_type=executor_type,
413
+ executor_kwargs=executor_kwargs,
414
+ verbosity_level=2 if self.verbose else 0
415
+ )
416
+
417
+ if self.verbose:
418
+ print("Agent initialized and ready")
419
+
420
+ def answer_question(self, question: str, task_file_path: Optional[str] = None) -> str:
421
+ """
422
+ Process a GAIA benchmark question and return the answer
423
+
424
+ Args:
425
+ question: The question to answer
426
+ task_file_path: Optional path to a file associated with the question
427
+
428
+ Returns:
429
+ The answer to the question
430
+ """
431
+ try:
432
+ if self.verbose:
433
+ print(f"Processing question: {question}")
434
+ if task_file_path:
435
+ print(f"With associated file: {task_file_path}")
436
+
437
+ # Create a context with file information if available
438
+ context = question
439
+ file_content = None
440
+
441
+ # If there's a file, read it and include its content in the context
442
+ if task_file_path:
443
+ try:
444
+ with open(task_file_path, 'r') as f:
445
+ file_content = f.read()
446
+
447
+ # Determine file type from extension
448
+ import os
449
+ file_ext = os.path.splitext(task_file_path)[1].lower()
450
+
451
+ context = f"""
452
+ Question: {question}
453
+ This question has an associated file. Here is the file content:
454
+ ```{file_ext}
455
+ {file_content}
456
+ ```
457
+ Analyze the file content above to answer the question.
458
+ """
459
+ except Exception as file_e:
460
+ context = f"""
461
+ Question: {question}
462
+ This question has an associated file at path: {task_file_path}
463
+ However, there was an error reading the file: {file_e}
464
+ You can still try to answer the question based on the information provided.
465
+ """
466
+
467
+ # Check for special cases that need specific formatting
468
+ # Reversed text questions
469
+ if question.startswith(".") or ".rewsna eht sa" in question:
470
+ context = f"""
471
+ This question appears to be in reversed text. Here's the reversed version:
472
+ {question[::-1]}
473
+ Now answer the question above. Remember to format your answer exactly as requested.
474
+ """
475
+
476
+ # Add a prompt to ensure precise answers
477
+ full_prompt = f"""{context}
478
+ When answering, provide ONLY the precise answer requested.
479
+ Do not include explanations, steps, reasoning, or additional text.
480
+ Be direct and specific. GAIA benchmark requires exact matching answers.
481
+ For example, if asked "What is the capital of France?", respond simply with "Paris".
482
+ """
483
+
484
+ # Run the agent with the question
485
+ answer = self.agent.run(full_prompt)
486
+
487
+ # Clean up the answer to ensure it's in the expected format
488
+ # Remove common prefixes that models often add
489
+ answer = self._clean_answer(answer)
490
+
491
+ if self.verbose:
492
+ print(f"Generated answer: {answer}")
493
+
494
+ return answer
495
+ except Exception as e:
496
+ error_msg = f"Error answering question: {e}"
497
+ if self.verbose:
498
+ print(error_msg)
499
+ return error_msg
500
+
501
+ def _clean_answer(self, answer: any) -> str:
502
+ """
503
+ Clean up the answer to remove common prefixes and formatting
504
+ that models often add but that can cause exact match failures.
505
+
506
+ Args:
507
+ answer: The raw answer from the model
508
+
509
+ Returns:
510
+ The cleaned answer as a string
511
+ """
512
+ # Convert non-string types to strings
513
+ if not isinstance(answer, str):
514
+ # Handle numeric types (float, int)
515
+ if isinstance(answer, float):
516
+ # Format floating point numbers properly
517
+ # Check if it's an integer value in float form (e.g., 12.0)
518
+ if answer.is_integer():
519
+ formatted_answer = str(int(answer))
520
+ else:
521
+ # For currency values that might need formatting
522
+ if abs(answer) >= 1000:
523
+ formatted_answer = f"${answer:,.2f}"
524
+ else:
525
+ formatted_answer = str(answer)
526
+ return formatted_answer
527
+ elif isinstance(answer, int):
528
+ return str(answer)
529
+ else:
530
+ # For any other type
531
+ return str(answer)
532
+
533
+ # Now we know answer is a string, so we can safely use string methods
534
+ # Normalize whitespace
535
+ answer = answer.strip()
536
+
537
+ # Remove common prefixes and formatting that models add
538
+ prefixes_to_remove = [
539
+ "The answer is ",
540
+ "Answer: ",
541
+ "Final answer: ",
542
+ "The result is ",
543
+ "To answer this question: ",
544
+ "Based on the information provided, ",
545
+ "According to the information: ",
546
+ ]
547
+
548
+ for prefix in prefixes_to_remove:
549
+ if answer.startswith(prefix):
550
+ answer = answer[len(prefix):].strip()
551
+
552
+ # Remove quotes if they wrap the entire answer
553
+ if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")):
554
+ answer = answer[1:-1].strip()
555
+
556
+ return answer