Spaces:
Paused
Paused
| import gradio as gr | |
| from groq import Groq | |
| from pydantic import BaseModel | |
| import json | |
| import sqlite3 | |
| import pandas as pd | |
| from datetime import datetime, timedelta | |
| import random | |
| import re | |
| # Pydantic models for structured output | |
| class ValidationStatus(BaseModel): | |
| is_valid: bool | |
| syntax_errors: list[str] | |
| class SQLQueryGeneration(BaseModel): | |
| query: str | |
| query_type: str | |
| tables_used: list[str] | |
| estimated_complexity: str | |
| execution_notes: list[str] | |
| validation_status: ValidationStatus | |
| # Enhanced data generators for ANY table type | |
| def generate_generic_table_data(table_name, row_count=15): | |
| """Generate sample data for ANY table based on common patterns""" | |
| # Define field generators | |
| def gen_id(): | |
| return list(range(1, row_count + 1)) | |
| def gen_names(): | |
| first = ["Alice", "Bob", "Carol", "David", "Emma", "Frank", "Grace", "Henry", "Ivy", "Jack", | |
| "Karen", "Leo", "Maria", "Nathan", "Olivia"] | |
| last = ["Johnson", "Smith", "Williams", "Brown", "Jones", "Garcia", "Miller", "Davis", | |
| "Rodriguez", "Martinez", "Anderson", "Taylor", "Thomas", "Moore", "Jackson"] | |
| return [f"{random.choice(first)} {random.choice(last)}" for _ in range(row_count)] | |
| def gen_emails(names=None): | |
| if names: | |
| return [f"{name.lower().replace(' ', '.')}@example.com" for name in names] | |
| return [f"user{i}@example.com" for i in range(1, row_count + 1)] | |
| def gen_dates(days_back=365): | |
| base = datetime.now() | |
| return [(base - timedelta(days=random.randint(0, days_back))).strftime('%Y-%m-%d') | |
| for _ in range(row_count)] | |
| def gen_amounts(): | |
| return [round(random.uniform(100, 5000), 2) for _ in range(row_count)] | |
| def gen_salaries(): | |
| return [random.choice([45000, 55000, 65000, 75000, 85000, 95000, 105000, 120000]) | |
| for _ in range(row_count)] | |
| def gen_prices(): | |
| return [round(random.uniform(10, 1000), 2) for _ in range(row_count)] | |
| def gen_quantities(): | |
| return [random.randint(0, 100) for _ in range(row_count)] | |
| def gen_ratings(): | |
| return [round(random.uniform(1, 10), 1) for _ in range(row_count)] | |
| def gen_scores(): | |
| return [random.randint(60, 100) for _ in range(row_count)] | |
| def gen_ages(): | |
| return [random.randint(18, 80) for _ in range(row_count)] | |
| def gen_boolean(): | |
| return [random.choice([True, False, True, True]) for _ in range(row_count)] | |
| def gen_status(): | |
| return [random.choice(['Active', 'Inactive', 'Pending', 'Active', 'Active']) | |
| for _ in range(row_count)] | |
| # Table-specific schemas with intelligent field detection | |
| table_schemas = { | |
| 'employees': { | |
| 'employee_id': gen_id(), | |
| 'name': gen_names(), | |
| 'email': gen_emails(gen_names()), | |
| 'department_id': [random.randint(1, 5) for _ in range(row_count)], | |
| 'salary': gen_salaries(), | |
| 'hire_date': gen_dates(1825), | |
| 'position': [random.choice(['Engineer', 'Manager', 'Analyst', 'Developer', 'Designer']) | |
| for _ in range(row_count)] | |
| }, | |
| 'departments': lambda: { | |
| 'id': list(range(1, 6)), | |
| 'name': ['Engineering', 'Sales', 'Marketing', 'HR', 'Finance'], | |
| 'manager_id': [random.randint(1, 15) for _ in range(5)], | |
| 'budget': [random.randint(100000, 1000000) for _ in range(5)] | |
| }, | |
| 'books': { | |
| 'book_id': gen_id(), | |
| 'title': [f"Book Title {i}" for i in range(1, row_count + 1)], | |
| 'author': gen_names(), | |
| 'publication_year': [random.randint(2000, 2025) for _ in range(row_count)], | |
| 'isbn': [f"978-{random.randint(1000000000, 9999999999)}" for _ in range(row_count)], | |
| 'available': gen_boolean(), | |
| 'category': [random.choice(['Fiction', 'Science', 'History', 'Technology', 'Arts']) | |
| for _ in range(row_count)] | |
| }, | |
| 'students': { | |
| 'student_id': gen_id(), | |
| 'name': gen_names(), | |
| 'email': gen_emails(gen_names()), | |
| 'age': [random.randint(18, 25) for _ in range(row_count)], | |
| 'major': [random.choice(['Computer Science', 'Engineering', 'Business', 'Mathematics', 'Physics']) | |
| for _ in range(row_count)], | |
| 'gpa': [round(random.uniform(2.5, 4.0), 2) for _ in range(row_count)], | |
| 'enrollment_year': [random.randint(2020, 2025) for _ in range(row_count)] | |
| }, | |
| 'courses': { | |
| 'course_id': gen_id(), | |
| 'course_name': [f"Course {i}" for i in range(1, row_count + 1)], | |
| 'subject': [random.choice(['Mathematics', 'Computer Science', 'Physics', 'Chemistry']) | |
| for _ in range(row_count)], | |
| 'credits': [random.choice([3, 4, 5]) for _ in range(row_count)], | |
| 'instructor': gen_names() | |
| }, | |
| 'grades': { | |
| 'grade_id': gen_id(), | |
| 'student_id': [random.randint(1, 15) for _ in range(row_count)], | |
| 'course_id': [random.randint(1, 15) for _ in range(row_count)], | |
| 'score': gen_scores(), | |
| 'grade_date': gen_dates(180) | |
| }, | |
| 'items': { | |
| 'item_id': gen_id(), | |
| 'item_name': [f"Item {i}" for i in range(1, row_count + 1)], | |
| 'category': [random.choice(['Electronics', 'Furniture', 'Supplies', 'Equipment']) | |
| for _ in range(row_count)], | |
| 'stock_level': gen_quantities(), | |
| 'reorder_point': [random.randint(10, 30) for _ in range(row_count)], | |
| 'price': gen_prices() | |
| }, | |
| 'movies': { | |
| 'movie_id': gen_id(), | |
| 'title': [f"Movie Title {i}" for i in range(1, row_count + 1)], | |
| 'director': gen_names(), | |
| 'release_year': [random.randint(2015, 2025) for _ in range(row_count)], | |
| 'rating': gen_ratings(), | |
| 'genre': [random.choice(['Action', 'Drama', 'Comedy', 'Sci-Fi', 'Thriller']) | |
| for _ in range(row_count)], | |
| 'duration_minutes': [random.randint(90, 180) for _ in range(row_count)] | |
| }, | |
| 'patients': { | |
| 'patient_id': gen_id(), | |
| 'name': gen_names(), | |
| 'age': gen_ages(), | |
| 'email': gen_emails(gen_names()), | |
| 'phone': [f"+1-555-{random.randint(1000, 9999)}" for _ in range(row_count)], | |
| 'last_visit': gen_dates(90), | |
| 'condition': [random.choice(['Diabetes', 'Hypertension', 'Asthma', 'Healthy']) | |
| for _ in range(row_count)] | |
| }, | |
| 'appointments': { | |
| 'appointment_id': gen_id(), | |
| 'patient_id': [random.randint(1, 15) for _ in range(row_count)], | |
| 'doctor_name': gen_names(), | |
| 'appointment_date': gen_dates(60), | |
| 'status': [random.choice(['Scheduled', 'Completed', 'Cancelled']) for _ in range(row_count)] | |
| }, | |
| 'properties': { | |
| 'property_id': gen_id(), | |
| 'address': [f"{random.randint(100, 9999)} Main St" for _ in range(row_count)], | |
| 'city': [random.choice(['Downtown', 'Suburbs', 'Uptown', 'Eastside']) for _ in range(row_count)], | |
| 'price': [random.randint(150000, 800000) for _ in range(row_count)], | |
| 'bedrooms': [random.randint(1, 5) for _ in range(row_count)], | |
| 'bathrooms': [random.randint(1, 3) for _ in range(row_count)], | |
| 'sqft': [random.randint(800, 3500) for _ in range(row_count)], | |
| 'status': [random.choice(['Available', 'Sold', 'Pending']) for _ in range(row_count)] | |
| }, | |
| 'events': { | |
| 'event_id': gen_id(), | |
| 'event_name': [f"Event {i}" for i in range(1, row_count + 1)], | |
| 'event_date': [datetime(2026, 1, random.randint(1, 31)).strftime('%Y-%m-%d') | |
| for _ in range(row_count)], | |
| 'location': [random.choice(['Hall A', 'Conference Room', 'Auditorium', 'Stadium']) | |
| for _ in range(row_count)], | |
| 'attendees': [random.randint(10, 200) for _ in range(row_count)], | |
| 'status': [random.choice(['Upcoming', 'Completed', 'Cancelled']) for _ in range(row_count)] | |
| }, | |
| 'dishes': { | |
| 'dish_id': gen_id(), | |
| 'dish_name': [f"Dish {i}" for i in range(1, row_count + 1)], | |
| 'category': [random.choice(['Appetizer', 'Main Course', 'Dessert', 'Beverage']) | |
| for _ in range(row_count)], | |
| 'price': [round(random.uniform(5, 50), 2) for _ in range(row_count)], | |
| 'preparation_time': [random.randint(10, 60) for _ in range(row_count)] | |
| }, | |
| 'orders': { | |
| 'order_id': gen_id(), | |
| 'customer_id': [random.randint(1, 15) for _ in range(row_count)], | |
| 'dish_id': [random.randint(1, 15) for _ in range(row_count)], | |
| 'quantity': [random.randint(1, 5) for _ in range(row_count)], | |
| 'order_date': gen_dates(30), | |
| 'total_amount': gen_amounts() | |
| }, | |
| 'members': { | |
| 'member_id': gen_id(), | |
| 'name': gen_names(), | |
| 'email': gen_emails(gen_names()), | |
| 'membership_type': [random.choice(['Basic', 'Premium', 'VIP']) for _ in range(row_count)], | |
| 'join_date': gen_dates(730), | |
| 'expiry_date': [(datetime.now() + timedelta(days=random.randint(-30, 90))).strftime('%Y-%m-%d') | |
| for _ in range(row_count)], | |
| 'status': [random.choice(['Active', 'Active', 'Active', 'Inactive']) for _ in range(row_count)] | |
| }, | |
| 'customers': { | |
| 'customer_id': gen_id(), | |
| 'name': gen_names(), | |
| 'email': gen_emails(gen_names()), | |
| 'phone': [f"+1-555-{random.randint(1000, 9999)}" for _ in range(row_count)], | |
| 'registration_date': gen_dates(365), | |
| 'status': gen_status() | |
| }, | |
| 'products': { | |
| 'product_id': gen_id(), | |
| 'product_name': [f"Product {i}" for i in range(1, row_count + 1)], | |
| 'category': [random.choice(['Electronics', 'Clothing', 'Home', 'Sports', 'Books']) | |
| for _ in range(row_count)], | |
| 'price': gen_prices(), | |
| 'stock_quantity': gen_quantities(), | |
| 'supplier_id': [random.randint(1, 5) for _ in range(row_count)] | |
| } | |
| } | |
| # Return predefined schema if exists, otherwise create generic one | |
| table_lower = table_name.lower() | |
| if table_lower in table_schemas: | |
| schema = table_schemas[table_lower] | |
| # If it's a callable (lambda), execute it | |
| if callable(schema): | |
| return schema() | |
| return schema | |
| # Generic fallback for unknown tables | |
| generic_data = { | |
| f'{table_name}_id': gen_id(), | |
| 'name': gen_names(), | |
| 'created_date': gen_dates(), | |
| 'status': gen_status(), | |
| 'value': gen_amounts() | |
| } | |
| return generic_data | |
| def create_database_from_tables(tables_used): | |
| """Create SQLite database with sample data for ALL tables mentioned""" | |
| conn = sqlite3.connect(':memory:') | |
| cursor = conn.cursor() | |
| sample_data = {} | |
| # Generate data for each table mentioned | |
| for table in tables_used: | |
| table_name = table.lower().strip() | |
| # Generate appropriate sample data | |
| # Special handling for departments (only 5 rows) | |
| if table_name == 'departments': | |
| table_dict = generate_generic_table_data(table_name, row_count=5) | |
| else: | |
| table_dict = generate_generic_table_data(table_name, row_count=15) | |
| df = pd.DataFrame(table_dict) | |
| df.to_sql(table_name, conn, index=False, if_exists='replace') | |
| sample_data[table_name] = df | |
| return conn, sample_data | |
| def execute_sql_on_sample_data(sql_query, conn): | |
| """Execute the generated SQL query on sample database""" | |
| try: | |
| df_result = pd.read_sql_query(sql_query, conn) | |
| return df_result, None | |
| except Exception as e: | |
| return None, str(e) | |
| def process_nl_query(api_key, natural_query): | |
| """Main function to process natural language query""" | |
| if not api_key: | |
| return "β Please enter your Groq API key", "", pd.DataFrame(), "" | |
| if not natural_query: | |
| return "β Please enter a natural language query", "", pd.DataFrame(), "" | |
| try: | |
| # Initialize Groq client | |
| client = Groq(api_key=api_key) | |
| # Step 1: Generate SQL from natural language | |
| output_text = "## π STEP-BY-STEP PROCESS\n\n" | |
| output_text += "### Step 1: Understanding User Intent\n" | |
| output_text += f"**User Query:** {natural_query}\n\n" | |
| # Call Groq API for SQL generation with Kimi model | |
| response = client.chat.completions.create( | |
| model="moonshotai/kimi-k2-instruct-0905", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": """You are a SQL expert. Generate structured SQL queries from natural language descriptions with proper syntax validation and metadata. | |
| IMPORTANT: Return your response in JSON format with the following structure: | |
| { | |
| "query": "SQL query string", | |
| "query_type": "SELECT/INSERT/UPDATE/DELETE", | |
| "tables_used": ["table1", "table2"], | |
| "estimated_complexity": "low/medium/high", | |
| "execution_notes": ["note1", "note2"], | |
| "validation_status": { | |
| "is_valid": true/false, | |
| "syntax_errors": [] | |
| } | |
| } | |
| Use standard SQL syntax compatible with SQLite. | |
| - Always use proper JOINs when multiple tables are involved | |
| - Use WHERE clauses for filtering | |
| - Use GROUP BY for aggregations | |
| - For date comparisons, use date('now') and datetime functions | |
| - Extract ALL table names mentioned or implied in the query and list them in "tables_used" | |
| - If a query mentions departments and employees, include BOTH tables | |
| - Be thorough in identifying all tables needed for the query""", | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Convert this natural language query to SQL and return as JSON: {natural_query}" | |
| }, | |
| ], | |
| response_format={ | |
| "type": "json_object" | |
| }, | |
| temperature=0.3 | |
| ) | |
| # Parse the response | |
| response_content = response.choices[0].message.content | |
| sql_data = json.loads(response_content) | |
| # Try to map to our Pydantic model with better error handling | |
| try: | |
| sql_query_gen = SQLQueryGeneration(**sql_data) | |
| except Exception as e: | |
| # If response doesn't match exact schema, create it manually | |
| sql_query_gen = SQLQueryGeneration( | |
| query=sql_data.get('query', sql_data.get('sql_query', '')), | |
| query_type=sql_data.get('query_type', 'SELECT'), | |
| tables_used=sql_data.get('tables_used', sql_data.get('tables', [])), | |
| estimated_complexity=sql_data.get('estimated_complexity', 'medium'), | |
| execution_notes=sql_data.get('execution_notes', sql_data.get('notes', [])), | |
| validation_status=ValidationStatus( | |
| is_valid=sql_data.get('validation_status', {}).get('is_valid', True), | |
| syntax_errors=sql_data.get('validation_status', {}).get('syntax_errors', []) | |
| ) | |
| ) | |
| # Step 2: Display Structured SQL Output | |
| output_text += "### Step 2: Generated Structured SQL\n\n" | |
| output_text += "```json\n" | |
| output_text += json.dumps(sql_query_gen.model_dump(), indent=2) | |
| output_text += "\n```\n\n" | |
| # Step 3: Generate Sample Database Tables | |
| output_text += "### Step 3: Auto-Generated Sample Database Tables\n\n" | |
| output_text += f"**Tables to be created:** {', '.join(sql_query_gen.tables_used)}\n\n" | |
| conn, sample_data = create_database_from_tables(sql_query_gen.tables_used) | |
| # Display sample tables (show first 10 rows for readability) | |
| for table_name, df in sample_data.items(): | |
| output_text += f"**π Sample `{table_name}` Table** ({len(df)} rows):\n\n" | |
| display_df = df.head(10) | |
| output_text += display_df.to_markdown(index=False) | |
| if len(df) > 10: | |
| output_text += f"\n\n*...and {len(df) - 10} more rows*" | |
| output_text += "\n\n" | |
| # Step 4: Execute SQL Query | |
| output_text += "### Step 4: Execute Generated SQL on Sample Tables\n\n" | |
| output_text += f"**SQL Query:**\n```sql\n{sql_query_gen.query}\n```\n\n" | |
| result_df, error = execute_sql_on_sample_data(sql_query_gen.query, conn) | |
| if error: | |
| output_text += f"β **Execution Error:** {error}\n" | |
| result_table = pd.DataFrame({"Error": [error]}) | |
| else: | |
| output_text += "β **Query executed successfully!**\n\n" | |
| output_text += f"**π SQL Execution Result** ({len(result_df)} rows returned):\n\n" | |
| if len(result_df) > 0: | |
| output_text += result_df.to_markdown(index=False) | |
| else: | |
| output_text += "*No results found matching the criteria*" | |
| result_table = result_df | |
| conn.close() | |
| # Format outputs for Gradio | |
| json_output = json.dumps(sql_query_gen.model_dump(), indent=2) | |
| return output_text, json_output, result_table, sql_query_gen.query | |
| except Exception as e: | |
| error_msg = f"β **Error:** {str(e)}\n\n**Full error details:**\n```\n{repr(e)}\n```\n\nPlease check your API key and try again." | |
| return error_msg, "", pd.DataFrame({"Error": [str(e)]}), "" | |
| # Create Gradio Interface | |
| with gr.Blocks(title="Natural Language to SQL Query Executor", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # π Natural Language to SQL Query Executor | |
| Convert natural language queries into SQL, generate sample data, and execute queries automatically! | |
| **Example queries to try:** | |
| - "Find all customers who made orders over $500 in the last 30 days, show their name, email, and total order amount" | |
| - "Show all employees who earn more than $75,000 and work in the Engineering department" | |
| - "List students who scored above 85% in Mathematics" | |
| - "Find all books published after 2020 that are currently available" | |
| - "Show properties with price between $200,000 and $500,000" | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| api_key_input = gr.Textbox( | |
| label="π Groq API Key", | |
| type="password", | |
| placeholder="Enter your Groq API key here...", | |
| info="Get your API key from https://console.groq.com" | |
| ) | |
| query_input = gr.Textbox( | |
| label="π¬ Natural Language Query", | |
| placeholder="e.g., Find all customers who made orders over $500 in the last 30 days...", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("π Generate & Execute SQL", variant="primary", size="lg") | |
| gr.Markdown("### π Generated SQL Query") | |
| sql_output = gr.Code(label="SQL Query", language="sql") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π Process & Results") | |
| process_output = gr.Markdown() | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π― Structured JSON Output") | |
| json_output = gr.Code(label="JSON Response", language="json") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π Query Execution Result") | |
| result_output = gr.Dataframe( | |
| label="Result Table", | |
| interactive=False, | |
| wrap=True | |
| ) | |
| # Connect the button to the processing function | |
| submit_btn.click( | |
| fn=process_nl_query, | |
| inputs=[api_key_input, query_input], | |
| outputs=[process_output, json_output, result_output, sql_output] | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π How it works: | |
| 1. **Enter your Groq API key** - Required for SQL generation (using Kimi K2 Instruct model) | |
| 2. **Write your query in plain English** - Describe what data you want to find | |
| 3. **Click Generate & Execute** - The system will: | |
| - Convert your query to SQL | |
| - Automatically detect and create ALL required tables | |
| - Generate realistic sample data for those tables | |
| - Execute the query | |
| - Show you the results | |
| ### π― Features: | |
| - β Natural language to SQL conversion using Kimi K2 Instruct | |
| - β **Smart table detection** - Creates ANY table mentioned in your query | |
| - β Automatic sample data generation for 15+ table types | |
| - β Query validation and metadata | |
| - β SQL execution on sample data | |
| - β Structured JSON output format | |
| - β Support for employees, books, students, movies, patients, properties, events, and more! | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |