| | import streamlit as st |
| | from openai import OpenAI |
| | import sqlite3 |
| | from init_db import initialize_database |
| |
|
| | |
| | initialize_database() |
| |
|
| | |
| | st.set_page_config(page_title="Zero SQL", layout="wide") |
| | st.title("Zero SQL - Natural Language to SQL Query") |
| |
|
| | |
| | with st.sidebar: |
| | st.header("API Configuration") |
| | api_key = st.text_input("OpenAI API Key", type="password") |
| |
|
| | |
| | with st.form("query_form"): |
| | user_input = st.text_area( |
| | "Enter your data request in natural language:", |
| | placeholder="e.g. Show all orders from last week", |
| | height=150 |
| | ) |
| | submitted = st.form_submit_button("Generate Query") |
| |
|
| | if submitted: |
| | if not api_key: |
| | st.error("🔑 API key is required!") |
| | elif not user_input: |
| | st.error("📝 Please enter your data request!") |
| | else: |
| | try: |
| | |
| | client = OpenAI(api_key=api_key) |
| | |
| | |
| | system_context = """Given the following SQL tables, your job is to write queries given a user's request. |
| | CREATE TABLE Produkte ( |
| | ProduktID INTEGER PRIMARY KEY AUTOINCREMENT, |
| | Produktname TEXT NOT NULL, |
| | Preis REAL NOT NULL |
| | ); |
| | |
| | CREATE TABLE Bestellungen ( |
| | BestellungID INTEGER PRIMARY KEY AUTOINCREMENT, |
| | ProduktID INTEGER NOT NULL, |
| | Menge INTEGER NOT NULL, |
| | Bestelldatum TEXT NOT NULL, |
| | Person TEXT NOT NULL, |
| | FOREIGN KEY (ProduktID) REFERENCES Produkte(ProduktID) |
| | );""" |
| | |
| | |
| | response = client.chat.completions.create( |
| | model="gpt-4o", |
| | messages=[ |
| | {"role": "system", "content": system_context}, |
| | {"role": "user", "content": f"Generate the SQL query for: {user_input}. Only output the raw SQL query without any code block delimiters or markdown."} |
| | ], |
| | response_format={"type": "text"} |
| | ) |
| | |
| | sql_query = response.choices[0].message.content.strip() |
| |
|
| | |
| | conn = sqlite3.connect('database.db') |
| | cursor = conn.cursor() |
| | cursor.execute(sql_query) |
| | |
| | results = cursor.fetchall() |
| | column_names = [description[0] for description in cursor.description] |
| | conn.close() |
| |
|
| | |
| | st.subheader("Generated SQL Query") |
| | st.code(sql_query, language="sql") |
| |
|
| | st.subheader("Query Results") |
| | if results: |
| | st.dataframe( |
| | data=results, |
| | columns=column_names, |
| | use_container_width=True, |
| | hide_index=True |
| | ) |
| | else: |
| | st.info("No results found", icon="ℹ️") |
| | |
| | except sqlite3.Error as e: |
| | st.error(f"SQL Error: {str(e)}") |
| | except Exception as e: |
| | st.error(f"An error occurred: {str(e)}") |