razvan commited on
Commit
3ab9f4d
·
verified ·
1 Parent(s): 5f7ed04

Upload plugins/mlintern/skills/ml-intern-harness/scripts/inspect_dataset.py with huggingface_hub

Browse files
plugins/mlintern/skills/ml-intern-harness/scripts/inspect_dataset.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Inspect a Hugging Face dataset using the Dataset Viewer API.
3
+
4
+ This mirrors the useful parts of upstream ml-intern's hf_inspect_dataset tool:
5
+ status, configs/splits, schema, sample rows, and parquet file availability.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import argparse
11
+ import json
12
+ import os
13
+ import sys
14
+ import urllib.error
15
+ import urllib.parse
16
+ import urllib.request
17
+ from concurrent.futures import ThreadPoolExecutor, as_completed
18
+ from typing import Any
19
+
20
+
21
+ BASE_URL = "https://datasets-server.huggingface.co"
22
+ MAX_SAMPLE_VALUE_LEN = 150
23
+
24
+
25
+ def fetch_json(path: str, params: dict[str, Any], token: str | None) -> dict[str, Any]:
26
+ query = urllib.parse.urlencode({k: v for k, v in params.items() if v is not None})
27
+ request = urllib.request.Request(f"{BASE_URL}{path}?{query}")
28
+ if token:
29
+ request.add_header("Authorization", f"Bearer {token}")
30
+ try:
31
+ with urllib.request.urlopen(request, timeout=30) as response:
32
+ return json.loads(response.read().decode("utf-8"))
33
+ except urllib.error.HTTPError as exc:
34
+ body = exc.read().decode("utf-8", errors="replace")
35
+ raise RuntimeError(f"{path} returned HTTP {exc.code}: {body[:500]}") from exc
36
+
37
+
38
+ def type_name(feature: Any) -> str:
39
+ if isinstance(feature, str):
40
+ return feature
41
+ if not isinstance(feature, dict):
42
+ return type(feature).__name__
43
+ feature_type = feature.get("_type")
44
+ if feature_type == "ClassLabel":
45
+ names = feature.get("names") or []
46
+ if 0 < len(names) <= 5:
47
+ values = ", ".join(f"{name}={idx}" for idx, name in enumerate(names))
48
+ return f"ClassLabel ({values})"
49
+ return f"ClassLabel ({len(names)} classes)"
50
+ if feature_type:
51
+ return feature_type
52
+ if "dtype" in feature:
53
+ return str(feature["dtype"])
54
+ return json.dumps(feature, ensure_ascii=False)[:120]
55
+
56
+
57
+ def extract_configs(splits_data: dict[str, Any]) -> list[dict[str, Any]]:
58
+ configs: dict[str, dict[str, Any]] = {}
59
+ for item in splits_data.get("splits", []):
60
+ config = item.get("config", "default")
61
+ split = item.get("split", "train")
62
+ row_count = item.get("num_rows") or item.get("num_examples")
63
+ configs.setdefault(config, {"name": config, "splits": []})
64
+ configs[config]["splits"].append({"name": split, "rows": row_count})
65
+ return list(configs.values())
66
+
67
+
68
+ def format_status(data: dict[str, Any]) -> str:
69
+ available = [
70
+ key
71
+ for key in ("viewer", "preview", "search", "filter", "statistics")
72
+ if data.get(key)
73
+ ]
74
+ if available:
75
+ return f"## Status\nValid ({', '.join(available)})"
76
+ return "## Status\nDataset may have Dataset Viewer issues"
77
+
78
+
79
+ def format_structure(configs: list[dict[str, Any]], max_rows: int = 20) -> str:
80
+ lines = ["## Structure (configs & splits)", "| Config | Split | Rows |", "|---|---|---:|"]
81
+ total = sum(len(config["splits"]) for config in configs)
82
+ shown = 0
83
+ for config in configs:
84
+ for split in config["splits"]:
85
+ if shown >= max_rows:
86
+ break
87
+ rows = split["rows"] if split["rows"] is not None else "?"
88
+ lines.append(f"| {config['name']} | {split['name']} | {rows} |")
89
+ shown += 1
90
+ if shown >= max_rows:
91
+ break
92
+ if total > shown:
93
+ lines.append("| ... | ... | ... |")
94
+ lines.append(f"_Showing {shown} of {total} config/split rows._")
95
+ return "\n".join(lines)
96
+
97
+
98
+ def format_schema(info: dict[str, Any], config: str) -> str:
99
+ features = info.get("dataset_info", {}).get("features", {})
100
+ lines = [f"## Schema ({config})", "| Column | Type |", "|---|---|"]
101
+ if not features:
102
+ lines.append("| (none found) | unknown |")
103
+ for column, feature in features.items():
104
+ lines.append(f"| {column} | {type_name(feature)} |")
105
+ return "\n".join(lines)
106
+
107
+
108
+ def maybe_json(value: Any) -> Any:
109
+ if isinstance(value, str):
110
+ try:
111
+ return json.loads(value)
112
+ except json.JSONDecodeError:
113
+ return value
114
+ return value
115
+
116
+
117
+ def format_messages(messages: Any) -> str | None:
118
+ messages = maybe_json(messages)
119
+ if not isinstance(messages, list) or not messages:
120
+ return None
121
+ roles: set[str] = set()
122
+ keys: set[str] = set()
123
+ has_tool_calls = False
124
+ has_tool_results = False
125
+ example: dict[str, Any] | None = None
126
+ fallback: dict[str, Any] | None = None
127
+ for message in messages:
128
+ if not isinstance(message, dict):
129
+ continue
130
+ keys.update(message.keys())
131
+ if message.get("role"):
132
+ roles.add(str(message["role"]))
133
+ if message.get("tool_calls") or message.get("function_call"):
134
+ has_tool_calls = True
135
+ example = example or message
136
+ if message.get("role") in {"tool", "function"} or message.get("tool_call_id"):
137
+ has_tool_results = True
138
+ if message.get("role") == "assistant":
139
+ example = example or message
140
+ elif message.get("role") != "system":
141
+ fallback = fallback or message
142
+ example = example or fallback
143
+ lines = ["## Messages Column Format"]
144
+ lines.append(f"Roles: {', '.join(sorted(roles)) if roles else 'unknown'}")
145
+ common = ["role", "content", "tool_calls", "tool_call_id", "name", "function_call"]
146
+ lines.append("Message keys: " + ", ".join(f"{key} {'yes' if key in keys else 'no'}" for key in common))
147
+ if has_tool_calls:
148
+ lines.append("Tool calls: present")
149
+ if has_tool_results:
150
+ lines.append("Tool results: present")
151
+ if example:
152
+ cleaned = dict(example)
153
+ content = cleaned.get("content")
154
+ if isinstance(content, str) and len(content) > 100:
155
+ cleaned["content"] = content[:100] + "..."
156
+ lines.append("")
157
+ lines.append("Example message structure:")
158
+ lines.append("```json")
159
+ lines.append(json.dumps(cleaned, indent=2, ensure_ascii=False))
160
+ lines.append("```")
161
+ return "\n".join(lines)
162
+
163
+
164
+ def format_samples(rows_data: dict[str, Any], config: str, split: str, limit: int) -> str:
165
+ rows = rows_data.get("rows", [])[:limit]
166
+ lines = [f"## Sample Rows ({config}/{split})"]
167
+ first_messages: Any = None
168
+ for idx, row_wrapper in enumerate(rows, 1):
169
+ row = row_wrapper.get("row", {})
170
+ lines.append(f"**Row {idx}:**")
171
+ for key, value in row.items():
172
+ if key.lower() == "messages" and first_messages is None:
173
+ first_messages = value
174
+ text = str(value)
175
+ if len(text) > MAX_SAMPLE_VALUE_LEN:
176
+ text = text[:MAX_SAMPLE_VALUE_LEN] + "..."
177
+ lines.append(f"- {key}: {text}")
178
+ if not rows:
179
+ lines.append("(no rows returned)")
180
+ message_section = format_messages(first_messages) if first_messages is not None else None
181
+ if message_section:
182
+ lines.append("")
183
+ lines.append(message_section)
184
+ return "\n".join(lines)
185
+
186
+
187
+ def format_parquet(data: dict[str, Any], max_rows: int = 20) -> str | None:
188
+ files = data.get("parquet_files", [])
189
+ if not files:
190
+ return None
191
+ groups: dict[str, dict[str, int]] = {}
192
+ for item in files:
193
+ key = f"{item.get('config', 'default')}/{item.get('split', 'train')}"
194
+ groups.setdefault(key, {"count": 0, "size": 0})
195
+ groups[key]["count"] += 1
196
+ size = item.get("size") or 0
197
+ groups[key]["size"] += int(size) if isinstance(size, (int, float)) else 0
198
+ lines = ["## Files (Parquet)"]
199
+ for key, values in list(groups.items())[:max_rows]:
200
+ size_mb = values["size"] / (1024 * 1024)
201
+ lines.append(f"- {key}: {values['count']} file(s), {size_mb:.1f} MB")
202
+ if len(groups) > max_rows:
203
+ lines.append(f"- ... showing {max_rows} of {len(groups)} groups")
204
+ return "\n".join(lines)
205
+
206
+
207
+ def compatibility_notes(features: dict[str, Any]) -> str:
208
+ columns = set(features)
209
+ lines = ["## Training Compatibility"]
210
+ checks = {
211
+ "SFT": bool({"messages", "text"} & columns or {"prompt", "completion"} <= columns),
212
+ "DPO": {"prompt", "chosen", "rejected"} <= columns,
213
+ "GRPO": "prompt" in columns,
214
+ }
215
+ for method, ok in checks.items():
216
+ lines.append(f"- {method}: {'looks compatible' if ok else 'columns not sufficient'}")
217
+ if "messages" in columns:
218
+ lines.append("- Chat data: inspect the sample message roles before choosing a trainer template.")
219
+ return "\n".join(lines)
220
+
221
+
222
+ def inspect_dataset(dataset: str, config: str | None, split: str | None, sample_rows: int, token: str | None) -> str:
223
+ warnings: list[str] = []
224
+ with ThreadPoolExecutor(max_workers=3) as pool:
225
+ futures = {
226
+ pool.submit(fetch_json, "/is-valid", {"dataset": dataset}, token): "is-valid",
227
+ pool.submit(fetch_json, "/splits", {"dataset": dataset}, token): "splits",
228
+ pool.submit(fetch_json, "/parquet", {"dataset": dataset}, token): "parquet",
229
+ }
230
+ phase1: dict[str, Any] = {}
231
+ for future in as_completed(futures):
232
+ name = futures[future]
233
+ try:
234
+ phase1[name] = future.result()
235
+ except Exception as exc:
236
+ warnings.append(f"{name}: {exc}")
237
+
238
+ configs = extract_configs(phase1.get("splits", {}))
239
+ selected_config = config or (configs[0]["name"] if configs else "default")
240
+ selected_split = split or (configs[0]["splits"][0]["name"] if configs and configs[0]["splits"] else "train")
241
+
242
+ with ThreadPoolExecutor(max_workers=2) as pool:
243
+ futures = {
244
+ pool.submit(fetch_json, "/info", {"dataset": dataset, "config": selected_config}, token): "info",
245
+ pool.submit(
246
+ fetch_json,
247
+ "/first-rows",
248
+ {"dataset": dataset, "config": selected_config, "split": selected_split},
249
+ token,
250
+ ): "first-rows",
251
+ }
252
+ phase2: dict[str, Any] = {}
253
+ for future in as_completed(futures):
254
+ name = futures[future]
255
+ try:
256
+ phase2[name] = future.result()
257
+ except Exception as exc:
258
+ warnings.append(f"{name}: {exc}")
259
+
260
+ features = phase2.get("info", {}).get("dataset_info", {}).get("features", {})
261
+ sections = [f"# {dataset}"]
262
+ if "is-valid" in phase1:
263
+ sections.append(format_status(phase1["is-valid"]))
264
+ if configs:
265
+ sections.append(format_structure(configs))
266
+ if "info" in phase2:
267
+ sections.append(format_schema(phase2["info"], selected_config))
268
+ sections.append(compatibility_notes(features))
269
+ if "first-rows" in phase2:
270
+ sections.append(format_samples(phase2["first-rows"], selected_config, selected_split, sample_rows))
271
+ parquet = format_parquet(phase1.get("parquet", {}))
272
+ if parquet:
273
+ sections.append(parquet)
274
+ if warnings:
275
+ sections.append("## Warnings\n" + "\n".join(f"- {warning}" for warning in warnings))
276
+ return "\n\n".join(sections)
277
+
278
+
279
+ def main() -> int:
280
+ parser = argparse.ArgumentParser(description=__doc__)
281
+ parser.add_argument("dataset", help="Dataset id, for example stanfordnlp/imdb")
282
+ parser.add_argument("--config", help="Config/subset name")
283
+ parser.add_argument("--split", help="Split for sample rows")
284
+ parser.add_argument("--sample-rows", type=int, default=3, help="Number of rows to show, max 10")
285
+ parser.add_argument("--token-env", default="HF_TOKEN", help="Environment variable containing an HF token")
286
+ args = parser.parse_args()
287
+ token = os.environ.get(args.token_env)
288
+ print(inspect_dataset(args.dataset, args.config, args.split, min(args.sample_rows, 10), token))
289
+ return 0
290
+
291
+
292
+ if __name__ == "__main__":
293
+ sys.exit(main())