trace / inference.py
bingyan user
Rebrand TRACE -> SPARK
8619a66
raw
history blame contribute delete
68.8 kB
"""
SPARK inference engine.
Loads trained EC and TPD models and runs end-to-end inference
from preprocessed arrays (dimensionless for CV, physical for TPD).
"""
import json
import sys
import os
from pathlib import Path
import numpy as np
import torch
import flow_model as _fm_module
import multi_mechanism_model as _mm_module
import tpd_model as _tpd_mod
import generate_tpd_data as _tpd_gen
from multi_mechanism_model import MultiMechanismFlow
from tpd_model import MultiMechanismFlowTPD
from flow_model import MECHANISM_PARAMS, ActNorm
from generate_tpd_data import TPD_MECHANISM_PARAMS
def _fix_actnorm_initialized(model):
"""Mark all ActNorm layers as initialized after loading a checkpoint.
Old checkpoints lack the ``_initialized`` buffer, so ``load_state_dict``
leaves it at ``False``. The first forward pass would then overwrite the
trained ``log_scale``/``bias`` with data-dependent statistics.
"""
for module in model.modules():
if isinstance(module, ActNorm) and not module.initialized:
module.initialized = True
class SPARKPredictor:
"""Unified predictor for both EC (cyclic voltammetry) and TPD domains."""
def __init__(self, ec_checkpoint=None, tpd_checkpoint=None, device=None,
ec_image_checkpoint=None, tpd_image_checkpoint=None,
ec_joint_checkpoint=None, tpd_joint_checkpoint=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.ec_model = None
self.ec_norm_stats = None
self.tpd_model = None
self.tpd_norm_stats = None
# Optional image-input variants; loaded only if checkpoint paths
# are supplied. Each is a separate model with its own encoder.
self.ec_image_model = None
self.ec_image_mech_list = None
self.tpd_image_model = None
self.tpd_image_mech_list = None
# Optional Phase-2 joint (image + waveform) variants.
self.ec_joint_model = None
self.ec_joint_mech_list = None
self.tpd_joint_model = None
self.tpd_joint_mech_list = None
if ec_checkpoint is not None:
self._load_ec(ec_checkpoint)
if tpd_checkpoint is not None:
self._load_tpd(tpd_checkpoint)
if ec_image_checkpoint is not None:
self._load_ec_image(ec_image_checkpoint)
if tpd_image_checkpoint is not None:
self._load_tpd_image(tpd_image_checkpoint)
if ec_joint_checkpoint is not None:
self._load_ec_joint(ec_joint_checkpoint)
if tpd_joint_checkpoint is not None:
self._load_tpd_joint(tpd_joint_checkpoint)
@property
def has_ec_image_model(self) -> bool:
return self.ec_image_model is not None
@property
def has_tpd_image_model(self) -> bool:
return self.tpd_image_model is not None
@property
def has_ec_joint_model(self) -> bool:
return self.ec_joint_model is not None
@property
def has_tpd_joint_model(self) -> bool:
return self.tpd_joint_model is not None
def _load_ec(self, ckpt_path):
ckpt_path = Path(ckpt_path)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
args = checkpoint["args"]
# If the checkpoint was trained on a custom mechanism subset (e.g.
# v14_9mech), patch the global MECHANISM_LIST in flow_model and
# multi_mechanism_model BEFORE constructing MultiMechanismFlow so
# the classifier and flow_heads sizes match the checkpoint exactly.
# Otherwise strict=False below would silently drop classifier
# weights and we'd end up with random-weights predictions.
if args.get("mechanism_list") is not None:
new_list = list(args["mechanism_list"])
_fm_module.MECHANISM_LIST = new_list
_mm_module.MECHANISM_LIST = new_list
self.ec_mechanism_list = list(_fm_module.MECHANISM_LIST)
self.ec_model = MultiMechanismFlow(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
aggregation=args.get("aggregation", "set_transformer"),
use_summary_features=args.get("use_summary_features", False),
)
# If checkpoint includes a trained OOD head, initialize the matching
# nn.Sequential before loading so its weights are restored too.
state = checkpoint["model_state_dict"]
if any(k.startswith("ood_head.") for k in state):
meta = checkpoint.get("ood_head_meta", {})
self.ec_model.init_ood_head(
hidden_dim=meta.get("hidden_dim", 64),
extra_input_dim=meta.get("extra_input_dim", 0),
use_nll=meta.get("use_nll", False),
use_posterior_width=meta.get("use_posterior_width", False),
)
missing, unexpected = self.ec_model.load_state_dict(
state, strict=False)
# We only tolerate buffers like ActNorm._initialized being missing/extra.
suspicious_missing = [
k for k in missing if not k.endswith("_initialized")
]
suspicious_unexpected = [
k for k in unexpected if not k.endswith("_initialized")
]
if suspicious_missing or suspicious_unexpected:
print(f"[SPARKPredictor] WARNING: state_dict mismatch on EC ckpt.")
if suspicious_missing:
print(f" missing ({len(suspicious_missing)}): "
f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}")
if suspicious_unexpected:
print(f" unexpected({len(suspicious_unexpected)}): "
f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}")
_fix_actnorm_initialized(self.ec_model)
self.ec_model.to(self.device).eval()
# Search for norm_stats in multiple locations
ckpt_dir = ckpt_path.parent
stem = ckpt_path.stem.replace("best", "").rstrip("_")
prefix = stem + "_" if stem else ""
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}norm_stats.json", "ec_norm_stats.json", "norm_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.ec_norm_stats = json.load(f)
break
if self.ec_norm_stats is not None:
break
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}theta_stats.json", "ec_theta_stats.json", "theta_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.ec_theta_stats = json.load(f)
break
if hasattr(self, "ec_theta_stats") and self.ec_theta_stats is not None:
break
def _load_tpd(self, ckpt_path):
ckpt_path = Path(ckpt_path)
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
args = checkpoint["args"]
self.tpd_use_summary = args.get("use_summary_features", False)
# As with EC: align module-globals AND pass mechanism_list explicitly
# so the classifier/flow_heads match the checkpoint shape (e.g.
# tpd_11mech_v1 has 11 mechs, not the default 13).
ckpt_mech_list = args.get("mechanism_list")
if ckpt_mech_list is not None:
new_list = list(ckpt_mech_list)
_tpd_gen.TPD_MECHANISM_LIST = new_list
_tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)}
_tpd_mod.TPD_MECHANISM_LIST = new_list
try:
import dataset_tpd as _ds_tpd
_ds_tpd.TPD_MECHANISM_LIST = new_list
except ImportError:
pass
self.tpd_mechanism_list = list(_tpd_gen.TPD_MECHANISM_LIST)
self.tpd_model = MultiMechanismFlowTPD(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
use_summary_features=self.tpd_use_summary,
use_bounded_flow=args.get("use_bounded_flow", False),
mechanism_list=self.tpd_mechanism_list,
)
state = checkpoint["model_state_dict"]
if any(k.startswith("ood_head.") for k in state):
meta = checkpoint.get("ood_head_meta", {})
self.tpd_model.init_ood_head(
hidden_dim=meta.get("hidden_dim", 64),
extra_input_dim=meta.get("extra_input_dim", 0),
)
missing, unexpected = self.tpd_model.load_state_dict(
state, strict=False)
suspicious_missing = [
k for k in missing if not k.endswith("_initialized")
]
suspicious_unexpected = [
k for k in unexpected if not k.endswith("_initialized")
]
if suspicious_missing or suspicious_unexpected:
print(f"[SPARKPredictor] WARNING: state_dict mismatch on TPD ckpt.")
if suspicious_missing:
print(f" missing ({len(suspicious_missing)}): "
f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}")
if suspicious_unexpected:
print(f" unexpected({len(suspicious_unexpected)}): "
f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}")
_fix_actnorm_initialized(self.tpd_model)
self.tpd_model.to(self.device).eval()
# Search for norm_stats in multiple locations
ckpt_dir = ckpt_path.parent
stem = ckpt_path.stem.replace("best", "").rstrip("_")
prefix = stem + "_" if stem else ""
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}norm_stats.json", "tpd_norm_stats.json", "norm_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.tpd_norm_stats = json.load(f)
break
if self.tpd_norm_stats is not None:
break
for search_dir in [ckpt_dir, ckpt_dir.parent]:
for name_pattern in [f"{prefix}theta_stats.json", "tpd_theta_stats.json", "theta_stats.json"]:
p = search_dir / name_pattern
if p.exists():
with open(p) as f:
self.tpd_theta_stats = json.load(f)
break
if hasattr(self, "tpd_theta_stats") and self.tpd_theta_stats is not None:
break
# =====================================================================
# Image-input variants (parallel to the waveform models above)
# =====================================================================
@staticmethod
def _detect_input_mode(state_dict) -> str:
"""Detect input_mode from a checkpoint's state_dict keys.
Phase-1 image -> encoder.per_cv_encoder.stem.*
Phase-1 wave -> encoder.per_cv_encoder.conv.*
Phase-2 joint -> encoder.image_encoder.* + encoder.waveform_encoder.*
"""
if any(k.startswith("encoder.image_encoder.") for k in state_dict) and \
any(k.startswith("encoder.waveform_encoder.") for k in state_dict):
return "image+waveform"
if any(k.startswith("encoder.per_cv_encoder.stem.") for k in state_dict):
return "image"
return "waveform"
def _build_image_model_state(self, ckpt_path, builder_cls,
is_tpd=False, expected_input_mode="image"):
ckpt_path = Path(ckpt_path)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
args = ckpt["args"] if isinstance(ckpt, dict) and "args" in ckpt else {}
# Older image-mode trainers didn't store input_mode in args; sniff
# the state_dict to disambiguate.
actual = (args.get("input_mode")
or self._detect_input_mode(ckpt["model_state_dict"]))
if actual != expected_input_mode:
raise ValueError(
f"{expected_input_mode!r} checkpoint expected; got input_mode="
f"{actual!r} for {ckpt_path}"
)
return ckpt_path, ckpt, args
def _load_ec_image(self, ckpt_path):
ckpt_path, ckpt, args = self._build_image_model_state(
ckpt_path, MultiMechanismFlow,
)
if args.get("mechanism_list") is not None:
new_list = list(args["mechanism_list"])
_fm_module.MECHANISM_LIST = new_list
_mm_module.MECHANISM_LIST = new_list
self.ec_image_mech_list = list(_fm_module.MECHANISM_LIST)
self.ec_image_model = MultiMechanismFlow(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
aggregation=args.get("aggregation", "set_transformer"),
use_summary_features=False,
input_mode="image",
image_in_channels=args.get("image_in_channels", 1),
)
state = ckpt["model_state_dict"]
if any(k.startswith("ood_head.") for k in state):
meta = ckpt.get("ood_head_meta", {})
self.ec_image_model.init_ood_head(
hidden_dim=meta.get("hidden_dim", 64),
extra_input_dim=meta.get("extra_input_dim", 0),
use_nll=meta.get("use_nll", False),
use_posterior_width=meta.get("use_posterior_width", False),
)
missing, unexpected = self.ec_image_model.load_state_dict(
state, strict=False)
suspicious_missing = [k for k in missing if not k.endswith("_initialized")]
suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")]
if suspicious_missing or suspicious_unexpected:
print("[SPARKPredictor] WARNING: state_dict mismatch on EC image ckpt.")
if suspicious_missing:
print(f" missing ({len(suspicious_missing)}): "
f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}")
if suspicious_unexpected:
print(f" unexpected({len(suspicious_unexpected)}): "
f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}")
_fix_actnorm_initialized(self.ec_image_model)
self.ec_image_model.to(self.device).eval()
def _load_tpd_image(self, ckpt_path):
ckpt_path, ckpt, args = self._build_image_model_state(
ckpt_path, MultiMechanismFlowTPD, is_tpd=True,
)
ckpt_mech_list = args.get("mechanism_list")
if ckpt_mech_list is not None:
new_list = list(ckpt_mech_list)
_tpd_gen.TPD_MECHANISM_LIST = new_list
_tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)}
_tpd_mod.TPD_MECHANISM_LIST = new_list
try:
import dataset_tpd as _ds_tpd
_ds_tpd.TPD_MECHANISM_LIST = new_list
except ImportError:
pass
self.tpd_image_mech_list = list(_tpd_gen.TPD_MECHANISM_LIST)
self.tpd_image_model = MultiMechanismFlowTPD(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
use_summary_features=False,
use_bounded_flow=args.get("use_bounded_flow", False),
mechanism_list=self.tpd_image_mech_list,
input_mode="image",
image_in_channels=args.get("image_in_channels", 1),
)
state = ckpt["model_state_dict"]
if any(k.startswith("ood_head.") for k in state):
meta = ckpt.get("ood_head_meta", {})
self.tpd_image_model.init_ood_head(
hidden_dim=meta.get("hidden_dim", 64),
extra_input_dim=meta.get("extra_input_dim", 0),
)
missing, unexpected = self.tpd_image_model.load_state_dict(
state, strict=False)
suspicious_missing = [k for k in missing if not k.endswith("_initialized")]
suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")]
if suspicious_missing or suspicious_unexpected:
print("[SPARKPredictor] WARNING: state_dict mismatch on TPD image ckpt.")
if suspicious_missing:
print(f" missing ({len(suspicious_missing)}): "
f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}")
if suspicious_unexpected:
print(f" unexpected({len(suspicious_unexpected)}): "
f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}")
_fix_actnorm_initialized(self.tpd_image_model)
self.tpd_image_model.to(self.device).eval()
# =====================================================================
# Phase-2 joint (image + waveform) loaders
# =====================================================================
def _load_ec_joint(self, ckpt_path):
ckpt_path, ckpt, args = self._build_image_model_state(
ckpt_path, MultiMechanismFlow,
expected_input_mode="image+waveform",
)
if args.get("mechanism_list") is not None:
new_list = list(args["mechanism_list"])
_fm_module.MECHANISM_LIST = new_list
_mm_module.MECHANISM_LIST = new_list
self.ec_joint_mech_list = list(_fm_module.MECHANISM_LIST)
self.ec_joint_model = MultiMechanismFlow(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
aggregation=args.get("aggregation", "set_transformer"),
use_summary_features=False,
input_mode="image+waveform",
image_in_channels=args.get("image_in_channels", 1),
)
state = ckpt["model_state_dict"]
if any(k.startswith("ood_head.") for k in state):
meta = ckpt.get("ood_head_meta", {})
self.ec_joint_model.init_ood_head(
hidden_dim=meta.get("hidden_dim", 64),
extra_input_dim=meta.get("extra_input_dim", 0),
use_nll=meta.get("use_nll", False),
use_posterior_width=meta.get("use_posterior_width", False),
)
missing, unexpected = self.ec_joint_model.load_state_dict(
state, strict=False)
suspicious_missing = [k for k in missing if not k.endswith("_initialized")]
suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")]
if suspicious_missing or suspicious_unexpected:
print("[SPARKPredictor] WARNING: state_dict mismatch on EC joint ckpt.")
if suspicious_missing:
print(f" missing ({len(suspicious_missing)}): "
f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}")
if suspicious_unexpected:
print(f" unexpected({len(suspicious_unexpected)}): "
f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}")
_fix_actnorm_initialized(self.ec_joint_model)
self.ec_joint_model.to(self.device).eval()
def _load_tpd_joint(self, ckpt_path):
ckpt_path, ckpt, args = self._build_image_model_state(
ckpt_path, MultiMechanismFlowTPD, is_tpd=True,
expected_input_mode="image+waveform",
)
ckpt_mech_list = args.get("mechanism_list")
if ckpt_mech_list is not None:
new_list = list(ckpt_mech_list)
_tpd_gen.TPD_MECHANISM_LIST = new_list
_tpd_gen.TPD_MECHANISM_TO_ID = {m: i for i, m in enumerate(new_list)}
_tpd_mod.TPD_MECHANISM_LIST = new_list
try:
import dataset_tpd as _ds_tpd
_ds_tpd.TPD_MECHANISM_LIST = new_list
except ImportError:
pass
self.tpd_joint_mech_list = list(_tpd_gen.TPD_MECHANISM_LIST)
self.tpd_joint_model = MultiMechanismFlowTPD(
d_context=args.get("d_context", 128),
d_model=args.get("d_model", 128),
n_coupling_layers=args.get("n_coupling_layers", 6),
hidden_dim=args.get("hidden_dim", 96),
coupling_type=args.get("coupling_type", "spline"),
n_bins=args.get("n_bins", 8),
tail_bound=args.get("tail_bound", 5.0),
use_summary_features=False,
use_bounded_flow=args.get("use_bounded_flow", False),
mechanism_list=self.tpd_joint_mech_list,
input_mode="image+waveform",
image_in_channels=args.get("image_in_channels", 1),
)
state = ckpt["model_state_dict"]
if any(k.startswith("ood_head.") for k in state):
meta = ckpt.get("ood_head_meta", {})
self.tpd_joint_model.init_ood_head(
hidden_dim=meta.get("hidden_dim", 64),
extra_input_dim=meta.get("extra_input_dim", 0),
)
missing, unexpected = self.tpd_joint_model.load_state_dict(
state, strict=False)
suspicious_missing = [k for k in missing if not k.endswith("_initialized")]
suspicious_unexpected = [k for k in unexpected if not k.endswith("_initialized")]
if suspicious_missing or suspicious_unexpected:
print("[SPARKPredictor] WARNING: state_dict mismatch on TPD joint ckpt.")
if suspicious_missing:
print(f" missing ({len(suspicious_missing)}): "
f"{suspicious_missing[:6]}{' ...' if len(suspicious_missing) > 6 else ''}")
if suspicious_unexpected:
print(f" unexpected({len(suspicious_unexpected)}): "
f"{suspicious_unexpected[:6]}{' ...' if len(suspicious_unexpected) > 6 else ''}")
_fix_actnorm_initialized(self.tpd_joint_model)
self.tpd_joint_model.to(self.device).eval()
@staticmethod
def _pil_to_grayscale_tensor(pil_image, target_size=224):
"""Convert a PIL image to a [1, target_size, target_size] float
tensor in [0, 1]."""
from PIL import Image as PILImage
img = pil_image.convert("L")
if img.size != (target_size, target_size):
img = img.resize((target_size, target_size), PILImage.BILINEAR)
arr = np.asarray(img, dtype=np.float32) / 255.0
return torch.from_numpy(arr).unsqueeze(0) # [1, H, W]
def _build_image_input(self, pil_images, sigmas, flux_scales,
target_size=224):
"""Build image-mode model input from a list of PIL images.
Args:
pil_images: list of PIL.Image objects (one per scan rate / heating rate).
sigmas: 1-D array of raw scan rates (V/s) or heating rates (K/s).
flux_scales: 1-D array of log10(peak |signal|) per scan/curve.
If None, set to zero (model still works since flux_scales is
additive conditioning that the network learned to handle).
target_size: image edge length expected by the encoder.
"""
n = len(pil_images)
imgs = torch.stack(
[self._pil_to_grayscale_tensor(p, target_size) for p in pil_images]
) # [N, 1, H, W]
x = imgs.unsqueeze(0).to(self.device) # [1, N, 1, H, W]
scan_mask = torch.ones(1, n, dtype=torch.bool, device=self.device)
sigmas_log = np.log10(np.clip(np.asarray(sigmas, dtype=np.float32),
1e-10, None))
sigmas_t = torch.from_numpy(sigmas_log).unsqueeze(0).to(self.device)
if flux_scales is None:
fs_t = torch.zeros(1, n, dtype=torch.float32, device=self.device)
else:
fs_t = torch.from_numpy(
np.asarray(flux_scales, dtype=np.float32)
).unsqueeze(0).to(self.device)
return {
"input": x, "scan_mask": scan_mask,
"sigmas": sigmas_t, "flux_scales": fs_t,
}
@torch.no_grad()
def predict_ec_image(self, pil_images, sigmas, flux_scales=None,
n_samples=500, temperature=1.0):
"""Run image-mode CV inference. `pil_images` length should match `sigmas`.
Returns the same dict shape as `predict_ec`.
"""
if self.ec_image_model is None:
raise RuntimeError("EC image model not loaded")
if len(pil_images) != len(sigmas):
raise ValueError(
f"#images ({len(pil_images)}) must match #sigmas ({len(sigmas)})"
)
tensors = self._build_image_input(pil_images, sigmas, flux_scales)
pred = self.ec_image_model.predict(
tensors["input"],
scan_mask=tensors["scan_mask"],
sigmas=tensors["sigmas"],
flux_scales=tensors["flux_scales"],
n_samples=n_samples,
temperature=temperature,
)
return self._format_ec_pred(pred, self.ec_image_mech_list)
def _build_joint_input_ec(self, pil_images, sigmas,
potentials, fluxes, times):
"""Build the joint encoder's input dict for CV.
Combines `_build_image_input` (image branch) with the same
normalization/resampling pipeline `_prepare_ec_tensor` uses for the
waveform branch, then fuses them into the dict shape the joint
encoder expects.
"""
n_scans = len(pil_images)
if len(sigmas) != n_scans or len(potentials) != n_scans \
or len(fluxes) != n_scans:
raise ValueError(
"predict_ec_joint: pil_images, sigmas, potentials, fluxes "
"must all have the same length"
)
# Image branch: [1, N, 1, H, W]
imgs = torch.stack(
[self._pil_to_grayscale_tensor(p, target_size=224)
for p in pil_images]
).unsqueeze(0).to(self.device)
scan_mask_image = torch.ones(1, n_scans, dtype=torch.bool,
device=self.device)
# Waveform branch: re-use the existing tensor builder.
wf_tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas)
x = {
"image": imgs,
"waveform": wf_tensors["input"],
"scan_mask_image": scan_mask_image,
"scan_mask_waveform": wf_tensors["scan_mask"],
}
return x, wf_tensors["sigmas"], wf_tensors["flux_scales"]
def _build_joint_input_tpd(self, pil_images, betas, temperatures, rates):
n_rates = len(pil_images)
if len(betas) != n_rates or len(temperatures) != n_rates \
or len(rates) != n_rates:
raise ValueError(
"predict_tpd_joint: pil_images, betas, temperatures, rates "
"must all have the same length"
)
imgs = torch.stack(
[self._pil_to_grayscale_tensor(p, target_size=224)
for p in pil_images]
).unsqueeze(0).to(self.device)
scan_mask_image = torch.ones(1, n_rates, dtype=torch.bool,
device=self.device)
wf_tensors = self._prepare_tpd_tensor(temperatures, rates, betas)
x = {
"image": imgs,
"waveform": wf_tensors["input"],
"scan_mask_image": scan_mask_image,
"scan_mask_waveform": wf_tensors["scan_mask"],
}
return x, wf_tensors["sigmas"], wf_tensors["flux_scales"]
@torch.no_grad()
def predict_ec_joint(self, pil_images, sigmas, potentials, fluxes,
times=None, n_samples=500, temperature=1.0):
"""Run Phase-2 joint CV inference: image + waveform together."""
if self.ec_joint_model is None:
raise RuntimeError("EC joint model not loaded")
x, sigmas_t, flux_scales_t = self._build_joint_input_ec(
pil_images, sigmas, potentials, fluxes, times,
)
pred = self.ec_joint_model.predict(
x, scan_mask=None, sigmas=sigmas_t, flux_scales=flux_scales_t,
n_samples=n_samples, temperature=temperature,
)
return self._format_ec_pred(pred, self.ec_joint_mech_list)
@torch.no_grad()
def predict_tpd_joint(self, pil_images, betas, temperatures, rates,
n_samples=500, temperature=1.0):
"""Run Phase-2 joint TPD inference: image + waveform together."""
if self.tpd_joint_model is None:
raise RuntimeError("TPD joint model not loaded")
x, sigmas_t, flux_scales_t = self._build_joint_input_tpd(
pil_images, betas, temperatures, rates,
)
pred = self.tpd_joint_model.predict(
x, scan_mask=None, sigmas=sigmas_t, flux_scales=flux_scales_t,
n_samples=n_samples, temperature=temperature,
)
return self._format_tpd_pred(pred, self.tpd_joint_mech_list)
@torch.no_grad()
def predict_tpd_image(self, pil_images, betas, flux_scales=None,
n_samples=500, temperature=1.0):
"""Run image-mode TPD inference. Returns the same shape as predict_tpd."""
if self.tpd_image_model is None:
raise RuntimeError("TPD image model not loaded")
if len(pil_images) != len(betas):
raise ValueError(
f"#images ({len(pil_images)}) must match #betas ({len(betas)})"
)
tensors = self._build_image_input(pil_images, betas, flux_scales)
pred = self.tpd_image_model.predict(
tensors["input"],
scan_mask=tensors["scan_mask"],
sigmas=tensors["sigmas"],
flux_scales=tensors["flux_scales"],
n_samples=n_samples,
temperature=temperature,
)
return self._format_tpd_pred(pred, self.tpd_image_mech_list)
def _format_ec_pred(self, pred, mech_list):
probs = pred["mechanism_probs"][0].cpu().numpy()
pred_idx = int(pred["mechanism_pred"][0].cpu().item())
pred_mech = mech_list[pred_idx]
param_stats = {}
samples_dict = {}
for mech in mech_list:
if pred["samples"][mech] is not None:
s = pred["samples"][mech][0].cpu().numpy()
samples_dict[mech] = s
param_stats[mech] = {
"names": MECHANISM_PARAMS[mech]["names"],
"mean": s.mean(axis=0).tolist(),
"std": s.std(axis=0).tolist(),
"median": np.median(s, axis=0).tolist(),
"q05": np.quantile(s, 0.05, axis=0).tolist(),
"q95": np.quantile(s, 0.95, axis=0).tolist(),
}
ood_score = pred.get("ood_score")
ood_score_val = (float(ood_score[0].cpu().item())
if ood_score is not None else None)
return {
"domain": "ec",
"mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)},
"mechanism_names": mech_list,
"predicted_mechanism": pred_mech,
"predicted_mechanism_idx": pred_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
"ood_score": ood_score_val,
}
def _format_tpd_pred(self, pred, mech_list):
probs = pred["mechanism_probs"][0].cpu().numpy()
pred_idx = int(pred["mechanism_pred"][0].cpu().item())
pred_mech = mech_list[pred_idx]
param_stats = {}
samples_dict = {}
for mech in mech_list:
if pred["samples"][mech] is not None:
s = pred["samples"][mech][0].cpu().numpy()
samples_dict[mech] = s
param_stats[mech] = {
"names": TPD_MECHANISM_PARAMS[mech]["names"],
"mean": s.mean(axis=0).tolist(),
"std": s.std(axis=0).tolist(),
"median": np.median(s, axis=0).tolist(),
"q05": np.quantile(s, 0.05, axis=0).tolist(),
"q95": np.quantile(s, 0.95, axis=0).tolist(),
}
ood_score = pred.get("ood_score")
ood_score_val = (float(ood_score[0].cpu().item())
if ood_score is not None else None)
return {
"domain": "tpd",
"mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)},
"mechanism_names": mech_list,
"predicted_mechanism": pred_mech,
"predicted_mechanism_idx": pred_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
"ood_score": ood_score_val,
}
# =====================================================================
# Hybrid predictors: run image-mode + waveform-mode in parallel and
# combine, with selectable strategies.
# =====================================================================
@staticmethod
def _ensemble_results(image_result, waveform_result, param_names_lookup):
"""Combine image-mode + waveform-mode predictions.
- mechanism_probs: arithmetic mean of the two per-mech dicts.
- predicted_mechanism: argmax of the ensembled probs.
- posterior_samples: per-mech, concatenate the samples from each
model (when both produced samples), then recompute parameter_stats.
- ood_score: max of the two scores (more conservative; image-mode's
OOD means P(ID), so 'higher' is safer; we take min instead — see
below). We surface the IMAGE-mode OOD score as the headline OOD,
but if the waveform model also has one we take the lower of the
two so the banner is shown when *either* is concerned.
`param_names_lookup` is a dict mech -> list of parameter names; it
provides 'names' for the recomputed parameter_stats.
"""
if image_result is None:
return waveform_result
if waveform_result is None:
return image_result
mech_list = image_result["mechanism_names"]
probs_a = image_result["mechanism_probs"]
probs_b = waveform_result["mechanism_probs"]
ensemble_probs = {
m: 0.5 * (probs_a.get(m, 0.0) + probs_b.get(m, 0.0))
for m in mech_list
}
s = sum(ensemble_probs.values())
if s > 0:
ensemble_probs = {m: v / s for m, v in ensemble_probs.items()}
sorted_probs = sorted(ensemble_probs.items(), key=lambda kv: -kv[1])
top_mech, _ = sorted_probs[0]
top_idx = mech_list.index(top_mech)
samples_dict = {}
param_stats = {}
for mech in mech_list:
sa = image_result["posterior_samples"].get(mech)
sb = waveform_result["posterior_samples"].get(mech)
if sa is not None and sb is not None:
combined = np.concatenate([sa, sb], axis=0)
elif sa is not None:
combined = sa
elif sb is not None:
combined = sb
else:
combined = None
if combined is None:
continue
samples_dict[mech] = combined
names = param_names_lookup.get(mech, [f"p{i}" for i in range(combined.shape[-1])])
param_stats[mech] = {
"names": names,
"mean": combined.mean(axis=0).tolist(),
"std": combined.std(axis=0).tolist(),
"median": np.median(combined, axis=0).tolist(),
"q05": np.quantile(combined, 0.05, axis=0).tolist(),
"q95": np.quantile(combined, 0.95, axis=0).tolist(),
}
ood_a = image_result.get("ood_score")
ood_b = waveform_result.get("ood_score")
ood_vals = [v for v in (ood_a, ood_b) if v is not None]
ood_ensemble = min(ood_vals) if ood_vals else None # P(ID); lower=more concerning
return {
"domain": image_result["domain"],
"mechanism_probs": ensemble_probs,
"mechanism_names": mech_list,
"predicted_mechanism": top_mech,
"predicted_mechanism_idx": top_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
"ood_score": ood_ensemble,
"_ensemble": True,
}
@staticmethod
def _agreement_stats(image_result, waveform_result):
if image_result is None or waveform_result is None:
return {
"available": False,
"top_mech_match": None,
"top_prob_image": None,
"top_prob_waveform": None,
}
return {
"available": True,
"top_mech_match": (
image_result["predicted_mechanism"]
== waveform_result["predicted_mechanism"]
),
"top_prob_image": float(max(image_result["mechanism_probs"].values())),
"top_prob_waveform": float(max(waveform_result["mechanism_probs"].values())),
"image_mech": image_result["predicted_mechanism"],
"waveform_mech": waveform_result["predicted_mechanism"],
}
def _predict_ec_two_paths(
self, pil_images, sigmas, potentials, fluxes, times=None,
n_samples=500, temperature=1.0, do_preprocess=True,
):
"""Run image-mode + waveform-mode + joint CV inference (whichever are
available). Returns (image_result, waveform_result, joint_result,
preprocessing_meta). Any result may be None if the corresponding
model is missing or the input arrays were not provided.
"""
image_result = None
waveform_result = None
joint_result = None
preproc_meta = []
# Preprocess images once; both image-only and joint paths reuse them.
preprocessed = None
if pil_images and (self.has_ec_image_model or self.has_ec_joint_model):
if do_preprocess:
from image_preprocessing import prepare_for_image_mode
preprocessed = []
for p in pil_images:
out, meta = prepare_for_image_mode(p)
preprocessed.append(out)
preproc_meta.append(meta)
else:
preprocessed = list(pil_images)
preproc_meta = [{} for _ in pil_images]
if self.has_ec_image_model and preprocessed is not None:
flux_scales = None
if fluxes is not None:
flux_scales = [
float(np.log10(np.max(np.abs(np.asarray(f))) + 1e-30))
for f in fluxes
]
try:
image_result = self.predict_ec_image(
preprocessed, sigmas, flux_scales=flux_scales,
n_samples=n_samples, temperature=temperature,
)
except Exception as exc:
print(f"[SPARKPredictor] image-mode CV failed: {exc}")
if self.ec_model is not None and potentials is not None and fluxes is not None:
try:
waveform_result = self.predict_ec(
potentials, fluxes, sigmas, times=times,
n_samples=n_samples, temperature=temperature,
)
except Exception as exc:
print(f"[SPARKPredictor] waveform CV failed: {exc}")
if (self.has_ec_joint_model and preprocessed is not None
and potentials is not None and fluxes is not None):
try:
joint_result = self.predict_ec_joint(
preprocessed, sigmas, potentials, fluxes, times=times,
n_samples=n_samples, temperature=temperature,
)
except Exception as exc:
print(f"[SPARKPredictor] joint CV failed: {exc}")
return image_result, waveform_result, joint_result, preproc_meta
def predict_ec_hybrid(
self, pil_images, sigmas, potentials=None, fluxes=None, times=None,
n_samples=500, temperature=1.0, mode="ensemble",
do_preprocess=True, ood_fallback_threshold=0.3,
):
"""Hybrid CV inference combining image-mode and waveform-mode.
Args:
pil_images: list of PIL.Image (one per scan rate). Required for
image-mode; if missing, image-mode is skipped.
sigmas: list of dimensionless scan rates.
potentials, fluxes: dimensionless waveforms (one list of arrays
per scan rate). Required for waveform-mode; if missing,
waveform-mode is skipped.
times: optional dimensionless time arrays.
mode: one of 'ensemble', 'image_only', 'digitize_only',
'auto_fallback'.
do_preprocess: whether to run prepare_for_image_mode on inputs
before image-mode (default True).
ood_fallback_threshold: if mode='auto_fallback', image-mode's
OOD score (P(ID)) below this threshold triggers fallback to
waveform-mode.
Returns dict with:
headline: result dict to display (same shape as predict_ec).
image_mode: image-mode result or None.
waveform_mode: waveform result or None.
preprocessing_meta: per-image preprocessing metadata list.
agreement: agreement_stats(image, waveform) dict.
method_used: human-readable string of which method produced
'headline'.
"""
(image_result, waveform_result, joint_result, preproc_meta) = (
self._predict_ec_two_paths(
pil_images, sigmas, potentials, fluxes, times,
n_samples=n_samples, temperature=temperature,
do_preprocess=do_preprocess,
)
)
agreement = self._agreement_stats(image_result, waveform_result)
param_names_lookup = {m: MECHANISM_PARAMS[m]["names"] for m in MECHANISM_PARAMS}
headline, method_used = self._select_headline(
mode, image_result, waveform_result,
param_names_lookup, ood_fallback_threshold,
joint_result=joint_result,
)
return {
"headline": headline,
"image_mode": image_result,
"waveform_mode": waveform_result,
"joint_mode": joint_result,
"preprocessing_meta": preproc_meta,
"agreement": agreement,
"method_used": method_used,
}
def _predict_tpd_two_paths(
self, pil_images, betas, temperatures, rates,
n_samples=500, temperature=1.0, do_preprocess=True,
):
image_result = None
waveform_result = None
joint_result = None
preproc_meta = []
preprocessed = None
if pil_images and (self.has_tpd_image_model or self.has_tpd_joint_model):
if do_preprocess:
from image_preprocessing import prepare_for_image_mode
preprocessed = []
for p in pil_images:
out, meta = prepare_for_image_mode(p)
preprocessed.append(out)
preproc_meta.append(meta)
else:
preprocessed = list(pil_images)
preproc_meta = [{} for _ in pil_images]
if self.has_tpd_image_model and preprocessed is not None:
flux_scales = None
if rates is not None:
flux_scales = [
float(np.log10(np.max(np.abs(np.asarray(r))) + 1e-30))
for r in rates
]
try:
image_result = self.predict_tpd_image(
preprocessed, betas, flux_scales=flux_scales,
n_samples=n_samples, temperature=temperature,
)
except Exception as exc:
print(f"[SPARKPredictor] image-mode TPD failed: {exc}")
if self.tpd_model is not None and temperatures is not None and rates is not None:
try:
waveform_result = self.predict_tpd(
temperatures, rates, betas,
n_samples=n_samples, temperature=temperature,
)
except Exception as exc:
print(f"[SPARKPredictor] waveform TPD failed: {exc}")
if (self.has_tpd_joint_model and preprocessed is not None
and temperatures is not None and rates is not None):
try:
joint_result = self.predict_tpd_joint(
preprocessed, betas, temperatures, rates,
n_samples=n_samples, temperature=temperature,
)
except Exception as exc:
print(f"[SPARKPredictor] joint TPD failed: {exc}")
return image_result, waveform_result, joint_result, preproc_meta
def predict_tpd_hybrid(
self, pil_images, betas, temperatures=None, rates=None,
n_samples=500, temperature=1.0, mode="ensemble",
do_preprocess=True, ood_fallback_threshold=0.3,
):
"""Hybrid TPD inference. Same shape as predict_ec_hybrid."""
(image_result, waveform_result, joint_result, preproc_meta) = (
self._predict_tpd_two_paths(
pil_images, betas, temperatures, rates,
n_samples=n_samples, temperature=temperature,
do_preprocess=do_preprocess,
)
)
agreement = self._agreement_stats(image_result, waveform_result)
param_names_lookup = {m: TPD_MECHANISM_PARAMS[m]["names"]
for m in TPD_MECHANISM_PARAMS}
headline, method_used = self._select_headline(
mode, image_result, waveform_result,
param_names_lookup, ood_fallback_threshold,
joint_result=joint_result,
)
return {
"headline": headline,
"image_mode": image_result,
"waveform_mode": waveform_result,
"joint_mode": joint_result,
"preprocessing_meta": preproc_meta,
"agreement": agreement,
"method_used": method_used,
}
def _select_headline(self, mode, image_result, waveform_result,
param_names_lookup, ood_fallback_threshold,
joint_result=None):
"""Pick the headline result based on `mode`. Falls back gracefully
when one of the available paths is unavailable.
Supported modes: ``ensemble``, ``image_only``, ``digitize_only``,
``auto_fallback``, ``joint``.
"""
if (image_result is None and waveform_result is None
and joint_result is None):
raise RuntimeError(
"Hybrid inference: image, waveform, and joint paths all "
"failed or were unavailable."
)
if mode == "joint":
if joint_result is not None:
return joint_result, "joint image+waveform (Phase 2)"
if image_result is not None and waveform_result is not None:
return (
self._ensemble_results(
image_result, waveform_result, param_names_lookup),
"ensemble (joint unavailable)",
)
if image_result is not None:
return image_result, "image-direct (joint unavailable)"
return waveform_result, "digitize-then-infer (joint unavailable)"
if mode == "image_only":
if image_result is not None:
return image_result, "image-direct"
return waveform_result, "digitize-then-infer (image fallback)"
if mode == "digitize_only":
if waveform_result is not None:
return waveform_result, "digitize-then-infer"
return image_result, "image-direct (waveform fallback)"
if mode == "auto_fallback":
if image_result is None:
return waveform_result, "digitize-then-infer (no image model)"
ood = image_result.get("ood_score")
if (ood is not None and ood < ood_fallback_threshold
and waveform_result is not None):
return waveform_result, (
f"digitize-then-infer (image OOD score "
f"{ood:.2f} < {ood_fallback_threshold:.2f})"
)
return image_result, "image-direct"
# ensemble (default): prefer joint when present, else mix image+wave
if joint_result is not None:
return joint_result, "joint image+waveform (Phase 2)"
if image_result is not None and waveform_result is not None:
return (
self._ensemble_results(
image_result, waveform_result, param_names_lookup),
"ensemble (image + digitize)",
)
if image_result is not None:
return image_result, "image-direct (waveform unavailable)"
return waveform_result, "digitize-then-infer (image unavailable)"
def _prepare_ec_tensor(self, potentials, fluxes, times, sigmas):
"""
Build model input tensor from preprocessed dimensionless CV data.
Args:
potentials: list of 1-D arrays (dimensionless theta)
fluxes: list of 1-D arrays (dimensionless flux)
times: list of 1-D arrays (dimensionless time) or None
sigmas: 1-D array of dimensionless scan rates
Returns:
dict of tensors ready for model.predict()
"""
from scipy.interpolate import interp1d
n_scans = len(potentials)
T_target = 672
pot_resampled = []
flux_resampled = []
time_resampled = []
flux_scales = []
for i in range(n_scans):
pot = np.asarray(potentials[i], dtype=np.float32)
flx = np.asarray(fluxes[i], dtype=np.float32)
if times is not None and times[i] is not None:
tim = np.asarray(times[i], dtype=np.float32)
else:
theta_range = pot.max() - pot.min()
sigma = sigmas[i]
total_time = 2.0 * theta_range / sigma
tim = np.linspace(0, total_time, len(pot), dtype=np.float32)
peak = np.max(np.abs(flx)) + 1e-30
flux_scales.append(np.log10(peak))
flx = flx / peak
t_uniform = np.linspace(tim[0], tim[-1], T_target)
pot_resampled.append(
interp1d(tim, pot, kind="linear", fill_value="extrapolate")(t_uniform)
)
flux_resampled.append(
interp1d(tim, flx, kind="linear", fill_value="extrapolate")(t_uniform)
)
time_resampled.append(t_uniform)
pot_arr = np.stack(pot_resampled).astype(np.float32)
flx_arr = np.stack(flux_resampled).astype(np.float32)
tim_arr = np.stack(time_resampled).astype(np.float32)
ns = self.ec_norm_stats
if ns:
pot_arr = (pot_arr - ns["potential"][0]) / ns["potential"][1]
flx_arr = (flx_arr - ns["flux"][0]) / ns["flux"][1]
tim_arr = (tim_arr - ns["time"][0]) / ns["time"][1]
# [1, N, 3, T]
waveforms = np.stack([pot_arr, flx_arr, tim_arr], axis=1)
x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device)
scan_mask = torch.ones(1, n_scans, T_target, dtype=torch.bool, device=self.device)
sigmas_t = torch.from_numpy(
np.log10(np.asarray(sigmas, dtype=np.float32))
).unsqueeze(0).to(self.device)
flux_scales_t = torch.from_numpy(
np.asarray(flux_scales, dtype=np.float32)
).unsqueeze(0).to(self.device)
return {
"input": x,
"scan_mask": scan_mask,
"sigmas": sigmas_t,
"flux_scales": flux_scales_t,
}
def _prepare_tpd_tensor(self, temperatures, rates, betas):
"""
Build model input tensor from TPD data.
Args:
temperatures: list of 1-D arrays (K)
rates: list of 1-D arrays (arb. units)
betas: 1-D array of heating rates (K/s)
Returns:
dict of tensors ready for model.predict()
"""
from scipy.interpolate import interp1d
n_rates = len(temperatures)
T_target = 500
temp_resampled = []
rate_resampled = []
for i in range(n_rates):
temp = np.asarray(temperatures[i], dtype=np.float32)
rate = np.asarray(rates[i], dtype=np.float32)
t_uniform = np.linspace(temp[0], temp[-1], T_target)
temp_resampled.append(t_uniform)
rate_resampled.append(
interp1d(temp, rate, kind="linear", fill_value="extrapolate")(t_uniform)
)
temp_arr = np.stack(temp_resampled).astype(np.float32)
rate_arr = np.stack(rate_resampled).astype(np.float32)
summary_t = None
if getattr(self, 'tpd_use_summary', False):
from preprocessing import extract_tpd_summary_stats
hr_arr = np.asarray(betas, dtype=np.float32)
lengths = np.full(n_rates, T_target, dtype=np.int32)
summary = extract_tpd_summary_stats(
temp_arr, rate_arr, lengths, hr_arr, n_rates)
summary_t = torch.from_numpy(summary).unsqueeze(0).to(self.device)
rate_scales = []
for i in range(n_rates):
peak = np.max(np.abs(rate_arr[i])) + 1e-30
rate_scales.append(np.log10(peak))
rate_arr[i] /= peak
ns = self.tpd_norm_stats
if ns:
temp_arr = (temp_arr - ns["temperature"][0]) / ns["temperature"][1]
rate_arr = (rate_arr - ns["rate"][0]) / ns["rate"][1]
# [1, N, 2, T]
waveforms = np.stack([temp_arr, rate_arr], axis=1)
x = torch.from_numpy(waveforms).unsqueeze(0).to(self.device)
scan_mask = torch.ones(1, n_rates, T_target, dtype=torch.bool, device=self.device)
sigmas_t = torch.from_numpy(
np.log10(np.asarray(betas, dtype=np.float32))
).unsqueeze(0).to(self.device)
rate_scales_t = torch.from_numpy(
np.asarray(rate_scales, dtype=np.float32)
).unsqueeze(0).to(self.device)
result = {
"input": x,
"scan_mask": scan_mask,
"sigmas": sigmas_t,
"flux_scales": rate_scales_t,
}
if summary_t is not None:
result["summary"] = summary_t
return result
@torch.no_grad()
def predict_ec(self, potentials, fluxes, sigmas, times=None, n_samples=500, temperature=1.0):
"""
Run EC inference on dimensionless CV data.
Args:
potentials: list of 1-D arrays (dimensionless theta per scan rate)
fluxes: list of 1-D arrays (dimensionless flux per scan rate)
sigmas: list/array of dimensionless scan rates
times: optional list of 1-D time arrays
n_samples: posterior samples to draw
temperature: sampling temperature (>1 broadens posteriors)
Returns:
dict with mechanism_probs, mechanism_names, predicted_mechanism,
parameter_stats (per mechanism), posterior_samples (per mechanism)
"""
if self.ec_model is None:
raise RuntimeError("EC model not loaded")
tensors = self._prepare_ec_tensor(potentials, fluxes, times, sigmas)
pred = self.ec_model.predict(
tensors["input"],
scan_mask=tensors["scan_mask"],
sigmas=tensors["sigmas"],
flux_scales=tensors["flux_scales"],
n_samples=n_samples,
temperature=temperature,
)
mech_list = self.ec_mechanism_list
probs = pred["mechanism_probs"][0].cpu().numpy()
pred_idx = int(pred["mechanism_pred"][0].cpu().item())
pred_mech = mech_list[pred_idx]
param_stats = {}
samples_dict = {}
for mech in mech_list:
if pred["samples"][mech] is not None:
s = pred["samples"][mech][0].cpu().numpy() # [n_samples, D]
samples_dict[mech] = s
param_stats[mech] = {
"names": MECHANISM_PARAMS[mech]["names"],
"mean": s.mean(axis=0).tolist(),
"std": s.std(axis=0).tolist(),
"median": np.median(s, axis=0).tolist(),
"q05": np.quantile(s, 0.05, axis=0).tolist(),
"q95": np.quantile(s, 0.95, axis=0).tolist(),
}
ood_score = pred.get("ood_score")
ood_score_val = (float(ood_score[0].cpu().item())
if ood_score is not None else None)
return {
"domain": "ec",
"mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)},
"mechanism_names": mech_list,
"predicted_mechanism": pred_mech,
"predicted_mechanism_idx": pred_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
"ood_score": ood_score_val,
}
@torch.no_grad()
def predict_tpd(self, temperatures, rates, betas, n_samples=500, temperature=1.0):
"""
Run TPD inference.
Args:
temperatures: list of 1-D arrays (K per heating rate)
rates: list of 1-D arrays (signal per heating rate)
betas: list/array of heating rates (K/s)
n_samples: posterior samples to draw
temperature: sampling temperature
Returns:
dict with mechanism_probs, parameter_stats, posterior_samples
"""
if self.tpd_model is None:
raise RuntimeError("TPD model not loaded")
tensors = self._prepare_tpd_tensor(temperatures, rates, betas)
pred = self.tpd_model.predict(
tensors["input"],
scan_mask=tensors["scan_mask"],
sigmas=tensors["sigmas"],
flux_scales=tensors["flux_scales"],
n_samples=n_samples,
temperature=temperature,
summary=tensors.get("summary"),
)
mech_list = self.tpd_mechanism_list
probs = pred["mechanism_probs"][0].cpu().numpy()
pred_idx = int(pred["mechanism_pred"][0].cpu().item())
pred_mech = mech_list[pred_idx]
param_stats = {}
samples_dict = {}
for mech in mech_list:
if pred["samples"][mech] is not None:
s = pred["samples"][mech][0].cpu().numpy()
samples_dict[mech] = s
param_stats[mech] = {
"names": TPD_MECHANISM_PARAMS[mech]["names"],
"mean": s.mean(axis=0).tolist(),
"std": s.std(axis=0).tolist(),
"median": np.median(s, axis=0).tolist(),
"q05": np.quantile(s, 0.05, axis=0).tolist(),
"q95": np.quantile(s, 0.95, axis=0).tolist(),
}
ood_score = pred.get("ood_score")
ood_score_val = (float(ood_score[0].cpu().item())
if ood_score is not None else None)
return {
"domain": "tpd",
"mechanism_probs": {m: float(probs[i]) for i, m in enumerate(mech_list)},
"mechanism_names": mech_list,
"predicted_mechanism": pred_mech,
"predicted_mechanism_idx": pred_idx,
"parameter_stats": param_stats,
"posterior_samples": samples_dict,
"ood_score": ood_score_val,
}
# =====================================================================
# Signal Reconstruction
# =====================================================================
def reconstruct_ec(self, result, potentials, fluxes, sigmas,
base_params=None, mechanism=None):
"""
Reconstruct CV signals from inferred posterior median and compute metrics.
Args:
result: output dict from predict_ec()
potentials: list of 1-D arrays (original dimensionless theta)
fluxes: list of 1-D arrays (original dimensionless flux)
sigmas: list of dimensionless scan rates
base_params: dict of fixed simulation params; defaults used if None
mechanism: which mechanism to reconstruct (default: predicted)
Returns:
dict with 'observed', 'reconstructed' curve lists,
'nrmse', 'r2' per scan rate, and 'mean_nrmse', 'mean_r2'
"""
from evaluate_reconstruction import (
reconstruct_ec_signal, signal_nrmse, signal_r2,
)
mech = mechanism or result["predicted_mechanism"]
stats = result["parameter_stats"].get(mech)
if stats is None:
return None
theta_point = np.array(stats["median"])
if base_params is None:
pot0 = np.asarray(potentials[0])
base_params = {
"theta_i": float(pot0.max()),
"theta_v": float(pot0.min()),
"dA": 1.0,
"C_A_bulk": 1.0,
"C_B_bulk": 0.0,
"kinetics": mech,
}
try:
recon_results = reconstruct_ec_signal(
theta_point, mech, base_params, sigmas, n_spatial=64
)
except Exception:
return None
observed_curves = []
recon_curves = []
conc_curves = []
nrmses = []
r2s = []
for i, (pot, flx, sigma) in enumerate(zip(potentials, fluxes, sigmas)):
pot = np.asarray(pot)
flx = np.asarray(flx)
observed_curves.append({"x": pot, "y": flx})
if i < len(recon_results) and recon_results[i].get("success", False):
rec = recon_results[i]
rec_pot = np.asarray(rec["potential"])
rec_flx = np.asarray(rec["flux"])
n_obs = len(pot)
n_rec = len(rec_pot)
t_obs = np.linspace(0, 1, n_obs)
t_rec = np.linspace(0, 1, n_rec)
rec_flx_interp = np.interp(t_obs, t_rec, rec_flx)
recon_curves.append({"x": pot, "y": rec_flx_interp})
nrmse_val = signal_nrmse(flx, rec_flx_interp)
r2_val = signal_r2(flx, rec_flx_interp)
nrmses.append(nrmse_val)
r2s.append(r2_val)
if "c_ox_surface" in rec and "c_red_surface" in rec:
c_ox_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_ox_surface"]))
c_red_interp = np.interp(t_obs, t_rec, np.asarray(rec["c_red_surface"]))
conc_curves.append({
"x": pot,
"c_ox": c_ox_interp,
"c_red": c_red_interp,
})
else:
conc_curves.append(None)
else:
recon_curves.append({"x": pot, "y": np.zeros_like(flx)})
nrmses.append(float("nan"))
r2s.append(float("nan"))
conc_curves.append(None)
valid_nrmse = [v for v in nrmses if np.isfinite(v)]
valid_r2 = [v for v in r2s if np.isfinite(v)]
return {
"observed": observed_curves,
"reconstructed": recon_curves,
"concentrations": conc_curves,
"nrmse": nrmses,
"r2": r2s,
"mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"),
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"),
}
def reconstruct_tpd(self, result, temperatures, rates, betas,
base_params=None, mechanism=None):
"""
Reconstruct TPD signals from inferred posterior median and compute metrics.
Args:
result: output dict from predict_tpd()
temperatures: list of 1-D arrays (K)
rates: list of 1-D arrays (signal)
betas: list of heating rates (K/s)
base_params: dict of fixed simulation params; defaults used if None
mechanism: which mechanism to reconstruct (default: predicted)
Returns:
dict with 'observed', 'reconstructed' curve lists,
'nrmse', 'r2' per heating rate, and 'mean_nrmse', 'mean_r2'
"""
from evaluate_reconstruction import (
reconstruct_tpd_signal, signal_nrmse, signal_r2,
)
mech = mechanism or result["predicted_mechanism"]
stats = result["parameter_stats"].get(mech)
if stats is None:
return None
theta_point = np.array(stats["median"])
if base_params is None:
temp0 = np.asarray(temperatures[0])
base_params = {
"mechanism": mech,
"T_start": float(temp0.min()),
"T_end": float(temp0.max()),
"n_points": 500,
}
try:
recon_results = reconstruct_tpd_signal(
theta_point, mech, base_params, betas
)
except Exception:
return None
observed_curves = []
recon_curves = []
nrmses = []
r2s = []
for i, (temp, rate, beta) in enumerate(zip(temperatures, rates, betas)):
temp = np.asarray(temp)
rate = np.asarray(rate)
observed_curves.append({"x": temp, "y": rate})
if i < len(recon_results) and recon_results[i].get("success", False):
rec = recon_results[i]
rec_temp = np.asarray(rec["temperature"])
rec_rate = np.asarray(rec["rate"])
rec_rate_interp = np.interp(temp, rec_temp, rec_rate)
recon_curves.append({"x": temp, "y": rec_rate_interp})
nrmse_val = signal_nrmse(rate, rec_rate_interp)
r2_val = signal_r2(rate, rec_rate_interp)
nrmses.append(nrmse_val)
r2s.append(r2_val)
else:
recon_curves.append({"x": temp, "y": np.zeros_like(rate)})
nrmses.append(float("nan"))
r2s.append(float("nan"))
valid_nrmse = [v for v in nrmses if np.isfinite(v)]
valid_r2 = [v for v in r2s if np.isfinite(v)]
return {
"observed": observed_curves,
"reconstructed": recon_curves,
"nrmse": nrmses,
"r2": r2s,
"mean_nrmse": float(np.mean(valid_nrmse)) if valid_nrmse else float("nan"),
"mean_r2": float(np.mean(valid_r2)) if valid_r2 else float("nan"),
}