| """ |
| SPARK (Simulation-based Posterior Amortization for Reaction Kinetics) |
| |
| Gradio web interface for mechanism classification and parameter inference |
| from cyclic voltammetry (CV) and temperature-programmed desorption (TPD) data. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import tempfile |
| from pathlib import Path |
|
|
| import numpy as np |
| import gradio as gr |
|
|
|
|
|
|
| from inference import SPARKPredictor |
| from preprocessing import ( |
| nondimensionalize_cv, |
| estimate_E0, |
| parse_cv_csv, |
| parse_tpd_csv, |
| ) |
| from plotting import ( |
| plot_mechanism_probs, |
| plot_posteriors, |
| plot_parameter_table, |
| plot_reconstruction, |
| plot_concentration_profiles, |
| ) |
|
|
| |
| |
| |
| REPO_ROOT = Path(__file__).resolve().parent |
| DEMO_DIR = REPO_ROOT / "demo_data" |
| DEMO_RENDERS = REPO_ROOT / "demo_renders" |
|
|
| EC_CHECKPOINT = REPO_ROOT / "checkpoints" / "ec_best.pt" |
| TPD_CHECKPOINT = REPO_ROOT / "checkpoints" / "tpd_best.pt" |
|
|
| |
| |
| |
| EC_CHECKPOINT = Path( |
| os.environ.get( |
| "SPARK_EC_CHECKPOINT", |
| os.environ.get("ECFLOW_EC_CHECKPOINT", str(EC_CHECKPOINT)), |
| ) |
| ) |
| TPD_CHECKPOINT = Path( |
| os.environ.get( |
| "SPARK_TPD_CHECKPOINT", |
| os.environ.get("ECFLOW_TPD_CHECKPOINT", str(TPD_CHECKPOINT)), |
| ) |
| ) |
|
|
| |
| |
| |
| EC_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_EC_IMAGE_CHECKPOINT", "") |
| TPD_IMAGE_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_IMAGE_CHECKPOINT", "") |
| EC_IMAGE_CHECKPOINT = ( |
| Path(EC_IMAGE_CHECKPOINT_PATH) if EC_IMAGE_CHECKPOINT_PATH else None |
| ) |
| TPD_IMAGE_CHECKPOINT = ( |
| Path(TPD_IMAGE_CHECKPOINT_PATH) if TPD_IMAGE_CHECKPOINT_PATH else None |
| ) |
|
|
| |
| |
| |
| |
| EC_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_EC_JOINT_CHECKPOINT", "") |
| TPD_JOINT_CHECKPOINT_PATH = os.environ.get("SPARK_TPD_JOINT_CHECKPOINT", "") |
| EC_JOINT_CHECKPOINT = ( |
| Path(EC_JOINT_CHECKPOINT_PATH) if EC_JOINT_CHECKPOINT_PATH else None |
| ) |
| TPD_JOINT_CHECKPOINT = ( |
| Path(TPD_JOINT_CHECKPOINT_PATH) if TPD_JOINT_CHECKPOINT_PATH else None |
| ) |
|
|
| |
| |
| |
|
|
| def _discover_examples(): |
| """Scan demo_data/ for metadata files and build example catalogs.""" |
| cv_examples = {} |
| tpd_examples = {} |
| if not DEMO_DIR.is_dir(): |
| return cv_examples, tpd_examples |
| for meta_path in sorted(DEMO_DIR.glob("*_metadata.json")): |
| with open(meta_path) as f: |
| meta = json.load(f) |
| mech = meta["mechanism"] |
| csv_files = [str(DEMO_DIR / fn) for fn in meta["csv_files"]] |
| is_ood = bool(meta.get("is_ood", False)) |
| ood_tag = " *OOD*" if is_ood else "" |
| if meta_path.name.startswith("ec_"): |
| rates = meta.get("scan_rates_Vs", []) |
| rates_str = ", ".join(f"{r:.4g}" for r in rates) |
| phys = meta.get("physical_params", {}) |
| cv_examples[f"CV \u2014 {mech}{ood_tag}"] = { |
| "files": csv_files, |
| "scan_rates": rates_str, |
| "is_ood": is_ood, |
| "E0_V": phys.get("E0_V"), |
| "T_K": phys.get("T_K", 298.15), |
| "A_cm2": phys.get("A_cm2", 0.0707), |
| "C_mM": phys.get("C_mM", 1.0), |
| "D_cm2s": phys.get("D_cm2s", 1e-5), |
| "n_electrons": phys.get("n_electrons", 1), |
| } |
| elif meta_path.name.startswith("tpd_"): |
| betas = meta.get("betas_Ks", []) |
| betas_str = ", ".join(f"{b:.4g}" for b in betas) |
| tpd_examples[f"TPD \u2014 {mech}{ood_tag}"] = { |
| "files": csv_files, |
| "heating_rates": betas_str, |
| "is_ood": is_ood, |
| } |
| return cv_examples, tpd_examples |
|
|
|
|
| CV_EXAMPLES, TPD_EXAMPLES = _discover_examples() |
|
|
|
|
| def _discover_image_examples(): |
| """Build image example catalogs from demo_renders/ directory.""" |
| import re |
| cv_img_examples = {} |
| tpd_img_examples = {} |
| cv_output_renders = {} |
| tpd_output_renders = {} |
|
|
| if not DEMO_RENDERS.is_dir(): |
| return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders |
|
|
| for p in sorted(DEMO_RENDERS.glob("ec_*_physical.png")): |
| m = re.match(r"ec_(\w+?)_(\d+)mVs_physical\.png", p.name) |
| if not m: |
| continue |
| mech, rate_mVs = m.group(1), int(m.group(2)) |
| rate_Vs = rate_mVs / 1000.0 |
| cv_img_examples.setdefault(mech, []).append((str(p), rate_Vs)) |
|
|
| for p in sorted(DEMO_RENDERS.glob("tpd_*_physical.png")): |
| m = re.match(r"tpd_(\w+?)_(\d+)_physical\.png", p.name) |
| if not m: |
| continue |
| mech, idx = m.group(1), m.group(2) |
| tpd_img_examples.setdefault(mech, []).append((str(p), idx)) |
|
|
| for mech in cv_img_examples: |
| renders = {} |
| for suffix in ["classification", "posteriors", "reconstruction", "concentration"]: |
| rp = DEMO_RENDERS / f"ec_{mech}_{suffix}.png" |
| if rp.exists(): |
| renders[suffix] = str(rp) |
| if renders: |
| cv_output_renders[mech] = renders |
|
|
| for mech in tpd_img_examples: |
| renders = {} |
| for suffix in ["classification", "posteriors", "reconstruction"]: |
| rp = DEMO_RENDERS / f"tpd_{mech}_{suffix}.png" |
| if rp.exists(): |
| renders[suffix] = str(rp) |
| if renders: |
| tpd_output_renders[mech] = renders |
|
|
| return cv_img_examples, tpd_img_examples, cv_output_renders, tpd_output_renders |
|
|
|
|
| (CV_IMG_EXAMPLES, TPD_IMG_EXAMPLES, |
| CV_OUTPUT_RENDERS, TPD_OUTPUT_RENDERS) = _discover_image_examples() |
|
|
|
|
| def _load_cv_image_example(mech_name): |
| """Return (files, scan_rates_str) for a CV image example.""" |
| if not mech_name or mech_name not in CV_IMG_EXAMPLES: |
| return [gr.update()] * 2 |
| entries = CV_IMG_EXAMPLES[mech_name] |
| files = [e[0] for e in entries] |
| rates_str = ", ".join(f"{e[1]}" for e in entries) |
| return files, rates_str |
|
|
|
|
| def _load_tpd_image_example(mech_name): |
| """Return (image_path, heating_rate_str) for a TPD image example.""" |
| if not mech_name or mech_name not in TPD_IMG_EXAMPLES: |
| return [gr.update()] * 2 |
| entries = TPD_IMG_EXAMPLES[mech_name] |
| return entries[0][0], entries[0][1] |
|
|
|
|
| def _load_cv_example(example_name): |
| """Return (files, scan_rates, E0, T, A, C, D, n) for the chosen CV example.""" |
| if not example_name or example_name not in CV_EXAMPLES: |
| return [gr.update()] * 8 |
| ex = CV_EXAMPLES[example_name] |
| return ( |
| ex["files"], |
| ex["scan_rates"], |
| ex["E0_V"], |
| ex["T_K"], |
| ex["A_cm2"], |
| ex["C_mM"], |
| ex["D_cm2s"], |
| ex["n_electrons"], |
| ) |
|
|
|
|
| def _load_tpd_example(example_name): |
| """Return (files, heating_rates) for the chosen TPD example.""" |
| if not example_name or example_name not in TPD_EXAMPLES: |
| return [gr.update()] * 2 |
| ex = TPD_EXAMPLES[example_name] |
| return ( |
| ex["files"], |
| ex["heating_rates"], |
| ) |
|
|
|
|
| predictor = None |
|
|
|
|
| def get_predictor(): |
| global predictor |
| if predictor is None: |
| ec_ckpt = str(EC_CHECKPOINT) if EC_CHECKPOINT.exists() else None |
| tpd_ckpt = str(TPD_CHECKPOINT) if TPD_CHECKPOINT.exists() else None |
| ec_img = (str(EC_IMAGE_CHECKPOINT) |
| if (EC_IMAGE_CHECKPOINT and EC_IMAGE_CHECKPOINT.exists()) |
| else None) |
| tpd_img = (str(TPD_IMAGE_CHECKPOINT) |
| if (TPD_IMAGE_CHECKPOINT and TPD_IMAGE_CHECKPOINT.exists()) |
| else None) |
| ec_joint = (str(EC_JOINT_CHECKPOINT) |
| if (EC_JOINT_CHECKPOINT and EC_JOINT_CHECKPOINT.exists()) |
| else None) |
| tpd_joint = (str(TPD_JOINT_CHECKPOINT) |
| if (TPD_JOINT_CHECKPOINT and TPD_JOINT_CHECKPOINT.exists()) |
| else None) |
| predictor = SPARKPredictor( |
| ec_checkpoint=ec_ckpt, |
| tpd_checkpoint=tpd_ckpt, |
| ec_image_checkpoint=ec_img, |
| tpd_image_checkpoint=tpd_img, |
| ec_joint_checkpoint=ec_joint, |
| tpd_joint_checkpoint=tpd_joint, |
| device="cpu", |
| ) |
| return predictor |
|
|
|
|
| |
| |
| |
|
|
| def analyze_cv(files, scan_rates_text, E0_V, T_K, A_cm2, |
| C_mM, D_cm2s, n_electrons, n_samples): |
| """Analyze CV data from potentiostat CSV files. |
| |
| Accepts CSV files with columns for potential (V) and current (A/mA/µA). |
| If the CSV includes a Time (s) column, the scan rate is auto-detected. |
| Otherwise, scan rates must be provided. |
| """ |
| if not files: |
| return _ec_error("Please upload at least one CSV file.") |
|
|
| scan_rates_text = scan_rates_text.strip() if scan_rates_text else "" |
|
|
| user_rates = None |
| if scan_rates_text: |
| try: |
| user_rates = [float(s.strip()) for s in scan_rates_text.split(",")] |
| except ValueError: |
| return _ec_error("Invalid scan rates. Enter comma-separated numbers in V/s.") |
| if len(files) != len(user_rates): |
| return _ec_error( |
| f"Number of files ({len(files)}) must match number of " |
| f"scan rates ({len(user_rates)}).") |
|
|
| C_molcm3 = float(C_mM) * 1e-6 if C_mM else 1e-6 |
| n = int(n_electrons) if n_electrons else 1 |
| T = float(T_K) if T_K else 298.15 |
| A = float(A_cm2) if A_cm2 else 0.0707 |
|
|
| parsed_data = [] |
| scan_rates = [] |
| for idx, f in enumerate(files): |
| content = Path(f.name).read_text() |
| parsed = parse_cv_csv(content) |
| parsed_data.append(parsed) |
|
|
| if user_rates is not None: |
| v = user_rates[idx] |
| elif "scan_rate_Vs" in parsed: |
| v = parsed["scan_rate_Vs"] |
| else: |
| return _ec_error( |
| f"Cannot determine scan rate for file '{Path(f.name).name}'. " |
| "Either provide scan rates (V/s) or upload CSV files that " |
| "include a Time (s) column.") |
| scan_rates.append(v) |
|
|
| if E0_V: |
| e0 = float(E0_V) |
| e0_source = "user" |
| else: |
| e0_estimates = [estimate_E0(p["E_V"], p["i_A"]) for p in parsed_data] |
| e0 = float(np.median(e0_estimates)) |
| e0_source = "auto" |
|
|
| D = float(D_cm2s) if D_cm2s else 1e-5 |
|
|
| potentials, fluxes, sigmas_list = [], [], [] |
| for idx, (parsed, v) in enumerate(zip(parsed_data, scan_rates)): |
| E, i_A = parsed["E_V"], parsed["i_A"] |
|
|
| theta, flux, sigma = nondimensionalize_cv( |
| E, i_A, v, e0, T, A, C_molcm3, D, n |
| ) |
| potentials.append(theta) |
| fluxes.append(flux) |
| sigmas_list.append(sigma) |
|
|
| return _run_ec_analysis(potentials, fluxes, sigmas_list, n_samples) |
|
|
|
|
| _CURRENT_UNIT_SCALES = { |
| "µA": 1e-6, |
| "mA": 1e-3, |
| "A": 1.0, |
| "nA": 1e-9, |
| } |
|
|
|
|
| def _guess_current_unit(i_values): |
| """Guess the current unit from the magnitude of digitized values.""" |
| i_max = np.max(np.abs(i_values)) |
| if i_max > 1e3: |
| return "nA" |
| if i_max > 100: |
| return "µA" |
| if i_max > 0.1: |
| return "µA" |
| if i_max > 1e-4: |
| return "mA" |
| return "A" |
|
|
|
|
| def analyze_cv_image(files, scan_rate_text, E0_V, threshold, current_unit, |
| n_samples, x_min, x_max, y_min, y_max, |
| method_label="Ensemble (recommended)"): |
| """Analyze CV from uploaded plot images (one per scan rate). |
| |
| When the image-input SPARK model is available we run a hybrid pipeline: |
| image-mode prediction in parallel with the digitizer + waveform-mode path, |
| and combine via `method_label` (Ensemble / Image-direct / Digitize-then-infer |
| / Auto-fallback). When only the waveform model is available, falls back |
| to digitize-then-infer transparently. |
| |
| Returns 10 outputs: the existing 8 (banner, probs, state, mech_dd, post, |
| table, recon, conc) plus method-comparison HTML + preprocessed-image |
| gallery preview. |
| """ |
| err = _ec_image_error |
| if not files: |
| return err("Please upload at least one image.") |
|
|
| try: |
| from digitizer import digitize_plot, auto_detect_axis_bounds |
| from PIL import Image as PILImage |
| except ImportError: |
| return err("Required libraries not available for image digitization.") |
|
|
| scan_rate_text = scan_rate_text.strip() if scan_rate_text else "" |
| if not scan_rate_text: |
| return err("Please enter the scan rate(s) (V/s), comma-separated.") |
| try: |
| scan_rates = [float(s.strip()) for s in scan_rate_text.split(",")] |
| except ValueError: |
| return err("Invalid scan rates. Enter comma-separated numbers in V/s.") |
|
|
| if len(files) != len(scan_rates): |
| return err( |
| f"Number of images ({len(files)}) must match number of " |
| f"scan rates ({len(scan_rates)}).") |
|
|
| has_user_bounds = all( |
| v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] |
| ) |
|
|
| D = 1e-5 |
| T = 298.15 |
| A = 0.0707 |
| C_molcm3 = 1e-6 |
| n = 1 |
|
|
| |
| image_data = [] |
| e0_estimates = [] |
| for idx, f in enumerate(files): |
| img_arr = np.array(PILImage.open(f.name).convert("RGB")) |
| v_Vs = scan_rates[idx] |
|
|
| if has_user_bounds: |
| bounds = { |
| "x_min": float(x_min), "x_max": float(x_max), |
| "y_min": float(y_min), "y_max": float(y_max), |
| } |
| else: |
| bounds = auto_detect_axis_bounds(img_arr) |
| if bounds is None: |
| return err( |
| f"Could not auto-detect axis bounds for image {idx + 1}. " |
| "Please enter E min, E max, I min, I max under " |
| "'Axis overrides'.") |
|
|
| try: |
| E_V, I_raw = digitize_plot( |
| img_arr, bounds["x_min"], bounds["x_max"], |
| bounds["y_min"], bounds["y_max"], |
| threshold=int(threshold), |
| x_ticks=bounds.get("x_ticks"), |
| y_ticks=bounds.get("y_ticks"), |
| ) |
| except Exception as exc: |
| return err(f"Digitization failed for image {idx + 1}: {exc}") |
|
|
| if current_unit and current_unit != "auto": |
| i_unit = current_unit |
| elif "y_unit" in bounds: |
| i_unit = bounds["y_unit"] |
| else: |
| i_unit = _guess_current_unit(I_raw) |
|
|
| i_scale = _CURRENT_UNIT_SCALES.get(i_unit, 1e-6) |
| i_A = I_raw * i_scale |
|
|
| e0_estimates.append(float(estimate_E0(E_V, i_A))) |
| image_data.append((E_V, i_A, v_Vs, i_unit, bounds)) |
|
|
| |
| if E0_V is not None and E0_V != 0: |
| e0 = float(E0_V) |
| e0_source = "user" |
| else: |
| e0 = float(np.median(e0_estimates)) |
| e0_source = "auto" |
|
|
| |
| potentials, fluxes, sigmas_list = [], [], [] |
| preproc_parts = [] |
| for E_V, i_A, v_Vs, i_unit, bounds in image_data: |
| theta, flux, sigma = nondimensionalize_cv( |
| E_V, i_A, v_Vs, e0, T, A, C_molcm3, D, n |
| ) |
| potentials.append(theta) |
| fluxes.append(flux) |
| sigmas_list.append(sigma) |
|
|
| preproc_parts.append( |
| f"{v_Vs*1000:.1f} mV/s (σ={sigma:.2f}, " |
| f"E=[{bounds['x_min']:.3f}, {bounds['x_max']:.3f}] V, " |
| f"I=[{bounds['y_min']:.2f}, {bounds['y_max']:.2f}] {i_unit})" |
| ) |
|
|
| pil_images = [PILImage.open(f.name).convert("L") for f in files] |
| return _run_ec_analysis_hybrid( |
| pil_images, sigmas_list, potentials, fluxes, n_samples, method_label, |
| ) |
|
|
|
|
| OOD_THRESHOLD = float(os.environ.get("SPARK_OOD_THRESHOLD", "0.5")) |
|
|
|
|
| def _ood_banner_update(ood_score, threshold=OOD_THRESHOLD): |
| """Build the gr.HTML update for the OOD banner. |
| |
| Shown only when the binary OOD head's P(in-distribution) score drops |
| below `threshold`. Hidden otherwise. |
| """ |
| if ood_score is None or ood_score >= threshold: |
| return gr.update(visible=False) |
| html = ( |
| "<div class='ood-banner'>" |
| "<strong>Possible out-of-distribution input.</strong> " |
| f"OOD score = {ood_score:.2f} (below {threshold:.2f}). " |
| "The mechanism prediction below is the closest match within the " |
| "trained mechanism set; treat it as low-confidence and consider " |
| "whether the true mechanism may lie outside the catalog." |
| "</div>" |
| ) |
| return gr.update(value=html, visible=True) |
|
|
|
|
| def _ec_error(msg=""): |
| """Return empty outputs for EC error cases.""" |
| return ( |
| gr.update(visible=False), |
| None, None, gr.Dropdown(choices=[], value=None), None, None, None, None, |
| ) |
|
|
|
|
| def _run_ec_analysis(potentials, fluxes, sigmas, n_samples): |
| """Core EC analysis: predict + reconstruct for top mechanism.""" |
| pred = get_predictor() |
| result = pred.predict_ec(potentials, fluxes, sigmas, n_samples=int(n_samples)) |
| top_mech = result["predicted_mechanism"] |
| recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas) |
|
|
| fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="ec") |
|
|
| sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) |
| mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] |
|
|
| state = { |
| "result": result, |
| "potentials": [p.tolist() for p in potentials], |
| "fluxes": [f.tolist() for f in fluxes], |
| "sigmas": sigmas, |
| } |
|
|
| fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism( |
| top_mech, result, recon, sigmas |
| ) |
|
|
| return ( |
| _ood_banner_update(result.get("ood_score")), |
| fig_probs, state, |
| gr.Dropdown(choices=mech_choices, value=mech_choices[0]), |
| fig_post, fig_table, fig_recon, fig_conc, |
| ) |
|
|
|
|
| _HYBRID_MODE_LABELS = { |
| "Ensemble (recommended)": "ensemble", |
| "Image-direct": "image_only", |
| "Digitize-then-infer": "digitize_only", |
| "Auto-fallback": "auto_fallback", |
| } |
|
|
| _JOINT_LABEL = "Joint image+waveform (Phase 2)" |
|
|
|
|
| def _hybrid_method_to_internal(label): |
| """Map the user-facing radio label to the internal mode key.""" |
| if not label: |
| return "ensemble" |
| if label == _JOINT_LABEL: |
| return "joint" |
| return _HYBRID_MODE_LABELS.get(label, "ensemble") |
|
|
|
|
| def _hybrid_choices_and_default(domain): |
| """Return (radio_choices, default_label) for the inference-strategy radio, |
| inserting the Phase 2 joint option (and making it default) when the |
| corresponding joint checkpoint is loaded. |
| |
| Rationale for ``Ensemble (recommended)`` as the pre-Phase-2 default: |
| head-to-head eval (``outputs/image_vs_digitize/decision.md``) showed |
| that on the rendered held-out test split image-only beats digitize- |
| then-infer by ~30 pp (CV: 40.9% vs 5.4%, TPD: 36.1% vs 2.8%), but on |
| real-world-style stress images the ordering FLIPS (CV: 6.7% vs 26.7%, |
| TPD: 0% vs 60%). Ensemble averages both posteriors, so it's robust to |
| either failure mode. The default automatically switches to |
| ``Joint image+waveform (Phase 2)`` once the joint checkpoint is |
| available. |
| """ |
| pred = get_predictor() |
| has_joint = (pred.has_ec_joint_model if domain == "ec" |
| else pred.has_tpd_joint_model) |
| base = list(_HYBRID_MODE_LABELS.keys()) |
| if has_joint: |
| return [_JOINT_LABEL] + base, _JOINT_LABEL |
| return base, "Ensemble (recommended)" |
|
|
|
|
| def _method_comparison_html(hybrid_out, threshold=OOD_THRESHOLD): |
| """Render the Method-comparison panel. |
| |
| Shows per-method top mechanism, top probability, OOD score, and an |
| agreement badge. Hidden when only one method ran (no comparison to make). |
| Includes a third 'Joint (Phase 2)' card when the joint encoder ran. |
| """ |
| img = hybrid_out.get("image_mode") |
| wave = hybrid_out.get("waveform_mode") |
| joint = hybrid_out.get("joint_mode") |
| method_used = hybrid_out.get("method_used", "") |
| available = [m for m in (img, wave, joint) if m is not None] |
| if len(available) < 2: |
| return gr.update(visible=False) |
|
|
| def _ood_html(score): |
| if score is None: |
| return "—" |
| col = "#1b5e20" if score >= threshold else "#7a0000" |
| return (f"<span style='color:{col};font-weight:600'>" |
| f"{score:.2f}</span>") |
|
|
| def _card(title, result): |
| top = max(result["mechanism_probs"], key=result["mechanism_probs"].get) |
| return f""" |
| <div style='flex:1;min-width:220px;border:1px solid #e0e0e0;border-radius:8px; |
| padding:10px 12px;background:white;'> |
| <div style='font-weight:600;color:#333;margin-bottom:4px;'>{title}</div> |
| <div>top mech: <strong>{top}</strong> |
| ({result['mechanism_probs'][top]:.1%})</div> |
| <div>OOD score: {_ood_html(result.get('ood_score'))}</div> |
| </div>""" |
|
|
| cards = [] |
| if img is not None: |
| cards.append(_card("Image-direct (Phase 1)", img)) |
| if wave is not None: |
| cards.append(_card("Digitize-then-infer", wave)) |
| if joint is not None: |
| cards.append(_card("Joint image+waveform (Phase 2)", joint)) |
|
|
| tops = [max(r["mechanism_probs"], key=r["mechanism_probs"].get) |
| for r in available] |
| all_agree = len(set(tops)) == 1 |
| if all_agree: |
| badge = ("<span style='display:inline-block;padding:2px 10px;border-radius:8px;" |
| "background:#dff5e1;color:#1b5e20;font-weight:600;font-size:13px;'>" |
| "methods agree</span>") |
| else: |
| badge = ("<span style='display:inline-block;padding:2px 10px;border-radius:8px;" |
| "background:#ffe2c2;color:#7a3c00;font-weight:600;font-size:13px;'>" |
| "methods disagree \u2014 review carefully</span>") |
|
|
| footnote = ( |
| "Image-direct (Phase 1) is the image-only SPARK; Digitize-then-infer " |
| "uses the headline waveform model on a curve extracted from the plot; " |
| "Joint (Phase 2) fuses both signals inside a single encoder trained " |
| "with real-world distortion augmentation. When available, the joint " |
| "model is the recommended primary path." |
| ) |
|
|
| html = f""" |
| <div style='border:1px solid #d0d0d0;border-radius:10px;padding:14px 16px; |
| background:#fafafa;margin:8px 0;font-size:14px;line-height:1.5;'> |
| <div style='display:flex;justify-content:space-between;align-items:center;margin-bottom:8px;'> |
| <strong>Method comparison</strong> |
| <span style='color:#555;font-size:12px;'>headline: {method_used}</span> |
| </div> |
| <div style='display:flex;gap:12px;flex-wrap:wrap;'> |
| {''.join(cards)} |
| </div> |
| <div style='margin-top:8px;'>{badge}</div> |
| <div style='margin-top:6px;color:#666;font-size:12px;'>{footnote}</div> |
| </div> |
| """ |
| return gr.update(value=html, visible=True) |
|
|
|
|
| def _preprocessing_preview(hybrid_out, original_pils): |
| """Build a gallery preview showing what image-mode actually consumed |
| after auto-crop + gridline removal.""" |
| pre_meta = hybrid_out.get("preprocessing_meta") or [] |
| if not pre_meta or hybrid_out.get("image_mode") is None: |
| return gr.update(visible=False) |
| try: |
| from image_preprocessing import prepare_for_image_mode |
| previews = [] |
| for orig, meta in zip(original_pils, pre_meta): |
| prep, _ = prepare_for_image_mode(orig) |
| tag_parts = [] |
| if meta.get("was_cropped"): |
| tag_parts.append("cropped") |
| if meta.get("was_cleaned"): |
| tag_parts.append( |
| f"removed {meta.get('n_horiz_gridlines', 0)} horiz " |
| f"+ {meta.get('n_vert_gridlines', 0)} vert lines" |
| ) |
| tag = " | ".join(tag_parts) if tag_parts else "no changes needed" |
| previews.append((prep, tag)) |
| return gr.update(value=previews, visible=True) |
| except Exception as exc: |
| print(f"[preview] failed: {exc}") |
| return gr.update(visible=False) |
|
|
|
|
| def _ec_image_error(msg=""): |
| """Return empty outputs for EC image-tab error cases (10 outputs).""" |
| return ( |
| gr.update(visible=False), |
| None, None, |
| gr.Dropdown(choices=[], value=None), |
| None, None, None, None, |
| gr.update(visible=False), |
| gr.update(visible=False), |
| ) |
|
|
|
|
| def _run_ec_analysis_hybrid(pil_images, sigmas, potentials, fluxes, |
| n_samples, method_label): |
| """Hybrid CV analysis: runs image-mode + digitizer-mode in parallel |
| via predict_ec_hybrid and renders both the headline result (selected |
| by `method_label`) and the per-method comparison panel. |
| |
| Returns a 10-tuple matching `ec_img_outputs` in the UI wiring. |
| """ |
| pred = get_predictor() |
| mode = _hybrid_method_to_internal(method_label) |
|
|
| try: |
| hybrid_out = pred.predict_ec_hybrid( |
| pil_images, sigmas, |
| potentials=potentials, fluxes=fluxes, |
| n_samples=int(n_samples), mode=mode, |
| ) |
| except Exception as exc: |
| print(f"[hybrid CV] failed: {exc}") |
| return _ec_image_error() |
|
|
| headline = hybrid_out["headline"] |
| top_mech = headline["predicted_mechanism"] |
| recon = pred.reconstruct_ec(headline, potentials, fluxes, sigmas) |
|
|
| fig_probs = plot_mechanism_probs(headline["mechanism_probs"], domain="ec") |
| sorted_mechs = sorted(headline["mechanism_probs"].items(), key=lambda x: -x[1]) |
| mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] |
| state = { |
| "result": headline, |
| "potentials": [p.tolist() for p in potentials], |
| "fluxes": [f.tolist() for f in fluxes], |
| "sigmas": sigmas, |
| } |
| fig_post, fig_table, fig_recon, fig_conc = _render_ec_mechanism( |
| top_mech, headline, recon, sigmas |
| ) |
| comparison = _method_comparison_html(hybrid_out) |
| preview = _preprocessing_preview(hybrid_out, pil_images) |
| return ( |
| _ood_banner_update(headline.get("ood_score")), |
| fig_probs, state, |
| gr.Dropdown(choices=mech_choices, value=mech_choices[0]), |
| fig_post, fig_table, fig_recon, fig_conc, |
| comparison, preview, |
| ) |
|
|
|
|
| def _render_ec_mechanism(mech, result, recon, sigmas): |
| """Render posteriors, param table, reconstruction, and concentration for one EC mechanism.""" |
| stats = result["parameter_stats"].get(mech) |
| samples = result["posterior_samples"].get(mech) |
|
|
| fig_posteriors = None |
| fig_table = None |
| if stats and samples is not None: |
| fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="ec") |
| fig_table = plot_parameter_table(stats, mech) |
|
|
| fig_recon = None |
| fig_conc = None |
| if recon is not None: |
| scan_labels = [f"\u03c3 = {s:.2f}" for s in sigmas] if sigmas else None |
| fig_recon = plot_reconstruction( |
| recon["observed"], recon["reconstructed"], domain="ec", |
| nrmses=recon.get("nrmse"), r2s=recon.get("r2"), |
| scan_labels=scan_labels, |
| ) |
| conc_curves = recon.get("concentrations") |
| if conc_curves: |
| fig_conc = plot_concentration_profiles(conc_curves, scan_labels=scan_labels) |
|
|
| return fig_posteriors, fig_table, fig_recon, fig_conc |
|
|
|
|
| def _on_ec_mechanism_change(mech_choice, state): |
| """Callback when user selects a different EC mechanism from the dropdown.""" |
| if not state or not mech_choice: |
| return None, None, None, None |
|
|
| mech = mech_choice.split(" (")[0] |
| result = state["result"] |
| potentials = [np.array(p) for p in state["potentials"]] |
| fluxes = [np.array(f) for f in state["fluxes"]] |
| sigmas = state["sigmas"] |
|
|
| pred = get_predictor() |
| recon = pred.reconstruct_ec(result, potentials, fluxes, sigmas, mechanism=mech) |
| return _render_ec_mechanism(mech, result, recon, sigmas) |
|
|
|
|
| |
| |
| |
|
|
| def analyze_tpd(files, heating_rates_text, n_samples): |
| """Analyze TPD data.""" |
| if not files: |
| return _tpd_error("Please upload at least one CSV file.") |
|
|
| temperatures, rates = [], [] |
| csv_betas = [] |
| for f in files: |
| content = Path(f.name).read_text() |
| parsed = parse_tpd_csv(content) |
| temperatures.append(parsed["T_K"]) |
| rates.append(parsed["signal"]) |
| if "beta_Ks" in parsed: |
| csv_betas.append(parsed["beta_Ks"]) |
|
|
| heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" |
| if heating_rates_text: |
| try: |
| betas = [float(s.strip()) for s in heating_rates_text.split(",")] |
| except ValueError: |
| return _tpd_error("Invalid heating rates. Enter comma-separated numbers in K/s.") |
| if len(files) != len(betas): |
| return _tpd_error( |
| f"Number of files ({len(files)}) must match heating rates ({len(betas)}).") |
| elif len(csv_betas) == len(files): |
| betas = csv_betas |
| else: |
| return _tpd_error( |
| "Please enter the heating rate (β in K/s) for each file. " |
| "This value is critical for correct inference. " |
| "Alternatively, include a 'Time (s)' column in your CSV so β can be computed automatically.") |
|
|
| return _run_tpd_analysis(temperatures, rates, betas, n_samples) |
|
|
|
|
| def analyze_tpd_image(files, heating_rates_text, threshold, n_samples, |
| x_min, x_max, y_min, y_max, |
| method_label="Ensemble (recommended)"): |
| """Analyze TPD from uploaded plot images (one per heating rate). |
| |
| Hybrid: image-mode + digitizer-then-waveform run in parallel and are |
| combined per `method_label`. Axis bounds auto-detected via OCR. |
| Returns 9 outputs: existing 7 + method-comparison HTML + preview gallery. |
| """ |
| err = _tpd_image_error |
| if not files: |
| return err("Please upload at least one image.") |
|
|
| try: |
| from digitizer import digitize_plot, auto_detect_axis_bounds |
| from PIL import Image as PILImage |
| except ImportError: |
| return err("Required libraries not available for image digitization.") |
|
|
| heating_rates_text = heating_rates_text.strip() if heating_rates_text else "" |
| if not heating_rates_text: |
| return err( |
| "Please enter the heating rate(s) (β in K/s), comma-separated. " |
| "This value is critical for correct inference.") |
| try: |
| betas = [float(s.strip()) for s in heating_rates_text.split(",")] |
| except ValueError: |
| return err("Invalid heating rates. Enter comma-separated numbers in K/s.") |
|
|
| if len(files) != len(betas): |
| return err( |
| f"Number of images ({len(files)}) must match number of " |
| f"heating rates ({len(betas)}).") |
|
|
| has_user_bounds = all( |
| v is not None and v != 0 for v in [x_min, x_max, y_min, y_max] |
| ) |
|
|
| temperatures, rates = [], [] |
| for idx, f in enumerate(files): |
| img_arr = np.array(PILImage.open(f.name).convert("RGB")) |
|
|
| if has_user_bounds: |
| bounds = { |
| "x_min": float(x_min), "x_max": float(x_max), |
| "y_min": float(y_min), "y_max": float(y_max), |
| } |
| else: |
| bounds = auto_detect_axis_bounds(img_arr) |
| if bounds is None: |
| return err( |
| f"Could not auto-detect axis bounds for image {idx + 1}. " |
| "Please enter T min, T max, Signal min, Signal max " |
| "under 'Axis overrides'.") |
|
|
| try: |
| x_data, y_data = digitize_plot( |
| img_arr, bounds["x_min"], bounds["x_max"], |
| bounds["y_min"], bounds["y_max"], |
| threshold=int(threshold), |
| x_ticks=bounds.get("x_ticks"), |
| y_ticks=bounds.get("y_ticks"), |
| ) |
| except Exception as exc: |
| return err(f"Digitization failed for image {idx + 1}: {exc}") |
|
|
| temperatures.append(x_data) |
| rates.append(y_data) |
|
|
| pil_images = [PILImage.open(f.name).convert("L") for f in files] |
| return _run_tpd_analysis_hybrid( |
| pil_images, betas, temperatures, rates, n_samples, method_label, |
| ) |
|
|
|
|
| def _tpd_error(msg=""): |
| """Return empty outputs for TPD error cases (CSV path: 7 outputs).""" |
| return ( |
| gr.update(visible=False), |
| None, None, gr.Dropdown(choices=[], value=None), None, None, None, |
| ) |
|
|
|
|
| def _tpd_image_error(msg=""): |
| """Return empty outputs for TPD image-tab errors (9 outputs).""" |
| return ( |
| gr.update(visible=False), |
| None, None, gr.Dropdown(choices=[], value=None), None, None, None, |
| gr.update(visible=False), |
| gr.update(visible=False), |
| ) |
|
|
|
|
| def _run_tpd_analysis(temperatures, rates, betas, n_samples): |
| """Core TPD analysis: predict + reconstruct for top mechanism.""" |
| pred = get_predictor() |
| result = pred.predict_tpd(temperatures, rates, betas, n_samples=int(n_samples)) |
| top_mech = result["predicted_mechanism"] |
| recon = pred.reconstruct_tpd(result, temperatures, rates, betas) |
|
|
| fig_probs = plot_mechanism_probs(result["mechanism_probs"], domain="tpd") |
|
|
| sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1]) |
| mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] |
|
|
| state = { |
| "result": result, |
| "temperatures": [t.tolist() for t in temperatures], |
| "rates": [r.tolist() for r in rates], |
| "betas": betas, |
| } |
|
|
| fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, result, recon, betas) |
|
|
| return ( |
| _ood_banner_update(result.get("ood_score")), |
| fig_probs, state, |
| gr.Dropdown(choices=mech_choices, value=mech_choices[0]), |
| fig_post, fig_table, fig_recon, |
| ) |
|
|
|
|
| def _run_tpd_analysis_hybrid(pil_images, betas, temperatures, rates, |
| n_samples, method_label): |
| """Hybrid TPD analysis: image-mode + digitizer-then-waveform combined |
| via predict_tpd_hybrid. Returns 9 outputs to match `tpd_img_outputs`.""" |
| pred = get_predictor() |
| mode = _hybrid_method_to_internal(method_label) |
|
|
| try: |
| hybrid_out = pred.predict_tpd_hybrid( |
| pil_images, betas, |
| temperatures=temperatures, rates=rates, |
| n_samples=int(n_samples), mode=mode, |
| ) |
| except Exception as exc: |
| print(f"[hybrid TPD] failed: {exc}") |
| return _tpd_image_error() |
|
|
| headline = hybrid_out["headline"] |
| top_mech = headline["predicted_mechanism"] |
| recon = pred.reconstruct_tpd(headline, temperatures, rates, betas) |
|
|
| fig_probs = plot_mechanism_probs(headline["mechanism_probs"], domain="tpd") |
| sorted_mechs = sorted(headline["mechanism_probs"].items(), key=lambda x: -x[1]) |
| mech_choices = [f"{m} ({p:.1%})" for m, p in sorted_mechs] |
| state = { |
| "result": headline, |
| "temperatures": [t.tolist() for t in temperatures], |
| "rates": [r.tolist() for r in rates], |
| "betas": betas, |
| } |
| fig_post, fig_table, fig_recon = _render_tpd_mechanism(top_mech, headline, recon, betas) |
| comparison = _method_comparison_html(hybrid_out) |
| preview = _preprocessing_preview(hybrid_out, pil_images) |
| return ( |
| _ood_banner_update(headline.get("ood_score")), |
| fig_probs, state, |
| gr.Dropdown(choices=mech_choices, value=mech_choices[0]), |
| fig_post, fig_table, fig_recon, |
| comparison, preview, |
| ) |
|
|
|
|
| def _render_tpd_mechanism(mech, result, recon, betas): |
| """Render posteriors, param table, and reconstruction for one TPD mechanism.""" |
| stats = result["parameter_stats"].get(mech) |
| samples = result["posterior_samples"].get(mech) |
|
|
| fig_posteriors = None |
| fig_table = None |
| if stats and samples is not None: |
| fig_posteriors = plot_posteriors(samples, stats["names"], mech, domain="tpd") |
| fig_table = plot_parameter_table(stats, mech) |
|
|
| fig_recon = None |
| if recon is not None: |
| scan_labels = [f"\u03b2 = {b:.2f} K/s" for b in betas] if betas else None |
| fig_recon = plot_reconstruction( |
| recon["observed"], recon["reconstructed"], domain="tpd", |
| nrmses=recon.get("nrmse"), r2s=recon.get("r2"), |
| scan_labels=scan_labels, |
| ) |
|
|
| return fig_posteriors, fig_table, fig_recon |
|
|
|
|
| def _on_tpd_mechanism_change(mech_choice, state): |
| """Callback when user selects a different TPD mechanism from the dropdown.""" |
| if not state or not mech_choice: |
| return None, None, None |
|
|
| mech = mech_choice.split(" (")[0] |
| result = state["result"] |
| temperatures = [np.array(t) for t in state["temperatures"]] |
| rates = [np.array(r) for r in state["rates"]] |
| betas = state["betas"] |
|
|
| pred = get_predictor() |
| recon = pred.reconstruct_tpd(result, temperatures, rates, betas, mechanism=mech) |
| return _render_tpd_mechanism(mech, result, recon, betas) |
|
|
|
|
| |
| |
| |
|
|
| def download_results(result_text): |
| """Create a downloadable JSON from the summary.""" |
| if not result_text: |
| return None |
| tmp = tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") |
| tmp.write(result_text) |
| tmp.close() |
| return tmp.name |
|
|
|
|
| |
| |
| |
|
|
| def _build_ec_output_section(prefix): |
| """Build shared output components for one EC input tab. |
| |
| Returns (ood_banner, probs, state, mech_dd, posteriors, param_table, recon, conc). |
| """ |
| with gr.Group(elem_classes=["results-card"]): |
| gr.HTML("<h3 style='margin:0 0 10px 0;color:#0F172A;font-size:1.05em;'>Results</h3>") |
| ood_banner = gr.HTML(visible=False) |
| probs = gr.Plot(label="Mechanism Classification") |
| state = gr.State(value=None) |
| mech_dd = gr.Dropdown( |
| label="Inspect a mechanism in detail", |
| choices=[], |
| interactive=True, |
| ) |
| with gr.Tabs(elem_classes=["detail-tabs"]): |
| with gr.Tab("Posteriors"): |
| posteriors = gr.Plot(label=None, show_label=False) |
| with gr.Tab("Parameters"): |
| param_table = gr.Plot(label=None, show_label=False) |
| with gr.Tab("Reconstruction"): |
| recon = gr.Plot(label=None, show_label=False) |
| with gr.Tab("Concentration"): |
| conc = gr.Plot(label=None, show_label=False) |
| return ood_banner, probs, state, mech_dd, posteriors, param_table, recon, conc |
|
|
|
|
| def _build_tpd_output_section(prefix): |
| """Build shared output components for one TPD input tab. |
| |
| Returns (ood_banner, probs, state, mech_dd, posteriors, param_table, recon). |
| """ |
| with gr.Group(elem_classes=["results-card"]): |
| gr.HTML("<h3 style='margin:0 0 10px 0;color:#0F172A;font-size:1.05em;'>Results</h3>") |
| ood_banner = gr.HTML(visible=False) |
| probs = gr.Plot(label="Mechanism Classification") |
| state = gr.State(value=None) |
| mech_dd = gr.Dropdown( |
| label="Inspect a mechanism in detail", |
| choices=[], |
| interactive=True, |
| ) |
| with gr.Tabs(elem_classes=["detail-tabs"]): |
| with gr.Tab("Posteriors"): |
| posteriors = gr.Plot(label=None, show_label=False) |
| with gr.Tab("Parameters"): |
| param_table = gr.Plot(label=None, show_label=False) |
| with gr.Tab("Reconstruction"): |
| recon = gr.Plot(label=None, show_label=False) |
| return ood_banner, probs, state, mech_dd, posteriors, param_table, recon |
|
|
|
|
| CUSTOM_CSS = """ |
| /* ---------- Page container & typography ---------- */ |
| .gradio-container { background: #F8FAFC !important; } |
| .trace-page { max-width: 1180px; margin: 0 auto; padding: 8px 24px 32px 24px; } |
| .trace-page * { font-feature-settings: "ss01"; } |
| |
| /* ---------- Header ---------- */ |
| .main-header { text-align: center; padding: 28px 16px 16px 16px; } |
| .main-header h1 { |
| font-size: 2.4em; |
| margin: 0 0 6px 0; |
| letter-spacing: -0.6px; |
| font-weight: 700; |
| color: #0F172A; |
| } |
| .main-header h1 .accent { color: #2563EB; } |
| .main-header p { |
| color: #64748B; |
| font-size: 1.02em; |
| max-width: 720px; |
| margin: 0 auto; |
| line-height: 1.5; |
| } |
| .main-header .accent-rule { |
| height: 3px; |
| width: 56px; |
| margin: 14px auto 0 auto; |
| background: linear-gradient(90deg, #2563EB, #60A5FA); |
| border-radius: 2px; |
| } |
| |
| /* ---------- Cards ---------- */ |
| .input-card, .results-card { |
| background: #FFFFFF !important; |
| border: 1px solid #E5E7EB !important; |
| border-radius: 14px !important; |
| padding: 22px 26px !important; |
| margin-top: 16px !important; |
| box-shadow: 0 1px 2px rgba(15, 23, 42, 0.04) !important; |
| } |
| .results-card { padding-top: 18px !important; } |
| .results-empty { |
| color: #94A3B8; |
| text-align: center; |
| padding: 60px 16px; |
| font-size: 1.0em; |
| } |
| .results-empty .arrow { font-size: 2em; display: block; margin-bottom: 8px; } |
| |
| /* ---------- Helper text ---------- */ |
| .helper { |
| color: #64748B; |
| font-size: 0.93em; |
| line-height: 1.5; |
| margin: 0 0 14px 0; |
| } |
| .helper code { |
| background: #F1F5F9; |
| padding: 1px 6px; |
| border-radius: 4px; |
| font-size: 0.92em; |
| color: #334155; |
| } |
| |
| /* ---------- Mode radio (CSV / Image toggle) ---------- */ |
| .mode-radio { margin: 4px 0 0 0 !important; } |
| .mode-radio label[data-testid="block-label"] { display: none !important; } |
| .mode-radio .wrap { gap: 6px !important; } |
| |
| /* ---------- Top tabs polish ---------- */ |
| .tabs > .tab-nav { |
| border-bottom: 1px solid #E5E7EB !important; |
| padding-left: 4px !important; |
| } |
| .tabs > .tab-nav button { |
| font-size: 1.02em !important; |
| font-weight: 500 !important; |
| color: #64748B !important; |
| padding: 12px 20px !important; |
| border-bottom: 2px solid transparent !important; |
| text-transform: none !important; |
| } |
| .tabs > .tab-nav button.selected { |
| color: #2563EB !important; |
| border-bottom-color: #2563EB !important; |
| } |
| |
| /* ---------- Detail tabs (per-mechanism plot views) ---------- */ |
| .detail-tabs { margin-top: 14px; } |
| .detail-tabs .tab-nav button { font-size: 0.95em !important; padding: 8px 14px !important; } |
| |
| /* ---------- Accordion ---------- */ |
| .gr-accordion .label-wrap { padding: 8px 12px !important; } |
| |
| /* ---------- Buttons ---------- */ |
| .analyze-btn { margin-top: 14px !important; } |
| .analyze-btn button { |
| font-size: 1.05em !important; |
| padding: 12px 18px !important; |
| font-weight: 600 !important; |
| } |
| |
| /* ---------- About tab ---------- */ |
| .about-card { |
| background: #FFFFFF; |
| border: 1px solid #E5E7EB; |
| border-radius: 14px; |
| padding: 26px 30px; |
| margin-top: 16px; |
| box-shadow: 0 1px 2px rgba(15, 23, 42, 0.04); |
| } |
| .about-card h2 { margin-top: 0; font-size: 1.4em; color: #0F172A; } |
| .about-card table { font-size: 0.93em; } |
| .about-card pre { |
| background: #F8FAFC; |
| border: 1px solid #E5E7EB; |
| border-radius: 8px; |
| padding: 14px 16px; |
| font-size: 0.9em; |
| color: #334155; |
| } |
| |
| /* ---------- OOD banner ---------- */ |
| .ood-banner { |
| background: #FEF3C7; |
| border: 1px solid #F59E0B; |
| color: #92400E; |
| border-radius: 10px; |
| padding: 12px 16px; |
| margin: 4px 0 14px 0; |
| font-size: 0.96em; |
| line-height: 1.45; |
| } |
| .ood-banner strong { color: #78350F; } |
| |
| /* ---------- Footer ---------- */ |
| .trace-footer { |
| text-align: center; |
| color: #94A3B8; |
| font-size: 0.88em; |
| padding: 28px 0 8px 0; |
| } |
| .trace-footer a { color: #64748B; text-decoration: none; border-bottom: 1px dotted #CBD5E1; } |
| .trace-footer a:hover { color: #2563EB; border-bottom-color: #2563EB; } |
| |
| /* ---------- Hide Gradio chrome ---------- */ |
| footer { display: none !important; } |
| """ |
|
|
|
|
| def build_app(): |
| with gr.Blocks( |
| title="SPARK — Bayesian Inference for Electrochemistry & Catalysis", |
| theme=gr.themes.Soft( |
| primary_hue="blue", |
| secondary_hue="slate", |
| font=gr.themes.GoogleFont("Inter"), |
| ), |
| css=CUSTOM_CSS, |
| ) as app: |
| gr.HTML("<div class='trace-page'>") |
| gr.HTML( |
| "<div class='main-header'>" |
| "<h1><span class='accent'>SPARK</span></h1>" |
| "<p><strong>S</strong>imulation-based <strong>P</strong>osterior " |
| "<strong>A</strong>mortization for <strong>R</strong>eaction " |
| "<strong>K</strong>inetics — Bayesian inference for " |
| "reaction mechanisms and kinetic parameters from cyclic " |
| "voltammetry (CV) or temperature-programmed desorption (TPD) " |
| "data, in one forward pass.</p>" |
| "<div class='accent-rule'></div>" |
| "</div>" |
| ) |
|
|
| with gr.Tabs(): |
| |
| |
| |
| with gr.Tab("CV Analysis"): |
| cv_mode = gr.Radio( |
| choices=["CSV Data", "From Image"], |
| value="CSV Data", |
| label=None, |
| show_label=False, |
| elem_classes=["mode-radio"], |
| ) |
|
|
| |
| with gr.Column(visible=True) as cv_csv_group: |
| with gr.Group(elem_classes=["input-card"]): |
| gr.HTML( |
| "<p class='helper'>One CSV per scan rate. Expected columns: " |
| "<code>Potential (V)</code>, <code>Current (A/mA/µA)</code>, " |
| "optional <code>Time (s)</code> — if present, the scan rate is " |
| "inferred as the median of <code>|dE/dt|</code> over the forward " |
| "sweep (no need to specify it manually).</p>" |
| ) |
| if CV_EXAMPLES: |
| with gr.Row(): |
| cv_example_dd = gr.Dropdown( |
| label="Try an example", |
| choices=list(CV_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=4, |
| ) |
| cv_example_btn = gr.Button( |
| "Load", variant="secondary", scale=1, |
| ) |
| cv_files = gr.File( |
| label="CSV files (one per scan rate)", |
| file_count="multiple", |
| file_types=[".csv", ".txt"], |
| ) |
| cv_rates = gr.Textbox( |
| label="Scan rates (V/s), comma-separated", |
| placeholder="e.g., 0.01, 0.1, 1.0 (leave empty if CSV has time column)", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| with gr.Row(): |
| cv_E0 = gr.Number( |
| label="Formal potential E\u2080 (V)", |
| value=None, |
| info="Auto-estimated from peak positions if empty", |
| ) |
| cv_T = gr.Number(label="Temperature (K)", value=298.15) |
| cv_A = gr.Number(label="Electrode area (cm\u00b2)", value=0.0707) |
| with gr.Row(): |
| cv_C = gr.Number(label="Concentration (mM)", value=1.0) |
| cv_D = gr.Number( |
| label="Diffusion coeff D (cm\u00b2/s)", |
| value=None, |
| info="Estimated via Randles-\u0160ev\u010d\u00edk if empty", |
| ) |
| cv_n = gr.Number(label="Number of electrons", value=1, precision=0) |
| cv_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| cv_btn = gr.Button( |
| "Analyze", variant="primary", |
| elem_classes=["analyze-btn"], |
| ) |
|
|
| (cv_ood_banner, cv_probs, cv_state, |
| cv_mech_dd, cv_posteriors, cv_param_table, |
| cv_recon, cv_conc) = _build_ec_output_section("cv") |
|
|
| |
| with gr.Column(visible=False) as cv_image_group: |
| with gr.Group(elem_classes=["input-card"]): |
| gr.HTML( |
| "<p class='helper'>One image per scan rate (potential on x-axis, " |
| "current on y-axis). Axis bounds auto-detected via OCR — " |
| "override in Advanced if needed. " |
| "<em>Note:</em> digitized curves are inherently noisier than the " |
| "simulated data the model was trained on, so the OOD detector may " |
| "flag image inputs and reconstruction quality is typically lower " |
| "than the CSV path; for production use, prefer CSV.</p>" |
| ) |
| if CV_IMG_EXAMPLES: |
| with gr.Row(): |
| cv_img_example_dd = gr.Dropdown( |
| label="Try an example", |
| choices=list(CV_IMG_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=4, |
| ) |
| cv_img_example_btn = gr.Button( |
| "Load", variant="secondary", scale=1, |
| ) |
| cv_img_example_gallery = gr.Gallery( |
| label="Input plot images (one per scan rate) \u2014 these are what the model digitizes", |
| columns=3, height=220, object_fit="contain", |
| interactive=False, |
| ) |
| cv_img_files = gr.File( |
| label="Plot images (one per scan rate)", |
| file_count="multiple", |
| file_types=["image"], |
| ) |
| cv_img_scan_rate = gr.Textbox( |
| label="Scan rates (V/s), comma-separated", |
| placeholder="e.g., 0.01, 0.1, 1.0", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| with gr.Row(): |
| cv_img_E0 = gr.Number( |
| label="Formal potential E\u2080 (V)", |
| value=None, |
| info="Auto-estimated from peaks if empty", |
| ) |
| cv_img_threshold = gr.Slider( |
| 0, 255, value=0, step=1, |
| label="Binarization threshold (0 = auto)", |
| ) |
| cv_img_current_unit = gr.Dropdown( |
| label="Current unit on y-axis", |
| choices=["auto", "\u00b5A", "mA", "A", "nA"], |
| value="auto", |
| info="Select the unit shown on the y-axis of your plot", |
| ) |
| with gr.Row(): |
| cv_img_xmin = gr.Number(label="E min (V)", value=None) |
| cv_img_xmax = gr.Number(label="E max (V)", value=None) |
| cv_img_ymin = gr.Number(label="I min", value=None) |
| cv_img_ymax = gr.Number(label="I max", value=None) |
| cv_img_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| _cv_choices, _cv_default = _hybrid_choices_and_default("ec") |
| cv_img_method = gr.Radio( |
| choices=_cv_choices, |
| value=_cv_default, |
| label="Inference strategy", |
| info=("Joint (Phase 2) fuses the rendered image " |
| "with digitizer-extracted waveforms in a " |
| "single encoder; Ensemble averages " |
| "image-direct and digitize-then-infer; " |
| "Auto-fallback uses image-direct unless " |
| "its OOD score is low."), |
| ) |
| cv_img_btn = gr.Button( |
| "Analyze", variant="primary", |
| elem_classes=["analyze-btn"], |
| ) |
|
|
| (cv_img_ood_banner, cv_img_probs, cv_img_state, |
| cv_img_mech_dd, cv_img_posteriors, cv_img_param_table, |
| cv_img_recon, cv_img_conc) = _build_ec_output_section("cv_img") |
|
|
| with gr.Accordion("Method comparison", open=False): |
| cv_img_comparison = gr.HTML(visible=False) |
| cv_img_preview = gr.Gallery( |
| label="Preprocessed image fed to image-mode CNN", |
| columns=3, height=180, object_fit="contain", |
| interactive=False, visible=False, |
| ) |
|
|
| cv_mode.change( |
| lambda v: ( |
| gr.update(visible=v == "CSV Data"), |
| gr.update(visible=v == "From Image"), |
| ), |
| inputs=[cv_mode], |
| outputs=[cv_csv_group, cv_image_group], |
| ) |
|
|
| |
| ec_outputs = [ |
| cv_ood_banner, |
| cv_probs, cv_state, |
| cv_mech_dd, cv_posteriors, cv_param_table, |
| cv_recon, cv_conc, |
| ] |
| cv_btn.click( |
| analyze_cv, |
| inputs=[ |
| cv_files, cv_rates, cv_E0, cv_T, |
| cv_A, cv_C, cv_D, cv_n, cv_nsamples, |
| ], |
| outputs=ec_outputs, |
| ) |
| cv_mech_dd.change( |
| _on_ec_mechanism_change, |
| inputs=[cv_mech_dd, cv_state], |
| outputs=[cv_posteriors, cv_param_table, cv_recon, cv_conc], |
| ) |
| if CV_EXAMPLES: |
| cv_example_btn.click( |
| _load_cv_example, |
| inputs=[cv_example_dd], |
| outputs=[ |
| cv_files, cv_rates, cv_E0, cv_T, |
| cv_A, cv_C, cv_D, cv_n, |
| ], |
| ) |
|
|
| |
| ec_img_outputs = [ |
| cv_img_ood_banner, |
| cv_img_probs, cv_img_state, |
| cv_img_mech_dd, cv_img_posteriors, cv_img_param_table, |
| cv_img_recon, cv_img_conc, |
| cv_img_comparison, cv_img_preview, |
| ] |
| cv_img_btn.click( |
| analyze_cv_image, |
| inputs=[ |
| cv_img_files, cv_img_scan_rate, cv_img_E0, |
| cv_img_threshold, cv_img_current_unit, |
| cv_img_nsamples, |
| cv_img_xmin, cv_img_xmax, |
| cv_img_ymin, cv_img_ymax, |
| cv_img_method, |
| ], |
| outputs=ec_img_outputs, |
| ) |
| cv_img_mech_dd.change( |
| _on_ec_mechanism_change, |
| inputs=[cv_img_mech_dd, cv_img_state], |
| outputs=[cv_img_posteriors, cv_img_param_table, cv_img_recon, cv_img_conc], |
| ) |
| if CV_IMG_EXAMPLES: |
| def _cv_img_input_gallery(mech_name): |
| """Show the actual input PNGs the user will analyze.""" |
| if not mech_name or mech_name not in CV_IMG_EXAMPLES: |
| return [] |
| return [e[0] for e in CV_IMG_EXAMPLES[mech_name]] |
|
|
| def _on_cv_img_example_select(mech_name): |
| files, rates = _load_cv_image_example(mech_name) |
| return files, rates, _cv_img_input_gallery(mech_name) |
|
|
| cv_img_example_btn.click( |
| _on_cv_img_example_select, |
| inputs=[cv_img_example_dd], |
| outputs=[cv_img_files, cv_img_scan_rate, cv_img_example_gallery], |
| ) |
| cv_img_example_dd.change( |
| _cv_img_input_gallery, |
| inputs=[cv_img_example_dd], |
| outputs=[cv_img_example_gallery], |
| ) |
|
|
| |
| |
| |
| with gr.Tab("TPD Analysis"): |
| tpd_mode = gr.Radio( |
| choices=["CSV Data", "From Image"], |
| value="CSV Data", |
| label=None, |
| show_label=False, |
| elem_classes=["mode-radio"], |
| ) |
|
|
| |
| with gr.Column(visible=True) as tpd_csv_group: |
| with gr.Group(elem_classes=["input-card"]): |
| gr.HTML( |
| "<p class='helper'>One CSV per heating rate. Expected columns: " |
| "<code>Temperature (K)</code>, <code>Signal</code>, optional " |
| "<code>Time (s)</code> — if present, the heating rate β is " |
| "inferred as <code>dT/dt</code> from the time stamps " |
| "(no need to specify it manually).</p>" |
| ) |
| if TPD_EXAMPLES: |
| with gr.Row(): |
| tpd_example_dd = gr.Dropdown( |
| label="Try an example", |
| choices=list(TPD_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=4, |
| ) |
| tpd_example_btn = gr.Button( |
| "Load", variant="secondary", scale=1, |
| ) |
| tpd_files = gr.File( |
| label="CSV files (one per heating rate)", |
| file_count="multiple", |
| file_types=[".csv", ".txt"], |
| ) |
| tpd_betas = gr.Textbox( |
| label="Heating rates \u03b2 (K/s), comma-separated", |
| placeholder="e.g., 0.3, 2.6, 22.1 (leave empty if CSV has time column)", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| tpd_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| tpd_btn = gr.Button( |
| "Analyze", variant="primary", |
| elem_classes=["analyze-btn"], |
| ) |
|
|
| (tpd_ood_banner, tpd_probs, tpd_state, |
| tpd_mech_dd, tpd_posteriors, tpd_param_table, |
| tpd_recon) = _build_tpd_output_section("tpd") |
|
|
| |
| with gr.Column(visible=False) as tpd_image_group: |
| with gr.Group(elem_classes=["input-card"]): |
| gr.HTML( |
| "<p class='helper'>One image per heating rate (temperature on " |
| "x-axis, signal on y-axis). Axis bounds auto-detected via OCR " |
| "— override in Advanced if needed. " |
| "<em>Note:</em> digitized curves are inherently noisier than the " |
| "simulated data the model was trained on, so the OOD detector may " |
| "flag image inputs and reconstruction quality is typically lower " |
| "than the CSV path; for production use, prefer CSV.</p>" |
| ) |
| if TPD_IMG_EXAMPLES: |
| with gr.Row(): |
| tpd_img_example_dd = gr.Dropdown( |
| label="Try an example", |
| choices=list(TPD_IMG_EXAMPLES.keys()), |
| value=None, |
| interactive=True, |
| scale=4, |
| ) |
| tpd_img_example_btn = gr.Button( |
| "Load", variant="secondary", scale=1, |
| ) |
| tpd_img_example_gallery = gr.Gallery( |
| label="Input plot image \u2014 this is what the model digitizes", |
| columns=3, height=220, object_fit="contain", |
| interactive=False, |
| ) |
| tpd_img_files = gr.File( |
| label="Plot images (one per heating rate)", |
| file_count="multiple", |
| file_types=["image"], |
| ) |
| tpd_img_betas = gr.Textbox( |
| label="Heating rates \u03b2 (K/s), comma-separated", |
| placeholder="e.g., 0.3, 2.6, 22.1", |
| value="", |
| ) |
| with gr.Accordion("Advanced parameters", open=False): |
| with gr.Row(): |
| tpd_img_threshold = gr.Slider( |
| 0, 255, value=0, step=1, |
| label="Binarization threshold (0 = auto)", |
| ) |
| with gr.Row(): |
| tpd_img_xmin = gr.Number(label="T min (K)", value=None) |
| tpd_img_xmax = gr.Number(label="T max (K)", value=None) |
| tpd_img_ymin = gr.Number(label="Signal min", value=None) |
| tpd_img_ymax = gr.Number(label="Signal max", value=None) |
| tpd_img_nsamples = gr.Slider( |
| 100, 2000, value=500, step=100, |
| label="Posterior samples", |
| ) |
| _tpd_choices, _tpd_default = _hybrid_choices_and_default("tpd") |
| tpd_img_method = gr.Radio( |
| choices=_tpd_choices, |
| value=_tpd_default, |
| label="Inference strategy", |
| info=("Joint (Phase 2) fuses the image with " |
| "digitizer-extracted waveforms in a " |
| "single encoder; Ensemble averages " |
| "image-direct and digitize-then-infer; " |
| "Auto-fallback uses image-direct unless " |
| "its OOD score is low."), |
| ) |
| tpd_img_btn = gr.Button( |
| "Analyze", variant="primary", |
| elem_classes=["analyze-btn"], |
| ) |
|
|
| (tpd_img_ood_banner, tpd_img_probs, tpd_img_state, |
| tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, |
| tpd_img_recon) = _build_tpd_output_section("tpd_img") |
|
|
| with gr.Accordion("Method comparison", open=False): |
| tpd_img_comparison = gr.HTML(visible=False) |
| tpd_img_preview = gr.Gallery( |
| label="Preprocessed image fed to image-mode CNN", |
| columns=3, height=180, object_fit="contain", |
| interactive=False, visible=False, |
| ) |
|
|
| tpd_mode.change( |
| lambda v: ( |
| gr.update(visible=v == "CSV Data"), |
| gr.update(visible=v == "From Image"), |
| ), |
| inputs=[tpd_mode], |
| outputs=[tpd_csv_group, tpd_image_group], |
| ) |
|
|
| |
| tpd_outputs = [ |
| tpd_ood_banner, |
| tpd_probs, tpd_state, |
| tpd_mech_dd, tpd_posteriors, tpd_param_table, tpd_recon, |
| ] |
| tpd_btn.click( |
| analyze_tpd, |
| inputs=[tpd_files, tpd_betas, tpd_nsamples], |
| outputs=tpd_outputs, |
| ) |
| tpd_mech_dd.change( |
| _on_tpd_mechanism_change, |
| inputs=[tpd_mech_dd, tpd_state], |
| outputs=[tpd_posteriors, tpd_param_table, tpd_recon], |
| ) |
| if TPD_EXAMPLES: |
| tpd_example_btn.click( |
| _load_tpd_example, |
| inputs=[tpd_example_dd], |
| outputs=[tpd_files, tpd_betas], |
| ) |
|
|
| |
| tpd_img_outputs = [ |
| tpd_img_ood_banner, |
| tpd_img_probs, tpd_img_state, |
| tpd_img_mech_dd, tpd_img_posteriors, tpd_img_param_table, tpd_img_recon, |
| tpd_img_comparison, tpd_img_preview, |
| ] |
| tpd_img_btn.click( |
| analyze_tpd_image, |
| inputs=[ |
| tpd_img_files, tpd_img_betas, |
| tpd_img_threshold, tpd_img_nsamples, |
| tpd_img_xmin, tpd_img_xmax, |
| tpd_img_ymin, tpd_img_ymax, |
| tpd_img_method, |
| ], |
| outputs=tpd_img_outputs, |
| ) |
| tpd_img_mech_dd.change( |
| _on_tpd_mechanism_change, |
| inputs=[tpd_img_mech_dd, tpd_img_state], |
| outputs=[tpd_img_posteriors, tpd_img_param_table, tpd_img_recon], |
| ) |
| if TPD_IMG_EXAMPLES: |
| def _tpd_img_input_gallery(mech_name): |
| if not mech_name or mech_name not in TPD_IMG_EXAMPLES: |
| return [] |
| return [e[0] for e in TPD_IMG_EXAMPLES[mech_name]] |
|
|
| def _on_tpd_img_example_select(mech_name): |
| files, betas = _load_tpd_image_example(mech_name) |
| return files, betas, _tpd_img_input_gallery(mech_name) |
|
|
| tpd_img_example_btn.click( |
| _on_tpd_img_example_select, |
| inputs=[tpd_img_example_dd], |
| outputs=[tpd_img_files, tpd_img_betas, tpd_img_example_gallery], |
| ) |
| tpd_img_example_dd.change( |
| _tpd_img_input_gallery, |
| inputs=[tpd_img_example_dd], |
| outputs=[tpd_img_example_gallery], |
| ) |
|
|
| |
| |
| |
| with gr.Tab("About"): |
| with gr.Row(elem_classes=["about-card"]): |
| with gr.Column(scale=2): |
| gr.Markdown(""" |
| ## How it works |
| |
| **SPARK** (Simulation-based Posterior Amortization for Reaction Kinetics) uses |
| **conditional normalizing flows** with a **Set Transformer** encoder to |
| perform amortized Bayesian inference. Given one or more experimental |
| curves, it simultaneously classifies the reaction mechanism and produces |
| full posterior distributions over kinetic parameters — in a single |
| forward pass. |
| |
| The deployed checkpoints are the **noise-augmented headline models** |
| (CV: `v14_9mech`, TPD: `tpd_11mech_v2`) that retain 92.4 % (CV) and |
| 95.6 % (TPD) classification accuracy under realistic measurement noise. |
| Inference runs in **~50 ms on CPU**; mean 90 %-credible-interval coverage |
| is 92.2 % (CV) and 92.9 % (TPD). |
| |
| Training data is generated from physics-based simulators |
| (Crank–Nicolson for CV, ODE integrators for TPD). Posteriors are |
| calibrated via a coverage-aware loss with per-parameter inverse-spread |
| weighting. |
| |
| ### Citation |
| |
| ``` |
| Yan, B. (2026). SPARK: Amortized Bayesian Inference for |
| Mechanism Identification and Parameter Estimation in |
| Electrochemistry and Catalysis via Conditional |
| Normalizing Flows. [Preprint] |
| ``` |
| """) |
| with gr.Column(scale=1): |
| gr.Markdown(""" |
| ### Supported mechanisms |
| |
| **Electrochemistry (CV, 9):** |
| Nernst, Butler–Volmer, Marcus–Hush–Chidsey, |
| Adsorption, EC, Langmuir–Hinshelwood, EE, EC′, CE. |
| |
| **Catalysis (TPD, 11):** |
| First-order, Second-order, Zeroth-order, FirstOrderCovDep, |
| DiffLimited, Precursor, Dissociative, ActivatedAdsorption, |
| LH Surface, Mars–van Krevelen, TwoSite. |
| """) |
|
|
| gr.HTML( |
| "<div class='trace-footer'>" |
| "<a href='https://github.com/bingyan/ECFlow' target='_blank'>" |
| "github.com/bingyan/ECFlow</a></div>" |
| ) |
| gr.HTML("</div>") |
|
|
| return app |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| app = build_app() |
| app.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=False, |
| show_error=True, |
| ) |
|
|