Spaces:
Sleeping
Sleeping
File size: 6,871 Bytes
3c99552 bb9eeda 3c99552 51edc87 d3c46ef 3c99552 d3c46ef 3c99552 fbcd943 f2e6a5e ae9b7f8 3c99552 f2e6a5e 3c99552 bc32892 ab30ac8 3c99552 ab30ac8 3c99552 ab30ac8 3c99552 bb9eeda ab30ac8 3c99552 ab30ac8 3c99552 ab30ac8 56ffca6 ab30ac8 56ffca6 ab30ac8 56ffca6 ab30ac8 56ffca6 ab30ac8 3c99552 51edc87 21333bd 51edc87 21333bd 51edc87 21333bd 51edc87 21333bd 51edc87 bb9eeda d3c46ef bb9eeda d3c46ef bb9eeda d3c46ef bb9eeda 4b9b1da bb9eeda 3c99552 bb9eeda 21333bd bb9eeda 3c99552 bb9eeda 3c99552 bb9eeda 3c99552 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import argparse
import os
import matplotlib.pyplot as plt
import sys
import gradio as gr
import torch
import pandas as pd
import numpy as np
import io
import base64
from lime.lime_tabular import LimeTabularExplainer
from pycaret.classification import *
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
from util import load_data_and_prepare
import view
def parse_args(args):
parser = argparse.ArgumentParser(description="CBD Classification")
parser.add_argument('--data_dir', type=str, default="./data/")
parser.add_argument('--excel_file', type=str, default="DUMC_final.csv")
parser.add_argument('--mode', type=str, default="train")
parser.add_argument('--scale', type=bool, default=True)
parser.add_argument('--smote', type=bool, default=True)
parser.add_argument('--model_name_or_path', type=str, default="./data/model", choices=[])
return parser.parse_args(args)
# Inference function
def classify(tabular_data):
try:
# Ensure tabular_data is a 2D list and extract the first row
if isinstance(tabular_data, list) and isinstance(tabular_data[0], list):
tabular_data = tabular_data[0] # Extract the first row
else:
raise ValueError("Input data is not in the expected 2D list format.")
# Convert input data to a pandas DataFrame
input_data = pd.DataFrame([tabular_data], columns= tabular_header)
print(f"Original Input DataFrame:\n{input_data}")
# Use PyCaret's predict_model to make predictions
prediction = predict_model(model, data=input_data)
# Extract predicted class and probability
predicted_class = prediction.loc[0, "prediction_label"]
class_probability = prediction.loc[0, "prediction_score"]
# Generate appropriate output based on the prediction and probability
if class_probability < 0.34:
result = (
f"This analysis estimates a low probability ({class_probability:.2f}) of a common bile duct stone. "
"Please consult a medical professional for final diagnosis."
)
elif 0.34 <= class_probability < 0.67:
result = (
f"Based on the provided data, this tool estimates an intermediate probability ({class_probability:.2f}) "
"of a common bile duct stone. Further medical review is recommended."
)
else: # class_probability >= 0.67
result = (
f"Based on the provided data, this tool estimates a high probability ({class_probability:.2f}) "
"of a common bile duct stone. Further medical review is necessary."
)
return result
except Exception as e:
return f"An error occurred during classification: {str(e)}"
# Inference function
def predict_proba_fn(instance):
"""
PyCaret의 predict_model을 활용한 확률 예측 함수.
"""
# 2D 형태로 변환
if instance.ndim == 1:
instance = instance.reshape(1, -1)
# DataFrame으로 변환
instance_df = pd.DataFrame(instance, columns=train.columns)
# predict_model을 통해 예측 수행
predictions = predict_model(model, data=instance_df)
# prediction_label이 1이면 prediction_score, 0이면 1-prediction_score
predictions['class_1_prob'] = np.where(predictions['prediction_label'] == 1,
predictions['prediction_score'],
0)
predictions['class_0_prob'] = np.where(predictions['prediction_label'] == 0,
predictions['prediction_score'],
0)
# class_0_prob와 class_1_prob 반환
return predictions[['class_0_prob', 'class_1_prob']].values
# def explain_with_lime(tabular_data):
# instance = np.array(tabular_data[0],dtype='float')
# # Create an explainer instance for classification
# explainer = LimeTabularExplainer(
# training_data=train.values, # Use your training data
# feature_names=tabular_header,
# class_names=['intermediate', 'High'], # Replace with actual class names
# mode='classification'
# )
# # LIME expects a 2D numpy array or DataFrame for input, and we need to provide the correct number of features
# explanation = explainer.explain_instance(
# instance, # Single instance (first row of the tabular data)
# predict_proba_fn, # The prediction function
# num_features=len(tabular_header) # Number of features to display in the explanation
# )
# # Plot LIME explanation
# fig = explanation.as_pyplot_figure()
# fig.set_size_inches(25, 8)
# buf = io.BytesIO()
# fig.savefig(buf, format='png')
# buf.seek(0)
# encoded_image = base64.b64encode(buf.read()).decode('utf-8')
# buf.close()
# plt.close(fig)
# return f"<img src='data:image/png;base64,{encoded_image}'/>"
if __name__ == '__main__':
args = parse_args(sys.argv[1:])
train = load_data_and_prepare(args.data_dir, args.excel_file, args.mode, args.scale, args.smote)
model = load_model(args.model_name_or_path)
examples = view.examples
description = view.description
title_markdown = view.title_markdown
tabular_header = view.tabular_header
tabular_dtype = ['number'] * len(tabular_header)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(title_markdown)
gr.Markdown(description)
with gr.Row():
with gr.Column():
tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=11)
info = gr.Textbox(lines=1, label="Patient info", visible = False)
with gr.Row():
# btn_c = gr.ClearButton([tabular_input])
btn_c = gr.Button("Clear")
btn = gr.Button("Run")
result_output = gr.Textbox(lines=2, label="Classification Result")
lime_output = gr.HTML(label="LIME Explanation")
gr.Examples(examples=examples, inputs=[tabular_input, info])
btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
# btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button
# Clear functionality: resets inputs and outputs
def clear_fields():
return None, None, [[None] * len(tabular_header)]
btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])
demo.queue()
demo.launch(share=True)
|