grandline / scripts /run_dataset.py
dignity045's picture
Add selective HF parquet shard download support (--hf-files, --hf-subdir, --max-shards, --list-shards)
ab68c56 verified
#!/usr/bin/env python3
"""run_dataset.py โ€” Main entry point for processing a dataset through GrandLine.
Usage:
# Process local parquet files
python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml
# Download and process specific HF parquet shards (selective!)
python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml \
--hf-files "data/CC-MAIN-2024-10/000_00000.parquet" \
"data/CC-MAIN-2024-10/000_00001.parquet"
# Download first N shards from a specific subdirectory
python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml \
--hf-subdir "data/CC-MAIN-2024-10" --max-shards 5
# With executor tuning
python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml \
--executor configs/executors/kaggle.yaml \
--hf-subdir "data/CC-MAIN-2024-10" --max-shards 3
Workflow:
1. Load dataset config (+ global + executor configs)
2. Resolve input shards (local paths OR selective HF download)
3. Build the appropriate pipeline from trust level and dataset type
4. Process each shard independently (resumable)
5. Write packed parquet artifacts + manifests
"""
from __future__ import annotations
import sys
from pathlib import Path
import click
# Add src to path for development mode
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
from grandline.dedup_store import DedupStore
from grandline.io import download_hf_parquet_files, list_hf_parquet_shards
from grandline.pipelines.code import build_code_pipeline
from grandline.pipelines.curated_web import build_curated_web_pipeline
from grandline.pipelines.math import build_math_pipeline
from grandline.pipelines.papers import build_papers_pipeline
from grandline.pipelines.synthetic import build_synthetic_pipeline
from grandline.runtime import Runtime, RuntimeConfig
from grandline.util.config import load_dataset_config, validate_dataset_config
from grandline.util.logging import get_logger
from grandline.util.paths import ProjectPaths
logger = get_logger("run_dataset")
# Pipeline type โ†’ builder function
PIPELINE_BUILDERS = {
"curated_web": build_curated_web_pipeline,
"code": build_code_pipeline,
"math": build_math_pipeline,
"papers": build_papers_pipeline,
"synthetic": build_synthetic_pipeline,
}
@click.command()
@click.option(
"--config", "config_path", required=True, help="Path to dataset config YAML."
)
@click.option(
"--global-config",
default="configs/global.yaml",
help="Path to global config YAML.",
)
@click.option("--executor", default=None, help="Path to executor config YAML.")
@click.option("--base-dir", default=".", help="Project base directory.")
@click.option("--dry-run", is_flag=True, help="Only show what would be done.")
@click.option("--no-resume", is_flag=True, help="Reprocess all shards from scratch.")
@click.option(
"--hf-files",
multiple=True,
help="Specific parquet file paths within the HF repo to download and process.",
)
@click.option(
"--hf-subdir",
default=None,
help="HF repo subdirectory to discover shards from (e.g. 'data/CC-MAIN-2024-10').",
)
@click.option(
"--max-shards",
default=None,
type=int,
help="Maximum number of shards to download from --hf-subdir.",
)
@click.option(
"--list-shards",
is_flag=True,
help="List available shards in the HF repo and exit (no processing).",
)
@click.argument("overrides", nargs=-1)
def main(
config_path: str,
global_config: str,
executor: str | None,
base_dir: str,
dry_run: bool,
no_resume: bool,
hf_files: tuple[str, ...],
hf_subdir: str | None,
max_shards: int | None,
list_shards: bool,
overrides: tuple[str, ...],
) -> None:
"""Process a dataset through the GrandLine pipeline."""
# Setup paths
paths = ProjectPaths(base_dir)
paths.ensure_all()
# Load config
global_path = Path(global_config)
global_cfg = str(global_path) if global_path.exists() else None
config = load_dataset_config(
config_path=config_path,
global_config_path=global_cfg,
executor_config_path=executor,
overrides=list(overrides),
)
# Validate
errors = validate_dataset_config(config)
if errors:
for err in errors:
logger.error(f"Config error: {err}")
sys.exit(1)
dataset_name = config["name"]
pipeline_type = config.get("pipeline_type", "curated_web")
source_cfg = config.get("source", {})
repo_id = source_cfg.get("repo")
# Handle --list-shards: just show what's available and exit
if list_shards:
if not repo_id:
logger.error("Cannot list shards: no source.repo in config")
sys.exit(1)
subdir = hf_subdir or "data"
shards = list_hf_parquet_shards(repo_id, subdir=subdir)
print(f"\nAvailable shards in {repo_id}/{subdir}:")
print(f"{'โ”€' * 60}")
for s in shards:
if s.get("is_dir"):
print(f" ๐Ÿ“ {s['path']}/")
else:
size_gb = s["size_bytes"] / 1e9
print(f" ๐Ÿ“„ {s['path']} ({size_gb:.2f} GB)")
print(f"\nTotal: {len(shards)} entries")
return
logger.info(f"Dataset: {dataset_name}")
logger.info(f"Pipeline type: {pipeline_type}")
logger.info(f"Trust level: {config.get('trust_level', 0)}")
# Resolve input paths
input_paths: list[str] = []
if hf_files or hf_subdir:
# Download specific HF parquet files
if not repo_id:
logger.error("Cannot download HF files: no source.repo in config")
sys.exit(1)
if hf_files:
logger.info(f"Downloading {len(hf_files)} specific shard(s) from {repo_id}...")
local_paths = download_hf_parquet_files(
repo_id=repo_id,
file_patterns=list(hf_files),
)
else:
logger.info(
f"Discovering shards in {repo_id}/{hf_subdir}"
f"{f' (max {max_shards})' if max_shards else ''}..."
)
local_paths = download_hf_parquet_files(
repo_id=repo_id,
subdir=hf_subdir,
max_files=max_shards,
)
input_paths = local_paths
logger.info(f"Downloaded {len(input_paths)} shard(s)")
else:
# Use local paths from config
input_paths = source_cfg.get("paths", [])
if isinstance(input_paths, str):
input_paths = [input_paths]
if not input_paths:
logger.error(
"No input paths. Use --hf-files, --hf-subdir, or set source.paths in config."
)
sys.exit(1)
# Build pipeline
if pipeline_type not in PIPELINE_BUILDERS:
logger.error(
f"Unknown pipeline type: {pipeline_type}. "
f"Available: {list(PIPELINE_BUILDERS.keys())}"
)
sys.exit(1)
# Initialize dedup store
dedup_store = DedupStore(paths.dedup_db_path)
try:
builder = PIPELINE_BUILDERS[pipeline_type]
pipeline = builder(config, dedup_store)
logger.info(f"Pipeline: {pipeline}")
logger.info(f"Pipeline fingerprint: {pipeline.fingerprint[:16]}...")
# Build runtime config
rt_config = RuntimeConfig(
output_dir=str(paths.dataset_output_dir(dataset_name)),
state_dir=str(paths.state_dir / dataset_name),
num_workers=config.get("executor", {}).get("num_workers", 1),
resume=not no_resume,
dry_run=dry_run,
)
# Run
runtime = Runtime(config=rt_config, pipeline=pipeline)
manifests = runtime.run(input_paths)
# Summary
if manifests:
total_seqs = sum(m.num_output_sequences for m in manifests)
total_time = sum(m.processing_time_seconds for m in manifests)
logger.info(
f"Done: {len(manifests)} shards, "
f"{total_seqs} sequences, "
f"{total_time:.1f}s total"
)
elif not dry_run:
logger.info("No shards to process (all completed or none found).")
finally:
dedup_store.close()
if __name__ == "__main__":
main()