Spaces:
Sleeping
Sleeping
File size: 4,163 Bytes
3743009 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import pandas as pd
import sqlite3
import streamlit as st
import tempfile
from model_api import query_hf_model # Hugging Face API wrapper
# -----------------------------
# Connect to SQLite
# -----------------------------
@st.cache_resource
def connect_sqlite(db_path="data.db"):
try:
conn = sqlite3.connect(db_path, check_same_thread=False)
return conn
except Exception as e:
st.error(f"Error connecting to SQLite: {e}")
return None
# -----------------------------
# Load CSV to SQLite
# -----------------------------
def load_csv_to_sqlite(file, conn):
try:
df = pd.read_csv(file)
df.to_sql("csv_data", conn, if_exists="replace", index=False)
st.success("CSV data loaded into SQLite successfully.")
except Exception as e:
st.error(f"Error loading CSV: {e}")
# -----------------------------
# Generate SQL query using HF API
# -----------------------------
def generate_query(user_input, conn):
try:
# Get column names from SQLite
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(csv_data)")
field_names = [row[1] for row in cursor.fetchall()]
# Build prompt for HF model
prompt = f"""
You are a MySQL expert. Only respond with a MySQL SELECT query in this exact format:
SELECT column1, column2 FROM csv_data WHERE condition;
Rules:
- Use only these fields from the 'csv_data' table: {field_names}
- All field names are case-sensitive.
- String values must be in single quotes.
- For GROUP BY queries, do not include non-aggregated columns in SELECT unless they are also in GROUP BY.
- If extracting year from a date column (e.g., LaunchDate), use the YEAR() function in MySQL.
- Assume all dates are stored as TEXT in the format 'YYYY-MM-DD'.
User request:
\"{user_input}\"
"""
# Read HF token from Streamlit secrets
hf_token = st.secrets["HF_TOKEN"]
if not hf_token:
st.error("HF_TOKEN not found. Please add it in Streamlit secrets.")
return None
# Call Hugging Face API
query = query_hf_model(prompt, hf_token)
# Safety check: only allow SELECT queries
if not query.lower().strip().startswith("select"):
st.error("Generated SQL is not a SELECT query. Aborting.")
return None
return query
except Exception as e:
st.error(f"Error generating query: {e}")
return None
# -----------------------------
# Execute SQL query on SQLite
# -----------------------------
def execute_query(query, conn):
try:
df = pd.read_sql_query(query, conn)
return df
except Exception as e:
st.error(f"Query execution error: {e}")
return pd.DataFrame()
# -----------------------------
# Streamlit app
# -----------------------------
def main():
st.title("Natural Language to SQL Query and Output Generator (SQLite + HF API)")
# Connect to SQLite
conn = connect_sqlite()
if not conn:
st.stop()
# Upload CSV
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if uploaded_file:
with tempfile.NamedTemporaryFile(delete=False) as tmp:
tmp.write(uploaded_file.read())
tmp_path = tmp.name
load_csv_to_sqlite(tmp_path, conn)
st.markdown("---")
user_input = st.text_input("Ask your query (in plain English):")
if user_input:
query = generate_query(user_input, conn)
if query:
st.code(query, language="sql")
data = execute_query(query, conn)
if not data.empty:
st.dataframe(data)
# Download result as CSV
csv = data.to_csv(index=False).encode("utf-8")
st.download_button("Download Result as CSV", csv, "result.csv", "text/csv")
else:
st.warning("No matching records found.")
else:
st.error("Could not generate a valid SQL query.")
if __name__ == "__main__":
main()
|