""" Causal Discovery Module Identifies causal genetic factors and molecular interactions underlying exceptional longevity. Combines structural causal models with deep learning-based causal inference. Methods implemented: 1. Causal Feature Selection via Information Bottleneck (Seq2Exp-inspired) - Learn binary masks that identify causal features from each omics layer - Beta distribution prior for sparsity 2. Causal Structure Learning via NOTEARS-inspired DAG constraint - Learn causal graph between molecular features - Differentiable acyclicity constraint 3. Causal Mediation Analysis - Identify mediated effects through the central dogma layers - Decompose total effect into direct and indirect (pathway-mediated) References: - Seq2Exp (arxiv:2502.13991): Causal regulatory element discovery - Avici: Amortized causal structure learning in genomics - NOTEARS: Non-combinatorial Optimization via Trace Exponential Augmented lagRangian Structure learning """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Tuple, Dict # ─── Causal Feature Selection ──────────────────────────────────────────────── class CausalFeatureMask(nn.Module): """ Learns a binary mask over input features identifying causal features. Inspired by Seq2Exp's information bottleneck: uses Beta distribution prior to encourage sparsity. The mask is learned via the concrete/Gumbel-softmax reparameterization for differentiability. """ def __init__( self, num_features: int, prior_alpha: float = 0.1, prior_beta: float = 0.9, temperature: float = 0.5, ): """ Args: num_features: number of input features prior_alpha, prior_beta: Beta distribution parameters (skewed toward 0 to encourage sparse selection) temperature: Gumbel-softmax temperature (lower = more discrete) """ super().__init__() # Learnable logits for each feature's selection probability self.logit_p = nn.Parameter(torch.zeros(num_features)) self.prior_alpha = prior_alpha self.prior_beta = prior_beta self.temperature = temperature def forward(self, training: bool = True) -> torch.Tensor: """ Returns a soft (training) or hard (inference) binary mask. """ if training: # Concrete distribution (Gumbel-softmax) u = torch.rand_like(self.logit_p) gumbel = -torch.log(-torch.log(u + 1e-8) + 1e-8) logits = (self.logit_p + gumbel) / self.temperature mask = torch.sigmoid(logits) else: # Hard threshold at 0.5 mask = (torch.sigmoid(self.logit_p) > 0.5).float() return mask def sparsity_loss(self) -> torch.Tensor: """ Sparsity regularization: penalize large selection probabilities. Uses L1 norm of sigmoid(logit) to encourage zeros. """ p = torch.sigmoid(self.logit_p) # L1 penalty: encourages p → 0 for non-causal features return p.mean() class CausalOmicsSelector(nn.Module): """ Per-modality causal feature selection. Selects which features from each omics layer are causal for the outcome. """ def __init__( self, modality_dims: Dict[str, int], prior_alpha: float = 0.1, prior_beta: float = 0.9, ): super().__init__() self.masks = nn.ModuleDict({ name: CausalFeatureMask(dim, prior_alpha, prior_beta) for name, dim in modality_dims.items() }) self.modality_dims = modality_dims def forward( self, modalities: Dict[str, torch.Tensor], training: bool = True ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Apply causal masks to each modality. Returns: selected: masked features per modality masks: learned masks per modality """ selected = {} masks = {} for name, x in modalities.items(): mask = self.masks[name](training=training) selected[name] = x * mask.unsqueeze(0) # broadcast over batch masks[name] = mask return selected, masks def total_sparsity_loss(self) -> torch.Tensor: """Sum of sparsity losses across all modalities.""" return sum(self.masks[name].sparsity_loss() for name in self.masks) # ─── Causal Graph Structure Learning ──────────────────────────────────────── class CausalGraphLearner(nn.Module): """ Learns a causal graph (DAG) between a set of latent variables using a differentiable acyclicity constraint (NOTEARS-inspired). Adapted for molecular features: the learned adjacency matrix represents causal relationships between latent molecular representations. """ def __init__( self, num_variables: int, hidden_dim: int = 64, lambda_dag: float = 1.0, ): """ Args: num_variables: number of variables in the causal graph hidden_dim: dimension of each variable's representation lambda_dag: weight for the DAG constraint """ super().__init__() # Learnable adjacency matrix (causal strengths) self.W = nn.Parameter(torch.zeros(num_variables, num_variables)) self.num_variables = num_variables self.lambda_dag = lambda_dag nn.init.xavier_normal_(self.W) def forward(self) -> torch.Tensor: """Returns the learned weighted adjacency matrix.""" return self.W def dag_constraint(self) -> torch.Tensor: """ Differentiable DAG constraint (NOTEARS formulation). trace(exp(W * W)) - d = 0 iff W is a DAG. """ W = self.W * self.W # element-wise square for non-negativity M = torch.matrix_exp(W) # matrix exponential h = torch.trace(M) - self.num_variables return h * h # squared to ensure non-negative loss def causal_effects(self) -> torch.Tensor: """ Compute total causal effects using the learned adjacency. For linear SEM: total effect = (I - W)^(-1) """ W = self.W I = torch.eye(self.num_variables, device=W.device) total_effects = torch.linalg.inv(I - W) return total_effects # ─── Mediation Analysis ───────────────────────────────────────────────────── class CausalMediationAnalyzer(nn.Module): """ Analyzes causal mediation through the central dogma layers. For the path DNA → RNA → Protein → Phenotype, decomposes the total effect of a DNA feature on longevity into: - Direct effect (DNA → Phenotype, bypassing intermediates) - Indirect effects (DNA → RNA → Phenotype, DNA → RNA → Protein → Phenotype) This maps to the MuLGIT central dogma architecture. """ def __init__( self, dna_dim: int, rna_dim: int, protein_dim: int, ): super().__init__() # Path-specific coefficients self.dna_to_phenotype = nn.Linear(dna_dim, 1, bias=False) # direct self.dna_to_rna = nn.Linear(dna_dim, rna_dim, bias=False) # path 1 self.rna_to_phenotype = nn.Linear(rna_dim, 1, bias=False) # path 2 self.rna_to_protein = nn.Linear(rna_dim, protein_dim, bias=False) # path 3 self.protein_to_phenotype = nn.Linear(protein_dim, 1, bias=False) # path 4 def decompose_effect( self, dna_features: torch.Tensor, rna_features: Optional[torch.Tensor] = None, protein_features: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: """ Decompose total effect into direct and pathway-mediated effects. Returns dict with: total_effect: combined effect on phenotype direct_effect: DNA → Phenotype (bypassing RNA/protein) dna_rna_effect: DNA → RNA → Phenotype dna_rna_protein_effect: DNA → RNA → Protein → Phenotype """ direct = self.dna_to_phenotype(dna_features) rna_pred = self.dna_to_rna(dna_features) rna_effect = self.rna_to_phenotype(rna_pred) protein_pred = self.rna_to_protein(rna_pred) protein_effect = self.protein_to_phenotype(protein_pred) total = direct + rna_effect + protein_effect return { "total_effect": total, "direct_effect": direct, "dna_to_rna_effect": rna_effect, "dna_to_rna_to_protein_effect": protein_effect, } # ─── Causal Attribution ───────────────────────────────────────────────────── def compute_feature_attribution( model: nn.Module, input_modalities: Dict[str, torch.Tensor], target: int = 0, n_steps: int = 20, ) -> Dict[str, torch.Tensor]: """ Integrated Gradients-style causal attribution. Computes the contribution of each feature to the predicted risk score by integrating gradients along the path from baseline (zero) to input. """ attributions = {} for name, x in input_modalities.items(): baseline = torch.zeros_like(x) integrated_grad = torch.zeros_like(x) for alpha in torch.linspace(0, 1, n_steps): interpolated = baseline + alpha * (x - baseline) interpolated.requires_grad_(True) # Construct full input dict full_input = {k: v for k, v in input_modalities.items()} full_input[name] = interpolated # Forward pass output = model(**full_input) risk = output["risk"] # Gradient of risk w.r.t. interpolated input grad = torch.autograd.grad(risk.sum(), interpolated)[0] integrated_grad += grad.detach() # Average and multiply by (input - baseline) attributions[name] = (x - baseline) * (integrated_grad / n_steps) return attributions def identify_causal_features( attributions: Dict[str, torch.Tensor], top_k: int = 100, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: """ Identify top causal features from attribution scores. Returns dict mapping modality name to (top_indices, top_scores). """ results = {} for name, attr in attributions.items(): # Average attribution across batch mean_attr = attr.abs().mean(dim=0) # Get top-k features top_scores, top_indices = torch.topk(mean_attr, k=min(top_k, mean_attr.shape[0])) results[name] = (top_indices, top_scores) return results