import matplotlib.pyplot as plt import seaborn as sns import plotly.express as px import plotly.graph_objects as go import pandas as pd import numpy as np class VisualizationEngine: def __init__(self): plt.style.use('seaborn-v0_8') self.color_palette = sns.color_palette("husl", 8) def create_visualizations(self, df, selected_features): """Create various visualizations based on selected features""" plots = [] if not selected_features: selected_features = df.columns[:4] # Default to first 4 columns for feature in selected_features: if feature in df.columns and feature != 'ID': if df[feature].dtype in ['int64', 'float64']: # Numerical feature visualizations plots.extend(self._create_numerical_plots(df, feature)) else: # Categorical feature visualizations plots.extend(self._create_categorical_plots(df, feature)) # Create comparison plots if len(selected_features) >= 2: plots.extend(self._create_comparison_plots(df, selected_features)) return plots def _create_numerical_plots(self, df, feature): """Create plots for numerical features""" plots = [] # Histogram plt.figure(figsize=(10, 6)) plt.hist(df[feature], bins=30, alpha=0.7, color=self.color_palette[0], edgecolor='black') plt.title(f'{feature} Distribution') plt.xlabel(feature) plt.ylabel('Frequency') plt.grid(True, alpha=0.3) plt.tight_layout() plot_name = f'{feature.lower().replace(" ", "_")}_histogram.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() # Box plot plt.figure(figsize=(8, 6)) plt.boxplot(df[feature], patch_artist=True, boxprops=dict(facecolor=self.color_palette[1])) plt.title(f'{feature} Box Plot') plt.ylabel(feature) plt.grid(True, alpha=0.3) plt.tight_layout() plot_name = f'{feature.lower().replace(" ", "_")}_boxplot.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() # Density plot plt.figure(figsize=(10, 6)) df[feature].plot(kind='density', color=self.color_palette[2], linewidth=2) plt.title(f'{feature} Density Plot') plt.xlabel(feature) plt.ylabel('Density') plt.grid(True, alpha=0.3) plt.tight_layout() plot_name = f'{feature.lower().replace(" ", "_")}_density.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() return plots def _create_categorical_plots(self, df, feature): """Create plots for categorical features""" plots = [] value_counts = df[feature].value_counts() # Bar plot plt.figure(figsize=(12, 6)) bars = plt.bar(value_counts.index, value_counts.values, color=self.color_palette[:len(value_counts)]) plt.title(f'{feature} Distribution') plt.xlabel(feature) plt.ylabel('Count') plt.xticks(rotation=45) # Add value labels on bars for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{int(height)}', ha='center', va='bottom') plt.tight_layout() plot_name = f'{feature.lower().replace(" ", "_")}_barplot.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() # Pie chart plt.figure(figsize=(10, 8)) plt.pie(value_counts.values, labels=value_counts.index, autopct='%1.1f%%', colors=self.color_palette[:len(value_counts)]) plt.title(f'{feature} Distribution (Pie Chart)') plt.tight_layout() plot_name = f'{feature.lower().replace(" ", "_")}_piechart.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() return plots def _create_comparison_plots(self, df, features): """Create comparison plots between features""" plots = [] numeric_features = [f for f in features if df[f].dtype in ['int64', 'float64']] categorical_features = [f for f in features if df[f].dtype in ['object', 'category']] # Scatter plots for numeric features if len(numeric_features) >= 2: for i in range(len(numeric_features)): for j in range(i+1, len(numeric_features)): plt.figure(figsize=(10, 8)) plt.scatter(df[numeric_features[i]], df[numeric_features[j]], alpha=0.6, color=self.color_palette[0]) plt.xlabel(numeric_features[i]) plt.ylabel(numeric_features[j]) plt.title(f'{numeric_features[i]} vs {numeric_features[j]}') plt.grid(True, alpha=0.3) plt.tight_layout() plot_name = f'{numeric_features[i].lower().replace(" ", "_")}_vs_{numeric_features[j].lower().replace(" ", "_")}_scatter.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() # Box plots for numeric vs categorical if numeric_features and categorical_features: for num_feat in numeric_features[:2]: # Limit to avoid too many plots for cat_feat in categorical_features[:2]: plt.figure(figsize=(12, 8)) df.boxplot(column=num_feat, by=cat_feat, ax=plt.gca()) plt.title(f'{num_feat} by {cat_feat}') plt.suptitle('') # Remove default title plt.xticks(rotation=45) plt.tight_layout() plot_name = f'{num_feat.lower().replace(" ", "_")}_by_{cat_feat.lower().replace(" ", "_")}_boxplot.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() # Correlation heatmap for numeric features if len(numeric_features) >= 2: plt.figure(figsize=(10, 8)) correlation_matrix = df[numeric_features].corr() sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0, square=True, linewidths=0.5) plt.title('Feature Correlation Matrix') plt.tight_layout() plot_name = 'selected_features_correlation.png' plt.savefig(plot_name, dpi=300, bbox_inches='tight') plots.append(plot_name) plt.close() return plots def create_interactive_plots(self, df, features): """Create interactive Plotly visualizations""" plots = [] for feature in features: if feature in df.columns and feature != 'ID': if df[feature].dtype in ['int64', 'float64']: # Interactive histogram fig = px.histogram(df, x=feature, title=f'{feature} Distribution') fig.write_html(f'{feature.lower().replace(" ", "_")}_interactive_hist.html') plots.append(f'{feature.lower().replace(" ", "_")}_interactive_hist.html') else: # Interactive bar chart value_counts = df[feature].value_counts() fig = px.bar(x=value_counts.index, y=value_counts.values, title=f'{feature} Distribution') fig.write_html(f'{feature.lower().replace(" ", "_")}_interactive_bar.html') plots.append(f'{feature.lower().replace(" ", "_")}_interactive_bar.html') return plots