TWLab's picture
Add publication-ready ML project structure with full source code
e2b220f verified
"""
Publication-quality visualization module.
Generates figures suitable for peer-reviewed journals:
- Predicted vs Actual parity plots
- Residual analysis
- Feature importance (XGBoost gain + SHAP)
- Model comparison charts
- Training curves
- Per-material performance breakdown
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.metrics import r2_score
logger = logging.getLogger(__name__)
# Publication style defaults
STYLE_CONFIG = {
"font.size": 11,
"axes.titlesize": 12,
"axes.labelsize": 11,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"legend.fontsize": 10,
"figure.dpi": 300,
"savefig.dpi": 300,
"savefig.bbox": "tight",
"font.family": "serif",
}
def set_publication_style():
"""Set matplotlib parameters for publication-quality figures."""
plt.rcParams.update(STYLE_CONFIG)
sns.set_palette("Set2")
def plot_predicted_vs_actual(
y_true: np.ndarray,
y_pred: np.ndarray,
target_names: List[str],
model_name: str = "Ensemble",
save_path: Optional[Path] = None,
figsize: Tuple[float, float] = (14, 10),
) -> plt.Figure:
"""
Create predicted vs actual parity plots for all targets.
Parameters
----------
y_true, y_pred : np.ndarray, shape (n_samples, n_targets)
target_names : list of str
model_name : str
save_path : Path, optional
figsize : tuple
Returns
-------
matplotlib.Figure
"""
set_publication_style()
n_targets = len(target_names)
n_cols = 3
n_rows = (n_targets + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
axes = axes.flatten()
# Target-specific units for axis labels
units = {
"etch_depth_um": "Etch Depth (µm)",
"etch_width_um": "Etch Width (µm)",
"surface_roughness_Sa_um": "Surface Roughness Sa (µm)",
"aspect_ratio": "Aspect Ratio",
"side_wall_angle_deg": "Side Wall Angle (°)",
}
for i, target in enumerate(target_names):
ax = axes[i]
yt, yp = y_true[:, i], y_pred[:, i]
r2 = r2_score(yt, yp)
ax.scatter(yt, yp, alpha=0.4, s=12, edgecolors="none", c="steelblue")
# Perfect prediction line
lim_min = min(yt.min(), yp.min())
lim_max = max(yt.max(), yp.max())
margin = (lim_max - lim_min) * 0.05
ax.plot(
[lim_min - margin, lim_max + margin],
[lim_min - margin, lim_max + margin],
"r--", lw=1.5, label="Perfect prediction",
)
# ±10% error bands
x_range = np.linspace(lim_min, lim_max, 100)
ax.fill_between(x_range, x_range * 0.9, x_range * 1.1,
alpha=0.1, color="red", label="±10% band")
label = units.get(target, target)
ax.set_xlabel(f"Actual {label}")
ax.set_ylabel(f"Predicted {label}")
ax.set_title(f"{label}\nR² = {r2:.4f}", fontweight="bold")
ax.set_xlim(lim_min - margin, lim_max + margin)
ax.set_ylim(lim_min - margin, lim_max + margin)
ax.set_aspect("equal")
if i == 0:
ax.legend(loc="lower right", fontsize=9)
# Hide unused subplots
for i in range(n_targets, len(axes)):
axes[i].set_visible(False)
fig.suptitle(f"{model_name} — Predicted vs Actual", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
if save_path:
fig.savefig(save_path, bbox_inches="tight")
logger.info(f"Saved: {save_path}")
return fig
def plot_residual_analysis(
y_true: np.ndarray,
y_pred: np.ndarray,
target_names: List[str],
save_path: Optional[Path] = None,
) -> plt.Figure:
"""
Residual distribution and Q-Q plots for error analysis.
"""
set_publication_style()
n_targets = len(target_names)
fig, axes = plt.subplots(2, n_targets, figsize=(4 * n_targets, 8))
for i, target in enumerate(target_names):
residuals = y_true[:, i] - y_pred[:, i]
# Histogram of residuals
ax = axes[0, i] if n_targets > 1 else axes[0]
ax.hist(residuals, bins=50, density=True, alpha=0.7, color="steelblue", edgecolor="white")
ax.axvline(0, color="red", linestyle="--", lw=1)
ax.set_xlabel("Residual")
ax.set_ylabel("Density")
ax.set_title(target.replace("_", " "))
# Residual vs predicted
ax2 = axes[1, i] if n_targets > 1 else axes[1]
ax2.scatter(y_pred[:, i], residuals, alpha=0.3, s=8, c="steelblue")
ax2.axhline(0, color="red", linestyle="--", lw=1)
ax2.set_xlabel("Predicted")
ax2.set_ylabel("Residual")
plt.tight_layout()
if save_path:
fig.savefig(save_path, bbox_inches="tight")
return fig
def plot_feature_importance(
importances: Dict[str, Dict[str, float]],
top_n: int = 10,
save_path: Optional[Path] = None,
) -> plt.Figure:
"""
Feature importance heatmap and bar charts.
Parameters
----------
importances : dict
{target_name: {feature_name: importance_value}}
top_n : int
Show top N features per target
save_path : Path, optional
"""
set_publication_style()
# Create DataFrame
imp_df = pd.DataFrame(importances)
# Overall importance (mean across targets)
imp_df["Mean"] = imp_df.mean(axis=1)
imp_df = imp_df.sort_values("Mean", ascending=False)
# Heatmap of top features
top_features = imp_df.head(top_n).drop(columns="Mean")
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
top_features,
annot=True, fmt=".3f",
cmap="YlOrRd",
linewidths=0.5,
ax=ax,
)
ax.set_title("Feature Importance by Target (XGBoost Gain)", fontweight="bold")
ax.set_xlabel("Target Variable")
ax.set_ylabel("Feature")
plt.tight_layout()
if save_path:
fig.savefig(save_path, bbox_inches="tight")
return fig
def plot_model_comparison(
results_dict: Dict[str, pd.DataFrame],
metric: str = "R²",
target_names: List[str] = None,
save_path: Optional[Path] = None,
) -> plt.Figure:
"""
Bar chart comparing models across targets.
Parameters
----------
results_dict : dict
{model_name: metrics_DataFrame}
metric : str
Which metric column to plot
target_names : list, optional
save_path : Path, optional
"""
set_publication_style()
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(target_names))
width = 0.8 / len(results_dict)
colors = sns.color_palette("Set2", len(results_dict))
for j, (model_name, df) in enumerate(results_dict.items()):
values = [df.loc[t, metric] if t in df.index else 0 for t in target_names]
offset = (j - len(results_dict) / 2 + 0.5) * width
bars = ax.bar(x + offset, values, width, label=model_name, color=colors[j], edgecolor="white")
ax.set_xlabel("Target Variable")
ax.set_ylabel(metric)
ax.set_title(f"Model Comparison — {metric}", fontweight="bold")
ax.set_xticks(x)
ax.set_xticklabels([t.replace("_", "\n") for t in target_names], fontsize=9)
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
if save_path:
fig.savefig(save_path, bbox_inches="tight")
return fig
def plot_training_curves(
train_losses: List[float],
val_losses: List[float],
save_path: Optional[Path] = None,
) -> plt.Figure:
"""Plot neural network training and validation loss curves."""
set_publication_style()
fig, ax = plt.subplots(figsize=(8, 5))
epochs = range(1, len(train_losses) + 1)
ax.plot(epochs, train_losses, label="Training Loss", color="steelblue", lw=1.5)
ax.plot(epochs, val_losses, label="Validation Loss", color="darkorange", lw=1.5)
ax.set_xlabel("Epoch")
ax.set_ylabel("MSE Loss")
ax.set_title("Neural Network Training Curves", fontweight="bold")
ax.set_yscale("log")
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
if save_path:
fig.savefig(save_path, bbox_inches="tight")
return fig
def plot_per_material_performance(
material_results: pd.DataFrame,
metric_prefix: str = "R²",
target_names: List[str] = None,
save_path: Optional[Path] = None,
) -> plt.Figure:
"""Grouped bar chart showing performance per material type."""
set_publication_style()
# Filter columns for the metric
cols = [c for c in material_results.columns if c.startswith(metric_prefix)]
if not cols:
cols = [c for c in material_results.columns if "R²" in c]
plot_df = material_results[cols].copy()
plot_df.columns = [c.replace(f"{metric_prefix}_", "").replace("R²_", "") for c in cols]
fig, ax = plt.subplots(figsize=(12, 6))
plot_df.plot(kind="bar", ax=ax, colormap="Set2", edgecolor="white")
ax.set_xlabel("Material Type")
ax.set_ylabel("R²")
ax.set_title("Model Performance by Material Type", fontweight="bold")
ax.legend(title="Target", bbox_to_anchor=(1.02, 1), loc="upper left", fontsize=9)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
ax.grid(axis="y", alpha=0.3)
ax.set_ylim(0, 1.05)
plt.tight_layout()
if save_path:
fig.savefig(save_path, bbox_inches="tight")
return fig