| """Tokenization, manifesting, and Parquet sharding.""" |
|
|
| from __future__ import annotations |
|
|
| import hashlib |
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Iterable |
|
|
| try: |
| import pyarrow as pa |
| import pyarrow.parquet as pq |
| except ImportError: |
| pa = None |
| pq = None |
|
|
|
|
| SCHEMA_COLUMNS = ("id", "text", "tokens", "domain_tag", "quality_tier", "lang", "token_count", "split") |
|
|
|
|
| @dataclass(frozen=True) |
| class ShardConfig: |
| """Parameters for Parquet shard writing.""" |
|
|
| output_dir: str |
| shard_size: int = 2048 |
| validation_ratio: float = 0.01 |
| test_ratio: float = 0.001 |
|
|
|
|
| def assign_split(record_id: str, validation_ratio: float, test_ratio: float) -> str: |
| """Assign a deterministic split from the content id.""" |
| value = int(record_id[:8], 16) / 0xFFFFFFFF |
| if value < test_ratio: |
| return "test" |
| if value < test_ratio + validation_ratio: |
| return "validation" |
| return "train" |
|
|
|
|
| def build_manifest(shard_paths: Iterable[Path]) -> dict[str, object]: |
| """Create a manifest describing shard files.""" |
| shard_paths = list(shard_paths) |
| digest = hashlib.sha256() |
| for path in shard_paths: |
| digest.update(path.name.encode("utf-8")) |
| digest.update(str(path.stat().st_size).encode("utf-8")) |
| return { |
| "format": "parquet", |
| "schema": list(SCHEMA_COLUMNS), |
| "shards": [path.name for path in shard_paths], |
| "dataset_hash": digest.hexdigest(), |
| } |
|
|
|
|
| def write_shards(records: Iterable[dict[str, object]], tokenizer, config: ShardConfig) -> dict[str, object]: |
| """Write tokenized records to Parquet shards and emit a manifest.""" |
| if pa is None or pq is None: |
| raise ImportError("pyarrow is required to write parquet shards.") |
| output_dir = Path(config.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| buffer: list[dict[str, object]] = [] |
| shard_paths: list[Path] = [] |
| shard_index = 0 |
| for record in records: |
| tokens = tokenizer.encode(str(record["text"]), out_type=int) |
| row = { |
| "id": str(record["id"]), |
| "text": str(record["text"]), |
| "tokens": tokens, |
| "domain_tag": str(record["domain_tag"]), |
| "quality_tier": str(record["quality_tier"]), |
| "lang": str(record["lang"]), |
| "token_count": len(tokens), |
| "split": assign_split(str(record["id"]), config.validation_ratio, config.test_ratio), |
| } |
| buffer.append(row) |
| if len(buffer) >= config.shard_size: |
| shard_paths.append(_flush_shard(output_dir, shard_index, buffer)) |
| shard_index += 1 |
| buffer = [] |
| if buffer: |
| shard_paths.append(_flush_shard(output_dir, shard_index, buffer)) |
| manifest = build_manifest(shard_paths) |
| (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2), encoding="utf-8") |
| return manifest |
|
|
|
|
| def _flush_shard(output_dir: Path, shard_index: int, rows: list[dict[str, object]]) -> Path: |
| """Persist one Parquet shard.""" |
| table = pa.table({column: [row[column] for row in rows] for column in SCHEMA_COLUMNS}) |
| path = output_dir / f"shard-{shard_index:05d}.parquet" |
| pq.write_table(table, path) |
| return path |
|
|