MuLGIT / mulgit /perturb /model.py
vedatonuryilmaz's picture
Upload mulgit/perturb/model.py
27be5bc verified
"""
Full MuLGITPerturb Model.
Combines the frozen MuLGIT backbone with trainable perturbation response
prediction components to produce:
1. Per-gene expression change (Δ) with uncertainty (σ²)
2. Predicted post-perturbation molecular state
3. Pathway-level activation scores
4. Risk shift prediction (change in survival/longevity risk)
Architecture:
Input: (baseline_multi_omics, perturbation_descriptor)
└─ MuLGIT (frozen) → z_fused (baseline latent state)
└─ PerturbationEncoder → z_pert (perturbation embedding)
└─ FiLMConditioning → z_cond
└─ DeltaDecoderWithUncertainty → Δ, σ²
└─ PathwayDecoder → pathway scores
└─ RiskShiftHead → Δ_risk
Output: {delta, sigma2, y_post, pathway_scores, risk_shift, z_fused, z_pert}
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple
import sys
import os
# Import from sibling modules
from .config import MuLGITPerturbConfig
from .encoder import PerturbationEncoder, create_perturbation_encoder
from .conditioning import (
FiLMConditioning, DeltaDecoderWithUncertainty, PathwayDecoder,
create_film_conditioning, create_delta_decoder,
)
class RiskShiftHead(nn.Module):
"""
Predicts the change in MuLGIT survival risk caused by a perturbation.
Takes the perturbation embedding and the baseline state, predicts
how much the risk score should change:
risk_post = risk_baseline + Δ_risk
This links perturbation prediction back to the survival outcome
that MuLGIT was originally trained for.
"""
def __init__(self, state_dim: int = 48, pert_dim: int = 768):
super().__init__()
combined_dim = state_dim + pert_dim
self.net = nn.Sequential(
nn.Linear(combined_dim, 256),
nn.SELU(),
nn.AlphaDropout(0.1),
nn.Linear(256, 128),
nn.SELU(),
nn.AlphaDropout(0.1),
nn.Linear(128, 1),
)
def forward(self, z_fused: torch.Tensor, z_pert: torch.Tensor) -> torch.Tensor:
"""Predict risk shift Δ_risk."""
combined = torch.cat([z_fused, z_pert], dim=-1)
return self.net(combined).squeeze(-1) # (B,)
class MuLGITPerturb(nn.Module):
"""
Full perturbation prediction model built on frozen MuLGIT backbone.
Input:
- baseline_methylation, baseline_cnv, baseline_mrna, baseline_mirna
- perturbation: SMILES (list[str]) and/or gene_ids (LongTensor) + pert_types (list[str])
Output:
- delta: (B, n_genes) predicted logFC per gene
- sigma2: (B, n_genes) predicted variance per gene
- y_post: (B, n_genes) predicted post-perturbation expression
- pathway_scores: (B, n_pathways) pathway activation
- risk_shift: (B,) change in survival risk
- z_fused: (B, latent_dim) baseline state
- z_pert: (B, pert_dim) perturbation embedding
"""
def __init__(
self,
mulgit_model, # pretrained MuLGITModel (frozen)
config: MuLGITPerturbConfig,
):
super().__init__()
self.config = config
# ── Frozen MuLGIT backbone ───────────────────────────────────
self.mulgit = mulgit_model
self.mulgit_latent_dim = mulgit_model.latent_dim
# Freeze ALL MuLGIT parameters
for p in self.mulgit.parameters():
p.requires_grad = False
self.mulgit.eval()
# ── Trainable perturbation components ────────────────────────
self.pert_encoder = create_perturbation_encoder(config)
self.film = create_film_conditioning(config)
self.delta_decoder = create_delta_decoder(config)
# Pathway decoder
self.pathway_decoder = PathwayDecoder(
n_genes=config.n_output_genes,
)
# Risk shift head
self.risk_head = RiskShiftHead(
state_dim=config.mulgit_latent_dim,
pert_dim=config.pert_output_dim,
)
def encode_baseline(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Encode baseline omics through frozen MuLGIT.
Returns the fused latent state and intermediates.
"""
with torch.no_grad():
outputs = self.mulgit(
methylation, cnv, mrna, mirna,
return_intermediates=True,
)
return outputs
def forward(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
smiles_list: Optional[List[str]] = None,
gene_ids: Optional[torch.LongTensor] = None,
pert_types: Optional[List[str]] = None,
return_all: bool = True,
) -> Dict[str, torch.Tensor]:
"""
Forward pass: baseline + perturbation → predicted response.
Args:
methylation: (B, D_methylation)
cnv: (B, D_cnv)
mrna: (B, D_mrna) — also used as baseline expression for y_post = mrna + delta
mirna: (B, D_mirna)
smiles_list: list of SMILES strings (length B), or None
gene_ids: (B,) LongTensor of gene indices, or None
pert_types: list of perturbation type strings, or None
return_all: whether to return all intermediate representations
Returns:
dict with keys: delta, sigma2, y_post, pathway_scores, risk_shift,
z_fused, z_pert, risk_baseline, risk_post
"""
# 1. Encode baseline through frozen MuLGIT
mulgit_outputs = self.encode_baseline(methylation, cnv, mrna, mirna)
z_fused = mulgit_outputs["fused"] # (B, latent_dim)
risk_baseline = mulgit_outputs["risk"] # (B,)
# 2. Encode perturbation
z_pert = self.pert_encoder(smiles_list, gene_ids, pert_types) # (B, pert_dim)
# 3. Condition baseline state on perturbation
z_cond = self.film(z_fused, z_pert) # (B, latent_dim)
# 4. Predict Δ and uncertainty
delta, sigma2 = self.delta_decoder(z_cond) # (B, n_genes), (B, n_genes)
# 5. Predicted post-perturbation expression
# mrna is the baseline expression; we predict delta from it
# If n_output_genes != n_mrna, we use a subset
y_post = mrna[:, :delta.shape[1]] + delta
# 6. Pathway scores
pathway_scores = self.pathway_decoder(delta) # (B, n_pathways)
# 7. Risk shift
delta_risk = self.risk_head(z_fused, z_pert) # (B,)
risk_post = risk_baseline + delta_risk
outputs = {
"delta": delta,
"sigma2": sigma2,
"y_post": y_post,
"pathway_scores": pathway_scores,
"delta_risk": delta_risk,
"risk_baseline": risk_baseline,
"risk_post": risk_post,
}
if return_all:
outputs["z_fused"] = z_fused
outputs["z_pert"] = z_pert
outputs["z_cond"] = z_cond
return outputs
def predict_counterfactual(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
smiles_list: Optional[List[str]] = None,
gene_ids: Optional[torch.LongTensor] = None,
pert_types: Optional[List[str]] = None,
confidence_level: float = 0.95,
) -> Dict:
"""
Predict perturbation response with confidence intervals.
Returns a structured dict suitable for downstream analysis,
including per-gene effect sizes with uncertainty.
"""
outputs = self.forward(
methylation, cnv, mrna, mirna,
smiles_list, gene_ids, pert_types,
return_all=False,
)
# Compute confidence intervals from predicted variance
# Assuming Gaussian predictive distribution:
# CI: delta ± z_α/2 * sqrt(sigma2)
from scipy.stats import norm
z_score = norm.ppf(1 - (1 - confidence_level) / 2) # ~1.96 for 95%
delta = outputs["delta"]
sigma2 = outputs["sigma2"]
ci_low = delta - z_score * torch.sqrt(sigma2)
ci_high = delta + z_score * torch.sqrt(sigma2)
outputs["ci_low"] = ci_low
outputs["ci_high"] = ci_high
outputs["confidence_level"] = confidence_level
return outputs
def train(self, mode: bool = True):
"""
Set training mode. MuLGIT backbone stays frozen in eval mode.
"""
super().train(mode)
self.mulgit.eval() # always keep MuLGIT frozen
return self
def create_mulgit_perturb(
mulgit_model,
config: Optional[MuLGITPerturbConfig] = None,
**config_overrides,
) -> MuLGITPerturb:
"""
Factory function to create a MuLGITPerturb model.
Args:
mulgit_model: pretrained MuLGITModel (from mulgit.models)
config: MuLGITPerturbConfig, or None to use defaults
**config_overrides: override specific config values
Returns:
MuLGITPerturb model ready for training
"""
if config is None:
# Infer dimensions from the MuLGIT model
config = MuLGITPerturbConfig(
mulgit_latent_dim=mulgit_model.latent_dim,
)
# Apply overrides
for key, value in config_overrides.items():
if hasattr(config, key):
setattr(config, key, value)
return MuLGITPerturb(mulgit_model, config)