vedatonuryilmaz commited on
Commit
28ff549
Β·
verified Β·
1 Parent(s): 31a5e86

Upload mulgit/causal.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mulgit/causal.py +307 -0
mulgit/causal.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Causal Discovery Module
3
+
4
+ Identifies causal genetic factors and molecular interactions underlying
5
+ exceptional longevity. Combines structural causal models with deep
6
+ learning-based causal inference.
7
+
8
+ Methods implemented:
9
+ 1. Causal Feature Selection via Information Bottleneck (Seq2Exp-inspired)
10
+ - Learn binary masks that identify causal features from each omics layer
11
+ - Beta distribution prior for sparsity
12
+
13
+ 2. Causal Structure Learning via NOTEARS-inspired DAG constraint
14
+ - Learn causal graph between molecular features
15
+ - Differentiable acyclicity constraint
16
+
17
+ 3. Causal Mediation Analysis
18
+ - Identify mediated effects through the central dogma layers
19
+ - Decompose total effect into direct and indirect (pathway-mediated)
20
+
21
+ References:
22
+ - Seq2Exp (arxiv:2502.13991): Causal regulatory element discovery
23
+ - Avici: Amortized causal structure learning in genomics
24
+ - NOTEARS: Non-combinatorial Optimization via Trace Exponential
25
+ Augmented lagRangian Structure learning
26
+ """
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ from typing import Optional, List, Tuple, Dict
32
+
33
+
34
+ # ─── Causal Feature Selection ────────────────────────────────────────────────
35
+
36
+ class CausalFeatureMask(nn.Module):
37
+ """
38
+ Learns a binary mask over input features identifying causal features.
39
+
40
+ Inspired by Seq2Exp's information bottleneck: uses Beta distribution
41
+ prior to encourage sparsity. The mask is learned via the
42
+ concrete/Gumbel-softmax reparameterization for differentiability.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ num_features: int,
48
+ prior_alpha: float = 0.1,
49
+ prior_beta: float = 0.9,
50
+ temperature: float = 0.5,
51
+ ):
52
+ """
53
+ Args:
54
+ num_features: number of input features
55
+ prior_alpha, prior_beta: Beta distribution parameters (skewed
56
+ toward 0 to encourage sparse selection)
57
+ temperature: Gumbel-softmax temperature (lower = more discrete)
58
+ """
59
+ super().__init__()
60
+ # Learnable logits for each feature's selection probability
61
+ self.logit_p = nn.Parameter(torch.zeros(num_features))
62
+ self.prior_alpha = prior_alpha
63
+ self.prior_beta = prior_beta
64
+ self.temperature = temperature
65
+
66
+ def forward(self, training: bool = True) -> torch.Tensor:
67
+ """
68
+ Returns a soft (training) or hard (inference) binary mask.
69
+ """
70
+ if training:
71
+ # Concrete distribution (Gumbel-softmax)
72
+ u = torch.rand_like(self.logit_p)
73
+ gumbel = -torch.log(-torch.log(u + 1e-8) + 1e-8)
74
+ logits = (self.logit_p + gumbel) / self.temperature
75
+ mask = torch.sigmoid(logits)
76
+ else:
77
+ # Hard threshold at 0.5
78
+ mask = (torch.sigmoid(self.logit_p) > 0.5).float()
79
+ return mask
80
+
81
+ def sparsity_loss(self) -> torch.Tensor:
82
+ """
83
+ Sparsity regularization: penalize large selection probabilities.
84
+ Uses L1 norm of sigmoid(logit) to encourage zeros.
85
+ """
86
+ p = torch.sigmoid(self.logit_p)
87
+ # L1 penalty: encourages p β†’ 0 for non-causal features
88
+ return p.mean()
89
+
90
+
91
+ class CausalOmicsSelector(nn.Module):
92
+ """
93
+ Per-modality causal feature selection.
94
+ Selects which features from each omics layer are causal for the outcome.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ modality_dims: Dict[str, int],
100
+ prior_alpha: float = 0.1,
101
+ prior_beta: float = 0.9,
102
+ ):
103
+ super().__init__()
104
+ self.masks = nn.ModuleDict({
105
+ name: CausalFeatureMask(dim, prior_alpha, prior_beta)
106
+ for name, dim in modality_dims.items()
107
+ })
108
+ self.modality_dims = modality_dims
109
+
110
+ def forward(
111
+ self, modalities: Dict[str, torch.Tensor], training: bool = True
112
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
113
+ """
114
+ Apply causal masks to each modality.
115
+
116
+ Returns:
117
+ selected: masked features per modality
118
+ masks: learned masks per modality
119
+ """
120
+ selected = {}
121
+ masks = {}
122
+ for name, x in modalities.items():
123
+ mask = self.masks[name](training=training)
124
+ selected[name] = x * mask.unsqueeze(0) # broadcast over batch
125
+ masks[name] = mask
126
+ return selected, masks
127
+
128
+ def total_sparsity_loss(self) -> torch.Tensor:
129
+ """Sum of sparsity losses across all modalities."""
130
+ return sum(self.masks[name].sparsity_loss() for name in self.masks)
131
+
132
+
133
+ # ─── Causal Graph Structure Learning ────────────────────────────────────────
134
+
135
+ class CausalGraphLearner(nn.Module):
136
+ """
137
+ Learns a causal graph (DAG) between a set of latent variables using
138
+ a differentiable acyclicity constraint (NOTEARS-inspired).
139
+
140
+ Adapted for molecular features: the learned adjacency matrix represents
141
+ causal relationships between latent molecular representations.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ num_variables: int,
147
+ hidden_dim: int = 64,
148
+ lambda_dag: float = 1.0,
149
+ ):
150
+ """
151
+ Args:
152
+ num_variables: number of variables in the causal graph
153
+ hidden_dim: dimension of each variable's representation
154
+ lambda_dag: weight for the DAG constraint
155
+ """
156
+ super().__init__()
157
+ # Learnable adjacency matrix (causal strengths)
158
+ self.W = nn.Parameter(torch.zeros(num_variables, num_variables))
159
+ self.num_variables = num_variables
160
+ self.lambda_dag = lambda_dag
161
+ nn.init.xavier_normal_(self.W)
162
+
163
+ def forward(self) -> torch.Tensor:
164
+ """Returns the learned weighted adjacency matrix."""
165
+ return self.W
166
+
167
+ def dag_constraint(self) -> torch.Tensor:
168
+ """
169
+ Differentiable DAG constraint (NOTEARS formulation).
170
+ trace(exp(W * W)) - d = 0 iff W is a DAG.
171
+ """
172
+ W = self.W * self.W # element-wise square for non-negativity
173
+ M = torch.matrix_exp(W) # matrix exponential
174
+ h = torch.trace(M) - self.num_variables
175
+ return h * h # squared to ensure non-negative loss
176
+
177
+ def causal_effects(self) -> torch.Tensor:
178
+ """
179
+ Compute total causal effects using the learned adjacency.
180
+ For linear SEM: total effect = (I - W)^(-1)
181
+ """
182
+ W = self.W
183
+ I = torch.eye(self.num_variables, device=W.device)
184
+ total_effects = torch.linalg.inv(I - W)
185
+ return total_effects
186
+
187
+
188
+ # ─── Mediation Analysis ─────────────────────────────────────────────────────
189
+
190
+ class CausalMediationAnalyzer(nn.Module):
191
+ """
192
+ Analyzes causal mediation through the central dogma layers.
193
+
194
+ For the path DNA β†’ RNA β†’ Protein β†’ Phenotype, decomposes the total
195
+ effect of a DNA feature on longevity into:
196
+ - Direct effect (DNA β†’ Phenotype, bypassing intermediates)
197
+ - Indirect effects (DNA β†’ RNA β†’ Phenotype, DNA β†’ RNA β†’ Protein β†’ Phenotype)
198
+
199
+ This maps to the MuLGIT central dogma architecture.
200
+ """
201
+
202
+ def __init__(
203
+ self,
204
+ dna_dim: int,
205
+ rna_dim: int,
206
+ protein_dim: int,
207
+ ):
208
+ super().__init__()
209
+ # Path-specific coefficients
210
+ self.dna_to_phenotype = nn.Linear(dna_dim, 1, bias=False) # direct
211
+ self.dna_to_rna = nn.Linear(dna_dim, rna_dim, bias=False) # path 1
212
+ self.rna_to_phenotype = nn.Linear(rna_dim, 1, bias=False) # path 2
213
+ self.rna_to_protein = nn.Linear(rna_dim, protein_dim, bias=False) # path 3
214
+ self.protein_to_phenotype = nn.Linear(protein_dim, 1, bias=False) # path 4
215
+
216
+ def decompose_effect(
217
+ self,
218
+ dna_features: torch.Tensor,
219
+ rna_features: Optional[torch.Tensor] = None,
220
+ protein_features: Optional[torch.Tensor] = None,
221
+ ) -> Dict[str, torch.Tensor]:
222
+ """
223
+ Decompose total effect into direct and pathway-mediated effects.
224
+
225
+ Returns dict with:
226
+ total_effect: combined effect on phenotype
227
+ direct_effect: DNA β†’ Phenotype (bypassing RNA/protein)
228
+ dna_rna_effect: DNA β†’ RNA β†’ Phenotype
229
+ dna_rna_protein_effect: DNA β†’ RNA β†’ Protein β†’ Phenotype
230
+ """
231
+ direct = self.dna_to_phenotype(dna_features)
232
+
233
+ rna_pred = self.dna_to_rna(dna_features)
234
+ rna_effect = self.rna_to_phenotype(rna_pred)
235
+
236
+ protein_pred = self.rna_to_protein(rna_pred)
237
+ protein_effect = self.protein_to_phenotype(protein_pred)
238
+
239
+ total = direct + rna_effect + protein_effect
240
+
241
+ return {
242
+ "total_effect": total,
243
+ "direct_effect": direct,
244
+ "dna_to_rna_effect": rna_effect,
245
+ "dna_to_rna_to_protein_effect": protein_effect,
246
+ }
247
+
248
+
249
+ # ─── Causal Attribution ─────────────────────────────────────────────────────
250
+
251
+ def compute_feature_attribution(
252
+ model: nn.Module,
253
+ input_modalities: Dict[str, torch.Tensor],
254
+ target: int = 0,
255
+ n_steps: int = 20,
256
+ ) -> Dict[str, torch.Tensor]:
257
+ """
258
+ Integrated Gradients-style causal attribution.
259
+
260
+ Computes the contribution of each feature to the predicted risk score
261
+ by integrating gradients along the path from baseline (zero) to input.
262
+ """
263
+ attributions = {}
264
+
265
+ for name, x in input_modalities.items():
266
+ baseline = torch.zeros_like(x)
267
+ integrated_grad = torch.zeros_like(x)
268
+
269
+ for alpha in torch.linspace(0, 1, n_steps):
270
+ interpolated = baseline + alpha * (x - baseline)
271
+ interpolated.requires_grad_(True)
272
+
273
+ # Construct full input dict
274
+ full_input = {k: v for k, v in input_modalities.items()}
275
+ full_input[name] = interpolated
276
+
277
+ # Forward pass
278
+ output = model(**full_input)
279
+ risk = output["risk"]
280
+
281
+ # Gradient of risk w.r.t. interpolated input
282
+ grad = torch.autograd.grad(risk.sum(), interpolated)[0]
283
+ integrated_grad += grad.detach()
284
+
285
+ # Average and multiply by (input - baseline)
286
+ attributions[name] = (x - baseline) * (integrated_grad / n_steps)
287
+
288
+ return attributions
289
+
290
+
291
+ def identify_causal_features(
292
+ attributions: Dict[str, torch.Tensor],
293
+ top_k: int = 100,
294
+ ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
295
+ """
296
+ Identify top causal features from attribution scores.
297
+
298
+ Returns dict mapping modality name to (top_indices, top_scores).
299
+ """
300
+ results = {}
301
+ for name, attr in attributions.items():
302
+ # Average attribution across batch
303
+ mean_attr = attr.abs().mean(dim=0)
304
+ # Get top-k features
305
+ top_scores, top_indices = torch.topk(mean_attr, k=min(top_k, mean_attr.shape[0]))
306
+ results[name] = (top_indices, top_scores)
307
+ return results