WaelDahech commited on
Commit
e1843cb
·
1 Parent(s): f199baf

use tool wrapper

Browse files
Files changed (2) hide show
  1. app.py +3 -17
  2. my_tools.py +233 -16
app.py CHANGED
@@ -8,33 +8,19 @@ import pandas as pd
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
 
 
11
  from smolagents import Tool
12
  from smolagents import CodeAgent,DuckDuckGoSearchTool, InferenceClientModel,load_tool,tool
13
 
14
  # --- Basic Agent Definition ---
15
- TOOL_REGISTRY = [
16
- Tool(name="wikipedia_search", entry_point="mytools.wikipedia_search.call"),
17
- Tool(name="youtube_transcript", entry_point="mytools.youtube_transcript.call"),
18
- Tool(name="video_frame_analyzer", entry_point="mytools.video_frame_analyzer.call"),
19
- Tool(name="string_manipulator", entry_point="mytools.string_manipulator.call"),
20
- Tool(name="vision_chess_engine", entry_point="mytools.vision_chess_engine.call"),
21
- Tool(name="table_parser", entry_point="mytools.table_parser.call"),
22
- Tool(name="libretext_fetcher", entry_point="mytools.libretext_fetcher.call"),
23
- Tool(name="audio_transcriber", entry_point="mytools.audio_transcriber.call"),
24
- Tool(name="botanical_classifier", entry_point="mytools.botanical_classifier.call"),
25
- Tool(name="imdb_lookup", entry_point="mytools.imdb_lookup.call"),
26
- Tool(name="excel_reader", entry_point="mytools.excel_reader.call"),
27
- Tool(name="competition_db", entry_point="mytools.competition_db.call"),
28
- Tool(name="japanese_baseball_api", entry_point="mytools.japanese_baseball_api.call"),
29
- ]
30
-
31
  import yaml
32
  #with open("prompts.yaml", 'r') as stream:
33
  # prompt_templates = yaml.safe_load(stream)
34
 
