import gradio as gr from groq import Groq from pydantic import BaseModel import json import sqlite3 import pandas as pd from typing import List, Optional import re from datetime import datetime, timedelta import random # Pydantic Models class ValidationStatus(BaseModel): is_valid: bool syntax_errors: list[str] class SQLQueryGeneration(BaseModel): query: str query_type: str tables_used: list[str] estimated_complexity: str execution_notes: list[str] validation_status: ValidationStatus class TableSchema(BaseModel): table_name: str columns: list[dict] sample_data: list[dict] def generate_sample_data(user_query: str, groq_api_key: str) -> dict: """Generate sample table schema and data based on user query""" try: client = Groq(api_key=groq_api_key) # Get current date for context today = datetime.now().strftime('%Y-%m-%d') past_date_2y = (datetime.now() - timedelta(days=730)).strftime('%Y-%m-%d') # 2 years ago past_date_60d = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 60 days ago # Request to generate table schema and sample data schema_prompt = f"""Based on this query: "{user_query}" **Current date: {today}** Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure: {{ "tables": [ {{ "table_name": "table_name", "columns": [ {{"name": "column_name", "type": "INTEGER|TEXT|REAL|DATE"}}, ... ], "sample_data": [ {{"column_name": value, ...}}, ...at least 20-25 rows ] }} ] }} **CRITICAL INSTRUCTIONS FOR REALISTIC DATA:** 1. **DATES MUST BE IN THE PAST!** - For hire_date, created_at, registration_date: Use dates between {past_date_2y} and {today} - For order_date, transaction_date: If query mentions "last X days", use dates between {past_date_60d} and {today} - NEVER use future dates! 2. **For NUMERIC filters (salary, amount, price):** - If query says "over $80000", make 50-60% of records have values ABOVE 80000 - Create realistic variation: some at 85k, some at 95k, some at 120k, etc. - Also include records BELOW the threshold (40-50%) for realism 3. **For TEXT filters (department, category, status):** - If query mentions "Engineering department", ensure 50-60% of records have department = "Engineering" - Include other departments too: "Marketing", "Sales", "HR", "Finance" for variety 4. **Data quality:** - Use realistic names, emails (first.last@company.com format) - Make data diverse and meaningful - Ensure enough records match the query criteria to get meaningful results Example: For "Find Engineering employees with salary > 80000" - Create 20+ employee records - 12-15 should be in Engineering (60%) - Of Engineering employees, 8-10 should have salary > 80000 - Include other departments with various salaries for realism""" response = client.chat.completions.create( model="moonshotai/kimi-k2-instruct-0905", messages=[ {"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data. ALL DATES MUST BE IN THE PAST, NEVER IN THE FUTURE. Return ONLY valid JSON, no markdown formatting."}, {"role": "user", "content": schema_prompt} ], temperature=0.7 ) # Parse response content = response.choices[0].message.content.strip() # Remove markdown code blocks if present content = re.sub(r'```json\s*', '', content) content = re.sub(r'```\s*$', '', content) schema_data = json.loads(content) # Post-process: Enhance and fix data to ensure query results schema_data = enhance_sample_data(schema_data, user_query) return schema_data except Exception as e: raise Exception(f"Error generating sample data: {str(e)}") def enhance_sample_data(schema_data: dict, user_query: str) -> dict: """Enhance sample data to ensure queries return results and fix any date issues""" query_lower = user_query.lower() # Detect if query mentions time period (for order/transaction dates) time_keywords = { 'last 30 days': 30, 'last 60 days': 60, 'last 7 days': 7, 'last week': 7, 'last month': 30, 'last quarter': 90, 'last year': 365 } days_back = None for keyword, days in time_keywords.items(): if keyword in query_lower: days_back = days break # Detect amount/value thresholds threshold_amount = None amount_match = re.search(r'(?:over|above|greater than) \$?(\d+)', query_lower) if amount_match: threshold_amount = int(amount_match.group(1)) # Detect text filters (department, category, status, etc.) text_filters = {} # Department detection dept_patterns = [ r'(?:in|from) (?:the )?(\w+) department', r'department (?:is |= |== )?["\']?(\w+)["\']?', r'(\w+) department', ] for pattern in dept_patterns: dept_match = re.search(pattern, query_lower) if dept_match: text_filters['department'] = dept_match.group(1).capitalize() break # Category detection category_match = re.search(r'category (?:is |= )?["\']?(\w+)["\']?', query_lower) if category_match: text_filters['category'] = category_match.group(1).capitalize() # Status detection status_match = re.search(r'status (?:is |= )?["\']?(\w+)["\']?', query_lower) if status_match: text_filters['status'] = status_match.group(1).capitalize() for table in schema_data['tables']: enhanced_data = [] original_data = table['sample_data'] # Identify column types date_cols = [col['name'] for col in table['columns'] if col['type'] == 'DATE'] amount_cols = [col['name'] for col in table['columns'] if any(keyword in col['name'].lower() for keyword in ['amount', 'price', 'salary', 'total', 'cost', 'revenue'])] # Identify order/transaction date columns vs hire/created date columns transaction_date_cols = [col for col in date_cols if any(keyword in col.lower() for keyword in ['order', 'transaction', 'purchase', 'sale', 'payment'])] other_date_cols = [col for col in date_cols if col not in transaction_date_cols] for i, row in enumerate(original_data): new_row = row.copy() # FIX: Ensure transaction/order dates are in the past and within time period if specified if transaction_date_cols: for date_col in transaction_date_cols: if date_col in new_row: if days_back: # Within specified period random_days = random.randint(0, days_back) else: # Within last 60 days for transaction-type dates random_days = random.randint(0, 60) new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d') new_row[date_col] = new_date # FIX: Ensure other dates (hire_date, created_at, etc.) are in the PAST if other_date_cols: for date_col in other_date_cols: if date_col in new_row: try: # Check if date is in the future current_date = datetime.strptime(new_row[date_col], '%Y-%m-%d') if current_date > datetime.now(): # Replace with a past date (random between 1 month to 3 years ago) random_days = random.randint(30, 1095) new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d') new_row[date_col] = new_date except: # If date parsing fails, generate a new past date random_days = random.randint(30, 1095) new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d') new_row[date_col] = new_date # Enhance amount fields to match threshold if threshold_amount and amount_cols: for amount_col in amount_cols: if amount_col in new_row: # 55% of records above threshold, 45% below if i % 100 < 55: # More deterministic distribution # Above threshold new_row[amount_col] = int(random.uniform(threshold_amount * 1.05, threshold_amount * 2.5)) else: # Below threshold new_row[amount_col] = int(random.uniform(threshold_amount * 0.4, threshold_amount * 0.95)) # Apply text filters to ensure enough matching records for col_name, target_value in text_filters.items(): if col_name in new_row: # 55% should match the filter value if i % 100 < 55: new_row[col_name] = target_value else: # Use other values for variety if col_name == 'department': other_depts = ['Marketing', 'Sales', 'HR', 'Finance', 'Operations', 'IT'] new_row[col_name] = random.choice([d for d in other_depts if d != target_value]) elif col_name == 'status': other_statuses = ['Active', 'Inactive', 'Pending', 'Completed', 'Cancelled'] new_row[col_name] = random.choice([s for s in other_statuses if s != target_value]) enhanced_data.append(new_row) # Ensure we have at least 20 rows while len(enhanced_data) < 20: template_idx = len(enhanced_data) % len(original_data) template_row = enhanced_data[template_idx].copy() # Modify IDs to be unique for col in table['columns']: if 'id' in col['name'].lower() and col['type'] == 'INTEGER': template_row[col['name']] = len(enhanced_data) + 1 enhanced_data.append(template_row) table['sample_data'] = enhanced_data return schema_data def create_tables_in_db(schema_data: dict) -> sqlite3.Connection: """Create SQLite tables and populate with sample data""" conn = sqlite3.connect(':memory:') cursor = conn.cursor() for table in schema_data['tables']: table_name = table['table_name'] columns = table['columns'] # Create table column_defs = [] for col in columns: col_type = col['type'].upper() column_defs.append(f"{col['name']} {col_type}") create_table_sql = f"CREATE TABLE {table_name} ({', '.join(column_defs)})" cursor.execute(create_table_sql) # Insert sample data sample_data = table['sample_data'] if sample_data: col_names = [col['name'] for col in columns] placeholders = ', '.join(['?' for _ in col_names]) insert_sql = f"INSERT INTO {table_name} ({', '.join(col_names)}) VALUES ({placeholders})" for row in sample_data: values = [row.get(col) for col in col_names] cursor.execute(insert_sql, values) conn.commit() return conn def generate_sql_query(user_query: str, groq_api_key: str, schema_info: str) -> SQLQueryGeneration: """Generate SQL query using Groq API with schema context""" try: client = Groq(api_key=groq_api_key) enhanced_query = f"""Database Schema: {schema_info} User Request: {user_query} Generate a SQL query that works with the above schema. Use SQLite-compatible syntax.""" response = client.chat.completions.create( model="moonshotai/kimi-k2-instruct-0905", messages=[ { "role": "system", "content": "You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. Use standard SQL syntax compatible with SQLite. For date operations, use SQLite functions like date('now') and datetime().", }, {"role": "user", "content": enhanced_query}, ], response_format={ "type": "json_schema", "json_schema": { "name": "sql_query_generation", "schema": SQLQueryGeneration.model_json_schema() } } ) sql_query_generation = SQLQueryGeneration.model_validate( json.loads(response.choices[0].message.content) ) return sql_query_generation except Exception as e: raise Exception(f"Error generating SQL query: {str(e)}") def execute_sql_query(conn: sqlite3.Connection, query: str) -> pd.DataFrame: """Execute SQL query and return results as DataFrame""" try: df = pd.read_sql_query(query, conn) return df except Exception as e: raise Exception(f"Error executing SQL query: {str(e)}") def format_schema_info(schema_data: dict) -> str: """Format schema information for display""" info = [] for table in schema_data['tables']: info.append(f"\nTable: {table['table_name']}") info.append("Columns:") for col in table['columns']: info.append(f" - {col['name']} ({col['type']})") info.append(f"Sample rows: {len(table['sample_data'])}") return '\n'.join(info) def process_query(user_query: str, groq_api_key: str): """Main processing function""" if not groq_api_key or not groq_api_key.strip(): return "❌ Please enter your Groq API key", None, "", "", "" if not user_query or not user_query.strip(): return "❌ Please enter a query", None, "", "", "" try: output_log = [] # Step 1: Generate sample data output_log.append("### Step 1: Generating Sample Database Schema and Data") output_log.append(f"Query: {user_query}\n") schema_data = generate_sample_data(user_query, groq_api_key) schema_info = format_schema_info(schema_data) output_log.append("✅ Generated database schema:") output_log.append(schema_info) output_log.append("") # Step 2: Create tables output_log.append("### Step 2: Creating In-Memory SQLite Database") conn = create_tables_in_db(schema_data) output_log.append("✅ Tables created and populated with sample data\n") # Display sample data sample_tables_html = [] for table in schema_data['tables']: df_sample = pd.DataFrame(table['sample_data'][:10]) # Show first 10 rows sample_tables_html.append(f"

