| """ |
| Causal Discovery Module |
| |
| Identifies causal genetic factors and molecular interactions underlying |
| exceptional longevity. Combines structural causal models with deep |
| learning-based causal inference. |
| |
| Methods implemented: |
| 1. Causal Feature Selection via Information Bottleneck (Seq2Exp-inspired) |
| - Learn binary masks that identify causal features from each omics layer |
| - Beta distribution prior for sparsity |
| |
| 2. Causal Structure Learning via NOTEARS-inspired DAG constraint |
| - Learn causal graph between molecular features |
| - Differentiable acyclicity constraint |
| |
| 3. Causal Mediation Analysis |
| - Identify mediated effects through the central dogma layers |
| - Decompose total effect into direct and indirect (pathway-mediated) |
| |
| References: |
| - Seq2Exp (arxiv:2502.13991): Causal regulatory element discovery |
| - Avici: Amortized causal structure learning in genomics |
| - NOTEARS: Non-combinatorial Optimization via Trace Exponential |
| Augmented lagRangian Structure learning |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, List, Tuple, Dict |
|
|
|
|
| |
|
|
| class CausalFeatureMask(nn.Module): |
| """ |
| Learns a binary mask over input features identifying causal features. |
| |
| Inspired by Seq2Exp's information bottleneck: uses Beta distribution |
| prior to encourage sparsity. The mask is learned via the |
| concrete/Gumbel-softmax reparameterization for differentiability. |
| """ |
| |
| def __init__( |
| self, |
| num_features: int, |
| prior_alpha: float = 0.1, |
| prior_beta: float = 0.9, |
| temperature: float = 0.5, |
| ): |
| """ |
| Args: |
| num_features: number of input features |
| prior_alpha, prior_beta: Beta distribution parameters (skewed |
| toward 0 to encourage sparse selection) |
| temperature: Gumbel-softmax temperature (lower = more discrete) |
| """ |
| super().__init__() |
| |
| self.logit_p = nn.Parameter(torch.zeros(num_features)) |
| self.prior_alpha = prior_alpha |
| self.prior_beta = prior_beta |
| self.temperature = temperature |
|
|
| def forward(self, training: bool = True) -> torch.Tensor: |
| """ |
| Returns a soft (training) or hard (inference) binary mask. |
| """ |
| if training: |
| |
| u = torch.rand_like(self.logit_p) |
| gumbel = -torch.log(-torch.log(u + 1e-8) + 1e-8) |
| logits = (self.logit_p + gumbel) / self.temperature |
| mask = torch.sigmoid(logits) |
| else: |
| |
| mask = (torch.sigmoid(self.logit_p) > 0.5).float() |
| return mask |
|
|
| def sparsity_loss(self) -> torch.Tensor: |
| """ |
| Sparsity regularization: penalize large selection probabilities. |
| Uses L1 norm of sigmoid(logit) to encourage zeros. |
| """ |
| p = torch.sigmoid(self.logit_p) |
| |
| return p.mean() |
|
|
|
|
| class CausalOmicsSelector(nn.Module): |
| """ |
| Per-modality causal feature selection. |
| Selects which features from each omics layer are causal for the outcome. |
| """ |
| |
| def __init__( |
| self, |
| modality_dims: Dict[str, int], |
| prior_alpha: float = 0.1, |
| prior_beta: float = 0.9, |
| ): |
| super().__init__() |
| self.masks = nn.ModuleDict({ |
| name: CausalFeatureMask(dim, prior_alpha, prior_beta) |
| for name, dim in modality_dims.items() |
| }) |
| self.modality_dims = modality_dims |
|
|
| def forward( |
| self, modalities: Dict[str, torch.Tensor], training: bool = True |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: |
| """ |
| Apply causal masks to each modality. |
| |
| Returns: |
| selected: masked features per modality |
| masks: learned masks per modality |
| """ |
| selected = {} |
| masks = {} |
| for name, x in modalities.items(): |
| mask = self.masks[name](training=training) |
| selected[name] = x * mask.unsqueeze(0) |
| masks[name] = mask |
| return selected, masks |
|
|
| def total_sparsity_loss(self) -> torch.Tensor: |
| """Sum of sparsity losses across all modalities.""" |
| return sum(self.masks[name].sparsity_loss() for name in self.masks) |
|
|
|
|
| |
|
|
| class CausalGraphLearner(nn.Module): |
| """ |
| Learns a causal graph (DAG) between a set of latent variables using |
| a differentiable acyclicity constraint (NOTEARS-inspired). |
| |
| Adapted for molecular features: the learned adjacency matrix represents |
| causal relationships between latent molecular representations. |
| """ |
| |
| def __init__( |
| self, |
| num_variables: int, |
| hidden_dim: int = 64, |
| lambda_dag: float = 1.0, |
| ): |
| """ |
| Args: |
| num_variables: number of variables in the causal graph |
| hidden_dim: dimension of each variable's representation |
| lambda_dag: weight for the DAG constraint |
| """ |
| super().__init__() |
| |
| self.W = nn.Parameter(torch.zeros(num_variables, num_variables)) |
| self.num_variables = num_variables |
| self.lambda_dag = lambda_dag |
| nn.init.xavier_normal_(self.W) |
|
|
| def forward(self) -> torch.Tensor: |
| """Returns the learned weighted adjacency matrix.""" |
| return self.W |
|
|
| def dag_constraint(self) -> torch.Tensor: |
| """ |
| Differentiable DAG constraint (NOTEARS formulation). |
| trace(exp(W * W)) - d = 0 iff W is a DAG. |
| """ |
| W = self.W * self.W |
| M = torch.matrix_exp(W) |
| h = torch.trace(M) - self.num_variables |
| return h * h |
|
|
| def causal_effects(self) -> torch.Tensor: |
| """ |
| Compute total causal effects using the learned adjacency. |
| For linear SEM: total effect = (I - W)^(-1) |
| """ |
| W = self.W |
| I = torch.eye(self.num_variables, device=W.device) |
| total_effects = torch.linalg.inv(I - W) |
| return total_effects |
|
|
|
|
| |
|
|
| class CausalMediationAnalyzer(nn.Module): |
| """ |
| Analyzes causal mediation through the central dogma layers. |
| |
| For the path DNA β RNA β Protein β Phenotype, decomposes the total |
| effect of a DNA feature on longevity into: |
| - Direct effect (DNA β Phenotype, bypassing intermediates) |
| - Indirect effects (DNA β RNA β Phenotype, DNA β RNA β Protein β Phenotype) |
| |
| This maps to the MuLGIT central dogma architecture. |
| """ |
| |
| def __init__( |
| self, |
| dna_dim: int, |
| rna_dim: int, |
| protein_dim: int, |
| ): |
| super().__init__() |
| |
| self.dna_to_phenotype = nn.Linear(dna_dim, 1, bias=False) |
| self.dna_to_rna = nn.Linear(dna_dim, rna_dim, bias=False) |
| self.rna_to_phenotype = nn.Linear(rna_dim, 1, bias=False) |
| self.rna_to_protein = nn.Linear(rna_dim, protein_dim, bias=False) |
| self.protein_to_phenotype = nn.Linear(protein_dim, 1, bias=False) |
|
|
| def decompose_effect( |
| self, |
| dna_features: torch.Tensor, |
| rna_features: Optional[torch.Tensor] = None, |
| protein_features: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Decompose total effect into direct and pathway-mediated effects. |
| |
| Returns dict with: |
| total_effect: combined effect on phenotype |
| direct_effect: DNA β Phenotype (bypassing RNA/protein) |
| dna_rna_effect: DNA β RNA β Phenotype |
| dna_rna_protein_effect: DNA β RNA β Protein β Phenotype |
| """ |
| direct = self.dna_to_phenotype(dna_features) |
| |
| rna_pred = self.dna_to_rna(dna_features) |
| rna_effect = self.rna_to_phenotype(rna_pred) |
| |
| protein_pred = self.rna_to_protein(rna_pred) |
| protein_effect = self.protein_to_phenotype(protein_pred) |
| |
| total = direct + rna_effect + protein_effect |
| |
| return { |
| "total_effect": total, |
| "direct_effect": direct, |
| "dna_to_rna_effect": rna_effect, |
| "dna_to_rna_to_protein_effect": protein_effect, |
| } |
|
|
|
|
| |
|
|
| def compute_feature_attribution( |
| model: nn.Module, |
| input_modalities: Dict[str, torch.Tensor], |
| target: int = 0, |
| n_steps: int = 20, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Integrated Gradients-style causal attribution. |
| |
| Computes the contribution of each feature to the predicted risk score |
| by integrating gradients along the path from baseline (zero) to input. |
| """ |
| attributions = {} |
| |
| for name, x in input_modalities.items(): |
| baseline = torch.zeros_like(x) |
| integrated_grad = torch.zeros_like(x) |
| |
| for alpha in torch.linspace(0, 1, n_steps): |
| interpolated = baseline + alpha * (x - baseline) |
| interpolated.requires_grad_(True) |
| |
| |
| full_input = {k: v for k, v in input_modalities.items()} |
| full_input[name] = interpolated |
| |
| |
| output = model(**full_input) |
| risk = output["risk"] |
| |
| |
| grad = torch.autograd.grad(risk.sum(), interpolated)[0] |
| integrated_grad += grad.detach() |
| |
| |
| attributions[name] = (x - baseline) * (integrated_grad / n_steps) |
| |
| return attributions |
|
|
|
|
| def identify_causal_features( |
| attributions: Dict[str, torch.Tensor], |
| top_k: int = 100, |
| ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Identify top causal features from attribution scores. |
| |
| Returns dict mapping modality name to (top_indices, top_scores). |
| """ |
| results = {} |
| for name, attr in attributions.items(): |
| |
| mean_attr = attr.abs().mean(dim=0) |
| |
| top_scores, top_indices = torch.topk(mean_attr, k=min(top_k, mean_attr.shape[0])) |
| results[name] = (top_indices, top_scores) |
| return results |
|
|