rakesh-dvg commited on
Commit
504fce3
·
verified ·
1 Parent(s): f56eacc

Delete main_agent.py

Browse files
Files changed (1) hide show
  1. main_agent.py +0 -492
main_agent.py DELETED
@@ -1,492 +0,0 @@
1
- from smolagents import (
2
- CodeAgent,
3
- DuckDuckGoSearchTool,
4
- HfApiModel,
5
- LiteLLMModel,
6
- OpenAIServerModel,
7
- PythonInterpreterTool,
8
- tool,
9
- InferenceClientModel
10
- )
11
- from typing import List, Dict, Any, Optional
12
- import os
13
- import tempfile
14
- import re
15
- import json
16
- import requests
17
- from urllib.parse import urlparse
18
-
19
- @tool
20
- def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
21
- """
22
- Save content to a temporary file and return the path.
23
- Useful for processing files from the GAIA API.
24
-
25
- Args:
26
- content: The content to save to the file
27
- filename: Optional filename, will generate a random name if not provided
28
-
29
- Returns:
30
- Path to the saved file
31
- """
32
- temp_dir = tempfile.gettempdir()
33
- if filename is None:
34
- temp_file = tempfile.NamedTemporaryFile(delete=False)
35
- filepath = temp_file.name
36
- else:
37
- filepath = os.path.join(temp_dir, filename)
38
-
39
- # Write content to the file
40
- with open(filepath, 'w') as f:
41
- f.write(content)
42
-
43
- return f"File saved to {filepath}. You can read this file to process its contents."
44
-
45
- @tool
46
- def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
47
- """
48
- Download a file from a URL and save it to a temporary location.
49
-
50
- Args:
51
- url: The URL to download from
52
- filename: Optional filename, will generate one based on URL if not provided
53
-
54
- Returns:
55
- Path to the downloaded file
56
- """
57
- try:
58
- # Parse URL to get filename if not provided
59
- if not filename:
60
- path = urlparse(url).path
61
- filename = os.path.basename(path)
62
- if not filename:
63
- # Generate a random name if we couldn't extract one
64
- import uuid
65
- filename = f"downloaded_{uuid.uuid4().hex[:8]}"
66
-
67
- # Create temporary file
68
- temp_dir = tempfile.gettempdir()
69
- filepath = os.path.join(temp_dir, filename)
70
-
71
- # Download the file
72
- response = requests.get(url, stream=True)
73
- response.raise_for_status()
74
-
75
- # Save the file
76
- with open(filepath, 'wb') as f:
77
- for chunk in response.iter_content(chunk_size=8192):
78
- f.write(chunk)
79
-
80
- return f"File downloaded to {filepath}. You can now process this file."
81
- except Exception as e:
82
- return f"Error downloading file: {str(e)}"
83
-
84
- @tool
85
- def extract_text_from_image(image_path: str) -> str:
86
- """
87
- Extract text from an image using pytesseract (if available).
88
-
89
- Args:
90
- image_path: Path to the image file
91
-
92
- Returns:
93
- Extracted text or error message
94
- """
95
- try:
96
- # Try to import pytesseract
97
- import pytesseract
98
- from PIL import Image
99
-
100
- # Open the image
101
- image = Image.open(image_path)
102
-
103
- # Extract text
104
- text = pytesseract.image_to_string(image)
105
-
106
- return f"Extracted text from image:\n\n{text}"
107
- except ImportError:
108
- return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
109
- except Exception as e:
110
- return f"Error extracting text from image: {str(e)}"
111
-
112
- @tool
113
- def analyze_csv_file(file_path: str, query: str) -> str:
114
- """
115
- Analyze a CSV file using pandas and answer a question about it.
116
-
117
- Args:
118
- file_path: Path to the CSV file
119
- query: Question about the data
120
-
121
- Returns:
122
- Analysis result or error message
123
- """
124
- try:
125
- import pandas as pd
126
-
127
- # Read the CSV file
128
- df = pd.read_csv(file_path)
129
-
130
- # Run various analyses based on the query
131
- result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
132
- result += f"Columns: {', '.join(df.columns)}\n\n"
133
-
134
- # Add summary statistics
135
- result += "Summary statistics:\n"
136
- result += str(df.describe())
137
-
138
- return result
139
- except ImportError:
140
- return "Error: pandas is not installed. Please install it with 'pip install pandas'."
141
- except Exception as e:
142
- return f"Error analyzing CSV file: {str(e)}"
143
-
144
- @tool
145
- def analyze_excel_file(file_path: str, query: str) -> str:
146
- """
147
- Analyze an Excel file using pandas and answer a question about it.
148
-
149
- Args:
150
- file_path: Path to the Excel file
151
- query: Question about the data
152
-
153
- Returns:
154
- Analysis result or error message
155
- """
156
- try:
157
- import pandas as pd
158
-
159
- # Read the Excel file
160
- df = pd.read_excel(file_path)
161
-
162
- # Run various analyses based on the query
163
- result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
164
- result += f"Columns: {', '.join(df.columns)}\n\n"
165
-
166
- # Add summary statistics
167
- result += "Summary statistics:\n"
168
- result += str(df.describe())
169
-
170
- return result
171
- except ImportError:
172
- return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
173
- except Exception as e:
174
- return f"Error analyzing Excel file: {str(e)}"
175
-
176
- class GAIAAgent:
177
- def __init__(
178
- self,
179
- model_type: str = "HfApiModel",
180
- model_id: Optional[str] = None,
181
- api_key: Optional[str] = None,
182
- api_base: Optional[str] = None,
183
- temperature: float = 0.2,
184
- executor_type: str = "local", # Changed from use_e2b to executor_type
185
- additional_imports: List[str] = None,
186
- additional_tools: List[Any] = None,
187
- system_prompt: Optional[str] = None, # We'll still accept this parameter but not use it directly
188
- verbose: bool = False,
189
- provider: Optional[str] = None, # Add provider for InferenceClientModel
190
- timeout: Optional[int] = None # Add timeout for InferenceClientModel
191
- ):
192
- """
193
- Initialize a GAIAAgent with specified configuration
194
-
195
- Args:
196
- model_type: Type of model to use (HfApiModel, LiteLLMModel, OpenAIServerModel, InferenceClientModel)
197
- model_id: ID of the model to use
198
- api_key: API key for the model provider
199
- api_base: Base URL for API calls
200
- temperature: Temperature for text generation
201
- executor_type: Type of executor for code execution ('local' or 'e2b')
202
- additional_imports: Additional Python modules to allow importing
203
- additional_tools: Additional tools to provide to the agent
204
- system_prompt: Custom system prompt to use (not directly used, kept for backward compatibility)
205
- verbose: Enable verbose logging
206
- provider: Provider for InferenceClientModel (e.g., "hf-inference")
207
- timeout: Timeout in seconds for API calls
208
- """
209
- # Set verbosity
210
- self.verbose = verbose
211
- self.system_prompt = system_prompt # Store for potential future use
212
-
213
- # Initialize model based on configuration
214
- if model_type == "HfApiModel":
215
- if api_key is None:
216
- api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN")
217
- if not api_key:
218
- raise ValueError("No Hugging Face token provided. Please set HUGGINGFACEHUB_API_TOKEN environment variable or pass api_key parameter.")
219
-
220
- if self.verbose:
221
- print(f"Using Hugging Face token: {api_key[:5]}...")
222
-
223
- self.model = HfApiModel(
224
- model_id=model_id or "meta-llama/Llama-3-70B-Instruct",
225
- token=api_key,
226
- temperature=temperature
227
- )
228
- elif model_type == "InferenceClientModel":
229
- if api_key is None:
230
- api_key = os.getenv("HUGGINGFACEHUB_API_TOKEN")
231
- if not api_key:
232
- raise ValueError("No Hugging Face token provided. Please set HUGGINGFACEHUB_API_TOKEN environment variable or pass api_key parameter.")
233
-
234
- if self.verbose:
235
- print(f"Using Hugging Face token: {api_key[:5]}...")
236
-
237
- self.model = InferenceClientModel(
238
- model_id=model_id or "meta-llama/Llama-3-70B-Instruct",
239
- provider=provider or "hf-inference",
240
- token=api_key,
241
- timeout=timeout or 120,
242
- temperature=temperature
243
- )
244
- elif model_type == "LiteLLMModel":
245
- from smolagents import LiteLLMModel
246
- self.model = LiteLLMModel(
247
- model_id=model_id or "gpt-4o",
248
- api_key=api_key or os.getenv("OPENAI_API_KEY"),
249
- temperature=temperature
250
- )
251
- elif model_type == "OpenAIServerModel":
252
- # Check for xAI API key and base URL first
253
- xai_api_key = os.getenv("XAI_API_KEY")
254
- xai_api_base = os.getenv("XAI_API_BASE")
255
-
256
- # If xAI credentials are available, use them
257
- if xai_api_key and api_key is None:
258
- api_key = xai_api_key
259
- if self.verbose:
260
- print(f"Using xAI API key: {api_key[:5]}...")
261
-
262
- # If no API key specified, fall back to OPENAI_API_KEY
263
- if api_key is None:
264
- api_key = os.getenv("OPENAI_API_KEY")
265
- if not api_key:
266
- raise ValueError("No OpenAI API key provided. Please set OPENAI_API_KEY or XAI_API_KEY environment variable or pass api_key parameter.")
267
-
268
- # If xAI API base is available and no api_base is provided, use it
269
- if xai_api_base and api_base is None:
270
- api_base = xai_api_base
271
- if self.verbose:
272
- print(f"Using xAI API base URL: {api_base}")
273
-
274
- # If no API base specified but environment variable available, use it
275
- if api_base is None:
276
- api_base = os.getenv("AGENT_API_BASE")
277
- if api_base and self.verbose:
278
- print(f"Using API base from AGENT_API_BASE: {api_base}")
279
-
280
- self.model = OpenAIServerModel(
281
- model_id=model_id or "gpt-4o",
282
- api_key=api_key,
283
- api_base=api_base,
284
- temperature=temperature
285
- )
286
- else:
287
- raise ValueError(f"Unknown model type: {model_type}")
288
-
289
- if self.verbose:
290
- print(f"Initialized model: {model_type} - {model_id}")
291
-
292
- # Initialize default tools
293
- self.tools = [
294
- DuckDuckGoSearchTool(),
295
- PythonInterpreterTool(),
296
- save_and_read_file,
297
- download_file_from_url,
298
- analyze_csv_file,
299
- analyze_excel_file
300
- ]
301
-
302
- # Add extract_text_from_image if PIL and pytesseract are available
303
- try:
304
- import pytesseract
305
- from PIL import Image
306
- self.tools.append(extract_text_from_image)
307
- if self.verbose:
308
- print("Added image processing tool")
309
- except ImportError:
310
- if self.verbose:
311
- print("Image processing libraries not available")
312
-
313
- # Add any additional tools
314
- if additional_tools:
315
- self.tools.extend(additional_tools)
316
-
317
- if self.verbose:
318
- print(f"Initialized with {len(self.tools)} tools")
319
-
320
- # Setup imports allowed
321
- self.imports = ["pandas", "numpy", "datetime", "json", "re", "math", "os", "requests", "csv", "urllib"]
322
- if additional_imports:
323
- self.imports.extend(additional_imports)
324
-
325
- # Initialize the CodeAgent
326
- executor_kwargs = {}
327
- if executor_type == "e2b":
328
- try:
329
- # Try to import e2b dependencies to check if they're available
330
- from e2b_code_interpreter import Sandbox
331
- if self.verbose:
332
- print("Using e2b executor")
333
- except ImportError:
334
- if self.verbose:
335
- print("e2b dependencies not found, falling back to local executor")
336
- executor_type = "local" # Fallback to local if e2b is not available
337
-
338
- self.agent = CodeAgent(
339
- tools=self.tools,
340
- model=self.model,
341
- additional_authorized_imports=self.imports,
342
- executor_type=executor_type,
343
- executor_kwargs=executor_kwargs,
344
- verbosity_level=2 if self.verbose else 0
345
- )
346
-
347
- if self.verbose:
348
- print("Agent initialized and ready")
349
-
350
- def answer_question(self, question: str, task_file_path: Optional[str] = None) -> str:
351
- """
352
- Process a GAIA benchmark question and return the answer
353
-
354
- Args:
355
- question: The question to answer
356
- task_file_path: Optional path to a file associated with the question
357
-
358
- Returns:
359
- The answer to the question
360
- """
361
- try:
362
- if self.verbose:
363
- print(f"Processing question: {question}")
364
- if task_file_path:
365
- print(f"With associated file: {task_file_path}")
366
-
367
- # Create a context with file information if available
368
- context = question
369
- file_content = None
370
-
371
- # If there's a file, read it and include its content in the context
372
- if task_file_path:
373
- try:
374
- with open(task_file_path, 'r') as f:
375
- file_content = f.read()
376
-
377
- # Determine file type from extension
378
- import os
379
- file_ext = os.path.splitext(task_file_path)[1].lower()
380
-
381
- context = f"""
382
- Question: {question}
383
-
384
- This question has an associated file. Here is the file content:
385
-
386
- ```{file_ext}
387
- {file_content}
388
- ```
389
-
390
- Analyze the file content above to answer the question.
391
- """
392
- except Exception as file_e:
393
- context = f"""
394
- Question: {question}
395
-
396
- This question has an associated file at path: {task_file_path}
397
- However, there was an error reading the file: {file_e}
398
- You can still try to answer the question based on the information provided.
399
- """
400
-
401
- # Check for special cases that need specific formatting
402
- # Reversed text questions
403
- if question.startswith(".") or ".rewsna eht sa" in question:
404
- context = f"""
405
- This question appears to be in reversed text. Here's the reversed version:
406
- {question[::-1]}
407
-
408
- Now answer the question above. Remember to format your answer exactly as requested.
409
- """
410
-
411
- # Add a prompt to ensure precise answers
412
- full_prompt = f"""{context}
413
-
414
- When answering, provide ONLY the precise answer requested.
415
- Do not include explanations, steps, reasoning, or additional text.
416
- Be direct and specific. GAIA benchmark requires exact matching answers.
417
- For example, if asked "What is the capital of France?", respond simply with "Paris".
418
- """
419
-
420
- # Run the agent with the question
421
- answer = self.agent.run(full_prompt)
422
-
423
- # Clean up the answer to ensure it's in the expected format
424
- # Remove common prefixes that models often add
425
- answer = self._clean_answer(answer)
426
-
427
- if self.verbose:
428
- print(f"Generated answer: {answer}")
429
-
430
- return answer
431
- except Exception as e:
432
- error_msg = f"Error answering question: {e}"
433
- if self.verbose:
434
- print(error_msg)
435
- return error_msg
436
-
437
- def _clean_answer(self, answer: any) -> str:
438
- """
439
- Clean up the answer to remove common prefixes and formatting
440
- that models often add but that can cause exact match failures.
441
-
442
- Args:
443
- answer: The raw answer from the model
444
-
445
- Returns:
446
- The cleaned answer as a string
447
- """
448
- # Convert non-string types to strings
449
- if not isinstance(answer, str):
450
- # Handle numeric types (float, int)
451
- if isinstance(answer, float):
452
- # Format floating point numbers properly
453
- # Check if it's an integer value in float form (e.g., 12.0)
454
- if answer.is_integer():
455
- formatted_answer = str(int(answer))
456
- else:
457
- # For currency values that might need formatting
458
- if abs(answer) >= 1000:
459
- formatted_answer = f"${answer:,.2f}"
460
- else:
461
- formatted_answer = str(answer)
462
- return formatted_answer
463
- elif isinstance(answer, int):
464
- return str(answer)
465
- else:
466
- # For any other type
467
- return str(answer)
468
-
469
- # Now we know answer is a string, so we can safely use string methods
470
- # Normalize whitespace
471
- answer = answer.strip()
472
-
473
- # Remove common prefixes and formatting that models add
474
- prefixes_to_remove = [
475
- "The answer is ",
476
- "Answer: ",
477
- "Final answer: ",
478
- "The result is ",
479
- "To answer this question: ",
480
- "Based on the information provided, ",
481
- "According to the information: ",
482
- ]
483
-
484
- for prefix in prefixes_to_remove:
485
- if answer.startswith(prefix):
486
- answer = answer[len(prefix):].strip()
487
-
488
- # Remove quotes if they wrap the entire answer
489
- if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")):
490
- answer = answer[1:-1].strip()
491
-
492
- return answer