Spaces:
Running
Running
File size: 8,034 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | """
data_factory/validator.py
==========================
SQL execution validation layer.
GUARANTEE: Every record that passes this validator has a SQL that:
1. Runs without error against the actual seeded SQLite schema
2. Returns at least one row (non-empty result)
3. Returns the expected column names
No LLM-generated SQL ever reaches this validator β SQL always comes from
the human-verified template library. This validator is an extra safety net
to catch any copy-paste or formatting regressions.
"""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass, field
from typing import Any, Optional
from data_factory.schemas import build_connection, SCHEMA_CONTEXT
from data_factory.templates import Template
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# DATA CLASSES
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class ValidationResult:
passed: bool
sql: str
error: Optional[str] = None
row_count: int = 0
columns: list[str] = field(default_factory=list)
@dataclass
class DataRecord:
"""One training example ready to be written to JSONL/Parquet."""
domain: str
difficulty: str
sql: str
nl_question: str # The NL paraphrase used as prompt
persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented
has_order: bool
schema_context: str
row_count: int # From validation run
columns: list[str] # From validation run
source: str # "template_base" | "vllm_persona" | "rule_augmented"
template_id: int # Index into ALL_TEMPLATES
def to_training_dict(self) -> dict[str, Any]:
"""
Returns the dictionary that will be written to the output dataset.
Format is compatible with TRL / HuggingFace `datasets`:
prompt : chat-format messages list (system + user)
sql : ground-truth SQL (label / reward reference)
metadata: auxiliary fields for curriculum or filtering
"""
system_msg = (
"You are an expert SQL analyst. "
"Write a single SELECT query that answers the question. "
"Output ONLY the SQL query β no markdown, no explanation, no backticks."
)
user_msg = (
f"DATABASE SCHEMA\n"
f"---------------\n"
f"{self.schema_context}\n\n"
f"QUESTION: {self.nl_question}"
)
return {
"prompt": [
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
],
"sql": self.sql,
"metadata": {
"domain": self.domain,
"difficulty": self.difficulty,
"persona": self.persona,
"has_order": self.has_order,
"row_count": self.row_count,
"columns": self.columns,
"source": self.source,
"template_id": self.template_id,
},
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# VALIDATOR
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class SQLValidator:
"""
Validates SQL against a seeded in-memory SQLite connection.
One validator per domain to reuse the same connection for all templates
in that domain (performance optimization).
"""
def __init__(self, domain: str, seed: int = 42) -> None:
self.domain = domain
self._conn = build_connection(domain, seed=seed)
def validate(self, sql: str) -> ValidationResult:
"""
Execute SQL and return a ValidationResult.
Never raises β always returns a result object.
"""
sql = sql.strip().rstrip(";")
if not sql:
return ValidationResult(passed=False, sql=sql, error="Empty SQL string.")
# Block any write operations
first_word = sql.split()[0].lower() if sql.split() else ""
forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"}
if first_word in forbidden:
return ValidationResult(
passed=False, sql=sql,
error=f"Write operation '{first_word.upper()}' is not permitted."
)
try:
cur = self._conn.execute(sql)
cols = [d[0] for d in cur.description] if cur.description else []
rows = cur.fetchall()
return ValidationResult(
passed=True,
sql=sql,
row_count=len(rows),
columns=cols,
)
except sqlite3.Error as exc:
return ValidationResult(passed=False, sql=sql, error=str(exc))
def close(self) -> None:
self._conn.close()
def validate_template(template: Template, seed: int = 42) -> ValidationResult:
"""Convenience function: validate a single template."""
v = SQLValidator(template["domain"], seed=seed)
result = v.validate(template["sql"])
v.close()
return result
def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]:
"""
Run validation across all templates. Returns a summary dict.
Used during CI / smoke testing.
"""
from data_factory.schemas import SCHEMA_MAP
validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP}
passed = []
failed = []
for i, t in enumerate(templates):
v = validators[t["domain"]]
result = v.validate(t["sql"])
if result.passed:
passed.append(i)
else:
failed.append({"index": i, "domain": t["domain"],
"sql": t["sql"][:80], "error": result.error})
for v in validators.values():
v.close()
return {
"total": len(templates),
"passed": len(passed),
"failed": len(failed),
"failures": failed,
}
def build_record(
template: Template,
template_idx: int,
nl_question: str,
persona: str,
source: str,
validator: SQLValidator,
) -> Optional[DataRecord]:
"""
Validate the template SQL and, if it passes, build a DataRecord.
Parameters
----------
template : The source template (contains SQL, domain, difficulty).
template_idx : Index of template in ALL_TEMPLATES (for deduplication).
nl_question : The NL paraphrase to use as the prompt.
persona : Which persona/strategy generated this NL.
source : 'template_base' | 'vllm_persona' | 'rule_augmented'
validator : Pre-built SQLValidator for this domain.
Returns None if validation fails.
"""
vr = validator.validate(template["sql"])
if not vr.passed:
return None
return DataRecord(
domain=template["domain"],
difficulty=template["difficulty"],
sql=template["sql"],
nl_question=nl_question,
persona=persona,
has_order=template["has_order"],
schema_context=SCHEMA_CONTEXT[template["domain"]],
row_count=vr.row_count,
columns=vr.columns,
source=source,
template_id=template_idx,
)
|