mfarnas commited on
Commit
4da4fcb
·
1 Parent(s): 80ed7f2

move st_shap to inference_utils

Browse files
src/inference_utils.py CHANGED
@@ -1,6 +1,8 @@
1
  import pandas as pd
2
  import streamlit as st
 
3
  from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, brier_score_loss, log_loss
 
4
 
5
  def compute_metrics(y_true, y_pred_proba, threshold=0.5):
6
  y_pred = (y_pred_proba >= threshold).astype(int)
@@ -37,3 +39,51 @@ def add_predictions(df, probs):
37
  )
38
 
39
  return df_styled
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
  import streamlit as st
3
+ import shap
4
  from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, brier_score_loss, log_loss
5
+ import streamlit.components.v1 as components
6
 
7
  def compute_metrics(y_true, y_pred_proba, threshold=0.5):
8
  y_pred = (y_pred_proba >= threshold).astype(int)
 
39
  )
40
 
41
  return df_styled
42
+
43
+ def st_shap(plot, height=None):
44
+ shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
45
+ components.html(shap_html, height=height)
46
+
47
+ def ensemble_shap(models, X, model_weights=None):
48
+ """
49
+ Compute ensemble SHAP values for a list of tree-based models.
50
+ Returns a shap.Explanation with mean SHAP values across models.
51
+ """
52
+ import numpy as np
53
+ import shap
54
+
55
+ all_values = []
56
+ all_base_values = []
57
+
58
+ for model in models:
59
+ explainer = shap.TreeExplainer(model)
60
+ shap_values = explainer(X)
61
+
62
+ # Handle binary classification
63
+ if shap_values.values.ndim == 3:
64
+ # safer class selection
65
+ class_index = getattr(model, "classes_", [0, 1]).index(1)
66
+ shap_values = shap.Explanation(
67
+ values=shap_values.values[:, :, class_index],
68
+ base_values=shap_values.base_values[:, class_index],
69
+ data=X,
70
+ feature_names=X.columns
71
+ )
72
+
73
+ all_values.append(shap_values.values)
74
+ all_base_values.append(shap_values.base_values)
75
+
76
+ # Handle weights
77
+ if model_weights is None:
78
+ model_weights = np.ones(len(models))
79
+ model_weights = np.array(model_weights) / np.sum(model_weights)
80
+
81
+ mean_values = np.average(all_values, axis=0, weights=model_weights)
82
+ mean_base = np.average(all_base_values, axis=0, weights=model_weights)
83
+
84
+ return shap.Explanation(
85
+ values=mean_values,
86
+ base_values=mean_base,
87
+ data=X,
88
+ feature_names=X.columns
89
+ )
src/pages/3_Preprocessing_and_Training.py CHANGED
@@ -1,26 +1,23 @@
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import matplotlib.pyplot as plt
5
  from catboost import CatBoostClassifier, cv, Pool
 
 
6
  from model_utils import get_model, save_model, save_model_ensemble, ensemble_predict
7
  from preprocess_utils import load_train_features
8
  from preprocess_utils import preprocess_pipeline as preprocess
9
- from inference_utils import compute_metrics
10
  from sidebar import sidebar
11
- from sklearn.model_selection import StratifiedKFold
12
- import os
13
- from pathlib import Path
14
- import pyarrow.parquet as pq
15
  import shap
16
  import lime
17
  import lime.lime_tabular
18
- # Add this helper function at the top of the file
19
- import streamlit.components.v1 as components
20
-
21
- def st_shap(plot, height=None):
22
- shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
23
- components.html(shap_html, height=height)
24
 
25
  LOCAL = False
26
 
@@ -255,7 +252,7 @@ if "trained_model" in st.session_state or "trained_models" in st.session_state:
255
  shap_values_selected.values[sample_idx, :],
256
  X_force.iloc[sample_idx, :]
257
  ),
258
- height=250
259
  )
260
 
261
  # ---- Display feature + SHAP values for selected single-sample ----
 
1
+ import os
2
+ from pathlib import Path
3
+ import pyarrow.parquet as pq
4
+
5
  import streamlit as st
6
  import pandas as pd
7
  import numpy as np
8
  import matplotlib.pyplot as plt
9
  from catboost import CatBoostClassifier, cv, Pool
10
+ from sklearn.model_selection import StratifiedKFold
11
+
12
  from model_utils import get_model, save_model, save_model_ensemble, ensemble_predict
13
  from preprocess_utils import load_train_features
14
  from preprocess_utils import preprocess_pipeline as preprocess
15
+ from inference_utils import compute_metrics, st_shap
16
  from sidebar import sidebar
17
+
 
 
 
18
  import shap
19
  import lime
20
  import lime.lime_tabular
 
 
 
 
 
 
21
 
22
  LOCAL = False
23
 
 
252
  shap_values_selected.values[sample_idx, :],
253
  X_force.iloc[sample_idx, :]
254
  ),
255
+ height=200
256
  )
257
 
258
  # ---- Display feature + SHAP values for selected single-sample ----