import pandas as pd import sqlite3 import streamlit as st import tempfile from model_api import query_hf_model # Hugging Face API wrapper # ----------------------------- # Connect to SQLite # ----------------------------- @st.cache_resource def connect_sqlite(db_path="data.db"): try: conn = sqlite3.connect(db_path, check_same_thread=False) return conn except Exception as e: st.error(f"Error connecting to SQLite: {e}") return None # ----------------------------- # Load CSV to SQLite # ----------------------------- def load_csv_to_sqlite(file, conn): try: df = pd.read_csv(file) df.to_sql("csv_data", conn, if_exists="replace", index=False) st.success("CSV data loaded into SQLite successfully.") except Exception as e: st.error(f"Error loading CSV: {e}") # ----------------------------- # Generate SQL query using HF API # ----------------------------- def generate_query(user_input, conn): try: # Get column names from SQLite cursor = conn.cursor() cursor.execute("PRAGMA table_info(csv_data)") field_names = [row[1] for row in cursor.fetchall()] # Build prompt for HF model prompt = f""" You are a MySQL expert. Only respond with a MySQL SELECT query in this exact format: SELECT column1, column2 FROM csv_data WHERE condition; Rules: - Use only these fields from the 'csv_data' table: {field_names} - All field names are case-sensitive. - String values must be in single quotes. - For GROUP BY queries, do not include non-aggregated columns in SELECT unless they are also in GROUP BY. - If extracting year from a date column (e.g., LaunchDate), use the YEAR() function in MySQL. - Assume all dates are stored as TEXT in the format 'YYYY-MM-DD'. User request: \"{user_input}\" """ # Read HF token from Streamlit secrets hf_token = st.secrets["HF_TOKEN"] if not hf_token: st.error("HF_TOKEN not found. Please add it in Streamlit secrets.") return None # Call Hugging Face API query = query_hf_model(prompt, hf_token) # Safety check: only allow SELECT queries if not query.lower().strip().startswith("select"): st.error("Generated SQL is not a SELECT query. Aborting.") return None return query except Exception as e: st.error(f"Error generating query: {e}") return None # ----------------------------- # Execute SQL query on SQLite # ----------------------------- def execute_query(query, conn): try: df = pd.read_sql_query(query, conn) return df except Exception as e: st.error(f"Query execution error: {e}") return pd.DataFrame() # ----------------------------- # Streamlit app # ----------------------------- def main(): st.title("Natural Language to SQL Query and Output Generator (SQLite + HF API)") # Connect to SQLite conn = connect_sqlite() if not conn: st.stop() # Upload CSV uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) if uploaded_file: with tempfile.NamedTemporaryFile(delete=False) as tmp: tmp.write(uploaded_file.read()) tmp_path = tmp.name load_csv_to_sqlite(tmp_path, conn) st.markdown("---") user_input = st.text_input("Ask your query (in plain English):") if user_input: query = generate_query(user_input, conn) if query: st.code(query, language="sql") data = execute_query(query, conn) if not data.empty: st.dataframe(data) # Download result as CSV csv = data.to_csv(index=False).encode("utf-8") st.download_button("Download Result as CSV", csv, "result.csv", "text/csv") else: st.warning("No matching records found.") else: st.error("Could not generate a valid SQL query.") if __name__ == "__main__": main()