File size: 7,125 Bytes
cb2218d
a58baac
cb2218d
 
 
 
 
4864a9c
cb2218d
 
 
a58baac
 
 
9a1702d
 
a58baac
cb2218d
a58baac
 
 
 
 
 
 
 
 
9a1702d
a58baac
 
 
 
 
9a1702d
 
 
 
 
a58baac
 
 
 
9a1702d
a58baac
 
 
 
 
4d71586
a58baac
 
 
 
9a1702d
 
 
a58baac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a1702d
a58baac
9a1702d
 
 
a58baac
 
9a1702d
a58baac
9a1702d
 
 
 
a58baac
 
 
 
 
 
 
 
 
 
 
9a1702d
 
 
 
a58baac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb2218d
 
 
a58baac
 
7c1f47f
a58baac
 
 
 
 
 
cb2218d
910ce57
cb2218d
 
a58baac
cb2218d
 
a58baac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import re
import gradio as gr
from huggingface_hub import InferenceClient

HF_TOKEN = os.environ.get("HF_TOKEN")
client = InferenceClient(
    model="Qwen/Qwen3-Coder-30B-A3B-Instruct:ovhcloud",
    token=HF_TOKEN,
)

# ─────────────────────────────────────────────
# SYSTEM PROMPT  (strict, few-shot)
# ─────────────────────────────────────────────
SYSTEM_PROMPT = """\
You are a strict SQL code generator for DuckDB.
YOUR ONLY JOB is to output a single, valid DuckDB SQL query.

ABSOLUTE OUTPUT RULES β€” violating any rule makes the output wrong:
1. Output ONLY raw SQL. No markdown, no code fences, no backticks, no explanations.
2. Never prefix with "sql", "SQL:", "Here is", or any natural language.
3. Never output anything after the semicolon.
4. If the question cannot be answered from the schema, output exactly: NOT_A_DATA_QUESTION
5. NOT_A_DATA_QUESTION also applies to: greetings, general knowledge, math unrelated to the schema, anything not about querying the provided tables.

SQL RULES:
- Use ONLY table and column names that appear in the schema β€” never invent names.
- Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax.
- Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
- Prefer the fewest JOINs and subqueries needed to answer the question.
- Never use SELECT * β€” always name the columns you need.
- Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
- Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
- Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days.
- Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column").
- String concatenation: use || operator, never CONCAT().
- Current date/time: use current_date or current_timestamp, never NOW().

FEW-SHOT EXAMPLES:

Schema:
CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR);

Q: How many patients above 50 have asthma?
A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';

Q: Show me all patients who died during their hospital stay.
A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true;

Q: What is the average age of female patients?
A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';

Q: Who are the top 10 patients with the longest hospital stay?
A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10;

Q: Hello, how are you?
A: NOT_A_DATA_QUESTION

Q: What is the capital of France?
A: NOT_A_DATA_QUESTION

Now answer the user's question using ONLY the schema they provide."""

# ─────────────────────────────────────────────
# HELPERS
# ─────────────────────────────────────────────
VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER")


def clean_sql(raw: str) -> str:
    """Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace."""
    sql = raw.strip()
    # Strip <think>...</think> blocks (some models emit these)
    sql = re.sub(r"<think>.*?</think>", "", sql, flags=re.DOTALL)
    # Strip markdown code fences
    sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
    sql = re.sub(r"```$", "", sql)
    # Strip leading "sql" keyword
    sql = re.sub(r"(?i)^sql\s+", "", sql)
    # Strip any trailing text after the semicolon
    semi_match = re.search(r";", sql)
    if semi_match:
        sql = sql[: semi_match.end()]
    return sql.strip()


def validate_sql(sql: str) -> str:
    """
    Light sanity check on the generated SQL.
    Returns the SQL unchanged if it looks valid, or an error string.
    """
    upper = sql.upper().strip()

    if upper == "NOT_A_DATA_QUESTION":
        return (
            "⚠️ That question doesn't appear to be about the database. "
            "Try asking something that can be answered by querying the schema."
        )

    if not upper.startswith(VALID_SQL_STARTS):
        return (
            f"⚠️ The model returned an unexpected response instead of SQL:\n\n{sql}\n\n"
            "Try rephrasing your question to be more specific about the data."
        )

    return sql  # looks good


# ─────────────────────────────────────────────
# MAIN GENERATOR
# ─────────────────────────────────────────────
def generate_sql(question: str, schema_ddl: str):
    if not question.strip():
        yield "⚠️ Please enter a question."
        return
    if not schema_ddl.strip():
        yield "⚠️ Please provide your schema DDL."
        return

    prompt = (
        f"Database schema:\n{schema_ddl.strip()}\n\n"
        f"Question: {question.strip()}\n\n"
        "SQL:"
    )

    accumulated = ""
    try:
        for token in client.chat_completion(
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": prompt},
            ],
            max_tokens=500,
            temperature=0.0,
            stream=True,
        ):
            chunk = token.choices[0].delta.content or ""
            accumulated += chunk
            yield accumulated  # stream raw while typing

    except Exception as e:
        yield f"❌ Error calling model: {e}"
        return

    # Final: clean then validate
    final = validate_sql(clean_sql(accumulated))
    yield final


# ─────────────────────────────────────────────
# GRADIO UI
# ─────────────────────────────────────────────
demo = gr.Interface(
    fn=generate_sql,
    inputs=[
        gr.Textbox(
            label="Question",
            placeholder="Show me how many patients aged above 50 have asthma",
        ),
        gr.Textbox(
            label="Schema DDL",
            lines=10,
            placeholder="CREATE TABLE patients (...)",
        ),
    ],
    outputs=gr.Textbox(label="Generated SQL"),
    title="TinyEHR Text-to-SQL",
    description="Generate SQL queries for the TinyEHR dataset from natural language.",
    flagging_mode="never",
)

demo.launch()