texttoSQL / app.py
souvik16011991roy's picture
Upload 4 files
972aab5 verified
import streamlit as st
import sqlite3
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import database
import json
# Initialize database
database.init_database()
# Get schema information
schema_info = database.get_schema_info()
# Initialize the model and tokenizer
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf")
return model, tokenizer
def create_schema_prompt():
prompt = "Database Schema:\n"
for table, info in schema_info.items():
prompt += f"\nTable: {table}\n"
prompt += "Columns:\n"
for col, type_ in zip(info['columns'], info['types']):
sample_values = info['sample_values'][col][:3] # Take first 3 sample values
prompt += f"- {col} ({type_}), Example values: {', '.join(map(str, sample_values))}\n"
return prompt
def generate_sql_query(question):
model, tokenizer = load_model()
# Create detailed prompt with schema information
schema_prompt = create_schema_prompt()
prompt = f"""Given the following database schema and question, generate a SQL query that answers the question.
{schema_prompt}
Question: {question}
Write only the SQL query without any additional text or explanation. Make sure to:
1. Use the correct table and column names as shown in the schema
2. Handle joins appropriately if multiple tables are needed
3. Use appropriate SQL functions based on the question context
SQL Query:"""
# Generate SQL query
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
outputs = model.generate(
**inputs,
max_length=500,
num_return_sequences=1,
temperature=0.7,
top_p=0.95,
do_sample=True
)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the SQL part
sql_query = sql_query.split("SQL Query:")[-1].strip()
return sql_query
def execute_query(query):
conn = sqlite3.connect('data.db')
try:
result = pd.read_sql_query(query, conn)
return result, None
except Exception as e:
return None, str(e)
finally:
conn.close()
# Streamlit UI
st.title("Intelligent Text to SQL Query Assistant")
st.write("Ask questions about your data in natural language!")
# Display schema information in expandable section
with st.expander("View Database Schema"):
st.code(create_schema_prompt(), language="text")
# User input
user_question = st.text_area("Enter your question:", height=100)
if st.button("Generate and Execute Query"):
if user_question:
with st.spinner("Generating SQL query..."):
# Generate SQL query
sql_query = generate_sql_query(user_question)
# Display the generated query
st.subheader("Generated SQL Query:")
st.code(sql_query, language="sql")
# Execute the query
with st.spinner("Executing query..."):
results, error = execute_query(sql_query)
if error:
st.error(f"Error executing query: {error}")
else:
st.subheader("Query Results:")
st.dataframe(results)