Balaprime commited on
Commit
639a3c6
·
verified ·
1 Parent(s): c2ad5d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -0
app.py CHANGED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import re
5
+ import sqlparse
6
+
7
+ # Load model and tokenizer
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model = AutoModelForCausalLM.from_pretrained(
10
+ "onkolahmet/Qwen2-0.5B-Instruct-SQL-generator",
11
+ torch_dtype="auto",
12
+ device_map="auto"
13
+ )
14
+ tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator")
15
+
16
+ # # Few-shot examples to include in each prompt
17
+ # examples = [
18
+ # {
19
+ # "question": "Get the names and emails of customers who placed an order in the last 30 days.",
20
+ # "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
21
+ # },
22
+ # {
23
+ # "question": "Find all employees with a salary greater than 50000.",
24
+ # "sql": "SELECT * FROM employees WHERE salary > 50000;"
25
+ # },
26
+ # {
27
+ # "question": "List all product names and their categories where the price is below 50.",
28
+ # "sql": "SELECT name, category FROM products WHERE price < 50;"
29
+ # },
30
+ # {
31
+ # "question": "How many users registered in the year 2022?",
32
+ # "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
33
+ # }
34
+ # ]
35
+
36
+ def generate_sql(question, context=None):
37
+ # Construct prompt with few-shot examples and context if available
38
+ prompt = "Translate natural language questions to SQL queries.\n\n"
39
+
40
+ # Add table context if available
41
+ if context and context.strip():
42
+ prompt += f"Table Context:\n{context}\n\n"
43
+
44
+ # # Add few-shot examples
45
+ # for ex in examples:
46
+ # prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
47
+
48
+ # Add the current question
49
+ prompt += f"Q: {question}\nSQL:"
50
+
51
+ # Tokenize and generate
52
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
53
+
54
+ # Generate SQL query
55
+ outputs = model.generate(
56
+ inputs.input_ids,
57
+ max_new_tokens=128,
58
+ do_sample=True,
59
+ eos_token_id=tokenizer.eos_token_id
60
+ )
61
+
62
+ # Extract and decode only the new generation
63
+ sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
64
+ return sql_query.strip()
65
+
66
+ def clean_sql_output(sql_text):
67
+ """
68
+ Clean and deduplicate SQL queries:
69
+ 1. Remove comments
70
+ 2. Remove duplicate queries
71
+ 3. Extract only the most relevant query
72
+ 4. Format properly
73
+ """
74
+ # Remove SQL comments (both single line and multi-line)
75
+ sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
76
+ sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
77
+
78
+ # Remove markdown code block syntax if present
79
+ sql_text = re.sub(r'```sql|```', '', sql_text)
80
+
81
+ # Split into individual queries if multiple exist
82
+ if ';' in sql_text:
83
+ queries = [q.strip() for q in sql_text.split(';') if q.strip()]
84
+ else:
85
+ # If no semicolons, try to identify separate queries by SELECT statements
86
+ sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
87
+ select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
88
+
89
+ if len(select_matches) > 1:
90
+ queries = []
91
+ for i in range(len(select_matches)):
92
+ start = select_matches[i].start()
93
+ end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
94
+ queries.append(sql_text_cleaned[start:end].strip())
95
+ else:
96
+ queries = [sql_text]
97
+
98
+ # Remove empty queries
99
+ queries = [q for q in queries if q.strip()]
100
+
101
+ if not queries:
102
+ return ""
103
+
104
+ # If we have multiple queries, need to deduplicate
105
+ if len(queries) > 1:
106
+ # Normalize queries for comparison (lowercase, remove extra spaces)
107
+ normalized_queries = []
108
+ for q in queries:
109
+ # Use sqlparse to format and normalize
110
+ try:
111
+ formatted = sqlparse.format(
112
+ q + ('' if q.strip().endswith(';') else ';'),
113
+ keyword_case='lower',
114
+ identifier_case='lower',
115
+ strip_comments=True,
116
+ reindent=True
117
+ )
118
+ normalized_queries.append(formatted)
119
+ except:
120
+ # If sqlparse fails, just do basic normalization
121
+ normalized = re.sub(r'\s+', ' ', q.lower().strip())
122
+ normalized_queries.append(normalized)
123
+
124
+ # Find unique queries
125
+ unique_queries = []
126
+ unique_normalized = []
127
+
128
+ for i, norm_q in enumerate(normalized_queries):
129
+ if norm_q not in unique_normalized:
130
+ unique_normalized.append(norm_q)
131
+ unique_queries.append(queries[i])
132
+
133
+ # Choose the most likely correct query:
134
+ # 1. Prefer queries with SELECT
135
+ # 2. Prefer longer queries (often more detailed)
136
+ # 3. Prefer first query if all else equal
137
+ select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
138
+
139
+ if select_queries:
140
+ # Choose the longest SELECT query (likely most detailed)
141
+ best_query = max(select_queries, key=len)
142
+ elif unique_queries:
143
+ # If no SELECT queries, choose the longest query
144
+ best_query = max(unique_queries, key=len)
145
+ else:
146
+ # Fallback to the first query
147
+ best_query = queries[0]
148
+ else:
149
+ best_query = queries[0]
150
+
151
+ # Clean up the chosen query
152
+ best_query = best_query.strip()
153
+ if not best_query.endswith(';'):
154
+ best_query += ';'
155
+
156
+ # Final formatting to ensure consistent spacing
157
+ best_query = re.sub(r'\s+', ' ', best_query)
158
+
159
+ try:
160
+ # Use sqlparse to nicely format the SQL for display
161
+ formatted_sql = sqlparse.format(
162
+ best_query,
163
+ keyword_case='upper',
164
+ identifier_case='lower',
165
+ reindent=True,
166
+ indent_width=2
167
+ )
168
+ return formatted_sql
169
+ except:
170
+ return best_query
171
+
172
+ def process_input(question, table_context):
173
+ """Function to process user input through the model and return formatted results"""
174
+ if not question.strip():
175
+ return "Please enter a question."
176
+
177
+ # Generate SQL from the question and context
178
+ raw_sql = generate_sql(question, table_context)
179
+
180
+ # Clean the SQL output
181
+ cleaned_sql = clean_sql_output(raw_sql)
182
+
183
+ if not cleaned_sql:
184
+ return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question."
185
+
186
+ return cleaned_sql
187
+
188
+ # Sample table context examples for the example selector
189
+
190
+
191
+ # Sample question examples
192
+
193
+
194
+ # Create the Gradio interface
195
+ with gr.Blocks(title="Text to SQL Converter") as demo:
196
+ gr.Markdown("# Text to SQL Query Converter")
197
+ gr.Markdown("Enter your question and optional table context to generate an SQL query.")
198
+
199
+ with gr.Row():
200
+ with gr.Column():
201
+ question_input = gr.Textbox(
202
+ label="Your Question",
203
+ placeholder="e.g., Find all products with price less than $50",
204
+ lines=2
205
+ )
206
+
207
+ table_context = gr.Textbox(
208
+ label="Table Context (Optional)",
209
+ placeholder="Enter your database schema or table definitions here...",
210
+ lines=10
211
+ )
212
+
213
+ submit_btn = gr.Button("Generate SQL Query")
214
+
215
+ with gr.Column():
216
+ sql_output = gr.Code(
217
+ label="Generated SQL Query",
218
+ language="sql",
219
+ lines=12
220
+ )
221
+
222
+ # Examples section
223
+ gr.Markdown("### Try some examples")
224
+
225
+ example_selector = gr.Examples(
226
+ examples=[
227
+ ["List all products in the 'Electronics' category with price less than $500", example_contexts[1]],
228
+ ["Find the total number of employees in each department", example_contexts[2]],
229
+ ["Get customers who placed orders in the last 7 days", example_contexts[0]],
230
+ ["Count the number of products in each category", example_contexts[1]],
231
+ ["Find the average salary by department", example_contexts[2]]
232
+ ],
233
+ inputs=[question_input, table_context]
234
+ )
235
+
236
+ # Set up the submit button to trigger the process_input function
237
+ submit_btn.click(
238
+ fn=process_input,
239
+ inputs=[question_input, table_context],
240
+ outputs=sql_output
241
+ )
242
+
243
+ # Also trigger on pressing Enter in the question input
244
+ question_input.submit(
245
+ fn=process_input,
246
+ inputs=[question_input, table_context],
247
+ outputs=sql_output
248
+ )
249
+
250
+ # Add information about the model
251
+ gr.Markdown("""
252
+ ### About
253
+ This app uses a fine-tuned language model to convert natural language questions into SQL queries.
254
+
255
+ - **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
256
+ - **How to use**:
257
+ 1. Enter your question in natural language
258
+ 2. If you have specific table schemas, add them in the Table Context field
259
+ 3. Click "Generate SQL Query" or press Enter
260
+
261
+ Note: The model works best when table context is provided, but can generate generic SQL queries without it.
262
+ """)
263
+
264
+ # Launch the app
265
+ demo.launch()