""" run_data_factory.py ==================== Entry point and smoke-test runner for the NL2SQL Data Factory. Run this FIRST before running the full pipeline to verify: 1. All 66 SQL templates execute without errors 2. Rule augmentation produces diverse NL variants 3. Validators correctly accept/reject queries 4. Base pipeline generates well-formed JSONL records Usage: # Smoke test only (fast, ~10 seconds) python run_data_factory.py --smoke-test # Base mode (no GPU, generates all rule-augmented records) python run_data_factory.py --mode base # Full mode (H100 required) python run_data_factory.py --mode full --model meta-llama/Meta-Llama-3-70B-Instruct --tensor-parallel 4 # Preview what the dataset looks like python run_data_factory.py --smoke-test --show-samples 3 """ from __future__ import annotations import argparse import json import sys import textwrap from pathlib import Path # Allow running from project root sys.path.insert(0, str(Path(__file__).parent)) # ───────────────────────────────────────────────────────────────────────────── # SMOKE TEST # ───────────────────────────────────────────────────────────────────────────── def run_smoke_test(show_samples: int = 0) -> bool: print("\n" + "=" * 60) print(" NL2SQL DATA FACTORY — SMOKE TEST") print("=" * 60) all_passed = True # 1. Template validation print("\n[1/4] Validating all SQL templates against seeded data...") from data_factory.templates import ALL_TEMPLATES, template_stats from data_factory.validator import validate_all_templates stats = template_stats() result = validate_all_templates(ALL_TEMPLATES) print(f" Templates: {stats}") print(f" Validation: {result['passed']}/{result['total']} passed", end="") if result["failed"]: print(f" ← {result['failed']} FAILURES:") for f in result["failures"]: print(f" [{f['domain']}] {f['sql']}... → {f['error']}") all_passed = False else: print(" ✓") # 2. Rule augmentation print("\n[2/4] Testing rule-based augmentation...") from data_factory.augmentor import augment_nl test_nls = [ "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.", "Which medications are prescribed most often? Return medication_name, category, times_prescribed.", "Rank active employees by salary within their department. Return salary_rank.", ] for nl in test_nls: variants = augment_nl(nl, n=3, seed=42) if not variants: print(f" FAIL: No variants generated for: {nl[:50]}") all_passed = False else: print(f" ✓ {len(variants)} variants from: '{nl[:45]}...'") if show_samples > 0: for i, v in enumerate(variants[:show_samples]): print(f" [{i+1}] {v}") # 3. Validator accept/reject print("\n[3/4] Testing SQL validator accept/reject logic...") from data_factory.validator import SQLValidator v = SQLValidator("ecommerce") tests = [ ("SELECT id, name FROM customers WHERE tier = 'gold'", True, "valid SELECT"), ("INSERT INTO customers VALUES (1,'x','x@x.com','IN','gold','2024-01-01')", False, "rejected INSERT"), ("SELECT nonexistent_col FROM customers", False, "bad column name"), ("", False, "empty string"), ] for sql, expect_pass, label in tests: vr = v.validate(sql) status = "✓" if vr.passed == expect_pass else "✗" print(f" {status} {label}: passed={vr.passed}", end="") if not vr.passed: print(f" (error: {vr.error})", end="") print() if vr.passed != expect_pass: all_passed = False v.close() # 4. Mini base pipeline (first 5 templates only) print("\n[4/4] Running mini base pipeline (first 5 templates)...") from data_factory.pipeline import run_base_pipeline mini_templates = ALL_TEMPLATES[:5] records = run_base_pipeline(mini_templates, n_augmentations=2, seed=42) expected_min = 5 # at least canonical NLs if len(records) < expected_min: print(f" FAIL: Only {len(records)} records (expected ≥{expected_min})") all_passed = False else: print(f" ✓ Generated {len(records)} records from 5 templates") # Validate structure required_keys = {"prompt", "sql", "metadata"} for rec in records[:3]: missing = required_keys - rec.keys() if missing: print(f" FAIL: Record missing keys: {missing}") all_passed = False break else: print(" ✓ Record structure validated") if show_samples > 0 and records: print(f"\n --- Sample Record ---") sample = records[0] print(f" Domain: {sample['metadata']['domain']}") print(f" Difficulty: {sample['metadata']['difficulty']}") print(f" Persona: {sample['metadata']['persona']}") print(f" NL: {sample['prompt'][1]['content'].split('QUESTION: ')[-1][:100]}") print(f" SQL: {sample['sql'][:80]}...") # Summary print("\n" + "=" * 60) if all_passed: print(" ALL SMOKE TESTS PASSED ✓") print(" Safe to run: python run_data_factory.py --mode base") else: print(" SOME TESTS FAILED ✗ — fix errors before running pipeline") print("=" * 60 + "\n") return all_passed # ───────────────────────────────────────────────────────────────────────────── # INSPECT DATASET # ───────────────────────────────────────────────────────────────────────────── def inspect_dataset(jsonl_path: str, n: int = 5) -> None: """Pretty-print N records from an output JSONL file.""" path = Path(jsonl_path) if not path.exists(): print(f"File not found: {path}") return records = [] with open(path, encoding="utf-8") as f: for i, line in enumerate(f): if i >= n: break records.append(json.loads(line)) print(f"\n{'='*65}") print(f" Showing {len(records)} records from {path.name}") print(f"{'='*65}") for i, rec in enumerate(records): nl = rec["prompt"][1]["content"].split("QUESTION:")[-1].strip() sql = rec["sql"] meta = rec["metadata"] print(f"\n[{i+1}] Domain={meta['domain']} | Difficulty={meta['difficulty']} | " f"Persona={meta['persona']} | Source={meta['source']}") print(f" NL: {textwrap.shorten(nl, 90)}") print(f" SQL: {textwrap.shorten(sql, 90)}") print() # ───────────────────────────────────────────────────────────────────────────── # MAIN # ───────────────────────────────────────────────────────────────────────────── def main() -> None: parser = argparse.ArgumentParser( description="NL2SQL Data Factory — entry point.", formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument( "--smoke-test", action="store_true", help="Run smoke test only (validates all templates, no output written).", ) parser.add_argument( "--show-samples", type=int, default=0, help="During smoke test, show N sample NL variants and records.", ) parser.add_argument( "--inspect", type=str, default=None, help="Path to a JSONL output file to inspect.", ) parser.add_argument( "--inspect-n", type=int, default=5, help="Number of records to show when inspecting.", ) parser.add_argument( "--mode", choices=["base", "full"], default="base", help=( "base: rule augmentation only, ~450 records, no GPU needed.\n" "full: + vLLM persona variants, 500K+ records, H100 required." ), ) parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct") parser.add_argument("--tensor-parallel", type=int, default=4) parser.add_argument("--n-rule-augments", type=int, default=5) parser.add_argument("--n-persona-variants", type=int, default=10) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--temperature", type=float, default=0.85) parser.add_argument("--output-dir", default="generated_data/output") parser.add_argument("--checkpoint-dir", default="generated_data/checkpoints") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--no-parquet", action="store_true") parser.add_argument("--resume", action="store_true") parser.add_argument( "--domains", nargs="+", choices=["ecommerce","healthcare","finance","hr"], default=["ecommerce","healthcare","finance","hr"], ) parser.add_argument( "--difficulties", nargs="+", choices=["easy","medium","hard"], default=["easy","medium","hard"], ) args = parser.parse_args() if args.smoke_test: ok = run_smoke_test(show_samples=args.show_samples) sys.exit(0 if ok else 1) if args.inspect: inspect_dataset(args.inspect, n=args.inspect_n) sys.exit(0) # Forward to pipeline from data_factory.pipeline import main as pipeline_main # Re-parse with pipeline's own parser by forwarding sys.argv pipeline_main() if __name__ == "__main__": main()