import io # For BytesIO import streamlit as st import pandas as pd import duckdb import os import re import uuid import logging import traceback import io # Add this for BytesIO import plotly.express as px # Add this for Plotly from openai import OpenAI from typing import List, Dict, Any, Optional, Tuple # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger("text-to-sql-app") def main(): st.set_page_config(page_title="DataWhiz - Natural Language to SQL", layout="wide") setup_ui_header() client = setup_openai_client() if client is None: return initialize_session_state() # Simplified tabs with only essential ones tab1, tab2 = st.tabs(["๐Ÿ“‚ Upload & Preview Data", "๐Ÿ’ฌ Analyze Your Data"]) with tab1: # Home uploaded_file = st.file_uploader("๐Ÿ“‚ Upload a CSV file", type=["csv"]) if not uploaded_file: st.info("๐Ÿ‘† Upload your CSV file to begin analyzing your data. We'll help you explore and understand your data through natural language questions.") st.stop() # Store the dataframe in session state if "df" not in st.session_state or st.session_state.uploaded_file_name != uploaded_file.name: df = load_and_process_data(uploaded_file) st.session_state.df = df st.session_state.uploaded_file_name = uploaded_file.name st.session_state.sample_qs = [] # Reset sample questions display_data_preview(df) # Switch to Ask a Question tab st.session_state.current_tab = 1 st.rerun() else: df = st.session_state.df display_data_preview(df) with tab2: # Ask a Question if "df" not in st.session_state: st.warning("โš ๏ธ Please upload your data file first to start analyzing.") st.stop() df = st.session_state.df st.markdown("## ๐Ÿ’ฌ Ask Questions About Your Data") st.info("๐Ÿ’ก Try these sample questions or type your own. For example: 'What is the average price?', 'Show me the top 10 sales by region', or 'What's the trend in daily revenue?'") display_sample_questions(df, client) user_question = get_user_question() if user_question: process_user_question(user_question, df, client) if "last_result_df" in st.session_state and st.session_state.last_result_df is not None: df_result = st.session_state.last_result_df display_query_results(df_result) def setup_ui_header(): """Setup the app header and description.""" st.title("๐Ÿ” DataWhiz: Natural Language to SQL Converter") st.markdown("""

Transform Your Data Questions into SQL Queries

  1. ๐Ÿ“‚ Upload your CSV file containing the data you want to analyze
  2. ๐Ÿ’ฌ Ask questions in plain English about your data
  3. โœจ Get instant SQL queries and visualized results!

No SQL knowledge required - just ask questions naturally!

