File size: 4,287 Bytes
53f0cc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Component 3 runner script.

Reads YAML config and executes full Hugging Face dataset preprocessing.
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List

import yaml

# This makes "src" imports work when script is run from project root.
PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from src.dataset_pipeline.hf_dataset_pipeline import (  # noqa: E402
    HFDatasetPipeline,
    PipelineConfig,
    SourceDatasetSpec,
)


def parse_args() -> argparse.Namespace:
    # Parse command-line arguments for config and optional overrides.
    parser = argparse.ArgumentParser(description="Run Component 3 dataset preprocessing pipeline.")
    parser.add_argument(
        "--config",
        default="configs/component3_dataset_pipeline.yaml",
        help="Path to YAML config file.",
    )
    parser.add_argument(
        "--max_records_per_dataset",
        type=int,
        default=None,
        help="Optional override for quick test runs.",
    )
    return parser.parse_args()


def _read_yaml(path: Path) -> Dict[str, Any]:
    # Reads YAML file with friendly errors.
    if not path.exists():
        raise FileNotFoundError(f"Config file not found: {path}")
    with path.open("r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    if not isinstance(data, dict):
        raise ValueError("Config file is invalid. Expected a YAML object at top level.")
    return data


def _build_config(data: Dict[str, Any], max_records_override: int | None) -> PipelineConfig:
    # Converts generic dict into strongly typed config objects.
    dataset_specs: List[SourceDatasetSpec] = []
    datasets_data = data.get("datasets", [])
    if not isinstance(datasets_data, list) or not datasets_data:
        raise ValueError("Config must include a non-empty 'datasets' list.")

    for item in datasets_data:
        dataset_specs.append(
            SourceDatasetSpec(
                hf_dataset_id=str(item["hf_dataset_id"]),
                split=str(item.get("split", "train")),
                prompt_field=str(item["prompt_field"]),
                code_field=str(item["code_field"]),
                language_field=item.get("language_field"),
                default_language=str(item.get("default_language", "python")),
            )
        )

    cfg = PipelineConfig(
        datasets=dataset_specs,
        tokenizer_dir=str(data["tokenizer_dir"]),
        interim_output_dir=str(data["interim_output_dir"]),
        processed_output_dir=str(data["processed_output_dir"]),
        dedupe_db_path=str(data["dedupe_db_path"]),
        max_records_per_dataset=data.get("max_records_per_dataset"),
        min_prompt_chars=int(data.get("min_prompt_chars", 8)),
        min_code_chars=int(data.get("min_code_chars", 16)),
        max_code_chars=int(data.get("max_code_chars", 40_000)),
        progress_every=int(data.get("progress_every", 1_000)),
    )

    if max_records_override is not None:
        cfg.max_records_per_dataset = max_records_override
    return cfg


def main() -> None:
    # Main entry with explicit plain-English error handling.
    args = parse_args()
    try:
        config_path = Path(args.config)
        data = _read_yaml(config_path)
        cfg = _build_config(data, args.max_records_per_dataset)
        pipeline = HFDatasetPipeline(cfg)
        try:
            stats = pipeline.run()
        finally:
            pipeline.close()

        print("Component 3 pipeline completed successfully.")
        print("Saved files:")
        print(f"- {Path(cfg.interim_output_dir) / 'combined_clean.jsonl'}")
        print(f"- {Path(cfg.processed_output_dir) / 'train_tokenized.jsonl'}")
        print(f"- {Path(cfg.processed_output_dir) / 'pipeline_stats.json'}")
        print("Summary stats:")
        print(json.dumps(stats, indent=2))
    except Exception as exc:
        print("Component 3 pipeline failed.")
        print(f"What went wrong: {exc}")
        print(
            "Fix suggestion: verify internet access for Hugging Face, tokenizer path, "
            "and config field names."
        )
        raise SystemExit(1)


if __name__ == "__main__":
    main()