|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
from sqlalchemy import create_engine, text |
|
|
import openai |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY") |
|
|
|
|
|
|
|
|
DB_TYPE = "mysql+pymysql" |
|
|
DB_USER = "username" |
|
|
DB_PASS = "password" |
|
|
DB_HOST = "host" |
|
|
DB_PORT = "3306" |
|
|
DB_NAME = "db_name" |
|
|
|
|
|
DATABASE_URL = f"{DB_TYPE}://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{DB_NAME}" |
|
|
engine = create_engine(DATABASE_URL) |
|
|
|
|
|
|
|
|
def generate_sql(user_question, table_names=[]): |
|
|
""" |
|
|
Generates SQL query from user question using OpenAI GPT |
|
|
""" |
|
|
table_info = "" |
|
|
if table_names: |
|
|
table_info = f"These are your tables: {table_names}\n" |
|
|
|
|
|
prompt = f""" |
|
|
You are an expert SQL generator. |
|
|
{table_info} |
|
|
Write a SQL query that answers the following question: |
|
|
\"\"\"{user_question}\"\"\" |
|
|
Only return SQL, do not explain. |
|
|
""" |
|
|
response = openai.Completion.create( |
|
|
engine="text-davinci-003", |
|
|
prompt=prompt, |
|
|
temperature=0, |
|
|
max_tokens=300 |
|
|
) |
|
|
sql_query = response.choices[0].text.strip() |
|
|
return sql_query |
|
|
|
|
|
def run_query(sql_query): |
|
|
""" |
|
|
Runs SQL query using SQLAlchemy |
|
|
""" |
|
|
try: |
|
|
with engine.connect() as conn: |
|
|
result = pd.read_sql(text(sql_query), conn) |
|
|
return result |
|
|
except Exception as e: |
|
|
return f"Error executing query: {e}" |
|
|
|
|
|
|
|
|
st.title("🧠 AI SQL Assistant") |
|
|
st.markdown("Ask a question about your database, and it will generate SQL and show results.") |
|
|
|
|
|
user_question = st.text_input("Enter your question:") |
|
|
|
|
|
if st.button("Run Query") and user_question: |
|
|
with st.spinner("Generating SQL..."): |
|
|
sql_query = generate_sql(user_question) |
|
|
st.code(sql_query, language="sql") |
|
|
|
|
|
with st.spinner("Executing SQL..."): |
|
|
result = run_query(sql_query) |
|
|
if isinstance(result, pd.DataFrame): |
|
|
st.success("Query executed successfully!") |
|
|
st.dataframe(result) |
|
|
else: |
|
|
st.error(result) |
|
|
|