trace / tpd_model.py
bingyan user
Phase-2 joint image+waveform models + From-Image hybrid pipeline
7c95286
Raw
History Blame Contribute Delete
33.9 kB
"""
Multi-Mechanism Normalizing Flow for Joint Mechanism Identification
and Bayesian Parameter Inference from Multi-Heating-Rate TPD Signals.
Architecture mirrors the electrochemistry model (multi_mechanism_model.py)
but configured for TPD with 2 input channels (temperature, rate) and
6 catalysis mechanisms.
Reuses domain-agnostic components from flow_model.py and
multi_mechanism_model.py (SAB, PMA, MechanismClassifier, MechanismFlow).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from flow_model import (
SignalEncoder,
ActNorm,
ConditionalSplineCoupling,
ConditionalAffineCoupling,
)
from multi_mechanism_model import MechanismClassifier, MechanismFlow, SAB, PMA, SummaryProjection, SUMMARY_DIM
from image_encoder import ImageEncoder
from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS
from tpd_bijectors import (
Identity, Logit, AffineLogit,
apply_param_bijectors_forward, apply_param_bijectors_inverse,
all_identity,
)
# =============================================================================
# Per-parameter bijector registry (TPD-only).
#
# Maps parameter NAME -> bijector spec. Names are matched against
# TPD_MECHANISM_PARAMS[mech]['names']. Any name not present here defaults
# to Identity (current behavior, unbounded z-score normalization).
#
# This registry is consulted only when a model is constructed with
# use_bounded_flow=True. Models trained without that flag (incl. v1
# checkpoints) ignore it entirely.
# =============================================================================
TPD_PARAM_BIJECTORS = {
# Strictly-bounded coverage parameters in [0, 1]. These are the
# parameters that previously suffered posterior collapse / over-coverage
# under z-score-only normalization.
'theta_0': ('logit',),
'theta_A0': ('logit',),
'theta_B0': ('logit',),
'theta_O0': ('logit',),
# Site fraction in (0, 1).
'f_site1': ('logit',),
# Initial layer count, sampled in [1.5, 8.0]; treat as bounded
# in [1.0, 10.0] (with eps slack).
'n_layers': ('affine_logit', 1.0, 10.0),
# Note: delta, alpha_cov, omega are tied to Ed/Ea (joint constraint)
# and not strictly bounded on their own; defer to Identity.
}
def _make_bijector(spec):
if spec is None:
return Identity()
kind = spec[0]
if kind == 'identity':
return Identity()
if kind == 'logit':
return Logit()
if kind == 'affine_logit':
_, low, high = spec
return AffineLogit(low, high)
raise ValueError(f"Unknown bijector spec: {spec}")
def build_param_bijectors(mech: str):
"""Return a list of bijectors (one per parameter) for `mech`,
in the same order as TPD_MECHANISM_PARAMS[mech]['names'].
"""
names = TPD_MECHANISM_PARAMS[mech]['names']
return [_make_bijector(TPD_PARAM_BIJECTORS.get(n)) for n in names]
# =============================================================================
# Bounded variant of MechanismFlow (TPD-only)
# =============================================================================
class BoundedMechanismFlow(MechanismFlow):
"""MechanismFlow + per-parameter reparameterization bijectors.
Density of the physical parameter `theta` is computed in the
unbounded representation `u = bij(theta)`, then transformed back
via the change-of-variables rule:
log p_theta(theta) = log p_u(u) + log|du/dtheta|
where `log p_u(u)` is exactly what the base MechanismFlow would
compute on `u` (z-score normalize -> flow -> Gaussian base).
`theta_mean` and `theta_std` (inherited from MechanismFlow) now
refer to U-SPACE statistics, not physical-space statistics. Callers
must compute them in u-space (see
`MultiMechanismFlowTPD.compute_u_space_stats`).
"""
def __init__(self, theta_dim, param_bijectors, **kwargs):
super().__init__(theta_dim=theta_dim, **kwargs)
assert len(param_bijectors) == theta_dim, (
f"BoundedMechanismFlow expects {theta_dim} bijectors, got {len(param_bijectors)}"
)
# Register as a ModuleList so any future parameters in bijectors
# would move with .to(device); currently bijectors are stateless.
self.param_bijectors = nn.ModuleList(param_bijectors)
self._bijector_is_noop = all_identity(self.param_bijectors)
# ---- log p(theta | context) ------------------------------------------
def log_prob(self, theta, context):
if self._bijector_is_noop:
return super().log_prob(theta, context)
# 1. theta -> u in unbounded space, accumulate log|du/dtheta|.
u, log_det_b = apply_param_bijectors_forward(theta, list(self.param_bijectors))
# 2. z-score normalize in u-space, then pass through the flow.
u_norm = self.normalize_theta(u)
if self.coupling_type == 'spline':
u_norm = u_norm.clamp(-self.tail_bound, self.tail_bound)
z, log_det_flow = self.inverse_flow(u_norm, context)
# 3. Standard normal base + change-of-variables for normalization
# and bijector. Note: log_det_norm = -sum log(theta_std), exactly as
# in MechanismFlow.log_prob.
log_pz = -0.5 * (z ** 2 + math.log(2 * math.pi)).sum(dim=-1)
log_det_norm = -torch.log(self.theta_std).sum()
log_p = log_pz + log_det_flow + log_det_norm + log_det_b
return log_p.clamp(min=-50.0, max=50.0)
# ---- sampling --------------------------------------------------------
@torch.no_grad()
def sample(self, context, n_samples=100, temperature=1.0):
if self._bijector_is_noop:
return super().sample(context, n_samples=n_samples, temperature=temperature)
B = context.shape[0]
context_rep = (context.unsqueeze(1)
.expand(-1, n_samples, -1)
.reshape(B * n_samples, -1))
z = torch.randn(B * n_samples, self.theta_dim, device=context.device)
u_norm, _ = self.forward_flow(z, context_rep)
u = self.denormalize_theta(u_norm) # u-space
u = u.reshape(B, n_samples, self.theta_dim)
# Apply temperature inflation in U-SPACE (unbounded), so inflated
# samples remain valid pre-images of the bijector.
if isinstance(temperature, torch.Tensor):
T = temperature.to(u.device).reshape(1, 1, -1)
mu = u.mean(dim=1, keepdim=True)
u = mu + T * (u - mu)
elif temperature != 1.0:
mu = u.mean(dim=1, keepdim=True)
u = mu + temperature * (u - mu)
# Bring back to physical theta space.
u_flat = u.reshape(B * n_samples, self.theta_dim)
theta_flat, _ = apply_param_bijectors_inverse(u_flat, list(self.param_bijectors))
return theta_flat.reshape(B, n_samples, self.theta_dim)
def sample_with_grad(self, context, n_samples=64):
if self._bijector_is_noop:
return super().sample_with_grad(context, n_samples=n_samples)
B = context.shape[0]
context_rep = (context.unsqueeze(1)
.expand(-1, n_samples, -1)
.reshape(B * n_samples, -1))
z = torch.randn(B * n_samples, self.theta_dim, device=context.device)
u_norm, _ = self.forward_flow(z, context_rep)
u = self.denormalize_theta(u_norm) # u-space
# Bijector inverse is differentiable; gradients propagate through
# both the flow and the bijector for the calibration loss.
theta, _ = apply_param_bijectors_inverse(u, list(self.param_bijectors))
return theta.reshape(B, n_samples, self.theta_dim)
class MultiScanEncoderTPD(nn.Module):
"""
Encode a set of multi-heating-rate TPD curves into a single context vector.
Architecture (Set Transformer):
1. Shared per-curve encoder -> per-curve embedding
- input_mode='waveform': 1-D CNN over [B*N, in_channels, T]
- input_mode='image': 2-D CNN over [B*N, 1, H, W]
2. Augment with [log10(heating_rate), log10(peak_rate)]
3. SAB: self-attention across heating rates
4. PMA: attention-based pooling to single vector
5. rho MLP: project to final context
Waveform input: x [B, N_beta, 2, T], scan_mask [B, N_beta, T],
heating_rates [B, N_beta], rate_scales [B, N_beta]
Image input: x [B, N_beta, 1, H, W], scan_mask [B, N_beta]
(per-curve presence flag), heating_rates [B, N_beta],
rate_scales [B, N_beta]
Output: context [B, d_context]
"""
def __init__(self, in_channels=2, d_model=128, d_context=128, n_heads=4,
input_mode='waveform', image_in_channels=1):
super().__init__()
if input_mode not in ('waveform', 'image', 'image+waveform'):
raise ValueError(f"Unknown input_mode: {input_mode!r}")
self.input_mode = input_mode
if input_mode == 'waveform':
self.per_cv_encoder = SignalEncoder(
in_channels=in_channels, d_model=d_model, d_context=d_context,
)
elif input_mode == 'image':
self.per_cv_encoder = ImageEncoder(
in_channels=image_in_channels, d_model=d_model,
d_context=d_context,
)
else:
# Joint mode: parallel image + waveform encoders + fusion MLP.
self.image_encoder = ImageEncoder(
in_channels=image_in_channels, d_model=d_model,
d_context=d_context,
)
self.waveform_encoder = SignalEncoder(
in_channels=in_channels, d_model=d_model, d_context=d_context,
)
self.joint_fusion = nn.Sequential(
nn.Linear(2 * d_context, d_context),
nn.GELU(),
nn.Linear(d_context, d_context),
)
self.cv_augment = nn.Sequential(
nn.Linear(d_context + 2, d_context),
nn.GELU(),
)
self.sab = SAB(d_context, n_heads=n_heads)
self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1)
self.rho = nn.Sequential(
nn.Linear(d_context, d_context),
nn.GELU(),
nn.Linear(d_context, d_context),
)
def forward(self, x, scan_mask=None, sigmas=None, flux_scales=None):
"""
Waveform args:
x: [B, N_beta, 2, T] multi-heating-rate TPD curves
scan_mask: [B, N_beta, T] valid timestep mask
sigmas: [B, N_beta] log10 heating rates
flux_scales: [B, N_beta] log10(peak_rate) per curve
Image args:
x: [B, N_beta, 1, H, W] grayscale plot images in [0, 1]
scan_mask: [B, N_beta] per-curve presence (or [B, N_beta, T]
back-compat, reduced via .any(dim=-1))
sigmas, flux_scales: same as waveform mode.
Returns:
context: [B, d_context]
"""
if self.input_mode == 'image':
return self._forward_image(x, scan_mask=scan_mask,
sigmas=sigmas, flux_scales=flux_scales)
if self.input_mode == 'image+waveform':
if not isinstance(x, dict):
raise ValueError(
"image+waveform mode expects x to be a dict with keys "
"'image' and 'waveform'; got tensor"
)
return self._forward_joint(
x['image'], x['waveform'],
scan_mask_image=x.get('scan_mask_image'),
scan_mask_waveform=x.get('scan_mask_waveform', scan_mask),
sigmas=sigmas, flux_scales=flux_scales,
)
B, N, C, T = x.shape
x_flat = x.reshape(B * N, C, T)
mask_flat = scan_mask.reshape(B * N, T) if scan_mask is not None else None
h_flat = self.per_cv_encoder(x_flat, mask=mask_flat)
h = h_flat.reshape(B, N, -1)
if sigmas is None:
sigmas = torch.zeros(B, N, device=x.device)
if flux_scales is None:
flux_scales = torch.zeros(B, N, device=x.device)
aug_features = torch.stack([sigmas, flux_scales], dim=-1)
h = self.cv_augment(torch.cat([h, aug_features], dim=-1))
if scan_mask is not None:
cv_invalid = ~scan_mask.any(dim=-1) # [B, N] True = padded
else:
cv_invalid = None
h = self.sab(h, key_padding_mask=cv_invalid)
h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context]
h = h.squeeze(1)
context = self.rho(h)
return context
def _forward_joint(self, x_image, x_waveform, scan_mask_image=None,
scan_mask_waveform=None, sigmas=None, flux_scales=None):
"""Joint image+waveform forward; mirrors MultiScanEncoder._forward_joint."""
if x_image.dim() != 5:
raise ValueError(
f"Joint mode x_image expected [B,N,1,H,W]; got {tuple(x_image.shape)}"
)
if x_waveform.dim() != 4:
raise ValueError(
f"Joint mode x_waveform expected [B,N,2,T]; got {tuple(x_waveform.shape)}"
)
B, N = x_image.shape[0], x_image.shape[1]
device = x_image.device
x_img_flat = x_image.reshape(B * N, *x_image.shape[2:])
h_img_flat = self.image_encoder(x_img_flat)
T_w = x_waveform.shape[-1]
x_wave_flat = x_waveform.reshape(B * N, x_waveform.shape[2], T_w)
wave_mask_flat = (
scan_mask_waveform.reshape(B * N, T_w)
if scan_mask_waveform is not None else None
)
h_wave_flat = self.waveform_encoder(x_wave_flat, mask=wave_mask_flat)
h_joint_flat = self.joint_fusion(
torch.cat([h_img_flat, h_wave_flat], dim=-1)
)
h = h_joint_flat.reshape(B, N, -1)
if sigmas is None:
sigmas = torch.zeros(B, N, device=device)
if flux_scales is None:
flux_scales = torch.zeros(B, N, device=device)
aug_features = torch.stack([sigmas, flux_scales], dim=-1)
h = self.cv_augment(torch.cat([h, aug_features], dim=-1))
if scan_mask_image is not None:
cv_invalid = ~scan_mask_image.bool()
elif scan_mask_waveform is not None:
cv_invalid = ~scan_mask_waveform.any(dim=-1)
else:
cv_invalid = None
h = self.sab(h, key_padding_mask=cv_invalid)
h = self.pma(h, key_padding_mask=cv_invalid)
h = h.squeeze(1)
return self.rho(h)
def _forward_image(self, x, scan_mask=None, sigmas=None, flux_scales=None):
"""Image-mode forward; mirrors MultiScanEncoder._forward_image."""
if x.dim() != 5:
raise ValueError(
f"Image mode expects x of shape [B, N, C, H, W]; got {tuple(x.shape)}"
)
B, N, C, H, W = x.shape
x_flat = x.reshape(B * N, C, H, W)
h_flat = self.per_cv_encoder(x_flat)
h = h_flat.reshape(B, N, -1)
if sigmas is None:
sigmas = torch.zeros(B, N, device=x.device)
if flux_scales is None:
flux_scales = torch.zeros(B, N, device=x.device)
aug_features = torch.stack([sigmas, flux_scales], dim=-1)
h = self.cv_augment(torch.cat([h, aug_features], dim=-1))
if scan_mask is not None:
if scan_mask.dim() == 3:
cv_invalid = ~scan_mask.any(dim=-1)
else:
cv_invalid = ~scan_mask.bool()
else:
cv_invalid = None
h = self.sab(h, key_padding_mask=cv_invalid)
h = self.pma(h, key_padding_mask=cv_invalid)
h = h.squeeze(1)
context = self.rho(h)
return context
class MultiMechanismFlowTPD(nn.Module):
"""
Joint mechanism identification and parameter inference model for TPD.
Combines:
- Multi-heating-rate signal encoder (Set Transformer over per-curve embeddings)
- Mechanism classifier (6 TPD mechanisms)
- Per-mechanism normalizing flow heads
If use_summary_features=True, replaces the signal encoder with a simple
MLP projection from hand-crafted summary statistics (21-dim) to context
space, keeping all other components identical.
"""
def __init__(
self,
d_context=128,
d_model=128,
n_coupling_layers=6,
hidden_dim=96,
coupling_type='spline',
n_bins=8,
tail_bound=5.0,
use_summary_features=False,
mechanism_list=None,
use_bounded_flow=False,
input_mode='waveform',
image_in_channels=1,
):
super().__init__()
mech_list = mechanism_list if mechanism_list is not None else TPD_MECHANISM_LIST
self.n_mechanisms = len(mech_list)
self.mechanism_list = mech_list
self.d_context = d_context
self.use_summary_features = use_summary_features
self.use_bounded_flow = use_bounded_flow
self.input_mode = input_mode
if use_summary_features:
self.summary_proj = SummaryProjection(
summary_dim=SUMMARY_DIM, d_context=d_context,
)
self.encoder = None
else:
self.encoder = MultiScanEncoderTPD(
in_channels=2, d_model=d_model, d_context=d_context,
input_mode=input_mode,
image_in_channels=image_in_channels,
)
self.summary_proj = None
self.classifier = MechanismClassifier(
d_context=d_context,
n_mechanisms=self.n_mechanisms,
hidden_dim=hidden_dim,
)
self.flow_heads = nn.ModuleDict()
for mech in mech_list:
theta_dim = TPD_MECHANISM_PARAMS[mech]['dim']
if use_bounded_flow:
self.flow_heads[mech] = BoundedMechanismFlow(
theta_dim=theta_dim,
param_bijectors=build_param_bijectors(mech),
d_context=d_context,
n_coupling_layers=n_coupling_layers,
hidden_dim=hidden_dim,
coupling_type=coupling_type,
n_bins=n_bins,
tail_bound=tail_bound,
)
else:
self.flow_heads[mech] = MechanismFlow(
theta_dim=theta_dim,
d_context=d_context,
n_coupling_layers=n_coupling_layers,
hidden_dim=hidden_dim,
coupling_type=coupling_type,
n_bins=n_bins,
tail_bound=tail_bound,
)
self.ood_head = None
self.ood_head_extra_dim = 0
def init_ood_head(self, hidden_dim=64, dropout=0.3, extra_input_dim=0):
"""Initialize the binary OOD detection head.
Takes context vector + softmax probs (+ optional extra features)
as input, outputs logit for P(in-distribution).
Mirrors `MultiMechanismFlow.init_ood_head` so the same training
recipe (`train_ood_head_tpd.py`) can be applied.
"""
input_dim = self.d_context + self.n_mechanisms + extra_input_dim
self.ood_head = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
self.ood_head_extra_dim = extra_input_dim
return self.ood_head
def set_theta_stats(self, mechanism, mean, std):
"""Set normalization stats for a specific mechanism's flow head.
IMPORTANT: when this model was constructed with use_bounded_flow=True,
`mean` and `std` must already be in U-SPACE (post-bijector). Use
`compute_stats_for_mechanism` below to convert raw physical samples.
"""
self.flow_heads[mechanism].set_theta_stats(mean, std)
def compute_stats_for_mechanism(self, mechanism, physical_thetas):
"""Compute mean/std for `set_theta_stats` from raw physical samples.
- If use_bounded_flow=False: returns physical-space mean/std.
- If use_bounded_flow=True : applies the per-parameter bijectors
to `physical_thetas` first, then returns u-space mean/std.
Args:
mechanism: mechanism name in self.mechanism_list.
physical_thetas: [N, theta_dim] tensor of training-set thetas
in their natural physical units.
Returns:
(mean, std) tensors of shape [theta_dim] suitable for
set_theta_stats(mechanism, ...).
"""
t = physical_thetas
if self.use_bounded_flow:
head = self.flow_heads[mechanism]
u, _ = apply_param_bijectors_forward(t, list(head.param_bijectors))
t = u
mean = t.mean(dim=0)
std = t.std(dim=0).clamp(min=0.1) if t.shape[0] > 1 else torch.ones_like(mean)
return mean, std
def encode_signal(self, x, scan_mask=None, sigmas=None, flux_scales=None,
summary=None):
if self.use_summary_features:
assert summary is not None, "summary features required in summary mode"
return self.summary_proj(summary)
return self.encoder(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales)
@staticmethod
def _batch_meta(x):
"""Return (batch_size, device) from x, which may be a dict (joint mode)
or a tensor (single-modality)."""
if isinstance(x, dict):
ref = x['image'] if 'image' in x else x['waveform']
return ref.shape[0], ref.device
return x.shape[0], x.device
def _forward_impl(self, x, mechanism_ids, mech_theta, mech_theta_mask=None,
scan_mask=None, sigmas=None, flux_scales=None, summary=None,
return_calibration=False, cal_n_samples=64,
cal_levels=(0.5, 0.9), cal_beta=20.0,
return_context=False,
return_min_var=False, min_var_n_samples=32,
min_log_var=-6.0):
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
logits = self.classifier(context)
B, device = self._batch_meta(x)
nll = torch.zeros(B, device=device)
cal_losses = []
mv_penalties = []
for m_idx, mech in enumerate(self.mechanism_list):
sel = (mechanism_ids == m_idx)
if not sel.any():
continue
theta_dim = TPD_MECHANISM_PARAMS[mech]['dim']
ctx_m = context[sel]
theta_m = mech_theta[sel, :theta_dim]
log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m)
bad = ~torch.isfinite(log_p)
if bad.any():
log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p)
nll[sel] = -log_p
need_cal = return_calibration and ctx_m.shape[0] >= 4
need_mv = return_min_var and ctx_m.shape[0] >= 2
if need_cal:
samples = self.flow_heads[mech].sample_with_grad(
ctx_m, n_samples=cal_n_samples,
)
with torch.no_grad():
param_std = samples.std(dim=1).clamp(min=1e-4) # [B_m, D]
inv_spread_w = 1.0 / param_std # [B_m, D]
inv_spread_w = inv_spread_w / inv_spread_w.mean()
for level in cal_levels:
alpha = (1.0 - level) / 2.0
lower = torch.quantile(samples, alpha, dim=1) # [B_m, D]
upper = torch.quantile(samples, 1 - alpha, dim=1) # [B_m, D]
inside = self._straight_through_containment(
theta_m, lower, upper, cal_beta) # [B_m, D]
per_sample_loss = (inside - level).pow(2) # [B_m, D]
cal_losses.append((per_sample_loss * inv_spread_w).mean())
if need_mv:
log_var = torch.log(samples.var(dim=1) + 1e-8)
penalty = torch.relu(min_log_var - log_var)
mv_penalties.append(penalty.mean())
elif need_mv:
mv_samples = self.flow_heads[mech].sample_with_grad(
ctx_m, n_samples=min_var_n_samples,
)
log_var = torch.log(mv_samples.var(dim=1) + 1e-8)
penalty = torch.relu(min_log_var - log_var)
mv_penalties.append(penalty.mean())
out = {'logits': logits, 'nll': nll}
if return_calibration:
if cal_losses:
out['cal_loss'] = torch.stack(cal_losses).mean()
else:
out['cal_loss'] = torch.tensor(0.0, device=device)
if return_min_var:
if mv_penalties:
out['min_var_loss'] = torch.stack(mv_penalties).mean()
else:
out['min_var_loss'] = torch.tensor(0.0, device=device)
if return_context:
out['context'] = context
return out
def forward(self, x, mechanism_ids, mech_theta, mech_theta_mask=None,
scan_mask=None, sigmas=None, flux_scales=None, summary=None,
return_calibration=False, cal_n_samples=64,
cal_levels=(0.5, 0.9), cal_beta=20.0,
return_context=False,
return_min_var=False, min_var_n_samples=32,
min_log_var=-6.0):
"""
Compute classification logits and per-sample NLL for the true mechanism.
Returns:
dict with 'logits' [B, n_mechanisms] and 'nll' [B]
"""
return self._forward_impl(
x, mechanism_ids, mech_theta, mech_theta_mask,
scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales,
summary=summary,
return_calibration=return_calibration,
cal_n_samples=cal_n_samples,
cal_levels=cal_levels,
cal_beta=cal_beta,
return_context=return_context,
return_min_var=return_min_var,
min_var_n_samples=min_var_n_samples,
min_log_var=min_log_var,
)
@staticmethod
def _straight_through_containment(theta, lower, upper, beta):
"""Hard containment indicator with straight-through gradient estimator.
Forward: hard 0/1 indicator (unbiased coverage estimate).
Backward: sigmoid gradient (smooth, trainable).
"""
soft = (torch.sigmoid(beta * (theta - lower))
* torch.sigmoid(beta * (upper - theta)))
hard = ((theta >= lower) & (theta <= upper)).float()
return hard + (soft - soft.detach()) # STE: hard forward, soft backward
def forward_with_calibration(self, x, mechanism_ids, mech_theta,
mech_theta_mask=None, scan_mask=None,
sigmas=None, flux_scales=None,
cal_n_samples=64, cal_levels=(0.5, 0.9),
cal_beta=20.0, summary=None):
"""Forward pass with additional calibration loss (see MultiMechanismFlow)."""
return self._forward_impl(
x, mechanism_ids, mech_theta, mech_theta_mask,
scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales,
summary=summary,
return_calibration=True,
cal_n_samples=cal_n_samples,
cal_levels=cal_levels,
cal_beta=cal_beta,
)
@torch.no_grad()
def predict(self, x, scan_mask=None, sigmas=None, flux_scales=None,
n_samples=200, top_k=None, temperature=1.0,
temperature_map=None, summary=None):
"""
Full inference: classify mechanism, then sample parameters.
Args:
temperature: scalar fallback (>1 broadens posteriors)
temperature_map: dict mapping mechanism name -> list of
per-parameter temperatures. Overrides scalar temperature.
summary: [B, 21] hand-crafted summary stats (summary mode only)
Returns:
dict with mechanism_probs, mechanism_xdB, mechanism_pred, samples, stats
"""
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
logits = self.classifier(context)
probs = F.softmax(logits, dim=-1)
pred = probs.argmax(dim=-1)
probs_clamped = probs.clamp(min=1e-7, max=1 - 1e-7)
xdB = 10.0 * torch.log10(probs_clamped / (1.0 - probs_clamped))
samples_dict = {}
stats_dict = {}
for m_idx, mech in enumerate(self.mechanism_list):
if top_k is not None:
top_k_mechs = probs.topk(top_k, dim=-1).indices
if not (top_k_mechs == m_idx).any():
samples_dict[mech] = None
stats_dict[mech] = None
continue
T = temperature
if temperature_map is not None and mech in temperature_map:
T = torch.tensor(temperature_map[mech], dtype=torch.float32)
s = self.flow_heads[mech].sample(context, n_samples=n_samples,
temperature=T)
samples_dict[mech] = s
stats_dict[mech] = {
'mean': s.mean(dim=1),
'std': s.std(dim=1),
'median': s.median(dim=1).values,
'q05': s.quantile(0.05, dim=1),
'q95': s.quantile(0.95, dim=1),
}
ood_score = None
if self.ood_head is not None:
ood_input = torch.cat([context, probs], dim=-1)
ood_logit = self.ood_head(ood_input).squeeze(-1)
ood_score = torch.sigmoid(ood_logit)
return {
'mechanism_probs': probs,
'mechanism_xdB': xdB,
'mechanism_pred': pred,
'samples': samples_dict,
'stats': stats_dict,
'ood_score': ood_score,
}
def predict_single_mechanism(self, x, mechanism, scan_mask=None,
sigmas=None, flux_scales=None, n_samples=1000,
temperature=1.0, temperature_map=None,
summary=None):
"""Sample parameters assuming a known mechanism."""
context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, summary=summary)
T = temperature
if temperature_map is not None and mechanism in temperature_map:
T = torch.tensor(temperature_map[mechanism], dtype=torch.float32)
samples = self.flow_heads[mechanism].sample(context, n_samples=n_samples,
temperature=T)
return {
'mean': samples.mean(dim=1),
'std': samples.std(dim=1),
'median': samples.median(dim=1).values,
'q05': samples.quantile(0.05, dim=1),
'q95': samples.quantile(0.95, dim=1),
'samples': samples,
}
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == "__main__":
n_mechs = len(TPD_MECHANISM_LIST)
B, N_beta, T = n_mechs, 3, 500
x = torch.randn(B, N_beta, 2, T)
scan_mask = torch.ones(B, N_beta, T, dtype=torch.bool)
sigmas = torch.randn(B, N_beta)
flux_scales = torch.randn(B, N_beta)
mechanism_ids = torch.arange(n_mechs)
max_dim = max(TPD_MECHANISM_PARAMS[m]['dim'] for m in TPD_MECHANISM_LIST)
mech_theta = torch.randn(B, max_dim)
mech_theta_mask = torch.zeros(B, max_dim, dtype=torch.bool)
for i, mid in enumerate(mechanism_ids):
d = TPD_MECHANISM_PARAMS[TPD_MECHANISM_LIST[mid]]['dim']
mech_theta_mask[i, :d] = True
print("=" * 60)
print("Testing MultiMechanismFlowTPD (multi-heating-rate, Set Transformer)")
print("=" * 60)
model = MultiMechanismFlowTPD(
d_context=128,
d_model=128,
n_coupling_layers=8,
hidden_dim=128,
coupling_type='affine',
)
total_params = count_parameters(model)
print(f"Total parameters: {total_params:,}")
print(f" Encoder: {count_parameters(model.encoder):,}")
print(f" Classifier: {count_parameters(model.classifier):,}")
for mech in TPD_MECHANISM_LIST:
print(f" Flow ({mech}, dim={TPD_MECHANISM_PARAMS[mech]['dim']}): "
f"{count_parameters(model.flow_heads[mech]):,}")
out = model(x, mechanism_ids, mech_theta, mech_theta_mask,
scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales)
print(f"\nForward pass:")
print(f" Logits shape: {out['logits'].shape}")
print(f" NLL shape: {out['nll'].shape}")
print(f" NLL values: {out['nll']}")
pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas,
flux_scales=flux_scales, n_samples=100)
print(f"\nPrediction:")
print(f" Mechanism probs shape: {pred['mechanism_probs'].shape}")
print(f" Mechanism xdB shape: {pred['mechanism_xdB'].shape}")
print(f" Predicted mechanisms: {pred['mechanism_pred']}")
for mech in TPD_MECHANISM_LIST:
if pred['samples'][mech] is not None:
print(f" {mech} samples shape: {pred['samples'][mech].shape}")