Spaces:
Paused
Paused
| import gradio as gr | |
| from groq import Groq | |
| from pydantic import BaseModel | |
| import json | |
| import sqlite3 | |
| import pandas as pd | |
| from typing import List, Optional | |
| import re | |
| from datetime import datetime, timedelta | |
| import random | |
| # Pydantic Models | |
| class ValidationStatus(BaseModel): | |
| is_valid: bool | |
| syntax_errors: list[str] | |
| class SQLQueryGeneration(BaseModel): | |
| query: str | |
| query_type: str | |
| tables_used: list[str] | |
| estimated_complexity: str | |
| execution_notes: list[str] | |
| validation_status: ValidationStatus | |
| class TableSchema(BaseModel): | |
| table_name: str | |
| columns: list[dict] | |
| sample_data: list[dict] | |
| def generate_sample_data(user_query: str, groq_api_key: str) -> dict: | |
| """Generate sample table schema and data based on user query""" | |
| try: | |
| client = Groq(api_key=groq_api_key) | |
| # Get current date for context | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| past_date_2y = (datetime.now() - timedelta(days=730)).strftime('%Y-%m-%d') # 2 years ago | |
| past_date_60d = (datetime.now() - timedelta(days=60)).strftime('%Y-%m-%d') # 60 days ago | |
| # Request to generate table schema and sample data | |
| schema_prompt = f"""Based on this query: "{user_query}" | |
| **Current date: {today}** | |
| Generate a realistic database schema with sample data. Return ONLY valid JSON with this structure: | |
| {{ | |
| "tables": [ | |
| {{ | |
| "table_name": "table_name", | |
| "columns": [ | |
| {{"name": "column_name", "type": "INTEGER|TEXT|REAL|DATE"}}, | |
| ... | |
| ], | |
| "sample_data": [ | |
| {{"column_name": value, ...}}, | |
| ...at least 20-25 rows | |
| ] | |
| }} | |
| ] | |
| }} | |
| **CRITICAL INSTRUCTIONS FOR REALISTIC DATA:** | |
| 1. **DATES MUST BE IN THE PAST!** | |
| - For hire_date, created_at, registration_date: Use dates between {past_date_2y} and {today} | |
| - For order_date, transaction_date: If query mentions "last X days", use dates between {past_date_60d} and {today} | |
| - NEVER use future dates! | |
| 2. **For NUMERIC filters (salary, amount, price):** | |
| - If query says "over $80000", make 50-60% of records have values ABOVE 80000 | |
| - Create realistic variation: some at 85k, some at 95k, some at 120k, etc. | |
| - Also include records BELOW the threshold (40-50%) for realism | |
| 3. **For TEXT filters (department, category, status):** | |
| - If query mentions "Engineering department", ensure 50-60% of records have department = "Engineering" | |
| - Include other departments too: "Marketing", "Sales", "HR", "Finance" for variety | |
| 4. **Data quality:** | |
| - Use realistic names, emails (first.last@company.com format) | |
| - Make data diverse and meaningful | |
| - Ensure enough records match the query criteria to get meaningful results | |
| Example: For "Find Engineering employees with salary > 80000" | |
| - Create 20+ employee records | |
| - 12-15 should be in Engineering (60%) | |
| - Of Engineering employees, 8-10 should have salary > 80000 | |
| - Include other departments with various salaries for realism""" | |
| response = client.chat.completions.create( | |
| model="moonshotai/kimi-k2-instruct-0905", | |
| messages=[ | |
| {"role": "system", "content": "You are a database expert. Generate realistic table schemas and sample data. ALL DATES MUST BE IN THE PAST, NEVER IN THE FUTURE. Return ONLY valid JSON, no markdown formatting."}, | |
| {"role": "user", "content": schema_prompt} | |
| ], | |
| temperature=0.7 | |
| ) | |
| # Parse response | |
| content = response.choices[0].message.content.strip() | |
| # Remove markdown code blocks if present | |
| content = re.sub(r'```json\s*', '', content) | |
| content = re.sub(r'```\s*$', '', content) | |
| schema_data = json.loads(content) | |
| # Post-process: Enhance and fix data to ensure query results | |
| schema_data = enhance_sample_data(schema_data, user_query) | |
| return schema_data | |
| except Exception as e: | |
| raise Exception(f"Error generating sample data: {str(e)}") | |
| def enhance_sample_data(schema_data: dict, user_query: str) -> dict: | |
| """Enhance sample data to ensure queries return results and fix any date issues""" | |
| query_lower = user_query.lower() | |
| # Detect if query mentions time period (for order/transaction dates) | |
| time_keywords = { | |
| 'last 30 days': 30, | |
| 'last 60 days': 60, | |
| 'last 7 days': 7, | |
| 'last week': 7, | |
| 'last month': 30, | |
| 'last quarter': 90, | |
| 'last year': 365 | |
| } | |
| days_back = None | |
| for keyword, days in time_keywords.items(): | |
| if keyword in query_lower: | |
| days_back = days | |
| break | |
| # Detect amount/value thresholds | |
| threshold_amount = None | |
| amount_match = re.search(r'(?:over|above|greater than) \$?(\d+)', query_lower) | |
| if amount_match: | |
| threshold_amount = int(amount_match.group(1)) | |
| # Detect text filters (department, category, status, etc.) | |
| text_filters = {} | |
| # Department detection | |
| dept_patterns = [ | |
| r'(?:in|from) (?:the )?(\w+) department', | |
| r'department (?:is |= |== )?["\']?(\w+)["\']?', | |
| r'(\w+) department', | |
| ] | |
| for pattern in dept_patterns: | |
| dept_match = re.search(pattern, query_lower) | |
| if dept_match: | |
| text_filters['department'] = dept_match.group(1).capitalize() | |
| break | |
| # Category detection | |
| category_match = re.search(r'category (?:is |= )?["\']?(\w+)["\']?', query_lower) | |
| if category_match: | |
| text_filters['category'] = category_match.group(1).capitalize() | |
| # Status detection | |
| status_match = re.search(r'status (?:is |= )?["\']?(\w+)["\']?', query_lower) | |
| if status_match: | |
| text_filters['status'] = status_match.group(1).capitalize() | |
| for table in schema_data['tables']: | |
| enhanced_data = [] | |
| original_data = table['sample_data'] | |
| # Identify column types | |
| date_cols = [col['name'] for col in table['columns'] if col['type'] == 'DATE'] | |
| amount_cols = [col['name'] for col in table['columns'] | |
| if any(keyword in col['name'].lower() for keyword in ['amount', 'price', 'salary', 'total', 'cost', 'revenue'])] | |
| # Identify order/transaction date columns vs hire/created date columns | |
| transaction_date_cols = [col for col in date_cols | |
| if any(keyword in col.lower() for keyword in ['order', 'transaction', 'purchase', 'sale', 'payment'])] | |
| other_date_cols = [col for col in date_cols if col not in transaction_date_cols] | |
| for i, row in enumerate(original_data): | |
| new_row = row.copy() | |
| # FIX: Ensure transaction/order dates are in the past and within time period if specified | |
| if transaction_date_cols: | |
| for date_col in transaction_date_cols: | |
| if date_col in new_row: | |
| if days_back: | |
| # Within specified period | |
| random_days = random.randint(0, days_back) | |
| else: | |
| # Within last 60 days for transaction-type dates | |
| random_days = random.randint(0, 60) | |
| new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d') | |
| new_row[date_col] = new_date | |
| # FIX: Ensure other dates (hire_date, created_at, etc.) are in the PAST | |
| if other_date_cols: | |
| for date_col in other_date_cols: | |
| if date_col in new_row: | |
| try: | |
| # Check if date is in the future | |
| current_date = datetime.strptime(new_row[date_col], '%Y-%m-%d') | |
| if current_date > datetime.now(): | |
| # Replace with a past date (random between 1 month to 3 years ago) | |
| random_days = random.randint(30, 1095) | |
| new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d') | |
| new_row[date_col] = new_date | |
| except: | |
| # If date parsing fails, generate a new past date | |
| random_days = random.randint(30, 1095) | |
| new_date = (datetime.now() - timedelta(days=random_days)).strftime('%Y-%m-%d') | |
| new_row[date_col] = new_date | |
| # Enhance amount fields to match threshold | |
| if threshold_amount and amount_cols: | |
| for amount_col in amount_cols: | |
| if amount_col in new_row: | |
| # 55% of records above threshold, 45% below | |
| if i % 100 < 55: # More deterministic distribution | |
| # Above threshold | |
| new_row[amount_col] = int(random.uniform(threshold_amount * 1.05, threshold_amount * 2.5)) | |
| else: | |
| # Below threshold | |
| new_row[amount_col] = int(random.uniform(threshold_amount * 0.4, threshold_amount * 0.95)) | |
| # Apply text filters to ensure enough matching records | |
| for col_name, target_value in text_filters.items(): | |
| if col_name in new_row: | |
| # 55% should match the filter value | |
| if i % 100 < 55: | |
| new_row[col_name] = target_value | |
| else: | |
| # Use other values for variety | |
| if col_name == 'department': | |
| other_depts = ['Marketing', 'Sales', 'HR', 'Finance', 'Operations', 'IT'] | |
| new_row[col_name] = random.choice([d for d in other_depts if d != target_value]) | |
| elif col_name == 'status': | |
| other_statuses = ['Active', 'Inactive', 'Pending', 'Completed', 'Cancelled'] | |
| new_row[col_name] = random.choice([s for s in other_statuses if s != target_value]) | |
| enhanced_data.append(new_row) | |
| # Ensure we have at least 20 rows | |
| while len(enhanced_data) < 20: | |
| template_idx = len(enhanced_data) % len(original_data) | |
| template_row = enhanced_data[template_idx].copy() | |
| # Modify IDs to be unique | |
| for col in table['columns']: | |
| if 'id' in col['name'].lower() and col['type'] == 'INTEGER': | |
| template_row[col['name']] = len(enhanced_data) + 1 | |
| enhanced_data.append(template_row) | |
| table['sample_data'] = enhanced_data | |
| return schema_data | |
| def create_tables_in_db(schema_data: dict) -> sqlite3.Connection: | |
| """Create SQLite tables and populate with sample data""" | |
| conn = sqlite3.connect(':memory:') | |
| cursor = conn.cursor() | |
| for table in schema_data['tables']: | |
| table_name = table['table_name'] | |
| columns = table['columns'] | |
| # Create table | |
| column_defs = [] | |
| for col in columns: | |
| col_type = col['type'].upper() | |
| column_defs.append(f"{col['name']} {col_type}") | |
| create_table_sql = f"CREATE TABLE {table_name} ({', '.join(column_defs)})" | |
| cursor.execute(create_table_sql) | |
| # Insert sample data | |
| sample_data = table['sample_data'] | |
| if sample_data: | |
| col_names = [col['name'] for col in columns] | |
| placeholders = ', '.join(['?' for _ in col_names]) | |
| insert_sql = f"INSERT INTO {table_name} ({', '.join(col_names)}) VALUES ({placeholders})" | |
| for row in sample_data: | |
| values = [row.get(col) for col in col_names] | |
| cursor.execute(insert_sql, values) | |
| conn.commit() | |
| return conn | |
| def generate_sql_query(user_query: str, groq_api_key: str, schema_info: str) -> SQLQueryGeneration: | |
| """Generate SQL query using Groq API with schema context""" | |
| try: | |
| client = Groq(api_key=groq_api_key) | |
| enhanced_query = f"""Database Schema: | |
| {schema_info} | |
| User Request: {user_query} | |
| Generate a SQL query that works with the above schema. Use SQLite-compatible syntax.""" | |
| response = client.chat.completions.create( | |
| model="moonshotai/kimi-k2-instruct-0905", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. Use standard SQL syntax compatible with SQLite. For date operations, use SQLite functions like date('now') and datetime().", | |
| }, | |
| {"role": "user", "content": enhanced_query}, | |
| ], | |
| response_format={ | |
| "type": "json_schema", | |
| "json_schema": { | |
| "name": "sql_query_generation", | |
| "schema": SQLQueryGeneration.model_json_schema() | |
| } | |
| } | |
| ) | |
| sql_query_generation = SQLQueryGeneration.model_validate( | |
| json.loads(response.choices[0].message.content) | |
| ) | |
| return sql_query_generation | |
| except Exception as e: | |
| raise Exception(f"Error generating SQL query: {str(e)}") | |
| def execute_sql_query(conn: sqlite3.Connection, query: str) -> pd.DataFrame: | |
| """Execute SQL query and return results as DataFrame""" | |
| try: | |
| df = pd.read_sql_query(query, conn) | |
| return df | |
| except Exception as e: | |
| raise Exception(f"Error executing SQL query: {str(e)}") | |
| def format_schema_info(schema_data: dict) -> str: | |
| """Format schema information for display""" | |
| info = [] | |
| for table in schema_data['tables']: | |
| info.append(f"\nTable: {table['table_name']}") | |
| info.append("Columns:") | |
| for col in table['columns']: | |
| info.append(f" - {col['name']} ({col['type']})") | |
| info.append(f"Sample rows: {len(table['sample_data'])}") | |
| return '\n'.join(info) | |
| def process_query(user_query: str, groq_api_key: str): | |
| """Main processing function""" | |
| if not groq_api_key or not groq_api_key.strip(): | |
| return "β Please enter your Groq API key", None, "", "", "" | |
| if not user_query or not user_query.strip(): | |
| return "β Please enter a query", None, "", "", "" | |
| try: | |
| output_log = [] | |
| # Step 1: Generate sample data | |
| output_log.append("### Step 1: Generating Sample Database Schema and Data") | |
| output_log.append(f"Query: {user_query}\n") | |
| schema_data = generate_sample_data(user_query, groq_api_key) | |
| schema_info = format_schema_info(schema_data) | |
| output_log.append("β Generated database schema:") | |
| output_log.append(schema_info) | |
| output_log.append("") | |
| # Step 2: Create tables | |
| output_log.append("### Step 2: Creating In-Memory SQLite Database") | |
| conn = create_tables_in_db(schema_data) | |
| output_log.append("β Tables created and populated with sample data\n") | |
| # Display sample data | |
| sample_tables_html = [] | |
| for table in schema_data['tables']: | |
| df_sample = pd.DataFrame(table['sample_data'][:10]) # Show first 10 rows | |
| sample_tables_html.append(f"<h4>Sample Data from '{table['table_name']}' (first 10 rows):</h4>") | |
| sample_tables_html.append(df_sample.to_html(index=False, border=1, classes='table table-striped')) | |
| # Step 3: Generate SQL query | |
| output_log.append("### Step 3: Generating SQL Query") | |
| sql_generation = generate_sql_query(user_query, groq_api_key, schema_info) | |
| # Format the SQL generation output | |
| sql_output = { | |
| "query": sql_generation.query, | |
| "query_type": sql_generation.query_type, | |
| "tables_used": sql_generation.tables_used, | |
| "estimated_complexity": sql_generation.estimated_complexity, | |
| "execution_notes": sql_generation.execution_notes, | |
| "validation_status": { | |
| "is_valid": sql_generation.validation_status.is_valid, | |
| "syntax_errors": sql_generation.validation_status.syntax_errors | |
| } | |
| } | |
| sql_output_formatted = sql_output | |
| output_log.append("β SQL Query Generated:\n") | |
| # Step 4: Execute query | |
| output_log.append("\n### Step 4: Executing SQL Query") | |
| output_log.append(f"Executing: {sql_generation.query}\n") | |
| result_df = execute_sql_query(conn, sql_generation.query) | |
| if len(result_df) == 0: | |
| output_log.append("β οΈ Query executed successfully but returned 0 rows") | |
| output_log.append("This might happen if the sample data doesn't match the query criteria.") | |
| result_html = "<p><i>No results found. The query executed successfully but no data matched the criteria.</i></p>" | |
| else: | |
| output_log.append(f"β Query executed successfully! Returned {len(result_df)} row(s)\n") | |
| result_html = f"<h4>Query Results ({len(result_df)} rows):</h4>" | |
| result_html += result_df.to_html(index=False, border=1, classes='table table-striped') | |
| conn.close() | |
| # Combine all outputs | |
| process_log = '\n'.join(output_log) | |
| sample_data_html = '\n'.join(sample_tables_html) | |
| return process_log, sql_output_formatted, sample_data_html, result_html, "" | |
| except Exception as e: | |
| error_msg = f"β Error: {str(e)}" | |
| return error_msg, None, "", "", "" | |
| # Custom CSS for better table styling | |
| custom_css = """ | |
| .table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| margin: 10px 0; | |
| font-size: 14px; | |
| } | |
| .table th { | |
| background-color: #4a5568; | |
| color: white; | |
| font-weight: bold; | |
| padding: 10px; | |
| text-align: left; | |
| border: 1px solid #2d3748; | |
| } | |
| .table td { | |
| padding: 8px 10px; | |
| border: 1px solid #e2e8f0; | |
| } | |
| .table-striped tbody tr:nth-child(odd) { | |
| background-color: #f7fafc; | |
| } | |
| .table-striped tbody tr:nth-child(even) { | |
| background-color: #ffffff; | |
| } | |
| .table-striped tbody tr:hover { | |
| background-color: #edf2f7; | |
| } | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(title="SQLGenie - AI SQL Query Generator", theme=gr.themes.Ocean(), css=custom_css) as app: | |
| gr.Markdown(""" | |
| # β‘ SQLGenie - AI SQL Query Generator & Executor | |
| Transform natural language into SQL queries and see instant results! This app: | |
| 1. π² Generates realistic sample database tables based on your query | |
| 2. π§ Creates a structured SQL query from natural language using AI | |
| 3. βοΈ Executes the query on sample data | |
| 4. π Shows you the results instantly | |
| ### How to use: | |
| 1. Enter your Groq API key ([Get one free here](https://console.groq.com/keys)) | |
| 2. Describe what data you want in plain English | |
| 3. Click "Generate & Execute SQL" and watch the magic happen! β¨ | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| api_key_input = gr.Textbox( | |
| label="π Groq API Key", | |
| placeholder="Enter your Groq API key here...", | |
| type="password" | |
| ) | |
| query_input = gr.Textbox( | |
| label="π¬ Natural Language Query", | |
| placeholder="Example: Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("π Generate & Execute SQL", variant="primary", size="lg") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π Process Log") | |
| process_output = gr.Textbox( | |
| label="Execution Steps", | |
| lines=12, | |
| max_lines=20 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### ποΈ Sample Database Tables") | |
| sample_data_output = gr.HTML(label="Sample Data") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π Generated SQL Query (Structured Output)") | |
| sql_output = gr.JSON(label="SQL Query Metadata") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### β¨ Query Execution Results") | |
| result_output = gr.HTML(label="Results") | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount"], | |
| ["List all products that are out of stock along with their supplier information"], | |
| ["Show the top 5 employees by total sales in the last quarter"], | |
| ["Find all students who scored above 85% in Mathematics and their contact details"], | |
| ["Get all active users who haven't logged in for more than 60 days"], | |
| ["Show all transactions above $1000 in the last week with customer details"], | |
| ["Find employees in the Engineering department with salary over $80000"] | |
| ], | |
| inputs=query_input, | |
| label="π‘ Example Queries - Click to try!" | |
| ) | |
| submit_btn.click( | |
| fn=process_query, | |
| inputs=[query_input, api_key_input], | |
| outputs=[process_output, sql_output, sample_data_output, result_output, gr.Textbox(visible=False)] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π― Tips for Best Results: | |
| - Be specific about time periods (e.g., "last 30 days", "last quarter") | |
| - Mention thresholds clearly (e.g., "over $500", "above 85%") | |
| - Specify what fields you want to see (e.g., "show name, email, total") | |
| - The app generates realistic sample data automatically to match your query! | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() |