35
  MyAgent = CodeAgent(
36
  model= InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
37
- tools=[*TOOL_REGISTRY], ## add your tools here (don't remove final answer)
38
  max_steps=6,
39
  verbosity_level=1,
40
  grammar=None,
 
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
+ from my_tools import tools_list
12
+
13
  from smolagents import Tool
14
  from smolagents import CodeAgent,DuckDuckGoSearchTool, InferenceClientModel,load_tool,tool
15
 
16
  # --- Basic Agent Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  import yaml
18
  #with open("prompts.yaml", 'r') as stream:
19
  # prompt_templates = yaml.safe_load(stream)
20
 
21
  MyAgent = CodeAgent(
22
  model= InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
23
+ tools=[*tools_list], ## add your tools here (don't remove final answer)
24
  max_steps=6,
25
  verbosity_level=1,
26
  grammar=None,
my_tools.py CHANGED
@@ -12,19 +12,51 @@ import whisper
12
  from imdb import IMDb
13
  import subprocess
14
  import sys
 
 
15
 
16
  # === wikipedia_search ===
17
- def wikipedia_search_call(query: str) -> dict:
 
 
 
 
 
 
 
 
 
 
18
  page = wikipedia.page(query)
19
  sections = {sec: page.section(sec) for sec in page.sections}
20
  return {"title": page.title, "content": page.content, "sections": sections}
21
 
22
  # === youtube_transcript ===
23
- def youtube_transcript_call(video_id: str) -> list:
 
 
 
 
 
 
 
 
 
 
24
  return YouTubeTranscriptApi.get_transcript(video_id)
25
 
26
  # === video_frame_analyzer ===
27
- def download_and_sample(video_id: str, fps: int = 1) -> list:
 
 
 
 
 
 
 
 
 
 
28
  url = f"https://www.youtube.com/watch?v={video_id}"
29
  yt = YouTube(url)
30
  stream = yt.streams.filter(progressive=True, file_extension='mp4').first()
@@ -44,17 +76,49 @@ def download_and_sample(video_id: str, fps: int = 1) -> list:
44
  cap.release()
45
  return frames
46
 
47
- def detect_species(frame) -> list:
 
 
 
 
 
 
 
 
 
48
  # TODO: integrate actual CV model for bird-species detection
49
  return []
50
 
 
51
  def video_frame_analyzer_call(video_id: str) -> int:
 
 
 
 
 
 
 
 
 
52
  frames = download_and_sample(video_id)
53
  counts = [len(set(detect_species(f))) for f in frames]
54
  return max(counts) if counts else 0
55
 
56
  # === string_manipulator ===
57
- def string_manipulator_call(text: str, operation: str = "reverse", pattern: str = None, replacement: str = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if operation == "reverse":
59
  return text[::-1]
60
  if operation == "split":
@@ -64,7 +128,18 @@ def string_manipulator_call(text: str, operation: str = "reverse", pattern: str
64
  raise ValueError(f"Unsupported operation: {operation}")
65
 
66
  # === vision_chess_engine ===
 
67
  def vision_chess_engine_call(fen: str, depth: int = 20) -> str:
 
 
 
 
 
 
 
 
 
 
68
  engine = chess.engine.SimpleEngine.popen_uci("stockfish")
69
  board = chess.Board(fen)
70
  result = engine.play(board, chess.engine.Limit(depth=depth))
@@ -72,23 +147,58 @@ def vision_chess_engine_call(fen: str, depth: int = 20) -> str:
72
  return board.san(result.move)
73
 
74
  # === table_parser ===
75
- def table_parser_call(file_path: str, sheet_name: str = None) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
76
  if file_path.lower().endswith('.csv'):
77
  return pd.read_csv(file_path)
78
  return pd.read_excel(file_path, sheet_name=sheet_name)
79
 
80
  # === libretext_fetcher ===
81
- def libretext_fetcher_call(url: str, section_id: str) -> list:
 
 
 
 
 
 
 
 
 
 
 
82
  resp = requests.get(url)
83
  soup = BeautifulSoup(resp.text, "html.parser")
84
  sec = soup.find(id=section_id)
85
  if not sec:
86
  return []
87
- items = sec.find_next('ul').find_all('li')
88
- return [li.get_text(strip=True) for li in items]
 
 
 
89
 
90
  # === audio_transcriber ===
 
91
  def audio_transcriber_call(audio_path: str) -> str:
 
 
 
 
 
 
 
 
 
92
  model = whisper.load_model("base")
93
  result = model.transcribe(audio_path)
94
  return result.get("text", "")
@@ -96,11 +206,31 @@ def audio_transcriber_call(audio_path: str) -> str:
96
  # === botanical_classifier ===
97
  BOTANICAL_VEGETABLES = {"tomato", "eggplant", "pepper", "squash"}
98
 
99
- def botanical_classifier_call(items: list) -> list:
 
 
 
 
 
 
 
 
 
 
100
  return [item for item in items if item.lower() in BOTANICAL_VEGETABLES]
101
 
102
  # === imdb_lookup ===
103
- def imdb_lookup_call(person_name: str) -> dict:
 
 
 
 
 
 
 
 
 
 
104
  ia = IMDb()
105
  results = ia.search_person(person_name)
106
  if not results:
@@ -110,28 +240,115 @@ def imdb_lookup_call(person_name: str) -> dict:
110
  return {"name": person['name'], "filmography": person.get('filmography', {})}
111
 
112
  # === python_executor ===
 
113
  def python_executor_call(script_path: str) -> str:
 
 
 
 
 
 
 
 
 
114
  proc = subprocess.run([sys.executable, script_path], capture_output=True, text=True, check=True)
115
  return proc.stdout.strip()
116
 
117
  # === sports_stats_api ===
118
- def sports_stats_api_call(season: int, team: str, stat: str = "BB") -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
119
  raise NotImplementedError("sports_stats_api integration not configured")
120
 
121
  # === web_scraper ===
122
- def web_scraper_call(url: str, css_selector: str) -> list:
 
 
 
 
 
 
 
 
 
 
 
123
  resp = requests.get(url)
124
  soup = BeautifulSoup(resp.text, "html.parser")
125
  return [el.get_text(strip=True) for el in soup.select(css_selector)]
126
 
127
  # === excel_reader ===
128
- def excel_reader_call(file_path: str, sheet_name: str = None) -> pd.DataFrame:
 
 
 
 
 
 
 
 
 
 
 
129
  return pd.read_excel(file_path, sheet_name=sheet_name)
130
 
131
  # === competition_db ===
132
- def competition_db_call(year_start: int, year_end: int) -> list:
 
 
 
 
 
 
 
 
 
 
 
133
  raise NotImplementedError("competition_db integration not configured")
134
 
135
  # === japanese_baseball_api ===
136
- def japanese_baseball_api_call(team: str, date: str) -> list:
 
 
 
 
 
 
 
 
 
 
 
137
  raise NotImplementedError("japanese_baseball_api integration not configured")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  from imdb import IMDb
13
  import subprocess
14
  import sys
15
+ from typing import Optional, List, Dict, Any
16
+ from smolagents import tool
17
 
18
  # === wikipedia_search ===
19
+ @tool
20
+ def wikipedia_search_call(query: str) -> Dict[str, Any]:
21
+ """
22
+ Search Wikipedia for information about a specific topic.
23
+
24
+ Args:
25
+ query (str): The search query/topic to look up on Wikipedia
26
+
27
+ Returns:
28
+ dict: Dictionary containing the page title, content, and sections
29
+ """
30
  page = wikipedia.page(query)
31
  sections = {sec: page.section(sec) for sec in page.sections}
32
  return {"title": page.title, "content": page.content, "sections": sections}
33
 
34
  # === youtube_transcript ===
35
+ @tool
36
+ def youtube_transcript_call(video_id: str) -> List[Dict[str, Any]]:
37
+ """
38
+ Get the transcript/subtitles from a YouTube video.
39
+
40
+ Args:
41
+ video_id (str): The YouTube video ID (the part after v= in the URL)
42
+
43
+ Returns:
44
+ list: List of transcript segments with text and timing information
45
+ """
46
  return YouTubeTranscriptApi.get_transcript(video_id)
47
 
48
  # === video_frame_analyzer ===
49
+ def download_and_sample(video_id: str, fps: int = 1) -> List[Any]:
50
+ """
51
+ Download a YouTube video and sample frames at specified FPS.
52
+
53
+ Args:
54
+ video_id (str): The YouTube video ID
55
+ fps (int): Frames per second to sample (default: 1)
56
+
57
+ Returns:
58
+ list: List of video frames as numpy arrays
59
+ """
60
  url = f"https://www.youtube.com/watch?v={video_id}"
61
  yt = YouTube(url)
62
  stream = yt.streams.filter(progressive=True, file_extension='mp4').first()
 
76
  cap.release()
77
  return frames
78
 
79
+ def detect_species(frame: Any) -> List[str]:
80
+ """
81
+ Detect bird species in a video frame.
82
+
83
+ Args:
84
+ frame: Video frame as numpy array
85
+
86
+ Returns:
87
+ list: List of detected bird species names
88
+ """
89
  # TODO: integrate actual CV model for bird-species detection
90
  return []
91
 
92
+ @tool
93
  def video_frame_analyzer_call(video_id: str) -> int:
94
+ """
95
+ Analyze video frames to count unique bird species.
96
+
97
+ Args:
98
+ video_id (str): The YouTube video ID to analyze
99
+
100
+ Returns:
101
+ int: Maximum number of unique bird species detected in any frame
102
+ """
103
  frames = download_and_sample(video_id)
104
  counts = [len(set(detect_species(f))) for f in frames]
105
  return max(counts) if counts else 0
106
 
107
  # === string_manipulator ===
108
+ @tool
109
+ def string_manipulator_call(text: str, operation: str = "reverse", pattern: Optional[str] = None, replacement: Optional[str] = None) -> Any:
110
+ """
111
+ Perform various string manipulation operations.
112
+
113
+ Args:
114
+ text (str): The input text to manipulate
115
+ operation (str): The operation to perform ("reverse", "split", "regex_replace")
116
+ pattern (str, optional): Regex pattern for replacement operations
117
+ replacement (str, optional): Replacement string for regex operations
118
+
119
+ Returns:
120
+ Any: Result of the string operation (string or list)
121
+ """
122
  if operation == "reverse":
123
  return text[::-1]
124
  if operation == "split":
 
128
  raise ValueError(f"Unsupported operation: {operation}")
129
 
130
  # === vision_chess_engine ===
131
+ @tool
132
  def vision_chess_engine_call(fen: str, depth: int = 20) -> str:
133
+ """
134
+ Analyze a chess position and suggest the best move using Stockfish engine.
135
+
136
+ Args:
137
+ fen (str): FEN notation representing the chess position
138
+ depth (int): Search depth for the chess engine (default: 20)
139
+
140
+ Returns:
141
+ str: The best move in Standard Algebraic Notation (SAN)
142
+ """
143
  engine = chess.engine.SimpleEngine.popen_uci("stockfish")
144
  board = chess.Board(fen)
145
  result = engine.play(board, chess.engine.Limit(depth=depth))
 
147
  return board.san(result.move)
148
 
149
  # === table_parser ===
150
+ @tool
151
+ def table_parser_call(file_path: str, sheet_name: Optional[str] = None) -> pd.DataFrame:
152
+ """
153
+ Parse CSV or Excel files into a pandas DataFrame.
154
+
155
+ Args:
156
+ file_path (str): Path to the CSV or Excel file
157
+ sheet_name (str, optional): Sheet name for Excel files
158
+
159
+ Returns:
160
+ pd.DataFrame: Parsed data as a pandas DataFrame
161
+ """
162
  if file_path.lower().endswith('.csv'):
163
  return pd.read_csv(file_path)
164
  return pd.read_excel(file_path, sheet_name=sheet_name)
165
 
166
  # === libretext_fetcher ===
167
+ @tool
168
+ def libretext_fetcher_call(url: str, section_id: str) -> List[str]:
169
+ """
170
+ Fetch content from LibreTexts website by section ID.
171
+
172
+ Args:
173
+ url (str): The LibreTexts page URL
174
+ section_id (str): The HTML section ID to extract content from
175
+
176
+ Returns:
177
+ list: List of text items from the specified section
178
+ """
179
  resp = requests.get(url)
180
  soup = BeautifulSoup(resp.text, "html.parser")
181
  sec = soup.find(id=section_id)
182
  if not sec:
183
  return []
184
+ items = sec.find_next('ul')
185
+ if items and hasattr(items, 'find_all'):
186
+ items = items.find_all('li')
187
+ return [li.get_text(strip=True) for li in items]
188
+ return []
189
 
190
  # === audio_transcriber ===
191
+ @tool
192
  def audio_transcriber_call(audio_path: str) -> str:
193
+ """
194
+ Transcribe audio files to text using OpenAI Whisper.
195
+
196
+ Args:
197
+ audio_path (str): Path to the audio file to transcribe
198
+
199
+ Returns:
200
+ str: Transcribed text from the audio
201
+ """
202
  model = whisper.load_model("base")
203
  result = model.transcribe(audio_path)
204
  return result.get("text", "")
 
206
  # === botanical_classifier ===
207
  BOTANICAL_VEGETABLES = {"tomato", "eggplant", "pepper", "squash"}
208
 
209
+ @tool
210
+ def botanical_classifier_call(items: List[str]) -> List[str]:
211
+ """
212
+ Classify items as botanical vegetables.
213
+
214
+ Args:
215
+ items (list): List of items to classify
216
+
217
+ Returns:
218
+ list: Items that are classified as botanical vegetables
219
+ """
220
  return [item for item in items if item.lower() in BOTANICAL_VEGETABLES]
221
 
222
  # === imdb_lookup ===
223
+ @tool
224
+ def imdb_lookup_call(person_name: str) -> Dict[str, Any]:
225
+ """
226
+ Look up information about a person on IMDb.
227
+
228
+ Args:
229
+ person_name (str): Name of the person to search for
230
+
231
+ Returns:
232
+ dict: Dictionary containing person's name and filmography
233
+ """
234
  ia = IMDb()
235
  results = ia.search_person(person_name)
236
  if not results:
 
240
  return {"name": person['name'], "filmography": person.get('filmography', {})}
241
 
242
  # === python_executor ===
243
+ @tool
244
  def python_executor_call(script_path: str) -> str:
245
+ """
246
+ Execute a Python script and return its output.
247
+
248
+ Args:
249
+ script_path (str): Path to the Python script to execute
250
+
251
+ Returns:
252
+ str: Standard output from the script execution
253
+ """
254
  proc = subprocess.run([sys.executable, script_path], capture_output=True, text=True, check=True)
255
  return proc.stdout.strip()
256
 
257
  # === sports_stats_api ===
258
+ @tool
259
+ def sports_stats_api_call(season: int, team: str, stat: str = "BB") -> Dict[str, Any]:
260
+ """
261
+ Get sports statistics for a team in a specific season.
262
+
263
+ Args:
264
+ season (int): The sports season year
265
+ team (str): The team name
266
+ stat (str): The statistic type to retrieve (default: "BB")
267
+
268
+ Returns:
269
+ dict: Sports statistics data
270
+ """
271
  raise NotImplementedError("sports_stats_api integration not configured")
272
 
273
  # === web_scraper ===
274
+ @tool
275
+ def web_scraper_call(url: str, css_selector: str) -> List[str]:
276
+ """
277
+ Scrape content from a website using CSS selectors.
278
+
279
+ Args:
280
+ url (str): The URL to scrape
281
+ css_selector (str): CSS selector to find elements
282
+
283
+ Returns:
284
+ list: List of text content from matching elements
285
+ """
286
  resp = requests.get(url)
287
  soup = BeautifulSoup(resp.text, "html.parser")
288
  return [el.get_text(strip=True) for el in soup.select(css_selector)]
289
 
290
  # === excel_reader ===
291
+ @tool
292
+ def excel_reader_call(file_path: str, sheet_name: Optional[str] = None) -> pd.DataFrame:
293
+ """
294
+ Read Excel files into a pandas DataFrame.
295
+
296
+ Args:
297
+ file_path (str): Path to the Excel file
298
+ sheet_name (str, optional): Specific sheet name to read
299
+
300
+ Returns:
301
+ pd.DataFrame: Data from the Excel file as a pandas DataFrame
302
+ """
303
  return pd.read_excel(file_path, sheet_name=sheet_name)
304
 
305
  # === competition_db ===
306
+ @tool
307
+ def competition_db_call(year_start: int, year_end: int) -> List[Dict[str, Any]]:
308
+ """
309
+ Query competition database for events between specified years.
310
+
311
+ Args:
312
+ year_start (int): Start year for the query range
313
+ year_end (int): End year for the query range
314
+
315
+ Returns:
316
+ list: List of competition events in the specified year range
317
+ """
318
  raise NotImplementedError("competition_db integration not configured")
319
 
320
  # === japanese_baseball_api ===
321
+ @tool
322
+ def japanese_baseball_api_call(team: str, date: str) -> List[Dict[str, Any]]:
323
+ """
324
+ Get Japanese baseball data for a specific team and date.
325
+
326
+ Args:
327
+ team (str): The baseball team name
328
+ date (str): The date in YYYY-MM-DD format
329
+
330
+ Returns:
331
+ list: List of baseball game data for the specified team and date
332
+ """
333
  raise NotImplementedError("japanese_baseball_api integration not configured")
334
+
335
+
336
+
337
+ tools_list = [
338
+ wikipedia_search_call,
339
+ youtube_transcript_call,
340
+ video_frame_analyzer_call,
341
+ string_manipulator_call,
342
+ vision_chess_engine_call,
343
+ table_parser_call,
344
+ libretext_fetcher_call,
345
+ audio_transcriber_call,
346
+ botanical_classifier_call,
347
+ imdb_lookup_call,
348
+ python_executor_call,
349
+ sports_stats_api_call,
350
+ web_scraper_call,
351
+ excel_reader_call,
352
+ competition_db_call,
353
+ japanese_baseball_api_call,
354
+ ]