grandline / scripts /inspect_dataset.py
dignity045's picture
Initial GrandLine implementation: deterministic shard-first dataset preprocessing for LLM pretraining
ed59144 verified
#!/usr/bin/env python3
"""inspect_dataset.py — Inspect a dataset to understand its structure before processing.
Shows:
- Column names and types
- Row counts
- Sample rows
- Available metadata columns
- Value distributions for score/label columns
Usage:
python scripts/inspect_dataset.py --source path/to/data/
python scripts/inspect_dataset.py --hf-repo "HuggingFaceFW/fineweb-edu-score-2" --split train
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import click
import pyarrow.parquet as pq
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
@click.command()
@click.option("--source", help="Path to local parquet file or directory.")
@click.option("--hf-repo", help="HuggingFace dataset repo ID.")
@click.option("--split", default="train", help="Dataset split (for HF datasets).")
@click.option("--sample", default=5, help="Number of sample rows to show.")
@click.option("--text-column", default="text", help="Name of the text column.")
def main(
source: str | None,
hf_repo: str | None,
split: str,
sample: int,
text_column: str,
) -> None:
"""Inspect a dataset's structure and content."""
if source:
inspect_local(source, sample, text_column)
elif hf_repo:
inspect_hf(hf_repo, split, sample, text_column)
else:
click.echo("Error: provide --source or --hf-repo", err=True)
sys.exit(1)
def inspect_local(path_str: str, num_samples: int, text_column: str) -> None:
"""Inspect a local parquet file or directory."""
path = Path(path_str)
if path.is_file():
files = [path]
elif path.is_dir():
files = sorted(path.glob("*.parquet"))
else:
click.echo(f"Error: {path} is not a file or directory", err=True)
sys.exit(1)
if not files:
click.echo(f"No parquet files found in {path}", err=True)
sys.exit(1)
click.echo(f"\n{'='*60}")
click.echo(f"LOCAL DATASET INSPECTION")
click.echo(f"{'='*60}")
click.echo(f"Path: {path}")
click.echo(f"Files: {len(files)} parquet file(s)")
click.echo()
# Read first file for schema
first_file = files[0]
pf = pq.ParquetFile(str(first_file))
click.echo(f"Schema (from {first_file.name}):")
click.echo(f"{'─'*40}")
for field in pf.schema_arrow:
click.echo(f" {field.name}: {field.type}")
click.echo(f"\nMetadata:")
click.echo(f" Rows in first file: {pf.metadata.num_rows:,}")
click.echo(f" Row groups: {pf.metadata.num_row_groups}")
# Total rows across all files
total_rows = 0
for f in files:
meta = pq.read_metadata(str(f))
total_rows += meta.num_rows
click.echo(f" Total rows (all files): {total_rows:,}")
# Sample rows
click.echo(f"\nSample rows (first {num_samples}):")
click.echo(f"{'─'*40}")
table = pq.read_table(str(first_file)).slice(0, num_samples)
columns = table.column_names
for i in range(min(num_samples, table.num_rows)):
click.echo(f"\n [Row {i}]")
for col in columns:
value = table.column(col)[i].as_py()
if col == text_column and isinstance(value, str):
# Truncate text for display
display = value[:200] + "..." if len(value) > 200 else value
display = display.replace("\n", "\\n")
click.echo(f" {col}: {display}")
else:
click.echo(f" {col}: {value}")
# Column statistics for numeric columns
click.echo(f"\nColumn statistics:")
click.echo(f"{'─'*40}")
table_full = pq.read_table(str(first_file))
for col in columns:
if col == text_column:
# Text column: show length stats
texts = table_full.column(col).to_pylist()
lengths = [len(t) for t in texts if t is not None]
if lengths:
click.echo(
f" {col} (text): "
f"min_len={min(lengths)}, "
f"max_len={max(lengths)}, "
f"mean_len={sum(lengths)//len(lengths)}"
)
else:
arr = table_full.column(col)
try:
values = arr.to_pylist()
numeric = [v for v in values if isinstance(v, (int, float))]
if numeric:
click.echo(
f" {col}: "
f"min={min(numeric):.4g}, "
f"max={max(numeric):.4g}, "
f"mean={sum(numeric)/len(numeric):.4g}, "
f"nulls={values.count(None)}"
)
except Exception:
click.echo(f" {col}: (non-numeric)")
def inspect_hf(
repo_id: str, split: str, num_samples: int, text_column: str
) -> None:
"""Inspect a HuggingFace dataset."""
from datasets import load_dataset
click.echo(f"\n{'='*60}")
click.echo(f"HF DATASET INSPECTION")
click.echo(f"{'='*60}")
click.echo(f"Repo: {repo_id}")
click.echo(f"Split: {split}")
click.echo()
ds = load_dataset(repo_id, split=split, streaming=True)
click.echo(f"Columns: {ds.column_names}")
click.echo(f"Features:")
for name, feat in ds.features.items():
click.echo(f" {name}: {feat}")
click.echo(f"\nSample rows (first {num_samples}):")
click.echo(f"{'─'*40}")
for i, row in enumerate(ds):
if i >= num_samples:
break
click.echo(f"\n [Row {i}]")
for key, value in row.items():
if key == text_column and isinstance(value, str):
display = value[:200] + "..." if len(value) > 200 else value
display = display.replace("\n", "\\n")
click.echo(f" {key}: {display}")
else:
click.echo(f" {key}: {value}")
if __name__ == "__main__":
main()