BDR-AI commited on
Commit
4160c28
·
verified ·
1 Parent(s): c28ea98

Add ensemble predictor with 5-model architecture (Step 3/5)

Browse files
Files changed (1) hide show
  1. ensemble_predictor.py +219 -0
ensemble_predictor.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ensemble Predictor - 5-Model Architecture with Meta Learning
3
+ Implements the Maysat method with weighted voting and stacked generalization
4
+ """
5
+
6
+ import pickle
7
+ import json
8
+ import os
9
+ import numpy as np
10
+ from typing import Dict, List, Tuple, Any
11
+
12
+ class EnsemblePredictor:
13
+ """
14
+ Ensemble fraud detection using 5 models + meta learner
15
+ - Random Forest (baseline)
16
+ - XGBoost (gradient boosting)
17
+ - LightGBM (fast training)
18
+ - CatBoost (categorical features)
19
+ - DistilBERT (text analysis via text_processor)
20
+ """
21
+
22
+ def __init__(self):
23
+ self.models = {}
24
+ self.meta_learner = None
25
+ self.scaler = None
26
+ self.encoder = None
27
+ self.feature_columns = None
28
+ self.model_weights = {
29
+ 'xgboost': 0.25,
30
+ 'lightgbm': 0.25,
31
+ 'catboost': 0.20,
32
+ 'random_forest': 0.15,
33
+ 'distilbert': 0.15
34
+ }
35
+ self.load_models()
36
+
37
+ def load_models(self):
38
+ """Load all model artifacts if available"""
39
+ try:
40
+ models_path = 'models/'
41
+
42
+ # Load Random Forest (baseline)
43
+ if os.path.exists(f'{models_path}fraud_rf_model.pkl'):
44
+ with open(f'{models_path}fraud_rf_model.pkl', 'rb') as f:
45
+ self.models['random_forest'] = pickle.load(f)
46
+ print("✓ Random Forest loaded")
47
+
48
+ # Load XGBoost
49
+ if os.path.exists(f'{models_path}fraud_xgb_model.pkl'):
50
+ with open(f'{models_path}fraud_xgb_model.pkl', 'rb') as f:
51
+ self.models['xgboost'] = pickle.load(f)
52
+ print("✓ XGBoost loaded")
53
+
54
+ # Load LightGBM
55
+ if os.path.exists(f'{models_path}fraud_lgb_model.pkl'):
56
+ with open(f'{models_path}fraud_lgb_model.pkl', 'rb') as f:
57
+ self.models['lightgbm'] = pickle.load(f)
58
+ print("✓ LightGBM loaded")
59
+
60
+ # Load CatBoost
61
+ if os.path.exists(f'{models_path}fraud_cat_model.pkl'):
62
+ with open(f'{models_path}fraud_cat_model.pkl', 'rb') as f:
63
+ self.models['catboost'] = pickle.load(f)
64
+ print("✓ CatBoost loaded")
65
+
66
+ # Load preprocessing artifacts
67
+ if os.path.exists(f'{models_path}fraud_scaler.pkl'):
68
+ with open(f'{models_path}fraud_scaler.pkl', 'rb') as f:
69
+ self.scaler = pickle.load(f)
70
+
71
+ if os.path.exists(f'{models_path}fraud_encoder.pkl'):
72
+ with open(f'{models_path}fraud_encoder.pkl', 'rb') as f:
73
+ self.encoder = pickle.load(f)
74
+
75
+ if os.path.exists(f'{models_path}feature_columns.json'):
76
+ with open(f'{models_path}feature_columns.json', 'r') as f:
77
+ self.feature_columns = json.load(f)
78
+
79
+ # Load meta learner if available
80
+ if os.path.exists(f'{models_path}meta_learner.pkl'):
81
+ with open(f'{models_path}meta_learner.pkl', 'rb') as f:
82
+ self.meta_learner = pickle.load(f)
83
+ print("✓ Meta Learner loaded")
84
+
85
+ print(f"✓ Ensemble loaded: {len(self.models)} models")
86
+
87
+ except Exception as e:
88
+ print(f"Model loading error: {e}")
89
+
90
+ def predict_ensemble(self, features: np.ndarray, text_score: float = None) -> Dict[str, Any]:
91
+ """
92
+ Predict using ensemble with weighted voting
93
+
94
+ Args:
95
+ features: Engineered features array
96
+ text_score: Optional text analysis score from DistilBERT
97
+
98
+ Returns:
99
+ Dictionary with ensemble prediction and individual model scores
100
+ """
101
+ if len(self.models) == 0:
102
+ return {
103
+ 'ensemble_score': None,
104
+ 'method': 'No models loaded',
105
+ 'individual_scores': {}
106
+ }
107
+
108
+ try:
109
+ # Scale features
110
+ if self.scaler is not None:
111
+ features_scaled = self.scaler.transform([features])
112
+ else:
113
+ features_scaled = np.array([features])
114
+
115
+ # Get predictions from each model
116
+ individual_scores = {}
117
+
118
+ for model_name, model in self.models.items():
119
+ try:
120
+ # Get probability of fraud (class 1)
121
+ if hasattr(model, 'predict_proba'):
122
+ prob = model.predict_proba(features_scaled)[0][1]
123
+ else:
124
+ prob = model.predict(features_scaled)[0]
125
+
126
+ individual_scores[model_name] = float(prob)
127
+ except Exception as e:
128
+ print(f"Error predicting with {model_name}: {e}")
129
+ individual_scores[model_name] = 0.0
130
+
131
+ # Add text score if available
132
+ if text_score is not None:
133
+ individual_scores['distilbert'] = text_score
134
+
135
+ # Ensemble prediction
136
+ if self.meta_learner is not None:
137
+ # Use meta learner (stacked generalization)
138
+ meta_features = np.array([[individual_scores.get(m, 0.0) for m in self.model_weights.keys()]])
139
+ ensemble_score = self.meta_learner.predict_proba(meta_features)[0][1]
140
+ method = "Meta Learner (Stacked)"
141
+ else:
142
+ # Use weighted voting
143
+ ensemble_score = 0.0
144
+ total_weight = 0.0
145
+
146
+ for model_name, weight in self.model_weights.items():
147
+ if model_name in individual_scores:
148
+ ensemble_score += individual_scores[model_name] * weight
149
+ total_weight += weight
150
+
151
+ if total_weight > 0:
152
+ ensemble_score /= total_weight
153
+
154
+ method = "Weighted Voting"
155
+
156
+ return {
157
+ 'ensemble_score': float(ensemble_score),
158
+ 'method': method,
159
+ 'individual_scores': individual_scores,
160
+ 'num_models': len(individual_scores)
161
+ }
162
+
163
+ except Exception as e:
164
+ print(f"Ensemble prediction error: {e}")
165
+ return {
166
+ 'ensemble_score': None,
167
+ 'method': 'Error',
168
+ 'individual_scores': {},
169
+ 'error': str(e)
170
+ }
171
+
172
+ def get_model_status(self) -> Dict[str, bool]:
173
+ """Check which models are loaded"""
174
+ return {
175
+ 'random_forest': 'random_forest' in self.models,
176
+ 'xgboost': 'xgboost' in self.models,
177
+ 'lightgbm': 'lightgbm' in self.models,
178
+ 'catboost': 'catboost' in self.models,
179
+ 'meta_learner': self.meta_learner is not None,
180
+ 'scaler': self.scaler is not None,
181
+ 'encoder': self.encoder is not None
182
+ }
183
+
184
+ def get_feature_importance(self, model_name: str = 'random_forest') -> List[Tuple[str, float]]:
185
+ """Get feature importance from specified model"""
186
+ if model_name not in self.models:
187
+ return []
188
+
189
+ model = self.models[model_name]
190
+
191
+ if hasattr(model, 'feature_importances_'):
192
+ importances = model.feature_importances_
193
+ if self.feature_columns:
194
+ return sorted(
195
+ zip(self.feature_columns, importances),
196
+ key=lambda x: x[1],
197
+ reverse=True
198
+ )
199
+
200
+ return []
201
+
202
+
203
+ # Test the ensemble
204
+ if __name__ == "__main__":
205
+ print("="*60)
206
+ print("Ensemble Predictor - Model Status Check")
207
+ print("="*60)
208
+
209
+ ensemble = EnsemblePredictor()
210
+ status = ensemble.get_model_status()
211
+
212
+ print("\nModel Status:")
213
+ for model, loaded in status.items():
214
+ status_icon = "✓" if loaded else "✗"
215
+ print(f" {status_icon} {model}: {'Loaded' if loaded else 'Not found'}")
216
+
217
+ print("\n" + "="*60)
218
+ print(f"Ensemble ready with {len(ensemble.models)} models")
219
+ print("="*60)