Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- .gitignore +1 -0
- __init__.py +0 -0
- admin_dashboard.py +182 -0
- api_server.py +128 -0
- app_ui.py +182 -0
- check_schema.py +20 -0
- db.sql +34 -0
- db_connector.py +84 -0
- index.html +430 -0
- rag_manager.py +66 -0
- requirements.txt +21 -0
- setup_full_db.py +105 -0
- sql_generator.py +76 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
__init__.py
ADDED
|
File without changes
|
admin_dashboard.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Add src to path
|
| 8 |
+
sys.path.append(os.getcwd())
|
| 9 |
+
|
| 10 |
+
from src.rag_manager import RAGManager
|
| 11 |
+
from src.sql_generator import SQLGenerator
|
| 12 |
+
from src.db_connector import DatabaseConnector
|
| 13 |
+
|
| 14 |
+
# --- 1. CONFIGURATION ---
|
| 15 |
+
st.set_page_config(
|
| 16 |
+
page_title="NexusAI | Enterprise Data",
|
| 17 |
+
page_icon="✨",
|
| 18 |
+
layout="wide",
|
| 19 |
+
initial_sidebar_state="collapsed"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Custom CSS
|
| 23 |
+
st.markdown("""
|
| 24 |
+
<style>
|
| 25 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
| 26 |
+
html, body, [class*="css"] { font-family: 'Inter', sans-serif; }
|
| 27 |
+
.stApp { background-color: #0F1117; }
|
| 28 |
+
#MainMenu, footer, header { visibility: hidden; }
|
| 29 |
+
|
| 30 |
+
.stChatMessage { background-color: transparent !important; border: none !important; }
|
| 31 |
+
|
| 32 |
+
div[data-testid="stChatMessage"]:nth-child(odd) { flex-direction: row-reverse; }
|
| 33 |
+
div[data-testid="stChatMessage"]:nth-child(odd) .stMarkdown {
|
| 34 |
+
background-color: #2B2D31; color: #E0E0E0;
|
| 35 |
+
border-radius: 18px 18px 4px 18px; padding: 12px 20px;
|
| 36 |
+
text-align: right; margin-left: auto;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
div[data-testid="stChatMessage"]:nth-child(even) .stMarkdown {
|
| 40 |
+
background-color: transparent; color: #F0F0F0; padding-left: 10px;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.stChatInput { position: fixed; bottom: 30px; width: 70% !important; left: 50%; transform: translateX(-50%); z-index: 1000; }
|
| 44 |
+
.stTextInput > div > div > input { background-color: #1E2128; color: white; border-radius: 24px; border: 1px solid #363B47; }
|
| 45 |
+
|
| 46 |
+
div[data-testid="stDataFrame"] { background-color: #161920; border-radius: 10px; padding: 10px; border: 1px solid #30363D; }
|
| 47 |
+
section[data-testid="stSidebar"] { background-color: #0E1015; border-right: 1px solid #222; }
|
| 48 |
+
</style>
|
| 49 |
+
""", unsafe_allow_html=True)
|
| 50 |
+
|
| 51 |
+
# --- 2. INITIALIZATION ---
|
| 52 |
+
@st.cache_resource
|
| 53 |
+
def get_core():
|
| 54 |
+
load_dotenv()
|
| 55 |
+
key = os.getenv("GEMINI_API_KEY")
|
| 56 |
+
return RAGManager(), SQLGenerator(key), DatabaseConnector()
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
rag, sql_gen, db = get_core()
|
| 60 |
+
except Exception as e:
|
| 61 |
+
st.error(f"System Offline: {e}")
|
| 62 |
+
st.stop()
|
| 63 |
+
|
| 64 |
+
# --- 3. SIDEBAR ---
|
| 65 |
+
with st.sidebar:
|
| 66 |
+
st.markdown("## 🧠 NexusAI")
|
| 67 |
+
st.caption("Enterprise SQL Agent v2.0")
|
| 68 |
+
st.divider()
|
| 69 |
+
|
| 70 |
+
if db:
|
| 71 |
+
st.success("🟢 Database Connected")
|
| 72 |
+
|
| 73 |
+
st.markdown("### 📚 Quick Prompts")
|
| 74 |
+
prompts = [
|
| 75 |
+
"Top 5 employees by salary",
|
| 76 |
+
"Total sales revenue by Region",
|
| 77 |
+
"Show me products with low stock",
|
| 78 |
+
"Which department spends the most?"
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
for p in prompts:
|
| 82 |
+
if st.button(p, use_container_width=True):
|
| 83 |
+
st.session_state.last_prompt = p
|
| 84 |
+
|
| 85 |
+
if st.button("🗑️ Clear Context", type="primary", use_container_width=True):
|
| 86 |
+
st.session_state.messages = []
|
| 87 |
+
st.rerun()
|
| 88 |
+
|
| 89 |
+
# --- 4. MAIN INTERFACE ---
|
| 90 |
+
if "messages" not in st.session_state:
|
| 91 |
+
st.session_state.messages = []
|
| 92 |
+
|
| 93 |
+
if not st.session_state.messages:
|
| 94 |
+
st.markdown("""
|
| 95 |
+
<div style="text-align: center; margin-top: 100px;">
|
| 96 |
+
<h1 style="font-size: 3rem; background: -webkit-linear-gradient(#eee, #333); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">
|
| 97 |
+
What can I help you analyze?
|
| 98 |
+
</h1>
|
| 99 |
+
<p style="color: #666;">Connect to your database and ask questions in plain English.</p>
|
| 100 |
+
</div>
|
| 101 |
+
""", unsafe_allow_html=True)
|
| 102 |
+
|
| 103 |
+
for msg in st.session_state.messages:
|
| 104 |
+
with st.chat_message(msg["role"], avatar="👤" if msg["role"] == "user" else "✨"):
|
| 105 |
+
st.markdown(msg["content"])
|
| 106 |
+
|
| 107 |
+
if "data" in msg:
|
| 108 |
+
# ✅ FIX: Switched to clean dataframe display
|
| 109 |
+
st.dataframe(msg["data"], hide_index=True)
|
| 110 |
+
if "chart" in msg:
|
| 111 |
+
st.bar_chart(msg["chart"])
|
| 112 |
+
if "sql" in msg:
|
| 113 |
+
with st.expander("🛠️ View Query Logic"):
|
| 114 |
+
st.code(msg["sql"], language="sql")
|
| 115 |
+
|
| 116 |
+
# Handle Input
|
| 117 |
+
user_input = st.chat_input("Ask anything...")
|
| 118 |
+
|
| 119 |
+
if "last_prompt" in st.session_state and st.session_state.last_prompt:
|
| 120 |
+
user_input = st.session_state.last_prompt
|
| 121 |
+
st.session_state.last_prompt = None
|
| 122 |
+
|
| 123 |
+
if user_input:
|
| 124 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
| 125 |
+
with st.chat_message("user", avatar="👤"):
|
| 126 |
+
st.markdown(user_input)
|
| 127 |
+
|
| 128 |
+
with st.chat_message("assistant", avatar="✨"):
|
| 129 |
+
status_box = st.empty()
|
| 130 |
+
status_box.markdown("`⚡ analyzing...`")
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
tables = rag.get_relevant_tables(user_input)
|
| 134 |
+
context = "\n".join(tables)
|
| 135 |
+
|
| 136 |
+
sql = sql_gen.generate_sql(user_input, context)
|
| 137 |
+
|
| 138 |
+
results = db.execute_sql(sql)
|
| 139 |
+
status_box.empty()
|
| 140 |
+
|
| 141 |
+
if not results:
|
| 142 |
+
response = "No data found matching that request."
|
| 143 |
+
st.markdown(response)
|
| 144 |
+
st.session_state.messages.append({"role": "assistant", "content": response, "sql": sql})
|
| 145 |
+
else:
|
| 146 |
+
df = pd.DataFrame(results)
|
| 147 |
+
df_clean = df.reset_index(drop=True)
|
| 148 |
+
|
| 149 |
+
response = f"Found **{len(df)}** records."
|
| 150 |
+
st.markdown(response)
|
| 151 |
+
# ✅ FIX: Updated dataframe display
|
| 152 |
+
st.dataframe(df_clean, hide_index=True)
|
| 153 |
+
|
| 154 |
+
chart_data = None
|
| 155 |
+
numeric_cols = df_clean.select_dtypes(include=['number']).columns
|
| 156 |
+
|
| 157 |
+
if not numeric_cols.empty and len(df_clean) > 1:
|
| 158 |
+
try:
|
| 159 |
+
non_numeric = df_clean.select_dtypes(exclude=['number']).columns
|
| 160 |
+
st.markdown("##### 📊 Trends")
|
| 161 |
+
if not non_numeric.empty:
|
| 162 |
+
x_axis = non_numeric[0]
|
| 163 |
+
y_axis = numeric_cols[0]
|
| 164 |
+
chart_data = df_clean.set_index(x_axis)[y_axis]
|
| 165 |
+
st.bar_chart(chart_data, color="#7B61FF")
|
| 166 |
+
else:
|
| 167 |
+
chart_data = df_clean[numeric_cols[0]]
|
| 168 |
+
st.bar_chart(chart_data, color="#7B61FF")
|
| 169 |
+
except Exception:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
st.session_state.messages.append({
|
| 173 |
+
"role": "assistant",
|
| 174 |
+
"content": response,
|
| 175 |
+
"data": df_clean,
|
| 176 |
+
"chart": chart_data,
|
| 177 |
+
"sql": sql
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
status_box.empty()
|
| 182 |
+
st.error(f"Error: {e}")
|
api_server.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
import uvicorn
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
# 🚨 FORCE PYTHON TO FIND THE 'src' FOLDER
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 8 |
+
|
| 9 |
+
from fastapi import FastAPI, HTTPException
|
| 10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from typing import List, Optional
|
| 13 |
+
|
| 14 |
+
# Import your custom modules
|
| 15 |
+
from src.db_connector import Database
|
| 16 |
+
from src.rag_manager import RAGSystem
|
| 17 |
+
from src.sql_generator import SQLGenerator
|
| 18 |
+
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
|
| 21 |
+
app.add_middleware(
|
| 22 |
+
CORSMiddleware,
|
| 23 |
+
allow_origins=["*"],
|
| 24 |
+
allow_credentials=True,
|
| 25 |
+
allow_methods=["*"],
|
| 26 |
+
allow_headers=["*"],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
class ChatRequest(BaseModel):
|
| 30 |
+
question: str
|
| 31 |
+
history: Optional[List[dict]] = []
|
| 32 |
+
|
| 33 |
+
# --- HELPER: CLEAN AI OUTPUT ---
|
| 34 |
+
def clean_sql(sql_text: str) -> str:
|
| 35 |
+
"""Removes markdown, 'sql' tags, and extra whitespace."""
|
| 36 |
+
if not sql_text:
|
| 37 |
+
return ""
|
| 38 |
+
# Remove markdown code blocks (```sql ... ```)
|
| 39 |
+
cleaned = re.sub(r"```sql|```", "", sql_text, flags=re.IGNORECASE).strip()
|
| 40 |
+
# Remove trailing semicolons for consistency (optional, depending on DB driver)
|
| 41 |
+
cleaned = cleaned.rstrip(';')
|
| 42 |
+
return cleaned
|
| 43 |
+
|
| 44 |
+
print("--- 🚀 SYSTEM STARTUP SEQUENCE ---")
|
| 45 |
+
try:
|
| 46 |
+
print(" ...Connecting to Database")
|
| 47 |
+
db = Database()
|
| 48 |
+
print(" ✅ Database Connection: SUCCESS")
|
| 49 |
+
|
| 50 |
+
print(" ...Initializing RAG System")
|
| 51 |
+
rag = RAGSystem(db)
|
| 52 |
+
print(" ✅ RAG System: ONLINE")
|
| 53 |
+
|
| 54 |
+
print(" ...Loading AI Model")
|
| 55 |
+
generator = SQLGenerator()
|
| 56 |
+
print(" ✅ AI Model: LOADED")
|
| 57 |
+
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f" ❌ CRITICAL STARTUP ERROR: {e}")
|
| 60 |
+
|
| 61 |
+
@app.post("/chat")
|
| 62 |
+
def chat_endpoint(request: ChatRequest):
|
| 63 |
+
print(f"\n📨 NEW REQUEST: {request.question}")
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# 1. Get Context
|
| 67 |
+
context = rag.get_relevant_schema(request.question)
|
| 68 |
+
|
| 69 |
+
# 2. Generate SQL
|
| 70 |
+
raw_sql, explanation, friendly_msg = generator.generate_sql(request.question, context, request.history)
|
| 71 |
+
|
| 72 |
+
# 3. Clean and Validate SQL
|
| 73 |
+
# If the generator returned an error string directly (Error 1 fix)
|
| 74 |
+
if "Error:" in raw_sql or "Invalid Query" in raw_sql:
|
| 75 |
+
return {
|
| 76 |
+
"answer": [],
|
| 77 |
+
"sql": raw_sql,
|
| 78 |
+
"message": "I couldn't generate a safe query for that request. Try asking for specific data like 'Show me users' or 'List orders'.",
|
| 79 |
+
"follow_ups": ["Show top 10 rows from users", "Count total orders"]
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
cleaned_sql = clean_sql(raw_sql)
|
| 83 |
+
print(f" 🧹 Cleaned SQL: {cleaned_sql}")
|
| 84 |
+
|
| 85 |
+
# Safety check: Ensure it's a SELECT
|
| 86 |
+
if not cleaned_sql.upper().startswith("SELECT"):
|
| 87 |
+
return {
|
| 88 |
+
"answer": [],
|
| 89 |
+
"sql": cleaned_sql,
|
| 90 |
+
"message": "Security Alert: I can only perform READ (SELECT) operations.",
|
| 91 |
+
"follow_ups": []
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# 4. Run Query (Error 2 fix)
|
| 95 |
+
try:
|
| 96 |
+
results = db.run_query(cleaned_sql)
|
| 97 |
+
except Exception as db_err:
|
| 98 |
+
# Catch MySQL Syntax errors specifically
|
| 99 |
+
print(f" ⚠️ DB Error: {db_err}")
|
| 100 |
+
return {
|
| 101 |
+
"answer": [f"Error: {str(db_err)}"], # This puts the error in a safe list
|
| 102 |
+
"sql": cleaned_sql,
|
| 103 |
+
"message": "There was a syntax error in the generated SQL. I have displayed the error above.",
|
| 104 |
+
"follow_ups": []
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
# 5. Generate Follow-ups
|
| 108 |
+
follow_ups = generator.generate_followup_questions(request.question, cleaned_sql)
|
| 109 |
+
|
| 110 |
+
return {
|
| 111 |
+
"answer": results,
|
| 112 |
+
"sql": cleaned_sql,
|
| 113 |
+
"explanation": explanation,
|
| 114 |
+
"message": friendly_msg,
|
| 115 |
+
"follow_ups": follow_ups
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"❌ General Processing Error: {e}")
|
| 120 |
+
return {
|
| 121 |
+
"answer": [],
|
| 122 |
+
"sql": "-- System Error",
|
| 123 |
+
"message": f"Critical Error: {str(e)}",
|
| 124 |
+
"follow_ups": []
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
app_ui.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Add src to path
|
| 8 |
+
sys.path.append(os.getcwd())
|
| 9 |
+
|
| 10 |
+
from src.rag_manager import RAGManager
|
| 11 |
+
from src.sql_generator import SQLGenerator
|
| 12 |
+
from src.db_connector import DatabaseConnector
|
| 13 |
+
|
| 14 |
+
# --- 1. CONFIGURATION ---
|
| 15 |
+
st.set_page_config(
|
| 16 |
+
page_title="NexusAI | Enterprise Data",
|
| 17 |
+
page_icon="✨",
|
| 18 |
+
layout="wide",
|
| 19 |
+
initial_sidebar_state="collapsed"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Custom CSS
|
| 23 |
+
st.markdown("""
|
| 24 |
+
<style>
|
| 25 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap');
|
| 26 |
+
html, body, [class*="css"] { font-family: 'Inter', sans-serif; }
|
| 27 |
+
.stApp { background-color: #0F1117; }
|
| 28 |
+
#MainMenu, footer, header { visibility: hidden; }
|
| 29 |
+
|
| 30 |
+
.stChatMessage { background-color: transparent !important; border: none !important; }
|
| 31 |
+
|
| 32 |
+
div[data-testid="stChatMessage"]:nth-child(odd) { flex-direction: row-reverse; }
|
| 33 |
+
div[data-testid="stChatMessage"]:nth-child(odd) .stMarkdown {
|
| 34 |
+
background-color: #2B2D31; color: #E0E0E0;
|
| 35 |
+
border-radius: 18px 18px 4px 18px; padding: 12px 20px;
|
| 36 |
+
text-align: right; margin-left: auto;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
div[data-testid="stChatMessage"]:nth-child(even) .stMarkdown {
|
| 40 |
+
background-color: transparent; color: #F0F0F0; padding-left: 10px;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
.stChatInput { position: fixed; bottom: 30px; width: 70% !important; left: 50%; transform: translateX(-50%); z-index: 1000; }
|
| 44 |
+
.stTextInput > div > div > input { background-color: #1E2128; color: white; border-radius: 24px; border: 1px solid #363B47; }
|
| 45 |
+
|
| 46 |
+
div[data-testid="stDataFrame"] { background-color: #161920; border-radius: 10px; padding: 10px; border: 1px solid #30363D; }
|
| 47 |
+
section[data-testid="stSidebar"] { background-color: #0E1015; border-right: 1px solid #222; }
|
| 48 |
+
</style>
|
| 49 |
+
""", unsafe_allow_html=True)
|
| 50 |
+
|
| 51 |
+
# --- 2. INITIALIZATION ---
|
| 52 |
+
@st.cache_resource
|
| 53 |
+
def get_core():
|
| 54 |
+
load_dotenv()
|
| 55 |
+
key = os.getenv("GEMINI_API_KEY")
|
| 56 |
+
return RAGManager(), SQLGenerator(key), DatabaseConnector()
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
rag, sql_gen, db = get_core()
|
| 60 |
+
except Exception as e:
|
| 61 |
+
st.error(f"System Offline: {e}")
|
| 62 |
+
st.stop()
|
| 63 |
+
|
| 64 |
+
# --- 3. SIDEBAR ---
|
| 65 |
+
with st.sidebar:
|
| 66 |
+
st.markdown("## 🧠 NexusAI")
|
| 67 |
+
st.caption("Enterprise SQL Agent v2.0")
|
| 68 |
+
st.divider()
|
| 69 |
+
|
| 70 |
+
if db:
|
| 71 |
+
st.success("🟢 Database Connected")
|
| 72 |
+
|
| 73 |
+
st.markdown("### 📚 Quick Prompts")
|
| 74 |
+
prompts = [
|
| 75 |
+
"Top 5 employees by salary",
|
| 76 |
+
"Total sales revenue by Region",
|
| 77 |
+
"Show me products with low stock",
|
| 78 |
+
"Which department spends the most?"
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
for p in prompts:
|
| 82 |
+
if st.button(p, use_container_width=True):
|
| 83 |
+
st.session_state.last_prompt = p
|
| 84 |
+
|
| 85 |
+
if st.button("🗑️ Clear Context", type="primary", use_container_width=True):
|
| 86 |
+
st.session_state.messages = []
|
| 87 |
+
st.rerun()
|
| 88 |
+
|
| 89 |
+
# --- 4. MAIN INTERFACE ---
|
| 90 |
+
if "messages" not in st.session_state:
|
| 91 |
+
st.session_state.messages = []
|
| 92 |
+
|
| 93 |
+
if not st.session_state.messages:
|
| 94 |
+
st.markdown("""
|
| 95 |
+
<div style="text-align: center; margin-top: 100px;">
|
| 96 |
+
<h1 style="font-size: 3rem; background: -webkit-linear-gradient(#eee, #333); -webkit-background-clip: text; -webkit-text-fill-color: transparent;">
|
| 97 |
+
What can I help you analyze?
|
| 98 |
+
</h1>
|
| 99 |
+
<p style="color: #666;">Connect to your database and ask questions in plain English.</p>
|
| 100 |
+
</div>
|
| 101 |
+
""", unsafe_allow_html=True)
|
| 102 |
+
|
| 103 |
+
for msg in st.session_state.messages:
|
| 104 |
+
with st.chat_message(msg["role"], avatar="👤" if msg["role"] == "user" else "✨"):
|
| 105 |
+
st.markdown(msg["content"])
|
| 106 |
+
|
| 107 |
+
if "data" in msg:
|
| 108 |
+
# ✅ FIX: Switched to clean dataframe display
|
| 109 |
+
st.dataframe(msg["data"], hide_index=True)
|
| 110 |
+
if "chart" in msg:
|
| 111 |
+
st.bar_chart(msg["chart"])
|
| 112 |
+
if "sql" in msg:
|
| 113 |
+
with st.expander("🛠️ View Query Logic"):
|
| 114 |
+
st.code(msg["sql"], language="sql")
|
| 115 |
+
|
| 116 |
+
# Handle Input
|
| 117 |
+
user_input = st.chat_input("Ask anything...")
|
| 118 |
+
|
| 119 |
+
if "last_prompt" in st.session_state and st.session_state.last_prompt:
|
| 120 |
+
user_input = st.session_state.last_prompt
|
| 121 |
+
st.session_state.last_prompt = None
|
| 122 |
+
|
| 123 |
+
if user_input:
|
| 124 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
| 125 |
+
with st.chat_message("user", avatar="👤"):
|
| 126 |
+
st.markdown(user_input)
|
| 127 |
+
|
| 128 |
+
with st.chat_message("assistant", avatar="✨"):
|
| 129 |
+
status_box = st.empty()
|
| 130 |
+
status_box.markdown("`⚡ analyzing...`")
|
| 131 |
+
|
| 132 |
+
try:
|
| 133 |
+
tables = rag.get_relevant_tables(user_input)
|
| 134 |
+
context = "\n".join(tables)
|
| 135 |
+
|
| 136 |
+
sql = sql_gen.generate_sql(user_input, context)
|
| 137 |
+
|
| 138 |
+
results = db.execute_sql(sql)
|
| 139 |
+
status_box.empty()
|
| 140 |
+
|
| 141 |
+
if not results:
|
| 142 |
+
response = "No data found matching that request."
|
| 143 |
+
st.markdown(response)
|
| 144 |
+
st.session_state.messages.append({"role": "assistant", "content": response, "sql": sql})
|
| 145 |
+
else:
|
| 146 |
+
df = pd.DataFrame(results)
|
| 147 |
+
df_clean = df.reset_index(drop=True)
|
| 148 |
+
|
| 149 |
+
response = f"Found **{len(df)}** records."
|
| 150 |
+
st.markdown(response)
|
| 151 |
+
# ✅ FIX: Updated dataframe display
|
| 152 |
+
st.dataframe(df_clean, hide_index=True)
|
| 153 |
+
|
| 154 |
+
chart_data = None
|
| 155 |
+
numeric_cols = df_clean.select_dtypes(include=['number']).columns
|
| 156 |
+
|
| 157 |
+
if not numeric_cols.empty and len(df_clean) > 1:
|
| 158 |
+
try:
|
| 159 |
+
non_numeric = df_clean.select_dtypes(exclude=['number']).columns
|
| 160 |
+
st.markdown("##### 📊 Trends")
|
| 161 |
+
if not non_numeric.empty:
|
| 162 |
+
x_axis = non_numeric[0]
|
| 163 |
+
y_axis = numeric_cols[0]
|
| 164 |
+
chart_data = df_clean.set_index(x_axis)[y_axis]
|
| 165 |
+
st.bar_chart(chart_data, color="#7B61FF")
|
| 166 |
+
else:
|
| 167 |
+
chart_data = df_clean[numeric_cols[0]]
|
| 168 |
+
st.bar_chart(chart_data, color="#7B61FF")
|
| 169 |
+
except Exception:
|
| 170 |
+
pass
|
| 171 |
+
|
| 172 |
+
st.session_state.messages.append({
|
| 173 |
+
"role": "assistant",
|
| 174 |
+
"content": response,
|
| 175 |
+
"data": df_clean,
|
| 176 |
+
"chart": chart_data,
|
| 177 |
+
"sql": sql
|
| 178 |
+
})
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
status_box.empty()
|
| 182 |
+
st.error(f"Error: {e}")
|
check_schema.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from src.db_connector import Database
|
| 4 |
+
|
| 5 |
+
# Load your actual database connection
|
| 6 |
+
db = Database()
|
| 7 |
+
|
| 8 |
+
print("\n--- 🔍 REAL DATABASE SCHEMA ---")
|
| 9 |
+
try:
|
| 10 |
+
# Get all tables
|
| 11 |
+
tables = db.get_tables()
|
| 12 |
+
for table in tables:
|
| 13 |
+
print(f"\n📂 TABLE: {table}")
|
| 14 |
+
# Get columns for this table
|
| 15 |
+
columns = db.get_table_schema(table)
|
| 16 |
+
for col in columns:
|
| 17 |
+
print(f" - {col}")
|
| 18 |
+
print("\n-------------------------------")
|
| 19 |
+
except Exception as e:
|
| 20 |
+
print(f"❌ Error reading schema: {e}")
|
db.sql
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
create database chatbot;
|
| 2 |
+
use chatbot;
|
| 3 |
+
|
| 4 |
+
CREATE USER 'bot_user'@'%' IDENTIFIED BY 'YourSecurePassword123!';
|
| 5 |
+
GRANT SELECT ON chatbot.* TO 'bot_user'@'%';
|
| 6 |
+
FLUSH PRIVILEGES;
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
USE chatbot;
|
| 10 |
+
|
| 11 |
+
CREATE TABLE employees (
|
| 12 |
+
id INT AUTO_INCREMENT PRIMARY KEY,
|
| 13 |
+
name VARCHAR(100),
|
| 14 |
+
department VARCHAR(50),
|
| 15 |
+
salary DECIMAL(10,2),
|
| 16 |
+
hire_date DATE
|
| 17 |
+
);
|
| 18 |
+
|
| 19 |
+
CREATE TABLE sales (
|
| 20 |
+
sale_id INT AUTO_INCREMENT PRIMARY KEY,
|
| 21 |
+
employee_id INT,
|
| 22 |
+
amount DECIMAL(10,2),
|
| 23 |
+
sale_date DATE,
|
| 24 |
+
FOREIGN KEY (employee_id) REFERENCES employees(id)
|
| 25 |
+
);
|
| 26 |
+
|
| 27 |
+
-- Insert a little dummy data
|
| 28 |
+
INSERT INTO employees (name, department, salary, hire_date) VALUES
|
| 29 |
+
('Alice', 'Sales', 70000, '2023-01-15'),
|
| 30 |
+
('Bob', 'Engineering', 90000, '2022-05-20');
|
| 31 |
+
|
| 32 |
+
INSERT INTO sales (employee_id, amount, sale_date) VALUES
|
| 33 |
+
(1, 500.00, '2023-06-01'),
|
| 34 |
+
(1, 1200.50, '2023-06-03');
|
db_connector.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pymysql
|
| 2 |
+
import os
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from urllib.parse import urlparse, unquote
|
| 5 |
+
|
| 6 |
+
# ✅ FIX 1: Class renamed to 'Database' (matches api_server.py)
|
| 7 |
+
class Database:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
load_dotenv()
|
| 10 |
+
db_uri = os.getenv("DB_URI")
|
| 11 |
+
|
| 12 |
+
if not db_uri:
|
| 13 |
+
raise ValueError("❌ DB_URI is missing from .env file")
|
| 14 |
+
|
| 15 |
+
parsed = urlparse(db_uri)
|
| 16 |
+
self.host = parsed.hostname
|
| 17 |
+
self.port = parsed.port or 3306
|
| 18 |
+
self.user = parsed.username
|
| 19 |
+
self.password = unquote(parsed.password)
|
| 20 |
+
self.db_name = parsed.path[1:]
|
| 21 |
+
|
| 22 |
+
def get_connection(self):
|
| 23 |
+
return pymysql.connect(
|
| 24 |
+
host=self.host,
|
| 25 |
+
user=self.user,
|
| 26 |
+
password=self.password,
|
| 27 |
+
database=self.db_name,
|
| 28 |
+
port=self.port,
|
| 29 |
+
cursorclass=pymysql.cursors.DictCursor
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# ✅ FIX 2: Method renamed to 'run_query' (matches api_server.py)
|
| 33 |
+
def run_query(self, query):
|
| 34 |
+
conn = self.get_connection()
|
| 35 |
+
try:
|
| 36 |
+
with conn.cursor() as cursor:
|
| 37 |
+
cursor.execute(query)
|
| 38 |
+
return cursor.fetchall()
|
| 39 |
+
except Exception as e:
|
| 40 |
+
return [f"Error: {e}"]
|
| 41 |
+
finally:
|
| 42 |
+
conn.close()
|
| 43 |
+
|
| 44 |
+
def get_tables(self):
|
| 45 |
+
"""Returns a list of all table names."""
|
| 46 |
+
conn = self.get_connection()
|
| 47 |
+
try:
|
| 48 |
+
with conn.cursor() as cursor:
|
| 49 |
+
cursor.execute("SHOW TABLES")
|
| 50 |
+
return [list(row.values())[0] for row in cursor.fetchall()]
|
| 51 |
+
except Exception as e:
|
| 52 |
+
return []
|
| 53 |
+
finally:
|
| 54 |
+
conn.close()
|
| 55 |
+
|
| 56 |
+
def get_table_schema(self, table_name):
|
| 57 |
+
"""Returns column details for a specific table."""
|
| 58 |
+
conn = self.get_connection()
|
| 59 |
+
try:
|
| 60 |
+
with conn.cursor() as cursor:
|
| 61 |
+
cursor.execute(f"DESCRIBE {table_name}")
|
| 62 |
+
columns = []
|
| 63 |
+
for row in cursor.fetchall():
|
| 64 |
+
columns.append(f"{row['Field']} ({row['Type']})")
|
| 65 |
+
return columns
|
| 66 |
+
except Exception:
|
| 67 |
+
return []
|
| 68 |
+
finally:
|
| 69 |
+
conn.close()
|
| 70 |
+
|
| 71 |
+
# ✅ FIX 3: Added 'get_schema()' (no args) for the RAG system
|
| 72 |
+
def get_schema(self):
|
| 73 |
+
"""Generates a full text schema of the database for the AI."""
|
| 74 |
+
tables = self.get_tables()
|
| 75 |
+
schema_text = ""
|
| 76 |
+
|
| 77 |
+
for table in tables:
|
| 78 |
+
columns = self.get_table_schema(table)
|
| 79 |
+
schema_text += f"Table: {table}\nColumns:\n"
|
| 80 |
+
for col in columns:
|
| 81 |
+
schema_text += f" - {col}\n"
|
| 82 |
+
schema_text += "\n"
|
| 83 |
+
|
| 84 |
+
return schema_text
|
index.html
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>PlainSQL | Enterprise Data Assistant</title>
|
| 7 |
+
<script src="https://cdn.tailwindcss.com"></script>
|
| 8 |
+
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
|
| 9 |
+
<link href="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/themes/prism-tomorrow.min.css" rel="stylesheet" />
|
| 10 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/prism.min.js"></script>
|
| 11 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/prism/1.29.0/components/prism-sql.min.js"></script>
|
| 12 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&family=JetBrains+Mono:wght@400;500&display=swap" rel="stylesheet">
|
| 13 |
+
|
| 14 |
+
<script>
|
| 15 |
+
tailwind.config = {
|
| 16 |
+
theme: {
|
| 17 |
+
extend: {
|
| 18 |
+
fontFamily: { sans: ['Inter', 'sans-serif'], mono: ['JetBrains Mono', 'monospace'] },
|
| 19 |
+
colors: {
|
| 20 |
+
dark: { 900: '#0B0C10', 800: '#15171E', 700: '#1F222E' },
|
| 21 |
+
brand: { 500: '#38bdf8', 400: '#0ea5e9' }
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
</script>
|
| 27 |
+
<style>
|
| 28 |
+
body { background-color: #0B0C10; color: #E2E8F0; }
|
| 29 |
+
.scrollbar-custom::-webkit-scrollbar { width: 8px; }
|
| 30 |
+
.scrollbar-custom::-webkit-scrollbar-track { background: #15171E; }
|
| 31 |
+
.scrollbar-custom::-webkit-scrollbar-thumb { background: #334155; border-radius: 4px; }
|
| 32 |
+
|
| 33 |
+
.bubble-user { background: linear-gradient(135deg, #38bdf8 0%, #0284c7 100%); color: white; border-radius: 16px 16px 4px 16px; }
|
| 34 |
+
.bubble-ai { background: #1F222E; border: 1px solid #2D3142; border-radius: 16px 16px 16px 4px; }
|
| 35 |
+
|
| 36 |
+
.custom-table th { background: #262A3B; color: #94A3B8; font-size: 0.75rem; padding: 12px; text-transform: uppercase; letter-spacing: 0.05em; }
|
| 37 |
+
.custom-table td { border-bottom: 1px solid #1F222E; color: #CBD5E1; padding: 12px; font-size: 0.9rem; }
|
| 38 |
+
.custom-table tr:last-child td { border-bottom: none; }
|
| 39 |
+
|
| 40 |
+
@keyframes fadeIn { from { opacity: 0; transform: translateY(10px); } to { opacity: 1; transform: translateY(0); } }
|
| 41 |
+
.animate-fade { animation: fadeIn 0.3s ease-out forwards; }
|
| 42 |
+
.typing-cursor::after { content: '▋'; animation: blink 1s step-start infinite; color: #38bdf8; margin-left: 2px; }
|
| 43 |
+
@keyframes blink { 50% { opacity: 0; } }
|
| 44 |
+
|
| 45 |
+
.scrollbar-hide::-webkit-scrollbar { display: none; }
|
| 46 |
+
.scrollbar-hide { -ms-overflow-style: none; scrollbar-width: none; }
|
| 47 |
+
</style>
|
| 48 |
+
</head>
|
| 49 |
+
<body class="flex h-screen overflow-hidden selection:bg-brand-500 selection:text-white relative font-sans">
|
| 50 |
+
|
| 51 |
+
<audio id="sound-welcome" src="https://assets.mixkit.co/active_storage/sfx/2568/2568-preview.mp3" preload="auto"></audio>
|
| 52 |
+
<audio id="sound-message" src="https://assets.mixkit.co/active_storage/sfx/2346/2346-preview.mp3" preload="auto"></audio>
|
| 53 |
+
|
| 54 |
+
<div id="splash-screen" class="fixed inset-0 z-[100] bg-dark-900 flex flex-col items-center justify-center transition-opacity duration-700">
|
| 55 |
+
<div class="text-center space-y-6 animate-fade">
|
| 56 |
+
<div class="w-20 h-20 bg-brand-500 rounded-2xl flex items-center justify-center shadow-2xl shadow-brand-500/50 mx-auto mb-6">
|
| 57 |
+
<svg class="w-10 h-10 text-white" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 7v10c0 2.21 3.58 4 8 4s8-1.79 8-4V7M4 7c0 2.21 3.58 4 8 4s8-1.79 8-4M4 7c0-2.21 3.58-4 8-4s8 1.79 8 4m0 5c0 2.21-3.58 4-8 4s-8-1.79-8-4"/></svg>
|
| 58 |
+
</div>
|
| 59 |
+
<h1 class="text-4xl font-bold text-white tracking-tight">Plain<span class="text-brand-400">SQL</span></h1>
|
| 60 |
+
<p class="text-gray-400 text-sm">Enterprise Text-to-SQL Engine</p>
|
| 61 |
+
<button onclick="enterSystem()" class="mt-8 px-8 py-3 bg-white text-dark-900 font-bold rounded-full hover:bg-gray-100 transition-all shadow-lg hover:scale-105 transform">
|
| 62 |
+
Initialize System
|
| 63 |
+
</button>
|
| 64 |
+
</div>
|
| 65 |
+
</div>
|
| 66 |
+
|
| 67 |
+
<div id="chart-modal" class="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm hidden opacity-0 transition-opacity duration-300">
|
| 68 |
+
<div class="bg-dark-800 border border-dark-700 w-full max-w-4xl h-[600px] rounded-2xl shadow-2xl flex flex-col p-6 transform scale-95 transition-transform duration-300" id="chart-content">
|
| 69 |
+
<div class="flex justify-between items-center mb-4">
|
| 70 |
+
<h3 class="text-lg font-semibold text-white">Data Visualization</h3>
|
| 71 |
+
<button onclick="closeChart()" class="text-gray-500 hover:text-white p-2 text-xl">✕</button>
|
| 72 |
+
</div>
|
| 73 |
+
<div class="flex-1 relative w-full h-full"><canvas id="myChart"></canvas></div>
|
| 74 |
+
</div>
|
| 75 |
+
</div>
|
| 76 |
+
|
| 77 |
+
<aside class="w-64 bg-dark-800 border-r border-dark-700 flex flex-col hidden md:flex">
|
| 78 |
+
<div class="p-6 flex items-center gap-3">
|
| 79 |
+
<div class="w-8 h-8 bg-brand-500 rounded-lg flex items-center justify-center shadow-lg">
|
| 80 |
+
<svg class="w-5 h-5 text-white" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 7v10c0 2.21 3.58 4 8 4s8-1.79 8-4V7M4 7c0 2.21 3.58 4 8 4s8-1.79 8-4M4 7c0-2.21 3.58-4 8-4s8 1.79 8 4m0 5c0 2.21-3.58 4-8 4s-8-1.79-8-4"/></svg>
|
| 81 |
+
</div>
|
| 82 |
+
<span class="font-bold text-xl tracking-tight">PlainSQL</span>
|
| 83 |
+
</div>
|
| 84 |
+
<nav class="flex-1 px-4 space-y-2 overflow-y-auto scrollbar-custom">
|
| 85 |
+
<button onclick="window.location.reload()" class="w-full flex items-center gap-3 px-4 py-3 bg-brand-500/10 text-brand-400 rounded-xl border border-brand-500/20 hover:bg-brand-500/20 transition-all group">
|
| 86 |
+
<svg class="w-5 h-5 group-hover:rotate-90 transition-transform" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M12 6v6m0 0v6m0-6h6m-6 0H6"/></svg>
|
| 87 |
+
<span class="text-sm font-medium">New Analysis</span>
|
| 88 |
+
</button>
|
| 89 |
+
<div class="mt-8">
|
| 90 |
+
<p class="px-4 text-xs font-semibold text-gray-500 uppercase tracking-wider mb-2">Recent Queries</p>
|
| 91 |
+
<div id="history-list" class="space-y-1"></div>
|
| 92 |
+
</div>
|
| 93 |
+
</nav>
|
| 94 |
+
<div class="p-4 border-t border-dark-700">
|
| 95 |
+
<div class="flex items-center gap-3 px-3 py-2 bg-green-500/10 rounded-lg border border-green-500/20">
|
| 96 |
+
<div class="w-2 h-2 bg-green-500 rounded-full animate-pulse"></div>
|
| 97 |
+
<div class="flex flex-col">
|
| 98 |
+
<span class="text-[11px] font-bold text-green-400 uppercase">System Online</span>
|
| 99 |
+
<span class="text-[10px] text-gray-500">Read-Only Mode</span>
|
| 100 |
+
</div>
|
| 101 |
+
</div>
|
| 102 |
+
</div>
|
| 103 |
+
</aside>
|
| 104 |
+
|
| 105 |
+
<main class="flex-1 flex flex-col relative bg-dark-900">
|
| 106 |
+
<div id="chat-box" class="flex-1 overflow-y-auto p-4 md:p-8 space-y-6 pb-80 scroll-smooth scrollbar-custom">
|
| 107 |
+
<div class="flex gap-4 max-w-3xl mx-auto animate-fade">
|
| 108 |
+
<div class="w-8 h-8 rounded-full bg-brand-500/20 flex items-center justify-center flex-shrink-0 border border-brand-500/30 text-brand-400">
|
| 109 |
+
<svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"/></svg>
|
| 110 |
+
</div>
|
| 111 |
+
<div class="bubble-ai p-5 shadow-sm">
|
| 112 |
+
<p class="text-sm leading-relaxed text-gray-200">
|
| 113 |
+
Hello. I am <strong>PlainSQL</strong>. I can access your database securely to fetch real-time insights.<br><br>
|
| 114 |
+
<em>Try asking: "Show me top 5 employees by salary" or "List all active users."</em>
|
| 115 |
+
</p>
|
| 116 |
+
</div>
|
| 117 |
+
</div>
|
| 118 |
+
</div>
|
| 119 |
+
|
| 120 |
+
<div class="absolute bottom-0 w-full p-4 md:p-6 bg-gradient-to-t from-dark-900 via-dark-900 to-transparent z-20">
|
| 121 |
+
<div class="max-w-3xl mx-auto">
|
| 122 |
+
<div id="suggestions" class="flex gap-2 mb-3 overflow-x-auto pb-1 scrollbar-hide"></div>
|
| 123 |
+
<form id="chat-form" class="relative group">
|
| 124 |
+
<div class="absolute inset-0 bg-brand-500/20 rounded-2xl blur-lg group-hover:bg-brand-500/30 transition-all opacity-0 group-hover:opacity-100"></div>
|
| 125 |
+
<input type="text" id="user-input"
|
| 126 |
+
class="relative w-full bg-dark-800 text-white border border-dark-700 rounded-2xl py-4 pl-5 pr-14 focus:outline-none focus:border-brand-500 focus:ring-1 focus:ring-brand-500 transition-all placeholder-gray-500 shadow-xl"
|
| 127 |
+
placeholder="Ask a question in plain English..." autocomplete="off">
|
| 128 |
+
<button type="submit" class="absolute right-2 top-2 p-2 bg-brand-500 hover:bg-brand-400 text-white rounded-xl transition-all shadow-lg active:scale-95">
|
| 129 |
+
<svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M14 5l7 7m0 0l-7 7m7-7H3"/></svg>
|
| 130 |
+
</button>
|
| 131 |
+
</form>
|
| 132 |
+
<div class="text-center mt-2">
|
| 133 |
+
<p class="text-[10px] text-gray-600">AI Generated SQL can be inaccurate. Always verify important data.</p>
|
| 134 |
+
</div>
|
| 135 |
+
</div>
|
| 136 |
+
</div>
|
| 137 |
+
</main>
|
| 138 |
+
|
| 139 |
+
<script>
|
| 140 |
+
const API_URL = "http://127.0.0.1:8000/chat";
|
| 141 |
+
|
| 142 |
+
const form = document.getElementById('chat-form');
|
| 143 |
+
const input = document.getElementById('user-input');
|
| 144 |
+
const chatBox = document.getElementById('chat-box');
|
| 145 |
+
const suggestionsBox = document.getElementById('suggestions');
|
| 146 |
+
const historyList = document.getElementById('history-list');
|
| 147 |
+
const soundWelcome = document.getElementById('sound-welcome');
|
| 148 |
+
const soundMessage = document.getElementById('sound-message');
|
| 149 |
+
const splashScreen = document.getElementById('splash-screen');
|
| 150 |
+
const modal = document.getElementById('chart-modal');
|
| 151 |
+
const modalContent = document.getElementById('chart-content');
|
| 152 |
+
|
| 153 |
+
let chartInstance = null;
|
| 154 |
+
let conversationHistory = [];
|
| 155 |
+
|
| 156 |
+
function unlockAudio() {
|
| 157 |
+
soundMessage.volume = 0;
|
| 158 |
+
soundMessage.play().then(() => {
|
| 159 |
+
soundMessage.pause();
|
| 160 |
+
soundMessage.currentTime = 0;
|
| 161 |
+
}).catch(() => {});
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
function playIncomingSound() {
|
| 165 |
+
soundMessage.volume = 0.4;
|
| 166 |
+
soundMessage.currentTime = 0;
|
| 167 |
+
soundMessage.play().catch(e => console.warn("Audio blocked:", e));
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
function enterSystem() {
|
| 171 |
+
soundWelcome.volume = 0.5;
|
| 172 |
+
soundWelcome.play().catch(e => console.log("Init Audio Error:", e));
|
| 173 |
+
splashScreen.style.opacity = '0';
|
| 174 |
+
setTimeout(() => { splashScreen.style.display = 'none'; }, 700);
|
| 175 |
+
setTimeout(() => input.focus(), 800);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
function checkGreeting(text) {
|
| 179 |
+
const t = text.toLowerCase();
|
| 180 |
+
const greetings = ['hello', 'hi', 'hey', 'good morning', 'good afternoon', 'hola'];
|
| 181 |
+
|
| 182 |
+
if (greetings.some(g => t === g || t.startsWith(g + ' '))) {
|
| 183 |
+
return "Hello! 👋 I am **PlainSQL**, your data assistant. I'm here to help you query your database without writing code.<br><br>You can ask me things like: <em>'Show me top 5 employees by salary'</em> or <em>'List all active users.'</em>";
|
| 184 |
+
}
|
| 185 |
+
return null;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
// FIX 2: IMPROVED SCROLL FUNCTION
|
| 189 |
+
// Uses setTimeout to ensure DOM is fully rendered before scrolling
|
| 190 |
+
function scrollToBottom() {
|
| 191 |
+
setTimeout(() => {
|
| 192 |
+
chatBox.scrollTop = chatBox.scrollHeight;
|
| 193 |
+
}, 50);
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
async function typeText(element, text) {
|
| 197 |
+
element.classList.add('typing-cursor');
|
| 198 |
+
return new Promise(resolve => {
|
| 199 |
+
setTimeout(() => {
|
| 200 |
+
element.innerHTML = text;
|
| 201 |
+
element.classList.remove('typing-cursor');
|
| 202 |
+
resolve();
|
| 203 |
+
}, 300 + Math.random() * 500);
|
| 204 |
+
});
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
function showLoading() {
|
| 208 |
+
const id = 'loading-' + Date.now();
|
| 209 |
+
const html = `
|
| 210 |
+
<div id="${id}" class="flex gap-4 max-w-3xl mx-auto animate-fade">
|
| 211 |
+
<div class="w-8 h-8 rounded-full bg-brand-500/20 flex items-center justify-center text-brand-400">
|
| 212 |
+
<svg class="w-4 h-4 animate-spin" fill="none" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>
|
| 213 |
+
</div>
|
| 214 |
+
<div class="bubble-ai p-4 flex gap-2 items-center">
|
| 215 |
+
<span class="text-xs text-gray-400">Analyzing database schema...</span>
|
| 216 |
+
</div>
|
| 217 |
+
</div>`;
|
| 218 |
+
chatBox.insertAdjacentHTML('beforeend', html);
|
| 219 |
+
scrollToBottom();
|
| 220 |
+
return id;
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
async function appendMessage(role, content, isHtml = false) {
|
| 224 |
+
const wrapper = document.createElement('div');
|
| 225 |
+
wrapper.className = `flex gap-4 max-w-3xl mx-auto animate-fade ${role === 'user' ? 'justify-end' : ''}`;
|
| 226 |
+
|
| 227 |
+
const aiIcon = `<div class="w-8 h-8 rounded-full bg-brand-500/20 flex items-center justify-center flex-shrink-0 border border-brand-500/30 text-brand-400"><svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 10V3L4 14h7v7l9-11h-7z"/></svg></div>`;
|
| 228 |
+
const userIcon = `<div class="w-8 h-8 rounded-full bg-gray-700 flex items-center justify-center flex-shrink-0 text-gray-300"><svg class="w-5 h-5" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M16 7a4 4 0 11-8 0 4 4 0 018 0zM12 14a7 7 0 00-7 7h14a7 7 0 00-7-7z"/></svg></div>`;
|
| 229 |
+
|
| 230 |
+
const bubble = document.createElement('div');
|
| 231 |
+
bubble.className = role === 'user' ? 'bubble-user p-4 max-w-[85%] shadow-lg text-sm' : 'bubble-ai p-5 max-w-[90%] shadow-sm w-full text-sm leading-relaxed';
|
| 232 |
+
|
| 233 |
+
if (role === 'user') {
|
| 234 |
+
bubble.textContent = content;
|
| 235 |
+
wrapper.innerHTML = bubble.outerHTML + userIcon;
|
| 236 |
+
} else {
|
| 237 |
+
wrapper.innerHTML = aiIcon + bubble.outerHTML;
|
| 238 |
+
}
|
| 239 |
+
chatBox.appendChild(wrapper);
|
| 240 |
+
|
| 241 |
+
if (role === 'ai') {
|
| 242 |
+
playIncomingSound();
|
| 243 |
+
if (isHtml) {
|
| 244 |
+
wrapper.querySelector('.bubble-ai').innerHTML = content;
|
| 245 |
+
Prism.highlightAll();
|
| 246 |
+
} else {
|
| 247 |
+
await typeText(wrapper.querySelector('.bubble-ai'), content);
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
scrollToBottom();
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
function renderSuggestions(questions) {
|
| 254 |
+
suggestionsBox.innerHTML = '';
|
| 255 |
+
if (!questions) return;
|
| 256 |
+
questions.forEach((q, i) => {
|
| 257 |
+
const btn = document.createElement('button');
|
| 258 |
+
btn.className = "whitespace-nowrap px-4 py-1.5 bg-dark-800 border border-dark-700 hover:border-brand-500 hover:text-brand-400 text-gray-400 text-xs font-medium rounded-full transition-all animate-fade";
|
| 259 |
+
btn.style.animationDelay = `${i * 0.1}s`;
|
| 260 |
+
btn.textContent = q;
|
| 261 |
+
btn.onclick = () => { input.value = q; form.dispatchEvent(new Event('submit')); };
|
| 262 |
+
suggestionsBox.appendChild(btn);
|
| 263 |
+
});
|
| 264 |
+
scrollToBottom();
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
function addToHistory(question) {
|
| 268 |
+
const btn = document.createElement('button');
|
| 269 |
+
btn.className = "w-full text-left px-4 py-2 text-xs text-gray-400 hover:text-white hover:bg-dark-700 rounded-lg transition-colors truncate animate-fade";
|
| 270 |
+
btn.textContent = question;
|
| 271 |
+
btn.onclick = () => { input.value = question; form.dispatchEvent(new Event('submit')); };
|
| 272 |
+
historyList.prepend(btn);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
// FIX 3: EXPORT CSV FUNCTION
|
| 276 |
+
function downloadCSV(dataString) {
|
| 277 |
+
try {
|
| 278 |
+
const data = JSON.parse(decodeURIComponent(dataString));
|
| 279 |
+
if (!data || !data.length) return;
|
| 280 |
+
|
| 281 |
+
const headers = Object.keys(data[0]);
|
| 282 |
+
const csvRows = [];
|
| 283 |
+
csvRows.push(headers.join(','));
|
| 284 |
+
|
| 285 |
+
for (const row of data) {
|
| 286 |
+
const values = headers.map(header => {
|
| 287 |
+
const escaped = ('' + row[header]).replace(/"/g, '\\"');
|
| 288 |
+
return `"${escaped}"`;
|
| 289 |
+
});
|
| 290 |
+
csvRows.push(values.join(','));
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
const blob = new Blob([csvRows.join('\n')], { type: 'text/csv' });
|
| 294 |
+
const url = window.URL.createObjectURL(blob);
|
| 295 |
+
const a = document.createElement('a');
|
| 296 |
+
a.setAttribute('hidden', '');
|
| 297 |
+
a.setAttribute('href', url);
|
| 298 |
+
a.setAttribute('download', 'data_export.csv');
|
| 299 |
+
document.body.appendChild(a);
|
| 300 |
+
a.click();
|
| 301 |
+
document.body.removeChild(a);
|
| 302 |
+
} catch (e) {
|
| 303 |
+
console.error("Export failed", e);
|
| 304 |
+
}
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
form.addEventListener('submit', async (e) => {
|
| 308 |
+
e.preventDefault();
|
| 309 |
+
const question = input.value.trim();
|
| 310 |
+
if (!question) return;
|
| 311 |
+
|
| 312 |
+
unlockAudio();
|
| 313 |
+
suggestionsBox.innerHTML = '';
|
| 314 |
+
input.value = '';
|
| 315 |
+
|
| 316 |
+
await appendMessage('user', question);
|
| 317 |
+
addToHistory(question);
|
| 318 |
+
|
| 319 |
+
const greetingResponse = checkGreeting(question);
|
| 320 |
+
if (greetingResponse) {
|
| 321 |
+
setTimeout(() => appendMessage('ai', greetingResponse, true), 500);
|
| 322 |
+
return;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
const loadingId = showLoading();
|
| 326 |
+
|
| 327 |
+
try {
|
| 328 |
+
const res = await fetch(`${API_URL}?ts=${Date.now()}`, {
|
| 329 |
+
method: 'POST',
|
| 330 |
+
headers: { 'Content-Type': 'application/json' },
|
| 331 |
+
body: JSON.stringify({ question, history: conversationHistory })
|
| 332 |
+
});
|
| 333 |
+
|
| 334 |
+
if (!res.ok) throw new Error("Backend connection failed.");
|
| 335 |
+
|
| 336 |
+
const data = await res.json();
|
| 337 |
+
document.getElementById(loadingId).remove();
|
| 338 |
+
|
| 339 |
+
if (data.sql && !data.sql.includes("Error")) {
|
| 340 |
+
conversationHistory.push({ user: question, sql: data.sql });
|
| 341 |
+
if(conversationHistory.length > 5) conversationHistory.shift();
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
let content = "";
|
| 345 |
+
if (data.message) content += `<div class="mb-4">${data.message}</div>`;
|
| 346 |
+
|
| 347 |
+
if (Array.isArray(data.answer) && data.answer.length > 0) {
|
| 348 |
+
const firstRow = data.answer[0];
|
| 349 |
+
if (typeof firstRow === 'string' && (firstRow.toLowerCase().includes("error"))) {
|
| 350 |
+
content += `<div class="p-4 bg-red-500/10 border border-red-500/20 rounded-lg text-red-400 font-mono text-xs mb-3">⚠️ ${firstRow}</div>`;
|
| 351 |
+
} else {
|
| 352 |
+
const headers = Object.keys(firstRow);
|
| 353 |
+
const dataStr = encodeURIComponent(JSON.stringify(data.answer));
|
| 354 |
+
|
| 355 |
+
content += `
|
| 356 |
+
<div class="overflow-hidden rounded-xl border border-dark-700 shadow-xl mb-3 bg-[#15171E]">
|
| 357 |
+
<div class="overflow-x-auto">
|
| 358 |
+
<table class="w-full text-left custom-table">
|
| 359 |
+
<thead><tr>${headers.map(h => `<th>${h}</th>`).join('')}</tr></thead>
|
| 360 |
+
<tbody>${data.answer.map(row => `<tr>${headers.map(h => `<td>${row[h]}</td>`).join('')}</tr>`).join('')}</tbody>
|
| 361 |
+
</table>
|
| 362 |
+
</div>
|
| 363 |
+
</div>
|
| 364 |
+
<div class="flex gap-2 mb-4">
|
| 365 |
+
<button onclick="openChart('${dataStr}')" class="flex items-center gap-2 px-3 py-2 bg-brand-500 text-white text-xs font-bold rounded-lg hover:bg-brand-400 transition-colors">
|
| 366 |
+
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M11 3.055A9.001 9.001 0 1020.945 13H11V3.055z"/></svg>
|
| 367 |
+
Visualize
|
| 368 |
+
</button>
|
| 369 |
+
<button onclick="downloadCSV('${dataStr}')" class="flex items-center gap-2 px-3 py-2 bg-dark-700 border border-dark-600 text-gray-300 text-xs font-bold rounded-lg hover:bg-dark-600 transition-colors">
|
| 370 |
+
<svg class="w-4 h-4" fill="none" viewBox="0 0 24 24" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M4 16v1a3 3 0 003 3h10a3 3 0 003-3v-1m-4-4l-4 4m0 0l-4-4m4 4V4"/></svg>
|
| 371 |
+
Export CSV
|
| 372 |
+
</button>
|
| 373 |
+
</div>
|
| 374 |
+
`;
|
| 375 |
+
}
|
| 376 |
+
} else if (Array.isArray(data.answer) && data.answer.length === 0) {
|
| 377 |
+
content += `<div class="p-4 bg-yellow-500/5 border border-yellow-500/10 rounded-lg text-yellow-500/80 text-xs mb-3">No records found matching query.</div>`;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
if (data.sql) {
|
| 381 |
+
content += `
|
| 382 |
+
<div class="relative group mt-2">
|
| 383 |
+
<div class="absolute -top-3 left-3 bg-dark-700 px-2 text-[10px] text-gray-400 rounded border border-dark-600">Generated SQL</div>
|
| 384 |
+
<pre class="!m-0 !p-4 !bg-[#0d1117] !text-xs rounded-xl border border-dark-700/50"><code class="language-sql">${data.sql}</code></pre>
|
| 385 |
+
</div>`;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
if (content) await appendMessage('ai', content, true);
|
| 389 |
+
if (data.follow_ups) renderSuggestions(data.follow_ups);
|
| 390 |
+
|
| 391 |
+
} catch (err) {
|
| 392 |
+
document.getElementById(loadingId)?.remove();
|
| 393 |
+
await appendMessage('ai', `<div class="p-4 bg-red-900/20 border border-red-500/30 rounded-lg text-red-400 text-sm">Error: ${err.message}</div>`, true);
|
| 394 |
+
}
|
| 395 |
+
});
|
| 396 |
+
|
| 397 |
+
function openChart(dataString) {
|
| 398 |
+
const data = JSON.parse(decodeURIComponent(dataString));
|
| 399 |
+
const headers = Object.keys(data[0]);
|
| 400 |
+
let labelKey = headers.find(h => isNaN(data[0][h])) || headers[0];
|
| 401 |
+
let valueKey = headers.find(h => !isNaN(data[0][h]) && h !== labelKey) || headers[1];
|
| 402 |
+
const labels = data.map(row => row[labelKey]);
|
| 403 |
+
const values = data.map(row => row[valueKey]);
|
| 404 |
+
|
| 405 |
+
modal.classList.remove('hidden');
|
| 406 |
+
setTimeout(() => { modal.classList.remove('opacity-0'); modalContent.classList.remove('scale-95'); modalContent.classList.add('scale-100'); }, 10);
|
| 407 |
+
if (chartInstance) chartInstance.destroy();
|
| 408 |
+
|
| 409 |
+
const ctx = document.getElementById('myChart').getContext('2d');
|
| 410 |
+
chartInstance = new Chart(ctx, {
|
| 411 |
+
type: labels.length > 8 ? 'bar' : 'doughnut',
|
| 412 |
+
data: {
|
| 413 |
+
labels: labels,
|
| 414 |
+
datasets: [{
|
| 415 |
+
label: valueKey.toUpperCase(),
|
| 416 |
+
data: values,
|
| 417 |
+
backgroundColor: ['#38bdf8', '#a855f7', '#ec4899', '#22c55e', '#eab308'],
|
| 418 |
+
borderColor: '#15171E', borderWidth: 2
|
| 419 |
+
}]
|
| 420 |
+
},
|
| 421 |
+
options: { responsive: true, maintainAspectRatio: false, plugins: { legend: { position: 'bottom', labels: { color: '#94A3B8' } } }, scales: { x: { display: false }, y: { display: false } } }
|
| 422 |
+
});
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
function closeChart() {
|
| 426 |
+
modal.classList.add('opacity-0'); modalContent.classList.remove('scale-100'); modalContent.classList.add('scale-95'); setTimeout(() => { modal.classList.add('hidden'); }, 300);
|
| 427 |
+
}
|
| 428 |
+
</script>
|
| 429 |
+
</body>
|
| 430 |
+
</html>
|
rag_manager.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chromadb
|
| 2 |
+
from src.db_connector import Database
|
| 3 |
+
|
| 4 |
+
class RAGSystem:
|
| 5 |
+
def __init__(self, db_instance=None):
|
| 6 |
+
# ✅ FIX: Accept the DB connection passed from the server
|
| 7 |
+
if db_instance:
|
| 8 |
+
self.db = db_instance
|
| 9 |
+
else:
|
| 10 |
+
self.db = Database()
|
| 11 |
+
|
| 12 |
+
# Initialize ChromaDB
|
| 13 |
+
print(" ...Connecting to ChromaDB")
|
| 14 |
+
self.client = chromadb.PersistentClient(path="./chroma_db")
|
| 15 |
+
self.collection = self.client.get_or_create_collection(name="schema_knowledge")
|
| 16 |
+
|
| 17 |
+
# Refresh memory
|
| 18 |
+
self._index_schema()
|
| 19 |
+
|
| 20 |
+
def _index_schema(self):
|
| 21 |
+
"""Reads the database structure and saves it to ChromaDB."""
|
| 22 |
+
try:
|
| 23 |
+
tables = self.db.get_tables()
|
| 24 |
+
|
| 25 |
+
if self.collection.count() > 0:
|
| 26 |
+
existing_ids = self.collection.get()['ids']
|
| 27 |
+
if existing_ids:
|
| 28 |
+
self.collection.delete(ids=existing_ids)
|
| 29 |
+
|
| 30 |
+
for table in tables:
|
| 31 |
+
columns = self.db.get_table_schema(table)
|
| 32 |
+
col_list = []
|
| 33 |
+
for col in columns:
|
| 34 |
+
if isinstance(col, dict):
|
| 35 |
+
col_list.append(f"{col['name']} ({col['type']})")
|
| 36 |
+
else:
|
| 37 |
+
col_list.append(str(col))
|
| 38 |
+
|
| 39 |
+
schema_text = f"Table: {table}\nColumns: {', '.join(col_list)}"
|
| 40 |
+
self.collection.add(
|
| 41 |
+
documents=[schema_text],
|
| 42 |
+
metadatas=[{"table": table}],
|
| 43 |
+
ids=[table]
|
| 44 |
+
)
|
| 45 |
+
print(f" ✅ RAG System: Indexed {len(tables)} tables.")
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f" ⚠️ RAG Indexing Warning: {e}")
|
| 49 |
+
|
| 50 |
+
def get_relevant_schema(self, question):
|
| 51 |
+
try:
|
| 52 |
+
results = self.collection.query(query_texts=[question], n_results=3)
|
| 53 |
+
if results['documents']:
|
| 54 |
+
return "\n\n".join(results['documents'][0])
|
| 55 |
+
return ""
|
| 56 |
+
except Exception:
|
| 57 |
+
return self._get_full_schema_fallback()
|
| 58 |
+
|
| 59 |
+
def _get_full_schema_fallback(self):
|
| 60 |
+
tables = self.db.get_tables()
|
| 61 |
+
schema = []
|
| 62 |
+
for table in tables:
|
| 63 |
+
cols = self.db.get_table_schema(table)
|
| 64 |
+
col_list = [c['name'] if isinstance(c, dict) else str(c) for c in cols]
|
| 65 |
+
schema.append(f"Table: {table}\nColumns: {', '.join(col_list)}")
|
| 66 |
+
return "\n\n".join(schema)
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Framework
|
| 2 |
+
langchain==0.3.0
|
| 3 |
+
langchain-community
|
| 4 |
+
langchain-core
|
| 5 |
+
|
| 6 |
+
# Google Gemini Integration
|
| 7 |
+
langchain-google-genai
|
| 8 |
+
google-generativeai
|
| 9 |
+
|
| 10 |
+
# Database & Vector Store
|
| 11 |
+
sqlalchemy
|
| 12 |
+
pymysql
|
| 13 |
+
chromadb
|
| 14 |
+
langchain-chroma
|
| 15 |
+
|
| 16 |
+
# Utilities
|
| 17 |
+
python-dotenv
|
| 18 |
+
streamlit
|
| 19 |
+
|
| 20 |
+
# Critical Version Fixes
|
| 21 |
+
numpy<2.0.0
|
setup_full_db.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pymysql
|
| 2 |
+
import random
|
| 3 |
+
from faker import Faker
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import os
|
| 6 |
+
from urllib.parse import urlparse, unquote # <--- Added 'unquote' here
|
| 7 |
+
|
| 8 |
+
# --- CONFIGURATION ---
|
| 9 |
+
NUM_EMPLOYEES = 50
|
| 10 |
+
NUM_CUSTOMERS = 100
|
| 11 |
+
NUM_PRODUCTS = 20
|
| 12 |
+
NUM_SALES = 1000
|
| 13 |
+
|
| 14 |
+
# Load credentials
|
| 15 |
+
load_dotenv()
|
| 16 |
+
db_uri = os.getenv("DB_URI")
|
| 17 |
+
|
| 18 |
+
# Parse URI
|
| 19 |
+
parsed = urlparse(db_uri)
|
| 20 |
+
username = parsed.username
|
| 21 |
+
# FIX: 'unquote' converts 'Lalit%40851' back to 'Lalit@851'
|
| 22 |
+
password = unquote(parsed.password)
|
| 23 |
+
host = parsed.hostname
|
| 24 |
+
port = parsed.port
|
| 25 |
+
# FIX: Handle cases where path is empty or just slash
|
| 26 |
+
dbname = parsed.path[1:] if parsed.path else "chatbot"
|
| 27 |
+
|
| 28 |
+
print(f"--- 🏭 INITIALIZING BUSINESS SIMULATOR for DB: {dbname} ---")
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
# Connect without selecting a DB first (to create it if missing)
|
| 32 |
+
conn = pymysql.connect(host=host, user=username, password=password, port=port)
|
| 33 |
+
cursor = conn.cursor()
|
| 34 |
+
|
| 35 |
+
# 1. CREATE DATABASE & TABLES
|
| 36 |
+
print("1. Rebuilding Schema...")
|
| 37 |
+
cursor.execute(f"CREATE DATABASE IF NOT EXISTS {dbname}")
|
| 38 |
+
cursor.execute(f"USE {dbname}")
|
| 39 |
+
|
| 40 |
+
cursor.execute("SET FOREIGN_KEY_CHECKS = 0")
|
| 41 |
+
for t in ["sales", "employees", "products", "customers", "departments"]:
|
| 42 |
+
cursor.execute(f"DROP TABLE IF EXISTS {t}")
|
| 43 |
+
cursor.execute("SET FOREIGN_KEY_CHECKS = 1")
|
| 44 |
+
|
| 45 |
+
queries = [
|
| 46 |
+
"CREATE TABLE departments (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(50) UNIQUE, budget DECIMAL(15, 2), location VARCHAR(100))",
|
| 47 |
+
"CREATE TABLE employees (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100), email VARCHAR(100), department_id INT, role VARCHAR(50), salary DECIMAL(10, 2), hire_date DATE, FOREIGN KEY (department_id) REFERENCES departments(id))",
|
| 48 |
+
"CREATE TABLE products (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100), category VARCHAR(50), price DECIMAL(10, 2), stock_quantity INT)",
|
| 49 |
+
"CREATE TABLE customers (id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100), company VARCHAR(100), region VARCHAR(50), join_date DATE)",
|
| 50 |
+
"CREATE TABLE sales (id INT AUTO_INCREMENT PRIMARY KEY, employee_id INT, customer_id INT, product_id INT, quantity INT, total_amount DECIMAL(10, 2), sale_date DATE, FOREIGN KEY (employee_id) REFERENCES employees(id), FOREIGN KEY (customer_id) REFERENCES customers(id), FOREIGN KEY (product_id) REFERENCES products(id))"
|
| 51 |
+
]
|
| 52 |
+
|
| 53 |
+
for q in queries:
|
| 54 |
+
cursor.execute(q)
|
| 55 |
+
|
| 56 |
+
# 2. GENERATE DATA
|
| 57 |
+
fake = Faker()
|
| 58 |
+
print("2. Manufacturing Data...")
|
| 59 |
+
|
| 60 |
+
# Departments
|
| 61 |
+
depts = ["Sales", "Engineering", "HR", "Marketing", "Executive"]
|
| 62 |
+
dept_ids = []
|
| 63 |
+
for d in depts:
|
| 64 |
+
cursor.execute("INSERT INTO departments (name, budget, location) VALUES (%s, %s, %s)", (d, random.randint(50000, 1000000), fake.city()))
|
| 65 |
+
dept_ids.append(cursor.lastrowid)
|
| 66 |
+
|
| 67 |
+
# Employees
|
| 68 |
+
emp_ids = []
|
| 69 |
+
roles = ["Manager", "Associate", "Analyst", "Director", "Intern"]
|
| 70 |
+
for _ in range(NUM_EMPLOYEES):
|
| 71 |
+
cursor.execute("INSERT INTO employees (name, email, department_id, role, salary, hire_date) VALUES (%s, %s, %s, %s, %s, %s)",
|
| 72 |
+
(fake.name(), fake.email(), random.choice(dept_ids), random.choice(roles), random.randint(40000, 150000), fake.date_between(start_date='-5y', end_date='today')))
|
| 73 |
+
emp_ids.append(cursor.lastrowid)
|
| 74 |
+
|
| 75 |
+
# Products
|
| 76 |
+
prod_ids = []
|
| 77 |
+
for _ in range(NUM_PRODUCTS):
|
| 78 |
+
cursor.execute("INSERT INTO products (name, category, price, stock_quantity) VALUES (%s, %s, %s, %s)",
|
| 79 |
+
(fake.bs().title(), random.choice(["Software", "Hardware", "Service"]), round(random.uniform(50, 5000), 2), random.randint(0, 500)))
|
| 80 |
+
prod_ids.append(cursor.lastrowid)
|
| 81 |
+
|
| 82 |
+
# Customers
|
| 83 |
+
cust_ids = []
|
| 84 |
+
for _ in range(NUM_CUSTOMERS):
|
| 85 |
+
cursor.execute("INSERT INTO customers (name, company, region, join_date) VALUES (%s, %s, %s, %s)",
|
| 86 |
+
(fake.name(), fake.company(), random.choice(["North America", "Europe", "Asia", "South America"]), fake.date_between(start_date='-3y', end_date='today')))
|
| 87 |
+
cust_ids.append(cursor.lastrowid)
|
| 88 |
+
|
| 89 |
+
# Sales
|
| 90 |
+
print(f" -> Generating {NUM_SALES} sales transactions...")
|
| 91 |
+
for _ in range(NUM_SALES):
|
| 92 |
+
prod = random.choice(prod_ids)
|
| 93 |
+
cursor.execute("SELECT price FROM products WHERE id=%s", (prod,))
|
| 94 |
+
price = cursor.fetchone()[0]
|
| 95 |
+
qty = random.randint(1, 10)
|
| 96 |
+
cursor.execute("INSERT INTO sales (employee_id, customer_id, product_id, quantity, total_amount, sale_date) VALUES (%s, %s, %s, %s, %s, %s)",
|
| 97 |
+
(random.choice(emp_ids), random.choice(cust_ids), prod, qty, price * qty, fake.date_between(start_date='-1y', end_date='today')))
|
| 98 |
+
|
| 99 |
+
conn.commit()
|
| 100 |
+
conn.close()
|
| 101 |
+
print("✅ DONE! Database is populated.")
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"\n❌ CRITICAL ERROR: {e}")
|
| 105 |
+
print("Double check your password in .env!")
|
sql_generator.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
from huggingface_hub import InferenceClient
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
class SQLGenerator:
|
| 8 |
+
def __init__(self, api_key=None):
|
| 9 |
+
load_dotenv()
|
| 10 |
+
self.api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
| 11 |
+
self.repo_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
|
| 12 |
+
self.client = InferenceClient(token=self.api_token, timeout=25.0)
|
| 13 |
+
|
| 14 |
+
def generate_followup_questions(self, question, sql_query):
|
| 15 |
+
return ["Visualize this result", "Export as CSV", "Compare with last year"]
|
| 16 |
+
|
| 17 |
+
def generate_sql(self, question, context, history=None):
|
| 18 |
+
if history is None: history = []
|
| 19 |
+
|
| 20 |
+
forbidden = ["DROP", "DELETE", "UPDATE", "INSERT", "ALTER", "TRUNCATE", "GRANT"]
|
| 21 |
+
if any(word in question.upper() for word in forbidden):
|
| 22 |
+
return "SELECT 'Error: Blocked by Safety Layer' as status", "Safety Alert", "I cannot execute commands that modify data."
|
| 23 |
+
|
| 24 |
+
history_text = ""
|
| 25 |
+
if history:
|
| 26 |
+
history_text = "PREVIOUS CONVERSATION:\n" + "\n".join([f"User: {h['user']}\nSQL: {h['sql']}" for h in history[-2:]])
|
| 27 |
+
|
| 28 |
+
system_prompt = f"""You are an elite SQL Expert.
|
| 29 |
+
Schema:
|
| 30 |
+
{context}
|
| 31 |
+
|
| 32 |
+
{history_text}
|
| 33 |
+
|
| 34 |
+
Rules:
|
| 35 |
+
1. Output JSON: {{ "sql": "SELECT ...", "message": "Friendly text", "explanation": "Brief summary" }}
|
| 36 |
+
2. Query MUST be Read-Only (SELECT).
|
| 37 |
+
3. Do not include markdown formatting like ```json.
|
| 38 |
+
"""
|
| 39 |
+
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": question}]
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
print(f" ⚡ Generating SQL...")
|
| 43 |
+
response = self.client.chat_completion(messages=messages, model=self.repo_id, max_tokens=1024, temperature=0.1)
|
| 44 |
+
raw_text = response.choices[0].message.content
|
| 45 |
+
|
| 46 |
+
sql_query = ""
|
| 47 |
+
message = "Here is the data."
|
| 48 |
+
explanation = "Query generated successfully."
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
clean_json = re.sub(r"```json|```", "", raw_text).strip()
|
| 52 |
+
data = json.loads(clean_json)
|
| 53 |
+
sql_query = data.get("sql", "")
|
| 54 |
+
message = data.get("message", message)
|
| 55 |
+
explanation = data.get("explanation", explanation)
|
| 56 |
+
except:
|
| 57 |
+
match = re.search(r"(SELECT[\s\S]+?;)", raw_text, re.IGNORECASE)
|
| 58 |
+
if match: sql_query = match.group(1)
|
| 59 |
+
|
| 60 |
+
sql_query = sql_query.strip().replace("\n", " ")
|
| 61 |
+
if sql_query and not sql_query.endswith(";"): sql_query += ";"
|
| 62 |
+
|
| 63 |
+
# ✅ FIX: Strip comments and whitespace before validation
|
| 64 |
+
clean_check = re.sub(r"/\*.*?\*/|--.*?\n", "", sql_query, flags=re.DOTALL).strip().upper()
|
| 65 |
+
|
| 66 |
+
# ✅ FIX: Allow SELECT or WITH clauses
|
| 67 |
+
if not clean_check.startswith("SELECT") and not clean_check.startswith("WITH"):
|
| 68 |
+
print(f" ⚠️ Invalid SQL Blocked: {sql_query}")
|
| 69 |
+
return "SELECT 'Error: Invalid Query Type (Non-SELECT)' as status", "Safety Error", "I can only perform read-only operations."
|
| 70 |
+
|
| 71 |
+
return sql_query, explanation, message
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f" ❌ Model Error: {e}")
|
| 75 |
+
safe_e = str(e).replace("'", "").replace('"', "")
|
| 76 |
+
return f"SELECT 'Error: {safe_e}' as status", "System Error", "An unexpected error occurred."
|