Spaces:
Sleeping
Sleeping
File size: 19,285 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 | """
data_factory/pipeline.py
=========================
Master orchestration pipeline for the NL2SQL Synthetic Data Factory.
This module ties together:
1. Template library (66 verified SQL templates across 4 domains)
2. Rule-based NL augmentation (augmentor.py)
3. vLLM persona-based NL generation (generator.py)
4. SQL execution validation (validator.py)
5. Output serialisation (JSONL + Parquet)
Run modes:
--mode base : Only uses template base_nl + rule augmentation (no GPU required)
--mode full : base + vLLM persona generation (requires H100)
Output dataset format (JSONL, one record per line):
{
"prompt": [{"role": "system", ...}, {"role": "user", ...}],
"sql": "SELECT ...",
"metadata": { "domain", "difficulty", "persona", ... }
}
This format is directly loadable by:
datasets.load_dataset("json", data_files="output/train.jsonl")
"""
from __future__ import annotations
import argparse
import json
import logging
import os
import random
import time
from pathlib import Path
from typing import Any, Iterator, Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("pipeline")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# HELPERS
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _ensure_dirs(*dirs: Path) -> None:
for d in dirs:
d.mkdir(parents=True, exist_ok=True)
def _write_jsonl(records: list[dict], path: Path) -> None:
with open(path, "w", encoding="utf-8") as f:
for rec in records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
logger.info("Wrote %d records to %s", len(records), path)
def _write_parquet(records: list[dict], path: Path) -> None:
try:
import pandas as pd
df = pd.DataFrame(records)
df.to_parquet(path, index=False, engine="pyarrow", compression="snappy")
logger.info("Wrote %d records to %s (Parquet)", len(records), path)
except ImportError:
logger.warning("pandas/pyarrow not installed β skipping Parquet output.")
def _train_val_test_split(
records: list[dict],
train_frac: float = 0.90,
val_frac: float = 0.05,
seed: int = 42,
) -> tuple[list[dict], list[dict], list[dict]]:
"""
Stratified split by (domain, difficulty) to ensure all combinations
are represented in every split.
"""
rng = random.Random(seed)
from collections import defaultdict
buckets: dict[str, list[dict]] = defaultdict(list)
for rec in records:
key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}"
buckets[key].append(rec)
train, val, test = [], [], []
for key, bucket in buckets.items():
rng.shuffle(bucket)
n = len(bucket)
n_train = max(1, int(n * train_frac))
n_val = max(1, int(n * val_frac))
train.extend(bucket[:n_train])
val.extend(bucket[n_train:n_train + n_val])
test.extend(bucket[n_train + n_val:])
rng.shuffle(train)
rng.shuffle(val)
rng.shuffle(test)
return train, val, test
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PHASE 1: BASE + RULE AUGMENTATION (no GPU required)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_base_pipeline(
templates: list,
n_augmentations: int = 5,
seed: int = 42,
) -> list[dict]:
"""
Generate training records from:
(a) the canonical base_nl of each template
(b) rule-based augmented NL variants
Returns a list of training dicts (ready to write to JSONL).
"""
from data_factory.augmentor import augment_nl
from data_factory.validator import SQLValidator, build_record
from data_factory.schemas import SCHEMA_MAP
# Build one validator per domain (reuse connection across templates)
validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
records: list[dict] = []
for t_idx, template in enumerate(templates):
v = validators[template["domain"]]
# (a) Canonical base_nl
rec = build_record(
template=template,
template_idx=t_idx,
nl_question=template["base_nl"],
persona="canonical",
source="template_base",
validator=v,
)
if rec:
records.append(rec.to_training_dict())
# (b) Rule-augmented variants
augmented = augment_nl(
nl_question=template["base_nl"],
n=n_augmentations,
seed=seed + t_idx,
)
for nl_variant in augmented:
rec = build_record(
template=template,
template_idx=t_idx,
nl_question=nl_variant,
persona="rule_augmented",
source="rule_augmented",
validator=v,
)
if rec:
records.append(rec.to_training_dict())
for v in validators.values():
v.close()
logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates))
return records
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# PHASE 2: vLLM PERSONA GENERATION (H100 required)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_vllm_pipeline(
templates: list,
generator, # VLLMGenerator instance
personas: list[str],
n_variants_per_persona: int = 10,
batch_size: int = 64,
temperature: float = 0.85,
max_new_tokens: int = 350,
seed: int = 42,
) -> list[dict]:
"""
Generate additional NL variants using the LLM, then validate SQL.
Returns a list of training dicts.
"""
from data_factory.generator import generate_persona_variants_batch
from data_factory.validator import SQLValidator, build_record
from data_factory.schemas import SCHEMA_MAP
validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
records: list[dict] = []
gen_iter = generate_persona_variants_batch(
templates_subset=templates,
generator=generator,
personas=personas,
n_variants_per_persona=n_variants_per_persona,
batch_size=batch_size,
temperature=temperature,
max_new_tokens=max_new_tokens,
)
for job_result in gen_iter:
t_idx = job_result["template_idx"]
persona = job_result["persona"]
template = templates[t_idx]
v = validators[template["domain"]]
for nl_variant in job_result["nl_variants"]:
rec = build_record(
template=template,
template_idx=t_idx,
nl_question=nl_variant,
persona=persona,
source="vllm_persona",
validator=v,
)
if rec:
records.append(rec.to_training_dict())
for v in validators.values():
v.close()
logger.info("vLLM pipeline: %d records generated.", len(records))
return records
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CHECKPOINT UTILITIES
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path:
path = checkpoint_dir / f"{name}.jsonl"
_write_jsonl(records, path)
return path
def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]:
path = checkpoint_dir / f"{name}.jsonl"
if not path.exists():
return None
records = []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
records.append(json.loads(line))
logger.info("Loaded %d records from checkpoint %s", len(records), path)
return records
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# DATASET STATISTICS
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def print_dataset_stats(records: list[dict]) -> None:
from collections import Counter
domains = Counter(r["metadata"]["domain"] for r in records)
diffs = Counter(r["metadata"]["difficulty"] for r in records)
personas = Counter(r["metadata"]["persona"] for r in records)
sources = Counter(r["metadata"]["source"] for r in records)
print("\n" + "=" * 55)
print(f" DATASET STATISTICS ({len(records):,} total records)")
print("=" * 55)
print("\nBy Domain:")
for k, v in sorted(domains.items()):
print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
print("\nBy Difficulty:")
for k, v in sorted(diffs.items()):
print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
print("\nBy Persona/Source:")
for k, v in sorted(personas.items()):
print(f" {k:20s}: {v:6,}")
print("\nBy Source:")
for k, v in sorted(sources.items()):
print(f" {k:20s}: {v:6,}")
print("=" * 55 + "\n")
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# MAIN ENTRY POINT
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main() -> None:
parser = argparse.ArgumentParser(
description="NL2SQL Synthetic Data Factory β generates verified training data."
)
parser.add_argument(
"--mode", choices=["base", "full"], default="base",
help="base = rule augmentation only (no GPU). full = + vLLM on H100.",
)
parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct",
help="HuggingFace model name for vLLM (full mode only).")
parser.add_argument("--tensor-parallel", type=int, default=4,
help="Tensor parallel size for vLLM (number of H100s).")
parser.add_argument("--n-rule-augments", type=int, default=5,
help="Number of rule-based NL augmentations per template.")
parser.add_argument("--n-persona-variants", type=int, default=10,
help="Number of vLLM NL variants per (template, persona) pair.")
parser.add_argument("--batch-size", type=int, default=64,
help="vLLM batch size (larger = faster on H100).")
parser.add_argument("--temperature", type=float, default=0.85,
help="Sampling temperature for vLLM generation.")
parser.add_argument("--output-dir", type=str, default="generated_data/output",
help="Directory to write final dataset files.")
parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints",
help="Directory for intermediate checkpoints.")
parser.add_argument("--seed", type=int, default=42, help="Global random seed.")
parser.add_argument("--no-parquet", action="store_true",
help="Skip Parquet output (write only JSONL).")
parser.add_argument("--resume", action="store_true",
help="Resume from latest checkpoint if available.")
parser.add_argument("--domains", nargs="+",
choices=["ecommerce","healthcare","finance","hr"],
default=["ecommerce","healthcare","finance","hr"],
help="Domains to include (default: all 4).")
parser.add_argument("--difficulties", nargs="+",
choices=["easy","medium","hard"],
default=["easy","medium","hard"],
help="Difficulty levels to include (default: all 3).")
args = parser.parse_args()
output_dir = Path(args.output_dir)
checkpoint_dir = Path(args.checkpoint_dir)
_ensure_dirs(output_dir, checkpoint_dir)
# ββ Load templates βββββββββββββββββββββββββββββββββββββββββββββββββββββ
from data_factory.templates import ALL_TEMPLATES
templates = [
t for t in ALL_TEMPLATES
if t["domain"] in args.domains and t["difficulty"] in args.difficulties
]
logger.info("Loaded %d templates (domains=%s, difficulties=%s).",
len(templates), args.domains, args.difficulties)
# ββ Phase 1: Base + rule augmentation βββββββββββββββββββββββββββββββββ
all_records: list[dict] = []
ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None
if ckpt_base is not None:
all_records.extend(ckpt_base)
logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base))
else:
logger.info("=== Phase 1: Base + Rule Augmentation ===")
base_records = run_base_pipeline(
templates=templates,
n_augmentations=args.n_rule_augments,
seed=args.seed,
)
all_records.extend(base_records)
save_checkpoint(base_records, checkpoint_dir, "phase1_base")
# ββ Phase 2: vLLM persona generation (full mode only) βββββββββββββββββ
if args.mode == "full":
ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None
if ckpt_vllm is not None:
all_records.extend(ckpt_vllm)
logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm))
else:
logger.info("=== Phase 2: vLLM Persona Generation ===")
from data_factory.generator import VLLMGenerator
from data_factory.config import PERSONAS
generator = VLLMGenerator(
model_name=args.model,
mode="offline",
tensor_parallel_size=args.tensor_parallel,
gpu_memory_utilization=0.90,
)
vllm_records = run_vllm_pipeline(
templates=templates,
generator=generator,
personas=PERSONAS,
n_variants_per_persona=args.n_persona_variants,
batch_size=args.batch_size,
temperature=args.temperature,
max_new_tokens=350,
seed=args.seed,
)
all_records.extend(vllm_records)
save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm")
# ββ Deduplication ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
logger.info("Deduplicating %d records...", len(all_records))
seen_nl: set[str] = set()
deduped: list[dict] = []
for rec in all_records:
nl = rec["prompt"][1]["content"] # user message contains the NL question
if nl not in seen_nl:
seen_nl.add(nl)
deduped.append(rec)
logger.info("After dedup: %d unique records (removed %d duplicates).",
len(deduped), len(all_records) - len(deduped))
# ββ Statistics βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print_dataset_stats(deduped)
# ββ Train / Val / Test split βββββββββββββββββββββββββββββββββββββββββββ
train, val, test = _train_val_test_split(deduped, seed=args.seed)
logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test))
# ββ Write outputs βββββββββββββββββββββββββββββββββββββββββββββββββββββ
_write_jsonl(train, output_dir / "train.jsonl")
_write_jsonl(val, output_dir / "val.jsonl")
_write_jsonl(test, output_dir / "test.jsonl")
if not args.no_parquet:
_write_parquet(train, output_dir / "train.parquet")
_write_parquet(val, output_dir / "val.parquet")
_write_parquet(test, output_dir / "test.parquet")
# ββ Write dataset card βββββββββββββββββββββββββββββββββββββββββββββββββ
card = {
"name": "NL2SQL-Bench Synthetic Training Dataset",
"version": "1.0",
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
"total_records": len(deduped),
"splits": {"train": len(train), "val": len(val), "test": len(test)},
"domains": args.domains,
"difficulties": args.difficulties,
"mode": args.mode,
"seed": args.seed,
"sql_guarantee": (
"Every SQL in this dataset was human-authored and execution-validated "
"against a seeded SQLite database. Zero LLM-generated SQL."
),
}
with open(output_dir / "dataset_card.json", "w") as f:
json.dump(card, f, indent=2)
logger.info("=== Done! Dataset written to %s ===", output_dir)
if __name__ == "__main__":
main()
|