| | |
| |
|
| | import os |
| | import streamlit as st |
| | import pandas as pd |
| | import subprocess |
| | import json |
| | import plotly.express as px |
| | import re |
| | import io |
| | import requests |
| | from sqlalchemy import create_engine, text, inspect |
| |
|
| | |
| | HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| |
|
| | |
| | def mistral_call(schema=None, question="no questions were asked", hf_token=HF_TOKEN, model_id="mistralai/Mistral-7B-Instruct-v0.3"): |
| | api_url = f"https://api-inference.huggingface.co/models/{model_id}" |
| | headers = { |
| | "Authorization": f"Bearer {hf_token}", |
| | "Content-Type": "application/json" |
| | } |
| | prompt = f"""You are a helpful assistant that translates natural language questions into SQL using a database schema. |
| | ### Schema: |
| | {schema} |
| | ### Question: |
| | {question} |
| | """ |
| | payload = { |
| | "inputs": prompt, |
| | "parameters": { |
| | "max_new_tokens": 500, |
| | "do_sample": True, |
| | "temperature": 0.3, |
| | } |
| | } |
| | response = requests.post(api_url, headers=headers, json=payload) |
| | if response.status_code == 200: |
| | try: |
| | generated = response.json()[0]['generated_text'] |
| | return generated.split("### Question:")[-1].strip() |
| | except Exception as e: |
| | return f"Error parsing response: {e}" |
| | else: |
| | return f"API call failed: {response.status_code}\n{response.text}" |
| |
|
| | |
| | def extract_json(text): |
| | match = re.search(r"\{.*?\}", text, re.DOTALL) |
| | if match: |
| | try: |
| | return json.loads(match.group(0)) |
| | except json.JSONDecodeError: |
| | return None |
| | return None |
| |
|
| | def get_visualization_suggestion(data): |
| | prompt = f""" |
| | These are the dataset column names: {list(data.columns)}. |
| | Suggest one visualization using the format: |
| | {{"x": "column", "y": "column or list", "chart_type": "bar/line/scatter/pie"}} |
| | """ |
| | response = mistral_call(question=prompt) |
| | return extract_json(response) |
| |
|
| | |
| | def generate_demo_data_csv(user_input, num_rows=10): |
| | prompt = f""" |
| | Generate a {num_rows}-row structured dataset in CSV format with quoted column headers and values: |
| | "{user_input}" |
| | """ |
| | response = mistral_call(question=prompt) |
| | csv_data = "\n".join([line.strip() for line in response.splitlines() if line.strip().startswith('"')]) |
| | if csv_data: |
| | try: |
| | df = pd.read_csv(io.StringIO(csv_data)) |
| | buffer = io.StringIO() |
| | df.to_csv(buffer, index=False) |
| | return "Demo data generated.", buffer |
| | except Exception as e: |
| | return f"CSV error: {e}", None |
| | return "No CSV found.", None |
| |
|
| | |
| | def extract_sql_code_blocks(text): |
| | return re.findall(r"```sql\s+(.*?)```", text, re.DOTALL | re.IGNORECASE) |
| |
|
| | def remove_think_tags(text): |
| | return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE) |
| |
|
| | def classify_sql_task_prompt_engineered(user_input: str) -> str: |
| | prompt = f""" |
| | Classify into: |
| | CREATE_TABLE, INSERT_INTO, SELECT, UPDATE, DELETE, ALTER_TABLE, INSERT_CSV_EXISTING, INSERT_CSV_NEW |
| | Input: {user_input} |
| | Only return the task. |
| | """ |
| | classification = mistral_call(question=prompt) |
| | cleaned = remove_think_tags(classification).strip().upper() |
| | for t in ["CREATE_TABLE", "INSERT_INTO", "SELECT", "UPDATE", "DELETE", "ALTER_TABLE", "INSERT_CSV_EXISTING", "INSERT_CSV_NEW"]: |
| | if t in cleaned: |
| | return t |
| | return "UNKNOWN" |
| |
|
| | def handle_query(user_input, engine, task_type): |
| | try: |
| | inspector = inspect(engine) |
| | tables = inspector.get_table_names() |
| | prompt = f"Generate {task_type} SQL for: {user_input} using tables: {tables}" |
| | sql_code = mistral_call(question=prompt) |
| | sql_code = extract_sql_code_blocks(sql_code) |
| | return execute_sql(sql_code, engine) |
| | except Exception as e: |
| | return "None", f"Error: {e}" |
| |
|
| | def execute_sql(sql_code, engine): |
| | try: |
| | if isinstance(sql_code, list): |
| | sql_code = "\n".join(sql_code) |
| | statements = [stmt.strip() for stmt in sql_code.split(';') if stmt.strip()] |
| | with engine.connect() as conn: |
| | for stmt in statements: |
| | conn.execute(text(stmt + ";")) |
| | conn.commit() |
| | return sql_code, "β
SQL executed." |
| | except Exception as e: |
| | return "None", f"SQL error: {e}" |
| |
|
| | def insert_csv_existing(table_name, csv_file, engine): |
| | try: |
| | df = pd.read_csv(csv_file) |
| | df.to_sql(table_name, engine, if_exists='append', index=False) |
| | return f"β
CSV inserted into '{table_name}'." |
| | except Exception as e: |
| | return f"CSV insert error: {e}" |
| |
|
| | def insert_csv_new(table_name, csv_file, engine): |
| | try: |
| | df = pd.read_csv(csv_file) |
| | df.to_sql(table_name, engine, if_exists='replace', index=False) |
| | return f"β
CSV inserted into new table '{table_name}'." |
| | except Exception as e: |
| | return f"New CSV insert error: {e}" |
| |
|
| | |
| | st.set_page_config(page_title="AI Dashboard", layout="wide") |
| | st.title("π€ AI-Powered Multi-Feature Dashboard") |
| |
|
| | st.sidebar.title("Navigation") |
| | option = st.sidebar.radio("Select Feature", ["π Data Visualization", "π§ SQL Query Generator", "π Demo Data Generator", "π§ Smart SQL Task Handler"]) |
| |
|
| | if option == "π Data Visualization": |
| | uploaded_file = st.file_uploader("Upload your CSV", type="csv") |
| | if uploaded_file: |
| | try: |
| | content = uploaded_file.getvalue().decode("utf-8") |
| | df = pd.read_csv(io.StringIO(content)) |
| | df.columns = df.columns.str.strip().str.replace(" ", "_") |
| | st.write("CSV Preview") |
| | st.dataframe(df.head()) |
| | st.write("Shape:", df.shape) |
| |
|
| | with st.spinner("Getting chart suggestion..."): |
| | suggestion = get_visualization_suggestion(df) |
| |
|
| | st.write("Model suggestion:") |
| | st.code(suggestion) |
| |
|
| | if suggestion: |
| | x_col = suggestion.get("x", "").strip() |
| | y_col = suggestion.get("y", []) |
| | y_col = [y_col] if isinstance(y_col, str) else y_col |
| | chart = suggestion.get("chart_type") |
| | if x_col in df.columns and all(y in df.columns for y in y_col): |
| | fig = None |
| | if chart == "bar": |
| | fig = px.bar(df, x=x_col, y=y_col) |
| | elif chart == "line": |
| | fig = px.line(df, x=x_col, y=y_col) |
| | elif chart == "scatter": |
| | fig = px.scatter(df, x=x_col, y=y_col) |
| | elif chart == "pie" and len(y_col) == 1: |
| | fig = px.pie(df, names=x_col, values=y_col[0]) |
| | if fig: |
| | st.plotly_chart(fig) |
| | else: |
| | st.error("Unsupported chart type.") |
| | else: |
| | st.error("β οΈ Column suggestion doesn't match your CSV.") |
| | else: |
| | st.error("β No valid visualization suggestion returned.") |
| | except Exception as e: |
| | st.error(f"β Error reading CSV: {e}") |
| |
|
| | elif option == "π§ SQL Query Generator": |
| | user_input = st.text_area("Describe your SQL query in plain English:") |
| | if st.button("Generate SQL"): |
| | st.code(mistral_call(question=user_input)) |
| |
|
| | elif option == "π Demo Data Generator": |
| | user_input = st.text_area("Describe your dataset:") |
| | num_rows = st.number_input("Rows", 1, 1000, 10) |
| | if st.button("Generate Dataset"): |
| | msg, buffer = generate_demo_data_csv(user_input, num_rows) |
| | st.write(msg) |
| | if buffer: |
| | st.download_button("Download CSV", buffer.getvalue(), file_name="generated_data.csv", mime="text/csv") |
| |
|
| | elif option == "π§ Smart SQL Task Handler": |
| | st.sidebar.header("DB Settings") |
| | db_type = "SQLite" |
| | db_path = st.sidebar.text_input("SQLite File Path", value="smart_sql.db") |
| | connection_url = f"sqlite:///{db_path}" |
| | try: |
| | engine = create_engine(connection_url) |
| | with engine.connect(): pass |
| | st.sidebar.success("Connected!") |
| | except Exception as e: |
| | st.sidebar.error(f"Connection failed: {e}") |
| | st.stop() |
| |
|
| | user_input = st.text_area("Enter SQL task (or natural language):") |
| | csv_file = st.file_uploader("Optional CSV Upload") |
| | table_name = st.text_input("Table name (for CSV):") |
| | if st.button("Run SQL Task"): |
| | task = classify_sql_task_prompt_engineered(user_input) |
| | st.markdown(f"**Detected Task:** `{task}`") |
| | if task == "INSERT_CSV_EXISTING" and csv_file and table_name: |
| | st.write(insert_csv_existing(table_name, csv_file, engine)) |
| | elif task == "INSERT_CSV_NEW" and csv_file and table_name: |
| | st.write(insert_csv_new(table_name, csv_file, engine)) |
| | else: |
| | sql_code, msg = handle_query(user_input, engine, task) |
| | st.code(sql_code) |
| | st.write(msg) |
| |
|