Add selective HF parquet shard download support (--hf-files, --hf-subdir, --max-shards, --list-shards)
ab68c56 verified | #!/usr/bin/env python3 | |
| """run_dataset.py โ Main entry point for processing a dataset through GrandLine. | |
| Usage: | |
| # Process local parquet files | |
| python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml | |
| # Download and process specific HF parquet shards (selective!) | |
| python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml \ | |
| --hf-files "data/CC-MAIN-2024-10/000_00000.parquet" \ | |
| "data/CC-MAIN-2024-10/000_00001.parquet" | |
| # Download first N shards from a specific subdirectory | |
| python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml \ | |
| --hf-subdir "data/CC-MAIN-2024-10" --max-shards 5 | |
| # With executor tuning | |
| python scripts/run_dataset.py --config configs/datasets/fineweb_edu.yaml \ | |
| --executor configs/executors/kaggle.yaml \ | |
| --hf-subdir "data/CC-MAIN-2024-10" --max-shards 3 | |
| Workflow: | |
| 1. Load dataset config (+ global + executor configs) | |
| 2. Resolve input shards (local paths OR selective HF download) | |
| 3. Build the appropriate pipeline from trust level and dataset type | |
| 4. Process each shard independently (resumable) | |
| 5. Write packed parquet artifacts + manifests | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| import click | |
| # Add src to path for development mode | |
| sys.path.insert(0, str(Path(__file__).parent.parent / "src")) | |
| from grandline.dedup_store import DedupStore | |
| from grandline.io import download_hf_parquet_files, list_hf_parquet_shards | |
| from grandline.pipelines.code import build_code_pipeline | |
| from grandline.pipelines.curated_web import build_curated_web_pipeline | |
| from grandline.pipelines.math import build_math_pipeline | |
| from grandline.pipelines.papers import build_papers_pipeline | |
| from grandline.pipelines.synthetic import build_synthetic_pipeline | |
| from grandline.runtime import Runtime, RuntimeConfig | |
| from grandline.util.config import load_dataset_config, validate_dataset_config | |
| from grandline.util.logging import get_logger | |
| from grandline.util.paths import ProjectPaths | |
| logger = get_logger("run_dataset") | |
| # Pipeline type โ builder function | |
| PIPELINE_BUILDERS = { | |
| "curated_web": build_curated_web_pipeline, | |
| "code": build_code_pipeline, | |
| "math": build_math_pipeline, | |
| "papers": build_papers_pipeline, | |
| "synthetic": build_synthetic_pipeline, | |
| } | |
| def main( | |
| config_path: str, | |
| global_config: str, | |
| executor: str | None, | |
| base_dir: str, | |
| dry_run: bool, | |
| no_resume: bool, | |
| hf_files: tuple[str, ...], | |
| hf_subdir: str | None, | |
| max_shards: int | None, | |
| list_shards: bool, | |
| overrides: tuple[str, ...], | |
| ) -> None: | |
| """Process a dataset through the GrandLine pipeline.""" | |
| # Setup paths | |
| paths = ProjectPaths(base_dir) | |
| paths.ensure_all() | |
| # Load config | |
| global_path = Path(global_config) | |
| global_cfg = str(global_path) if global_path.exists() else None | |
| config = load_dataset_config( | |
| config_path=config_path, | |
| global_config_path=global_cfg, | |
| executor_config_path=executor, | |
| overrides=list(overrides), | |
| ) | |
| # Validate | |
| errors = validate_dataset_config(config) | |
| if errors: | |
| for err in errors: | |
| logger.error(f"Config error: {err}") | |
| sys.exit(1) | |
| dataset_name = config["name"] | |
| pipeline_type = config.get("pipeline_type", "curated_web") | |
| source_cfg = config.get("source", {}) | |
| repo_id = source_cfg.get("repo") | |
| # Handle --list-shards: just show what's available and exit | |
| if list_shards: | |
| if not repo_id: | |
| logger.error("Cannot list shards: no source.repo in config") | |
| sys.exit(1) | |
| subdir = hf_subdir or "data" | |
| shards = list_hf_parquet_shards(repo_id, subdir=subdir) | |
| print(f"\nAvailable shards in {repo_id}/{subdir}:") | |
| print(f"{'โ' * 60}") | |
| for s in shards: | |
| if s.get("is_dir"): | |
| print(f" ๐ {s['path']}/") | |
| else: | |
| size_gb = s["size_bytes"] / 1e9 | |
| print(f" ๐ {s['path']} ({size_gb:.2f} GB)") | |
| print(f"\nTotal: {len(shards)} entries") | |
| return | |
| logger.info(f"Dataset: {dataset_name}") | |
| logger.info(f"Pipeline type: {pipeline_type}") | |
| logger.info(f"Trust level: {config.get('trust_level', 0)}") | |
| # Resolve input paths | |
| input_paths: list[str] = [] | |
| if hf_files or hf_subdir: | |
| # Download specific HF parquet files | |
| if not repo_id: | |
| logger.error("Cannot download HF files: no source.repo in config") | |
| sys.exit(1) | |
| if hf_files: | |
| logger.info(f"Downloading {len(hf_files)} specific shard(s) from {repo_id}...") | |
| local_paths = download_hf_parquet_files( | |
| repo_id=repo_id, | |
| file_patterns=list(hf_files), | |
| ) | |
| else: | |
| logger.info( | |
| f"Discovering shards in {repo_id}/{hf_subdir}" | |
| f"{f' (max {max_shards})' if max_shards else ''}..." | |
| ) | |
| local_paths = download_hf_parquet_files( | |
| repo_id=repo_id, | |
| subdir=hf_subdir, | |
| max_files=max_shards, | |
| ) | |
| input_paths = local_paths | |
| logger.info(f"Downloaded {len(input_paths)} shard(s)") | |
| else: | |
| # Use local paths from config | |
| input_paths = source_cfg.get("paths", []) | |
| if isinstance(input_paths, str): | |
| input_paths = [input_paths] | |
| if not input_paths: | |
| logger.error( | |
| "No input paths. Use --hf-files, --hf-subdir, or set source.paths in config." | |
| ) | |
| sys.exit(1) | |
| # Build pipeline | |
| if pipeline_type not in PIPELINE_BUILDERS: | |
| logger.error( | |
| f"Unknown pipeline type: {pipeline_type}. " | |
| f"Available: {list(PIPELINE_BUILDERS.keys())}" | |
| ) | |
| sys.exit(1) | |
| # Initialize dedup store | |
| dedup_store = DedupStore(paths.dedup_db_path) | |
| try: | |
| builder = PIPELINE_BUILDERS[pipeline_type] | |
| pipeline = builder(config, dedup_store) | |
| logger.info(f"Pipeline: {pipeline}") | |
| logger.info(f"Pipeline fingerprint: {pipeline.fingerprint[:16]}...") | |
| # Build runtime config | |
| rt_config = RuntimeConfig( | |
| output_dir=str(paths.dataset_output_dir(dataset_name)), | |
| state_dir=str(paths.state_dir / dataset_name), | |
| num_workers=config.get("executor", {}).get("num_workers", 1), | |
| resume=not no_resume, | |
| dry_run=dry_run, | |
| ) | |
| # Run | |
| runtime = Runtime(config=rt_config, pipeline=pipeline) | |
| manifests = runtime.run(input_paths) | |
| # Summary | |
| if manifests: | |
| total_seqs = sum(m.num_output_sequences for m in manifests) | |
| total_time = sum(m.processing_time_seconds for m in manifests) | |
| logger.info( | |
| f"Done: {len(manifests)} shards, " | |
| f"{total_seqs} sequences, " | |
| f"{total_time:.1f}s total" | |
| ) | |
| elif not dry_run: | |
| logger.info("No shards to process (all completed or none found).") | |
| finally: | |
| dedup_store.close() | |
| if __name__ == "__main__": | |
| main() | |