| """ |
| Full MuLGITPerturb Model. |
| |
| Combines the frozen MuLGIT backbone with trainable perturbation response |
| prediction components to produce: |
| 1. Per-gene expression change (Δ) with uncertainty (σ²) |
| 2. Predicted post-perturbation molecular state |
| 3. Pathway-level activation scores |
| 4. Risk shift prediction (change in survival/longevity risk) |
| |
| Architecture: |
| Input: (baseline_multi_omics, perturbation_descriptor) |
| └─ MuLGIT (frozen) → z_fused (baseline latent state) |
| └─ PerturbationEncoder → z_pert (perturbation embedding) |
| └─ FiLMConditioning → z_cond |
| └─ DeltaDecoderWithUncertainty → Δ, σ² |
| └─ PathwayDecoder → pathway scores |
| └─ RiskShiftHead → Δ_risk |
| Output: {delta, sigma2, y_post, pathway_scores, risk_shift, z_fused, z_pert} |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, List, Tuple |
| import sys |
| import os |
|
|
| |
| from .config import MuLGITPerturbConfig |
| from .encoder import PerturbationEncoder, create_perturbation_encoder |
| from .conditioning import ( |
| FiLMConditioning, DeltaDecoderWithUncertainty, PathwayDecoder, |
| create_film_conditioning, create_delta_decoder, |
| ) |
|
|
|
|
| class RiskShiftHead(nn.Module): |
| """ |
| Predicts the change in MuLGIT survival risk caused by a perturbation. |
| |
| Takes the perturbation embedding and the baseline state, predicts |
| how much the risk score should change: |
| risk_post = risk_baseline + Δ_risk |
| |
| This links perturbation prediction back to the survival outcome |
| that MuLGIT was originally trained for. |
| """ |
|
|
| def __init__(self, state_dim: int = 48, pert_dim: int = 768): |
| super().__init__() |
| combined_dim = state_dim + pert_dim |
| self.net = nn.Sequential( |
| nn.Linear(combined_dim, 256), |
| nn.SELU(), |
| nn.AlphaDropout(0.1), |
| nn.Linear(256, 128), |
| nn.SELU(), |
| nn.AlphaDropout(0.1), |
| nn.Linear(128, 1), |
| ) |
|
|
| def forward(self, z_fused: torch.Tensor, z_pert: torch.Tensor) -> torch.Tensor: |
| """Predict risk shift Δ_risk.""" |
| combined = torch.cat([z_fused, z_pert], dim=-1) |
| return self.net(combined).squeeze(-1) |
|
|
|
|
| class MuLGITPerturb(nn.Module): |
| """ |
| Full perturbation prediction model built on frozen MuLGIT backbone. |
| |
| Input: |
| - baseline_methylation, baseline_cnv, baseline_mrna, baseline_mirna |
| - perturbation: SMILES (list[str]) and/or gene_ids (LongTensor) + pert_types (list[str]) |
| |
| Output: |
| - delta: (B, n_genes) predicted logFC per gene |
| - sigma2: (B, n_genes) predicted variance per gene |
| - y_post: (B, n_genes) predicted post-perturbation expression |
| - pathway_scores: (B, n_pathways) pathway activation |
| - risk_shift: (B,) change in survival risk |
| - z_fused: (B, latent_dim) baseline state |
| - z_pert: (B, pert_dim) perturbation embedding |
| """ |
|
|
| def __init__( |
| self, |
| mulgit_model, |
| config: MuLGITPerturbConfig, |
| ): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.mulgit = mulgit_model |
| self.mulgit_latent_dim = mulgit_model.latent_dim |
|
|
| |
| for p in self.mulgit.parameters(): |
| p.requires_grad = False |
| self.mulgit.eval() |
|
|
| |
| self.pert_encoder = create_perturbation_encoder(config) |
| self.film = create_film_conditioning(config) |
| self.delta_decoder = create_delta_decoder(config) |
|
|
| |
| self.pathway_decoder = PathwayDecoder( |
| n_genes=config.n_output_genes, |
| ) |
|
|
| |
| self.risk_head = RiskShiftHead( |
| state_dim=config.mulgit_latent_dim, |
| pert_dim=config.pert_output_dim, |
| ) |
|
|
| def encode_baseline( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Encode baseline omics through frozen MuLGIT. |
| |
| Returns the fused latent state and intermediates. |
| """ |
| with torch.no_grad(): |
| outputs = self.mulgit( |
| methylation, cnv, mrna, mirna, |
| return_intermediates=True, |
| ) |
| return outputs |
|
|
| def forward( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| smiles_list: Optional[List[str]] = None, |
| gene_ids: Optional[torch.LongTensor] = None, |
| pert_types: Optional[List[str]] = None, |
| return_all: bool = True, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass: baseline + perturbation → predicted response. |
| |
| Args: |
| methylation: (B, D_methylation) |
| cnv: (B, D_cnv) |
| mrna: (B, D_mrna) — also used as baseline expression for y_post = mrna + delta |
| mirna: (B, D_mirna) |
| smiles_list: list of SMILES strings (length B), or None |
| gene_ids: (B,) LongTensor of gene indices, or None |
| pert_types: list of perturbation type strings, or None |
| return_all: whether to return all intermediate representations |
| |
| Returns: |
| dict with keys: delta, sigma2, y_post, pathway_scores, risk_shift, |
| z_fused, z_pert, risk_baseline, risk_post |
| """ |
| |
| mulgit_outputs = self.encode_baseline(methylation, cnv, mrna, mirna) |
| z_fused = mulgit_outputs["fused"] |
| risk_baseline = mulgit_outputs["risk"] |
|
|
| |
| z_pert = self.pert_encoder(smiles_list, gene_ids, pert_types) |
|
|
| |
| z_cond = self.film(z_fused, z_pert) |
|
|
| |
| delta, sigma2 = self.delta_decoder(z_cond) |
|
|
| |
| |
| |
| y_post = mrna[:, :delta.shape[1]] + delta |
|
|
| |
| pathway_scores = self.pathway_decoder(delta) |
|
|
| |
| delta_risk = self.risk_head(z_fused, z_pert) |
| risk_post = risk_baseline + delta_risk |
|
|
| outputs = { |
| "delta": delta, |
| "sigma2": sigma2, |
| "y_post": y_post, |
| "pathway_scores": pathway_scores, |
| "delta_risk": delta_risk, |
| "risk_baseline": risk_baseline, |
| "risk_post": risk_post, |
| } |
|
|
| if return_all: |
| outputs["z_fused"] = z_fused |
| outputs["z_pert"] = z_pert |
| outputs["z_cond"] = z_cond |
|
|
| return outputs |
|
|
| def predict_counterfactual( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| smiles_list: Optional[List[str]] = None, |
| gene_ids: Optional[torch.LongTensor] = None, |
| pert_types: Optional[List[str]] = None, |
| confidence_level: float = 0.95, |
| ) -> Dict: |
| """ |
| Predict perturbation response with confidence intervals. |
| |
| Returns a structured dict suitable for downstream analysis, |
| including per-gene effect sizes with uncertainty. |
| """ |
| outputs = self.forward( |
| methylation, cnv, mrna, mirna, |
| smiles_list, gene_ids, pert_types, |
| return_all=False, |
| ) |
|
|
| |
| |
| |
| from scipy.stats import norm |
| z_score = norm.ppf(1 - (1 - confidence_level) / 2) |
|
|
| delta = outputs["delta"] |
| sigma2 = outputs["sigma2"] |
| ci_low = delta - z_score * torch.sqrt(sigma2) |
| ci_high = delta + z_score * torch.sqrt(sigma2) |
|
|
| outputs["ci_low"] = ci_low |
| outputs["ci_high"] = ci_high |
| outputs["confidence_level"] = confidence_level |
|
|
| return outputs |
|
|
| def train(self, mode: bool = True): |
| """ |
| Set training mode. MuLGIT backbone stays frozen in eval mode. |
| """ |
| super().train(mode) |
| self.mulgit.eval() |
| return self |
|
|
|
|
| def create_mulgit_perturb( |
| mulgit_model, |
| config: Optional[MuLGITPerturbConfig] = None, |
| **config_overrides, |
| ) -> MuLGITPerturb: |
| """ |
| Factory function to create a MuLGITPerturb model. |
| |
| Args: |
| mulgit_model: pretrained MuLGITModel (from mulgit.models) |
| config: MuLGITPerturbConfig, or None to use defaults |
| **config_overrides: override specific config values |
| |
| Returns: |
| MuLGITPerturb model ready for training |
| """ |
| if config is None: |
| |
| config = MuLGITPerturbConfig( |
| mulgit_latent_dim=mulgit_model.latent_dim, |
| ) |
|
|
| |
| for key, value in config_overrides.items(): |
| if hasattr(config, key): |
| setattr(config, key, value) |
|
|
| return MuLGITPerturb(mulgit_model, config) |
|
|