File size: 6,779 Bytes
3212b49
 
 
 
 
 
 
 
70796a3
3212b49
 
 
 
70796a3
 
3212b49
 
70796a3
 
3212b49
 
70796a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3212b49
 
 
70796a3
 
 
 
 
 
 
 
 
 
 
3212b49
 
 
 
 
 
 
 
 
 
70796a3
 
 
 
 
 
 
 
 
 
3212b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70796a3
 
3212b49
 
70796a3
 
3212b49
 
 
 
 
70796a3
 
 
 
 
 
 
3212b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70796a3
3212b49
70796a3
3212b49
 
 
70796a3
 
 
 
 
 
3212b49
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Causal Discovery Algorithm Selection Meta-Learner

A meta-learning system that predicts the **top-3 best causal discovery algorithms** for any discrete observational dataset, based on dataset meta-features.

## 🎯 What it Does

Given a new discrete dataset (pandas DataFrame), the system:
1. **Extracts 34 meta-features** (entropy, mutual information, chi² statistics, CI test probes, etc.)
2. **Predicts normalized SHD** for each of 9 algorithms via trained models
3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG

## 📊 Performance (Leave-One-Network-Out Cross-Validation)

### Best Model: Pairwise-GBM Ranking

| Metric | Value |
|--------|-------|
| **Top-3 Hit Rate** | **71.3%** (true best algorithm is in predicted top-3) |
| **Mean Regret** | **0.011** (tiny SHD gap vs oracle selection) |
| **Median Regret** | **0.000** (majority of predictions are perfect) |

### Model Comparison (178 configs, 14 networks + augmented)

| Model | Top-3 Hit Rate | NDCG@3 | Mean Regret |
|-------|---------------|--------|-------------|
| **Pairwise-GBM** | **71.3%** | — | 0.011 |
| GBM-300-lr01 | 67.4% | 0.957 | 0.011 |
| RF-200 | 66.9% | 0.961 | 0.007 |
| RF-500 | 66.3% | 0.962 | 0.007 |
| GBM-500-lr05 | 65.2% | 0.948 | 0.013 |

### Progression

| Stage | Configs | Networks | Top-3 Hit Rate |
|-------|---------|----------|---------------|
| Initial (small nets) | 65 | 4 | 68.2% |
| All 14 networks | 122 | 14 | 70.5% |
| + Data augmentation | 178 | 14+aug | **71.3%** |

## 🧪 Algorithm Pool (9 algorithms)

| Algorithm | Family | Library | Output | Wins |
|-----------|--------|---------|--------|------|
| **GES** | Score-based | causal-learn | CPDAG | 47% |
| **PC** | Constraint-based | causal-learn | CPDAG | 32% |
| **FCI** | Constraint-based | causal-learn | PAG | 8% |
| **K2** | Score-based | pgmpy | DAG | 6% |
| **HC** | Score-based (greedy) | pgmpy | DAG | 3% |
| **Tabu** | Score-based (meta) | pgmpy | DAG | 2% |
| **GRaSP** | Permutation-based | causal-learn | CPDAG | 1% |
| **BOSS** | Permutation-based | causal-learn | CPDAG | 1% |
| **MMHC** | Hybrid | pgmpy | DAG | <1% |

## 🔬 Key Insight: Dependency Parsing Connection

This project was inspired by a structural parallel between **NLP dependency parsing** and **causal discovery**:
- Both predict **directed graphs** over nodes (words/variables)
- Both have **ground-truth annotations** (treebanks/bnlearn networks)
- Both use **arc-level evaluation** (UAS/LAS ↔ SHD/F1)

The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection.

### Top Predictive Meta-Features
1. `n_variables` (30%) — network size (how many nodes in the graph)
2. `max_pairwise_MI` (24%) — strongest pairwise dependency (≈ biaffine arc score)
3. `max_cramers_v` (8%) — strongest association strength
4. `max_entropy` (7%) — variable complexity

### Three Ideas Borrowed from Parsing
1. **Biaffine-style pairwise features**: MI and Cramér's V between all variable pairs = parsing's arc scores
2. **Pairwise ranking** (our best model): For each algorithm pair (A,B), predict which wins → count wins to rank. Inspired by pairwise tournament-style parser selection
3. **Cross-domain transfer**: Train on well-characterized bnlearn networks → predict on new unseen datasets (= cross-lingual parser transfer)

## 🚀 Quick Start

```python
from causal_selection.meta_learner.predictor import predict_best_algorithms
import pandas as pd

# Load your discrete dataset
df = pd.read_csv("my_discrete_data.csv")

# Get top-3 recommendations
result = predict_best_algorithms(df, k=3)
# Prints ranked algorithms with predicted accuracy and confidence
```

## 📁 Project Structure

```
causal_selection/
├── data/
│   ├── generator.py          # Load bnlearn networks, sample data, DAG→CPDAG
│   ├── bif_files/            # 14 bnlearn BIF files (asia through win95pts)
│   └── results/              # Benchmark CSVs: meta-features, SHD matrices
├── discovery/
│   ├── algorithms.py         # 9 algorithm adapters with timeout handling
│   └── evaluator.py          # SHD, F1, Precision, Recall computation
├── features/
│   └── extractor.py          # 34 meta-features across 5 tiers
├── meta_learner/
│   ├── trainer.py            # Multi-Output RF/GBM + LONO-CV evaluation
│   └── predictor.py          # Inference: dataset → top-3 prediction
├── models/
│   ├── meta_learner.pkl      # Trained GBM (multi-output fallback)
│   ├── pairwise_model.pkl    # Pairwise ranking GBM (best model)
│   └── scaler.pkl            # Feature scaler
├── benchmark.py              # Full benchmark orchestration
├── run_benchmark.py          # Resumable benchmark runner
└── augment_and_improve.py    # Data augmentation + model improvement
```

## 📈 Benchmark Data

- **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts
- **178 dataset configs**: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection)
- **1,600+ algorithm runs**: 9 algorithms × 178 configs with per-algorithm timeout

### Data Augmentation Strategies
- **Variable subsampling**: Drop 20-40% of variables to create virtual sub-networks
- **Sample-size variation**: Generate N=300, 750, 1500, 3000 for each network
- **Noise injection**: Randomly flip 5-10% of categorical values

## 🔧 Dependencies

```
causal-learn>=0.1.4
pgmpy>=0.1.25
scikit-learn>=1.8
pandas
numpy
scipy
joblib
```

## 📚 References

- **Causal-Copilot** (arxiv:2504.13263) — Closest existing algorithm selection system
- **AVICI** (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture)
- **CauScale** (arxiv:2602.08629) — Scalable neural causal discovery
- **Dozat & Manning** (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing
- **TreeCRF** (arxiv:2005.00975) — Global structural training loss for parsing
- **SATzilla** (arxiv:1401.2474) — Algorithm selection via meta-learning
- **bnlearn** (bnlearn.com) — Bayesian network benchmark repository

## 🔮 Future Work (Phase 2)
1. **Biaffine neural encoder**: Pre-train a neural feature extractor that learns variable-pair "arc scores"
2. **Portfolio regret loss** (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE
3. **Hyperparameter co-selection**: Predict not just which algorithm but optimal hyperparameters (CASH)
4. **Ensemble prediction**: Run top-3 and vote on edges across their CPDAGs

## License

MIT