vidulpanickan commited on
Commit
a58baac
Β·
verified Β·
1 Parent(s): 2a1fcd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -46
app.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import os
 
2
  import gradio as gr
3
  from huggingface_hub import InferenceClient
4
 
@@ -8,59 +10,148 @@ client = InferenceClient(
8
  token=HF_TOKEN,
9
  )
10
 
11
- SYSTEM_PROMPT = """You are a SQL expert. Given a database schema and a question in English, generate a DuckDB-compatible SQL query that answers the question.
12
-
13
- Rules:
14
- - Return ONLY the SQL query, no explanation, no markdown, no code fences
15
- - Use exact table and column names from the schema
16
- - Use DuckDB SQL syntax
17
- - Add LIMIT 100 unless the user asks for a specific count or all rows
18
- - Keep queries simple: use the fewest tables and JOINs possible
19
- - For text search, always use ILIKE '%term%' (case-insensitive) instead of exact match
20
- - Never use LOWER() or UPPER() for comparison, use ILIKE instead
21
- - If the question is not related to querying the database (e.g. personal questions, general knowledge, chitchat), respond with exactly: NOT_A_DATA_QUESTION
22
- - Only generate SQL for questions that can be answered by querying the provided database schema"""
23
-
24
-
25
- def generate_sql(question: str, schema_ddl: str) -> str:
26
- prompt = f"Database schema:\n{schema_ddl}\n\nQuestion: {question}\n\nSQL:"
27
-
28
- sql = ""
29
- for token in client.chat_completion(
30
- messages=[
31
- {"role": "system", "content": SYSTEM_PROMPT},
32
- {"role": "user", "content": prompt},
33
- ],
34
- max_tokens=500,
35
- temperature=0.1,
36
- stream=True,
37
- ):
38
- chunk = token.choices[0].delta.content or ""
39
- sql += chunk
40
- yield sql.strip()
41
-
42
- # Clean up final result
43
- sql = sql.strip()
44
- if sql.startswith("```"):
45
- sql = sql.split("\n", 1)[1] if "\n" in sql else sql[3:]
46
- if sql.endswith("```"):
47
- sql = sql[:-3]
48
- sql = sql.strip()
49
- if sql.lower().startswith("sql"):
50
- sql = sql[3:].strip()
51
-
52
- yield sql
53
 
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  demo = gr.Interface(
56
  fn=generate_sql,
57
  inputs=[
58
- gr.Textbox(label="Question", placeholder="Show me all patients who died during their hospital stay"),
59
- gr.Textbox(label="Schema DDL", lines=10, placeholder="CREATE TABLE patients (...)"),
 
 
 
 
 
 
 
60
  ],
61
- outputs=gr.Textbox(label="Generated SQL"),
62
  title="TinyEHR Text-to-SQL",
63
  description="Generate SQL queries for the TinyEHR dataset from natural language.",
 
64
  )
65
 
66
- demo.launch()
 
1
+ # app.py
2
  import os
3
+ import re
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
 
 
10
  token=HF_TOKEN,
11
  )
12
 
