trace / plotting.py
bingyan user
Rebrand TRACE -> SPARK
8619a66
"""
Visualization utilities for the SPARK web app.
Generates matplotlib figures for mechanism classification, parameter
posteriors, and signal reconstruction overlays.
"""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
# Consistent figure proportions so plots sit predictably inside the new
# card width (~1080px). Single source of truth used by all helpers below;
# individual plot functions can still override when they need a wider or
# taller canvas (e.g. multi-panel grids).
BASE_FIGSIZE = (7.0, 4.2)
BASE_DPI = 110
plt.rcParams.update({
"figure.figsize": BASE_FIGSIZE,
"figure.dpi": BASE_DPI,
"savefig.dpi": BASE_DPI,
"figure.autolayout": True,
"font.size": 10.5,
"axes.titlesize": 11.5,
"axes.labelsize": 10.5,
"axes.spines.top": False,
"axes.spines.right": False,
"axes.edgecolor": "#94A3B8",
"axes.labelcolor": "#0F172A",
"xtick.color": "#475569",
"ytick.color": "#475569",
"legend.frameon": False,
"legend.fontsize": 9.5,
})
COLORS = {
"primary": "#2563EB",
"secondary": "#7C3AED",
"accent": "#059669",
"warm": "#DC2626",
"neutral": "#6B7280",
"bg": "#F9FAFB",
}
MECH_COLORS_EC = {
"Nernst": "#3B82F6",
"BV": "#8B5CF6",
"MHC": "#EC4899",
"Ads": "#F59E0B",
"EC": "#10B981",
"LH": "#EF4444",
}
MECH_COLORS_TPD = {
"FirstOrder": "#3B82F6",
"SecondOrder": "#8B5CF6",
"LH_Surface": "#EC4899",
"MvK": "#F59E0B",
"FirstOrderCovDep": "#10B981",
"DiffLimited": "#EF4444",
}
MECH_FULL_NAMES_EC = {
"Nernst": "Nernstian (reversible electron transfer)",
"BV": "Butler–Volmer (quasi-reversible electron transfer)",
"MHC": "Marcus–Hush–Chidsey (non-adiabatic electron transfer)",
"Ads": "surface-adsorbed redox couple",
"EC": "electron transfer followed by chemical step",
"LH": "Langmuir–Hinshelwood surface reaction",
"EE": "two sequential electron transfers",
"EC_prime": "electron transfer with catalytic regeneration",
"CE": "chemical step preceding electron transfer",
# OOD (not in trained set)
"ECE": "electron transfer–chemical step–electron transfer",
"EC_LH": "EC followed by Langmuir–Hinshelwood",
"MHC_EC": "Marcus–Hush–Chidsey followed by chemical step",
"MHC_LH": "Marcus–Hush–Chidsey followed by Langmuir–Hinshelwood",
}
MECH_FULL_NAMES_TPD = {
"FirstOrder": "1st-order desorption",
"SecondOrder": "2nd-order recombinative desorption",
"ZerothOrder": "0th-order (multilayer) desorption",
"LH_Surface": "Langmuir–Hinshelwood surface reaction",
"MvK": "Mars–van Krevelen (lattice oxygen)",
"FirstOrderCovDep": "1st-order with coverage-dependent Ed",
"DiffLimited": "diffusion-limited desorption",
"PrecursorMediated": "precursor-mediated desorption",
"Dissociative": "dissociative desorption",
"ActivatedAdsorption": "desorption with activated re-adsorption",
"TwoSite": "two-site (heterogeneous) desorption",
# OOD (not in trained set)
"EleyRideal": "Eley–Rideal surface reaction",
"BimolecularCovDep": "bimolecular coverage-dependent desorption",
}
def plot_mechanism_probs(probs_dict, domain="ec"):
"""
Horizontal bar chart of mechanism classification probabilities.
Args:
probs_dict: {mechanism_name: probability}
domain: 'ec' or 'tpd'
Returns:
matplotlib Figure
"""
colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
full_names = MECH_FULL_NAMES_EC if domain == "ec" else MECH_FULL_NAMES_TPD
names = list(probs_dict.keys())
probs = [probs_dict[n] for n in names]
sorted_idx = np.argsort(probs)
names = [names[i] for i in sorted_idx]
probs = [probs[i] for i in sorted_idx]
bar_colors = [colors.get(n, COLORS["neutral"]) for n in names]
display_names = [f"{n} ({full_names.get(n, n)})" for n in names]
fig, ax = plt.subplots(figsize=(11, max(3, len(names) * 0.7)))
bars = ax.barh(range(len(names)), probs, color=bar_colors, edgecolor="white",
linewidth=0.5, height=0.7)
ax.set_yticks(range(len(names)))
ax.set_yticklabels(display_names, fontsize=11, fontweight="medium")
ax.set_xlim(0, 1.05)
ax.set_xlabel("Probability", fontsize=12)
ax.set_title("Mechanism Classification", fontsize=14, fontweight="bold", pad=15)
for i, (bar, prob) in enumerate(zip(bars, probs)):
if prob > 0.05:
ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height() / 2,
f"{prob:.1%}", va="center", fontsize=11, fontweight="bold")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="x", alpha=0.3, linestyle="--")
fig.tight_layout()
return fig
def plot_posteriors(samples, param_names, mechanism_name, domain="ec"):
"""
Violin plots of posterior distributions for each parameter.
Args:
samples: [n_samples, D] array of posterior samples
param_names: list of parameter names
mechanism_name: name of the mechanism
domain: 'ec' or 'tpd'
Returns:
matplotlib Figure
"""
n_params = len(param_names)
fig, axes = plt.subplots(1, n_params, figsize=(max(4, 3 * n_params), 4.5))
if n_params == 1:
axes = [axes]
colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
color = colors.get(mechanism_name, COLORS["primary"])
for i, (ax, name) in enumerate(zip(axes, param_names)):
data = samples[:, i]
parts = ax.violinplot(data, positions=[0], showmeans=True,
showmedians=True, showextrema=False)
for pc in parts["bodies"]:
pc.set_facecolor(color)
pc.set_alpha(0.6)
parts["cmeans"].set_color("black")
parts["cmedians"].set_color(COLORS["warm"])
q05, q95 = np.quantile(data, [0.05, 0.95])
ax.axhline(q05, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8)
ax.axhline(q95, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8)
ax.set_title(_format_param_name(name), fontsize=11, fontweight="medium")
ax.set_xticks([])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
mean_val = data.mean()
ax.text(0.5, 0.02, f"mean={mean_val:.3f}", transform=ax.transAxes,
ha="center", fontsize=9, color=COLORS["neutral"])
fig.suptitle(f"Parameter Posteriors — {mechanism_name}",
fontsize=14, fontweight="bold")
# Shared legend explaining the mean / median markers and 90% interval.
from matplotlib.lines import Line2D
legend_handles = [
Line2D([0], [0], color="black", lw=1.5, label="mean"),
Line2D([0], [0], color=COLORS["warm"], lw=1.5, label="median"),
Line2D([0], [0], color=COLORS["neutral"], lw=0.8, ls="--",
label="5th / 95th percentile"),
]
fig.legend(
handles=legend_handles,
loc="lower center",
ncol=3,
frameon=False,
fontsize=9,
bbox_to_anchor=(0.5, -0.02),
)
fig.tight_layout(rect=[0, 0.04, 1, 0.93])
return fig
def plot_reconstruction(observed_curves, recon_curves, domain="ec",
nrmses=None, r2s=None, scan_labels=None):
"""
Overlay of observed vs reconstructed signals with optional metrics.
Args:
observed_curves: list of dicts with 'x' and 'y' arrays
recon_curves: list of dicts with 'x' and 'y' arrays (same length)
domain: 'ec' or 'tpd'
nrmses: optional list of NRMSE values per curve
r2s: optional list of R2 values per curve
scan_labels: optional list of label strings per curve
Returns:
matplotlib Figure
"""
n_curves = len(observed_curves)
fig, axes = plt.subplots(1, min(n_curves, 4),
figsize=(max(5, 4 * min(n_curves, 4)), 5),
squeeze=False)
axes = axes[0]
xlabel = "Potential (\u03b8)" if domain == "ec" else "Temperature (K)"
ylabel = "Flux" if domain == "ec" else "Rate"
for i, ax in enumerate(axes):
if i >= n_curves:
ax.set_visible(False)
continue
obs = observed_curves[i]
rec = recon_curves[i]
ax.plot(obs["x"], obs["y"], color=COLORS["neutral"], linewidth=1.5,
label="Observed", alpha=0.8)
ax.plot(rec["x"], rec["y"], color=COLORS["primary"], linewidth=1.5,
label="Reconstructed", linestyle="--")
ax.set_xlabel(xlabel, fontsize=10)
if i == 0:
ax.set_ylabel(ylabel, fontsize=10)
ax.legend(fontsize=8, framealpha=0.8, loc="best")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if scan_labels and i < len(scan_labels):
title = scan_labels[i]
elif domain == "ec":
title = f"Scan rate {i + 1}"
else:
title = f"Heating rate {i + 1}"
ax.set_title(title, fontsize=10)
metrics_parts = []
if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]):
metrics_parts.append(f"NRMSE={nrmses[i]:.4f}")
if r2s and i < len(r2s) and np.isfinite(r2s[i]):
metrics_parts.append(f"R\u00b2={r2s[i]:.4f}")
if metrics_parts:
ax.text(0.02, 0.98, " ".join(metrics_parts),
transform=ax.transAxes, fontsize=8, va="top",
color=COLORS["accent"], fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white",
alpha=0.8, edgecolor=COLORS["accent"]))
suptitle = "Signal Reconstruction"
if nrmses and r2s:
valid_nrmse = [v for v in nrmses if np.isfinite(v)]
valid_r2 = [v for v in r2s if np.isfinite(v)]
if valid_nrmse and valid_r2:
avg_nrmse = np.mean(valid_nrmse)
avg_r2 = np.mean(valid_r2)
suptitle += f" (avg NRMSE={avg_nrmse:.4f}, avg R\u00b2={avg_r2:.4f})"
fig.suptitle(suptitle, fontsize=12, fontweight="bold")
fig.tight_layout(rect=[0, 0, 1, 0.93])
return fig
def _add_sweep_arrows(ax, pot, y_ox, y_red, mid, show_labels=False):
"""Add direction arrows for forward/reverse sweeps on both species."""
sweep_specs = [
(slice(None, mid), 0.35),
(slice(mid, None), 0.65),
]
curves = [
(y_ox, COLORS["primary"]),
(y_red, COLORS["warm"]),
]
for y_data, color in curves:
for segment, frac in sweep_specs:
x_seg = pot[segment]
y_seg = y_data[segment]
n = len(x_seg)
if n < 10:
continue
idx = int(n * frac)
idx = max(2, min(idx, n - 3))
step = max(1, n // 30)
i0 = max(0, idx - step)
i1 = min(n - 1, idx + step)
ax.annotate(
"", xy=(x_seg[i1], y_seg[i1]),
xytext=(x_seg[i0], y_seg[i0]),
arrowprops=dict(arrowstyle="-|>", color=color,
lw=1.8, mutation_scale=14),
)
def plot_concentration_profiles(conc_curves, scan_labels=None):
"""
Plot surface concentration profiles (C_A and C_B) vs potential.
Args:
conc_curves: list of dicts with 'x' (potential), 'c_ox', 'c_red',
or None for failed reconstructions
scan_labels: optional list of label strings per curve
Returns:
matplotlib Figure, or None if no valid data
"""
valid = [c for c in conc_curves if c is not None]
if not valid:
return None
n_curves = len(conc_curves)
fig, axes = plt.subplots(1, min(n_curves, 4),
figsize=(max(5, 4 * min(n_curves, 4)), 5),
squeeze=False)
axes = axes[0]
for i, ax in enumerate(axes):
if i >= n_curves or conc_curves[i] is None:
ax.set_visible(False)
continue
c = conc_curves[i]
pot = np.asarray(c["x"])
c_ox = np.asarray(c["c_ox"])
c_red = np.asarray(c["c_red"])
mid = len(pot) // 2
# Forward sweep (reductive): first half
ax.plot(pot[:mid], c_ox[:mid], color=COLORS["primary"], linewidth=1.5,
label="C$_A$ (ox)")
ax.plot(pot[:mid], c_red[:mid], color=COLORS["warm"], linewidth=1.5,
label="C$_B$ (red)")
# Reverse sweep (oxidative): second half
ax.plot(pot[mid:], c_ox[mid:], color=COLORS["primary"], linewidth=1.5)
ax.plot(pot[mid:], c_red[mid:], color=COLORS["warm"], linewidth=1.5)
_add_sweep_arrows(ax, pot, c_ox, c_red, mid)
ax.set_xlabel("Potential (\u03b8)", fontsize=10)
if i == 0:
ax.set_ylabel("Surface concentration", fontsize=10)
ax.legend(fontsize=8, framealpha=0.8, loc="best")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if scan_labels and i < len(scan_labels):
ax.set_title(scan_labels[i], fontsize=10)
else:
ax.set_title(f"Scan rate {i + 1}", fontsize=10)
fig.suptitle("Surface Concentration Profiles", fontsize=12,
fontweight="bold")
fig.tight_layout(rect=[0, 0, 1, 0.93])
return fig
def plot_parameter_table(param_stats, mechanism_name):
"""
Create a formatted parameter summary table as a figure.
Args:
param_stats: dict with 'names', 'mean', 'std', 'q05', 'q95'
mechanism_name: name of the mechanism
Returns:
matplotlib Figure
"""
names = param_stats["names"]
means = param_stats["mean"]
stds = param_stats["std"]
q05s = param_stats["q05"]
q95s = param_stats["q95"]
n = len(names)
fig, ax = plt.subplots(figsize=(8, max(2, 0.6 * n + 1)))
ax.axis("off")
col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"]
cell_text = []
for i in range(n):
cell_text.append([
_format_param_name(names[i]),
f"{means[i]:.4f}",
f"{stds[i]:.4f}",
f"{q05s[i]:.4f}",
f"{q95s[i]:.4f}",
])
table = ax.table(cellText=cell_text, colLabels=col_labels,
loc="center", cellLoc="center")
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.0, 1.5)
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_facecolor("#E5E7EB")
cell.set_text_props(fontweight="bold")
else:
cell.set_facecolor("#F9FAFB" if row % 2 == 0 else "white")
ax.set_title(f"Parameter Estimates — {mechanism_name}",
fontsize=14, fontweight="bold", pad=20)
fig.tight_layout()
return fig
def _format_param_name(name):
"""Format parameter names for display."""
replacements = {
"log10(K0)": "log₁₀(K₀)",
"log10(dB)": "log₁₀(d_B)",
"log10(dA)": "log₁₀(d_A)",
"log10(kc)": "log₁₀(k_c)",
"log10(reorg_e)": "log₁₀(λ)",
"log10(Gamma_sat)": "log₁₀(Γ_sat)",
"log10(KA_eq)": "log₁₀(K_A,eq)",
"log10(KB_eq)": "log₁₀(K_B,eq)",
"log10(nu)": "log₁₀(ν)",
"log10(nu_red)": "log₁₀(ν_red)",
"log10(D0)": "log₁₀(D₀)",
"E0_offset": "E₀ offset",
"alpha": "α",
"alpha_cov": "α_cov",
"Ed": "E_d (K)",
"Ed0": "E_d0 (K)",
"Ea": "E_a (K)",
"Ea_red": "E_a,red (K)",
"Ea_reox": "E_a,reox (K)",
"E_diff": "E_diff (K)",
"theta_0": "θ₀",
"theta_A0": "θ_A0",
"theta_B0": "θ_B0",
"theta_O0": "θ_O0",
}
return replacements.get(name, name)