Sample Data from '{table['table_name']}' (first 10 rows):

") sample_tables_html.append(df_sample.to_html(index=False, border=1, classes='table table-striped')) # Step 3: Generate SQL query output_log.append("### Step 3: Generating SQL Query") sql_generation = generate_sql_query(user_query, groq_api_key, schema_info) # Format the SQL generation output sql_output = { "query": sql_generation.query, "query_type": sql_generation.query_type, "tables_used": sql_generation.tables_used, "estimated_complexity": sql_generation.estimated_complexity, "execution_notes": sql_generation.execution_notes, "validation_status": { "is_valid": sql_generation.validation_status.is_valid, "syntax_errors": sql_generation.validation_status.syntax_errors } } sql_output_formatted = sql_output output_log.append("✅ SQL Query Generated:\n") # Step 4: Execute query output_log.append("\n### Step 4: Executing SQL Query") output_log.append(f"Executing: {sql_generation.query}\n") result_df = execute_sql_query(conn, sql_generation.query) if len(result_df) == 0: output_log.append("⚠️ Query executed successfully but returned 0 rows") output_log.append("This might happen if the sample data doesn't match the query criteria.") result_html = "

No results found. The query executed successfully but no data matched the criteria.

" else: output_log.append(f"✅ Query executed successfully! Returned {len(result_df)} row(s)\n") result_html = f"

