Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,8 @@ import sqlite3
|
|
| 6 |
import pandas as pd
|
| 7 |
from typing import List, Optional
|
| 8 |
import re
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Pydantic Models
|
| 11 |
class ValidationStatus(BaseModel):
|
|
@@ -30,9 +32,14 @@ def generate_sample_data(user_query: str, groq_api_key: str) -> dict:
|
|
| 30 |
try:
|
| 31 |
client = Groq(api_key=groq_api_key)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
# Request to generate table schema and sample data
|
| 34 |
schema_prompt = f"""Based on this query: "{user_query}"
|
| 35 |
-
|
|
|
|
|
|
|
| 36 |
Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure:
|
| 37 |
{{
|
| 38 |
"tables": [
|
|
@@ -44,18 +51,26 @@ Generate a realistic database schema with sample data. Return ONLY valid JSON wi
|
|
| 44 |
],
|
| 45 |
"sample_data": [
|
| 46 |
{{"column_name": value, ...}},
|
| 47 |
-
...at least
|
| 48 |
]
|
| 49 |
}}
|
| 50 |
]
|
| 51 |
}}
|
| 52 |
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
response = client.chat.completions.create(
|
| 56 |
model="moonshotai/kimi-k2-instruct-0905",
|
| 57 |
messages=[
|
| 58 |
-
{"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data. Return ONLY valid JSON, no markdown formatting."},
|
| 59 |
{"role": "user", "content": schema_prompt}
|
| 60 |
],
|
| 61 |
temperature=0.7
|
|
@@ -68,10 +83,87 @@ Make the data realistic and relevant to the query. Include enough variety to mak
|
|
| 68 |
content = re.sub(r'```\s*$', '', content)
|
| 69 |
|
| 70 |
schema_data = json.loads(content)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
return schema_data
|
| 72 |
except Exception as e:
|
| 73 |
raise Exception(f"Error generating sample data: {str(e)}")
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
def create_tables_in_db(schema_data: dict) -> sqlite3.Connection:
|
| 76 |
"""Create SQLite tables and populate with sample data"""
|
| 77 |
conn = sqlite3.connect(':memory:')
|
|
@@ -114,14 +206,14 @@ def generate_sql_query(user_query: str, groq_api_key: str, schema_info: str) ->
|
|
| 114 |
|
| 115 |
User Request: {user_query}
|
| 116 |
|
| 117 |
-
Generate a SQL query that works with the above schema."""
|
| 118 |
|
| 119 |
response = client.chat.completions.create(
|
| 120 |
model="moonshotai/kimi-k2-instruct-0905",
|
| 121 |
messages=[
|
| 122 |
{
|
| 123 |
"role": "system",
|
| 124 |
-
"content": "You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. Use standard SQL syntax compatible with SQLite.",
|
| 125 |
},
|
| 126 |
{"role": "user", "content": enhanced_query},
|
| 127 |
],
|
|
@@ -190,9 +282,9 @@ def process_query(user_query: str, groq_api_key: str):
|
|
| 190 |
# Display sample data
|
| 191 |
sample_tables_html = []
|
| 192 |
for table in schema_data['tables']:
|
| 193 |
-
df_sample = pd.DataFrame(table['sample_data'][:
|
| 194 |
-
sample_tables_html.append(f"<h4>Sample Data from '{table['table_name']}' (first
|
| 195 |
-
sample_tables_html.append(df_sample.to_html(index=False, border=1))
|
| 196 |
|
| 197 |
# Step 3: Generate SQL query
|
| 198 |
output_log.append("### Step 3: Generating SQL Query")
|
|
@@ -221,11 +313,13 @@ def process_query(user_query: str, groq_api_key: str):
|
|
| 221 |
result_df = execute_sql_query(conn, sql_generation.query)
|
| 222 |
|
| 223 |
if len(result_df) == 0:
|
| 224 |
-
output_log.append("
|
| 225 |
-
|
|
|
|
| 226 |
else:
|
| 227 |
output_log.append(f"β
Query executed successfully! Returned {len(result_df)} row(s)\n")
|
| 228 |
-
result_html = result_df
|
|
|
|
| 229 |
|
| 230 |
conn.close()
|
| 231 |
|
|
@@ -239,32 +333,56 @@ def process_query(user_query: str, groq_api_key: str):
|
|
| 239 |
error_msg = f"β Error: {str(e)}"
|
| 240 |
return error_msg, "", "", "", ""
|
| 241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
# Gradio Interface
|
| 243 |
-
with gr.Blocks(title="SQL Query Generator
|
| 244 |
gr.Markdown("""
|
| 245 |
-
#
|
| 246 |
|
| 247 |
-
|
| 248 |
-
1.
|
| 249 |
-
2.
|
| 250 |
-
3.
|
|
|
|
| 251 |
|
| 252 |
### How to use:
|
| 253 |
-
1. Enter your Groq API key ([Get one here](https://console.groq.com/keys))
|
| 254 |
-
2.
|
| 255 |
-
3. Click "Generate & Execute SQL"
|
| 256 |
""")
|
| 257 |
|
| 258 |
with gr.Row():
|
| 259 |
with gr.Column(scale=2):
|
| 260 |
api_key_input = gr.Textbox(
|
| 261 |
-
label="Groq API Key",
|
| 262 |
placeholder="Enter your Groq API key here...",
|
| 263 |
type="password"
|
| 264 |
)
|
| 265 |
|
| 266 |
query_input = gr.Textbox(
|
| 267 |
-
label="Natural Language Query",
|
| 268 |
placeholder="Example: Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount",
|
| 269 |
lines=3
|
| 270 |
)
|
|
@@ -302,10 +420,12 @@ with gr.Blocks(title="SQL Query Generator & Executor", theme=gr.themes.Soft()) a
|
|
| 302 |
["List all products that are out of stock along with their supplier information"],
|
| 303 |
["Show the top 5 employees by total sales in the last quarter"],
|
| 304 |
["Find all students who scored above 85% in Mathematics and their contact details"],
|
| 305 |
-
["Get all active users who haven't logged in for more than 60 days"]
|
|
|
|
|
|
|
| 306 |
],
|
| 307 |
inputs=query_input,
|
| 308 |
-
label="Example Queries"
|
| 309 |
)
|
| 310 |
|
| 311 |
submit_btn.click(
|
|
@@ -313,6 +433,15 @@ with gr.Blocks(title="SQL Query Generator & Executor", theme=gr.themes.Soft()) a
|
|
| 313 |
inputs=[query_input, api_key_input],
|
| 314 |
outputs=[process_output, sql_output, sample_data_output, result_output, gr.Textbox(visible=False)]
|
| 315 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
|
| 317 |
if __name__ == "__main__":
|
| 318 |
app.launch()
|
|
|
|
| 6 |
import pandas as pd
|
| 7 |
from typing import List, Optional
|
| 8 |
import re
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
import random
|
| 11 |
|
| 12 |
# Pydantic Models
|
| 13 |
class ValidationStatus(BaseModel):
|
|
|
|
| 32 |
try:
|
| 33 |
client = Groq(api_key=groq_api_key)
|
| 34 |
|
| 35 |
+
# Get current date for context
|
| 36 |
+
today = datetime.now().strftime('%Y-%m-%d')
|
| 37 |
+
|
| 38 |
# Request to generate table schema and sample data
|
| 39 |
schema_prompt = f"""Based on this query: "{user_query}"
|
| 40 |
+
|
| 41 |
+
Current date: {today}
|
| 42 |
+
|
| 43 |
Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure:
|
| 44 |
{{
|
| 45 |
"tables": [
|
|
|
|
| 51 |
],
|
| 52 |
"sample_data": [
|
| 53 |
{{"column_name": value, ...}},
|
| 54 |
+
...at least 15-20 rows
|
| 55 |
]
|
| 56 |
}}
|
| 57 |
]
|
| 58 |
}}
|
| 59 |
|
| 60 |
+
IMPORTANT INSTRUCTIONS FOR REALISTIC DATA:
|
| 61 |
+
1. For DATE columns: Use dates in format 'YYYY-MM-DD'. Include dates from the last 60 days to ensure some fall within "last 30 days"
|
| 62 |
+
2. For queries mentioning "last X days": Generate at least 50% of dates within that timeframe
|
| 63 |
+
3. For queries with amount/price filters (e.g., "over $500"): Ensure at least 40% of records meet the criteria
|
| 64 |
+
4. For queries with thresholds: Create data both above AND below the threshold
|
| 65 |
+
5. Use realistic names, emails, and values
|
| 66 |
+
6. Make sure there's enough variety in the data to produce meaningful query results
|
| 67 |
+
|
| 68 |
+
Example date range: from {(datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d')} to {today}"""
|
| 69 |
|
| 70 |
response = client.chat.completions.create(
|
| 71 |
model="moonshotai/kimi-k2-instruct-0905",
|
| 72 |
messages=[
|
| 73 |
+
{"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data that will produce meaningful query results. Return ONLY valid JSON, no markdown formatting."},
|
| 74 |
{"role": "user", "content": schema_prompt}
|
| 75 |
],
|
| 76 |
temperature=0.7
|
|
|
|
| 83 |
content = re.sub(r'```\s*$', '', content)
|
| 84 |
|
| 85 |
schema_data = json.loads(content)
|
| 86 |
+
|
| 87 |
+
# Post-process: Enhance data to ensure query results
|
| 88 |
+
schema_data = enhance_sample_data(schema_data, user_query)
|
| 89 |
+
|
| 90 |
return schema_data
|
| 91 |
except Exception as e:
|
| 92 |
raise Exception(f"Error generating sample data: {str(e)}")
|
| 93 |
|
| 94 |
+
def enhance_sample_data(schema_data: dict, user_query: str) -> dict:
|
| 95 |
+
"""Enhance sample data to ensure queries return results"""
|
| 96 |
+
|
| 97 |
+
# Detect if query mentions time period
|
| 98 |
+
time_keywords = {
|
| 99 |
+
'last 30 days': 30,
|
| 100 |
+
'last 60 days': 60,
|
| 101 |
+
'last 7 days': 7,
|
| 102 |
+
'last week': 7,
|
| 103 |
+
'last month': 30,
|
| 104 |
+
'last quarter': 90,
|
| 105 |
+
'last year': 365
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
days_back = None
|
| 109 |
+
for keyword, days in time_keywords.items():
|
| 110 |
+
if keyword in user_query.lower():
|
| 111 |
+
days_back = days
|
| 112 |
+
break
|
| 113 |
+
|
| 114 |
+
# Detect amount/value thresholds
|
| 115 |
+
amount_pattern = r'\$?(\d+)'
|
| 116 |
+
amount_match = re.search(r'over \$?(\d+)', user_query.lower())
|
| 117 |
+
threshold_amount = int(amount_match.group(1)) if amount_match else None
|
| 118 |
+
|
| 119 |
+
for table in schema_data['tables']:
|
| 120 |
+
enhanced_data = []
|
| 121 |
+
original_data = table['sample_data']
|
| 122 |
+
|
| 123 |
+
# Find date and amount columns
|
| 124 |
+
date_cols = [col['name'] for col in table['columns'] if col['type'] == 'DATE']
|
| 125 |
+
amount_cols = [col['name'] for col in table['columns'] if 'amount' in col['name'].lower() or 'price' in col['name'].lower() or 'total' in col['name'].lower()]
|
| 126 |
+
|
| 127 |
+
for row in original_data:
|
| 128 |
+
new_row = row.copy()
|
| 129 |
+
|
| 130 |
+
# Enhance date fields to be within the time period
|
| 131 |
+
if days_back and date_cols:
|
| 132 |
+
for date_col in date_cols:
|
| 133 |
+
if date_col in new_row:
|
| 134 |
+
# Generate random date within the period
|
| 135 |
+
random_days = random.randint(0, days_back)
|
| 136 |
+
new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d')
|
| 137 |
+
new_row[date_col] = new_date
|
| 138 |
+
|
| 139 |
+
# Enhance amount fields to exceed threshold
|
| 140 |
+
if threshold_amount and amount_cols:
|
| 141 |
+
for amount_col in amount_cols:
|
| 142 |
+
if amount_col in new_row:
|
| 143 |
+
# 60% of records above threshold, 40% below
|
| 144 |
+
if random.random() < 0.6:
|
| 145 |
+
new_row[amount_col] = round(random.uniform(threshold_amount * 1.1, threshold_amount * 3), 2)
|
| 146 |
+
else:
|
| 147 |
+
new_row[amount_col] = round(random.uniform(threshold_amount * 0.3, threshold_amount * 0.9), 2)
|
| 148 |
+
|
| 149 |
+
enhanced_data.append(new_row)
|
| 150 |
+
|
| 151 |
+
# Add more rows if needed (ensure at least 15 rows)
|
| 152 |
+
while len(enhanced_data) < 15:
|
| 153 |
+
# Duplicate and modify existing rows
|
| 154 |
+
template_row = enhanced_data[len(enhanced_data) % len(original_data)].copy()
|
| 155 |
+
|
| 156 |
+
# Modify IDs to be unique
|
| 157 |
+
for col in table['columns']:
|
| 158 |
+
if 'id' in col['name'].lower() and col['type'] == 'INTEGER':
|
| 159 |
+
template_row[col['name']] = len(enhanced_data) + 1
|
| 160 |
+
|
| 161 |
+
enhanced_data.append(template_row)
|
| 162 |
+
|
| 163 |
+
table['sample_data'] = enhanced_data
|
| 164 |
+
|
| 165 |
+
return schema_data
|
| 166 |
+
|
| 167 |
def create_tables_in_db(schema_data: dict) -> sqlite3.Connection:
|
| 168 |
"""Create SQLite tables and populate with sample data"""
|
| 169 |
conn = sqlite3.connect(':memory:')
|
|
|
|
| 206 |
|
| 207 |
User Request: {user_query}
|
| 208 |
|
| 209 |
+
Generate a SQL query that works with the above schema. Use SQLite-compatible syntax."""
|
| 210 |
|
| 211 |
response = client.chat.completions.create(
|
| 212 |
model="moonshotai/kimi-k2-instruct-0905",
|
| 213 |
messages=[
|
| 214 |
{
|
| 215 |
"role": "system",
|
| 216 |
+
"content": "You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. Use standard SQL syntax compatible with SQLite. For date operations, use SQLite functions like date('now') and datetime().",
|
| 217 |
},
|
| 218 |
{"role": "user", "content": enhanced_query},
|
| 219 |
],
|
|
|
|
| 282 |
# Display sample data
|
| 283 |
sample_tables_html = []
|
| 284 |
for table in schema_data['tables']:
|
| 285 |
+
df_sample = pd.DataFrame(table['sample_data'][:10]) # Show first 10 rows
|
| 286 |
+
sample_tables_html.append(f"<h4>Sample Data from '{table['table_name']}' (first 10 rows):</h4>")
|
| 287 |
+
sample_tables_html.append(df_sample.to_html(index=False, border=1, classes='table table-striped'))
|
| 288 |
|
| 289 |
# Step 3: Generate SQL query
|
| 290 |
output_log.append("### Step 3: Generating SQL Query")
|
|
|
|
| 313 |
result_df = execute_sql_query(conn, sql_generation.query)
|
| 314 |
|
| 315 |
if len(result_df) == 0:
|
| 316 |
+
output_log.append("β οΈ Query executed successfully but returned 0 rows")
|
| 317 |
+
output_log.append("This might happen if the sample data doesn't match the query criteria.")
|
| 318 |
+
result_html = "<p><i>No results found. The query executed successfully but no data matched the criteria.</i></p>"
|
| 319 |
else:
|
| 320 |
output_log.append(f"β
Query executed successfully! Returned {len(result_df)} row(s)\n")
|
| 321 |
+
result_html = f"<h4>Query Results ({len(result_df)} rows):</h4>"
|
| 322 |
+
result_html += result_df.to_html(index=False, border=1, classes='table table-striped')
|
| 323 |
|
| 324 |
conn.close()
|
| 325 |
|
|
|
|
| 333 |
error_msg = f"β Error: {str(e)}"
|
| 334 |
return error_msg, "", "", "", ""
|
| 335 |
|
| 336 |
+
# Custom CSS for better table styling
|
| 337 |
+
custom_css = """
|
| 338 |
+
.table {
|
| 339 |
+
width: 100%;
|
| 340 |
+
border-collapse: collapse;
|
| 341 |
+
margin: 10px 0;
|
| 342 |
+
}
|
| 343 |
+
.table th {
|
| 344 |
+
background-color: #f0f0f0;
|
| 345 |
+
font-weight: bold;
|
| 346 |
+
padding: 8px;
|
| 347 |
+
text-align: left;
|
| 348 |
+
border: 1px solid #ddd;
|
| 349 |
+
}
|
| 350 |
+
.table td {
|
| 351 |
+
padding: 8px;
|
| 352 |
+
border: 1px solid #ddd;
|
| 353 |
+
}
|
| 354 |
+
.table-striped tbody tr:nth-child(odd) {
|
| 355 |
+
background-color: #f9f9f9;
|
| 356 |
+
}
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
# Gradio Interface
|
| 360 |
+
with gr.Blocks(title="SQLGenie - AI SQL Query Generator", theme=gr.themes.Ocean(), css=custom_css) as app:
|
| 361 |
gr.Markdown("""
|
| 362 |
+
# β‘ SQLGenie - AI SQL Query Generator & Executor
|
| 363 |
|
| 364 |
+
Transform natural language into SQL queries and see instant results! This app:
|
| 365 |
+
1. π² Generates realistic sample database tables based on your query
|
| 366 |
+
2. π§ Creates a structured SQL query from natural language using AI
|
| 367 |
+
3. βοΈ Executes the query on sample data
|
| 368 |
+
4. π Shows you the results instantly
|
| 369 |
|
| 370 |
### How to use:
|
| 371 |
+
1. Enter your Groq API key ([Get one free here](https://console.groq.com/keys))
|
| 372 |
+
2. Describe what data you want in plain English
|
| 373 |
+
3. Click "Generate & Execute SQL" and watch the magic happen! β¨
|
| 374 |
""")
|
| 375 |
|
| 376 |
with gr.Row():
|
| 377 |
with gr.Column(scale=2):
|
| 378 |
api_key_input = gr.Textbox(
|
| 379 |
+
label="π Groq API Key",
|
| 380 |
placeholder="Enter your Groq API key here...",
|
| 381 |
type="password"
|
| 382 |
)
|
| 383 |
|
| 384 |
query_input = gr.Textbox(
|
| 385 |
+
label="π¬ Natural Language Query",
|
| 386 |
placeholder="Example: Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount",
|
| 387 |
lines=3
|
| 388 |
)
|
|
|
|
| 420 |
["List all products that are out of stock along with their supplier information"],
|
| 421 |
["Show the top 5 employees by total sales in the last quarter"],
|
| 422 |
["Find all students who scored above 85% in Mathematics and their contact details"],
|
| 423 |
+
["Get all active users who haven't logged in for more than 60 days"],
|
| 424 |
+
["Show all transactions above $1000 in the last week with customer details"],
|
| 425 |
+
["Find employees in the Engineering department with salary over $80000"]
|
| 426 |
],
|
| 427 |
inputs=query_input,
|
| 428 |
+
label="π‘ Example Queries - Click to try!"
|
| 429 |
)
|
| 430 |
|
| 431 |
submit_btn.click(
|
|
|
|
| 433 |
inputs=[query_input, api_key_input],
|
| 434 |
outputs=[process_output, sql_output, sample_data_output, result_output, gr.Textbox(visible=False)]
|
| 435 |
)
|
| 436 |
+
|
| 437 |
+
gr.Markdown("""
|
| 438 |
+
---
|
| 439 |
+
### π― Tips for Best Results:
|
| 440 |
+
- Be specific about time periods (e.g., "last 30 days", "last quarter")
|
| 441 |
+
- Mention thresholds clearly (e.g., "over $500", "above 85%")
|
| 442 |
+
- Specify what fields you want to see (e.g., "show name, email, total")
|
| 443 |
+
- The app generates realistic sample data automatically to match your query!
|
| 444 |
+
""")
|
| 445 |
|
| 446 |
if __name__ == "__main__":
|
| 447 |
app.launch()
|