#!/usr/bin/env python3 """inspect_dataset.py — Inspect a dataset to understand its structure before processing. Shows: - Column names and types - Row counts - Sample rows - Available metadata columns - Value distributions for score/label columns Usage: python scripts/inspect_dataset.py --source path/to/data/ python scripts/inspect_dataset.py --hf-repo "HuggingFaceFW/fineweb-edu-score-2" --split train """ from __future__ import annotations import json import sys from pathlib import Path import click import pyarrow.parquet as pq sys.path.insert(0, str(Path(__file__).parent.parent / "src")) @click.command() @click.option("--source", help="Path to local parquet file or directory.") @click.option("--hf-repo", help="HuggingFace dataset repo ID.") @click.option("--split", default="train", help="Dataset split (for HF datasets).") @click.option("--sample", default=5, help="Number of sample rows to show.") @click.option("--text-column", default="text", help="Name of the text column.") def main( source: str | None, hf_repo: str | None, split: str, sample: int, text_column: str, ) -> None: """Inspect a dataset's structure and content.""" if source: inspect_local(source, sample, text_column) elif hf_repo: inspect_hf(hf_repo, split, sample, text_column) else: click.echo("Error: provide --source or --hf-repo", err=True) sys.exit(1) def inspect_local(path_str: str, num_samples: int, text_column: str) -> None: """Inspect a local parquet file or directory.""" path = Path(path_str) if path.is_file(): files = [path] elif path.is_dir(): files = sorted(path.glob("*.parquet")) else: click.echo(f"Error: {path} is not a file or directory", err=True) sys.exit(1) if not files: click.echo(f"No parquet files found in {path}", err=True) sys.exit(1) click.echo(f"\n{'='*60}") click.echo(f"LOCAL DATASET INSPECTION") click.echo(f"{'='*60}") click.echo(f"Path: {path}") click.echo(f"Files: {len(files)} parquet file(s)") click.echo() # Read first file for schema first_file = files[0] pf = pq.ParquetFile(str(first_file)) click.echo(f"Schema (from {first_file.name}):") click.echo(f"{'─'*40}") for field in pf.schema_arrow: click.echo(f" {field.name}: {field.type}") click.echo(f"\nMetadata:") click.echo(f" Rows in first file: {pf.metadata.num_rows:,}") click.echo(f" Row groups: {pf.metadata.num_row_groups}") # Total rows across all files total_rows = 0 for f in files: meta = pq.read_metadata(str(f)) total_rows += meta.num_rows click.echo(f" Total rows (all files): {total_rows:,}") # Sample rows click.echo(f"\nSample rows (first {num_samples}):") click.echo(f"{'─'*40}") table = pq.read_table(str(first_file)).slice(0, num_samples) columns = table.column_names for i in range(min(num_samples, table.num_rows)): click.echo(f"\n [Row {i}]") for col in columns: value = table.column(col)[i].as_py() if col == text_column and isinstance(value, str): # Truncate text for display display = value[:200] + "..." if len(value) > 200 else value display = display.replace("\n", "\\n") click.echo(f" {col}: {display}") else: click.echo(f" {col}: {value}") # Column statistics for numeric columns click.echo(f"\nColumn statistics:") click.echo(f"{'─'*40}") table_full = pq.read_table(str(first_file)) for col in columns: if col == text_column: # Text column: show length stats texts = table_full.column(col).to_pylist() lengths = [len(t) for t in texts if t is not None] if lengths: click.echo( f" {col} (text): " f"min_len={min(lengths)}, " f"max_len={max(lengths)}, " f"mean_len={sum(lengths)//len(lengths)}" ) else: arr = table_full.column(col) try: values = arr.to_pylist() numeric = [v for v in values if isinstance(v, (int, float))] if numeric: click.echo( f" {col}: " f"min={min(numeric):.4g}, " f"max={max(numeric):.4g}, " f"mean={sum(numeric)/len(numeric):.4g}, " f"nulls={values.count(None)}" ) except Exception: click.echo(f" {col}: (non-numeric)") def inspect_hf( repo_id: str, split: str, num_samples: int, text_column: str ) -> None: """Inspect a HuggingFace dataset.""" from datasets import load_dataset click.echo(f"\n{'='*60}") click.echo(f"HF DATASET INSPECTION") click.echo(f"{'='*60}") click.echo(f"Repo: {repo_id}") click.echo(f"Split: {split}") click.echo() ds = load_dataset(repo_id, split=split, streaming=True) click.echo(f"Columns: {ds.column_names}") click.echo(f"Features:") for name, feat in ds.features.items(): click.echo(f" {name}: {feat}") click.echo(f"\nSample rows (first {num_samples}):") click.echo(f"{'─'*40}") for i, row in enumerate(ds): if i >= num_samples: break click.echo(f"\n [Row {i}]") for key, value in row.items(): if key == text_column and isinstance(value, str): display = value[:200] + "..." if len(value) > 200 else value display = display.replace("\n", "\\n") click.echo(f" {key}: {display}") else: click.echo(f" {key}: {value}") if __name__ == "__main__": main()