Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import sqlite3 | |
| import streamlit as st | |
| import tempfile | |
| from model_api import query_hf_model # Hugging Face API wrapper | |
| # ----------------------------- | |
| # Connect to SQLite | |
| # ----------------------------- | |
| def connect_sqlite(db_path="data.db"): | |
| try: | |
| conn = sqlite3.connect(db_path, check_same_thread=False) | |
| return conn | |
| except Exception as e: | |
| st.error(f"Error connecting to SQLite: {e}") | |
| return None | |
| # ----------------------------- | |
| # Load CSV to SQLite | |
| # ----------------------------- | |
| def load_csv_to_sqlite(file, conn): | |
| try: | |
| df = pd.read_csv(file) | |
| df.to_sql("csv_data", conn, if_exists="replace", index=False) | |
| st.success("CSV data loaded into SQLite successfully.") | |
| except Exception as e: | |
| st.error(f"Error loading CSV: {e}") | |
| # ----------------------------- | |
| # Generate SQL query using HF API | |
| # ----------------------------- | |
| def generate_query(user_input, conn): | |
| try: | |
| # Get column names from SQLite | |
| cursor = conn.cursor() | |
| cursor.execute("PRAGMA table_info(csv_data)") | |
| field_names = [row[1] for row in cursor.fetchall()] | |
| # Build prompt for HF model | |
| prompt = f""" | |
| You are a MySQL expert. Only respond with a MySQL SELECT query in this exact format: | |
| SELECT column1, column2 FROM csv_data WHERE condition; | |
| Rules: | |
| - Use only these fields from the 'csv_data' table: {field_names} | |
| - All field names are case-sensitive. | |
| - String values must be in single quotes. | |
| - For GROUP BY queries, do not include non-aggregated columns in SELECT unless they are also in GROUP BY. | |
| - If extracting year from a date column (e.g., LaunchDate), use the YEAR() function in MySQL. | |
| - Assume all dates are stored as TEXT in the format 'YYYY-MM-DD'. | |
| User request: | |
| \"{user_input}\" | |
| """ | |
| # Read HF token from Streamlit secrets | |
| hf_token = st.secrets["HF_TOKEN"] | |
| if not hf_token: | |
| st.error("HF_TOKEN not found. Please add it in Streamlit secrets.") | |
| return None | |
| # Call Hugging Face API | |
| query = query_hf_model(prompt, hf_token) | |
| # Safety check: only allow SELECT queries | |
| if not query.lower().strip().startswith("select"): | |
| st.error("Generated SQL is not a SELECT query. Aborting.") | |
| return None | |
| return query | |
| except Exception as e: | |
| st.error(f"Error generating query: {e}") | |
| return None | |
| # ----------------------------- | |
| # Execute SQL query on SQLite | |
| # ----------------------------- | |
| def execute_query(query, conn): | |
| try: | |
| df = pd.read_sql_query(query, conn) | |
| return df | |
| except Exception as e: | |
| st.error(f"Query execution error: {e}") | |
| return pd.DataFrame() | |
| # ----------------------------- | |
| # Streamlit app | |
| # ----------------------------- | |
| def main(): | |
| st.title("Natural Language to SQL Query and Output Generator (SQLite + HF API)") | |
| # Connect to SQLite | |
| conn = connect_sqlite() | |
| if not conn: | |
| st.stop() | |
| # Upload CSV | |
| uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
| if uploaded_file: | |
| with tempfile.NamedTemporaryFile(delete=False) as tmp: | |
| tmp.write(uploaded_file.read()) | |
| tmp_path = tmp.name | |
| load_csv_to_sqlite(tmp_path, conn) | |
| st.markdown("---") | |
| user_input = st.text_input("Ask your query (in plain English):") | |
| if user_input: | |
| query = generate_query(user_input, conn) | |
| if query: | |
| st.code(query, language="sql") | |
| data = execute_query(query, conn) | |
| if not data.empty: | |
| st.dataframe(data) | |
| # Download result as CSV | |
| csv = data.to_csv(index=False).encode("utf-8") | |
| st.download_button("Download Result as CSV", csv, "result.csv", "text/csv") | |
| else: | |
| st.warning("No matching records found.") | |
| else: | |
| st.error("Could not generate a valid SQL query.") | |
| if __name__ == "__main__": | |
| main() | |