Spaces:
Sleeping
Sleeping
| import io # For BytesIO | |
| import streamlit as st | |
| import pandas as pd | |
| import duckdb | |
| import os | |
| import re | |
| import uuid | |
| import logging | |
| import traceback | |
| import io # Add this for BytesIO | |
| import plotly.express as px # Add this for Plotly | |
| from openai import OpenAI | |
| from typing import List, Dict, Any, Optional, Tuple | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger("text-to-sql-app") | |
| def main(): | |
| st.set_page_config(page_title="DataWhiz - Natural Language to SQL", layout="wide") | |
| setup_ui_header() | |
| client = setup_openai_client() | |
| if client is None: | |
| return | |
| initialize_session_state() | |
| # Simplified tabs with only essential ones | |
| tab1, tab2 = st.tabs(["π Upload & Preview Data", "π¬ Analyze Your Data"]) | |
| with tab1: # Home | |
| uploaded_file = st.file_uploader("π Upload a CSV file", type=["csv"]) | |
| if not uploaded_file: | |
| st.info("π Upload your CSV file to begin analyzing your data. We'll help you explore and understand your data through natural language questions.") | |
| st.stop() | |
| # Store the dataframe in session state | |
| if "df" not in st.session_state or st.session_state.uploaded_file_name != uploaded_file.name: | |
| df = load_and_process_data(uploaded_file) | |
| st.session_state.df = df | |
| st.session_state.uploaded_file_name = uploaded_file.name | |
| st.session_state.sample_qs = [] # Reset sample questions | |
| display_data_preview(df) | |
| # Switch to Ask a Question tab | |
| st.session_state.current_tab = 1 | |
| st.rerun() | |
| else: | |
| df = st.session_state.df | |
| display_data_preview(df) | |
| with tab2: # Ask a Question | |
| if "df" not in st.session_state: | |
| st.warning("β οΈ Please upload your data file first to start analyzing.") | |
| st.stop() | |
| df = st.session_state.df | |
| st.markdown("## π¬ Ask Questions About Your Data") | |
| st.info("π‘ Try these sample questions or type your own. For example: 'What is the average price?', 'Show me the top 10 sales by region', or 'What's the trend in daily revenue?'") | |
| display_sample_questions(df, client) | |
| user_question = get_user_question() | |
| if user_question: | |
| process_user_question(user_question, df, client) | |
| if "last_result_df" in st.session_state and st.session_state.last_result_df is not None: | |
| df_result = st.session_state.last_result_df | |
| display_query_results(df_result) | |
| def setup_ui_header(): | |
| """Setup the app header and description.""" | |
| st.title("π DataWhiz: Natural Language to SQL Converter") | |
| st.markdown(""" | |
| <div style="background-color:#f0f2f6;padding:15px;border-radius:10px;margin-bottom:20px"> | |
| <h4 style="margin-top:0;color:#262730;">Transform Your Data Questions into SQL Queries</h4> | |
| <ol style="color:#262730;"> | |
| <li>π Upload your CSV file containing the data you want to analyze</li> | |
| <li>π¬ Ask questions in plain English about your data</li> | |
| <li>β¨ Get instant SQL queries and visualized results!</li> | |
| </ol> | |
| <p style="color:#262730;margin-top:10px;">No SQL knowledge required - just ask questions naturally!</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def setup_openai_client() -> Optional[OpenAI]: | |
| """Setup OpenAI client with API key.""" | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| user_api_key = "" | |
| if not api_key: | |
| user_api_key = st.text_input("π Enter your OpenAI API Key", type="password", | |
| help="Your API key is not stored and only used for this session") | |
| if not user_api_key: | |
| return None | |
| api_key = user_api_key | |
| try: | |
| client = OpenAI(api_key=api_key) | |
| logger.info("OpenAI client initialized successfully") | |
| return client | |
| except Exception as e: | |
| logger.error(f"Failed to initialize OpenAI client: {str(e)}") | |
| st.error(f"β Failed to initialize OpenAI client: {str(e)}") | |
| return None | |
| def initialize_session_state(): | |
| """Initialize session state variables.""" | |
| if "selected_question" not in st.session_state: | |
| st.session_state.selected_question = "" | |
| if "input_key" not in st.session_state: | |
| st.session_state.input_key = str(uuid.uuid4()) | |
| if "sample_qs" not in st.session_state: | |
| st.session_state.sample_qs = [] | |
| if "current_tab" not in st.session_state: | |
| st.session_state.current_tab = 0 | |
| if "uploaded_file_name" not in st.session_state: | |
| st.session_state.uploaded_file_name = None | |
| if "query_history" not in st.session_state: | |
| st.session_state.query_history = [] | |
| def normalize_string_columns(df: pd.DataFrame) -> pd.DataFrame: | |
| """Normalize string columns by converting to lowercase and stripping whitespace.""" | |
| string_cols = df.select_dtypes(include="object").columns | |
| for col in string_cols: | |
| df[col] = df[col].astype(str).str.lower().str.strip() | |
| return df | |
| def load_and_process_data(uploaded_file) -> pd.DataFrame: | |
| """Load CSV data and perform initial processing.""" | |
| try: | |
| logger.info(f"Loading CSV file: {uploaded_file.name}") | |
| df = pd.read_csv(uploaded_file) | |
| # Normalize string columns | |
| df = normalize_string_columns(df) | |
| # Convert potential datetime columns | |
| for col in df.columns: | |
| if "date" in col.lower() or "time" in col.lower(): | |
| try: | |
| df[col] = pd.to_datetime(df[col], errors='coerce') | |
| logger.info(f"Converted column to datetime: {col}") | |
| except Exception as e: | |
| logger.warning(f"Failed to convert column {col} to datetime: {str(e)}") | |
| return df | |
| except Exception as e: | |
| logger.error(f"Error loading CSV: {str(e)}", exc_info=True) | |
| st.error(f"β Error loading CSV file: {str(e)}") | |
| st.stop() | |
| def display_data_preview(df: pd.DataFrame): | |
| """Display dataset preview and schema information.""" | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| with st.expander("π Dataset Preview", expanded=True): | |
| st.dataframe(df.head(), use_container_width=True) | |
| st.caption("Preview of your data (first 5 rows)") | |
| with col2: | |
| with st.expander("π Data Structure", expanded=False): | |
| # Enhanced schema display | |
| schema_info = [] | |
| for col in df.columns: | |
| col_type = str(df[col].dtype) | |
| non_null = df[col].count() | |
| null_pct = round((len(df) - non_null) / len(df) * 100, 1) if len(df) > 0 else 0 | |
| schema_info.append(f"- **{col}**: {col_type} ({non_null} non-null, {null_pct}% missing)") | |
| st.markdown("\n".join(schema_info)) | |
| # Basic statistics | |
| st.divider() | |
| st.caption(f"**Total Records**: {len(df):,} | **Total Fields**: {len(df.columns)}") | |
| def get_column_metadata(df: pd.DataFrame, sample_size: int = 5) -> str: | |
| """Get column metadata including types and examples.""" | |
| metadata = [] | |
| for col in df.columns: | |
| col_type = str(df[col].dtype) | |
| samples = df[col].dropna().astype(str).unique()[:sample_size] | |
| sample_str = ", ".join(samples) | |
| metadata.append(f"- **{col}** ({col_type}): e.g., {sample_str}") | |
| return "\n".join(metadata) | |
| def generate_llm_queries(df: pd.DataFrame, client: OpenAI, num_questions: int = 10) -> List[str]: | |
| """Generate sample analytical questions using OpenAI.""" | |
| try: | |
| # Get column metadata and sample data | |
| metadata = get_column_metadata(df) | |
| sample_data = df.head(3).to_string() | |
| # Create a more focused prompt for business analytics questions | |
| prompt = f""" | |
| You are a business analyst assistant. Based on the following data schema and sample data, generate {num_questions} meaningful business analytics questions. | |
| Data Schema: | |
| {metadata} | |
| Sample Data: | |
| {sample_data} | |
| Generate questions that: | |
| 1. Focus on business metrics and KPIs | |
| 2. Include comparisons, trends, and aggregations | |
| 3. Are relevant for business reporting and decision making | |
| 4. Can be answered with the available data | |
| 5. Are clear and specific | |
| Format each question as a numbered list. | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.7 | |
| ) | |
| content = response.choices[0].message.content.strip() | |
| # Extract questions from numbered list | |
| questions = [re.sub(r'^\d+\.\s*', '', q).strip() for q in content.split('\n') if q.strip()] | |
| # Validate and filter questions | |
| valid_questions = [] | |
| for question in questions: | |
| try: | |
| # Generate SQL for the question | |
| sql = generate_sql(question, df.head(5), client) | |
| # Try to execute the SQL | |
| con = duckdb.connect(database=':memory:') | |
| con.register("data_table", df) | |
| con.execute(sql).fetchdf() | |
| # If execution succeeds, add to valid questions | |
| valid_questions.append(question) | |
| # If we have enough valid questions, stop | |
| if len(valid_questions) >= 5: | |
| break | |
| except Exception as e: | |
| logger.warning(f"Question validation failed: {question}. Error: {str(e)}") | |
| continue | |
| # If we don't have enough valid questions, add some generic ones | |
| if len(valid_questions) < 5: | |
| generic_questions = [ | |
| "What is the total count of records?", | |
| "What are the unique values in the first column?", | |
| "What is the average value of numeric columns?", | |
| "Show the top 10 records", | |
| "What is the distribution of values in the first column?" | |
| ] | |
| valid_questions.extend(generic_questions[:5-len(valid_questions)]) | |
| logger.info(f"Generated {len(valid_questions)} valid sample questions") | |
| return valid_questions | |
| except Exception as e: | |
| logger.error(f"Error generating sample questions: {str(e)}", exc_info=True) | |
| st.warning(f"β οΈ Could not generate sample questions: {str(e)}") | |
| return [ | |
| "What is the total count of records?", | |
| "What are the unique values in the first column?", | |
| "What is the average value of numeric columns?", | |
| "Show the top 10 records", | |
| "What is the distribution of values in the first column?" | |
| ] | |
| def validate_sample_question(question: str, df: pd.DataFrame, client: OpenAI) -> Tuple[bool, str]: | |
| """ | |
| Validate a sample question by generating and executing its SQL. | |
| Returns (success, error_message) | |
| """ | |
| try: | |
| # Generate SQL for the question | |
| sql = generate_sql(question, df.head(5), client) | |
| # Try to execute the SQL | |
| con = duckdb.connect(database=':memory:') | |
| con.register("data_table", df) | |
| try: | |
| con.execute(sql).fetchdf() | |
| return True, "" | |
| except Exception as e: | |
| return False, str(e) | |
| except Exception as e: | |
| return False, str(e) | |
| def display_sample_questions(df: pd.DataFrame, client: OpenAI): | |
| """Display sample questions as clickable buttons.""" | |
| st.markdown("### π‘ Try Sample Questions") | |
| # Generate sample questions if not already done | |
| if not st.session_state.sample_qs: | |
| with st.spinner("Generating and validating sample questions..."): | |
| st.session_state.sample_qs = generate_llm_queries(df, client, num_questions=10) | |
| # Display in a nice grid with 2 columns | |
| cols = st.columns(2) | |
| for i, question in enumerate(st.session_state.sample_qs): | |
| col = cols[i % 2] | |
| if col.button(question, key=f"sample_{i}", | |
| use_container_width=True, | |
| help="Click to use this question"): | |
| st.session_state.selected_question = question | |
| st.session_state.input_key = str(uuid.uuid4()) | |
| logger.info(f"Sample question selected: {question}") | |
| st.rerun() | |
| def get_user_question() -> str: | |
| """Get the user's question input.""" | |
| st.markdown("### π€ Ask Your Own Question") | |
| user_question = st.text_input( | |
| "Enter your question:", | |
| value=st.session_state.selected_question, | |
| key=st.session_state.input_key, | |
| placeholder="e.g., What is the average price?", | |
| ) | |
| if user_question: | |
| logger.info(f"User question received: {user_question}") | |
| return user_question | |
| def sanitize_sql(query: str) -> str: | |
| """Sanitize and format SQL query for DuckDB.""" | |
| # Remove code block markers | |
| query = re.sub(r"```sql|```", "", query, flags=re.IGNORECASE).strip() | |
| # Replace DATE(column) with CAST(column AS DATE) | |
| query = re.sub(r"\bDATE\((.*?)\)", r"CAST(\1 AS DATE)", query, flags=re.IGNORECASE) | |
| # Handle common replacements | |
| replacements = { | |
| "β₯": ">=", "β€": "<=", "β ": "!=", | |
| """: '"', """: '"', "'": "'", "'": "'", | |
| "DATE('now', '-30 days')": "CURRENT_DATE - INTERVAL '30 days'", | |
| "DATE('now', '-7 days')": "CURRENT_DATE - INTERVAL '7 days'", | |
| "DATE('now', '-1 day')": "CURRENT_DATE - INTERVAL '1 day'", | |
| "DATE_SUB(CURRENT_DATE, INTERVAL 7 DAY)": "CURRENT_DATE - INTERVAL '7 days'", | |
| "CURRENT_DATE - INTERVAL 30 DAY": "CURRENT_DATE - INTERVAL '30 days'", | |
| "CURRENT_DATE - INTERVAL 7 DAY": "CURRENT_DATE - INTERVAL '7 days'", | |
| "CURRENT_DATE - INTERVAL 1 DAY": "CURRENT_DATE - INTERVAL '1 day'", | |
| "table_name": "data_table" | |
| } | |
| for k, v in replacements.items(): | |
| query = query.replace(k, v) | |
| return query | |
| def generate_sql(nl_query: str, df_sample: pd.DataFrame, client: OpenAI) -> str: | |
| """Generate SQL query from natural language.""" | |
| try: | |
| logger.info(f"Generating SQL for question: {nl_query}") | |
| # Create schema description | |
| schema_desc = df_sample.dtypes.to_string() | |
| sample_data = df_sample.head(3).to_string() | |
| # Enhanced prompt with guidance on proper SQL syntax and common patterns | |
| prompt = f""" | |
| You are a data analysis assistant trained to transform analytical questions into well-structured SQL queries using DuckDB syntax. | |
| Below is the schema of a table named 'data_table': | |
| {schema_desc} | |
| Sample data (first 3 rows): | |
| {sample_data} | |
| Translate the following question into a valid SQL query: | |
| Question: {nl_query} | |
| IMPORTANT: Generate ONLY DuckDB-compliant SQL code. Do not generate explanations or alternatives unless there's ambiguity in the question. | |
| CORE SQL RULES: | |
| 1. Always include columns in GROUP BY if they appear in SELECT and aren't inside aggregate functions | |
| 2. When calculating overall averages or statistics, don't include individual IDs in the outer SELECT | |
| 3. Make sure subqueries have appropriate aliases | |
| 4. For aggregations across groups, use GROUP BY. For overall aggregations, don't select individual rows | |
| 5. When using window functions, ensure PARTITION BY and ORDER BY clauses are correct | |
| 6. Use proper column references (table_name.column_name or alias.column_name) when joining tables or using subqueries | |
| DUCKDB-SPECIFIC SYNTAX: | |
| 7. DATE/TIME OPERATIONS: | |
| - Subtract from dates: Use DATE_SUB(date, INTERVAL n unit) or date - INTERVAL 'n unit' | |
| - Add to dates: Use DATE_ADD(date, INTERVAL n unit) or date + INTERVAL 'n unit' | |
| - Valid interval units: 'day', 'month', 'year', 'hour', 'minute', 'second' | |
| - Extract components: Use EXTRACT(unit FROM date) | |
| - Format dates: Use strftime(timestamp, format) | |
| - DO NOT use: DATEADD(), DATEDIFF(), GETDATE(), DATEPART() or other SQL Server/MySQL specific functions | |
| 8. STRING OPERATIONS: | |
| - Concatenation: Use || operator (NOT +) | |
| - Case insensitive comparison: Use LOWER(col) = LOWER(value) or ILIKE | |
| - Substring: Use SUBSTRING(string, start, length) | |
| - Pattern matching: Use string LIKE pattern (% for wildcards) | |
| - Regular expressions: Use REGEXP_MATCHES(string, pattern) | |
| 9. AGGREGATION FUNCTIONS: | |
| - Use NULLIF() to prevent division by zero: SUM(value) / NULLIF(COUNT(*), 0) | |
| - For percentages, multiply by 100.0 (not 100) to ensure float division | |
| - For conditional counting: COUNT(CASE WHEN condition THEN 1 END) | |
| - For conditional sums: SUM(CASE WHEN condition THEN value ELSE 0 END) | |
| 10. COMMON EXCEPTIONS TO AVOID: | |
| - Don't use non-standard SQL functions without checking DuckDB compatibility | |
| - Don't use database-specific functions from SQL Server, MySQL, PostgreSQL, etc. | |
| - Don't mix aggregate and non-aggregate columns in SELECT without proper GROUP BY | |
| - Don't use column aliases in WHERE clauses (use the full expression or a subquery) | |
| - Don't use ORDER BY positional references (ORDER BY 1, 2) - use explicit column names | |
| - Don't use TOP; use LIMIT instead | |
| 11. PERFORMANCE CONSIDERATIONS: | |
| - Use WITH clauses (CTEs) for complex subqueries to improve readability | |
| - Filter data as early as possible in the query | |
| - Use EXISTS or IN for membership tests, not joins, when appropriate | |
| - Avoid unnecessary DISTINCT operations when possible | |
| 12. TEMPORAL DATA PATTERNS: | |
| - Recent time period: timestamp >= CURRENT_DATE - INTERVAL 'n days/months/years' | |
| - Current month: EXTRACT(MONTH FROM timestamp) = EXTRACT(MONTH FROM CURRENT_DATE) AND EXTRACT(YEAR FROM timestamp) = EXTRACT(YEAR FROM CURRENT_DATE) | |
| - Month-to-date: timestamp >= DATE_TRUNC('month', CURRENT_DATE) | |
| - Year-to-date: timestamp >= DATE_TRUNC('year', CURRENT_DATE) | |
| - Last N complete months: timestamp >= DATE_TRUNC('month', CURRENT_DATE) - INTERVAL 'n months' AND timestamp < DATE_TRUNC('month', CURRENT_DATE) | |
| COMPLEX PATTERN EXAMPLES: | |
| Example 1 (for the question: "What percentage of users who add a product to their cart actually proceed to purchase it?"): | |
| WITH cart_users AS ( | |
| SELECT DISTINCT user_id | |
| FROM data_table | |
| WHERE action = 'add_to_cart' | |
| ), | |
| purchase_users AS ( | |
| SELECT DISTINCT user_id | |
| FROM data_table | |
| WHERE action = 'purchase' | |
| ) | |
| SELECT | |
| (COUNT(DISTINCT pu.user_id) * 100.0) / NULLIF(COUNT(DISTINCT cu.user_id), 0) AS conversion_rate | |
| FROM cart_users cu | |
| LEFT JOIN purchase_users pu ON cu.user_id = pu.user_id; | |
| Example 2 (for date filtering): "How many purchases were made in the last 7 days?" | |
| SELECT COUNT(*) AS recent_purchases | |
| FROM data_table | |
| WHERE action = 'purchase' | |
| AND timestamp >= CURRENT_DATE - INTERVAL '7 days'; | |
| Example 3 (for sequential analysis): "What's the average time between a user viewing a product and purchasing it?" | |
| WITH view_events AS ( | |
| SELECT user_id, product_id, timestamp AS view_time | |
| FROM data_table | |
| WHERE action = 'view' | |
| ), | |
| purchase_events AS ( | |
| SELECT user_id, product_id, timestamp AS purchase_time | |
| FROM data_table | |
| WHERE action = 'purchase' | |
| ) | |
| SELECT AVG( | |
| EXTRACT(EPOCH FROM purchase_time - view_time) / 3600 | |
| ) AS avg_hours_to_purchase | |
| FROM view_events ve | |
| JOIN purchase_events pe ON ve.user_id = pe.user_id AND ve.product_id = pe.product_id | |
| WHERE purchase_time > view_time; | |
| SQL:""" | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2, | |
| ) | |
| sql = response.choices[0].message.content.strip() | |
| sanitized_sql = sanitize_sql(sql) | |
| logger.info(f"Generated SQL: {sanitized_sql}") | |
| return sanitized_sql | |
| except Exception as e: | |
| logger.error(f"Error generating SQL: {str(e)}", exc_info=True) | |
| raise Exception(f"Failed to generate SQL: {str(e)}") | |
| def validate_sql(sql: str, df: pd.DataFrame) -> Tuple[bool, str]: | |
| """ | |
| Validate SQL query without executing it. | |
| Returns (is_valid, error_message) | |
| """ | |
| try: | |
| logger.info(f"Validating SQL query: {sql}") | |
| # Check for basic SQL syntax using DuckDB's explain | |
| con = duckdb.connect(database=':memory:') | |
| con.register("data_table", df) | |
| # Use EXPLAIN to validate the SQL without executing it | |
| try: | |
| con.execute(f"EXPLAIN {sql}") | |
| return True, "" | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.warning(f"SQL validation failed: {error_msg}") | |
| return False, error_msg | |
| except Exception as e: | |
| logger.error(f"Error during SQL validation: {str(e)}", exc_info=True) | |
| return False, str(e) | |
| def fix_sql_with_llm(sql: str, error_msg: str, client: OpenAI) -> str: | |
| """ | |
| Try to fix SQL errors using the LLM. | |
| """ | |
| try: | |
| logger.info(f"Attempting to fix SQL error: {error_msg}") | |
| prompt = f""" | |
| You are an expert SQL developer who specializes in fixing SQL syntax errors. | |
| Here's a SQL query that has an error: | |
| ```sql | |
| {sql} | |
| ``` | |
| The error message is: | |
| ``` | |
| {error_msg} | |
| ``` | |
| Please fix the SQL query to resolve this error. Return ONLY the corrected SQL query without any explanations or markdown formatting. | |
| """ | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2, | |
| ) | |
| fixed_sql = response.choices[0].message.content.strip() | |
| fixed_sql = sanitize_sql(fixed_sql) | |
| logger.info(f"Fixed SQL: {fixed_sql}") | |
| return fixed_sql | |
| except Exception as e: | |
| logger.error(f"Error fixing SQL: {str(e)}", exc_info=True) | |
| # Return original SQL if fixing fails | |
| return sql | |
| def execute_query(sql: str, df: pd.DataFrame) -> pd.DataFrame: | |
| """Execute SQL query using DuckDB.""" | |
| try: | |
| logger.info(f"Executing SQL query: {sql}") | |
| con = duckdb.connect(database=':memory:') | |
| con.register("data_table", df) | |
| result_df = con.execute(sql).fetchdf() | |
| logger.info(f"Query executed successfully, returned {len(result_df)} rows") | |
| return result_df | |
| except Exception as e: | |
| logger.error(f"Error executing SQL: {str(e)}", exc_info=True) | |
| error_details = traceback.format_exc() | |
| logger.debug(f"SQL execution error details: {error_details}") | |
| raise Exception(f"Failed to execute SQL: {str(e)}") | |
| def auto_visualize(df: pd.DataFrame): | |
| """Automatically choose and create appropriate visualizations based on data types.""" | |
| try: | |
| # Get column types | |
| numeric_cols = df.select_dtypes(include=['number']).columns.tolist() | |
| categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() | |
| date_cols = detect_date_columns(df) | |
| # Ensure we have data to visualize | |
| if len(df) == 0: | |
| st.info("No data to visualize.") | |
| return | |
| # Case 1: Time series data | |
| if date_cols and numeric_cols: | |
| date_col = date_cols[0] # Take first date column | |
| metric_col = numeric_cols[0] # Take first numeric column | |
| st.subheader("π Time Series Analysis") | |
| st.caption(f"Showing {metric_col} over time") | |
| # Convert to datetime to ensure proper plotting | |
| if not pd.api.types.is_datetime64_any_dtype(df[date_col]): | |
| df[date_col] = pd.to_datetime(df[date_col], errors='coerce', infer_datetime_format=True) | |
| # Sort by date | |
| df_sorted = df.sort_values(by=date_col) | |
| # Create time series chart with trendline | |
| fig = px.line(df_sorted, x=date_col, y=metric_col, | |
| title=f"{metric_col} Trend Over Time", | |
| template="plotly_white") | |
| # Add trendline | |
| fig.add_trace(px.scatter(df_sorted, x=date_col, y=metric_col, | |
| trendline="ols", | |
| trendline_color_override="red").data[1]) | |
| # Customize layout | |
| fig.update_layout( | |
| xaxis_title="Date", | |
| yaxis_title=metric_col, | |
| hovermode="x unified", | |
| showlegend=True | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Add summary statistics | |
| with st.expander("π Summary Statistics", expanded=False): | |
| stats = df_sorted[metric_col].describe() | |
| st.write(stats) | |
| # Case 2: Categorical vs Numeric data | |
| elif categorical_cols and numeric_cols: | |
| cat_col = categorical_cols[0] # Take first categorical column | |
| metric_col = numeric_cols[0] # Take first numeric column | |
| st.subheader("π Category Analysis") | |
| st.caption(f"Showing {metric_col} by {cat_col}") | |
| # Limit categories for readability | |
| top_cats = df.groupby(cat_col)[metric_col].sum().nlargest(10).index | |
| df_filtered = df[df[cat_col].isin(top_cats)] | |
| # Create bar chart with hover data | |
| fig = px.bar(df_filtered, | |
| x=cat_col, | |
| y=metric_col, | |
| title=f"Top 10 {cat_col} by {metric_col}", | |
| template="plotly_white", | |
| color=metric_col, | |
| color_continuous_scale="Viridis") | |
| # Customize layout | |
| fig.update_layout( | |
| xaxis_title=cat_col, | |
| yaxis_title=metric_col, | |
| hovermode="x unified", | |
| showlegend=False | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Add summary statistics | |
| with st.expander("π Summary Statistics", expanded=False): | |
| stats = df_filtered.groupby(cat_col)[metric_col].agg(['sum', 'mean', 'count']).round(2) | |
| st.write(stats) | |
| # Case 3: Two numeric columns - scatter plot with correlation | |
| elif len(numeric_cols) >= 2: | |
| x_col = numeric_cols[0] | |
| y_col = numeric_cols[1] | |
| st.subheader("π Correlation Analysis") | |
| st.caption(f"Exploring relationship between {x_col} and {y_col}") | |
| # Create scatter plot with trendline | |
| fig = px.scatter(df, | |
| x=x_col, | |
| y=y_col, | |
| title=f"{y_col} vs {x_col}", | |
| template="plotly_white", | |
| trendline="ols", | |
| trendline_color_override="red") | |
| # Customize layout | |
| fig.update_layout( | |
| xaxis_title=x_col, | |
| yaxis_title=y_col, | |
| hovermode="closest" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Show correlation matrix | |
| with st.expander("π Correlation Matrix", expanded=False): | |
| corr = df[numeric_cols].corr() | |
| fig = px.imshow(corr, | |
| text_auto=True, | |
| title="Correlation Heatmap", | |
| color_continuous_scale='RdBu_r', | |
| template="plotly_white") | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Case 4: Single numeric column - distribution analysis | |
| elif numeric_cols: | |
| num_col = numeric_cols[0] | |
| st.subheader("π Distribution Analysis") | |
| st.caption(f"Distribution of {num_col}") | |
| # Create histogram with KDE | |
| fig = px.histogram(df, | |
| x=num_col, | |
| title=f"Distribution of {num_col}", | |
| template="plotly_white", | |
| marginal="box", | |
| nbins=30) | |
| # Customize layout | |
| fig.update_layout( | |
| xaxis_title=num_col, | |
| yaxis_title="Count", | |
| hovermode="x unified" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Show summary statistics | |
| with st.expander("π Summary Statistics", expanded=False): | |
| stats = df[num_col].describe() | |
| st.write(stats) | |
| # Case 5: Categorical data only - frequency analysis | |
| elif categorical_cols: | |
| cat_col = categorical_cols[0] | |
| st.subheader("π’ Category Distribution") | |
| st.caption(f"Frequency of {cat_col}") | |
| # Count frequency of each category | |
| value_counts = df[cat_col].value_counts().reset_index() | |
| value_counts.columns = [cat_col, 'Count'] | |
| # Limit to top 10 categories | |
| if len(value_counts) > 10: | |
| value_counts = value_counts.head(10) | |
| # Create bar chart | |
| fig = px.bar(value_counts, | |
| x=cat_col, | |
| y='Count', | |
| title=f"Top 10 {cat_col} Distribution", | |
| template="plotly_white", | |
| color='Count', | |
| color_continuous_scale="Viridis") | |
| # Customize layout | |
| fig.update_layout( | |
| xaxis_title=cat_col, | |
| yaxis_title="Count", | |
| hovermode="x unified", | |
| showlegend=False | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Show value counts | |
| with st.expander("π Value Counts", expanded=False): | |
| st.write(value_counts) | |
| else: | |
| st.info("Could not determine appropriate visualization for this data.") | |
| except Exception as e: | |
| st.error(f"Auto-visualization error: {str(e)}") | |
| def detect_date_columns(df: pd.DataFrame) -> list: | |
| """Detect columns that contain date information.""" | |
| date_cols = [] | |
| # First check columns that are already datetime | |
| date_cols.extend(df.select_dtypes(include=['datetime']).columns.tolist()) | |
| # Then check if string columns might contain dates | |
| string_cols = df.select_dtypes(include=['object']).columns | |
| for col in string_cols: | |
| if col not in date_cols: # Skip if already identified as date | |
| # Sample values (max 100) to check if they can be parsed as dates | |
| sample = df[col].dropna().head(10) # Reduced sample size to minimize warnings | |
| # Try to determine if this is a date column | |
| if len(sample) > 0: | |
| date_count = 0 | |
| for val in sample: | |
| try: | |
| # Try to convert to datetime | |
| if pd.to_datetime(val, errors='coerce') is not pd.NaT: | |
| date_count += 1 | |
| except: | |
| pass | |
| # If more than 70% of sampled values are valid dates, consider it a date column | |
| if date_count / len(sample) > 0.7: | |
| date_cols.append(col) | |
| return date_cols | |
| def process_user_question(user_question: str, df: pd.DataFrame, client: OpenAI): | |
| """Process user question, generate SQL, execute, and display results.""" | |
| try: | |
| # Generate SQL from the question | |
| with st.spinner("Generating SQL query..."): | |
| sql = generate_sql(user_question, df.head(5), client) | |
| # Display the generated SQL | |
| st.markdown("#### π§Ύ Generated SQL") | |
| # Create SQL display with copy button | |
| col1, col2 = st.columns([9, 1]) | |
| col1.code(sql, language="sql") | |
| if col2.button("π", help="Copy SQL to clipboard"): | |
| st.toast("SQL copied to clipboard!") | |
| # Validate the SQL before execution | |
| is_valid, error_msg = validate_sql(sql, df) | |
| if not is_valid: | |
| st.warning(f"β οΈ The generated SQL may have issues: {error_msg}") | |
| # Attempt to fix the SQL | |
| with st.spinner("Attempting to fix SQL..."): | |
| fixed_sql = fix_sql_with_llm(sql, error_msg, client) | |
| # Only proceed with fixed SQL if it's different from the original | |
| if fixed_sql != sql: | |
| st.markdown("#### π οΈ Fixed SQL") | |
| col1, col2 = st.columns([9, 1]) | |
| col1.code(fixed_sql, language="sql") | |
| if col2.button("π", help="Copy fixed SQL to clipboard", key="copy_fixed"): | |
| st.toast("Fixed SQL copied to clipboard!") | |
| # Ask user which version to use | |
| use_fixed = st.radio("Which SQL would you like to execute?", | |
| ["Use original SQL", "Use fixed SQL"], index=1) | |
| if use_fixed == "Use fixed SQL": | |
| sql = fixed_sql | |
| # Execute the query | |
| with st.spinner("Executing query..."): | |
| try: | |
| result_df = execute_query(sql, df) | |
| # Store in history | |
| st.session_state.query_history.append({ | |
| "question": user_question, | |
| "sql": sql, | |
| "succeeded": True, | |
| "timestamp": pd.Timestamp.now().strftime("%H:%M:%S") | |
| }) | |
| # Show results | |
| st.success("β Query executed successfully!") | |
| # Display results | |
| display_query_results(result_df) | |
| except Exception as e: | |
| # Add to history as failed | |
| st.session_state.query_history.append({ | |
| "question": user_question, | |
| "sql": sql, | |
| "succeeded": False, | |
| "error": str(e), | |
| "timestamp": pd.Timestamp.now().strftime("%H:%M:%S") | |
| }) | |
| st.error(f"β Error executing SQL: {str(e)}") | |
| # Show debug help | |
| with st.expander("π Debug Information", expanded=True): | |
| st.write("The error might be due to:") | |
| common_issues = [ | |
| "Missing GROUP BY for columns in SELECT", | |
| "Incorrect column names or table references", | |
| "Syntax errors in functions or operators", | |
| "Type mismatches in comparisons or joins" | |
| ] | |
| for issue in common_issues: | |
| st.markdown(f"- {issue}") | |
| # Offer to try a simpler version | |
| st.markdown("#### Try a simpler version") | |
| st.write("Let's simplify the query to help diagnose the issue:") | |
| # Generate a simplified version | |
| with st.spinner("Generating simplified query..."): | |
| simplify_prompt = f""" | |
| The following SQL query is failing with this error: {str(e)} | |
| SQL: | |
| ``` | |
| {sql} | |
| ``` | |
| Please create a much simpler version of this query that focuses only on the core intent. | |
| Remove complex joins, subqueries, and functions that might be causing issues. | |
| Return ONLY the simplified SQL without any explanations. | |
| """ | |
| simple_response = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=[{"role": "user", "content": simplify_prompt}], | |
| temperature=0.2, | |
| ) | |
| simple_sql = simple_response.choices[0].message.content.strip() | |
| simple_sql = sanitize_sql(simple_sql) | |
| st.code(simple_sql, language="sql") | |
| if st.button("Try Simplified Query"): | |
| try: | |
| simple_result = execute_query(simple_sql, df) | |
| st.success("β Simplified query succeeded!") | |
| st.dataframe(simple_result, use_container_width=True) | |
| except Exception as simple_error: | |
| st.error(f"Even simplified query failed: {str(simple_error)}") | |
| except Exception as e: | |
| st.error(f"β Error generating SQL: {str(e)}") | |
| def display_query_results(result_df: pd.DataFrame): | |
| """Display query results with basic visualization options.""" | |
| # Show count of results | |
| st.caption(f"π Found {len(result_df):,} matching records") | |
| # Initialize view mode in session state if not exists | |
| if "view_mode" not in st.session_state: | |
| st.session_state.view_mode = "Table" | |
| # Determine if chart is possible | |
| can_show_chart = not result_df.empty and len(result_df.columns) >= 2 | |
| # If chart is not possible, force table view | |
| if not can_show_chart and st.session_state.view_mode == "Chart": | |
| st.session_state.view_mode = "Table" | |
| # Add view mode toggle | |
| col1, col2 = st.columns([1, 4]) | |
| with col1: | |
| view_mode = st.radio( | |
| "Display Results As", | |
| ["Table", "Chart"], | |
| horizontal=True, | |
| key="view_mode", | |
| disabled=not can_show_chart, | |
| help="Chart view is disabled because the query result has only one column. At least 2 columns are required for visualization." if not can_show_chart else None | |
| ) | |
| # Display based on view mode | |
| if st.session_state.view_mode == "Chart" and can_show_chart: | |
| auto_visualize(result_df) | |
| else: | |
| st.dataframe(result_df, use_container_width=True, height=300) | |
| # Run the main application | |
| if __name__ == "__main__": | |
| main() |