shukdevdattaEX commited on
Commit
944a160
Β·
verified Β·
1 Parent(s): f9d7c74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -439
app.py CHANGED
@@ -3,10 +3,9 @@ from groq import Groq
3
  from pydantic import BaseModel
4
  import json
5
  import sqlite3
6
- import pandas as pd
7
  from datetime import datetime, timedelta
8
  import random
9
- import re
10
 
11
  # Pydantic models for structured output
12
  class ValidationStatus(BaseModel):
@@ -21,510 +20,328 @@ class SQLQueryGeneration(BaseModel):
21
  execution_notes: list[str]
22
  validation_status: ValidationStatus
23
 
24
- def extract_table_schema_from_sql(sql_query):
25
- """Extract all column names and table names from SQL query"""
26
- # Extract table names
27
- table_pattern = r'FROM\s+(\w+)|JOIN\s+(\w+)'
28
- tables = re.findall(table_pattern, sql_query, re.IGNORECASE)
29
- table_names = [t[0] or t[1] for t in tables]
30
-
31
- # Extract column names from SELECT, WHERE, GROUP BY, ORDER BY
32
- # Remove aliases (AS something)
33
- cleaned_query = re.sub(r'\s+AS\s+\w+', '', sql_query, flags=re.IGNORECASE)
34
-
35
- # Find all potential column references (table.column or column)
36
- column_pattern = r'(?:[\w]+\.)?(\w+)'
37
-
38
- # Extract from different parts
39
- columns = set()
40
-
41
- # From SELECT clause
42
- select_match = re.search(r'SELECT\s+(.+?)\s+FROM', sql_query, re.IGNORECASE | re.DOTALL)
43
- if select_match:
44
- select_part = select_match.group(1)
45
- # Remove aggregation functions
46
- select_part = re.sub(r'(SUM|COUNT|AVG|MAX|MIN|DISTINCT)\s*\(', '', select_part, flags=re.IGNORECASE)
47
- select_part = re.sub(r'\)', '', select_part)
48
- cols = re.findall(r'[\w]+\.(\w+)|(?:^|,\s*)(\w+)', select_part)
49
- for col in cols:
50
- c = col[0] or col[1]
51
- if c and c.upper() not in ['SELECT', 'FROM', 'WHERE', 'AS', 'ON']:
52
- columns.add(c.lower())
53
-
54
- # From WHERE clause
55
- where_match = re.search(r'WHERE\s+(.+?)(?:GROUP|ORDER|LIMIT|$)', sql_query, re.IGNORECASE | re.DOTALL)
56
- if where_match:
57
- where_part = where_match.group(1)
58
- cols = re.findall(r'[\w]+\.(\w+)|(\w+)\s*[=<>!]', where_part)
59
- for col in cols:
60
- c = col[0] or col[1]
61
- if c and c.upper() not in ['AND', 'OR', 'NOT', 'IN', 'LIKE', 'IS', 'NULL']:
62
- columns.add(c.lower())
63
-
64
- # From JOIN ON clause
65
- join_matches = re.findall(r'ON\s+(.+?)(?:WHERE|GROUP|ORDER|JOIN|$)', sql_query, re.IGNORECASE)
66
- for join_match in join_matches:
67
- cols = re.findall(r'[\w]+\.(\w+)', join_match)
68
- columns.update([c.lower() for c in cols])
69
-
70
- # From GROUP BY
71
- group_match = re.search(r'GROUP\s+BY\s+(.+?)(?:ORDER|HAVING|LIMIT|$)', sql_query, re.IGNORECASE)
72
- if group_match:
73
- group_part = group_match.group(1)
74
- cols = re.findall(r'[\w]+\.(\w+)|(\w+)', group_part)
75
- for col in cols:
76
- c = col[0] or col[1]
77
- if c:
78
- columns.add(c.lower())
79
-
80
- # From ORDER BY
81
- order_match = re.search(r'ORDER\s+BY\s+(.+?)(?:LIMIT|$)', sql_query, re.IGNORECASE)
82
- if order_match:
83
- order_part = order_match.group(1)
84
- cols = re.findall(r'[\w]+\.(\w+)|(\w+)', order_part)
85
- for col in cols:
86
- c = col[0] or col[1]
87
- if c and c.upper() not in ['ASC', 'DESC']:
88
- columns.add(c.lower())
89
-
90
- return list(set(table_names)), list(columns)
91
-
92
- def generate_table_with_columns(table_name, required_columns, row_count=15):
93
- """Generate table data ensuring ALL required columns exist"""
94
-
95
- # Helper functions
96
- def gen_id():
97
- return list(range(1, row_count + 1))
98
-
99
- def gen_names():
100
- first = ["Alice", "Bob", "Carol", "David", "Emma", "Frank", "Grace", "Henry", "Ivy", "Jack",
101
- "Karen", "Leo", "Maria", "Nathan", "Olivia"]
102
- last = ["Johnson", "Smith", "Williams", "Brown", "Jones", "Garcia", "Miller", "Davis",
103
- "Rodriguez", "Martinez", "Anderson", "Taylor", "Thomas", "Moore", "Jackson"]
104
- return [f"{random.choice(first)} {random.choice(last)}" for _ in range(row_count)]
105
-
106
- def gen_emails():
107
- return [f"user{i}@example.com" for i in range(1, row_count + 1)]
108
-
109
- def gen_dates(days_back=365):
110
- base = datetime.now()
111
- return [(base - timedelta(days=random.randint(0, days_back))).strftime('%Y-%m-%d')
112
- for _ in range(row_count)]
113
-
114
- def gen_years():
115
- return [random.randint(2000, 2025) for _ in range(row_count)]
116
-
117
- def gen_amounts():
118
- return [round(random.uniform(100, 5000), 2) for _ in range(row_count)]
119
-
120
- def gen_salaries():
121
- return [random.choice([45000, 55000, 65000, 75000, 85000, 95000, 105000, 120000])
122
- for _ in range(row_count)]
123
-
124
- def gen_prices():
125
- return [round(random.uniform(10, 1000), 2) for _ in range(row_count)]
126
-
127
- def gen_quantities():
128
- return [random.randint(0, 100) for _ in range(row_count)]
129
-
130
- def gen_ratings():
131
- return [round(random.uniform(1.0, 10.0), 1) for _ in range(row_count)]
132
-
133
- def gen_scores():
134
- return [random.randint(60, 100) for _ in range(row_count)]
135
-
136
- def gen_ages():
137
- return [random.randint(18, 80) for _ in range(row_count)]
138
-
139
- def gen_boolean():
140
- return [random.choice([True, False, True, True]) for _ in range(row_count)]
141
-
142
- def gen_status():
143
- return [random.choice(['Active', 'Inactive', 'Pending', 'Active', 'Active'])
144
- for _ in range(row_count)]
145
-
146
- def gen_categories():
147
- return [random.choice(['Category A', 'Category B', 'Category C', 'Category D'])
148
- for _ in range(row_count)]
149
-
150
- def gen_foreign_key():
151
- return [random.randint(1, 15) for _ in range(row_count)]
152
-
153
- def gen_phone():
154
- return [f"+1-555-{random.randint(1000, 9999)}" for _ in range(row_count)]
155
-
156
- def gen_text():
157
- return [f"Text content {i}" for i in range(1, row_count + 1)]
158
-
159
- def gen_duration():
160
- return [random.randint(60, 240) for _ in range(row_count)]
161
-
162
- # Column type mapping based on name patterns
163
- def infer_column_data(col_name):
164
- col_lower = col_name.lower()
165
-
166
- # ID columns
167
- if col_lower.endswith('_id') or col_lower == 'id':
168
- if col_lower == f'{table_name}_id' or col_lower == 'id':
169
- return gen_id()
170
- return gen_foreign_key()
171
-
172
- # Name columns
173
- if 'name' in col_lower or 'title' in col_lower:
174
- return gen_names() if 'name' in col_lower else gen_text()
175
-
176
- # Email columns
177
- if 'email' in col_lower:
178
- return gen_emails()
179
-
180
- # Phone columns
181
- if 'phone' in col_lower:
182
- return gen_phone()
183
-
184
- # Date columns
185
- if any(word in col_lower for word in ['date', 'created', 'updated', 'joined', 'registered', 'hired', 'published', 'visited', 'appointed', 'enrolled']):
186
- return gen_dates()
187
-
188
- # Year columns
189
- if 'year' in col_lower or col_lower.endswith('_year'):
190
- return gen_years()
191
-
192
- # Money/Amount columns
193
- if any(word in col_lower for word in ['salary', 'amount', 'price', 'cost', 'revenue', 'budget']):
194
- if 'salary' in col_lower:
195
- return gen_salaries()
196
- elif 'price' in col_lower or 'cost' in col_lower:
197
- return gen_prices()
198
- return gen_amounts()
199
-
200
- # Rating columns
201
- if 'rating' in col_lower or 'score' in col_lower:
202
- if 'rating' in col_lower:
203
- return gen_ratings()
204
- return gen_scores()
205
-
206
- # Age columns
207
- if 'age' in col_lower:
208
- return gen_ages()
209
-
210
- # Quantity/Stock columns
211
- if any(word in col_lower for word in ['quantity', 'stock', 'count', 'level']):
212
- return gen_quantities()
213
-
214
- # Status columns
215
- if 'status' in col_lower:
216
- return gen_status()
217
-
218
- # Category/Type columns
219
- if any(word in col_lower for word in ['category', 'type', 'genre', 'department', 'major', 'subject']):
220
- return gen_categories()
221
-
222
- # Boolean columns
223
- if any(word in col_lower for word in ['available', 'active', 'enabled', 'verified', 'completed']):
224
- return gen_boolean()
225
-
226
- # Duration/Time columns
227
- if any(word in col_lower for word in ['duration', 'time', 'minutes', 'hours']):
228
- return gen_duration()
229
-
230
- # Position/Role columns
231
- if any(word in col_lower for word in ['position', 'role', 'job', 'title']):
232
- return [random.choice(['Manager', 'Engineer', 'Analyst', 'Developer', 'Designer'])
233
- for _ in range(row_count)]
234
-
235
- # Default to text
236
- return gen_text()
237
-
238
- # Build the table schema
239
- table_data = {}
240
-
241
- # Ensure primary ID exists
242
- primary_id = f'{table_name}_id'
243
- if primary_id not in required_columns and 'id' not in required_columns:
244
- table_data[primary_id] = gen_id()
245
-
246
- # Add all required columns
247
- for col in required_columns:
248
- if col not in table_data:
249
- table_data[col] = infer_column_data(col)
250
-
251
- return table_data
252
-
253
- def create_database_from_sql(sql_query, tables_used):
254
- """Create SQLite database with sample data based on SQL query analysis"""
255
  conn = sqlite3.connect(':memory:')
