DataWhizz / app.py
DipenDRangwani's picture
Update app.py
3ae0f23 verified
import io # For BytesIO
import streamlit as st
import pandas as pd
import duckdb
import os
import re
import uuid
import logging
import traceback
import io # Add this for BytesIO
import plotly.express as px # Add this for Plotly
from openai import OpenAI
from typing import List, Dict, Any, Optional, Tuple
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("text-to-sql-app")
def main():
st.set_page_config(page_title="DataWhiz - Natural Language to SQL", layout="wide")
setup_ui_header()
client = setup_openai_client()
if client is None:
return
initialize_session_state()
# Simplified tabs with only essential ones
tab1, tab2 = st.tabs(["πŸ“‚ Upload & Preview Data", "πŸ’¬ Analyze Your Data"])
with tab1: # Home
uploaded_file = st.file_uploader("πŸ“‚ Upload a CSV file", type=["csv"])
if not uploaded_file:
st.info("πŸ‘† Upload your CSV file to begin analyzing your data. We'll help you explore and understand your data through natural language questions.")
st.stop()
# Store the dataframe in session state
if "df" not in st.session_state or st.session_state.uploaded_file_name != uploaded_file.name:
df = load_and_process_data(uploaded_file)
st.session_state.df = df
st.session_state.uploaded_file_name = uploaded_file.name
st.session_state.sample_qs = [] # Reset sample questions
display_data_preview(df)
# Switch to Ask a Question tab
st.session_state.current_tab = 1
st.rerun()
else:
df = st.session_state.df
display_data_preview(df)
with tab2: # Ask a Question
if "df" not in st.session_state:
st.warning("⚠️ Please upload your data file first to start analyzing.")
st.stop()
df = st.session_state.df
st.markdown("## πŸ’¬ Ask Questions About Your Data")
st.info("πŸ’‘ Try these sample questions or type your own. For example: 'What is the average price?', 'Show me the top 10 sales by region', or 'What's the trend in daily revenue?'")
display_sample_questions(df, client)
user_question = get_user_question()
if user_question:
process_user_question(user_question, df, client)
if "last_result_df" in st.session_state and st.session_state.last_result_df is not None:
df_result = st.session_state.last_result_df
display_query_results(df_result)
def setup_ui_header():
"""Setup the app header and description."""
st.title("πŸ” DataWhiz: Natural Language to SQL Converter")
st.markdown("""
<div style="background-color:#f0f2f6;padding:15px;border-radius:10px;margin-bottom:20px">
<h4 style="margin-top:0;color:#262730;">Transform Your Data Questions into SQL Queries</h4>
<ol style="color:#262730;">
<li>πŸ“‚ Upload your CSV file containing the data you want to analyze</li>
<li>πŸ’¬ Ask questions in plain English about your data</li>
<li>✨ Get instant SQL queries and visualized results!</li>
</ol>
<p style="color:#262730;margin-top:10px;">No SQL knowledge required - just ask questions naturally!</p>
</div>
""", unsafe_allow_html=True)
def setup_openai_client() -> Optional[OpenAI]:
"""Setup OpenAI client with API key."""
api_key = os.getenv("OPENAI_API_KEY")
user_api_key = ""
if not api_key:
user_api_key = st.text_input("πŸ” Enter your OpenAI API Key", type="password",
help="Your API key is not stored and only used for this session")
if not user_api_key:
return None
api_key = user_api_key
try:
client = OpenAI(api_key=api_key)
logger.info("OpenAI client initialized successfully")
return client
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {str(e)}")
st.error(f"❌ Failed to initialize OpenAI client: {str(e)}")
return None
def initialize_session_state():
"""Initialize session state variables."""
if "selected_question" not in st.session_state:
st.session_state.selected_question = ""
if "input_key" not in st.session_state:
st.session_state.input_key = str(uuid.uuid4())
if "sample_qs" not in st.session_state:
st.session_state.sample_qs = []
if "current_tab" not in st.session_state:
st.session_state.current_tab = 0
if "uploaded_file_name" not in st.session_state:
st.session_state.uploaded_file_name = None
if "query_history" not in st.session_state:
st.session_state.query_history = []
def normalize_string_columns(df: pd.DataFrame) -> pd.DataFrame:
"""Normalize string columns by converting to lowercase and stripping whitespace."""
string_cols = df.select_dtypes(include="object").columns
for col in string_cols:
df[col] = df[col].astype(str).str.lower().str.strip()
return df
def load_and_process_data(uploaded_file) -> pd.DataFrame:
"""Load CSV data and perform initial processing."""
try:
logger.info(f"Loading CSV file: {uploaded_file.name}")
df = pd.read_csv(uploaded_file)
# Normalize string columns
df = normalize_string_columns(df)
# Convert potential datetime columns
for col in df.columns:
if "date" in col.lower() or "time" in col.lower():
try:
df[col] = pd.to_datetime(df[col], errors='coerce')
logger.info(f"Converted column to datetime: {col}")
except Exception as e:
logger.warning(f"Failed to convert column {col} to datetime: {str(e)}")
return df
except Exception as e:
logger.error(f"Error loading CSV: {str(e)}", exc_info=True)
st.error(f"❌ Error loading CSV file: {str(e)}")
st.stop()
def display_data_preview(df: pd.DataFrame):
"""Display dataset preview and schema information."""
col1, col2 = st.columns([2, 1])
with col1:
with st.expander("πŸ“Š Dataset Preview", expanded=True):
st.dataframe(df.head(), use_container_width=True)
st.caption("Preview of your data (first 5 rows)")
with col2:
with st.expander("πŸ“ Data Structure", expanded=False):
# Enhanced schema display
schema_info = []
for col in df.columns:
col_type = str(df[col].dtype)
non_null = df[col].count()
null_pct = round((len(df) - non_null) / len(df) * 100, 1) if len(df) > 0 else 0
schema_info.append(f"- **{col}**: {col_type} ({non_null} non-null, {null_pct}% missing)")
st.markdown("\n".join(schema_info))
# Basic statistics
st.divider()
st.caption(f"**Total Records**: {len(df):,} | **Total Fields**: {len(df.columns)}")
def get_column_metadata(df: pd.DataFrame, sample_size: int = 5) -> str:
"""Get column metadata including types and examples."""
metadata = []
for col in df.columns:
col_type = str(df[col].dtype)
samples = df[col].dropna().astype(str).unique()[:sample_size]
sample_str = ", ".join(samples)
metadata.append(f"- **{col}** ({col_type}): e.g., {sample_str}")
return "\n".join(metadata)
def generate_llm_queries(df: pd.DataFrame, client: OpenAI, num_questions: int = 10) -> List[str]:
"""Generate sample analytical questions using OpenAI."""
try:
# Get column metadata and sample data
metadata = get_column_metadata(df)
sample_data = df.head(3).to_string()
# Create a more focused prompt for business analytics questions
prompt = f"""
You are a business analyst assistant. Based on the following data schema and sample data, generate {num_questions} meaningful business analytics questions.
Data Schema:
{metadata}
Sample Data:
{sample_data}
Generate questions that:
1. Focus on business metrics and KPIs
2. Include comparisons, trends, and aggregations
3. Are relevant for business reporting and decision making
4. Can be answered with the available data
5. Are clear and specific
Format each question as a numbered list.
"""
response = client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
temperature=0.7
)
content = response.choices[0].message.content.strip()
# Extract questions from numbered list
questions = [re.sub(r'^\d+\.\s*', '', q).strip() for q in content.split('\n') if q.strip()]
# Validate and filter questions
valid_questions = []
for question in questions:
try:
# Generate SQL for the question
sql = generate_sql(question, df.head(5), client)
# Try to execute the SQL
con = duckdb.connect(database=':memory:')
con.register("data_table", df)
con.execute(sql).fetchdf()
# If execution succeeds, add to valid questions
valid_questions.append(question)
# If we have enough valid questions, stop
if len(valid_questions) >= 5:
break
except Exception as e:
logger.warning(f"Question validation failed: {question}. Error: {str(e)}")
continue
# If we don't have enough valid questions, add some generic ones
if len(valid_questions) < 5:
generic_questions = [
"What is the total count of records?",
"What are the unique values in the first column?",
"What is the average value of numeric columns?",
"Show the top 10 records",
"What is the distribution of values in the first column?"
]
valid_questions.extend(generic_questions[:5-len(valid_questions)])
logger.info(f"Generated {len(valid_questions)} valid sample questions")
return valid_questions
except Exception as e:
logger.error(f"Error generating sample questions: {str(e)}", exc_info=True)
st.warning(f"⚠️ Could not generate sample questions: {str(e)}")
return [
"What is the total count of records?",
"What are the unique values in the first column?",
"What is the average value of numeric columns?",
"Show the top 10 records",
"What is the distribution of values in the first column?"
]
def validate_sample_question(question: str, df: pd.DataFrame, client: OpenAI) -> Tuple[bool, str]:
"""
Validate a sample question by generating and executing its SQL.
Returns (success, error_message)
"""
try:
# Generate SQL for the question
sql = generate_sql(question, df.head(5), client)
# Try to execute the SQL
con = duckdb.connect(database=':memory:')
con.register("data_table", df)
try:
con.execute(sql).fetchdf()
return True, ""
except Exception as e:
return False, str(e)
except Exception as e:
return False, str(e)
def display_sample_questions(df: pd.DataFrame, client: OpenAI):
"""Display sample questions as clickable buttons."""
st.markdown("### πŸ’‘ Try Sample Questions")
# Generate sample questions if not already done
if not st.session_state.sample_qs:
with st.spinner("Generating and validating sample questions..."):
st.session_state.sample_qs = generate_llm_queries(df, client, num_questions=10)
# Display in a nice grid with 2 columns
cols = st.columns(2)
for i, question in enumerate(st.session_state.sample_qs):
col = cols[i % 2]
if col.button(question, key=f"sample_{i}",
use_container_width=True,
help="Click to use this question"):
st.session_state.selected_question = question
st.session_state.input_key = str(uuid.uuid4())
logger.info(f"Sample question selected: {question}")
st.rerun()
def get_user_question() -> str:
"""Get the user's question input."""
st.markdown("### πŸ€– Ask Your Own Question")
user_question = st.text_input(
"Enter your question:",
value=st.session_state.selected_question,
key=st.session_state.input_key,
placeholder="e.g., What is the average price?",
)
if user_question:
logger.info(f"User question received: {user_question}")
return user_question
def sanitize_sql(query: str) -> str:
"""Sanitize and format SQL query for DuckDB."""
# Remove code block markers
query = re.sub(r"```sql|```", "", query, flags=re.IGNORECASE).strip()
# Replace DATE(column) with CAST(column AS DATE)
query = re.sub(r"\bDATE\((.*?)\)", r"CAST(\1 AS DATE)", query, flags=re.IGNORECASE)
# Handle common replacements
replacements = {
"β‰₯": ">=", "≀": "<=", "β‰ ": "!=",
""": '"', """: '"', "'": "'", "'": "'",
"DATE('now', '-30 days')": "CURRENT_DATE - INTERVAL '30 days'",
"DATE('now', '-7 days')": "CURRENT_DATE - INTERVAL '7 days'",
"DATE('now', '-1 day')": "CURRENT_DATE - INTERVAL '1 day'",
"DATE_SUB(CURRENT_DATE, INTERVAL 7 DAY)": "CURRENT_DATE - INTERVAL '7 days'",
"CURRENT_DATE - INTERVAL 30 DAY": "CURRENT_DATE - INTERVAL '30 days'",
"CURRENT_DATE - INTERVAL 7 DAY": "CURRENT_DATE - INTERVAL '7 days'",
"CURRENT_DATE - INTERVAL 1 DAY": "CURRENT_DATE - INTERVAL '1 day'",
"table_name": "data_table"
}
for k, v in replacements.items():
query = query.replace(k, v)
return query
def generate_sql(nl_query: str, df_sample: pd.DataFrame, client: OpenAI) -> str:
"""Generate SQL query from natural language."""
try:
logger.info(f"Generating SQL for question: {nl_query}")
# Create schema description
schema_desc = df_sample.dtypes.to_string()
sample_data = df_sample.head(3).to_string()
# Enhanced prompt with guidance on proper SQL syntax and common patterns
prompt = f"""
You are a data analysis assistant trained to transform analytical questions into well-structured SQL queries using DuckDB syntax.
Below is the schema of a table named 'data_table':
{schema_desc}
Sample data (first 3 rows):
{sample_data}
Translate the following question into a valid SQL query:
Question: {nl_query}
IMPORTANT: Generate ONLY DuckDB-compliant SQL code. Do not generate explanations or alternatives unless there's ambiguity in the question.
CORE SQL RULES:
1. Always include columns in GROUP BY if they appear in SELECT and aren't inside aggregate functions
2. When calculating overall averages or statistics, don't include individual IDs in the outer SELECT
3. Make sure subqueries have appropriate aliases
4. For aggregations across groups, use GROUP BY. For overall aggregations, don't select individual rows
5. When using window functions, ensure PARTITION BY and ORDER BY clauses are correct
6. Use proper column references (table_name.column_name or alias.column_name) when joining tables or using subqueries
DUCKDB-SPECIFIC SYNTAX:
7. DATE/TIME OPERATIONS:
- Subtract from dates: Use DATE_SUB(date, INTERVAL n unit) or date - INTERVAL 'n unit'
- Add to dates: Use DATE_ADD(date, INTERVAL n unit) or date + INTERVAL 'n unit'
- Valid interval units: 'day', 'month', 'year', 'hour', 'minute', 'second'
- Extract components: Use EXTRACT(unit FROM date)
- Format dates: Use strftime(timestamp, format)
- DO NOT use: DATEADD(), DATEDIFF(), GETDATE(), DATEPART() or other SQL Server/MySQL specific functions
8. STRING OPERATIONS:
- Concatenation: Use || operator (NOT +)
- Case insensitive comparison: Use LOWER(col) = LOWER(value) or ILIKE
- Substring: Use SUBSTRING(string, start, length)
- Pattern matching: Use string LIKE pattern (% for wildcards)
- Regular expressions: Use REGEXP_MATCHES(string, pattern)
9. AGGREGATION FUNCTIONS:
- Use NULLIF() to prevent division by zero: SUM(value) / NULLIF(COUNT(*), 0)
- For percentages, multiply by 100.0 (not 100) to ensure float division
- For conditional counting: COUNT(CASE WHEN condition THEN 1 END)
- For conditional sums: SUM(CASE WHEN condition THEN value ELSE 0 END)
10. COMMON EXCEPTIONS TO AVOID:
- Don't use non-standard SQL functions without checking DuckDB compatibility
- Don't use database-specific functions from SQL Server, MySQL, PostgreSQL, etc.
- Don't mix aggregate and non-aggregate columns in SELECT without proper GROUP BY
- Don't use column aliases in WHERE clauses (use the full expression or a subquery)
- Don't use ORDER BY positional references (ORDER BY 1, 2) - use explicit column names
- Don't use TOP; use LIMIT instead
11. PERFORMANCE CONSIDERATIONS:
- Use WITH clauses (CTEs) for complex subqueries to improve readability
- Filter data as early as possible in the query
- Use EXISTS or IN for membership tests, not joins, when appropriate
- Avoid unnecessary DISTINCT operations when possible
12. TEMPORAL DATA PATTERNS:
- Recent time period: timestamp >= CURRENT_DATE - INTERVAL 'n days/months/years'
- Current month: EXTRACT(MONTH FROM timestamp) = EXTRACT(MONTH FROM CURRENT_DATE) AND EXTRACT(YEAR FROM timestamp) = EXTRACT(YEAR FROM CURRENT_DATE)
- Month-to-date: timestamp >= DATE_TRUNC('month', CURRENT_DATE)
- Year-to-date: timestamp >= DATE_TRUNC('year', CURRENT_DATE)
- Last N complete months: timestamp >= DATE_TRUNC('month', CURRENT_DATE) - INTERVAL 'n months' AND timestamp < DATE_TRUNC('month', CURRENT_DATE)
COMPLEX PATTERN EXAMPLES:
Example 1 (for the question: "What percentage of users who add a product to their cart actually proceed to purchase it?"):
WITH cart_users AS (
SELECT DISTINCT user_id
FROM data_table
WHERE action = 'add_to_cart'
),
purchase_users AS (
SELECT DISTINCT user_id
FROM data_table
WHERE action = 'purchase'
)
SELECT
(COUNT(DISTINCT pu.user_id) * 100.0) / NULLIF(COUNT(DISTINCT cu.user_id), 0) AS conversion_rate
FROM cart_users cu
LEFT JOIN purchase_users pu ON cu.user_id = pu.user_id;
Example 2 (for date filtering): "How many purchases were made in the last 7 days?"
SELECT COUNT(*) AS recent_purchases
FROM data_table
WHERE action = 'purchase'
AND timestamp >= CURRENT_DATE - INTERVAL '7 days';
Example 3 (for sequential analysis): "What's the average time between a user viewing a product and purchasing it?"
WITH view_events AS (
SELECT user_id, product_id, timestamp AS view_time
FROM data_table
WHERE action = 'view'
),
purchase_events AS (
SELECT user_id, product_id, timestamp AS purchase_time
FROM data_table
WHERE action = 'purchase'
)
SELECT AVG(
EXTRACT(EPOCH FROM purchase_time - view_time) / 3600
) AS avg_hours_to_purchase
FROM view_events ve
JOIN purchase_events pe ON ve.user_id = pe.user_id AND ve.product_id = pe.product_id
WHERE purchase_time > view_time;
SQL:"""
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
)
sql = response.choices[0].message.content.strip()
sanitized_sql = sanitize_sql(sql)
logger.info(f"Generated SQL: {sanitized_sql}")
return sanitized_sql
except Exception as e:
logger.error(f"Error generating SQL: {str(e)}", exc_info=True)
raise Exception(f"Failed to generate SQL: {str(e)}")
def validate_sql(sql: str, df: pd.DataFrame) -> Tuple[bool, str]:
"""
Validate SQL query without executing it.
Returns (is_valid, error_message)
"""
try:
logger.info(f"Validating SQL query: {sql}")
# Check for basic SQL syntax using DuckDB's explain
con = duckdb.connect(database=':memory:')
con.register("data_table", df)
# Use EXPLAIN to validate the SQL without executing it
try:
con.execute(f"EXPLAIN {sql}")
return True, ""
except Exception as e:
error_msg = str(e)
logger.warning(f"SQL validation failed: {error_msg}")
return False, error_msg
except Exception as e:
logger.error(f"Error during SQL validation: {str(e)}", exc_info=True)
return False, str(e)
def fix_sql_with_llm(sql: str, error_msg: str, client: OpenAI) -> str:
"""
Try to fix SQL errors using the LLM.
"""
try:
logger.info(f"Attempting to fix SQL error: {error_msg}")
prompt = f"""
You are an expert SQL developer who specializes in fixing SQL syntax errors.
Here's a SQL query that has an error:
```sql
{sql}
```
The error message is:
```
{error_msg}
```
Please fix the SQL query to resolve this error. Return ONLY the corrected SQL query without any explanations or markdown formatting.
"""
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
temperature=0.2,
)
fixed_sql = response.choices[0].message.content.strip()
fixed_sql = sanitize_sql(fixed_sql)
logger.info(f"Fixed SQL: {fixed_sql}")
return fixed_sql
except Exception as e:
logger.error(f"Error fixing SQL: {str(e)}", exc_info=True)
# Return original SQL if fixing fails
return sql
def execute_query(sql: str, df: pd.DataFrame) -> pd.DataFrame:
"""Execute SQL query using DuckDB."""
try:
logger.info(f"Executing SQL query: {sql}")
con = duckdb.connect(database=':memory:')
con.register("data_table", df)
result_df = con.execute(sql).fetchdf()
logger.info(f"Query executed successfully, returned {len(result_df)} rows")
return result_df
except Exception as e:
logger.error(f"Error executing SQL: {str(e)}", exc_info=True)
error_details = traceback.format_exc()
logger.debug(f"SQL execution error details: {error_details}")
raise Exception(f"Failed to execute SQL: {str(e)}")
def auto_visualize(df: pd.DataFrame):
"""Automatically choose and create appropriate visualizations based on data types."""
try:
# Get column types
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
date_cols = detect_date_columns(df)
# Ensure we have data to visualize
if len(df) == 0:
st.info("No data to visualize.")
return
# Case 1: Time series data
if date_cols and numeric_cols:
date_col = date_cols[0] # Take first date column
metric_col = numeric_cols[0] # Take first numeric column
st.subheader("πŸ“ˆ Time Series Analysis")
st.caption(f"Showing {metric_col} over time")
# Convert to datetime to ensure proper plotting
if not pd.api.types.is_datetime64_any_dtype(df[date_col]):
df[date_col] = pd.to_datetime(df[date_col], errors='coerce', infer_datetime_format=True)
# Sort by date
df_sorted = df.sort_values(by=date_col)
# Create time series chart with trendline
fig = px.line(df_sorted, x=date_col, y=metric_col,
title=f"{metric_col} Trend Over Time",
template="plotly_white")
# Add trendline
fig.add_trace(px.scatter(df_sorted, x=date_col, y=metric_col,
trendline="ols",
trendline_color_override="red").data[1])
# Customize layout
fig.update_layout(
xaxis_title="Date",
yaxis_title=metric_col,
hovermode="x unified",
showlegend=True
)
st.plotly_chart(fig, use_container_width=True)
# Add summary statistics
with st.expander("πŸ“Š Summary Statistics", expanded=False):
stats = df_sorted[metric_col].describe()
st.write(stats)
# Case 2: Categorical vs Numeric data
elif categorical_cols and numeric_cols:
cat_col = categorical_cols[0] # Take first categorical column
metric_col = numeric_cols[0] # Take first numeric column
st.subheader("πŸ“Š Category Analysis")
st.caption(f"Showing {metric_col} by {cat_col}")
# Limit categories for readability
top_cats = df.groupby(cat_col)[metric_col].sum().nlargest(10).index
df_filtered = df[df[cat_col].isin(top_cats)]
# Create bar chart with hover data
fig = px.bar(df_filtered,
x=cat_col,
y=metric_col,
title=f"Top 10 {cat_col} by {metric_col}",
template="plotly_white",
color=metric_col,
color_continuous_scale="Viridis")
# Customize layout
fig.update_layout(
xaxis_title=cat_col,
yaxis_title=metric_col,
hovermode="x unified",
showlegend=False
)
st.plotly_chart(fig, use_container_width=True)
# Add summary statistics
with st.expander("πŸ“Š Summary Statistics", expanded=False):
stats = df_filtered.groupby(cat_col)[metric_col].agg(['sum', 'mean', 'count']).round(2)
st.write(stats)
# Case 3: Two numeric columns - scatter plot with correlation
elif len(numeric_cols) >= 2:
x_col = numeric_cols[0]
y_col = numeric_cols[1]
st.subheader("πŸ” Correlation Analysis")
st.caption(f"Exploring relationship between {x_col} and {y_col}")
# Create scatter plot with trendline
fig = px.scatter(df,
x=x_col,
y=y_col,
title=f"{y_col} vs {x_col}",
template="plotly_white",
trendline="ols",
trendline_color_override="red")
# Customize layout
fig.update_layout(
xaxis_title=x_col,
yaxis_title=y_col,
hovermode="closest"
)
st.plotly_chart(fig, use_container_width=True)
# Show correlation matrix
with st.expander("πŸ“Š Correlation Matrix", expanded=False):
corr = df[numeric_cols].corr()
fig = px.imshow(corr,
text_auto=True,
title="Correlation Heatmap",
color_continuous_scale='RdBu_r',
template="plotly_white")
st.plotly_chart(fig, use_container_width=True)
# Case 4: Single numeric column - distribution analysis
elif numeric_cols:
num_col = numeric_cols[0]
st.subheader("πŸ“‰ Distribution Analysis")
st.caption(f"Distribution of {num_col}")
# Create histogram with KDE
fig = px.histogram(df,
x=num_col,
title=f"Distribution of {num_col}",
template="plotly_white",
marginal="box",
nbins=30)
# Customize layout
fig.update_layout(
xaxis_title=num_col,
yaxis_title="Count",
hovermode="x unified"
)
st.plotly_chart(fig, use_container_width=True)
# Show summary statistics
with st.expander("πŸ“Š Summary Statistics", expanded=False):
stats = df[num_col].describe()
st.write(stats)
# Case 5: Categorical data only - frequency analysis
elif categorical_cols:
cat_col = categorical_cols[0]
st.subheader("πŸ”’ Category Distribution")
st.caption(f"Frequency of {cat_col}")
# Count frequency of each category
value_counts = df[cat_col].value_counts().reset_index()
value_counts.columns = [cat_col, 'Count']
# Limit to top 10 categories
if len(value_counts) > 10:
value_counts = value_counts.head(10)
# Create bar chart
fig = px.bar(value_counts,
x=cat_col,
y='Count',
title=f"Top 10 {cat_col} Distribution",
template="plotly_white",
color='Count',
color_continuous_scale="Viridis")
# Customize layout
fig.update_layout(
xaxis_title=cat_col,
yaxis_title="Count",
hovermode="x unified",
showlegend=False
)
st.plotly_chart(fig, use_container_width=True)
# Show value counts
with st.expander("πŸ“Š Value Counts", expanded=False):
st.write(value_counts)
else:
st.info("Could not determine appropriate visualization for this data.")
except Exception as e:
st.error(f"Auto-visualization error: {str(e)}")
def detect_date_columns(df: pd.DataFrame) -> list:
"""Detect columns that contain date information."""
date_cols = []
# First check columns that are already datetime
date_cols.extend(df.select_dtypes(include=['datetime']).columns.tolist())
# Then check if string columns might contain dates
string_cols = df.select_dtypes(include=['object']).columns
for col in string_cols:
if col not in date_cols: # Skip if already identified as date
# Sample values (max 100) to check if they can be parsed as dates
sample = df[col].dropna().head(10) # Reduced sample size to minimize warnings
# Try to determine if this is a date column
if len(sample) > 0:
date_count = 0
for val in sample:
try:
# Try to convert to datetime
if pd.to_datetime(val, errors='coerce') is not pd.NaT:
date_count += 1
except:
pass
# If more than 70% of sampled values are valid dates, consider it a date column
if date_count / len(sample) > 0.7:
date_cols.append(col)
return date_cols
def process_user_question(user_question: str, df: pd.DataFrame, client: OpenAI):
"""Process user question, generate SQL, execute, and display results."""
try:
# Generate SQL from the question
with st.spinner("Generating SQL query..."):
sql = generate_sql(user_question, df.head(5), client)
# Display the generated SQL
st.markdown("#### 🧾 Generated SQL")
# Create SQL display with copy button
col1, col2 = st.columns([9, 1])
col1.code(sql, language="sql")
if col2.button("πŸ“‹", help="Copy SQL to clipboard"):
st.toast("SQL copied to clipboard!")
# Validate the SQL before execution
is_valid, error_msg = validate_sql(sql, df)
if not is_valid:
st.warning(f"⚠️ The generated SQL may have issues: {error_msg}")
# Attempt to fix the SQL
with st.spinner("Attempting to fix SQL..."):
fixed_sql = fix_sql_with_llm(sql, error_msg, client)
# Only proceed with fixed SQL if it's different from the original
if fixed_sql != sql:
st.markdown("#### πŸ› οΈ Fixed SQL")
col1, col2 = st.columns([9, 1])
col1.code(fixed_sql, language="sql")
if col2.button("πŸ“‹", help="Copy fixed SQL to clipboard", key="copy_fixed"):
st.toast("Fixed SQL copied to clipboard!")
# Ask user which version to use
use_fixed = st.radio("Which SQL would you like to execute?",
["Use original SQL", "Use fixed SQL"], index=1)
if use_fixed == "Use fixed SQL":
sql = fixed_sql
# Execute the query
with st.spinner("Executing query..."):
try:
result_df = execute_query(sql, df)
# Store in history
st.session_state.query_history.append({
"question": user_question,
"sql": sql,
"succeeded": True,
"timestamp": pd.Timestamp.now().strftime("%H:%M:%S")
})
# Show results
st.success("βœ… Query executed successfully!")
# Display results
display_query_results(result_df)
except Exception as e:
# Add to history as failed
st.session_state.query_history.append({
"question": user_question,
"sql": sql,
"succeeded": False,
"error": str(e),
"timestamp": pd.Timestamp.now().strftime("%H:%M:%S")
})
st.error(f"❌ Error executing SQL: {str(e)}")
# Show debug help
with st.expander("πŸ” Debug Information", expanded=True):
st.write("The error might be due to:")
common_issues = [
"Missing GROUP BY for columns in SELECT",
"Incorrect column names or table references",
"Syntax errors in functions or operators",
"Type mismatches in comparisons or joins"
]
for issue in common_issues:
st.markdown(f"- {issue}")
# Offer to try a simpler version
st.markdown("#### Try a simpler version")
st.write("Let's simplify the query to help diagnose the issue:")
# Generate a simplified version
with st.spinner("Generating simplified query..."):
simplify_prompt = f"""
The following SQL query is failing with this error: {str(e)}
SQL:
```
{sql}
```
Please create a much simpler version of this query that focuses only on the core intent.
Remove complex joins, subqueries, and functions that might be causing issues.
Return ONLY the simplified SQL without any explanations.
"""
simple_response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": simplify_prompt}],
temperature=0.2,
)
simple_sql = simple_response.choices[0].message.content.strip()
simple_sql = sanitize_sql(simple_sql)
st.code(simple_sql, language="sql")
if st.button("Try Simplified Query"):
try:
simple_result = execute_query(simple_sql, df)
st.success("βœ… Simplified query succeeded!")
st.dataframe(simple_result, use_container_width=True)
except Exception as simple_error:
st.error(f"Even simplified query failed: {str(simple_error)}")
except Exception as e:
st.error(f"❌ Error generating SQL: {str(e)}")
def display_query_results(result_df: pd.DataFrame):
"""Display query results with basic visualization options."""
# Show count of results
st.caption(f"πŸ“Š Found {len(result_df):,} matching records")
# Initialize view mode in session state if not exists
if "view_mode" not in st.session_state:
st.session_state.view_mode = "Table"
# Determine if chart is possible
can_show_chart = not result_df.empty and len(result_df.columns) >= 2
# If chart is not possible, force table view
if not can_show_chart and st.session_state.view_mode == "Chart":
st.session_state.view_mode = "Table"
# Add view mode toggle
col1, col2 = st.columns([1, 4])
with col1:
view_mode = st.radio(
"Display Results As",
["Table", "Chart"],
horizontal=True,
key="view_mode",
disabled=not can_show_chart,
help="Chart view is disabled because the query result has only one column. At least 2 columns are required for visualization." if not can_show_chart else None
)
# Display based on view mode
if st.session_state.view_mode == "Chart" and can_show_chart:
auto_visualize(result_df)
else:
st.dataframe(result_df, use_container_width=True, height=300)
# Run the main application
if __name__ == "__main__":
main()