jiehou commited on
Commit
4a45429
·
verified ·
1 Parent(s): 951b39c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import seaborn as sn
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.metrics import confusion_matrix
5
+ from matplotlib.colors import ListedColormap
6
+
7
+ import numpy as np
8
+
9
+ import gradio as gr
10
+
11
+
12
+ set_input = gr.Dataframe(type="numpy", row_count=10, col_count=3, headers=['Sample Index', 'Predicted Prob', 'Label (Y)'], datatype=["number", "number", "number"])
13
+ set_input2 = gr.Slider(0, 1, step = 0.1, value=0.4, label="Set Probability Threshold (Default = 0.5)")
14
+
15
+ #set_output = gr.Textbox(label ='test')
16
+ set_output1 = gr.Dataframe(type="pandas", label = 'Predicted Labels',max_rows=10)
17
+
18
+ set_output2 = gr.Image(label="Confusion Matrix")
19
+ set_output3 = gr.Image(label="ROC curve")
20
+ set_output4 = gr.Image(label="Threshold Tuning curve")
21
+
22
+ def perf_measure(y_actual, y_hat):
23
+ TP = 0
24
+ FP = 0
25
+ TN = 0
26
+ FN = 0
27
+
28
+ for i in range(len(y_hat)):
29
+ if y_actual[i]==y_hat[i]==1:
30
+ TP += 1
31
+ if y_hat[i]==1 and y_actual[i]!=y_hat[i]:
32
+ FP += 1
33
+ if y_actual[i]==y_hat[i]==0:
34
+ TN += 1
35
+ if y_hat[i]==0 and y_actual[i]!=y_hat[i]:
36
+ FN += 1
37
+
38
+ return(TP, FP, TN, FN)
39
+
40
+
41
+ def visualize_ROC(set_threshold,set_input):
42
+ import numpy as np
43
+ prob = set_input[:,1]
44
+ pred_label = (prob >= set_threshold).astype(int)
45
+ actual_label = set_input[:,2]
46
+ import pandas as pd
47
+
48
+ data = {
49
+ 'Predicted Prob': prob,
50
+ 'Predicted Label': pred_label,
51
+ 'Actual Label': actual_label
52
+ }
53
+
54
+ import pandas as pd
55
+ import seaborn as sn
56
+ import matplotlib.pyplot as plt
57
+
58
+
59
+
60
+ df = pd.DataFrame(data)
61
+ confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label'])
62
+
63
+ fig, ax = plt.subplots(figsize=(12,4))
64
+ sn.heatmap(confusion_matrix_results, annot=True,annot_kws={"size": 20},cbar=False,
65
+ square=False,
66
+ fmt='g',
67
+ cmap=ListedColormap(['white']), linecolor='black',
68
+ linewidths=1.5)
69
+
70
+ sn.set(font_scale=2)
71
+ plt.xlabel("Predicted Label")
72
+ plt.ylabel("Actual Label")
73
+ plt.text(0.6,0.55,'(TN)')
74
+ plt.text(1.6,0.55,'(FP)')
75
+ plt.text(0.6,1.55,'(FN)')
76
+ plt.text(1.6,1.55,'(TP)')
77
+
78
+ ax.xaxis.tick_top()
79
+
80
+ ax.xaxis.set_ticks_position('top')
81
+ ax.xaxis.set_label_position('top')
82
+ plt.tight_layout()
83
+
84
+ plt.savefig('tmp.png', dpi=100)
85
+
86
+ ## get ROC curve
87
+ from sklearn.metrics import roc_curve
88
+ fpr_mod, tpr_mod, thrsholds_mod = roc_curve(df['Actual Label'], df['Predicted Prob'])
89
+
90
+ TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label'])
91
+
92
+ # Sensitivity, hit rate, recall, or true positive rate
93
+ try:
94
+ recall = TP/(TP+FN)
95
+ except:
96
+ recall = 0
97
+
98
+ try:
99
+ precision = TP/(TP+FP)
100
+ except:
101
+ precision = 0
102
+
103
+ try:
104
+ specificity = TN/(TN+FP)
105
+ except:
106
+ specificity = 0
107
+
108
+ try:
109
+ TPR = TP/(TP+FN)
110
+ except:
111
+ TPR = 0
112
+
113
+ # Fall out or false positive rate
114
+ try:
115
+ FPR = FP/(FP+TN)
116
+ except:
117
+ FPR = 0
118
+
119
+
120
+ try:
121
+ f1_score_cur = 2*recall*precision/(precision+recall)
122
+ except:
123
+ f1_score_cur = 0
124
+
125
+ try:
126
+ g_mean_cur = np.sqrt(recall*specificity)
127
+ except:
128
+ g_mean_cur = 0
129
+
130
+
131
+ fig, ax = plt.subplots(figsize=(12,8))
132
+
133
+ import matplotlib.pyplot as plt
134
+ import numpy as np
135
+ plt.rcParams["figure.autolayout"] = True
136
+ plt.rcParams['figure.facecolor'] = 'white'
137
+ m1, c1 = 1, 0
138
+ x = np.linspace(0, 1, 500)
139
+
140
+ plt.plot(fpr_mod, tpr_mod, label = 'ROC', c='blue', linestyle='-')
141
+
142
+ plt.plot(x, x * m1 + c1, 'black', linestyle='--')
143
+ plt.xlim(0, 1)
144
+ plt.ylim(0, 1)
145
+ #xi = (c1 - c2) / (m2 - m1)
146
+ #yi = m1 * xi + c1
147
+ plt.axvline(x=FPR, color='gray', linestyle='--')
148
+ plt.axhline(y=TPR, color='gray', linestyle='--')
149
+ plt.scatter(FPR, TPR, color='red', s=300)
150
+
151
+ ax.set_facecolor("white")
152
+
153
+ ax.tick_params(axis='x', colors='black')
154
+ ax.tick_params(axis='y', colors='black')
155
+ ax.spines['left'].set_color('black')
156
+ ax.spines['bottom'].set_color('black')
157
+ ax.spines['top'].set_color('black')
158
+ ax.spines['right'].set_color('black')
159
+ plt.xlabel('False Positive Rate (1 - specificity)')
160
+ plt.ylabel('True Positive Rate (Recall)')
161
+ plt.text(FPR, TPR, 'FPR:%s, TPR:%s' % (round(FPR,2),round(TPR,2)))
162
+ plt.title("ROC curve", fontsize=20)
163
+ plt.tight_layout()
164
+
165
+ plt.savefig('tmp2.png', dpi=100)
166
+
167
+
168
+
169
+
170
+ ### plot threshold versus f1-score
171
+ thres_list = []
172
+ f1_score_list = []
173
+ g_mean_list = []
174
+ for thres in np.arange(0,1,0.01):
175
+ prob = set_input[:,1]
176
+ pred_label = (prob >= thres).astype(int)
177
+ actual_label = set_input[:,2]
178
+ import pandas as pd
179
+
180
+ data = {
181
+ 'Predicted Prob': prob,
182
+ 'Predicted Label': pred_label,
183
+ 'Actual Label': actual_label
184
+ }
185
+
186
+
187
+ df = pd.DataFrame(data)
188
+ confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label'])
189
+
190
+ TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label'])
191
+
192
+ # Sensitivity, hit rate, recall, or true positive rate
193
+ try:
194
+ recall = TP/(TP+FN)
195
+ except:
196
+ recall = 0
197
+
198
+ try:
199
+ precision = TP/(TP+FP)
200
+ except:
201
+ precision = 0
202
+
203
+ try:
204
+ specificity = TN/(TN+FP)
205
+ except:
206
+ specificity = 0
207
+
208
+ try:
209
+ TPR = TP/(TP+FN)
210
+ except:
211
+ TPR = 0
212
+
213
+ # Fall out or false positive rate
214
+ try:
215
+ FPR = FP/(FP+TN)
216
+ except:
217
+ FPR = 0
218
+
219
+ try:
220
+ f1_score = 2*recall*precision/(precision+recall)
221
+ except:
222
+ f1_score = 0
223
+
224
+ try:
225
+ g_mean = np.sqrt(recall*specificity)
226
+ except:
227
+ g_mean = 0
228
+
229
+
230
+ thres_list.append(thres)
231
+ f1_score_list.append(f1_score)
232
+ g_mean_list.append(g_mean)
233
+
234
+ fig, ax = plt.subplots(figsize=(12,8))
235
+
236
+ import matplotlib.pyplot as plt
237
+ import numpy as np
238
+ plt.rcParams["figure.autolayout"] = True
239
+ plt.rcParams['figure.facecolor'] = 'white'
240
+ m1, c1 = 1, 0
241
+ x = np.linspace(0, 1, 500)
242
+
243
+ plt.plot(thres_list, f1_score_list, label = 'F1-score', c='black', linestyle='-')
244
+ plt.plot(thres_list, g_mean_list, label = 'G-mean', c='red', linestyle='-')
245
+
246
+ plt.xlim(0, 1)
247
+ plt.ylim(0, 1)
248
+ #xi = (c1 - c2) / (m2 - m1)
249
+ #yi = m1 * xi + c1
250
+ plt.axvline(x=set_threshold, color='gray', linestyle='--')
251
+ plt.axhline(y=f1_score_cur, color='gray', linestyle='--')
252
+ plt.scatter(set_threshold, f1_score_cur, color='red', s=300)
253
+ plt.scatter(set_threshold, g_mean_cur, color='red', s=300)
254
+
255
+ ax.set_facecolor("white")
256
+
257
+ ax.tick_params(axis='x', colors='black')
258
+ ax.tick_params(axis='y', colors='black')
259
+ ax.spines['left'].set_color('black')
260
+ ax.spines['bottom'].set_color('black')
261
+ ax.spines['top'].set_color('black')
262
+ ax.spines['right'].set_color('black')
263
+ plt.xlabel('Threshold cut-off')
264
+ plt.ylabel('F1-score & G-mean')
265
+ plt.legend(loc='upper left')
266
+ plt.text(set_threshold, f1_score_cur, 'F1-score:%s' % (round(f1_score_cur,2)))
267
+ plt.text(set_threshold, g_mean_cur, 'G-mean:%s' % (round(g_mean_cur,2)))
268
+ plt.title("Threshold tuning curves (F1-score & G-mean)", fontsize=20)
269
+ plt.tight_layout()
270
+
271
+ plt.savefig('tmp3.png', dpi=100)
272
+
273
+
274
+ #return df,'tmp.png','tmp2.png'
275
+ return 'tmp.png','tmp2.png','tmp3.png'
276
+
277
+ def get_example():
278
+
279
+ import numpy as np
280
+ import pandas as pd
281
+ np.random.seed(seed = 42)
282
+
283
+ N=100
284
+ pd_class1 = pd.DataFrame({'Sample Index': [i for i in range(1,int(N/4)+1)],'Predicted Prob': np.random.uniform(0.4,0.8,int(N/4)), 'Label (Y)': np.repeat(1,int(N/4))})
285
+ pd_class2 = pd.DataFrame({'Sample Index': [i for i in range(int(N/4)+1,N+1)],'Predicted Prob': np.random.uniform(0,0.7,int(3*N/4)), 'Label (Y)': np.repeat(0,int(3*N/4))})
286
+
287
+
288
+ pd_all = pd.concat([pd_class1, pd_class2]).reset_index(drop=True)
289
+ pd_all = pd_all.sample(frac=1).reset_index(drop=True)
290
+ pd_all['Sample Index'] = [i for i in range(1,N+1)]
291
+ return pd_all.to_numpy()
292
+
293
+
294
+ ### configure Gradio
295
+ interface = gr.Interface(fn=visualize_ROC,
296
+ inputs=[set_input2, set_input],
297
+ outputs=[set_output2,set_output3,set_output4],
298
+ examples_per_page = 2,
299
+ examples=[
300
+ [0.5,get_example()],
301
+ [0.7,get_example()],
302
+ ],
303
+ title="ML Demo for Receiver Operating Characteristic (ROC) curve",
304
+ description= "Click examples below for a quick demo",
305
+ theme = 'huggingface',
306
+ #layout = 'horizontal',
307
+ live=True
308
+ )
309
+
310
+
311
+ interface.launch(debug=True, height=1400, width=2800)