haoranrr commited on
Commit
cc49bd9
·
verified ·
1 Parent(s): 1bd1629
Files changed (1) hide show
  1. app.py +326 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import autogluon
2
+ from tkinter import Tk,filedialog
3
+ import pandas as pd
4
+ from sklearn.model_selection import train_test_split
5
+ from autogluon.tabular import TabularDataset,TabularPredictor
6
+ from sklearn.metrics import roc_auc_score,f1_score,roc_curve,confusion_matrix
7
+ import gradio as gr
8
+ import matplotlib.pyplot as plt
9
+ from sklearn.calibration import calibration_curve
10
+ import seaborn as sns
11
+ # Tkinter部分,用于上传csv文件
12
+ # 可以使用gradio默认的部分
13
+ def upload_file():
14
+ root=Tk()
15
+ root.withdraw()
16
+ file_path = filedialog.askopenfilename(title="Select CSV File", filetypes=[("Training files", "*.xlsx *.csv")])
17
+ return file_path
18
+ # 模型训练内部评估
19
+ plt.rc('font',family='Times New Roman')
20
+ def train_and_evaluate(file):
21
+ # 读取csv件
22
+ df = pd.read_csv(file.name)
23
+ label='hospital_expire_flag'
24
+ # 分割数据集
25
+ train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
26
+ # 训练模型
27
+ predictor=TabularPredictor(label=label,problem_type='binary',eval_metric='f1',path='./autogluon/').fit(train_df)
28
+ # 载入最佳模型
29
+ best_model=predictor._trainer.load_model(predictor.get_model_names()[-1])
30
+ # 预测概率
31
+ y_prob = best_model.predict_proba(test_df.drop(label,axis=1))
32
+
33
+ # 计算AUC
34
+ auc = roc_auc_score(test_df[label], y_prob)
35
+ # 绘制ROC曲线
36
+ fpr, tpr, _ = roc_curve(test_df[label], y_prob)
37
+ plt.figure(figsize=(5, 4))
38
+ plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
39
+ plt.plot([0, 1], [0, 1], linestyle='--')
40
+ plt.xlabel('False Positive Rate')
41
+ plt.ylabel('True Positive Rate')
42
+ plt.title('ROC Curve')
43
+ sns.despine()
44
+ plt.legend(loc='best')
45
+ plt.savefig('./roc_curve.png',dpi=200,bbox_inches='tight')
46
+
47
+ # 绘制校准曲线
48
+ prob_true, prob_pred = calibration_curve(test_df[label], y_prob, n_bins=10)
49
+ plt.figure(figsize=(5, 4))
50
+ plt.plot(prob_true, prob_pred, marker='o', label='Autogluon')
51
+ plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly Calibrated')
52
+ plt.ylabel('Predicted Probability',fontdict=dict(family='Times New Roman',size=15))
53
+ plt.xlabel('True Probability',fontdict=dict(family='Times New Roman',size=15))
54
+ plt.title('Calibration Curve',fontdict=dict(family='Times New Roman',size=15))
55
+ sns.despine()
56
+ plt.legend()
57
+ plt.savefig('./Calibration_curve.png',dpi=200,bbox_inches='tight')
58
+
59
+ # 绘制决策曲线
60
+ y_pred=y_prob
61
+ y_test=test_df[label]
62
+ thresh_group=np.arange(0, 1, 0.01)
63
+ net_benefit_model = np.array([])
64
+ for thresh in thresh_group:
65
+ y_pred_label = y_pred > thresh
66
+ tn, fp, fn, tp = confusion_matrix(y_test, y_pred_label).ravel()
67
+ n = len(y_test)
68
+ net_benefit = (tp / n) - (fp / n) * (thresh / (1 - thresh))
69
+ net_benefit_model = np.append(net_benefit_model, net_benefit)
70
+
71
+ net_benefit_all = np.array([])
72
+ tn, fp, fn, tp = confusion_matrix(y_test, y_test).ravel()
73
+ total = tp + tn
74
+ for thresh in thresh_group:
75
+ net_benefit_ = (tp / total) - (tn / total) * (thresh / (1 - thresh))
76
+ net_benefit_all = np.append(net_benefit_all, net_benefit_)
77
+ plt.figure(figsize=(5, 4))
78
+ ax=plt.gca()
79
+ ax.plot(thresh_group, net_benefit_model)
80
+ ax.plot(thresh_group, net_benefit_all, linestyle='--', label='Treat all')
81
+ ax.plot((0, 1), (0, 0), color='black', linestyle='--', label='Treat none')
82
+ ax.fill_between(thresh_group, net_benefit_model, 0, alpha=0.2)
83
+ ax.set_xlim(0, 1)
84
+ ax.set_ylim(net_benefit_model.min() - 0.15, net_benefit_model.max() + 0.15)
85
+ ax.set_xlabel('Threshold Probability', fontdict={'family': 'Times New Roman', 'fontsize': 15})
86
+ ax.set_ylabel('Net Benefit', fontdict={'family': 'Times New Roman', 'fontsize': 15})
87
+ ax.grid(which='minor')
88
+ ax.spines['right'].set_color((0.8, 0.8, 0.8))
89
+ ax.spines['top'].set_color((0.8, 0.8, 0.8))
90
+ ax.legend(loc='upper right')
91
+ sns.despine()
92
+ plt.title('Decision Curve',fontdict=dict(family='Times New Roman',size=15))
93
+ plt.savefig('./Decision_curve.png',dpi=200,bbox_inches='tight')
94
+
95
+ FI=predictor.feature_importance(test_df)
96
+ norm = plt.Normalize(min(FI.importance[:6]), max(FI.importance[:6]))
97
+ colors = plt.cm.viridis(norm(FI.importance[:6].values))
98
+ # 绘制棒图
99
+ plt.figure(figsize=(5,4))
100
+ plt.bar(FI.index[:6], FI.importance[:6],color=colors)
101
+ ax=plt.gca()
102
+ # 添加标题和标签
103
+ plt.title('Feature Importance',fontdict=dict(family='Times New Roman',size=15),pad=0.2)
104
+ plt.xlabel('Features')
105
+ plt.ylabel('Permutation Shuffling Values')
106
+ sns.despine()
107
+ plt.xticks(rotation=45)
108
+ plt.savefig('./feature_importance.png',dpi=200,bbox_inches='tight')
109
+
110
+ return './roc_curve.png','./Calibration_curve.png','./Decision_curve.png' ,'./feature_importance.png'
111
+ # 外部验证
112
+ def external_evaluate(file):
113
+ # 读取csv件
114
+ df = pd.read_csv(file.name)
115
+ label='hospital_expire_flag'
116
+ # 训练模型
117
+ predictor=TabularPredictor.load('./autogluon/')
118
+ # 载入最佳模型
119
+ best_model=predictor._trainer.load_model(predictor.get_model_names()[-1])
120
+ # 预测概率
121
+ y_prob = best_model.predict_proba(df.drop(label,axis=1))
122
+ # 计算AUC
123
+ auc = roc_auc_score(df[label], y_prob)
124
+ # 绘制ROC曲线
125
+ fpr, tpr, _ = roc_curve(df[label], y_prob)
126
+ plt.figure(figsize=(5, 4))
127
+ plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
128
+ plt.plot([0, 1], [0, 1], linestyle='--')
129
+ plt.xlabel('False Positive Rate')
130
+ plt.ylabel('True Positive Rate')
131
+ plt.title('ROC Curve')
132
+ sns.despine()
133
+ plt.legend(loc='best')
134
+ plt.savefig('./roc_curve_external.png',dpi=200,bbox_inches='tight')
135
+
136
+ # 绘制校准曲线
137
+ prob_true, prob_pred = calibration_curve(df[label], y_prob, n_bins=10)
138
+ plt.figure(figsize=(5, 4))
139
+ plt.plot(prob_true, prob_pred, marker='o', label='Autogluon')
140
+ plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly Calibrated')
141
+ plt.ylabel('Predicted Probability',fontdict=dict(family='Times New Roman',size=15))
142
+ plt.xlabel('True Probability',fontdict=dict(family='Times New Roman',size=15))
143
+ plt.title('Calibration Curve',fontdict=dict(family='Times New Roman',size=15))
144
+ sns.despine()
145
+ plt.legend()
146
+ plt.savefig('./Calibration_curve_external.png',dpi=200,bbox_inches='tight')
147
+
148
+ # 绘制决策曲线
149
+ y_pred=y_prob
150
+ y_test=df[label]
151
+ thresh_group=np.arange(0, 1, 0.01)
152
+ net_benefit_model = np.array([])
153
+ for thresh in thresh_group:
154
+ y_pred_label = y_pred > thresh
155
+ tn, fp, fn, tp = confusion_matrix(y_test, y_pred_label).ravel()
156
+ n = len(y_test)
157
+ net_benefit = (tp / n) - (fp / n) * (thresh / (1 - thresh))
158
+ net_benefit_model = np.append(net_benefit_model, net_benefit)
159
+
160
+ net_benefit_all = np.array([])
161
+ tn, fp, fn, tp = confusion_matrix(y_test, y_test).ravel()
162
+ total = tp + tn
163
+ for thresh in thresh_group:
164
+ net_benefit_ = (tp / total) - (tn / total) * (thresh / (1 - thresh))
165
+ net_benefit_all = np.append(net_benefit_all, net_benefit_)
166
+ plt.figure(figsize=(5, 4))
167
+ ax=plt.gca()
168
+ ax.plot(thresh_group, net_benefit_model)
169
+ ax.plot(thresh_group, net_benefit_all, linestyle='--', label='Treat all')
170
+ ax.plot((0, 1), (0, 0), color='black', linestyle='--', label='Treat none')
171
+ ax.fill_between(thresh_group, net_benefit_model, 0, alpha=0.2)
172
+ ax.set_xlim(0, 1)
173
+ ax.set_ylim(net_benefit_model.min() - 0.15, net_benefit_model.max() + 0.15)
174
+ ax.set_xlabel('Threshold Probability', fontdict={'family': 'Times New Roman', 'fontsize': 15})
175
+ ax.set_ylabel('Net Benefit', fontdict={'family': 'Times New Roman', 'fontsize': 15})
176
+ ax.grid(which='minor')
177
+ ax.spines['right'].set_color((0.8, 0.8, 0.8))
178
+ ax.spines['top'].set_color((0.8, 0.8, 0.8))
179
+ ax.legend(loc='upper right')
180
+ sns.despine()
181
+ plt.title('Decision Curve',fontdict=dict(family='Times New Roman',size=15),pad=0.01)
182
+ plt.savefig('./Decision_curve_external.png',dpi=200,bbox_inches='tight')
183
+ # 绘制棒图
184
+ FI=predictor.feature_importance(df)
185
+ norm = plt.Normalize(min(FI.importance[:6]), max(FI.importance[:6]))
186
+ colors = plt.cm.viridis(norm(FI.importance[:6].values))
187
+
188
+ plt.figure(figsize=(5,4))
189
+ plt.bar(FI.index[:6], FI.importance[:6],color=colors)
190
+ ax=plt.gca()
191
+ # 添加标题和标签
192
+ plt.title('Feature Importance',fontdict=dict(family='Times New Roman',size=15),pad=0.2)
193
+ plt.xlabel('Features')
194
+ plt.ylabel('Permutation Shuffling Values')
195
+ sns.despine()
196
+ plt.xticks(rotation=45)
197
+ plt.savefig('./feature_importance_external.png',dpi=200,bbox_inches='tight')
198
+
199
+ return './roc_curve_external.png','./Calibration_curve_external.png','./Decision_curve_external.png' ,'./feature_importance_external.png'
200
+ def preview_excel(file):
201
+ df = pd.read_csv(file.name)
202
+ return df.head(3)
203
+ import gradio as gr
204
+ import base64
205
+
206
+ # CSS styles for the interface
207
+ css = """
208
+ body {
209
+ background-color: #f8f9fa;
210
+ font-family: 'Arial', sans-serif;
211
+ }
212
+ #file_input, #external_file_input, #dataframe {
213
+ border: 2px dashed #007bff;
214
+ padding: 20px;
215
+ border-radius: 10px;
216
+ background-color: #fff;
217
+ }
218
+ #train_button, #evaluate_button, #dataframe_button {
219
+ background-color: #007bff;
220
+ color: gray; /* Changed to white for better contrast */
221
+ font-size: 18px;
222
+ border-radius: 5px;
223
+ margin-top: 10px;
224
+ transition: background-color 0.3s;
225
+ }
226
+ #train_button:hover, #evaluate_button:hover, #dataframe_button:hover {
227
+ background-color: #0056b3;
228
+ }
229
+ #roc_image, #calibration_image, #decision_image, #external_eval_image1, #external_eval_image2, #external_eval_image3 {
230
+ border: 1px solid #ddd;
231
+ border-radius: 10px;
232
+ padding: 10px;
233
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
234
+ }
235
+ h1 {
236
+ color: blue;
237
+ text-align: center;
238
+ font-size: 28px;
239
+ }
240
+ h2 {
241
+ color: #007bff;
242
+ text-align: center;
243
+ }
244
+ p {
245
+ color: #555;
246
+ text-align: center;
247
+ }
248
+ .spinner {
249
+ display: none;
250
+ text-align: center;
251
+ margin-top: 20px;
252
+ }
253
+ """
254
+
255
+ # Load and encode the background image
256
+ with open("D:/Haoran/科研/毕设/分析/模型部署/automl6.png", "rb") as image_file:
257
+ encoded_string = base64.b64encode(image_file.read()).decode()
258
+
259
+ # Create the HTML layout with a background image
260
+ background_image = f"""
261
+ <div style="position: relative; height: 30vh;">
262
+ <div style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; background-image: url('data:image/jpeg;base64,{encoded_string}'); background-size: contain; background-repeat: no-repeat; background-position: center; opacity: 0.7;">
263
+ </div>
264
+ <div style="position: absolute; top: 85%; left: 50%; transform: translate(-50%, -50%); text-align: center;">
265
+ <h1 style="color: blue; font-weight: bold; font-size: 45px; white-space: nowrap;">Clinical Prediction Model Training and Evaluation based on AutoML</h1>
266
+ <p>Upload your CSV file with a 'hospital_expire_flag' column for binary classification. The tool will train a model, evaluate it, and display ROC, Calibration, Decision curves and Feature Importance plot.</p>
267
+ </div>
268
+ </div>
269
+ """
270
+
271
+ # Create Gradio Blocks interface
272
+ with gr.Blocks(css=css) as interface:
273
+ gr.HTML(background_image)
274
+
275
+ with gr.Row():
276
+ file_input = gr.File(label='Upload Model Training CSV File', elem_id="file_input")
277
+
278
+ pre_button = gr.Button('Preview of the First 3 Rows', elem_id='dataframe_button')
279
+
280
+ with gr.Row():
281
+ dataframe = gr.DataFrame(elem_id='dataframe')
282
+
283
+ pre_button.click(fn=preview_excel, inputs=file_input, outputs=dataframe)
284
+
285
+ train_button = gr.Button("Train and Internal Evaluate", elem_id="train_button")
286
+
287
+ with gr.Row():
288
+ img1 = gr.Image(label="ROC Curve", type='filepath', elem_id="roc_image")
289
+ img2 = gr.Image(label="Calibration Curve", type='filepath', elem_id="calibration_image")
290
+ img3 = gr.Image(label="Decision Curve", type='filepath', elem_id="decision_image")
291
+ img4 = gr.Image(label="Feature Importance", type='filepath', elem_id="feature_importance_image")
292
+
293
+ spinner = gr.Markdown("<div class='spinner'>Training model... Please wait...</div>")
294
+
295
+
296
+ def handle_click(file):
297
+ spinner.update(value="正在训练模型,请稍候...", visible=True)
298
+ try:
299
+ results = train_and_evaluate(file)
300
+ return results
301
+ except Exception as e:
302
+ return f"训练失败: {str(e)}"
303
+ finally:
304
+ spinner.update(visible=False)
305
+ train_button.click(fn=handle_click, inputs=file_input, outputs=[img1, img2, img3, img4])
306
+ # External evaluation section
307
+ gr.Markdown("<h2 style='text-align: center;'>External Evaluation</h2>")
308
+ external_file_input = gr.File(label='Upload External Evaluation CSV File', elem_id="external_file_input")
309
+ evaluate_button = gr.Button("External Evaluate", elem_id="evaluate_button")
310
+ with gr.Row():
311
+ external_eval_image1 = gr.Image(label="ROC Curve", type='filepath', elem_id="external_eval_image1")
312
+ external_eval_image2 = gr.Image(label="Calibration Curve", type='filepath', elem_id="external_eval_image2")
313
+ external_eval_image3 = gr.Image(label="Decision Curve", type='filepath', elem_id="external_eval_image3")
314
+ external_eval_image4 = gr.Image(label="Feature Importance", type='filepath', elem_id="external_eval_image4")
315
+ def evaluate_click(file):
316
+ spinner.update(value="正在进行外部评估,请稍候...", visible=True)
317
+ try:
318
+ results = external_evaluate(file)
319
+ return results
320
+ except Exception as e:
321
+ return f"外部评估失败: {str(e)}"
322
+ finally:
323
+ spinner.update(visible=False)
324
+ evaluate_button.click(fn=evaluate_click, inputs=external_file_input, outputs=[external_eval_image1, external_eval_image2, external_eval_image3, external_eval_image4])
325
+ # Launch the interface
326
+ interface.launch()