""" Full MuLGITPerturb Model. Combines the frozen MuLGIT backbone with trainable perturbation response prediction components to produce: 1. Per-gene expression change (Δ) with uncertainty (σ²) 2. Predicted post-perturbation molecular state 3. Pathway-level activation scores 4. Risk shift prediction (change in survival/longevity risk) Architecture: Input: (baseline_multi_omics, perturbation_descriptor) └─ MuLGIT (frozen) → z_fused (baseline latent state) └─ PerturbationEncoder → z_pert (perturbation embedding) └─ FiLMConditioning → z_cond └─ DeltaDecoderWithUncertainty → Δ, σ² └─ PathwayDecoder → pathway scores └─ RiskShiftHead → Δ_risk Output: {delta, sigma2, y_post, pathway_scores, risk_shift, z_fused, z_pert} """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, List, Tuple import sys import os # Import from sibling modules from .config import MuLGITPerturbConfig from .encoder import PerturbationEncoder, create_perturbation_encoder from .conditioning import ( FiLMConditioning, DeltaDecoderWithUncertainty, PathwayDecoder, create_film_conditioning, create_delta_decoder, ) class RiskShiftHead(nn.Module): """ Predicts the change in MuLGIT survival risk caused by a perturbation. Takes the perturbation embedding and the baseline state, predicts how much the risk score should change: risk_post = risk_baseline + Δ_risk This links perturbation prediction back to the survival outcome that MuLGIT was originally trained for. """ def __init__(self, state_dim: int = 48, pert_dim: int = 768): super().__init__() combined_dim = state_dim + pert_dim self.net = nn.Sequential( nn.Linear(combined_dim, 256), nn.SELU(), nn.AlphaDropout(0.1), nn.Linear(256, 128), nn.SELU(), nn.AlphaDropout(0.1), nn.Linear(128, 1), ) def forward(self, z_fused: torch.Tensor, z_pert: torch.Tensor) -> torch.Tensor: """Predict risk shift Δ_risk.""" combined = torch.cat([z_fused, z_pert], dim=-1) return self.net(combined).squeeze(-1) # (B,) class MuLGITPerturb(nn.Module): """ Full perturbation prediction model built on frozen MuLGIT backbone. Input: - baseline_methylation, baseline_cnv, baseline_mrna, baseline_mirna - perturbation: SMILES (list[str]) and/or gene_ids (LongTensor) + pert_types (list[str]) Output: - delta: (B, n_genes) predicted logFC per gene - sigma2: (B, n_genes) predicted variance per gene - y_post: (B, n_genes) predicted post-perturbation expression - pathway_scores: (B, n_pathways) pathway activation - risk_shift: (B,) change in survival risk - z_fused: (B, latent_dim) baseline state - z_pert: (B, pert_dim) perturbation embedding """ def __init__( self, mulgit_model, # pretrained MuLGITModel (frozen) config: MuLGITPerturbConfig, ): super().__init__() self.config = config # ── Frozen MuLGIT backbone ─────────────────────────────────── self.mulgit = mulgit_model self.mulgit_latent_dim = mulgit_model.latent_dim # Freeze ALL MuLGIT parameters for p in self.mulgit.parameters(): p.requires_grad = False self.mulgit.eval() # ── Trainable perturbation components ──────────────────────── self.pert_encoder = create_perturbation_encoder(config) self.film = create_film_conditioning(config) self.delta_decoder = create_delta_decoder(config) # Pathway decoder self.pathway_decoder = PathwayDecoder( n_genes=config.n_output_genes, ) # Risk shift head self.risk_head = RiskShiftHead( state_dim=config.mulgit_latent_dim, pert_dim=config.pert_output_dim, ) def encode_baseline( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ Encode baseline omics through frozen MuLGIT. Returns the fused latent state and intermediates. """ with torch.no_grad(): outputs = self.mulgit( methylation, cnv, mrna, mirna, return_intermediates=True, ) return outputs def forward( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, smiles_list: Optional[List[str]] = None, gene_ids: Optional[torch.LongTensor] = None, pert_types: Optional[List[str]] = None, return_all: bool = True, ) -> Dict[str, torch.Tensor]: """ Forward pass: baseline + perturbation → predicted response. Args: methylation: (B, D_methylation) cnv: (B, D_cnv) mrna: (B, D_mrna) — also used as baseline expression for y_post = mrna + delta mirna: (B, D_mirna) smiles_list: list of SMILES strings (length B), or None gene_ids: (B,) LongTensor of gene indices, or None pert_types: list of perturbation type strings, or None return_all: whether to return all intermediate representations Returns: dict with keys: delta, sigma2, y_post, pathway_scores, risk_shift, z_fused, z_pert, risk_baseline, risk_post """ # 1. Encode baseline through frozen MuLGIT mulgit_outputs = self.encode_baseline(methylation, cnv, mrna, mirna) z_fused = mulgit_outputs["fused"] # (B, latent_dim) risk_baseline = mulgit_outputs["risk"] # (B,) # 2. Encode perturbation z_pert = self.pert_encoder(smiles_list, gene_ids, pert_types) # (B, pert_dim) # 3. Condition baseline state on perturbation z_cond = self.film(z_fused, z_pert) # (B, latent_dim) # 4. Predict Δ and uncertainty delta, sigma2 = self.delta_decoder(z_cond) # (B, n_genes), (B, n_genes) # 5. Predicted post-perturbation expression # mrna is the baseline expression; we predict delta from it # If n_output_genes != n_mrna, we use a subset y_post = mrna[:, :delta.shape[1]] + delta # 6. Pathway scores pathway_scores = self.pathway_decoder(delta) # (B, n_pathways) # 7. Risk shift delta_risk = self.risk_head(z_fused, z_pert) # (B,) risk_post = risk_baseline + delta_risk outputs = { "delta": delta, "sigma2": sigma2, "y_post": y_post, "pathway_scores": pathway_scores, "delta_risk": delta_risk, "risk_baseline": risk_baseline, "risk_post": risk_post, } if return_all: outputs["z_fused"] = z_fused outputs["z_pert"] = z_pert outputs["z_cond"] = z_cond return outputs def predict_counterfactual( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, smiles_list: Optional[List[str]] = None, gene_ids: Optional[torch.LongTensor] = None, pert_types: Optional[List[str]] = None, confidence_level: float = 0.95, ) -> Dict: """ Predict perturbation response with confidence intervals. Returns a structured dict suitable for downstream analysis, including per-gene effect sizes with uncertainty. """ outputs = self.forward( methylation, cnv, mrna, mirna, smiles_list, gene_ids, pert_types, return_all=False, ) # Compute confidence intervals from predicted variance # Assuming Gaussian predictive distribution: # CI: delta ± z_α/2 * sqrt(sigma2) from scipy.stats import norm z_score = norm.ppf(1 - (1 - confidence_level) / 2) # ~1.96 for 95% delta = outputs["delta"] sigma2 = outputs["sigma2"] ci_low = delta - z_score * torch.sqrt(sigma2) ci_high = delta + z_score * torch.sqrt(sigma2) outputs["ci_low"] = ci_low outputs["ci_high"] = ci_high outputs["confidence_level"] = confidence_level return outputs def train(self, mode: bool = True): """ Set training mode. MuLGIT backbone stays frozen in eval mode. """ super().train(mode) self.mulgit.eval() # always keep MuLGIT frozen return self def create_mulgit_perturb( mulgit_model, config: Optional[MuLGITPerturbConfig] = None, **config_overrides, ) -> MuLGITPerturb: """ Factory function to create a MuLGITPerturb model. Args: mulgit_model: pretrained MuLGITModel (from mulgit.models) config: MuLGITPerturbConfig, or None to use defaults **config_overrides: override specific config values Returns: MuLGITPerturb model ready for training """ if config is None: # Infer dimensions from the MuLGIT model config = MuLGITPerturbConfig( mulgit_latent_dim=mulgit_model.latent_dim, ) # Apply overrides for key, value in config_overrides.items(): if hasattr(config, key): setattr(config, key, value) return MuLGITPerturb(mulgit_model, config)