Oguzz07 commited on
Commit
eabf58d
·
verified ·
1 Parent(s): 6f65cd1

Add causal_selection/meta_learner/predictor.py

Browse files
causal_selection/meta_learner/predictor.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference pipeline: given a new discrete dataset, predict the top-3 causal discovery algorithms.
3
+ """
4
+ import numpy as np
5
+ import pandas as pd
6
+ import logging
7
+ import json
8
+
9
+ from causal_selection.features.extractor import extract_all_features, FEATURE_NAMES, features_to_vector
10
+ from causal_selection.meta_learner.trainer import load_model, ALGO_NAMES
11
+ from causal_selection.discovery.algorithms import ALGORITHM_POOL
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def predict_best_algorithms(df, k=3, model=None, scaler=None, verbose=True):
17
+ """Given a new discrete dataset, predict the top-k best causal discovery algorithms.
18
+
19
+ Args:
20
+ df: pd.DataFrame with integer-encoded discrete columns
21
+ k: number of top algorithms to recommend
22
+ model: pre-loaded model (optional, loaded from disk if None)
23
+ scaler: pre-loaded scaler (optional)
24
+ verbose: print details
25
+
26
+ Returns:
27
+ dict with:
28
+ - 'top_k': list of (algo_name, predicted_score) tuples, best first
29
+ - 'full_ranking': list of all (algo_name, predicted_score)
30
+ - 'meta_features': dict of extracted features
31
+ - 'confidence': estimated confidence based on prediction spread
32
+ """
33
+ # Load model if not provided
34
+ if model is None or scaler is None:
35
+ model, scaler = load_model()
36
+
37
+ # Extract meta-features
38
+ if verbose:
39
+ print(f"Dataset shape: {df.shape}")
40
+ print(f"Extracting meta-features...")
41
+
42
+ features = extract_all_features(df)
43
+ feature_vector = features_to_vector(features).reshape(1, -1)
44
+
45
+ # Scale and predict
46
+ X_scaled = scaler.transform(feature_vector)
47
+ predicted_scores = model.predict(X_scaled)[0] # normalized SHD predictions
48
+
49
+ # Rank algorithms (lower predicted score = better)
50
+ ranking_indices = np.argsort(predicted_scores)
51
+
52
+ full_ranking = [(ALGO_NAMES[i], float(predicted_scores[i])) for i in ranking_indices]
53
+ top_k = full_ranking[:k]
54
+
55
+ # Confidence: how much better is top-1 vs others?
56
+ scores_sorted = sorted(predicted_scores)
57
+ spread = scores_sorted[-1] - scores_sorted[0] if len(scores_sorted) > 1 else 0
58
+ gap_top1_top2 = scores_sorted[1] - scores_sorted[0] if len(scores_sorted) > 1 else 0
59
+
60
+ result = {
61
+ 'top_k': top_k,
62
+ 'full_ranking': full_ranking,
63
+ 'meta_features': features,
64
+ 'confidence': {
65
+ 'score_spread': spread,
66
+ 'top1_top2_gap': gap_top1_top2,
67
+ 'recommendation': _get_confidence_text(gap_top1_top2, spread),
68
+ }
69
+ }
70
+
71
+ if verbose:
72
+ print(f"\n{'='*60}")
73
+ print(f"TOP-{k} ALGORITHM RECOMMENDATIONS")
74
+ print(f"{'='*60}")
75
+ for rank, (algo, score) in enumerate(top_k, 1):
76
+ algo_info = ALGORITHM_POOL[algo]
77
+ print(f"\n #{rank}: {algo}")
78
+ print(f" Predicted nSHD: {score:.4f}")
79
+ print(f" Family: {algo_info['family']}")
80
+ print(f" Output: {algo_info['output_type']}")
81
+ print(f" Library: {algo_info['library']}")
82
+
83
+ print(f"\n{'='*60}")
84
+ print(f"FULL RANKING")
85
+ print(f"{'='*60}")
86
+ for rank, (algo, score) in enumerate(full_ranking, 1):
87
+ marker = " <<<" if rank <= k else ""
88
+ print(f" {rank:2d}. {algo:15s} nSHD={score:.4f}{marker}")
89
+
90
+ print(f"\nConfidence: {result['confidence']['recommendation']}")
91
+
92
+ # Key dataset properties
93
+ print(f"\n{'='*60}")
94
+ print(f"DATASET CHARACTERISTICS")
95
+ print(f"{'='*60}")
96
+ print(f" Variables: {features['n_variables']:.0f}")
97
+ print(f" Samples: {features['n_samples']:.0f}")
98
+ print(f" N/P ratio: {features['n_over_p']:.1f}")
99
+ print(f" Avg cardinality: {features['avg_cardinality']:.1f}")
100
+ print(f" Density proxy: {features['density_proxy']:.3f}")
101
+ print(f" Mean MI: {features['mean_pairwise_MI']:.4f}")
102
+ print(f" V-structure proxy: {features['v_structure_proxy']:.3f}")
103
+
104
+ return result
105
+
106
+
107
+ def _get_confidence_text(gap, spread):
108
+ """Generate human-readable confidence assessment."""
109
+ if spread < 0.01:
110
+ return "LOW - All algorithms predicted to perform similarly. Consider running top-3 and comparing."
111
+ elif gap > 0.05:
112
+ return "HIGH - Clear winner predicted. Top-1 algorithm strongly recommended."
113
+ elif gap > 0.02:
114
+ return "MEDIUM - Top algorithms are close. Running top-3 recommended for comparison."
115
+ else:
116
+ return "LOW-MEDIUM - Marginal differences between top algorithms. Run all top-3."
117
+
118
+
119
+ if __name__ == '__main__':
120
+ logging.basicConfig(level=logging.INFO)
121
+
122
+ # Demo: predict on Asia network
123
+ from causal_selection.data.generator import load_bn_model, sample_dataset
124
+
125
+ model = load_bn_model('sachs')
126
+ df = sample_dataset(model, 2000, seed=99)
127
+
128
+ result = predict_best_algorithms(df, k=3, verbose=True)