""" SPARK inference engine. Loads trained EC and TPD models and runs end-to-end inference from preprocessed arrays (dimensionless for CV, physical for TPD). """ import json import sys import os from pathlib import Path import numpy as np import torch import flow_model as _fm_module import multi_mechanism_model as _mm_module import tpd_model as _tpd_mod import generate_tpd_data as _tpd_gen from multi_mechanism_model import MultiMechanismFlow from tpd_model import MultiMechanismFlowTPD from flow_model import MECHANISM_PARAMS, ActNorm from generate_tpd_data import TPD_MECHANISM_PARAMS def _fix_actnorm_initialized(model): """Mark all ActNorm layers as initialized after loading a checkpoint. Old checkpoints lack the ``_initialized`` buffer, so ``load_state_dict`` leaves it at ``False``. The first forward pass would then overwrite the trained ``log_scale``/``bias`` with data-dependent statistics. """ for module in model.modules(): if isinstance(module, ActNorm) and not module.initialized: module.initialized = True class SPARKPredictor: """Unified predictor for both EC (cyclic voltammetry) and TPD domains.""" def __init__(self, ec_checkpoint=None, tpd_checkpoint=None, device=None, ec_image_checkpoint=None, tpd_image_checkpoint=None, ec_joint_checkpoint=None, tpd_joint_checkpoint=None): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.ec_model = None self.ec_norm_stats = None self.tpd_model = None self.tpd_norm_stats = None # Optional image-input variants; loaded only if checkpoint paths # are supplied. Each is a separate model with its own encoder. self.ec_image_model = None self.ec_image_mech_list = None self.tpd_image_model = None self.tpd_image_mech_list = None # Optional Phase-2 joint (image + waveform) variants. self.ec_joint_model = None self.ec_joint_mech_list = None self.tpd_joint_model = None self.tpd_joint_mech_list = None if ec_checkpoint is not None: self._load_ec(ec_checkpoint) if tpd_checkpoint is not None: self._load_tpd(tpd_checkpoint) if ec_image_checkpoint is not None: self._load_ec_image(ec_image_checkpoint) if tpd_image_checkpoint is not None: self._load_tpd_image(tpd_image_checkpoint) if ec_joint_checkpoint is not None: self._load_ec_joint(ec_joint_checkpoint) if tpd_joint_checkpoint is not None: self._load_tpd_joint(tpd_joint_checkpoint) @property def has_ec_image_model(self) -> bool: return self.ec_image_model is not None @property def has_tpd_image_model(self) -> bool: return self.tpd_image_model is not None @property def has_ec_joint_model(self) -> bool: return self.ec_joint_model is not None @property def has_tpd_joint_model(self) -> bool: return self.tpd_joint_model is not None def _load_ec(self, ckpt_path): ckpt_path = Path(ckpt_path) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) args = checkpoint["args"] # If the checkpoint was trained on a custom mechanism subset (e.g. # v14_9mech), patch the global MECHANISM_LIST in flow_model and # multi_mechanism_model BEFORE constructing MultiMechanismFlow so # the classifier and flow_heads sizes match the checkpoint exactly. # Otherwise strict=False below would silently drop classifier # weights and we'd end up with random-weights predictions. if args.get("mechanism_list") is not None: new_list = list(args["mechanism_list"]) _fm_module.MECHANISM_LIST = new_list _mm_module.MECHANISM_LIST = new_list self.ec_mechanism_list = list(_fm_module.MECHANISM_LIST) self.ec_model = MultiMechanismFlow( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), aggregation=args.get("aggregation", "set_transformer"), use_summary_features=args.get("use_summary_features", False), ) # If checkpoint includes a trained OOD head, initialize the matching # nn.Sequential before loading so its weights are restored too. state = checkpoint["model_state_dict"] if any(k.startswith("ood_head.") for k in state): meta = checkpoint.get("ood_head_meta", {}) self.ec_model.init_ood_head( hidden_dim=meta.get("hidden_dim", 64), extra_input_dim=meta.get("extra_input_dim", 0), use_nll=meta.get("use_nll", False), use_posterior_width=meta.get("use_posterior_width", False), ) missing, unexpected = self.ec_model.load_state_dict( state, strict=False) # We only tolerate buffers like ActNorm._initialized being missing/extra. suspicious_missing = [ k for k in missing if not k.endswith("_initialized") ] suspicious_unexpected = [ k for k in unexpected if not k.endswith("_initialized") ] if suspicious_missing or suspicious_unexpected: print(f"[SPARKPredictor] WARNING: state_dict mismatch on EC ckpt.") if suspicious_missing: print(f" missing ({len(suspicious_missing)}): " f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") if suspicious_unexpected: print(f" unexpected({len(suspicious_unexpected)}): " f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") _fix_actnorm_initialized(self.ec_model) self.ec_model.to(self.device).eval() # Search for norm_stats in multiple locations ckpt_dir = ckpt_path.parent stem = ckpt_path.stem.replace("best", "").rstrip("_") prefix = stem + "_" if stem else "" for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}norm_stats.json", "ec_norm_stats.json", "norm_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.ec_norm_stats = json.load(f) break if self.ec_norm_stats is not None: break for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}theta_stats.json", "ec_theta_stats.json", "theta_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.ec_theta_stats = json.load(f) break if hasattr(self, "ec_theta_stats") and self.ec_theta_stats is not None: break def _load_tpd(self, ckpt_path): ckpt_path = Path(ckpt_path) checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) args = checkpoint["args"] self.tpd_use_summary = args.get("use_summary_features", False) # As with EC: align module-globals AND pass mechanism_list explicitly # so the classifier/flow_heads match the checkpoint shape (e.g. # tpd_11mech_v1 has 11 mechs, not the default 13). ckpt_mech_list = args.get("mechanism_list") if ckpt_mech_list is not None: new_list = list(ckpt_mech_list) _tpd_gen.TPD_MECHANISM_LIST = new_list _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)} _tpd_mod.TPD_MECHANISM_LIST = new_list try: import dataset_tpd as _ds_tpd _ds_tpd.TPD_MECHANISM_LIST = new_list except ImportError: pass self.tpd_mechanism_list = list(_tpd_gen.TPD_MECHANISM_LIST) self.tpd_model = MultiMechanismFlowTPD( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), use_summary_features=self.tpd_use_summary, use_bounded_flow=args.get("use_bounded_flow", False), mechanism_list=self.tpd_mechanism_list, ) state = checkpoint["model_state_dict"] if any(k.startswith("ood_head.") for k in state): meta = checkpoint.get("ood_head_meta", {}) self.tpd_model.init_ood_head( hidden_dim=meta.get("hidden_dim", 64), extra_input_dim=meta.get("extra_input_dim", 0), ) missing, unexpected = self.tpd_model.load_state_dict( state, strict=False) suspicious_missing = [ k for k in missing if not k.endswith("_initialized") ] suspicious_unexpected = [ k for k in unexpected if not k.endswith("_initialized") ] if suspicious_missing or suspicious_unexpected: print(f"[SPARKPredictor] WARNING: state_dict mismatch on TPD ckpt.") if suspicious_missing: print(f" missing ({len(suspicious_missing)}): " f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") if suspicious_unexpected: print(f" unexpected({len(suspicious_unexpected)}): " f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") _fix_actnorm_initialized(self.tpd_model) self.tpd_model.to(self.device).eval() # Search for norm_stats in multiple locations ckpt_dir = ckpt_path.parent stem = ckpt_path.stem.replace("best", "").rstrip("_") prefix = stem + "_" if stem else "" for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}norm_stats.json", "tpd_norm_stats.json", "norm_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.tpd_norm_stats = json.load(f) break if self.tpd_norm_stats is not None: break for search_dir in [ckpt_dir, ckpt_dir.parent]: for name_pattern in [f"{prefix}theta_stats.json", "tpd_theta_stats.json", "theta_stats.json"]: p = search_dir / name_pattern if p.exists(): with open(p) as f: self.tpd_theta_stats = json.load(f) break if hasattr(self, "tpd_theta_stats") and self.tpd_theta_stats is not None: break # ===================================================================== # Image-input variants (parallel to the waveform models above) # ===================================================================== @staticmethod def _detect_input_mode(state_dict) -> str: """Detect input_mode from a checkpoint's state_dict keys. Phase-1 image -> encoder.per_cv_encoder.stem.* Phase-1 wave -> encoder.per_cv_encoder.conv.* Phase-2 joint -> encoder.image_encoder.* + encoder.waveform_encoder.* """ if any(k.startswith("encoder.image_encoder.") for k in state_dict) and \ any(k.startswith("encoder.waveform_encoder.") for k in state_dict): return "image+waveform" if any(k.startswith("encoder.per_cv_encoder.stem.") for k in state_dict): return "image" return "waveform" def _build_image_model_state(self, ckpt_path, builder_cls, is_tpd=False, expected_input_mode="image"): ckpt_path = Path(ckpt_path) ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) args = ckpt["args"] if isinstance(ckpt, dict) and "args" in ckpt else {} # Older image-mode trainers didn't store input_mode in args; sniff # the state_dict to disambiguate. actual = (args.get("input_mode") or self._detect_input_mode(ckpt["model_state_dict"])) if actual != expected_input_mode: raise ValueError( f"{expected_input_mode!r} checkpoint expected; got input_mode=" f"{actual!r} for {ckpt_path}" ) return ckpt_path, ckpt, args def _load_ec_image(self, ckpt_path): ckpt_path, ckpt, args = self._build_image_model_state( ckpt_path, MultiMechanismFlow, ) if args.get("mechanism_list") is not None: new_list = list(args["mechanism_list"]) _fm_module.MECHANISM_LIST = new_list _mm_module.MECHANISM_LIST = new_list self.ec_image_mech_list = list(_fm_module.MECHANISM_LIST) self.ec_image_model = MultiMechanismFlow( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), aggregation=args.get("aggregation", "set_transformer"), use_summary_features=False, input_mode="image", image_in_channels=args.get("image_in_channels", 1), ) state = ckpt["model_state_dict"] if any(k.startswith("ood_head.") for k in state): meta = ckpt.get("ood_head_meta", {}) self.ec_image_model.init_ood_head( hidden_dim=meta.get("hidden_dim", 64), extra_input_dim=meta.get("extra_input_dim", 0), use_nll=meta.get("use_nll", False), use_posterior_width=meta.get("use_posterior_width", False), ) missing, unexpected = self.ec_image_model.load_state_dict( state, strict=False) suspicious_missing = [k for k in missing if not k.endswith("_initialized")] suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] if suspicious_missing or suspicious_unexpected: print("[SPARKPredictor] WARNING: state_dict mismatch on EC image ckpt.") if suspicious_missing: print(f" missing ({len(suspicious_missing)}): " f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") if suspicious_unexpected: print(f" unexpected({len(suspicious_unexpected)}): " f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") _fix_actnorm_initialized(self.ec_image_model) self.ec_image_model.to(self.device).eval() def _load_tpd_image(self, ckpt_path): ckpt_path, ckpt, args = self._build_image_model_state( ckpt_path, MultiMechanismFlowTPD, is_tpd=True, ) ckpt_mech_list = args.get("mechanism_list") if ckpt_mech_list is not None: new_list = list(ckpt_mech_list) _tpd_gen.TPD_MECHANISM_LIST = new_list _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)} _tpd_mod.TPD_MECHANISM_LIST = new_list try: import dataset_tpd as _ds_tpd _ds_tpd.TPD_MECHANISM_LIST = new_list except ImportError: pass self.tpd_image_mech_list = list(_tpd_gen.TPD_MECHANISM_LIST) self.tpd_image_model = MultiMechanismFlowTPD( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), use_summary_features=False, use_bounded_flow=args.get("use_bounded_flow", False), mechanism_list=self.tpd_image_mech_list, input_mode="image", image_in_channels=args.get("image_in_channels", 1), ) state = ckpt["model_state_dict"] if any(k.startswith("ood_head.") for k in state): meta = ckpt.get("ood_head_meta", {}) self.tpd_image_model.init_ood_head( hidden_dim=meta.get("hidden_dim", 64), extra_input_dim=meta.get("extra_input_dim", 0), ) missing, unexpected = self.tpd_image_model.load_state_dict( state, strict=False) suspicious_missing = [k for k in missing if not k.endswith("_initialized")] suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] if suspicious_missing or suspicious_unexpected: print("[SPARKPredictor] WARNING: state_dict mismatch on TPD image ckpt.") if suspicious_missing: print(f" missing ({len(suspicious_missing)}): " f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") if suspicious_unexpected: print(f" unexpected({len(suspicious_unexpected)}): " f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") _fix_actnorm_initialized(self.tpd_image_model) self.tpd_image_model.to(self.device).eval() # ===================================================================== # Phase-2 joint (image + waveform) loaders # ===================================================================== def _load_ec_joint(self, ckpt_path): ckpt_path, ckpt, args = self._build_image_model_state( ckpt_path, MultiMechanismFlow, expected_input_mode="image+waveform", ) if args.get("mechanism_list") is not None: new_list = list(args["mechanism_list"]) _fm_module.MECHANISM_LIST = new_list _mm_module.MECHANISM_LIST = new_list self.ec_joint_mech_list = list(_fm_module.MECHANISM_LIST) self.ec_joint_model = MultiMechanismFlow( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), aggregation=args.get("aggregation", "set_transformer"), use_summary_features=False, input_mode="image+waveform", image_in_channels=args.get("image_in_channels", 1), ) state = ckpt["model_state_dict"] if any(k.startswith("ood_head.") for k in state): meta = ckpt.get("ood_head_meta", {}) self.ec_joint_model.init_ood_head( hidden_dim=meta.get("hidden_dim", 64), extra_input_dim=meta.get("extra_input_dim", 0), use_nll=meta.get("use_nll", False), use_posterior_width=meta.get("use_posterior_width", False), ) missing, unexpected = self.ec_joint_model.load_state_dict( state, strict=False) suspicious_missing = [k for k in missing if not k.endswith("_initialized")] suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] if suspicious_missing or suspicious_unexpected: print("[SPARKPredictor] WARNING: state_dict mismatch on EC joint ckpt.") if suspicious_missing: print(f" missing ({len(suspicious_missing)}): " f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") if suspicious_unexpected: print(f" unexpected({len(suspicious_unexpected)}): " f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") _fix_actnorm_initialized(self.ec_joint_model) self.ec_joint_model.to(self.device).eval() def _load_tpd_joint(self, ckpt_path): ckpt_path, ckpt, args = self._build_image_model_state( ckpt_path, MultiMechanismFlowTPD, is_tpd=True, expected_input_mode="image+waveform", ) ckpt_mech_list = args.get("mechanism_list") if ckpt_mech_list is not None: new_list = list(ckpt_mech_list) _tpd_gen.TPD_MECHANISM_LIST = new_list _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)} _tpd_mod.TPD_MECHANISM_LIST = new_list try: import dataset_tpd as _ds_tpd _ds_tpd.TPD_MECHANISM_LIST = new_list except ImportError: pass self.tpd_joint_mech_list = list(_tpd_gen.TPD_MECHANISM_LIST) self.tpd_joint_model = MultiMechanismFlowTPD( d_context=args.get("d_context", 128), d_model=args.get("d_model", 128), n_coupling_layers=args.get("n_coupling_layers", 6), hidden_dim=args.get("hidden_dim", 96), coupling_type=args.get("coupling_type", "spline"), n_bins=args.get("n_bins", 8), tail_bound=args.get("tail_bound", 5.0), use_summary_features=False, use_bounded_flow=args.get("use_bounded_flow", False), mechanism_list=self.tpd_joint_mech_list, input_mode="image+waveform", image_in_channels=args.get("image_in_channels", 1), ) state = ckpt["model_state_dict"] if any(k.startswith("ood_head.") for k in state): meta = ckpt.get("ood_head_meta", {}) self.tpd_joint_model.init_ood_head( hidden_dim=meta.get("hidden_dim", 64), extra_input_dim=meta.get("extra_input_dim", 0), ) missing, unexpected = self.tpd_joint_model.load_state_dict( state, strict=False) suspicious_missing = [k for k in missing if not k.endswith("_initialized")] suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] if suspicious_missing or suspicious_unexpected: print("[SPARKPredictor] WARNING: state_dict mismatch on TPD joint ckpt.") if suspicious_missing: print(f" missing ({len(suspicious_missing)}): " f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") if suspicious_unexpected: print(f" unexpected({len(suspicious_unexpected)}): " f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") _fix_actnorm_initialized(self.tpd_joint_model) self.tpd_joint_model.to(self.device).eval() @staticmethod def _pil_to_grayscale_tensor(pil_image, target_size=224): """Convert a PIL image to a [1, target_size, target_size] float tensor in [0, 1].""" from PIL import Image as PILImage img = pil_image.convert("L") if img.size != (target_size, target_size): img = img.resize((target_size, target_size), PILImage.BILINEAR) arr = np.asarray(img, dtype=np.float32) / 255.0 return torch.from_numpy(arr).unsqueeze(0) # [1, H, W] def _build_image_input(self, pil_images, sigmas, flux_scales, target_size=224): """Build image-mode model input from a list of PIL images. Args: pil_images: list of PIL.Image objects (one per scan rate / heating rate). sigmas: 1-D array of raw scan rates (V/s) or heating rates (K/s). flux_scales: 1-D array of log10(peak |signal|) per scan/curve. If None, set to zero (model still works since flux_scales is additive conditioning that the network learned to handle). target_size: image edge length expected by the encoder. """ n = len(pil_images) imgs = torch.stack( [self._pil_to_grayscale_tensor(p, target_size) for p in pil_images] ) # [N, 1, H, W] x = imgs.unsqueeze(0).to(self.device) # [1, N, 1, H, W] scan_mask = torch.ones(1, n, dtype=torch.bool, device=self.device) sigmas_log = np.log10(np.clip(np.asarray(sigmas, dtype=np.float32), 1e-10, None)) sigmas_t = torch.from_numpy(sigmas_log).unsqueeze(0).to(self.device) if flux_scales is None: fs_t = torch.zeros(1, n, dtype=torch.float32, device=self.device) else: fs_t = torch.from_numpy( np.asarray(flux_scales, dtype=np.float32) ).unsqueeze(0).to(self.device) return { "input": x, "scan_mask": scan_mask, "sigmas": sigmas_t, "flux_scales": fs_t, } @torch.no_grad() def predict_ec_image(self, pil_images, sigmas, flux_scales=None, n_samples=500, temperature=1.0): """Run image-mode CV inference. `pil_images` length should match `sigmas`. Returns the same dict shape as `predict_ec`. """ if self.ec_image_model is None: raise RuntimeError("EC image model not loaded") if len(pil_images) != len(sigmas): raise ValueError( f"#images ({len(pil_images)}) must match #sigmas ({len(sigmas)})" ) tensors = self._build_image_input(pil_images, sigmas, flux_scales) pred = self.ec_image_model.predict( tensors["input"], scan_mask=tensors["scan_mask"], sigmas=tensors["sigmas"], flux_scales=tensors["flux_scales"], n_samples=n_samples, temperature=temperature, ) return self._format_ec_pred(pred, self.ec_image_mech_list) def _build_joint_input_ec(self, pil_images, sigmas, potentials, fluxes, times): """Build the joint encoder's input dict for CV. Combines `_build_image_input` (image branch) with the same normalization/resampling pipeline `_prepare_ec_tensor` uses for the waveform branch, then fuses them into the dict shape the joint encoder expects. """ n_scans = len(pil_images) if len(sigmas) != n_scans or len(potentials) != n_scans \ or len(fluxes) != n_scans: raise ValueError( "predict_ec_joint: pil_images, sigmas, potentials, fluxes " "must all have the same length" ) # Image branch: [1, N, 1, H, W] imgs = torch.stack( [self._pil_to_grayscale_tensor(p, target_size=224) for p in pil_images] ).unsqueeze(0).to(self.device) scan_mask_image = torch.ones(1, n_scans, dtype=torch.bool, device=self.device) # Waveform branch: re-use the existing tensor builder. wf_tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas) x = { "image": imgs, "waveform": wf_tensors["input"], "scan_mask_image": scan_mask_image, "scan_mask_waveform": wf_tensors["scan_mask"], } return x, wf_tensors["sigmas"], wf_tensors["flux_scales"] def _build_joint_input_tpd(self, pil_images, betas, temperatures, rates): n_rates = len(pil_images) if len(betas) != n_rates or len(temperatures) != n_rates \ or len(rates) != n_rates: raise ValueError( "predict_tpd_joint: pil_images, betas, temperatures, rates " "must all have the same length" ) imgs = torch.stack( [self._pil_to_grayscale_tensor(p, target_size=224) for p in pil_images] ).unsqueeze(0).to(self.device) scan_mask_image = torch.ones(1, n_rates, dtype=torch.bool, device=self.device) wf_tensors = self._prepare_tpd_tensor(temperatures, rates, betas) x = { "image": imgs, "waveform": wf_tensors["input"], "scan_mask_image": scan_mask_image, "scan_mask_waveform": wf_tensors["scan_mask"], } return x, wf_tensors["sigmas"], wf_tensors["flux_scales"] @torch.no_grad() def predict_ec_joint(self, pil_images, sigmas, potentials, fluxes, times=None, n_samples=500, temperature=1.0): """Run Phase-2 joint CV inference: image + waveform together.""" if self.ec_joint_model is None: raise RuntimeError("EC joint model not loaded") x, sigmas_t, flux_scales_t = self._build_joint_input_ec( pil_images, sigmas, potentials, fluxes, times, ) pred = self.ec_joint_model.predict( x, scan_mask=None, sigmas=sigmas_t, flux_scales=flux_scales_t, n_samples=n_samples, temperature=temperature, ) return self._format_ec_pred(pred, self.ec_joint_mech_list) @torch.no_grad() def predict_tpd_joint(self, pil_images, betas, temperatures, rates, n_samples=500, temperature=1.0): """Run Phase-2 joint TPD inference: image + waveform together.""" if self.tpd_joint_model is None: raise RuntimeError("TPD joint model not loaded") x, sigmas_t, flux_scales_t = self._build_joint_input_tpd( pil_images, betas, temperatures, rates, ) pred = self.tpd_joint_model.predict( x, scan_mask=None, sigmas=sigmas_t, flux_scales=flux_scales_t, n_samples=n_samples, temperature=temperature, ) return self._format_tpd_pred(pred, self.tpd_joint_mech_list) @torch.no_grad() def predict_tpd_image(self, pil_images, betas, flux_scales=None, n_samples=500, temperature=1.0): """Run image-mode TPD inference. Returns the same shape as predict_tpd.""" if self.tpd_image_model is None: raise RuntimeError("TPD image model not loaded") if len(pil_images) != len(betas): raise ValueError( f"#images ({len(pil_images)}) must match #betas ({len(betas)})" ) tensors = self._build_image_input(pil_images, betas, flux_scales) pred = self.tpd_image_model.predict( tensors["input"], scan_mask=tensors["scan_mask"], sigmas=tensors["sigmas"], flux_scales=tensors["flux_scales"], n_samples=n_samples, temperature=temperature, ) return self._format_tpd_pred(pred, self.tpd_image_mech_list) def _format_ec_pred(self, pred, mech_list): probs = pred["mechanism_probs"][0].cpu().numpy() pred_idx = int(pred["mechanism_pred"][0].cpu().item()) pred_mech = mech_list[pred_idx] param_stats = {} samples_dict = {} for mech in mech_list: if pred["samples"][mech] is not None: s = pred["samples"][mech][0].cpu().numpy() samples_dict[mech] = s param_stats[mech] = { "names": MECHANISM_PARAMS[mech]["names"], "mean": s.mean(axis=0).tolist(), "std": s.std(axis=0).tolist(), "median": np.median(s, axis=0).tolist(), "q05": np.quantile(s, 0.05, axis=0).tolist(), "q95": np.quantile(s, 0.95, axis=0).tolist(), } ood_score = pred.get("ood_score") ood_score_val = (float(ood_score[0].cpu().item()) if ood_score is not None else None) return { "domain": "ec", "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, "mechanism_names": mech_list, "predicted_mechanism": pred_mech, "predicted_mechanism_idx": pred_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, "ood_score": ood_score_val, } def _format_tpd_pred(self, pred, mech_list): probs = pred["mechanism_probs"][0].cpu().numpy() pred_idx = int(pred["mechanism_pred"][0].cpu().item()) pred_mech = mech_list[pred_idx] param_stats = {} samples_dict = {} for mech in mech_list: if pred["samples"][mech] is not None: s = pred["samples"][mech][0].cpu().numpy() samples_dict[mech] = s param_stats[mech] = { "names": TPD_MECHANISM_PARAMS[mech]["names"], "mean": s.mean(axis=0).tolist(), "std": s.std(axis=0).tolist(), "median": np.median(s, axis=0).tolist(), "q05": np.quantile(s, 0.05, axis=0).tolist(), "q95": np.quantile(s, 0.95, axis=0).tolist(), } ood_score = pred.get("ood_score") ood_score_val = (float(ood_score[0].cpu().item()) if ood_score is not None else None) return { "domain": "tpd", "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, "mechanism_names": mech_list, "predicted_mechanism": pred_mech, "predicted_mechanism_idx": pred_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, "ood_score": ood_score_val, } # ===================================================================== # Hybrid predictors: run image-mode + waveform-mode in parallel and # combine, with selectable strategies. # ===================================================================== @staticmethod def _ensemble_results(image_result, waveform_result, param_names_lookup): """Combine image-mode + waveform-mode predictions. - mechanism_probs: arithmetic mean of the two per-mech dicts. - predicted_mechanism: argmax of the ensembled probs. - posterior_samples: per-mech, concatenate the samples from each model (when both produced samples), then recompute parameter_stats. - ood_score: max of the two scores (more conservative; image-mode's OOD means P(ID), so 'higher' is safer; we take min instead — see below). We surface the IMAGE-mode OOD score as the headline OOD, but if the waveform model also has one we take the lower of the two so the banner is shown when *either* is concerned. `param_names_lookup` is a dict mech -> list of parameter names; it provides 'names' for the recomputed parameter_stats. """ if image_result is None: return waveform_result if waveform_result is None: return image_result mech_list = image_result["mechanism_names"] probs_a = image_result["mechanism_probs"] probs_b = waveform_result["mechanism_probs"] ensemble_probs = { m: 0.5 * (probs_a.get(m, 0.0) + probs_b.get(m, 0.0)) for m in mech_list } s = sum(ensemble_probs.values()) if s > 0: ensemble_probs = {m: v / s for m, v in ensemble_probs.items()} sorted_probs = sorted(ensemble_probs.items(), key=lambda kv: -kv[1]) top_mech, _ = sorted_probs[0] top_idx = mech_list.index(top_mech) samples_dict = {} param_stats = {} for mech in mech_list: sa = image_result["posterior_samples"].get(mech) sb = waveform_result["posterior_samples"].get(mech) if sa is not None and sb is not None: combined = np.concatenate([sa, sb], axis=0) elif sa is not None: combined = sa elif sb is not None: combined = sb else: combined = None if combined is None: continue samples_dict[mech] = combined names = param_names_lookup.get(mech, [f"p{i}" for i in range(combined.shape[-1])]) param_stats[mech] = { "names": names, "mean": combined.mean(axis=0).tolist(), "std": combined.std(axis=0).tolist(), "median": np.median(combined, axis=0).tolist(), "q05": np.quantile(combined, 0.05, axis=0).tolist(), "q95": np.quantile(combined, 0.95, axis=0).tolist(), } ood_a = image_result.get("ood_score") ood_b = waveform_result.get("ood_score") ood_vals = [v for v in (ood_a, ood_b) if v is not None] ood_ensemble = min(ood_vals) if ood_vals else None # P(ID); lower=more concerning return { "domain": image_result["domain"], "mechanism_probs": ensemble_probs, "mechanism_names": mech_list, "predicted_mechanism": top_mech, "predicted_mechanism_idx": top_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, "ood_score": ood_ensemble, "_ensemble": True, } @staticmethod def _agreement_stats(image_result, waveform_result): if image_result is None or waveform_result is None: return { "available": False, "top_mech_match": None, "top_prob_image": None, "top_prob_waveform": None, } return { "available": True, "top_mech_match": ( image_result["predicted_mechanism"] == waveform_result["predicted_mechanism"] ), "top_prob_image": float(max(image_result["mechanism_probs"].values())), "top_prob_waveform": float(max(waveform_result["mechanism_probs"].values())), "image_mech": image_result["predicted_mechanism"], "waveform_mech": waveform_result["predicted_mechanism"], } def _predict_ec_two_paths( self, pil_images, sigmas, potentials, fluxes, times=None, n_samples=500, temperature=1.0, do_preprocess=True, ): """Run image-mode + waveform-mode + joint CV inference (whichever are available). Returns (image_result, waveform_result, joint_result, preprocessing_meta). Any result may be None if the corresponding model is missing or the input arrays were not provided. """ image_result = None waveform_result = None joint_result = None preproc_meta = [] # Preprocess images once; both image-only and joint paths reuse them. preprocessed = None if pil_images and (self.has_ec_image_model or self.has_ec_joint_model): if do_preprocess: from image_preprocessing import prepare_for_image_mode preprocessed = [] for p in pil_images: out, meta = prepare_for_image_mode(p) preprocessed.append(out) preproc_meta.append(meta) else: preprocessed = list(pil_images) preproc_meta = [{} for _ in pil_images] if self.has_ec_image_model and preprocessed is not None: flux_scales = None if fluxes is not None: flux_scales = [ float(np.log10(np.max(np.abs(np.asarray(f))) + 1e-30)) for f in fluxes ] try: image_result = self.predict_ec_image( preprocessed, sigmas, flux_scales=flux_scales, n_samples=n_samples, temperature=temperature, ) except Exception as exc: print(f"[SPARKPredictor] image-mode CV failed: {exc}") if self.ec_model is not None and potentials is not None and fluxes is not None: try: waveform_result = self.predict_ec( potentials, fluxes, sigmas, times=times, n_samples=n_samples, temperature=temperature, ) except Exception as exc: print(f"[SPARKPredictor] waveform CV failed: {exc}") if (self.has_ec_joint_model and preprocessed is not None and potentials is not None and fluxes is not None): try: joint_result = self.predict_ec_joint( preprocessed, sigmas, potentials, fluxes, times=times, n_samples=n_samples, temperature=temperature, ) except Exception as exc: print(f"[SPARKPredictor] joint CV failed: {exc}") return image_result, waveform_result, joint_result, preproc_meta def predict_ec_hybrid( self, pil_images, sigmas, potentials=None, fluxes=None, times=None, n_samples=500, temperature=1.0, mode="ensemble", do_preprocess=True, ood_fallback_threshold=0.3, ): """Hybrid CV inference combining image-mode and waveform-mode. Args: pil_images: list of PIL.Image (one per scan rate). Required for image-mode; if missing, image-mode is skipped. sigmas: list of dimensionless scan rates. potentials, fluxes: dimensionless waveforms (one list of arrays per scan rate). Required for waveform-mode; if missing, waveform-mode is skipped. times: optional dimensionless time arrays. mode: one of 'ensemble', 'image_only', 'digitize_only', 'auto_fallback'. do_preprocess: whether to run prepare_for_image_mode on inputs before image-mode (default True). ood_fallback_threshold: if mode='auto_fallback', image-mode's OOD score (P(ID)) below this threshold triggers fallback to waveform-mode. Returns dict with: headline: result dict to display (same shape as predict_ec). image_mode: image-mode result or None. waveform_mode: waveform result or None. preprocessing_meta: per-image preprocessing metadata list. agreement: agreement_stats(image, waveform) dict. method_used: human-readable string of which method produced 'headline'. """ (image_result, waveform_result, joint_result, preproc_meta) = ( self._predict_ec_two_paths( pil_images, sigmas, potentials, fluxes, times, n_samples=n_samples, temperature=temperature, do_preprocess=do_preprocess, ) ) agreement = self._agreement_stats(image_result, waveform_result) param_names_lookup = {m: MECHANISM_PARAMS[m]["names"] for m in MECHANISM_PARAMS} headline, method_used = self._select_headline( mode, image_result, waveform_result, param_names_lookup, ood_fallback_threshold, joint_result=joint_result, ) return { "headline": headline, "image_mode": image_result, "waveform_mode": waveform_result, "joint_mode": joint_result, "preprocessing_meta": preproc_meta, "agreement": agreement, "method_used": method_used, } def _predict_tpd_two_paths( self, pil_images, betas, temperatures, rates, n_samples=500, temperature=1.0, do_preprocess=True, ): image_result = None waveform_result = None joint_result = None preproc_meta = [] preprocessed = None if pil_images and (self.has_tpd_image_model or self.has_tpd_joint_model): if do_preprocess: from image_preprocessing import prepare_for_image_mode preprocessed = [] for p in pil_images: out, meta = prepare_for_image_mode(p) preprocessed.append(out) preproc_meta.append(meta) else: preprocessed = list(pil_images) preproc_meta = [{} for _ in pil_images] if self.has_tpd_image_model and preprocessed is not None: flux_scales = None if rates is not None: flux_scales = [ float(np.log10(np.max(np.abs(np.asarray(r))) + 1e-30)) for r in rates ] try: image_result = self.predict_tpd_image( preprocessed, betas, flux_scales=flux_scales, n_samples=n_samples, temperature=temperature, ) except Exception as exc: print(f"[SPARKPredictor] image-mode TPD failed: {exc}") if self.tpd_model is not None and temperatures is not None and rates is not None: try: waveform_result = self.predict_tpd( temperatures, rates, betas, n_samples=n_samples, temperature=temperature, ) except Exception as exc: print(f"[SPARKPredictor] waveform TPD failed: {exc}") if (self.has_tpd_joint_model and preprocessed is not None and temperatures is not None and rates is not None): try: joint_result = self.predict_tpd_joint( preprocessed, betas, temperatures, rates, n_samples=n_samples, temperature=temperature, ) except Exception as exc: print(f"[SPARKPredictor] joint TPD failed: {exc}") return image_result, waveform_result, joint_result, preproc_meta def predict_tpd_hybrid( self, pil_images, betas, temperatures=None, rates=None, n_samples=500, temperature=1.0, mode="ensemble", do_preprocess=True, ood_fallback_threshold=0.3, ): """Hybrid TPD inference. Same shape as predict_ec_hybrid.""" (image_result, waveform_result, joint_result, preproc_meta) = ( self._predict_tpd_two_paths( pil_images, betas, temperatures, rates, n_samples=n_samples, temperature=temperature, do_preprocess=do_preprocess, ) ) agreement = self._agreement_stats(image_result, waveform_result) param_names_lookup = {m: TPD_MECHANISM_PARAMS[m]["names"] for m in TPD_MECHANISM_PARAMS} headline, method_used = self._select_headline( mode, image_result, waveform_result, param_names_lookup, ood_fallback_threshold, joint_result=joint_result, ) return { "headline": headline, "image_mode": image_result, "waveform_mode": waveform_result, "joint_mode": joint_result, "preprocessing_meta": preproc_meta, "agreement": agreement, "method_used": method_used, } def _select_headline(self, mode, image_result, waveform_result, param_names_lookup, ood_fallback_threshold, joint_result=None): """Pick the headline result based on `mode`. Falls back gracefully when one of the available paths is unavailable. Supported modes: ``ensemble``, ``image_only``, ``digitize_only``, ``auto_fallback``, ``joint``. """ if (image_result is None and waveform_result is None and joint_result is None): raise RuntimeError( "Hybrid inference: image, waveform, and joint paths all " "failed or were unavailable." ) if mode == "joint": if joint_result is not None: return joint_result, "joint image+waveform (Phase 2)" if image_result is not None and waveform_result is not None: return ( self._ensemble_results( image_result, waveform_result, param_names_lookup), "ensemble (joint unavailable)", ) if image_result is not None: return image_result, "image-direct (joint unavailable)" return waveform_result, "digitize-then-infer (joint unavailable)" if mode == "image_only": if image_result is not None: return image_result, "image-direct" return waveform_result, "digitize-then-infer (image fallback)" if mode == "digitize_only": if waveform_result is not None: return waveform_result, "digitize-then-infer" return image_result, "image-direct (waveform fallback)" if mode == "auto_fallback": if image_result is None: return waveform_result, "digitize-then-infer (no image model)" ood = image_result.get("ood_score") if (ood is not None and ood < ood_fallback_threshold and waveform_result is not None): return waveform_result, ( f"digitize-then-infer (image OOD score " f"{ood:.2f} < {ood_fallback_threshold:.2f})" ) return image_result, "image-direct" # ensemble (default): prefer joint when present, else mix image+wave if joint_result is not None: return joint_result, "joint image+waveform (Phase 2)" if image_result is not None and waveform_result is not None: return ( self._ensemble_results( image_result, waveform_result, param_names_lookup), "ensemble (image + digitize)", ) if image_result is not None: return image_result, "image-direct (waveform unavailable)" return waveform_result, "digitize-then-infer (image unavailable)" def _prepare_ec_tensor(self, potentials, fluxes, times, sigmas): """ Build model input tensor from preprocessed dimensionless CV data. Args: potentials: list of 1-D arrays (dimensionless theta) fluxes: list of 1-D arrays (dimensionless flux) times: list of 1-D arrays (dimensionless time) or None sigmas: 1-D array of dimensionless scan rates Returns: dict of tensors ready for model.predict() """ from scipy.interpolate import interp1d n_scans = len(potentials) T_target = 672 pot_resampled = [] flux_resampled = [] time_resampled = [] flux_scales = [] for i in range(n_scans): pot = np.asarray(potentials[i], dtype=np.float32) flx = np.asarray(fluxes[i], dtype=np.float32) if times is not None and times[i] is not None: tim = np.asarray(times[i], dtype=np.float32) else: theta_range = pot.max() - pot.min() sigma = sigmas[i] total_time = 2.0 * theta_range / sigma tim = np.linspace(0, total_time, len(pot), dtype=np.float32) peak = np.max(np.abs(flx)) + 1e-30 flux_scales.append(np.log10(peak)) flx = flx / peak t_uniform = np.linspace(tim[0], tim[-1], T_target) pot_resampled.append( interp1d(tim, pot, kind="linear", fill_value="extrapolate")(t_uniform) ) flux_resampled.append( interp1d(tim, flx, kind="linear", fill_value="extrapolate")(t_uniform) ) time_resampled.append(t_uniform) pot_arr = np.stack(pot_resampled).astype(np.float32) flx_arr = np.stack(flux_resampled).astype(np.float32) tim_arr = np.stack(time_resampled).astype(np.float32) ns = self.ec_norm_stats if ns: pot_arr = (pot_arr - ns["potential"][0]) / ns["potential"][1] flx_arr = (flx_arr - ns["flux"][0]) / ns["flux"][1] tim_arr = (tim_arr - ns["time"][0]) / ns["time"][1] # [1, N, 3, T] waveforms = np.stack([pot_arr, flx_arr, tim_arr], axis=1) x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) scan_mask = torch.ones(1, n_scans, T_target, dtype=torch.bool, device=self.device) sigmas_t = torch.from_numpy( np.log10(np.asarray(sigmas, dtype=np.float32)) ).unsqueeze(0).to(self.device) flux_scales_t = torch.from_numpy( np.asarray(flux_scales, dtype=np.float32) ).unsqueeze(0).to(self.device) return { "input": x, "scan_mask": scan_mask, "sigmas": sigmas_t, "flux_scales": flux_scales_t, } def _prepare_tpd_tensor(self, temperatures, rates, betas): """ Build model input tensor from TPD data. Args: temperatures: list of 1-D arrays (K) rates: list of 1-D arrays (arb. units) betas: 1-D array of heating rates (K/s) Returns: dict of tensors ready for model.predict() """ from scipy.interpolate import interp1d n_rates = len(temperatures) T_target = 500 temp_resampled = [] rate_resampled = [] for i in range(n_rates): temp = np.asarray(temperatures[i], dtype=np.float32) rate = np.asarray(rates[i], dtype=np.float32) t_uniform = np.linspace(temp[0], temp[-1], T_target) temp_resampled.append(t_uniform) rate_resampled.append( interp1d(temp, rate, kind="linear", fill_value="extrapolate")(t_uniform) ) temp_arr = np.stack(temp_resampled).astype(np.float32) rate_arr = np.stack(rate_resampled).astype(np.float32) summary_t = None if getattr(self, 'tpd_use_summary', False): from preprocessing import extract_tpd_summary_stats hr_arr = np.asarray(betas, dtype=np.float32) lengths = np.full(n_rates, T_target, dtype=np.int32) summary = extract_tpd_summary_stats( temp_arr, rate_arr, lengths, hr_arr, n_rates) summary_t = torch.from_numpy(summary).unsqueeze(0).to(self.device) rate_scales = [] for i in range(n_rates): peak = np.max(np.abs(rate_arr[i])) + 1e-30 rate_scales.append(np.log10(peak)) rate_arr[i] /= peak ns = self.tpd_norm_stats if ns: temp_arr = (temp_arr - ns["temperature"][0]) / ns["temperature"][1] rate_arr = (rate_arr - ns["rate"][0]) / ns["rate"][1] # [1, N, 2, T] waveforms = np.stack([temp_arr, rate_arr], axis=1) x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) scan_mask = torch.ones(1, n_rates, T_target, dtype=torch.bool, device=self.device) sigmas_t = torch.from_numpy( np.log10(np.asarray(betas, dtype=np.float32)) ).unsqueeze(0).to(self.device) rate_scales_t = torch.from_numpy( np.asarray(rate_scales, dtype=np.float32) ).unsqueeze(0).to(self.device) result = { "input": x, "scan_mask": scan_mask, "sigmas": sigmas_t, "flux_scales": rate_scales_t, } if summary_t is not None: result["summary"] = summary_t return result @torch.no_grad() def predict_ec(self, potentials, fluxes, sigmas, times=None, n_samples=500, temperature=1.0): """ Run EC inference on dimensionless CV data. Args: potentials: list of 1-D arrays (dimensionless theta per scan rate) fluxes: list of 1-D arrays (dimensionless flux per scan rate) sigmas: list/array of dimensionless scan rates times: optional list of 1-D time arrays n_samples: posterior samples to draw temperature: sampling temperature (>1 broadens posteriors) Returns: dict with mechanism_probs, mechanism_names, predicted_mechanism, parameter_stats (per mechanism), posterior_samples (per mechanism) """ if self.ec_model is None: raise RuntimeError("EC model not loaded") tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas) pred = self.ec_model.predict( tensors["input"], scan_mask=tensors["scan_mask"], sigmas=tensors["sigmas"], flux_scales=tensors["flux_scales"], n_samples=n_samples, temperature=temperature, ) mech_list = self.ec_mechanism_list probs = pred["mechanism_probs"][0].cpu().numpy() pred_idx = int(pred["mechanism_pred"][0].cpu().item()) pred_mech = mech_list[pred_idx] param_stats = {} samples_dict = {} for mech in mech_list: if pred["samples"][mech] is not None: s = pred["samples"][mech][0].cpu().numpy() # [n_samples, D] samples_dict[mech] = s param_stats[mech] = { "names": MECHANISM_PARAMS[mech]["names"], "mean": s.mean(axis=0).tolist(), "std": s.std(axis=0).tolist(), "median": np.median(s, axis=0).tolist(), "q05": np.quantile(s, 0.05, axis=0).tolist(), "q95": np.quantile(s, 0.95, axis=0).tolist(), } ood_score = pred.get("ood_score") ood_score_val = (float(ood_score[0].cpu().item()) if ood_score is not None else None) return { "domain": "ec", "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, "mechanism_names": mech_list, "predicted_mechanism": pred_mech, "predicted_mechanism_idx": pred_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, "ood_score": ood_score_val, } @torch.no_grad() def predict_tpd(self, temperatures, rates, betas, n_samples=500, temperature=1.0): """ Run TPD inference. Args: temperatures: list of 1-D arrays (K per heating rate) rates: list of 1-D arrays (signal per heating rate) betas: list/array of heating rates (K/s) n_samples: posterior samples to draw temperature: sampling temperature Returns: dict with mechanism_probs, parameter_stats, posterior_samples """ if self.tpd_model is None: raise RuntimeError("TPD model not loaded") tensors = self._prepare_tpd_tensor(temperatures, rates, betas) pred = self.tpd_model.predict( tensors["input"], scan_mask=tensors["scan_mask"], sigmas=tensors["sigmas"], flux_scales=tensors["flux_scales"], n_samples=n_samples, temperature=temperature, summary=tensors.get("summary"), ) mech_list = self.tpd_mechanism_list probs = pred["mechanism_probs"][0].cpu().numpy() pred_idx = int(pred["mechanism_pred"][0].cpu().item()) pred_mech = mech_list[pred_idx] param_stats = {} samples_dict = {} for mech in mech_list: if pred["samples"][mech] is not None: s = pred["samples"][mech][0].cpu().numpy() samples_dict[mech] = s param_stats[mech] = { "names": TPD_MECHANISM_PARAMS[mech]["names"], "mean": s.mean(axis=0).tolist(), "std": s.std(axis=0).tolist(), "median": np.median(s, axis=0).tolist(), "q05": np.quantile(s, 0.05, axis=0).tolist(), "q95": np.quantile(s, 0.95, axis=0).tolist(), } ood_score = pred.get("ood_score") ood_score_val = (float(ood_score[0].cpu().item()) if ood_score is not None else None) return { "domain": "tpd", "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, "mechanism_names": mech_list, "predicted_mechanism": pred_mech, "predicted_mechanism_idx": pred_idx, "parameter_stats": param_stats, "posterior_samples": samples_dict, "ood_score": ood_score_val, } # ===================================================================== # Signal Reconstruction # ===================================================================== def reconstruct_ec(self, result, potentials, fluxes, sigmas, base_params=None, mechanism=None): """ Reconstruct CV signals from inferred posterior median and compute metrics. Args: result: output dict from predict_ec() potentials: list of 1-D arrays (original dimensionless theta) fluxes: list of 1-D arrays (original dimensionless flux) sigmas: list of dimensionless scan rates base_params: dict of fixed simulation params; defaults used if None mechanism: which mechanism to reconstruct (default: predicted) Returns: dict with 'observed', 'reconstructed' curve lists, 'nrmse', 'r2' per scan rate, and 'mean_nrmse', 'mean_r2' """ from evaluate_reconstruction import ( reconstruct_ec_signal, signal_nrmse, signal_r2, ) mech = mechanism or result["predicted_mechanism"] stats = result["parameter_stats"].get(mech) if stats is None: return None theta_point = np.array(stats["median"]) if base_params is None: pot0 = np.asarray(potentials[0]) base_params = { "theta_i": float(pot0.max()), "theta_v": float(pot0.min()), "dA": 1.0, "C_A_bulk": 1.0, "C_B_bulk": 0.0, "kinetics": mech, } try: recon_results = reconstruct_ec_signal( theta_point, mech, base_params, sigmas, n_spatial=64 ) except Exception: return None observed_curves = [] recon_curves = [] conc_curves = [] nrmses = [] r2s = [] for i, (pot, flx, sigma) in enumerate(zip(potentials, fluxes, sigmas)): pot = np.asarray(pot) flx = np.asarray(flx) observed_curves.append({"x": pot, "y": flx}) if i < len(recon_results) and recon_results[i].get("success", False): rec = recon_results[i] rec_pot = np.asarray(rec["potential"]) rec_flx = np.asarray(rec["flux"]) n_obs = len(pot) n_rec = len(rec_pot) t_obs = np.linspace(0, 1, n_obs) t_rec = np.linspace(0, 1, n_rec) rec_flx_interp = np.interp(t_obs, t_rec, rec_flx) recon_curves.append({"x": pot, "y": rec_flx_interp}) nrmse_val = signal_nrmse(flx, rec_flx_interp) r2_val = signal_r2(flx, rec_flx_interp) nrmses.append(nrmse_val) r2s.append(r2_val) if "c_ox_surface" in rec and "c_red_surface" in rec: c_ox_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_ox_surface"])) c_red_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_red_surface"])) conc_curves.append({ "x": pot, "c_ox": c_ox_interp, "c_red": c_red_interp, }) else: conc_curves.append(None) else: recon_curves.append({"x": pot, "y": np.zeros_like(flx)}) nrmses.append(float("nan")) r2s.append(float("nan")) conc_curves.append(None) valid_nrmse = [v for v in nrmses if np.isfinite(v)] valid_r2 = [v for v in r2s if np.isfinite(v)] return { "observed": observed_curves, "reconstructed": recon_curves, "concentrations": conc_curves, "nrmse": nrmses, "r2": r2s, "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), } def reconstruct_tpd(self, result, temperatures, rates, betas, base_params=None, mechanism=None): """ Reconstruct TPD signals from inferred posterior median and compute metrics. Args: result: output dict from predict_tpd() temperatures: list of 1-D arrays (K) rates: list of 1-D arrays (signal) betas: list of heating rates (K/s) base_params: dict of fixed simulation params; defaults used if None mechanism: which mechanism to reconstruct (default: predicted) Returns: dict with 'observed', 'reconstructed' curve lists, 'nrmse', 'r2' per heating rate, and 'mean_nrmse', 'mean_r2' """ from evaluate_reconstruction import ( reconstruct_tpd_signal, signal_nrmse, signal_r2, ) mech = mechanism or result["predicted_mechanism"] stats = result["parameter_stats"].get(mech) if stats is None: return None theta_point = np.array(stats["median"]) if base_params is None: temp0 = np.asarray(temperatures[0]) base_params = { "mechanism": mech, "T_start": float(temp0.min()), "T_end": float(temp0.max()), "n_points": 500, } try: recon_results = reconstruct_tpd_signal( theta_point, mech, base_params, betas ) except Exception: return None observed_curves = [] recon_curves = [] nrmses = [] r2s = [] for i, (temp, rate, beta) in enumerate(zip(temperatures, rates, betas)): temp = np.asarray(temp) rate = np.asarray(rate) observed_curves.append({"x": temp, "y": rate}) if i < len(recon_results) and recon_results[i].get("success", False): rec = recon_results[i] rec_temp = np.asarray(rec["temperature"]) rec_rate = np.asarray(rec["rate"]) rec_rate_interp = np.interp(temp, rec_temp, rec_rate) recon_curves.append({"x": temp, "y": rec_rate_interp}) nrmse_val = signal_nrmse(rate, rec_rate_interp) r2_val = signal_r2(rate, rec_rate_interp) nrmses.append(nrmse_val) r2s.append(r2_val) else: recon_curves.append({"x": temp, "y": np.zeros_like(rate)}) nrmses.append(float("nan")) r2s.append(float("nan")) valid_nrmse = [v for v in nrmses if np.isfinite(v)] valid_r2 = [v for v in r2s if np.isfinite(v)] return { "observed": observed_curves, "reconstructed": recon_curves, "nrmse": nrmses, "r2": r2s, "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), }