""" Visualization utilities for the SPARK web app. Generates matplotlib figures for mechanism classification, parameter posteriors, and signal reconstruction overlays. """ import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec # Consistent figure proportions so plots sit predictably inside the new # card width (~1080px). Single source of truth used by all helpers below; # individual plot functions can still override when they need a wider or # taller canvas (e.g. multi-panel grids). BASE_FIGSIZE = (7.0, 4.2) BASE_DPI = 110 plt.rcParams.update({ "figure.figsize": BASE_FIGSIZE, "figure.dpi": BASE_DPI, "savefig.dpi": BASE_DPI, "figure.autolayout": True, "font.size": 10.5, "axes.titlesize": 11.5, "axes.labelsize": 10.5, "axes.spines.top": False, "axes.spines.right": False, "axes.edgecolor": "#94A3B8", "axes.labelcolor": "#0F172A", "xtick.color": "#475569", "ytick.color": "#475569", "legend.frameon": False, "legend.fontsize": 9.5, }) COLORS = { "primary": "#2563EB", "secondary": "#7C3AED", "accent": "#059669", "warm": "#DC2626", "neutral": "#6B7280", "bg": "#F9FAFB", } MECH_COLORS_EC = { "Nernst": "#3B82F6", "BV": "#8B5CF6", "MHC": "#EC4899", "Ads": "#F59E0B", "EC": "#10B981", "LH": "#EF4444", } MECH_COLORS_TPD = { "FirstOrder": "#3B82F6", "SecondOrder": "#8B5CF6", "LH_Surface": "#EC4899", "MvK": "#F59E0B", "FirstOrderCovDep": "#10B981", "DiffLimited": "#EF4444", } MECH_FULL_NAMES_EC = { "Nernst": "Nernstian (reversible electron transfer)", "BV": "Butler–Volmer (quasi-reversible electron transfer)", "MHC": "Marcus–Hush–Chidsey (non-adiabatic electron transfer)", "Ads": "surface-adsorbed redox couple", "EC": "electron transfer followed by chemical step", "LH": "Langmuir–Hinshelwood surface reaction", "EE": "two sequential electron transfers", "EC_prime": "electron transfer with catalytic regeneration", "CE": "chemical step preceding electron transfer", # OOD (not in trained set) "ECE": "electron transfer–chemical step–electron transfer", "EC_LH": "EC followed by Langmuir–Hinshelwood", "MHC_EC": "Marcus–Hush–Chidsey followed by chemical step", "MHC_LH": "Marcus–Hush–Chidsey followed by Langmuir–Hinshelwood", } MECH_FULL_NAMES_TPD = { "FirstOrder": "1st-order desorption", "SecondOrder": "2nd-order recombinative desorption", "ZerothOrder": "0th-order (multilayer) desorption", "LH_Surface": "Langmuir–Hinshelwood surface reaction", "MvK": "Mars–van Krevelen (lattice oxygen)", "FirstOrderCovDep": "1st-order with coverage-dependent Ed", "DiffLimited": "diffusion-limited desorption", "PrecursorMediated": "precursor-mediated desorption", "Dissociative": "dissociative desorption", "ActivatedAdsorption": "desorption with activated re-adsorption", "TwoSite": "two-site (heterogeneous) desorption", # OOD (not in trained set) "EleyRideal": "Eley–Rideal surface reaction", "BimolecularCovDep": "bimolecular coverage-dependent desorption", } def plot_mechanism_probs(probs_dict, domain="ec"): """ Horizontal bar chart of mechanism classification probabilities. Args: probs_dict: {mechanism_name: probability} domain: 'ec' or 'tpd' Returns: matplotlib Figure """ colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD full_names = MECH_FULL_NAMES_EC if domain == "ec" else MECH_FULL_NAMES_TPD names = list(probs_dict.keys()) probs = [probs_dict[n] for n in names] sorted_idx = np.argsort(probs) names = [names[i] for i in sorted_idx] probs = [probs[i] for i in sorted_idx] bar_colors = [colors.get(n, COLORS["neutral"]) for n in names] display_names = [f"{n} ({full_names.get(n, n)})" for n in names] fig, ax = plt.subplots(figsize=(11, max(3, len(names) * 0.7))) bars = ax.barh(range(len(names)), probs, color=bar_colors, edgecolor="white", linewidth=0.5, height=0.7) ax.set_yticks(range(len(names))) ax.set_yticklabels(display_names, fontsize=11, fontweight="medium") ax.set_xlim(0, 1.05) ax.set_xlabel("Probability", fontsize=12) ax.set_title("Mechanism Classification", fontsize=14, fontweight="bold", pad=15) for i, (bar, prob) in enumerate(zip(bars, probs)): if prob > 0.05: ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height() / 2, f"{prob:.1%}", va="center", fontsize=11, fontweight="bold") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.grid(axis="x", alpha=0.3, linestyle="--") fig.tight_layout() return fig def plot_posteriors(samples, param_names, mechanism_name, domain="ec"): """ Violin plots of posterior distributions for each parameter. Args: samples: [n_samples, D] array of posterior samples param_names: list of parameter names mechanism_name: name of the mechanism domain: 'ec' or 'tpd' Returns: matplotlib Figure """ n_params = len(param_names) fig, axes = plt.subplots(1, n_params, figsize=(max(4, 3 * n_params), 4.5)) if n_params == 1: axes = [axes] colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD color = colors.get(mechanism_name, COLORS["primary"]) for i, (ax, name) in enumerate(zip(axes, param_names)): data = samples[:, i] parts = ax.violinplot(data, positions=[0], showmeans=True, showmedians=True, showextrema=False) for pc in parts["bodies"]: pc.set_facecolor(color) pc.set_alpha(0.6) parts["cmeans"].set_color("black") parts["cmedians"].set_color(COLORS["warm"]) q05, q95 = np.quantile(data, [0.05, 0.95]) ax.axhline(q05, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8) ax.axhline(q95, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8) ax.set_title(_format_param_name(name), fontsize=11, fontweight="medium") ax.set_xticks([]) ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) mean_val = data.mean() ax.text(0.5, 0.02, f"mean={mean_val:.3f}", transform=ax.transAxes, ha="center", fontsize=9, color=COLORS["neutral"]) fig.suptitle(f"Parameter Posteriors — {mechanism_name}", fontsize=14, fontweight="bold") # Shared legend explaining the mean / median markers and 90% interval. from matplotlib.lines import Line2D legend_handles = [ Line2D([0], [0], color="black", lw=1.5, label="mean"), Line2D([0], [0], color=COLORS["warm"], lw=1.5, label="median"), Line2D([0], [0], color=COLORS["neutral"], lw=0.8, ls="--", label="5th / 95th percentile"), ] fig.legend( handles=legend_handles, loc="lower center", ncol=3, frameon=False, fontsize=9, bbox_to_anchor=(0.5, -0.02), ) fig.tight_layout(rect=[0, 0.04, 1, 0.93]) return fig def plot_reconstruction(observed_curves, recon_curves, domain="ec", nrmses=None, r2s=None, scan_labels=None): """ Overlay of observed vs reconstructed signals with optional metrics. Args: observed_curves: list of dicts with 'x' and 'y' arrays recon_curves: list of dicts with 'x' and 'y' arrays (same length) domain: 'ec' or 'tpd' nrmses: optional list of NRMSE values per curve r2s: optional list of R2 values per curve scan_labels: optional list of label strings per curve Returns: matplotlib Figure """ n_curves = len(observed_curves) fig, axes = plt.subplots(1, min(n_curves, 4), figsize=(max(5, 4 * min(n_curves, 4)), 5), squeeze=False) axes = axes[0] xlabel = "Potential (\u03b8)" if domain == "ec" else "Temperature (K)" ylabel = "Flux" if domain == "ec" else "Rate" for i, ax in enumerate(axes): if i >= n_curves: ax.set_visible(False) continue obs = observed_curves[i] rec = recon_curves[i] ax.plot(obs["x"], obs["y"], color=COLORS["neutral"], linewidth=1.5, label="Observed", alpha=0.8) ax.plot(rec["x"], rec["y"], color=COLORS["primary"], linewidth=1.5, label="Reconstructed", linestyle="--") ax.set_xlabel(xlabel, fontsize=10) if i == 0: ax.set_ylabel(ylabel, fontsize=10) ax.legend(fontsize=8, framealpha=0.8, loc="best") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if scan_labels and i < len(scan_labels): title = scan_labels[i] elif domain == "ec": title = f"Scan rate {i + 1}" else: title = f"Heating rate {i + 1}" ax.set_title(title, fontsize=10) metrics_parts = [] if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]): metrics_parts.append(f"NRMSE={nrmses[i]:.4f}") if r2s and i < len(r2s) and np.isfinite(r2s[i]): metrics_parts.append(f"R\u00b2={r2s[i]:.4f}") if metrics_parts: ax.text(0.02, 0.98, " ".join(metrics_parts), transform=ax.transAxes, fontsize=8, va="top", color=COLORS["accent"], fontweight="bold", bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8, edgecolor=COLORS["accent"])) suptitle = "Signal Reconstruction" if nrmses and r2s: valid_nrmse = [v for v in nrmses if np.isfinite(v)] valid_r2 = [v for v in r2s if np.isfinite(v)] if valid_nrmse and valid_r2: avg_nrmse = np.mean(valid_nrmse) avg_r2 = np.mean(valid_r2) suptitle += f" (avg NRMSE={avg_nrmse:.4f}, avg R\u00b2={avg_r2:.4f})" fig.suptitle(suptitle, fontsize=12, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.93]) return fig def _add_sweep_arrows(ax, pot, y_ox, y_red, mid, show_labels=False): """Add direction arrows for forward/reverse sweeps on both species.""" sweep_specs = [ (slice(None, mid), 0.35), (slice(mid, None), 0.65), ] curves = [ (y_ox, COLORS["primary"]), (y_red, COLORS["warm"]), ] for y_data, color in curves: for segment, frac in sweep_specs: x_seg = pot[segment] y_seg = y_data[segment] n = len(x_seg) if n < 10: continue idx = int(n * frac) idx = max(2, min(idx, n - 3)) step = max(1, n // 30) i0 = max(0, idx - step) i1 = min(n - 1, idx + step) ax.annotate( "", xy=(x_seg[i1], y_seg[i1]), xytext=(x_seg[i0], y_seg[i0]), arrowprops=dict(arrowstyle="-|>", color=color, lw=1.8, mutation_scale=14), ) def plot_concentration_profiles(conc_curves, scan_labels=None): """ Plot surface concentration profiles (C_A and C_B) vs potential. Args: conc_curves: list of dicts with 'x' (potential), 'c_ox', 'c_red', or None for failed reconstructions scan_labels: optional list of label strings per curve Returns: matplotlib Figure, or None if no valid data """ valid = [c for c in conc_curves if c is not None] if not valid: return None n_curves = len(conc_curves) fig, axes = plt.subplots(1, min(n_curves, 4), figsize=(max(5, 4 * min(n_curves, 4)), 5), squeeze=False) axes = axes[0] for i, ax in enumerate(axes): if i >= n_curves or conc_curves[i] is None: ax.set_visible(False) continue c = conc_curves[i] pot = np.asarray(c["x"]) c_ox = np.asarray(c["c_ox"]) c_red = np.asarray(c["c_red"]) mid = len(pot) // 2 # Forward sweep (reductive): first half ax.plot(pot[:mid], c_ox[:mid], color=COLORS["primary"], linewidth=1.5, label="C$_A$ (ox)") ax.plot(pot[:mid], c_red[:mid], color=COLORS["warm"], linewidth=1.5, label="C$_B$ (red)") # Reverse sweep (oxidative): second half ax.plot(pot[mid:], c_ox[mid:], color=COLORS["primary"], linewidth=1.5) ax.plot(pot[mid:], c_red[mid:], color=COLORS["warm"], linewidth=1.5) _add_sweep_arrows(ax, pot, c_ox, c_red, mid) ax.set_xlabel("Potential (\u03b8)", fontsize=10) if i == 0: ax.set_ylabel("Surface concentration", fontsize=10) ax.legend(fontsize=8, framealpha=0.8, loc="best") ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) if scan_labels and i < len(scan_labels): ax.set_title(scan_labels[i], fontsize=10) else: ax.set_title(f"Scan rate {i + 1}", fontsize=10) fig.suptitle("Surface Concentration Profiles", fontsize=12, fontweight="bold") fig.tight_layout(rect=[0, 0, 1, 0.93]) return fig def plot_parameter_table(param_stats, mechanism_name): """ Create a formatted parameter summary table as a figure. Args: param_stats: dict with 'names', 'mean', 'std', 'q05', 'q95' mechanism_name: name of the mechanism Returns: matplotlib Figure """ names = param_stats["names"] means = param_stats["mean"] stds = param_stats["std"] q05s = param_stats["q05"] q95s = param_stats["q95"] n = len(names) fig, ax = plt.subplots(figsize=(8, max(2, 0.6 * n + 1))) ax.axis("off") col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"] cell_text = [] for i in range(n): cell_text.append([ _format_param_name(names[i]), f"{means[i]:.4f}", f"{stds[i]:.4f}", f"{q05s[i]:.4f}", f"{q95s[i]:.4f}", ]) table = ax.table(cellText=cell_text, colLabels=col_labels, loc="center", cellLoc="center") table.auto_set_font_size(False) table.set_fontsize(11) table.scale(1.0, 1.5) for (row, col), cell in table.get_celld().items(): if row == 0: cell.set_facecolor("#E5E7EB") cell.set_text_props(fontweight="bold") else: cell.set_facecolor("#F9FAFB" if row % 2 == 0 else "white") ax.set_title(f"Parameter Estimates — {mechanism_name}", fontsize=14, fontweight="bold", pad=20) fig.tight_layout() return fig def _format_param_name(name): """Format parameter names for display.""" replacements = { "log10(K0)": "log₁₀(K₀)", "log10(dB)": "log₁₀(d_B)", "log10(dA)": "log₁₀(d_A)", "log10(kc)": "log₁₀(k_c)", "log10(reorg_e)": "log₁₀(λ)", "log10(Gamma_sat)": "log₁₀(Γ_sat)", "log10(KA_eq)": "log₁₀(K_A,eq)", "log10(KB_eq)": "log₁₀(K_B,eq)", "log10(nu)": "log₁₀(ν)", "log10(nu_red)": "log₁₀(ν_red)", "log10(D0)": "log₁₀(D₀)", "E0_offset": "E₀ offset", "alpha": "α", "alpha_cov": "α_cov", "Ed": "E_d (K)", "Ed0": "E_d0 (K)", "Ea": "E_a (K)", "Ea_red": "E_a,red (K)", "Ea_reox": "E_a,reox (K)", "E_diff": "E_diff (K)", "theta_0": "θ₀", "theta_A0": "θ_A0", "theta_B0": "θ_B0", "theta_O0": "θ_O0", } return replacements.get(name, name)