trace / app.py
bingyan user
Fix page header: SPARK acronym expansion (was leftover TRACE)
5b37efe
Raw
History Blame Contribute Delete
73.1 kB
"""
SPARK (Simulation-based Posterior Amortization for Reaction Kinetics)
Gradio web interface for mechanism classification and parameter inference
from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data.
"""
import os
import sys
import json
import tempfile
from pathlib import Path
import numpy as np
import gradio as gr
from inference import SPARKPredictor
from preprocessing import (
nondimensionalize_cv,
estimate_E0,
parse_cv_csv,
parse_tpd_csv,
)
from plotting import (
plot_mechanism_probs,
plot_posteriors,
plot_parameter_table,
plot_reconstruction,
plot_concentration_profiles,
)
# ---------------------------------------------------------------------------
# Model paths (relative to repo root)
# ---------------------------------------------------------------------------
REPO_ROOT = Path(__file__).resolve().parent
DEMO_DIR = REPO_ROOT / "demo_data"
DEMO_RENDERS = REPO_ROOT / "demo_renders"
EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt"
TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt"
# Allow override via environment variables. SPARK_*_CHECKPOINT is the
# canonical name; ECFLOW_*_CHECKPOINT remains accepted for backward
# compatibility with existing HF Space secrets.
EC_CHECKPOINT = Path(
os.environ.get(
"SPARK_EC_CHECKPOINT",
os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)),
)
)
TPD_CHECKPOINT = Path(
os.environ.get(
"SPARK_TPD_CHECKPOINT",
os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)),
)
)
# Optional image-input checkpoints. When set and the file exists, the
# "From Image" tabs use the image-input SPARK model directly instead of
# digitizing the plot. If not set, the digitizer fallback is used.
EC_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_EC_IMAGE_CHECKPOINT", "")
TPD_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_IMAGE_CHECKPOINT", "")
EC_IMAGE_CHECKPOINT = (
Path(EC_IMAGE_CHECKPOINT_PATH) if EC_IMAGE_CHECKPOINT_PATH else None
)
TPD_IMAGE_CHECKPOINT = (
Path(TPD_IMAGE_CHECKPOINT_PATH) if TPD_IMAGE_CHECKPOINT_PATH else None
)
# Optional Phase 2 joint (image + waveform) checkpoints. When present, the
# "From Image" tabs expose a "Joint image+waveform (Phase 2)" inference
# strategy that fuses the rendered image with the digitizer-extracted
# waveform inside a single encoder. Falls back to ensemble if missing.
EC_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_EC_JOINT_CHECKPOINT", "")
TPD_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_JOINT_CHECKPOINT", "")
EC_JOINT_CHECKPOINT = (
Path(EC_JOINT_CHECKPOINT_PATH) if EC_JOINT_CHECKPOINT_PATH else None
)
TPD_JOINT_CHECKPOINT = (
Path(TPD_JOINT_CHECKPOINT_PATH) if TPD_JOINT_CHECKPOINT_PATH else None
)
# ---------------------------------------------------------------------------
# Demo examples
# ---------------------------------------------------------------------------
def _discover_examples():
"""Scan demo_data/ for metadata files and build example catalogs."""
cv_examples = {}
tpd_examples = {}
if not DEMO_DIR.is_dir():
return cv_examples, tpd_examples
for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")):
with open(meta_path) as f:
meta = json.load(f)
mech = meta["mechanism"]
csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]]
is_ood = bool(meta.get("is_ood", False))
ood_tag = " *OOD*" if is_ood else ""
if meta_path.name.startswith("ec_"):
rates = meta.get("scan_rates_Vs", [])
rates_str = ", ".join(f"{r:.4g}" for r in rates)
phys = meta.get("physical_params", {})
cv_examples[f"CV \u2014 {mech}{ood_tag}"] = {
"files": csv_files,
"scan_rates": rates_str,
"is_ood": is_ood,
"E0_V": phys.get("E0_V"),
"T_K": phys.get("T_K", 298.15),
"A_cm2": phys.get("A_cm2", 0.0707),
"C_mM": phys.get("C_mM", 1.0),
"D_cm2s": phys.get("D_cm2s", 1e-5),
"n_electrons": phys.get("n_electrons", 1),
}
elif meta_path.name.startswith("tpd_"):
betas = meta.get("betas_Ks", [])
betas_str = ", ".join(f"{b:.4g}" for b in betas)
tpd_examples[f"TPD \u2014 {mech}{ood_tag}"] = {
"files": csv_files,
"heating_rates": betas_str,
"is_ood": is_ood,
}
return cv_examples, tpd_examples
CV_EXAMPLES, TPD_EXAMPLES = _discover_examples()
def _discover_image_examples():
"""Build image example catalogs from demo_renders/ directory."""
import re
cv_img_examples = {}
tpd_img_examples = {}
cv_output_renders = {}
tpd_output_renders = {}
if not DEMO_RENDERS.is_dir():
return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders
for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")):
m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name)
if not m:
continue
mech, rate_mVs = m.group(1), int(m.group(2))
rate_Vs = rate_mVs / 1000.0
cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs))
for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")):
m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name)
if not m:
continue
mech, idx = m.group(1), m.group(2)
tpd_img_examples.setdefault(mech, []).append((str(p), idx))
for mech in cv_img_examples:
renders = {}
for suffix in ["classification", "posteriors", "reconstruction", "concentration"]:
rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png"
if rp.exists():
renders[suffix] = str(rp)
if renders:
cv_output_renders[mech] = renders
for mech in tpd_img_examples:
renders = {}
for suffix in ["classification", "posteriors", "reconstruction"]:
rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png"
if rp.exists():
renders[suffix] = str(rp)
if renders:
tpd_output_renders[mech] = renders
return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders
(CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES,
CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples()
def _load_cv_image_example(mech_name):
"""Return (files, scan_rates_str) for a CV image example."""
if not mech_name or mech_name not in CV_IMG_EXAMPLES:
return [gr.update()] * 2
entries = CV_IMG_EXAMPLES[mech_name]
files = [e[0] for e in entries]
rates_str = ", ".join(f"{e[1]}" for e in entries)
return files, rates_str
def _load_tpd_image_example(mech_name):
"""Return (image_path, heating_rate_str) for a TPD image example."""
if not mech_name or mech_name not in TPD_IMG_EXAMPLES:
return [gr.update()] * 2
entries = TPD_IMG_EXAMPLES[mech_name]
return entries[0][0], entries[0][1]
def _load_cv_example(example_name):
"""Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example."""
if not example_name or example_name not in CV_EXAMPLES:
return [gr.update()] * 8
ex = CV_EXAMPLES[example_name]
return (
ex["files"],
ex["scan_rates"],
ex["E0_V"],
ex["T_K"],
ex["A_cm2"],
ex["C_mM"],
ex["D_cm2s"],
ex["n_electrons"],
)
def _load_tpd_example(example_name):
"""Return (files, heating_rates) for the chosen TPD example."""
if not example_name or example_name not in TPD_EXAMPLES:
return [gr.update()] * 2
ex = TPD_EXAMPLES[example_name]
return (
ex["files"],
ex["heating_rates"],
)
predictor = None
def get_predictor():
global predictor
if predictor is None:
ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None
tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None
ec_img = (str(EC_IMAGE_CHECKPOINT)
if (EC_IMAGE_CHECKPOINT and EC_IMAGE_CHECKPOINT.exists())
else None)
tpd_img = (str(TPD_IMAGE_CHECKPOINT)
if (TPD_IMAGE_CHECKPOINT and TPD_IMAGE_CHECKPOINT.exists())
else None)
ec_joint = (str(EC_JOINT_CHECKPOINT)
if (EC_JOINT_CHECKPOINT and EC_JOINT_CHECKPOINT.exists())
else None)
tpd_joint = (str(TPD_JOINT_CHECKPOINT)
if (TPD_JOINT_CHECKPOINT and TPD_JOINT_CHECKPOINT.exists())
else None)
predictor = SPARKPredictor(
ec_checkpoint=ec_ckpt,
tpd_checkpoint=tpd_ckpt,
ec_image_checkpoint=ec_img,
tpd_image_checkpoint=tpd_img,
ec_joint_checkpoint=ec_joint,
tpd_joint_checkpoint=tpd_joint,
device="cpu",
)
return predictor
# =========================================================================
# CV Analysis
# =========================================================================
def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2,
C_mM, D_cm2s, n_electrons, n_samples):
"""Analyze CV data from potentiostat CSV files.
Accepts CSV files with columns for potential (V) and current (A/mA/µA).
If the CSV includes a Time (s) column, the scan rate is auto-detected.
Otherwise, scan rates must be provided.
"""
if not files:
return _ec_error("Please upload at least one CSV file.")
scan_rates_text = scan_rates_text.strip() if scan_rates_text else ""
user_rates = None
if scan_rates_text:
try:
user_rates = [float(s.strip()) for s in scan_rates_text.split(",")]
except ValueError:
return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.")
if len(files) != len(user_rates):
return _ec_error(
f"Number of files ({len(files)}) must match number of "
f"scan rates ({len(user_rates)}).")
C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6
n = int(n_electrons) if n_electrons else 1
T = float(T_K) if T_K else 298.15
A = float(A_cm2) if A_cm2 else 0.0707
parsed_data = []
scan_rates = []
for idx, f in enumerate(files):
content = Path(f.name).read_text()
parsed = parse_cv_csv(content)
parsed_data.append(parsed)
if user_rates is not None:
v = user_rates[idx]
elif "scan_rate_Vs" in parsed:
v = parsed["scan_rate_Vs"]
else:
return _ec_error(
f"Cannot determine scan rate for file '{Path(f.name).name}'. "
"Either provide scan rates (V/s) or upload CSV files that "
"include a Time (s) column.")
scan_rates.append(v)
if E0_V:
e0 = float(E0_V)
e0_source = "user"
else:
e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data]
e0 = float(np.median(e0_estimates))
e0_source = "auto"
D = float(D_cm2s) if D_cm2s else 1e-5
potentials, fluxes, sigmas_list = [], [], []
for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)):
E, i_A = parsed["E_V"], parsed["i_A"]
theta, flux, sigma = nondimensionalize_cv(
E, i_A, v, e0, T, A, C_molcm3, D, n
)
potentials.append(theta)
fluxes.append(flux)
sigmas_list.append(sigma)
return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples)
_CURRENT_UNIT_SCALES = {
"µA": 1e-6,
"mA": 1e-3,
"A": 1.0,
"nA": 1e-9,
}
def _guess_current_unit(i_values):
"""Guess the current unit from the magnitude of digitized values."""
i_max = np.max(np.abs(i_values))
if i_max > 1e3:
return "nA"
if i_max > 100:
return "µA"
if i_max > 0.1:
return "µA"
if i_max > 1e-4:
return "mA"
return "A"
def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit,
n_samples, x_min, x_max, y_min, y_max,
method_label="Ensemble (recommended)"):
"""Analyze CV from uploaded plot images (one per scan rate).
When the image-input SPARK model is available we run a hybrid pipeline:
image-mode prediction in parallel with the digitizer + waveform-mode path,
and combine via `method_label` (Ensemble / Image-direct / Digitize-then-infer
/ Auto-fallback). When only the waveform model is available, falls back
to digitize-then-infer transparently.
Returns 10 outputs: the existing 8 (banner, probs, state, mech_dd, post,
table, recon, conc) plus method-comparison HTML + preprocessed-image
gallery preview.
"""
err = _ec_image_error
if not files:
return err("Please upload at least one image.")
try:
from digitizer import digitize_plot, auto_detect_axis_bounds
from PIL import Image as PILImage
except ImportError:
return err("Required libraries not available for image digitization.")
scan_rate_text = scan_rate_text.strip() if scan_rate_text else ""
if not scan_rate_text:
return err("Please enter the scan rate(s) (V/s), comma-separated.")
try:
scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")]
except ValueError:
return err("Invalid scan rates. Enter comma-separated numbers in V/s.")
if len(files) != len(scan_rates):
return err(
f"Number of images ({len(files)}) must match number of "
f"scan rates ({len(scan_rates)}).")
has_user_bounds = all(
v is not None and v != 0 for v in [x_min, x_max, y_min, y_max]
)
D = 1e-5
T = 298.15
A = 0.0707
C_molcm3 = 1e-6
n = 1
# First pass: digitize all images and estimate E0 per image
image_data = []
e0_estimates = []
for idx, f in enumerate(files):
img_arr = np.array(PILImage.open(f.name).convert("RGB"))
v_Vs = scan_rates[idx]
if has_user_bounds:
bounds = {
"x_min": float(x_min), "x_max": float(x_max),
"y_min": float(y_min), "y_max": float(y_max),
}
else:
bounds = auto_detect_axis_bounds(img_arr)
if bounds is None:
return err(
f"Could not auto-detect axis bounds for image {idx + 1}. "
"Please enter E min, E max, I min, I max under "
"'Axis overrides'.")
try:
E_V, I_raw = digitize_plot(
img_arr, bounds["x_min"], bounds["x_max"],
bounds["y_min"], bounds["y_max"],
threshold=int(threshold),
x_ticks=bounds.get("x_ticks"),
y_ticks=bounds.get("y_ticks"),
)
except Exception as exc:
return err(f"Digitization failed for image {idx + 1}: {exc}")
if current_unit and current_unit != "auto":
i_unit = current_unit
elif "y_unit" in bounds:
i_unit = bounds["y_unit"]
else:
i_unit = _guess_current_unit(I_raw)
i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6)
i_A = I_raw * i_scale
e0_estimates.append(float(estimate_E0(E_V, i_A)))
image_data.append((E_V, i_A, v_Vs, i_unit, bounds))
# Determine E0: user-provided or median of per-image estimates
if E0_V is not None and E0_V != 0:
e0 = float(E0_V)
e0_source = "user"
else:
e0 = float(np.median(e0_estimates))
e0_source = "auto"
# Second pass: nondimensionalize with shared E0
potentials, fluxes, sigmas_list = [], [], []
preproc_parts = []
for E_V, i_A, v_Vs, i_unit, bounds in image_data:
theta, flux, sigma = nondimensionalize_cv(
E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n
)
potentials.append(theta)
fluxes.append(flux)
sigmas_list.append(sigma)
preproc_parts.append(
f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, "
f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, "
f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})"
)
pil_images = [PILImage.open(f.name).convert("L") for f in files]
return _run_ec_analysis_hybrid(
pil_images, sigmas_list, potentials, fluxes, n_samples, method_label,
)
OOD_THRESHOLD = float(os.environ.get("SPARK_OOD_THRESHOLD", "0.5"))
def _ood_banner_update(ood_score, threshold=OOD_THRESHOLD):
"""Build the gr.HTML update for the OOD banner.
Shown only when the binary OOD head's P(in-distribution) score drops
below `threshold`. Hidden otherwise.
"""
if ood_score is None or ood_score >= threshold:
return gr.update(visible=False)
html = (
"<div class='ood-banner'>"
"<strong>Possible out-of-distribution input.</strong> "
f"OOD score = {ood_score:.2f} (below {threshold:.2f}). "
"The mechanism prediction below is the closest match within the "
"trained mechanism set; treat it as low-confidence and consider "
"whether the true mechanism may lie outside the catalog."
"</div>"
)
return gr.update(value=html, visible=True)
def _ec_error(msg=""):
"""Return empty outputs for EC error cases."""
return (
gr.update(visible=False), # ood_banner
None, None, gr.Dropdown(choices=[], value=None), None, None, None, None,
)
def _run_ec_analysis(potentials, fluxes, sigmas, n_samples):
"""Core EC analysis: predict + reconstruct for top mechanism."""
pred = get_predictor()
result = pred.predict_ec(potentials, fluxes, sigmas, n_samples=int(n_samples))
top_mech = result["predicted_mechanism"]
recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas)
fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="ec")
sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1])
mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs]
state = {
"result": result,
"potentials": [p.tolist() for p in potentials],
"fluxes": [f.tolist() for f in fluxes],
"sigmas": sigmas,
}
fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism(
top_mech, result, recon, sigmas
)
return (
_ood_banner_update(result.get("ood_score")),
fig_probs, state,
gr.Dropdown(choices=mech_choices, value=mech_choices[0]),
fig_post, fig_table, fig_recon, fig_conc,
)
_HYBRID_MODE_LABELS = {
"Ensemble (recommended)": "ensemble",
"Image-direct": "image_only",
"Digitize-then-infer": "digitize_only",
"Auto-fallback": "auto_fallback",
}
_JOINT_LABEL = "Joint image+waveform (Phase 2)"
def _hybrid_method_to_internal(label):
"""Map the user-facing radio label to the internal mode key."""
if not label:
return "ensemble"
if label == _JOINT_LABEL:
return "joint"
return _HYBRID_MODE_LABELS.get(label, "ensemble")
def _hybrid_choices_and_default(domain):
"""Return (radio_choices, default_label) for the inference-strategy radio,
inserting the Phase 2 joint option (and making it default) when the
corresponding joint checkpoint is loaded.
Rationale for ``Ensemble (recommended)`` as the pre-Phase-2 default:
head-to-head eval (``outputs/image_vs_digitize/decision.md``) showed
that on the rendered held-out test split image-only beats digitize-
then-infer by ~30 pp (CV: 40.9% vs 5.4%, TPD: 36.1% vs 2.8%), but on
real-world-style stress images the ordering FLIPS (CV: 6.7% vs 26.7%,
TPD: 0% vs 60%). Ensemble averages both posteriors, so it's robust to
either failure mode. The default automatically switches to
``Joint image+waveform (Phase 2)`` once the joint checkpoint is
available.
"""
pred = get_predictor()
has_joint = (pred.has_ec_joint_model if domain == "ec"
else pred.has_tpd_joint_model)
base = list(_HYBRID_MODE_LABELS.keys())
if has_joint:
return [_JOINT_LABEL] + base, _JOINT_LABEL
return base, "Ensemble (recommended)"
def _method_comparison_html(hybrid_out, threshold=OOD_THRESHOLD):
"""Render the Method-comparison panel.
Shows per-method top mechanism, top probability, OOD score, and an
agreement badge. Hidden when only one method ran (no comparison to make).
Includes a third 'Joint (Phase 2)' card when the joint encoder ran.
"""
img = hybrid_out.get("image_mode")
wave = hybrid_out.get("waveform_mode")
joint = hybrid_out.get("joint_mode")
method_used = hybrid_out.get("method_used", "")
available = [m for m in (img, wave, joint) if m is not None]
if len(available) < 2:
return gr.update(visible=False)
def _ood_html(score):
if score is None:
return "&mdash;"
col = "#1b5e20" if score >= threshold else "#7a0000"
return (f"<span style='color:{col};font-weight:600'>"
f"{score:.2f}</span>")
def _card(title, result):
top = max(result["mechanism_probs"], key=result["mechanism_probs"].get)
return f"""
<div style='flex:1;min-width:220px;border:1px solid #e0e0e0;border-radius:8px;
padding:10px 12px;background:white;'>
<div style='font-weight:600;color:#333;margin-bottom:4px;'>{title}</div>
<div>top mech: <strong>{top}</strong>
({result['mechanism_probs'][top]:.1%})</div>
<div>OOD score: {_ood_html(result.get('ood_score'))}</div>
</div>"""
cards = []
if img is not None:
cards.append(_card("Image-direct (Phase 1)", img))
if wave is not None:
cards.append(_card("Digitize-then-infer", wave))
if joint is not None:
cards.append(_card("Joint image+waveform (Phase 2)", joint))
tops = [max(r["mechanism_probs"], key=r["mechanism_probs"].get)
for r in available]
all_agree = len(set(tops)) == 1
if all_agree:
badge = ("<span style='display:inline-block;padding:2px 10px;border-radius:8px;"
"background:#dff5e1;color:#1b5e20;font-weight:600;font-size:13px;'>"
"methods agree</span>")
else:
badge = ("<span style='display:inline-block;padding:2px 10px;border-radius:8px;"
"background:#ffe2c2;color:#7a3c00;font-weight:600;font-size:13px;'>"
"methods disagree \u2014 review carefully</span>")
footnote = (
"Image-direct (Phase 1) is the image-only SPARK; Digitize-then-infer "
"uses the headline waveform model on a curve extracted from the plot; "
"Joint (Phase 2) fuses both signals inside a single encoder trained "
"with real-world distortion augmentation. When available, the joint "
"model is the recommended primary path."
)
html = f"""
<div style='border:1px solid #d0d0d0;border-radius:10px;padding:14px 16px;
background:#fafafa;margin:8px 0;font-size:14px;line-height:1.5;'>
<div style='display:flex;justify-content:space-between;align-items:center;margin-bottom:8px;'>
<strong>Method comparison</strong>
<span style='color:#555;font-size:12px;'>headline: {method_used}</span>
</div>
<div style='display:flex;gap:12px;flex-wrap:wrap;'>
{''.join(cards)}
</div>
<div style='margin-top:8px;'>{badge}</div>
<div style='margin-top:6px;color:#666;font-size:12px;'>{footnote}</div>
</div>
"""
return gr.update(value=html, visible=True)
def _preprocessing_preview(hybrid_out, original_pils):
"""Build a gallery preview showing what image-mode actually consumed
after auto-crop + gridline removal."""
pre_meta = hybrid_out.get("preprocessing_meta") or []
if not pre_meta or hybrid_out.get("image_mode") is None:
return gr.update(visible=False)
try:
from image_preprocessing import prepare_for_image_mode
previews = []
for orig, meta in zip(original_pils, pre_meta):
prep, _ = prepare_for_image_mode(orig)
tag_parts = []
if meta.get("was_cropped"):
tag_parts.append("cropped")
if meta.get("was_cleaned"):
tag_parts.append(
f"removed {meta.get('n_horiz_gridlines', 0)} horiz "
f"+ {meta.get('n_vert_gridlines', 0)} vert lines"
)
tag = " | ".join(tag_parts) if tag_parts else "no changes needed"
previews.append((prep, tag))
return gr.update(value=previews, visible=True)
except Exception as exc:
print(f"[preview] failed: {exc}")
return gr.update(visible=False)
def _ec_image_error(msg=""):
"""Return empty outputs for EC image-tab error cases (10 outputs)."""
return (
gr.update(visible=False), # ood_banner
None, None, # probs, state
gr.Dropdown(choices=[], value=None),
None, None, None, None, # post, table, recon, conc
gr.update(visible=False), # comparison_html
gr.update(visible=False), # preview_gallery
)
def _run_ec_analysis_hybrid(pil_images, sigmas, potentials, fluxes,
n_samples, method_label):
"""Hybrid CV analysis: runs image-mode + digitizer-mode in parallel
via predict_ec_hybrid and renders both the headline result (selected
by `method_label`) and the per-method comparison panel.
Returns a 10-tuple matching `ec_img_outputs` in the UI wiring.
"""
pred = get_predictor()
mode = _hybrid_method_to_internal(method_label)
try:
hybrid_out = pred.predict_ec_hybrid(
pil_images, sigmas,
potentials=potentials, fluxes=fluxes,
n_samples=int(n_samples), mode=mode,
)
except Exception as exc:
print(f"[hybrid CV] failed: {exc}")
return _ec_image_error()
headline = hybrid_out["headline"]
top_mech = headline["predicted_mechanism"]
recon = pred.reconstruct_ec(headline, potentials, fluxes, sigmas)
fig_probs = plot_mechanism_probs(headline["mechanism_probs"], domain="ec")
sorted_mechs = sorted(headline["mechanism_probs"].items(), key=lambda x: -x[1])
mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs]
state = {
"result": headline,
"potentials": [p.tolist() for p in potentials],
"fluxes": [f.tolist() for f in fluxes],
"sigmas": sigmas,
}
fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism(
top_mech, headline, recon, sigmas
)
comparison = _method_comparison_html(hybrid_out)
preview = _preprocessing_preview(hybrid_out, pil_images)
return (
_ood_banner_update(headline.get("ood_score")),
fig_probs, state,
gr.Dropdown(choices=mech_choices, value=mech_choices[0]),
fig_post, fig_table, fig_recon, fig_conc,
comparison, preview,
)
def _render_ec_mechanism(mech, result, recon, sigmas):
"""Render posteriors, param table, reconstruction, and concentration for one EC mechanism."""
stats = result["parameter_stats"].get(mech)
samples = result["posterior_samples"].get(mech)
fig_posteriors = None
fig_table = None
if stats and samples is not None:
fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="ec")
fig_table = plot_parameter_table(stats, mech)
fig_recon = None
fig_conc = None
if recon is not None:
scan_labels = [f"\u03c3 = {s:.2f}" for s in sigmas] if sigmas else None
fig_recon = plot_reconstruction(
recon["observed"], recon["reconstructed"], domain="ec",
nrmses=recon.get("nrmse"), r2s=recon.get("r2"),
scan_labels=scan_labels,
)
conc_curves = recon.get("concentrations")
if conc_curves:
fig_conc = plot_concentration_profiles(conc_curves, scan_labels=scan_labels)
return fig_posteriors, fig_table, fig_recon, fig_conc
def _on_ec_mechanism_change(mech_choice, state):
"""Callback when user selects a different EC mechanism from the dropdown."""
if not state or not mech_choice:
return None, None, None, None
mech = mech_choice.split(" (")[0]
result = state["result"]
potentials = [np.array(p) for p in state["potentials"]]
fluxes = [np.array(f) for f in state["fluxes"]]
sigmas = state["sigmas"]
pred = get_predictor()
recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas, mechanism=mech)
return _render_ec_mechanism(mech, result, recon, sigmas)
# =========================================================================
# TPD Analysis
# =========================================================================
def analyze_tpd(files, heating_rates_text, n_samples):
"""Analyze TPD data."""
if not files:
return _tpd_error("Please upload at least one CSV file.")
temperatures, rates = [], []
csv_betas = []
for f in files:
content = Path(f.name).read_text()
parsed = parse_tpd_csv(content)
temperatures.append(parsed["T_K"])
rates.append(parsed["signal"])
if "beta_Ks" in parsed:
csv_betas.append(parsed["beta_Ks"])
heating_rates_text = heating_rates_text.strip() if heating_rates_text else ""
if heating_rates_text:
try:
betas = [float(s.strip()) for s in heating_rates_text.split(",")]
except ValueError:
return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.")
if len(files) != len(betas):
return _tpd_error(
f"Number of files ({len(files)}) must match heating rates ({len(betas)}).")
elif len(csv_betas) == len(files):
betas = csv_betas
else:
return _tpd_error(
"Please enter the heating rate (β in K/s) for each file. "
"This value is critical for correct inference. "
"Alternatively, include a 'Time (s)' column in your CSV so β can be computed automatically.")
return _run_tpd_analysis(temperatures, rates, betas, n_samples)
def analyze_tpd_image(files, heating_rates_text, threshold, n_samples,
x_min, x_max, y_min, y_max,
method_label="Ensemble (recommended)"):
"""Analyze TPD from uploaded plot images (one per heating rate).
Hybrid: image-mode + digitizer-then-waveform run in parallel and are
combined per `method_label`. Axis bounds auto-detected via OCR.
Returns 9 outputs: existing 7 + method-comparison HTML + preview gallery.
"""
err = _tpd_image_error
if not files:
return err("Please upload at least one image.")
try:
from digitizer import digitize_plot, auto_detect_axis_bounds
from PIL import Image as PILImage
except ImportError:
return err("Required libraries not available for image digitization.")
heating_rates_text = heating_rates_text.strip() if heating_rates_text else ""
if not heating_rates_text:
return err(
"Please enter the heating rate(s) (β in K/s), comma-separated. "
"This value is critical for correct inference.")
try:
betas = [float(s.strip()) for s in heating_rates_text.split(",")]
except ValueError:
return err("Invalid heating rates. Enter comma-separated numbers in K/s.")
if len(files) != len(betas):
return err(
f"Number of images ({len(files)}) must match number of "
f"heating rates ({len(betas)}).")
has_user_bounds = all(
v is not None and v != 0 for v in [x_min, x_max, y_min, y_max]
)
temperatures, rates = [], []
for idx, f in enumerate(files):
img_arr = np.array(PILImage.open(f.name).convert("RGB"))
if has_user_bounds:
bounds = {
"x_min": float(x_min), "x_max": float(x_max),
"y_min": float(y_min), "y_max": float(y_max),
}
else:
bounds = auto_detect_axis_bounds(img_arr)
if bounds is None:
return err(
f"Could not auto-detect axis bounds for image {idx + 1}. "
"Please enter T min, T max, Signal min, Signal max "
"under 'Axis overrides'.")
try:
x_data, y_data = digitize_plot(
img_arr, bounds["x_min"], bounds["x_max"],
bounds["y_min"], bounds["y_max"],
threshold=int(threshold),
x_ticks=bounds.get("x_ticks"),
y_ticks=bounds.get("y_ticks"),
)
except Exception as exc:
return err(f"Digitization failed for image {idx + 1}: {exc}")
temperatures.append(x_data)
rates.append(y_data)
pil_images = [PILImage.open(f.name).convert("L") for f in files]
return _run_tpd_analysis_hybrid(
pil_images, betas, temperatures, rates, n_samples, method_label,
)
def _tpd_error(msg=""):
"""Return empty outputs for TPD error cases (CSV path: 7 outputs)."""
return (
gr.update(visible=False), # ood_banner
None, None, gr.Dropdown(choices=[], value=None), None, None, None,
)
def _tpd_image_error(msg=""):
"""Return empty outputs for TPD image-tab errors (9 outputs)."""
return (
gr.update(visible=False), # ood_banner
None, None, gr.Dropdown(choices=[], value=None), None, None, None,
gr.update(visible=False), # comparison_html
gr.update(visible=False), # preview_gallery
)
def _run_tpd_analysis(temperatures, rates, betas, n_samples):
"""Core TPD analysis: predict + reconstruct for top mechanism."""
pred = get_predictor()
result = pred.predict_tpd(temperatures, rates, betas, n_samples=int(n_samples))
top_mech = result["predicted_mechanism"]
recon = pred.reconstruct_tpd(result, temperatures, rates, betas)
fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="tpd")
sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1])
mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs]
state = {
"result": result,
"temperatures": [t.tolist() for t in temperatures],
"rates": [r.tolist() for r in rates],
"betas": betas,
}
fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, result, recon, betas)
return (
_ood_banner_update(result.get("ood_score")),
fig_probs, state,
gr.Dropdown(choices=mech_choices, value=mech_choices[0]),
fig_post, fig_table, fig_recon,
)
def _run_tpd_analysis_hybrid(pil_images, betas, temperatures, rates,
n_samples, method_label):
"""Hybrid TPD analysis: image-mode + digitizer-then-waveform combined
via predict_tpd_hybrid. Returns 9 outputs to match `tpd_img_outputs`."""
pred = get_predictor()
mode = _hybrid_method_to_internal(method_label)
try:
hybrid_out = pred.predict_tpd_hybrid(
pil_images, betas,
temperatures=temperatures, rates=rates,
n_samples=int(n_samples), mode=mode,
)
except Exception as exc:
print(f"[hybrid TPD] failed: {exc}")
return _tpd_image_error()
headline = hybrid_out["headline"]
top_mech = headline["predicted_mechanism"]
recon = pred.reconstruct_tpd(headline, temperatures, rates, betas)
fig_probs = plot_mechanism_probs(headline["mechanism_probs"], domain="tpd")
sorted_mechs = sorted(headline["mechanism_probs"].items(), key=lambda x: -x[1])
mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs]
state = {
"result": headline,
"temperatures": [t.tolist() for t in temperatures],
"rates": [r.tolist() for r in rates],
"betas": betas,
}
fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, headline, recon, betas)
comparison = _method_comparison_html(hybrid_out)
preview = _preprocessing_preview(hybrid_out, pil_images)
return (
_ood_banner_update(headline.get("ood_score")),
fig_probs, state,
gr.Dropdown(choices=mech_choices, value=mech_choices[0]),
fig_post, fig_table, fig_recon,
comparison, preview,
)
def _render_tpd_mechanism(mech, result, recon, betas):
"""Render posteriors, param table, and reconstruction for one TPD mechanism."""
stats = result["parameter_stats"].get(mech)
samples = result["posterior_samples"].get(mech)
fig_posteriors = None
fig_table = None
if stats and samples is not None:
fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="tpd")
fig_table = plot_parameter_table(stats, mech)
fig_recon = None
if recon is not None:
scan_labels = [f"\u03b2 = {b:.2f} K/s" for b in betas] if betas else None
fig_recon = plot_reconstruction(
recon["observed"], recon["reconstructed"], domain="tpd",
nrmses=recon.get("nrmse"), r2s=recon.get("r2"),
scan_labels=scan_labels,
)
return fig_posteriors, fig_table, fig_recon
def _on_tpd_mechanism_change(mech_choice, state):
"""Callback when user selects a different TPD mechanism from the dropdown."""
if not state or not mech_choice:
return None, None, None
mech = mech_choice.split(" (")[0]
result = state["result"]
temperatures = [np.array(t) for t in state["temperatures"]]
rates = [np.array(r) for r in state["rates"]]
betas = state["betas"]
pred = get_predictor()
recon = pred.reconstruct_tpd(result, temperatures, rates, betas, mechanism=mech)
return _render_tpd_mechanism(mech, result, recon, betas)
# =========================================================================
# Shared helpers
# =========================================================================
def download_results(result_text):
"""Create a downloadable JSON from the summary."""
if not result_text:
return None
tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w")
tmp.write(result_text)
tmp.close()
return tmp.name
# =========================================================================
# Gradio UI
# =========================================================================
def _build_ec_output_section(prefix):
"""Build shared output components for one EC input tab.
Returns (ood_banner, probs, state, mech_dd, posteriors, param_table, recon, conc).
"""
with gr.Group(elem_classes=["results-card"]):
gr.HTML("<h3 style='margin:0 0 10px 0;color:#0F172A;font-size:1.05em;'>Results</h3>")
ood_banner = gr.HTML(visible=False)
probs = gr.Plot(label="Mechanism Classification")
state = gr.State(value=None)
mech_dd = gr.Dropdown(
label="Inspect a mechanism in detail",
choices=[],
interactive=True,
)
with gr.Tabs(elem_classes=["detail-tabs"]):
with gr.Tab("Posteriors"):
posteriors = gr.Plot(label=None, show_label=False)
with gr.Tab("Parameters"):
param_table = gr.Plot(label=None, show_label=False)
with gr.Tab("Reconstruction"):
recon = gr.Plot(label=None, show_label=False)
with gr.Tab("Concentration"):
conc = gr.Plot(label=None, show_label=False)
return ood_banner, probs, state, mech_dd, posteriors, param_table, recon, conc
def _build_tpd_output_section(prefix):
"""Build shared output components for one TPD input tab.
Returns (ood_banner, probs, state, mech_dd, posteriors, param_table, recon).
"""
with gr.Group(elem_classes=["results-card"]):
gr.HTML("<h3 style='margin:0 0 10px 0;color:#0F172A;font-size:1.05em;'>Results</h3>")
ood_banner = gr.HTML(visible=False)
probs = gr.Plot(label="Mechanism Classification")
state = gr.State(value=None)
mech_dd = gr.Dropdown(
label="Inspect a mechanism in detail",
choices=[],
interactive=True,
)
with gr.Tabs(elem_classes=["detail-tabs"]):
with gr.Tab("Posteriors"):
posteriors = gr.Plot(label=None, show_label=False)
with gr.Tab("Parameters"):
param_table = gr.Plot(label=None, show_label=False)
with gr.Tab("Reconstruction"):
recon = gr.Plot(label=None, show_label=False)
return ood_banner, probs, state, mech_dd, posteriors, param_table, recon
CUSTOM_CSS = """
/* ---------- Page container & typography ---------- */
.gradio-container { background: #F8FAFC !important; }
.trace-page { max-width: 1180px; margin: 0 auto; padding: 8px 24px 32px 24px; }
.trace-page * { font-feature-settings: "ss01"; }
/* ---------- Header ---------- */
.main-header { text-align: center; padding: 28px 16px 16px 16px; }
.main-header h1 {
font-size: 2.4em;
margin: 0 0 6px 0;
letter-spacing: -0.6px;
font-weight: 700;
color: #0F172A;
}
.main-header h1 .accent { color: #2563EB; }
.main-header p {
color: #64748B;
font-size: 1.02em;
max-width: 720px;
margin: 0 auto;
line-height: 1.5;
}
.main-header .accent-rule {
height: 3px;
width: 56px;
margin: 14px auto 0 auto;
background: linear-gradient(90deg, #2563EB, #60A5FA);
border-radius: 2px;
}
/* ---------- Cards ---------- */
.input-card, .results-card {
background: #FFFFFF !important;
border: 1px solid #E5E7EB !important;
border-radius: 14px !important;
padding: 22px 26px !important;
margin-top: 16px !important;
box-shadow: 0 1px 2px rgba(15, 23, 42, 0.04) !important;
}
.results-card { padding-top: 18px !important; }
.results-empty {
color: #94A3B8;
text-align: center;
padding: 60px 16px;
font-size: 1.0em;
}
.results-empty .arrow { font-size: 2em; display: block; margin-bottom: 8px; }
/* ---------- Helper text ---------- */
.helper {
color: #64748B;
font-size: 0.93em;
line-height: 1.5;
margin: 0 0 14px 0;
}
.helper code {
background: #F1F5F9;
padding: 1px 6px;
border-radius: 4px;
font-size: 0.92em;
color: #334155;
}
/* ---------- Mode radio (CSV / Image toggle) ---------- */
.mode-radio { margin: 4px 0 0 0 !important; }
.mode-radio label[data-testid="block-label"] { display: none !important; }
.mode-radio .wrap { gap: 6px !important; }
/* ---------- Top tabs polish ---------- */
.tabs > .tab-nav {
border-bottom: 1px solid #E5E7EB !important;
padding-left: 4px !important;
}
.tabs > .tab-nav button {
font-size: 1.02em !important;
font-weight: 500 !important;
color: #64748B !important;
padding: 12px 20px !important;
border-bottom: 2px solid transparent !important;
text-transform: none !important;
}
.tabs > .tab-nav button.selected {
color: #2563EB !important;
border-bottom-color: #2563EB !important;
}
/* ---------- Detail tabs (per-mechanism plot views) ---------- */
.detail-tabs { margin-top: 14px; }
.detail-tabs .tab-nav button { font-size: 0.95em !important; padding: 8px 14px !important; }
/* ---------- Accordion ---------- */
.gr-accordion .label-wrap { padding: 8px 12px !important; }
/* ---------- Buttons ---------- */
.analyze-btn { margin-top: 14px !important; }
.analyze-btn button {
font-size: 1.05em !important;
padding: 12px 18px !important;
font-weight: 600 !important;
}
/* ---------- About tab ---------- */
.about-card {
background: #FFFFFF;
border: 1px solid #E5E7EB;
border-radius: 14px;
padding: 26px 30px;
margin-top: 16px;
box-shadow: 0 1px 2px rgba(15, 23, 42, 0.04);
}
.about-card h2 { margin-top: 0; font-size: 1.4em; color: #0F172A; }
.about-card table { font-size: 0.93em; }
.about-card pre {
background: #F8FAFC;
border: 1px solid #E5E7EB;
border-radius: 8px;
padding: 14px 16px;
font-size: 0.9em;
color: #334155;
}
/* ---------- OOD banner ---------- */
.ood-banner {
background: #FEF3C7;
border: 1px solid #F59E0B;
color: #92400E;
border-radius: 10px;
padding: 12px 16px;
margin: 4px 0 14px 0;
font-size: 0.96em;
line-height: 1.45;
}
.ood-banner strong { color: #78350F; }
/* ---------- Footer ---------- */
.trace-footer {
text-align: center;
color: #94A3B8;
font-size: 0.88em;
padding: 28px 0 8px 0;
}
.trace-footer a { color: #64748B; text-decoration: none; border-bottom: 1px dotted #CBD5E1; }
.trace-footer a:hover { color: #2563EB; border-bottom-color: #2563EB; }
/* ---------- Hide Gradio chrome ---------- */
footer { display: none !important; }
"""
def build_app():
with gr.Blocks(
title="SPARK — Bayesian Inference for Electrochemistry & Catalysis",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="slate",
font=gr.themes.GoogleFont("Inter"),
),
css=CUSTOM_CSS,
) as app:
gr.HTML("<div class='trace-page'>")
gr.HTML(
"<div class='main-header'>"
"<h1><span class='accent'>SPARK</span></h1>"
"<p><strong>S</strong>imulation-based <strong>P</strong>osterior "
"<strong>A</strong>mortization for <strong>R</strong>eaction "
"<strong>K</strong>inetics &mdash; Bayesian inference for "
"reaction mechanisms and kinetic parameters from cyclic "
"voltammetry (CV) or temperature-programmed desorption (TPD) "
"data, in one forward pass.</p>"
"<div class='accent-rule'></div>"
"</div>"
)
with gr.Tabs():
# =================================================================
# Tab 1: CV Analysis
# =================================================================
with gr.Tab("CV Analysis"):
cv_mode = gr.Radio(
choices=["CSV Data", "From Image"],
value="CSV Data",
label=None,
show_label=False,
elem_classes=["mode-radio"],
)
# --- CSV upload mode (primary) ---
with gr.Column(visible=True) as cv_csv_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"<p class='helper'>One CSV per scan rate. Expected columns: "
"<code>Potential (V)</code>, <code>Current (A/mA/&micro;A)</code>, "
"optional <code>Time (s)</code> &mdash; if present, the scan rate is "
"inferred as the median of <code>|dE/dt|</code> over the forward "
"sweep (no need to specify it manually).</p>"
)
if CV_EXAMPLES:
with gr.Row():
cv_example_dd = gr.Dropdown(
label="Try an example",
choices=list(CV_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
cv_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
cv_files = gr.File(
label="CSV files (one per scan rate)",
file_count="multiple",
file_types=[".csv", ".txt"],
)
cv_rates = gr.Textbox(
label="Scan rates (V/s), comma-separated",
placeholder="e.g., 0.01, 0.1, 1.0 (leave empty if CSV has time column)",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
cv_E0 = gr.Number(
label="Formal potential E\u2080 (V)",
value=None,
info="Auto-estimated from peak positions if empty",
)
cv_T = gr.Number(label="Temperature (K)", value=298.15)
cv_A = gr.Number(label="Electrode area (cm\u00b2)", value=0.0707)
with gr.Row():
cv_C = gr.Number(label="Concentration (mM)", value=1.0)
cv_D = gr.Number(
label="Diffusion coeff D (cm\u00b2/s)",
value=None,
info="Estimated via Randles-\u0160ev\u010d\u00edk if empty",
)
cv_n = gr.Number(label="Number of electrons", value=1, precision=0)
cv_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
cv_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(cv_ood_banner, cv_probs, cv_state,
cv_mech_dd, cv_posteriors, cv_param_table,
cv_recon, cv_conc) = _build_ec_output_section("cv")
# --- Image mode ---
with gr.Column(visible=False) as cv_image_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"<p class='helper'>One image per scan rate (potential on x-axis, "
"current on y-axis). Axis bounds auto-detected via OCR &mdash; "
"override in Advanced if needed. "
"<em>Note:</em> digitized curves are inherently noisier than the "
"simulated data the model was trained on, so the OOD detector may "
"flag image inputs and reconstruction quality is typically lower "
"than the CSV path; for production use, prefer CSV.</p>"
)
if CV_IMG_EXAMPLES:
with gr.Row():
cv_img_example_dd = gr.Dropdown(
label="Try an example",
choices=list(CV_IMG_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
cv_img_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
cv_img_example_gallery = gr.Gallery(
label="Input plot images (one per scan rate) \u2014 these are what the model digitizes",
columns=3, height=220, object_fit="contain",
interactive=False,
)
cv_img_files = gr.File(
label="Plot images (one per scan rate)",
file_count="multiple",
file_types=["image"],
)
cv_img_scan_rate = gr.Textbox(
label="Scan rates (V/s), comma-separated",
placeholder="e.g., 0.01, 0.1, 1.0",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
cv_img_E0 = gr.Number(
label="Formal potential E\u2080 (V)",
value=None,
info="Auto-estimated from peaks if empty",
)
cv_img_threshold = gr.Slider(
0, 255, value=0, step=1,
label="Binarization threshold (0 = auto)",
)
cv_img_current_unit = gr.Dropdown(
label="Current unit on y-axis",
choices=["auto", "\u00b5A", "mA", "A", "nA"],
value="auto",
info="Select the unit shown on the y-axis of your plot",
)
with gr.Row():
cv_img_xmin = gr.Number(label="E min (V)", value=None)
cv_img_xmax = gr.Number(label="E max (V)", value=None)
cv_img_ymin = gr.Number(label="I min", value=None)
cv_img_ymax = gr.Number(label="I max", value=None)
cv_img_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
_cv_choices, _cv_default = _hybrid_choices_and_default("ec")
cv_img_method = gr.Radio(
choices=_cv_choices,
value=_cv_default,
label="Inference strategy",
info=("Joint (Phase 2) fuses the rendered image "
"with digitizer-extracted waveforms in a "
"single encoder; Ensemble averages "
"image-direct and digitize-then-infer; "
"Auto-fallback uses image-direct unless "
"its OOD score is low."),
)
cv_img_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(cv_img_ood_banner, cv_img_probs, cv_img_state,
cv_img_mech_dd, cv_img_posteriors, cv_img_param_table,
cv_img_recon, cv_img_conc) = _build_ec_output_section("cv_img")
with gr.Accordion("Method comparison", open=False):
cv_img_comparison = gr.HTML(visible=False)
cv_img_preview = gr.Gallery(
label="Preprocessed image fed to image-mode CNN",
columns=3, height=180, object_fit="contain",
interactive=False, visible=False,
)
cv_mode.change(
lambda v: (
gr.update(visible=v == "CSV Data"),
gr.update(visible=v == "From Image"),
),
inputs=[cv_mode],
outputs=[cv_csv_group, cv_image_group],
)
# CSV wiring
ec_outputs = [
cv_ood_banner,
cv_probs, cv_state,
cv_mech_dd, cv_posteriors, cv_param_table,
cv_recon, cv_conc,
]
cv_btn.click(
analyze_cv,
inputs=[
cv_files, cv_rates, cv_E0, cv_T,
cv_A, cv_C, cv_D, cv_n, cv_nsamples,
],
outputs=ec_outputs,
)
cv_mech_dd.change(
_on_ec_mechanism_change,
inputs=[cv_mech_dd, cv_state],
outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc],
)
if CV_EXAMPLES:
cv_example_btn.click(
_load_cv_example,
inputs=[cv_example_dd],
outputs=[
cv_files, cv_rates, cv_E0, cv_T,
cv_A, cv_C, cv_D, cv_n,
],
)
# Image wiring (hybrid: 8 base outputs + comparison + preview)
ec_img_outputs = [
cv_img_ood_banner,
cv_img_probs, cv_img_state,
cv_img_mech_dd, cv_img_posteriors, cv_img_param_table,
cv_img_recon, cv_img_conc,
cv_img_comparison, cv_img_preview,
]
cv_img_btn.click(
analyze_cv_image,
inputs=[
cv_img_files, cv_img_scan_rate, cv_img_E0,
cv_img_threshold, cv_img_current_unit,
cv_img_nsamples,
cv_img_xmin, cv_img_xmax,
cv_img_ymin, cv_img_ymax,
cv_img_method,
],
outputs=ec_img_outputs,
)
cv_img_mech_dd.change(
_on_ec_mechanism_change,
inputs=[cv_img_mech_dd, cv_img_state],
outputs=[cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc],
)
if CV_IMG_EXAMPLES:
def _cv_img_input_gallery(mech_name):
"""Show the actual input PNGs the user will analyze."""
if not mech_name or mech_name not in CV_IMG_EXAMPLES:
return []
return [e[0] for e in CV_IMG_EXAMPLES[mech_name]]
def _on_cv_img_example_select(mech_name):
files, rates = _load_cv_image_example(mech_name)
return files, rates, _cv_img_input_gallery(mech_name)
cv_img_example_btn.click(
_on_cv_img_example_select,
inputs=[cv_img_example_dd],
outputs=[cv_img_files, cv_img_scan_rate, cv_img_example_gallery],
)
cv_img_example_dd.change(
_cv_img_input_gallery,
inputs=[cv_img_example_dd],
outputs=[cv_img_example_gallery],
)
# =================================================================
# Tab 2: TPD Analysis
# =================================================================
with gr.Tab("TPD Analysis"):
tpd_mode = gr.Radio(
choices=["CSV Data", "From Image"],
value="CSV Data",
label=None,
show_label=False,
elem_classes=["mode-radio"],
)
# --- CSV mode ---
with gr.Column(visible=True) as tpd_csv_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"<p class='helper'>One CSV per heating rate. Expected columns: "
"<code>Temperature (K)</code>, <code>Signal</code>, optional "
"<code>Time (s)</code> &mdash; if present, the heating rate &beta; is "
"inferred as <code>dT/dt</code> from the time stamps "
"(no need to specify it manually).</p>"
)
if TPD_EXAMPLES:
with gr.Row():
tpd_example_dd = gr.Dropdown(
label="Try an example",
choices=list(TPD_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
tpd_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
tpd_files = gr.File(
label="CSV files (one per heating rate)",
file_count="multiple",
file_types=[".csv", ".txt"],
)
tpd_betas = gr.Textbox(
label="Heating rates \u03b2 (K/s), comma-separated",
placeholder="e.g., 0.3, 2.6, 22.1 (leave empty if CSV has time column)",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
tpd_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
tpd_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(tpd_ood_banner, tpd_probs, tpd_state,
tpd_mech_dd, tpd_posteriors, tpd_param_table,
tpd_recon) = _build_tpd_output_section("tpd")
# --- Image mode ---
with gr.Column(visible=False) as tpd_image_group:
with gr.Group(elem_classes=["input-card"]):
gr.HTML(
"<p class='helper'>One image per heating rate (temperature on "
"x-axis, signal on y-axis). Axis bounds auto-detected via OCR "
"&mdash; override in Advanced if needed. "
"<em>Note:</em> digitized curves are inherently noisier than the "
"simulated data the model was trained on, so the OOD detector may "
"flag image inputs and reconstruction quality is typically lower "
"than the CSV path; for production use, prefer CSV.</p>"
)
if TPD_IMG_EXAMPLES:
with gr.Row():
tpd_img_example_dd = gr.Dropdown(
label="Try an example",
choices=list(TPD_IMG_EXAMPLES.keys()),
value=None,
interactive=True,
scale=4,
)
tpd_img_example_btn = gr.Button(
"Load", variant="secondary", scale=1,
)
tpd_img_example_gallery = gr.Gallery(
label="Input plot image \u2014 this is what the model digitizes",
columns=3, height=220, object_fit="contain",
interactive=False,
)
tpd_img_files = gr.File(
label="Plot images (one per heating rate)",
file_count="multiple",
file_types=["image"],
)
tpd_img_betas = gr.Textbox(
label="Heating rates \u03b2 (K/s), comma-separated",
placeholder="e.g., 0.3, 2.6, 22.1",
value="",
)
with gr.Accordion("Advanced parameters", open=False):
with gr.Row():
tpd_img_threshold = gr.Slider(
0, 255, value=0, step=1,
label="Binarization threshold (0 = auto)",
)
with gr.Row():
tpd_img_xmin = gr.Number(label="T min (K)", value=None)
tpd_img_xmax = gr.Number(label="T max (K)", value=None)
tpd_img_ymin = gr.Number(label="Signal min", value=None)
tpd_img_ymax = gr.Number(label="Signal max", value=None)
tpd_img_nsamples = gr.Slider(
100, 2000, value=500, step=100,
label="Posterior samples",
)
_tpd_choices, _tpd_default = _hybrid_choices_and_default("tpd")
tpd_img_method = gr.Radio(
choices=_tpd_choices,
value=_tpd_default,
label="Inference strategy",
info=("Joint (Phase 2) fuses the image with "
"digitizer-extracted waveforms in a "
"single encoder; Ensemble averages "
"image-direct and digitize-then-infer; "
"Auto-fallback uses image-direct unless "
"its OOD score is low."),
)
tpd_img_btn = gr.Button(
"Analyze", variant="primary",
elem_classes=["analyze-btn"],
)
(tpd_img_ood_banner, tpd_img_probs, tpd_img_state,
tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table,
tpd_img_recon) = _build_tpd_output_section("tpd_img")
with gr.Accordion("Method comparison", open=False):
tpd_img_comparison = gr.HTML(visible=False)
tpd_img_preview = gr.Gallery(
label="Preprocessed image fed to image-mode CNN",
columns=3, height=180, object_fit="contain",
interactive=False, visible=False,
)
tpd_mode.change(
lambda v: (
gr.update(visible=v == "CSV Data"),
gr.update(visible=v == "From Image"),
),
inputs=[tpd_mode],
outputs=[tpd_csv_group, tpd_image_group],
)
# CSV wiring
tpd_outputs = [
tpd_ood_banner,
tpd_probs, tpd_state,
tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon,
]
tpd_btn.click(
analyze_tpd,
inputs=[tpd_files, tpd_betas, tpd_nsamples],
outputs=tpd_outputs,
)
tpd_mech_dd.change(
_on_tpd_mechanism_change,
inputs=[tpd_mech_dd, tpd_state],
outputs=[tpd_posteriors, tpd_param_table, tpd_recon],
)
if TPD_EXAMPLES:
tpd_example_btn.click(
_load_tpd_example,
inputs=[tpd_example_dd],
outputs=[tpd_files, tpd_betas],
)
# Image wiring (hybrid: 7 base outputs + comparison + preview)
tpd_img_outputs = [
tpd_img_ood_banner,
tpd_img_probs, tpd_img_state,
tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon,
tpd_img_comparison, tpd_img_preview,
]
tpd_img_btn.click(
analyze_tpd_image,
inputs=[
tpd_img_files, tpd_img_betas,
tpd_img_threshold, tpd_img_nsamples,
tpd_img_xmin, tpd_img_xmax,
tpd_img_ymin, tpd_img_ymax,
tpd_img_method,
],
outputs=tpd_img_outputs,
)
tpd_img_mech_dd.change(
_on_tpd_mechanism_change,
inputs=[tpd_img_mech_dd, tpd_img_state],
outputs=[tpd_img_posteriors, tpd_img_param_table, tpd_img_recon],
)
if TPD_IMG_EXAMPLES:
def _tpd_img_input_gallery(mech_name):
if not mech_name or mech_name not in TPD_IMG_EXAMPLES:
return []
return [e[0] for e in TPD_IMG_EXAMPLES[mech_name]]
def _on_tpd_img_example_select(mech_name):
files, betas = _load_tpd_image_example(mech_name)
return files, betas, _tpd_img_input_gallery(mech_name)
tpd_img_example_btn.click(
_on_tpd_img_example_select,
inputs=[tpd_img_example_dd],
outputs=[tpd_img_files, tpd_img_betas, tpd_img_example_gallery],
)
tpd_img_example_dd.change(
_tpd_img_input_gallery,
inputs=[tpd_img_example_dd],
outputs=[tpd_img_example_gallery],
)
# =================================================================
# Tab 3: About
# =================================================================
with gr.Tab("About"):
with gr.Row(elem_classes=["about-card"]):
with gr.Column(scale=2):
gr.Markdown("""
## How it works
**SPARK** (Simulation-based Posterior Amortization for Reaction Kinetics) uses
**conditional normalizing flows** with a **Set Transformer** encoder to
perform amortized Bayesian inference. Given one or more experimental
curves, it simultaneously classifies the reaction mechanism and produces
full posterior distributions over kinetic parameters &mdash; in a single
forward pass.
The deployed checkpoints are the **noise-augmented headline models**
(CV: `v14_9mech`, TPD: `tpd_11mech_v2`) that retain 92.4 % (CV) and
95.6 % (TPD) classification accuracy under realistic measurement noise.
Inference runs in **~50 ms on CPU**; mean 90 %-credible-interval coverage
is 92.2 % (CV) and 92.9 % (TPD).
Training data is generated from physics-based simulators
(Crank&ndash;Nicolson for CV, ODE integrators for TPD). Posteriors are
calibrated via a coverage-aware loss with per-parameter inverse-spread
weighting.
### Citation
```
Yan, B. (2026). SPARK: Amortized Bayesian Inference for
Mechanism Identification and Parameter Estimation in
Electrochemistry and Catalysis via Conditional
Normalizing Flows. [Preprint]
```
""")
with gr.Column(scale=1):
gr.Markdown("""
### Supported mechanisms
**Electrochemistry (CV, 9):**
Nernst, Butler&ndash;Volmer, Marcus&ndash;Hush&ndash;Chidsey,
Adsorption, EC, Langmuir&ndash;Hinshelwood, EE, EC&prime;, CE.
**Catalysis (TPD, 11):**
First-order, Second-order, Zeroth-order, FirstOrderCovDep,
DiffLimited, Precursor, Dissociative, ActivatedAdsorption,
LH Surface, Mars&ndash;van Krevelen, TwoSite.
""")
gr.HTML(
"<div class='trace-footer'>"
"<a href='https://github.com/bingyan/ECFlow' target='_blank'>"
"github.com/bingyan/ECFlow</a></div>"
)
gr.HTML("</div>")
return app
# =========================================================================
# Entry point
# =========================================================================
if __name__ == "__main__":
app = build_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
)