Query Results ({len(result_df)} rows):

" result_html += result_df.to_html(index=False, border=1, classes='table table-striped') conn.close() # Combine all outputs process_log = '\n'.join(output_log) sample_data_html = '\n'.join(sample_tables_html) return process_log, sql_output_formatted, sample_data_html, result_html, "" except Exception as e: error_msg = f"❌ Error: {str(e)}" return error_msg, None, "", "", "" # Custom CSS for better table styling custom_css = """ .table { width: 100%; border-collapse: collapse; margin: 10px 0; font-size: 14px; } .table th { background-color: #4a5568; color: white; font-weight: bold; padding: 10px; text-align: left; border: 1px solid #2d3748; } .table td { padding: 8px 10px; border: 1px solid #e2e8f0; } .table-striped tbody tr:nth-child(odd) { background-color: #f7fafc; } .table-striped tbody tr:nth-child(even) { background-color: #ffffff; } .table-striped tbody tr:hover { background-color: #edf2f7; } """ # Gradio Interface with gr.Blocks(title="SQLGenie - AI SQL Query Generator", theme=gr.themes.Ocean(), css=custom_css) as app: gr.Markdown(""" # ⚡ SQLGenie - AI SQL Query Generator & Executor Transform natural language into SQL queries and see instant results! This app: 1. 🎲 Generates realistic sample database tables based on your query 2. 🧙 Creates a structured SQL query from natural language using AI 3. ⚙️ Executes the query on sample data 4. 📊 Shows you the results instantly ### How to use: 1. Enter your Groq API key ([Get one free here](https://console.groq.com/keys)) 2. Describe what data you want in plain English 3. Click "Generate & Execute SQL" and watch the magic happen! ✨ """) with gr.Row(): with gr.Column(scale=2): api_key_input = gr.Textbox( label="🔑 Groq API Key", placeholder="Enter your Groq API key here...", type="password" ) query_input = gr.Textbox( label="💬 Natural Language Query", placeholder="Example: Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount", lines=3 ) submit_btn = gr.Button("🚀 Generate & Execute SQL", variant="primary", size="lg") with gr.Row(): with gr.Column(): gr.Markdown("### 📋 Process Log") process_output = gr.Textbox( label="Execution Steps", lines=12, max_lines=20 ) with gr.Row(): with gr.Column(): gr.Markdown("### 🗂️ Sample Database Tables") sample_data_output = gr.HTML(label="Sample Data") with gr.Row(): with gr.Column(): gr.Markdown("### 📝 Generated SQL Query (Structured Output)") sql_output = gr.JSON(label="SQL Query Metadata") with gr.Row(): with gr.Column(): gr.Markdown("### ✨ Query Execution Results") result_output = gr.HTML(label="Results") # Examples gr.Examples( examples=[ ["Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount"], ["List all products that are out of stock along with their supplier information"], ["Show the top 5 employees by total sales in the last quarter"], ["Find all students who scored above 85% in Mathematics and their contact details"], ["Get all active users who haven't logged in for more than 60 days"], ["Show all transactions above $1000 in the last week with customer details"], ["Find employees in the Engineering department with salary over $80000"] ], inputs=query_input, label="💡 Example Queries - Click to try!" ) submit_btn.click( fn=process_query, inputs=[query_input, api_key_input], outputs=[process_output, sql_output, sample_data_output, result_output, gr.Textbox(visible=False)] ) gr.Markdown(""" --- ### 🎯 Tips for Best Results: - Be specific about time periods (e.g., "last 30 days", "last quarter") - Mention thresholds clearly (e.g., "over $500", "above 85%") - Specify what fields you want to see (e.g., "show name, email, total") - The app generates realistic sample data automatically to match your query! """) if __name__ == "__main__": app.launch()