|
|
import os |
|
|
import pandas as pd |
|
|
import sqlite3 |
|
|
import numpy as np |
|
|
import json |
|
|
import re |
|
|
from typing import List, Dict, Tuple |
|
|
from groq import Groq |
|
|
import gradio as gr |
|
|
from sklearn.metrics import accuracy_score |
|
|
import warnings |
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
|
|
|
if not GROQ_API_KEY: |
|
|
print("⚠️ WARNING: GROQ_API_KEY environment variable not set!") |
|
|
print("Please add your Groq API key to your Hugging Face Space secrets.") |
|
|
print("For demo purposes, the app will continue but API calls will fail.") |
|
|
GROQ_API_KEY = "dummy-key-for-demo" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnhancedNL2SQLConverter: |
|
|
def __init__(self, model_name: str = "llama-3.3-70b-versatile"): |
|
|
self.model_name = model_name |
|
|
self.client = None |
|
|
|
|
|
try: |
|
|
if GROQ_API_KEY and GROQ_API_KEY != "dummy-key-for-demo": |
|
|
self.client = Groq(api_key=GROQ_API_KEY) |
|
|
print(f"✅ Successfully initialized Groq client with model: {self.model_name}") |
|
|
else: |
|
|
print("⚠️ Groq client not initialized - API key missing") |
|
|
except Exception as e: |
|
|
print(f"❌ Error initializing Groq client: {str(e)}") |
|
|
self.client = None |
|
|
|
|
|
self.default_schema = """ |
|
|
Table: employees |
|
|
Columns: |
|
|
- id (INTEGER) PRIMARY KEY |
|
|
- name (TEXT) NOT NULL |
|
|
- department (TEXT) |
|
|
- salary (REAL) |
|
|
- hire_date (TEXT) |
|
|
- manager_id (INTEGER) |
|
|
""" |
|
|
|
|
|
def generate_sql(self, query: str, schema: str = None) -> str: |
|
|
try: |
|
|
if not self.client: |
|
|
return "ERROR: Groq API client not initialized. Please check your API key." |
|
|
|
|
|
schema_to_use = schema or self.default_schema |
|
|
|
|
|
system_prompt = """You are an expert SQL query generator. Convert natural language questions to SQL queries based on the provided database schema. |
|
|
|
|
|
Rules: |
|
|
1. Only return the SQL query, nothing else |
|
|
2. Use proper SQL syntax |
|
|
3. Be precise with column names and table names |
|
|
4. Use appropriate WHERE clauses, JOINs, and aggregations as needed |
|
|
5. For date comparisons, use proper date format |
|
|
6. Don't include explanations, just the SQL query""" |
|
|
|
|
|
user_prompt = f"""Database Schema: |
|
|
{schema_to_use} |
|
|
|
|
|
Natural Language Question: {query} |
|
|
|
|
|
Generate the SQL query:""" |
|
|
|
|
|
chat_completion = self.client.chat.completions.create( |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
model=self.model_name, |
|
|
temperature=0.1, |
|
|
max_tokens=200 |
|
|
) |
|
|
|
|
|
sql_query = chat_completion.choices[0].message.content.strip() |
|
|
return self._clean_sql(sql_query) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error generating SQL: {str(e)}") |
|
|
return f"ERROR: Could not generate SQL query - {str(e)}" |
|
|
|
|
|
def _clean_sql(self, sql: str) -> str: |
|
|
sql = sql.strip() |
|
|
sql = re.sub(r'```sql\n?', '', sql) |
|
|
sql = re.sub(r'```\n?', '', sql) |
|
|
sql = re.sub(r'^["\']|["\']$', '', sql) |
|
|
sql = sql.rstrip(';') |
|
|
|
|
|
sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER'] |
|
|
if not any(sql.upper().startswith(keyword) for keyword in sql_keywords): |
|
|
for keyword in sql_keywords: |
|
|
if keyword in sql.upper(): |
|
|
sql = sql[sql.upper().find(keyword):] |
|
|
break |
|
|
return sql |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SQLEvaluator: |
|
|
def __init__(self): |
|
|
self.db_path = "test_database.db" |
|
|
self.setup_test_database() |
|
|
|
|
|
def setup_test_database(self): |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
|
|
|
cursor.execute(''' |
|
|
CREATE TABLE IF NOT EXISTS employees ( |
|
|
id INTEGER PRIMARY KEY, |
|
|
name TEXT NOT NULL, |
|
|
department TEXT, |
|
|
salary REAL, |
|
|
hire_date TEXT, |
|
|
manager_id INTEGER |
|
|
)''') |
|
|
|
|
|
sample_data = [ |
|
|
(1, 'Alice Johnson', 'Engineering', 75000, '2022-01-15', None), |
|
|
(2, 'Bob Smith', 'Sales', 65000, '2021-06-20', None), |
|
|
(3, 'Charlie Brown', 'Engineering', 80000, '2020-03-10', 1), |
|
|
(4, 'Diana Prince', 'HR', 60000, '2023-02-28', None), |
|
|
(5, 'Eve Wilson', 'Sales', 70000, '2022-11-05', 2), |
|
|
(6, 'Frank Miller', 'Engineering', 85000, '2019-08-12', 1), |
|
|
(7, 'Grace Lee', 'Marketing', 55000, '2023-01-20', None), |
|
|
(8, 'Henry Davis', 'Engineering', 72000, '2022-07-30', 1) |
|
|
] |
|
|
|
|
|
cursor.executemany(''' |
|
|
INSERT OR REPLACE INTO employees (id, name, department, salary, hire_date, manager_id) |
|
|
VALUES (?, ?, ?, ?, ?, ?)''', sample_data) |
|
|
|
|
|
conn.commit() |
|
|
conn.close() |
|
|
print("✅ Test database initialized successfully") |
|
|
|
|
|
def execute_sql(self, sql_query: str) -> Tuple[bool, any]: |
|
|
try: |
|
|
conn = sqlite3.connect(self.db_path) |
|
|
cursor = conn.cursor() |
|
|
cursor.execute(sql_query) |
|
|
|
|
|
if sql_query.strip().upper().startswith('SELECT'): |
|
|
results = cursor.fetchall() |
|
|
columns = [description[0] for description in cursor.description] |
|
|
conn.close() |
|
|
return True, {'columns': columns, 'data': results} |
|
|
else: |
|
|
conn.commit() |
|
|
conn.close() |
|
|
return True, "Query executed successfully" |
|
|
except Exception as e: |
|
|
return False, str(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
converter = EnhancedNL2SQLConverter() |
|
|
evaluator = SQLEvaluator() |
|
|
print("✅ Application components initialized successfully") |
|
|
except Exception as e: |
|
|
print(f"❌ Error initializing components: {str(e)}") |
|
|
converter = None |
|
|
evaluator = SQLEvaluator() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_nl_query(nl_query: str) -> Tuple[str, str, str]: |
|
|
"""Process natural language query and return SQL + results""" |
|
|
if not nl_query.strip(): |
|
|
return "", "", "⚠️ Please enter a natural language query." |
|
|
|
|
|
try: |
|
|
if not converter: |
|
|
return "", "", "❌ Error: SQL converter not initialized. Please check API configuration." |
|
|
|
|
|
generated_sql = converter.generate_sql(nl_query) |
|
|
|
|
|
if generated_sql.startswith("ERROR"): |
|
|
return generated_sql, "", "❌ Failed to generate SQL query. Please check your API key." |
|
|
|
|
|
success, result = evaluator.execute_sql(generated_sql) |
|
|
|
|
|
if success and isinstance(result, dict): |
|
|
df = pd.DataFrame(result['data'], columns=result['columns']) |
|
|
if len(df) == 0: |
|
|
formatted_output = "No results found." |
|
|
else: |
|
|
formatted_output = df.to_string(index=False) |
|
|
return generated_sql, formatted_output, "✅ Query executed successfully!" |
|
|
elif success: |
|
|
return generated_sql, str(result), "✅ Query executed successfully!" |
|
|
else: |
|
|
return generated_sql, "", f"❌ Error executing query: {result}" |
|
|
|
|
|
except Exception as e: |
|
|
return "", "", f"❌ Unexpected error: {str(e)}" |
|
|
|
|
|
def get_sample_queries(): |
|
|
return [ |
|
|
"Show all employees in the Engineering department", |
|
|
"Find employees with salary greater than 70000", |
|
|
"List all employees hired after 2022", |
|
|
"Count employees by department", |
|
|
"Show the highest paid employee in each department", |
|
|
"Find employees who don't have a manager", |
|
|
"Show average salary by department" |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
/* Main container styling */ |
|
|
.gradio-container { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; |
|
|
min-height: 100vh; |
|
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
|
|
} |
|
|
|
|
|
/* Header styling */ |
|
|
.header-container { |
|
|
background: rgba(255, 255, 255, 0.1); |
|
|
backdrop-filter: blur(20px); |
|
|
border-radius: 20px; |
|
|
padding: 2rem; |
|
|
margin-bottom: 2rem; |
|
|
border: 1px solid rgba(255, 255, 255, 0.2); |
|
|
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.37); |
|
|
} |
|
|
|
|
|
/* Card styling */ |
|
|
.card { |
|
|
background: rgba(255, 255, 255, 0.95); |
|
|
backdrop-filter: blur(20px); |
|
|
border-radius: 16px; |
|
|
padding: 1.5rem; |
|
|
margin: 1rem 0; |
|
|
border: 1px solid rgba(255, 255, 255, 0.3); |
|
|
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15); |
|
|
transition: all 0.3s ease; |
|
|
} |
|
|
|
|
|
.card:hover { |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 12px 40px rgba(31, 38, 135, 0.25); |
|
|
} |
|
|
|
|
|
/* Input styling */ |
|
|
.gr-textbox { |
|
|
border-radius: 12px !important; |
|
|
border: 2px solid rgba(103, 126, 234, 0.3) !important; |
|
|
background: rgba(255, 255, 255, 0.9) !important; |
|
|
transition: all 0.3s ease !important; |
|
|
} |
|
|
|
|
|
.gr-textbox:focus { |
|
|
border-color: #667eea !important; |
|
|
box-shadow: 0 0 0 3px rgba(103, 126, 234, 0.1) !important; |
|
|
transform: scale(1.02); |
|
|
} |
|
|
|
|
|
/* Button styling */ |
|
|
.gr-button { |
|
|
background: linear-gradient(45deg, #667eea, #764ba2) !important; |
|
|
border: none !important; |
|
|
border-radius: 12px !important; |
|
|
padding: 12px 24px !important; |
|
|
font-weight: 600 !important; |
|
|
color: white !important; |
|
|
transition: all 0.3s ease !important; |
|
|
box-shadow: 0 4px 15px rgba(103, 126, 234, 0.4) !important; |
|
|
} |
|
|
|
|
|
.gr-button:hover { |
|
|
transform: translateY(-2px) !important; |
|
|
box-shadow: 0 8px 25px rgba(103, 126, 234, 0.6) !important; |
|
|
} |
|
|
|
|
|
.sample-btn { |
|
|
background: linear-gradient(45deg, #f093fb, #f5576c) !important; |
|
|
margin: 0.25rem !important; |
|
|
font-size: 0.9rem !important; |
|
|
padding: 8px 16px !important; |
|
|
} |
|
|
|
|
|
.sample-btn:hover { |
|
|
background: linear-gradient(45deg, #f5576c, #f093fb) !important; |
|
|
} |
|
|
|
|
|
/* Results area styling */ |
|
|
.results-container { |
|
|
background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%); |
|
|
border-radius: 16px; |
|
|
padding: 1.5rem; |
|
|
margin-top: 1rem; |
|
|
} |
|
|
|
|
|
/* Status indicators */ |
|
|
.status-success { |
|
|
color: #10b981 !important; |
|
|
font-weight: 600 !important; |
|
|
} |
|
|
|
|
|
.status-error { |
|
|
color: #ef4444 !important; |
|
|
font-weight: 600 !important; |
|
|
} |
|
|
|
|
|
.status-warning { |
|
|
color: #f59e0b !important; |
|
|
font-weight: 600 !important; |
|
|
} |
|
|
|
|
|
/* Schema box */ |
|
|
.schema-box { |
|
|
background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%); |
|
|
border-radius: 12px; |
|
|
padding: 1rem; |
|
|
font-family: 'Monaco', 'Consolas', monospace; |
|
|
border-left: 4px solid #f59e0b; |
|
|
} |
|
|
|
|
|
/* Animation keyframes */ |
|
|
@keyframes fadeInUp { |
|
|
from { |
|
|
opacity: 0; |
|
|
transform: translateY(30px); |
|
|
} |
|
|
to { |
|
|
opacity: 1; |
|
|
transform: translateY(0); |
|
|
} |
|
|
} |
|
|
|
|
|
.fade-in { |
|
|
animation: fadeInUp 0.6s ease-out; |
|
|
} |
|
|
|
|
|
/* Responsive design */ |
|
|
@media (max-width: 768px) { |
|
|
.gradio-container { |
|
|
padding: 1rem; |
|
|
} |
|
|
|
|
|
.card { |
|
|
padding: 1rem; |
|
|
margin: 0.5rem 0; |
|
|
} |
|
|
} |
|
|
|
|
|
/* Loading spinner */ |
|
|
.loading { |
|
|
display: inline-block; |
|
|
width: 20px; |
|
|
height: 20px; |
|
|
border: 3px solid rgba(255,255,255,.3); |
|
|
border-radius: 50%; |
|
|
border-top-color: #fff; |
|
|
animation: spin 1s ease-in-out infinite; |
|
|
} |
|
|
|
|
|
@keyframes spin { |
|
|
to { transform: rotate(360deg); } |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css, title="AI-Powered NL2SQL Converter", theme=gr.themes.Glass()) as iface: |
|
|
|
|
|
with gr.Row(elem_classes="header-container fade-in"): |
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; color: white;"> |
|
|
<h1 style="font-size: 3rem; margin-bottom: 0.5rem; background: linear-gradient(45deg, #fff, #f0f0f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent;"> |
|
|
🚀 AI-Powered NL2SQL Converter |
|
|
</h1> |
|
|
<p style="font-size: 1.2rem; opacity: 0.9; margin-bottom: 1rem;"> |
|
|
Transform natural language into powerful SQL queries using Groq's advanced AI |
|
|
</p> |
|
|
<div style="display: flex; justify-content: center; gap: 2rem; margin-top: 1rem;"> |
|
|
<div style="text-align: center;"> |
|
|
<div style="font-size: 2rem;">🤖</div> |
|
|
<div style="font-size: 0.9rem; opacity: 0.8;">AI-Powered</div> |
|
|
</div> |
|
|
<div style="text-align: center;"> |
|
|
<div style="font-size: 2rem;">⚡</div> |
|
|
<div style="font-size: 0.9rem; opacity: 0.8;">Lightning Fast</div> |
|
|
</div> |
|
|
<div style="text-align: center;"> |
|
|
<div style="font-size: 2rem;">🎯</div> |
|
|
<div style="font-size: 0.9rem; opacity: 0.8;">Precise Results</div> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(elem_classes="card fade-in"): |
|
|
gr.HTML(""" |
|
|
<div class="schema-box"> |
|
|
<h3 style="color: #d97706; margin-bottom: 1rem;">📊 Database Schema</h3> |
|
|
<div style="background: rgba(255,255,255,0.7); padding: 1rem; border-radius: 8px;"> |
|
|
<strong>employees</strong> table:<br> |
|
|
• <code>id</code> (INTEGER) - Primary Key<br> |
|
|
• <code>name</code> (TEXT) - Employee Name<br> |
|
|
• <code>department</code> (TEXT) - Department<br> |
|
|
• <code>salary</code> (REAL) - Salary Amount<br> |
|
|
• <code>hire_date</code> (TEXT) - Hiring Date<br> |
|
|
• <code>manager_id</code> (INTEGER) - Manager Reference |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Row(elem_classes="card fade-in"): |
|
|
with gr.Column(scale=3): |
|
|
nl_input = gr.Textbox( |
|
|
label="💬 Ask your question in plain English", |
|
|
placeholder="e.g., Show me all engineers earning more than $75,000", |
|
|
lines=3, |
|
|
elem_classes="main-input" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button( |
|
|
"🔮 Generate & Execute SQL", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
elem_classes="main-button" |
|
|
) |
|
|
clear_btn = gr.Button( |
|
|
"🗑️ Clear", |
|
|
variant="secondary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
gr.HTML("<h3 style='color: #667eea; margin-bottom: 1rem;'>🎯 Try These Examples</h3>") |
|
|
|
|
|
sample_queries = get_sample_queries() |
|
|
for i, query in enumerate(sample_queries): |
|
|
sample_btn = gr.Button( |
|
|
f"💡 {query}", |
|
|
variant="secondary", |
|
|
size="sm", |
|
|
elem_classes="sample-btn" |
|
|
) |
|
|
sample_btn.click( |
|
|
lambda q=query: q, |
|
|
outputs=nl_input |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(elem_classes="results-container fade-in"): |
|
|
with gr.Column(): |
|
|
gr.HTML("<h3 style='color: #6366f1; margin-bottom: 1rem;'>📝 Generated SQL Query</h3>") |
|
|
sql_output = gr.Code( |
|
|
label="", |
|
|
language="sql", |
|
|
lines=4, |
|
|
interactive=False, |
|
|
elem_classes="sql-output" |
|
|
) |
|
|
|
|
|
status_output = gr.HTML( |
|
|
"<div style='padding: 1rem; text-align: center; font-size: 1.1rem;'>Ready to process your query! 🚀</div>" |
|
|
) |
|
|
|
|
|
with gr.Row(elem_classes="card fade-in"): |
|
|
gr.HTML("<h3 style='color: #059669; margin-bottom: 1rem;'>📊 Query Results</h3>") |
|
|
results_output = gr.Code( |
|
|
label="", |
|
|
lines=12, |
|
|
interactive=False, |
|
|
elem_classes="results-output" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(elem_classes="card fade-in"): |
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; padding: 1rem;"> |
|
|
<h3 style="color: #667eea; margin-bottom: 1rem;">🔍 About This Application</h3> |
|
|
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 1rem; margin-top: 1rem;"> |
|
|
<div style="background: linear-gradient(135deg, #667eea, #764ba2); color: white; padding: 1rem; border-radius: 12px;"> |
|
|
<h4>🤖 AI Model</h4> |
|
|
<p>Powered by Groq's Llama3-70B for intelligent SQL generation</p> |
|
|
</div> |
|
|
<div style="background: linear-gradient(135deg, #f093fb, #f5576c); color: white; padding: 1rem; border-radius: 12px;"> |
|
|
<h4>💾 Database</h4> |
|
|
<p>SQLite with sample employee data for testing and learning</p> |
|
|
</div> |
|
|
<div style="background: linear-gradient(135deg, #a8edea, #fed6e3); color: #374151; padding: 1rem; border-radius: 12px;"> |
|
|
<h4>✨ Features</h4> |
|
|
<p>Natural language processing, SQL execution, and formatted results</p> |
|
|
</div> |
|
|
</div> |
|
|
<div style="margin-top: 2rem; padding: 1rem; background: rgba(103, 126, 234, 0.1); border-radius: 12px;"> |
|
|
<h4 style="color: #667eea;">💡 Pro Tips for Better Results</h4> |
|
|
<ul style="text-align: left; display: inline-block; color: #4b5563;"> |
|
|
<li>Be specific and clear in your questions</li> |
|
|
<li>Use column names mentioned in the schema</li> |
|
|
<li>Try the sample queries to understand the format</li> |
|
|
<li>Use natural language - no need for technical jargon</li> |
|
|
</ul> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
def enhanced_process(query): |
|
|
if not query.strip(): |
|
|
return "", "<div class='status-warning'>⚠️ Please enter a question first!</div>", "" |
|
|
|
|
|
|
|
|
loading_html = "<div class='status-info'>🔄 Processing your query... <span class='loading'></span></div>" |
|
|
|
|
|
try: |
|
|
sql, results, status = process_nl_query(query) |
|
|
|
|
|
|
|
|
if "successfully" in status.lower(): |
|
|
status_html = f"<div class='status-success'>{status}</div>" |
|
|
elif "error" in status.lower() or "failed" in status.lower(): |
|
|
status_html = f"<div class='status-error'>{status}</div>" |
|
|
else: |
|
|
status_html = f"<div class='status-warning'>{status}</div>" |
|
|
|
|
|
return sql, status_html, results |
|
|
|
|
|
except Exception as e: |
|
|
return "", f"<div class='status-error'>❌ Unexpected error: {str(e)}</div>", "" |
|
|
|
|
|
def clear_all(): |
|
|
return "", "", "<div style='padding: 1rem; text-align: center; font-size: 1.1rem;'>Ready to process your query! 🚀</div>", "" |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=enhanced_process, |
|
|
inputs=[nl_input], |
|
|
outputs=[sql_output, status_output, results_output] |
|
|
) |
|
|
|
|
|
nl_input.submit( |
|
|
fn=enhanced_process, |
|
|
inputs=[nl_input], |
|
|
outputs=[sql_output, status_output, results_output] |
|
|
) |
|
|
|
|
|
clear_btn.click( |
|
|
fn=clear_all, |
|
|
outputs=[nl_input, sql_output, status_output, results_output] |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Launching Enhanced NL2SQL Application...") |
|
|
iface.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=True, |
|
|
show_error=True |
|
|
) |