jiehou commited on
Commit
33f8b18
·
1 Parent(s): 4892b16

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+ import gradio as gr
5
+
6
+
7
+ set_input = gr.inputs.Dataframe(type="numpy", row_count=10, col_count=3, headers=['Sample Index', 'Predicted Prob', 'Label (Y)'], datatype=["number", "number", "number"])
8
+ set_input2 = gr.inputs.Slider(0, 1, step = 0.1, default=0.4)
9
+
10
+ #set_output = gr.inputs.Textbox(label ='test')
11
+ set_output1 = gr.outputs.Dataframe(type="pandas", label = 'Predicted Labels',max_rows=10)
12
+
13
+ set_output2 = gr.outputs.Image(label="Confusion Matrix")
14
+ set_output3 = gr.outputs.Image(label="ROC curve")
15
+
16
+ def visualize_ROC(set_threshold,set_input):
17
+ prob = set_input[:,1]
18
+ pred_label = (prob >= set_threshold).astype(int)
19
+ actual_label = set_input[:,2]
20
+ import pandas as pd
21
+
22
+ data = {
23
+ 'Predicted Prob': prob,
24
+ 'Predicted Label': pred_label,
25
+ 'Actual Label': actual_label
26
+ }
27
+
28
+ import pandas as pd
29
+ import seaborn as sn
30
+ import matplotlib.pyplot as plt
31
+
32
+
33
+
34
+ df = pd.DataFrame(data)
35
+ confusion_matrix_results = confusion_matrix(df['Actual Label'], df['Predicted Label'])
36
+
37
+ fig, ax = plt.subplots(figsize=(12,4))
38
+ sn.heatmap(confusion_matrix_results, annot=True,annot_kws={"size": 20},cbar=False,
39
+ square=False,
40
+ fmt='g',
41
+ cmap=ListedColormap(['white']), linecolor='black',
42
+ linewidths=1.5)
43
+
44
+ sn.set(font_scale=2)
45
+ plt.xlabel("Predicted Label")
46
+ plt.ylabel("Actual Label")
47
+ plt.text(0.6,0.55,'(TN)')
48
+ plt.text(1.6,0.55,'(FP)')
49
+ plt.text(0.6,1.55,'(FN)')
50
+ plt.text(1.6,1.55,'(TP)')
51
+
52
+ ax.xaxis.tick_top()
53
+
54
+ ax.xaxis.set_ticks_position('top')
55
+ ax.xaxis.set_label_position('top')
56
+ plt.tight_layout()
57
+
58
+ plt.savefig('tmp.png', dpi=100)
59
+
60
+ ## get ROC curve
61
+ from sklearn.metrics import roc_curve
62
+ fpr_mod, tpr_mod, thrsholds_mod = roc_curve(df['Actual Label'], df['Predicted Prob'])
63
+
64
+ TP, FP, TN, FN = perf_measure(df['Actual Label'], df['Predicted Label'])
65
+
66
+ # Sensitivity, hit rate, recall, or true positive rate
67
+ try:
68
+ TPR = TP/(TP+FN)
69
+ except:
70
+ TPR = 0
71
+
72
+ # Fall out or false positive rate
73
+ try:
74
+ FPR = FP/(FP+TN)
75
+ except:
76
+ FPR = 0
77
+
78
+
79
+
80
+ fig, ax = plt.subplots(figsize=(12,8))
81
+
82
+ import matplotlib.pyplot as plt
83
+ import numpy as np
84
+ plt.rcParams["figure.autolayout"] = True
85
+ plt.rcParams['figure.facecolor'] = 'white'
86
+ m1, c1 = 1, 0
87
+ x = np.linspace(0, 1, 500)
88
+
89
+ plt.plot(fpr_mod, tpr_mod, label = 'ROC', c='black', linestyle='-')
90
+
91
+ plt.plot(x, x * m1 + c1, 'black', linestyle='--')
92
+ plt.xlim(0, 1)
93
+ plt.ylim(0, 1)
94
+ #xi = (c1 - c2) / (m2 - m1)
95
+ #yi = m1 * xi + c1
96
+ plt.axvline(x=FPR, color='gray', linestyle='--')
97
+ plt.axhline(y=TPR, color='gray', linestyle='--')
98
+ plt.scatter(FPR, TPR, color='red', s=300)
99
+
100
+ ax.set_facecolor("white")
101
+
102
+ ax.tick_params(axis='x', colors='black')
103
+ ax.tick_params(axis='y', colors='black')
104
+ ax.spines['left'].set_color('black')
105
+ ax.spines['bottom'].set_color('black')
106
+ ax.spines['top'].set_color('black')
107
+ ax.spines['right'].set_color('black')
108
+ plt.xlabel('False Positive Rate (1 - specificity)')
109
+ plt.ylabel('True Positive Rate (Recall)')
110
+ plt.text(TPR, TPR, 'TPR:%s, FPR:%s' % (FPR,TPR))
111
+ plt.title("ROC curve", fontsize=20)
112
+ plt.tight_layout()
113
+
114
+ plt.savefig('tmp2.png', dpi=100)
115
+
116
+ #return df,'tmp.png','tmp2.png'
117
+ return 'tmp.png','tmp2.png'
118
+
119
+ def get_example():
120
+
121
+ import numpy as np
122
+ import pandas as pd
123
+ np.random.seed(seed = 3)
124
+
125
+ N=100
126
+ pd_class1 = pd.DataFrame({'Sample Index': [i for i in range(1,int(N/2)+1)],'Predicted Prob': np.random.uniform(0.3,1,int(N/2)), 'Label (Y)': np.repeat(1,int(N/2))})
127
+ pd_class2 = pd.DataFrame({'Sample Index': [i for i in range(int(N/2)+1,N+1)],'Predicted Prob': np.random.uniform(0,0.7,int(N/2)), 'Label (Y)': np.repeat(0,int(N/2))})
128
+
129
+
130
+ pd_all = pd.concat([pd_class1, pd_class2]).reset_index(drop=True)
131
+ pd_all = pd_all.sample(frac=1)
132
+ return pd_all.to_numpy()
133
+
134
+
135
+ ### configure Gradio
136
+ interface = gr.Interface(fn=visualize_ROC,
137
+ inputs=[set_input2, set_input],
138
+ outputs=[set_output2,set_output3],
139
+ examples_per_page = 2,
140
+ examples=[
141
+ [0.5,get_example()],
142
+ [0.7,get_example()],
143
+ ],
144
+ title="CSCI4750/5750: ROC curve",
145
+ description= "Click examples below for a quick demo",
146
+ theme = 'huggingface',
147
+ layout = 'horizontal',
148
+ live=True
149
+ )
150
+
151
+
152
+ interface.launch(debug=True, height=1400, width=1400)