vidulpanickan commited on
Commit
9a1702d
Β·
verified Β·
1 Parent(s): 7c1f47f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -7
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import os
3
  import re
4
  import gradio as gr
@@ -13,8 +12,8 @@ client = InferenceClient(
13
  # ─────────────────────────────────────────────
14
  # SYSTEM PROMPT (strict, few-shot)
15
  # ─────────────────────────────────────────────
16
- SYSTEM_PROMPT = """You are a strict SQL code generator for DuckDB.
17
-
18
  YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
19
 
20
  ABSOLUTE OUTPUT RULES β€” violating any rule makes the output wrong:
@@ -26,18 +25,23 @@ ABSOLUTE OUTPUT RULES β€” violating any rule makes the output wrong:
26
 
27
  SQL RULES:
28
  - Use ONLY table and column names that appear in the schema β€” never invent names.
29
- - Use DuckDB syntax exclusively.
30
  - Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
31
  - For SELECT queries, default to LIMIT 100 unless the user asks for all rows or a specific count.
32
  - Prefer the fewest JOINs and subqueries needed to answer the question.
33
  - Never use SELECT * β€” always name the columns you need.
34
  - Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
35
  - Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
36
- ...
 
 
 
 
37
  FEW-SHOT EXAMPLES:
38
 
39
  Schema:
40
  CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
 
41
 
42
  Q: How many patients above 50 have asthma?
43
  A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
@@ -48,6 +52,9 @@ A: SELECT patient_id, age, gender, diagnosis FROM patients WHERE died = true LIM
48
  Q: What is the average age of female patients?
49
  A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
50
 
 
 
 
51
  Q: Hello, how are you?
52
  A: NOT_A_DATA_QUESTION
53
 
@@ -63,11 +70,19 @@ VALID_SQL_STARTS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE", "CREATE", "D
63
 
64
 
65
  def clean_sql(raw: str) -> str:
66
- """Remove markdown fences, leading 'sql' keyword, and extra whitespace."""
67
  sql = raw.strip()
 
 
 
68
  sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
69
  sql = re.sub(r"```$", "", sql)
 
70
  sql = re.sub(r"(?i)^sql\s+", "", sql)
 
 
 
 
71
  return sql.strip()
72
 
73
 
@@ -79,7 +94,10 @@ def validate_sql(sql: str) -> str:
79
  upper = sql.upper().strip()
80
 
81
  if upper == "NOT_A_DATA_QUESTION":
82
- return "⚠️ That question doesn't appear to be about the database. Try asking something that can be answered by querying the schema."
 
 
 
83
 
84
  if not upper.startswith(VALID_SQL_STARTS):
85
  return (
 
 
1
  import os
2
  import re
3
  import gradio as gr
 
12
  # ─────────────────────────────────────────────
13
  # SYSTEM PROMPT (strict, few-shot)
14
  # ─────────────────────────────────────────────
15
+ SYSTEM_PROMPT = """\
16
+ You are a strict SQL code generator for DuckDB.
17
  YOUR ONLY JOB is to output a single, valid DuckDB SQL query.
18
 
19
  ABSOLUTE OUTPUT RULES β€” violating any rule makes the output wrong:
 
25
 
26
  SQL RULES:
27
  - Use ONLY table and column names that appear in the schema β€” never invent names.
28
+ - Use DuckDB syntax exclusively. Never use SQLite or MySQL syntax.
29
  - Text matching: always use ILIKE '%term%'. Never use LOWER() or UPPER() for comparison.
30
  - For SELECT queries, default to LIMIT 100 unless the user asks for all rows or a specific count.
31
  - Prefer the fewest JOINs and subqueries needed to answer the question.
32
  - Never use SELECT * β€” always name the columns you need.
33
  - Age filters: use a numeric comparison on the age column directly (e.g. age > 50).
34
  - Counts: use COUNT(*) or COUNT(column). Alias it clearly, e.g. AS num_patients.
35
+ - Date arithmetic: NEVER use julianday(). Use datediff('day', start_col, end_col) for days between two timestamps. Use epoch(end_col - start_col) / 86400 for interval-to-days.
36
+ - Identifier quoting: wrap table and column names in double quotes if they start with a digit or contain special characters (e.g. "2b_concept", "my-column").
37
+ - String concatenation: use || operator, never CONCAT().
38
+ - Current date/time: use current_date or current_timestamp, never NOW().
39
+
40
  FEW-SHOT EXAMPLES:
41
 
42
  Schema:
43
  CREATE TABLE patients (patient_id INT, age INT, gender VARCHAR, diagnosis VARCHAR, died BOOLEAN);
44
+ CREATE TABLE admissions (subject_id INT, admittime TIMESTAMP, dischtime TIMESTAMP, admission_type VARCHAR);
45
 
46
  Q: How many patients above 50 have asthma?
47
  A: SELECT COUNT(*) AS num_patients FROM patients WHERE age > 50 AND diagnosis ILIKE '%asthma%';
 
52
  Q: What is the average age of female patients?
53
  A: SELECT AVG(age) AS avg_age FROM patients WHERE gender ILIKE '%female%';
54
 
55
+ Q: Who are the top 10 patients with the longest hospital stay?
56
+ A: SELECT a.subject_id, datediff('day', a.admittime, a.dischtime) AS stay_days FROM admissions a WHERE a.dischtime IS NOT NULL ORDER BY stay_days DESC LIMIT 10;
57
+
58
  Q: Hello, how are you?
59
  A: NOT_A_DATA_QUESTION
60
 
 
70
 
71
 
72
  def clean_sql(raw: str) -> str:
73
+ """Remove markdown fences, leading 'sql' keyword, thinking tags, and extra whitespace."""
74
  sql = raw.strip()
75
+ # Strip <think>...</think> blocks (some models emit these)
76
+ sql = re.sub(r"<think>.*?</think>", "", sql, flags=re.DOTALL)
77
+ # Strip markdown code fences
78
  sql = re.sub(r"^```[a-zA-Z]*\n?", "", sql)
79
  sql = re.sub(r"```$", "", sql)
80
+ # Strip leading "sql" keyword
81
  sql = re.sub(r"(?i)^sql\s+", "", sql)
82
+ # Strip any trailing text after the semicolon
83
+ semi_match = re.search(r";", sql)
84
+ if semi_match:
85
+ sql = sql[: semi_match.end()]
86
  return sql.strip()
87
 
88
 
 
94
  upper = sql.upper().strip()
95
 
96
  if upper == "NOT_A_DATA_QUESTION":
97
+ return (
98
+ "⚠️ That question doesn't appear to be about the database. "
99
+ "Try asking something that can be answered by querying the schema."
100
+ )
101
 
102
  if not upper.startswith(VALID_SQL_STARTS):
103
  return (