Spaces:
Running
Running
File size: 7,125 Bytes
cb2218d a58baac cb2218d 4864a9c cb2218d a58baac 9a1702d a58baac cb2218d a58baac 9a1702d a58baac 9a1702d a58baac 9a1702d a58baac 4d71586 a58baac 9a1702d a58baac 9a1702d a58baac 9a1702d a58baac 9a1702d a58baac 9a1702d a58baac 9a1702d a58baac cb2218d a58baac 7c1f47f a58baac cb2218d 910ce57 cb2218d a58baac cb2218d a58baac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | import os
import re
import gradio as gr
from huggingface_hub import InferenceClient
HF_TOKEN = os.environ.get("HF_TOKEN")
client = InferenceClient(
model="Qwen/Qwen3-Coder-30B-A3B-Instruct:ovhcloud",
token=HF_TOKEN,
)
# βββββββββββββββββββββββββββββββββββββββββββββ
# SYSTEM PROMPT (strict, few-shot)
# βββββββββββββββββββββββββββββββββββββββββββββ
SYSTEM_PROMPT = """\
You are a strict SQL code generator for DuckDB.
YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
ABSOLUTE OUTPUT RULES β violating any rule makes the output wrong:
1. Output ONLY raw SQL. No markdown, no code fences, no backticks, no explanations.
2. Never prefix with "sql", "SQL:", "Here is", or any natural language.
3. Never output anything after the semicolon.
4. If the question cannot be answered from the schema, output exactly: NOT_A_DATA_QUESTION
5. NOT_A_DATA_QUESTION also applies to: greetings, general knowledge, math unrelated to the schema, anything not about querying the provided tables.
SQL RULES:
- Use ONLY table and column names that appear in the schema β never invent names.
- Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax.
- Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
- Prefer the fewest JOINs and subqueries needed to answer the question.
- Never use SELECT * β always name the columns you need.
- Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
- Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
- Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days.
- Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column").
- String concatenation: use || operator, never CONCAT().
- Current date/time: use current_date or current_timestamp, never NOW().
FEW-SHOT EXAMPLES:
Schema:
CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR);
Q: How many patients above 50 have asthma?
A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
Q: Show me all patients who died during their hospital stay.
A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true;
Q: What is the average age of female patients?
A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
Q: Who are the top 10 patients with the longest hospital stay?
A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10;
Q: Hello, how are you?
A: NOT_A_DATA_QUESTION
Q: What is the capital of France?
A: NOT_A_DATA_QUESTION
Now answer the user's question using ONLY the schema they provide."""
# βββββββββββββββββββββββββββββββββββββββββββββ
# HELPERS
# βββββββββββββββββββββββββββββββββββββββββββββ
VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER")
def clean_sql(raw: str) -> str:
"""Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace."""
sql = raw.strip()
# Strip <think>...</think> blocks (some models emit these)
sql = re.sub(r"<think>.*?</think>", "", sql, flags=re.DOTALL)
# Strip markdown code fences
sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
sql = re.sub(r"```$", "", sql)
# Strip leading "sql" keyword
sql = re.sub(r"(?i)^sql\s+", "", sql)
# Strip any trailing text after the semicolon
semi_match = re.search(r";", sql)
if semi_match:
sql = sql[: semi_match.end()]
return sql.strip()
def validate_sql(sql: str) -> str:
"""
Light sanity check on the generated SQL.
Returns the SQL unchanged if it looks valid, or an error string.
"""
upper = sql.upper().strip()
if upper == "NOT_A_DATA_QUESTION":
return (
"β οΈ That question doesn't appear to be about the database. "
"Try asking something that can be answered by querying the schema."
)
if not upper.startswith(VALID_SQL_STARTS):
return (
f"β οΈ The model returned an unexpected response instead of SQL:\n\n{sql}\n\n"
"Try rephrasing your question to be more specific about the data."
)
return sql # looks good
# βββββββββββββββββββββββββββββββββββββββββββββ
# MAIN GENERATOR
# βββββββββββββββββββββββββββββββββββββββββββββ
def generate_sql(question: str, schema_ddl: str):
if not question.strip():
yield "β οΈ Please enter a question."
return
if not schema_ddl.strip():
yield "β οΈ Please provide your schema DDL."
return
prompt = (
f"Database schema:\n{schema_ddl.strip()}\n\n"
f"Question: {question.strip()}\n\n"
"SQL:"
)
accumulated = ""
try:
for token in client.chat_completion(
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt},
],
max_tokens=500,
temperature=0.0,
stream=True,
):
chunk = token.choices[0].delta.content or ""
accumulated += chunk
yield accumulated # stream raw while typing
except Exception as e:
yield f"β Error calling model: {e}"
return
# Final: clean then validate
final = validate_sql(clean_sql(accumulated))
yield final
# βββββββββββββββββββββββββββββββββββββββββββββ
# GRADIO UI
# βββββββββββββββββββββββββββββββββββββββββββββ
demo = gr.Interface(
fn=generate_sql,
inputs=[
gr.Textbox(
label="Question",
placeholder="Show me how many patients aged above 50 have asthma",
),
gr.Textbox(
label="Schema DDL",
lines=10,
placeholder="CREATE TABLE patients (...)",
),
],
outputs=gr.Textbox(label="Generated SQL"),
title="TinyEHR Text-to-SQL",
description="Generate SQL queries for the TinyEHR dataset from natural language.",
flagging_mode="never",
)
demo.launch() |