Spaces:
Running
Running
| """ | |
| data_factory/pipeline.py | |
| ========================= | |
| Master orchestration pipeline for the NL2SQL Synthetic Data Factory. | |
| This module ties together: | |
| 1. Template library (66 verified SQL templates across 4 domains) | |
| 2. Rule-based NL augmentation (augmentor.py) | |
| 3. vLLM persona-based NL generation (generator.py) | |
| 4. SQL execution validation (validator.py) | |
| 5. Output serialisation (JSONL + Parquet) | |
| Run modes: | |
| --mode base : Only uses template base_nl + rule augmentation (no GPU required) | |
| --mode full : base + vLLM persona generation (requires H100) | |
| Output dataset format (JSONL, one record per line): | |
| { | |
| "prompt": [{"role": "system", ...}, {"role": "user", ...}], | |
| "sql": "SELECT ...", | |
| "metadata": { "domain", "difficulty", "persona", ... } | |
| } | |
| This format is directly loadable by: | |
| datasets.load_dataset("json", data_files="output/train.jsonl") | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Iterator, Optional | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger("pipeline") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _ensure_dirs(*dirs: Path) -> None: | |
| for d in dirs: | |
| d.mkdir(parents=True, exist_ok=True) | |
| def _write_jsonl(records: list[dict], path: Path) -> None: | |
| with open(path, "w", encoding="utf-8") as f: | |
| for rec in records: | |
| f.write(json.dumps(rec, ensure_ascii=False) + "\n") | |
| logger.info("Wrote %d records to %s", len(records), path) | |
| def _write_parquet(records: list[dict], path: Path) -> None: | |
| try: | |
| import pandas as pd | |
| df = pd.DataFrame(records) | |
| df.to_parquet(path, index=False, engine="pyarrow", compression="snappy") | |
| logger.info("Wrote %d records to %s (Parquet)", len(records), path) | |
| except ImportError: | |
| logger.warning("pandas/pyarrow not installed β skipping Parquet output.") | |
| def _train_val_test_split( | |
| records: list[dict], | |
| train_frac: float = 0.90, | |
| val_frac: float = 0.05, | |
| seed: int = 42, | |
| ) -> tuple[list[dict], list[dict], list[dict]]: | |
| """ | |
| Stratified split by (domain, difficulty) to ensure all combinations | |
| are represented in every split. | |
| """ | |
| rng = random.Random(seed) | |
| from collections import defaultdict | |
| buckets: dict[str, list[dict]] = defaultdict(list) | |
| for rec in records: | |
| key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}" | |
| buckets[key].append(rec) | |
| train, val, test = [], [], [] | |
| for key, bucket in buckets.items(): | |
| rng.shuffle(bucket) | |
| n = len(bucket) | |
| n_train = max(1, int(n * train_frac)) | |
| n_val = max(1, int(n * val_frac)) | |
| train.extend(bucket[:n_train]) | |
| val.extend(bucket[n_train:n_train + n_val]) | |
| test.extend(bucket[n_train + n_val:]) | |
| rng.shuffle(train) | |
| rng.shuffle(val) | |
| rng.shuffle(test) | |
| return train, val, test | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PHASE 1: BASE + RULE AUGMENTATION (no GPU required) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_base_pipeline( | |
| templates: list, | |
| n_augmentations: int = 5, | |
| seed: int = 42, | |
| ) -> list[dict]: | |
| """ | |
| Generate training records from: | |
| (a) the canonical base_nl of each template | |
| (b) rule-based augmented NL variants | |
| Returns a list of training dicts (ready to write to JSONL). | |
| """ | |
| from data_factory.augmentor import augment_nl | |
| from data_factory.validator import SQLValidator, build_record | |
| from data_factory.schemas import SCHEMA_MAP | |
| # Build one validator per domain (reuse connection across templates) | |
| validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP} | |
| records: list[dict] = [] | |
| for t_idx, template in enumerate(templates): | |
| v = validators[template["domain"]] | |
| # (a) Canonical base_nl | |
| rec = build_record( | |
| template=template, | |
| template_idx=t_idx, | |
| nl_question=template["base_nl"], | |
| persona="canonical", | |
| source="template_base", | |
| validator=v, | |
| ) | |
| if rec: | |
| records.append(rec.to_training_dict()) | |
| # (b) Rule-augmented variants | |
| augmented = augment_nl( | |
| nl_question=template["base_nl"], | |
| n=n_augmentations, | |
| seed=seed + t_idx, | |
| ) | |
| for nl_variant in augmented: | |
| rec = build_record( | |
| template=template, | |
| template_idx=t_idx, | |
| nl_question=nl_variant, | |
| persona="rule_augmented", | |
| source="rule_augmented", | |
| validator=v, | |
| ) | |
| if rec: | |
| records.append(rec.to_training_dict()) | |
| for v in validators.values(): | |
| v.close() | |
| logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates)) | |
| return records | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PHASE 2: vLLM PERSONA GENERATION (H100 required) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_vllm_pipeline( | |
| templates: list, | |
| generator, # VLLMGenerator instance | |
| personas: list[str], | |
| n_variants_per_persona: int = 10, | |
| batch_size: int = 64, | |
| temperature: float = 0.85, | |
| max_new_tokens: int = 350, | |
| seed: int = 42, | |
| ) -> list[dict]: | |
| """ | |
| Generate additional NL variants using the LLM, then validate SQL. | |
| Returns a list of training dicts. | |
| """ | |
| from data_factory.generator import generate_persona_variants_batch | |
| from data_factory.validator import SQLValidator, build_record | |
| from data_factory.schemas import SCHEMA_MAP | |
| validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP} | |
| records: list[dict] = [] | |
| gen_iter = generate_persona_variants_batch( | |
| templates_subset=templates, | |
| generator=generator, | |
| personas=personas, | |
| n_variants_per_persona=n_variants_per_persona, | |
| batch_size=batch_size, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| ) | |
| for job_result in gen_iter: | |
| t_idx = job_result["template_idx"] | |
| persona = job_result["persona"] | |
| template = templates[t_idx] | |
| v = validators[template["domain"]] | |
| for nl_variant in job_result["nl_variants"]: | |
| rec = build_record( | |
| template=template, | |
| template_idx=t_idx, | |
| nl_question=nl_variant, | |
| persona=persona, | |
| source="vllm_persona", | |
| validator=v, | |
| ) | |
| if rec: | |
| records.append(rec.to_training_dict()) | |
| for v in validators.values(): | |
| v.close() | |
| logger.info("vLLM pipeline: %d records generated.", len(records)) | |
| return records | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CHECKPOINT UTILITIES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path: | |
| path = checkpoint_dir / f"{name}.jsonl" | |
| _write_jsonl(records, path) | |
| return path | |
| def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]: | |
| path = checkpoint_dir / f"{name}.jsonl" | |
| if not path.exists(): | |
| return None | |
| records = [] | |
| with open(path, encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| records.append(json.loads(line)) | |
| logger.info("Loaded %d records from checkpoint %s", len(records), path) | |
| return records | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DATASET STATISTICS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def print_dataset_stats(records: list[dict]) -> None: | |
| from collections import Counter | |
| domains = Counter(r["metadata"]["domain"] for r in records) | |
| diffs = Counter(r["metadata"]["difficulty"] for r in records) | |
| personas = Counter(r["metadata"]["persona"] for r in records) | |
| sources = Counter(r["metadata"]["source"] for r in records) | |
| print("\n" + "=" * 55) | |
| print(f" DATASET STATISTICS ({len(records):,} total records)") | |
| print("=" * 55) | |
| print("\nBy Domain:") | |
| for k, v in sorted(domains.items()): | |
| print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)") | |
| print("\nBy Difficulty:") | |
| for k, v in sorted(diffs.items()): | |
| print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)") | |
| print("\nBy Persona/Source:") | |
| for k, v in sorted(personas.items()): | |
| print(f" {k:20s}: {v:6,}") | |
| print("\nBy Source:") | |
| for k, v in sorted(sources.items()): | |
| print(f" {k:20s}: {v:6,}") | |
| print("=" * 55 + "\n") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN ENTRY POINT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="NL2SQL Synthetic Data Factory β generates verified training data." | |
| ) | |
| parser.add_argument( | |
| "--mode", choices=["base", "full"], default="base", | |
| help="base = rule augmentation only (no GPU). full = + vLLM on H100.", | |
| ) | |
| parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct", | |
| help="HuggingFace model name for vLLM (full mode only).") | |
| parser.add_argument("--tensor-parallel", type=int, default=4, | |
| help="Tensor parallel size for vLLM (number of H100s).") | |
| parser.add_argument("--n-rule-augments", type=int, default=5, | |
| help="Number of rule-based NL augmentations per template.") | |
| parser.add_argument("--n-persona-variants", type=int, default=10, | |
| help="Number of vLLM NL variants per (template, persona) pair.") | |
| parser.add_argument("--batch-size", type=int, default=64, | |
| help="vLLM batch size (larger = faster on H100).") | |
| parser.add_argument("--temperature", type=float, default=0.85, | |
| help="Sampling temperature for vLLM generation.") | |
| parser.add_argument("--output-dir", type=str, default="generated_data/output", | |
| help="Directory to write final dataset files.") | |
| parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints", | |
| help="Directory for intermediate checkpoints.") | |
| parser.add_argument("--seed", type=int, default=42, help="Global random seed.") | |
| parser.add_argument("--no-parquet", action="store_true", | |
| help="Skip Parquet output (write only JSONL).") | |
| parser.add_argument("--resume", action="store_true", | |
| help="Resume from latest checkpoint if available.") | |
| parser.add_argument("--domains", nargs="+", | |
| choices=["ecommerce","healthcare","finance","hr"], | |
| default=["ecommerce","healthcare","finance","hr"], | |
| help="Domains to include (default: all 4).") | |
| parser.add_argument("--difficulties", nargs="+", | |
| choices=["easy","medium","hard"], | |
| default=["easy","medium","hard"], | |
| help="Difficulty levels to include (default: all 3).") | |
| args = parser.parse_args() | |
| output_dir = Path(args.output_dir) | |
| checkpoint_dir = Path(args.checkpoint_dir) | |
| _ensure_dirs(output_dir, checkpoint_dir) | |
| # ββ Load templates βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from data_factory.templates import ALL_TEMPLATES | |
| templates = [ | |
| t for t in ALL_TEMPLATES | |
| if t["domain"] in args.domains and t["difficulty"] in args.difficulties | |
| ] | |
| logger.info("Loaded %d templates (domains=%s, difficulties=%s).", | |
| len(templates), args.domains, args.difficulties) | |
| # ββ Phase 1: Base + rule augmentation βββββββββββββββββββββββββββββββββ | |
| all_records: list[dict] = [] | |
| ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None | |
| if ckpt_base is not None: | |
| all_records.extend(ckpt_base) | |
| logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base)) | |
| else: | |
| logger.info("=== Phase 1: Base + Rule Augmentation ===") | |
| base_records = run_base_pipeline( | |
| templates=templates, | |
| n_augmentations=args.n_rule_augments, | |
| seed=args.seed, | |
| ) | |
| all_records.extend(base_records) | |
| save_checkpoint(base_records, checkpoint_dir, "phase1_base") | |
| # ββ Phase 2: vLLM persona generation (full mode only) βββββββββββββββββ | |
| if args.mode == "full": | |
| ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None | |
| if ckpt_vllm is not None: | |
| all_records.extend(ckpt_vllm) | |
| logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm)) | |
| else: | |
| logger.info("=== Phase 2: vLLM Persona Generation ===") | |
| from data_factory.generator import VLLMGenerator | |
| from data_factory.config import PERSONAS | |
| generator = VLLMGenerator( | |
| model_name=args.model, | |
| mode="offline", | |
| tensor_parallel_size=args.tensor_parallel, | |
| gpu_memory_utilization=0.90, | |
| ) | |
| vllm_records = run_vllm_pipeline( | |
| templates=templates, | |
| generator=generator, | |
| personas=PERSONAS, | |
| n_variants_per_persona=args.n_persona_variants, | |
| batch_size=args.batch_size, | |
| temperature=args.temperature, | |
| max_new_tokens=350, | |
| seed=args.seed, | |
| ) | |
| all_records.extend(vllm_records) | |
| save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm") | |
| # ββ Deduplication ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Deduplicating %d records...", len(all_records)) | |
| seen_nl: set[str] = set() | |
| deduped: list[dict] = [] | |
| for rec in all_records: | |
| nl = rec["prompt"][1]["content"] # user message contains the NL question | |
| if nl not in seen_nl: | |
| seen_nl.add(nl) | |
| deduped.append(rec) | |
| logger.info("After dedup: %d unique records (removed %d duplicates).", | |
| len(deduped), len(all_records) - len(deduped)) | |
| # ββ Statistics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print_dataset_stats(deduped) | |
| # ββ Train / Val / Test split βββββββββββββββββββββββββββββββββββββββββββ | |
| train, val, test = _train_val_test_split(deduped, seed=args.seed) | |
| logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test)) | |
| # ββ Write outputs βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _write_jsonl(train, output_dir / "train.jsonl") | |
| _write_jsonl(val, output_dir / "val.jsonl") | |
| _write_jsonl(test, output_dir / "test.jsonl") | |
| if not args.no_parquet: | |
| _write_parquet(train, output_dir / "train.parquet") | |
| _write_parquet(val, output_dir / "val.parquet") | |
| _write_parquet(test, output_dir / "test.parquet") | |
| # ββ Write dataset card βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| card = { | |
| "name": "NL2SQL-Bench Synthetic Training Dataset", | |
| "version": "1.0", | |
| "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), | |
| "total_records": len(deduped), | |
| "splits": {"train": len(train), "val": len(val), "test": len(test)}, | |
| "domains": args.domains, | |
| "difficulties": args.difficulties, | |
| "mode": args.mode, | |
| "seed": args.seed, | |
| "sql_guarantee": ( | |
| "Every SQL in this dataset was human-authored and execution-validated " | |
| "against a seeded SQLite database. Zero LLM-generated SQL." | |
| ), | |
| } | |
| with open(output_dir / "dataset_card.json", "w") as f: | |
| json.dump(card, f, indent=2) | |
| logger.info("=== Done! Dataset written to %s ===", output_dir) | |
| if __name__ == "__main__": | |
| main() | |