Oguzz07 commited on
Commit
9efaa0b
·
verified ·
1 Parent(s): c1c4e89

Add causal_selection/data/generator.py

Browse files
Files changed (1) hide show
  1. causal_selection/data/generator.py +291 -0
causal_selection/data/generator.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data generation module: load bnlearn networks, sample datasets, extract ground truth.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ from pgmpy.readwrite import BIFReader
8
+ from pgmpy.sampling import BayesianModelSampling
9
+ import warnings
10
+ import logging
11
+
12
+ warnings.filterwarnings('ignore')
13
+ logger = logging.getLogger(__name__)
14
+
15
+ BIF_DIR = os.path.join(os.path.dirname(__file__), 'bif_files')
16
+
17
+ # Network tiers for CPU budget management
18
+ SMALL_NETWORKS = ['asia', 'cancer', 'earthquake', 'sachs', 'survey']
19
+ MEDIUM_NETWORKS = ['alarm', 'barley', 'child', 'insurance', 'mildew', 'water']
20
+ LARGE_NETWORKS = ['hailfinder', 'hepar2', 'win95pts']
21
+
22
+ ALL_NETWORKS = SMALL_NETWORKS + MEDIUM_NETWORKS + LARGE_NETWORKS
23
+
24
+ # Sample sizes per tier
25
+ SAMPLE_SIZES = {
26
+ 'small': [250, 500, 1000, 2000, 5000, 10000],
27
+ 'medium': [500, 1000, 2000, 5000],
28
+ 'large': [500, 1000, 2000],
29
+ }
30
+
31
+ SEEDS_PER_CONFIG = 3
32
+
33
+
34
+ def get_network_tier(name):
35
+ if name in SMALL_NETWORKS:
36
+ return 'small'
37
+ elif name in MEDIUM_NETWORKS:
38
+ return 'medium'
39
+ else:
40
+ return 'large'
41
+
42
+
43
+ def load_bn_model(name):
44
+ """Load a Bayesian network from BIF file."""
45
+ bif_path = os.path.join(BIF_DIR, f'{name}.bif')
46
+ if not os.path.exists(bif_path):
47
+ raise FileNotFoundError(f"BIF file not found: {bif_path}")
48
+ reader = BIFReader(bif_path)
49
+ model = reader.get_model()
50
+ return model
51
+
52
+
53
+ def get_true_dag_adjmat(model):
54
+ """Extract ground-truth DAG adjacency matrix from a BayesianNetwork model.
55
+
56
+ Returns:
57
+ adjmat: np.ndarray of shape (n_nodes, n_nodes), adjmat[i,j]=1 means i->j
58
+ node_names: list of node names (ordering)
59
+ """
60
+ nodes = sorted(model.nodes())
61
+ n = len(nodes)
62
+ node_idx = {node: i for i, node in enumerate(nodes)}
63
+ adjmat = np.zeros((n, n), dtype=int)
64
+ for parent, child in model.edges():
65
+ adjmat[node_idx[parent], node_idx[child]] = 1
66
+ return adjmat, nodes
67
+
68
+
69
+ def dag_to_cpdag(dag_adjmat):
70
+ """Convert a DAG adjacency matrix to its CPDAG (completed partially directed acyclic graph).
71
+
72
+ A CPDAG represents the Markov equivalence class:
73
+ - Compelled edges (in all DAGs of the class) remain directed
74
+ - Reversible edges become undirected (represented as bidirectional)
75
+
76
+ Uses the Chickering (2002) algorithm:
77
+ 1. Find all v-structures (i -> k <- j where i and j not adjacent)
78
+ 2. Apply Meek's orientation rules iteratively
79
+
80
+ Returns:
81
+ cpdag: np.ndarray, cpdag[i,j]=1 and cpdag[j,i]=0 means i->j (directed)
82
+ cpdag[i,j]=1 and cpdag[j,i]=1 means i--j (undirected)
83
+ """
84
+ n = dag_adjmat.shape[0]
85
+
86
+ # Start with skeleton (undirected)
87
+ skeleton = ((dag_adjmat + dag_adjmat.T) > 0).astype(int)
88
+ cpdag = skeleton.copy()
89
+
90
+ # Step 1: Find v-structures and orient them
91
+ # v-structure: i -> k <- j where i and j are NOT adjacent in skeleton
92
+ for k in range(n):
93
+ parents_of_k = np.where(dag_adjmat[:, k] == 1)[0]
94
+ for idx_a in range(len(parents_of_k)):
95
+ for idx_b in range(idx_a + 1, len(parents_of_k)):
96
+ i = parents_of_k[idx_a]
97
+ j = parents_of_k[idx_b]
98
+ # Check if i and j are NOT adjacent
99
+ if skeleton[i, j] == 0:
100
+ # This is a v-structure: i -> k <- j
101
+ # Orient both edges as directed in CPDAG
102
+ cpdag[i, k] = 1
103
+ cpdag[k, i] = 0
104
+ cpdag[j, k] = 1
105
+ cpdag[k, j] = 0
106
+
107
+ # Step 2: Apply Meek's rules iteratively until convergence
108
+ changed = True
109
+ while changed:
110
+ changed = False
111
+ for i in range(n):
112
+ for j in range(n):
113
+ if cpdag[i, j] == 1 and cpdag[j, i] == 1:
114
+ # i -- j is undirected, try to orient
115
+
116
+ # Rule 1: If k -> i -- j and k not adj j, then i -> j
117
+ for k in range(n):
118
+ if k != i and k != j:
119
+ if cpdag[k, i] == 1 and cpdag[i, k] == 0: # k -> i
120
+ if cpdag[k, j] == 0 and cpdag[j, k] == 0: # k not adj j
121
+ cpdag[j, i] = 0 # orient i -> j
122
+ changed = True
123
+
124
+ # Rule 2: If i -> k -> j and i -- j, then i -> j
125
+ if cpdag[i, j] == 1 and cpdag[j, i] == 1: # still undirected
126
+ for k in range(n):
127
+ if k != i and k != j:
128
+ if (cpdag[i, k] == 1 and cpdag[k, i] == 0 and # i -> k
129
+ cpdag[k, j] == 1 and cpdag[j, k] == 0): # k -> j
130
+ cpdag[j, i] = 0 # orient i -> j
131
+ changed = True
132
+
133
+ # Rule 3: If i -- k1 -> j and i -- k2 -> j and k1 not adj k2, then i -> j
134
+ if cpdag[i, j] == 1 and cpdag[j, i] == 1:
135
+ k_candidates = []
136
+ for k in range(n):
137
+ if k != i and k != j:
138
+ if (cpdag[i, k] == 1 and cpdag[k, i] == 1 and # i -- k
139
+ cpdag[k, j] == 1 and cpdag[j, k] == 0): # k -> j
140
+ k_candidates.append(k)
141
+ for idx_a in range(len(k_candidates)):
142
+ for idx_b in range(idx_a + 1, len(k_candidates)):
143
+ k1, k2 = k_candidates[idx_a], k_candidates[idx_b]
144
+ if cpdag[k1, k2] == 0 and cpdag[k2, k1] == 0: # not adjacent
145
+ cpdag[j, i] = 0 # orient i -> j
146
+ changed = True
147
+
148
+ return cpdag
149
+
150
+
151
+ def sample_dataset(model, n_samples, seed=42):
152
+ """Sample observational data from a Bayesian network.
153
+
154
+ Returns:
155
+ df: pd.DataFrame with integer-encoded discrete variables
156
+ """
157
+ np.random.seed(seed)
158
+ sampler = BayesianModelSampling(model)
159
+ try:
160
+ df = sampler.forward_sample(size=n_samples, seed=seed)
161
+ except TypeError:
162
+ # Fallback for pgmpy/pandas version compatibility issues
163
+ # Use bnlearn sampling or manual forward sampling
164
+ df = _manual_forward_sample(model, n_samples, seed)
165
+
166
+ # Ensure consistent column ordering (sorted)
167
+ df = df[sorted(df.columns)]
168
+
169
+ # Encode string/category columns as integers
170
+ for col in df.columns:
171
+ if df[col].dtype == object or df[col].dtype.name == 'category':
172
+ df[col] = df[col].astype('category').cat.codes
173
+
174
+ # Ensure all columns are numeric
175
+ df = df.apply(pd.to_numeric, errors='coerce').fillna(0).astype(int)
176
+
177
+ return df
178
+
179
+
180
+ def _manual_forward_sample(model, n_samples, seed=42):
181
+ """Manual forward sampling when pgmpy's sampler has compatibility issues."""
182
+ import networkx as nx
183
+
184
+ rng = np.random.RandomState(seed)
185
+ nodes = list(nx.topological_sort(model))
186
+
187
+ # Get CPDs
188
+ cpd_dict = {}
189
+ for cpd in model.get_cpds():
190
+ cpd_dict[cpd.variable] = cpd
191
+
192
+ samples = {node: [] for node in nodes}
193
+
194
+ for _ in range(n_samples):
195
+ sample = {}
196
+ for node in nodes:
197
+ cpd = cpd_dict[node]
198
+ parents = cpd.get_evidence()
199
+
200
+ if not parents:
201
+ # Root node - sample from marginal
202
+ probs = cpd.get_values().flatten()
203
+ probs = probs / probs.sum() # normalize
204
+ val = rng.choice(len(probs), p=probs)
205
+ else:
206
+ # Conditional sampling
207
+ parent_vals = tuple(sample[p] for p in parents)
208
+ # Get the column of CPT corresponding to parent values
209
+ values = cpd.get_values()
210
+ state_names = cpd.state_names
211
+
212
+ # Calculate column index from parent states
213
+ col_idx = 0
214
+ stride = 1
215
+ for p in reversed(parents):
216
+ p_card = len(state_names[p])
217
+ col_idx += sample[p] * stride
218
+ stride *= p_card
219
+
220
+ probs = values[:, col_idx]
221
+ probs = np.abs(probs)
222
+ probs = probs / probs.sum()
223
+ val = rng.choice(len(probs), p=probs)
224
+
225
+ sample[node] = val
226
+ samples[node].append(val)
227
+
228
+ return pd.DataFrame(samples)
229
+
230
+
231
+ def generate_all_datasets(networks=None, output_dir=None):
232
+ """Generate all dataset configurations.
233
+
234
+ Returns list of dicts with:
235
+ - network: str
236
+ - n_samples: int
237
+ - seed: int
238
+ - df: pd.DataFrame
239
+ - true_dag: np.ndarray
240
+ - true_cpdag: np.ndarray
241
+ - node_names: list
242
+ """
243
+ if networks is None:
244
+ networks = ALL_NETWORKS
245
+
246
+ configs = []
247
+ for net_name in networks:
248
+ tier = get_network_tier(net_name)
249
+ sample_sizes = SAMPLE_SIZES[tier]
250
+
251
+ logger.info(f"Loading network: {net_name}")
252
+ model = load_bn_model(net_name)
253
+ true_dag, node_names = get_true_dag_adjmat(model)
254
+ true_cpdag = dag_to_cpdag(true_dag)
255
+
256
+ for n_samples in sample_sizes:
257
+ for seed in range(SEEDS_PER_CONFIG):
258
+ try:
259
+ df = sample_dataset(model, n_samples, seed=seed)
260
+ config = {
261
+ 'network': net_name,
262
+ 'n_samples': n_samples,
263
+ 'seed': seed,
264
+ 'df': df,
265
+ 'true_dag': true_dag,
266
+ 'true_cpdag': true_cpdag,
267
+ 'node_names': node_names,
268
+ }
269
+ configs.append(config)
270
+ logger.info(f" {net_name} N={n_samples} seed={seed}: {df.shape}")
271
+ except Exception as e:
272
+ logger.error(f" FAILED {net_name} N={n_samples} seed={seed}: {e}")
273
+
274
+ return configs
275
+
276
+
277
+ if __name__ == '__main__':
278
+ logging.basicConfig(level=logging.INFO)
279
+
280
+ # Quick test
281
+ model = load_bn_model('asia')
282
+ dag, nodes = get_true_dag_adjmat(model)
283
+ cpdag = dag_to_cpdag(dag)
284
+
285
+ print(f"ASIA - nodes: {nodes}")
286
+ print(f"DAG adjacency:\n{dag}")
287
+ print(f"CPDAG adjacency:\n{cpdag}")
288
+
289
+ df = sample_dataset(model, 1000, seed=0)
290
+ print(f"\nSampled data: {df.shape}")
291
+ print(df.head())