MuLGIT / mulgit /causal.py
vedatonuryilmaz's picture
Upload mulgit/causal.py with huggingface_hub
28ff549 verified
"""
Causal Discovery Module
Identifies causal genetic factors and molecular interactions underlying
exceptional longevity. Combines structural causal models with deep
learning-based causal inference.
Methods implemented:
1. Causal Feature Selection via Information Bottleneck (Seq2Exp-inspired)
- Learn binary masks that identify causal features from each omics layer
- Beta distribution prior for sparsity
2. Causal Structure Learning via NOTEARS-inspired DAG constraint
- Learn causal graph between molecular features
- Differentiable acyclicity constraint
3. Causal Mediation Analysis
- Identify mediated effects through the central dogma layers
- Decompose total effect into direct and indirect (pathway-mediated)
References:
- Seq2Exp (arxiv:2502.13991): Causal regulatory element discovery
- Avici: Amortized causal structure learning in genomics
- NOTEARS: Non-combinatorial Optimization via Trace Exponential
Augmented lagRangian Structure learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, List, Tuple, Dict
# ─── Causal Feature Selection ────────────────────────────────────────────────
class CausalFeatureMask(nn.Module):
"""
Learns a binary mask over input features identifying causal features.
Inspired by Seq2Exp's information bottleneck: uses Beta distribution
prior to encourage sparsity. The mask is learned via the
concrete/Gumbel-softmax reparameterization for differentiability.
"""
def __init__(
self,
num_features: int,
prior_alpha: float = 0.1,
prior_beta: float = 0.9,
temperature: float = 0.5,
):
"""
Args:
num_features: number of input features
prior_alpha, prior_beta: Beta distribution parameters (skewed
toward 0 to encourage sparse selection)
temperature: Gumbel-softmax temperature (lower = more discrete)
"""
super().__init__()
# Learnable logits for each feature's selection probability
self.logit_p = nn.Parameter(torch.zeros(num_features))
self.prior_alpha = prior_alpha
self.prior_beta = prior_beta
self.temperature = temperature
def forward(self, training: bool = True) -> torch.Tensor:
"""
Returns a soft (training) or hard (inference) binary mask.
"""
if training:
# Concrete distribution (Gumbel-softmax)
u = torch.rand_like(self.logit_p)
gumbel = -torch.log(-torch.log(u + 1e-8) + 1e-8)
logits = (self.logit_p + gumbel) / self.temperature
mask = torch.sigmoid(logits)
else:
# Hard threshold at 0.5
mask = (torch.sigmoid(self.logit_p) > 0.5).float()
return mask
def sparsity_loss(self) -> torch.Tensor:
"""
Sparsity regularization: penalize large selection probabilities.
Uses L1 norm of sigmoid(logit) to encourage zeros.
"""
p = torch.sigmoid(self.logit_p)
# L1 penalty: encourages p β†’ 0 for non-causal features
return p.mean()
class CausalOmicsSelector(nn.Module):
"""
Per-modality causal feature selection.
Selects which features from each omics layer are causal for the outcome.
"""
def __init__(
self,
modality_dims: Dict[str, int],
prior_alpha: float = 0.1,
prior_beta: float = 0.9,
):
super().__init__()
self.masks = nn.ModuleDict({
name: CausalFeatureMask(dim, prior_alpha, prior_beta)
for name, dim in modality_dims.items()
})
self.modality_dims = modality_dims
def forward(
self, modalities: Dict[str, torch.Tensor], training: bool = True
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Apply causal masks to each modality.
Returns:
selected: masked features per modality
masks: learned masks per modality
"""
selected = {}
masks = {}
for name, x in modalities.items():
mask = self.masks[name](training=training)
selected[name] = x * mask.unsqueeze(0) # broadcast over batch
masks[name] = mask
return selected, masks
def total_sparsity_loss(self) -> torch.Tensor:
"""Sum of sparsity losses across all modalities."""
return sum(self.masks[name].sparsity_loss() for name in self.masks)
# ─── Causal Graph Structure Learning ────────────────────────────────────────
class CausalGraphLearner(nn.Module):
"""
Learns a causal graph (DAG) between a set of latent variables using
a differentiable acyclicity constraint (NOTEARS-inspired).
Adapted for molecular features: the learned adjacency matrix represents
causal relationships between latent molecular representations.
"""
def __init__(
self,
num_variables: int,
hidden_dim: int = 64,
lambda_dag: float = 1.0,
):
"""
Args:
num_variables: number of variables in the causal graph
hidden_dim: dimension of each variable's representation
lambda_dag: weight for the DAG constraint
"""
super().__init__()
# Learnable adjacency matrix (causal strengths)
self.W = nn.Parameter(torch.zeros(num_variables, num_variables))
self.num_variables = num_variables
self.lambda_dag = lambda_dag
nn.init.xavier_normal_(self.W)
def forward(self) -> torch.Tensor:
"""Returns the learned weighted adjacency matrix."""
return self.W
def dag_constraint(self) -> torch.Tensor:
"""
Differentiable DAG constraint (NOTEARS formulation).
trace(exp(W * W)) - d = 0 iff W is a DAG.
"""
W = self.W * self.W # element-wise square for non-negativity
M = torch.matrix_exp(W) # matrix exponential
h = torch.trace(M) - self.num_variables
return h * h # squared to ensure non-negative loss
def causal_effects(self) -> torch.Tensor:
"""
Compute total causal effects using the learned adjacency.
For linear SEM: total effect = (I - W)^(-1)
"""
W = self.W
I = torch.eye(self.num_variables, device=W.device)
total_effects = torch.linalg.inv(I - W)
return total_effects
# ─── Mediation Analysis ─────────────────────────────────────────────────────
class CausalMediationAnalyzer(nn.Module):
"""
Analyzes causal mediation through the central dogma layers.
For the path DNA β†’ RNA β†’ Protein β†’ Phenotype, decomposes the total
effect of a DNA feature on longevity into:
- Direct effect (DNA β†’ Phenotype, bypassing intermediates)
- Indirect effects (DNA β†’ RNA β†’ Phenotype, DNA β†’ RNA β†’ Protein β†’ Phenotype)
This maps to the MuLGIT central dogma architecture.
"""
def __init__(
self,
dna_dim: int,
rna_dim: int,
protein_dim: int,
):
super().__init__()
# Path-specific coefficients
self.dna_to_phenotype = nn.Linear(dna_dim, 1, bias=False) # direct
self.dna_to_rna = nn.Linear(dna_dim, rna_dim, bias=False) # path 1
self.rna_to_phenotype = nn.Linear(rna_dim, 1, bias=False) # path 2
self.rna_to_protein = nn.Linear(rna_dim, protein_dim, bias=False) # path 3
self.protein_to_phenotype = nn.Linear(protein_dim, 1, bias=False) # path 4
def decompose_effect(
self,
dna_features: torch.Tensor,
rna_features: Optional[torch.Tensor] = None,
protein_features: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Decompose total effect into direct and pathway-mediated effects.
Returns dict with:
total_effect: combined effect on phenotype
direct_effect: DNA β†’ Phenotype (bypassing RNA/protein)
dna_rna_effect: DNA β†’ RNA β†’ Phenotype
dna_rna_protein_effect: DNA β†’ RNA β†’ Protein β†’ Phenotype
"""
direct = self.dna_to_phenotype(dna_features)
rna_pred = self.dna_to_rna(dna_features)
rna_effect = self.rna_to_phenotype(rna_pred)
protein_pred = self.rna_to_protein(rna_pred)
protein_effect = self.protein_to_phenotype(protein_pred)
total = direct + rna_effect + protein_effect
return {
"total_effect": total,
"direct_effect": direct,
"dna_to_rna_effect": rna_effect,
"dna_to_rna_to_protein_effect": protein_effect,
}
# ─── Causal Attribution ─────────────────────────────────────────────────────
def compute_feature_attribution(
model: nn.Module,
input_modalities: Dict[str, torch.Tensor],
target: int = 0,
n_steps: int = 20,
) -> Dict[str, torch.Tensor]:
"""
Integrated Gradients-style causal attribution.
Computes the contribution of each feature to the predicted risk score
by integrating gradients along the path from baseline (zero) to input.
"""
attributions = {}
for name, x in input_modalities.items():
baseline = torch.zeros_like(x)
integrated_grad = torch.zeros_like(x)
for alpha in torch.linspace(0, 1, n_steps):
interpolated = baseline + alpha * (x - baseline)
interpolated.requires_grad_(True)
# Construct full input dict
full_input = {k: v for k, v in input_modalities.items()}
full_input[name] = interpolated
# Forward pass
output = model(**full_input)
risk = output["risk"]
# Gradient of risk w.r.t. interpolated input
grad = torch.autograd.grad(risk.sum(), interpolated)[0]
integrated_grad += grad.detach()
# Average and multiply by (input - baseline)
attributions[name] = (x - baseline) * (integrated_grad / n_steps)
return attributions
def identify_causal_features(
attributions: Dict[str, torch.Tensor],
top_k: int = 100,
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
"""
Identify top causal features from attribution scores.
Returns dict mapping modality name to (top_indices, top_scores).
"""
results = {}
for name, attr in attributions.items():
# Average attribution across batch
mean_attr = attr.abs().mean(dim=0)
# Get top-k features
top_scores, top_indices = torch.topk(mean_attr, k=min(top_k, mean_attr.shape[0]))
results[name] = (top_indices, top_scores)
return results