Ghisalbertifederico commited on
Commit
7d633dc
Β·
verified Β·
1 Parent(s): 4d4d6cb

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +243 -94
tools.py CHANGED
@@ -7,59 +7,106 @@ import base64
7
  import subprocess
8
  from config import GROQ_API_KEY, OPENROUTER_API_KEY
9
  from functools import lru_cache
 
10
  # Force UTF-8 output on Windows to avoid charmap crashes with Unicode characters
11
  if sys.platform == "win32":
12
  sys.stdout.reconfigure(encoding="utf-8", errors="replace")
13
  sys.stderr.reconfigure(encoding="utf-8", errors="replace")
14
- import pypdf
15
  import requests
16
  from tempfile import NamedTemporaryFile
17
  import pandas as pd
18
  import markdownify
19
  from langchain_community.document_loaders import WikipediaLoader
20
- from langchain_community.tools.tavily_search import TavilySearchResults
21
  from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
22
- from langchain_core.messages import HumanMessage, SystemMessage
23
  from langchain_core.tools import tool
24
- from langchain_openai import ChatOpenAI
25
  from youtube_transcript_api import YouTubeTranscriptApi
26
 
 
 
 
 
27
  @tool
28
- def wikipedia_search(query: str, max_pages: int = 2) -> str:
29
- """Search Wikipedia for a short query and return a truncated summary.
30
- """
31
  print(f"[TOOL] wiki_search called with query: {query}")
32
- docs = WikipediaLoader(query=query, load_max_docs=max_pages).load()
33
- joined = "\n\n---\n\n".join(d.page_content for d in docs)
34
- return joined[:48_000]
 
 
 
 
35
 
36
- @lru_cache(maxsize=256)
37
- def ddg_search(query: str, k: int = 6) -> list[dict[str, str]]:
38
- """Visit a webpage URL and return its text content (truncated).
39
- """
40
- wrapper = DuckDuckGoSearchAPIWrapper(max_results=k)
41
- hits = wrapper.results(query)
42
- return [
43
- {
44
- "title": hit.get("title", "")[:500],
45
- "snippet": hit.get("snippet", "")[:12000],
46
- "link": hit.get("link", "")[:300],
47
- }
48
- for hit in hits[:k]
49
- ]
50
 
51
- @tool
52
- def web_search(query: str, k: int = 6) -> str:
53
- """Search the web using DuckDuckGo and Tavily
54
- """
 
55
  try:
