|
|
import autogluon |
|
|
from tkinter import Tk,filedialog |
|
|
import pandas as pd |
|
|
from sklearn.model_selection import train_test_split |
|
|
from autogluon.tabular import TabularDataset,TabularPredictor |
|
|
from sklearn.metrics import roc_auc_score,f1_score,roc_curve,confusion_matrix |
|
|
import gradio as gr |
|
|
import matplotlib.pyplot as plt |
|
|
from sklearn.calibration import calibration_curve |
|
|
import seaborn as sns |
|
|
|
|
|
|
|
|
def upload_file(): |
|
|
root=Tk() |
|
|
root.withdraw() |
|
|
file_path = filedialog.askopenfilename(title="Select CSV File", filetypes=[("Training files", "*.xlsx *.csv")]) |
|
|
return file_path |
|
|
|
|
|
plt.rc('font',family='Times New Roman') |
|
|
def train_and_evaluate(file): |
|
|
|
|
|
df = pd.read_csv(file.name) |
|
|
label='hospital_expire_flag' |
|
|
|
|
|
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42) |
|
|
|
|
|
predictor=TabularPredictor(label=label,problem_type='binary',eval_metric='f1',path='./autogluon/').fit(train_df) |
|
|
|
|
|
best_model=predictor._trainer.load_model(predictor.get_model_names()[-1]) |
|
|
|
|
|
y_prob = best_model.predict_proba(test_df.drop(label,axis=1)) |
|
|
|
|
|
|
|
|
auc = roc_auc_score(test_df[label], y_prob) |
|
|
|
|
|
fpr, tpr, _ = roc_curve(test_df[label], y_prob) |
|
|
plt.figure(figsize=(5, 4)) |
|
|
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}') |
|
|
plt.plot([0, 1], [0, 1], linestyle='--') |
|
|
plt.xlabel('False Positive Rate') |
|
|
plt.ylabel('True Positive Rate') |
|
|
plt.title('ROC Curve') |
|
|
sns.despine() |
|
|
plt.legend(loc='best') |
|
|
plt.savefig('./roc_curve.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
|
|
|
prob_true, prob_pred = calibration_curve(test_df[label], y_prob, n_bins=10) |
|
|
plt.figure(figsize=(5, 4)) |
|
|
plt.plot(prob_true, prob_pred, marker='o', label='Autogluon') |
|
|
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly Calibrated') |
|
|
plt.ylabel('Predicted Probability',fontdict=dict(family='Times New Roman',size=15)) |
|
|
plt.xlabel('True Probability',fontdict=dict(family='Times New Roman',size=15)) |
|
|
plt.title('Calibration Curve',fontdict=dict(family='Times New Roman',size=15)) |
|
|
sns.despine() |
|
|
plt.legend() |
|
|
plt.savefig('./Calibration_curve.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
|
|
|
y_pred=y_prob |
|
|
y_test=test_df[label] |
|
|
thresh_group=np.arange(0, 1, 0.01) |
|
|
net_benefit_model = np.array([]) |
|
|
for thresh in thresh_group: |
|
|
y_pred_label = y_pred > thresh |
|
|
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_label).ravel() |
|
|
n = len(y_test) |
|
|
net_benefit = (tp / n) - (fp / n) * (thresh / (1 - thresh)) |
|
|
net_benefit_model = np.append(net_benefit_model, net_benefit) |
|
|
|
|
|
net_benefit_all = np.array([]) |
|
|
tn, fp, fn, tp = confusion_matrix(y_test, y_test).ravel() |
|
|
total = tp + tn |
|
|
for thresh in thresh_group: |
|
|
net_benefit_ = (tp / total) - (tn / total) * (thresh / (1 - thresh)) |
|
|
net_benefit_all = np.append(net_benefit_all, net_benefit_) |
|
|
plt.figure(figsize=(5, 4)) |
|
|
ax=plt.gca() |
|
|
ax.plot(thresh_group, net_benefit_model) |
|
|
ax.plot(thresh_group, net_benefit_all, linestyle='--', label='Treat all') |
|
|
ax.plot((0, 1), (0, 0), color='black', linestyle='--', label='Treat none') |
|
|
ax.fill_between(thresh_group, net_benefit_model, 0, alpha=0.2) |
|
|
ax.set_xlim(0, 1) |
|
|
ax.set_ylim(net_benefit_model.min() - 0.15, net_benefit_model.max() + 0.15) |
|
|
ax.set_xlabel('Threshold Probability', fontdict={'family': 'Times New Roman', 'fontsize': 15}) |
|
|
ax.set_ylabel('Net Benefit', fontdict={'family': 'Times New Roman', 'fontsize': 15}) |
|
|
ax.grid(which='minor') |
|
|
ax.spines['right'].set_color((0.8, 0.8, 0.8)) |
|
|
ax.spines['top'].set_color((0.8, 0.8, 0.8)) |
|
|
ax.legend(loc='upper right') |
|
|
sns.despine() |
|
|
plt.title('Decision Curve',fontdict=dict(family='Times New Roman',size=15)) |
|
|
plt.savefig('./Decision_curve.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
FI=predictor.feature_importance(test_df) |
|
|
norm = plt.Normalize(min(FI.importance[:6]), max(FI.importance[:6])) |
|
|
colors = plt.cm.viridis(norm(FI.importance[:6].values)) |
|
|
|
|
|
plt.figure(figsize=(5,4)) |
|
|
plt.bar(FI.index[:6], FI.importance[:6],color=colors) |
|
|
ax=plt.gca() |
|
|
|
|
|
plt.title('Feature Importance',fontdict=dict(family='Times New Roman',size=15),pad=0.2) |
|
|
plt.xlabel('Features') |
|
|
plt.ylabel('Permutation Shuffling Values') |
|
|
sns.despine() |
|
|
plt.xticks(rotation=45) |
|
|
plt.savefig('./feature_importance.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
return './roc_curve.png','./Calibration_curve.png','./Decision_curve.png' ,'./feature_importance.png' |
|
|
|
|
|
def external_evaluate(file): |
|
|
|
|
|
df = pd.read_csv(file.name) |
|
|
label='hospital_expire_flag' |
|
|
|
|
|
predictor=TabularPredictor.load('./autogluon/') |
|
|
|
|
|
best_model=predictor._trainer.load_model(predictor.get_model_names()[-1]) |
|
|
|
|
|
y_prob = best_model.predict_proba(df.drop(label,axis=1)) |
|
|
|
|
|
auc = roc_auc_score(df[label], y_prob) |
|
|
|
|
|
fpr, tpr, _ = roc_curve(df[label], y_prob) |
|
|
plt.figure(figsize=(5, 4)) |
|
|
plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}') |
|
|
plt.plot([0, 1], [0, 1], linestyle='--') |
|
|
plt.xlabel('False Positive Rate') |
|
|
plt.ylabel('True Positive Rate') |
|
|
plt.title('ROC Curve') |
|
|
sns.despine() |
|
|
plt.legend(loc='best') |
|
|
plt.savefig('./roc_curve_external.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
|
|
|
prob_true, prob_pred = calibration_curve(df[label], y_prob, n_bins=10) |
|
|
plt.figure(figsize=(5, 4)) |
|
|
plt.plot(prob_true, prob_pred, marker='o', label='Autogluon') |
|
|
plt.plot([0, 1], [0, 1], linestyle='--', label='Perfectly Calibrated') |
|
|
plt.ylabel('Predicted Probability',fontdict=dict(family='Times New Roman',size=15)) |
|
|
plt.xlabel('True Probability',fontdict=dict(family='Times New Roman',size=15)) |
|
|
plt.title('Calibration Curve',fontdict=dict(family='Times New Roman',size=15)) |
|
|
sns.despine() |
|
|
plt.legend() |
|
|
plt.savefig('./Calibration_curve_external.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
|
|
|
y_pred=y_prob |
|
|
y_test=df[label] |
|
|
thresh_group=np.arange(0, 1, 0.01) |
|
|
net_benefit_model = np.array([]) |
|
|
for thresh in thresh_group: |
|
|
y_pred_label = y_pred > thresh |
|
|
tn, fp, fn, tp = confusion_matrix(y_test, y_pred_label).ravel() |
|
|
n = len(y_test) |
|
|
net_benefit = (tp / n) - (fp / n) * (thresh / (1 - thresh)) |
|
|
net_benefit_model = np.append(net_benefit_model, net_benefit) |
|
|
|
|
|
net_benefit_all = np.array([]) |
|
|
tn, fp, fn, tp = confusion_matrix(y_test, y_test).ravel() |
|
|
total = tp + tn |
|
|
for thresh in thresh_group: |
|
|
net_benefit_ = (tp / total) - (tn / total) * (thresh / (1 - thresh)) |
|
|
net_benefit_all = np.append(net_benefit_all, net_benefit_) |
|
|
plt.figure(figsize=(5, 4)) |
|
|
ax=plt.gca() |
|
|
ax.plot(thresh_group, net_benefit_model) |
|
|
ax.plot(thresh_group, net_benefit_all, linestyle='--', label='Treat all') |
|
|
ax.plot((0, 1), (0, 0), color='black', linestyle='--', label='Treat none') |
|
|
ax.fill_between(thresh_group, net_benefit_model, 0, alpha=0.2) |
|
|
ax.set_xlim(0, 1) |
|
|
ax.set_ylim(net_benefit_model.min() - 0.15, net_benefit_model.max() + 0.15) |
|
|
ax.set_xlabel('Threshold Probability', fontdict={'family': 'Times New Roman', 'fontsize': 15}) |
|
|
ax.set_ylabel('Net Benefit', fontdict={'family': 'Times New Roman', 'fontsize': 15}) |
|
|
ax.grid(which='minor') |
|
|
ax.spines['right'].set_color((0.8, 0.8, 0.8)) |
|
|
ax.spines['top'].set_color((0.8, 0.8, 0.8)) |
|
|
ax.legend(loc='upper right') |
|
|
sns.despine() |
|
|
plt.title('Decision Curve',fontdict=dict(family='Times New Roman',size=15),pad=0.01) |
|
|
plt.savefig('./Decision_curve_external.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
FI=predictor.feature_importance(df) |
|
|
norm = plt.Normalize(min(FI.importance[:6]), max(FI.importance[:6])) |
|
|
colors = plt.cm.viridis(norm(FI.importance[:6].values)) |
|
|
|
|
|
plt.figure(figsize=(5,4)) |
|
|
plt.bar(FI.index[:6], FI.importance[:6],color=colors) |
|
|
ax=plt.gca() |
|
|
|
|
|
plt.title('Feature Importance',fontdict=dict(family='Times New Roman',size=15),pad=0.2) |
|
|
plt.xlabel('Features') |
|
|
plt.ylabel('Permutation Shuffling Values') |
|
|
sns.despine() |
|
|
plt.xticks(rotation=45) |
|
|
plt.savefig('./feature_importance_external.png',dpi=200,bbox_inches='tight') |
|
|
|
|
|
return './roc_curve_external.png','./Calibration_curve_external.png','./Decision_curve_external.png' ,'./feature_importance_external.png' |
|
|
def preview_excel(file): |
|
|
df = pd.read_csv(file.name) |
|
|
return df.head(3) |
|
|
import gradio as gr |
|
|
import base64 |
|
|
|
|
|
|
|
|
css = """ |
|
|
body { |
|
|
background-color: #f8f9fa; |
|
|
font-family: 'Arial', sans-serif; |
|
|
} |
|
|
#file_input, #external_file_input, #dataframe { |
|
|
border: 2px dashed #007bff; |
|
|
padding: 20px; |
|
|
border-radius: 10px; |
|
|
background-color: #fff; |
|
|
} |
|
|
#train_button, #evaluate_button, #dataframe_button { |
|
|
background-color: #007bff; |
|
|
color: gray; /* Changed to white for better contrast */ |
|
|
font-size: 18px; |
|
|
border-radius: 5px; |
|
|
margin-top: 10px; |
|
|
transition: background-color 0.3s; |
|
|
} |
|
|
#train_button:hover, #evaluate_button:hover, #dataframe_button:hover { |
|
|
background-color: #0056b3; |
|
|
} |
|
|
#roc_image, #calibration_image, #decision_image, #external_eval_image1, #external_eval_image2, #external_eval_image3 { |
|
|
border: 1px solid #ddd; |
|
|
border-radius: 10px; |
|
|
padding: 10px; |
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
h1 { |
|
|
color: blue; |
|
|
text-align: center; |
|
|
font-size: 28px; |
|
|
} |
|
|
h2 { |
|
|
color: #007bff; |
|
|
text-align: center; |
|
|
} |
|
|
p { |
|
|
color: #555; |
|
|
text-align: center; |
|
|
} |
|
|
.spinner { |
|
|
display: none; |
|
|
text-align: center; |
|
|
margin-top: 20px; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with open("D:/Haoran/科研/毕设/分析/模型部署/automl6.png", "rb") as image_file: |
|
|
encoded_string = base64.b64encode(image_file.read()).decode() |
|
|
|
|
|
|
|
|
background_image = f""" |
|
|
<div style="position: relative; height: 30vh;"> |
|
|
<div style="position: absolute; top: 0; left: 0; width: 100%; height: 100%; background-image: url('data:image/jpeg;base64,{encoded_string}'); background-size: contain; background-repeat: no-repeat; background-position: center; opacity: 0.7;"> |
|
|
</div> |
|
|
<div style="position: absolute; top: 85%; left: 50%; transform: translate(-50%, -50%); text-align: center;"> |
|
|
<h1 style="color: blue; font-weight: bold; font-size: 45px; white-space: nowrap;">Clinical Prediction Model Training and Evaluation based on AutoML</h1> |
|
|
<p>Upload your CSV file with a 'hospital_expire_flag' column for binary classification. The tool will train a model, evaluate it, and display ROC, Calibration, Decision curves and Feature Importance plot.</p> |
|
|
</div> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=css) as interface: |
|
|
gr.HTML(background_image) |
|
|
|
|
|
with gr.Row(): |
|
|
file_input = gr.File(label='Upload Model Training CSV File', elem_id="file_input") |
|
|
|
|
|
pre_button = gr.Button('Preview of the First 3 Rows', elem_id='dataframe_button') |
|
|
|
|
|
with gr.Row(): |
|
|
dataframe = gr.DataFrame(elem_id='dataframe') |
|
|
|
|
|
pre_button.click(fn=preview_excel, inputs=file_input, outputs=dataframe) |
|
|
|
|
|
train_button = gr.Button("Train and Internal Evaluate", elem_id="train_button") |
|
|
|
|
|
with gr.Row(): |
|
|
img1 = gr.Image(label="ROC Curve", type='filepath', elem_id="roc_image") |
|
|
img2 = gr.Image(label="Calibration Curve", type='filepath', elem_id="calibration_image") |
|
|
img3 = gr.Image(label="Decision Curve", type='filepath', elem_id="decision_image") |
|
|
img4 = gr.Image(label="Feature Importance", type='filepath', elem_id="feature_importance_image") |
|
|
|
|
|
spinner = gr.Markdown("<div class='spinner'>Training model... Please wait...</div>") |
|
|
|
|
|
|
|
|
def handle_click(file): |
|
|
spinner.update(value="正在训练模型,请稍候...", visible=True) |
|
|
try: |
|
|
results = train_and_evaluate(file) |
|
|
return results |
|
|
except Exception as e: |
|
|
return f"训练失败: {str(e)}" |
|
|
finally: |
|
|
spinner.update(visible=False) |
|
|
train_button.click(fn=handle_click, inputs=file_input, outputs=[img1, img2, img3, img4]) |
|
|
|
|
|
gr.Markdown("<h2 style='text-align: center;'>External Evaluation</h2>") |
|
|
external_file_input = gr.File(label='Upload External Evaluation CSV File', elem_id="external_file_input") |
|
|
evaluate_button = gr.Button("External Evaluate", elem_id="evaluate_button") |
|
|
with gr.Row(): |
|
|
external_eval_image1 = gr.Image(label="ROC Curve", type='filepath', elem_id="external_eval_image1") |
|
|
external_eval_image2 = gr.Image(label="Calibration Curve", type='filepath', elem_id="external_eval_image2") |
|
|
external_eval_image3 = gr.Image(label="Decision Curve", type='filepath', elem_id="external_eval_image3") |
|
|
external_eval_image4 = gr.Image(label="Feature Importance", type='filepath', elem_id="external_eval_image4") |
|
|
def evaluate_click(file): |
|
|
spinner.update(value="正在进行外部评估,请稍候...", visible=True) |
|
|
try: |
|
|
results = external_evaluate(file) |
|
|
return results |
|
|
except Exception as e: |
|
|
return f"外部评估失败: {str(e)}" |
|
|
finally: |
|
|
spinner.update(visible=False) |
|
|
evaluate_button.click(fn=evaluate_click, inputs=external_file_input, outputs=[external_eval_image1, external_eval_image2, external_eval_image3, external_eval_image4]) |
|
|
|
|
|
interface.launch() |