""" SPARK (Simulation-based Posterior Amortization for Reaction Kinetics) Gradio web interface for mechanism classification and parameter inference from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data. """ import os import sys import json import tempfile from pathlib import Path import numpy as np import gradio as gr from inference import SPARKPredictor from preprocessing import ( nondimensionalize_cv, estimate_E0, parse_cv_csv, parse_tpd_csv, ) from plotting import ( plot_mechanism_probs, plot_posteriors, plot_parameter_table, plot_reconstruction, plot_concentration_profiles, ) # --------------------------------------------------------------------------- # Model paths (relative to repo root) # --------------------------------------------------------------------------- REPO_ROOT = Path(__file__).resolve().parent DEMO_DIR = REPO_ROOT / "demo_data" DEMO_RENDERS = REPO_ROOT / "demo_renders" EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt" TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt" # Allow override via environment variables. SPARK_*_CHECKPOINT is the # canonical name; ECFLOW_*_CHECKPOINT remains accepted for backward # compatibility with existing HF Space secrets. EC_CHECKPOINT = Path( os.environ.get( "SPARK_EC_CHECKPOINT", os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)), ) ) TPD_CHECKPOINT = Path( os.environ.get( "SPARK_TPD_CHECKPOINT", os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)), ) ) # Optional image-input checkpoints. When set and the file exists, the # "From Image" tabs use the image-input SPARK model directly instead of # digitizing the plot. If not set, the digitizer fallback is used. EC_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_EC_IMAGE_CHECKPOINT", "") TPD_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_IMAGE_CHECKPOINT", "") EC_IMAGE_CHECKPOINT = ( Path(EC_IMAGE_CHECKPOINT_PATH) if EC_IMAGE_CHECKPOINT_PATH else None ) TPD_IMAGE_CHECKPOINT = ( Path(TPD_IMAGE_CHECKPOINT_PATH) if TPD_IMAGE_CHECKPOINT_PATH else None ) # Optional Phase 2 joint (image + waveform) checkpoints. When present, the # "From Image" tabs expose a "Joint image+waveform (Phase 2)" inference # strategy that fuses the rendered image with the digitizer-extracted # waveform inside a single encoder. Falls back to ensemble if missing. EC_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_EC_JOINT_CHECKPOINT", "") TPD_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_JOINT_CHECKPOINT", "") EC_JOINT_CHECKPOINT = ( Path(EC_JOINT_CHECKPOINT_PATH) if EC_JOINT_CHECKPOINT_PATH else None ) TPD_JOINT_CHECKPOINT = ( Path(TPD_JOINT_CHECKPOINT_PATH) if TPD_JOINT_CHECKPOINT_PATH else None ) # --------------------------------------------------------------------------- # Demo examples # --------------------------------------------------------------------------- def _discover_examples(): """Scan demo_data/ for metadata files and build example catalogs.""" cv_examples = {} tpd_examples = {} if not DEMO_DIR.is_dir(): return cv_examples, tpd_examples for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")): with open(meta_path) as f: meta = json.load(f) mech = meta["mechanism"] csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]] is_ood = bool(meta.get("is_ood", False)) ood_tag = " *OOD*" if is_ood else "" if meta_path.name.startswith("ec_"): rates = meta.get("scan_rates_Vs", []) rates_str = ", ".join(f"{r:.4g}" for r in rates) phys = meta.get("physical_params", {}) cv_examples[f"CV \u2014 {mech}{ood_tag}"] = { "files": csv_files, "scan_rates": rates_str, "is_ood": is_ood, "E0_V": phys.get("E0_V"), "T_K": phys.get("T_K", 298.15), "A_cm2": phys.get("A_cm2", 0.0707), "C_mM": phys.get("C_mM", 1.0), "D_cm2s": phys.get("D_cm2s", 1e-5), "n_electrons": phys.get("n_electrons", 1), } elif meta_path.name.startswith("tpd_"): betas = meta.get("betas_Ks", []) betas_str = ", ".join(f"{b:.4g}" for b in betas) tpd_examples[f"TPD \u2014 {mech}{ood_tag}"] = { "files": csv_files, "heating_rates": betas_str, "is_ood": is_ood, } return cv_examples, tpd_examples CV_EXAMPLES, TPD_EXAMPLES = _discover_examples() def _discover_image_examples(): """Build image example catalogs from demo_renders/ directory.""" import re cv_img_examples = {} tpd_img_examples = {} cv_output_renders = {} tpd_output_renders = {} if not DEMO_RENDERS.is_dir(): return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")): m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name) if not m: continue mech, rate_mVs = m.group(1), int(m.group(2)) rate_Vs = rate_mVs / 1000.0 cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs)) for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")): m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name) if not m: continue mech, idx = m.group(1), m.group(2) tpd_img_examples.setdefault(mech, []).append((str(p), idx)) for mech in cv_img_examples: renders = {} for suffix in ["classification", "posteriors", "reconstruction", "concentration"]: rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png" if rp.exists(): renders[suffix] = str(rp) if renders: cv_output_renders[mech] = renders for mech in tpd_img_examples: renders = {} for suffix in ["classification", "posteriors", "reconstruction"]: rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png" if rp.exists(): renders[suffix] = str(rp) if renders: tpd_output_renders[mech] = renders return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders (CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES, CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples() def _load_cv_image_example(mech_name): """Return (files, scan_rates_str) for a CV image example.""" if not mech_name or mech_name not in CV_IMG_EXAMPLES: return [gr.update()] * 2 entries = CV_IMG_EXAMPLES[mech_name] files = [e[0] for e in entries] rates_str = ", ".join(f"{e[1]}" for e in entries) return files, rates_str def _load_tpd_image_example(mech_name): """Return (image_path, heating_rate_str) for a TPD image example.""" if not mech_name or mech_name not in TPD_IMG_EXAMPLES: return [gr.update()] * 2 entries = TPD_IMG_EXAMPLES[mech_name] return entries[0][0], entries[0][1] def _load_cv_example(example_name): """Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example.""" if not example_name or example_name not in CV_EXAMPLES: return [gr.update()] * 8 ex = CV_EXAMPLES[example_name] return ( ex["files"], ex["scan_rates"], ex["E0_V"], ex["T_K"], ex["A_cm2"], ex["C_mM"], ex["D_cm2s"], ex["n_electrons"], ) def _load_tpd_example(example_name): """Return (files, heating_rates) for the chosen TPD example.""" if not example_name or example_name not in TPD_EXAMPLES: return [gr.update()] * 2 ex = TPD_EXAMPLES[example_name] return ( ex["files"], ex["heating_rates"], ) predictor = None def get_predictor(): global predictor if predictor is None: ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None ec_img = (str(EC_IMAGE_CHECKPOINT) if (EC_IMAGE_CHECKPOINT and EC_IMAGE_CHECKPOINT.exists()) else None) tpd_img = (str(TPD_IMAGE_CHECKPOINT) if (TPD_IMAGE_CHECKPOINT and TPD_IMAGE_CHECKPOINT.exists()) else None) ec_joint = (str(EC_JOINT_CHECKPOINT) if (EC_JOINT_CHECKPOINT and EC_JOINT_CHECKPOINT.exists()) else None) tpd_joint = (str(TPD_JOINT_CHECKPOINT) if (TPD_JOINT_CHECKPOINT and TPD_JOINT_CHECKPOINT.exists()) else None) predictor = SPARKPredictor( ec_checkpoint=ec_ckpt, tpd_checkpoint=tpd_ckpt, ec_image_checkpoint=ec_img, tpd_image_checkpoint=tpd_img, ec_joint_checkpoint=ec_joint, tpd_joint_checkpoint=tpd_joint, device="cpu", ) return predictor # ========================================================================= # CV Analysis # ========================================================================= def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2, C_mM, D_cm2s, n_electrons, n_samples): """Analyze CV data from potentiostat CSV files. Accepts CSV files with columns for potential (V) and current (A/mA/µA). If the CSV includes a Time (s) column, the scan rate is auto-detected. Otherwise, scan rates must be provided. """ if not files: return _ec_error("Please upload at least one CSV file.") scan_rates_text = scan_rates_text.strip() if scan_rates_text else "" user_rates = None if scan_rates_text: try: user_rates = [float(s.strip()) for s in scan_rates_text.split(",")] except ValueError: return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.") if len(files) != len(user_rates): return _ec_error( f"Number of files ({len(files)}) must match number of " f"scan rates ({len(user_rates)}).") C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6 n = int(n_electrons) if n_electrons else 1 T = float(T_K) if T_K else 298.15 A = float(A_cm2) if A_cm2 else 0.0707 parsed_data = [] scan_rates = [] for idx, f in enumerate(files): content = Path(f.name).read_text() parsed = parse_cv_csv(content) parsed_data.append(parsed) if user_rates is not None: v = user_rates[idx] elif "scan_rate_Vs" in parsed: v = parsed["scan_rate_Vs"] else: return _ec_error( f"Cannot determine scan rate for file '{Path(f.name).name}'. " "Either provide scan rates (V/s) or upload CSV files that " "include a Time (s) column.") scan_rates.append(v) if E0_V: e0 = float(E0_V) e0_source = "user" else: e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data] e0 = float(np.median(e0_estimates)) e0_source = "auto" D = float(D_cm2s) if D_cm2s else 1e-5 potentials, fluxes, sigmas_list = [], [], [] for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)): E, i_A = parsed["E_V"], parsed["i_A"] theta, flux, sigma = nondimensionalize_cv( E, i_A, v, e0, T, A, C_molcm3, D, n ) potentials.append(theta) fluxes.append(flux) sigmas_list.append(sigma) return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples) _CURRENT_UNIT_SCALES = { "µA": 1e-6, "mA": 1e-3, "A": 1.0, "nA": 1e-9, } def _guess_current_unit(i_values): """Guess the current unit from the magnitude of digitized values.""" i_max = np.max(np.abs(i_values)) if i_max > 1e3: return "nA" if i_max > 100: return "µA" if i_max > 0.1: return "µA" if i_max > 1e-4: return "mA" return "A" def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit, n_samples, x_min, x_max, y_min, y_max, method_label="Ensemble (recommended)"): """Analyze CV from uploaded plot images (one per scan rate). When the image-input SPARK model is available we run a hybrid pipeline: image-mode prediction in parallel with the digitizer + waveform-mode path, and combine via `method_label` (Ensemble / Image-direct / Digitize-then-infer / Auto-fallback). When only the waveform model is available, falls back to digitize-then-infer transparently. Returns 10 outputs: the existing 8 (banner, probs, state, mech_dd, post, table, recon, conc) plus method-comparison HTML + preprocessed-image gallery preview. """ err = _ec_image_error if not files: return err("Please upload at least one image.") try: from digitizer import digitize_plot, auto_detect_axis_bounds from PIL import Image as PILImage except ImportError: return err("Required libraries not available for image digitization.") scan_rate_text = scan_rate_text.strip() if scan_rate_text else "" if not scan_rate_text: return err("Please enter the scan rate(s) (V/s), comma-separated.") try: scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")] except ValueError: return err("Invalid scan rates. Enter comma-separated numbers in V/s.") if len(files) != len(scan_rates): return err( f"Number of images ({len(files)}) must match number of " f"scan rates ({len(scan_rates)}).") has_user_bounds = all( v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] ) D = 1e-5 T = 298.15 A = 0.0707 C_molcm3 = 1e-6 n = 1 # First pass: digitize all images and estimate E0 per image image_data = [] e0_estimates = [] for idx, f in enumerate(files): img_arr = np.array(PILImage.open(f.name).convert("RGB")) v_Vs = scan_rates[idx] if has_user_bounds: bounds = { "x_min": float(x_min), "x_max": float(x_max), "y_min": float(y_min), "y_max": float(y_max), } else: bounds = auto_detect_axis_bounds(img_arr) if bounds is None: return err( f"Could not auto-detect axis bounds for image {idx + 1}. " "Please enter E min, E max, I min, I max under " "'Axis overrides'.") try: E_V, I_raw = digitize_plot( img_arr, bounds["x_min"], bounds["x_max"], bounds["y_min"], bounds["y_max"], threshold=int(threshold), x_ticks=bounds.get("x_ticks"), y_ticks=bounds.get("y_ticks"), ) except Exception as exc: return err(f"Digitization failed for image {idx + 1}: {exc}") if current_unit and current_unit != "auto": i_unit = current_unit elif "y_unit" in bounds: i_unit = bounds["y_unit"] else: i_unit = _guess_current_unit(I_raw) i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6) i_A = I_raw * i_scale e0_estimates.append(float(estimate_E0(E_V, i_A))) image_data.append((E_V, i_A, v_Vs, i_unit, bounds)) # Determine E0: user-provided or median of per-image estimates if E0_V is not None and E0_V != 0: e0 = float(E0_V) e0_source = "user" else: e0 = float(np.median(e0_estimates)) e0_source = "auto" # Second pass: nondimensionalize with shared E0 potentials, fluxes, sigmas_list = [], [], [] preproc_parts = [] for E_V, i_A, v_Vs, i_unit, bounds in image_data: theta, flux, sigma = nondimensionalize_cv( E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n ) potentials.append(theta) fluxes.append(flux) sigmas_list.append(sigma) preproc_parts.append( f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, " f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, " f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})" ) pil_images = [PILImage.open(f.name).convert("L") for f in files] return _run_ec_analysis_hybrid( pil_images, sigmas_list, potentials, fluxes, n_samples, method_label, ) OOD_THRESHOLD = float(os.environ.get("SPARK_OOD_THRESHOLD", "0.5")) def _ood_banner_update(ood_score, threshold=OOD_THRESHOLD): """Build the gr.HTML update for the OOD banner. Shown only when the binary OOD head's P(in-distribution) score drops below `threshold`. Hidden otherwise. """ if ood_score is None or ood_score >= threshold: return gr.update(visible=False) html = ( "
" "Possible out-of-distribution input. " f"OOD score = {ood_score:.2f} (below {threshold:.2f}). " "The mechanism prediction below is the closest match within the " "trained mechanism set; treat it as low-confidence and consider " "whether the true mechanism may lie outside the catalog." "
" ) return gr.update(value=html, visible=True) def _ec_error(msg=""): """Return empty outputs for EC error cases.""" return ( gr.update(visible=False), # ood_banner None, None, gr.Dropdown(choices=[], value=None), None, None, None, None, ) def _run_ec_analysis(potentials, fluxes, sigmas, n_samples): """Core EC analysis: predict + reconstruct for top mechanism.""" pred = get_predictor() result = pred.predict_ec(potentials, fluxes, sigmas, n_samples=int(n_samples)) top_mech = result["predicted_mechanism"] recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas) fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="ec") sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] state = { "result": result, "potentials": [p.tolist() for p in potentials], "fluxes": [f.tolist() for f in fluxes], "sigmas": sigmas, } fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism( top_mech, result, recon, sigmas ) return ( _ood_banner_update(result.get("ood_score")), fig_probs, state, gr.Dropdown(choices=mech_choices, value=mech_choices[0]), fig_post, fig_table, fig_recon, fig_conc, ) _HYBRID_MODE_LABELS = { "Ensemble (recommended)": "ensemble", "Image-direct": "image_only", "Digitize-then-infer": "digitize_only", "Auto-fallback": "auto_fallback", } _JOINT_LABEL = "Joint image+waveform (Phase 2)" def _hybrid_method_to_internal(label): """Map the user-facing radio label to the internal mode key.""" if not label: return "ensemble" if label == _JOINT_LABEL: return "joint" return _HYBRID_MODE_LABELS.get(label, "ensemble") def _hybrid_choices_and_default(domain): """Return (radio_choices, default_label) for the inference-strategy radio, inserting the Phase 2 joint option (and making it default) when the corresponding joint checkpoint is loaded. Rationale for ``Ensemble (recommended)`` as the pre-Phase-2 default: head-to-head eval (``outputs/image_vs_digitize/decision.md``) showed that on the rendered held-out test split image-only beats digitize- then-infer by ~30 pp (CV: 40.9% vs 5.4%, TPD: 36.1% vs 2.8%), but on real-world-style stress images the ordering FLIPS (CV: 6.7% vs 26.7%, TPD: 0% vs 60%). Ensemble averages both posteriors, so it's robust to either failure mode. The default automatically switches to ``Joint image+waveform (Phase 2)`` once the joint checkpoint is available. """ pred = get_predictor() has_joint = (pred.has_ec_joint_model if domain == "ec" else pred.has_tpd_joint_model) base = list(_HYBRID_MODE_LABELS.keys()) if has_joint: return [_JOINT_LABEL] + base, _JOINT_LABEL return base, "Ensemble (recommended)" def _method_comparison_html(hybrid_out, threshold=OOD_THRESHOLD): """Render the Method-comparison panel. Shows per-method top mechanism, top probability, OOD score, and an agreement badge. Hidden when only one method ran (no comparison to make). Includes a third 'Joint (Phase 2)' card when the joint encoder ran. """ img = hybrid_out.get("image_mode") wave = hybrid_out.get("waveform_mode") joint = hybrid_out.get("joint_mode") method_used = hybrid_out.get("method_used", "") available = [m for m in (img, wave, joint) if m is not None] if len(available) < 2: return gr.update(visible=False) def _ood_html(score): if score is None: return "—" col = "#1b5e20" if score >= threshold else "#7a0000" return (f"" f"{score:.2f}") def _card(title, result): top = max(result["mechanism_probs"], key=result["mechanism_probs"].get) return f"""
{title}
top mech: {top} ({result['mechanism_probs'][top]:.1%})
OOD score: {_ood_html(result.get('ood_score'))}
""" cards = [] if img is not None: cards.append(_card("Image-direct (Phase 1)", img)) if wave is not None: cards.append(_card("Digitize-then-infer", wave)) if joint is not None: cards.append(_card("Joint image+waveform (Phase 2)", joint)) tops = [max(r["mechanism_probs"], key=r["mechanism_probs"].get) for r in available] all_agree = len(set(tops)) == 1 if all_agree: badge = ("" "methods agree") else: badge = ("" "methods disagree \u2014 review carefully") footnote = ( "Image-direct (Phase 1) is the image-only SPARK; Digitize-then-infer " "uses the headline waveform model on a curve extracted from the plot; " "Joint (Phase 2) fuses both signals inside a single encoder trained " "with real-world distortion augmentation. When available, the joint " "model is the recommended primary path." ) html = f"""
Method comparison headline: {method_used}
{''.join(cards)}
{badge}
{footnote}
""" return gr.update(value=html, visible=True) def _preprocessing_preview(hybrid_out, original_pils): """Build a gallery preview showing what image-mode actually consumed after auto-crop + gridline removal.""" pre_meta = hybrid_out.get("preprocessing_meta") or [] if not pre_meta or hybrid_out.get("image_mode") is None: return gr.update(visible=False) try: from image_preprocessing import prepare_for_image_mode previews = [] for orig, meta in zip(original_pils, pre_meta): prep, _ = prepare_for_image_mode(orig) tag_parts = [] if meta.get("was_cropped"): tag_parts.append("cropped") if meta.get("was_cleaned"): tag_parts.append( f"removed {meta.get('n_horiz_gridlines', 0)} horiz " f"+ {meta.get('n_vert_gridlines', 0)} vert lines" ) tag = " | ".join(tag_parts) if tag_parts else "no changes needed" previews.append((prep, tag)) return gr.update(value=previews, visible=True) except Exception as exc: print(f"[preview] failed: {exc}") return gr.update(visible=False) def _ec_image_error(msg=""): """Return empty outputs for EC image-tab error cases (10 outputs).""" return ( gr.update(visible=False), # ood_banner None, None, # probs, state gr.Dropdown(choices=[], value=None), None, None, None, None, # post, table, recon, conc gr.update(visible=False), # comparison_html gr.update(visible=False), # preview_gallery ) def _run_ec_analysis_hybrid(pil_images, sigmas, potentials, fluxes, n_samples, method_label): """Hybrid CV analysis: runs image-mode + digitizer-mode in parallel via predict_ec_hybrid and renders both the headline result (selected by `method_label`) and the per-method comparison panel. Returns a 10-tuple matching `ec_img_outputs` in the UI wiring. """ pred = get_predictor() mode = _hybrid_method_to_internal(method_label) try: hybrid_out = pred.predict_ec_hybrid( pil_images, sigmas, potentials=potentials, fluxes=fluxes, n_samples=int(n_samples), mode=mode, ) except Exception as exc: print(f"[hybrid CV] failed: {exc}") return _ec_image_error() headline = hybrid_out["headline"] top_mech = headline["predicted_mechanism"] recon = pred.reconstruct_ec(headline, potentials, fluxes, sigmas) fig_probs = plot_mechanism_probs(headline["mechanism_probs"], domain="ec") sorted_mechs = sorted(headline["mechanism_probs"].items(), key=lambda x: -x[1]) mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] state = { "result": headline, "potentials": [p.tolist() for p in potentials], "fluxes": [f.tolist() for f in fluxes], "sigmas": sigmas, } fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism( top_mech, headline, recon, sigmas ) comparison = _method_comparison_html(hybrid_out) preview = _preprocessing_preview(hybrid_out, pil_images) return ( _ood_banner_update(headline.get("ood_score")), fig_probs, state, gr.Dropdown(choices=mech_choices, value=mech_choices[0]), fig_post, fig_table, fig_recon, fig_conc, comparison, preview, ) def _render_ec_mechanism(mech, result, recon, sigmas): """Render posteriors, param table, reconstruction, and concentration for one EC mechanism.""" stats = result["parameter_stats"].get(mech) samples = result["posterior_samples"].get(mech) fig_posteriors = None fig_table = None if stats and samples is not None: fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="ec") fig_table = plot_parameter_table(stats, mech) fig_recon = None fig_conc = None if recon is not None: scan_labels = [f"\u03c3 = {s:.2f}" for s in sigmas] if sigmas else None fig_recon = plot_reconstruction( recon["observed"], recon["reconstructed"], domain="ec", nrmses=recon.get("nrmse"), r2s=recon.get("r2"), scan_labels=scan_labels, ) conc_curves = recon.get("concentrations") if conc_curves: fig_conc = plot_concentration_profiles(conc_curves, scan_labels=scan_labels) return fig_posteriors, fig_table, fig_recon, fig_conc def _on_ec_mechanism_change(mech_choice, state): """Callback when user selects a different EC mechanism from the dropdown.""" if not state or not mech_choice: return None, None, None, None mech = mech_choice.split(" (")[0] result = state["result"] potentials = [np.array(p) for p in state["potentials"]] fluxes = [np.array(f) for f in state["fluxes"]] sigmas = state["sigmas"] pred = get_predictor() recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas, mechanism=mech) return _render_ec_mechanism(mech, result, recon, sigmas) # ========================================================================= # TPD Analysis # ========================================================================= def analyze_tpd(files, heating_rates_text, n_samples): """Analyze TPD data.""" if not files: return _tpd_error("Please upload at least one CSV file.") temperatures, rates = [], [] csv_betas = [] for f in files: content = Path(f.name).read_text() parsed = parse_tpd_csv(content) temperatures.append(parsed["T_K"]) rates.append(parsed["signal"]) if "beta_Ks" in parsed: csv_betas.append(parsed["beta_Ks"]) heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" if heating_rates_text: try: betas = [float(s.strip()) for s in heating_rates_text.split(",")] except ValueError: return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.") if len(files) != len(betas): return _tpd_error( f"Number of files ({len(files)}) must match heating rates ({len(betas)}).") elif len(csv_betas) == len(files): betas = csv_betas else: return _tpd_error( "Please enter the heating rate (β in K/s) for each file. " "This value is critical for correct inference. " "Alternatively, include a 'Time (s)' column in your CSV so β can be computed automatically.") return _run_tpd_analysis(temperatures, rates, betas, n_samples) def analyze_tpd_image(files, heating_rates_text, threshold, n_samples, x_min, x_max, y_min, y_max, method_label="Ensemble (recommended)"): """Analyze TPD from uploaded plot images (one per heating rate). Hybrid: image-mode + digitizer-then-waveform run in parallel and are combined per `method_label`. Axis bounds auto-detected via OCR. Returns 9 outputs: existing 7 + method-comparison HTML + preview gallery. """ err = _tpd_image_error if not files: return err("Please upload at least one image.") try: from digitizer import digitize_plot, auto_detect_axis_bounds from PIL import Image as PILImage except ImportError: return err("Required libraries not available for image digitization.") heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" if not heating_rates_text: return err( "Please enter the heating rate(s) (β in K/s), comma-separated. " "This value is critical for correct inference.") try: betas = [float(s.strip()) for s in heating_rates_text.split(",")] except ValueError: return err("Invalid heating rates. Enter comma-separated numbers in K/s.") if len(files) != len(betas): return err( f"Number of images ({len(files)}) must match number of " f"heating rates ({len(betas)}).") has_user_bounds = all( v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] ) temperatures, rates = [], [] for idx, f in enumerate(files): img_arr = np.array(PILImage.open(f.name).convert("RGB")) if has_user_bounds: bounds = { "x_min": float(x_min), "x_max": float(x_max), "y_min": float(y_min), "y_max": float(y_max), } else: bounds = auto_detect_axis_bounds(img_arr) if bounds is None: return err( f"Could not auto-detect axis bounds for image {idx + 1}. " "Please enter T min, T max, Signal min, Signal max " "under 'Axis overrides'.") try: x_data, y_data = digitize_plot( img_arr, bounds["x_min"], bounds["x_max"], bounds["y_min"], bounds["y_max"], threshold=int(threshold), x_ticks=bounds.get("x_ticks"), y_ticks=bounds.get("y_ticks"), ) except Exception as exc: return err(f"Digitization failed for image {idx + 1}: {exc}") temperatures.append(x_data) rates.append(y_data) pil_images = [PILImage.open(f.name).convert("L") for f in files] return _run_tpd_analysis_hybrid( pil_images, betas, temperatures, rates, n_samples, method_label, ) def _tpd_error(msg=""): """Return empty outputs for TPD error cases (CSV path: 7 outputs).""" return ( gr.update(visible=False), # ood_banner None, None, gr.Dropdown(choices=[], value=None), None, None, None, ) def _tpd_image_error(msg=""): """Return empty outputs for TPD image-tab errors (9 outputs).""" return ( gr.update(visible=False), # ood_banner None, None, gr.Dropdown(choices=[], value=None), None, None, None, gr.update(visible=False), # comparison_html gr.update(visible=False), # preview_gallery ) def _run_tpd_analysis(temperatures, rates, betas, n_samples): """Core TPD analysis: predict + reconstruct for top mechanism.""" pred = get_predictor() result = pred.predict_tpd(temperatures, rates, betas, n_samples=int(n_samples)) top_mech = result["predicted_mechanism"] recon = pred.reconstruct_tpd(result, temperatures, rates, betas) fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="tpd") sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] state = { "result": result, "temperatures": [t.tolist() for t in temperatures], "rates": [r.tolist() for r in rates], "betas": betas, } fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, result, recon, betas) return ( _ood_banner_update(result.get("ood_score")), fig_probs, state, gr.Dropdown(choices=mech_choices, value=mech_choices[0]), fig_post, fig_table, fig_recon, ) def _run_tpd_analysis_hybrid(pil_images, betas, temperatures, rates, n_samples, method_label): """Hybrid TPD analysis: image-mode + digitizer-then-waveform combined via predict_tpd_hybrid. Returns 9 outputs to match `tpd_img_outputs`.""" pred = get_predictor() mode = _hybrid_method_to_internal(method_label) try: hybrid_out = pred.predict_tpd_hybrid( pil_images, betas, temperatures=temperatures, rates=rates, n_samples=int(n_samples), mode=mode, ) except Exception as exc: print(f"[hybrid TPD] failed: {exc}") return _tpd_image_error() headline = hybrid_out["headline"] top_mech = headline["predicted_mechanism"] recon = pred.reconstruct_tpd(headline, temperatures, rates, betas) fig_probs = plot_mechanism_probs(headline["mechanism_probs"], domain="tpd") sorted_mechs = sorted(headline["mechanism_probs"].items(), key=lambda x: -x[1]) mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] state = { "result": headline, "temperatures": [t.tolist() for t in temperatures], "rates": [r.tolist() for r in rates], "betas": betas, } fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, headline, recon, betas) comparison = _method_comparison_html(hybrid_out) preview = _preprocessing_preview(hybrid_out, pil_images) return ( _ood_banner_update(headline.get("ood_score")), fig_probs, state, gr.Dropdown(choices=mech_choices, value=mech_choices[0]), fig_post, fig_table, fig_recon, comparison, preview, ) def _render_tpd_mechanism(mech, result, recon, betas): """Render posteriors, param table, and reconstruction for one TPD mechanism.""" stats = result["parameter_stats"].get(mech) samples = result["posterior_samples"].get(mech) fig_posteriors = None fig_table = None if stats and samples is not None: fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="tpd") fig_table = plot_parameter_table(stats, mech) fig_recon = None if recon is not None: scan_labels = [f"\u03b2 = {b:.2f} K/s" for b in betas] if betas else None fig_recon = plot_reconstruction( recon["observed"], recon["reconstructed"], domain="tpd", nrmses=recon.get("nrmse"), r2s=recon.get("r2"), scan_labels=scan_labels, ) return fig_posteriors, fig_table, fig_recon def _on_tpd_mechanism_change(mech_choice, state): """Callback when user selects a different TPD mechanism from the dropdown.""" if not state or not mech_choice: return None, None, None mech = mech_choice.split(" (")[0] result = state["result"] temperatures = [np.array(t) for t in state["temperatures"]] rates = [np.array(r) for r in state["rates"]] betas = state["betas"] pred = get_predictor() recon = pred.reconstruct_tpd(result, temperatures, rates, betas, mechanism=mech) return _render_tpd_mechanism(mech, result, recon, betas) # ========================================================================= # Shared helpers # ========================================================================= def download_results(result_text): """Create a downloadable JSON from the summary.""" if not result_text: return None tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") tmp.write(result_text) tmp.close() return tmp.name # ========================================================================= # Gradio UI # ========================================================================= def _build_ec_output_section(prefix): """Build shared output components for one EC input tab. Returns (ood_banner, probs, state, mech_dd, posteriors, param_table, recon, conc). """ with gr.Group(elem_classes=["results-card"]): gr.HTML("

Results

") ood_banner = gr.HTML(visible=False) probs = gr.Plot(label="Mechanism Classification") state = gr.State(value=None) mech_dd = gr.Dropdown( label="Inspect a mechanism in detail", choices=[], interactive=True, ) with gr.Tabs(elem_classes=["detail-tabs"]): with gr.Tab("Posteriors"): posteriors = gr.Plot(label=None, show_label=False) with gr.Tab("Parameters"): param_table = gr.Plot(label=None, show_label=False) with gr.Tab("Reconstruction"): recon = gr.Plot(label=None, show_label=False) with gr.Tab("Concentration"): conc = gr.Plot(label=None, show_label=False) return ood_banner, probs, state, mech_dd, posteriors, param_table, recon, conc def _build_tpd_output_section(prefix): """Build shared output components for one TPD input tab. Returns (ood_banner, probs, state, mech_dd, posteriors, param_table, recon). """ with gr.Group(elem_classes=["results-card"]): gr.HTML("

Results

") ood_banner = gr.HTML(visible=False) probs = gr.Plot(label="Mechanism Classification") state = gr.State(value=None) mech_dd = gr.Dropdown( label="Inspect a mechanism in detail", choices=[], interactive=True, ) with gr.Tabs(elem_classes=["detail-tabs"]): with gr.Tab("Posteriors"): posteriors = gr.Plot(label=None, show_label=False) with gr.Tab("Parameters"): param_table = gr.Plot(label=None, show_label=False) with gr.Tab("Reconstruction"): recon = gr.Plot(label=None, show_label=False) return ood_banner, probs, state, mech_dd, posteriors, param_table, recon CUSTOM_CSS = """ /* ---------- Page container & typography ---------- */ .gradio-container { background: #F8FAFC !important; } .trace-page { max-width: 1180px; margin: 0 auto; padding: 8px 24px 32px 24px; } .trace-page * { font-feature-settings: "ss01"; } /* ---------- Header ---------- */ .main-header { text-align: center; padding: 28px 16px 16px 16px; } .main-header h1 { font-size: 2.4em; margin: 0 0 6px 0; letter-spacing: -0.6px; font-weight: 700; color: #0F172A; } .main-header h1 .accent { color: #2563EB; } .main-header p { color: #64748B; font-size: 1.02em; max-width: 720px; margin: 0 auto; line-height: 1.5; } .main-header .accent-rule { height: 3px; width: 56px; margin: 14px auto 0 auto; background: linear-gradient(90deg, #2563EB, #60A5FA); border-radius: 2px; } /* ---------- Cards ---------- */ .input-card, .results-card { background: #FFFFFF !important; border: 1px solid #E5E7EB !important; border-radius: 14px !important; padding: 22px 26px !important; margin-top: 16px !important; box-shadow: 0 1px 2px rgba(15, 23, 42, 0.04) !important; } .results-card { padding-top: 18px !important; } .results-empty { color: #94A3B8; text-align: center; padding: 60px 16px; font-size: 1.0em; } .results-empty .arrow { font-size: 2em; display: block; margin-bottom: 8px; } /* ---------- Helper text ---------- */ .helper { color: #64748B; font-size: 0.93em; line-height: 1.5; margin: 0 0 14px 0; } .helper code { background: #F1F5F9; padding: 1px 6px; border-radius: 4px; font-size: 0.92em; color: #334155; } /* ---------- Mode radio (CSV / Image toggle) ---------- */ .mode-radio { margin: 4px 0 0 0 !important; } .mode-radio label[data-testid="block-label"] { display: none !important; } .mode-radio .wrap { gap: 6px !important; } /* ---------- Top tabs polish ---------- */ .tabs > .tab-nav { border-bottom: 1px solid #E5E7EB !important; padding-left: 4px !important; } .tabs > .tab-nav button { font-size: 1.02em !important; font-weight: 500 !important; color: #64748B !important; padding: 12px 20px !important; border-bottom: 2px solid transparent !important; text-transform: none !important; } .tabs > .tab-nav button.selected { color: #2563EB !important; border-bottom-color: #2563EB !important; } /* ---------- Detail tabs (per-mechanism plot views) ---------- */ .detail-tabs { margin-top: 14px; } .detail-tabs .tab-nav button { font-size: 0.95em !important; padding: 8px 14px !important; } /* ---------- Accordion ---------- */ .gr-accordion .label-wrap { padding: 8px 12px !important; } /* ---------- Buttons ---------- */ .analyze-btn { margin-top: 14px !important; } .analyze-btn button { font-size: 1.05em !important; padding: 12px 18px !important; font-weight: 600 !important; } /* ---------- About tab ---------- */ .about-card { background: #FFFFFF; border: 1px solid #E5E7EB; border-radius: 14px; padding: 26px 30px; margin-top: 16px; box-shadow: 0 1px 2px rgba(15, 23, 42, 0.04); } .about-card h2 { margin-top: 0; font-size: 1.4em; color: #0F172A; } .about-card table { font-size: 0.93em; } .about-card pre { background: #F8FAFC; border: 1px solid #E5E7EB; border-radius: 8px; padding: 14px 16px; font-size: 0.9em; color: #334155; } /* ---------- OOD banner ---------- */ .ood-banner { background: #FEF3C7; border: 1px solid #F59E0B; color: #92400E; border-radius: 10px; padding: 12px 16px; margin: 4px 0 14px 0; font-size: 0.96em; line-height: 1.45; } .ood-banner strong { color: #78350F; } /* ---------- Footer ---------- */ .trace-footer { text-align: center; color: #94A3B8; font-size: 0.88em; padding: 28px 0 8px 0; } .trace-footer a { color: #64748B; text-decoration: none; border-bottom: 1px dotted #CBD5E1; } .trace-footer a:hover { color: #2563EB; border-bottom-color: #2563EB; } /* ---------- Hide Gradio chrome ---------- */ footer { display: none !important; } """ def build_app(): with gr.Blocks( title="SPARK — Bayesian Inference for Electrochemistry & Catalysis", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="slate", font=gr.themes.GoogleFont("Inter"), ), css=CUSTOM_CSS, ) as app: gr.HTML("
") gr.HTML( "
" "

SPARK

" "

Simulation-based Posterior " "Amortization for Reaction " "Kinetics — Bayesian inference for " "reaction mechanisms and kinetic parameters from cyclic " "voltammetry (CV) or temperature-programmed desorption (TPD) " "data, in one forward pass.

" "
" "
" ) with gr.Tabs(): # ================================================================= # Tab 1: CV Analysis # ================================================================= with gr.Tab("CV Analysis"): cv_mode = gr.Radio( choices=["CSV Data", "From Image"], value="CSV Data", label=None, show_label=False, elem_classes=["mode-radio"], ) # --- CSV upload mode (primary) --- with gr.Column(visible=True) as cv_csv_group: with gr.Group(elem_classes=["input-card"]): gr.HTML( "

One CSV per scan rate. Expected columns: " "Potential (V), Current (A/mA/µA), " "optional Time (s) — if present, the scan rate is " "inferred as the median of |dE/dt| over the forward " "sweep (no need to specify it manually).

" ) if CV_EXAMPLES: with gr.Row(): cv_example_dd = gr.Dropdown( label="Try an example", choices=list(CV_EXAMPLES.keys()), value=None, interactive=True, scale=4, ) cv_example_btn = gr.Button( "Load", variant="secondary", scale=1, ) cv_files = gr.File( label="CSV files (one per scan rate)", file_count="multiple", file_types=[".csv", ".txt"], ) cv_rates = gr.Textbox( label="Scan rates (V/s), comma-separated", placeholder="e.g., 0.01, 0.1, 1.0 (leave empty if CSV has time column)", value="", ) with gr.Accordion("Advanced parameters", open=False): with gr.Row(): cv_E0 = gr.Number( label="Formal potential E\u2080 (V)", value=None, info="Auto-estimated from peak positions if empty", ) cv_T = gr.Number(label="Temperature (K)", value=298.15) cv_A = gr.Number(label="Electrode area (cm\u00b2)", value=0.0707) with gr.Row(): cv_C = gr.Number(label="Concentration (mM)", value=1.0) cv_D = gr.Number( label="Diffusion coeff D (cm\u00b2/s)", value=None, info="Estimated via Randles-\u0160ev\u010d\u00edk if empty", ) cv_n = gr.Number(label="Number of electrons", value=1, precision=0) cv_nsamples = gr.Slider( 100, 2000, value=500, step=100, label="Posterior samples", ) cv_btn = gr.Button( "Analyze", variant="primary", elem_classes=["analyze-btn"], ) (cv_ood_banner, cv_probs, cv_state, cv_mech_dd, cv_posteriors, cv_param_table, cv_recon, cv_conc) = _build_ec_output_section("cv") # --- Image mode --- with gr.Column(visible=False) as cv_image_group: with gr.Group(elem_classes=["input-card"]): gr.HTML( "

One image per scan rate (potential on x-axis, " "current on y-axis). Axis bounds auto-detected via OCR — " "override in Advanced if needed. " "Note: digitized curves are inherently noisier than the " "simulated data the model was trained on, so the OOD detector may " "flag image inputs and reconstruction quality is typically lower " "than the CSV path; for production use, prefer CSV.

" ) if CV_IMG_EXAMPLES: with gr.Row(): cv_img_example_dd = gr.Dropdown( label="Try an example", choices=list(CV_IMG_EXAMPLES.keys()), value=None, interactive=True, scale=4, ) cv_img_example_btn = gr.Button( "Load", variant="secondary", scale=1, ) cv_img_example_gallery = gr.Gallery( label="Input plot images (one per scan rate) \u2014 these are what the model digitizes", columns=3, height=220, object_fit="contain", interactive=False, ) cv_img_files = gr.File( label="Plot images (one per scan rate)", file_count="multiple", file_types=["image"], ) cv_img_scan_rate = gr.Textbox( label="Scan rates (V/s), comma-separated", placeholder="e.g., 0.01, 0.1, 1.0", value="", ) with gr.Accordion("Advanced parameters", open=False): with gr.Row(): cv_img_E0 = gr.Number( label="Formal potential E\u2080 (V)", value=None, info="Auto-estimated from peaks if empty", ) cv_img_threshold = gr.Slider( 0, 255, value=0, step=1, label="Binarization threshold (0 = auto)", ) cv_img_current_unit = gr.Dropdown( label="Current unit on y-axis", choices=["auto", "\u00b5A", "mA", "A", "nA"], value="auto", info="Select the unit shown on the y-axis of your plot", ) with gr.Row(): cv_img_xmin = gr.Number(label="E min (V)", value=None) cv_img_xmax = gr.Number(label="E max (V)", value=None) cv_img_ymin = gr.Number(label="I min", value=None) cv_img_ymax = gr.Number(label="I max", value=None) cv_img_nsamples = gr.Slider( 100, 2000, value=500, step=100, label="Posterior samples", ) _cv_choices, _cv_default = _hybrid_choices_and_default("ec") cv_img_method = gr.Radio( choices=_cv_choices, value=_cv_default, label="Inference strategy", info=("Joint (Phase 2) fuses the rendered image " "with digitizer-extracted waveforms in a " "single encoder; Ensemble averages " "image-direct and digitize-then-infer; " "Auto-fallback uses image-direct unless " "its OOD score is low."), ) cv_img_btn = gr.Button( "Analyze", variant="primary", elem_classes=["analyze-btn"], ) (cv_img_ood_banner, cv_img_probs, cv_img_state, cv_img_mech_dd, cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc) = _build_ec_output_section("cv_img") with gr.Accordion("Method comparison", open=False): cv_img_comparison = gr.HTML(visible=False) cv_img_preview = gr.Gallery( label="Preprocessed image fed to image-mode CNN", columns=3, height=180, object_fit="contain", interactive=False, visible=False, ) cv_mode.change( lambda v: ( gr.update(visible=v == "CSV Data"), gr.update(visible=v == "From Image"), ), inputs=[cv_mode], outputs=[cv_csv_group, cv_image_group], ) # CSV wiring ec_outputs = [ cv_ood_banner, cv_probs, cv_state, cv_mech_dd, cv_posteriors, cv_param_table, cv_recon, cv_conc, ] cv_btn.click( analyze_cv, inputs=[ cv_files, cv_rates, cv_E0, cv_T, cv_A, cv_C, cv_D, cv_n, cv_nsamples, ], outputs=ec_outputs, ) cv_mech_dd.change( _on_ec_mechanism_change, inputs=[cv_mech_dd, cv_state], outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc], ) if CV_EXAMPLES: cv_example_btn.click( _load_cv_example, inputs=[cv_example_dd], outputs=[ cv_files, cv_rates, cv_E0, cv_T, cv_A, cv_C, cv_D, cv_n, ], ) # Image wiring (hybrid: 8 base outputs + comparison + preview) ec_img_outputs = [ cv_img_ood_banner, cv_img_probs, cv_img_state, cv_img_mech_dd, cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc, cv_img_comparison, cv_img_preview, ] cv_img_btn.click( analyze_cv_image, inputs=[ cv_img_files, cv_img_scan_rate, cv_img_E0, cv_img_threshold, cv_img_current_unit, cv_img_nsamples, cv_img_xmin, cv_img_xmax, cv_img_ymin, cv_img_ymax, cv_img_method, ], outputs=ec_img_outputs, ) cv_img_mech_dd.change( _on_ec_mechanism_change, inputs=[cv_img_mech_dd, cv_img_state], outputs=[cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc], ) if CV_IMG_EXAMPLES: def _cv_img_input_gallery(mech_name): """Show the actual input PNGs the user will analyze.""" if not mech_name or mech_name not in CV_IMG_EXAMPLES: return [] return [e[0] for e in CV_IMG_EXAMPLES[mech_name]] def _on_cv_img_example_select(mech_name): files, rates = _load_cv_image_example(mech_name) return files, rates, _cv_img_input_gallery(mech_name) cv_img_example_btn.click( _on_cv_img_example_select, inputs=[cv_img_example_dd], outputs=[cv_img_files, cv_img_scan_rate, cv_img_example_gallery], ) cv_img_example_dd.change( _cv_img_input_gallery, inputs=[cv_img_example_dd], outputs=[cv_img_example_gallery], ) # ================================================================= # Tab 2: TPD Analysis # ================================================================= with gr.Tab("TPD Analysis"): tpd_mode = gr.Radio( choices=["CSV Data", "From Image"], value="CSV Data", label=None, show_label=False, elem_classes=["mode-radio"], ) # --- CSV mode --- with gr.Column(visible=True) as tpd_csv_group: with gr.Group(elem_classes=["input-card"]): gr.HTML( "

One CSV per heating rate. Expected columns: " "Temperature (K), Signal, optional " "Time (s) — if present, the heating rate β is " "inferred as dT/dt from the time stamps " "(no need to specify it manually).

" ) if TPD_EXAMPLES: with gr.Row(): tpd_example_dd = gr.Dropdown( label="Try an example", choices=list(TPD_EXAMPLES.keys()), value=None, interactive=True, scale=4, ) tpd_example_btn = gr.Button( "Load", variant="secondary", scale=1, ) tpd_files = gr.File( label="CSV files (one per heating rate)", file_count="multiple", file_types=[".csv", ".txt"], ) tpd_betas = gr.Textbox( label="Heating rates \u03b2 (K/s), comma-separated", placeholder="e.g., 0.3, 2.6, 22.1 (leave empty if CSV has time column)", value="", ) with gr.Accordion("Advanced parameters", open=False): tpd_nsamples = gr.Slider( 100, 2000, value=500, step=100, label="Posterior samples", ) tpd_btn = gr.Button( "Analyze", variant="primary", elem_classes=["analyze-btn"], ) (tpd_ood_banner, tpd_probs, tpd_state, tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon) = _build_tpd_output_section("tpd") # --- Image mode --- with gr.Column(visible=False) as tpd_image_group: with gr.Group(elem_classes=["input-card"]): gr.HTML( "

One image per heating rate (temperature on " "x-axis, signal on y-axis). Axis bounds auto-detected via OCR " "— override in Advanced if needed. " "Note: digitized curves are inherently noisier than the " "simulated data the model was trained on, so the OOD detector may " "flag image inputs and reconstruction quality is typically lower " "than the CSV path; for production use, prefer CSV.

" ) if TPD_IMG_EXAMPLES: with gr.Row(): tpd_img_example_dd = gr.Dropdown( label="Try an example", choices=list(TPD_IMG_EXAMPLES.keys()), value=None, interactive=True, scale=4, ) tpd_img_example_btn = gr.Button( "Load", variant="secondary", scale=1, ) tpd_img_example_gallery = gr.Gallery( label="Input plot image \u2014 this is what the model digitizes", columns=3, height=220, object_fit="contain", interactive=False, ) tpd_img_files = gr.File( label="Plot images (one per heating rate)", file_count="multiple", file_types=["image"], ) tpd_img_betas = gr.Textbox( label="Heating rates \u03b2 (K/s), comma-separated", placeholder="e.g., 0.3, 2.6, 22.1", value="", ) with gr.Accordion("Advanced parameters", open=False): with gr.Row(): tpd_img_threshold = gr.Slider( 0, 255, value=0, step=1, label="Binarization threshold (0 = auto)", ) with gr.Row(): tpd_img_xmin = gr.Number(label="T min (K)", value=None) tpd_img_xmax = gr.Number(label="T max (K)", value=None) tpd_img_ymin = gr.Number(label="Signal min", value=None) tpd_img_ymax = gr.Number(label="Signal max", value=None) tpd_img_nsamples = gr.Slider( 100, 2000, value=500, step=100, label="Posterior samples", ) _tpd_choices, _tpd_default = _hybrid_choices_and_default("tpd") tpd_img_method = gr.Radio( choices=_tpd_choices, value=_tpd_default, label="Inference strategy", info=("Joint (Phase 2) fuses the image with " "digitizer-extracted waveforms in a " "single encoder; Ensemble averages " "image-direct and digitize-then-infer; " "Auto-fallback uses image-direct unless " "its OOD score is low."), ) tpd_img_btn = gr.Button( "Analyze", variant="primary", elem_classes=["analyze-btn"], ) (tpd_img_ood_banner, tpd_img_probs, tpd_img_state, tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon) = _build_tpd_output_section("tpd_img") with gr.Accordion("Method comparison", open=False): tpd_img_comparison = gr.HTML(visible=False) tpd_img_preview = gr.Gallery( label="Preprocessed image fed to image-mode CNN", columns=3, height=180, object_fit="contain", interactive=False, visible=False, ) tpd_mode.change( lambda v: ( gr.update(visible=v == "CSV Data"), gr.update(visible=v == "From Image"), ), inputs=[tpd_mode], outputs=[tpd_csv_group, tpd_image_group], ) # CSV wiring tpd_outputs = [ tpd_ood_banner, tpd_probs, tpd_state, tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon, ] tpd_btn.click( analyze_tpd, inputs=[tpd_files, tpd_betas, tpd_nsamples], outputs=tpd_outputs, ) tpd_mech_dd.change( _on_tpd_mechanism_change, inputs=[tpd_mech_dd, tpd_state], outputs=[tpd_posteriors, tpd_param_table, tpd_recon], ) if TPD_EXAMPLES: tpd_example_btn.click( _load_tpd_example, inputs=[tpd_example_dd], outputs=[tpd_files, tpd_betas], ) # Image wiring (hybrid: 7 base outputs + comparison + preview) tpd_img_outputs = [ tpd_img_ood_banner, tpd_img_probs, tpd_img_state, tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon, tpd_img_comparison, tpd_img_preview, ] tpd_img_btn.click( analyze_tpd_image, inputs=[ tpd_img_files, tpd_img_betas, tpd_img_threshold, tpd_img_nsamples, tpd_img_xmin, tpd_img_xmax, tpd_img_ymin, tpd_img_ymax, tpd_img_method, ], outputs=tpd_img_outputs, ) tpd_img_mech_dd.change( _on_tpd_mechanism_change, inputs=[tpd_img_mech_dd, tpd_img_state], outputs=[tpd_img_posteriors, tpd_img_param_table, tpd_img_recon], ) if TPD_IMG_EXAMPLES: def _tpd_img_input_gallery(mech_name): if not mech_name or mech_name not in TPD_IMG_EXAMPLES: return [] return [e[0] for e in TPD_IMG_EXAMPLES[mech_name]] def _on_tpd_img_example_select(mech_name): files, betas = _load_tpd_image_example(mech_name) return files, betas, _tpd_img_input_gallery(mech_name) tpd_img_example_btn.click( _on_tpd_img_example_select, inputs=[tpd_img_example_dd], outputs=[tpd_img_files, tpd_img_betas, tpd_img_example_gallery], ) tpd_img_example_dd.change( _tpd_img_input_gallery, inputs=[tpd_img_example_dd], outputs=[tpd_img_example_gallery], ) # ================================================================= # Tab 3: About # ================================================================= with gr.Tab("About"): with gr.Row(elem_classes=["about-card"]): with gr.Column(scale=2): gr.Markdown(""" ## How it works **SPARK** (Simulation-based Posterior Amortization for Reaction Kinetics) uses **conditional normalizing flows** with a **Set Transformer** encoder to perform amortized Bayesian inference. Given one or more experimental curves, it simultaneously classifies the reaction mechanism and produces full posterior distributions over kinetic parameters — in a single forward pass. The deployed checkpoints are the **noise-augmented headline models** (CV: `v14_9mech`, TPD: `tpd_11mech_v2`) that retain 92.4 % (CV) and 95.6 % (TPD) classification accuracy under realistic measurement noise. Inference runs in **~50 ms on CPU**; mean 90 %-credible-interval coverage is 92.2 % (CV) and 92.9 % (TPD). Training data is generated from physics-based simulators (Crank–Nicolson for CV, ODE integrators for TPD). Posteriors are calibrated via a coverage-aware loss with per-parameter inverse-spread weighting. ### Citation ``` Yan, B. (2026). SPARK: Amortized Bayesian Inference for Mechanism Identification and Parameter Estimation in Electrochemistry and Catalysis via Conditional Normalizing Flows. [Preprint] ``` """) with gr.Column(scale=1): gr.Markdown(""" ### Supported mechanisms **Electrochemistry (CV, 9):** Nernst, Butler–Volmer, Marcus–Hush–Chidsey, Adsorption, EC, Langmuir–Hinshelwood, EE, EC′, CE. **Catalysis (TPD, 11):** First-order, Second-order, Zeroth-order, FirstOrderCovDep, DiffLimited, Precursor, Dissociative, ActivatedAdsorption, LH Surface, Mars–van Krevelen, TwoSite. """) gr.HTML( "" ) gr.HTML("
") return app # ========================================================================= # Entry point # ========================================================================= if __name__ == "__main__": app = build_app() app.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, )