Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import openai | |
| import sqlite3 | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import os | |
| from typing import Optional, Tuple | |
| import re | |
| # OpenRouter API Key (Replace with yours) | |
| OPENROUTER_API_KEY = "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" | |
| OPENROUTER_MODEL = "sophosympatheia/rogue-rose-103b-v0.2:free" | |
| # Hugging Face Space path | |
| DB_PATH = "ecommerce.db" | |
| # Ensure dataset exists | |
| if not os.path.exists(DB_PATH): | |
| os.system("wget https://your-dataset-link.com/ecommerce.db -O ecommerce.db") # Replace with actual dataset link | |
| # Initialize OpenAI client | |
| openai_client = openai.OpenAI(api_key=OPENROUTER_API_KEY, base_url="https://openrouter.ai/api/v1") | |
| # Function: Fetch database schema | |
| def fetch_schema(db_path: str) -> str: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| schema = "" | |
| for table in tables: | |
| table_name = table[0] | |
| cursor.execute(f"PRAGMA table_info({table_name});") | |
| columns = cursor.fetchall() | |
| schema += f"Table: {table_name}\n" | |
| for column in columns: | |
| schema += f" Column: {column[1]}, Type: {column[2]}\n" | |
| conn.close() | |
| return schema | |
| # Function: Extract SQL query from LLM response | |
| def extract_sql_query(response: str) -> str: | |
| # Use regex to find content between ```sql and ``` | |
| match = re.search(r"```sql(.*?)```", response, re.DOTALL) | |
| if match: | |
| return match.group(1).strip() # Extract and return the SQL query | |
| return response # Fallback: return the entire response if no SQL block is found | |
| # Function: Convert text to SQL | |
| def text_to_sql(query: str, schema: str) -> str: | |
| prompt = ( | |
| "You are an SQL expert. Given the following database schema:\n\n" | |
| f"{schema}\n\n" | |
| "Convert the following query into SQL:\n\n" | |
| f"Query: {query}\n" | |
| "SQL:" | |
| ) | |
| try: | |
| response = openai_client.chat.completions.create( | |
| model=OPENROUTER_MODEL, | |
| messages=[{"role": "system", "content": "You are an SQL expert."}, {"role": "user", "content": prompt}] | |
| ) | |
| sql_response = response.choices[0].message.content.strip() | |
| return extract_sql_query(sql_response) # Extract SQL query from the response | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def preprocess_sql_for_sqlite(sql_query: str) -> str: | |
| """ | |
| Replace non-SQLite functions with SQLite-compatible equivalents. | |
| """ | |
| sql_query = re.sub(r"\bMONTH\s*\(\s*([\w.]+)\s*\)", r"strftime('%m', \1)", sql_query) | |
| sql_query = re.sub(r"\bYEAR\s*\(\s*([\w.]+)\s*\)", r"strftime('%Y', \1)", sql_query) | |
| return sql_query | |
| def execute_sql(sql_query: str) -> Tuple[Optional[pd.DataFrame], Optional[str]]: | |
| try: | |
| conn = sqlite3.connect(DB_PATH) | |
| sql_query = preprocess_sql_for_sqlite(sql_query) # Convert to SQLite-compatible SQL | |
| df = pd.read_sql_query(sql_query, conn) | |
| conn.close() | |
| return df, None | |
| except Exception as e: | |
| return None, f"SQL Execution Error: {e}" | |
| # Function: Generate Dynamic Visualization | |
| def visualize_data(df: pd.DataFrame) -> Optional[str]: | |
| if df.empty or df.shape[1] < 2: | |
| return None | |
| plt.figure(figsize=(6, 4)) | |
| sns.set_theme(style="darkgrid") | |
| # Detect numeric columns | |
| numeric_cols = df.select_dtypes(include=['number']).columns | |
| if len(numeric_cols) < 1: | |
| return None | |
| # Choose visualization type dynamically | |
| if len(numeric_cols) == 1: # Single numeric column, assume it's a count metric | |
| sns.histplot(df[numeric_cols[0]], bins=10, kde=True, color="teal") | |
| plt.title(f"Distribution of {numeric_cols[0]}") | |
| elif len(numeric_cols) == 2: # Two numeric columns, assume X-Y plot | |
| sns.scatterplot(x=df[numeric_cols[0]], y=df[numeric_cols[1]], color="blue") | |
| plt.title(f"{numeric_cols[0]} vs {numeric_cols[1]}") | |
| elif df.shape[0] < 10: # If rows are few, prefer pie chart | |
| plt.pie(df[numeric_cols[0]], labels=df.iloc[:, 0], autopct='%1.1f%%', colors=sns.color_palette("pastel")) | |
| plt.title(f"Proportion of {numeric_cols[0]}") | |
| else: # Default: Bar chart for categories + values | |
| sns.barplot(x=df.iloc[:, 0], y=df[numeric_cols[0]], palette="coolwarm") | |
| plt.xticks(rotation=45) | |
| plt.title(f"{df.columns[0]} vs {numeric_cols[0]}") | |
| plt.tight_layout() | |
| plt.savefig("chart.png") | |
| return "chart.png" | |
| # Gradio UI | |
| def gradio_ui(query: str) -> Tuple[str, str, Optional[str]]: | |
| schema = fetch_schema(DB_PATH) | |
| sql_query = text_to_sql(query, schema) | |
| df, error = execute_sql(sql_query) | |
| if error: | |
| return sql_query, error, None | |
| visualization = visualize_data(df) if df is not None else None | |
| return sql_query, df.to_string(index=False), visualization | |
| # Launch Gradio App | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## SQL Explorer: Text-to-SQL with Real Execution & Visualization") | |
| query_input = gr.Textbox(label="Enter your query", placeholder="e.g., Show all products sold in 2018.") | |
| submit_btn = gr.Button("Convert & Execute") | |
| sql_output = gr.Textbox(label="Generated SQL Query") | |
| table_output = gr.Textbox(label="Query Results") | |
| chart_output = gr.Image(label="Data Visualization") | |
| submit_btn.click(gradio_ui, inputs=[query_input], outputs=[sql_output, table_output, chart_output]) | |
| demo.launch() |