MuLGIT / mulgit /models.py
vedatonuryilmaz's picture
Upload mulgit/models.py with huggingface_hub
3ce5098 verified
"""
MuLGIT Core Models
Implements self-normalizing multi-omics integration networks based on SeNMo
paper (2405.08226) with central dogma-inspired layer architecture.
Architecture design:
Layer 1 (DNA→RNA): Methylation → Gene Expression (epigenetic regulation)
Layer 2 (RNA→Protein): Gene Expression → Protein Expression (translation)
Layer 3 (Protein→Metabolite): Protein → Metabolite (metabolic pathways)
Layer 4 (Metabolite→Phenotype): All upstream → Survival/Longevity outcome
Each "layer" is a self-normalizing feed-forward block with SELU activations
and AlphaDropout for training stability on high-dimensional low-sample data.
References:
- SeNMo: Self-Normalizing Multi-Omics (arxiv:2405.08226)
- Self-Normalizing Neural Networks (Klambauer et al., NeurIPS 2017)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple
from dataclasses import dataclass
# ─── Self-Normalizing Building Blocks ───────────────────────────────────────
class SNNBlock(nn.Module):
"""
Single self-normalizing block: Linear β†’ SELU β†’ AlphaDropout.
Matches SeNMo architecture from paper.
"""
def __init__(self, in_features: int, out_features: int, dropout: float = 0.1):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
self.selu = nn.SELU()
self.dropout = nn.AlphaDropout(dropout)
# SNN initialization (LeCun normal for SELU)
nn.init.normal_(self.linear.weight, mean=0.0, std=1.0 / (in_features ** 0.5))
nn.init.zeros_(self.linear.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.selu(self.linear(x)))
class SNNStack(nn.Module):
"""
Stack of SNNBlocks forming a deep self-normalizing network, following
SeNMo's layer dimensions: [1024, 512, 256, 128, 48, 48, 48].
"""
def __init__(
self,
input_dim: int,
hidden_dims: List[int] = None,
output_dim: int = 48,
dropout: float = 0.1058,
):
super().__init__()
if hidden_dims is None:
hidden_dims = [1024, 512, 256, 128]
layers = []
prev_dim = input_dim
for h_dim in hidden_dims:
layers.append(SNNBlock(prev_dim, h_dim, dropout))
prev_dim = h_dim
# Output embedding layer
layers.append(SNNBlock(prev_dim, output_dim, dropout))
self.network = nn.Sequential(*layers)
self.output_dim = output_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.network(x)
# ─── Omics-Specific Encoders ────────────────────────────────────────────────
class OmicsEncoder(nn.Module):
"""
Encodes a single omics modality through SNN stack into a compact
latent representation (48-dim, following SeNMo).
"""
def __init__(
self,
input_dim: int,
hidden_dims: Optional[List[int]] = None,
latent_dim: int = 48,
dropout: float = 0.1058,
):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.encoder = SNNStack(
input_dim=input_dim,
hidden_dims=hidden_dims or [1024, 512, 256, 128],
output_dim=latent_dim,
dropout=dropout,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
# ─── Central Dogma Layer ────────────────────────────────────────────────────
class CentralDogmaFusion(nn.Module):
"""
Multi-layer fusion following the central dogma:
[Methylation, CNV] β†’ [Gene Expression] β†’ [miRNA] β†’ [Protein] β†’ Phenotype
Each omics modality is separately encoded, then fused progressively
through the biological information flow.
Design inspired by Life-Code's sequential modeling + SeNMo's SNN fusion.
"""
def __init__(
self,
dim_methylation: int,
dim_cnv: int,
dim_mrna: int,
dim_mirna: int,
latent_dim: int = 48,
dropout: float = 0.1058,
):
super().__init__()
# Individual modality encoders
self.methylation_encoder = OmicsEncoder(dim_methylation, latent_dim=latent_dim, dropout=dropout)
self.cnv_encoder = OmicsEncoder(dim_cnv, latent_dim=latent_dim, dropout=dropout)
self.mrna_encoder = OmicsEncoder(dim_mrna, latent_dim=latent_dim, dropout=dropout)
self.mirna_encoder = OmicsEncoder(dim_mirna, latent_dim=latent_dim, dropout=dropout)
# Cross-layer fusion: DNA-level β†’ RNA-level
# [methylation + CNV] fused representation β†’ conditions mRNA encoding
dna_fusion_dim = latent_dim * 2 # methylation + CNV
self.dna_to_rna = SNNBlock(dna_fusion_dim + latent_dim, latent_dim, dropout) # + mRNA
# Full fusion: DNA + RNA β†’ final representation
full_fusion_dim = latent_dim * 3 # DNA_fused + mRNA + miRNA
self.final_fusion = SNNStack(
input_dim=full_fusion_dim,
hidden_dims=[256, 128],
output_dim=latent_dim,
dropout=dropout,
)
def forward(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Returns:
fused: final fused representation (B, latent_dim)
intermediates: dict of intermediate representations
"""
# Encode each modality
z_meth = self.methylation_encoder(methylation) # DNA-level (epigenetic)
z_cnv = self.cnv_encoder(cnv) # DNA-level (structural)
z_mrna = self.mrna_encoder(mrna) # RNA-level
z_mirna = self.mirna_encoder(mirna) # RNA-level (regulatory)
# DNA-level fusion
z_dna = torch.cat([z_meth, z_cnv], dim=-1) # (B, 2*latent_dim)
# DNA→RNA: condition mRNA on DNA-level features
z_dna_rna = torch.cat([z_dna, z_mrna], dim=-1) # (B, 3*latent_dim)
z_dna_to_rna = self.dna_to_rna(z_dna_rna) # (B, latent_dim)
# Full fusion: DNA + RNA + miRNA
z_full = torch.cat([z_dna_to_rna, z_mrna, z_mirna], dim=-1)
z_fused = self.final_fusion(z_full)
intermediates = {
"methylation": z_meth,
"cnv": z_cnv,
"mrna": z_mrna,
"mirna": z_mirna,
"dna_fused": z_dna,
"dna_to_rna": z_dna_to_rna,
}
return z_fused, intermediates
# ─── Survival Prediction Head ───────────────────────────────────────────────
class CoxPredictionHead(nn.Module):
"""
Outputs a risk score (log hazard ratio) from the fused representation.
Trained with Cox proportional hazards loss.
Following SeNMo: single linear layer produces risk score.
"""
def __init__(self, input_dim: int = 48):
super().__init__()
self.risk_layer = nn.Linear(input_dim, 1, bias=False)
nn.init.normal_(self.risk_layer.weight, mean=0.0, std=0.001)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.risk_layer(x).squeeze(-1) # (B,)
# ─── Full MuLGIT Model ──────────────────────────────────────────────────────
class MuLGITModel(nn.Module):
"""
Full MuLGIT model: Multi-layer Genotype Integration Transformer.
Combines:
- CentralDogmaFusion: cross-layer multi-omics integration
- CoxPredictionHead: survival/longevity risk prediction
Training: Cox proportional hazards loss for survival outcomes.
Can also be used with binary cross-entropy for longevity classification.
"""
def __init__(
self,
dim_methylation: int,
dim_cnv: int,
dim_mrna: int,
dim_mirna: int,
latent_dim: int = 48,
dropout: float = 0.1058,
num_classes: int = 1, # 1 for Cox regression, >1 for classification
):
super().__init__()
self.fusion = CentralDogmaFusion(
dim_methylation=dim_methylation,
dim_cnv=dim_cnv,
dim_mrna=dim_mrna,
dim_mirna=dim_mirna,
latent_dim=latent_dim,
dropout=dropout,
)
# Survival prediction head
self.cox_head = CoxPredictionHead(latent_dim)
# Optional classification head (for tumor type etc.)
if num_classes > 1:
self.classifier = nn.Linear(latent_dim, num_classes)
else:
self.classifier = None
self.latent_dim = latent_dim
def forward(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
return_intermediates: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Forward pass through the central dogma pipeline.
Args:
methylation: (B, D_meth)
cnv: (B, D_cnv)
mrna: (B, D_mrna)
mirna: (B, D_mirna)
return_intermediates: if True, return all layer representations
Returns dict with:
risk: predicted risk score (B,)
logits: class logits if classifier present (B, num_classes)
fused: final fused representation (B, latent_dim)
intermediates: dict of per-layer representations (if requested)
"""
fused, intermediates = self.fusion(methylation, cnv, mrna, mirna)
outputs = {
"risk": self.cox_head(fused),
"fused": fused,
}
if self.classifier is not None:
outputs["logits"] = self.classifier(fused)
if return_intermediates:
outputs["intermediates"] = intermediates
return outputs
# ─── Model Configuration ────────────────────────────────────────────────────
@dataclass
class MuLGITConfig:
"""Configuration for MuLGIT model."""
dim_methylation: int
dim_cnv: int
dim_mrna: int
dim_mirna: int
latent_dim: int = 48
dropout: float = 0.1058
num_classes: int = 1
learning_rate: float = 5.8e-4 # from SeNMo paper
weight_decay: float = 0.00598 # from SeNMo paper
batch_size: int = 256 # from SeNMo paper
max_epochs: int = 100
# ─── Model Factory ──────────────────────────────────────────────────────────
def create_mulgit_model(
dim_methylation: int = 20000,
dim_cnv: int = 20000,
dim_mrna: int = 20000,
dim_mirna: int = 2000,
latent_dim: int = 48,
dropout: float = 0.1058,
**kwargs,
) -> MuLGITModel:
"""Create a MuLGIT model with sensible defaults."""
return MuLGITModel(
dim_methylation=dim_methylation,
dim_cnv=dim_cnv,
dim_mrna=dim_mrna,
dim_mirna=dim_mirna,
latent_dim=latent_dim,
dropout=dropout,
**kwargs,
)
# ─── Simpler Alternative: SeNMo-Style Single Fusion ─────────────────────────
class SeNMoFusion(nn.Module):
"""
Direct multi-omics fusion as in SeNMo: concatenate all modalities β†’ SNN.
Simpler but doesn't model the central dogma flow explicitly.
Good baseline to compare against MuLGIT.
"""
def __init__(
self,
dim_methylation: int,
dim_cnv: int,
dim_mrna: int,
dim_mirna: int,
latent_dim: int = 48,
dropout: float = 0.1058,
):
super().__init__()
total_dim = dim_methylation + dim_cnv + dim_mrna + dim_mirna
hidden_dims = [1024, 512, 256, 128]
self.encoder = SNNStack(total_dim, hidden_dims, latent_dim, dropout)
self.cox_head = CoxPredictionHead(latent_dim)
def forward(
self,
methylation: torch.Tensor,
cnv: torch.Tensor,
mrna: torch.Tensor,
mirna: torch.Tensor,
) -> Dict[str, torch.Tensor]:
x = torch.cat([methylation, cnv, mrna, mirna], dim=-1)
fused = self.encoder(x)
return {"risk": self.cox_head(fused), "fused": fused}