AseemD commited on
Commit
991c693
·
verified ·
1 Parent(s): 9c2a3c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +248 -0
app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import shap
3
+ import lime
4
+ from lime.lime_tabular import LimeTabularExplainer
5
+ import pandas as pd
6
+ import numpy as np
7
+ import joblib
8
+ import seaborn as sns
9
+ import xgboost as xgb
10
+ import streamlit.components.v1 as components
11
+ from sklearn.ensemble import RandomForestClassifier
12
+ from sklearn.linear_model import LogisticRegression
13
+ from sklearn.datasets import make_classification
14
+ from sklearn.model_selection import train_test_split
15
+ import matplotlib.pyplot as plt
16
+ from imblearn.pipeline import Pipeline as imbPipeline
17
+
18
+
19
+ def load_dataset(name):
20
+ if name == "Financial":
21
+ # Replace with your dataset
22
+ data = pd.read_csv("datasets/loan_approval_dataset.csv")
23
+ data.columns = data.columns.str.strip()
24
+ data = data.drop(columns=['loan_id'])
25
+ # Remove leading/trailing spaces from the categorical column values
26
+ data['education'] = data['education'].str.strip()
27
+ data['self_employed'] = data['self_employed'].str.strip()
28
+ data['loan_status'] = data['loan_status'].str.strip()
29
+ # Encode categorical variables
30
+ data['education'] = data['education'].map({'Graduate': 1, 'Not Graduate': 0})
31
+ data['self_employed'] = data['self_employed'].map({'Yes': 1, 'No': 0})
32
+ data['loan_status'] = data['loan_status'].map({'Approved': 1, 'Rejected': 0})
33
+
34
+ elif name == "NLP":
35
+ # Replace with your dataset and all the preprocessing steps
36
+ data = pd.read_csv("datasets/nlp_dataset.csv")
37
+
38
+ elif name == "Healthcare":
39
+ data = pd.read_csv("datasets/healthcare_dataset.csv")
40
+ data.columns = data.columns.str.strip()
41
+ data = data.drop_duplicates()
42
+ data = data[data['gender'] != 'Other']
43
+ def recategorize_smoking(smoking_status):
44
+ if smoking_status in ['never', 'No Info']:
45
+ return 0
46
+ elif smoking_status == 'current':
47
+ return 1
48
+ elif smoking_status in ['ever', 'former', 'not current']:
49
+ return 2
50
+
51
+ data['smoking_history'] = data['smoking_history'].apply(recategorize_smoking)
52
+ data['gender'] = data['gender'].map({'Male': 0, 'Female': 1})
53
+
54
+
55
+ return data
56
+
57
+ def load_models(dataset_name):
58
+ if dataset_name == "Financial":
59
+ return joblib.load("models/loan_models.pkl")
60
+ elif dataset_name == "NLP":
61
+ return joblib.load("models/nlp_models.pkl")
62
+ elif dataset_name == "Healthcare":
63
+ model_path = "models/healthcare_models.pkl"
64
+ model = joblib.load(model_path)
65
+ return {"Random Forest": model}
66
+
67
+
68
+ def main():
69
+ plt.style.use('default')
70
+ #st.set_option('deprecation.showPyplotGlobalUse', False)
71
+ st.title("Model Interpretability Visualization with LIME and SHAP")
72
+
73
+ # Create different sections for each dataset
74
+ st.subheader("1. Select a Dataset")
75
+ dataset = st.selectbox("Choose a dataset:", ["Financial", "Healthcare"])
76
+
77
+ # Perform different interpretability methods on the first dataset
78
+ if dataset == "Financial":
79
+ # 1. Load the dataset
80
+ X = load_dataset(dataset)
81
+ st.write(f"{dataset} Dataset Sample")
82
+ st.write(X.head())
83
+
84
+ # 2. Select interpretability method
85
+ st.subheader("2. Select an Interpretability Method")
86
+ method = st.selectbox("Choose an interpretability method:", ["LIME", "SHAP"])
87
+
88
+ if method == "SHAP":
89
+ st.subheader("3. Interpretability using SHAP")
90
+ # SHAP analysis
91
+ loaded_models = load_models(dataset)
92
+ model = loaded_models['XG Boost']
93
+ sns.set_style('whitegrid')
94
+ X = X.drop(columns=["loan_status"]).copy()
95
+ X = X.astype(float)
96
+ explainer = shap.Explainer(model)
97
+ shap_values = explainer(X)
98
+
99
+ # Visualize SHAP values
100
+ idx = st.slider("Select Test Instance", 0, len(X) - 1, 0)
101
+ st.write("SHAP Force Plot for a Single Prediction")
102
+ shap.force_plot(explainer.expected_value, shap_values[idx].values, X.iloc[idx], matplotlib=True, show=False)
103
+ st.pyplot(bbox_inches='tight')
104
+ st.write("SHAP Summary Plot")
105
+ shap.summary_plot(shap_values, X, show=False)
106
+ st.pyplot(bbox_inches='tight')
107
+ st.write("SHAP Bar Plot")
108
+ shap.summary_plot(shap_values, X, plot_type="bar", show=False)
109
+ st.pyplot(bbox_inches='tight')
110
+
111
+ elif method == "LIME":
112
+ st.subheader("3. Interpretability using LIME")
113
+ # Choose model type
114
+ model_choice = st.radio("Select Model", ["Logistic Regression", 'Decision Tree', 'XG Boost', "Random Forest"])
115
+ loaded_models = load_models(dataset)
116
+ model = loaded_models[model_choice]
117
+ sns.set_style('whitegrid')
118
+ x = X.iloc[: , :-1].values
119
+ y = X.iloc[: , -1].values
120
+ X_train, X_test, y_train, y_test = train_test_split(x, y,
121
+ test_size=0.25,
122
+ random_state=42)
123
+ target = ['Rejected', 'Approved']
124
+ labels = {'0': 'Rejected', '1': 'Approved'}
125
+ idx = st.slider("Select Test Instance", 0, len(X_test) - 1, 0)
126
+
127
+ # Explain the prediction instance using LIME
128
+ explainer = lime.lime_tabular.LimeTabularExplainer(
129
+ X_train,
130
+ feature_names=list(X.columns),
131
+ class_names=target,
132
+ discretize_continuous=True,
133
+ )
134
+ exp = explainer.explain_instance(
135
+ X_test[idx],
136
+ model.predict_proba,
137
+ )
138
+
139
+ # Visualize the explanation
140
+ st.write("LIME Explanation")
141
+ exp.save_to_file('lime_explanation.html')
142
+ HtmlFile = open(f'lime_explanation.html', 'r', encoding='utf-8')
143
+ components.html(HtmlFile.read(), height=600)
144
+ st.write('True label:', labels[str(y_test[idx])])
145
+ st.write("Effect of Predictors")
146
+ exp.as_pyplot_figure()
147
+ st.pyplot(bbox_inches='tight')
148
+
149
+ # Perform different interpretability methods on the second dataset
150
+ elif dataset == "Healthcare":
151
+ data = load_dataset(dataset)
152
+ st.write(f"{dataset} Dataset Sample")
153
+ st.write(data.head())
154
+
155
+ st.subheader("2. Select an Interpretability Method")
156
+ method = st.selectbox("Choose an interpretability method:", ["LIME", "SHAP"])
157
+
158
+ loaded_models = load_models(dataset)
159
+ model = loaded_models.get('Random Forest')
160
+
161
+ idx = st.slider("Select Test Instance", 0, 24031, 0)
162
+
163
+
164
+ if method == "SHAP":
165
+ st.subheader("3. Interpretability using SHAP")
166
+ loaded_models = load_models(dataset)
167
+ model = loaded_models.get('Random Forest')
168
+ if model and isinstance(model, imbPipeline):
169
+ st.write("Model loaded and is a valid pipeline.")
170
+ try:
171
+ if 'classifier' in model.named_steps:
172
+ tree_model = model.named_steps['classifier']
173
+ if isinstance(tree_model, RandomForestClassifier):
174
+ explainer = shap.TreeExplainer(tree_model)
175
+ X_shap = data.drop(columns=["diabetes"])
176
+ st.write(f"Data shape for SHAP: {X_shap.shape}")
177
+
178
+ sample_size = 1000
179
+ X_sample = X_shap.sample(n=sample_size, random_state=42)
180
+ st.write(f"Using a sample of {sample_size} instances for SHAP analysis.")
181
+
182
+ shap_values = explainer.shap_values(X_sample)
183
+
184
+ st.write(f"SHAP values shape: {np.array(shap_values).shape}")
185
+
186
+ idx = st.slider("Select Test Instance", 0, len(X_sample) - 1, 0)
187
+ st.write("SHAP Force Plot for a Single Prediction")
188
+ shap.force_plot(explainer.expected_value[1], shap_values[1][idx, :], X_sample.iloc[idx, :], matplotlib=True, show=False)
189
+ st.pyplot(bbox_inches='tight')
190
+
191
+ st.write("SHAP Summary Plot")
192
+ shap.summary_plot(shap_values[1], X_sample, show=False)
193
+ st.pyplot(bbox_inches='tight')
194
+
195
+ st.write("SHAP Bar Plot")
196
+ shap.summary_plot(shap_values[1], X_sample, plot_type="bar", show=False)
197
+ st.pyplot(bbox_inches='tight')
198
+ else:
199
+ st.error("The classifier in the pipeline is not a RandomForest.")
200
+ else:
201
+ st.error("RandomForest classifier not found in the pipeline.")
202
+ except Exception as e:
203
+ st.error(f"Error during SHAP analysis: {e}")
204
+ else:
205
+ st.error("Model could not be loaded or is not a valid RandomForest pipeline.")
206
+
207
+
208
+ elif method == "LIME":
209
+ st.subheader("3. Interpretability using LIME")
210
+ model_choice = st.radio("Select Model", ["Random Forest"])
211
+ model = loaded_models.get('Random Forest')
212
+ sns.set_style('whitegrid')
213
+ x = data.drop(columns=["diabetes"])
214
+ y = data["diabetes"]
215
+ X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.25, random_state=42)
216
+
217
+ target = ['Non-Diabetic', 'Diabetic']
218
+
219
+ explainer = LimeTabularExplainer(
220
+ X_train.values,
221
+ feature_names=X_train.columns.tolist(),
222
+ class_names=target,
223
+ verbose=True,
224
+ mode='classification'
225
+ )
226
+
227
+ instance = X_test.iloc[idx].values.reshape(1, -1)
228
+
229
+ def model_predict(instance):
230
+ return model.predict_proba(pd.DataFrame(instance, columns=X_train.columns))
231
+
232
+ exp = explainer.explain_instance(
233
+ data_row=instance[0],
234
+ predict_fn=model_predict
235
+ )
236
+
237
+ st.write("LIME Explanation")
238
+ exp.save_to_file('lime_explanation.html')
239
+ HtmlFile = open('lime_explanation.html', 'r', encoding='utf-8')
240
+ components.html(HtmlFile.read(), height=600)
241
+ st.write('True label:', target[y_test.iloc[idx]])
242
+ st.write("Effect of Predictors")
243
+ exp.as_pyplot_figure()
244
+ st.pyplot(bbox_inches='tight')
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main()