56
- hits = ddg_search(query, k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  if hits:
58
  return json.dumps(hits, ensure_ascii=False)
 
59
 
60
- except Exception as exc:
61
- return f"search_error:{exc}"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @tool
64
  def get_youtube_transcript(video_url: str) -> str:
65
  """Fetch the transcript/captions of a YouTube video.
@@ -68,22 +115,44 @@ def get_youtube_transcript(video_url: str) -> str:
68
  video_url: Full YouTube URL or just the video ID.
69
 
70
  Returns:
71
- The full transcript as a single string.
72
  """
73
-
74
  match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", video_url)
75
  video_id = match.group(1) if match else video_url
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
- try:
78
- # youtube-transcript-api >= 0.6.0
79
- entries = YouTubeTranscriptApi().fetch(video_id)
80
- except TypeError:
81
- # fallback for older versions
82
- entries = YouTubeTranscriptApi.get_transcript(video_id)
 
 
 
 
 
 
 
83
  return " ".join(e["text"] for e in entries)
84
- except Exception as e:
85
- return "TRANSCRIPT_UNAVAILABLE"
86
 
 
 
 
 
87
  @tool
88
  def describe_image(img_bytes: bytes, question: str) -> str:
89
  """Use a vision model to interpret or answer questions about an image file.
@@ -95,89 +164,169 @@ def describe_image(img_bytes: bytes, question: str) -> str:
95
  Returns:
96
  A text description or answer about the image content.
97
  """
98
- mime_type = "image/png"
99
  image_data = base64.standard_b64encode(img_bytes).decode("utf-8")
100
 
101
- payload = {
102
- "model": "nvidia/nemotron-nano-12b-v2-vl:free",
103
- "messages": [
104
- {
105
- "role": "user",
106
- "content": [
107
- {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_data}"}},
108
- {"type": "text", "text": question},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  ],
 
110
  }
111
- ],
112
- "max_tokens": 1024,
113
- }
114
- headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
115
- resp = requests.post(
116
- "https://openrouter.ai/api/v1/chat/completions",
117
- json=payload, headers=headers, timeout=60,
118
- )
119
- resp.raise_for_status()
120
- return resp.json()["choices"][0]["message"]["content"]
 
 
 
 
 
121
 
122
 
 
 
 
123
  @tool
124
  def transcribe_audio(audio_bytes: bytes) -> str:
125
- """Transcribe an audio file (.mp3, .wav, .m4a, .flac) to text using Whisper.
126
- """
127
  headers = {"Authorization": f"Bearer {GROQ_API_KEY}"}
128
  with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
129
  f.write(audio_bytes)
130
  file_path = f.name
131
- with open(file_path, "rb") as f:
132
- resp = requests.post(
133
- "https://api.groq.com/openai/v1/audio/transcriptions",
134
- headers=headers,
135
- files={"file": (os.path.basename(file_path), f)},
136
- data={"model": "whisper-large-v3"},
137
- timeout=120,
138
- )
139
- resp.raise_for_status()
140
- return resp.json().get("text", "")
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
142
  @tool
143
  def run_python_file(code: str) -> str:
144
- """Execute a Python (.py) file and return its printed output.
145
- """
 
 
146
 
 
 
 
147
  try:
148
  with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
149
  f.write(code)
150
  path = f.name
151
  proc = subprocess.run(
152
- ["python", path], capture_output=True, text=True, timeout=45
153
  )
154
- out = proc.stdout.strip().splitlines()
155
- return out[-1] if out else ""
 
 
 
 
 
 
 
 
 
156
  except Exception as exc:
157
- return f"py_error:{exc}"
 
 
 
 
 
 
158
 
 
 
 
159
  @tool
160
  def read_task_file(xls_bytes: bytes) -> str:
161
- """Read the contents of a local file attached to the task.
162
- Supports plain text, Python, CSV, JSON, Excel (.xlsx/.xls), PDF, and audio files.
 
 
 
 
 
 
163
  """
 
164
  try:
165
- df = pd.read_excel(xls_bytes)
166
  return df.to_string(index=False)
167
- except:
168
- df = pd.read_csv(xls_bytes)
 
 
 
 
169
  return df.to_string(index=False)
170
- # if ext == ".pdf":
171
- # try:
172
- # from pypdf import PdfReader
173
- # except ImportError:
174
- # return "PDF reading requires the 'pypdf' package (pip install pypdf)."
175
- # reader = PdfReader(file_path)
176
- # pages = [page.extract_text() or "" for page in reader.pages]
177
- # return "\n".join(pages).strip()
178
- # Default: read as UTF-8 text (covers .txt, .py, .json, .md, etc.)
179
- with open(file_path, "r", encoding="utf-8", errors="replace") as f:
180
- return f.read()
 
 
 
 
 
 
 
 
181
 
182
 
183
  _DOWNLOAD_DIR = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gaia_files")
 
7
  import subprocess
8
  from config import GROQ_API_KEY, OPENROUTER_API_KEY
9
  from functools import lru_cache
10
+
11
  # Force UTF-8 output on Windows to avoid charmap crashes with Unicode characters
12
  if sys.platform == "win32":
13
  sys.stdout.reconfigure(encoding="utf-8", errors="replace")
14
  sys.stderr.reconfigure(encoding="utf-8", errors="replace")
15
+
16
  import requests
17
  from tempfile import NamedTemporaryFile
18
  import pandas as pd
19
  import markdownify
20
  from langchain_community.document_loaders import WikipediaLoader
 
21
  from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
 
22
  from langchain_core.tools import tool
 
23
  from youtube_transcript_api import YouTubeTranscriptApi
24
 
25
+
26
+ # ──────────────────────────────────────────────────────────────────────────── #
27
+ # Wikipedia
28
+ # ──────────────────────────────────────────────────────────────────────────── #
29
  @tool
30
+ def wikipedia_search(query: str, max_pages: int = 3) -> str:
31
+ """Search Wikipedia for a query and return article summaries."""
 
32
  print(f"[TOOL] wiki_search called with query: {query}")
33
+ try:
34
+ docs = WikipediaLoader(query=query, load_max_docs=max_pages).load()
35
+ joined = "\n\n---\n\n".join(d.page_content for d in docs)
36
+ return joined[:50_000] if joined else "No Wikipedia results found."
37
+ except Exception as e:
38
+ print(f"[TOOL] wiki_search error: {e}")
39
+ return f"Wikipedia search failed: {e}"
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # ──────────────────────────────────────────────────────────────────────────── #
43
+ # Web Search (DuckDuckGo)
44
+ # ──────────────────────────────────────────────────────────────────────────── #
45
+ def _ddg_search_raw(query: str, k: int = 8) -> list[dict]:
46
+ """Raw DuckDuckGo search returning list of result dicts."""
47
  try:
48
+ wrapper = DuckDuckGoSearchAPIWrapper(max_results=k)
49
+ hits = wrapper.results(query)
50
+ return [
51
+ {
52
+ "title": hit.get("title", "")[:500],
53
+ "snippet": hit.get("snippet", "")[:4000],
54
+ "link": hit.get("link", "")[:300],
55
+ }
56
+ for hit in hits[:k]
57
+ ]
58
+ except Exception as e:
59
+ print(f"[TOOL] DDG search error: {e}")
60
+ return []
61
+
62
+
63
+ @tool
64
+ def web_search(query: str, k: int = 8) -> str:
65
+ """Search the web using DuckDuckGo and return results as JSON."""
66
+ hits = _ddg_search_raw(query, k)
67
+ if hits:
68
+ return json.dumps(hits, ensure_ascii=False)
69
+ # Fallback: try with a simplified query
70
+ simplified = re.sub(r'["\']', '', query)
71
+ if simplified != query:
72
+ hits = _ddg_search_raw(simplified, k)
73
  if hits:
74
  return json.dumps(hits, ensure_ascii=False)
75
+ return "No search results found."
76
 
 
 
77
 
78
+ # ──────────────────────────────────────────────────────────────────────────── #
79
+ # Visit Webpage (fetch actual page content)
80
+ # ──────────────────────────────────────────────────────────────────────────── #
81
+ @tool
82
+ def visit_webpage(url: str) -> str:
83
+ """Fetch the content of a webpage URL and return cleaned text.
84
+
85
+ Args:
86
+ url: The URL to fetch.
87
+
88
+ Returns:
89
+ The main text content of the page, truncated to ~40k chars.
90
+ """
91
+ print(f"[TOOL] visit_webpage: {url}")
92
+ try:
93
+ headers = {
94
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
95
+ "(KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
96
+ }
97
+ resp = requests.get(url, headers=headers, timeout=20)
98
+ resp.raise_for_status()
99
+ text = markdownify.markdownify(resp.text, strip=["img", "script", "style"])
100
+ text = re.sub(r'\n{3,}', '\n\n', text).strip()
101
+ return text[:40_000]
102
+ except Exception as e:
103
+ print(f"[TOOL] visit_webpage error: {e}")
104
+ return f"Could not fetch {url}: {e}"
105
+
106
+
107
+ # ──────────────────────────────────────────────────────────────────────────── #
108
+ # YouTube Transcript
109
+ # ──────────────────────────────────────────────────────────────────────────── #
110
  @tool
111
  def get_youtube_transcript(video_url: str) -> str:
112
  """Fetch the transcript/captions of a YouTube video.
 
115
  video_url: Full YouTube URL or just the video ID.
116
 
117
  Returns:
118
+ The full transcript as a single string, or TRANSCRIPT_UNAVAILABLE.
119
  """
 
120
  match = re.search(r"(?:v=|youtu\.be/)([A-Za-z0-9_-]{11})", video_url)
121
  video_id = match.group(1) if match else video_url
122
+
123
+ # Try new API first, then old API
124
+ for attempt_fn in [_fetch_transcript_new_api, _fetch_transcript_old_api]:
125
+ result = attempt_fn(video_id)
126
+ if result and result != "TRANSCRIPT_UNAVAILABLE":
127
+ print(f"[TOOL] YouTube transcript: {len(result)} chars")
128
+ return result
129
+
130
+ return "TRANSCRIPT_UNAVAILABLE"
131
+
132
+
133
+ def _fetch_transcript_new_api(video_id: str) -> str:
134
  try:
135
+ ytt = YouTubeTranscriptApi()
136
+ entries = ytt.fetch(video_id)
137
+ return " ".join(
138
+ e.text if hasattr(e, 'text') else e.get("text", "")
139
+ for e in entries
140
+ )
141
+ except Exception:
142
+ return ""
143
+
144
+
145
+ def _fetch_transcript_old_api(video_id: str) -> str:
146
+ try:
147
+ entries = YouTubeTranscriptApi.get_transcript(video_id)
148
  return " ".join(e["text"] for e in entries)
149
+ except Exception:
150
+ return ""
151
 
152
+
153
+ # ──────────────────────────────────────────────────────────────────────────── #
154
+ # Image Description (Vision model)
155
+ # ──────────────────────────────────────────────────────────────────────────── #
156
  @tool
157
  def describe_image(img_bytes: bytes, question: str) -> str:
158
  """Use a vision model to interpret or answer questions about an image file.
 
164
  Returns:
165
  A text description or answer about the image content.
166
  """
 
167
  image_data = base64.standard_b64encode(img_bytes).decode("utf-8")
168
 
169
+ models_to_try = [
170
+ "google/gemini-2.0-flash-001",
171
+ "qwen/qwen-2.5-vl-72b-instruct",
172
+ "nvidia/nemotron-nano-12b-v2-vl:free",
173
+ ]
174
+
175
+ for model in models_to_try:
176
+ try:
177
+ payload = {
178
+ "model": model,
179
+ "messages": [
180
+ {
181
+ "role": "user",
182
+ "content": [
183
+ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
184
+ {"type": "text", "text": (
185
+ f"{question}\n\n"
186
+ "Be extremely specific and precise. "
187
+ "If this is a chess position, list ALL pieces with their exact square coordinates in algebraic notation. "
188
+ "If there is text in the image, transcribe it exactly. "
189
+ "If there are numbers, list them all."
190
+ )},
191
+ ],
192
+ }
193
  ],
194
+ "max_tokens": 2048,
195
  }
196
+ headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
197
+ resp = requests.post(
198
+ "https://openrouter.ai/api/v1/chat/completions",
199
+ json=payload, headers=headers, timeout=90,
200
+ )
201
+ resp.raise_for_status()
202
+ content = resp.json()["choices"][0]["message"]["content"]
203
+ if content and len(content.strip()) > 10:
204
+ print(f"[TOOL] describe_image success with {model}")
205
+ return content
206
+ except Exception as e:
207
+ print(f"[TOOL] describe_image failed with {model}: {e}")
208
+ continue
209
+
210
+ return "IMAGE_DESCRIPTION_UNAVAILABLE"
211
 
212
 
213
+ # ──────────────────────────────────────────────────────────────────────────── #
214
+ # Audio Transcription (Whisper via Groq)
215
+ # ──────────────────────────────────────────────────────────────────────────── #
216
  @tool
217
  def transcribe_audio(audio_bytes: bytes) -> str:
218
+ """Transcribe an audio file (.mp3, .wav, .m4a, .flac) to text using Whisper."""
 
219
  headers = {"Authorization": f"Bearer {GROQ_API_KEY}"}
220
  with NamedTemporaryFile(suffix=".mp3", delete=False) as f:
221
  f.write(audio_bytes)
222
  file_path = f.name
223
+ try:
224
+ with open(file_path, "rb") as f:
225
+ resp = requests.post(
226
+ "https://api.groq.com/openai/v1/audio/transcriptions",
227
+ headers=headers,
228
+ files={"file": (os.path.basename(file_path), f)},
229
+ data={"model": "whisper-large-v3"},
230
+ timeout=120,
231
+ )
232
+ resp.raise_for_status()
233
+ text = resp.json().get("text", "")
234
+ print(f"[TOOL] transcribe_audio: {len(text)} chars")
235
+ return text
236
+ except Exception as e:
237
+ print(f"[TOOL] transcribe_audio error: {e}")
238
+ return f"TRANSCRIPTION_ERROR: {e}"
239
+ finally:
240
+ try:
241
+ os.unlink(file_path)
242
+ except OSError:
243
+ pass
244
+
245
 
246
+ # ──────────────────────────────────────────────────────────────────────────── #
247
+ # Python Execution
248
+ # ──────────────────────────────────────────────────────────────────────────── #
249
  @tool
250
  def run_python_file(code: str) -> str:
251
+ """Execute Python code and return its printed output.
252
+
253
+ Args:
254
+ code: The Python source code to execute.
255
 
256
+ Returns:
257
+ The last line of stdout, or stderr if no stdout.
258
+ """
259
  try:
260
  with NamedTemporaryFile(delete=False, suffix=".py", mode="w") as f:
261
  f.write(code)
262
  path = f.name
263
  proc = subprocess.run(
264
+ [sys.executable, path], capture_output=True, text=True, timeout=45
265
  )
266
+ stdout = proc.stdout.strip()
267
+ stderr = proc.stderr.strip()
268
+ if stdout:
269
+ lines = [l for l in stdout.splitlines() if l.strip()]
270
+ return lines[-1] if lines else stdout
271
+ elif stderr:
272
+ return f"py_stderr: {stderr[:2000]}"
273
+ else:
274
+ return ""
275
+ except subprocess.TimeoutExpired:
276
+ return "py_error: execution timed out after 45s"
277
  except Exception as exc:
278
+ return f"py_error: {exc}"
279
+ finally:
280
+ try:
281
+ os.unlink(path)
282
+ except OSError:
283
+ pass
284
+
285
 
286
+ # ──────────────────────────────────────────────────────────────────────────── #
287
+ # File Reading (Excel / CSV / PDF / Text)
288
+ # ──────────────────────────────────────────────────────────────────────────── #
289
  @tool
290
  def read_task_file(xls_bytes: bytes) -> str:
291
+ """Read the contents of a file attached to the task.
292
+ Supports Excel (.xlsx/.xls), CSV, PDF, and plain text.
293
+
294
+ Args:
295
+ xls_bytes: Raw bytes of the file.
296
+
297
+ Returns:
298
+ The file contents as text.
299
  """
300
+ # Try Excel first
301
  try:
302
+ df = pd.read_excel(io.BytesIO(xls_bytes))
303
  return df.to_string(index=False)
304
+ except Exception:
305
+ pass
306
+
307
+ # Try CSV
308
+ try:
309
+ df = pd.read_csv(io.BytesIO(xls_bytes))
310
  return df.to_string(index=False)
311
+ except Exception:
312
+ pass
313
+
314
+ # Try PDF
315
+ try:
316
+ from pypdf import PdfReader
317
+ reader = PdfReader(io.BytesIO(xls_bytes))
318
+ pages = [page.extract_text() or "" for page in reader.pages]
319
+ text = "\n".join(pages).strip()
320
+ if text:
321
+ return text
322
+ except Exception:
323
+ pass
324
+
325
+ # Fallback: decode as UTF-8 text
326
+ try:
327
+ return xls_bytes.decode("utf-8", errors="replace")
328
+ except Exception:
329
+ return "Could not read the attached file in any supported format."
330
 
331
 
332
  _DOWNLOAD_DIR = os.path.join(os.environ.get("TMPDIR", "/tmp"), "gaia_files")