256
-
257
- # Extract schema from SQL
258
- detected_tables, detected_columns = extract_table_schema_from_sql(sql_query)
259
-
260
- # Merge with provided tables
261
- all_tables = list(set(tables_used + detected_tables))
262
 
263
  sample_data = {}
264
 
265
- # For each table, determine which columns it needs
266
- for table in all_tables:
267
- table_name = table.lower().strip()
268
-
269
- # Find columns that belong to this table from SQL
270
- table_columns = []
271
-
272
- # Look for table.column references
273
- table_col_pattern = rf'{table_name}\.(\w+)'
274
- table_specific_cols = re.findall(table_col_pattern, sql_query, re.IGNORECASE)
275
- table_columns.extend([col.lower() for col in table_specific_cols])
276
-
277
- # If no table-specific columns found, add common columns based on detected columns
278
- if not table_columns:
279
- table_columns = detected_columns
280
-
281
- # Ensure we have at least some basic columns
282
- if not table_columns:
283
- table_columns = ['id', 'name', 'created_date', 'status']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
- # Generate table with required columns
286
- row_count = 5 if table_name == 'departments' else 15
287
- table_dict = generate_table_with_columns(table_name, table_columns, row_count)
 
 
 
 
288
 
289
- df = pd.DataFrame(table_dict)
290
- df.to_sql(table_name, conn, index=False, if_exists='replace')
291
- sample_data[table_name] = df
292
 
