Oguzz07 commited on
Commit
78e2a75
·
verified ·
1 Parent(s): 9efaa0b

Add causal_selection/discovery/algorithms.py

Browse files
causal_selection/discovery/algorithms.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Algorithm adapters: run each causal discovery algorithm with timeout handling.
3
+ All algorithms take a pandas DataFrame (integer-encoded discrete data) and return
4
+ an adjacency matrix (np.ndarray) representing the learned graph.
5
+ """
6
+ import numpy as np
7
+ import pandas as pd
8
+ import signal
9
+ import time
10
+ import traceback
11
+ import logging
12
+ from functools import wraps
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TimeoutError(Exception):
18
+ pass
19
+
20
+
21
+ def timeout_handler(signum, frame):
22
+ raise TimeoutError("Algorithm timed out")
23
+
24
+
25
+ def run_with_timeout(func, timeout_sec, *args, **kwargs):
26
+ """Run a function with a timeout (Unix only, uses SIGALRM)."""
27
+ old_handler = signal.signal(signal.SIGALRM, timeout_handler)
28
+ signal.alarm(timeout_sec)
29
+ try:
30
+ result = func(*args, **kwargs)
31
+ signal.alarm(0)
32
+ return result
33
+ except TimeoutError:
34
+ logger.warning(f"Timeout after {timeout_sec}s")
35
+ raise
36
+ finally:
37
+ signal.signal(signal.SIGALRM, old_handler)
38
+ signal.alarm(0)
39
+
40
+
41
+ def safe_run(algo_name, func, timeout_sec, *args, **kwargs):
42
+ """Run algorithm with timeout and exception handling.
43
+
44
+ Returns:
45
+ (adjmat, runtime, status): adjacency matrix, time in seconds, status string
46
+ """
47
+ start = time.time()
48
+ try:
49
+ adjmat = run_with_timeout(func, timeout_sec, *args, **kwargs)
50
+ runtime = time.time() - start
51
+ return adjmat, runtime, 'success'
52
+ except TimeoutError:
53
+ runtime = time.time() - start
54
+ logger.warning(f"{algo_name}: TIMEOUT after {runtime:.1f}s")
55
+ return None, runtime, 'timeout'
56
+ except Exception as e:
57
+ runtime = time.time() - start
58
+ logger.error(f"{algo_name}: ERROR after {runtime:.1f}s: {e}")
59
+ logger.debug(traceback.format_exc())
60
+ return None, runtime, f'error: {str(e)[:100]}'
61
+
62
+
63
+ # ============================================================
64
+ # CONSTRAINT-BASED ALGORITHMS (causal-learn)
65
+ # ============================================================
66
+
67
+ def run_pc_discrete(df, alpha=0.01, stable=True):
68
+ """PC algorithm for discrete data using G-squared test."""
69
+ from causallearn.search.ConstraintBased.PC import pc
70
+ from causallearn.utils.cit import gsq
71
+
72
+ data = df.values.astype(int)
73
+ cg = pc(data, alpha=alpha, indep_test=gsq, stable=stable,
74
+ show_progress=False)
75
+
76
+ # Extract adjacency matrix from GeneralGraph
77
+ adj = cg.G.graph # numpy array
78
+ return _causallearn_graph_to_adjmat(adj)
79
+
80
+
81
+ def run_fci(df, alpha=0.05, depth=4):
82
+ """FCI algorithm - outputs PAG. We extract the directed edges."""
83
+ from causallearn.search.ConstraintBased.FCI import fci
84
+ from causallearn.utils.cit import gsq
85
+
86
+ data = df.values.astype(int)
87
+ g, edges = fci(data, independence_test_method=gsq, alpha=alpha,
88
+ depth=depth, show_progress=False)
89
+
90
+ adj = g.graph
91
+ return _causallearn_pag_to_adjmat(adj)
92
+
93
+
94
+ # ============================================================
95
+ # SCORE-BASED ALGORITHMS (causal-learn)
96
+ # ============================================================
97
+
98
+ def run_ges_causallearn(df, score_func='local_score_BDeu'):
99
+ """GES algorithm from causal-learn."""
100
+ from causallearn.search.ScoreBased.GES import ges
101
+
102
+ data = df.values.astype(int)
103
+ record = ges(data, score_func=score_func, maxP=None,
104
+ parameters=None)
105
+
106
+ adj = record['G'].graph
107
+ return _causallearn_graph_to_adjmat(adj)
108
+
109
+
110
+ def run_boss(df, score_func='local_score_BDeu'):
111
+ """BOSS (Best Order Score Search) algorithm."""
112
+ from causallearn.search.PermutationBased.BOSS import boss
113
+
114
+ data = df.values.astype(int)
115
+ cg = boss(data, score_func=score_func,
116
+ parameters=None)
117
+
118
+ adj = cg.graph
119
+ return _causallearn_graph_to_adjmat(adj)
120
+
121
+
122
+ def run_grasp(df, score_func='local_score_BDeu', depth=3):
123
+ """GRaSP (Greedy relaxation of Sparsest Permutation) algorithm."""
124
+ from causallearn.search.PermutationBased.GRaSP import grasp
125
+
126
+ data = df.values.astype(int)
127
+ cg = grasp(data, score_func=score_func, depth=depth,
128
+ parameters=None)
129
+
130
+ adj = cg.graph
131
+ return _causallearn_graph_to_adjmat(adj)
132
+
133
+
134
+ # ============================================================
135
+ # SCORE-BASED ALGORITHMS (pgmpy)
136
+ # ============================================================
137
+
138
+ def run_ges_pgmpy(df, scoring_method='bicscore'):
139
+ """GES algorithm from pgmpy."""
140
+ from pgmpy.estimators import ExhaustiveSearch, HillClimbSearch, BicScore, BDeuScore
141
+
142
+ # Use HillClimbSearch as proxy since pgmpy doesn't have native GES
143
+ # We'll use the tabu search for better exploration
144
+ scoring = BicScore(df) if scoring_method == 'bicscore' else BDeuScore(df)
145
+ hc = HillClimbSearch(df)
146
+ best_model = hc.estimate(scoring_method=scoring, max_indegree=4,
147
+ max_iter=100000, epsilon=0.0001)
148
+
149
+ return _pgmpy_model_to_adjmat(best_model, sorted(df.columns))
150
+
151
+
152
+ def run_hc(df, scoring_method='bicscore', max_indegree=3, max_iter=100000):
153
+ """Hill-Climbing algorithm."""
154
+ from pgmpy.estimators import HillClimbSearch, BicScore, BDeuScore
155
+
156
+ scoring = BicScore(df) if scoring_method == 'bicscore' else BDeuScore(df)
157
+ hc = HillClimbSearch(df)
158
+ best_model = hc.estimate(scoring_method=scoring, max_indegree=max_indegree,
159
+ max_iter=max_iter, epsilon=0.0001)
160
+
161
+ return _pgmpy_model_to_adjmat(best_model, sorted(df.columns))
162
+
163
+
164
+ def run_tabu(df, scoring_method='bicscore', tabu_length=100, max_indegree=3, max_iter=100000):
165
+ """Tabu Search algorithm."""
166
+ from pgmpy.estimators import HillClimbSearch, BicScore, BDeuScore
167
+
168
+ scoring = BicScore(df) if scoring_method == 'bicscore' else BDeuScore(df)
169
+ hc = HillClimbSearch(df)
170
+ best_model = hc.estimate(scoring_method=scoring, max_indegree=max_indegree,
171
+ max_iter=max_iter, epsilon=0.0001,
172
+ tabu_length=tabu_length)
173
+
174
+ return _pgmpy_model_to_adjmat(best_model, sorted(df.columns))
175
+
176
+
177
+ def run_mmhc(df, scoring_method='bdeuscore', significance_level=0.01):
178
+ """Max-Min Hill-Climbing (MMHC) hybrid algorithm."""
179
+ from pgmpy.estimators import MmhcEstimator, BDeuScore, BicScore
180
+
181
+ mmhc = MmhcEstimator(df)
182
+ scoring = BDeuScore(df) if scoring_method == 'bdeuscore' else BicScore(df)
183
+ best_model = mmhc.estimate(scoring_method=scoring,
184
+ significance_level=significance_level)
185
+
186
+ return _pgmpy_model_to_adjmat(best_model, sorted(df.columns))
187
+
188
+
189
+ def run_k2(df, max_parents=3):
190
+ """K2 algorithm (requires node ordering - we use alphabetical as default).
191
+ We run with multiple random orderings and take the best-scoring result.
192
+ """
193
+ from pgmpy.estimators import HillClimbSearch, K2Score
194
+
195
+ scoring = K2Score(df)
196
+ hc = HillClimbSearch(df)
197
+ best_model = hc.estimate(scoring_method=scoring, max_indegree=max_parents,
198
+ max_iter=100000, epsilon=0.0001)
199
+
200
+ return _pgmpy_model_to_adjmat(best_model, sorted(df.columns))
201
+
202
+
203
+ # ============================================================
204
+ # CONVERSION UTILITIES
205
+ # ============================================================
206
+
207
+ def _causallearn_graph_to_adjmat(graph_matrix):
208
+ """Convert causal-learn's graph representation to standard adjacency matrix.
209
+
210
+ causal-learn encoding:
211
+ graph[i,j] = -1 and graph[j,i] = 1 means i -> j
212
+ graph[i,j] = -1 and graph[j,i] = -1 means i -- j (undirected)
213
+ graph[i,j] = 1 and graph[j,i] = 1 means i <-> j (bidirected)
214
+
215
+ Our encoding:
216
+ adj[i,j] = 1 and adj[j,i] = 0 means i -> j
217
+ adj[i,j] = 1 and adj[j,i] = 1 means i -- j (undirected)
218
+ """
219
+ n = graph_matrix.shape[0]
220
+ adj = np.zeros((n, n), dtype=int)
221
+
222
+ for i in range(n):
223
+ for j in range(n):
224
+ if i == j:
225
+ continue
226
+ if graph_matrix[i, j] == -1 and graph_matrix[j, i] == 1:
227
+ # i -> j (tail at i, arrowhead at j)
228
+ adj[i, j] = 1
229
+ elif graph_matrix[i, j] == -1 and graph_matrix[j, i] == -1:
230
+ # i -- j (undirected)
231
+ adj[i, j] = 1
232
+ adj[j, i] = 1
233
+ elif graph_matrix[i, j] == 1 and graph_matrix[j, i] == 1:
234
+ # i <-> j (bidirected) - treat as undirected for CPDAG comparison
235
+ adj[i, j] = 1
236
+ adj[j, i] = 1
237
+
238
+ return adj
239
+
240
+
241
+ def _causallearn_pag_to_adjmat(pag_matrix):
242
+ """Convert PAG (from FCI) to adjacency matrix.
243
+
244
+ PAG encoding in causal-learn:
245
+ 1 = arrowhead (>), -1 = tail (-), 2 = circle (o)
246
+
247
+ We extract: definite directed edges and definite adjacencies.
248
+ """
249
+ n = pag_matrix.shape[0]
250
+ adj = np.zeros((n, n), dtype=int)
251
+
252
+ for i in range(n):
253
+ for j in range(n):
254
+ if i == j:
255
+ continue
256
+ # i -> j: tail at i (-1), arrowhead at j (1)
257
+ if pag_matrix[i, j] == -1 and pag_matrix[j, i] == 1:
258
+ adj[i, j] = 1
259
+ # i -- j or i o-o j or i o-> j: treat as undirected edge
260
+ elif pag_matrix[i, j] != 0 and pag_matrix[j, i] != 0:
261
+ adj[i, j] = 1
262
+ adj[j, i] = 1
263
+
264
+ return adj
265
+
266
+
267
+ def _pgmpy_model_to_adjmat(model, node_names):
268
+ """Convert pgmpy DAGModel to adjacency matrix."""
269
+ n = len(node_names)
270
+ node_idx = {name: i for i, name in enumerate(node_names)}
271
+ adj = np.zeros((n, n), dtype=int)
272
+
273
+ for parent, child in model.edges():
274
+ if parent in node_idx and child in node_idx:
275
+ adj[node_idx[parent], node_idx[child]] = 1
276
+
277
+ return adj
278
+
279
+
280
+ # ============================================================
281
+ # ALGORITHM REGISTRY
282
+ # ============================================================
283
+
284
+ ALGORITHM_POOL = {
285
+ 'PC_discrete': {
286
+ 'func': run_pc_discrete,
287
+ 'kwargs': {'alpha': 0.01, 'stable': True},
288
+ 'library': 'causal_learn',
289
+ 'output_type': 'cpdag',
290
+ 'family': 'constraint',
291
+ },
292
+ 'FCI': {
293
+ 'func': run_fci,
294
+ 'kwargs': {'alpha': 0.05, 'depth': 4},
295
+ 'library': 'causal_learn',
296
+ 'output_type': 'pag',
297
+ 'family': 'constraint',
298
+ },
299
+ 'GES': {
300
+ 'func': run_ges_causallearn,
301
+ 'kwargs': {'score_func': 'local_score_BDeu'},
302
+ 'library': 'causal_learn',
303
+ 'output_type': 'cpdag',
304
+ 'family': 'score',
305
+ },
306
+ 'BOSS': {
307
+ 'func': run_boss,
308
+ 'kwargs': {'score_func': 'local_score_BDeu'},
309
+ 'library': 'causal_learn',
310
+ 'output_type': 'cpdag',
311
+ 'family': 'permutation',
312
+ },
313
+ 'GRaSP': {
314
+ 'func': run_grasp,
315
+ 'kwargs': {'score_func': 'local_score_BDeu', 'depth': 3},
316
+ 'library': 'causal_learn',
317
+ 'output_type': 'cpdag',
318
+ 'family': 'permutation',
319
+ },
320
+ 'HC': {
321
+ 'func': run_hc,
322
+ 'kwargs': {'scoring_method': 'bicscore', 'max_indegree': 3, 'max_iter': 100000},
323
+ 'library': 'pgmpy',
324
+ 'output_type': 'dag',
325
+ 'family': 'score',
326
+ },
327
+ 'Tabu': {
328
+ 'func': run_tabu,
329
+ 'kwargs': {'scoring_method': 'bicscore', 'tabu_length': 100, 'max_indegree': 3, 'max_iter': 100000},
330
+ 'library': 'pgmpy',
331
+ 'output_type': 'dag',
332
+ 'family': 'score',
333
+ },
334
+ 'MMHC': {
335
+ 'func': run_mmhc,
336
+ 'kwargs': {'scoring_method': 'bdeuscore', 'significance_level': 0.01},
337
+ 'library': 'pgmpy',
338
+ 'output_type': 'dag',
339
+ 'family': 'hybrid',
340
+ },
341
+ 'K2': {
342
+ 'func': run_k2,
343
+ 'kwargs': {'max_parents': 3},
344
+ 'library': 'pgmpy',
345
+ 'output_type': 'dag',
346
+ 'family': 'score',
347
+ },
348
+ }
349
+
350
+
351
+ def run_algorithm(algo_name, df, timeout_sec=600):
352
+ """Run a single algorithm on a dataset.
353
+
354
+ Returns:
355
+ dict with keys: adjmat, runtime, status, output_type
356
+ """
357
+ if algo_name not in ALGORITHM_POOL:
358
+ raise ValueError(f"Unknown algorithm: {algo_name}")
359
+
360
+ algo = ALGORITHM_POOL[algo_name]
361
+ func = algo['func']
362
+ kwargs = algo['kwargs'].copy()
363
+
364
+ adjmat, runtime, status = safe_run(
365
+ algo_name, func, timeout_sec, df, **kwargs
366
+ )
367
+
368
+ return {
369
+ 'adjmat': adjmat,
370
+ 'runtime': runtime,
371
+ 'status': status,
372
+ 'output_type': algo['output_type'],
373
+ 'family': algo['family'],
374
+ }
375
+
376
+
377
+ if __name__ == '__main__':
378
+ logging.basicConfig(level=logging.INFO)
379
+
380
+ # Quick test on Asia
381
+ from causal_selection.data.generator import load_bn_model, sample_dataset
382
+
383
+ model = load_bn_model('asia')
384
+ df = sample_dataset(model, 1000, seed=0)
385
+
386
+ print(f"Testing on ASIA (N=1000)...")
387
+ for algo_name in ALGORITHM_POOL:
388
+ result = run_algorithm(algo_name, df, timeout_sec=60)
389
+ status = result['status']
390
+ runtime = result['runtime']
391
+ if result['adjmat'] is not None:
392
+ n_edges = result['adjmat'].sum()
393
+ print(f" {algo_name:15s}: {status:10s} {runtime:6.2f}s edges={n_edges}")
394
+ else:
395
+ print(f" {algo_name:15s}: {status:20s} {runtime:6.2f}s")