vidulpanickan commited on
Commit
cb2218d
·
verified ·
1 Parent(s): 8077940

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
+
5
+ HF_TOKEN = os.environ.get("HF_TOKEN")
6
+ client = InferenceClient(
7
+ model="Qwen/Qwen2.5-Coder-32B-Instruct",
8
+ token=HF_TOKEN,
9
+ )
10
+
11
+ SYSTEM_PROMPT = """You are a SQL expert. Given a database schema and a question in English, generate a DuckDB-compatible SQL query that answers the question.
12
+
13
+ Rules:
14
+ - Return ONLY the SQL query, no explanation, no markdown, no code fences
15
+ - Use exact table and column names from the schema
16
+ - Use DuckDB SQL syntax
17
+ - Add LIMIT 100 unless the user asks for a specific count or all rows
18
+ - Keep queries simple: use the fewest tables and JOINs possible
19
+ - For text search, always use ILIKE '%term%' (case-insensitive) instead of exact match
20
+ - Never use LOWER() or UPPER() for comparison, use ILIKE instead
21
+ - If the question is not related to querying the database (e.g. personal questions, general knowledge, chitchat), respond with exactly: NOT_A_DATA_QUESTION
22
+ - Only generate SQL for questions that can be answered by querying the provided database schema"""
23
+
24
+
25
+ def generate_sql(question: str, schema_ddl: str) -> str:
26
+ prompt = f"Database schema:\n{schema_ddl}\n\nQuestion: {question}\n\nSQL:"
27
+
28
+ sql = ""
29
+ for token in client.chat_completion(
30
+ messages=[
31
+ {"role": "system", "content": SYSTEM_PROMPT},
32
+ {"role": "user", "content": prompt},
33
+ ],
34
+ max_tokens=500,
35
+ temperature=0.1,
36
+ stream=True,
37
+ ):
38
+ chunk = token.choices[0].delta.content or ""
39
+ sql += chunk
40
+ yield sql.strip()
41
+
42
+ # Clean up final result
43
+ sql = sql.strip()
44
+ if sql.startswith("```"):
45
+ sql = sql.split("\n", 1)[1] if "\n" in sql else sql[3:]
46
+ if sql.endswith("```"):
47
+ sql = sql[:-3]
48
+ sql = sql.strip()
49
+ if sql.lower().startswith("sql"):
50
+ sql = sql[3:].strip()
51
+
52
+ yield sql
53
+
54
+
55
+ demo = gr.Interface(
56
+ fn=generate_sql,
57
+ inputs=[
58
+ gr.Textbox(label="Question", placeholder="Show me all patients who died during their hospital stay"),
59
+ gr.Textbox(label="Schema DDL", lines=10, placeholder="CREATE TABLE patients (...)"),
60
+ ],
61
+ outputs=gr.Textbox(label="Generated SQL"),
62
+ title="TinyEHR Text-to-SQL",
63
+ description="Generate SQL queries for the TinyEHR dataset from natural language.",
64
+ )
65
+
66
+ demo.launch()