Yoan Di Cosmo commited on
Commit
ba44575
·
1 Parent(s): 32f776a

Added 2 dataset tools, one to list rows, the other one to download the dataset

Browse files
agent/core/tools.py CHANGED
@@ -13,6 +13,12 @@ from lmnr import observe
13
  from mcp.types import EmbeddedResource, ImageContent, TextContent
14
 
15
  from agent.config import MCPServerConfig
 
 
 
 
 
 
16
  from agent.tools.docs_tools import (
17
  EXPLORE_HF_DOCS_TOOL_SPEC,
18
  HF_DOCS_FETCH_TOOL_SPEC,
@@ -257,6 +263,19 @@ def create_builtin_tools() -> list[ToolSpec]:
257
  parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
258
  handler=hf_docs_fetch_handler,
259
  ),
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  # Planning and job management tools
261
  ToolSpec(
262
  name=PLAN_TOOL_SPEC["name"],
 
13
  from mcp.types import EmbeddedResource, ImageContent, TextContent
14
 
15
  from agent.config import MCPServerConfig
16
+ from agent.tools.dataset_tools import (
17
+ DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC,
18
+ DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC,
19
+ hf_datasets_download_rows_handler,
20
+ hf_datasets_list_splits_handler,
21
+ )
22
  from agent.tools.docs_tools import (
23
  EXPLORE_HF_DOCS_TOOL_SPEC,
24
  HF_DOCS_FETCH_TOOL_SPEC,
 
263
  parameters=HF_DOCS_FETCH_TOOL_SPEC["parameters"],
264
  handler=hf_docs_fetch_handler,
265
  ),
266
+ # Datasets server tools
267
+ ToolSpec(
268
+ name=DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC["name"],
269
+ description=DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC["description"],
270
+ parameters=DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC["parameters"],
271
+ handler=hf_datasets_list_splits_handler,
272
+ ),
273
+ ToolSpec(
274
+ name=DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC["name"],
275
+ description=DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC["description"],
276
+ parameters=DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC["parameters"],
277
+ handler=hf_datasets_download_rows_handler,
278
+ ),
279
  # Planning and job management tools
280
  ToolSpec(
281
  name=PLAN_TOOL_SPEC["name"],
agent/tools/__init__.py CHANGED
@@ -18,6 +18,12 @@ from agent.tools.github_search_code import (
18
  GITHUB_SEARCH_CODE_TOOL_SPEC,
19
  github_search_code_handler,
20
  )
 
 
 
 
 
 
21
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
22
  from agent.tools.types import ToolResult
23
 
@@ -34,4 +40,8 @@ __all__ = [
34
  "github_read_file_handler",
35
  "GITHUB_SEARCH_CODE_TOOL_SPEC",
36
  "github_search_code_handler",
 
 
 
 
37
  ]
 
18
  GITHUB_SEARCH_CODE_TOOL_SPEC,
19
  github_search_code_handler,
20
  )
21
+ from agent.tools.dataset_tools import (
22
+ DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC,
23
+ DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC,
24
+ hf_datasets_download_rows_handler,
25
+ hf_datasets_list_splits_handler,
26
+ )
27
  from agent.tools.jobs_tool import HF_JOBS_TOOL_SPEC, HfJobsTool, hf_jobs_handler
28
  from agent.tools.types import ToolResult
29
 
 
40
  "github_read_file_handler",
41
  "GITHUB_SEARCH_CODE_TOOL_SPEC",
42
  "github_search_code_handler",
43
+ "DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC",
44
+ "hf_datasets_list_splits_handler",
45
+ "DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC",
46
+ "hf_datasets_download_rows_handler",
47
  ]
