File size: 6,871 Bytes
3c99552
 
 
 
 
 
 
bb9eeda
 
 
3c99552
 
 
 
 
51edc87
d3c46ef
3c99552
 
d3c46ef
3c99552
fbcd943
f2e6a5e
 
ae9b7f8
 
3c99552
f2e6a5e
3c99552
 
 
 
 
 
bc32892
ab30ac8
3c99552
 
 
ab30ac8
3c99552
 
ab30ac8
3c99552
 
bb9eeda
ab30ac8
3c99552
 
ab30ac8
3c99552
 
 
 
ab30ac8
 
56ffca6
ab30ac8
 
56ffca6
ab30ac8
56ffca6
ab30ac8
 
56ffca6
ab30ac8
 
 
 
 
 
3c99552
 
 
 
 
51edc87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21333bd
 
51edc87
21333bd
 
 
 
 
 
 
 
 
 
 
 
 
 
51edc87
21333bd
 
 
 
 
 
 
 
 
51edc87
21333bd
51edc87
bb9eeda
 
d3c46ef
bb9eeda
d3c46ef
 
 
 
bb9eeda
 
d3c46ef
bb9eeda
 
 
 
 
4b9b1da
bb9eeda
 
 
 
 
 
 
 
3c99552
bb9eeda
 
 
 
 
21333bd
bb9eeda
 
 
 
3c99552
bb9eeda
3c99552
 
bb9eeda
 
3c99552
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import os
import matplotlib.pyplot as plt
import sys
import gradio as gr
import torch
import pandas as pd
import numpy as np
import io
import base64
from lime.lime_tabular import LimeTabularExplainer
from pycaret.classification import *
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")

from util import load_data_and_prepare
import view


       
def parse_args(args):
    parser = argparse.ArgumentParser(description="CBD Classification")
    parser.add_argument('--data_dir', type=str, default="./data/")
    parser.add_argument('--excel_file', type=str, default="DUMC_final.csv")
    parser.add_argument('--mode', type=str, default="train")
    parser.add_argument('--scale', type=bool, default=True)
    parser.add_argument('--smote', type=bool, default=True)
    parser.add_argument('--model_name_or_path', type=str, default="./data/model", choices=[])

    return parser.parse_args(args)



# Inference function
def classify(tabular_data):
    
    try:
        # Ensure tabular_data is a 2D list and extract the first row
        if isinstance(tabular_data, list) and isinstance(tabular_data[0], list):
            tabular_data = tabular_data[0] # Extract the first row
        else:
            raise ValueError("Input data is not in the expected 2D list format.")
 
        # Convert input data to a pandas DataFrame
        input_data = pd.DataFrame([tabular_data], columns= tabular_header)
        print(f"Original Input DataFrame:\n{input_data}")

        # Use PyCaret's predict_model to make predictions
        prediction = predict_model(model, data=input_data)
 
        # Extract predicted class and probability
        predicted_class = prediction.loc[0, "prediction_label"]
        class_probability = prediction.loc[0, "prediction_score"]

        # Generate appropriate output based on the prediction and probability
        if class_probability < 0.34:
            result = (
                f"This analysis estimates a low probability ({class_probability:.2f}) of a common bile duct stone. "
                 "Please consult a medical professional for final diagnosis."
            )
        elif 0.34 <= class_probability < 0.67:
            result = (
                f"Based on the provided data, this tool estimates an intermediate probability ({class_probability:.2f}) "
                "of a common bile duct stone. Further medical review is recommended."
            )
        else: # class_probability >= 0.67
            result = (
                f"Based on the provided data, this tool estimates a high probability ({class_probability:.2f}) "
                "of a common bile duct stone. Further medical review is necessary."
            )

        return result

    except Exception as e:
        return f"An error occurred during classification: {str(e)}"

# Inference function
def predict_proba_fn(instance):
    """
    PyCaret의 predict_model을 활용한 확률 예측 함수.
    """
    # 2D 형태로 변환
    if instance.ndim == 1:
        instance = instance.reshape(1, -1)
    
    # DataFrame으로 변환
    instance_df = pd.DataFrame(instance, columns=train.columns)
    
    # predict_model을 통해 예측 수행
    predictions = predict_model(model, data=instance_df)
    
    # prediction_label이 1이면 prediction_score, 0이면 1-prediction_score
    predictions['class_1_prob'] = np.where(predictions['prediction_label'] == 1, 
                                            predictions['prediction_score'], 
                                            0)
    
    predictions['class_0_prob'] = np.where(predictions['prediction_label'] == 0, 
                                            predictions['prediction_score'], 
                                            0)
    
    # class_0_prob와 class_1_prob 반환
    return predictions[['class_0_prob', 'class_1_prob']].values


# def explain_with_lime(tabular_data):
#     instance = np.array(tabular_data[0],dtype='float')
    
#     # Create an explainer instance for classification
#     explainer = LimeTabularExplainer(
#         training_data=train.values,  # Use your training data
#         feature_names=tabular_header,
#         class_names=['intermediate', 'High'],  # Replace with actual class names
#         mode='classification'
#     )

#     # LIME expects a 2D numpy array or DataFrame for input, and we need to provide the correct number of features
#     explanation = explainer.explain_instance(
#         instance,  # Single instance (first row of the tabular data)
#         predict_proba_fn,  # The prediction function
#         num_features=len(tabular_header)  # Number of features to display in the explanation
#     )
    
#     # Plot LIME explanation
#     fig = explanation.as_pyplot_figure()
#     fig.set_size_inches(25, 8) 
#     buf = io.BytesIO()
#     fig.savefig(buf, format='png')
#     buf.seek(0)
#     encoded_image = base64.b64encode(buf.read()).decode('utf-8')
#     buf.close()
#     plt.close(fig)
    
#     return f"<img src='data:image/png;base64,{encoded_image}'/>"

if __name__ == '__main__':
    args = parse_args(sys.argv[1:])
    train = load_data_and_prepare(args.data_dir, args.excel_file, args.mode, args.scale, args.smote)
    model = load_model(args.model_name_or_path)
    examples = view.examples
    description = view.description
    title_markdown = view.title_markdown
    tabular_header = view.tabular_header
    tabular_dtype = ['number'] * len(tabular_header)


    with gr.Blocks(theme=gr.themes.Soft()) as demo:
        gr.Markdown(title_markdown)
        gr.Markdown(description)
        with gr.Row():
            with gr.Column():
                tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=11)
                info = gr.Textbox(lines=1, label="Patient info", visible = False)
                    
                with gr.Row():
                    # btn_c = gr.ClearButton([tabular_input])
                    btn_c = gr.Button("Clear")
                    btn = gr.Button("Run")
                    
    

        
        result_output = gr.Textbox(lines=2, label="Classification Result")
        lime_output = gr.HTML(label="LIME Explanation")
        gr.Examples(examples=examples, inputs=[tabular_input, info])
        btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
        # btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output)  # Add LIME button    
    
        # Clear functionality: resets inputs and outputs
        def clear_fields():
            return None, None, [[None] * len(tabular_header)]

        btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])


    demo.queue()
    demo.launch(share=True)