File size: 5,930 Bytes
ed59144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
#!/usr/bin/env python3
"""inspect_dataset.py — Inspect a dataset to understand its structure before processing.

Shows:
- Column names and types
- Row counts
- Sample rows
- Available metadata columns
- Value distributions for score/label columns

Usage:
    python scripts/inspect_dataset.py --source path/to/data/
    python scripts/inspect_dataset.py --hf-repo "HuggingFaceFW/fineweb-edu-score-2" --split train
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

import click
import pyarrow.parquet as pq

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))


@click.command()
@click.option("--source", help="Path to local parquet file or directory.")
@click.option("--hf-repo", help="HuggingFace dataset repo ID.")
@click.option("--split", default="train", help="Dataset split (for HF datasets).")
@click.option("--sample", default=5, help="Number of sample rows to show.")
@click.option("--text-column", default="text", help="Name of the text column.")
def main(
    source: str | None,
    hf_repo: str | None,
    split: str,
    sample: int,
    text_column: str,
) -> None:
    """Inspect a dataset's structure and content."""
    if source:
        inspect_local(source, sample, text_column)
    elif hf_repo:
        inspect_hf(hf_repo, split, sample, text_column)
    else:
        click.echo("Error: provide --source or --hf-repo", err=True)
        sys.exit(1)


def inspect_local(path_str: str, num_samples: int, text_column: str) -> None:
    """Inspect a local parquet file or directory."""
    path = Path(path_str)

    if path.is_file():
        files = [path]
    elif path.is_dir():
        files = sorted(path.glob("*.parquet"))
    else:
        click.echo(f"Error: {path} is not a file or directory", err=True)
        sys.exit(1)

    if not files:
        click.echo(f"No parquet files found in {path}", err=True)
        sys.exit(1)

    click.echo(f"\n{'='*60}")
    click.echo(f"LOCAL DATASET INSPECTION")
    click.echo(f"{'='*60}")
    click.echo(f"Path: {path}")
    click.echo(f"Files: {len(files)} parquet file(s)")
    click.echo()

    # Read first file for schema
    first_file = files[0]
    pf = pq.ParquetFile(str(first_file))

    click.echo(f"Schema (from {first_file.name}):")
    click.echo(f"{'─'*40}")
    for field in pf.schema_arrow:
        click.echo(f"  {field.name}: {field.type}")

    click.echo(f"\nMetadata:")
    click.echo(f"  Rows in first file: {pf.metadata.num_rows:,}")
    click.echo(f"  Row groups: {pf.metadata.num_row_groups}")

    # Total rows across all files
    total_rows = 0
    for f in files:
        meta = pq.read_metadata(str(f))
        total_rows += meta.num_rows
    click.echo(f"  Total rows (all files): {total_rows:,}")

    # Sample rows
    click.echo(f"\nSample rows (first {num_samples}):")
    click.echo(f"{'─'*40}")
    table = pq.read_table(str(first_file)).slice(0, num_samples)
    columns = table.column_names

    for i in range(min(num_samples, table.num_rows)):
        click.echo(f"\n  [Row {i}]")
        for col in columns:
            value = table.column(col)[i].as_py()
            if col == text_column and isinstance(value, str):
                # Truncate text for display
                display = value[:200] + "..." if len(value) > 200 else value
                display = display.replace("\n", "\\n")
                click.echo(f"    {col}: {display}")
            else:
                click.echo(f"    {col}: {value}")

    # Column statistics for numeric columns
    click.echo(f"\nColumn statistics:")
    click.echo(f"{'─'*40}")
    table_full = pq.read_table(str(first_file))
    for col in columns:
        if col == text_column:
            # Text column: show length stats
            texts = table_full.column(col).to_pylist()
            lengths = [len(t) for t in texts if t is not None]
            if lengths:
                click.echo(
                    f"  {col} (text): "
                    f"min_len={min(lengths)}, "
                    f"max_len={max(lengths)}, "
                    f"mean_len={sum(lengths)//len(lengths)}"
                )
        else:
            arr = table_full.column(col)
            try:
                values = arr.to_pylist()
                numeric = [v for v in values if isinstance(v, (int, float))]
                if numeric:
                    click.echo(
                        f"  {col}: "
                        f"min={min(numeric):.4g}, "
                        f"max={max(numeric):.4g}, "
                        f"mean={sum(numeric)/len(numeric):.4g}, "
                        f"nulls={values.count(None)}"
                    )
            except Exception:
                click.echo(f"  {col}: (non-numeric)")


def inspect_hf(
    repo_id: str, split: str, num_samples: int, text_column: str
) -> None:
    """Inspect a HuggingFace dataset."""
    from datasets import load_dataset

    click.echo(f"\n{'='*60}")
    click.echo(f"HF DATASET INSPECTION")
    click.echo(f"{'='*60}")
    click.echo(f"Repo: {repo_id}")
    click.echo(f"Split: {split}")
    click.echo()

    ds = load_dataset(repo_id, split=split, streaming=True)

    click.echo(f"Columns: {ds.column_names}")
    click.echo(f"Features:")
    for name, feat in ds.features.items():
        click.echo(f"  {name}: {feat}")

    click.echo(f"\nSample rows (first {num_samples}):")
    click.echo(f"{'─'*40}")

    for i, row in enumerate(ds):
        if i >= num_samples:
            break
        click.echo(f"\n  [Row {i}]")
        for key, value in row.items():
            if key == text_column and isinstance(value, str):
                display = value[:200] + "..." if len(value) > 200 else value
                display = display.replace("\n", "\\n")
                click.echo(f"    {key}: {display}")
            else:
                click.echo(f"    {key}: {value}")


if __name__ == "__main__":
    main()