File size: 5,930 Bytes
ed59144 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | #!/usr/bin/env python3
"""inspect_dataset.py — Inspect a dataset to understand its structure before processing.
Shows:
- Column names and types
- Row counts
- Sample rows
- Available metadata columns
- Value distributions for score/label columns
Usage:
python scripts/inspect_dataset.py --source path/to/data/
python scripts/inspect_dataset.py --hf-repo "HuggingFaceFW/fineweb-edu-score-2" --split train
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import click
import pyarrow.parquet as pq
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
@click.command()
@click.option("--source", help="Path to local parquet file or directory.")
@click.option("--hf-repo", help="HuggingFace dataset repo ID.")
@click.option("--split", default="train", help="Dataset split (for HF datasets).")
@click.option("--sample", default=5, help="Number of sample rows to show.")
@click.option("--text-column", default="text", help="Name of the text column.")
def main(
source: str | None,
hf_repo: str | None,
split: str,
sample: int,
text_column: str,
) -> None:
"""Inspect a dataset's structure and content."""
if source:
inspect_local(source, sample, text_column)
elif hf_repo:
inspect_hf(hf_repo, split, sample, text_column)
else:
click.echo("Error: provide --source or --hf-repo", err=True)
sys.exit(1)
def inspect_local(path_str: str, num_samples: int, text_column: str) -> None:
"""Inspect a local parquet file or directory."""
path = Path(path_str)
if path.is_file():
files = [path]
elif path.is_dir():
files = sorted(path.glob("*.parquet"))
else:
click.echo(f"Error: {path} is not a file or directory", err=True)
sys.exit(1)
if not files:
click.echo(f"No parquet files found in {path}", err=True)
sys.exit(1)
click.echo(f"\n{'='*60}")
click.echo(f"LOCAL DATASET INSPECTION")
click.echo(f"{'='*60}")
click.echo(f"Path: {path}")
click.echo(f"Files: {len(files)} parquet file(s)")
click.echo()
# Read first file for schema
first_file = files[0]
pf = pq.ParquetFile(str(first_file))
click.echo(f"Schema (from {first_file.name}):")
click.echo(f"{'─'*40}")
for field in pf.schema_arrow:
click.echo(f" {field.name}: {field.type}")
click.echo(f"\nMetadata:")
click.echo(f" Rows in first file: {pf.metadata.num_rows:,}")
click.echo(f" Row groups: {pf.metadata.num_row_groups}")
# Total rows across all files
total_rows = 0
for f in files:
meta = pq.read_metadata(str(f))
total_rows += meta.num_rows
click.echo(f" Total rows (all files): {total_rows:,}")
# Sample rows
click.echo(f"\nSample rows (first {num_samples}):")
click.echo(f"{'─'*40}")
table = pq.read_table(str(first_file)).slice(0, num_samples)
columns = table.column_names
for i in range(min(num_samples, table.num_rows)):
click.echo(f"\n [Row {i}]")
for col in columns:
value = table.column(col)[i].as_py()
if col == text_column and isinstance(value, str):
# Truncate text for display
display = value[:200] + "..." if len(value) > 200 else value
display = display.replace("\n", "\\n")
click.echo(f" {col}: {display}")
else:
click.echo(f" {col}: {value}")
# Column statistics for numeric columns
click.echo(f"\nColumn statistics:")
click.echo(f"{'─'*40}")
table_full = pq.read_table(str(first_file))
for col in columns:
if col == text_column:
# Text column: show length stats
texts = table_full.column(col).to_pylist()
lengths = [len(t) for t in texts if t is not None]
if lengths:
click.echo(
f" {col} (text): "
f"min_len={min(lengths)}, "
f"max_len={max(lengths)}, "
f"mean_len={sum(lengths)//len(lengths)}"
)
else:
arr = table_full.column(col)
try:
values = arr.to_pylist()
numeric = [v for v in values if isinstance(v, (int, float))]
if numeric:
click.echo(
f" {col}: "
f"min={min(numeric):.4g}, "
f"max={max(numeric):.4g}, "
f"mean={sum(numeric)/len(numeric):.4g}, "
f"nulls={values.count(None)}"
)
except Exception:
click.echo(f" {col}: (non-numeric)")
def inspect_hf(
repo_id: str, split: str, num_samples: int, text_column: str
) -> None:
"""Inspect a HuggingFace dataset."""
from datasets import load_dataset
click.echo(f"\n{'='*60}")
click.echo(f"HF DATASET INSPECTION")
click.echo(f"{'='*60}")
click.echo(f"Repo: {repo_id}")
click.echo(f"Split: {split}")
click.echo()
ds = load_dataset(repo_id, split=split, streaming=True)
click.echo(f"Columns: {ds.column_names}")
click.echo(f"Features:")
for name, feat in ds.features.items():
click.echo(f" {name}: {feat}")
click.echo(f"\nSample rows (first {num_samples}):")
click.echo(f"{'─'*40}")
for i, row in enumerate(ds):
if i >= num_samples:
break
click.echo(f"\n [Row {i}]")
for key, value in row.items():
if key == text_column and isinstance(value, str):
display = value[:200] + "..." if len(value) > 200 else value
display = display.replace("\n", "\\n")
click.echo(f" {key}: {display}")
else:
click.echo(f" {key}: {value}")
if __name__ == "__main__":
main()
|