Initial GrandLine implementation: deterministic shard-first dataset preprocessing for LLM pretraining
ed59144 verified | #!/usr/bin/env python3 | |
| """inspect_dataset.py — Inspect a dataset to understand its structure before processing. | |
| Shows: | |
| - Column names and types | |
| - Row counts | |
| - Sample rows | |
| - Available metadata columns | |
| - Value distributions for score/label columns | |
| Usage: | |
| python scripts/inspect_dataset.py --source path/to/data/ | |
| python scripts/inspect_dataset.py --hf-repo "HuggingFaceFW/fineweb-edu-score-2" --split train | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from pathlib import Path | |
| import click | |
| import pyarrow.parquet as pq | |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | |
| def main( | |
| source: str | None, | |
| hf_repo: str | None, | |
| split: str, | |
| sample: int, | |
| text_column: str, | |
| ) -> None: | |
| """Inspect a dataset's structure and content.""" | |
| if source: | |
| inspect_local(source, sample, text_column) | |
| elif hf_repo: | |
| inspect_hf(hf_repo, split, sample, text_column) | |
| else: | |
| click.echo("Error: provide --source or --hf-repo", err=True) | |
| sys.exit(1) | |
| def inspect_local(path_str: str, num_samples: int, text_column: str) -> None: | |
| """Inspect a local parquet file or directory.""" | |
| path = Path(path_str) | |
| if path.is_file(): | |
| files = [path] | |
| elif path.is_dir(): | |
| files = sorted(path.glob("*.parquet")) | |
| else: | |
| click.echo(f"Error: {path} is not a file or directory", err=True) | |
| sys.exit(1) | |
| if not files: | |
| click.echo(f"No parquet files found in {path}", err=True) | |
| sys.exit(1) | |
| click.echo(f"\n{'='*60}") | |
| click.echo(f"LOCAL DATASET INSPECTION") | |
| click.echo(f"{'='*60}") | |
| click.echo(f"Path: {path}") | |
| click.echo(f"Files: {len(files)} parquet file(s)") | |
| click.echo() | |
| # Read first file for schema | |
| first_file = files[0] | |
| pf = pq.ParquetFile(str(first_file)) | |
| click.echo(f"Schema (from {first_file.name}):") | |
| click.echo(f"{'─'*40}") | |
| for field in pf.schema_arrow: | |
| click.echo(f" {field.name}: {field.type}") | |
| click.echo(f"\nMetadata:") | |
| click.echo(f" Rows in first file: {pf.metadata.num_rows:,}") | |
| click.echo(f" Row groups: {pf.metadata.num_row_groups}") | |
| # Total rows across all files | |
| total_rows = 0 | |
| for f in files: | |
| meta = pq.read_metadata(str(f)) | |
| total_rows += meta.num_rows | |
| click.echo(f" Total rows (all files): {total_rows:,}") | |
| # Sample rows | |
| click.echo(f"\nSample rows (first {num_samples}):") | |
| click.echo(f"{'─'*40}") | |
| table = pq.read_table(str(first_file)).slice(0, num_samples) | |
| columns = table.column_names | |
| for i in range(min(num_samples, table.num_rows)): | |
| click.echo(f"\n [Row {i}]") | |
| for col in columns: | |
| value = table.column(col)[i].as_py() | |
| if col == text_column and isinstance(value, str): | |
| # Truncate text for display | |
| display = value[:200] + "..." if len(value) > 200 else value | |
| display = display.replace("\n", "\\n") | |
| click.echo(f" {col}: {display}") | |
| else: | |
| click.echo(f" {col}: {value}") | |
| # Column statistics for numeric columns | |
| click.echo(f"\nColumn statistics:") | |
| click.echo(f"{'─'*40}") | |
| table_full = pq.read_table(str(first_file)) | |
| for col in columns: | |
| if col == text_column: | |
| # Text column: show length stats | |
| texts = table_full.column(col).to_pylist() | |
| lengths = [len(t) for t in texts if t is not None] | |
| if lengths: | |
| click.echo( | |
| f" {col} (text): " | |
| f"min_len={min(lengths)}, " | |
| f"max_len={max(lengths)}, " | |
| f"mean_len={sum(lengths)//len(lengths)}" | |
| ) | |
| else: | |
| arr = table_full.column(col) | |
| try: | |
| values = arr.to_pylist() | |
| numeric = [v for v in values if isinstance(v, (int, float))] | |
| if numeric: | |
| click.echo( | |
| f" {col}: " | |
| f"min={min(numeric):.4g}, " | |
| f"max={max(numeric):.4g}, " | |
| f"mean={sum(numeric)/len(numeric):.4g}, " | |
| f"nulls={values.count(None)}" | |
| ) | |
| except Exception: | |
| click.echo(f" {col}: (non-numeric)") | |
| def inspect_hf( | |
| repo_id: str, split: str, num_samples: int, text_column: str | |
| ) -> None: | |
| """Inspect a HuggingFace dataset.""" | |
| from datasets import load_dataset | |
| click.echo(f"\n{'='*60}") | |
| click.echo(f"HF DATASET INSPECTION") | |
| click.echo(f"{'='*60}") | |
| click.echo(f"Repo: {repo_id}") | |
| click.echo(f"Split: {split}") | |
| click.echo() | |
| ds = load_dataset(repo_id, split=split, streaming=True) | |
| click.echo(f"Columns: {ds.column_names}") | |
| click.echo(f"Features:") | |
| for name, feat in ds.features.items(): | |
| click.echo(f" {name}: {feat}") | |
| click.echo(f"\nSample rows (first {num_samples}):") | |
| click.echo(f"{'─'*40}") | |
| for i, row in enumerate(ds): | |
| if i >= num_samples: | |
| break | |
| click.echo(f"\n [Row {i}]") | |
| for key, value in row.items(): | |
| if key == text_column and isinstance(value, str): | |
| display = value[:200] + "..." if len(value) > 200 else value | |
| display = display.replace("\n", "\\n") | |
| click.echo(f" {key}: {display}") | |
| else: | |
| click.echo(f" {key}: {value}") | |
| if __name__ == "__main__": | |
| main() | |