Data-Science-Agent / src /tools /matplotlib_visualizations.py
Pulastya B
feat: Initial commit - Data Science Agent with React frontend and FastAPI backend
226ac39
"""
Matplotlib + Seaborn Visualization Engine
Production-quality visualizations that work reliably with Gradio UI.
All functions return matplotlib Figure objects (not file paths).
Designed for publication-quality plots with professional styling.
"""
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend for Gradio compatibility
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from typing import Dict, Any, List, Optional, Tuple, Union
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
# Set global style
sns.set_style('whitegrid')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10
plt.rcParams['legend.fontsize'] = 10
# ============================================================================
# BASIC PLOTS
# ============================================================================
def create_scatter_plot(
x: Union[np.ndarray, pd.Series, list],
y: Union[np.ndarray, pd.Series, list],
hue: Optional[Union[np.ndarray, pd.Series, list]] = None,
size: Optional[Union[np.ndarray, pd.Series, list]] = None,
title: str = "Scatter Plot",
xlabel: str = "X",
ylabel: str = "Y",
figsize: Tuple[int, int] = (10, 6),
alpha: float = 0.6,
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a professional scatter plot with optional color coding and size variation.
Args:
x: X-axis data
y: Y-axis data
hue: Optional categorical data for color coding
size: Optional numeric data for size variation
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size (width, height)
alpha: Point transparency (0-1)
save_path: Optional path to save PNG file
Returns:
matplotlib Figure object
Example:
>>> fig = create_scatter_plot(df['feature1'], df['target'],
... hue=df['category'], title='Feature vs Target')
>>> # Display in Gradio: gr.Plot(value=fig)
"""
try:
fig, ax = plt.subplots(figsize=figsize)
# Convert inputs to arrays
x = np.array(x)
y = np.array(y)
if hue is not None:
hue = np.array(hue)
unique_hues = np.unique(hue)
colors = sns.color_palette('Set2', n_colors=len(unique_hues))
for i, hue_val in enumerate(unique_hues):
mask = hue == hue_val
scatter_size = 50 if size is None else np.array(size)[mask]
ax.scatter(x[mask], y[mask],
c=[colors[i]],
s=scatter_size,
alpha=alpha,
label=str(hue_val),
edgecolors='black',
linewidth=0.5)
ax.legend(title='Category', loc='best', framealpha=0.9)
else:
scatter_size = 50 if size is None else size
ax.scatter(x, y,
c='steelblue',
s=scatter_size,
alpha=alpha,
edgecolors='black',
linewidth=0.5)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved scatter plot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating scatter plot: {str(e)}")
return None
def create_line_plot(
x: Union[np.ndarray, pd.Series, list],
y: Union[Dict[str, np.ndarray], np.ndarray, pd.Series, list],
title: str = "Line Plot",
xlabel: str = "X",
ylabel: str = "Y",
figsize: Tuple[int, int] = (10, 6),
markers: bool = True,
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a line plot (supports multiple lines via dict).
Args:
x: X-axis data
y: Y-axis data (dict for multiple lines: {'label1': y1, 'label2': y2})
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size
markers: Show markers on lines
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
x = np.array(x)
marker_style = 'o' if markers else None
if isinstance(y, dict):
colors = sns.color_palette('husl', n_colors=len(y))
for i, (label, y_data) in enumerate(y.items()):
ax.plot(x, np.array(y_data),
marker=marker_style,
label=label,
linewidth=2,
markersize=6,
color=colors[i])
ax.legend(loc='best', framealpha=0.9)
else:
ax.plot(x, np.array(y),
marker=marker_style,
linewidth=2,
markersize=6,
color='steelblue')
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved line plot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating line plot: {str(e)}")
return None
def create_bar_chart(
categories: Union[list, np.ndarray],
values: Union[np.ndarray, pd.Series, list],
title: str = "Bar Chart",
xlabel: str = "Category",
ylabel: str = "Value",
figsize: Tuple[int, int] = (10, 6),
horizontal: bool = False,
color: str = 'steelblue',
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a bar chart (vertical or horizontal).
Args:
categories: Category names
values: Values for each category
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size
horizontal: If True, create horizontal bars
color: Bar color
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
categories = list(categories)
values = np.array(values)
if horizontal:
ax.barh(categories, values, color=color, edgecolor='black', linewidth=0.7)
ax.set_xlabel(ylabel, fontsize=12)
ax.set_ylabel(xlabel, fontsize=12)
else:
ax.bar(categories, values, color=color, edgecolor='black', linewidth=0.7)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
# Rotate labels if many categories
if len(categories) > 10:
plt.xticks(rotation=45, ha='right')
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved bar chart to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating bar chart: {str(e)}")
return None
def create_histogram(
data: Union[np.ndarray, pd.Series, list],
title: str = "Histogram",
xlabel: str = "Value",
ylabel: str = "Frequency",
bins: int = 30,
figsize: Tuple[int, int] = (10, 6),
kde: bool = True,
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a histogram with optional KDE overlay.
Args:
data: Data to plot
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
bins: Number of bins
figsize: Figure size
kde: Show kernel density estimate
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
data = np.array(data)
data = data[~np.isnan(data)] # Remove NaN values
if len(data) == 0:
print(" βœ— No valid data for histogram")
return None
# Create histogram
ax.hist(data, bins=bins, color='steelblue',
edgecolor='black', alpha=0.7, density=kde)
# Add KDE if requested
if kde:
ax2 = ax.twinx()
sns.kdeplot(data, ax=ax2, color='darkred', linewidth=2, label='KDE')
ax2.set_ylabel('Density', fontsize=12)
ax2.legend(loc='upper right')
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved histogram to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating histogram: {str(e)}")
return None
def create_boxplot(
data: Union[Dict[str, np.ndarray], pd.DataFrame],
title: str = "Box Plot",
xlabel: str = "Category",
ylabel: str = "Value",
figsize: Tuple[int, int] = (10, 6),
horizontal: bool = False,
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create box plots for multiple columns/categories.
Args:
data: Dictionary of {column_name: values} or DataFrame
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size
horizontal: If True, create horizontal boxplots
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
if isinstance(data, pd.DataFrame):
data_to_plot = [data[col].dropna() for col in data.columns]
labels = data.columns
elif isinstance(data, dict):
data_to_plot = [np.array(v)[~np.isnan(np.array(v))] for v in data.values()]
labels = list(data.keys())
else:
raise ValueError("Data must be DataFrame or dict")
bp = ax.boxplot(data_to_plot,
labels=labels,
vert=not horizontal,
patch_artist=True,
notch=True,
showmeans=True)
# Styling
for patch in bp['boxes']:
patch.set_facecolor('lightblue')
patch.set_alpha(0.7)
for whisker in bp['whiskers']:
whisker.set(linewidth=1.5, color='gray')
for cap in bp['caps']:
cap.set(linewidth=1.5, color='gray')
for median in bp['medians']:
median.set(linewidth=2, color='darkred')
for mean in bp['means']:
mean.set(marker='D', markerfacecolor='green', markersize=6)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
if horizontal:
ax.set_xlabel(ylabel, fontsize=12)
ax.set_ylabel(xlabel, fontsize=12)
else:
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
if len(labels) > 8:
plt.xticks(rotation=45, ha='right')
ax.grid(True, alpha=0.3, linestyle='--', axis='y' if not horizontal else 'x')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved boxplot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating boxplot: {str(e)}")
return None
# ============================================================================
# STATISTICAL PLOTS
# ============================================================================
def create_correlation_heatmap(
data: Union[pd.DataFrame, np.ndarray],
columns: Optional[List[str]] = None,
title: str = "Correlation Heatmap",
figsize: Tuple[int, int] = (12, 10),
annot: bool = True,
cmap: str = 'RdBu_r',
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a correlation heatmap with annotations.
Args:
data: DataFrame or correlation matrix
columns: Column names (if data is np.ndarray)
title: Plot title
figsize: Figure size
annot: Show correlation values as annotations
cmap: Colormap (diverging, centered at 0)
save_path: Optional save path
Returns:
matplotlib Figure object
Example:
>>> fig = create_correlation_heatmap(df[numeric_cols])
"""
try:
fig, ax = plt.subplots(figsize=figsize)
# Calculate correlation if DataFrame
if isinstance(data, pd.DataFrame):
corr_matrix = data.corr()
else:
corr_matrix = pd.DataFrame(data, columns=columns, index=columns)
# Create heatmap
mask = np.triu(np.ones_like(corr_matrix, dtype=bool)) # Mask upper triangle
sns.heatmap(corr_matrix,
mask=mask,
annot=annot,
fmt='.2f',
cmap=cmap,
center=0,
square=True,
linewidths=0.5,
cbar_kws={'shrink': 0.8, 'label': 'Correlation'},
ax=ax,
vmin=-1,
vmax=1)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved correlation heatmap to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating correlation heatmap: {str(e)}")
return None
def create_distribution_plot(
data: Union[np.ndarray, pd.Series, list],
title: str = "Distribution Plot",
xlabel: str = "Value",
figsize: Tuple[int, int] = (10, 6),
show_rug: bool = False,
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a distribution plot with histogram and KDE.
Args:
data: Data to plot
title: Plot title
xlabel: X-axis label
figsize: Figure size
show_rug: Show rug plot (data points on x-axis)
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
data = np.array(data)
data = data[~np.isnan(data)]
if len(data) == 0:
print(" βœ— No valid data for distribution plot")
return None
# Create distribution plot
sns.histplot(data, kde=True, ax=ax, color='steelblue',
edgecolor='black', alpha=0.6, bins=30)
if show_rug:
sns.rugplot(data, ax=ax, color='darkred', alpha=0.5, height=0.05)
# Add statistics text
mean_val = np.mean(data)
median_val = np.median(data)
std_val = np.std(data)
stats_text = f'Mean: {mean_val:.2f}\nMedian: {median_val:.2f}\nStd: {std_val:.2f}'
ax.text(0.98, 0.98, stats_text,
transform=ax.transAxes,
verticalalignment='top',
horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
fontsize=10)
# Add vertical lines for mean and median
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label='Mean')
ax.axvline(median_val, color='green', linestyle='--', linewidth=2, label='Median')
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel('Frequency / Density', fontsize=12)
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved distribution plot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating distribution plot: {str(e)}")
return None
def create_violin_plot(
data: Union[Dict[str, np.ndarray], pd.DataFrame],
title: str = "Violin Plot",
xlabel: str = "Category",
ylabel: str = "Value",
figsize: Tuple[int, int] = (10, 6),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create violin plots showing distribution for multiple categories.
Args:
data: Dictionary or DataFrame with categories
title: Plot title
xlabel: X-axis label
ylabel: Y-axis label
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
if isinstance(data, dict):
# Convert dict to DataFrame for seaborn
df_list = []
for key, values in data.items():
df_list.append(pd.DataFrame({
'Category': [key] * len(values),
'Value': values
}))
plot_df = pd.concat(df_list, ignore_index=True)
else:
plot_df = data
# Create violin plot
sns.violinplot(data=plot_df, x='Category', y='Value', ax=ax,
palette='Set2', inner='box')
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel(xlabel, fontsize=12)
ax.set_ylabel(ylabel, fontsize=12)
if len(plot_df['Category'].unique()) > 8:
plt.xticks(rotation=45, ha='right')
ax.grid(True, alpha=0.3, linestyle='--', axis='y')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved violin plot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating violin plot: {str(e)}")
return None
def create_pairplot(
data: pd.DataFrame,
hue: Optional[str] = None,
title: str = "Pair Plot",
figsize: Tuple[int, int] = (12, 12),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a pairplot (scatterplot matrix) for multiple features.
Args:
data: DataFrame with features to plot
hue: Column name for color coding
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
# Seaborn pairplot returns a PairGrid, we need to extract the figure
if hue and hue in data.columns:
pair_grid = sns.pairplot(data, hue=hue, palette='Set2',
diag_kind='kde', corner=True)
else:
pair_grid = sns.pairplot(data, palette='Set2',
diag_kind='kde', corner=True)
fig = pair_grid.fig
fig.suptitle(title, fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved pairplot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating pairplot: {str(e)}")
return None
# ============================================================================
# MACHINE LEARNING PLOTS
# ============================================================================
def create_roc_curve(
models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]],
title: str = "ROC Curve Comparison",
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create ROC curves for multiple models on the same plot.
Args:
models_data: Dict of {model_name: (fpr, tpr, auc_score)}
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
Example:
>>> from sklearn.metrics import roc_curve, auc
>>> fpr, tpr, _ = roc_curve(y_true, y_pred_proba)
>>> auc_score = auc(fpr, tpr)
>>> models = {'Random Forest': (fpr, tpr, auc_score)}
>>> fig = create_roc_curve(models)
"""
try:
fig, ax = plt.subplots(figsize=figsize)
colors = sns.color_palette('husl', n_colors=len(models_data))
for i, (model_name, (fpr, tpr, auc_score)) in enumerate(models_data.items()):
ax.plot(fpr, tpr,
linewidth=2.5,
label=f'{model_name} (AUC = {auc_score:.3f})',
color=colors[i])
# Add diagonal reference line (random classifier)
ax.plot([0, 1], [0, 1],
linestyle='--',
linewidth=2,
color='gray',
label='Random Classifier (AUC = 0.500)')
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.legend(loc='lower right', fontsize=10, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved ROC curve to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating ROC curve: {str(e)}")
return None
def create_confusion_matrix(
cm: np.ndarray,
class_names: Optional[List[str]] = None,
title: str = "Confusion Matrix",
figsize: Tuple[int, int] = (10, 8),
show_percentages: bool = True,
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a confusion matrix heatmap with annotations.
Args:
cm: Confusion matrix (from sklearn.metrics.confusion_matrix)
class_names: Names of classes (optional)
title: Plot title
figsize: Figure size
show_percentages: Show percentages in addition to counts
save_path: Optional save path
Returns:
matplotlib Figure object
Example:
>>> from sklearn.metrics import confusion_matrix
>>> cm = confusion_matrix(y_true, y_pred)
>>> fig = create_confusion_matrix(cm, class_names=['Class 0', 'Class 1'])
"""
try:
fig, ax = plt.subplots(figsize=figsize)
if class_names is None:
class_names = [f'Class {i}' for i in range(len(cm))]
# Normalize for percentages
cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
# Create annotations
if show_percentages:
annotations = np.array([[f'{count}\n({percent:.1f}%)'
for count, percent in zip(row_counts, row_percents)]
for row_counts, row_percents in zip(cm, cm_percent)])
else:
annotations = cm
# Create heatmap
sns.heatmap(cm,
annot=annotations,
fmt='',
cmap='Blues',
square=True,
linewidths=0.5,
cbar_kws={'label': 'Count'},
xticklabels=class_names,
yticklabels=class_names,
ax=ax)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_ylabel('Actual', fontsize=12)
ax.set_xlabel('Predicted', fontsize=12)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved confusion matrix to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating confusion matrix: {str(e)}")
return None
def create_precision_recall_curve(
models_data: Dict[str, Tuple[np.ndarray, np.ndarray, float]],
title: str = "Precision-Recall Curve",
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create precision-recall curves for multiple models.
Args:
models_data: Dict of {model_name: (precision, recall, avg_precision)}
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
colors = sns.color_palette('husl', n_colors=len(models_data))
for i, (model_name, (precision, recall, avg_precision)) in enumerate(models_data.items()):
ax.plot(recall, precision,
linewidth=2.5,
label=f'{model_name} (AP = {avg_precision:.3f})',
color=colors[i])
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('Recall', fontsize=12)
ax.set_ylabel('Precision', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.legend(loc='best', fontsize=10, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved precision-recall curve to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating precision-recall curve: {str(e)}")
return None
def create_feature_importance(
feature_names: List[str],
importances: np.ndarray,
title: str = "Feature Importance",
top_n: int = 20,
figsize: Tuple[int, int] = (10, 8),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a horizontal bar chart of feature importances.
Args:
feature_names: List of feature names
importances: Array of importance values
title: Plot title
top_n: Number of top features to show
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
Example:
>>> importances = model.feature_importances_
>>> fig = create_feature_importance(feature_names, importances, top_n=15)
"""
try:
# Sort by importance
indices = np.argsort(importances)[::-1][:top_n]
sorted_features = [feature_names[i] for i in indices]
sorted_importances = importances[indices]
# Create figure with appropriate height
height = max(8, top_n * 0.4)
fig, ax = plt.subplots(figsize=(figsize[0], height))
# Color bars by positive/negative (if any negative values)
colors = ['green' if x >= 0 else 'red' for x in sorted_importances]
# Create horizontal bar chart
y_pos = np.arange(len(sorted_features))
ax.barh(y_pos, sorted_importances, color=colors, edgecolor='black', linewidth=0.7)
ax.set_yticks(y_pos)
ax.set_yticklabels(sorted_features)
ax.invert_yaxis() # Top features at top
ax.set_xlabel('Importance Score', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved feature importance to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating feature importance plot: {str(e)}")
return None
def create_residual_plot(
y_true: np.ndarray,
y_pred: np.ndarray,
title: str = "Residual Plot",
figsize: Tuple[int, int] = (10, 6),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a residual plot (Predicted vs Actual) for regression models.
Args:
y_true: True target values
y_pred: Predicted values
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(figsize[0]*2, figsize[1]))
residuals = y_true - y_pred
# Plot 1: Predicted vs Actual
ax1.scatter(y_true, y_pred, alpha=0.5, s=50, edgecolors='black', linewidth=0.5)
# Add perfect prediction line
min_val = min(y_true.min(), y_pred.min())
max_val = max(y_true.max(), y_pred.max())
ax1.plot([min_val, max_val], [min_val, max_val],
'r--', linewidth=2, label='Perfect Prediction')
ax1.set_xlabel('Actual Values', fontsize=12)
ax1.set_ylabel('Predicted Values', fontsize=12)
ax1.set_title('Predicted vs Actual', fontsize=13, fontweight='bold')
ax1.legend()
ax1.grid(True, alpha=0.3, linestyle='--')
# Plot 2: Residuals vs Predicted
ax2.scatter(y_pred, residuals, alpha=0.5, s=50,
color='steelblue', edgecolors='black', linewidth=0.5)
ax2.axhline(y=0, color='red', linestyle='--', linewidth=2)
ax2.set_xlabel('Predicted Values', fontsize=12)
ax2.set_ylabel('Residuals', fontsize=12)
ax2.set_title('Residuals vs Predicted', fontsize=13, fontweight='bold')
ax2.grid(True, alpha=0.3, linestyle='--')
fig.suptitle(title, fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved residual plot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating residual plot: {str(e)}")
return None
def create_learning_curve(
train_sizes: np.ndarray,
train_scores_mean: np.ndarray,
train_scores_std: np.ndarray,
val_scores_mean: np.ndarray,
val_scores_std: np.ndarray,
title: str = "Learning Curve",
figsize: Tuple[int, int] = (10, 6),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a learning curve showing training and validation scores.
Args:
train_sizes: Array of training set sizes
train_scores_mean: Mean training scores
train_scores_std: Std of training scores
val_scores_mean: Mean validation scores
val_scores_std: Std of validation scores
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
# Plot training scores
ax.plot(train_sizes, train_scores_mean, 'o-', color='blue',
linewidth=2, markersize=8, label='Training Score')
ax.fill_between(train_sizes,
train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std,
alpha=0.2, color='blue')
# Plot validation scores
ax.plot(train_sizes, val_scores_mean, 'o-', color='orange',
linewidth=2, markersize=8, label='Validation Score')
ax.fill_between(train_sizes,
val_scores_mean - val_scores_std,
val_scores_mean + val_scores_std,
alpha=0.2, color='orange')
ax.set_xlabel('Training Set Size', fontsize=12)
ax.set_ylabel('Score', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.legend(loc='best', fontsize=11, framealpha=0.9)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved learning curve to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating learning curve: {str(e)}")
return None
# ============================================================================
# DATA QUALITY PLOTS
# ============================================================================
def create_missing_values_heatmap(
df: pd.DataFrame,
title: str = "Missing Values Heatmap",
figsize: Tuple[int, int] = (12, 8),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a heatmap showing missing values pattern.
Args:
df: DataFrame to analyze
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
fig, ax = plt.subplots(figsize=figsize)
# Create binary matrix (1 = missing, 0 = present)
missing_matrix = df.isnull().astype(int)
# Plot heatmap
sns.heatmap(missing_matrix.T,
cbar=False,
cmap='RdYlGn_r',
ax=ax,
yticklabels=df.columns)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.set_xlabel('Sample Index', fontsize=12)
ax.set_ylabel('Features', fontsize=12)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved missing values heatmap to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating missing values heatmap: {str(e)}")
return None
def create_missing_values_bar(
df: pd.DataFrame,
title: str = "Missing Values by Column",
figsize: Tuple[int, int] = (10, 6),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a bar chart showing percentage of missing values per column.
Args:
df: DataFrame to analyze
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
# Calculate missing percentages
missing_pct = (df.isnull().sum() / len(df) * 100).sort_values(ascending=False)
missing_pct = missing_pct[missing_pct > 0] # Only columns with missing values
if len(missing_pct) == 0:
print(" β„Ή No missing values found")
return None
height = max(6, len(missing_pct) * 0.3)
fig, ax = plt.subplots(figsize=(figsize[0], height))
# Create horizontal bar chart
colors = plt.cm.Reds(missing_pct / 100)
ax.barh(range(len(missing_pct)), missing_pct.values,
color=colors, edgecolor='black', linewidth=0.7)
ax.set_yticks(range(len(missing_pct)))
ax.set_yticklabels(missing_pct.index)
ax.set_xlabel('Missing Values (%)', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
# Add percentage labels
for i, v in enumerate(missing_pct.values):
ax.text(v + 1, i, f'{v:.1f}%', va='center', fontsize=10)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved missing values bar chart to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating missing values bar chart: {str(e)}")
return None
def create_outlier_detection_boxplot(
df: pd.DataFrame,
columns: Optional[List[str]] = None,
title: str = "Outlier Detection",
figsize: Tuple[int, int] = (12, 6),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create box plots for outlier detection across multiple columns.
Args:
df: DataFrame with numeric columns
columns: Columns to plot (None = all numeric)
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
if columns is None:
columns = df.select_dtypes(include=[np.number]).columns.tolist()[:10]
return create_boxplot(df[columns], title=title, figsize=figsize, save_path=save_path)
except Exception as e:
print(f" βœ— Error creating outlier detection plot: {str(e)}")
return None
def create_skewness_plot(
df: pd.DataFrame,
title: str = "Feature Skewness Distribution",
figsize: Tuple[int, int] = (10, 6),
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a bar chart showing skewness of numeric features.
Args:
df: DataFrame with numeric columns
title: Plot title
figsize: Figure size
save_path: Optional save path
Returns:
matplotlib Figure object
"""
try:
# Calculate skewness for numeric columns
numeric_cols = df.select_dtypes(include=[np.number]).columns
skewness = df[numeric_cols].skew().sort_values(ascending=False)
if len(skewness) == 0:
print(" β„Ή No numeric columns to analyze")
return None
height = max(6, len(skewness) * 0.3)
fig, ax = plt.subplots(figsize=(figsize[0], height))
# Color by skewness level
colors = ['green' if abs(x) < 0.5 else 'orange' if abs(x) < 1 else 'red'
for x in skewness.values]
ax.barh(range(len(skewness)), skewness.values,
color=colors, edgecolor='black', linewidth=0.7)
ax.set_yticks(range(len(skewness)))
ax.set_yticklabels(skewness.index)
ax.set_xlabel('Skewness', fontsize=12)
ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
ax.axvline(x=0, color='black', linestyle='-', linewidth=1)
ax.axvline(x=-0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
ax.axvline(x=0.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
ax.grid(True, alpha=0.3, linestyle='--', axis='x')
# Add legend
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor='green', label='Low (|skew| < 0.5)'),
Patch(facecolor='orange', label='Moderate (0.5 ≀ |skew| < 1)'),
Patch(facecolor='red', label='High (|skew| β‰₯ 1)')
]
ax.legend(handles=legend_elements, loc='best')
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved skewness plot to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating skewness plot: {str(e)}")
return None
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def save_figure(fig: plt.Figure, path: str, dpi: int = 300) -> None:
"""
Save a matplotlib figure to file.
Args:
fig: Matplotlib Figure object
path: Output file path (supports .png, .jpg, .pdf, .svg)
dpi: Resolution (dots per inch)
"""
try:
Path(path).parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path, dpi=dpi, bbox_inches='tight', facecolor='white')
print(f" βœ“ Saved figure to {path}")
except Exception as e:
print(f" βœ— Error saving figure: {str(e)}")
def close_figure(fig: plt.Figure) -> None:
"""
Close a matplotlib figure to free memory.
Args:
fig: Matplotlib Figure object
"""
if fig is not None:
plt.close(fig)
def create_subplots_grid(
plot_data: List[Dict[str, Any]],
rows: int,
cols: int,
figsize: Tuple[int, int] = (15, 12),
title: str = "Plot Grid",
save_path: Optional[str] = None
) -> plt.Figure:
"""
Create a grid of subplots.
Args:
plot_data: List of dicts with plot specifications
rows: Number of rows
cols: Number of columns
figsize: Figure size
title: Overall title
save_path: Optional save path
Returns:
matplotlib Figure object
Example:
>>> plots = [
... {'type': 'scatter', 'x': x1, 'y': y1, 'title': 'Plot 1'},
... {'type': 'hist', 'data': data1, 'title': 'Plot 2'}
... ]
>>> fig = create_subplots_grid(plots, 2, 2)
"""
try:
fig, axes = plt.subplots(rows, cols, figsize=figsize)
axes = axes.flatten() if rows * cols > 1 else [axes]
for i, (ax, plot_spec) in enumerate(zip(axes, plot_data)):
plot_type = plot_spec.get('type', 'scatter')
if plot_type == 'scatter':
ax.scatter(plot_spec['x'], plot_spec['y'], alpha=0.6)
elif plot_type == 'hist':
ax.hist(plot_spec['data'], bins=30, edgecolor='black')
elif plot_type == 'line':
ax.plot(plot_spec['x'], plot_spec['y'])
ax.set_title(plot_spec.get('title', f'Subplot {i+1}'), fontweight='bold')
ax.grid(True, alpha=0.3)
# Hide unused subplots
for i in range(len(plot_data), len(axes)):
axes[i].axis('off')
fig.suptitle(title, fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f" βœ“ Saved subplot grid to {save_path}")
return fig
except Exception as e:
print(f" βœ— Error creating subplot grid: {str(e)}")
return None