"""
AI-Powered EDA & Feature Engineering Assistant
This application enables users to upload a CSV dataset, and utilizes LLMs to analyze
the dataset to provide EDA and feature engineering recommendations.
"""
import streamlit as st
import pandas as pd
import os
import base64
from io import BytesIO
from dotenv import load_dotenv
from typing import Dict, List, Any, Optional
import time
import logging
import plotly.express as px
import numpy as np
# Import LangChain memory components
from langchain.memory import ConversationBufferMemory
from langchain_core.messages import AIMessage, HumanMessage
# Import local modules
from eda_analysis import DatasetAnalyzer
from llm_inference import LLMInference
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
# Set page configuration - must be the first Streamlit command
st.set_page_config(
page_title="AI-Powered EDA & Feature Engineering Assistant",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded"
)
# Initialize our classes
@st.cache_resource
def get_llm_inference():
try:
return LLMInference()
except Exception as e:
st.error(f"Error initializing LLM inference: {str(e)}")
return None
llm_inference = get_llm_inference()
# Session state initialization
if "dataset_analyzer" not in st.session_state:
st.session_state.dataset_analyzer = DatasetAnalyzer()
if "dataset_loaded" not in st.session_state:
st.session_state.dataset_loaded = False
if "dataset_info" not in st.session_state:
st.session_state.dataset_info = {}
if "visualizations" not in st.session_state:
st.session_state.visualizations = {}
if "eda_insights" not in st.session_state:
st.session_state.eda_insights = ""
if "feature_engineering_recommendations" not in st.session_state:
st.session_state.feature_engineering_recommendations = ""
if "data_quality_insights" not in st.session_state:
st.session_state.data_quality_insights = ""
if "active_tab" not in st.session_state:
st.session_state.active_tab = "welcome"
# Add new functions to support the updated UI
def initialize_session_state():
"""Initialize session state variables needed for the application"""
# Initialize session variables with appropriate defaults
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Initialize conversation memory for LangChain
if "conversation_memory" not in st.session_state:
st.session_state.conversation_memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
# For dataframe and related variables, ensure proper initialization
# df should not be in session_state until a proper DataFrame is loaded
if "descriptive_stats" not in st.session_state:
st.session_state.descriptive_stats = None
if "selected_columns" not in st.session_state:
st.session_state.selected_columns = []
if "filtered_df" not in st.session_state:
st.session_state.filtered_df = None
if "ai_insights" not in st.session_state:
st.session_state.ai_insights = None
if "loading_insights" not in st.session_state:
st.session_state.loading_insights = False
if "selected_tab" not in st.session_state:
st.session_state.selected_tab = 'tab-overview'
if "dataset_name" not in st.session_state:
st.session_state.dataset_name = ""
# Logging initialization
logger.info("Session state initialized")
def apply_custom_css():
"""Apply additional custom CSS that's not already in the main CSS block"""
st.markdown("""
""", unsafe_allow_html=True)
def generate_ai_insights():
"""Generate AI-powered insights about the dataset"""
# Make sure we have a dataframe to analyze
if 'df' not in st.session_state:
logger.warning("Cannot generate AI insights: No dataframe in session state")
return {}
df = st.session_state.df
insights = {}
# Try to use the LLM for insights generation first
try:
if llm_inference is not None:
# Create dataset_info dictionary for LLM
num_rows, num_cols = df.shape
num_numerical = len(df.select_dtypes(include=['number']).columns)
num_categorical = len(df.select_dtypes(include=['object', 'category']).columns)
num_missing = df.isnull().sum().sum()
# Format missing values for better readability
missing_cols = df.isnull().sum()[df.isnull().sum() > 0]
missing_values = {}
for col in missing_cols.index:
count = missing_cols[col]
percent = round(count / len(df) * 100, 2)
missing_values[col] = (count, percent)
# Get numerical columns and their correlations if applicable
num_cols = df.select_dtypes(include=['number']).columns
correlations = "No numerical columns to calculate correlations."
if len(num_cols) > 1:
# Calculate correlations
corr_matrix = df[num_cols].corr()
# Get top correlations (absolute values)
corr_pairs = []
for i in range(len(num_cols)):
for j in range(i):
val = corr_matrix.iloc[i, j]
if abs(val) > 0.5: # Only show strong correlations
corr_pairs.append((num_cols[i], num_cols[j], val))
# Sort by absolute correlation and format
if corr_pairs:
corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
formatted_corrs = []
for col1, col2, val in corr_pairs[:5]: # Top 5
formatted_corrs.append(f"{col1} and {col2}: {val:.3f}")
correlations = "\n".join(formatted_corrs)
dataset_info = {
"shape": f"{num_rows} rows, {num_cols} columns",
"columns": df.columns.tolist(),
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
"missing_values": missing_values,
"basic_stats": df.describe().to_string(),
"correlations": correlations,
"sample_data": df.head(5).to_string()
}
# Generate EDA insights with better error handling
logger.info("Requesting EDA insights from LLM")
try:
eda_insights = llm_inference.generate_eda_insights(dataset_info)
if eda_insights and isinstance(eda_insights, str) and len(eda_insights) > 50:
# Clean and format the response
eda_insights = eda_insights.strip()
insights["EDA Insights"] = [eda_insights]
logger.info("Successfully generated EDA insights")
else:
logger.warning(f"EDA insights response was invalid: {type(eda_insights)}, length: {len(eda_insights) if isinstance(eda_insights, str) else 'N/A'}")
except Exception as e:
logger.error(f"Error generating EDA insights: {str(e)}")
# Generate feature engineering recommendations
if "EDA Insights" in insights: # Only proceed if EDA worked
logger.info("Requesting feature engineering recommendations from LLM")
try:
fe_insights = llm_inference.generate_feature_engineering_recommendations(dataset_info)
if fe_insights and isinstance(fe_insights, str) and len(fe_insights) > 50:
fe_insights = fe_insights.strip()
insights["Feature Engineering Recommendations"] = [fe_insights]
logger.info("Successfully generated feature engineering recommendations")
else:
logger.warning(f"Feature engineering response was invalid: {type(fe_insights)}, length: {len(fe_insights) if isinstance(fe_insights, str) else 'N/A'}")
except Exception as e:
logger.error(f"Error generating feature engineering recommendations: {str(e)}")
# Generate data quality insights
logger.info("Requesting data quality insights from LLM")
try:
dq_insights = llm_inference.generate_data_quality_insights(dataset_info)
if dq_insights and isinstance(dq_insights, str) and len(dq_insights) > 50:
dq_insights = dq_insights.strip()
insights["Data Quality Insights"] = [dq_insights]
logger.info("Successfully generated data quality insights")
else:
logger.warning(f"Data quality response was invalid: {type(dq_insights)}, length: {len(dq_insights) if isinstance(dq_insights, str) else 'N/A'}")
except Exception as e:
logger.error(f"Error generating data quality insights: {str(e)}")
# If we have at least one type of insights, consider it a success
if insights:
# Mark that the insights are loaded
st.session_state['loading_insights'] = False
logger.info("Successfully generated AI insights using LLM")
return insights
logger.warning("All LLM generated insights failed or were too short. Falling back to template insights.")
else:
logger.warning("LLM inference is not available. Falling back to template insights.")
except Exception as e:
logger.error(f"Error in generate_ai_insights(): {str(e)}. Falling back to template insights.")
# If LLM fails or is not available, generate template-based insights
logger.info("Falling back to template-based insights generation")
# Add missing values insights
missing_data = df.isnull().sum()
missing_percent = (missing_data / len(df)) * 100
missing_cols = missing_data[missing_data > 0]
missing_insights = []
if len(missing_cols) > 0:
missing_insights.append(f"Found {len(missing_cols)} columns with missing values.")
for col in missing_cols.index[:3]: # Show details for top 3
missing_insights.append(f"Column '{col}' has {missing_data[col]} missing values ({missing_percent[col]:.2f}%).")
if len(missing_cols) > 3:
missing_insights.append(f"And {len(missing_cols) - 3} more columns have missing values.")
# Add recommendation
if any(missing_percent > 50):
high_missing = missing_percent[missing_percent > 50].index.tolist()
missing_insights.append(f"Consider dropping columns with >50% missing values: {', '.join(high_missing[:3])}.")
else:
missing_insights.append("Consider using imputation techniques for columns with missing values.")
else:
missing_insights.append("No missing values found in the dataset. Great job!")
insights["Missing Values Analysis"] = missing_insights
# Add distribution insights
num_cols = df.select_dtypes(include=['number']).columns
dist_insights = []
if len(num_cols) > 0:
for col in num_cols[:3]: # Analyze top 3 numeric columns
# Check for skewness
skew = df[col].skew()
if abs(skew) > 1:
direction = "right" if skew > 0 else "left"
dist_insights.append(f"Column '{col}' is {direction}-skewed (skewness: {skew:.2f}). Consider log transformation.")
# Check for outliers using IQR
Q1 = df[col].quantile(0.25)
Q3 = df[col].quantile(0.75)
IQR = Q3 - Q1
outliers = df[(df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))][col].count()
if outliers > 0:
pct = (outliers / len(df)) * 100
dist_insights.append(f"Column '{col}' has {outliers} outliers ({pct:.2f}%). Consider outlier treatment.")
if len(num_cols) > 3:
dist_insights.append(f"Additional {len(num_cols) - 3} numerical columns not analyzed here.")
else:
dist_insights.append("No numerical columns found for distribution analysis.")
insights["Distribution Insights"] = dist_insights
# Add correlation insights
corr_insights = []
if len(num_cols) > 1:
# Calculate correlation
corr_matrix = df[num_cols].corr()
high_corr = []
# Find high correlations
for i in range(len(corr_matrix.columns)):
for j in range(i):
if abs(corr_matrix.iloc[i, j]) > 0.7:
high_corr.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_matrix.iloc[i, j]))
if high_corr:
corr_insights.append(f"Found {len(high_corr)} pairs of highly correlated features.")
for col1, col2, corr_val in high_corr[:3]: # Show top 3
corr_direction = "positively" if corr_val > 0 else "negatively"
corr_insights.append(f"'{col1}' and '{col2}' are strongly {corr_direction} correlated (r={corr_val:.2f}).")
if len(high_corr) > 3:
corr_insights.append(f"And {len(high_corr) - 3} more highly correlated pairs found.")
corr_insights.append("Consider removing some highly correlated features to reduce dimensionality.")
else:
corr_insights.append("No strong correlations found between features.")
else:
corr_insights.append("Need at least 2 numerical columns to analyze correlations.")
insights["Correlation Analysis"] = corr_insights
# Add feature engineering recommendations
fe_insights = []
# Check for date columns
date_cols = []
for col in df.columns:
if df[col].dtype == 'object':
try:
pd.to_datetime(df[col])
date_cols.append(col)
except:
pass
if date_cols:
fe_insights.append(f"Found {len(date_cols)} potential date columns: {', '.join(date_cols[:3])}.")
fe_insights.append("Consider extracting year, month, day, weekday from these columns.")
# Check for categorical columns
cat_cols = df.select_dtypes(include=['object']).columns
if len(cat_cols) > 0:
fe_insights.append(f"Found {len(cat_cols)} categorical columns.")
fe_insights.append("Consider one-hot encoding or label encoding for categorical features.")
# Check for high cardinality
high_card_cols = []
for col in cat_cols:
if df[col].nunique() > 10:
high_card_cols.append((col, df[col].nunique()))
if high_card_cols:
fe_insights.append(f"Some categorical columns have high cardinality:")
for col, card in high_card_cols[:2]:
fe_insights.append(f"Column '{col}' has {card} unique values. Consider grouping less common categories.")
# Suggest polynomial features if few numeric features
if 1 < len(num_cols) < 5:
fe_insights.append("Consider creating polynomial features or interaction terms between numerical features.")
insights["Feature Engineering Recommendations"] = fe_insights
# Add a slight delay to simulate processing
time.sleep(1)
# Mark that the insights are loaded
st.session_state['loading_insights'] = False
logger.info("Template-based insights generation completed")
return insights
def display_chat_interface():
"""Display a chat interface for interacting with the data"""
st.markdown('
', unsafe_allow_html=True)
st.markdown('
💬 Chat with Your Data
', unsafe_allow_html=True)
# Initialize chat history if not present
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# Make sure we have data to chat about
if 'df' not in st.session_state or st.session_state.df is None:
st.error("No dataset loaded. Please upload a CSV file to chat with your data.")
# Show a preview of chat capabilities
st.markdown("""
What can I help you with?
Once you upload a dataset, you can ask questions like:
- What patterns do you see in my data?
- How many missing values are there?
- What feature engineering would you recommend?
- Show me the distribution of a specific column
- What are the correlations between features?
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
return
# Add a button to clear chat history
col1, col2 = st.columns([4, 1])
with col2:
if st.button("Clear Chat", key="clear_chat"):
st.session_state.chat_history = []
# Reset conversation memory
if "conversation_memory" in st.session_state:
st.session_state.conversation_memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
logger.info("Chat history and memory cleared")
st.rerun()
# Display chat history
for message in st.session_state.chat_history:
if message["role"] == "user":
st.chat_message("user").write(message["content"])
else:
st.chat_message("assistant").write(message["content"])
# If no chat history, show some example questions
if not st.session_state.chat_history:
st.info("Ask me anything about your dataset! I can help you understand patterns, identify issues, and suggest improvements.")
st.markdown("### Example questions you can ask:")
# Create a grid of example questions using columns
col1, col2 = st.columns(2)
with col1:
example_questions = [
"What are the key patterns in this dataset?",
"Which columns have missing values?",
"What kind of feature engineering would help?"
]
for i, question in enumerate(example_questions):
if st.button(question, key=f"example_q_{i}"):
process_chat_message(question)
st.rerun()
with col2:
more_questions = [
"How are the numerical variables distributed?",
"What are the strongest correlations?",
"How can I prepare this data for modeling?"
]
for i, question in enumerate(more_questions):
if st.button(question, key=f"example_q_{i+3}"):
process_chat_message(question)
st.rerun()
# Input area for new messages
user_input = st.chat_input("Ask a question about your data...", key="chat_input")
if user_input:
# Add user message to chat history
process_chat_message(user_input)
st.rerun()
st.markdown('', unsafe_allow_html=True)
def display_descriptive_tab():
st.markdown('', unsafe_allow_html=True)
st.markdown('
📊 Descriptive Statistics
', unsafe_allow_html=True)
# Make sure we access the data from session state
if 'df' not in st.session_state or 'descriptive_stats' not in st.session_state:
st.error("No dataset loaded. Please upload a CSV file.")
st.markdown('', unsafe_allow_html=True)
return
df = st.session_state.df
descriptive_stats = st.session_state.descriptive_stats
# Display descriptive statistics in a more visually appealing way
col1, col2 = st.columns([3, 1])
with col1:
# Style the dataframe
st.markdown('', unsafe_allow_html=True)
st.subheader("Numerical Summary")
st.dataframe(descriptive_stats.style.background_gradient(cmap='Blues', axis=0)
.format(precision=2, na_rep="Missing"), use_container_width=True)
st.markdown('
', unsafe_allow_html=True)
with col2:
st.markdown('', unsafe_allow_html=True)
st.subheader("Dataset Overview")
# Display dataset information in a cleaner format
total_rows = df.shape[0]
total_cols = df.shape[1]
numeric_cols = len(df.select_dtypes(include=['number']).columns)
cat_cols = len(df.select_dtypes(include=['object', 'category']).columns)
date_cols = len(df.select_dtypes(include=['datetime']).columns)
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
# Add missing values information with visualization
st.markdown('', unsafe_allow_html=True)
st.subheader("Missing Values")
col1, col2 = st.columns([2, 3])
with col1:
# Calculate missing values
missing_data = df.isnull().sum()
missing_percent = (missing_data / len(df)) * 100
missing_data = pd.DataFrame({
'Missing Values': missing_data,
'Percentage (%)': missing_percent.round(2)
})
missing_data = missing_data[missing_data['Missing Values'] > 0].sort_values('Missing Values', ascending=False)
if not missing_data.empty:
st.dataframe(missing_data.style.background_gradient(cmap='Reds', subset=['Percentage (%)'])
.format({'Percentage (%)': '{:.2f}%'}), use_container_width=True)
else:
st.success("No missing values found in the dataset! 🎉")
with col2:
if not missing_data.empty:
# Create a horizontal bar chart for missing values
fig = px.bar(missing_data,
x='Percentage (%)',
y=missing_data.index,
orientation='h',
color='Percentage (%)',
color_continuous_scale='Reds',
title='Missing Values by Column')
fig.update_layout(
height=max(350, len(missing_data) * 30),
xaxis_title='Missing (%)',
yaxis_title='',
coloraxis_showscale=False,
margin=dict(l=0, r=10, t=30, b=0)
)
st.plotly_chart(fig, use_container_width=True)
st.markdown('
', unsafe_allow_html=True)
st.markdown('', unsafe_allow_html=True)
def display_distribution_tab():
st.markdown('', unsafe_allow_html=True)
st.markdown('
📈 Data Distribution
', unsafe_allow_html=True)
# Make sure we access the data from session state
if 'df' not in st.session_state:
st.error("No dataset loaded. Please upload a CSV file.")
st.markdown('', unsafe_allow_html=True)
return
df = st.session_state.df
# Add filters for better UX
st.markdown('', unsafe_allow_html=True)
col1, col2 = st.columns([1, 1])
with col1:
chart_type = st.selectbox(
"Select Chart Type",
["Histogram", "Box Plot", "Violin Plot", "Distribution Plot"],
key="chart_type_select"
)
with col2:
if chart_type != "Distribution Plot":
column_type = "Numerical" if chart_type in ["Histogram", "Box Plot", "Violin Plot"] else "Categorical"
columns_to_show = list(df.select_dtypes(include=['number']).columns) if column_type == "Numerical" else list(df.select_dtypes(include=['object', 'category']).columns)
selected_columns = st.multiselect(
f"Select {column_type} Columns to Visualize",
options=columns_to_show,
default=list(columns_to_show[:min(3, len(columns_to_show))]), # Convert to list ✅
key="column_select"
)
else:
num_cols = list(df.select_dtypes(include=['number']).columns) # Convert to list ✅
selected_columns = st.multiselect(
"Select Numerical Columns",
options=num_cols,
default=list(num_cols[:min(3, len(num_cols))]), # Convert to list ✅
key="column_select"
)
st.markdown('
', unsafe_allow_html=True)
# Display selected charts
if selected_columns:
st.markdown('', unsafe_allow_html=True)
if chart_type == "Histogram":
col1, col2 = st.columns([3, 1])
with col2:
bins = st.slider("Number of bins", min_value=5, max_value=100, value=30, key="hist_bins")
kde = st.checkbox("Show KDE", value=True, key="show_kde")
with col1:
pass
# Display histograms with better styling
for column in selected_columns:
st.markdown(f'
{column}
', unsafe_allow_html=True)
fig = px.histogram(df, x=column, nbins=bins,
title=f"Histogram of {column}",
marginal="box" if kde else None,
color_discrete_sequence=['rgba(99, 102, 241, 0.7)'])
fig.update_layout(
template="plotly_white",
height=400,
margin=dict(l=10, r=10, t=40, b=10),
xaxis_title=column,
yaxis_title="Frequency",
bargap=0.1
)
st.plotly_chart(fig, use_container_width=True)
# Show basic statistics
stats = df[column].describe().to_dict()
st.markdown(f"""
Mean: {stats['mean']:.2f}
Median: {stats['50%']:.2f}
Std Dev: {stats['std']:.2f}
Min: {stats['min']:.2f}
Max: {stats['max']:.2f}
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
elif chart_type == "Box Plot":
for column in selected_columns:
st.markdown(f'
{column}
', unsafe_allow_html=True)
fig = px.box(df, y=column, title=f"Box Plot of {column}",
color_discrete_sequence=['rgba(99, 102, 241, 0.7)'])
fig.update_layout(
template="plotly_white",
height=400,
margin=dict(l=10, r=10, t=40, b=10),
yaxis_title=column
)
st.plotly_chart(fig, use_container_width=True)
# Show outlier information
q1 = df[column].quantile(0.25)
q3 = df[column].quantile(0.75)
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outliers = df[(df[column] < lower_bound) | (df[column] > upper_bound)][column]
st.markdown(f"""
Q1 (25%): {q1:.2f}
Median: {df[column].median():.2f}
Q3 (75%): {q3:.2f}
IQR: {iqr:.2f}
Outliers: {len(outliers)} ({(len(outliers)/len(df)*100):.2f}%)
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
elif chart_type == "Violin Plot":
for column in selected_columns:
st.markdown(f'
{column}
', unsafe_allow_html=True)
fig = px.violin(df, y=column, box=True, points="all", title=f"Violin Plot of {column}",
color_discrete_sequence=['rgba(99, 102, 241, 0.7)'])
fig.update_layout(
template="plotly_white",
height=400,
margin=dict(l=10, r=10, t=40, b=10),
yaxis_title=column
)
fig.update_traces(marker=dict(size=3, opacity=0.5))
st.plotly_chart(fig, use_container_width=True)
st.markdown('', unsafe_allow_html=True)
elif chart_type == "Distribution Plot":
if len(selected_columns) >= 2:
st.markdown('
', unsafe_allow_html=True)
chart_options = st.radio(
"Select Distribution Plot Type",
["Scatter Plot", "Correlation Heatmap"],
horizontal=True
)
if chart_options == "Scatter Plot":
col1, col2 = st.columns([3, 1])
with col2:
x_axis = st.selectbox("X-axis", options=selected_columns, index=0)
y_axis = st.selectbox("Y-axis", options=selected_columns, index=min(1, len(selected_columns)-1))
color_option = st.selectbox("Color by", options=["None"] + df.columns.tolist())
with col1:
if color_option != "None":
fig = px.scatter(df, x=x_axis, y=y_axis,
color=color_option,
title=f"{y_axis} vs {x_axis} (colored by {color_option})",
opacity=0.7,
marginal_x="histogram", marginal_y="histogram")
else:
fig = px.scatter(df, x=x_axis, y=y_axis,
title=f"{y_axis} vs {x_axis}",
opacity=0.7,
marginal_x="histogram", marginal_y="histogram")
fig.update_layout(
template="plotly_white",
height=600,
margin=dict(l=10, r=10, t=40, b=10),
)
st.plotly_chart(fig, use_container_width=True)
elif chart_options == "Correlation Heatmap":
# Calculate correlation matrix
corr_matrix = df[selected_columns].corr()
# Create heatmap
fig = px.imshow(corr_matrix,
text_auto=".2f",
color_continuous_scale="RdBu_r",
zmin=-1, zmax=1,
title="Correlation Heatmap")
fig.update_layout(
template="plotly_white",
height=600,
margin=dict(l=10, r=10, t=40, b=10),
)
st.plotly_chart(fig, use_container_width=True)
# Show highest correlations
corr_df = corr_matrix.stack().reset_index()
corr_df.columns = ['Variable 1', 'Variable 2', 'Correlation']
corr_df = corr_df[corr_df['Variable 1'] != corr_df['Variable 2']]
corr_df = corr_df.sort_values('Correlation', ascending=False).head(5)
st.markdown("##### Top 5 Highest Correlations")
st.dataframe(corr_df.style.background_gradient(cmap='Blues')
.format({'Correlation': '{:.2f}'}), use_container_width=True)
st.markdown('
', unsafe_allow_html=True)
else:
st.warning("Please select at least 2 numerical columns to see distribution plots")
st.markdown('
', unsafe_allow_html=True)
else:
st.info("Please select at least one column to visualize")
st.markdown('', unsafe_allow_html=True)
def display_ai_insights_tab():
st.markdown('', unsafe_allow_html=True)
st.markdown('
🧠AI-Generated Insights
', unsafe_allow_html=True)
# Make sure we access the data from session state
if 'df' not in st.session_state:
st.error("No dataset loaded. Please upload a CSV file.")
st.markdown('', unsafe_allow_html=True)
return
if st.session_state.get('loading_insights', False):
with st.spinner("Generating AI insights about your data..."):
st.markdown('', unsafe_allow_html=True)
time.sleep(0.1) # Small delay to ensure UI updates
# AI insights section
if 'ai_insights' in st.session_state and st.session_state.ai_insights and len(st.session_state.ai_insights) > 0:
insights = st.session_state.ai_insights
st.markdown('', unsafe_allow_html=True)
for i, (category, insight_list) in enumerate(insights.items()):
with st.expander(f"{category}", expanded=i < 2):
st.markdown('
', unsafe_allow_html=True)
# Check if the insights are from LLM (single string) or template (list of strings)
if len(insight_list) == 1 and isinstance(insight_list[0], str) and len(insight_list[0]) > 100:
# This is likely an LLM-generated insight (single long string)
st.markdown(insight_list[0])
else:
# Template-based insights (list of strings)
for insight in insight_list:
st.markdown(f"""
""", unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
st.markdown('
', unsafe_allow_html=True)
# Add regenerate button
st.markdown('', unsafe_allow_html=True)
if st.button("Regenerate Insights", key="regenerate_insights"):
st.session_state['loading_insights'] = True
st.session_state['ai_insights'] = None
logger.info("User requested regeneration of AI insights")
st.rerun()
st.markdown('
', unsafe_allow_html=True)
else:
if not st.session_state.get('loading_insights', False):
# Show generate button if insights are not loading and not available
st.markdown('', unsafe_allow_html=True)
st.markdown("""
ðŸ§
Generate AI-powered insights about your dataset to discover patterns, anomalies, and suggestions for feature engineering.
""", unsafe_allow_html=True)
if st.button("Generate Insights", key="generate_insights"):
st.session_state['loading_insights'] = True
logger.info("User initiated AI insights generation")
st.rerun()
st.markdown('
', unsafe_allow_html=True)
st.markdown('', unsafe_allow_html=True)
def display_welcome_page():
"""Display a welcome page with information about the application"""
# Use Streamlit columns and components instead of raw HTML
st.title("Welcome to AI-Powered EDA & Feature Engineering Assistant")
st.write("""
Upload your CSV dataset and leverage the power of AI to analyze, visualize, and improve your data.
This tool helps you understand your data better and prepare it for machine learning models.
""")
# Feature cards
st.subheader("Key Features")
# Use Streamlit columns to create a grid layout
col1, col2 = st.columns(2)
with col1:
st.markdown("#### 📊 Exploratory Data Analysis")
st.write("Quickly understand your dataset with automatic statistical analysis and visualizations")
st.markdown("#### 🧠AI-Powered Insights")
st.write("Get intelligent recommendations about patterns, anomalies, and opportunities in your data")
st.markdown("#### âš¡ Feature Engineering")
st.write("Transform and enhance your features to improve machine learning model performance")
with col2:
st.markdown("#### 📈 Interactive Visualizations")
st.write("Explore distributions, relationships, and outliers with dynamic charts")
st.markdown("#### 💬 Chat Interface")
st.write("Ask questions about your data and get AI-powered answers in natural language")
st.markdown("#### 🔄 Data Transformation")
st.write("Clean, transform, and prepare your data for modeling with guided workflows")
# Usage section
st.subheader("How to use")
st.markdown("""
1. **Upload** your CSV dataset using the sidebar on the left
2. **Explore** automatically generated statistics and visualizations
3. **Generate** AI insights to better understand your data
4. **Chat** with AI to ask specific questions about your dataset
5. **Transform** your features based on recommendations
""")
# Upload prompt
st.info("👈 Please upload a CSV file using the sidebar to get started")
def display_relationships_tab():
"""Display correlations and relationships between variables"""
st.markdown('', unsafe_allow_html=True)
st.markdown('
🔄 Relationships & Correlations
', unsafe_allow_html=True)
# Make sure we have data to visualize
if 'df' not in st.session_state or st.session_state.df is None:
st.error("No dataset loaded. Please upload a CSV file.")
st.markdown('', unsafe_allow_html=True)
return
df = st.session_state.df
# Select numerical columns for correlation analysis
num_cols = df.select_dtypes(include=['number']).columns
if len(num_cols) < 2:
st.warning("At least 2 numerical columns are needed for correlation analysis.")
st.markdown('', unsafe_allow_html=True)
return
# Correlation matrix heatmap
st.subheader("Correlation Matrix")
# Calculate correlation
corr_matrix = df[num_cols].corr()
# Create correlation heatmap
fig = px.imshow(
corr_matrix,
text_auto=".2f",
color_continuous_scale="RdBu_r",
zmin=-1, zmax=1,
aspect="auto",
title="Correlation Heatmap"
)
fig.update_layout(
height=600,
width=800,
title_font_size=20,
margin=dict(l=10, r=10, t=30, b=10)
)
st.plotly_chart(fig, use_container_width=True)
# Show top correlations
st.subheader("Top Correlations")
# Extract and format correlations
corr_pairs = []
for i in range(len(num_cols)):
for j in range(i):
corr_pairs.append({
'Feature 1': num_cols[i],
'Feature 2': num_cols[j],
'Correlation': corr_matrix.iloc[i, j]
})
# Convert to dataframe and sort
corr_df = pd.DataFrame(corr_pairs)
sorted_corr = corr_df.sort_values('Correlation', key=abs, ascending=False).head(10)
# Show table with styled background
st.dataframe(
sorted_corr.style.background_gradient(cmap='RdBu_r', subset=['Correlation'])
.format({'Correlation': '{:.3f}'}),
use_container_width=True
)
# Scatter plot matrix
st.subheader("Scatter Plot Matrix")
# Convert num_cols to a list before using it in multiselect
num_cols = list(df.select_dtypes(include=['number']).columns)
# Ensure default selection is also a list
selected_cols = st.multiselect(
"Select columns for scatter plot matrix (max 5 recommended)",
options=num_cols,
default=list(num_cols[:min(4, len(num_cols))]) # Convert to list ✅
)
if selected_cols:
if len(selected_cols) > 5:
st.warning("More than 5 columns may make the plot hard to read.")
color_col = st.selectbox("Color by", options=["None"] + df.columns.tolist())
# Only pass the color parameter if not "None"
if color_col != "None":
fig = px.scatter_matrix(
df,
dimensions=selected_cols,
color=color_col,
opacity=0.7,
title="Scatter Plot Matrix"
)
else:
fig = px.scatter_matrix(
df,
dimensions=selected_cols,
opacity=0.7,
title="Scatter Plot Matrix"
)
fig.update_layout(
height=700,
title_font_size=18,
margin=dict(l=10, r=10, t=30, b=10)
)
st.plotly_chart(fig, use_container_width=True)
st.markdown('', unsafe_allow_html=True)
def process_chat_message(user_message):
"""Process a user message in the chat interface"""
# Add user message to chat history
st.session_state.chat_history.append({"role": "user", "content": user_message})
# Generate a response from the AI
if 'df' in st.session_state and st.session_state.df is not None:
# Try to use LLM if available, otherwise fall back to templates
try:
if llm_inference is not None:
# Create a prompt about the dataset
df = st.session_state.df
# Get basic dataset info
num_rows, num_cols = df.shape
num_numerical = len(df.select_dtypes(include=['number']).columns)
num_categorical = len(df.select_dtypes(include=['object', 'category']).columns)
num_missing = df.isnull().sum().sum()
missing_cols = df.isnull().sum()[df.isnull().sum() > 0]
# Format missing values for better readability
missing_values = {}
for col in missing_cols.index:
count = missing_cols[col]
percent = round(count / len(df) * 100, 2)
missing_values[col] = (count, percent)
# Get correlations for numerical columns
num_cols = df.select_dtypes(include=['number']).columns
correlations = "No numerical columns to calculate correlations."
if len(num_cols) > 1:
# Calculate correlations
corr_matrix = df[num_cols].corr()
# Get top 5 correlations (absolute values)
corr_pairs = []
for i in range(len(num_cols)):
for j in range(i):
val = corr_matrix.iloc[i, j]
if abs(val) > 0.5: # Only show strong correlations
corr_pairs.append((num_cols[i], num_cols[j], val))
# Sort by absolute correlation and format
if corr_pairs:
corr_pairs.sort(key=lambda x: abs(x[2]), reverse=True)
formatted_corrs = []
for col1, col2, val in corr_pairs[:5]: # Top 5
formatted_corrs.append(f"{col1} and {col2}: {val:.3f}")
correlations = "\n".join(formatted_corrs)
# Create dataset_info dictionary for LLM
dataset_info = {
"shape": f"{num_rows} rows, {num_cols} columns",
"columns": df.columns.tolist(),
"dtypes": {col: str(dtype) for col, dtype in df.dtypes.items()},
"missing_values": missing_values,
"basic_stats": df.describe().to_string(),
"correlations": correlations,
"sample_data": df.head(5).to_string()
}
# Generate response using LLM with memory
logger.info(f"Sending question to LLM with memory: {user_message}")
# Convert chat history to LangChain format for the memory object if needed
if len(st.session_state.chat_history) > 1 and "conversation_memory" in st.session_state:
# Use the memory-enabled version to maintain conversation context
response = llm_inference.answer_with_memory(
user_message,
dataset_info,
st.session_state.conversation_memory
)
else:
# If it's the first message, just use the regular question answering
response = llm_inference.answer_dataset_question(user_message, dataset_info)
# Initialize the memory with this first exchange
if "conversation_memory" in st.session_state:
st.session_state.conversation_memory.save_context(
{"input": user_message},
{"output": response}
)
# Log the raw response for debugging
logger.info(f"Raw LLM response: {response[:100]}...")
# If response is not empty and is a valid string
if response and isinstance(response, str) and len(response) > 10:
# Clean up the response if needed
cleaned_response = response.strip()
# Add to chat history
st.session_state.chat_history.append({"role": "assistant", "content": cleaned_response})
return
else:
logger.warning(f"LLM response too short or invalid: {response}")
raise Exception("LLM response too short or invalid")
else:
raise Exception("LLM not available")
except Exception as e:
logger.warning(f"Error using LLM for chat response: {str(e)}. Falling back to templates.")
# Fall back happens below
# If we're here, either there's no dataframe, LLM failed, or response was invalid
# Use template-based responses as fallback
if 'df' in st.session_state and st.session_state.df is not None:
df = st.session_state.df
# Simple response templates
responses = {
"missing": f"I found {df.isnull().sum().sum()} missing values across the dataset. The columns with the most missing values are: {df.isnull().sum().sort_values(ascending=False).head(3).index.tolist()}.",
"pattern": "Looking at the data, I can see several interesting patterns. The numerical features show varied distributions, and there might be some correlations worth exploring further.",
"feature": "Based on the data, I'd recommend feature engineering steps like handling missing values, encoding categorical variables, and possibly creating interaction terms for highly correlated features.",
"distribution": f"The numerical variables show different distributions. Some appear to be normally distributed while others show skewness. Let me know if you want to see visualizations for specific columns.",
"correlation": "I detected several strong correlations in the dataset. You might want to look at the correlation heatmap in the Relationships tab for more details.",
"prepare": "To prepare this data for modeling, I suggest: 1) Handling missing values, 2) Encoding categorical variables, 3) Feature scaling, and 4) Possibly dimensionality reduction if you have many features."
}
# Simple keyword matching for demo purposes
if "missing" in user_message.lower():
response = responses["missing"]
elif "pattern" in user_message.lower():
response = responses["pattern"]
elif "feature" in user_message.lower() or "engineering" in user_message.lower():
response = responses["feature"]
elif "distribut" in user_message.lower():
response = responses["distribution"]
elif "correlat" in user_message.lower() or "relation" in user_message.lower():
response = responses["correlation"]
elif "prepare" in user_message.lower() or "model" in user_message.lower():
response = responses["prepare"]
else:
# Generic response
response = "I analyzed your dataset and found some interesting insights. You can explore different aspects of your data using the tabs above. Is there anything specific you'd like to know about your data?"
else:
response = "Please upload a dataset first so I can analyze it and answer your questions."
# Add AI response to chat history
st.session_state.chat_history.append({"role": "assistant", "content": response})
def main():
"""Main function to run the application"""
# Initialize session state at the beginning
initialize_session_state()
# Apply CSS styling
apply_custom_css()
# Sidebar for file upload and settings
with st.sidebar:
st.markdown('', unsafe_allow_html=True)
# File uploader
st.markdown('', unsafe_allow_html=True)
# Load example dataset
with st.expander("Or use an example dataset"):
example_datasets = {
"Iris": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv",
"Tips": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/tips.csv",
"Titanic": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/titanic.csv",
"Diamonds": "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/diamonds.csv"
}
selected_example = st.selectbox("Select example dataset", list(example_datasets.keys()))
if st.button("Load Example", key="load_example_btn"):
try:
# Load the selected example dataset
df = pd.read_csv(example_datasets[selected_example])
# Verify we have a valid dataframe
if df is not None and not df.empty:
st.session_state['df'] = df
st.session_state['descriptive_stats'] = df.describe()
st.session_state['dataset_name'] = selected_example
st.success(f"Loaded {selected_example} dataset!")
else:
st.error(f"The {selected_example} dataset appears to be empty.")
except Exception as e:
st.error(f"Error loading example dataset: {str(e)}")
# Only show these sections if a dataset is loaded
if 'df' in st.session_state:
# Dataset Info
st.markdown('', unsafe_allow_html=True)
# Column filters
st.markdown('', unsafe_allow_html=True)
# Feature Engineering options with Streamlit buttons instead of JavaScript
st.markdown('', unsafe_allow_html=True)
# If data is uploaded, process it
if uploaded_file is not None and ('df' not in st.session_state or st.session_state.get('df') is None):
try:
# Attempt to read the CSV file
df = pd.read_csv(uploaded_file)
# Verify that we have a valid dataframe before storing in session state
if df is not None and not df.empty:
st.session_state['df'] = df
st.session_state['descriptive_stats'] = df.describe()
st.session_state['dataset_name'] = uploaded_file.name
st.success(f"Successfully loaded dataset: {uploaded_file.name}")
else:
st.error("The uploaded file appears to be empty.")
except Exception as e:
st.error(f"Error reading CSV file: {str(e)}")
# Create navigation tabs using Streamlit
st.write("### Navigation")
tabs = ["Overview", "Distribution", "Relationships", "AI Insights", "Chat"]
# Create columns for each tab
cols = st.columns(len(tabs))
# Handle tab selection using Streamlit buttons
for i, tab in enumerate(tabs):
with cols[i]:
if st.button(tab, key=f"tab_{tab.lower()}"):
st.session_state['selected_tab'] = f"tab-{tab.lower().replace(' ', '-')}"
st.rerun()
# Show selected tab indicator
selected_tab_name = st.session_state['selected_tab'].replace('tab-', '').replace('-', ' ').title()
st.markdown(f"Selected: {selected_tab_name}
", unsafe_allow_html=True)
# Show welcome message if no data is uploaded
if 'df' not in st.session_state:
display_welcome_page()
else:
# Display content based on selected tab
if st.session_state['selected_tab'] == 'tab-overview':
display_descriptive_tab()
elif st.session_state['selected_tab'] == 'tab-distribution':
display_distribution_tab()
elif st.session_state['selected_tab'] == 'tab-relationships':
display_relationships_tab()
elif st.session_state['selected_tab'] == 'tab-ai-insights' or st.session_state['selected_tab'] == 'tab-ai':
display_ai_insights_tab()
elif st.session_state['selected_tab'] == 'tab-chat':
display_chat_interface()
# After all tabs are rendered, check if we have a regenerate action
# This is processed at the end to avoid session state changes during rendering
if (st.session_state.get('loading_insights', False) and
('ai_insights' not in st.session_state or st.session_state.get('ai_insights') is None)):
logger.info("Generating AI insights at end of main function")
try:
st.session_state['ai_insights'] = generate_ai_insights()
logger.info(f"Generated insights: {len(st.session_state['ai_insights'])} categories")
st.session_state['loading_insights'] = False
except Exception as e:
logger.error(f"Error generating insights in main function: {str(e)}")
st.session_state['loading_insights'] = False
st.session_state['ai_insights'] = {} # Set to empty dict to prevent repeated failures
finally:
st.rerun()
if __name__ == "__main__":
main()