| """ |
| SPARK inference engine. |
| |
| Loads trained EC and TPD models and runs end-to-end inference |
| from preprocessed arrays (dimensionless for CV, physical for TPD). |
| """ |
|
|
| import json |
| import sys |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
|
|
|
|
| import flow_model as _fm_module |
| import multi_mechanism_model as _mm_module |
| import tpd_model as _tpd_mod |
| import generate_tpd_data as _tpd_gen |
| from multi_mechanism_model import MultiMechanismFlow |
| from tpd_model import MultiMechanismFlowTPD |
| from flow_model import MECHANISM_PARAMS, ActNorm |
| from generate_tpd_data import TPD_MECHANISM_PARAMS |
|
|
|
|
| def _fix_actnorm_initialized(model): |
| """Mark all ActNorm layers as initialized after loading a checkpoint. |
| |
| Old checkpoints lack the ``_initialized`` buffer, so ``load_state_dict`` |
| leaves it at ``False``. The first forward pass would then overwrite the |
| trained ``log_scale``/``bias`` with data-dependent statistics. |
| """ |
| for module in model.modules(): |
| if isinstance(module, ActNorm) and not module.initialized: |
| module.initialized = True |
|
|
|
|
| class SPARKPredictor: |
| """Unified predictor for both EC (cyclic voltammetry) and TPD domains.""" |
|
|
| def __init__(self, ec_checkpoint=None, tpd_checkpoint=None, device=None, |
| ec_image_checkpoint=None, tpd_image_checkpoint=None, |
| ec_joint_checkpoint=None, tpd_joint_checkpoint=None): |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.device = device |
|
|
| self.ec_model = None |
| self.ec_norm_stats = None |
| self.tpd_model = None |
| self.tpd_norm_stats = None |
|
|
| |
| |
| self.ec_image_model = None |
| self.ec_image_mech_list = None |
| self.tpd_image_model = None |
| self.tpd_image_mech_list = None |
|
|
| |
| self.ec_joint_model = None |
| self.ec_joint_mech_list = None |
| self.tpd_joint_model = None |
| self.tpd_joint_mech_list = None |
|
|
| if ec_checkpoint is not None: |
| self._load_ec(ec_checkpoint) |
| if tpd_checkpoint is not None: |
| self._load_tpd(tpd_checkpoint) |
| if ec_image_checkpoint is not None: |
| self._load_ec_image(ec_image_checkpoint) |
| if tpd_image_checkpoint is not None: |
| self._load_tpd_image(tpd_image_checkpoint) |
| if ec_joint_checkpoint is not None: |
| self._load_ec_joint(ec_joint_checkpoint) |
| if tpd_joint_checkpoint is not None: |
| self._load_tpd_joint(tpd_joint_checkpoint) |
|
|
| @property |
| def has_ec_image_model(self) -> bool: |
| return self.ec_image_model is not None |
|
|
| @property |
| def has_tpd_image_model(self) -> bool: |
| return self.tpd_image_model is not None |
|
|
| @property |
| def has_ec_joint_model(self) -> bool: |
| return self.ec_joint_model is not None |
|
|
| @property |
| def has_tpd_joint_model(self) -> bool: |
| return self.tpd_joint_model is not None |
|
|
| def _load_ec(self, ckpt_path): |
| ckpt_path = Path(ckpt_path) |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| args = checkpoint["args"] |
|
|
| |
| |
| |
| |
| |
| |
| if args.get("mechanism_list") is not None: |
| new_list = list(args["mechanism_list"]) |
| _fm_module.MECHANISM_LIST = new_list |
| _mm_module.MECHANISM_LIST = new_list |
| self.ec_mechanism_list = list(_fm_module.MECHANISM_LIST) |
|
|
| self.ec_model = MultiMechanismFlow( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| aggregation=args.get("aggregation", "set_transformer"), |
| use_summary_features=args.get("use_summary_features", False), |
| ) |
|
|
| |
| |
| state = checkpoint["model_state_dict"] |
| if any(k.startswith("ood_head.") for k in state): |
| meta = checkpoint.get("ood_head_meta", {}) |
| self.ec_model.init_ood_head( |
| hidden_dim=meta.get("hidden_dim", 64), |
| extra_input_dim=meta.get("extra_input_dim", 0), |
| use_nll=meta.get("use_nll", False), |
| use_posterior_width=meta.get("use_posterior_width", False), |
| ) |
|
|
| missing, unexpected = self.ec_model.load_state_dict( |
| state, strict=False) |
| |
| suspicious_missing = [ |
| k for k in missing if not k.endswith("_initialized") |
| ] |
| suspicious_unexpected = [ |
| k for k in unexpected if not k.endswith("_initialized") |
| ] |
| if suspicious_missing or suspicious_unexpected: |
| print(f"[SPARKPredictor] WARNING: state_dict mismatch on EC ckpt.") |
| if suspicious_missing: |
| print(f" missing ({len(suspicious_missing)}): " |
| f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") |
| if suspicious_unexpected: |
| print(f" unexpected({len(suspicious_unexpected)}): " |
| f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") |
| _fix_actnorm_initialized(self.ec_model) |
| self.ec_model.to(self.device).eval() |
|
|
| |
| ckpt_dir = ckpt_path.parent |
| stem = ckpt_path.stem.replace("best", "").rstrip("_") |
| prefix = stem + "_" if stem else "" |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}norm_stats.json", "ec_norm_stats.json", "norm_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.ec_norm_stats = json.load(f) |
| break |
| if self.ec_norm_stats is not None: |
| break |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}theta_stats.json", "ec_theta_stats.json", "theta_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.ec_theta_stats = json.load(f) |
| break |
| if hasattr(self, "ec_theta_stats") and self.ec_theta_stats is not None: |
| break |
|
|
| def _load_tpd(self, ckpt_path): |
| ckpt_path = Path(ckpt_path) |
| checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| args = checkpoint["args"] |
|
|
| self.tpd_use_summary = args.get("use_summary_features", False) |
|
|
| |
| |
| |
| ckpt_mech_list = args.get("mechanism_list") |
| if ckpt_mech_list is not None: |
| new_list = list(ckpt_mech_list) |
| _tpd_gen.TPD_MECHANISM_LIST = new_list |
| _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)} |
| _tpd_mod.TPD_MECHANISM_LIST = new_list |
| try: |
| import dataset_tpd as _ds_tpd |
| _ds_tpd.TPD_MECHANISM_LIST = new_list |
| except ImportError: |
| pass |
| self.tpd_mechanism_list = list(_tpd_gen.TPD_MECHANISM_LIST) |
|
|
| self.tpd_model = MultiMechanismFlowTPD( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| use_summary_features=self.tpd_use_summary, |
| use_bounded_flow=args.get("use_bounded_flow", False), |
| mechanism_list=self.tpd_mechanism_list, |
| ) |
|
|
| state = checkpoint["model_state_dict"] |
| if any(k.startswith("ood_head.") for k in state): |
| meta = checkpoint.get("ood_head_meta", {}) |
| self.tpd_model.init_ood_head( |
| hidden_dim=meta.get("hidden_dim", 64), |
| extra_input_dim=meta.get("extra_input_dim", 0), |
| ) |
|
|
| missing, unexpected = self.tpd_model.load_state_dict( |
| state, strict=False) |
| suspicious_missing = [ |
| k for k in missing if not k.endswith("_initialized") |
| ] |
| suspicious_unexpected = [ |
| k for k in unexpected if not k.endswith("_initialized") |
| ] |
| if suspicious_missing or suspicious_unexpected: |
| print(f"[SPARKPredictor] WARNING: state_dict mismatch on TPD ckpt.") |
| if suspicious_missing: |
| print(f" missing ({len(suspicious_missing)}): " |
| f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") |
| if suspicious_unexpected: |
| print(f" unexpected({len(suspicious_unexpected)}): " |
| f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") |
| _fix_actnorm_initialized(self.tpd_model) |
| self.tpd_model.to(self.device).eval() |
|
|
| |
| ckpt_dir = ckpt_path.parent |
| stem = ckpt_path.stem.replace("best", "").rstrip("_") |
| prefix = stem + "_" if stem else "" |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}norm_stats.json", "tpd_norm_stats.json", "norm_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.tpd_norm_stats = json.load(f) |
| break |
| if self.tpd_norm_stats is not None: |
| break |
| for search_dir in [ckpt_dir, ckpt_dir.parent]: |
| for name_pattern in [f"{prefix}theta_stats.json", "tpd_theta_stats.json", "theta_stats.json"]: |
| p = search_dir / name_pattern |
| if p.exists(): |
| with open(p) as f: |
| self.tpd_theta_stats = json.load(f) |
| break |
| if hasattr(self, "tpd_theta_stats") and self.tpd_theta_stats is not None: |
| break |
|
|
| |
| |
| |
|
|
| @staticmethod |
| def _detect_input_mode(state_dict) -> str: |
| """Detect input_mode from a checkpoint's state_dict keys. |
| |
| Phase-1 image -> encoder.per_cv_encoder.stem.* |
| Phase-1 wave -> encoder.per_cv_encoder.conv.* |
| Phase-2 joint -> encoder.image_encoder.* + encoder.waveform_encoder.* |
| """ |
| if any(k.startswith("encoder.image_encoder.") for k in state_dict) and \ |
| any(k.startswith("encoder.waveform_encoder.") for k in state_dict): |
| return "image+waveform" |
| if any(k.startswith("encoder.per_cv_encoder.stem.") for k in state_dict): |
| return "image" |
| return "waveform" |
|
|
| def _build_image_model_state(self, ckpt_path, builder_cls, |
| is_tpd=False, expected_input_mode="image"): |
| ckpt_path = Path(ckpt_path) |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| args = ckpt["args"] if isinstance(ckpt, dict) and "args" in ckpt else {} |
| |
| |
| actual = (args.get("input_mode") |
| or self._detect_input_mode(ckpt["model_state_dict"])) |
| if actual != expected_input_mode: |
| raise ValueError( |
| f"{expected_input_mode!r} checkpoint expected; got input_mode=" |
| f"{actual!r} for {ckpt_path}" |
| ) |
| return ckpt_path, ckpt, args |
|
|
| def _load_ec_image(self, ckpt_path): |
| ckpt_path, ckpt, args = self._build_image_model_state( |
| ckpt_path, MultiMechanismFlow, |
| ) |
| if args.get("mechanism_list") is not None: |
| new_list = list(args["mechanism_list"]) |
| _fm_module.MECHANISM_LIST = new_list |
| _mm_module.MECHANISM_LIST = new_list |
| self.ec_image_mech_list = list(_fm_module.MECHANISM_LIST) |
|
|
| self.ec_image_model = MultiMechanismFlow( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| aggregation=args.get("aggregation", "set_transformer"), |
| use_summary_features=False, |
| input_mode="image", |
| image_in_channels=args.get("image_in_channels", 1), |
| ) |
|
|
| state = ckpt["model_state_dict"] |
| if any(k.startswith("ood_head.") for k in state): |
| meta = ckpt.get("ood_head_meta", {}) |
| self.ec_image_model.init_ood_head( |
| hidden_dim=meta.get("hidden_dim", 64), |
| extra_input_dim=meta.get("extra_input_dim", 0), |
| use_nll=meta.get("use_nll", False), |
| use_posterior_width=meta.get("use_posterior_width", False), |
| ) |
|
|
| missing, unexpected = self.ec_image_model.load_state_dict( |
| state, strict=False) |
| suspicious_missing = [k for k in missing if not k.endswith("_initialized")] |
| suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] |
| if suspicious_missing or suspicious_unexpected: |
| print("[SPARKPredictor] WARNING: state_dict mismatch on EC image ckpt.") |
| if suspicious_missing: |
| print(f" missing ({len(suspicious_missing)}): " |
| f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") |
| if suspicious_unexpected: |
| print(f" unexpected({len(suspicious_unexpected)}): " |
| f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") |
| _fix_actnorm_initialized(self.ec_image_model) |
| self.ec_image_model.to(self.device).eval() |
|
|
| def _load_tpd_image(self, ckpt_path): |
| ckpt_path, ckpt, args = self._build_image_model_state( |
| ckpt_path, MultiMechanismFlowTPD, is_tpd=True, |
| ) |
| ckpt_mech_list = args.get("mechanism_list") |
| if ckpt_mech_list is not None: |
| new_list = list(ckpt_mech_list) |
| _tpd_gen.TPD_MECHANISM_LIST = new_list |
| _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)} |
| _tpd_mod.TPD_MECHANISM_LIST = new_list |
| try: |
| import dataset_tpd as _ds_tpd |
| _ds_tpd.TPD_MECHANISM_LIST = new_list |
| except ImportError: |
| pass |
| self.tpd_image_mech_list = list(_tpd_gen.TPD_MECHANISM_LIST) |
|
|
| self.tpd_image_model = MultiMechanismFlowTPD( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| use_summary_features=False, |
| use_bounded_flow=args.get("use_bounded_flow", False), |
| mechanism_list=self.tpd_image_mech_list, |
| input_mode="image", |
| image_in_channels=args.get("image_in_channels", 1), |
| ) |
|
|
| state = ckpt["model_state_dict"] |
| if any(k.startswith("ood_head.") for k in state): |
| meta = ckpt.get("ood_head_meta", {}) |
| self.tpd_image_model.init_ood_head( |
| hidden_dim=meta.get("hidden_dim", 64), |
| extra_input_dim=meta.get("extra_input_dim", 0), |
| ) |
|
|
| missing, unexpected = self.tpd_image_model.load_state_dict( |
| state, strict=False) |
| suspicious_missing = [k for k in missing if not k.endswith("_initialized")] |
| suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] |
| if suspicious_missing or suspicious_unexpected: |
| print("[SPARKPredictor] WARNING: state_dict mismatch on TPD image ckpt.") |
| if suspicious_missing: |
| print(f" missing ({len(suspicious_missing)}): " |
| f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") |
| if suspicious_unexpected: |
| print(f" unexpected({len(suspicious_unexpected)}): " |
| f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") |
| _fix_actnorm_initialized(self.tpd_image_model) |
| self.tpd_image_model.to(self.device).eval() |
|
|
| |
| |
| |
|
|
| def _load_ec_joint(self, ckpt_path): |
| ckpt_path, ckpt, args = self._build_image_model_state( |
| ckpt_path, MultiMechanismFlow, |
| expected_input_mode="image+waveform", |
| ) |
| if args.get("mechanism_list") is not None: |
| new_list = list(args["mechanism_list"]) |
| _fm_module.MECHANISM_LIST = new_list |
| _mm_module.MECHANISM_LIST = new_list |
| self.ec_joint_mech_list = list(_fm_module.MECHANISM_LIST) |
|
|
| self.ec_joint_model = MultiMechanismFlow( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| aggregation=args.get("aggregation", "set_transformer"), |
| use_summary_features=False, |
| input_mode="image+waveform", |
| image_in_channels=args.get("image_in_channels", 1), |
| ) |
|
|
| state = ckpt["model_state_dict"] |
| if any(k.startswith("ood_head.") for k in state): |
| meta = ckpt.get("ood_head_meta", {}) |
| self.ec_joint_model.init_ood_head( |
| hidden_dim=meta.get("hidden_dim", 64), |
| extra_input_dim=meta.get("extra_input_dim", 0), |
| use_nll=meta.get("use_nll", False), |
| use_posterior_width=meta.get("use_posterior_width", False), |
| ) |
|
|
| missing, unexpected = self.ec_joint_model.load_state_dict( |
| state, strict=False) |
| suspicious_missing = [k for k in missing if not k.endswith("_initialized")] |
| suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] |
| if suspicious_missing or suspicious_unexpected: |
| print("[SPARKPredictor] WARNING: state_dict mismatch on EC joint ckpt.") |
| if suspicious_missing: |
| print(f" missing ({len(suspicious_missing)}): " |
| f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") |
| if suspicious_unexpected: |
| print(f" unexpected({len(suspicious_unexpected)}): " |
| f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") |
| _fix_actnorm_initialized(self.ec_joint_model) |
| self.ec_joint_model.to(self.device).eval() |
|
|
| def _load_tpd_joint(self, ckpt_path): |
| ckpt_path, ckpt, args = self._build_image_model_state( |
| ckpt_path, MultiMechanismFlowTPD, is_tpd=True, |
| expected_input_mode="image+waveform", |
| ) |
| ckpt_mech_list = args.get("mechanism_list") |
| if ckpt_mech_list is not None: |
| new_list = list(ckpt_mech_list) |
| _tpd_gen.TPD_MECHANISM_LIST = new_list |
| _tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)} |
| _tpd_mod.TPD_MECHANISM_LIST = new_list |
| try: |
| import dataset_tpd as _ds_tpd |
| _ds_tpd.TPD_MECHANISM_LIST = new_list |
| except ImportError: |
| pass |
| self.tpd_joint_mech_list = list(_tpd_gen.TPD_MECHANISM_LIST) |
|
|
| self.tpd_joint_model = MultiMechanismFlowTPD( |
| d_context=args.get("d_context", 128), |
| d_model=args.get("d_model", 128), |
| n_coupling_layers=args.get("n_coupling_layers", 6), |
| hidden_dim=args.get("hidden_dim", 96), |
| coupling_type=args.get("coupling_type", "spline"), |
| n_bins=args.get("n_bins", 8), |
| tail_bound=args.get("tail_bound", 5.0), |
| use_summary_features=False, |
| use_bounded_flow=args.get("use_bounded_flow", False), |
| mechanism_list=self.tpd_joint_mech_list, |
| input_mode="image+waveform", |
| image_in_channels=args.get("image_in_channels", 1), |
| ) |
|
|
| state = ckpt["model_state_dict"] |
| if any(k.startswith("ood_head.") for k in state): |
| meta = ckpt.get("ood_head_meta", {}) |
| self.tpd_joint_model.init_ood_head( |
| hidden_dim=meta.get("hidden_dim", 64), |
| extra_input_dim=meta.get("extra_input_dim", 0), |
| ) |
|
|
| missing, unexpected = self.tpd_joint_model.load_state_dict( |
| state, strict=False) |
| suspicious_missing = [k for k in missing if not k.endswith("_initialized")] |
| suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")] |
| if suspicious_missing or suspicious_unexpected: |
| print("[SPARKPredictor] WARNING: state_dict mismatch on TPD joint ckpt.") |
| if suspicious_missing: |
| print(f" missing ({len(suspicious_missing)}): " |
| f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}") |
| if suspicious_unexpected: |
| print(f" unexpected({len(suspicious_unexpected)}): " |
| f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}") |
| _fix_actnorm_initialized(self.tpd_joint_model) |
| self.tpd_joint_model.to(self.device).eval() |
|
|
| @staticmethod |
| def _pil_to_grayscale_tensor(pil_image, target_size=224): |
| """Convert a PIL image to a [1, target_size, target_size] float |
| tensor in [0, 1].""" |
| from PIL import Image as PILImage |
| img = pil_image.convert("L") |
| if img.size != (target_size, target_size): |
| img = img.resize((target_size, target_size), PILImage.BILINEAR) |
| arr = np.asarray(img, dtype=np.float32) / 255.0 |
| return torch.from_numpy(arr).unsqueeze(0) |
|
|
| def _build_image_input(self, pil_images, sigmas, flux_scales, |
| target_size=224): |
| """Build image-mode model input from a list of PIL images. |
| |
| Args: |
| pil_images: list of PIL.Image objects (one per scan rate / heating rate). |
| sigmas: 1-D array of raw scan rates (V/s) or heating rates (K/s). |
| flux_scales: 1-D array of log10(peak |signal|) per scan/curve. |
| If None, set to zero (model still works since flux_scales is |
| additive conditioning that the network learned to handle). |
| target_size: image edge length expected by the encoder. |
| """ |
| n = len(pil_images) |
| imgs = torch.stack( |
| [self._pil_to_grayscale_tensor(p, target_size) for p in pil_images] |
| ) |
| x = imgs.unsqueeze(0).to(self.device) |
| scan_mask = torch.ones(1, n, dtype=torch.bool, device=self.device) |
| sigmas_log = np.log10(np.clip(np.asarray(sigmas, dtype=np.float32), |
| 1e-10, None)) |
| sigmas_t = torch.from_numpy(sigmas_log).unsqueeze(0).to(self.device) |
| if flux_scales is None: |
| fs_t = torch.zeros(1, n, dtype=torch.float32, device=self.device) |
| else: |
| fs_t = torch.from_numpy( |
| np.asarray(flux_scales, dtype=np.float32) |
| ).unsqueeze(0).to(self.device) |
| return { |
| "input": x, "scan_mask": scan_mask, |
| "sigmas": sigmas_t, "flux_scales": fs_t, |
| } |
|
|
| @torch.no_grad() |
| def predict_ec_image(self, pil_images, sigmas, flux_scales=None, |
| n_samples=500, temperature=1.0): |
| """Run image-mode CV inference. `pil_images` length should match `sigmas`. |
| |
| Returns the same dict shape as `predict_ec`. |
| """ |
| if self.ec_image_model is None: |
| raise RuntimeError("EC image model not loaded") |
| if len(pil_images) != len(sigmas): |
| raise ValueError( |
| f"#images ({len(pil_images)}) must match #sigmas ({len(sigmas)})" |
| ) |
| tensors = self._build_image_input(pil_images, sigmas, flux_scales) |
| pred = self.ec_image_model.predict( |
| tensors["input"], |
| scan_mask=tensors["scan_mask"], |
| sigmas=tensors["sigmas"], |
| flux_scales=tensors["flux_scales"], |
| n_samples=n_samples, |
| temperature=temperature, |
| ) |
| return self._format_ec_pred(pred, self.ec_image_mech_list) |
|
|
| def _build_joint_input_ec(self, pil_images, sigmas, |
| potentials, fluxes, times): |
| """Build the joint encoder's input dict for CV. |
| |
| Combines `_build_image_input` (image branch) with the same |
| normalization/resampling pipeline `_prepare_ec_tensor` uses for the |
| waveform branch, then fuses them into the dict shape the joint |
| encoder expects. |
| """ |
| n_scans = len(pil_images) |
| if len(sigmas) != n_scans or len(potentials) != n_scans \ |
| or len(fluxes) != n_scans: |
| raise ValueError( |
| "predict_ec_joint: pil_images, sigmas, potentials, fluxes " |
| "must all have the same length" |
| ) |
|
|
| |
| imgs = torch.stack( |
| [self._pil_to_grayscale_tensor(p, target_size=224) |
| for p in pil_images] |
| ).unsqueeze(0).to(self.device) |
| scan_mask_image = torch.ones(1, n_scans, dtype=torch.bool, |
| device=self.device) |
|
|
| |
| wf_tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas) |
| x = { |
| "image": imgs, |
| "waveform": wf_tensors["input"], |
| "scan_mask_image": scan_mask_image, |
| "scan_mask_waveform": wf_tensors["scan_mask"], |
| } |
| return x, wf_tensors["sigmas"], wf_tensors["flux_scales"] |
|
|
| def _build_joint_input_tpd(self, pil_images, betas, temperatures, rates): |
| n_rates = len(pil_images) |
| if len(betas) != n_rates or len(temperatures) != n_rates \ |
| or len(rates) != n_rates: |
| raise ValueError( |
| "predict_tpd_joint: pil_images, betas, temperatures, rates " |
| "must all have the same length" |
| ) |
| imgs = torch.stack( |
| [self._pil_to_grayscale_tensor(p, target_size=224) |
| for p in pil_images] |
| ).unsqueeze(0).to(self.device) |
| scan_mask_image = torch.ones(1, n_rates, dtype=torch.bool, |
| device=self.device) |
| wf_tensors = self._prepare_tpd_tensor(temperatures, rates, betas) |
| x = { |
| "image": imgs, |
| "waveform": wf_tensors["input"], |
| "scan_mask_image": scan_mask_image, |
| "scan_mask_waveform": wf_tensors["scan_mask"], |
| } |
| return x, wf_tensors["sigmas"], wf_tensors["flux_scales"] |
|
|
| @torch.no_grad() |
| def predict_ec_joint(self, pil_images, sigmas, potentials, fluxes, |
| times=None, n_samples=500, temperature=1.0): |
| """Run Phase-2 joint CV inference: image + waveform together.""" |
| if self.ec_joint_model is None: |
| raise RuntimeError("EC joint model not loaded") |
| x, sigmas_t, flux_scales_t = self._build_joint_input_ec( |
| pil_images, sigmas, potentials, fluxes, times, |
| ) |
| pred = self.ec_joint_model.predict( |
| x, scan_mask=None, sigmas=sigmas_t, flux_scales=flux_scales_t, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| return self._format_ec_pred(pred, self.ec_joint_mech_list) |
|
|
| @torch.no_grad() |
| def predict_tpd_joint(self, pil_images, betas, temperatures, rates, |
| n_samples=500, temperature=1.0): |
| """Run Phase-2 joint TPD inference: image + waveform together.""" |
| if self.tpd_joint_model is None: |
| raise RuntimeError("TPD joint model not loaded") |
| x, sigmas_t, flux_scales_t = self._build_joint_input_tpd( |
| pil_images, betas, temperatures, rates, |
| ) |
| pred = self.tpd_joint_model.predict( |
| x, scan_mask=None, sigmas=sigmas_t, flux_scales=flux_scales_t, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| return self._format_tpd_pred(pred, self.tpd_joint_mech_list) |
|
|
| @torch.no_grad() |
| def predict_tpd_image(self, pil_images, betas, flux_scales=None, |
| n_samples=500, temperature=1.0): |
| """Run image-mode TPD inference. Returns the same shape as predict_tpd.""" |
| if self.tpd_image_model is None: |
| raise RuntimeError("TPD image model not loaded") |
| if len(pil_images) != len(betas): |
| raise ValueError( |
| f"#images ({len(pil_images)}) must match #betas ({len(betas)})" |
| ) |
| tensors = self._build_image_input(pil_images, betas, flux_scales) |
| pred = self.tpd_image_model.predict( |
| tensors["input"], |
| scan_mask=tensors["scan_mask"], |
| sigmas=tensors["sigmas"], |
| flux_scales=tensors["flux_scales"], |
| n_samples=n_samples, |
| temperature=temperature, |
| ) |
| return self._format_tpd_pred(pred, self.tpd_image_mech_list) |
|
|
| def _format_ec_pred(self, pred, mech_list): |
| probs = pred["mechanism_probs"][0].cpu().numpy() |
| pred_idx = int(pred["mechanism_pred"][0].cpu().item()) |
| pred_mech = mech_list[pred_idx] |
| param_stats = {} |
| samples_dict = {} |
| for mech in mech_list: |
| if pred["samples"][mech] is not None: |
| s = pred["samples"][mech][0].cpu().numpy() |
| samples_dict[mech] = s |
| param_stats[mech] = { |
| "names": MECHANISM_PARAMS[mech]["names"], |
| "mean": s.mean(axis=0).tolist(), |
| "std": s.std(axis=0).tolist(), |
| "median": np.median(s, axis=0).tolist(), |
| "q05": np.quantile(s, 0.05, axis=0).tolist(), |
| "q95": np.quantile(s, 0.95, axis=0).tolist(), |
| } |
| ood_score = pred.get("ood_score") |
| ood_score_val = (float(ood_score[0].cpu().item()) |
| if ood_score is not None else None) |
| return { |
| "domain": "ec", |
| "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, |
| "mechanism_names": mech_list, |
| "predicted_mechanism": pred_mech, |
| "predicted_mechanism_idx": pred_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| "ood_score": ood_score_val, |
| } |
|
|
| def _format_tpd_pred(self, pred, mech_list): |
| probs = pred["mechanism_probs"][0].cpu().numpy() |
| pred_idx = int(pred["mechanism_pred"][0].cpu().item()) |
| pred_mech = mech_list[pred_idx] |
| param_stats = {} |
| samples_dict = {} |
| for mech in mech_list: |
| if pred["samples"][mech] is not None: |
| s = pred["samples"][mech][0].cpu().numpy() |
| samples_dict[mech] = s |
| param_stats[mech] = { |
| "names": TPD_MECHANISM_PARAMS[mech]["names"], |
| "mean": s.mean(axis=0).tolist(), |
| "std": s.std(axis=0).tolist(), |
| "median": np.median(s, axis=0).tolist(), |
| "q05": np.quantile(s, 0.05, axis=0).tolist(), |
| "q95": np.quantile(s, 0.95, axis=0).tolist(), |
| } |
| ood_score = pred.get("ood_score") |
| ood_score_val = (float(ood_score[0].cpu().item()) |
| if ood_score is not None else None) |
| return { |
| "domain": "tpd", |
| "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, |
| "mechanism_names": mech_list, |
| "predicted_mechanism": pred_mech, |
| "predicted_mechanism_idx": pred_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| "ood_score": ood_score_val, |
| } |
|
|
| |
| |
| |
| |
|
|
| @staticmethod |
| def _ensemble_results(image_result, waveform_result, param_names_lookup): |
| """Combine image-mode + waveform-mode predictions. |
| |
| - mechanism_probs: arithmetic mean of the two per-mech dicts. |
| - predicted_mechanism: argmax of the ensembled probs. |
| - posterior_samples: per-mech, concatenate the samples from each |
| model (when both produced samples), then recompute parameter_stats. |
| - ood_score: max of the two scores (more conservative; image-mode's |
| OOD means P(ID), so 'higher' is safer; we take min instead — see |
| below). We surface the IMAGE-mode OOD score as the headline OOD, |
| but if the waveform model also has one we take the lower of the |
| two so the banner is shown when *either* is concerned. |
| |
| `param_names_lookup` is a dict mech -> list of parameter names; it |
| provides 'names' for the recomputed parameter_stats. |
| """ |
| if image_result is None: |
| return waveform_result |
| if waveform_result is None: |
| return image_result |
|
|
| mech_list = image_result["mechanism_names"] |
| probs_a = image_result["mechanism_probs"] |
| probs_b = waveform_result["mechanism_probs"] |
|
|
| ensemble_probs = { |
| m: 0.5 * (probs_a.get(m, 0.0) + probs_b.get(m, 0.0)) |
| for m in mech_list |
| } |
| s = sum(ensemble_probs.values()) |
| if s > 0: |
| ensemble_probs = {m: v / s for m, v in ensemble_probs.items()} |
| sorted_probs = sorted(ensemble_probs.items(), key=lambda kv: -kv[1]) |
| top_mech, _ = sorted_probs[0] |
| top_idx = mech_list.index(top_mech) |
|
|
| samples_dict = {} |
| param_stats = {} |
| for mech in mech_list: |
| sa = image_result["posterior_samples"].get(mech) |
| sb = waveform_result["posterior_samples"].get(mech) |
| if sa is not None and sb is not None: |
| combined = np.concatenate([sa, sb], axis=0) |
| elif sa is not None: |
| combined = sa |
| elif sb is not None: |
| combined = sb |
| else: |
| combined = None |
| if combined is None: |
| continue |
| samples_dict[mech] = combined |
| names = param_names_lookup.get(mech, [f"p{i}" for i in range(combined.shape[-1])]) |
| param_stats[mech] = { |
| "names": names, |
| "mean": combined.mean(axis=0).tolist(), |
| "std": combined.std(axis=0).tolist(), |
| "median": np.median(combined, axis=0).tolist(), |
| "q05": np.quantile(combined, 0.05, axis=0).tolist(), |
| "q95": np.quantile(combined, 0.95, axis=0).tolist(), |
| } |
|
|
| ood_a = image_result.get("ood_score") |
| ood_b = waveform_result.get("ood_score") |
| ood_vals = [v for v in (ood_a, ood_b) if v is not None] |
| ood_ensemble = min(ood_vals) if ood_vals else None |
|
|
| return { |
| "domain": image_result["domain"], |
| "mechanism_probs": ensemble_probs, |
| "mechanism_names": mech_list, |
| "predicted_mechanism": top_mech, |
| "predicted_mechanism_idx": top_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| "ood_score": ood_ensemble, |
| "_ensemble": True, |
| } |
|
|
| @staticmethod |
| def _agreement_stats(image_result, waveform_result): |
| if image_result is None or waveform_result is None: |
| return { |
| "available": False, |
| "top_mech_match": None, |
| "top_prob_image": None, |
| "top_prob_waveform": None, |
| } |
| return { |
| "available": True, |
| "top_mech_match": ( |
| image_result["predicted_mechanism"] |
| == waveform_result["predicted_mechanism"] |
| ), |
| "top_prob_image": float(max(image_result["mechanism_probs"].values())), |
| "top_prob_waveform": float(max(waveform_result["mechanism_probs"].values())), |
| "image_mech": image_result["predicted_mechanism"], |
| "waveform_mech": waveform_result["predicted_mechanism"], |
| } |
|
|
| def _predict_ec_two_paths( |
| self, pil_images, sigmas, potentials, fluxes, times=None, |
| n_samples=500, temperature=1.0, do_preprocess=True, |
| ): |
| """Run image-mode + waveform-mode + joint CV inference (whichever are |
| available). Returns (image_result, waveform_result, joint_result, |
| preprocessing_meta). Any result may be None if the corresponding |
| model is missing or the input arrays were not provided. |
| """ |
| image_result = None |
| waveform_result = None |
| joint_result = None |
| preproc_meta = [] |
|
|
| |
| preprocessed = None |
| if pil_images and (self.has_ec_image_model or self.has_ec_joint_model): |
| if do_preprocess: |
| from image_preprocessing import prepare_for_image_mode |
| preprocessed = [] |
| for p in pil_images: |
| out, meta = prepare_for_image_mode(p) |
| preprocessed.append(out) |
| preproc_meta.append(meta) |
| else: |
| preprocessed = list(pil_images) |
| preproc_meta = [{} for _ in pil_images] |
|
|
| if self.has_ec_image_model and preprocessed is not None: |
| flux_scales = None |
| if fluxes is not None: |
| flux_scales = [ |
| float(np.log10(np.max(np.abs(np.asarray(f))) + 1e-30)) |
| for f in fluxes |
| ] |
| try: |
| image_result = self.predict_ec_image( |
| preprocessed, sigmas, flux_scales=flux_scales, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| except Exception as exc: |
| print(f"[SPARKPredictor] image-mode CV failed: {exc}") |
|
|
| if self.ec_model is not None and potentials is not None and fluxes is not None: |
| try: |
| waveform_result = self.predict_ec( |
| potentials, fluxes, sigmas, times=times, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| except Exception as exc: |
| print(f"[SPARKPredictor] waveform CV failed: {exc}") |
|
|
| if (self.has_ec_joint_model and preprocessed is not None |
| and potentials is not None and fluxes is not None): |
| try: |
| joint_result = self.predict_ec_joint( |
| preprocessed, sigmas, potentials, fluxes, times=times, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| except Exception as exc: |
| print(f"[SPARKPredictor] joint CV failed: {exc}") |
|
|
| return image_result, waveform_result, joint_result, preproc_meta |
|
|
| def predict_ec_hybrid( |
| self, pil_images, sigmas, potentials=None, fluxes=None, times=None, |
| n_samples=500, temperature=1.0, mode="ensemble", |
| do_preprocess=True, ood_fallback_threshold=0.3, |
| ): |
| """Hybrid CV inference combining image-mode and waveform-mode. |
| |
| Args: |
| pil_images: list of PIL.Image (one per scan rate). Required for |
| image-mode; if missing, image-mode is skipped. |
| sigmas: list of dimensionless scan rates. |
| potentials, fluxes: dimensionless waveforms (one list of arrays |
| per scan rate). Required for waveform-mode; if missing, |
| waveform-mode is skipped. |
| times: optional dimensionless time arrays. |
| mode: one of 'ensemble', 'image_only', 'digitize_only', |
| 'auto_fallback'. |
| do_preprocess: whether to run prepare_for_image_mode on inputs |
| before image-mode (default True). |
| ood_fallback_threshold: if mode='auto_fallback', image-mode's |
| OOD score (P(ID)) below this threshold triggers fallback to |
| waveform-mode. |
| |
| Returns dict with: |
| headline: result dict to display (same shape as predict_ec). |
| image_mode: image-mode result or None. |
| waveform_mode: waveform result or None. |
| preprocessing_meta: per-image preprocessing metadata list. |
| agreement: agreement_stats(image, waveform) dict. |
| method_used: human-readable string of which method produced |
| 'headline'. |
| """ |
| (image_result, waveform_result, joint_result, preproc_meta) = ( |
| self._predict_ec_two_paths( |
| pil_images, sigmas, potentials, fluxes, times, |
| n_samples=n_samples, temperature=temperature, |
| do_preprocess=do_preprocess, |
| ) |
| ) |
| agreement = self._agreement_stats(image_result, waveform_result) |
| param_names_lookup = {m: MECHANISM_PARAMS[m]["names"] for m in MECHANISM_PARAMS} |
| headline, method_used = self._select_headline( |
| mode, image_result, waveform_result, |
| param_names_lookup, ood_fallback_threshold, |
| joint_result=joint_result, |
| ) |
| return { |
| "headline": headline, |
| "image_mode": image_result, |
| "waveform_mode": waveform_result, |
| "joint_mode": joint_result, |
| "preprocessing_meta": preproc_meta, |
| "agreement": agreement, |
| "method_used": method_used, |
| } |
|
|
| def _predict_tpd_two_paths( |
| self, pil_images, betas, temperatures, rates, |
| n_samples=500, temperature=1.0, do_preprocess=True, |
| ): |
| image_result = None |
| waveform_result = None |
| joint_result = None |
| preproc_meta = [] |
|
|
| preprocessed = None |
| if pil_images and (self.has_tpd_image_model or self.has_tpd_joint_model): |
| if do_preprocess: |
| from image_preprocessing import prepare_for_image_mode |
| preprocessed = [] |
| for p in pil_images: |
| out, meta = prepare_for_image_mode(p) |
| preprocessed.append(out) |
| preproc_meta.append(meta) |
| else: |
| preprocessed = list(pil_images) |
| preproc_meta = [{} for _ in pil_images] |
|
|
| if self.has_tpd_image_model and preprocessed is not None: |
| flux_scales = None |
| if rates is not None: |
| flux_scales = [ |
| float(np.log10(np.max(np.abs(np.asarray(r))) + 1e-30)) |
| for r in rates |
| ] |
| try: |
| image_result = self.predict_tpd_image( |
| preprocessed, betas, flux_scales=flux_scales, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| except Exception as exc: |
| print(f"[SPARKPredictor] image-mode TPD failed: {exc}") |
|
|
| if self.tpd_model is not None and temperatures is not None and rates is not None: |
| try: |
| waveform_result = self.predict_tpd( |
| temperatures, rates, betas, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| except Exception as exc: |
| print(f"[SPARKPredictor] waveform TPD failed: {exc}") |
|
|
| if (self.has_tpd_joint_model and preprocessed is not None |
| and temperatures is not None and rates is not None): |
| try: |
| joint_result = self.predict_tpd_joint( |
| preprocessed, betas, temperatures, rates, |
| n_samples=n_samples, temperature=temperature, |
| ) |
| except Exception as exc: |
| print(f"[SPARKPredictor] joint TPD failed: {exc}") |
|
|
| return image_result, waveform_result, joint_result, preproc_meta |
|
|
| def predict_tpd_hybrid( |
| self, pil_images, betas, temperatures=None, rates=None, |
| n_samples=500, temperature=1.0, mode="ensemble", |
| do_preprocess=True, ood_fallback_threshold=0.3, |
| ): |
| """Hybrid TPD inference. Same shape as predict_ec_hybrid.""" |
| (image_result, waveform_result, joint_result, preproc_meta) = ( |
| self._predict_tpd_two_paths( |
| pil_images, betas, temperatures, rates, |
| n_samples=n_samples, temperature=temperature, |
| do_preprocess=do_preprocess, |
| ) |
| ) |
| agreement = self._agreement_stats(image_result, waveform_result) |
| param_names_lookup = {m: TPD_MECHANISM_PARAMS[m]["names"] |
| for m in TPD_MECHANISM_PARAMS} |
| headline, method_used = self._select_headline( |
| mode, image_result, waveform_result, |
| param_names_lookup, ood_fallback_threshold, |
| joint_result=joint_result, |
| ) |
| return { |
| "headline": headline, |
| "image_mode": image_result, |
| "waveform_mode": waveform_result, |
| "joint_mode": joint_result, |
| "preprocessing_meta": preproc_meta, |
| "agreement": agreement, |
| "method_used": method_used, |
| } |
|
|
| def _select_headline(self, mode, image_result, waveform_result, |
| param_names_lookup, ood_fallback_threshold, |
| joint_result=None): |
| """Pick the headline result based on `mode`. Falls back gracefully |
| when one of the available paths is unavailable. |
| |
| Supported modes: ``ensemble``, ``image_only``, ``digitize_only``, |
| ``auto_fallback``, ``joint``. |
| """ |
| if (image_result is None and waveform_result is None |
| and joint_result is None): |
| raise RuntimeError( |
| "Hybrid inference: image, waveform, and joint paths all " |
| "failed or were unavailable." |
| ) |
|
|
| if mode == "joint": |
| if joint_result is not None: |
| return joint_result, "joint image+waveform (Phase 2)" |
| if image_result is not None and waveform_result is not None: |
| return ( |
| self._ensemble_results( |
| image_result, waveform_result, param_names_lookup), |
| "ensemble (joint unavailable)", |
| ) |
| if image_result is not None: |
| return image_result, "image-direct (joint unavailable)" |
| return waveform_result, "digitize-then-infer (joint unavailable)" |
|
|
| if mode == "image_only": |
| if image_result is not None: |
| return image_result, "image-direct" |
| return waveform_result, "digitize-then-infer (image fallback)" |
|
|
| if mode == "digitize_only": |
| if waveform_result is not None: |
| return waveform_result, "digitize-then-infer" |
| return image_result, "image-direct (waveform fallback)" |
|
|
| if mode == "auto_fallback": |
| if image_result is None: |
| return waveform_result, "digitize-then-infer (no image model)" |
| ood = image_result.get("ood_score") |
| if (ood is not None and ood < ood_fallback_threshold |
| and waveform_result is not None): |
| return waveform_result, ( |
| f"digitize-then-infer (image OOD score " |
| f"{ood:.2f} < {ood_fallback_threshold:.2f})" |
| ) |
| return image_result, "image-direct" |
|
|
| |
| if joint_result is not None: |
| return joint_result, "joint image+waveform (Phase 2)" |
| if image_result is not None and waveform_result is not None: |
| return ( |
| self._ensemble_results( |
| image_result, waveform_result, param_names_lookup), |
| "ensemble (image + digitize)", |
| ) |
| if image_result is not None: |
| return image_result, "image-direct (waveform unavailable)" |
| return waveform_result, "digitize-then-infer (image unavailable)" |
|
|
| def _prepare_ec_tensor(self, potentials, fluxes, times, sigmas): |
| """ |
| Build model input tensor from preprocessed dimensionless CV data. |
| |
| Args: |
| potentials: list of 1-D arrays (dimensionless theta) |
| fluxes: list of 1-D arrays (dimensionless flux) |
| times: list of 1-D arrays (dimensionless time) or None |
| sigmas: 1-D array of dimensionless scan rates |
| |
| Returns: |
| dict of tensors ready for model.predict() |
| """ |
| from scipy.interpolate import interp1d |
|
|
| n_scans = len(potentials) |
| T_target = 672 |
|
|
| pot_resampled = [] |
| flux_resampled = [] |
| time_resampled = [] |
| flux_scales = [] |
|
|
| for i in range(n_scans): |
| pot = np.asarray(potentials[i], dtype=np.float32) |
| flx = np.asarray(fluxes[i], dtype=np.float32) |
|
|
| if times is not None and times[i] is not None: |
| tim = np.asarray(times[i], dtype=np.float32) |
| else: |
| theta_range = pot.max() - pot.min() |
| sigma = sigmas[i] |
| total_time = 2.0 * theta_range / sigma |
| tim = np.linspace(0, total_time, len(pot), dtype=np.float32) |
|
|
| peak = np.max(np.abs(flx)) + 1e-30 |
| flux_scales.append(np.log10(peak)) |
| flx = flx / peak |
|
|
| t_uniform = np.linspace(tim[0], tim[-1], T_target) |
| pot_resampled.append( |
| interp1d(tim, pot, kind="linear", fill_value="extrapolate")(t_uniform) |
| ) |
| flux_resampled.append( |
| interp1d(tim, flx, kind="linear", fill_value="extrapolate")(t_uniform) |
| ) |
| time_resampled.append(t_uniform) |
|
|
| pot_arr = np.stack(pot_resampled).astype(np.float32) |
| flx_arr = np.stack(flux_resampled).astype(np.float32) |
| tim_arr = np.stack(time_resampled).astype(np.float32) |
|
|
| ns = self.ec_norm_stats |
| if ns: |
| pot_arr = (pot_arr - ns["potential"][0]) / ns["potential"][1] |
| flx_arr = (flx_arr - ns["flux"][0]) / ns["flux"][1] |
| tim_arr = (tim_arr - ns["time"][0]) / ns["time"][1] |
|
|
| |
| waveforms = np.stack([pot_arr, flx_arr, tim_arr], axis=1) |
| x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) |
| scan_mask = torch.ones(1, n_scans, T_target, dtype=torch.bool, device=self.device) |
| sigmas_t = torch.from_numpy( |
| np.log10(np.asarray(sigmas, dtype=np.float32)) |
| ).unsqueeze(0).to(self.device) |
| flux_scales_t = torch.from_numpy( |
| np.asarray(flux_scales, dtype=np.float32) |
| ).unsqueeze(0).to(self.device) |
|
|
| return { |
| "input": x, |
| "scan_mask": scan_mask, |
| "sigmas": sigmas_t, |
| "flux_scales": flux_scales_t, |
| } |
|
|
| def _prepare_tpd_tensor(self, temperatures, rates, betas): |
| """ |
| Build model input tensor from TPD data. |
| |
| Args: |
| temperatures: list of 1-D arrays (K) |
| rates: list of 1-D arrays (arb. units) |
| betas: 1-D array of heating rates (K/s) |
| |
| Returns: |
| dict of tensors ready for model.predict() |
| """ |
| from scipy.interpolate import interp1d |
|
|
| n_rates = len(temperatures) |
| T_target = 500 |
|
|
| temp_resampled = [] |
| rate_resampled = [] |
|
|
| for i in range(n_rates): |
| temp = np.asarray(temperatures[i], dtype=np.float32) |
| rate = np.asarray(rates[i], dtype=np.float32) |
|
|
| t_uniform = np.linspace(temp[0], temp[-1], T_target) |
| temp_resampled.append(t_uniform) |
| rate_resampled.append( |
| interp1d(temp, rate, kind="linear", fill_value="extrapolate")(t_uniform) |
| ) |
|
|
| temp_arr = np.stack(temp_resampled).astype(np.float32) |
| rate_arr = np.stack(rate_resampled).astype(np.float32) |
|
|
| summary_t = None |
| if getattr(self, 'tpd_use_summary', False): |
| from preprocessing import extract_tpd_summary_stats |
| hr_arr = np.asarray(betas, dtype=np.float32) |
| lengths = np.full(n_rates, T_target, dtype=np.int32) |
| summary = extract_tpd_summary_stats( |
| temp_arr, rate_arr, lengths, hr_arr, n_rates) |
| summary_t = torch.from_numpy(summary).unsqueeze(0).to(self.device) |
|
|
| rate_scales = [] |
| for i in range(n_rates): |
| peak = np.max(np.abs(rate_arr[i])) + 1e-30 |
| rate_scales.append(np.log10(peak)) |
| rate_arr[i] /= peak |
|
|
| ns = self.tpd_norm_stats |
| if ns: |
| temp_arr = (temp_arr - ns["temperature"][0]) / ns["temperature"][1] |
| rate_arr = (rate_arr - ns["rate"][0]) / ns["rate"][1] |
|
|
| |
| waveforms = np.stack([temp_arr, rate_arr], axis=1) |
| x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device) |
| scan_mask = torch.ones(1, n_rates, T_target, dtype=torch.bool, device=self.device) |
| sigmas_t = torch.from_numpy( |
| np.log10(np.asarray(betas, dtype=np.float32)) |
| ).unsqueeze(0).to(self.device) |
| rate_scales_t = torch.from_numpy( |
| np.asarray(rate_scales, dtype=np.float32) |
| ).unsqueeze(0).to(self.device) |
|
|
| result = { |
| "input": x, |
| "scan_mask": scan_mask, |
| "sigmas": sigmas_t, |
| "flux_scales": rate_scales_t, |
| } |
| if summary_t is not None: |
| result["summary"] = summary_t |
| return result |
|
|
| @torch.no_grad() |
| def predict_ec(self, potentials, fluxes, sigmas, times=None, n_samples=500, temperature=1.0): |
| """ |
| Run EC inference on dimensionless CV data. |
| |
| Args: |
| potentials: list of 1-D arrays (dimensionless theta per scan rate) |
| fluxes: list of 1-D arrays (dimensionless flux per scan rate) |
| sigmas: list/array of dimensionless scan rates |
| times: optional list of 1-D time arrays |
| n_samples: posterior samples to draw |
| temperature: sampling temperature (>1 broadens posteriors) |
| |
| Returns: |
| dict with mechanism_probs, mechanism_names, predicted_mechanism, |
| parameter_stats (per mechanism), posterior_samples (per mechanism) |
| """ |
| if self.ec_model is None: |
| raise RuntimeError("EC model not loaded") |
|
|
| tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas) |
| pred = self.ec_model.predict( |
| tensors["input"], |
| scan_mask=tensors["scan_mask"], |
| sigmas=tensors["sigmas"], |
| flux_scales=tensors["flux_scales"], |
| n_samples=n_samples, |
| temperature=temperature, |
| ) |
|
|
| mech_list = self.ec_mechanism_list |
| probs = pred["mechanism_probs"][0].cpu().numpy() |
| pred_idx = int(pred["mechanism_pred"][0].cpu().item()) |
| pred_mech = mech_list[pred_idx] |
|
|
| param_stats = {} |
| samples_dict = {} |
| for mech in mech_list: |
| if pred["samples"][mech] is not None: |
| s = pred["samples"][mech][0].cpu().numpy() |
| samples_dict[mech] = s |
| param_stats[mech] = { |
| "names": MECHANISM_PARAMS[mech]["names"], |
| "mean": s.mean(axis=0).tolist(), |
| "std": s.std(axis=0).tolist(), |
| "median": np.median(s, axis=0).tolist(), |
| "q05": np.quantile(s, 0.05, axis=0).tolist(), |
| "q95": np.quantile(s, 0.95, axis=0).tolist(), |
| } |
|
|
| ood_score = pred.get("ood_score") |
| ood_score_val = (float(ood_score[0].cpu().item()) |
| if ood_score is not None else None) |
|
|
| return { |
| "domain": "ec", |
| "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, |
| "mechanism_names": mech_list, |
| "predicted_mechanism": pred_mech, |
| "predicted_mechanism_idx": pred_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| "ood_score": ood_score_val, |
| } |
|
|
| @torch.no_grad() |
| def predict_tpd(self, temperatures, rates, betas, n_samples=500, temperature=1.0): |
| """ |
| Run TPD inference. |
| |
| Args: |
| temperatures: list of 1-D arrays (K per heating rate) |
| rates: list of 1-D arrays (signal per heating rate) |
| betas: list/array of heating rates (K/s) |
| n_samples: posterior samples to draw |
| temperature: sampling temperature |
| |
| Returns: |
| dict with mechanism_probs, parameter_stats, posterior_samples |
| """ |
| if self.tpd_model is None: |
| raise RuntimeError("TPD model not loaded") |
|
|
| tensors = self._prepare_tpd_tensor(temperatures, rates, betas) |
| pred = self.tpd_model.predict( |
| tensors["input"], |
| scan_mask=tensors["scan_mask"], |
| sigmas=tensors["sigmas"], |
| flux_scales=tensors["flux_scales"], |
| n_samples=n_samples, |
| temperature=temperature, |
| summary=tensors.get("summary"), |
| ) |
|
|
| mech_list = self.tpd_mechanism_list |
| probs = pred["mechanism_probs"][0].cpu().numpy() |
| pred_idx = int(pred["mechanism_pred"][0].cpu().item()) |
| pred_mech = mech_list[pred_idx] |
|
|
| param_stats = {} |
| samples_dict = {} |
| for mech in mech_list: |
| if pred["samples"][mech] is not None: |
| s = pred["samples"][mech][0].cpu().numpy() |
| samples_dict[mech] = s |
| param_stats[mech] = { |
| "names": TPD_MECHANISM_PARAMS[mech]["names"], |
| "mean": s.mean(axis=0).tolist(), |
| "std": s.std(axis=0).tolist(), |
| "median": np.median(s, axis=0).tolist(), |
| "q05": np.quantile(s, 0.05, axis=0).tolist(), |
| "q95": np.quantile(s, 0.95, axis=0).tolist(), |
| } |
|
|
| ood_score = pred.get("ood_score") |
| ood_score_val = (float(ood_score[0].cpu().item()) |
| if ood_score is not None else None) |
|
|
| return { |
| "domain": "tpd", |
| "mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)}, |
| "mechanism_names": mech_list, |
| "predicted_mechanism": pred_mech, |
| "predicted_mechanism_idx": pred_idx, |
| "parameter_stats": param_stats, |
| "posterior_samples": samples_dict, |
| "ood_score": ood_score_val, |
| } |
|
|
| |
| |
| |
|
|
| def reconstruct_ec(self, result, potentials, fluxes, sigmas, |
| base_params=None, mechanism=None): |
| """ |
| Reconstruct CV signals from inferred posterior median and compute metrics. |
| |
| Args: |
| result: output dict from predict_ec() |
| potentials: list of 1-D arrays (original dimensionless theta) |
| fluxes: list of 1-D arrays (original dimensionless flux) |
| sigmas: list of dimensionless scan rates |
| base_params: dict of fixed simulation params; defaults used if None |
| mechanism: which mechanism to reconstruct (default: predicted) |
| |
| Returns: |
| dict with 'observed', 'reconstructed' curve lists, |
| 'nrmse', 'r2' per scan rate, and 'mean_nrmse', 'mean_r2' |
| """ |
| from evaluate_reconstruction import ( |
| reconstruct_ec_signal, signal_nrmse, signal_r2, |
| ) |
|
|
| mech = mechanism or result["predicted_mechanism"] |
| stats = result["parameter_stats"].get(mech) |
| if stats is None: |
| return None |
|
|
| theta_point = np.array(stats["median"]) |
|
|
| if base_params is None: |
| pot0 = np.asarray(potentials[0]) |
| base_params = { |
| "theta_i": float(pot0.max()), |
| "theta_v": float(pot0.min()), |
| "dA": 1.0, |
| "C_A_bulk": 1.0, |
| "C_B_bulk": 0.0, |
| "kinetics": mech, |
| } |
|
|
| try: |
| recon_results = reconstruct_ec_signal( |
| theta_point, mech, base_params, sigmas, n_spatial=64 |
| ) |
| except Exception: |
| return None |
|
|
| observed_curves = [] |
| recon_curves = [] |
| conc_curves = [] |
| nrmses = [] |
| r2s = [] |
|
|
| for i, (pot, flx, sigma) in enumerate(zip(potentials, fluxes, sigmas)): |
| pot = np.asarray(pot) |
| flx = np.asarray(flx) |
| observed_curves.append({"x": pot, "y": flx}) |
|
|
| if i < len(recon_results) and recon_results[i].get("success", False): |
| rec = recon_results[i] |
| rec_pot = np.asarray(rec["potential"]) |
| rec_flx = np.asarray(rec["flux"]) |
|
|
| n_obs = len(pot) |
| n_rec = len(rec_pot) |
| t_obs = np.linspace(0, 1, n_obs) |
| t_rec = np.linspace(0, 1, n_rec) |
| rec_flx_interp = np.interp(t_obs, t_rec, rec_flx) |
| recon_curves.append({"x": pot, "y": rec_flx_interp}) |
| nrmse_val = signal_nrmse(flx, rec_flx_interp) |
| r2_val = signal_r2(flx, rec_flx_interp) |
| nrmses.append(nrmse_val) |
| r2s.append(r2_val) |
|
|
| if "c_ox_surface" in rec and "c_red_surface" in rec: |
| c_ox_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_ox_surface"])) |
| c_red_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_red_surface"])) |
| conc_curves.append({ |
| "x": pot, |
| "c_ox": c_ox_interp, |
| "c_red": c_red_interp, |
| }) |
| else: |
| conc_curves.append(None) |
| else: |
| recon_curves.append({"x": pot, "y": np.zeros_like(flx)}) |
| nrmses.append(float("nan")) |
| r2s.append(float("nan")) |
| conc_curves.append(None) |
|
|
| valid_nrmse = [v for v in nrmses if np.isfinite(v)] |
| valid_r2 = [v for v in r2s if np.isfinite(v)] |
|
|
| return { |
| "observed": observed_curves, |
| "reconstructed": recon_curves, |
| "concentrations": conc_curves, |
| "nrmse": nrmses, |
| "r2": r2s, |
| "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), |
| "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), |
| } |
|
|
| def reconstruct_tpd(self, result, temperatures, rates, betas, |
| base_params=None, mechanism=None): |
| """ |
| Reconstruct TPD signals from inferred posterior median and compute metrics. |
| |
| Args: |
| result: output dict from predict_tpd() |
| temperatures: list of 1-D arrays (K) |
| rates: list of 1-D arrays (signal) |
| betas: list of heating rates (K/s) |
| base_params: dict of fixed simulation params; defaults used if None |
| mechanism: which mechanism to reconstruct (default: predicted) |
| |
| Returns: |
| dict with 'observed', 'reconstructed' curve lists, |
| 'nrmse', 'r2' per heating rate, and 'mean_nrmse', 'mean_r2' |
| """ |
| from evaluate_reconstruction import ( |
| reconstruct_tpd_signal, signal_nrmse, signal_r2, |
| ) |
|
|
| mech = mechanism or result["predicted_mechanism"] |
| stats = result["parameter_stats"].get(mech) |
| if stats is None: |
| return None |
|
|
| theta_point = np.array(stats["median"]) |
|
|
| if base_params is None: |
| temp0 = np.asarray(temperatures[0]) |
| base_params = { |
| "mechanism": mech, |
| "T_start": float(temp0.min()), |
| "T_end": float(temp0.max()), |
| "n_points": 500, |
| } |
|
|
| try: |
| recon_results = reconstruct_tpd_signal( |
| theta_point, mech, base_params, betas |
| ) |
| except Exception: |
| return None |
|
|
| observed_curves = [] |
| recon_curves = [] |
| nrmses = [] |
| r2s = [] |
|
|
| for i, (temp, rate, beta) in enumerate(zip(temperatures, rates, betas)): |
| temp = np.asarray(temp) |
| rate = np.asarray(rate) |
| observed_curves.append({"x": temp, "y": rate}) |
|
|
| if i < len(recon_results) and recon_results[i].get("success", False): |
| rec = recon_results[i] |
| rec_temp = np.asarray(rec["temperature"]) |
| rec_rate = np.asarray(rec["rate"]) |
|
|
| rec_rate_interp = np.interp(temp, rec_temp, rec_rate) |
| recon_curves.append({"x": temp, "y": rec_rate_interp}) |
| nrmse_val = signal_nrmse(rate, rec_rate_interp) |
| r2_val = signal_r2(rate, rec_rate_interp) |
|
|
| nrmses.append(nrmse_val) |
| r2s.append(r2_val) |
| else: |
| recon_curves.append({"x": temp, "y": np.zeros_like(rate)}) |
| nrmses.append(float("nan")) |
| r2s.append(float("nan")) |
|
|
| valid_nrmse = [v for v in nrmses if np.isfinite(v)] |
| valid_r2 = [v for v in r2s if np.isfinite(v)] |
|
|
| return { |
| "observed": observed_curves, |
| "reconstructed": recon_curves, |
| "nrmse": nrmses, |
| "r2": r2s, |
| "mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"), |
| "mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"), |
| } |
|
|