Sazzz02 commited on
Commit
f61d013
Β·
verified Β·
1 Parent(s): 4dec6b1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import sqlite3
4
+ import numpy as np
5
+ import json
6
+ import re
7
+ from typing import List, Dict, Tuple
8
+ from groq import Groq
9
+ import gradio as gr
10
+ from sklearn.metrics import accuracy_score
11
+ import warnings
12
+ warnings.filterwarnings('ignore')
13
+
14
+ # ------------------------------
15
+ # βœ… GROQ API KEY FROM ENVIRONMENT
16
+ # ------------------------------
17
+ # Don't hardcode API keys - use Hugging Face Secrets
18
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
19
+
20
+ if not GROQ_API_KEY:
21
+ raise ValueError("GROQ_API_KEY environment variable not set. Please add it to your Hugging Face Space secrets.")
22
+
23
+ # ------------------------------
24
+ # SQL Converter Using Groq API
25
+ # ------------------------------
26
+
27
+ class EnhancedNL2SQLConverter:
28
+ def __init__(self, model_name: str = "llama3-70b-8192"):
29
+ self.client = Groq(api_key=GROQ_API_KEY)
30
+ self.model_name = model_name
31
+ print(f"Using Groq API with model: {self.model_name}")
32
+
33
+ self.default_schema = """
34
+ Table: employees
35
+ Columns:
36
+ - id (INTEGER) PRIMARY KEY
37
+ - name (TEXT) NOT NULL
38
+ - department (TEXT)
39
+ - salary (REAL)
40
+ - hire_date (TEXT)
41
+ - manager_id (INTEGER)
42
+ """
43
+
44
+ def generate_sql(self, query: str, schema: str = None) -> str:
45
+ try:
46
+ schema_to_use = schema or self.default_schema
47
+
48
+ system_prompt = """You are an expert SQL query generator. Convert natural language questions to SQL queries based on the provided database schema.
49
+
50
+ Rules:
51
+ 1. Only return the SQL query, nothing else
52
+ 2. Use proper SQL syntax
53
+ 3. Be precise with column names and table names
54
+ 4. Use appropriate WHERE clauses, JOINs, and aggregations as needed
55
+ 5. For date comparisons, use proper date format
56
+ 6. Don't include explanations, just the SQL query"""
57
+
58
+ user_prompt = f"""Database Schema:
59
+ {schema_to_use}
60
+
61
+ Natural Language Question: {query}
62
+
63
+ Generate the SQL query:"""
64
+
65
+ chat_completion = self.client.chat.completions.create(
66
+ messages=[
67
+ {"role": "system", "content": system_prompt},
68
+ {"role": "user", "content": user_prompt}
69
+ ],
70
+ model=self.model_name,
71
+ temperature=0.1,
72
+ max_tokens=200
73
+ )
74
+
75
+ sql_query = chat_completion.choices[0].message.content.strip()
76
+ return self._clean_sql(sql_query)
77
+
78
+ except Exception as e:
79
+ print(f"Error generating SQL: {str(e)}")
80
+ return f"ERROR: Could not generate SQL query - {str(e)}"
81
+
82
+ def _clean_sql(self, sql: str) -> str:
83
+ sql = sql.strip()
84
+ sql = re.sub(r'```sql\n?', '', sql)
85
+ sql = re.sub(r'```\n?', '', sql)
86
+ sql = re.sub(r'^["\']|["\']$', '', sql)
87
+ sql = sql.rstrip(';')
88
+
89
+ sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER']
90
+ if not any(sql.upper().startswith(keyword) for keyword in sql_keywords):
91
+ for keyword in sql_keywords:
92
+ if keyword in sql.upper():
93
+ sql = sql[sql.upper().find(keyword):]
94
+ break
95
+ return sql
96
+
97
+ # ------------------------------
98
+ # SQL Evaluator & Test Database
99
+ # ------------------------------
100
+
101
+ class SQLEvaluator:
102
+ def __init__(self):
103
+ self.db_path = "test_database.db"
104
+ self.setup_test_database()
105
+
106
+ def setup_test_database(self):
107
+ conn = sqlite3.connect(self.db_path)
108
+ cursor = conn.cursor()
109
+
110
+ # Create employees table
111
+ cursor.execute('''
112
+ CREATE TABLE IF NOT EXISTS employees (
113
+ id INTEGER PRIMARY KEY,
114
+ name TEXT NOT NULL,
115
+ department TEXT,
116
+ salary REAL,
117
+ hire_date TEXT,
118
+ manager_id INTEGER
119
+ )''')
120
+
121
+ # Insert sample data
122
+ sample_data = [
123
+ (1, 'Alice Johnson', 'Engineering', 75000, '2022-01-15', None),
124
+ (2, 'Bob Smith', 'Sales', 65000, '2021-06-20', None),
125
+ (3, 'Charlie Brown', 'Engineering', 80000, '2020-03-10', 1),
126
+ (4, 'Diana Prince', 'HR', 60000, '2023-02-28', None),
127
+ (5, 'Eve Wilson', 'Sales', 70000, '2022-11-05', 2),
128
+ (6, 'Frank Miller', 'Engineering', 85000, '2019-08-12', 1),
129
+ (7, 'Grace Lee', 'Marketing', 55000, '2023-01-20', None),
130
+ (8, 'Henry Davis', 'Engineering', 72000, '2022-07-30', 1)
131
+ ]
132
+
133
+ cursor.executemany('''
134
+ INSERT OR REPLACE INTO employees (id, name, department, salary, hire_date, manager_id)
135
+ VALUES (?, ?, ?, ?, ?, ?)''', sample_data)
136
+
137
+ conn.commit()
138
+ conn.close()
139
+ print("βœ… Test database initialized successfully")
140
+
141
+ def execute_sql(self, sql_query: str) -> Tuple[bool, any]:
142
+ try:
143
+ conn = sqlite3.connect(self.db_path)
144
+ cursor = conn.cursor()
145
+ cursor.execute(sql_query)
146
+
147
+ if sql_query.strip().upper().startswith('SELECT'):
148
+ results = cursor.fetchall()
149
+ columns = [description[0] for description in cursor.description]
150
+ conn.close()
151
+ return True, {'columns': columns, 'data': results}
152
+ else:
153
+ conn.commit()
154
+ conn.close()
155
+ return True, "Query executed successfully"
156
+ except Exception as e:
157
+ return False, str(e)
158
+
159
+ # ------------------------------
160
+ # Initialize components
161
+ # ------------------------------
162
+ converter = EnhancedNL2SQLConverter()
163
+ evaluator = SQLEvaluator()
164
+
165
+ # ------------------------------
166
+ # Gradio Interface Functions
167
+ # ------------------------------
168
+
169
+ def process_nl_query(nl_query: str) -> Tuple[str, str, str]:
170
+ """Process natural language query and return SQL + results"""
171
+ if not nl_query.strip():
172
+ return "", "", "Please enter a natural language query."
173
+
174
+ try:
175
+ # Generate SQL
176
+ generated_sql = converter.generate_sql(nl_query)
177
+
178
+ if generated_sql.startswith("ERROR"):
179
+ return generated_sql, "", "Failed to generate SQL query."
180
+
181
+ # Execute SQL
182
+ success, result = evaluator.execute_sql(generated_sql)
183
+
184
+ if success and isinstance(result, dict):
185
+ # Format results as DataFrame
186
+ df = pd.DataFrame(result['data'], columns=result['columns'])
187
+ if len(df) == 0:
188
+ formatted_output = "No results found."
189
+ else:
190
+ formatted_output = df.to_string(index=False)
191
+ return generated_sql, formatted_output, "βœ… Query executed successfully!"
192
+ elif success:
193
+ return generated_sql, str(result), "βœ… Query executed successfully!"
194
+ else:
195
+ return generated_sql, "", f"❌ Error executing query: {result}"
196
+
197
+ except Exception as e:
198
+ return "", "", f"❌ Unexpected error: {str(e)}"
199
+
200
+ def get_sample_queries():
201
+ """Return sample queries for users to try"""
202
+ return [
203
+ "Show all employees in the Engineering department",
204
+ "Find employees with salary greater than 70000",
205
+ "List all employees hired after 2022",
206
+ "Count employees by department",
207
+ "Show the highest paid employee in each department",
208
+ "Find employees who don't have a manager",
209
+ "Show average salary by department"
210
+ ]
211
+
212
+ def load_sample_query(query):
213
+ """Load a sample query into the input"""
214
+ return query
215
+
216
+ # ------------------------------
217
+ # Gradio UI
218
+ # ------------------------------
219
+
220
+ # Custom CSS for better styling
221
+ css = """
222
+ .gradio-container {
223
+ max-width: 1200px !important;
224
+ }
225
+ .sample-queries {
226
+ margin: 10px 0;
227
+ }
228
+ """
229
+
230
+ with gr.Blocks(css=css, title="NL2SQL with Groq AI", theme=gr.themes.Soft()) as iface:
231
+ gr.Markdown("""
232
+ # πŸš€ Natural Language to SQL Converter
233
+
234
+ Convert your natural language questions into SQL queries using **Groq AI** and execute them on a sample employee database!
235
+
236
+ ### Sample Database Schema:
237
+ **employees** table with columns: `id`, `name`, `department`, `salary`, `hire_date`, `manager_id`
238
+ """)
239
+
240
+ with gr.Row():
241
+ with gr.Column(scale=2):
242
+ nl_input = gr.Textbox(
243
+ label="πŸ’¬ Enter Your Question",
244
+ placeholder="e.g., Show all employees in Engineering department",
245
+ lines=2
246
+ )
247
+
248
+ submit_btn = gr.Button("πŸ”„ Generate & Execute SQL", variant="primary")
249
+
250
+ with gr.Column(scale=1):
251
+ gr.Markdown("### πŸ“ Try These Sample Queries:")
252
+ sample_queries = get_sample_queries()
253
+
254
+ for i, query in enumerate(sample_queries):
255
+ gr.Button(
256
+ f"{query}",
257
+ variant="secondary",
258
+ size="sm"
259
+ ).click(
260
+ lambda q=query: q,
261
+ outputs=nl_input
262
+ )
263
+
264
+ with gr.Row():
265
+ with gr.Column():
266
+ sql_output = gr.Textbox(
267
+ label="πŸ”§ Generated SQL Query",
268
+ lines=3,
269
+ interactive=False
270
+ )
271
+
272
+ status_output = gr.Textbox(
273
+ label="πŸ“Š Status",
274
+ lines=1,
275
+ interactive=False
276
+ )
277
+
278
+ results_output = gr.Textbox(
279
+ label="πŸ“‹ Query Results",
280
+ lines=10,
281
+ interactive=False
282
+ )
283
+
284
+ # Event handlers
285
+ submit_btn.click(
286
+ fn=process_nl_query,
287
+ inputs=[nl_input],
288
+ outputs=[sql_output, results_output, status_output]
289
+ )
290
+
291
+ nl_input.submit(
292
+ fn=process_nl_query,
293
+ inputs=[nl_input],
294
+ outputs=[sql_output, results_output, status_output]
295
+ )
296
+
297
+ gr.Markdown("""
298
+ ### πŸ” About This App:
299
+ - **AI Model**: Groq's Llama3-70B for SQL generation
300
+ - **Database**: SQLite with sample employee data
301
+ - **Features**: Natural language processing, SQL execution, formatted results
302
+
303
+ ### πŸ’‘ Tips:
304
+ - Be specific in your questions
305
+ - Use clear, simple language
306
+ - Try the sample queries to get started
307
+ """)
308
+
309
+ # Launch the app
310
+ if __name__ == "__main__":
311
+ iface.launch()