import matplotlib.pyplot as plt import seaborn as sns import numpy as np from scipy import stats import pandas as pd from sklearn.metrics import confusion_matrix, roc_curve, auc from sklearn.cluster import KMeans from sklearn.decomposition import PCA from sklearn.manifold import TSNE from scipy.cluster.hierarchy import dendrogram, linkage import logging # Configure logging for this module logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Consistent theme settings for plots FIG_SIZE = (10, 6) TITLE_FONT_SIZE = 14 LABEL_FONT_SIZE = 12 LEGEND_FONT_SIZE = 10 PRIMARY_COLOR = "#4C72B0" # A nice blue SECONDARY_COLOR = "#55A868" # A nice green def plot_histogram(df, col): """Generates a histogram for a given numeric column. Args: df (pd.DataFrame): The input DataFrame. col (str): The name of the numeric column to plot. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating histogram for column: {col}") if col not in df.columns: logging.error(f"Column '{col}' not found for histogram.") return None, f"Column '{col}' not found." if not pd.api.types.is_numeric_dtype(df[col]): logging.error(f"Column '{col}' is not numeric for histogram.") return None, "Histogram is only for numeric columns." plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") # Calculate optimal bin width using Freedman-Diaconis rule try: iqr = df[col].quantile(0.75) - df[col].quantile(0.25) if iqr > 0: bin_width = 2 * iqr / (len(df[col]) ** (1/3)) bins = int((df[col].max() - df[col].min()) / bin_width) if bin_width > 0 else 25 else: bins = 25 # Default if IQR is zero except Exception as e: logging.warning(f"Could not calculate optimal bins for {col}: {e}. Using default 25 bins.") bins = 25 ax = sns.histplot(df[col], kde=True, bins=bins, color=PRIMARY_COLOR, edgecolor='black', line_kws={'linewidth': 2, 'linestyle': '--'}) # Add mean and median lines mean_val = df[col].mean() median_val = df[col].median() ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.2f}') ax.axvline(median_val, color='green', linestyle='-', linewidth=2, label=f'Median: {median_val:.2f}') skewness = df[col].skew() plt.title(f'Distribution of {col} (Skewness: {skewness:.2f})', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel(col, fontsize=LABEL_FONT_SIZE) plt.ylabel('Density', fontsize=LABEL_FONT_SIZE) plt.legend(fontsize=LEGEND_FONT_SIZE) plt.tight_layout() logging.info(f"Histogram for {col} generated successfully.") return plt.gcf(), None def plot_bar(df, col): """Generates a bar plot for a given categorical or discrete numeric column. Args: df (pd.DataFrame): The input DataFrame. col (str): The name of the column to plot. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating bar plot for column: {col}") if col not in df.columns: logging.error(f"Column '{col}' not found for bar plot.") return None, f"Column '{col}' not found." plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") counts = df[col].value_counts() # Handle too many categories by showing top N and grouping others if len(counts) > 15: logging.info(f"Column {col} has too many unique values ({len(counts)}). Showing top 14 and grouping others.") top_14 = counts.nlargest(14) other_sum = counts.nsmallest(len(counts) - 14).sum() top_14['Other'] = other_sum counts = top_14 ax = sns.barplot(y=counts.index.astype(str), x=counts.values, palette="viridis", orient='h') # Add count labels to bars for i, v in enumerate(counts.values): ax.text(v + 1, i, str(v), color='black', va='center', fontsize=10) plt.title(f'Frequency of {col}', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel('Count', fontsize=LABEL_FONT_SIZE) plt.ylabel(col, fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info(f"Bar plot for {col} generated successfully.") return plt.gcf(), None def plot_scatter(df, col1, col2, color_col=None): """Generates a scatter plot between two numeric columns, with optional coloring. Args: df (pd.DataFrame): The input DataFrame. col1 (str): The name of the first numeric column (x-axis). col2 (str): The name of the second numeric column (y-axis). color_col (str, optional): The name of a column to use for coloring points. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating scatter plot for {col1} vs {col2}, colored by {color_col or 'None'}") if col1 not in df.columns or col2 not in df.columns: logging.error(f"One or both columns ({col1}, {col2}) not found for scatter plot.") return None, "One or both columns not found." if not pd.api.types.is_numeric_dtype(df[col1]) or not pd.api.types.is_numeric_dtype(df[col2]): logging.error(f"Columns {col1} or {col2} are not numeric for scatter plot.") return None, "Scatter plots are only available for numeric columns." if color_col and color_col != 'None' and color_col not in df.columns: logging.error(f"Color column '{color_col}' not found for scatter plot.") return None, f"Color column '{color_col}' not found." try: plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") hue = color_col if color_col and color_col != 'None' else None plot_df = df.dropna(subset=[col1, col2]) # Drop NaNs for plotting sns.scatterplot(data=plot_df, x=col1, y=col2, hue=hue, palette="coolwarm", s=50, alpha=0.6) # Add a linear regression trend line if both columns are numeric if pd.api.types.is_numeric_dtype(df[col1]) and pd.api.types.is_numeric_dtype(df[col2]): # Ensure there's enough data for linear regression if len(plot_df) > 1: m, b, r_value, _, _ = stats.linregress(plot_df[col1], plot_df[col2]) x_line = np.array([plot_df[col1].min(), plot_df[col1].max()]) y_line = m * x_line + b plt.plot(x_line, y_line, color='red', linestyle='--', label=f'Trend Line (R² = {r_value**2:.2f})') plt.legend(fontsize=LEGEND_FONT_SIZE) else: logging.warning("Not enough data points for linear regression trend line.") plt.title(f'{col1} vs. {col2}', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel(col1, fontsize=LABEL_FONT_SIZE) plt.ylabel(col2, fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info(f"Scatter plot for {col1} vs {col2} generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"An error occurred during scatter plot generation: {e}", exc_info=True) return None, f"An error occurred during plot generation: {e}" def plot_box(df, continuous_var, group_var): """Generates a box plot to show the distribution of a continuous variable across categories of a grouping variable. Args: df (pd.DataFrame): The input DataFrame. continuous_var (str): The name of the continuous numeric column. group_var (str): The name of the categorical or discrete column for grouping. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating box plot for {continuous_var} by {group_var}") if continuous_var not in df.columns or group_var not in df.columns: logging.error(f"One or both columns ({continuous_var}, {group_var}) not found for box plot.") return None, "One or both columns not found." if not pd.api.types.is_numeric_dtype(df[continuous_var]): logging.error(f"Column '{continuous_var}' is not numeric for box plot.") return None, "Box plots require a numeric column for the x-axis." plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") # Order categories by median of the continuous variable order = df.groupby(group_var)[continuous_var].median().sort_values(ascending=False).index sns.boxplot(data=df, x=continuous_var, y=group_var, palette="Set2", order=order, orient='h') plt.title(f'{continuous_var} by {group_var}', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel(continuous_var, fontsize=LABEL_FONT_SIZE) plt.ylabel(group_var, fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info(f"Box plot for {continuous_var} by {group_var} generated successfully.") return plt.gcf(), None def plot_pie(df, col): """Generates a pie chart for a given categorical column. Args: df (pd.DataFrame): The input DataFrame. col (str): The name of the categorical column to plot. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating pie chart for column: {col}") if col not in df.columns: logging.error(f"Column '{col}' not found for pie chart.") return None, f"Column '{col}' not found." counts = df[col].value_counts() # Handle too many categories by showing top N and grouping others if len(counts) > 7: logging.info(f"Column {col} has too many unique values ({len(counts)}). Showing top 6 and grouping others.") top_6 = counts.nlargest(6) other_sum = counts.nsmallest(len(counts) - 6).sum() top_6['Other'] = other_sum counts = top_6 plt.figure(figsize=(8, 8)) # Pie charts often look better square explode = [0.03] * len(counts) # Slightly separate slices for better visual colors = sns.color_palette('pastel')[0:len(counts)] plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=90, explode=explode, colors=colors, pctdistance=0.85) centre_circle = plt.Circle((0,0),0.70,fc='white') # Donut chart effect fig = plt.gcf() fig.gca().add_artist(centre_circle) plt.title(f'Distribution of {col}', fontsize=TITLE_FONT_SIZE, weight='bold') plt.tight_layout() logging.info(f"Pie chart for {col} generated successfully.") return plt.gcf(), None def plot_heatmap(df): """Generates a correlation heatmap for all numeric columns in the DataFrame. Args: df (pd.DataFrame): The input DataFrame. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info("Generating correlation heatmap.") numeric_df = df.select_dtypes(include=np.number) if numeric_df.shape[1] < 2: logging.error("Not enough numeric columns for a heatmap.") return None, "Not enough numeric columns for a heatmap." plt.figure(figsize=(12, 10)) corr = numeric_df.corr() # Generate a mask for the upper triangle mask = np.triu(np.ones_like(corr, dtype=bool)) sns.heatmap(corr, mask=mask, annot=True, cmap='coolwarm', fmt=".2f", linewidths=.5, vmin=-1, vmax=1, annot_kws={"size": 8}) plt.title('Correlation Heatmap', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xticks(rotation=45, ha='right', fontsize=LABEL_FONT_SIZE) plt.yticks(rotation=0, fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info("Correlation heatmap generated successfully.") return plt.gcf(), None def plot_confusion_matrix(y_true, y_pred, class_names): """Generates a confusion matrix plot. Args: y_true (array-like): True labels. y_pred (array-like): Predicted labels. class_names (list): List of class names for labels. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info("Generating confusion matrix.") try: cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names) plt.title('Confusion Matrix', fontsize=TITLE_FONT_SIZE, weight='bold') plt.ylabel('Actual', fontsize=LABEL_FONT_SIZE) plt.xlabel('Predicted', fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info("Confusion matrix generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"Error generating confusion matrix: {e}", exc_info=True) return None, f"Error generating confusion matrix: {e}" def plot_roc_curve(y_true, y_pred_proba, class_names=None): """Generates a Receiver Operating Characteristic (ROC) curve. Args: y_true (array-like): True binary labels. y_pred_proba (array-like): Target scores, probabilities of the positive class. class_names (list, optional): List of class names. Not directly used in plot but good for context. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info("Generating ROC curve.") try: # Handle multi-class or binary probability predictions if y_pred_proba.ndim == 1: # Binary classification, single probability array fpr, tpr, _ = roc_curve(y_true, y_pred_proba) elif y_pred_proba.shape[1] == 2: # Binary classification, two columns of probabilities fpr, tpr, _ = roc_curve(y_true, y_pred_proba[:, 1]) # Assume second column is positive class else: # Multi-class, need to binarize or choose a class # For simplicity, if multi-class, we'll plot ROC for the first class vs. rest # A more robust solution would allow selecting a class or plotting all. logging.warning("Multi-class ROC curve requested. Plotting for first class vs. rest.") # Binarize y_true for the first class y_true_bin = (y_true == sorted(np.unique(y_true))[0]).astype(int) fpr, tpr, _ = roc_curve(y_true_bin, y_pred_proba[:, 0]) roc_auc = auc(fpr, tpr) plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate', fontsize=LABEL_FONT_SIZE) plt.ylabel('True Positive Rate', fontsize=LABEL_FONT_SIZE) plt.title('Receiver Operating Characteristic', fontsize=TITLE_FONT_SIZE, weight='bold') plt.legend(loc="lower right", fontsize=LEGEND_FONT_SIZE) plt.tight_layout() logging.info("ROC curve generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"Error generating ROC curve: {e}", exc_info=True) return None, f"Error generating ROC curve: {e}" def plot_feature_importance(model, feature_names): """Generates a feature importance bar plot for tree-based models. Args: model: A trained model with a 'feature_importances_' attribute. feature_names (list): List of feature names corresponding to the importances. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info("Generating feature importance plot.") if not hasattr(model, 'feature_importances_'): logging.error("Model does not have feature importances attribute.") return None, "Model does not have feature importances." try: importances = model.feature_importances_ # Sort features by importance in descending order indices = np.argsort(importances)[::-1] plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") # Plot top N features for clarity num_features_to_plot = min(len(feature_names), 20) # Plot top 20 features or fewer if less available plt.title("Feature Importances", fontsize=TITLE_FONT_SIZE, weight='bold') sns.barplot(x=importances[indices[:num_features_to_plot]], y=[feature_names[i] for i in indices[:num_features_to_plot]], palette="viridis") plt.xlabel("Relative Importance", fontsize=LABEL_FONT_SIZE) plt.ylabel("Feature Name", fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info("Feature importance plot generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"Error generating feature importance plot: {e}", exc_info=True) return None, f"Error generating feature importance plot: {e}" def plot_elbow_curve(X, max_k=10): """Generates an elbow curve to help determine the optimal number of clusters (k) for KMeans. Args: X (pd.DataFrame or np.array): The input data for clustering. max_k (int, optional): The maximum number of clusters to test. Defaults to 10. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating elbow curve for max_k={max_k}") inertias = [] if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) # Ensure X is a DataFrame for .dropna() X_cleaned = X.dropna() # Handle NaNs for KMeans if X_cleaned.empty: logging.error("Data is empty after cleaning for Elbow Curve.") return None, "Data is empty after cleaning for Elbow Curve." # Ensure max_k is not greater than the number of samples if max_k > len(X_cleaned): logging.warning(f"max_k ({max_k}) is greater than number of samples ({len(X_cleaned)}). Adjusting max_k.") max_k = len(X_cleaned) if max_k < 1: return None, "max_k must be at least 1." try: for k in range(1, max_k + 1): kmeans = KMeans(n_clusters=k, random_state=42, n_init=10) # n_init to suppress warning kmeans.fit(X_cleaned) inertias.append(kmeans.inertia_) plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") plt.plot(range(1, max_k + 1), inertias, marker='o', linestyle='-', color=PRIMARY_COLOR) plt.xlabel('Number of clusters (k)', fontsize=LABEL_FONT_SIZE) plt.ylabel('Inertia', fontsize=LABEL_FONT_SIZE) plt.title('Elbow Method For Optimal k', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xticks(np.arange(1, max_k + 1, 1)) # Ensure integer ticks plt.tight_layout() logging.info("Elbow curve generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"Error generating elbow curve: {e}", exc_info=True) return None, f"Error generating elbow curve: {e}" def plot_cluster_plot(X, labels, title="Cluster Plot"): """Generates a 2D scatter plot of clusters, optionally after dimensionality reduction. Args: X (pd.DataFrame or np.array): The input data. labels (array-like, optional): Cluster labels for coloring points. If None, points are not colored. title (str, optional): Title of the plot. Defaults to "Cluster Plot". Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info(f"Generating cluster plot with title: {title}") if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) # Handle NaNs before dimensionality reduction X_cleaned = X.dropna() if X_cleaned.empty: logging.error("Data is empty after cleaning for Cluster Plot.") return None, "Data is empty after cleaning for Cluster Plot." plot_df = X_cleaned.copy() xlabel = 'Feature 1' ylabel = 'Feature 2' # Reduce dimensions to 2 if data has more than 2 features if X_cleaned.shape[1] > 2: try: logging.info("Applying PCA for dimensionality reduction to 2 components.") pca = PCA(n_components=2) X_reduced = pca.fit_transform(X_cleaned) plot_df = pd.DataFrame(X_reduced, columns=['PC1', 'PC2']) xlabel = 'Principal Component 1' ylabel = 'Principal Component 2' except Exception as e: logging.error(f"Could not reduce dimensions for cluster plot using PCA: {e}", exc_info=True) return None, f"Could not reduce dimensions for cluster plot: {e}" elif X_cleaned.shape[1] == 1: logging.error("Data must have at least 2 dimensions for a 2D cluster plot.") return None, "Data must have at least 2 dimensions for a 2D cluster plot." plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") if labels is not None: # Align labels with cleaned data if necessary if isinstance(labels, pd.Series): labels_aligned = labels.loc[X_cleaned.index] if labels.index.equals(X.index) else labels # Simple alignment else: labels_aligned = labels # Assume already aligned or numpy array sns.scatterplot(x=plot_df.iloc[:, 0], y=plot_df.iloc[:, 1], hue=labels_aligned, palette='viridis', s=50, alpha=0.7) plt.legend(title='Cluster', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=LEGEND_FONT_SIZE) else: sns.scatterplot(x=plot_df.iloc[:, 0], y=plot_df.iloc[:, 1], s=50, alpha=0.7, color=PRIMARY_COLOR) plt.title(title, fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel(xlabel, fontsize=LABEL_FONT_SIZE) plt.ylabel(ylabel, fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info("Cluster plot generated successfully.") return plt.gcf(), None def plot_dendrogram(X): """Generates a dendrogram for hierarchical clustering. Args: X (pd.DataFrame or np.array): The input data for clustering. Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info("Generating dendrogram.") if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) X_cleaned = X.dropna() # Handle NaNs if X_cleaned.empty: logging.error("Data is empty after cleaning for Dendrogram.") return None, "Data is empty after cleaning for Dendrogram." # Limit the number of samples for dendrogram for performance and readability if X_cleaned.shape[0] > 1000: logging.warning(f"Dendrogram data size ({X_cleaned.shape[0]}) is large. Sampling 1000 points.") X_cleaned = X_cleaned.sample(n=1000, random_state=42) try: linked = linkage(X_cleaned, 'ward') # Ward method minimizes variance within clusters plt.figure(figsize=(12, 8)) dendrogram(linked, orientation='top', distance_sort='descending', show_leaf_counts=True, leaf_rotation=90, leaf_font_size=8) plt.title('Hierarchical Clustering Dendrogram', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel('Sample Index or Cluster Size', fontsize=LABEL_FONT_SIZE) plt.ylabel('Distance', fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info("Dendrogram generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"Error generating dendrogram: {e}", exc_info=True) return None, f"Error generating dendrogram: {e}" def plot_tsne(X, labels=None): """Generates a t-SNE plot for dimensionality reduction and visualization of high-dimensional data. Args: X (pd.DataFrame or np.array): The input high-dimensional data. labels (array-like, optional): Labels for coloring points (e.g., cluster assignments). Returns: tuple: A matplotlib Figure object and an error message (None if successful). """ logging.info("Generating t-SNE plot.") if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) X_cleaned = X.dropna() # Handle NaNs if X_cleaned.empty: logging.error("Data is empty after cleaning for t-SNE.") return None, "Data is empty after cleaning for t-SNE." # t-SNE can be computationally expensive on large datasets, consider sampling if X_cleaned.shape[0] > 2000: logging.warning(f"t-SNE data size ({X_cleaned.shape[0]}) is large. Sampling 2000 points.") X_cleaned = X_cleaned.sample(n=2000, random_state=42) if labels is not None: # Align labels with sampled data if isinstance(labels, pd.Series): labels = labels.loc[X_cleaned.index] else: # If numpy array, convert to series for easy indexing labels = pd.Series(labels).loc[X_cleaned.index] try: # Perplexity should be less than the number of samples perplexity_val = min(30, len(X_cleaned) - 1) if len(X_cleaned) > 1 else 1 if perplexity_val < 1: return None, "Not enough samples for t-SNE (need at least 2)." tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity_val) X_tsne = tsne.fit_transform(X_cleaned) plt.figure(figsize=FIG_SIZE) sns.set_style("whitegrid") if labels is not None: sns.scatterplot(x=X_tsne[:, 0], y=X_tsne[:, 1], hue=labels, palette='viridis', s=50, alpha=0.7) plt.legend(title='Cluster/Label', bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=LEGEND_FONT_SIZE) else: sns.scatterplot(x=X_tsne[:, 0], y=X_tsne[:, 1], s=50, alpha=0.7, color=PRIMARY_COLOR) plt.title('t-SNE Plot', fontsize=TITLE_FONT_SIZE, weight='bold') plt.xlabel('t-SNE Component 1', fontsize=LABEL_FONT_SIZE) plt.ylabel('t-SNE Component 2', fontsize=LABEL_FONT_SIZE) plt.tight_layout() logging.info("t-SNE plot generated successfully.") return plt.gcf(), None except Exception as e: logging.error(f"Error generating t-SNE plot: {e}", exc_info=True) return None, f"Error generating t-SNE plot: {e}"