SQLGenie / app.py
shukdevdattaEX's picture
Update app.py
82a80ce verified
raw
history blame
22.2 kB
import gradio as gr
from groq import Groq
from pydantic import BaseModel
import json
import sqlite3
import pandas as pd
from typing import List, Optional
import re
from datetime import datetime, timedelta
import random
# Pydantic Models
class ValidationStatus(BaseModel):
is_valid: bool
syntax_errors: list[str]
class SQLQueryGeneration(BaseModel):
query: str
query_type: str
tables_used: list[str]
estimated_complexity: str
execution_notes: list[str]
validation_status: ValidationStatus
class TableSchema(BaseModel):
table_name: str
columns: list[dict]
sample_data: list[dict]
def generate_sample_data(user_query: str, groq_api_key: str) -> dict:
"""Generate sample table schema and data based on user query"""
try:
client = Groq(api_key=groq_api_key)
# Get current date for context
today = datetime.now().strftime('%Y-%m-%d')
past_date_2y = (datetime.now() - timedelta(days=730)).strftime('%Y-%m-%d') # 2 years ago
past_date_60d = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 60 days ago
# Request to generate table schema and sample data
schema_prompt = f"""Based on this query: "{user_query}"
**Current date: {today}**
Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure:
{{
"tables": [
{{
"table_name": "table_name",
"columns": [
{{"name": "column_name", "type": "INTEGER|TEXT|REAL|DATE"}},
...
],
"sample_data": [
{{"column_name": value, ...}},
...at least 20-25 rows
]
}}
]
}}
**CRITICAL INSTRUCTIONS FOR REALISTIC DATA:**
1. **DATES MUST BE IN THE PAST!**
- For hire_date, created_at, registration_date: Use dates between {past_date_2y} and {today}
- For order_date, transaction_date: If query mentions "last X days", use dates between {past_date_60d} and {today}
- NEVER use future dates!
2. **For NUMERIC filters (salary, amount, price):**
- If query says "over $80000", make 50-60% of records have values ABOVE 80000
- Create realistic variation: some at 85k, some at 95k, some at 120k, etc.
- Also include records BELOW the threshold (40-50%) for realism
3. **For TEXT filters (department, category, status):**
- If query mentions "Engineering department", ensure 50-60% of records have department = "Engineering"
- Include other departments too: "Marketing", "Sales", "HR", "Finance" for variety
4. **Data quality:**
- Use realistic names, emails (first.last@company.com format)
- Make data diverse and meaningful
- Ensure enough records match the query criteria to get meaningful results
Example: For "Find Engineering employees with salary > 80000"
- Create 20+ employee records
- 12-15 should be in Engineering (60%)
- Of Engineering employees, 8-10 should have salary > 80000
- Include other departments with various salaries for realism"""
response = client.chat.completions.create(
model="moonshotai/kimi-k2-instruct-0905",
messages=[
{"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data. ALL DATES MUST BE IN THE PAST, NEVER IN THE FUTURE. Return ONLY valid JSON, no markdown formatting."},
{"role": "user", "content": schema_prompt}
],
temperature=0.7
)
# Parse response
content = response.choices[0].message.content.strip()
# Remove markdown code blocks if present
content = re.sub(r'```json\s*', '', content)
content = re.sub(r'```\s*$', '', content)
schema_data = json.loads(content)
# Post-process: Enhance and fix data to ensure query results
schema_data = enhance_sample_data(schema_data, user_query)
return schema_data
except Exception as e:
raise Exception(f"Error generating sample data: {str(e)}")
def enhance_sample_data(schema_data: dict, user_query: str) -> dict:
"""Enhance sample data to ensure queries return results and fix any date issues"""
query_lower = user_query.lower()
# Detect if query mentions time period (for order/transaction dates)
time_keywords = {
'last 30 days': 30,
'last 60 days': 60,
'last 7 days': 7,
'last week': 7,
'last month': 30,
'last quarter': 90,
'last year': 365
}
days_back = None
for keyword, days in time_keywords.items():
if keyword in query_lower:
days_back = days
break
# Detect amount/value thresholds
threshold_amount = None
amount_match = re.search(r'(?:over|above|greater than) \$?(\d+)', query_lower)
if amount_match:
threshold_amount = int(amount_match.group(1))
# Detect text filters (department, category, status, etc.)
text_filters = {}
# Department detection
dept_patterns = [
r'(?:in|from) (?:the )?(\w+) department',
r'department (?:is |= |== )?["\']?(\w+)["\']?',
r'(\w+) department',
]
for pattern in dept_patterns:
dept_match = re.search(pattern, query_lower)
if dept_match:
text_filters['department'] = dept_match.group(1).capitalize()
break
# Category detection
category_match = re.search(r'category (?:is |= )?["\']?(\w+)["\']?', query_lower)
if category_match:
text_filters['category'] = category_match.group(1).capitalize()
# Status detection
status_match = re.search(r'status (?:is |= )?["\']?(\w+)["\']?', query_lower)
if status_match:
text_filters['status'] = status_match.group(1).capitalize()
for table in schema_data['tables']:
enhanced_data = []
original_data = table['sample_data']
# Identify column types
date_cols = [col['name'] for col in table['columns'] if col['type'] == 'DATE']
amount_cols = [col['name'] for col in table['columns']
if any(keyword in col['name'].lower() for keyword in ['amount', 'price', 'salary', 'total', 'cost', 'revenue'])]
# Identify order/transaction date columns vs hire/created date columns
transaction_date_cols = [col for col in date_cols
if any(keyword in col.lower() for keyword in ['order', 'transaction', 'purchase', 'sale', 'payment'])]
other_date_cols = [col for col in date_cols if col not in transaction_date_cols]
for i, row in enumerate(original_data):
new_row = row.copy()
# FIX: Ensure transaction/order dates are in the past and within time period if specified
if transaction_date_cols:
for date_col in transaction_date_cols:
if date_col in new_row:
if days_back:
# Within specified period
random_days = random.randint(0, days_back)
else:
# Within last 60 days for transaction-type dates
random_days = random.randint(0, 60)
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
new_row[date_col] = new_date
# FIX: Ensure other dates (hire_date, created_at, etc.) are in the PAST
if other_date_cols:
for date_col in other_date_cols:
if date_col in new_row:
try:
# Check if date is in the future
current_date = datetime.strptime(new_row[date_col], '%Y-%m-%d')
if current_date > datetime.now():
# Replace with a past date (random between 1 month to 3 years ago)
random_days = random.randint(30, 1095)
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
new_row[date_col] = new_date
except:
# If date parsing fails, generate a new past date
random_days = random.randint(30, 1095)
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
new_row[date_col] = new_date
# Enhance amount fields to match threshold
if threshold_amount and amount_cols:
for amount_col in amount_cols:
if amount_col in new_row:
# 55% of records above threshold, 45% below
if i % 100 < 55: # More deterministic distribution
# Above threshold
new_row[amount_col] = int(random.uniform(threshold_amount * 1.05, threshold_amount * 2.5))
else:
# Below threshold
new_row[amount_col] = int(random.uniform(threshold_amount * 0.4, threshold_amount * 0.95))
# Apply text filters to ensure enough matching records
for col_name, target_value in text_filters.items():
if col_name in new_row:
# 55% should match the filter value
if i % 100 < 55:
new_row[col_name] = target_value
else:
# Use other values for variety
if col_name == 'department':
other_depts = ['Marketing', 'Sales', 'HR', 'Finance', 'Operations', 'IT']
new_row[col_name] = random.choice([d for d in other_depts if d != target_value])
elif col_name == 'status':
other_statuses = ['Active', 'Inactive', 'Pending', 'Completed', 'Cancelled']
new_row[col_name] = random.choice([s for s in other_statuses if s != target_value])
enhanced_data.append(new_row)
# Ensure we have at least 20 rows
while len(enhanced_data) < 20:
template_idx = len(enhanced_data) % len(original_data)
template_row = enhanced_data[template_idx].copy()
# Modify IDs to be unique
for col in table['columns']:
if 'id' in col['name'].lower() and col['type'] == 'INTEGER':
template_row[col['name']] = len(enhanced_data) + 1
enhanced_data.append(template_row)
table['sample_data'] = enhanced_data
return schema_data
def create_tables_in_db(schema_data: dict) -> sqlite3.Connection:
"""Create SQLite tables and populate with sample data"""
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
for table in schema_data['tables']:
table_name = table['table_name']
columns = table['columns']
# Create table
column_defs = []
for col in columns:
col_type = col['type'].upper()
column_defs.append(f"{col['name']} {col_type}")
create_table_sql = f"CREATE TABLE {table_name} ({', '.join(column_defs)})"
cursor.execute(create_table_sql)
# Insert sample data
sample_data = table['sample_data']
if sample_data:
col_names = [col['name'] for col in columns]
placeholders = ', '.join(['?' for _ in col_names])
insert_sql = f"INSERT INTO {table_name} ({', '.join(col_names)}) VALUES ({placeholders})"
for row in sample_data:
values = [row.get(col) for col in col_names]
cursor.execute(insert_sql, values)
conn.commit()
return conn
def generate_sql_query(user_query: str, groq_api_key: str, schema_info: str) -> SQLQueryGeneration:
"""Generate SQL query using Groq API with schema context"""
try:
client = Groq(api_key=groq_api_key)
enhanced_query = f"""Database Schema:
{schema_info}
User Request: {user_query}
Generate a SQL query that works with the above schema. Use SQLite-compatible syntax."""
response = client.chat.completions.create(
model="moonshotai/kimi-k2-instruct-0905",
messages=[
{
"role": "system",
"content": "You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. Use standard SQL syntax compatible with SQLite. For date operations, use SQLite functions like date('now') and datetime().",
},
{"role": "user", "content": enhanced_query},
],
response_format={
"type": "json_schema",
"json_schema": {
"name": "sql_query_generation",
"schema": SQLQueryGeneration.model_json_schema()
}
}
)
sql_query_generation = SQLQueryGeneration.model_validate(
json.loads(response.choices[0].message.content)
)
return sql_query_generation
except Exception as e:
raise Exception(f"Error generating SQL query: {str(e)}")
def execute_sql_query(conn: sqlite3.Connection, query: str) -> pd.DataFrame:
"""Execute SQL query and return results as DataFrame"""
try:
df = pd.read_sql_query(query, conn)
return df
except Exception as e:
raise Exception(f"Error executing SQL query: {str(e)}")
def format_schema_info(schema_data: dict) -> str:
"""Format schema information for display"""
info = []
for table in schema_data['tables']:
info.append(f"\nTable: {table['table_name']}")
info.append("Columns:")
for col in table['columns']:
info.append(f" - {col['name']} ({col['type']})")
info.append(f"Sample rows: {len(table['sample_data'])}")
return '\n'.join(info)
def process_query(user_query: str, groq_api_key: str):
"""Main processing function"""
if not groq_api_key or not groq_api_key.strip():
return "❌ Please enter your Groq API key", None, "", "", ""
if not user_query or not user_query.strip():
return "❌ Please enter a query", None, "", "", ""
try:
output_log = []
# Step 1: Generate sample data
output_log.append("### Step 1: Generating Sample Database Schema and Data")
output_log.append(f"Query: {user_query}\n")
schema_data = generate_sample_data(user_query, groq_api_key)
schema_info = format_schema_info(schema_data)
output_log.append("βœ… Generated database schema:")
output_log.append(schema_info)
output_log.append("")
# Step 2: Create tables
output_log.append("### Step 2: Creating In-Memory SQLite Database")
conn = create_tables_in_db(schema_data)
output_log.append("βœ… Tables created and populated with sample data\n")
# Display sample data
sample_tables_html = []
for table in schema_data['tables']:
df_sample = pd.DataFrame(table['sample_data'][:10]) # Show first 10 rows
sample_tables_html.append(f"<h4>Sample Data from '{table['table_name']}' (first 10 rows):</h4>")
sample_tables_html.append(df_sample.to_html(index=False, border=1, classes='table table-striped'))
# Step 3: Generate SQL query
output_log.append("### Step 3: Generating SQL Query")
sql_generation = generate_sql_query(user_query, groq_api_key, schema_info)
# Format the SQL generation output
sql_output = {
"query": sql_generation.query,
"query_type": sql_generation.query_type,
"tables_used": sql_generation.tables_used,
"estimated_complexity": sql_generation.estimated_complexity,
"execution_notes": sql_generation.execution_notes,
"validation_status": {
"is_valid": sql_generation.validation_status.is_valid,
"syntax_errors": sql_generation.validation_status.syntax_errors
}
}
sql_output_formatted = sql_output
output_log.append("βœ… SQL Query Generated:\n")
# Step 4: Execute query
output_log.append("\n### Step 4: Executing SQL Query")
output_log.append(f"Executing: {sql_generation.query}\n")
result_df = execute_sql_query(conn, sql_generation.query)
if len(result_df) == 0:
output_log.append("⚠️ Query executed successfully but returned 0 rows")
output_log.append("This might happen if the sample data doesn't match the query criteria.")
result_html = "<p><i>No results found. The query executed successfully but no data matched the criteria.</i></p>"
else:
output_log.append(f"βœ… Query executed successfully! Returned {len(result_df)} row(s)\n")
result_html = f"<h4>Query Results ({len(result_df)} rows):</h4>"
result_html += result_df.to_html(index=False, border=1, classes='table table-striped')
conn.close()
# Combine all outputs
process_log = '\n'.join(output_log)
sample_data_html = '\n'.join(sample_tables_html)
return process_log, sql_output_formatted, sample_data_html, result_html, ""
except Exception as e:
error_msg = f"❌ Error: {str(e)}"
return error_msg, None, "", "", ""
# Custom CSS for better table styling
custom_css = """
.table {
width: 100%;
border-collapse: collapse;
margin: 10px 0;
font-size: 14px;
}
.table th {
background-color: #4a5568;
color: white;
font-weight: bold;
padding: 10px;
text-align: left;
border: 1px solid #2d3748;
}
.table td {
padding: 8px 10px;
border: 1px solid #e2e8f0;
}
.table-striped tbody tr:nth-child(odd) {
background-color: #f7fafc;
}
.table-striped tbody tr:nth-child(even) {
background-color: #ffffff;
}
.table-striped tbody tr:hover {
background-color: #edf2f7;
}
"""
# Gradio Interface
with gr.Blocks(title="SQLGenie - AI SQL Query Generator", theme=gr.themes.Ocean(), css=custom_css) as app:
gr.Markdown("""
# ⚑ SQLGenie - AI SQL Query Generator & Executor
Transform natural language into SQL queries and see instant results! This app:
1. 🎲 Generates realistic sample database tables based on your query
2. πŸ§™ Creates a structured SQL query from natural language using AI
3. βš™οΈ Executes the query on sample data
4. πŸ“Š Shows you the results instantly
### How to use:
1. Enter your Groq API key ([Get one free here](https://console.groq.com/keys))
2. Describe what data you want in plain English
3. Click "Generate & Execute SQL" and watch the magic happen! ✨
""")
with gr.Row():
with gr.Column(scale=2):
api_key_input = gr.Textbox(
label="πŸ”‘ Groq API Key",
placeholder="Enter your Groq API key here...",
type="password"
)
query_input = gr.Textbox(
label="πŸ’¬ Natural Language Query",
placeholder="Example: Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount",
lines=3
)
submit_btn = gr.Button("πŸš€ Generate & Execute SQL", variant="primary", size="lg")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“‹ Process Log")
process_output = gr.Textbox(
label="Execution Steps",
lines=12,
max_lines=20
)
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ—‚οΈ Sample Database Tables")
sample_data_output = gr.HTML(label="Sample Data")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“ Generated SQL Query (Structured Output)")
sql_output = gr.JSON(label="SQL Query Metadata")
with gr.Row():
with gr.Column():
gr.Markdown("### ✨ Query Execution Results")
result_output = gr.HTML(label="Results")
# Examples
gr.Examples(
examples=[
["Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount"],
["List all products that are out of stock along with their supplier information"],
["Show the top 5 employees by total sales in the last quarter"],
["Find all students who scored above 85% in Mathematics and their contact details"],
["Get all active users who haven't logged in for more than 60 days"],
["Show all transactions above $1000 in the last week with customer details"],
["Find employees in the Engineering department with salary over $80000"]
],
inputs=query_input,
label="πŸ’‘ Example Queries - Click to try!"
)
submit_btn.click(
fn=process_query,
inputs=[query_input, api_key_input],
outputs=[process_output, sql_output, sample_data_output, result_output, gr.Textbox(visible=False)]
)
gr.Markdown("""
---
### 🎯 Tips for Best Results:
- Be specific about time periods (e.g., "last 30 days", "last quarter")
- Mention thresholds clearly (e.g., "over $500", "above 85%")
- Specify what fields you want to see (e.g., "show name, email, total")
- The app generates realistic sample data automatically to match your query!
""")
if __name__ == "__main__":
app.launch()