Spaces:
Running
Running
Pulastya B
feat: Initial commit - Data Science Agent with React frontend and FastAPI backend
226ac39
| """ | |
| Matplotlib + Seaborn Visualization Engine | |
| Production-quality visualizations that work reliably with Gradio UI. | |
| All functions return matplotlib Figure objects (not file paths). | |
| Designed for publication-quality plots with professional styling. | |
| """ | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend for Gradio compatibility | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import numpy as np | |
| import pandas as pd | |
| from typing import Dict, Any, List, Optional, Tuple, Union | |
| from pathlib import Path | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Set global style | |
| sns.set_style('whitegrid') | |
| plt.rcParams['figure.facecolor'] = 'white' | |
| plt.rcParams['axes.facecolor'] = 'white' | |
| plt.rcParams['font.size'] = 10 | |
| plt.rcParams['axes.labelsize'] = 12 | |
| plt.rcParams['axes.titlesize'] = 14 | |
| plt.rcParams['xtick.labelsize'] = 10 | |
| plt.rcParams['ytick.labelsize'] = 10 | |
| plt.rcParams['legend.fontsize'] = 10 | |
| # ============================================================================ | |
| # BASIC PLOTS | |
| # ============================================================================ | |
| def create_scatter_plot( | |
| x: Union[np.ndarray, pd.Series, list], | |
| y: Union[np.ndarray, pd.Series, list], | |
| hue: Optional[Union[np.ndarray, pd.Series, list]] = None, | |
| size: Optional[Union[np.ndarray, pd.Series, list]] = None, | |
| title: str = "Scatter Plot", | |
| xlabel: str = "X", | |
| ylabel: str = "Y", | |
| figsize: Tuple[int, int] = (10, 6), | |
| alpha: float = 0.6, | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a professional scatter plot with optional color coding and size variation. | |
| Args: | |
| x: X-axis data | |
| y: Y-axis data | |
| hue: Optional categorical data for color coding | |
| size: Optional numeric data for size variation | |
| title: Plot title | |
| xlabel: X-axis label | |
| ylabel: Y-axis label | |
| figsize: Figure size (width, height) | |
| alpha: Point transparency (0-1) | |
| save_path: Optional path to save PNG file | |
| Returns: | |
| matplotlib Figure object | |
| Example: | |
| >>> fig = create_scatter_plot(df['feature1'], df['target'], | |
| ... hue=df['category'], title='Feature vs Target') | |
| >>> # Display in Gradio: gr.Plot(value=fig) | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Convert inputs to arrays | |
| x = np.array(x) | |
| y = np.array(y) | |
| if hue is not None: | |
| hue = np.array(hue) | |
| unique_hues = np.unique(hue) | |
| colors = sns.color_palette('Set2', n_colors=len(unique_hues)) | |
| for i, hue_val in enumerate(unique_hues): | |
| mask = hue == hue_val | |
| scatter_size = 50 if size is None else np.array(size)[mask] | |
| ax.scatter(x[mask], y[mask], | |
| c=[colors[i]], | |
| s=scatter_size, | |
| alpha=alpha, | |
| label=str(hue_val), | |
| edgecolors='black', | |
| linewidth=0.5) | |
| ax.legend(title='Category', loc='best', framealpha=0.9) | |
| else: | |
| scatter_size = 50 if size is None else size | |
| ax.scatter(x, y, | |
| c='steelblue', | |
| s=scatter_size, | |
| alpha=alpha, | |
| edgecolors='black', | |
| linewidth=0.5) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved scatter plot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating scatter plot: {str(e)}") | |
| return None | |
| def create_line_plot( | |
| x: Union[np.ndarray, pd.Series, list], | |
| y: Union[Dict[str, np.ndarray], np.ndarray, pd.Series, list], | |
| title: str = "Line Plot", | |
| xlabel: str = "X", | |
| ylabel: str = "Y", | |
| figsize: Tuple[int, int] = (10, 6), | |
| markers: bool = True, | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a line plot (supports multiple lines via dict). | |
| Args: | |
| x: X-axis data | |
| y: Y-axis data (dict for multiple lines: {'label1': y1, 'label2': y2}) | |
| title: Plot title | |
| xlabel: X-axis label | |
| ylabel: Y-axis label | |
| figsize: Figure size | |
| markers: Show markers on lines | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| x = np.array(x) | |
| marker_style = 'o' if markers else None | |
| if isinstance(y, dict): | |
| colors = sns.color_palette('husl', n_colors=len(y)) | |
| for i, (label, y_data) in enumerate(y.items()): | |
| ax.plot(x, np.array(y_data), | |
| marker=marker_style, | |
| label=label, | |
| linewidth=2, | |
| markersize=6, | |
| color=colors[i]) | |
| ax.legend(loc='best', framealpha=0.9) | |
| else: | |
| ax.plot(x, np.array(y), | |
| marker=marker_style, | |
| linewidth=2, | |
| markersize=6, | |
| color='steelblue') | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved line plot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating line plot: {str(e)}") | |
| return None | |
| def create_bar_chart( | |
| categories: Union[list, np.ndarray], | |
| values: Union[np.ndarray, pd.Series, list], | |
| title: str = "Bar Chart", | |
| xlabel: str = "Category", | |
| ylabel: str = "Value", | |
| figsize: Tuple[int, int] = (10, 6), | |
| horizontal: bool = False, | |
| color: str = 'steelblue', | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a bar chart (vertical or horizontal). | |
| Args: | |
| categories: Category names | |
| values: Values for each category | |
| title: Plot title | |
| xlabel: X-axis label | |
| ylabel: Y-axis label | |
| figsize: Figure size | |
| horizontal: If True, create horizontal bars | |
| color: Bar color | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| categories = list(categories) | |
| values = np.array(values) | |
| if horizontal: | |
| ax.barh(categories, values, color=color, edgecolor='black', linewidth=0.7) | |
| ax.set_xlabel(ylabel, fontsize=12) | |
| ax.set_ylabel(xlabel, fontsize=12) | |
| else: | |
| ax.bar(categories, values, color=color, edgecolor='black', linewidth=0.7) | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| # Rotate labels if many categories | |
| if len(categories) > 10: | |
| plt.xticks(rotation=45, ha='right') | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved bar chart to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating bar chart: {str(e)}") | |
| return None | |
| def create_histogram( | |
| data: Union[np.ndarray, pd.Series, list], | |
| title: str = "Histogram", | |
| xlabel: str = "Value", | |
| ylabel: str = "Frequency", | |
| bins: int = 30, | |
| figsize: Tuple[int, int] = (10, 6), | |
| kde: bool = True, | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a histogram with optional KDE overlay. | |
| Args: | |
| data: Data to plot | |
| title: Plot title | |
| xlabel: X-axis label | |
| ylabel: Y-axis label | |
| bins: Number of bins | |
| figsize: Figure size | |
| kde: Show kernel density estimate | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| data = np.array(data) | |
| data = data[~np.isnan(data)] # Remove NaN values | |
| if len(data) == 0: | |
| print(" β No valid data for histogram") | |
| return None | |
| # Create histogram | |
| ax.hist(data, bins=bins, color='steelblue', | |
| edgecolor='black', alpha=0.7, density=kde) | |
| # Add KDE if requested | |
| if kde: | |
| ax2 = ax.twinx() | |
| sns.kdeplot(data, ax=ax2, color='darkred', linewidth=2, label='KDE') | |
| ax2.set_ylabel('Density', fontsize=12) | |
| ax2.legend(loc='upper right') | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved histogram to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating histogram: {str(e)}") | |
| return None | |
| def create_boxplot( | |
| data: Union[Dict[str, np.ndarray], pd.DataFrame], | |
| title: str = "Box Plot", | |
| xlabel: str = "Category", | |
| ylabel: str = "Value", | |
| figsize: Tuple[int, int] = (10, 6), | |
| horizontal: bool = False, | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create box plots for multiple columns/categories. | |
| Args: | |
| data: Dictionary of {column_name: values} or DataFrame | |
| title: Plot title | |
| xlabel: X-axis label | |
| ylabel: Y-axis label | |
| figsize: Figure size | |
| horizontal: If True, create horizontal boxplots | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| if isinstance(data, pd.DataFrame): | |
| data_to_plot = [data[col].dropna() for col in data.columns] | |
| labels = data.columns | |
| elif isinstance(data, dict): | |
| data_to_plot = [np.array(v)[~np.isnan(np.array(v))] for v in data.values()] | |
| labels = list(data.keys()) | |
| else: | |
| raise ValueError("Data must be DataFrame or dict") | |
| bp = ax.boxplot(data_to_plot, | |
| labels=labels, | |
| vert=not horizontal, | |
| patch_artist=True, | |
| notch=True, | |
| showmeans=True) | |
| # Styling | |
| for patch in bp['boxes']: | |
| patch.set_facecolor('lightblue') | |
| patch.set_alpha(0.7) | |
| for whisker in bp['whiskers']: | |
| whisker.set(linewidth=1.5, color='gray') | |
| for cap in bp['caps']: | |
| cap.set(linewidth=1.5, color='gray') | |
| for median in bp['medians']: | |
| median.set(linewidth=2, color='darkred') | |
| for mean in bp['means']: | |
| mean.set(marker='D', markerfacecolor='green', markersize=6) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| if horizontal: | |
| ax.set_xlabel(ylabel, fontsize=12) | |
| ax.set_ylabel(xlabel, fontsize=12) | |
| else: | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| if len(labels) > 8: | |
| plt.xticks(rotation=45, ha='right') | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved boxplot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating boxplot: {str(e)}") | |
| return None | |
| # ============================================================================ | |
| # STATISTICAL PLOTS | |
| # ============================================================================ | |
| def create_correlation_heatmap( | |
| data: Union[pd.DataFrame, np.ndarray], | |
| columns: Optional[List[str]] = None, | |
| title: str = "Correlation Heatmap", | |
| figsize: Tuple[int, int] = (12, 10), | |
| annot: bool = True, | |
| cmap: str = 'RdBu_r', | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a correlation heatmap with annotations. | |
| Args: | |
| data: DataFrame or correlation matrix | |
| columns: Column names (if data is np.ndarray) | |
| title: Plot title | |
| figsize: Figure size | |
| annot: Show correlation values as annotations | |
| cmap: Colormap (diverging, centered at 0) | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| Example: | |
| >>> fig = create_correlation_heatmap(df[numeric_cols]) | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Calculate correlation if DataFrame | |
| if isinstance(data, pd.DataFrame): | |
| corr_matrix = data.corr() | |
| else: | |
| corr_matrix = pd.DataFrame(data, columns=columns, index=columns) | |
| # Create heatmap | |
| mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) # Mask upper triangle | |
| sns.heatmap(corr_matrix, | |
| mask=mask, | |
| annot=annot, | |
| fmt='.2f', | |
| cmap=cmap, | |
| center=0, | |
| square=True, | |
| linewidths=0.5, | |
| cbar_kws={'shrink': 0.8, 'label': 'Correlation'}, | |
| ax=ax, | |
| vmin=-1, | |
| vmax=1) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved correlation heatmap to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating correlation heatmap: {str(e)}") | |
| return None | |
| def create_distribution_plot( | |
| data: Union[np.ndarray, pd.Series, list], | |
| title: str = "Distribution Plot", | |
| xlabel: str = "Value", | |
| figsize: Tuple[int, int] = (10, 6), | |
| show_rug: bool = False, | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a distribution plot with histogram and KDE. | |
| Args: | |
| data: Data to plot | |
| title: Plot title | |
| xlabel: X-axis label | |
| figsize: Figure size | |
| show_rug: Show rug plot (data points on x-axis) | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| data = np.array(data) | |
| data = data[~np.isnan(data)] | |
| if len(data) == 0: | |
| print(" β No valid data for distribution plot") | |
| return None | |
| # Create distribution plot | |
| sns.histplot(data, kde=True, ax=ax, color='steelblue', | |
| edgecolor='black', alpha=0.6, bins=30) | |
| if show_rug: | |
| sns.rugplot(data, ax=ax, color='darkred', alpha=0.5, height=0.05) | |
| # Add statistics text | |
| mean_val = np.mean(data) | |
| median_val = np.median(data) | |
| std_val = np.std(data) | |
| stats_text = f'Mean: {mean_val:.2f}\nMedian: {median_val:.2f}\nStd: {std_val:.2f}' | |
| ax.text(0.98, 0.98, stats_text, | |
| transform=ax.transAxes, | |
| verticalalignment='top', | |
| horizontalalignment='right', | |
| bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5), | |
| fontsize=10) | |
| # Add vertical lines for mean and median | |
| ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label='Mean') | |
| ax.axvline(median_val, color='green', linestyle='--', linewidth=2, label='Median') | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel('Frequency / Density', fontsize=12) | |
| ax.legend(loc='upper left') | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved distribution plot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating distribution plot: {str(e)}") | |
| return None | |
| def create_violin_plot( | |
| data: Union[Dict[str, np.ndarray], pd.DataFrame], | |
| title: str = "Violin Plot", | |
| xlabel: str = "Category", | |
| ylabel: str = "Value", | |
| figsize: Tuple[int, int] = (10, 6), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create violin plots showing distribution for multiple categories. | |
| Args: | |
| data: Dictionary or DataFrame with categories | |
| title: Plot title | |
| xlabel: X-axis label | |
| ylabel: Y-axis label | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| if isinstance(data, dict): | |
| # Convert dict to DataFrame for seaborn | |
| df_list = [] | |
| for key, values in data.items(): | |
| df_list.append(pd.DataFrame({ | |
| 'Category': [key] * len(values), | |
| 'Value': values | |
| })) | |
| plot_df = pd.concat(df_list, ignore_index=True) | |
| else: | |
| plot_df = data | |
| # Create violin plot | |
| sns.violinplot(data=plot_df, x='Category', y='Value', ax=ax, | |
| palette='Set2', inner='box') | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xlabel(xlabel, fontsize=12) | |
| ax.set_ylabel(ylabel, fontsize=12) | |
| if len(plot_df['Category'].unique()) > 8: | |
| plt.xticks(rotation=45, ha='right') | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='y') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved violin plot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating violin plot: {str(e)}") | |
| return None | |
| def create_pairplot( | |
| data: pd.DataFrame, | |
| hue: Optional[str] = None, | |
| title: str = "Pair Plot", | |
| figsize: Tuple[int, int] = (12, 12), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a pairplot (scatterplot matrix) for multiple features. | |
| Args: | |
| data: DataFrame with features to plot | |
| hue: Column name for color coding | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| # Seaborn pairplot returns a PairGrid, we need to extract the figure | |
| if hue and hue in data.columns: | |
| pair_grid = sns.pairplot(data, hue=hue, palette='Set2', | |
| diag_kind='kde', corner=True) | |
| else: | |
| pair_grid = sns.pairplot(data, palette='Set2', | |
| diag_kind='kde', corner=True) | |
| fig = pair_grid.fig | |
| fig.suptitle(title, fontsize=14, fontweight='bold', y=1.01) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved pairplot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating pairplot: {str(e)}") | |
| return None | |
| # ============================================================================ | |
| # MACHINE LEARNING PLOTS | |
| # ============================================================================ | |
| def create_roc_curve( | |
| models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]], | |
| title: str = "ROC Curve Comparison", | |
| figsize: Tuple[int, int] = (10, 8), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create ROC curves for multiple models on the same plot. | |
| Args: | |
| models_data: Dict of {model_name: (fpr, tpr, auc_score)} | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| Example: | |
| >>> from sklearn.metrics import roc_curve, auc | |
| >>> fpr, tpr, _ = roc_curve(y_true, y_pred_proba) | |
| >>> auc_score = auc(fpr, tpr) | |
| >>> models = {'Random Forest': (fpr, tpr, auc_score)} | |
| >>> fig = create_roc_curve(models) | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| colors = sns.color_palette('husl', n_colors=len(models_data)) | |
| for i, (model_name, (fpr, tpr, auc_score)) in enumerate(models_data.items()): | |
| ax.plot(fpr, tpr, | |
| linewidth=2.5, | |
| label=f'{model_name} (AUC = {auc_score:.3f})', | |
| color=colors[i]) | |
| # Add diagonal reference line (random classifier) | |
| ax.plot([0, 1], [0, 1], | |
| linestyle='--', | |
| linewidth=2, | |
| color='gray', | |
| label='Random Classifier (AUC = 0.500)') | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('False Positive Rate', fontsize=12) | |
| ax.set_ylabel('True Positive Rate', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.legend(loc='lower right', fontsize=10, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved ROC curve to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating ROC curve: {str(e)}") | |
| return None | |
| def create_confusion_matrix( | |
| cm: np.ndarray, | |
| class_names: Optional[List[str]] = None, | |
| title: str = "Confusion Matrix", | |
| figsize: Tuple[int, int] = (10, 8), | |
| show_percentages: bool = True, | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a confusion matrix heatmap with annotations. | |
| Args: | |
| cm: Confusion matrix (from sklearn.metrics.confusion_matrix) | |
| class_names: Names of classes (optional) | |
| title: Plot title | |
| figsize: Figure size | |
| show_percentages: Show percentages in addition to counts | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| Example: | |
| >>> from sklearn.metrics import confusion_matrix | |
| >>> cm = confusion_matrix(y_true, y_pred) | |
| >>> fig = create_confusion_matrix(cm, class_names=['Class 0', 'Class 1']) | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| if class_names is None: | |
| class_names = [f'Class {i}' for i in range(len(cm))] | |
| # Normalize for percentages | |
| cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100 | |
| # Create annotations | |
| if show_percentages: | |
| annotations = np.array([[f'{count}\n({percent:.1f}%)' | |
| for count, percent in zip(row_counts, row_percents)] | |
| for row_counts, row_percents in zip(cm, cm_percent)]) | |
| else: | |
| annotations = cm | |
| # Create heatmap | |
| sns.heatmap(cm, | |
| annot=annotations, | |
| fmt='', | |
| cmap='Blues', | |
| square=True, | |
| linewidths=0.5, | |
| cbar_kws={'label': 'Count'}, | |
| xticklabels=class_names, | |
| yticklabels=class_names, | |
| ax=ax) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_ylabel('Actual', fontsize=12) | |
| ax.set_xlabel('Predicted', fontsize=12) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved confusion matrix to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating confusion matrix: {str(e)}") | |
| return None | |
| def create_precision_recall_curve( | |
| models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]], | |
| title: str = "Precision-Recall Curve", | |
| figsize: Tuple[int, int] = (10, 8), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create precision-recall curves for multiple models. | |
| Args: | |
| models_data: Dict of {model_name: (precision, recall, avg_precision)} | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| colors = sns.color_palette('husl', n_colors=len(models_data)) | |
| for i, (model_name, (precision, recall, avg_precision)) in enumerate(models_data.items()): | |
| ax.plot(recall, precision, | |
| linewidth=2.5, | |
| label=f'{model_name} (AP = {avg_precision:.3f})', | |
| color=colors[i]) | |
| ax.set_xlim([0.0, 1.0]) | |
| ax.set_ylim([0.0, 1.05]) | |
| ax.set_xlabel('Recall', fontsize=12) | |
| ax.set_ylabel('Precision', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.legend(loc='best', fontsize=10, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved precision-recall curve to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating precision-recall curve: {str(e)}") | |
| return None | |
| def create_feature_importance( | |
| feature_names: List[str], | |
| importances: np.ndarray, | |
| title: str = "Feature Importance", | |
| top_n: int = 20, | |
| figsize: Tuple[int, int] = (10, 8), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a horizontal bar chart of feature importances. | |
| Args: | |
| feature_names: List of feature names | |
| importances: Array of importance values | |
| title: Plot title | |
| top_n: Number of top features to show | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| Example: | |
| >>> importances = model.feature_importances_ | |
| >>> fig = create_feature_importance(feature_names, importances, top_n=15) | |
| """ | |
| try: | |
| # Sort by importance | |
| indices = np.argsort(importances)[::-1][:top_n] | |
| sorted_features = [feature_names[i] for i in indices] | |
| sorted_importances = importances[indices] | |
| # Create figure with appropriate height | |
| height = max(8, top_n * 0.4) | |
| fig, ax = plt.subplots(figsize=(figsize[0], height)) | |
| # Color bars by positive/negative (if any negative values) | |
| colors = ['green' if x >= 0 else 'red' for x in sorted_importances] | |
| # Create horizontal bar chart | |
| y_pos = np.arange(len(sorted_features)) | |
| ax.barh(y_pos, sorted_importances, color=colors, edgecolor='black', linewidth=0.7) | |
| ax.set_yticks(y_pos) | |
| ax.set_yticklabels(sorted_features) | |
| ax.invert_yaxis() # Top features at top | |
| ax.set_xlabel('Importance Score', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='x') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved feature importance to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating feature importance plot: {str(e)}") | |
| return None | |
| def create_residual_plot( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| title: str = "Residual Plot", | |
| figsize: Tuple[int, int] = (10, 6), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a residual plot (Predicted vs Actual) for regression models. | |
| Args: | |
| y_true: True target values | |
| y_pred: Predicted values | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(figsize[0]*2, figsize[1])) | |
| residuals = y_true - y_pred | |
| # Plot 1: Predicted vs Actual | |
| ax1.scatter(y_true, y_pred, alpha=0.5, s=50, edgecolors='black', linewidth=0.5) | |
| # Add perfect prediction line | |
| min_val = min(y_true.min(), y_pred.min()) | |
| max_val = max(y_true.max(), y_pred.max()) | |
| ax1.plot([min_val, max_val], [min_val, max_val], | |
| 'r--', linewidth=2, label='Perfect Prediction') | |
| ax1.set_xlabel('Actual Values', fontsize=12) | |
| ax1.set_ylabel('Predicted Values', fontsize=12) | |
| ax1.set_title('Predicted vs Actual', fontsize=13, fontweight='bold') | |
| ax1.legend() | |
| ax1.grid(True, alpha=0.3, linestyle='--') | |
| # Plot 2: Residuals vs Predicted | |
| ax2.scatter(y_pred, residuals, alpha=0.5, s=50, | |
| color='steelblue', edgecolors='black', linewidth=0.5) | |
| ax2.axhline(y=0, color='red', linestyle='--', linewidth=2) | |
| ax2.set_xlabel('Predicted Values', fontsize=12) | |
| ax2.set_ylabel('Residuals', fontsize=12) | |
| ax2.set_title('Residuals vs Predicted', fontsize=13, fontweight='bold') | |
| ax2.grid(True, alpha=0.3, linestyle='--') | |
| fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved residual plot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating residual plot: {str(e)}") | |
| return None | |
| def create_learning_curve( | |
| train_sizes: np.ndarray, | |
| train_scores_mean: np.ndarray, | |
| train_scores_std: np.ndarray, | |
| val_scores_mean: np.ndarray, | |
| val_scores_std: np.ndarray, | |
| title: str = "Learning Curve", | |
| figsize: Tuple[int, int] = (10, 6), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a learning curve showing training and validation scores. | |
| Args: | |
| train_sizes: Array of training set sizes | |
| train_scores_mean: Mean training scores | |
| train_scores_std: Std of training scores | |
| val_scores_mean: Mean validation scores | |
| val_scores_std: Std of validation scores | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Plot training scores | |
| ax.plot(train_sizes, train_scores_mean, 'o-', color='blue', | |
| linewidth=2, markersize=8, label='Training Score') | |
| ax.fill_between(train_sizes, | |
| train_scores_mean - train_scores_std, | |
| train_scores_mean + train_scores_std, | |
| alpha=0.2, color='blue') | |
| # Plot validation scores | |
| ax.plot(train_sizes, val_scores_mean, 'o-', color='orange', | |
| linewidth=2, markersize=8, label='Validation Score') | |
| ax.fill_between(train_sizes, | |
| val_scores_mean - val_scores_std, | |
| val_scores_mean + val_scores_std, | |
| alpha=0.2, color='orange') | |
| ax.set_xlabel('Training Set Size', fontsize=12) | |
| ax.set_ylabel('Score', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.legend(loc='best', fontsize=11, framealpha=0.9) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved learning curve to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating learning curve: {str(e)}") | |
| return None | |
| # ============================================================================ | |
| # DATA QUALITY PLOTS | |
| # ============================================================================ | |
| def create_missing_values_heatmap( | |
| df: pd.DataFrame, | |
| title: str = "Missing Values Heatmap", | |
| figsize: Tuple[int, int] = (12, 8), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a heatmap showing missing values pattern. | |
| Args: | |
| df: DataFrame to analyze | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Create binary matrix (1 = missing, 0 = present) | |
| missing_matrix = df.isnull().astype(int) | |
| # Plot heatmap | |
| sns.heatmap(missing_matrix.T, | |
| cbar=False, | |
| cmap='RdYlGn_r', | |
| ax=ax, | |
| yticklabels=df.columns) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.set_xlabel('Sample Index', fontsize=12) | |
| ax.set_ylabel('Features', fontsize=12) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved missing values heatmap to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating missing values heatmap: {str(e)}") | |
| return None | |
| def create_missing_values_bar( | |
| df: pd.DataFrame, | |
| title: str = "Missing Values by Column", | |
| figsize: Tuple[int, int] = (10, 6), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a bar chart showing percentage of missing values per column. | |
| Args: | |
| df: DataFrame to analyze | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| # Calculate missing percentages | |
| missing_pct = (df.isnull().sum() / len(df) * 100).sort_values(ascending=False) | |
| missing_pct = missing_pct[missing_pct > 0] # Only columns with missing values | |
| if len(missing_pct) == 0: | |
| print(" βΉ No missing values found") | |
| return None | |
| height = max(6, len(missing_pct) * 0.3) | |
| fig, ax = plt.subplots(figsize=(figsize[0], height)) | |
| # Create horizontal bar chart | |
| colors = plt.cm.Reds(missing_pct / 100) | |
| ax.barh(range(len(missing_pct)), missing_pct.values, | |
| color=colors, edgecolor='black', linewidth=0.7) | |
| ax.set_yticks(range(len(missing_pct))) | |
| ax.set_yticklabels(missing_pct.index) | |
| ax.set_xlabel('Missing Values (%)', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='x') | |
| # Add percentage labels | |
| for i, v in enumerate(missing_pct.values): | |
| ax.text(v + 1, i, f'{v:.1f}%', va='center', fontsize=10) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved missing values bar chart to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating missing values bar chart: {str(e)}") | |
| return None | |
| def create_outlier_detection_boxplot( | |
| df: pd.DataFrame, | |
| columns: Optional[List[str]] = None, | |
| title: str = "Outlier Detection", | |
| figsize: Tuple[int, int] = (12, 6), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create box plots for outlier detection across multiple columns. | |
| Args: | |
| df: DataFrame with numeric columns | |
| columns: Columns to plot (None = all numeric) | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| if columns is None: | |
| columns = df.select_dtypes(include=[np.number]).columns.tolist()[:10] | |
| return create_boxplot(df[columns], title=title, figsize=figsize, save_path=save_path) | |
| except Exception as e: | |
| print(f" β Error creating outlier detection plot: {str(e)}") | |
| return None | |
| def create_skewness_plot( | |
| df: pd.DataFrame, | |
| title: str = "Feature Skewness Distribution", | |
| figsize: Tuple[int, int] = (10, 6), | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a bar chart showing skewness of numeric features. | |
| Args: | |
| df: DataFrame with numeric columns | |
| title: Plot title | |
| figsize: Figure size | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| try: | |
| # Calculate skewness for numeric columns | |
| numeric_cols = df.select_dtypes(include=[np.number]).columns | |
| skewness = df[numeric_cols].skew().sort_values(ascending=False) | |
| if len(skewness) == 0: | |
| print(" βΉ No numeric columns to analyze") | |
| return None | |
| height = max(6, len(skewness) * 0.3) | |
| fig, ax = plt.subplots(figsize=(figsize[0], height)) | |
| # Color by skewness level | |
| colors = ['green' if abs(x) < 0.5 else 'orange' if abs(x) < 1 else 'red' | |
| for x in skewness.values] | |
| ax.barh(range(len(skewness)), skewness.values, | |
| color=colors, edgecolor='black', linewidth=0.7) | |
| ax.set_yticks(range(len(skewness))) | |
| ax.set_yticklabels(skewness.index) | |
| ax.set_xlabel('Skewness', fontsize=12) | |
| ax.set_title(title, fontsize=14, fontweight='bold', pad=20) | |
| ax.axvline(x=0, color='black', linestyle='-', linewidth=1) | |
| ax.axvline(x=-0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5) | |
| ax.axvline(x=0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5) | |
| ax.grid(True, alpha=0.3, linestyle='--', axis='x') | |
| # Add legend | |
| from matplotlib.patches import Patch | |
| legend_elements = [ | |
| Patch(facecolor='green', label='Low (|skew| < 0.5)'), | |
| Patch(facecolor='orange', label='Moderate (0.5 β€ |skew| < 1)'), | |
| Patch(facecolor='red', label='High (|skew| β₯ 1)') | |
| ] | |
| ax.legend(handles=legend_elements, loc='best') | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved skewness plot to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating skewness plot: {str(e)}") | |
| return None | |
| # ============================================================================ | |
| # UTILITY FUNCTIONS | |
| # ============================================================================ | |
| def save_figure(fig: plt.Figure, path: str, dpi: int = 300) -> None: | |
| """ | |
| Save a matplotlib figure to file. | |
| Args: | |
| fig: Matplotlib Figure object | |
| path: Output file path (supports .png, .jpg, .pdf, .svg) | |
| dpi: Resolution (dots per inch) | |
| """ | |
| try: | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path, dpi=dpi, bbox_inches='tight', facecolor='white') | |
| print(f" β Saved figure to {path}") | |
| except Exception as e: | |
| print(f" β Error saving figure: {str(e)}") | |
| def close_figure(fig: plt.Figure) -> None: | |
| """ | |
| Close a matplotlib figure to free memory. | |
| Args: | |
| fig: Matplotlib Figure object | |
| """ | |
| if fig is not None: | |
| plt.close(fig) | |
| def create_subplots_grid( | |
| plot_data: List[Dict[str, Any]], | |
| rows: int, | |
| cols: int, | |
| figsize: Tuple[int, int] = (15, 12), | |
| title: str = "Plot Grid", | |
| save_path: Optional[str] = None | |
| ) -> plt.Figure: | |
| """ | |
| Create a grid of subplots. | |
| Args: | |
| plot_data: List of dicts with plot specifications | |
| rows: Number of rows | |
| cols: Number of columns | |
| figsize: Figure size | |
| title: Overall title | |
| save_path: Optional save path | |
| Returns: | |
| matplotlib Figure object | |
| Example: | |
| >>> plots = [ | |
| ... {'type': 'scatter', 'x': x1, 'y': y1, 'title': 'Plot 1'}, | |
| ... {'type': 'hist', 'data': data1, 'title': 'Plot 2'} | |
| ... ] | |
| >>> fig = create_subplots_grid(plots, 2, 2) | |
| """ | |
| try: | |
| fig, axes = plt.subplots(rows, cols, figsize=figsize) | |
| axes = axes.flatten() if rows * cols > 1 else [axes] | |
| for i, (ax, plot_spec) in enumerate(zip(axes, plot_data)): | |
| plot_type = plot_spec.get('type', 'scatter') | |
| if plot_type == 'scatter': | |
| ax.scatter(plot_spec['x'], plot_spec['y'], alpha=0.6) | |
| elif plot_type == 'hist': | |
| ax.hist(plot_spec['data'], bins=30, edgecolor='black') | |
| elif plot_type == 'line': | |
| ax.plot(plot_spec['x'], plot_spec['y']) | |
| ax.set_title(plot_spec.get('title', f'Subplot {i+1}'), fontweight='bold') | |
| ax.grid(True, alpha=0.3) | |
| # Hide unused subplots | |
| for i in range(len(plot_data), len(axes)): | |
| axes[i].axis('off') | |
| fig.suptitle(title, fontsize=16, fontweight='bold', y=0.995) | |
| plt.tight_layout() | |
| if save_path: | |
| fig.savefig(save_path, dpi=300, bbox_inches='tight') | |
| print(f" β Saved subplot grid to {save_path}") | |
| return fig | |
| except Exception as e: | |
| print(f" β Error creating subplot grid: {str(e)}") | |
| return None | |