QueryBridge / app.py
Omkar1872's picture
Upload 5 files
3743009 verified
import pandas as pd
import sqlite3
import streamlit as st
import tempfile
from model_api import query_hf_model # Hugging Face API wrapper
# -----------------------------
# Connect to SQLite
# -----------------------------
@st.cache_resource
def connect_sqlite(db_path="data.db"):
try:
conn = sqlite3.connect(db_path, check_same_thread=False)
return conn
except Exception as e:
st.error(f"Error connecting to SQLite: {e}")
return None
# -----------------------------
# Load CSV to SQLite
# -----------------------------
def load_csv_to_sqlite(file, conn):
try:
df = pd.read_csv(file)
df.to_sql("csv_data", conn, if_exists="replace", index=False)
st.success("CSV data loaded into SQLite successfully.")
except Exception as e:
st.error(f"Error loading CSV: {e}")
# -----------------------------
# Generate SQL query using HF API
# -----------------------------
def generate_query(user_input, conn):
try:
# Get column names from SQLite
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(csv_data)")
field_names = [row[1] for row in cursor.fetchall()]
# Build prompt for HF model
prompt = f"""
You are a MySQL expert. Only respond with a MySQL SELECT query in this exact format:
SELECT column1, column2 FROM csv_data WHERE condition;
Rules:
- Use only these fields from the 'csv_data' table: {field_names}
- All field names are case-sensitive.
- String values must be in single quotes.
- For GROUP BY queries, do not include non-aggregated columns in SELECT unless they are also in GROUP BY.
- If extracting year from a date column (e.g., LaunchDate), use the YEAR() function in MySQL.
- Assume all dates are stored as TEXT in the format 'YYYY-MM-DD'.
User request:
\"{user_input}\"
"""
# Read HF token from Streamlit secrets
hf_token = st.secrets["HF_TOKEN"]
if not hf_token:
st.error("HF_TOKEN not found. Please add it in Streamlit secrets.")
return None
# Call Hugging Face API
query = query_hf_model(prompt, hf_token)
# Safety check: only allow SELECT queries
if not query.lower().strip().startswith("select"):
st.error("Generated SQL is not a SELECT query. Aborting.")
return None
return query
except Exception as e:
st.error(f"Error generating query: {e}")
return None
# -----------------------------
# Execute SQL query on SQLite
# -----------------------------
def execute_query(query, conn):
try:
df = pd.read_sql_query(query, conn)
return df
except Exception as e:
st.error(f"Query execution error: {e}")
return pd.DataFrame()
# -----------------------------
# Streamlit app
# -----------------------------
def main():
st.title("Natural Language to SQL Query and Output Generator (SQLite + HF API)")
# Connect to SQLite
conn = connect_sqlite()
if not conn:
st.stop()
# Upload CSV
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(uploaded_file.read())
tmp_path = tmp.name
load_csv_to_sqlite(tmp_path, conn)
st.markdown("---")
user_input = st.text_input("Ask your query (in plain English):")
if user_input:
query = generate_query(user_input, conn)
if query:
st.code(query, language="sql")
data = execute_query(query, conn)
if not data.empty:
st.dataframe(data)
# Download result as CSV
csv = data.to_csv(index=False).encode("utf-8")
st.download_button("Download Result as CSV", csv, "result.csv", "text/csv")
else:
st.warning("No matching records found.")
else:
st.error("Could not generate a valid SQL query.")
if __name__ == "__main__":
main()