293
- return conn, sample_data
294
 
295
- def execute_sql_on_sample_data(sql_query, conn):
296
- """Execute the generated SQL query on sample database"""
297
  try:
298
- df_result = pd.read_sql_query(sql_query, conn)
299
- return df_result, None
 
 
 
 
 
 
 
300
  except Exception as e:
301
- return None, str(e)
302
 
303
- def process_nl_query(api_key, natural_query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  """Main function to process natural language query"""
305
  if not api_key:
306
- return "❌ Please enter your Groq API key", "", pd.DataFrame(), ""
307
 
308
- if not natural_query:
309
- return "❌ Please enter a natural language query", "", pd.DataFrame(), ""
310
 
311
  try:
312
- # Initialize Groq client
313
  client = Groq(api_key=api_key)
314
 
315
- # Step 1: Generate SQL from natural language
316
- output_text = "## πŸ“‹ STEP-BY-STEP PROCESS\n\n"
317
- output_text += "### Step 1: Understanding User Intent\n"
318
- output_text += f"**User Query:** {natural_query}\n\n"
319
-
320
- # Call Groq API for SQL generation with Kimi model
321
  response = client.chat.completions.create(
322
  model="moonshotai/kimi-k2-instruct-0905",
323
  messages=[
324
  {
325
  "role": "system",
326
- "content": """You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata.
327
-
328
- IMPORTANT: Return your response in JSON format with the following structure:
329
- {
330
- "query": "SQL query string",
331
- "query_type": "SELECT/INSERT/UPDATE/DELETE",
332
- "tables_used": ["table1", "table2"],
333
- "estimated_complexity": "low/medium/high",
334
- "execution_notes": ["note1", "note2"],
335
- "validation_status": {
336
- "is_valid": true/false,
337
- "syntax_errors": []
338
- }
339
- }
340
-
341
- CRITICAL SQL GENERATION RULES:
342
- - Use standard SQL syntax compatible with SQLite
343
- - Always use proper JOINs when multiple tables are involved
344
- - Use WHERE clauses for filtering
345
- - Use GROUP BY for aggregations
346
- - For date/year comparisons, use column names like 'release_year' NOT 'release_date' for year-based filtering
347
- - Common date columns: created_date, updated_date, order_date, hire_date, publication_year, release_year
348
- - Extract ALL table names mentioned or implied in the query and list them in "tables_used"
349
- - If a query mentions departments and employees, include BOTH tables
350
- - Be thorough in identifying all tables needed for the query
351
- - Use consistent column naming: prefer release_year over release_date for movies, publication_year for books
352
- - When filtering by years or time periods, use the appropriate column (release_year, publication_year, etc.)""",
353
- },
354
- {
355
- "role": "user",
356
- "content": f"Convert this natural language query to SQL and return as JSON. Use proper column names (e.g., release_year instead of release_date for year-based filters): {natural_query}"
357
  },
 
358
  ],
359
  response_format={
360
- "type": "json_object"
361
- },
362
- temperature=0.3
 
 
 
363
  )
364
 
365
- # Parse the response
366
- response_content = response.choices[0].message.content
367
- sql_data = json.loads(response_content)
368
-
369
- # Try to map to our Pydantic model with better error handling
370
- try:
371
- sql_query_gen = SQLQueryGeneration(**sql_data)
372
- except Exception as e:
373
- # If response doesn't match exact schema, create it manually
374
- sql_query_gen = SQLQueryGeneration(
375
- query=sql_data.get('query', sql_data.get('sql_query', '')),
376
- query_type=sql_data.get('query_type', 'SELECT'),
377
- tables_used=sql_data.get('tables_used', sql_data.get('tables', [])),
378
- estimated_complexity=sql_data.get('estimated_complexity', 'medium'),
379
- execution_notes=sql_data.get('execution_notes', sql_data.get('notes', [])),
380
- validation_status=ValidationStatus(
381
- is_valid=sql_data.get('validation_status', {}).get('is_valid', True),
382
- syntax_errors=sql_data.get('validation_status', {}).get('syntax_errors', [])
383
- )
384
- )
385
-
386
- # Step 2: Display Structured SQL Output
387
- output_text += "### Step 2: Generated Structured SQL\n\n"
388
- output_text += "```json\n"
389
- output_text += json.dumps(sql_query_gen.model_dump(), indent=2)
390
- output_text += "\n```\n\n"
391
-
392
- # Step 3: Generate Sample Database Tables - INTELLIGENT SCHEMA DETECTION
393
- output_text += "### Step 3: Auto-Generated Sample Database Tables\n\n"
394
- output_text += f"**Analyzing SQL query to create appropriate table schemas...**\n\n"
395
-
396
- conn, sample_data = create_database_from_sql(sql_query_gen.query, sql_query_gen.tables_used)
397
 
398
- # Display sample tables (show first 10 rows for readability)
399
- for table_name, df in sample_data.items():
400
- output_text += f"**πŸ“Š Sample `{table_name}` Table** ({len(df)} rows):\n\n"
401
- output_text += f"*Columns: {', '.join(df.columns.tolist())}*\n\n"
402
- display_df = df.head(10)
403
- output_text += display_df.to_markdown(index=False)
404
- if len(df) > 10:
405
- output_text += f"\n\n*...and {len(df) - 10} more rows*"
406
- output_text += "\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
- # Step 4: Execute SQL Query
409
- output_text += "### Step 4: Execute Generated SQL on Sample Tables\n\n"
410
- output_text += f"**SQL Query:**\n```sql\n{sql_query_gen.query}\n```\n\n"
 
 
411
 
412
- result_df, error = execute_sql_on_sample_data(sql_query_gen.query, conn)
 
413
 
414
  if error:
415
- output_text += f"❌ **Execution Error:** {error}\n\n"
416
- output_text += "**Troubleshooting:** The SQL query may reference columns that don't exist in the generated tables. "
417
- output_text += "This can happen if the AI model uses different column names than what was generated.\n"
418
- result_table = pd.DataFrame({"Error": [error]})
 
 
 
419
  else:
420
- output_text += "βœ… **Query executed successfully!**\n\n"
421
- output_text += f"**πŸ“ˆ SQL Execution Result** ({len(result_df)} rows returned):\n\n"
422
- if len(result_df) > 0:
423
- output_text += result_df.to_markdown(index=False)
424
- else:
425
- output_text += "*No results found matching the criteria*"
426
- result_table = result_df
 
 
 
427
 
428
  conn.close()
429
 
430
- # Format outputs for Gradio
431
- json_output = json.dumps(sql_query_gen.model_dump(), indent=2)
432
-
433
- return output_text, json_output, result_table, sql_query_gen.query
434
 
435
  except Exception as e:
436
- error_msg = f"❌ **Error:** {str(e)}\n\n**Full error details:**\n```\n{repr(e)}\n```\n\nPlease check your API key and try again."
437
- return error_msg, "", pd.DataFrame({"Error": [str(e)]}), ""
438
 
439
- # Create Gradio Interface
440
- with gr.Blocks(title="Natural Language to SQL Query Executor", theme=gr.themes.Ocean()) as demo:
441
  gr.Markdown("""
442
- # πŸ” Natural Language to SQL Query Executor with Intelligent Schema Detection
443
 
444
- Convert **ANY** natural language query into SQL, automatically generate matching database schemas, and execute queries!
445
 
446
  **Example queries to try:**
447
  - "Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount"
448
- - "Show all employees who earn more than $75,000 and work in the Engineering department"
449
- - "List students who scored above 85% in Mathematics"
450
- - "Find all movies released in the last 5 years with rating above 8.0"
451
- - "Show properties with price between $200,000 and $500,000"
452
- - "List all books published after 2020 that are available"
453
- - "Show active gym members whose membership expires in the next 30 days"
454
  """)
455
 
456
  with gr.Row():
457
- with gr.Column(scale=1):
458
  api_key_input = gr.Textbox(
459
  label="πŸ”‘ Groq API Key",
460
- type="password",
461
  placeholder="Enter your Groq API key here...",
462
- info="Get your API key from https://console.groq.com"
463
  )
464
-
465
  query_input = gr.Textbox(
466
  label="πŸ’¬ Natural Language Query",
467
- placeholder="e.g., Find all movies released in the last 5 years with rating above 8.0...",
468
  lines=3
469
  )
470
-
471
  submit_btn = gr.Button("πŸš€ Generate & Execute SQL", variant="primary", size="lg")
472
-
473
- gr.Markdown("### πŸ“ Generated SQL Query")
474
- sql_output = gr.Code(label="SQL Query", language="sql")
475
-
476
- with gr.Row():
477
- with gr.Column():
478
- gr.Markdown("### πŸ“Š Process & Results")
479
- process_output = gr.Markdown()
480
 
481
  with gr.Row():
482
  with gr.Column():
483
- gr.Markdown("### 🎯 Structured JSON Output")
484
- json_output = gr.Code(label="JSON Response", language="json")
 
 
485
 
486
- with gr.Row():
487
- with gr.Column():
488
- gr.Markdown("### πŸ“ˆ Query Execution Result")
489
- result_output = gr.Dataframe(
490
- label="Result Table",
491
- interactive=False,
492
- wrap=True
493
- )
494
-
495
- # Connect the button to the processing function
496
  submit_btn.click(
497
- fn=process_nl_query,
498
  inputs=[api_key_input, query_input],
499
- outputs=[process_output, json_output, result_output, sql_output]
500
  )
501
 
502
  gr.Markdown("""
503
  ---
504
- ### πŸ“– How it works:
505
- 1. **Enter your Groq API key** - Required for SQL generation (using Kimi K2 Instruct model)
506
- 2. **Write your query in plain English** - Describe what data you want to find
507
- 3. **Click Generate & Execute** - The system will:
508
- - Convert your query to SQL
509
- - **Intelligently analyze the SQL to detect required columns**
510
- - Automatically create tables with the exact columns needed
511
- - Generate realistic sample data matching the schema
512
- - Execute the query
513
- - Show you the results
514
-
515
- ### 🎯 Revolutionary Features:
516
- - βœ… **AI-powered SQL generation** using Kimi K2 Instruct
517
- - βœ… **Intelligent schema detection** - Analyzes SQL to create matching tables
518
- - βœ… **Dynamic column inference** - Automatically determines column types from SQL
519
- - βœ… **Handles ANY query** - No predefined schemas, works with any table/column combination
520
- - βœ… **Smart data generation** - Creates realistic data based on column names
521
- - βœ… **Zero errors** - Tables always match the generated SQL
522
- - βœ… **Universal support** - Works with employees, movies, students, products, and ANY other domain!
523
-
524
- ### 🧠 Intelligence:
525
- The system analyzes your SQL query to understand what columns are needed, then generates tables with exactly those columns!
526
  """)
527
 
528
- # Launch the app
529
  if __name__ == "__main__":
530
- demo.launch()
 
3
  from pydantic import BaseModel
4
  import json
5
  import sqlite3
6
+ import re
7
  from datetime import datetime, timedelta
8
  import random
 
9
 
10
  # Pydantic models for structured output
11
  class ValidationStatus(BaseModel):
 
20
  execution_notes: list[str]
21
  validation_status: ValidationStatus
22
 
23
+ def generate_sample_data(query, tables_used):
24
+ """Generate sample data based on the query and tables used"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  conn = sqlite3.connect(':memory:')
26
+ cursor = conn.cursor()
 
 
 
 
 
27
 
28
  sample_data = {}
29
 
30
+ # Generate data based on common table patterns
31
+ if 'customers' in tables_used:
32
+ cursor.execute('''
33
+ CREATE TABLE customers (
34
+ customer_id INTEGER PRIMARY KEY,
35
+ name TEXT,
36
+ email TEXT
37
+ )
38
+ ''')
39
+
40
+ customers = [
41
+ (1, 'Alice Johnson', 'alice@example.com'),
42
+ (2, 'Bob Smith', 'bob@example.com'),
43
+ (3, 'Carol Williams', 'carol@example.com'),
44
+ (4, 'David Brown', 'david@example.com'),
45
+ (5, 'Eve Davis', 'eve@example.com')
46
+ ]
47
+
48
+ cursor.executemany('INSERT INTO customers VALUES (?, ?, ?)', customers)
49
+ sample_data['customers'] = customers
50
+
51
+ if 'orders' in tables_used:
52
+ cursor.execute('''
53
+ CREATE TABLE orders (
54
+ order_id INTEGER PRIMARY KEY,
55
+ customer_id INTEGER,
56
+ total_amount REAL,
57
+ order_date TEXT
58
+ )
59
+ ''')
60
+
61
+ today = datetime.now()
62
+ orders = [
63
+ (101, 1, 600, (today - timedelta(days=10)).strftime('%Y-%m-%d')),
64
+ (102, 1, 450, (today - timedelta(days=5)).strftime('%Y-%m-%d')),
65
+ (103, 2, 1200, (today - timedelta(days=15)).strftime('%Y-%m-%d')),
66
+ (104, 3, 300, (today - timedelta(days=20)).strftime('%Y-%m-%d')),
67
+ (105, 3, 800, (today - timedelta(days=2)).strftime('%Y-%m-%d')),
68
+ (106, 4, 550, (today - timedelta(days=7)).strftime('%Y-%m-%d')),
69
+ (107, 5, 1500, (today - timedelta(days=12)).strftime('%Y-%m-%d'))
70
+ ]
71
+
72
+ cursor.executemany('INSERT INTO orders VALUES (?, ?, ?, ?)', orders)
73
+ sample_data['orders'] = orders
74
+
75
+ if 'products' in tables_used:
76
+ cursor.execute('''
77
+ CREATE TABLE products (
78
+ product_id INTEGER PRIMARY KEY,
79
+ product_name TEXT,
80
+ price REAL,
81
+ category TEXT
82
+ )
83
+ ''')
84
+
85
+ products = [
86
+ (1, 'Laptop', 999.99, 'Electronics'),
87
+ (2, 'Mouse', 29.99, 'Electronics'),
88
+ (3, 'Keyboard', 79.99, 'Electronics'),
89
+ (4, 'Monitor', 299.99, 'Electronics'),
90
+ (5, 'Desk', 199.99, 'Furniture')
91
+ ]
92
+
93
+ cursor.executemany('INSERT INTO products VALUES (?, ?, ?, ?)', products)
94
+ sample_data['products'] = products
95
+
96
+ if 'employees' in tables_used:
97
+ cursor.execute('''
98
+ CREATE TABLE employees (
99
+ employee_id INTEGER PRIMARY KEY,
100
+ name TEXT,
101
+ department TEXT,
102
+ salary REAL
103
+ )
104
+ ''')
105
 
106
+ employees = [
107
+ (1, 'John Doe', 'Engineering', 85000),
108
+ (2, 'Jane Smith', 'Marketing', 75000),
109
+ (3, 'Mike Johnson', 'Sales', 70000),
110
+ (4, 'Sarah Williams', 'Engineering', 90000),
111
+ (5, 'Tom Brown', 'HR', 65000)
112
+ ]
113
 
114
+ cursor.executemany('INSERT INTO employees VALUES (?, ?, ?, ?)', employees)
115
+ sample_data['employees'] = employees
 
116
 
117
+ return conn, cursor, sample_data
118
 
119
+ def execute_sql_query(cursor, query):
120
+ """Execute the SQL query and return results"""
121
  try:
122
+ # Convert MySQL/PostgreSQL specific functions to SQLite
123
+ sqlite_query = query.replace('DATE_SUB(NOW(), INTERVAL 30 DAY)',
124
+ f"date('now', '-30 days')")
125
+ sqlite_query = sqlite_query.replace('NOW()', "date('now')")
126
+
127
+ cursor.execute(sqlite_query)
128
+ results = cursor.fetchall()
129
+ columns = [description[0] for description in cursor.description]
130
+ return results, columns, None
131
  except Exception as e:
132
+ return None, None, str(e)
133
 
134
+ def format_sample_tables(sample_data):
135
+ """Format sample tables as HTML for display"""
136
+ html = "<div style='margin: 20px 0;'>"
137
+
138
+ for table_name, data in sample_data.items():
139
+ html += f"<h3>πŸ“Š Sample {table_name} Table</h3>"
140
+ html += "<table style='border-collapse: collapse; width: 100%; margin-bottom: 20px;'>"
141
+
142
+ if table_name == 'customers':
143
+ html += "<tr style='background-color: #f0f0f0;'><th style='border: 1px solid #ddd; padding: 8px;'>customer_id</th><th style='border: 1px solid #ddd; padding: 8px;'>name</th><th style='border: 1px solid #ddd; padding: 8px;'>email</th></tr>"
144
+ for row in data:
145
+ html += f"<tr><td style='border: 1px solid #ddd; padding: 8px;'>{row[0]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[1]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[2]}</td></tr>"
146
+
147
+ elif table_name == 'orders':
148
+ html += "<tr style='background-color: #f0f0f0;'><th style='border: 1px solid #ddd; padding: 8px;'>order_id</th><th style='border: 1px solid #ddd; padding: 8px;'>customer_id</th><th style='border: 1px solid #ddd; padding: 8px;'>total_amount</th><th style='border: 1px solid #ddd; padding: 8px;'>order_date</th></tr>"
149
+ for row in data:
150
+ html += f"<tr><td style='border: 1px solid #ddd; padding: 8px;'>{row[0]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[1]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[2]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[3]}</td></tr>"
151
+
152
+ elif table_name == 'products':
153
+ html += "<tr style='background-color: #f0f0f0;'><th style='border: 1px solid #ddd; padding: 8px;'>product_id</th><th style='border: 1px solid #ddd; padding: 8px;'>product_name</th><th style='border: 1px solid #ddd; padding: 8px;'>price</th><th style='border: 1px solid #ddd; padding: 8px;'>category</th></tr>"
154
+ for row in data:
155
+ html += f"<tr><td style='border: 1px solid #ddd; padding: 8px;'>{row[0]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[1]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[2]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[3]}</td></tr>"
156
+
157
+ elif table_name == 'employees':
158
+ html += "<tr style='background-color: #f0f0f0;'><th style='border: 1px solid #ddd; padding: 8px;'>employee_id</th><th style='border: 1px solid #ddd; padding: 8px;'>name</th><th style='border: 1px solid #ddd; padding: 8px;'>department</th><th style='border: 1px solid #ddd; padding: 8px;'>salary</th></tr>"
159
+ for row in data:
160
+ html += f"<tr><td style='border: 1px solid #ddd; padding: 8px;'>{row[0]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[1]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[2]}</td><td style='border: 1px solid #ddd; padding: 8px;'>{row[3]}</td></tr>"
161
+
162
+ html += "</table>"
163
+
164
+ html += "</div>"
165
+ return html
166
+
167
+ def format_execution_result(results, columns):
168
+ """Format SQL execution results as HTML table"""
169
+ if not results:
170
+ return "<p>No results found.</p>"
171
+
172
+ html = "<div style='margin: 20px 0;'>"
173
+ html += "<h3>βœ… SQL Execution Result (Final Output Table)</h3>"
174
+ html += "<table style='border-collapse: collapse; width: 100%;'>"
175
+
176
+ # Header
177
+ html += "<tr style='background-color: #4CAF50; color: white;'>"
178
+ for col in columns:
179
+ html += f"<th style='border: 1px solid #ddd; padding: 8px;'>{col}</th>"
180
+ html += "</tr>"
181
+
182
+ # Rows
183
+ for row in results:
184
+ html += "<tr>"
185
+ for cell in row:
186
+ html += f"<td style='border: 1px solid #ddd; padding: 8px;'>{cell}</td>"
187
+ html += "</tr>"
188
+
189
+ html += "</table></div>"
190
+ return html
191
+
192
+ def process_query(api_key, user_query):
193
  """Main function to process natural language query"""
194
  if not api_key:
195
+ return "❌ Please enter your Groq API key", "", "", ""
196
 
197
+ if not user_query:
198
+ return "❌ Please enter a query", "", "", ""
199
 
200
  try:
201
+ # Step 1: Generate SQL using Groq
202
  client = Groq(api_key=api_key)
203
 
 
 
 
 
 
 
204
  response = client.chat.completions.create(
205
  model="moonshotai/kimi-k2-instruct-0905",
206
  messages=[
207
  {
208
  "role": "system",
209
+ "content": "You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. Use standard SQL syntax compatible with MySQL.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  },
211
+ {"role": "user", "content": user_query},
212
  ],
213
  response_format={
214
+ "type": "json_schema",
215
+ "json_schema": {
216
+ "name": "sql_query_generation",
217
+ "schema": SQLQueryGeneration.model_json_schema()
218
+ }
219
+ }
220
  )
221
 
222
+ sql_query_generation = SQLQueryGeneration.model_validate(
223
+ json.loads(response.choices[0].message.content)
224
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # Step 2: Format the structured output
227
+ step1_output = f"""
228
+ ## 🎯 Step 1: Understand User Intent
229
+
230
+ **User Query:** "{user_query}"
231
+
232
+ **Identified Components:**
233
+ - **Tables:** {', '.join(sql_query_generation.tables_used)}
234
+ - **Query Type:** {sql_query_generation.query_type}
235
+ - **Complexity:** {sql_query_generation.estimated_complexity}
236
+ """
237
+
238
+ step2_output = f"""
239
+ ## πŸ”§ Step 2: Generate Structured SQL
240
+
241
+ ```json
242
+ {json.dumps(sql_query_generation.model_dump(), indent=2)}
243
+ ```
244
+
245
+ **Generated SQL Query:**
246
+ ```sql
247
+ {sql_query_generation.query}
248
+ ```
249
+ """
250
+
251
+ # Step 3: Generate sample data
252
+ conn, cursor, sample_data = generate_sample_data(
253
+ sql_query_generation.query,
254
+ sql_query_generation.tables_used
255
+ )
256
 
257
+ step3_output = f"""
258
+ ## πŸ“Š Step 3: Auto-Generate Sample Database Tables
259
+
260
+ {format_sample_tables(sample_data)}
261
+ """
262
 
263
+ # Step 4: Execute query
264
+ results, columns, error = execute_sql_query(cursor, sql_query_generation.query)
265
 
266
  if error:
267
+ step4_output = f"""
268
+ ## ⚠️ Step 4: SQL Execution
269
+
270
+ **Error:** {error}
271
+
272
+ **Note:** The query might use database-specific functions. The sample execution uses SQLite.
273
+ """
274
  else:
275
+ step4_output = f"""
276
+ ## πŸš€ Step 4: Execute Generated SQL on Sample Tables
277
+
278
+ **Applied Conditions:**
279
+ {chr(10).join([f"- {note}" for note in sql_query_generation.execution_notes])}
280
+
281
+ {format_execution_result(results, columns)}
282
+
283
+ **Total Rows Returned:** {len(results)}
284
+ """
285
 
286
  conn.close()
287
 
288
+ return step1_output, step2_output, step3_output, step4_output
 
 
 
289
 
290
  except Exception as e:
291
+ error_msg = f"❌ **Error:** {str(e)}"
292
+ return error_msg, "", "", ""
293
 
294
+ # Create Gradio interface
295
+ with gr.Blocks(title="Natural Language to SQL Query Executor", theme=gr.themes.Soft()) as app:
296
  gr.Markdown("""
297
+ # πŸ” Natural Language to SQL Query Executor
298
 
299
+ Convert natural language queries to SQL, auto-generate sample data, and execute queries!
300
 
301
  **Example queries to try:**
302
  - "Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount"
303
+ - "Get all employees in the Engineering department with salary above 80000"
304
+ - "Show top 5 products by price"
 
 
 
 
305
  """)
306
 
307
  with gr.Row():
308
+ with gr.Column():
309
  api_key_input = gr.Textbox(
310
  label="πŸ”‘ Groq API Key",
 
311
  placeholder="Enter your Groq API key here...",
312
+ type="password"
313
  )
 
314
  query_input = gr.Textbox(
315
  label="πŸ’¬ Natural Language Query",
316
+ placeholder="Enter your query in plain English...",
317
  lines=3
318
  )
 
319
  submit_btn = gr.Button("πŸš€ Generate & Execute SQL", variant="primary", size="lg")
 
 
 
 
 
 
 
 
320
 
321
  with gr.Row():
322
  with gr.Column():
323
+ step1_output = gr.Markdown(label="Step 1: Understanding")
324
+ step2_output = gr.Markdown(label="Step 2: SQL Generation")
325
+ step3_output = gr.HTML(label="Step 3: Sample Data")
326
+ step4_output = gr.HTML(label="Step 4: Execution Results")
327
 
 
 
 
 
 
 
 
 
 
 
328
  submit_btn.click(
329
+ fn=process_query,
330
  inputs=[api_key_input, query_input],
331
+ outputs=[step1_output, step2_output, step3_output, step4_output]
332
  )
333
 
334
  gr.Markdown("""
335
  ---
336
+ ### πŸ“ How it works:
337
+ 1. **Understand Intent:** Analyzes your natural language query
338
+ 2. **Generate SQL:** Creates structured SQL with metadata
339
+ 3. **Create Sample Data:** Auto-generates realistic sample tables
340
+ 4. **Execute & Display:** Runs the query and shows results
341
+
342
+ ### πŸ”— Get your Groq API key:
343
+ Visit [console.groq.com](https://console.groq.com) to get your free API key!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  """)
345
 
 
346
  if __name__ == "__main__":
347
+ app.launch()