""", unsafe_allow_html=True) def setup_openai_client() -> Optional[OpenAI]: """Setup OpenAI client with API key.""" api_key = os.getenv("OPENAI_API_KEY") user_api_key = "" if not api_key: user_api_key = st.text_input("๐Ÿ” Enter your OpenAI API Key", type="password", help="Your API key is not stored and only used for this session") if not user_api_key: return None api_key = user_api_key try: client = OpenAI(api_key=api_key) logger.info("OpenAI client initialized successfully") return client except Exception as e: logger.error(f"Failed to initialize OpenAI client: {str(e)}") st.error(f"โŒ Failed to initialize OpenAI client: {str(e)}") return None def initialize_session_state(): """Initialize session state variables.""" if "selected_question" not in st.session_state: st.session_state.selected_question = "" if "input_key" not in st.session_state: st.session_state.input_key = str(uuid.uuid4()) if "sample_qs" not in st.session_state: st.session_state.sample_qs = [] if "current_tab" not in st.session_state: st.session_state.current_tab = 0 if "uploaded_file_name" not in st.session_state: st.session_state.uploaded_file_name = None if "query_history" not in st.session_state: st.session_state.query_history = [] def normalize_string_columns(df: pd.DataFrame) -> pd.DataFrame: """Normalize string columns by converting to lowercase and stripping whitespace.""" string_cols = df.select_dtypes(include="object").columns for col in string_cols: df[col] = df[col].astype(str).str.lower().str.strip() return df def load_and_process_data(uploaded_file) -> pd.DataFrame: """Load CSV data and perform initial processing.""" try: logger.info(f"Loading CSV file: {uploaded_file.name}") df = pd.read_csv(uploaded_file) # Normalize string columns df = normalize_string_columns(df) # Convert potential datetime columns for col in df.columns: if "date" in col.lower() or "time" in col.lower(): try: df[col] = pd.to_datetime(df[col], errors='coerce') logger.info(f"Converted column to datetime: {col}") except Exception as e: logger.warning(f"Failed to convert column {col} to datetime: {str(e)}") return df except Exception as e: logger.error(f"Error loading CSV: {str(e)}", exc_info=True) st.error(f"โŒ Error loading CSV file: {str(e)}") st.stop() def display_data_preview(df: pd.DataFrame): """Display dataset preview and schema information.""" col1, col2 = st.columns([2, 1]) with col1: with st.expander("๐Ÿ“Š Dataset Preview", expanded=True): st.dataframe(df.head(), use_container_width=True) st.caption("Preview of your data (first 5 rows)") with col2: with st.expander("๐Ÿ“ Data Structure", expanded=False): # Enhanced schema display schema_info = [] for col in df.columns: col_type = str(df[col].dtype) non_null = df[col].count() null_pct = round((len(df) - non_null) / len(df) * 100, 1) if len(df) > 0 else 0 schema_info.append(f"- **{col}**: {col_type} ({non_null} non-null, {null_pct}% missing)") st.markdown("\n".join(schema_info)) # Basic statistics st.divider() st.caption(f"**Total Records**: {len(df):,} | **Total Fields**: {len(df.columns)}") def get_column_metadata(df: pd.DataFrame, sample_size: int = 5) -> str: """Get column metadata including types and examples.""" metadata = [] for col in df.columns: col_type = str(df[col].dtype) samples = df[col].dropna().astype(str).unique()[:sample_size] sample_str = ", ".join(samples) metadata.append(f"- **{col}** ({col_type}): e.g., {sample_str}") return "\n".join(metadata) def generate_llm_queries(df: pd.DataFrame, client: OpenAI, num_questions: int = 10) -> List[str]: """Generate sample analytical questions using OpenAI.""" try: # Get column metadata and sample data metadata = get_column_metadata(df) sample_data = df.head(3).to_string() # Create a more focused prompt for business analytics questions prompt = f""" You are a business analyst assistant. Based on the following data schema and sample data, generate {num_questions} meaningful business analytics questions. Data Schema: {metadata} Sample Data: {sample_data} Generate questions that: 1. Focus on business metrics and KPIs 2. Include comparisons, trends, and aggregations 3. Are relevant for business reporting and decision making 4. Can be answered with the available data 5. Are clear and specific Format each question as a numbered list. """ response = client.chat.completions.create( model="gpt-4", messages=[{"role": "user", "content": prompt}], temperature=0.7 ) content = response.choices[0].message.content.strip() # Extract questions from numbered list questions = [re.sub(r'^\d+\.\s*', '', q).strip() for q in content.split('\n') if q.strip()] # Validate and filter questions valid_questions = [] for question in questions: try: # Generate SQL for the question sql = generate_sql(question, df.head(5), client) # Try to execute the SQL con = duckdb.connect(database=':memory:') con.register("data_table", df) con.execute(sql).fetchdf() # If execution succeeds, add to valid questions valid_questions.append(question) # If we have enough valid questions, stop if len(valid_questions) >= 5: break except Exception as e: logger.warning(f"Question validation failed: {question}. Error: {str(e)}") continue # If we don't have enough valid questions, add some generic ones if len(valid_questions) < 5: generic_questions = [ "What is the total count of records?", "What are the unique values in the first column?", "What is the average value of numeric columns?", "Show the top 10 records", "What is the distribution of values in the first column?" ] valid_questions.extend(generic_questions[:5-len(valid_questions)]) logger.info(f"Generated {len(valid_questions)} valid sample questions") return valid_questions except Exception as e: logger.error(f"Error generating sample questions: {str(e)}", exc_info=True) st.warning(f"โš ๏ธ Could not generate sample questions: {str(e)}") return [ "What is the total count of records?", "What are the unique values in the first column?", "What is the average value of numeric columns?", "Show the top 10 records", "What is the distribution of values in the first column?" ] def validate_sample_question(question: str, df: pd.DataFrame, client: OpenAI) -> Tuple[bool, str]: """ Validate a sample question by generating and executing its SQL. Returns (success, error_message) """ try: # Generate SQL for the question sql = generate_sql(question, df.head(5), client) # Try to execute the SQL con = duckdb.connect(database=':memory:') con.register("data_table", df) try: con.execute(sql).fetchdf() return True, "" except Exception as e: return False, str(e) except Exception as e: return False, str(e) def display_sample_questions(df: pd.DataFrame, client: OpenAI): """Display sample questions as clickable buttons.""" st.markdown("### ๐Ÿ’ก Try Sample Questions") # Generate sample questions if not already done if not st.session_state.sample_qs: with st.spinner("Generating and validating sample questions..."): st.session_state.sample_qs = generate_llm_queries(df, client, num_questions=10) # Display in a nice grid with 2 columns cols = st.columns(2) for i, question in enumerate(st.session_state.sample_qs): col = cols[i % 2] if col.button(question, key=f"sample_{i}", use_container_width=True, help="Click to use this question"): st.session_state.selected_question = question st.session_state.input_key = str(uuid.uuid4()) logger.info(f"Sample question selected: {question}") st.rerun() def get_user_question() -> str: """Get the user's question input.""" st.markdown("### ๐Ÿค– Ask Your Own Question") user_question = st.text_input( "Enter your question:", value=st.session_state.selected_question, key=st.session_state.input_key, placeholder="e.g., What is the average price?", ) if user_question: logger.info(f"User question received: {user_question}") return user_question def sanitize_sql(query: str) -> str: """Sanitize and format SQL query for DuckDB.""" # Remove code block markers query = re.sub(r"```sql|```", "", query, flags=re.IGNORECASE).strip() # Replace DATE(column) with CAST(column AS DATE) query = re.sub(r"\bDATE\((.*?)\)", r"CAST(\1 AS DATE)", query, flags=re.IGNORECASE) # Handle common replacements replacements = { "โ‰ฅ": ">=", "โ‰ค": "<=", "โ‰ ": "!=", """: '"', """: '"', "'": "'", "'": "'", "DATE('now', '-30 days')": "CURRENT_DATE - INTERVAL '30 days'", "DATE('now', '-7 days')": "CURRENT_DATE - INTERVAL '7 days'", "DATE('now', '-1 day')": "CURRENT_DATE - INTERVAL '1 day'", "DATE_SUB(CURRENT_DATE, INTERVAL 7 DAY)": "CURRENT_DATE - INTERVAL '7 days'", "CURRENT_DATE - INTERVAL 30 DAY": "CURRENT_DATE - INTERVAL '30 days'", "CURRENT_DATE - INTERVAL 7 DAY": "CURRENT_DATE - INTERVAL '7 days'", "CURRENT_DATE - INTERVAL 1 DAY": "CURRENT_DATE - INTERVAL '1 day'", "table_name": "data_table" } for k, v in replacements.items(): query = query.replace(k, v) return query def generate_sql(nl_query: str, df_sample: pd.DataFrame, client: OpenAI) -> str: """Generate SQL query from natural language.""" try: logger.info(f"Generating SQL for question: {nl_query}") # Create schema description schema_desc = df_sample.dtypes.to_string() sample_data = df_sample.head(3).to_string() # Enhanced prompt with guidance on proper SQL syntax and common patterns prompt = f""" You are a data analysis assistant trained to transform analytical questions into well-structured SQL queries using DuckDB syntax. Below is the schema of a table named 'data_table': {schema_desc} Sample data (first 3 rows): {sample_data} Translate the following question into a valid SQL query: Question: {nl_query} IMPORTANT: Generate ONLY DuckDB-compliant SQL code. Do not generate explanations or alternatives unless there's ambiguity in the question. CORE SQL RULES: 1. Always include columns in GROUP BY if they appear in SELECT and aren't inside aggregate functions 2. When calculating overall averages or statistics, don't include individual IDs in the outer SELECT 3. Make sure subqueries have appropriate aliases 4. For aggregations across groups, use GROUP BY. For overall aggregations, don't select individual rows 5. When using window functions, ensure PARTITION BY and ORDER BY clauses are correct 6. Use proper column references (table_name.column_name or alias.column_name) when joining tables or using subqueries DUCKDB-SPECIFIC SYNTAX: 7. DATE/TIME OPERATIONS: - Subtract from dates: Use DATE_SUB(date, INTERVAL n unit) or date - INTERVAL 'n unit' - Add to dates: Use DATE_ADD(date, INTERVAL n unit) or date + INTERVAL 'n unit' - Valid interval units: 'day', 'month', 'year', 'hour', 'minute', 'second' - Extract components: Use EXTRACT(unit FROM date) - Format dates: Use strftime(timestamp, format) - DO NOT use: DATEADD(), DATEDIFF(), GETDATE(), DATEPART() or other SQL Server/MySQL specific functions 8. STRING OPERATIONS: - Concatenation: Use || operator (NOT +) - Case insensitive comparison: Use LOWER(col) = LOWER(value) or ILIKE - Substring: Use SUBSTRING(string, start, length) - Pattern matching: Use string LIKE pattern (% for wildcards) - Regular expressions: Use REGEXP_MATCHES(string, pattern) 9. AGGREGATION FUNCTIONS: - Use NULLIF() to prevent division by zero: SUM(value) / NULLIF(COUNT(*), 0) - For percentages, multiply by 100.0 (not 100) to ensure float division - For conditional counting: COUNT(CASE WHEN condition THEN 1 END) - For conditional sums: SUM(CASE WHEN condition THEN value ELSE 0 END) 10. COMMON EXCEPTIONS TO AVOID: - Don't use non-standard SQL functions without checking DuckDB compatibility - Don't use database-specific functions from SQL Server, MySQL, PostgreSQL, etc. - Don't mix aggregate and non-aggregate columns in SELECT without proper GROUP BY - Don't use column aliases in WHERE clauses (use the full expression or a subquery) - Don't use ORDER BY positional references (ORDER BY 1, 2) - use explicit column names - Don't use TOP; use LIMIT instead 11. PERFORMANCE CONSIDERATIONS: - Use WITH clauses (CTEs) for complex subqueries to improve readability - Filter data as early as possible in the query - Use EXISTS or IN for membership tests, not joins, when appropriate - Avoid unnecessary DISTINCT operations when possible 12. TEMPORAL DATA PATTERNS: - Recent time period: timestamp >= CURRENT_DATE - INTERVAL 'n days/months/years' - Current month: EXTRACT(MONTH FROM timestamp) = EXTRACT(MONTH FROM CURRENT_DATE) AND EXTRACT(YEAR FROM timestamp) = EXTRACT(YEAR FROM CURRENT_DATE) - Month-to-date: timestamp >= DATE_TRUNC('month', CURRENT_DATE) - Year-to-date: timestamp >= DATE_TRUNC('year', CURRENT_DATE) - Last N complete months: timestamp >= DATE_TRUNC('month', CURRENT_DATE) - INTERVAL 'n months' AND timestamp < DATE_TRUNC('month', CURRENT_DATE) COMPLEX PATTERN EXAMPLES: Example 1 (for the question: "What percentage of users who add a product to their cart actually proceed to purchase it?"): WITH cart_users AS ( SELECT DISTINCT user_id FROM data_table WHERE action = 'add_to_cart' ), purchase_users AS ( SELECT DISTINCT user_id FROM data_table WHERE action = 'purchase' ) SELECT (COUNT(DISTINCT pu.user_id) * 100.0) / NULLIF(COUNT(DISTINCT cu.user_id), 0) AS conversion_rate FROM cart_users cu LEFT JOIN purchase_users pu ON cu.user_id = pu.user_id; Example 2 (for date filtering): "How many purchases were made in the last 7 days?" SELECT COUNT(*) AS recent_purchases FROM data_table WHERE action = 'purchase' AND timestamp >= CURRENT_DATE - INTERVAL '7 days'; Example 3 (for sequential analysis): "What's the average time between a user viewing a product and purchasing it?" WITH view_events AS ( SELECT user_id, product_id, timestamp AS view_time FROM data_table WHERE action = 'view' ), purchase_events AS ( SELECT user_id, product_id, timestamp AS purchase_time FROM data_table WHERE action = 'purchase' ) SELECT AVG( EXTRACT(EPOCH FROM purchase_time - view_time) / 3600 ) AS avg_hours_to_purchase FROM view_events ve JOIN purchase_events pe ON ve.user_id = pe.user_id AND ve.product_id = pe.product_id WHERE purchase_time > view_time; SQL:""" response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], temperature=0.2, ) sql = response.choices[0].message.content.strip() sanitized_sql = sanitize_sql(sql) logger.info(f"Generated SQL: {sanitized_sql}") return sanitized_sql except Exception as e: logger.error(f"Error generating SQL: {str(e)}", exc_info=True) raise Exception(f"Failed to generate SQL: {str(e)}") def validate_sql(sql: str, df: pd.DataFrame) -> Tuple[bool, str]: """ Validate SQL query without executing it. Returns (is_valid, error_message) """ try: logger.info(f"Validating SQL query: {sql}") # Check for basic SQL syntax using DuckDB's explain con = duckdb.connect(database=':memory:') con.register("data_table", df) # Use EXPLAIN to validate the SQL without executing it try: con.execute(f"EXPLAIN {sql}") return True, "" except Exception as e: error_msg = str(e) logger.warning(f"SQL validation failed: {error_msg}") return False, error_msg except Exception as e: logger.error(f"Error during SQL validation: {str(e)}", exc_info=True) return False, str(e) def fix_sql_with_llm(sql: str, error_msg: str, client: OpenAI) -> str: """ Try to fix SQL errors using the LLM. """ try: logger.info(f"Attempting to fix SQL error: {error_msg}") prompt = f""" You are an expert SQL developer who specializes in fixing SQL syntax errors. Here's a SQL query that has an error: ```sql {sql} ``` The error message is: ``` {error_msg} ``` Please fix the SQL query to resolve this error. Return ONLY the corrected SQL query without any explanations or markdown formatting. """ response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], temperature=0.2, ) fixed_sql = response.choices[0].message.content.strip() fixed_sql = sanitize_sql(fixed_sql) logger.info(f"Fixed SQL: {fixed_sql}") return fixed_sql except Exception as e: logger.error(f"Error fixing SQL: {str(e)}", exc_info=True) # Return original SQL if fixing fails return sql def execute_query(sql: str, df: pd.DataFrame) -> pd.DataFrame: """Execute SQL query using DuckDB.""" try: logger.info(f"Executing SQL query: {sql}") con = duckdb.connect(database=':memory:') con.register("data_table", df) result_df = con.execute(sql).fetchdf() logger.info(f"Query executed successfully, returned {len(result_df)} rows") return result_df except Exception as e: logger.error(f"Error executing SQL: {str(e)}", exc_info=True) error_details = traceback.format_exc() logger.debug(f"SQL execution error details: {error_details}") raise Exception(f"Failed to execute SQL: {str(e)}") def auto_visualize(df: pd.DataFrame): """Automatically choose and create appropriate visualizations based on data types.""" try: # Get column types numeric_cols = df.select_dtypes(include=['number']).columns.tolist() categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() date_cols = detect_date_columns(df) # Ensure we have data to visualize if len(df) == 0: st.info("No data to visualize.") return # Case 1: Time series data if date_cols and numeric_cols: date_col = date_cols[0] # Take first date column metric_col = numeric_cols[0] # Take first numeric column st.subheader("๐Ÿ“ˆ Time Series Analysis") st.caption(f"Showing {metric_col} over time") # Convert to datetime to ensure proper plotting if not pd.api.types.is_datetime64_any_dtype(df[date_col]): df[date_col] = pd.to_datetime(df[date_col], errors='coerce', infer_datetime_format=True) # Sort by date df_sorted = df.sort_values(by=date_col) # Create time series chart with trendline fig = px.line(df_sorted, x=date_col, y=metric_col, title=f"{metric_col} Trend Over Time", template="plotly_white") # Add trendline fig.add_trace(px.scatter(df_sorted, x=date_col, y=metric_col, trendline="ols", trendline_color_override="red").data[1]) # Customize layout fig.update_layout( xaxis_title="Date", yaxis_title=metric_col, hovermode="x unified", showlegend=True ) st.plotly_chart(fig, use_container_width=True) # Add summary statistics with st.expander("๐Ÿ“Š Summary Statistics", expanded=False): stats = df_sorted[metric_col].describe() st.write(stats) # Case 2: Categorical vs Numeric data elif categorical_cols and numeric_cols: cat_col = categorical_cols[0] # Take first categorical column metric_col = numeric_cols[0] # Take first numeric column st.subheader("๐Ÿ“Š Category Analysis") st.caption(f"Showing {metric_col} by {cat_col}") # Limit categories for readability top_cats = df.groupby(cat_col)[metric_col].sum().nlargest(10).index df_filtered = df[df[cat_col].isin(top_cats)] # Create bar chart with hover data fig = px.bar(df_filtered, x=cat_col, y=metric_col, title=f"Top 10 {cat_col} by {metric_col}", template="plotly_white", color=metric_col, color_continuous_scale="Viridis") # Customize layout fig.update_layout( xaxis_title=cat_col, yaxis_title=metric_col, hovermode="x unified", showlegend=False ) st.plotly_chart(fig, use_container_width=True) # Add summary statistics with st.expander("๐Ÿ“Š Summary Statistics", expanded=False): stats = df_filtered.groupby(cat_col)[metric_col].agg(['sum', 'mean', 'count']).round(2) st.write(stats) # Case 3: Two numeric columns - scatter plot with correlation elif len(numeric_cols) >= 2: x_col = numeric_cols[0] y_col = numeric_cols[1] st.subheader("๐Ÿ” Correlation Analysis") st.caption(f"Exploring relationship between {x_col} and {y_col}") # Create scatter plot with trendline fig = px.scatter(df, x=x_col, y=y_col, title=f"{y_col} vs {x_col}", template="plotly_white", trendline="ols", trendline_color_override="red") # Customize layout fig.update_layout( xaxis_title=x_col, yaxis_title=y_col, hovermode="closest" ) st.plotly_chart(fig, use_container_width=True) # Show correlation matrix with st.expander("๐Ÿ“Š Correlation Matrix", expanded=False): corr = df[numeric_cols].corr() fig = px.imshow(corr, text_auto=True, title="Correlation Heatmap", color_continuous_scale='RdBu_r', template="plotly_white") st.plotly_chart(fig, use_container_width=True) # Case 4: Single numeric column - distribution analysis elif numeric_cols: num_col = numeric_cols[0] st.subheader("๐Ÿ“‰ Distribution Analysis") st.caption(f"Distribution of {num_col}") # Create histogram with KDE fig = px.histogram(df, x=num_col, title=f"Distribution of {num_col}", template="plotly_white", marginal="box", nbins=30) # Customize layout fig.update_layout( xaxis_title=num_col, yaxis_title="Count", hovermode="x unified" ) st.plotly_chart(fig, use_container_width=True) # Show summary statistics with st.expander("๐Ÿ“Š Summary Statistics", expanded=False): stats = df[num_col].describe() st.write(stats) # Case 5: Categorical data only - frequency analysis elif categorical_cols: cat_col = categorical_cols[0] st.subheader("๐Ÿ”ข Category Distribution") st.caption(f"Frequency of {cat_col}") # Count frequency of each category value_counts = df[cat_col].value_counts().reset_index() value_counts.columns = [cat_col, 'Count'] # Limit to top 10 categories if len(value_counts) > 10: value_counts = value_counts.head(10) # Create bar chart fig = px.bar(value_counts, x=cat_col, y='Count', title=f"Top 10 {cat_col} Distribution", template="plotly_white", color='Count', color_continuous_scale="Viridis") # Customize layout fig.update_layout( xaxis_title=cat_col, yaxis_title="Count", hovermode="x unified", showlegend=False ) st.plotly_chart(fig, use_container_width=True) # Show value counts with st.expander("๐Ÿ“Š Value Counts", expanded=False): st.write(value_counts) else: st.info("Could not determine appropriate visualization for this data.") except Exception as e: st.error(f"Auto-visualization error: {str(e)}") def detect_date_columns(df: pd.DataFrame) -> list: """Detect columns that contain date information.""" date_cols = [] # First check columns that are already datetime date_cols.extend(df.select_dtypes(include=['datetime']).columns.tolist()) # Then check if string columns might contain dates string_cols = df.select_dtypes(include=['object']).columns for col in string_cols: if col not in date_cols: # Skip if already identified as date # Sample values (max 100) to check if they can be parsed as dates sample = df[col].dropna().head(10) # Reduced sample size to minimize warnings # Try to determine if this is a date column if len(sample) > 0: date_count = 0 for val in sample: try: # Try to convert to datetime if pd.to_datetime(val, errors='coerce') is not pd.NaT: date_count += 1 except: pass # If more than 70% of sampled values are valid dates, consider it a date column if date_count / len(sample) > 0.7: date_cols.append(col) return date_cols def process_user_question(user_question: str, df: pd.DataFrame, client: OpenAI): """Process user question, generate SQL, execute, and display results.""" try: # Generate SQL from the question with st.spinner("Generating SQL query..."): sql = generate_sql(user_question, df.head(5), client) # Display the generated SQL st.markdown("#### ๐Ÿงพ Generated SQL") # Create SQL display with copy button col1, col2 = st.columns([9, 1]) col1.code(sql, language="sql") if col2.button("๐Ÿ“‹", help="Copy SQL to clipboard"): st.toast("SQL copied to clipboard!") # Validate the SQL before execution is_valid, error_msg = validate_sql(sql, df) if not is_valid: st.warning(f"โš ๏ธ The generated SQL may have issues: {error_msg}") # Attempt to fix the SQL with st.spinner("Attempting to fix SQL..."): fixed_sql = fix_sql_with_llm(sql, error_msg, client) # Only proceed with fixed SQL if it's different from the original if fixed_sql != sql: st.markdown("#### ๐Ÿ› ๏ธ Fixed SQL") col1, col2 = st.columns([9, 1]) col1.code(fixed_sql, language="sql") if col2.button("๐Ÿ“‹", help="Copy fixed SQL to clipboard", key="copy_fixed"): st.toast("Fixed SQL copied to clipboard!") # Ask user which version to use use_fixed = st.radio("Which SQL would you like to execute?", ["Use original SQL", "Use fixed SQL"], index=1) if use_fixed == "Use fixed SQL": sql = fixed_sql # Execute the query with st.spinner("Executing query..."): try: result_df = execute_query(sql, df) # Store in history st.session_state.query_history.append({ "question": user_question, "sql": sql, "succeeded": True, "timestamp": pd.Timestamp.now().strftime("%H:%M:%S") }) # Show results st.success("โœ… Query executed successfully!") # Display results display_query_results(result_df) except Exception as e: # Add to history as failed st.session_state.query_history.append({ "question": user_question, "sql": sql, "succeeded": False, "error": str(e), "timestamp": pd.Timestamp.now().strftime("%H:%M:%S") }) st.error(f"โŒ Error executing SQL: {str(e)}") # Show debug help with st.expander("๐Ÿ” Debug Information", expanded=True): st.write("The error might be due to:") common_issues = [ "Missing GROUP BY for columns in SELECT", "Incorrect column names or table references", "Syntax errors in functions or operators", "Type mismatches in comparisons or joins" ] for issue in common_issues: st.markdown(f"- {issue}") # Offer to try a simpler version st.markdown("#### Try a simpler version") st.write("Let's simplify the query to help diagnose the issue:") # Generate a simplified version with st.spinner("Generating simplified query..."): simplify_prompt = f""" The following SQL query is failing with this error: {str(e)} SQL: ``` {sql} ``` Please create a much simpler version of this query that focuses only on the core intent. Remove complex joins, subqueries, and functions that might be causing issues. Return ONLY the simplified SQL without any explanations. """ simple_response = client.chat.completions.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": simplify_prompt}], temperature=0.2, ) simple_sql = simple_response.choices[0].message.content.strip() simple_sql = sanitize_sql(simple_sql) st.code(simple_sql, language="sql") if st.button("Try Simplified Query"): try: simple_result = execute_query(simple_sql, df) st.success("โœ… Simplified query succeeded!") st.dataframe(simple_result, use_container_width=True) except Exception as simple_error: st.error(f"Even simplified query failed: {str(simple_error)}") except Exception as e: st.error(f"โŒ Error generating SQL: {str(e)}") def display_query_results(result_df: pd.DataFrame): """Display query results with basic visualization options.""" # Show count of results st.caption(f"๐Ÿ“Š Found {len(result_df):,} matching records") # Initialize view mode in session state if not exists if "view_mode" not in st.session_state: st.session_state.view_mode = "Table" # Determine if chart is possible can_show_chart = not result_df.empty and len(result_df.columns) >= 2 # If chart is not possible, force table view if not can_show_chart and st.session_state.view_mode == "Chart": st.session_state.view_mode = "Table" # Add view mode toggle col1, col2 = st.columns([1, 4]) with col1: view_mode = st.radio( "Display Results As", ["Table", "Chart"], horizontal=True, key="view_mode", disabled=not can_show_chart, help="Chart view is disabled because the query result has only one column. At least 2 columns are required for visualization." if not can_show_chart else None ) # Display based on view mode if st.session_state.view_mode == "Chart" and can_show_chart: auto_visualize(result_df) else: st.dataframe(result_df, use_container_width=True, height=300) # Run the main application if __name__ == "__main__": main()