File size: 6,779 Bytes
3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 70796a3 3212b49 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | # Causal Discovery Algorithm Selection Meta-Learner
A meta-learning system that predicts the **top-3 best causal discovery algorithms** for any discrete observational dataset, based on dataset meta-features.
## 🎯 What it Does
Given a new discrete dataset (pandas DataFrame), the system:
1. **Extracts 34 meta-features** (entropy, mutual information, chi² statistics, CI test probes, etc.)
2. **Predicts normalized SHD** for each of 9 algorithms via trained models
3. **Ranks and returns the top-3** algorithms expected to produce the most accurate CPDAG
## 📊 Performance (Leave-One-Network-Out Cross-Validation)
### Best Model: Pairwise-GBM Ranking
| Metric | Value |
|--------|-------|
| **Top-3 Hit Rate** | **71.3%** (true best algorithm is in predicted top-3) |
| **Mean Regret** | **0.011** (tiny SHD gap vs oracle selection) |
| **Median Regret** | **0.000** (majority of predictions are perfect) |
### Model Comparison (178 configs, 14 networks + augmented)
| Model | Top-3 Hit Rate | NDCG@3 | Mean Regret |
|-------|---------------|--------|-------------|
| **Pairwise-GBM** | **71.3%** | — | 0.011 |
| GBM-300-lr01 | 67.4% | 0.957 | 0.011 |
| RF-200 | 66.9% | 0.961 | 0.007 |
| RF-500 | 66.3% | 0.962 | 0.007 |
| GBM-500-lr05 | 65.2% | 0.948 | 0.013 |
### Progression
| Stage | Configs | Networks | Top-3 Hit Rate |
|-------|---------|----------|---------------|
| Initial (small nets) | 65 | 4 | 68.2% |
| All 14 networks | 122 | 14 | 70.5% |
| + Data augmentation | 178 | 14+aug | **71.3%** |
## 🧪 Algorithm Pool (9 algorithms)
| Algorithm | Family | Library | Output | Wins |
|-----------|--------|---------|--------|------|
| **GES** | Score-based | causal-learn | CPDAG | 47% |
| **PC** | Constraint-based | causal-learn | CPDAG | 32% |
| **FCI** | Constraint-based | causal-learn | PAG | 8% |
| **K2** | Score-based | pgmpy | DAG | 6% |
| **HC** | Score-based (greedy) | pgmpy | DAG | 3% |
| **Tabu** | Score-based (meta) | pgmpy | DAG | 2% |
| **GRaSP** | Permutation-based | causal-learn | CPDAG | 1% |
| **BOSS** | Permutation-based | causal-learn | CPDAG | 1% |
| **MMHC** | Hybrid | pgmpy | DAG | <1% |
## 🔬 Key Insight: Dependency Parsing Connection
This project was inspired by a structural parallel between **NLP dependency parsing** and **causal discovery**:
- Both predict **directed graphs** over nodes (words/variables)
- Both have **ground-truth annotations** (treebanks/bnlearn networks)
- Both use **arc-level evaluation** (UAS/LAS ↔ SHD/F1)
The biaffine pairwise scoring mechanism from Dozat & Manning (2017) was independently reinvented by AVICI and CauScale for causal structure learning — validating this connection.
### Top Predictive Meta-Features
1. `n_variables` (30%) — network size (how many nodes in the graph)
2. `max_pairwise_MI` (24%) — strongest pairwise dependency (≈ biaffine arc score)
3. `max_cramers_v` (8%) — strongest association strength
4. `max_entropy` (7%) — variable complexity
### Three Ideas Borrowed from Parsing
1. **Biaffine-style pairwise features**: MI and Cramér's V between all variable pairs = parsing's arc scores
2. **Pairwise ranking** (our best model): For each algorithm pair (A,B), predict which wins → count wins to rank. Inspired by pairwise tournament-style parser selection
3. **Cross-domain transfer**: Train on well-characterized bnlearn networks → predict on new unseen datasets (= cross-lingual parser transfer)
## 🚀 Quick Start
```python
from causal_selection.meta_learner.predictor import predict_best_algorithms
import pandas as pd
# Load your discrete dataset
df = pd.read_csv("my_discrete_data.csv")
# Get top-3 recommendations
result = predict_best_algorithms(df, k=3)
# Prints ranked algorithms with predicted accuracy and confidence
```
## 📁 Project Structure
```
causal_selection/
├── data/
│ ├── generator.py # Load bnlearn networks, sample data, DAG→CPDAG
│ ├── bif_files/ # 14 bnlearn BIF files (asia through win95pts)
│ └── results/ # Benchmark CSVs: meta-features, SHD matrices
├── discovery/
│ ├── algorithms.py # 9 algorithm adapters with timeout handling
│ └── evaluator.py # SHD, F1, Precision, Recall computation
├── features/
│ └── extractor.py # 34 meta-features across 5 tiers
├── meta_learner/
│ ├── trainer.py # Multi-Output RF/GBM + LONO-CV evaluation
│ └── predictor.py # Inference: dataset → top-3 prediction
├── models/
│ ├── meta_learner.pkl # Trained GBM (multi-output fallback)
│ ├── pairwise_model.pkl # Pairwise ranking GBM (best model)
│ └── scaler.pkl # Feature scaler
├── benchmark.py # Full benchmark orchestration
├── run_benchmark.py # Resumable benchmark runner
└── augment_and_improve.py # Data augmentation + model improvement
```
## 📈 Benchmark Data
- **14 bnlearn networks**: asia, cancer, earthquake, sachs, survey, alarm, barley, child, insurance, mildew, water, hailfinder, hepar2, win95pts
- **178 dataset configs**: 122 original + 56 augmented (variable subsampling, sample-size variation, noise injection)
- **1,600+ algorithm runs**: 9 algorithms × 178 configs with per-algorithm timeout
### Data Augmentation Strategies
- **Variable subsampling**: Drop 20-40% of variables to create virtual sub-networks
- **Sample-size variation**: Generate N=300, 750, 1500, 3000 for each network
- **Noise injection**: Randomly flip 5-10% of categorical values
## 🔧 Dependencies
```
causal-learn>=0.1.4
pgmpy>=0.1.25
scikit-learn>=1.8
pandas
numpy
scipy
joblib
```
## 📚 References
- **Causal-Copilot** (arxiv:2504.13263) — Closest existing algorithm selection system
- **AVICI** (arxiv:2205.12934) — Amortized causal structure learning (biaffine architecture)
- **CauScale** (arxiv:2602.08629) — Scalable neural causal discovery
- **Dozat & Manning** (arxiv:1611.01734) — Deep Biaffine Attention for dependency parsing
- **TreeCRF** (arxiv:2005.00975) — Global structural training loss for parsing
- **SATzilla** (arxiv:1401.2474) — Algorithm selection via meta-learning
- **bnlearn** (bnlearn.com) — Bayesian network benchmark repository
## 🔮 Future Work (Phase 2)
1. **Biaffine neural encoder**: Pre-train a neural feature extractor that learns variable-pair "arc scores"
2. **Portfolio regret loss** (TreeCRF-inspired): Global ranking optimization instead of per-algorithm MSE
3. **Hyperparameter co-selection**: Predict not just which algorithm but optimal hyperparameters (CASH)
4. **Ensemble prediction**: Run top-3 and vote on edges across their CPDAGs
## License
MIT
|