LuisZermeno commited on
Commit
f64ef80
·
verified ·
1 Parent(s): a6d07ff

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +124 -29
tools.py CHANGED
@@ -3,7 +3,7 @@ import re
3
  import json
4
  import base64
5
  import requests
6
- import wikipedia
7
  import numpy as np
8
  import pandas as pd
9
  from typing import Dict, Any, List, Optional, Union
@@ -22,6 +22,9 @@ import logging
22
 
23
  logger = logging.getLogger(__name__)
24
 
 
 
 
25
  # Tool implementations
26
 
27
  def web_search_tool(query: str, num_results: int = 5) -> str:
@@ -50,26 +53,41 @@ def web_search_tool(query: str, num_results: int = 5) -> str:
50
  def wikipedia_tool(query: str) -> str:
51
  """Search and get content from Wikipedia"""
52
  try:
53
- # Try direct page first
54
- try:
55
- page = wikipedia.page(query)
56
- return f"Title: {page.title}\n\nSummary: {page.summary[:1000]}...\n\nURL: {page.url}"
57
- except wikipedia.exceptions.DisambiguationError as e:
58
- # If ambiguous, try first option
59
- if e.options:
60
- page = wikipedia.page(e.options[0])
61
- return f"Title: {page.title}\n\nSummary: {page.summary[:1000]}...\n\nURL: {page.url}"
62
- except wikipedia.exceptions.PageError:
63
- # If page not found, search
64
- search_results = wikipedia.search(query, results=5)
65
- if search_results:
66
- page = wikipedia.page(search_results[0])
67
- return f"Title: {page.title}\n\nSummary: {page.summary[:1000]}...\n\nURL: {page.url}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  return "No Wikipedia results found."
70
  except Exception as e:
71
  logger.error(f"Wikipedia error: {str(e)}")
72
- return f"Wikipedia search failed: {str(e)}"
 
73
 
74
  def calculator_tool(expression: str) -> str:
75
  """Evaluate mathematical expressions safely"""
@@ -87,9 +105,9 @@ def calculator_tool(expression: str) -> str:
87
  node = ast.parse(expression, mode='eval')
88
 
89
  # Safety check
90
- for node in ast.walk(node):
91
- if isinstance(node, ast.Name) and node.id not in allowed_names:
92
- raise ValueError(f"Unsafe operation: {node.id}")
93
 
94
  result = eval(compile(ast.parse(expression, mode='eval'), '<string>', 'eval'),
95
  {"__builtins__": {}}, allowed_names)
@@ -135,12 +153,21 @@ def python_repl_tool(code: str) -> str:
135
  def image_analysis_tool(image_path: str, query: str = "") -> str:
136
  """Analyze images using OCR and basic computer vision"""
137
  try:
 
138
  if image_path.startswith('data:'):
139
- # Handle base64 encoded images
140
  header, encoded = image_path.split(',', 1)
141
  data = base64.b64decode(encoded)
142
  image = Image.open(io.BytesIO(data))
143
  else:
 
 
 
 
 
 
 
 
 
144
  image = Image.open(image_path)
145
 
146
  # Perform OCR
@@ -173,6 +200,18 @@ def image_analysis_tool(image_path: str, query: str = "") -> str:
173
  def file_reader_tool(file_path: str, query: str = "") -> str:
174
  """Read and analyze various file types"""
175
  try:
 
 
 
 
 
 
 
 
 
 
 
 
176
  file_ext = os.path.splitext(file_path)[1].lower()
177
 
178
  if file_ext in ['.txt', '.md', '.py', '.json', '.xml', '.html']:
@@ -181,11 +220,37 @@ def file_reader_tool(file_path: str, query: str = "") -> str:
181
  return f"File content:\n{content[:2000]}{'...' if len(content) > 2000 else ''}"
182
 
183
  elif file_ext in ['.csv']:
