File size: 4,163 Bytes
3743009
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import pandas as pd
import sqlite3
import streamlit as st
import tempfile
from model_api import query_hf_model  # Hugging Face API wrapper

# -----------------------------
# Connect to SQLite
# -----------------------------
@st.cache_resource
def connect_sqlite(db_path="data.db"):
    try:
        conn = sqlite3.connect(db_path, check_same_thread=False)
        return conn
    except Exception as e:
        st.error(f"Error connecting to SQLite: {e}")
        return None

# -----------------------------
# Load CSV to SQLite
# -----------------------------
def load_csv_to_sqlite(file, conn):
    try:
        df = pd.read_csv(file)
        df.to_sql("csv_data", conn, if_exists="replace", index=False)
        st.success("CSV data loaded into SQLite successfully.")
    except Exception as e:
        st.error(f"Error loading CSV: {e}")

# -----------------------------
# Generate SQL query using HF API
# -----------------------------
def generate_query(user_input, conn):
    try:
        # Get column names from SQLite
        cursor = conn.cursor()
        cursor.execute("PRAGMA table_info(csv_data)")
        field_names = [row[1] for row in cursor.fetchall()]

        # Build prompt for HF model
        prompt = f"""

You are a MySQL expert. Only respond with a MySQL SELECT query in this exact format:

SELECT column1, column2 FROM csv_data WHERE condition;



Rules:

- Use only these fields from the 'csv_data' table: {field_names}

- All field names are case-sensitive.

- String values must be in single quotes.

- For GROUP BY queries, do not include non-aggregated columns in SELECT unless they are also in GROUP BY.

- If extracting year from a date column (e.g., LaunchDate), use the YEAR() function in MySQL.

- Assume all dates are stored as TEXT in the format 'YYYY-MM-DD'.



User request:

\"{user_input}\"

"""

        # Read HF token from Streamlit secrets
        hf_token = st.secrets["HF_TOKEN"]
        if not hf_token:
            st.error("HF_TOKEN not found. Please add it in Streamlit secrets.")
            return None

        # Call Hugging Face API
        query = query_hf_model(prompt, hf_token)

        # Safety check: only allow SELECT queries
        if not query.lower().strip().startswith("select"):
            st.error("Generated SQL is not a SELECT query. Aborting.")
            return None

        return query
    except Exception as e:
        st.error(f"Error generating query: {e}")
        return None

# -----------------------------
# Execute SQL query on SQLite
# -----------------------------
def execute_query(query, conn):
    try:
        df = pd.read_sql_query(query, conn)
        return df
    except Exception as e:
        st.error(f"Query execution error: {e}")
        return pd.DataFrame()

# -----------------------------
# Streamlit app
# -----------------------------
def main():
    st.title("Natural Language to SQL Query and Output Generator (SQLite + HF API)")

    # Connect to SQLite
    conn = connect_sqlite()
    if not conn:
        st.stop()

    # Upload CSV
    uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
    if uploaded_file:
        with tempfile.NamedTemporaryFile(delete=False) as tmp:
            tmp.write(uploaded_file.read())
            tmp_path = tmp.name
        load_csv_to_sqlite(tmp_path, conn)

    st.markdown("---")
    user_input = st.text_input("Ask your query (in plain English):")
    if user_input:
        query = generate_query(user_input, conn)
        if query:
            st.code(query, language="sql")
            data = execute_query(query, conn)
            if not data.empty:
                st.dataframe(data)

                # Download result as CSV
                csv = data.to_csv(index=False).encode("utf-8")
                st.download_button("Download Result as CSV", csv, "result.csv", "text/csv")
            else:
                st.warning("No matching records found.")
        else:
            st.error("Could not generate a valid SQL query.")

if __name__ == "__main__":
    main()