| """ |
| MuLGIT Core Models |
| |
| Implements self-normalizing multi-omics integration networks based on SeNMo |
| paper (2405.08226) with central dogma-inspired layer architecture. |
| |
| Architecture design: |
| Layer 1 (DNAβRNA): Methylation β Gene Expression (epigenetic regulation) |
| Layer 2 (RNAβProtein): Gene Expression β Protein Expression (translation) |
| Layer 3 (ProteinβMetabolite): Protein β Metabolite (metabolic pathways) |
| Layer 4 (MetaboliteβPhenotype): All upstream β Survival/Longevity outcome |
| |
| Each "layer" is a self-normalizing feed-forward block with SELU activations |
| and AlphaDropout for training stability on high-dimensional low-sample data. |
| |
| References: |
| - SeNMo: Self-Normalizing Multi-Omics (arxiv:2405.08226) |
| - Self-Normalizing Neural Networks (Klambauer et al., NeurIPS 2017) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Dict, List, Tuple |
| from dataclasses import dataclass |
|
|
|
|
| |
|
|
| class SNNBlock(nn.Module): |
| """ |
| Single self-normalizing block: Linear β SELU β AlphaDropout. |
| Matches SeNMo architecture from paper. |
| """ |
| def __init__(self, in_features: int, out_features: int, dropout: float = 0.1): |
| super().__init__() |
| self.linear = nn.Linear(in_features, out_features) |
| self.selu = nn.SELU() |
| self.dropout = nn.AlphaDropout(dropout) |
| |
| |
| nn.init.normal_(self.linear.weight, mean=0.0, std=1.0 / (in_features ** 0.5)) |
| nn.init.zeros_(self.linear.bias) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.dropout(self.selu(self.linear(x))) |
|
|
|
|
| class SNNStack(nn.Module): |
| """ |
| Stack of SNNBlocks forming a deep self-normalizing network, following |
| SeNMo's layer dimensions: [1024, 512, 256, 128, 48, 48, 48]. |
| """ |
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dims: List[int] = None, |
| output_dim: int = 48, |
| dropout: float = 0.1058, |
| ): |
| super().__init__() |
| if hidden_dims is None: |
| hidden_dims = [1024, 512, 256, 128] |
| |
| layers = [] |
| prev_dim = input_dim |
| |
| for h_dim in hidden_dims: |
| layers.append(SNNBlock(prev_dim, h_dim, dropout)) |
| prev_dim = h_dim |
| |
| |
| layers.append(SNNBlock(prev_dim, output_dim, dropout)) |
| |
| self.network = nn.Sequential(*layers) |
| self.output_dim = output_dim |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.network(x) |
|
|
|
|
| |
|
|
| class OmicsEncoder(nn.Module): |
| """ |
| Encodes a single omics modality through SNN stack into a compact |
| latent representation (48-dim, following SeNMo). |
| """ |
| def __init__( |
| self, |
| input_dim: int, |
| hidden_dims: Optional[List[int]] = None, |
| latent_dim: int = 48, |
| dropout: float = 0.1058, |
| ): |
| super().__init__() |
| self.input_dim = input_dim |
| self.latent_dim = latent_dim |
| self.encoder = SNNStack( |
| input_dim=input_dim, |
| hidden_dims=hidden_dims or [1024, 512, 256, 128], |
| output_dim=latent_dim, |
| dropout=dropout, |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.encoder(x) |
|
|
|
|
| |
|
|
| class CentralDogmaFusion(nn.Module): |
| """ |
| Multi-layer fusion following the central dogma: |
| [Methylation, CNV] β [Gene Expression] β [miRNA] β [Protein] β Phenotype |
| |
| Each omics modality is separately encoded, then fused progressively |
| through the biological information flow. |
| |
| Design inspired by Life-Code's sequential modeling + SeNMo's SNN fusion. |
| """ |
| def __init__( |
| self, |
| dim_methylation: int, |
| dim_cnv: int, |
| dim_mrna: int, |
| dim_mirna: int, |
| latent_dim: int = 48, |
| dropout: float = 0.1058, |
| ): |
| super().__init__() |
| |
| self.methylation_encoder = OmicsEncoder(dim_methylation, latent_dim=latent_dim, dropout=dropout) |
| self.cnv_encoder = OmicsEncoder(dim_cnv, latent_dim=latent_dim, dropout=dropout) |
| self.mrna_encoder = OmicsEncoder(dim_mrna, latent_dim=latent_dim, dropout=dropout) |
| self.mirna_encoder = OmicsEncoder(dim_mirna, latent_dim=latent_dim, dropout=dropout) |
| |
| |
| |
| dna_fusion_dim = latent_dim * 2 |
| self.dna_to_rna = SNNBlock(dna_fusion_dim + latent_dim, latent_dim, dropout) |
| |
| |
| full_fusion_dim = latent_dim * 3 |
| self.final_fusion = SNNStack( |
| input_dim=full_fusion_dim, |
| hidden_dims=[256, 128], |
| output_dim=latent_dim, |
| dropout=dropout, |
| ) |
|
|
| def forward( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
| """ |
| Returns: |
| fused: final fused representation (B, latent_dim) |
| intermediates: dict of intermediate representations |
| """ |
| |
| z_meth = self.methylation_encoder(methylation) |
| z_cnv = self.cnv_encoder(cnv) |
| z_mrna = self.mrna_encoder(mrna) |
| z_mirna = self.mirna_encoder(mirna) |
| |
| |
| z_dna = torch.cat([z_meth, z_cnv], dim=-1) |
| |
| |
| z_dna_rna = torch.cat([z_dna, z_mrna], dim=-1) |
| z_dna_to_rna = self.dna_to_rna(z_dna_rna) |
| |
| |
| z_full = torch.cat([z_dna_to_rna, z_mrna, z_mirna], dim=-1) |
| z_fused = self.final_fusion(z_full) |
| |
| intermediates = { |
| "methylation": z_meth, |
| "cnv": z_cnv, |
| "mrna": z_mrna, |
| "mirna": z_mirna, |
| "dna_fused": z_dna, |
| "dna_to_rna": z_dna_to_rna, |
| } |
| |
| return z_fused, intermediates |
|
|
|
|
| |
|
|
| class CoxPredictionHead(nn.Module): |
| """ |
| Outputs a risk score (log hazard ratio) from the fused representation. |
| Trained with Cox proportional hazards loss. |
| |
| Following SeNMo: single linear layer produces risk score. |
| """ |
| def __init__(self, input_dim: int = 48): |
| super().__init__() |
| self.risk_layer = nn.Linear(input_dim, 1, bias=False) |
| nn.init.normal_(self.risk_layer.weight, mean=0.0, std=0.001) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.risk_layer(x).squeeze(-1) |
|
|
|
|
| |
|
|
| class MuLGITModel(nn.Module): |
| """ |
| Full MuLGIT model: Multi-layer Genotype Integration Transformer. |
| |
| Combines: |
| - CentralDogmaFusion: cross-layer multi-omics integration |
| - CoxPredictionHead: survival/longevity risk prediction |
| |
| Training: Cox proportional hazards loss for survival outcomes. |
| Can also be used with binary cross-entropy for longevity classification. |
| """ |
| def __init__( |
| self, |
| dim_methylation: int, |
| dim_cnv: int, |
| dim_mrna: int, |
| dim_mirna: int, |
| latent_dim: int = 48, |
| dropout: float = 0.1058, |
| num_classes: int = 1, |
| ): |
| super().__init__() |
| self.fusion = CentralDogmaFusion( |
| dim_methylation=dim_methylation, |
| dim_cnv=dim_cnv, |
| dim_mrna=dim_mrna, |
| dim_mirna=dim_mirna, |
| latent_dim=latent_dim, |
| dropout=dropout, |
| ) |
| |
| |
| self.cox_head = CoxPredictionHead(latent_dim) |
| |
| |
| if num_classes > 1: |
| self.classifier = nn.Linear(latent_dim, num_classes) |
| else: |
| self.classifier = None |
| |
| self.latent_dim = latent_dim |
| |
| def forward( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| return_intermediates: bool = False, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Forward pass through the central dogma pipeline. |
| |
| Args: |
| methylation: (B, D_meth) |
| cnv: (B, D_cnv) |
| mrna: (B, D_mrna) |
| mirna: (B, D_mirna) |
| return_intermediates: if True, return all layer representations |
| |
| Returns dict with: |
| risk: predicted risk score (B,) |
| logits: class logits if classifier present (B, num_classes) |
| fused: final fused representation (B, latent_dim) |
| intermediates: dict of per-layer representations (if requested) |
| """ |
| fused, intermediates = self.fusion(methylation, cnv, mrna, mirna) |
| |
| outputs = { |
| "risk": self.cox_head(fused), |
| "fused": fused, |
| } |
| |
| if self.classifier is not None: |
| outputs["logits"] = self.classifier(fused) |
| |
| if return_intermediates: |
| outputs["intermediates"] = intermediates |
| |
| return outputs |
|
|
|
|
| |
|
|
| @dataclass |
| class MuLGITConfig: |
| """Configuration for MuLGIT model.""" |
| dim_methylation: int |
| dim_cnv: int |
| dim_mrna: int |
| dim_mirna: int |
| latent_dim: int = 48 |
| dropout: float = 0.1058 |
| num_classes: int = 1 |
| learning_rate: float = 5.8e-4 |
| weight_decay: float = 0.00598 |
| batch_size: int = 256 |
| max_epochs: int = 100 |
|
|
|
|
| |
|
|
| def create_mulgit_model( |
| dim_methylation: int = 20000, |
| dim_cnv: int = 20000, |
| dim_mrna: int = 20000, |
| dim_mirna: int = 2000, |
| latent_dim: int = 48, |
| dropout: float = 0.1058, |
| **kwargs, |
| ) -> MuLGITModel: |
| """Create a MuLGIT model with sensible defaults.""" |
| return MuLGITModel( |
| dim_methylation=dim_methylation, |
| dim_cnv=dim_cnv, |
| dim_mrna=dim_mrna, |
| dim_mirna=dim_mirna, |
| latent_dim=latent_dim, |
| dropout=dropout, |
| **kwargs, |
| ) |
|
|
|
|
| |
|
|
| class SeNMoFusion(nn.Module): |
| """ |
| Direct multi-omics fusion as in SeNMo: concatenate all modalities β SNN. |
| Simpler but doesn't model the central dogma flow explicitly. |
| Good baseline to compare against MuLGIT. |
| """ |
| def __init__( |
| self, |
| dim_methylation: int, |
| dim_cnv: int, |
| dim_mrna: int, |
| dim_mirna: int, |
| latent_dim: int = 48, |
| dropout: float = 0.1058, |
| ): |
| super().__init__() |
| total_dim = dim_methylation + dim_cnv + dim_mrna + dim_mirna |
| hidden_dims = [1024, 512, 256, 128] |
| self.encoder = SNNStack(total_dim, hidden_dims, latent_dim, dropout) |
| self.cox_head = CoxPredictionHead(latent_dim) |
|
|
| def forward( |
| self, |
| methylation: torch.Tensor, |
| cnv: torch.Tensor, |
| mrna: torch.Tensor, |
| mirna: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| x = torch.cat([methylation, cnv, mrna, mirna], dim=-1) |
| fused = self.encoder(x) |
| return {"risk": self.cox_head(fused), "fused": fused} |
|
|