nl2sql-bench / data_factory /pipeline.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
data_factory/pipeline.py
=========================
Master orchestration pipeline for the NL2SQL Synthetic Data Factory.
This module ties together:
1. Template library (66 verified SQL templates across 4 domains)
2. Rule-based NL augmentation (augmentor.py)
3. vLLM persona-based NL generation (generator.py)
4. SQL execution validation (validator.py)
5. Output serialisation (JSONL + Parquet)
Run modes:
--mode base : Only uses template base_nl + rule augmentation (no GPU required)
--mode full : base + vLLM persona generation (requires H100)
Output dataset format (JSONL, one record per line):
{
"prompt": [{"role": "system", ...}, {"role": "user", ...}],
"sql": "SELECT ...",
"metadata": { "domain", "difficulty", "persona", ... }
}
This format is directly loadable by:
datasets.load_dataset("json", data_files="output/train.jsonl")
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Iterator, Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("pipeline")
# ─────────────────────────────────────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _ensure_dirs(*dirs: Path) -> None:
for d in dirs:
d.mkdir(parents=True, exist_ok=True)
def _write_jsonl(records: list[dict], path: Path) -> None:
with open(path, "w", encoding="utf-8") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
logger.info("Wrote %d records to %s", len(records), path)
def _write_parquet(records: list[dict], path: Path) -> None:
try:
import pandas as pd
df = pd.DataFrame(records)
df.to_parquet(path, index=False, engine="pyarrow", compression="snappy")
logger.info("Wrote %d records to %s (Parquet)", len(records), path)
except ImportError:
logger.warning("pandas/pyarrow not installed β€” skipping Parquet output.")
def _train_val_test_split(
records: list[dict],
train_frac: float = 0.90,
val_frac: float = 0.05,
seed: int = 42,
) -> tuple[list[dict], list[dict], list[dict]]:
"""
Stratified split by (domain, difficulty) to ensure all combinations
are represented in every split.
"""
rng = random.Random(seed)
from collections import defaultdict
buckets: dict[str, list[dict]] = defaultdict(list)
for rec in records:
key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}"
buckets[key].append(rec)
train, val, test = [], [], []
for key, bucket in buckets.items():
rng.shuffle(bucket)
n = len(bucket)
n_train = max(1, int(n * train_frac))
n_val = max(1, int(n * val_frac))
train.extend(bucket[:n_train])
val.extend(bucket[n_train:n_train + n_val])
test.extend(bucket[n_train + n_val:])
rng.shuffle(train)
rng.shuffle(val)
rng.shuffle(test)
return train, val, test
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 1: BASE + RULE AUGMENTATION (no GPU required)
# ─────────────────────────────────────────────────────────────────────────────
def run_base_pipeline(
templates: list,
n_augmentations: int = 5,
seed: int = 42,
) -> list[dict]:
"""
Generate training records from:
(a) the canonical base_nl of each template
(b) rule-based augmented NL variants
Returns a list of training dicts (ready to write to JSONL).
"""
from data_factory.augmentor import augment_nl
from data_factory.validator import SQLValidator, build_record
from data_factory.schemas import SCHEMA_MAP
# Build one validator per domain (reuse connection across templates)
validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
records: list[dict] = []
for t_idx, template in enumerate(templates):
v = validators[template["domain"]]
# (a) Canonical base_nl
rec = build_record(
template=template,
template_idx=t_idx,
nl_question=template["base_nl"],
persona="canonical",
source="template_base",
validator=v,
)
if rec:
records.append(rec.to_training_dict())
# (b) Rule-augmented variants
augmented = augment_nl(
nl_question=template["base_nl"],
n=n_augmentations,
seed=seed + t_idx,
)
for nl_variant in augmented:
rec = build_record(
template=template,
template_idx=t_idx,
nl_question=nl_variant,
persona="rule_augmented",
source="rule_augmented",
validator=v,
)
if rec:
records.append(rec.to_training_dict())
for v in validators.values():
v.close()
logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates))
return records
# ─────────────────────────────────────────────────────────────────────────────
# PHASE 2: vLLM PERSONA GENERATION (H100 required)
# ─────────────────────────────────────────────────────────────────────────────
def run_vllm_pipeline(
templates: list,
generator, # VLLMGenerator instance
personas: list[str],
n_variants_per_persona: int = 10,
batch_size: int = 64,
temperature: float = 0.85,
max_new_tokens: int = 350,
seed: int = 42,
) -> list[dict]:
"""
Generate additional NL variants using the LLM, then validate SQL.
Returns a list of training dicts.
"""
from data_factory.generator import generate_persona_variants_batch
from data_factory.validator import SQLValidator, build_record
from data_factory.schemas import SCHEMA_MAP
validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
records: list[dict] = []
gen_iter = generate_persona_variants_batch(
templates_subset=templates,
generator=generator,
personas=personas,
n_variants_per_persona=n_variants_per_persona,
batch_size=batch_size,
temperature=temperature,
max_new_tokens=max_new_tokens,
)
for job_result in gen_iter:
t_idx = job_result["template_idx"]
persona = job_result["persona"]
template = templates[t_idx]
v = validators[template["domain"]]
for nl_variant in job_result["nl_variants"]:
rec = build_record(
template=template,
template_idx=t_idx,
nl_question=nl_variant,
persona=persona,
source="vllm_persona",
validator=v,
)
if rec:
records.append(rec.to_training_dict())
for v in validators.values():
v.close()
logger.info("vLLM pipeline: %d records generated.", len(records))
return records
# ─────────────────────────────────────────────────────────────────────────────
# CHECKPOINT UTILITIES
# ─────────────────────────────────────────────────────────────────────────────
def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path:
path = checkpoint_dir / f"{name}.jsonl"
_write_jsonl(records, path)
return path
def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]:
path = checkpoint_dir / f"{name}.jsonl"
if not path.exists():
return None
records = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
logger.info("Loaded %d records from checkpoint %s", len(records), path)
return records
# ─────────────────────────────────────────────────────────────────────────────
# DATASET STATISTICS
# ─────────────────────────────────────────────────────────────────────────────
def print_dataset_stats(records: list[dict]) -> None:
from collections import Counter
domains = Counter(r["metadata"]["domain"] for r in records)
diffs = Counter(r["metadata"]["difficulty"] for r in records)
personas = Counter(r["metadata"]["persona"] for r in records)
sources = Counter(r["metadata"]["source"] for r in records)
print("\n" + "=" * 55)
print(f" DATASET STATISTICS ({len(records):,} total records)")
print("=" * 55)
print("\nBy Domain:")
for k, v in sorted(domains.items()):
print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
print("\nBy Difficulty:")
for k, v in sorted(diffs.items()):
print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
print("\nBy Persona/Source:")
for k, v in sorted(personas.items()):
print(f" {k:20s}: {v:6,}")
print("\nBy Source:")
for k, v in sorted(sources.items()):
print(f" {k:20s}: {v:6,}")
print("=" * 55 + "\n")
# ─────────────────────────────────────────────────────────────────────────────
# MAIN ENTRY POINT
# ─────────────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="NL2SQL Synthetic Data Factory β€” generates verified training data."
)
parser.add_argument(
"--mode", choices=["base", "full"], default="base",
help="base = rule augmentation only (no GPU). full = + vLLM on H100.",
)
parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct",
help="HuggingFace model name for vLLM (full mode only).")
parser.add_argument("--tensor-parallel", type=int, default=4,
help="Tensor parallel size for vLLM (number of H100s).")
parser.add_argument("--n-rule-augments", type=int, default=5,
help="Number of rule-based NL augmentations per template.")
parser.add_argument("--n-persona-variants", type=int, default=10,
help="Number of vLLM NL variants per (template, persona) pair.")
parser.add_argument("--batch-size", type=int, default=64,
help="vLLM batch size (larger = faster on H100).")
parser.add_argument("--temperature", type=float, default=0.85,
help="Sampling temperature for vLLM generation.")
parser.add_argument("--output-dir", type=str, default="generated_data/output",
help="Directory to write final dataset files.")
parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints",
help="Directory for intermediate checkpoints.")
parser.add_argument("--seed", type=int, default=42, help="Global random seed.")
parser.add_argument("--no-parquet", action="store_true",
help="Skip Parquet output (write only JSONL).")
parser.add_argument("--resume", action="store_true",
help="Resume from latest checkpoint if available.")
parser.add_argument("--domains", nargs="+",
choices=["ecommerce","healthcare","finance","hr"],
default=["ecommerce","healthcare","finance","hr"],
help="Domains to include (default: all 4).")
parser.add_argument("--difficulties", nargs="+",
choices=["easy","medium","hard"],
default=["easy","medium","hard"],
help="Difficulty levels to include (default: all 3).")
args = parser.parse_args()
output_dir = Path(args.output_dir)
checkpoint_dir = Path(args.checkpoint_dir)
_ensure_dirs(output_dir, checkpoint_dir)
# ── Load templates ─────────────────────────────────────────────────────
from data_factory.templates import ALL_TEMPLATES
templates = [
t for t in ALL_TEMPLATES
if t["domain"] in args.domains and t["difficulty"] in args.difficulties
]
logger.info("Loaded %d templates (domains=%s, difficulties=%s).",
len(templates), args.domains, args.difficulties)
# ── Phase 1: Base + rule augmentation ─────────────────────────────────
all_records: list[dict] = []
ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None
if ckpt_base is not None:
all_records.extend(ckpt_base)
logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base))
else:
logger.info("=== Phase 1: Base + Rule Augmentation ===")
base_records = run_base_pipeline(
templates=templates,
n_augmentations=args.n_rule_augments,
seed=args.seed,
)
all_records.extend(base_records)
save_checkpoint(base_records, checkpoint_dir, "phase1_base")
# ── Phase 2: vLLM persona generation (full mode only) ─────────────────
if args.mode == "full":
ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None
if ckpt_vllm is not None:
all_records.extend(ckpt_vllm)
logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm))
else:
logger.info("=== Phase 2: vLLM Persona Generation ===")
from data_factory.generator import VLLMGenerator
from data_factory.config import PERSONAS
generator = VLLMGenerator(
model_name=args.model,
mode="offline",
tensor_parallel_size=args.tensor_parallel,
gpu_memory_utilization=0.90,
)
vllm_records = run_vllm_pipeline(
templates=templates,
generator=generator,
personas=PERSONAS,
n_variants_per_persona=args.n_persona_variants,
batch_size=args.batch_size,
temperature=args.temperature,
max_new_tokens=350,
seed=args.seed,
)
all_records.extend(vllm_records)
save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm")
# ── Deduplication ──────────────────────────────────────────────────────
logger.info("Deduplicating %d records...", len(all_records))
seen_nl: set[str] = set()
deduped: list[dict] = []
for rec in all_records:
nl = rec["prompt"][1]["content"] # user message contains the NL question
if nl not in seen_nl:
seen_nl.add(nl)
deduped.append(rec)
logger.info("After dedup: %d unique records (removed %d duplicates).",
len(deduped), len(all_records) - len(deduped))
# ── Statistics ─────────────────────────────────────────────────────────
print_dataset_stats(deduped)
# ── Train / Val / Test split ───────────────────────────────────────────
train, val, test = _train_val_test_split(deduped, seed=args.seed)
logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test))
# ── Write outputs ─────────────────────────────────────────────────────
_write_jsonl(train, output_dir / "train.jsonl")
_write_jsonl(val, output_dir / "val.jsonl")
_write_jsonl(test, output_dir / "test.jsonl")
if not args.no_parquet:
_write_parquet(train, output_dir / "train.parquet")
_write_parquet(val, output_dir / "val.parquet")
_write_parquet(test, output_dir / "test.parquet")
# ── Write dataset card ─────────────────────────────────────────────────
card = {
"name": "NL2SQL-Bench Synthetic Training Dataset",
"version": "1.0",
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"total_records": len(deduped),
"splits": {"train": len(train), "val": len(val), "test": len(test)},
"domains": args.domains,
"difficulties": args.difficulties,
"mode": args.mode,
"seed": args.seed,
"sql_guarantee": (
"Every SQL in this dataset was human-authored and execution-validated "
"against a seeded SQLite database. Zero LLM-generated SQL."
),
}
with open(output_dir / "dataset_card.json", "w") as f:
json.dump(card, f, indent=2)
logger.info("=== Done! Dataset written to %s ===", output_dir)
if __name__ == "__main__":
main()