| """ |
| Multi-Mechanism Normalizing Flow for Joint Mechanism Identification |
| and Bayesian Parameter Inference from Multi-Heating-Rate TPD Signals. |
| |
| Architecture mirrors the electrochemistry model (multi_mechanism_model.py) |
| but configured for TPD with 2 input channels (temperature, rate) and |
| 6 catalysis mechanisms. |
| |
| Reuses domain-agnostic components from flow_model.py and |
| multi_mechanism_model.py (SAB, PMA, MechanismClassifier, MechanismFlow). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| from flow_model import ( |
| SignalEncoder, |
| ActNorm, |
| ConditionalSplineCoupling, |
| ConditionalAffineCoupling, |
| ) |
| from multi_mechanism_model import MechanismClassifier, MechanismFlow, SAB, PMA, SummaryProjection, SUMMARY_DIM |
| from image_encoder import ImageEncoder |
|
|
| from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS |
|
|
| from tpd_bijectors import ( |
| Identity, Logit, AffineLogit, |
| apply_param_bijectors_forward, apply_param_bijectors_inverse, |
| all_identity, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| TPD_PARAM_BIJECTORS = { |
| |
| |
| |
| 'theta_0': ('logit',), |
| 'theta_A0': ('logit',), |
| 'theta_B0': ('logit',), |
| 'theta_O0': ('logit',), |
| |
| 'f_site1': ('logit',), |
| |
| |
| 'n_layers': ('affine_logit', 1.0, 10.0), |
| |
| |
| } |
|
|
|
|
| def _make_bijector(spec): |
| if spec is None: |
| return Identity() |
| kind = spec[0] |
| if kind == 'identity': |
| return Identity() |
| if kind == 'logit': |
| return Logit() |
| if kind == 'affine_logit': |
| _, low, high = spec |
| return AffineLogit(low, high) |
| raise ValueError(f"Unknown bijector spec: {spec}") |
|
|
|
|
| def build_param_bijectors(mech: str): |
| """Return a list of bijectors (one per parameter) for `mech`, |
| in the same order as TPD_MECHANISM_PARAMS[mech]['names']. |
| """ |
| names = TPD_MECHANISM_PARAMS[mech]['names'] |
| return [_make_bijector(TPD_PARAM_BIJECTORS.get(n)) for n in names] |
|
|
|
|
| |
| |
| |
|
|
| class BoundedMechanismFlow(MechanismFlow): |
| """MechanismFlow + per-parameter reparameterization bijectors. |
| |
| Density of the physical parameter `theta` is computed in the |
| unbounded representation `u = bij(theta)`, then transformed back |
| via the change-of-variables rule: |
| |
| log p_theta(theta) = log p_u(u) + log|du/dtheta| |
| |
| where `log p_u(u)` is exactly what the base MechanismFlow would |
| compute on `u` (z-score normalize -> flow -> Gaussian base). |
| |
| `theta_mean` and `theta_std` (inherited from MechanismFlow) now |
| refer to U-SPACE statistics, not physical-space statistics. Callers |
| must compute them in u-space (see |
| `MultiMechanismFlowTPD.compute_u_space_stats`). |
| """ |
|
|
| def __init__(self, theta_dim, param_bijectors, **kwargs): |
| super().__init__(theta_dim=theta_dim, **kwargs) |
| assert len(param_bijectors) == theta_dim, ( |
| f"BoundedMechanismFlow expects {theta_dim} bijectors, got {len(param_bijectors)}" |
| ) |
| |
| |
| self.param_bijectors = nn.ModuleList(param_bijectors) |
| self._bijector_is_noop = all_identity(self.param_bijectors) |
|
|
| |
| def log_prob(self, theta, context): |
| if self._bijector_is_noop: |
| return super().log_prob(theta, context) |
|
|
| |
| u, log_det_b = apply_param_bijectors_forward(theta, list(self.param_bijectors)) |
|
|
| |
| u_norm = self.normalize_theta(u) |
| if self.coupling_type == 'spline': |
| u_norm = u_norm.clamp(-self.tail_bound, self.tail_bound) |
| z, log_det_flow = self.inverse_flow(u_norm, context) |
|
|
| |
| |
| |
| log_pz = -0.5 * (z ** 2 + math.log(2 * math.pi)).sum(dim=-1) |
| log_det_norm = -torch.log(self.theta_std).sum() |
|
|
| log_p = log_pz + log_det_flow + log_det_norm + log_det_b |
| return log_p.clamp(min=-50.0, max=50.0) |
|
|
| |
| @torch.no_grad() |
| def sample(self, context, n_samples=100, temperature=1.0): |
| if self._bijector_is_noop: |
| return super().sample(context, n_samples=n_samples, temperature=temperature) |
|
|
| B = context.shape[0] |
| context_rep = (context.unsqueeze(1) |
| .expand(-1, n_samples, -1) |
| .reshape(B * n_samples, -1)) |
| z = torch.randn(B * n_samples, self.theta_dim, device=context.device) |
| u_norm, _ = self.forward_flow(z, context_rep) |
| u = self.denormalize_theta(u_norm) |
| u = u.reshape(B, n_samples, self.theta_dim) |
|
|
| |
| |
| if isinstance(temperature, torch.Tensor): |
| T = temperature.to(u.device).reshape(1, 1, -1) |
| mu = u.mean(dim=1, keepdim=True) |
| u = mu + T * (u - mu) |
| elif temperature != 1.0: |
| mu = u.mean(dim=1, keepdim=True) |
| u = mu + temperature * (u - mu) |
|
|
| |
| u_flat = u.reshape(B * n_samples, self.theta_dim) |
| theta_flat, _ = apply_param_bijectors_inverse(u_flat, list(self.param_bijectors)) |
| return theta_flat.reshape(B, n_samples, self.theta_dim) |
|
|
| def sample_with_grad(self, context, n_samples=64): |
| if self._bijector_is_noop: |
| return super().sample_with_grad(context, n_samples=n_samples) |
|
|
| B = context.shape[0] |
| context_rep = (context.unsqueeze(1) |
| .expand(-1, n_samples, -1) |
| .reshape(B * n_samples, -1)) |
| z = torch.randn(B * n_samples, self.theta_dim, device=context.device) |
| u_norm, _ = self.forward_flow(z, context_rep) |
| u = self.denormalize_theta(u_norm) |
| |
| |
| theta, _ = apply_param_bijectors_inverse(u, list(self.param_bijectors)) |
| return theta.reshape(B, n_samples, self.theta_dim) |
|
|
|
|
| class MultiScanEncoderTPD(nn.Module): |
| """ |
| Encode a set of multi-heating-rate TPD curves into a single context vector. |
| |
| Architecture (Set Transformer): |
| 1. Shared per-curve encoder -> per-curve embedding |
| - input_mode='waveform': 1-D CNN over [B*N, in_channels, T] |
| - input_mode='image': 2-D CNN over [B*N, 1, H, W] |
| 2. Augment with [log10(heating_rate), log10(peak_rate)] |
| 3. SAB: self-attention across heating rates |
| 4. PMA: attention-based pooling to single vector |
| 5. rho MLP: project to final context |
| |
| Waveform input: x [B, N_beta, 2, T], scan_mask [B, N_beta, T], |
| heating_rates [B, N_beta], rate_scales [B, N_beta] |
| Image input: x [B, N_beta, 1, H, W], scan_mask [B, N_beta] |
| (per-curve presence flag), heating_rates [B, N_beta], |
| rate_scales [B, N_beta] |
| Output: context [B, d_context] |
| """ |
| def __init__(self, in_channels=2, d_model=128, d_context=128, n_heads=4, |
| input_mode='waveform', image_in_channels=1): |
| super().__init__() |
| if input_mode not in ('waveform', 'image', 'image+waveform'): |
| raise ValueError(f"Unknown input_mode: {input_mode!r}") |
| self.input_mode = input_mode |
| if input_mode == 'waveform': |
| self.per_cv_encoder = SignalEncoder( |
| in_channels=in_channels, d_model=d_model, d_context=d_context, |
| ) |
| elif input_mode == 'image': |
| self.per_cv_encoder = ImageEncoder( |
| in_channels=image_in_channels, d_model=d_model, |
| d_context=d_context, |
| ) |
| else: |
| |
| self.image_encoder = ImageEncoder( |
| in_channels=image_in_channels, d_model=d_model, |
| d_context=d_context, |
| ) |
| self.waveform_encoder = SignalEncoder( |
| in_channels=in_channels, d_model=d_model, d_context=d_context, |
| ) |
| self.joint_fusion = nn.Sequential( |
| nn.Linear(2 * d_context, d_context), |
| nn.GELU(), |
| nn.Linear(d_context, d_context), |
| ) |
| self.cv_augment = nn.Sequential( |
| nn.Linear(d_context + 2, d_context), |
| nn.GELU(), |
| ) |
| self.sab = SAB(d_context, n_heads=n_heads) |
| self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1) |
| self.rho = nn.Sequential( |
| nn.Linear(d_context, d_context), |
| nn.GELU(), |
| nn.Linear(d_context, d_context), |
| ) |
|
|
| def forward(self, x, scan_mask=None, sigmas=None, flux_scales=None): |
| """ |
| Waveform args: |
| x: [B, N_beta, 2, T] multi-heating-rate TPD curves |
| scan_mask: [B, N_beta, T] valid timestep mask |
| sigmas: [B, N_beta] log10 heating rates |
| flux_scales: [B, N_beta] log10(peak_rate) per curve |
| |
| Image args: |
| x: [B, N_beta, 1, H, W] grayscale plot images in [0, 1] |
| scan_mask: [B, N_beta] per-curve presence (or [B, N_beta, T] |
| back-compat, reduced via .any(dim=-1)) |
| sigmas, flux_scales: same as waveform mode. |
| |
| Returns: |
| context: [B, d_context] |
| """ |
| if self.input_mode == 'image': |
| return self._forward_image(x, scan_mask=scan_mask, |
| sigmas=sigmas, flux_scales=flux_scales) |
| if self.input_mode == 'image+waveform': |
| if not isinstance(x, dict): |
| raise ValueError( |
| "image+waveform mode expects x to be a dict with keys " |
| "'image' and 'waveform'; got tensor" |
| ) |
| return self._forward_joint( |
| x['image'], x['waveform'], |
| scan_mask_image=x.get('scan_mask_image'), |
| scan_mask_waveform=x.get('scan_mask_waveform', scan_mask), |
| sigmas=sigmas, flux_scales=flux_scales, |
| ) |
|
|
| B, N, C, T = x.shape |
|
|
| x_flat = x.reshape(B * N, C, T) |
| mask_flat = scan_mask.reshape(B * N, T) if scan_mask is not None else None |
|
|
| h_flat = self.per_cv_encoder(x_flat, mask=mask_flat) |
| h = h_flat.reshape(B, N, -1) |
|
|
| if sigmas is None: |
| sigmas = torch.zeros(B, N, device=x.device) |
| if flux_scales is None: |
| flux_scales = torch.zeros(B, N, device=x.device) |
|
|
| aug_features = torch.stack([sigmas, flux_scales], dim=-1) |
| h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) |
|
|
| if scan_mask is not None: |
| cv_invalid = ~scan_mask.any(dim=-1) |
| else: |
| cv_invalid = None |
|
|
| h = self.sab(h, key_padding_mask=cv_invalid) |
| h = self.pma(h, key_padding_mask=cv_invalid) |
| h = h.squeeze(1) |
|
|
| context = self.rho(h) |
| return context |
|
|
| def _forward_joint(self, x_image, x_waveform, scan_mask_image=None, |
| scan_mask_waveform=None, sigmas=None, flux_scales=None): |
| """Joint image+waveform forward; mirrors MultiScanEncoder._forward_joint.""" |
| if x_image.dim() != 5: |
| raise ValueError( |
| f"Joint mode x_image expected [B,N,1,H,W]; got {tuple(x_image.shape)}" |
| ) |
| if x_waveform.dim() != 4: |
| raise ValueError( |
| f"Joint mode x_waveform expected [B,N,2,T]; got {tuple(x_waveform.shape)}" |
| ) |
| B, N = x_image.shape[0], x_image.shape[1] |
| device = x_image.device |
|
|
| x_img_flat = x_image.reshape(B * N, *x_image.shape[2:]) |
| h_img_flat = self.image_encoder(x_img_flat) |
|
|
| T_w = x_waveform.shape[-1] |
| x_wave_flat = x_waveform.reshape(B * N, x_waveform.shape[2], T_w) |
| wave_mask_flat = ( |
| scan_mask_waveform.reshape(B * N, T_w) |
| if scan_mask_waveform is not None else None |
| ) |
| h_wave_flat = self.waveform_encoder(x_wave_flat, mask=wave_mask_flat) |
|
|
| h_joint_flat = self.joint_fusion( |
| torch.cat([h_img_flat, h_wave_flat], dim=-1) |
| ) |
| h = h_joint_flat.reshape(B, N, -1) |
|
|
| if sigmas is None: |
| sigmas = torch.zeros(B, N, device=device) |
| if flux_scales is None: |
| flux_scales = torch.zeros(B, N, device=device) |
| aug_features = torch.stack([sigmas, flux_scales], dim=-1) |
| h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) |
|
|
| if scan_mask_image is not None: |
| cv_invalid = ~scan_mask_image.bool() |
| elif scan_mask_waveform is not None: |
| cv_invalid = ~scan_mask_waveform.any(dim=-1) |
| else: |
| cv_invalid = None |
|
|
| h = self.sab(h, key_padding_mask=cv_invalid) |
| h = self.pma(h, key_padding_mask=cv_invalid) |
| h = h.squeeze(1) |
| return self.rho(h) |
|
|
| def _forward_image(self, x, scan_mask=None, sigmas=None, flux_scales=None): |
| """Image-mode forward; mirrors MultiScanEncoder._forward_image.""" |
| if x.dim() != 5: |
| raise ValueError( |
| f"Image mode expects x of shape [B, N, C, H, W]; got {tuple(x.shape)}" |
| ) |
| B, N, C, H, W = x.shape |
| x_flat = x.reshape(B * N, C, H, W) |
| h_flat = self.per_cv_encoder(x_flat) |
| h = h_flat.reshape(B, N, -1) |
|
|
| if sigmas is None: |
| sigmas = torch.zeros(B, N, device=x.device) |
| if flux_scales is None: |
| flux_scales = torch.zeros(B, N, device=x.device) |
|
|
| aug_features = torch.stack([sigmas, flux_scales], dim=-1) |
| h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) |
|
|
| if scan_mask is not None: |
| if scan_mask.dim() == 3: |
| cv_invalid = ~scan_mask.any(dim=-1) |
| else: |
| cv_invalid = ~scan_mask.bool() |
| else: |
| cv_invalid = None |
|
|
| h = self.sab(h, key_padding_mask=cv_invalid) |
| h = self.pma(h, key_padding_mask=cv_invalid) |
| h = h.squeeze(1) |
|
|
| context = self.rho(h) |
| return context |
|
|
|
|
| class MultiMechanismFlowTPD(nn.Module): |
| """ |
| Joint mechanism identification and parameter inference model for TPD. |
| |
| Combines: |
| - Multi-heating-rate signal encoder (Set Transformer over per-curve embeddings) |
| - Mechanism classifier (6 TPD mechanisms) |
| - Per-mechanism normalizing flow heads |
| |
| If use_summary_features=True, replaces the signal encoder with a simple |
| MLP projection from hand-crafted summary statistics (21-dim) to context |
| space, keeping all other components identical. |
| """ |
| def __init__( |
| self, |
| d_context=128, |
| d_model=128, |
| n_coupling_layers=6, |
| hidden_dim=96, |
| coupling_type='spline', |
| n_bins=8, |
| tail_bound=5.0, |
| use_summary_features=False, |
| mechanism_list=None, |
| use_bounded_flow=False, |
| input_mode='waveform', |
| image_in_channels=1, |
| ): |
| super().__init__() |
| mech_list = mechanism_list if mechanism_list is not None else TPD_MECHANISM_LIST |
| self.n_mechanisms = len(mech_list) |
| self.mechanism_list = mech_list |
| self.d_context = d_context |
| self.use_summary_features = use_summary_features |
| self.use_bounded_flow = use_bounded_flow |
| self.input_mode = input_mode |
|
|
| if use_summary_features: |
| self.summary_proj = SummaryProjection( |
| summary_dim=SUMMARY_DIM, d_context=d_context, |
| ) |
| self.encoder = None |
| else: |
| self.encoder = MultiScanEncoderTPD( |
| in_channels=2, d_model=d_model, d_context=d_context, |
| input_mode=input_mode, |
| image_in_channels=image_in_channels, |
| ) |
| self.summary_proj = None |
|
|
| self.classifier = MechanismClassifier( |
| d_context=d_context, |
| n_mechanisms=self.n_mechanisms, |
| hidden_dim=hidden_dim, |
| ) |
|
|
| self.flow_heads = nn.ModuleDict() |
| for mech in mech_list: |
| theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] |
| if use_bounded_flow: |
| self.flow_heads[mech] = BoundedMechanismFlow( |
| theta_dim=theta_dim, |
| param_bijectors=build_param_bijectors(mech), |
| d_context=d_context, |
| n_coupling_layers=n_coupling_layers, |
| hidden_dim=hidden_dim, |
| coupling_type=coupling_type, |
| n_bins=n_bins, |
| tail_bound=tail_bound, |
| ) |
| else: |
| self.flow_heads[mech] = MechanismFlow( |
| theta_dim=theta_dim, |
| d_context=d_context, |
| n_coupling_layers=n_coupling_layers, |
| hidden_dim=hidden_dim, |
| coupling_type=coupling_type, |
| n_bins=n_bins, |
| tail_bound=tail_bound, |
| ) |
|
|
| self.ood_head = None |
| self.ood_head_extra_dim = 0 |
|
|
| def init_ood_head(self, hidden_dim=64, dropout=0.3, extra_input_dim=0): |
| """Initialize the binary OOD detection head. |
| |
| Takes context vector + softmax probs (+ optional extra features) |
| as input, outputs logit for P(in-distribution). |
| |
| Mirrors `MultiMechanismFlow.init_ood_head` so the same training |
| recipe (`train_ood_head_tpd.py`) can be applied. |
| """ |
| input_dim = self.d_context + self.n_mechanisms + extra_input_dim |
| self.ood_head = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, 1), |
| ) |
| self.ood_head_extra_dim = extra_input_dim |
| return self.ood_head |
|
|
| def set_theta_stats(self, mechanism, mean, std): |
| """Set normalization stats for a specific mechanism's flow head. |
| |
| IMPORTANT: when this model was constructed with use_bounded_flow=True, |
| `mean` and `std` must already be in U-SPACE (post-bijector). Use |
| `compute_stats_for_mechanism` below to convert raw physical samples. |
| """ |
| self.flow_heads[mechanism].set_theta_stats(mean, std) |
|
|
| def compute_stats_for_mechanism(self, mechanism, physical_thetas): |
| """Compute mean/std for `set_theta_stats` from raw physical samples. |
| |
| - If use_bounded_flow=False: returns physical-space mean/std. |
| - If use_bounded_flow=True : applies the per-parameter bijectors |
| to `physical_thetas` first, then returns u-space mean/std. |
| |
| Args: |
| mechanism: mechanism name in self.mechanism_list. |
| physical_thetas: [N, theta_dim] tensor of training-set thetas |
| in their natural physical units. |
| |
| Returns: |
| (mean, std) tensors of shape [theta_dim] suitable for |
| set_theta_stats(mechanism, ...). |
| """ |
| t = physical_thetas |
| if self.use_bounded_flow: |
| head = self.flow_heads[mechanism] |
| u, _ = apply_param_bijectors_forward(t, list(head.param_bijectors)) |
| t = u |
| mean = t.mean(dim=0) |
| std = t.std(dim=0).clamp(min=0.1) if t.shape[0] > 1 else torch.ones_like(mean) |
| return mean, std |
|
|
| def encode_signal(self, x, scan_mask=None, sigmas=None, flux_scales=None, |
| summary=None): |
| if self.use_summary_features: |
| assert summary is not None, "summary features required in summary mode" |
| return self.summary_proj(summary) |
| return self.encoder(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales) |
|
|
| @staticmethod |
| def _batch_meta(x): |
| """Return (batch_size, device) from x, which may be a dict (joint mode) |
| or a tensor (single-modality).""" |
| if isinstance(x, dict): |
| ref = x['image'] if 'image' in x else x['waveform'] |
| return ref.shape[0], ref.device |
| return x.shape[0], x.device |
|
|
| def _forward_impl(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, |
| scan_mask=None, sigmas=None, flux_scales=None, summary=None, |
| return_calibration=False, cal_n_samples=64, |
| cal_levels=(0.5, 0.9), cal_beta=20.0, |
| return_context=False, |
| return_min_var=False, min_var_n_samples=32, |
| min_log_var=-6.0): |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| logits = self.classifier(context) |
|
|
| B, device = self._batch_meta(x) |
| nll = torch.zeros(B, device=device) |
| cal_losses = [] |
| mv_penalties = [] |
|
|
| for m_idx, mech in enumerate(self.mechanism_list): |
| sel = (mechanism_ids == m_idx) |
| if not sel.any(): |
| continue |
|
|
| theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] |
| ctx_m = context[sel] |
| theta_m = mech_theta[sel, :theta_dim] |
|
|
| log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m) |
| bad = ~torch.isfinite(log_p) |
| if bad.any(): |
| log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p) |
| nll[sel] = -log_p |
|
|
| need_cal = return_calibration and ctx_m.shape[0] >= 4 |
| need_mv = return_min_var and ctx_m.shape[0] >= 2 |
|
|
| if need_cal: |
| samples = self.flow_heads[mech].sample_with_grad( |
| ctx_m, n_samples=cal_n_samples, |
| ) |
|
|
| with torch.no_grad(): |
| param_std = samples.std(dim=1).clamp(min=1e-4) |
| inv_spread_w = 1.0 / param_std |
| inv_spread_w = inv_spread_w / inv_spread_w.mean() |
|
|
| for level in cal_levels: |
| alpha = (1.0 - level) / 2.0 |
| lower = torch.quantile(samples, alpha, dim=1) |
| upper = torch.quantile(samples, 1 - alpha, dim=1) |
|
|
| inside = self._straight_through_containment( |
| theta_m, lower, upper, cal_beta) |
|
|
| per_sample_loss = (inside - level).pow(2) |
| cal_losses.append((per_sample_loss * inv_spread_w).mean()) |
|
|
| if need_mv: |
| log_var = torch.log(samples.var(dim=1) + 1e-8) |
| penalty = torch.relu(min_log_var - log_var) |
| mv_penalties.append(penalty.mean()) |
| elif need_mv: |
| mv_samples = self.flow_heads[mech].sample_with_grad( |
| ctx_m, n_samples=min_var_n_samples, |
| ) |
| log_var = torch.log(mv_samples.var(dim=1) + 1e-8) |
| penalty = torch.relu(min_log_var - log_var) |
| mv_penalties.append(penalty.mean()) |
|
|
| out = {'logits': logits, 'nll': nll} |
| if return_calibration: |
| if cal_losses: |
| out['cal_loss'] = torch.stack(cal_losses).mean() |
| else: |
| out['cal_loss'] = torch.tensor(0.0, device=device) |
| if return_min_var: |
| if mv_penalties: |
| out['min_var_loss'] = torch.stack(mv_penalties).mean() |
| else: |
| out['min_var_loss'] = torch.tensor(0.0, device=device) |
| if return_context: |
| out['context'] = context |
| return out |
|
|
| def forward(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, |
| scan_mask=None, sigmas=None, flux_scales=None, summary=None, |
| return_calibration=False, cal_n_samples=64, |
| cal_levels=(0.5, 0.9), cal_beta=20.0, |
| return_context=False, |
| return_min_var=False, min_var_n_samples=32, |
| min_log_var=-6.0): |
| """ |
| Compute classification logits and per-sample NLL for the true mechanism. |
| |
| Returns: |
| dict with 'logits' [B, n_mechanisms] and 'nll' [B] |
| """ |
| return self._forward_impl( |
| x, mechanism_ids, mech_theta, mech_theta_mask, |
| scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, |
| summary=summary, |
| return_calibration=return_calibration, |
| cal_n_samples=cal_n_samples, |
| cal_levels=cal_levels, |
| cal_beta=cal_beta, |
| return_context=return_context, |
| return_min_var=return_min_var, |
| min_var_n_samples=min_var_n_samples, |
| min_log_var=min_log_var, |
| ) |
|
|
| @staticmethod |
| def _straight_through_containment(theta, lower, upper, beta): |
| """Hard containment indicator with straight-through gradient estimator. |
| |
| Forward: hard 0/1 indicator (unbiased coverage estimate). |
| Backward: sigmoid gradient (smooth, trainable). |
| """ |
| soft = (torch.sigmoid(beta * (theta - lower)) |
| * torch.sigmoid(beta * (upper - theta))) |
| hard = ((theta >= lower) & (theta <= upper)).float() |
| return hard + (soft - soft.detach()) |
|
|
| def forward_with_calibration(self, x, mechanism_ids, mech_theta, |
| mech_theta_mask=None, scan_mask=None, |
| sigmas=None, flux_scales=None, |
| cal_n_samples=64, cal_levels=(0.5, 0.9), |
| cal_beta=20.0, summary=None): |
| """Forward pass with additional calibration loss (see MultiMechanismFlow).""" |
| return self._forward_impl( |
| x, mechanism_ids, mech_theta, mech_theta_mask, |
| scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, |
| summary=summary, |
| return_calibration=True, |
| cal_n_samples=cal_n_samples, |
| cal_levels=cal_levels, |
| cal_beta=cal_beta, |
| ) |
|
|
| @torch.no_grad() |
| def predict(self, x, scan_mask=None, sigmas=None, flux_scales=None, |
| n_samples=200, top_k=None, temperature=1.0, |
| temperature_map=None, summary=None): |
| """ |
| Full inference: classify mechanism, then sample parameters. |
| |
| Args: |
| temperature: scalar fallback (>1 broadens posteriors) |
| temperature_map: dict mapping mechanism name -> list of |
| per-parameter temperatures. Overrides scalar temperature. |
| summary: [B, 21] hand-crafted summary stats (summary mode only) |
| |
| Returns: |
| dict with mechanism_probs, mechanism_xdB, mechanism_pred, samples, stats |
| """ |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| logits = self.classifier(context) |
| probs = F.softmax(logits, dim=-1) |
| pred = probs.argmax(dim=-1) |
|
|
| probs_clamped = probs.clamp(min=1e-7, max=1 - 1e-7) |
| xdB = 10.0 * torch.log10(probs_clamped / (1.0 - probs_clamped)) |
|
|
| samples_dict = {} |
| stats_dict = {} |
|
|
| for m_idx, mech in enumerate(self.mechanism_list): |
| if top_k is not None: |
| top_k_mechs = probs.topk(top_k, dim=-1).indices |
| if not (top_k_mechs == m_idx).any(): |
| samples_dict[mech] = None |
| stats_dict[mech] = None |
| continue |
|
|
| T = temperature |
| if temperature_map is not None and mech in temperature_map: |
| T = torch.tensor(temperature_map[mech], dtype=torch.float32) |
|
|
| s = self.flow_heads[mech].sample(context, n_samples=n_samples, |
| temperature=T) |
| samples_dict[mech] = s |
| stats_dict[mech] = { |
| 'mean': s.mean(dim=1), |
| 'std': s.std(dim=1), |
| 'median': s.median(dim=1).values, |
| 'q05': s.quantile(0.05, dim=1), |
| 'q95': s.quantile(0.95, dim=1), |
| } |
|
|
| ood_score = None |
| if self.ood_head is not None: |
| ood_input = torch.cat([context, probs], dim=-1) |
| ood_logit = self.ood_head(ood_input).squeeze(-1) |
| ood_score = torch.sigmoid(ood_logit) |
|
|
| return { |
| 'mechanism_probs': probs, |
| 'mechanism_xdB': xdB, |
| 'mechanism_pred': pred, |
| 'samples': samples_dict, |
| 'stats': stats_dict, |
| 'ood_score': ood_score, |
| } |
|
|
| def predict_single_mechanism(self, x, mechanism, scan_mask=None, |
| sigmas=None, flux_scales=None, n_samples=1000, |
| temperature=1.0, temperature_map=None, |
| summary=None): |
| """Sample parameters assuming a known mechanism.""" |
| context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, summary=summary) |
| T = temperature |
| if temperature_map is not None and mechanism in temperature_map: |
| T = torch.tensor(temperature_map[mechanism], dtype=torch.float32) |
| samples = self.flow_heads[mechanism].sample(context, n_samples=n_samples, |
| temperature=T) |
| return { |
| 'mean': samples.mean(dim=1), |
| 'std': samples.std(dim=1), |
| 'median': samples.median(dim=1).values, |
| 'q05': samples.quantile(0.05, dim=1), |
| 'q95': samples.quantile(0.95, dim=1), |
| 'samples': samples, |
| } |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| if __name__ == "__main__": |
| n_mechs = len(TPD_MECHANISM_LIST) |
| B, N_beta, T = n_mechs, 3, 500 |
|
|
| x = torch.randn(B, N_beta, 2, T) |
| scan_mask = torch.ones(B, N_beta, T, dtype=torch.bool) |
| sigmas = torch.randn(B, N_beta) |
| flux_scales = torch.randn(B, N_beta) |
| mechanism_ids = torch.arange(n_mechs) |
|
|
| max_dim = max(TPD_MECHANISM_PARAMS[m]['dim'] for m in TPD_MECHANISM_LIST) |
| mech_theta = torch.randn(B, max_dim) |
| mech_theta_mask = torch.zeros(B, max_dim, dtype=torch.bool) |
| for i, mid in enumerate(mechanism_ids): |
| d = TPD_MECHANISM_PARAMS[TPD_MECHANISM_LIST[mid]]['dim'] |
| mech_theta_mask[i, :d] = True |
|
|
| print("=" * 60) |
| print("Testing MultiMechanismFlowTPD (multi-heating-rate, Set Transformer)") |
| print("=" * 60) |
|
|
| model = MultiMechanismFlowTPD( |
| d_context=128, |
| d_model=128, |
| n_coupling_layers=8, |
| hidden_dim=128, |
| coupling_type='affine', |
| ) |
|
|
| total_params = count_parameters(model) |
| print(f"Total parameters: {total_params:,}") |
| print(f" Encoder: {count_parameters(model.encoder):,}") |
| print(f" Classifier: {count_parameters(model.classifier):,}") |
| for mech in TPD_MECHANISM_LIST: |
| print(f" Flow ({mech}, dim={TPD_MECHANISM_PARAMS[mech]['dim']}): " |
| f"{count_parameters(model.flow_heads[mech]):,}") |
|
|
| out = model(x, mechanism_ids, mech_theta, mech_theta_mask, |
| scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) |
| print(f"\nForward pass:") |
| print(f" Logits shape: {out['logits'].shape}") |
| print(f" NLL shape: {out['nll'].shape}") |
| print(f" NLL values: {out['nll']}") |
|
|
| pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas, |
| flux_scales=flux_scales, n_samples=100) |
| print(f"\nPrediction:") |
| print(f" Mechanism probs shape: {pred['mechanism_probs'].shape}") |
| print(f" Mechanism xdB shape: {pred['mechanism_xdB'].shape}") |
| print(f" Predicted mechanisms: {pred['mechanism_pred']}") |
| for mech in TPD_MECHANISM_LIST: |
| if pred['samples'][mech] is not None: |
| print(f" {mech} samples shape: {pred['samples'][mech].shape}") |
|
|