|
|
import streamlit as st |
|
|
import sqlite3 |
|
|
import pandas as pd |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import database |
|
|
import json |
|
|
|
|
|
|
|
|
database.init_database() |
|
|
|
|
|
|
|
|
schema_info = database.get_schema_info() |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model(): |
|
|
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf") |
|
|
model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf") |
|
|
return model, tokenizer |
|
|
|
|
|
def create_schema_prompt(): |
|
|
prompt = "Database Schema:\n" |
|
|
for table, info in schema_info.items(): |
|
|
prompt += f"\nTable: {table}\n" |
|
|
prompt += "Columns:\n" |
|
|
for col, type_ in zip(info['columns'], info['types']): |
|
|
sample_values = info['sample_values'][col][:3] |
|
|
prompt += f"- {col} ({type_}), Example values: {', '.join(map(str, sample_values))}\n" |
|
|
return prompt |
|
|
|
|
|
def generate_sql_query(question): |
|
|
model, tokenizer = load_model() |
|
|
|
|
|
|
|
|
schema_prompt = create_schema_prompt() |
|
|
prompt = f"""Given the following database schema and question, generate a SQL query that answers the question. |
|
|
|
|
|
{schema_prompt} |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Write only the SQL query without any additional text or explanation. Make sure to: |
|
|
1. Use the correct table and column names as shown in the schema |
|
|
2. Handle joins appropriately if multiple tables are needed |
|
|
3. Use appropriate SQL functions based on the question context |
|
|
|
|
|
SQL Query:""" |
|
|
|
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True) |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_length=500, |
|
|
num_return_sequences=1, |
|
|
temperature=0.7, |
|
|
top_p=0.95, |
|
|
do_sample=True |
|
|
) |
|
|
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
sql_query = sql_query.split("SQL Query:")[-1].strip() |
|
|
return sql_query |
|
|
|
|
|
def execute_query(query): |
|
|
conn = sqlite3.connect('data.db') |
|
|
try: |
|
|
result = pd.read_sql_query(query, conn) |
|
|
return result, None |
|
|
except Exception as e: |
|
|
return None, str(e) |
|
|
finally: |
|
|
conn.close() |
|
|
|
|
|
|
|
|
st.title("Intelligent Text to SQL Query Assistant") |
|
|
st.write("Ask questions about your data in natural language!") |
|
|
|
|
|
|
|
|
with st.expander("View Database Schema"): |
|
|
st.code(create_schema_prompt(), language="text") |
|
|
|
|
|
|
|
|
user_question = st.text_area("Enter your question:", height=100) |
|
|
|
|
|
if st.button("Generate and Execute Query"): |
|
|
if user_question: |
|
|
with st.spinner("Generating SQL query..."): |
|
|
|
|
|
sql_query = generate_sql_query(user_question) |
|
|
|
|
|
|
|
|
st.subheader("Generated SQL Query:") |
|
|
st.code(sql_query, language="sql") |
|
|
|
|
|
|
|
|
with st.spinner("Executing query..."): |
|
|
results, error = execute_query(sql_query) |
|
|
|
|
|
if error: |
|
|
st.error(f"Error executing query: {error}") |
|
|
else: |
|
|
st.subheader("Query Results:") |
|
|
st.dataframe(results) |
|
|
|