""" MuLGIT Core Models Implements self-normalizing multi-omics integration networks based on SeNMo paper (2405.08226) with central dogma-inspired layer architecture. Architecture design: Layer 1 (DNA→RNA): Methylation → Gene Expression (epigenetic regulation) Layer 2 (RNA→Protein): Gene Expression → Protein Expression (translation) Layer 3 (Protein→Metabolite): Protein → Metabolite (metabolic pathways) Layer 4 (Metabolite→Phenotype): All upstream → Survival/Longevity outcome Each "layer" is a self-normalizing feed-forward block with SELU activations and AlphaDropout for training stability on high-dimensional low-sample data. References: - SeNMo: Self-Normalizing Multi-Omics (arxiv:2405.08226) - Self-Normalizing Neural Networks (Klambauer et al., NeurIPS 2017) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict, List, Tuple from dataclasses import dataclass # ─── Self-Normalizing Building Blocks ─────────────────────────────────────── class SNNBlock(nn.Module): """ Single self-normalizing block: Linear → SELU → AlphaDropout. Matches SeNMo architecture from paper. """ def __init__(self, in_features: int, out_features: int, dropout: float = 0.1): super().__init__() self.linear = nn.Linear(in_features, out_features) self.selu = nn.SELU() self.dropout = nn.AlphaDropout(dropout) # SNN initialization (LeCun normal for SELU) nn.init.normal_(self.linear.weight, mean=0.0, std=1.0 / (in_features ** 0.5)) nn.init.zeros_(self.linear.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.selu(self.linear(x))) class SNNStack(nn.Module): """ Stack of SNNBlocks forming a deep self-normalizing network, following SeNMo's layer dimensions: [1024, 512, 256, 128, 48, 48, 48]. """ def __init__( self, input_dim: int, hidden_dims: List[int] = None, output_dim: int = 48, dropout: float = 0.1058, ): super().__init__() if hidden_dims is None: hidden_dims = [1024, 512, 256, 128] layers = [] prev_dim = input_dim for h_dim in hidden_dims: layers.append(SNNBlock(prev_dim, h_dim, dropout)) prev_dim = h_dim # Output embedding layer layers.append(SNNBlock(prev_dim, output_dim, dropout)) self.network = nn.Sequential(*layers) self.output_dim = output_dim def forward(self, x: torch.Tensor) -> torch.Tensor: return self.network(x) # ─── Omics-Specific Encoders ──────────────────────────────────────────────── class OmicsEncoder(nn.Module): """ Encodes a single omics modality through SNN stack into a compact latent representation (48-dim, following SeNMo). """ def __init__( self, input_dim: int, hidden_dims: Optional[List[int]] = None, latent_dim: int = 48, dropout: float = 0.1058, ): super().__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.encoder = SNNStack( input_dim=input_dim, hidden_dims=hidden_dims or [1024, 512, 256, 128], output_dim=latent_dim, dropout=dropout, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.encoder(x) # ─── Central Dogma Layer ──────────────────────────────────────────────────── class CentralDogmaFusion(nn.Module): """ Multi-layer fusion following the central dogma: [Methylation, CNV] → [Gene Expression] → [miRNA] → [Protein] → Phenotype Each omics modality is separately encoded, then fused progressively through the biological information flow. Design inspired by Life-Code's sequential modeling + SeNMo's SNN fusion. """ def __init__( self, dim_methylation: int, dim_cnv: int, dim_mrna: int, dim_mirna: int, latent_dim: int = 48, dropout: float = 0.1058, ): super().__init__() # Individual modality encoders self.methylation_encoder = OmicsEncoder(dim_methylation, latent_dim=latent_dim, dropout=dropout) self.cnv_encoder = OmicsEncoder(dim_cnv, latent_dim=latent_dim, dropout=dropout) self.mrna_encoder = OmicsEncoder(dim_mrna, latent_dim=latent_dim, dropout=dropout) self.mirna_encoder = OmicsEncoder(dim_mirna, latent_dim=latent_dim, dropout=dropout) # Cross-layer fusion: DNA-level → RNA-level # [methylation + CNV] fused representation → conditions mRNA encoding dna_fusion_dim = latent_dim * 2 # methylation + CNV self.dna_to_rna = SNNBlock(dna_fusion_dim + latent_dim, latent_dim, dropout) # + mRNA # Full fusion: DNA + RNA → final representation full_fusion_dim = latent_dim * 3 # DNA_fused + mRNA + miRNA self.final_fusion = SNNStack( input_dim=full_fusion_dim, hidden_dims=[256, 128], output_dim=latent_dim, dropout=dropout, ) def forward( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Returns: fused: final fused representation (B, latent_dim) intermediates: dict of intermediate representations """ # Encode each modality z_meth = self.methylation_encoder(methylation) # DNA-level (epigenetic) z_cnv = self.cnv_encoder(cnv) # DNA-level (structural) z_mrna = self.mrna_encoder(mrna) # RNA-level z_mirna = self.mirna_encoder(mirna) # RNA-level (regulatory) # DNA-level fusion z_dna = torch.cat([z_meth, z_cnv], dim=-1) # (B, 2*latent_dim) # DNA→RNA: condition mRNA on DNA-level features z_dna_rna = torch.cat([z_dna, z_mrna], dim=-1) # (B, 3*latent_dim) z_dna_to_rna = self.dna_to_rna(z_dna_rna) # (B, latent_dim) # Full fusion: DNA + RNA + miRNA z_full = torch.cat([z_dna_to_rna, z_mrna, z_mirna], dim=-1) z_fused = self.final_fusion(z_full) intermediates = { "methylation": z_meth, "cnv": z_cnv, "mrna": z_mrna, "mirna": z_mirna, "dna_fused": z_dna, "dna_to_rna": z_dna_to_rna, } return z_fused, intermediates # ─── Survival Prediction Head ─────────────────────────────────────────────── class CoxPredictionHead(nn.Module): """ Outputs a risk score (log hazard ratio) from the fused representation. Trained with Cox proportional hazards loss. Following SeNMo: single linear layer produces risk score. """ def __init__(self, input_dim: int = 48): super().__init__() self.risk_layer = nn.Linear(input_dim, 1, bias=False) nn.init.normal_(self.risk_layer.weight, mean=0.0, std=0.001) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.risk_layer(x).squeeze(-1) # (B,) # ─── Full MuLGIT Model ────────────────────────────────────────────────────── class MuLGITModel(nn.Module): """ Full MuLGIT model: Multi-layer Genotype Integration Transformer. Combines: - CentralDogmaFusion: cross-layer multi-omics integration - CoxPredictionHead: survival/longevity risk prediction Training: Cox proportional hazards loss for survival outcomes. Can also be used with binary cross-entropy for longevity classification. """ def __init__( self, dim_methylation: int, dim_cnv: int, dim_mrna: int, dim_mirna: int, latent_dim: int = 48, dropout: float = 0.1058, num_classes: int = 1, # 1 for Cox regression, >1 for classification ): super().__init__() self.fusion = CentralDogmaFusion( dim_methylation=dim_methylation, dim_cnv=dim_cnv, dim_mrna=dim_mrna, dim_mirna=dim_mirna, latent_dim=latent_dim, dropout=dropout, ) # Survival prediction head self.cox_head = CoxPredictionHead(latent_dim) # Optional classification head (for tumor type etc.) if num_classes > 1: self.classifier = nn.Linear(latent_dim, num_classes) else: self.classifier = None self.latent_dim = latent_dim def forward( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, return_intermediates: bool = False, ) -> Dict[str, torch.Tensor]: """ Forward pass through the central dogma pipeline. Args: methylation: (B, D_meth) cnv: (B, D_cnv) mrna: (B, D_mrna) mirna: (B, D_mirna) return_intermediates: if True, return all layer representations Returns dict with: risk: predicted risk score (B,) logits: class logits if classifier present (B, num_classes) fused: final fused representation (B, latent_dim) intermediates: dict of per-layer representations (if requested) """ fused, intermediates = self.fusion(methylation, cnv, mrna, mirna) outputs = { "risk": self.cox_head(fused), "fused": fused, } if self.classifier is not None: outputs["logits"] = self.classifier(fused) if return_intermediates: outputs["intermediates"] = intermediates return outputs # ─── Model Configuration ──────────────────────────────────────────────────── @dataclass class MuLGITConfig: """Configuration for MuLGIT model.""" dim_methylation: int dim_cnv: int dim_mrna: int dim_mirna: int latent_dim: int = 48 dropout: float = 0.1058 num_classes: int = 1 learning_rate: float = 5.8e-4 # from SeNMo paper weight_decay: float = 0.00598 # from SeNMo paper batch_size: int = 256 # from SeNMo paper max_epochs: int = 100 # ─── Model Factory ────────────────────────────────────────────────────────── def create_mulgit_model( dim_methylation: int = 20000, dim_cnv: int = 20000, dim_mrna: int = 20000, dim_mirna: int = 2000, latent_dim: int = 48, dropout: float = 0.1058, **kwargs, ) -> MuLGITModel: """Create a MuLGIT model with sensible defaults.""" return MuLGITModel( dim_methylation=dim_methylation, dim_cnv=dim_cnv, dim_mrna=dim_mrna, dim_mirna=dim_mirna, latent_dim=latent_dim, dropout=dropout, **kwargs, ) # ─── Simpler Alternative: SeNMo-Style Single Fusion ───────────────────────── class SeNMoFusion(nn.Module): """ Direct multi-omics fusion as in SeNMo: concatenate all modalities → SNN. Simpler but doesn't model the central dogma flow explicitly. Good baseline to compare against MuLGIT. """ def __init__( self, dim_methylation: int, dim_cnv: int, dim_mrna: int, dim_mirna: int, latent_dim: int = 48, dropout: float = 0.1058, ): super().__init__() total_dim = dim_methylation + dim_cnv + dim_mrna + dim_mirna hidden_dims = [1024, 512, 256, 128] self.encoder = SNNStack(total_dim, hidden_dims, latent_dim, dropout) self.cox_head = CoxPredictionHead(latent_dim) def forward( self, methylation: torch.Tensor, cnv: torch.Tensor, mrna: torch.Tensor, mirna: torch.Tensor, ) -> Dict[str, torch.Tensor]: x = torch.cat([methylation, cnv, mrna, mirna], dim=-1) fused = self.encoder(x) return {"risk": self.cox_head(fused), "fused": fused}