sage / data /pipeline.py
sage002's picture
feat: add authenticated remote control UI and ngrok launcher
1e799aa verified
"""End-to-end raw-corpus to Parquet shard pipeline."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sentencepiece as spm
from data.dedup import deduplicate_records
from data.filter import FilterConfig, filter_record
from data.ingest import SOURCE_REGISTRY, stream_source
from data.shard import ShardConfig, write_shards
def _select_sources(names: list[str] | None) -> tuple:
if not names:
return SOURCE_REGISTRY
wanted = set(names)
selected = tuple(spec for spec in SOURCE_REGISTRY if spec.name in wanted)
missing = sorted(wanted - {spec.name for spec in selected})
if missing:
raise ValueError(f"Unknown sources: {', '.join(missing)}")
return selected
def build_records(source_names: list[str] | None = None, limit_per_source: int | None = None) -> list[dict[str, object]]:
"""Load, filter, and deduplicate records from the configured raw sources."""
records: list[dict[str, object]] = []
for spec in _select_sources(source_names):
source_records: list[dict[str, object]] = []
for record in stream_source(spec):
filtered = filter_record(record, FilterConfig())
if filtered is None:
continue
source_records.append(filtered)
if limit_per_source is not None and len(source_records) >= limit_per_source:
break
records.extend(source_records)
return deduplicate_records(records)
def run_pipeline(
tokenizer_model: str,
output_dir: str = "data/processed",
source_names: list[str] | None = None,
shard_size: int = 2048,
limit_per_source: int | None = None,
) -> dict[str, object]:
"""Create Parquet shards from the current raw JSONL corpora."""
tokenizer = spm.SentencePieceProcessor()
tokenizer.load(tokenizer_model)
records = build_records(source_names=source_names, limit_per_source=limit_per_source)
manifest = write_shards(records, tokenizer, ShardConfig(output_dir=output_dir, shard_size=shard_size))
summary = {
"tokenizer_model": tokenizer_model,
"output_dir": output_dir,
"records": len(records),
"sources": source_names or [spec.name for spec in SOURCE_REGISTRY],
"manifest": manifest,
}
Path(output_dir).mkdir(parents=True, exist_ok=True)
(Path(output_dir) / "pipeline_summary.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
return summary
def build_argparser() -> argparse.ArgumentParser:
"""Build the CLI parser."""
parser = argparse.ArgumentParser(description="Filter, deduplicate, tokenize, and shard SAGE raw corpora.")
parser.add_argument("--tokenizer-model", default="tokenizer/tokenizer.model", help="SentencePiece tokenizer model.")
parser.add_argument("--output-dir", default="data/processed", help="Destination directory for parquet shards.")
parser.add_argument("--sources", nargs="*", default=None, help="Subset of source names from data.ingest.SOURCE_REGISTRY.")
parser.add_argument("--shard-size", type=int, default=2048, help="Rows per parquet shard.")
parser.add_argument("--limit-per-source", type=int, default=None, help="Optional cap for smoke-testing.")
return parser
def main() -> None:
"""CLI entrypoint."""
args = build_argparser().parse_args()
summary = run_pipeline(
tokenizer_model=args.tokenizer_model,
output_dir=args.output_dir,
source_names=args.sources,
shard_size=args.shard_size,
limit_per_source=args.limit_per_source,
)
print(json.dumps(summary, indent=2))
if __name__ == "__main__":
main()