184
- df = pd.read_csv(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  info = f"CSV file with {len(df)} rows and {len(df.columns)} columns.\n"
186
  info += f"Columns: {', '.join(df.columns)}\n\n"
187
  info += f"First 5 rows:\n{df.head().to_string()}\n\n"
188
  info += f"Data types:\n{df.dtypes.to_string()}"
 
 
 
 
 
 
 
 
189
  return info
190
 
191
  elif file_ext in ['.xlsx', '.xls']:
@@ -213,6 +278,14 @@ def audio_analysis_tool(audio_path: str) -> str:
213
  try:
214
  recognizer = sr.Recognizer()
215
 
 
 
 
 
 
 
 
 
216
  # Convert to WAV if needed
217
  if not audio_path.endswith('.wav'):
218
  audio = AudioSegment.from_file(audio_path)
@@ -235,7 +308,7 @@ def audio_analysis_tool(audio_path: str) -> str:
235
  result = f"Speech recognition error: {str(e)}"
236
 
237
  # Clean up temp file
238
- if wav_path != audio_path:
239
  os.unlink(wav_path)
240
 
241
  return result
@@ -247,6 +320,14 @@ def audio_analysis_tool(audio_path: str) -> str:
247
  def data_analysis_tool(file_path: str, operation: str, **kwargs) -> str:
248
  """Perform data analysis operations on CSV/Excel files"""
249
  try:
 
 
 
 
 
 
 
 
250
  # Load data
251
  if file_path.endswith('.csv'):
252
  df = pd.read_csv(file_path)
@@ -256,22 +337,29 @@ def data_analysis_tool(file_path: str, operation: str, **kwargs) -> str:
256
  # Perform requested operation
257
  if operation == "sum":
258
  column = kwargs.get('column')
259
- if column:
260
  result = df[column].sum()
261
  return f"Sum of {column}: {result}"
 
262
 
263
  elif operation == "mean":
264
  column = kwargs.get('column')
265
- if column:
266
  result = df[column].mean()
267
  return f"Mean of {column}: {result}"
 
268
 
269
  elif operation == "count":
270
  column = kwargs.get('column')
271
  value = kwargs.get('value')
272
- if column and value:
273
- result = len(df[df[column] == value])
274
- return f"Count of {column}={value}: {result}"
 
 
 
 
 
275
 
276
  elif operation == "groupby":
277
  group_column = kwargs.get('group_column')
@@ -280,16 +368,23 @@ def data_analysis_tool(file_path: str, operation: str, **kwargs) -> str:
280
  if group_column and agg_column:
281
  result = df.groupby(group_column)[agg_column].agg(agg_func)
282
  return f"Grouped results:\n{result.to_string()}"
 
283
 
284
  elif operation == "filter":
285
  condition = kwargs.get('condition')
286
  if condition:
287
  filtered_df = df.query(condition)
288
  return f"Filtered data ({len(filtered_df)} rows):\n{filtered_df.head().to_string()}"
 
289
 
290
  elif operation == "describe":
291
  return f"Data description:\n{df.describe().to_string()}"
292
 
 
 
 
 
 
293
  return "Operation not recognized or missing parameters."
294
 
295
  except Exception as e:
@@ -374,7 +469,7 @@ tool_schemas = {
374
  "type": "object",
375
  "properties": {
376
  "file_path": {"type": "string", "description": "Path to data file"},
377
- "operation": {"type": "string", "description": "Operation: sum, mean, count, groupby, filter, describe"},
378
  "kwargs": {"type": "object", "description": "Additional parameters for the operation"}
379
  },
380
  "required": ["file_path", "operation"]
 
3
  import json
4
  import base64
5
  import requests
6
+ import wikipediaapi
7
  import numpy as np
8
  import pandas as pd
9
  from typing import Dict, Any, List, Optional, Union
 
22
 
23
  logger = logging.getLogger(__name__)
24
 
25
+ # Initialize Wikipedia API
26
+ wiki_wiki = wikipediaapi.Wikipedia('GAIA-Agent/1.0', 'en')
27
+
28
  # Tool implementations
29
 
30
  def web_search_tool(query: str, num_results: int = 5) -> str:
 
53
  def wikipedia_tool(query: str) -> str:
54
  """Search and get content from Wikipedia"""
55
  try:
56
+ # Try to get page directly
57
+ page = wiki_wiki.page(query)
58
+
59
+ if page.exists():
60
+ # Get summary (first 1000 characters)
61
+ summary = page.summary[:1000] if len(page.summary) > 1000 else page.summary
62
+ return f"Title: {page.title}\n\nSummary: {summary}...\n\nURL: {page.fullurl}"
63
+ else:
64
+ # Search for pages
65
+ from duckduckgo_search import DDGS
66
+ ddgs = DDGS()
67
+ search_query = f"site:wikipedia.org {query}"
68
+ results = list(ddgs.text(search_query, max_results=3))
69
+
70
+ if results:
71
+ # Try to extract Wikipedia page title from first result
72
+ first_result = results[0]
73
+ if 'wikipedia.org/wiki/' in first_result['link']:
74
+ page_title = first_result['link'].split('/wiki/')[-1].replace('_', ' ')
75
+ page = wiki_wiki.page(page_title)
76
+ if page.exists():
77
+ summary = page.summary[:1000] if len(page.summary) > 1000 else page.summary
78
+ return f"Title: {page.title}\n\nSummary: {summary}...\n\nURL: {page.fullurl}"
79
+
80
+ # Return search results if can't get page
81
+ formatted_results = []
82
+ for result in results:
83
+ formatted_results.append(f"- {result['title']}: {result['body'][:200]}...")
84
+ return "Wikipedia search results:\n" + "\n".join(formatted_results)
85
 
86
  return "No Wikipedia results found."
87
  except Exception as e:
88
  logger.error(f"Wikipedia error: {str(e)}")
89
+ # Fallback to web search
90
+ return web_search_tool(f"site:wikipedia.org {query}", num_results=3)
91
 
92
  def calculator_tool(expression: str) -> str:
93
  """Evaluate mathematical expressions safely"""
 
105
  node = ast.parse(expression, mode='eval')
106
 
107
  # Safety check
108
+ for n in ast.walk(node):
109
+ if isinstance(n, ast.Name) and n.id not in allowed_names:
110
+ raise ValueError(f"Unsafe operation: {n.id}")
111
 
112
  result = eval(compile(ast.parse(expression, mode='eval'), '<string>', 'eval'),
113
  {"__builtins__": {}}, allowed_names)
 
153
  def image_analysis_tool(image_path: str, query: str = "") -> str:
154
  """Analyze images using OCR and basic computer vision"""
155
  try:
156
+ # Handle base64 encoded images
157
  if image_path.startswith('data:'):
 
158
  header, encoded = image_path.split(',', 1)
159
  data = base64.b64decode(encoded)
160
  image = Image.open(io.BytesIO(data))
161
  else:
162
+ # Check if file exists in uploaded files
163
+ uploaded_files = json.loads(os.environ.get("UPLOADED_FILES", "[]"))
164
+ if uploaded_files and not os.path.exists(image_path):
165
+ # Try to find the file in uploaded files
166
+ for file_path in uploaded_files:
167
+ if os.path.basename(file_path) == os.path.basename(image_path):
168
+ image_path = file_path
169
+ break
170
+
171
  image = Image.open(image_path)
172
 
173
  # Perform OCR
 
200
  def file_reader_tool(file_path: str, query: str = "") -> str:
201
  """Read and analyze various file types"""
202
  try:
203
+ # Check uploaded files
204
+ uploaded_files = json.loads(os.environ.get("UPLOADED_FILES", "[]"))
205
+ if uploaded_files and not os.path.exists(file_path):
206
+ # Try to find the file in uploaded files
207
+ for uploaded_path in uploaded_files:
208
+ if os.path.basename(uploaded_path) == os.path.basename(file_path):
209
+ file_path = uploaded_path
210
+ break
211
+
212
+ if not os.path.exists(file_path):
213
+ return f"File not found: {file_path}"
214
+
215
  file_ext = os.path.splitext(file_path)[1].lower()
216
 
217
  if file_ext in ['.txt', '.md', '.py', '.json', '.xml', '.html']:
 
220
  return f"File content:\n{content[:2000]}{'...' if len(content) > 2000 else ''}"
221
 
222
  elif file_ext in ['.csv']:
223
+ # Try multiple encodings and delimiters
224
+ encodings = ['utf-8', 'latin1', 'iso-8859-1', 'cp1252']
225
+ delimiters = [',', ';', '\t', '|']
226
+
227
+ df = None
228
+ for encoding in encodings:
229
+ for delimiter in delimiters:
230
+ try:
231
+ df = pd.read_csv(file_path, encoding=encoding, delimiter=delimiter)
232
+ if len(df.columns) > 1: # Successful parse
233
+ break
234
+ except:
235
+ continue
236
+ if df is not None and len(df.columns) > 1:
237
+ break
238
+
239
+ if df is None:
240
+ return "Failed to parse CSV file with multiple encoding/delimiter attempts"
241
+
242
  info = f"CSV file with {len(df)} rows and {len(df.columns)} columns.\n"
243
  info += f"Columns: {', '.join(df.columns)}\n\n"
244
  info += f"First 5 rows:\n{df.head().to_string()}\n\n"
245
  info += f"Data types:\n{df.dtypes.to_string()}"
246
+
247
+ # Check for date columns and analyze if query mentions time
248
+ if query and any(word in query.lower() for word in ['month', 'year', 'date', 'january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december']):
249
+ from search_strategies import DataAnalysisStrategy
250
+ temporal_result = DataAnalysisStrategy.analyze_for_temporal_data(df, query)
251
+ if temporal_result is not None:
252
+ info += f"\n\nTemporal analysis result:\n{temporal_result.head(10).to_string()}"
253
+
254
  return info
255
 
256
  elif file_ext in ['.xlsx', '.xls']:
 
278
  try:
279
  recognizer = sr.Recognizer()
280
 
281
+ # Check uploaded files
282
+ uploaded_files = json.loads(os.environ.get("UPLOADED_FILES", "[]"))
283
+ if uploaded_files and not os.path.exists(audio_path):
284
+ for uploaded_path in uploaded_files:
285
+ if os.path.basename(uploaded_path) == os.path.basename(audio_path):
286
+ audio_path = uploaded_path
287
+ break
288
+
289
  # Convert to WAV if needed
290
  if not audio_path.endswith('.wav'):
291
  audio = AudioSegment.from_file(audio_path)
 
308
  result = f"Speech recognition error: {str(e)}"
309
 
310
  # Clean up temp file
311
+ if wav_path != audio_path and os.path.exists(wav_path):
312
  os.unlink(wav_path)
313
 
314
  return result
 
320
  def data_analysis_tool(file_path: str, operation: str, **kwargs) -> str:
321
  """Perform data analysis operations on CSV/Excel files"""
322
  try:
323
+ # Check uploaded files
324
+ uploaded_files = json.loads(os.environ.get("UPLOADED_FILES", "[]"))
325
+ if uploaded_files and not os.path.exists(file_path):
326
+ for uploaded_path in uploaded_files:
327
+ if os.path.basename(uploaded_path) == os.path.basename(file_path):
328
+ file_path = uploaded_path
329
+ break
330
+
331
  # Load data
332
  if file_path.endswith('.csv'):
333
  df = pd.read_csv(file_path)
 
337
  # Perform requested operation
338
  if operation == "sum":
339
  column = kwargs.get('column')
340
+ if column and column in df.columns:
341
  result = df[column].sum()
342
  return f"Sum of {column}: {result}"
343
+ return f"Column '{column}' not found"
344
 
345
  elif operation == "mean":
346
  column = kwargs.get('column')
347
+ if column and column in df.columns:
348
  result = df[column].mean()
349
  return f"Mean of {column}: {result}"
350
+ return f"Column '{column}' not found"
351
 
352
  elif operation == "count":
353
  column = kwargs.get('column')
354
  value = kwargs.get('value')
355
+ if column and column in df.columns:
356
+ if value:
357
+ result = len(df[df[column] == value])
358
+ return f"Count of {column}={value}: {result}"
359
+ else:
360
+ result = df[column].value_counts()
361
+ return f"Value counts for {column}:\n{result.to_string()}"
362
+ return f"Column '{column}' not found"
363
 
364
  elif operation == "groupby":
365
  group_column = kwargs.get('group_column')
 
368
  if group_column and agg_column:
369
  result = df.groupby(group_column)[agg_column].agg(agg_func)
370
  return f"Grouped results:\n{result.to_string()}"
371
+ return "Missing group_column or agg_column"
372
 
373
  elif operation == "filter":
374
  condition = kwargs.get('condition')
375
  if condition:
376
  filtered_df = df.query(condition)
377
  return f"Filtered data ({len(filtered_df)} rows):\n{filtered_df.head().to_string()}"
378
+ return "Missing filter condition"
379
 
380
  elif operation == "describe":
381
  return f"Data description:\n{df.describe().to_string()}"
382
 
383
+ elif operation == "info":
384
+ buffer = io.StringIO()
385
+ df.info(buf=buffer)
386
+ return buffer.getvalue()
387
+
388
  return "Operation not recognized or missing parameters."
389
 
390
  except Exception as e:
 
469
  "type": "object",
470
  "properties": {
471
  "file_path": {"type": "string", "description": "Path to data file"},
472
+ "operation": {"type": "string", "description": "Operation: sum, mean, count, groupby, filter, describe, info"},
473
  "kwargs": {"type": "object", "description": "Additional parameters for the operation"}
474
  },
475
  "required": ["file_path", "operation"]