Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# app.py
|
| 2 |
import os
|
| 3 |
import re
|
| 4 |
import gradio as gr
|
|
@@ -13,8 +12,8 @@ client = InferenceClient(
|
|
| 13 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
# SYSTEM PROMPT (strict, few-shot)
|
| 15 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
-
SYSTEM_PROMPT = """
|
| 17 |
-
|
| 18 |
YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
|
| 19 |
|
| 20 |
ABSOLUTE OUTPUT RULES β violating any rule makes the output wrong:
|
|
@@ -26,18 +25,23 @@ ABSOLUTE OUTPUT RULES β violating any rule makes the output wrong:
|
|
| 26 |
|
| 27 |
SQL RULES:
|
| 28 |
- Use ONLY table and column names that appear in the schema β never invent names.
|
| 29 |
-
- Use DuckDB syntax exclusively.
|
| 30 |
- Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
|
| 31 |
- For SELECT queries, default to LIMIT 100 unless the user asks for all rows or a specific count.
|
| 32 |
- Prefer the fewest JOINs and subqueries needed to answer the question.
|
| 33 |
- Never use SELECT * β always name the columns you need.
|
| 34 |
- Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
|
| 35 |
- Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
|
| 36 |
-
...
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
FEW-SHOT EXAMPLES:
|
| 38 |
|
| 39 |
Schema:
|
| 40 |
CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
|
|
|
|
| 41 |
|
| 42 |
Q: How many patients above 50 have asthma?
|
| 43 |
A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
|
|
@@ -48,6 +52,9 @@ A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true LIM
|
|
| 48 |
Q: What is the average age of female patients?
|
| 49 |
A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
Q: Hello, how are you?
|
| 52 |
A: NOT_A_DATA_QUESTION
|
| 53 |
|
|
@@ -63,11 +70,19 @@ VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "D
|
|
| 63 |
|
| 64 |
|
| 65 |
def clean_sql(raw: str) -> str:
|
| 66 |
-
"""Remove markdown fences, leading 'sql' keyword, and extra whitespace."""
|
| 67 |
sql = raw.strip()
|
|
|
|
|
|
|
|
|
|
| 68 |
sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
|
| 69 |
sql = re.sub(r"```$", "", sql)
|
|
|
|
| 70 |
sql = re.sub(r"(?i)^sql\s+", "", sql)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
return sql.strip()
|
| 72 |
|
| 73 |
|
|
@@ -79,7 +94,10 @@ def validate_sql(sql: str) -> str:
|
|
| 79 |
upper = sql.upper().strip()
|
| 80 |
|
| 81 |
if upper == "NOT_A_DATA_QUESTION":
|
| 82 |
-
return
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
if not upper.startswith(VALID_SQL_STARTS):
|
| 85 |
return (
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
import gradio as gr
|
|
|
|
| 12 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
# SYSTEM PROMPT (strict, few-shot)
|
| 14 |
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
SYSTEM_PROMPT = """\
|
| 16 |
+
You are a strict SQL code generator for DuckDB.
|
| 17 |
YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
|
| 18 |
|
| 19 |
ABSOLUTE OUTPUT RULES β violating any rule makes the output wrong:
|
|
|
|
| 25 |
|
| 26 |
SQL RULES:
|
| 27 |
- Use ONLY table and column names that appear in the schema β never invent names.
|
| 28 |
+
- Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax.
|
| 29 |
- Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
|
| 30 |
- For SELECT queries, default to LIMIT 100 unless the user asks for all rows or a specific count.
|
| 31 |
- Prefer the fewest JOINs and subqueries needed to answer the question.
|
| 32 |
- Never use SELECT * β always name the columns you need.
|
| 33 |
- Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
|
| 34 |
- Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
|
| 35 |
+
- Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days.
|
| 36 |
+
- Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column").
|
| 37 |
+
- String concatenation: use || operator, never CONCAT().
|
| 38 |
+
- Current date/time: use current_date or current_timestamp, never NOW().
|
| 39 |
+
|
| 40 |
FEW-SHOT EXAMPLES:
|
| 41 |
|
| 42 |
Schema:
|
| 43 |
CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
|
| 44 |
+
CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR);
|
| 45 |
|
| 46 |
Q: How many patients above 50 have asthma?
|
| 47 |
A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
|
|
|
|
| 52 |
Q: What is the average age of female patients?
|
| 53 |
A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
|
| 54 |
|
| 55 |
+
Q: Who are the top 10 patients with the longest hospital stay?
|
| 56 |
+
A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10;
|
| 57 |
+
|
| 58 |
Q: Hello, how are you?
|
| 59 |
A: NOT_A_DATA_QUESTION
|
| 60 |
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def clean_sql(raw: str) -> str:
|
| 73 |
+
"""Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace."""
|
| 74 |
sql = raw.strip()
|
| 75 |
+
# Strip <think>...</think> blocks (some models emit these)
|
| 76 |
+
sql = re.sub(r"<think>.*?</think>", "", sql, flags=re.DOTALL)
|
| 77 |
+
# Strip markdown code fences
|
| 78 |
sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
|
| 79 |
sql = re.sub(r"```$", "", sql)
|
| 80 |
+
# Strip leading "sql" keyword
|
| 81 |
sql = re.sub(r"(?i)^sql\s+", "", sql)
|
| 82 |
+
# Strip any trailing text after the semicolon
|
| 83 |
+
semi_match = re.search(r";", sql)
|
| 84 |
+
if semi_match:
|
| 85 |
+
sql = sql[: semi_match.end()]
|
| 86 |
return sql.strip()
|
| 87 |
|
| 88 |
|
|
|
|
| 94 |
upper = sql.upper().strip()
|
| 95 |
|
| 96 |
if upper == "NOT_A_DATA_QUESTION":
|
| 97 |
+
return (
|
| 98 |
+
"β οΈ That question doesn't appear to be about the database. "
|
| 99 |
+
"Try asking something that can be answered by querying the schema."
|
| 100 |
+
)
|
| 101 |
|
| 102 |
if not upper.startswith(VALID_SQL_STARTS):
|
| 103 |
return (
|