13
+ # ─────────────────────────────────────────────
14
+ # SYSTEM PROMPT (strict, few-shot)
15
+ # ─────────────────────────────────────────────
16
+ SYSTEM_PROMPT = """You are a strict SQL code generator for DuckDB.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
19
 
20
+ ABSOLUTE OUTPUT RULES β€” violating any rule makes the output wrong:
21
+ 1. Output ONLY raw SQL. No markdown, no code fences, no backticks, no explanations.
22
+ 2. Never prefix with "sql", "SQL:", "Here is", or any natural language.
23
+ 3. Never output anything after the semicolon.
24
+ 4. If the question cannot be answered from the schema, output exactly: NOT_A_DATA_QUESTION
25
+ 5. NOT_A_DATA_QUESTION also applies to: greetings, general knowledge, math unrelated to the schema, anything not about querying the provided tables.
26
+
27
+ SQL RULES:
28
+ - Use ONLY table and column names that appear in the schema β€” never invent names.
29
+ - Use DuckDB syntax exclusively.
30
+ - Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
31
+ - For SELECT queries, default to LIMIT 100 unless the user asks for all rows or a specific count.
32
+ - Prefer the fewest JOINs and subqueries needed to answer the question.
33
+ - Never use SELECT * β€” always name the columns you need.
34
+ - Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
35
+ - Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
36
+ - INSERT, UPDATE, DELETE, CREATE, DROP, ALTER are all allowed β€” the user owns their database.
37
+
38
+ FEW-SHOT EXAMPLES:
39
+
40
+ Schema:
41
+ CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
42
+
43
+ Q: How many patients above 50 have asthma?
44
+ A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
45
+
46
+ Q: Show me all patients who died during their hospital stay.
47
+ A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true LIMIT 100;
48
+
49
+ Q: What is the average age of female patients?
50
+ A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
51
+
52
+ Q: Hello, how are you?
53
+ A: NOT_A_DATA_QUESTION
54
+
55
+ Q: What is the capital of France?
56
+ A: NOT_A_DATA_QUESTION
57
+
58
+ Now answer the user's question using ONLY the schema they provide."""
59
+
60
+ # ─────────────────────────────────────────────
61
+ # HELPERS
62
+ # ─────────────────────────────────────────────
63
+ VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER")
64
+
65
+
66
+ def clean_sql(raw: str) -> str:
67
+ """Remove markdown fences, leading 'sql' keyword, and extra whitespace."""
68
+ sql = raw.strip()
69
+ sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
70
+ sql = re.sub(r"```$", "", sql)
71
+ sql = re.sub(r"(?i)^sql\s+", "", sql)
72
+ return sql.strip()
73
+
74
+
75
+ def validate_sql(sql: str) -> str:
76
+ """
77
+ Light sanity check on the generated SQL.
78
+ Returns the SQL unchanged if it looks valid, or an error string.
79
+ """
80
+ upper = sql.upper().strip()
81
+
82
+ if upper == "NOT_A_DATA_QUESTION":
83
+ return "⚠️ That question doesn't appear to be about the database. Try asking something that can be answered by querying the schema."
84
+
85
+ if not upper.startswith(VALID_SQL_STARTS):
86
+ return (
87
+ f"⚠️ The model returned an unexpected response instead of SQL:\n\n{sql}\n\n"
88
+ "Try rephrasing your question to be more specific about the data."
89
+ )
90
+
91
+ return sql # looks good
92
+
93
+
94
+ # ─────────────────────────────────────────────
95
+ # MAIN GENERATOR
96
+ # ─────────────────────────────────────────────
97
+ def generate_sql(question: str, schema_ddl: str):
98
+ if not question.strip():
99
+ yield "⚠️ Please enter a question."
100
+ return
101
+ if not schema_ddl.strip():
102
+ yield "⚠️ Please provide your schema DDL."
103
+ return
104
+
105
+ prompt = (
106
+ f"Database schema:\n{schema_ddl.strip()}\n\n"
107
+ f"Question: {question.strip()}\n\n"
108
+ "SQL:"
109
+ )
110
+
111
+ accumulated = ""
112
+ try:
113
+ for token in client.chat_completion(
114
+ messages=[
115
+ {"role": "system", "content": SYSTEM_PROMPT},
116
+ {"role": "user", "content": prompt},
117
+ ],
118
+ max_tokens=500,
119
+ temperature=0.0,
120
+ stream=True,
121
+ ):
122
+ chunk = token.choices[0].delta.content or ""
123
+ accumulated += chunk
124
+ yield accumulated # stream raw while typing
125
+
126
+ except Exception as e:
127
+ yield f"❌ Error calling model: {e}"
128
+ return
129
+
130
+ # Final: clean then validate
131
+ final = validate_sql(clean_sql(accumulated))
132
+ yield final
133
+
134
+
135
+ # ─────────────────────────────────────────────
136
+ # GRADIO UI
137
+ # ─────────────────────────────────────────────
138
  demo = gr.Interface(
139
  fn=generate_sql,
140
  inputs=[
141
+ gr.Textbox(
142
+ label="Question",
143
+ placeholder="Show me how many patients above 50 have asthma",
144
+ ),
145
+ gr.Textbox(
146
+ label="Schema DDL",
147
+ lines=10,
148
+ placeholder="CREATE TABLE patients (...)",
149
+ ),
150
  ],
151
+ outputs=gr.Textbox(label="Generated SQL", show_copy_button=True),
152
  title="TinyEHR Text-to-SQL",
153
  description="Generate SQL queries for the TinyEHR dataset from natural language.",
154
+ flagging_mode="never",
155
  )
156
 
157
+ demo.launch()