agent/tools/dataset_tools.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging Face Dataset Tool - Query datasets via the Datasets Server API
3
+
4
+ Allows downloading rows and listing splits from Hugging Face datasets.
5
+ """
6
+
7
+ from typing import Any, Dict
8
+
9
+ import httpx
10
+
11
+ from agent.tools.types import ToolResult
12
+
13
+
14
+ def list_splits(dataset: str) -> ToolResult:
15
+ """
16
+ List all available splits for a dataset.
17
+
18
+ Args:
19
+ dataset: Dataset identifier (e.g., "facebook/research-plan-gen")
20
+
21
+ Returns:
22
+ ToolResult with split information
23
+ """
24
+ base_url = "https://datasets-server.huggingface.co"
25
+ url = f"{base_url}/splits"
26
+
27
+ params = {"dataset": dataset}
28
+
29
+ try:
30
+ response = httpx.get(url, params=params, timeout=30.0)
31
+ response.raise_for_status()
32
+ data = response.json()
33
+
34
+ splits = data.get("splits", [])
35
+ if not splits:
36
+ return {
37
+ "formatted": f"No splits found for dataset '{dataset}'",
38
+ "totalResults": 0,
39
+ "resultsShared": 0,
40
+ "isError": False,
41
+ }
42
+
43
+ # Format splits information
44
+ split_info = []
45
+ for split in splits:
46
+ split_name = split.get("split", "unknown")
47
+ num_rows = split.get("num_examples", "unknown")
48
+ split_info.append(f"- **{split_name}**: {num_rows} rows")
49
+
50
+ formatted = f"Available splits for dataset '{dataset}':\n\n" + "\n".join(split_info)
51
+
52
+ return {
53
+ "formatted": formatted,
54
+ "totalResults": len(splits),
55
+ "resultsShared": len(splits),
56
+ "isError": False,
57
+ }
58
+
59
+ except httpx.HTTPStatusError as e:
60
+ return {
61
+ "formatted": f"HTTP error {e.response.status_code}: {str(e)}",
62
+ "totalResults": 0,
63
+ "resultsShared": 0,
64
+ "isError": True,
65
+ }
66
+ except Exception as e:
67
+ return {
68
+ "formatted": f"Failed to list splits: {str(e)}",
69
+ "totalResults": 0,
70
+ "resultsShared": 0,
71
+ "isError": True,
72
+ }
73
+
74
+
75
+ def download_rows(
76
+ dataset: str,
77
+ split: str,
78
+ config: str | None = None,
79
+ offset: int = 0,
80
+ length: int = 100,
81
+ ) -> ToolResult:
82
+ """
83
+ Download rows from a dataset split.
84
+
85
+ Args:
86
+ dataset: Dataset identifier (e.g., "facebook/research-plan-gen")
87
+ split: Split name (e.g., "train", "test", "validation")
88
+ config: Optional config name (for datasets with multiple configs)
89
+ offset: Starting row index (default: 0)
90
+ length: Number of rows to fetch (default: 100, max recommended: 1000)
91
+
92
+ Returns:
93
+ ToolResult with row data
94
+ """
95
+ base_url = "https://datasets-server.huggingface.co"
96
+ url = f"{base_url}/rows"
97
+
98
+ params = {
99
+ "dataset": dataset,
100
+ "split": split,
101
+ "offset": offset,
102
+ "length": length,
103
+ }
104
+
105
+ if config:
106
+ params["config"] = config
107
+
108
+ try:
109
+ response = httpx.get(url, params=params, timeout=60.0)
110
+ response.raise_for_status()
111
+ data = response.json()
112
+
113
+ rows = data.get("rows", [])
114
+ features = data.get("features", [])
115
+
116
+ if not rows:
117
+ return {
118
+ "formatted": f"No rows found for dataset '{dataset}', split '{split}' at offset {offset}",
119
+ "totalResults": 0,
120
+ "resultsShared": 0,
121
+ "isError": False,
122
+ }
123
+
124
+ # Format a summary of the rows
125
+ formatted_parts = [
126
+ f"Downloaded {len(rows)} rows from dataset '{dataset}'",
127
+ f"Split: {split}",
128
+ f"Offset: {offset}",
129
+ ]
130
+
131
+ if config:
132
+ formatted_parts.append(f"Config: {config}")
133
+
134
+ formatted_parts.append(f"\nFeatures: {', '.join([f.get('name', 'unknown') for f in features])}")
135
+ formatted_parts.append(f"\nTotal rows in response: {len(rows)}")
136
+
137
+ # Show first row as example
138
+ if rows:
139
+ first_row = rows[0].get("row", {})
140
+ formatted_parts.append(f"\nExample row (first row):")
141
+ for key, value in list(first_row.items())[:5]: # Show first 5 fields
142
+ value_str = str(value)
143
+ if len(value_str) > 200:
144
+ value_str = value_str[:200] + "..."
145
+ formatted_parts.append(f" - {key}: {value_str}")
146
+
147
+ formatted = "\n".join(formatted_parts)
148
+
149
+ return {
150
+ "formatted": formatted,
151
+ "totalResults": len(rows),
152
+ "resultsShared": len(rows),
153
+ "isError": False,
154
+ }
155
+
156
+ except httpx.HTTPStatusError as e:
157
+ return {
158
+ "formatted": f"HTTP error {e.response.status_code}: {str(e)}",
159
+ "totalResults": 0,
160
+ "resultsShared": 0,
161
+ "isError": True,
162
+ }
163
+ except Exception as e:
164
+ return {
165
+ "formatted": f"Failed to download rows: {str(e)}",
166
+ "totalResults": 0,
167
+ "resultsShared": 0,
168
+ "isError": True,
169
+ }
170
+
171
+
172
+ # Tool specifications
173
+ DATASETS_SERVER_LIST_SPLITS_TOOL_SPEC = {
174
+ "name": "hf_datasets_list_splits",
175
+ "description": (
176
+ "List all available splits for a Hugging Face dataset.\n\n"
177
+ "Use this to discover what splits (train, test, validation, etc.) are available "
178
+ "for a dataset before downloading rows.\n\n"
179
+ "## When to use\n"
180
+ "- When you need to know what splits are available for a dataset\n"
181
+ "- Before downloading rows to identify the correct split name\n"
182
+ "- To check dataset structure and organization\n\n"
183
+ "## Example\n"
184
+ "{\n"
185
+ ' "dataset": "facebook/research-plan-gen"\n'
186
+ "}"
187
+ ),
188
+ "parameters": {
189
+ "type": "object",
190
+ "properties": {
191
+ "dataset": {
192
+ "type": "string",
193
+ "description": "Dataset identifier in format 'org/dataset-name' (e.g., 'facebook/research-plan-gen'). Required.",
194
+ },
195
+ },
196
+ "required": ["dataset"],
197
+ },
198
+ }
199
+
200
+ DATASETS_SERVER_DOWNLOAD_ROWS_TOOL_SPEC = {
201
+ "name": "hf_datasets_download_rows",
202
+ "description": (
203
+ "Download rows from a Hugging Face dataset split via the Datasets Server API.\n\n"
204
+ "Fetches a specified number of rows starting from a given offset. Useful for "
205
+ "sampling data, inspecting dataset contents, or processing datasets in batches.\n\n"
206
+ "## When to use\n"
207
+ "- When you need to inspect or sample data from a dataset\n"
208
+ "- To download specific rows for analysis or processing\n"
209
+ "- To fetch data in batches (use offset and length parameters)\n\n"
210
+ "## When NOT to use\n"
211
+ "- For downloading entire large datasets (use huggingface_hub or datasets library instead)\n"
212
+ "- When you need to process all data (use streaming or local download)\n\n"
213
+ "## Examples\n"
214
+ "// Get first 100 rows from training split\n"
215
+ "{\n"
216
+ ' "dataset": "facebook/research-plan-gen",\n'
217
+ ' "split": "train",\n'
218
+ ' "config": "arxiv",\n'
219
+ ' "offset": 0,\n'
220
+ ' "length": 100\n'
221
+ "}\n\n"
222
+ "// Get next batch (rows 100-200)\n"
223
+ "{\n"
224
+ ' "dataset": "facebook/research-plan-gen",\n'
225
+ ' "split": "train",\n'
226
+ ' "offset": 100,\n'
227
+ ' "length": 100\n'
228
+ "}"
229
+ ),
230
+ "parameters": {
231
+ "type": "object",
232
+ "properties": {
233
+ "dataset": {
234
+ "type": "string",
235
+ "description": "Dataset identifier in format 'org/dataset-name' (e.g., 'facebook/research-plan-gen'). Required.",
236
+ },
237
+ "split": {
238
+ "type": "string",
239
+ "description": "Split name (e.g., 'train', 'test', 'validation'). Required.",
240
+ },
241
+ "config": {
242
+ "type": "string",
243
+ "description": "Config name (only needed for datasets with multiple configs). Optional.",
244
+ },
245
+ "offset": {
246
+ "type": "integer",
247
+ "description": "Starting row index (default: 0).",
248
+ "default": 0,
249
+ },
250
+ "length": {
251
+ "type": "integer",
252
+ "description": "Number of rows to fetch (default: 100, max recommended: 1000).",
253
+ "default": 100,
254
+ },
255
+ },
256
+ "required": ["dataset", "split"],
257
+ },
258
+ }
259
+
260
+
261
+ async def hf_datasets_list_splits_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
262
+ """Handler for listing dataset splits"""
263
+ try:
264
+ result = list_splits(dataset=arguments["dataset"])
265
+ return result["formatted"], not result.get("isError", False)
266
+ except Exception as e:
267
+ return f"Error: {str(e)}", False
268
+
269
+
270
+ async def hf_datasets_download_rows_handler(arguments: Dict[str, Any]) -> tuple[str, bool]:
271
+ """Handler for downloading dataset rows"""
272
+ try:
273
+ result = download_rows(
274
+ dataset=arguments["dataset"],
275
+ split=arguments["split"],
276
+ config=arguments.get("config"),
277
+ offset=arguments.get("offset", 0),
278
+ length=arguments.get("length", 100),
279
+ )
280
+ return result["formatted"], not result.get("isError", False)
281
+ except Exception as e:
282
+ return f"Error: {str(e)}", False
283
+