""" Multi-Mechanism Normalizing Flow for Joint Mechanism Identification and Bayesian Parameter Inference from Multi-Heating-Rate TPD Signals. Architecture mirrors the electrochemistry model (multi_mechanism_model.py) but configured for TPD with 2 input channels (temperature, rate) and 6 catalysis mechanisms. Reuses domain-agnostic components from flow_model.py and multi_mechanism_model.py (SAB, PMA, MechanismClassifier, MechanismFlow). """ import torch import torch.nn as nn import torch.nn.functional as F import math from flow_model import ( SignalEncoder, ActNorm, ConditionalSplineCoupling, ConditionalAffineCoupling, ) from multi_mechanism_model import MechanismClassifier, MechanismFlow, SAB, PMA, SummaryProjection, SUMMARY_DIM from image_encoder import ImageEncoder from generate_tpd_data import TPD_MECHANISM_LIST, TPD_MECHANISM_PARAMS from tpd_bijectors import ( Identity, Logit, AffineLogit, apply_param_bijectors_forward, apply_param_bijectors_inverse, all_identity, ) # ============================================================================= # Per-parameter bijector registry (TPD-only). # # Maps parameter NAME -> bijector spec. Names are matched against # TPD_MECHANISM_PARAMS[mech]['names']. Any name not present here defaults # to Identity (current behavior, unbounded z-score normalization). # # This registry is consulted only when a model is constructed with # use_bounded_flow=True. Models trained without that flag (incl. v1 # checkpoints) ignore it entirely. # ============================================================================= TPD_PARAM_BIJECTORS = { # Strictly-bounded coverage parameters in [0, 1]. These are the # parameters that previously suffered posterior collapse / over-coverage # under z-score-only normalization. 'theta_0': ('logit',), 'theta_A0': ('logit',), 'theta_B0': ('logit',), 'theta_O0': ('logit',), # Site fraction in (0, 1). 'f_site1': ('logit',), # Initial layer count, sampled in [1.5, 8.0]; treat as bounded # in [1.0, 10.0] (with eps slack). 'n_layers': ('affine_logit', 1.0, 10.0), # Note: delta, alpha_cov, omega are tied to Ed/Ea (joint constraint) # and not strictly bounded on their own; defer to Identity. } def _make_bijector(spec): if spec is None: return Identity() kind = spec[0] if kind == 'identity': return Identity() if kind == 'logit': return Logit() if kind == 'affine_logit': _, low, high = spec return AffineLogit(low, high) raise ValueError(f"Unknown bijector spec: {spec}") def build_param_bijectors(mech: str): """Return a list of bijectors (one per parameter) for `mech`, in the same order as TPD_MECHANISM_PARAMS[mech]['names']. """ names = TPD_MECHANISM_PARAMS[mech]['names'] return [_make_bijector(TPD_PARAM_BIJECTORS.get(n)) for n in names] # ============================================================================= # Bounded variant of MechanismFlow (TPD-only) # ============================================================================= class BoundedMechanismFlow(MechanismFlow): """MechanismFlow + per-parameter reparameterization bijectors. Density of the physical parameter `theta` is computed in the unbounded representation `u = bij(theta)`, then transformed back via the change-of-variables rule: log p_theta(theta) = log p_u(u) + log|du/dtheta| where `log p_u(u)` is exactly what the base MechanismFlow would compute on `u` (z-score normalize -> flow -> Gaussian base). `theta_mean` and `theta_std` (inherited from MechanismFlow) now refer to U-SPACE statistics, not physical-space statistics. Callers must compute them in u-space (see `MultiMechanismFlowTPD.compute_u_space_stats`). """ def __init__(self, theta_dim, param_bijectors, **kwargs): super().__init__(theta_dim=theta_dim, **kwargs) assert len(param_bijectors) == theta_dim, ( f"BoundedMechanismFlow expects {theta_dim} bijectors, got {len(param_bijectors)}" ) # Register as a ModuleList so any future parameters in bijectors # would move with .to(device); currently bijectors are stateless. self.param_bijectors = nn.ModuleList(param_bijectors) self._bijector_is_noop = all_identity(self.param_bijectors) # ---- log p(theta | context) ------------------------------------------ def log_prob(self, theta, context): if self._bijector_is_noop: return super().log_prob(theta, context) # 1. theta -> u in unbounded space, accumulate log|du/dtheta|. u, log_det_b = apply_param_bijectors_forward(theta, list(self.param_bijectors)) # 2. z-score normalize in u-space, then pass through the flow. u_norm = self.normalize_theta(u) if self.coupling_type == 'spline': u_norm = u_norm.clamp(-self.tail_bound, self.tail_bound) z, log_det_flow = self.inverse_flow(u_norm, context) # 3. Standard normal base + change-of-variables for normalization # and bijector. Note: log_det_norm = -sum log(theta_std), exactly as # in MechanismFlow.log_prob. log_pz = -0.5 * (z ** 2 + math.log(2 * math.pi)).sum(dim=-1) log_det_norm = -torch.log(self.theta_std).sum() log_p = log_pz + log_det_flow + log_det_norm + log_det_b return log_p.clamp(min=-50.0, max=50.0) # ---- sampling -------------------------------------------------------- @torch.no_grad() def sample(self, context, n_samples=100, temperature=1.0): if self._bijector_is_noop: return super().sample(context, n_samples=n_samples, temperature=temperature) B = context.shape[0] context_rep = (context.unsqueeze(1) .expand(-1, n_samples, -1) .reshape(B * n_samples, -1)) z = torch.randn(B * n_samples, self.theta_dim, device=context.device) u_norm, _ = self.forward_flow(z, context_rep) u = self.denormalize_theta(u_norm) # u-space u = u.reshape(B, n_samples, self.theta_dim) # Apply temperature inflation in U-SPACE (unbounded), so inflated # samples remain valid pre-images of the bijector. if isinstance(temperature, torch.Tensor): T = temperature.to(u.device).reshape(1, 1, -1) mu = u.mean(dim=1, keepdim=True) u = mu + T * (u - mu) elif temperature != 1.0: mu = u.mean(dim=1, keepdim=True) u = mu + temperature * (u - mu) # Bring back to physical theta space. u_flat = u.reshape(B * n_samples, self.theta_dim) theta_flat, _ = apply_param_bijectors_inverse(u_flat, list(self.param_bijectors)) return theta_flat.reshape(B, n_samples, self.theta_dim) def sample_with_grad(self, context, n_samples=64): if self._bijector_is_noop: return super().sample_with_grad(context, n_samples=n_samples) B = context.shape[0] context_rep = (context.unsqueeze(1) .expand(-1, n_samples, -1) .reshape(B * n_samples, -1)) z = torch.randn(B * n_samples, self.theta_dim, device=context.device) u_norm, _ = self.forward_flow(z, context_rep) u = self.denormalize_theta(u_norm) # u-space # Bijector inverse is differentiable; gradients propagate through # both the flow and the bijector for the calibration loss. theta, _ = apply_param_bijectors_inverse(u, list(self.param_bijectors)) return theta.reshape(B, n_samples, self.theta_dim) class MultiScanEncoderTPD(nn.Module): """ Encode a set of multi-heating-rate TPD curves into a single context vector. Architecture (Set Transformer): 1. Shared per-curve encoder -> per-curve embedding - input_mode='waveform': 1-D CNN over [B*N, in_channels, T] - input_mode='image': 2-D CNN over [B*N, 1, H, W] 2. Augment with [log10(heating_rate), log10(peak_rate)] 3. SAB: self-attention across heating rates 4. PMA: attention-based pooling to single vector 5. rho MLP: project to final context Waveform input: x [B, N_beta, 2, T], scan_mask [B, N_beta, T], heating_rates [B, N_beta], rate_scales [B, N_beta] Image input: x [B, N_beta, 1, H, W], scan_mask [B, N_beta] (per-curve presence flag), heating_rates [B, N_beta], rate_scales [B, N_beta] Output: context [B, d_context] """ def __init__(self, in_channels=2, d_model=128, d_context=128, n_heads=4, input_mode='waveform', image_in_channels=1): super().__init__() if input_mode not in ('waveform', 'image', 'image+waveform'): raise ValueError(f"Unknown input_mode: {input_mode!r}") self.input_mode = input_mode if input_mode == 'waveform': self.per_cv_encoder = SignalEncoder( in_channels=in_channels, d_model=d_model, d_context=d_context, ) elif input_mode == 'image': self.per_cv_encoder = ImageEncoder( in_channels=image_in_channels, d_model=d_model, d_context=d_context, ) else: # Joint mode: parallel image + waveform encoders + fusion MLP. self.image_encoder = ImageEncoder( in_channels=image_in_channels, d_model=d_model, d_context=d_context, ) self.waveform_encoder = SignalEncoder( in_channels=in_channels, d_model=d_model, d_context=d_context, ) self.joint_fusion = nn.Sequential( nn.Linear(2 * d_context, d_context), nn.GELU(), nn.Linear(d_context, d_context), ) self.cv_augment = nn.Sequential( nn.Linear(d_context + 2, d_context), nn.GELU(), ) self.sab = SAB(d_context, n_heads=n_heads) self.pma = PMA(d_context, n_heads=n_heads, n_seeds=1) self.rho = nn.Sequential( nn.Linear(d_context, d_context), nn.GELU(), nn.Linear(d_context, d_context), ) def forward(self, x, scan_mask=None, sigmas=None, flux_scales=None): """ Waveform args: x: [B, N_beta, 2, T] multi-heating-rate TPD curves scan_mask: [B, N_beta, T] valid timestep mask sigmas: [B, N_beta] log10 heating rates flux_scales: [B, N_beta] log10(peak_rate) per curve Image args: x: [B, N_beta, 1, H, W] grayscale plot images in [0, 1] scan_mask: [B, N_beta] per-curve presence (or [B, N_beta, T] back-compat, reduced via .any(dim=-1)) sigmas, flux_scales: same as waveform mode. Returns: context: [B, d_context] """ if self.input_mode == 'image': return self._forward_image(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) if self.input_mode == 'image+waveform': if not isinstance(x, dict): raise ValueError( "image+waveform mode expects x to be a dict with keys " "'image' and 'waveform'; got tensor" ) return self._forward_joint( x['image'], x['waveform'], scan_mask_image=x.get('scan_mask_image'), scan_mask_waveform=x.get('scan_mask_waveform', scan_mask), sigmas=sigmas, flux_scales=flux_scales, ) B, N, C, T = x.shape x_flat = x.reshape(B * N, C, T) mask_flat = scan_mask.reshape(B * N, T) if scan_mask is not None else None h_flat = self.per_cv_encoder(x_flat, mask=mask_flat) h = h_flat.reshape(B, N, -1) if sigmas is None: sigmas = torch.zeros(B, N, device=x.device) if flux_scales is None: flux_scales = torch.zeros(B, N, device=x.device) aug_features = torch.stack([sigmas, flux_scales], dim=-1) h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) if scan_mask is not None: cv_invalid = ~scan_mask.any(dim=-1) # [B, N] True = padded else: cv_invalid = None h = self.sab(h, key_padding_mask=cv_invalid) h = self.pma(h, key_padding_mask=cv_invalid) # [B, 1, d_context] h = h.squeeze(1) context = self.rho(h) return context def _forward_joint(self, x_image, x_waveform, scan_mask_image=None, scan_mask_waveform=None, sigmas=None, flux_scales=None): """Joint image+waveform forward; mirrors MultiScanEncoder._forward_joint.""" if x_image.dim() != 5: raise ValueError( f"Joint mode x_image expected [B,N,1,H,W]; got {tuple(x_image.shape)}" ) if x_waveform.dim() != 4: raise ValueError( f"Joint mode x_waveform expected [B,N,2,T]; got {tuple(x_waveform.shape)}" ) B, N = x_image.shape[0], x_image.shape[1] device = x_image.device x_img_flat = x_image.reshape(B * N, *x_image.shape[2:]) h_img_flat = self.image_encoder(x_img_flat) T_w = x_waveform.shape[-1] x_wave_flat = x_waveform.reshape(B * N, x_waveform.shape[2], T_w) wave_mask_flat = ( scan_mask_waveform.reshape(B * N, T_w) if scan_mask_waveform is not None else None ) h_wave_flat = self.waveform_encoder(x_wave_flat, mask=wave_mask_flat) h_joint_flat = self.joint_fusion( torch.cat([h_img_flat, h_wave_flat], dim=-1) ) h = h_joint_flat.reshape(B, N, -1) if sigmas is None: sigmas = torch.zeros(B, N, device=device) if flux_scales is None: flux_scales = torch.zeros(B, N, device=device) aug_features = torch.stack([sigmas, flux_scales], dim=-1) h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) if scan_mask_image is not None: cv_invalid = ~scan_mask_image.bool() elif scan_mask_waveform is not None: cv_invalid = ~scan_mask_waveform.any(dim=-1) else: cv_invalid = None h = self.sab(h, key_padding_mask=cv_invalid) h = self.pma(h, key_padding_mask=cv_invalid) h = h.squeeze(1) return self.rho(h) def _forward_image(self, x, scan_mask=None, sigmas=None, flux_scales=None): """Image-mode forward; mirrors MultiScanEncoder._forward_image.""" if x.dim() != 5: raise ValueError( f"Image mode expects x of shape [B, N, C, H, W]; got {tuple(x.shape)}" ) B, N, C, H, W = x.shape x_flat = x.reshape(B * N, C, H, W) h_flat = self.per_cv_encoder(x_flat) h = h_flat.reshape(B, N, -1) if sigmas is None: sigmas = torch.zeros(B, N, device=x.device) if flux_scales is None: flux_scales = torch.zeros(B, N, device=x.device) aug_features = torch.stack([sigmas, flux_scales], dim=-1) h = self.cv_augment(torch.cat([h, aug_features], dim=-1)) if scan_mask is not None: if scan_mask.dim() == 3: cv_invalid = ~scan_mask.any(dim=-1) else: cv_invalid = ~scan_mask.bool() else: cv_invalid = None h = self.sab(h, key_padding_mask=cv_invalid) h = self.pma(h, key_padding_mask=cv_invalid) h = h.squeeze(1) context = self.rho(h) return context class MultiMechanismFlowTPD(nn.Module): """ Joint mechanism identification and parameter inference model for TPD. Combines: - Multi-heating-rate signal encoder (Set Transformer over per-curve embeddings) - Mechanism classifier (6 TPD mechanisms) - Per-mechanism normalizing flow heads If use_summary_features=True, replaces the signal encoder with a simple MLP projection from hand-crafted summary statistics (21-dim) to context space, keeping all other components identical. """ def __init__( self, d_context=128, d_model=128, n_coupling_layers=6, hidden_dim=96, coupling_type='spline', n_bins=8, tail_bound=5.0, use_summary_features=False, mechanism_list=None, use_bounded_flow=False, input_mode='waveform', image_in_channels=1, ): super().__init__() mech_list = mechanism_list if mechanism_list is not None else TPD_MECHANISM_LIST self.n_mechanisms = len(mech_list) self.mechanism_list = mech_list self.d_context = d_context self.use_summary_features = use_summary_features self.use_bounded_flow = use_bounded_flow self.input_mode = input_mode if use_summary_features: self.summary_proj = SummaryProjection( summary_dim=SUMMARY_DIM, d_context=d_context, ) self.encoder = None else: self.encoder = MultiScanEncoderTPD( in_channels=2, d_model=d_model, d_context=d_context, input_mode=input_mode, image_in_channels=image_in_channels, ) self.summary_proj = None self.classifier = MechanismClassifier( d_context=d_context, n_mechanisms=self.n_mechanisms, hidden_dim=hidden_dim, ) self.flow_heads = nn.ModuleDict() for mech in mech_list: theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] if use_bounded_flow: self.flow_heads[mech] = BoundedMechanismFlow( theta_dim=theta_dim, param_bijectors=build_param_bijectors(mech), d_context=d_context, n_coupling_layers=n_coupling_layers, hidden_dim=hidden_dim, coupling_type=coupling_type, n_bins=n_bins, tail_bound=tail_bound, ) else: self.flow_heads[mech] = MechanismFlow( theta_dim=theta_dim, d_context=d_context, n_coupling_layers=n_coupling_layers, hidden_dim=hidden_dim, coupling_type=coupling_type, n_bins=n_bins, tail_bound=tail_bound, ) self.ood_head = None self.ood_head_extra_dim = 0 def init_ood_head(self, hidden_dim=64, dropout=0.3, extra_input_dim=0): """Initialize the binary OOD detection head. Takes context vector + softmax probs (+ optional extra features) as input, outputs logit for P(in-distribution). Mirrors `MultiMechanismFlow.init_ood_head` so the same training recipe (`train_ood_head_tpd.py`) can be applied. """ input_dim = self.d_context + self.n_mechanisms + extra_input_dim self.ood_head = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, 1), ) self.ood_head_extra_dim = extra_input_dim return self.ood_head def set_theta_stats(self, mechanism, mean, std): """Set normalization stats for a specific mechanism's flow head. IMPORTANT: when this model was constructed with use_bounded_flow=True, `mean` and `std` must already be in U-SPACE (post-bijector). Use `compute_stats_for_mechanism` below to convert raw physical samples. """ self.flow_heads[mechanism].set_theta_stats(mean, std) def compute_stats_for_mechanism(self, mechanism, physical_thetas): """Compute mean/std for `set_theta_stats` from raw physical samples. - If use_bounded_flow=False: returns physical-space mean/std. - If use_bounded_flow=True : applies the per-parameter bijectors to `physical_thetas` first, then returns u-space mean/std. Args: mechanism: mechanism name in self.mechanism_list. physical_thetas: [N, theta_dim] tensor of training-set thetas in their natural physical units. Returns: (mean, std) tensors of shape [theta_dim] suitable for set_theta_stats(mechanism, ...). """ t = physical_thetas if self.use_bounded_flow: head = self.flow_heads[mechanism] u, _ = apply_param_bijectors_forward(t, list(head.param_bijectors)) t = u mean = t.mean(dim=0) std = t.std(dim=0).clamp(min=0.1) if t.shape[0] > 1 else torch.ones_like(mean) return mean, std def encode_signal(self, x, scan_mask=None, sigmas=None, flux_scales=None, summary=None): if self.use_summary_features: assert summary is not None, "summary features required in summary mode" return self.summary_proj(summary) return self.encoder(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) @staticmethod def _batch_meta(x): """Return (batch_size, device) from x, which may be a dict (joint mode) or a tensor (single-modality).""" if isinstance(x, dict): ref = x['image'] if 'image' in x else x['waveform'] return ref.shape[0], ref.device return x.shape[0], x.device def _forward_impl(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, scan_mask=None, sigmas=None, flux_scales=None, summary=None, return_calibration=False, cal_n_samples=64, cal_levels=(0.5, 0.9), cal_beta=20.0, return_context=False, return_min_var=False, min_var_n_samples=32, min_log_var=-6.0): context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) logits = self.classifier(context) B, device = self._batch_meta(x) nll = torch.zeros(B, device=device) cal_losses = [] mv_penalties = [] for m_idx, mech in enumerate(self.mechanism_list): sel = (mechanism_ids == m_idx) if not sel.any(): continue theta_dim = TPD_MECHANISM_PARAMS[mech]['dim'] ctx_m = context[sel] theta_m = mech_theta[sel, :theta_dim] log_p = self.flow_heads[mech].log_prob(theta_m, ctx_m) bad = ~torch.isfinite(log_p) if bad.any(): log_p = torch.where(bad, torch.full_like(log_p, -10.0).detach(), log_p) nll[sel] = -log_p need_cal = return_calibration and ctx_m.shape[0] >= 4 need_mv = return_min_var and ctx_m.shape[0] >= 2 if need_cal: samples = self.flow_heads[mech].sample_with_grad( ctx_m, n_samples=cal_n_samples, ) with torch.no_grad(): param_std = samples.std(dim=1).clamp(min=1e-4) # [B_m, D] inv_spread_w = 1.0 / param_std # [B_m, D] inv_spread_w = inv_spread_w / inv_spread_w.mean() for level in cal_levels: alpha = (1.0 - level) / 2.0 lower = torch.quantile(samples, alpha, dim=1) # [B_m, D] upper = torch.quantile(samples, 1 - alpha, dim=1) # [B_m, D] inside = self._straight_through_containment( theta_m, lower, upper, cal_beta) # [B_m, D] per_sample_loss = (inside - level).pow(2) # [B_m, D] cal_losses.append((per_sample_loss * inv_spread_w).mean()) if need_mv: log_var = torch.log(samples.var(dim=1) + 1e-8) penalty = torch.relu(min_log_var - log_var) mv_penalties.append(penalty.mean()) elif need_mv: mv_samples = self.flow_heads[mech].sample_with_grad( ctx_m, n_samples=min_var_n_samples, ) log_var = torch.log(mv_samples.var(dim=1) + 1e-8) penalty = torch.relu(min_log_var - log_var) mv_penalties.append(penalty.mean()) out = {'logits': logits, 'nll': nll} if return_calibration: if cal_losses: out['cal_loss'] = torch.stack(cal_losses).mean() else: out['cal_loss'] = torch.tensor(0.0, device=device) if return_min_var: if mv_penalties: out['min_var_loss'] = torch.stack(mv_penalties).mean() else: out['min_var_loss'] = torch.tensor(0.0, device=device) if return_context: out['context'] = context return out def forward(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, scan_mask=None, sigmas=None, flux_scales=None, summary=None, return_calibration=False, cal_n_samples=64, cal_levels=(0.5, 0.9), cal_beta=20.0, return_context=False, return_min_var=False, min_var_n_samples=32, min_log_var=-6.0): """ Compute classification logits and per-sample NLL for the true mechanism. Returns: dict with 'logits' [B, n_mechanisms] and 'nll' [B] """ return self._forward_impl( x, mechanism_ids, mech_theta, mech_theta_mask, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary, return_calibration=return_calibration, cal_n_samples=cal_n_samples, cal_levels=cal_levels, cal_beta=cal_beta, return_context=return_context, return_min_var=return_min_var, min_var_n_samples=min_var_n_samples, min_log_var=min_log_var, ) @staticmethod def _straight_through_containment(theta, lower, upper, beta): """Hard containment indicator with straight-through gradient estimator. Forward: hard 0/1 indicator (unbiased coverage estimate). Backward: sigmoid gradient (smooth, trainable). """ soft = (torch.sigmoid(beta * (theta - lower)) * torch.sigmoid(beta * (upper - theta))) hard = ((theta >= lower) & (theta <= upper)).float() return hard + (soft - soft.detach()) # STE: hard forward, soft backward def forward_with_calibration(self, x, mechanism_ids, mech_theta, mech_theta_mask=None, scan_mask=None, sigmas=None, flux_scales=None, cal_n_samples=64, cal_levels=(0.5, 0.9), cal_beta=20.0, summary=None): """Forward pass with additional calibration loss (see MultiMechanismFlow).""" return self._forward_impl( x, mechanism_ids, mech_theta, mech_theta_mask, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary, return_calibration=True, cal_n_samples=cal_n_samples, cal_levels=cal_levels, cal_beta=cal_beta, ) @torch.no_grad() def predict(self, x, scan_mask=None, sigmas=None, flux_scales=None, n_samples=200, top_k=None, temperature=1.0, temperature_map=None, summary=None): """ Full inference: classify mechanism, then sample parameters. Args: temperature: scalar fallback (>1 broadens posteriors) temperature_map: dict mapping mechanism name -> list of per-parameter temperatures. Overrides scalar temperature. summary: [B, 21] hand-crafted summary stats (summary mode only) Returns: dict with mechanism_probs, mechanism_xdB, mechanism_pred, samples, stats """ context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) logits = self.classifier(context) probs = F.softmax(logits, dim=-1) pred = probs.argmax(dim=-1) probs_clamped = probs.clamp(min=1e-7, max=1 - 1e-7) xdB = 10.0 * torch.log10(probs_clamped / (1.0 - probs_clamped)) samples_dict = {} stats_dict = {} for m_idx, mech in enumerate(self.mechanism_list): if top_k is not None: top_k_mechs = probs.topk(top_k, dim=-1).indices if not (top_k_mechs == m_idx).any(): samples_dict[mech] = None stats_dict[mech] = None continue T = temperature if temperature_map is not None and mech in temperature_map: T = torch.tensor(temperature_map[mech], dtype=torch.float32) s = self.flow_heads[mech].sample(context, n_samples=n_samples, temperature=T) samples_dict[mech] = s stats_dict[mech] = { 'mean': s.mean(dim=1), 'std': s.std(dim=1), 'median': s.median(dim=1).values, 'q05': s.quantile(0.05, dim=1), 'q95': s.quantile(0.95, dim=1), } ood_score = None if self.ood_head is not None: ood_input = torch.cat([context, probs], dim=-1) ood_logit = self.ood_head(ood_input).squeeze(-1) ood_score = torch.sigmoid(ood_logit) return { 'mechanism_probs': probs, 'mechanism_xdB': xdB, 'mechanism_pred': pred, 'samples': samples_dict, 'stats': stats_dict, 'ood_score': ood_score, } def predict_single_mechanism(self, x, mechanism, scan_mask=None, sigmas=None, flux_scales=None, n_samples=1000, temperature=1.0, temperature_map=None, summary=None): """Sample parameters assuming a known mechanism.""" context = self.encode_signal(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, summary=summary) T = temperature if temperature_map is not None and mechanism in temperature_map: T = torch.tensor(temperature_map[mechanism], dtype=torch.float32) samples = self.flow_heads[mechanism].sample(context, n_samples=n_samples, temperature=T) return { 'mean': samples.mean(dim=1), 'std': samples.std(dim=1), 'median': samples.median(dim=1).values, 'q05': samples.quantile(0.05, dim=1), 'q95': samples.quantile(0.95, dim=1), 'samples': samples, } def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) if __name__ == "__main__": n_mechs = len(TPD_MECHANISM_LIST) B, N_beta, T = n_mechs, 3, 500 x = torch.randn(B, N_beta, 2, T) scan_mask = torch.ones(B, N_beta, T, dtype=torch.bool) sigmas = torch.randn(B, N_beta) flux_scales = torch.randn(B, N_beta) mechanism_ids = torch.arange(n_mechs) max_dim = max(TPD_MECHANISM_PARAMS[m]['dim'] for m in TPD_MECHANISM_LIST) mech_theta = torch.randn(B, max_dim) mech_theta_mask = torch.zeros(B, max_dim, dtype=torch.bool) for i, mid in enumerate(mechanism_ids): d = TPD_MECHANISM_PARAMS[TPD_MECHANISM_LIST[mid]]['dim'] mech_theta_mask[i, :d] = True print("=" * 60) print("Testing MultiMechanismFlowTPD (multi-heating-rate, Set Transformer)") print("=" * 60) model = MultiMechanismFlowTPD( d_context=128, d_model=128, n_coupling_layers=8, hidden_dim=128, coupling_type='affine', ) total_params = count_parameters(model) print(f"Total parameters: {total_params:,}") print(f" Encoder: {count_parameters(model.encoder):,}") print(f" Classifier: {count_parameters(model.classifier):,}") for mech in TPD_MECHANISM_LIST: print(f" Flow ({mech}, dim={TPD_MECHANISM_PARAMS[mech]['dim']}): " f"{count_parameters(model.flow_heads[mech]):,}") out = model(x, mechanism_ids, mech_theta, mech_theta_mask, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales) print(f"\nForward pass:") print(f" Logits shape: {out['logits'].shape}") print(f" NLL shape: {out['nll'].shape}") print(f" NLL values: {out['nll']}") pred = model.predict(x, scan_mask=scan_mask, sigmas=sigmas, flux_scales=flux_scales, n_samples=100) print(f"\nPrediction:") print(f" Mechanism probs shape: {pred['mechanism_probs'].shape}") print(f" Mechanism xdB shape: {pred['mechanism_xdB'].shape}") print(f" Predicted mechanisms: {pred['mechanism_pred']}") for mech in TPD_MECHANISM_LIST: if pred['samples'][mech] is not None: print(f" {mech} samples shape: {pred['samples'][mech].shape}")