"""
SPARK (Simulation-based Posterior Amortization for Reaction Kinetics)
Gradio web interface for mechanism classification and parameter inference
from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data.
"""
import os
import sys
import json
import tempfile
from pathlib import Path
import numpy as np
import gradio as gr
from inference import SPARKPredictor
from preprocessing import (
nondimensionalize_cv,
estimate_E0,
parse_cv_csv,
parse_tpd_csv,
)
from plotting import (
plot_mechanism_probs,
plot_posteriors,
plot_parameter_table,
plot_reconstruction,
plot_concentration_profiles,
)
# ---------------------------------------------------------------------------
# Model paths (relative to repo root)
# ---------------------------------------------------------------------------
REPO_ROOT = Path(__file__).resolve().parent
DEMO_DIR = REPO_ROOT / "demo_data"
DEMO_RENDERS = REPO_ROOT / "demo_renders"
EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
# Allow override via environment variables. SPARK_*_CHECKPOINT is the
# canonical name; ECFLOW_*_CHECKPOINT remains accepted for backward
# compatibility with existing HF Space secrets.
EC_CHECKPOINT = Path(
os.environ.get(
"SPARK_EC_CHECKPOINT",
os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)),
)
)
TPD_CHECKPOINT = Path(
os.environ.get(
"SPARK_TPD_CHECKPOINT",
os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)),
)
)
# Optional image-input checkpoints. When set and the file exists, the
# "From Image" tabs use the image-input SPARK model directly instead of
# digitizing the plot. If not set, the digitizer fallback is used.
EC_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_EC_IMAGE_CHECKPOINT", "")
TPD_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_IMAGE_CHECKPOINT", "")
EC_IMAGE_CHECKPOINT = (
Path(EC_IMAGE_CHECKPOINT_PATH) if EC_IMAGE_CHECKPOINT_PATH else None
)
TPD_IMAGE_CHECKPOINT = (
Path(TPD_IMAGE_CHECKPOINT_PATH) if TPD_IMAGE_CHECKPOINT_PATH else None
)
# Optional Phase 2 joint (image + waveform) checkpoints. When present, the
# "From Image" tabs expose a "Joint image+waveform (Phase 2)" inference
# strategy that fuses the rendered image with the digitizer-extracted
# waveform inside a single encoder. Falls back to ensemble if missing.
EC_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_EC_JOINT_CHECKPOINT", "")
TPD_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_JOINT_CHECKPOINT", "")
EC_JOINT_CHECKPOINT = (
Path(EC_JOINT_CHECKPOINT_PATH) if EC_JOINT_CHECKPOINT_PATH else None
)
TPD_JOINT_CHECKPOINT = (
Path(TPD_JOINT_CHECKPOINT_PATH) if TPD_JOINT_CHECKPOINT_PATH else None
)
# ---------------------------------------------------------------------------
# Demo examples
# ---------------------------------------------------------------------------
def _discover_examples():
"""Scan demo_data/ for metadata files and build example catalogs."""
cv_examples = {}
tpd_examples = {}
if not DEMO_DIR.is_dir():
return cv_examples, tpd_examples
for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")):
with open(meta_path) as f:
meta = json.load(f)
mech = meta["mechanism"]
csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]]
is_ood = bool(meta.get("is_ood", False))
ood_tag = " *OOD*" if is_ood else ""
if meta_path.name.startswith("ec_"):
rates = meta.get("scan_rates_Vs", [])
rates_str = ", ".join(f"{r:.4g}" for r in rates)
phys = meta.get("physical_params", {})
cv_examples[f"CV \u2014 {mech}{ood_tag}"] = {
"files": csv_files,
"scan_rates": rates_str,
"is_ood": is_ood,
"E0_V": phys.get("E0_V"),
"T_K": phys.get("T_K", 298.15),
"A_cm2": phys.get("A_cm2", 0.0707),
"C_mM": phys.get("C_mM", 1.0),
"D_cm2s": phys.get("D_cm2s", 1e-5),
"n_electrons": phys.get("n_electrons", 1),
}
elif meta_path.name.startswith("tpd_"):
betas = meta.get("betas_Ks", [])
betas_str = ", ".join(f"{b:.4g}" for b in betas)
tpd_examples[f"TPD \u2014 {mech}{ood_tag}"] = {
"files": csv_files,
"heating_rates": betas_str,
"is_ood": is_ood,
}
return cv_examples, tpd_examples
CV_EXAMPLES, TPD_EXAMPLES = _discover_examples()
def _discover_image_examples():
"""Build image example catalogs from demo_renders/ directory."""
import re
cv_img_examples = {}
tpd_img_examples = {}
cv_output_renders = {}
tpd_output_renders = {}
if not DEMO_RENDERS.is_dir():
return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders
for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")):
m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name)
if not m:
continue
mech, rate_mVs = m.group(1), int(m.group(2))
rate_Vs = rate_mVs / 1000.0
cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs))
for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")):
m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name)
if not m:
continue
mech, idx = m.group(1), m.group(2)
tpd_img_examples.setdefault(mech, []).append((str(p), idx))
for mech in cv_img_examples:
renders = {}
for suffix in ["classification", "posteriors", "reconstruction", "concentration"]:
rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png"
if rp.exists():
renders[suffix] = str(rp)
if renders:
cv_output_renders[mech] = renders
for mech in tpd_img_examples:
renders = {}
for suffix in ["classification", "posteriors", "reconstruction"]:
rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png"
if rp.exists():
renders[suffix] = str(rp)
if renders:
tpd_output_renders[mech] = renders
return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders
(CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES,
CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples()
def _load_cv_image_example(mech_name):
"""Return (files, scan_rates_str) for a CV image example."""
if not mech_name or mech_name not in CV_IMG_EXAMPLES:
return [gr.update()] * 2
entries = CV_IMG_EXAMPLES[mech_name]
files = [e[0] for e in entries]
rates_str = ", ".join(f"{e[1]}" for e in entries)
return files, rates_str
def _load_tpd_image_example(mech_name):
"""Return (image_path, heating_rate_str) for a TPD image example."""
if not mech_name or mech_name not in TPD_IMG_EXAMPLES:
return [gr.update()] * 2
entries = TPD_IMG_EXAMPLES[mech_name]
return entries[0][0], entries[0][1]
def _load_cv_example(example_name):
"""Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example."""
if not example_name or example_name not in CV_EXAMPLES:
return [gr.update()] * 8
ex = CV_EXAMPLES[example_name]
return (
ex["files"],
ex["scan_rates"],
ex["E0_V"],
ex["T_K"],
ex["A_cm2"],
ex["C_mM"],
ex["D_cm2s"],
ex["n_electrons"],
)
def _load_tpd_example(example_name):
"""Return (files, heating_rates) for the chosen TPD example."""
if not example_name or example_name not in TPD_EXAMPLES:
return [gr.update()] * 2
ex = TPD_EXAMPLES[example_name]
return (
ex["files"],
ex["heating_rates"],
)
predictor = None
def get_predictor():
global predictor
if predictor is None:
ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None
tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None
ec_img = (str(EC_IMAGE_CHECKPOINT)
if (EC_IMAGE_CHECKPOINT and EC_IMAGE_CHECKPOINT.exists())
else None)
tpd_img = (str(TPD_IMAGE_CHECKPOINT)
if (TPD_IMAGE_CHECKPOINT and TPD_IMAGE_CHECKPOINT.exists())
else None)
ec_joint = (str(EC_JOINT_CHECKPOINT)
if (EC_JOINT_CHECKPOINT and EC_JOINT_CHECKPOINT.exists())
else None)
tpd_joint = (str(TPD_JOINT_CHECKPOINT)
if (TPD_JOINT_CHECKPOINT and TPD_JOINT_CHECKPOINT.exists())
else None)
predictor = SPARKPredictor(
ec_checkpoint=ec_ckpt,
tpd_checkpoint=tpd_ckpt,
ec_image_checkpoint=ec_img,
tpd_image_checkpoint=tpd_img,
ec_joint_checkpoint=ec_joint,
tpd_joint_checkpoint=tpd_joint,
device="cpu",
)
return predictor
# =========================================================================
# CV Analysis
# =========================================================================
def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2,
C_mM, D_cm2s, n_electrons, n_samples):
"""Analyze CV data from potentiostat CSV files.
Accepts CSV files with columns for potential (V) and current (A/mA/µA).
If the CSV includes a Time (s) column, the scan rate is auto-detected.
Otherwise, scan rates must be provided.
"""
if not files:
return _ec_error("Please upload at least one CSV file.")
scan_rates_text = scan_rates_text.strip() if scan_rates_text else ""
user_rates = None
if scan_rates_text:
try:
user_rates = [float(s.strip()) for s in scan_rates_text.split(",")]
except ValueError:
return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.")
if len(files) != len(user_rates):
return _ec_error(
f"Number of files ({len(files)}) must match number of "
f"scan rates ({len(user_rates)}).")
C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6
n = int(n_electrons) if n_electrons else 1
T = float(T_K) if T_K else 298.15
A = float(A_cm2) if A_cm2 else 0.0707
parsed_data = []
scan_rates = []
for idx, f in enumerate(files):
content = Path(f.name).read_text()
parsed = parse_cv_csv(content)
parsed_data.append(parsed)
if user_rates is not None:
v = user_rates[idx]
elif "scan_rate_Vs" in parsed:
v = parsed["scan_rate_Vs"]
else:
return _ec_error(
f"Cannot determine scan rate for file '{Path(f.name).name}'. "
"Either provide scan rates (V/s) or upload CSV files that "
"include a Time (s) column.")
scan_rates.append(v)
if E0_V:
e0 = float(E0_V)
e0_source = "user"
else:
e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data]
e0 = float(np.median(e0_estimates))
e0_source = "auto"
D = float(D_cm2s) if D_cm2s else 1e-5
potentials, fluxes, sigmas_list = [], [], []
for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)):
E, i_A = parsed["E_V"], parsed["i_A"]
theta, flux, sigma = nondimensionalize_cv(
E, i_A, v, e0, T, A, C_molcm3, D, n
)
potentials.append(theta)
fluxes.append(flux)
sigmas_list.append(sigma)
return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples)
_CURRENT_UNIT_SCALES = {
"µA": 1e-6,
"mA": 1e-3,
"A": 1.0,
"nA": 1e-9,
}
def _guess_current_unit(i_values):
"""Guess the current unit from the magnitude of digitized values."""
i_max = np.max(np.abs(i_values))
if i_max > 1e3:
return "nA"
if i_max > 100:
return "µA"
if i_max > 0.1:
return "µA"
if i_max > 1e-4:
return "mA"
return "A"
def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit,
n_samples, x_min, x_max, y_min, y_max,
method_label="Ensemble (recommended)"):
"""Analyze CV from uploaded plot images (one per scan rate).
When the image-input SPARK model is available we run a hybrid pipeline:
image-mode prediction in parallel with the digitizer + waveform-mode path,
and combine via `method_label` (Ensemble / Image-direct / Digitize-then-infer
/ Auto-fallback). When only the waveform model is available, falls back
to digitize-then-infer transparently.
Returns 10 outputs: the existing 8 (banner, probs, state, mech_dd, post,
table, recon, conc) plus method-comparison HTML + preprocessed-image
gallery preview.
"""
err = _ec_image_error
if not files:
return err("Please upload at least one image.")
try:
from digitizer import digitize_plot, auto_detect_axis_bounds
from PIL import Image as PILImage
except ImportError:
return err("Required libraries not available for image digitization.")
scan_rate_text = scan_rate_text.strip() if scan_rate_text else ""
if not scan_rate_text:
return err("Please enter the scan rate(s) (V/s), comma-separated.")
try:
scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")]
except ValueError:
return err("Invalid scan rates. Enter comma-separated numbers in V/s.")
if len(files) != len(scan_rates):
return err(
f"Number of images ({len(files)}) must match number of "
f"scan rates ({len(scan_rates)}).")
has_user_bounds = all(
v is not None and v != 0 for v in [x_min, x_max, y_min, y_max]
)
D = 1e-5
T = 298.15
A = 0.0707
C_molcm3 = 1e-6
n = 1
# First pass: digitize all images and estimate E0 per image
image_data = []
e0_estimates = []
for idx, f in enumerate(files):
img_arr = np.array(PILImage.open(f.name).convert("RGB"))
v_Vs = scan_rates[idx]
if has_user_bounds:
bounds = {
"x_min": float(x_min), "x_max": float(x_max),
"y_min": float(y_min), "y_max": float(y_max),
}
else:
bounds = auto_detect_axis_bounds(img_arr)
if bounds is None:
return err(
f"Could not auto-detect axis bounds for image {idx + 1}. "
"Please enter E min, E max, I min, I max under "
"'Axis overrides'.")
try:
E_V, I_raw = digitize_plot(
img_arr, bounds["x_min"], bounds["x_max"],
bounds["y_min"], bounds["y_max"],
threshold=int(threshold),
x_ticks=bounds.get("x_ticks"),
y_ticks=bounds.get("y_ticks"),
)
except Exception as exc:
return err(f"Digitization failed for image {idx + 1}: {exc}")
if current_unit and current_unit != "auto":
i_unit = current_unit
elif "y_unit" in bounds:
i_unit = bounds["y_unit"]
else:
i_unit = _guess_current_unit(I_raw)
i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6)
i_A = I_raw * i_scale
e0_estimates.append(float(estimate_E0(E_V, i_A)))
image_data.append((E_V, i_A, v_Vs, i_unit, bounds))
# Determine E0: user-provided or median of per-image estimates
if E0_V is not None and E0_V != 0:
e0 = float(E0_V)
e0_source = "user"
else:
e0 = float(np.median(e0_estimates))
e0_source = "auto"
# Second pass: nondimensionalize with shared E0
potentials, fluxes, sigmas_list = [], [], []
preproc_parts = []
for E_V, i_A, v_Vs, i_unit, bounds in image_data:
theta, flux, sigma = nondimensionalize_cv(
E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n
)
potentials.append(theta)
fluxes.append(flux)
sigmas_list.append(sigma)
preproc_parts.append(
f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, "
f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, "
f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})"
)
pil_images = [PILImage.open(f.name).convert("L") for f in files]
return _run_ec_analysis_hybrid(
pil_images, sigmas_list, potentials, fluxes, n_samples, method_label,
)
OOD_THRESHOLD = float(os.environ.get("SPARK_OOD_THRESHOLD", "0.5"))
def _ood_banner_update(ood_score, threshold=OOD_THRESHOLD):
"""Build the gr.HTML update for the OOD banner.
Shown only when the binary OOD head's P(in-distribution) score drops
below `threshold`. Hidden otherwise.
"""
if ood_score is None or ood_score >= threshold:
return gr.update(visible=False)
html = (
"
")
gr.HTML(
"
"
"
SPARK
"
"
Simulation-based Posterior "
"Amortization for Reaction "
"Kinetics — Bayesian inference for "
"reaction mechanisms and kinetic parameters from cyclic "
"voltammetry (CV) or temperature-programmed desorption (TPD) "
"data, in one forward pass.
"
"
"
"
"
)
with gr.Tabs():
# =================================================================
# Tab 1: CV Analysis
# =================================================================
with gr.Tab("CV Analysis"):
cv_mode = gr.Radio(
choices=["CSV Data", "From Image"],
value="CSV Data",
label=None,
show_label=False,
elem_classes=["mode-radio"],
)
# --- CSV upload mode (primary) ---
with gr.Column(visible=True) as cv_csv_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"
One CSV per scan rate. Expected columns: "
"Potential (V), Current (A/mA/µA), "
"optional Time (s) — if present, the scan rate is "
"inferred as the median of |dE/dt| over the forward "
"sweep (no need to specify it manually).
"
)
if CV_EXAMPLES:
with gr.Row():
cv_example_dd = gr.Dropdown(
label="Try an example",
choices=list(CV_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
cv_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
cv_files = gr.File(
label="CSV files (one per scan rate)",
file_count="multiple",
file_types=[".csv", ".txt"],
)
cv_rates = gr.Textbox(
label="Scan rates (V/s), comma-separated",
placeholder="e.g., 0.01, 0.1, 1.0 (leave empty if CSV has time column)",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
cv_E0 = gr.Number(
label="Formal potential E\u2080 (V)",
value=None,
info="Auto-estimated from peak positions if empty",
)
cv_T = gr.Number(label="Temperature (K)", value=298.15)
cv_A = gr.Number(label="Electrode area (cm\u00b2)", value=0.0707)
with gr.Row():
cv_C = gr.Number(label="Concentration (mM)", value=1.0)
cv_D = gr.Number(
label="Diffusion coeff D (cm\u00b2/s)",
value=None,
info="Estimated via Randles-\u0160ev\u010d\u00edk if empty",
)
cv_n = gr.Number(label="Number of electrons", value=1, precision=0)
cv_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
cv_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(cv_ood_banner, cv_probs, cv_state,
cv_mech_dd, cv_posteriors, cv_param_table,
cv_recon, cv_conc) = _build_ec_output_section("cv")
# --- Image mode ---
with gr.Column(visible=False) as cv_image_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"
One image per scan rate (potential on x-axis, "
"current on y-axis). Axis bounds auto-detected via OCR — "
"override in Advanced if needed. "
"Note: digitized curves are inherently noisier than the "
"simulated data the model was trained on, so the OOD detector may "
"flag image inputs and reconstruction quality is typically lower "
"than the CSV path; for production use, prefer CSV.
"
)
if CV_IMG_EXAMPLES:
with gr.Row():
cv_img_example_dd = gr.Dropdown(
label="Try an example",
choices=list(CV_IMG_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
cv_img_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
cv_img_example_gallery = gr.Gallery(
label="Input plot images (one per scan rate) \u2014 these are what the model digitizes",
columns=3, height=220, object_fit="contain",
interactive=False,
)
cv_img_files = gr.File(
label="Plot images (one per scan rate)",
file_count="multiple",
file_types=["image"],
)
cv_img_scan_rate = gr.Textbox(
label="Scan rates (V/s), comma-separated",
placeholder="e.g., 0.01, 0.1, 1.0",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
cv_img_E0 = gr.Number(
label="Formal potential E\u2080 (V)",
value=None,
info="Auto-estimated from peaks if empty",
)
cv_img_threshold = gr.Slider(
0, 255, value=0, step=1,
label="Binarization threshold (0 = auto)",
)
cv_img_current_unit = gr.Dropdown(
label="Current unit on y-axis",
choices=["auto", "\u00b5A", "mA", "A", "nA"],
value="auto",
info="Select the unit shown on the y-axis of your plot",
)
with gr.Row():
cv_img_xmin = gr.Number(label="E min (V)", value=None)
cv_img_xmax = gr.Number(label="E max (V)", value=None)
cv_img_ymin = gr.Number(label="I min", value=None)
cv_img_ymax = gr.Number(label="I max", value=None)
cv_img_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
_cv_choices, _cv_default = _hybrid_choices_and_default("ec")
cv_img_method = gr.Radio(
choices=_cv_choices,
value=_cv_default,
label="Inference strategy",
info=("Joint (Phase 2) fuses the rendered image "
"with digitizer-extracted waveforms in a "
"single encoder; Ensemble averages "
"image-direct and digitize-then-infer; "
"Auto-fallback uses image-direct unless "
"its OOD score is low."),
)
cv_img_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(cv_img_ood_banner, cv_img_probs, cv_img_state,
cv_img_mech_dd, cv_img_posteriors, cv_img_param_table,
cv_img_recon, cv_img_conc) = _build_ec_output_section("cv_img")
with gr.Accordion("Method comparison", open=False):
cv_img_comparison = gr.HTML(visible=False)
cv_img_preview = gr.Gallery(
label="Preprocessed image fed to image-mode CNN",
columns=3, height=180, object_fit="contain",
interactive=False, visible=False,
)
cv_mode.change(
lambda v: (
gr.update(visible=v == "CSV Data"),
gr.update(visible=v == "From Image"),
),
inputs=[cv_mode],
outputs=[cv_csv_group, cv_image_group],
)
# CSV wiring
ec_outputs = [
cv_ood_banner,
cv_probs, cv_state,
cv_mech_dd, cv_posteriors, cv_param_table,
cv_recon, cv_conc,
]
cv_btn.click(
analyze_cv,
inputs=[
cv_files, cv_rates, cv_E0, cv_T,
cv_A, cv_C, cv_D, cv_n, cv_nsamples,
],
outputs=ec_outputs,
)
cv_mech_dd.change(
_on_ec_mechanism_change,
inputs=[cv_mech_dd, cv_state],
outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
)
if CV_EXAMPLES:
cv_example_btn.click(
_load_cv_example,
inputs=[cv_example_dd],
outputs=[
cv_files, cv_rates, cv_E0, cv_T,
cv_A, cv_C, cv_D, cv_n,
],
)
# Image wiring (hybrid: 8 base outputs + comparison + preview)
ec_img_outputs = [
cv_img_ood_banner,
cv_img_probs, cv_img_state,
cv_img_mech_dd, cv_img_posteriors, cv_img_param_table,
cv_img_recon, cv_img_conc,
cv_img_comparison, cv_img_preview,
]
cv_img_btn.click(
analyze_cv_image,
inputs=[
cv_img_files, cv_img_scan_rate, cv_img_E0,
cv_img_threshold, cv_img_current_unit,
cv_img_nsamples,
cv_img_xmin, cv_img_xmax,
cv_img_ymin, cv_img_ymax,
cv_img_method,
],
outputs=ec_img_outputs,
)
cv_img_mech_dd.change(
_on_ec_mechanism_change,
inputs=[cv_img_mech_dd, cv_img_state],
outputs=[cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc],
)
if CV_IMG_EXAMPLES:
def _cv_img_input_gallery(mech_name):
"""Show the actual input PNGs the user will analyze."""
if not mech_name or mech_name not in CV_IMG_EXAMPLES:
return []
return [e[0] for e in CV_IMG_EXAMPLES[mech_name]]
def _on_cv_img_example_select(mech_name):
files, rates = _load_cv_image_example(mech_name)
return files, rates, _cv_img_input_gallery(mech_name)
cv_img_example_btn.click(
_on_cv_img_example_select,
inputs=[cv_img_example_dd],
outputs=[cv_img_files, cv_img_scan_rate, cv_img_example_gallery],
)
cv_img_example_dd.change(
_cv_img_input_gallery,
inputs=[cv_img_example_dd],
outputs=[cv_img_example_gallery],
)
# =================================================================
# Tab 2: TPD Analysis
# =================================================================
with gr.Tab("TPD Analysis"):
tpd_mode = gr.Radio(
choices=["CSV Data", "From Image"],
value="CSV Data",
label=None,
show_label=False,
elem_classes=["mode-radio"],
)
# --- CSV mode ---
with gr.Column(visible=True) as tpd_csv_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"
One CSV per heating rate. Expected columns: "
"Temperature (K), Signal, optional "
"Time (s) — if present, the heating rate β is "
"inferred as dT/dt from the time stamps "
"(no need to specify it manually).
"
)
if TPD_EXAMPLES:
with gr.Row():
tpd_example_dd = gr.Dropdown(
label="Try an example",
choices=list(TPD_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
tpd_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
tpd_files = gr.File(
label="CSV files (one per heating rate)",
file_count="multiple",
file_types=[".csv", ".txt"],
)
tpd_betas = gr.Textbox(
label="Heating rates \u03b2 (K/s), comma-separated",
placeholder="e.g., 0.3, 2.6, 22.1 (leave empty if CSV has time column)",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
tpd_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
tpd_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(tpd_ood_banner, tpd_probs, tpd_state,
tpd_mech_dd, tpd_posteriors, tpd_param_table,
tpd_recon) = _build_tpd_output_section("tpd")
# --- Image mode ---
with gr.Column(visible=False) as tpd_image_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"
One image per heating rate (temperature on "
"x-axis, signal on y-axis). Axis bounds auto-detected via OCR "
"— override in Advanced if needed. "
"Note: digitized curves are inherently noisier than the "
"simulated data the model was trained on, so the OOD detector may "
"flag image inputs and reconstruction quality is typically lower "
"than the CSV path; for production use, prefer CSV.
"
)
if TPD_IMG_EXAMPLES:
with gr.Row():
tpd_img_example_dd = gr.Dropdown(
label="Try an example",
choices=list(TPD_IMG_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
tpd_img_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
tpd_img_example_gallery = gr.Gallery(
label="Input plot image \u2014 this is what the model digitizes",
columns=3, height=220, object_fit="contain",
interactive=False,
)
tpd_img_files = gr.File(
label="Plot images (one per heating rate)",
file_count="multiple",
file_types=["image"],
)
tpd_img_betas = gr.Textbox(
label="Heating rates \u03b2 (K/s), comma-separated",
placeholder="e.g., 0.3, 2.6, 22.1",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
tpd_img_threshold = gr.Slider(
0, 255, value=0, step=1,
label="Binarization threshold (0 = auto)",
)
with gr.Row():
tpd_img_xmin = gr.Number(label="T min (K)", value=None)
tpd_img_xmax = gr.Number(label="T max (K)", value=None)
tpd_img_ymin = gr.Number(label="Signal min", value=None)
tpd_img_ymax = gr.Number(label="Signal max", value=None)
tpd_img_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
_tpd_choices, _tpd_default = _hybrid_choices_and_default("tpd")
tpd_img_method = gr.Radio(
choices=_tpd_choices,
value=_tpd_default,
label="Inference strategy",
info=("Joint (Phase 2) fuses the image with "
"digitizer-extracted waveforms in a "
"single encoder; Ensemble averages "
"image-direct and digitize-then-infer; "
"Auto-fallback uses image-direct unless "
"its OOD score is low."),
)
tpd_img_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(tpd_img_ood_banner, tpd_img_probs, tpd_img_state,
tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table,
tpd_img_recon) = _build_tpd_output_section("tpd_img")
with gr.Accordion("Method comparison", open=False):
tpd_img_comparison = gr.HTML(visible=False)
tpd_img_preview = gr.Gallery(
label="Preprocessed image fed to image-mode CNN",
columns=3, height=180, object_fit="contain",
interactive=False, visible=False,
)
tpd_mode.change(
lambda v: (
gr.update(visible=v == "CSV Data"),
gr.update(visible=v == "From Image"),
),
inputs=[tpd_mode],
outputs=[tpd_csv_group, tpd_image_group],
)
# CSV wiring
tpd_outputs = [
tpd_ood_banner,
tpd_probs, tpd_state,
tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon,
]
tpd_btn.click(
analyze_tpd,
inputs=[tpd_files, tpd_betas, tpd_nsamples],
outputs=tpd_outputs,
)
tpd_mech_dd.change(
_on_tpd_mechanism_change,
inputs=[tpd_mech_dd, tpd_state],
outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
)
if TPD_EXAMPLES:
tpd_example_btn.click(
_load_tpd_example,
inputs=[tpd_example_dd],
outputs=[tpd_files, tpd_betas],
)
# Image wiring (hybrid: 7 base outputs + comparison + preview)
tpd_img_outputs = [
tpd_img_ood_banner,
tpd_img_probs, tpd_img_state,
tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon,
tpd_img_comparison, tpd_img_preview,
]
tpd_img_btn.click(
analyze_tpd_image,
inputs=[
tpd_img_files, tpd_img_betas,
tpd_img_threshold, tpd_img_nsamples,
tpd_img_xmin, tpd_img_xmax,
tpd_img_ymin, tpd_img_ymax,
tpd_img_method,
],
outputs=tpd_img_outputs,
)
tpd_img_mech_dd.change(
_on_tpd_mechanism_change,
inputs=[tpd_img_mech_dd, tpd_img_state],
outputs=[tpd_img_posteriors, tpd_img_param_table, tpd_img_recon],
)
if TPD_IMG_EXAMPLES:
def _tpd_img_input_gallery(mech_name):
if not mech_name or mech_name not in TPD_IMG_EXAMPLES:
return []
return [e[0] for e in TPD_IMG_EXAMPLES[mech_name]]
def _on_tpd_img_example_select(mech_name):
files, betas = _load_tpd_image_example(mech_name)
return files, betas, _tpd_img_input_gallery(mech_name)
tpd_img_example_btn.click(
_on_tpd_img_example_select,
inputs=[tpd_img_example_dd],
outputs=[tpd_img_files, tpd_img_betas, tpd_img_example_gallery],
)
tpd_img_example_dd.change(
_tpd_img_input_gallery,
inputs=[tpd_img_example_dd],
outputs=[tpd_img_example_gallery],
)
# =================================================================
# Tab 3: About
# =================================================================
with gr.Tab("About"):
with gr.Row(elem_classes=["about-card"]):
with gr.Column(scale=2):
gr.Markdown("""
## How it works
**SPARK** (Simulation-based Posterior Amortization for Reaction Kinetics) uses
**conditional normalizing flows** with a **Set Transformer** encoder to
perform amortized Bayesian inference. Given one or more experimental
curves, it simultaneously classifies the reaction mechanism and produces
full posterior distributions over kinetic parameters — in a single
forward pass.
The deployed checkpoints are the **noise-augmented headline models**
(CV: `v14_9mech`, TPD: `tpd_11mech_v2`) that retain 92.4 % (CV) and
95.6 % (TPD) classification accuracy under realistic measurement noise.
Inference runs in **~50 ms on CPU**; mean 90 %-credible-interval coverage
is 92.2 % (CV) and 92.9 % (TPD).
Training data is generated from physics-based simulators
(Crank–Nicolson for CV, ODE integrators for TPD). Posteriors are
calibrated via a coverage-aware loss with per-parameter inverse-spread
weighting.
### Citation
```
Yan, B. (2026). SPARK: Amortized Bayesian Inference for
Mechanism Identification and Parameter Estimation in
Electrochemistry and Catalysis via Conditional
Normalizing Flows. [Preprint]
```
""")
with gr.Column(scale=1):
gr.Markdown("""
### Supported mechanisms
**Electrochemistry (CV, 9):**
Nernst, Butler–Volmer, Marcus–Hush–Chidsey,
Adsorption, EC, Langmuir–Hinshelwood, EE, EC′, CE.
**Catalysis (TPD, 11):**
First-order, Second-order, Zeroth-order, FirstOrderCovDep,
DiffLimited, Precursor, Dissociative, ActivatedAdsorption,
LH Surface, Mars–van Krevelen, TwoSite.
""")
gr.HTML(
""
)
gr.HTML("
")
return app
# =========================================================================
# Entry point
# =========================================================================
if __name__ == "__main__":
app = build_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
)