Sanyam0605 commited on
Commit
7c6bfd3
·
verified ·
1 Parent(s): d7f1382

Delete core_agent.py

Browse files
Files changed (1) hide show
  1. core_agent.py +0 -556
core_agent.py DELETED
@@ -1,556 +0,0 @@
1
- from smolagents import (
2
- CodeAgent,
3
- DuckDuckGoSearchTool,
4
- HfApiModel,
5
- OpenAIServerModel,
6
- PythonInterpreterTool,
7
- tool,
8
- InferenceClientModel,
9
- TransformersModel
10
- )
11
- from typing import List, Dict, Any, Optional
12
- import os
13
- import tempfile
14
- import re
15
- import json
16
- import requests
17
- from urllib.parse import urlparse
18
- import google.generativeai as genai
19
-
20
- class GeminiResponseWrapper:
21
- """A wrapper class to make Gemini API responses compatible with smolagents"""
22
- def __init__(self, text):
23
- self.text = text
24
- # This content property is what smolagents will access - must be a string for strip() to work
25
- self.content = text
26
-
27
- class GeminiModel:
28
- """Model implementation for Google's Gemini models"""
29
-
30
- def __init__(self, api_key, model_id="gemini-2.0-flash-lite", temperature=0.2):
31
- genai.configure(api_key=api_key)
32
- self.model = genai.GenerativeModel(model_name=model_id)
33
- self.temperature = temperature
34
-
35
- def __call__(self, messages, stop_sequences=None, **kwargs):
36
- try:
37
- # For Gemini, we'll just use the latest message's content
38
- # as that contains the full context we need
39
- if not messages or len(messages) == 0:
40
- return GeminiResponseWrapper("No messages provided")
41
-
42
- # Get the content from the latest message
43
- last_message = messages[-1]
44
- if "content" not in last_message:
45
- return GeminiResponseWrapper("Invalid message format")
46
-
47
- content = last_message["content"]
48
- if isinstance(content, list):
49
- prompt = ""
50
- for item in content:
51
- if isinstance(item, dict) and "text" in item:
52
- prompt += item["text"] + "\n"
53
- elif isinstance(item, str):
54
- prompt += item + "\n"
55
- else:
56
- prompt = str(content)
57
-
58
- # Generate the response with Gemini
59
- response = self.model.generate_content(prompt)
60
-
61
- # Return a wrapper object that has both .text and .content properties
62
- if hasattr(response, 'text'):
63
- return GeminiResponseWrapper(response.text)
64
- else:
65
- return GeminiResponseWrapper(str(response))
66
-
67
- except Exception as e:
68
- # Return error in the wrapper format
69
- return GeminiResponseWrapper(f"Error: {str(e)}")
70
-
71
- @tool
72
- def save_and_read_file(content: str, filename: Optional[str] = None) -> str:
73
- """
74
- Save content to a temporary file and return the path.
75
- Useful for processing files from the GAIA API.
76
-
77
- Args:
78
- content: The content to save to the file
79
- filename: Optional filename, will generate a random name if not provided
80
-
81
- Returns:
82
- Path to the saved file
83
- """
84
- temp_dir = tempfile.gettempdir()
85
- if filename is None:
86
- temp_file = tempfile.NamedTemporaryFile(delete=False)
87
- filepath = temp_file.name
88
- else:
89
- filepath = os.path.join(temp_dir, filename)
90
-
91
- # Write content to the file
92
- with open(filepath, 'w') as f:
93
- f.write(content)
94
-
95
- return f"File saved to {filepath}. You can read this file to process its contents."
96
-
97
- @tool
98
- def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
99
- """
100
- Download a file from a URL and save it to a temporary location.
101
-
102
- Args:
103
- url: The URL to download from
104
- filename: Optional filename, will generate one based on URL if not provided
105
-
106
- Returns:
107
- Path to the downloaded file
108
- """
109
- try:
110
- # Parse URL to get filename if not provided
111
- if not filename:
112
- path = urlparse(url).path
113
- filename = os.path.basename(path)
114
- if not filename:
115
- # Generate a random name if we couldn't extract one
116
- import uuid
117
- filename = f"downloaded_{uuid.uuid4().hex[:8]}"
118
-
119
- # Create temporary file
120
- temp_dir = tempfile.gettempdir()
121
- filepath = os.path.join(temp_dir, filename)
122
-
123
- # Download the file
124
- response = requests.get(url, stream=True)
125
- response.raise_for_status()
126
-
127
- # Save the file
128
- with open(filepath, 'wb') as f:
129
- for chunk in response.iter_content(chunk_size=8192):
130
- f.write(chunk)
131
-
132
- return f"File downloaded to {filepath}. You can now process this file."
133
- except Exception as e:
134
- return f"Error downloading file: {str(e)}"
135
-
136
- @tool
137
- def extract_text_from_image(image_path: str) -> str:
138
- """
139
- Extract text from an image using pytesseract (if available).
140
-
141
- Args:
142
- image_path: Path to the image file
143
-
144
- Returns:
145
- Extracted text or error message
146
- """
147
- try:
148
- # Try to import pytesseract
149
- import pytesseract
150
- from PIL import Image
151
-
152
- # Open the image
153
- image = Image.open(image_path)
154
-
155
- # Extract text
156
- text = pytesseract.image_to_string(image)
157
-
158
- return f"Extracted text from image:\n\n{text}"
159
- except ImportError:
160
- return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
161
- except Exception as e:
162
- return f"Error extracting text from image: {str(e)}"
163
-
164
- @tool
165
- def analyze_csv_file(file_path: str, query: str) -> str:
166
- """
167
- Analyze a CSV file using pandas and answer a question about it.
168
-
169
- Args:
170
- file_path: Path to the CSV file
171
- query: Question about the data
172
-
173
- Returns:
174
- Analysis result or error message
175
- """
176
- try:
177
- import pandas as pd
178
-
179
- # Read the CSV file
180
- df = pd.read_csv(file_path)
181
-
182
- # Run various analyses based on the query
183
- result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
184
- result += f"Columns: {', '.join(df.columns)}\n\n"
185
-
186
- # Add summary statistics
187
- result += "Summary statistics:\n"
188
- result += str(df.describe())
189
-
190
- return result
191
- except ImportError:
192
- return "Error: pandas is not installed. Please install it with 'pip install pandas'."
193
- except Exception as e:
194
- return f"Error analyzing CSV file: {str(e)}"
195
-
196
- @tool
197
- def analyze_excel_file(file_path: str, query: str) -> str:
198
- """
199
- Analyze an Excel file using pandas and answer a question about it.
200
-
201
- Args:
202
- file_path: Path to the Excel file
203
- query: Question about the data
204
-
205
- Returns:
206
- Analysis result or error message
207
- """
208
- try:
209
- import pandas as pd
210
-
211
- # Read the Excel file
212
- df = pd.read_excel(file_path)
213
-
214
- # Run various analyses based on the query
215
- result = f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
216
- result += f"Columns: {', '.join(df.columns)}\n\n"
217
-
218
- # Add summary statistics
219
- result += "Summary statistics:\n"
220
- result += str(df.describe())
221
-
222
- return result
223
- except ImportError:
224
- return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
225
- except Exception as e:
226
- return f"Error analyzing Excel file: {str(e)}"
227
-
228
- class GAIAAgent:
229
- def __init__(
230
- self,
231
- model_type: str = "HfApiModel",
232
- model_id: Optional[str] = None,
233
- api_key: Optional[str] = None,
234
- api_base: Optional[str] = None,
235
- temperature: float = 0.2,
236
- executor_type: str = "local", # Changed from use_e2b to executor_type
237
- additional_imports: List[str] = None,
238
- additional_tools: List[Any] = None,
239
- system_prompt: Optional[str] = None, # We'll still accept this parameter but not use it directly
240
- verbose: bool = False,
241
- provider: Optional[str] = None, # Add provider for InferenceClientModel
242
- timeout: Optional[int] = None # Add timeout for InferenceClientModel
243
- ):
244
- """
245
- Initialize a GAIAAgent with specified configuration
246
-
247
- Args:
248
- model_type: Type of model to use (HfApiModel, LiteLLMModel, OpenAIServerModel, InferenceClientModel)
249
- model_id: ID of the model to use
250
- api_key: API key for the model provider
251
- api_base: Base URL for API calls
252
- temperature: Temperature for text generation
253
- executor_type: Type of executor for code execution ('local' or 'e2b')
254
- additional_imports: Additional Python modules to allow importing
255
- additional_tools: Additional tools to provide to the agent
256
- system_prompt: Custom system prompt to use (not directly used, kept for backward compatibility)
257
- verbose: Enable verbose logging
258
- provider: Provider for InferenceClientModel (e.g., "hf-inference")
259
- timeout: Timeout in seconds for API calls
260
- """
261
- # Set verbosity
262
- self.verbose = verbose
263
- self.system_prompt = system_prompt # Store for potential future use
264
-
265
- # Initialize model based on configuration
266
- if model_type == "HfApiModel":
267
- if api_key is None:
268
- api_key = os.getenv("HF_API_TOKEN")
269
- if not api_key:
270
- raise ValueError("No Hugging Face token provided. Please set HF_API_TOKEN environment variable or pass api_key parameter.")
271
-
272
- if self.verbose:
273
- print(f"Using Hugging Face token: {api_key[:5]}...")
274
-
275
- self.model = HfApiModel(
276
- model_id=model_id or "meta-llama/Llama-3-70B-Instruct",
277
- token=api_key,
278
- temperature=temperature
279
- )
280
- elif model_type == "InferenceClientModel":
281
- if api_key is None:
282
- api_key = os.getenv("HF_API_TOKEN")
283
- if not api_key:
284
- raise ValueError("No Hugging Face token provided. Please set HUGGINGFACEHUB_API_TOKEN environment variable or pass api_key parameter.")
285
-
286
- if self.verbose:
287
- print(f"Using Hugging Face token: {api_key[:5]}...")
288
-
289
- self.model = InferenceClientModel(
290
- model_id=model_id or "meta-llama/Llama-3-70B-Instruct",
291
- provider=provider or "hf-inference",
292
- token=api_key,
293
- timeout=timeout or 120,
294
- temperature=temperature
295
- )
296
- elif model_type == "LiteLLMModel":
297
- from smolagents import LiteLLMModel
298
- self.model = LiteLLMModel(
299
- model_id=model_id or "gpt-4o",
300
- api_key=api_key or os.getenv("OPENAI_API_KEY"),
301
- temperature=temperature
302
- )
303
- elif model_type == "TransformersModel":
304
- if not model_id:
305
- raise ValueError("model_id is required for TransformersModel")
306
- self.model = TransformersModel(
307
- model_id=model_id,
308
- device_map="auto", # Let it automatically decide device placement
309
- trust_remote_code=True, # Needed for some models
310
- temperature=temperature
311
- )
312
- elif model_type == "GeminiModel":
313
- api_key = api_key or os.getenv("GOOGLE_API_KEY")
314
- if not api_key:
315
- raise ValueError("No Google API key provided. Please set GOOGLE_API_KEY environment variable or pass api_key parameter.")
316
- self.model = GeminiModel(
317
- api_key=api_key,
318
- model_id=model_id or "gemini-pro",
319
- temperature=temperature
320
- )
321
- elif model_type == "OpenAIServerModel":
322
- # Check for xAI API key and base URL first
323
- xai_api_key = os.getenv("XAI_API_KEY")
324
- xai_api_base = os.getenv("XAI_API_BASE")
325
-
326
- # If xAI credentials are available, use them
327
- if xai_api_key and api_key is None:
328
- api_key = xai_api_key
329
- if self.verbose:
330
- print(f"Using xAI API key: {api_key[:5]}...")
331
-
332
- # If no API key specified, fall back to OPENAI_API_KEY
333
- if api_key is None:
334
- api_key = os.getenv("OPENAI_API_KEY")
335
- if not api_key:
336
- raise ValueError("No OpenAI API key provided. Please set OPENAI_API_KEY or XAI_API_KEY environment variable or pass api_key parameter.")
337
-
338
- # If xAI API base is available and no api_base is provided, use it
339
- if xai_api_base and api_base is None:
340
- api_base = xai_api_base
341
- if self.verbose:
342
- print(f"Using xAI API base URL: {api_base}")
343
-
344
- # If no API base specified but environment variable available, use it
345
- if api_base is None:
346
- api_base = os.getenv("AGENT_API_BASE")
347
- if api_base and self.verbose:
348
- print(f"Using API base from AGENT_API_BASE: {api_base}")
349
-
350
- self.model = OpenAIServerModel(
351
- model_id=model_id or "gpt-4o",
352
- api_key=api_key,
353
- api_base=api_base,
354
- temperature=temperature
355
- )
356
- else:
357
- raise ValueError(f"Unknown model type: {model_type}")
358
-
359
- if self.verbose:
360
- print(f"Initialized model: {model_type} - {model_id}")
361
-
362
- # Initialize default tools
363
- self.tools = [
364
- DuckDuckGoSearchTool(),
365
- PythonInterpreterTool(),
366
- save_and_read_file,
367
- download_file_from_url,
368
- analyze_csv_file,
369
- analyze_excel_file
370
- ]
371
-
372
- # Add extract_text_from_image if PIL and pytesseract are available
373
- try:
374
- import pytesseract
375
- from PIL import Image
376
- self.tools.append(extract_text_from_image)
377
- if self.verbose:
378
- print("Added image processing tool")
379
- except ImportError:
380
- if self.verbose:
381
- print("Image processing libraries not available")
382
-
383
- # Add any additional tools
384
- if additional_tools:
385
- self.tools.extend(additional_tools)
386
-
387
- if self.verbose:
388
- print(f"Initialized with {len(self.tools)} tools")
389
-
390
- # Setup imports allowed
391
- self.imports = ["pandas", "numpy", "datetime", "json", "re", "math", "os", "requests", "csv", "urllib"]
392
- if additional_imports:
393
- self.imports.extend(additional_imports)
394
-
395
- # Initialize the CodeAgent
396
- executor_kwargs = {}
397
- if executor_type == "e2b":
398
- try:
399
- # Try to import e2b dependencies to check if they're available
400
- from e2b_code_interpreter import Sandbox
401
- if self.verbose:
402
- print("Using e2b executor")
403
- except ImportError:
404
- if self.verbose:
405
- print("e2b dependencies not found, falling back to local executor")
406
- executor_type = "local" # Fallback to local if e2b is not available
407
-
408
- self.agent = CodeAgent(
409
- tools=self.tools,
410
- model=self.model,
411
- additional_authorized_imports=self.imports,
412
- executor_type=executor_type,
413
- executor_kwargs=executor_kwargs,
414
- verbosity_level=2 if self.verbose else 0
415
- )
416
-
417
- if self.verbose:
418
- print("Agent initialized and ready")
419
-
420
- def answer_question(self, question: str, task_file_path: Optional[str] = None) -> str:
421
- """
422
- Process a GAIA benchmark question and return the answer
423
-
424
- Args:
425
- question: The question to answer
426
- task_file_path: Optional path to a file associated with the question
427
-
428
- Returns:
429
- The answer to the question
430
- """
431
- try:
432
- if self.verbose:
433
- print(f"Processing question: {question}")
434
- if task_file_path:
435
- print(f"With associated file: {task_file_path}")
436
-
437
- # Create a context with file information if available
438
- context = question
439
- file_content = None
440
-
441
- # If there's a file, read it and include its content in the context
442
- if task_file_path:
443
- try:
444
- with open(task_file_path, 'r') as f:
445
- file_content = f.read()
446
-
447
- # Determine file type from extension
448
- import os
449
- file_ext = os.path.splitext(task_file_path)[1].lower()
450
-
451
- context = f"""
452
- Question: {question}
453
- This question has an associated file. Here is the file content:
454
- ```{file_ext}
455
- {file_content}
456
- ```
457
- Analyze the file content above to answer the question.
458
- """
459
- except Exception as file_e:
460
- context = f"""
461
- Question: {question}
462
- This question has an associated file at path: {task_file_path}
463
- However, there was an error reading the file: {file_e}
464
- You can still try to answer the question based on the information provided.
465
- """
466
-
467
- # Check for special cases that need specific formatting
468
- # Reversed text questions
469
- if question.startswith(".") or ".rewsna eht sa" in question:
470
- context = f"""
471
- This question appears to be in reversed text. Here's the reversed version:
472
- {question[::-1]}
473
- Now answer the question above. Remember to format your answer exactly as requested.
474
- """
475
-
476
- # Add a prompt to ensure precise answers
477
- full_prompt = f"""{context}
478
- When answering, provide ONLY the precise answer requested.
479
- Do not include explanations, steps, reasoning, or additional text.
480
- Be direct and specific. GAIA benchmark requires exact matching answers.
481
- For example, if asked "What is the capital of France?", respond simply with "Paris".
482
- """
483
-
484
- # Run the agent with the question
485
- answer = self.agent.run(full_prompt)
486
-
487
- # Clean up the answer to ensure it's in the expected format
488
- # Remove common prefixes that models often add
489
- answer = self._clean_answer(answer)
490
-
491
- if self.verbose:
492
- print(f"Generated answer: {answer}")
493
-
494
- return answer
495
- except Exception as e:
496
- error_msg = f"Error answering question: {e}"
497
- if self.verbose:
498
- print(error_msg)
499
- return error_msg
500
-
501
- def _clean_answer(self, answer: any) -> str:
502
- """
503
- Clean up the answer to remove common prefixes and formatting
504
- that models often add but that can cause exact match failures.
505
-
506
- Args:
507
- answer: The raw answer from the model
508
-
509
- Returns:
510
- The cleaned answer as a string
511
- """
512
- # Convert non-string types to strings
513
- if not isinstance(answer, str):
514
- # Handle numeric types (float, int)
515
- if isinstance(answer, float):
516
- # Format floating point numbers properly
517
- # Check if it's an integer value in float form (e.g., 12.0)
518
- if answer.is_integer():
519
- formatted_answer = str(int(answer))
520
- else:
521
- # For currency values that might need formatting
522
- if abs(answer) >= 1000:
523
- formatted_answer = f"${answer:,.2f}"
524
- else:
525
- formatted_answer = str(answer)
526
- return formatted_answer
527
- elif isinstance(answer, int):
528
- return str(answer)
529
- else:
530
- # For any other type
531
- return str(answer)
532
-
533
- # Now we know answer is a string, so we can safely use string methods
534
- # Normalize whitespace
535
- answer = answer.strip()
536
-
537
- # Remove common prefixes and formatting that models add
538
- prefixes_to_remove = [
539
- "The answer is ",
540
- "Answer: ",
541
- "Final answer: ",
542
- "The result is ",
543
- "To answer this question: ",
544
- "Based on the information provided, ",
545
- "According to the information: ",
546
- ]
547
-
548
- for prefix in prefixes_to_remove:
549
- if answer.startswith(prefix):
550
- answer = answer[len(prefix):].strip()
551
-
552
- # Remove quotes if they wrap the entire answer
553
- if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")):
554
- answer = answer[1:-1].strip()
555
-
556
- return answer