File size: 8,034 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
data_factory/validator.py
==========================
SQL execution validation layer.

GUARANTEE: Every record that passes this validator has a SQL that:
  1. Runs without error against the actual seeded SQLite schema
  2. Returns at least one row (non-empty result)
  3. Returns the expected column names

No LLM-generated SQL ever reaches this validator β€” SQL always comes from
the human-verified template library. This validator is an extra safety net
to catch any copy-paste or formatting regressions.
"""

from __future__ import annotations

import sqlite3
from dataclasses import dataclass, field
from typing import Any, Optional

from data_factory.schemas import build_connection, SCHEMA_CONTEXT
from data_factory.templates import Template


# ─────────────────────────────────────────────────────────────────────────────
# DATA CLASSES
# ─────────────────────────────────────────────────────────────────────────────

@dataclass
class ValidationResult:
    passed:       bool
    sql:          str
    error:        Optional[str]     = None
    row_count:    int               = 0
    columns:      list[str]         = field(default_factory=list)


@dataclass
class DataRecord:
    """One training example ready to be written to JSONL/Parquet."""
    domain:        str
    difficulty:    str
    sql:           str
    nl_question:   str          # The NL paraphrase used as prompt
    persona:       str          # ceo | chatty | lazy_typist | non_techie | analyst | augmented
    has_order:     bool
    schema_context: str
    row_count:     int          # From validation run
    columns:       list[str]    # From validation run
    source:        str          # "template_base" | "vllm_persona" | "rule_augmented"
    template_id:   int          # Index into ALL_TEMPLATES

    def to_training_dict(self) -> dict[str, Any]:
        """
        Returns the dictionary that will be written to the output dataset.

        Format is compatible with TRL / HuggingFace `datasets`:
          prompt  : chat-format messages list (system + user)
          sql     : ground-truth SQL (label / reward reference)
          metadata: auxiliary fields for curriculum or filtering
        """
        system_msg = (
            "You are an expert SQL analyst. "
            "Write a single SELECT query that answers the question. "
            "Output ONLY the SQL query β€” no markdown, no explanation, no backticks."
        )
        user_msg = (
            f"DATABASE SCHEMA\n"
            f"---------------\n"
            f"{self.schema_context}\n\n"
            f"QUESTION: {self.nl_question}"
        )
        return {
            "prompt": [
                {"role": "system", "content": system_msg},
                {"role": "user",   "content": user_msg},
            ],
            "sql":    self.sql,
            "metadata": {
                "domain":     self.domain,
                "difficulty": self.difficulty,
                "persona":    self.persona,
                "has_order":  self.has_order,
                "row_count":  self.row_count,
                "columns":    self.columns,
                "source":     self.source,
                "template_id": self.template_id,
            },
        }


# ─────────────────────────────────────────────────────────────────────────────
# VALIDATOR
# ─────────────────────────────────────────────────────────────────────────────

class SQLValidator:
    """
    Validates SQL against a seeded in-memory SQLite connection.

    One validator per domain to reuse the same connection for all templates
    in that domain (performance optimization).
    """

    def __init__(self, domain: str, seed: int = 42) -> None:
        self.domain = domain
        self._conn = build_connection(domain, seed=seed)

    def validate(self, sql: str) -> ValidationResult:
        """
        Execute SQL and return a ValidationResult.
        Never raises β€” always returns a result object.
        """
        sql = sql.strip().rstrip(";")
        if not sql:
            return ValidationResult(passed=False, sql=sql, error="Empty SQL string.")

        # Block any write operations
        first_word = sql.split()[0].lower() if sql.split() else ""
        forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"}
        if first_word in forbidden:
            return ValidationResult(
                passed=False, sql=sql,
                error=f"Write operation '{first_word.upper()}' is not permitted."
            )

        try:
            cur = self._conn.execute(sql)
            cols = [d[0] for d in cur.description] if cur.description else []
            rows = cur.fetchall()
            return ValidationResult(
                passed=True,
                sql=sql,
                row_count=len(rows),
                columns=cols,
            )
        except sqlite3.Error as exc:
            return ValidationResult(passed=False, sql=sql, error=str(exc))

    def close(self) -> None:
        self._conn.close()


def validate_template(template: Template, seed: int = 42) -> ValidationResult:
    """Convenience function: validate a single template."""
    v = SQLValidator(template["domain"], seed=seed)
    result = v.validate(template["sql"])
    v.close()
    return result


def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]:
    """
    Run validation across all templates. Returns a summary dict.
    Used during CI / smoke testing.
    """
    from data_factory.schemas import SCHEMA_MAP

    validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP}
    passed = []
    failed = []

    for i, t in enumerate(templates):
        v = validators[t["domain"]]
        result = v.validate(t["sql"])
        if result.passed:
            passed.append(i)
        else:
            failed.append({"index": i, "domain": t["domain"],
                           "sql": t["sql"][:80], "error": result.error})

    for v in validators.values():
        v.close()

    return {
        "total":  len(templates),
        "passed": len(passed),
        "failed": len(failed),
        "failures": failed,
    }


def build_record(
    template: Template,
    template_idx: int,
    nl_question: str,
    persona: str,
    source: str,
    validator: SQLValidator,
) -> Optional[DataRecord]:
    """
    Validate the template SQL and, if it passes, build a DataRecord.

    Parameters
    ----------
    template     : The source template (contains SQL, domain, difficulty).
    template_idx : Index of template in ALL_TEMPLATES (for deduplication).
    nl_question  : The NL paraphrase to use as the prompt.
    persona      : Which persona/strategy generated this NL.
    source       : 'template_base' | 'vllm_persona' | 'rule_augmented'
    validator    : Pre-built SQLValidator for this domain.

    Returns None if validation fails.
    """
    vr = validator.validate(template["sql"])
    if not vr.passed:
        return None

    return DataRecord(
        domain=template["domain"],
        difficulty=template["difficulty"],
        sql=template["sql"],
        nl_question=nl_question,
        persona=persona,
        has_order=template["has_order"],
        schema_context=SCHEMA_CONTEXT[template["domain"]],
        row_count=vr.row_count,
        columns=vr.columns,
        source=source,
        template_id=template_idx,
    )