ready2drop commited on
Commit
3c99552
ยท
verified ยท
1 Parent(s): 23c6ac6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +378 -224
app.py CHANGED
@@ -1,224 +1,378 @@
1
- import argparse
2
- import os
3
- import io
4
- import base64
5
- import matplotlib.pyplot as plt
6
- import sys
7
- import bleach
8
- import gradio as gr
9
- import torch
10
- import numpy as np
11
- import pandas as pd
12
- import pickle
13
- from sklearn.preprocessing import StandardScaler
14
- from lime.lime_tabular import LimeTabularExplainer
15
- from pycaret.classification import *
16
- from bc_feature_engineering import load_data
17
- import warnings
18
- warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
19
-
20
- def parse_args(args):
21
- parser = argparse.ArgumentParser(description="M3D-LaMed chat")
22
- parser.add_argument('--data_dir', type=str, default="/mnt/c/Users/user/Downloads/DUMC_project/DUMC_total")
23
- parser.add_argument('--excel_file', type=str, default="dumc_1223_case3_duct_correct.csv")
24
- parser.add_argument('--modality', type=str, default="tabular")
25
- parser.add_argument('--phase', type=str, default="combine")
26
- parser.add_argument('--smote', type=bool, default=True)
27
- parser.add_argument('--model_name_or_path', type=str, default="logs/2025-01-13-18-16-test-tabular/ensemble_1", choices=[])
28
- parser.add_argument('--top_p', type=float, default=None)
29
- parser.add_argument('--temperature', type=float, default=1.0)
30
- parser.add_argument('--device', type=str, default="cuda", choices=["cuda", "cpu"])
31
-
32
- return parser.parse_args(args)
33
-
34
-
35
- def load_data_and_prepare(data_dir, excel_file, modality, phase, smote):
36
- # Load train, validation, and test data
37
- train_df,val_df = load_data(data_dir, excel_file, 'train', modality, phase, smote)
38
-
39
- train_df.drop(columns=['patient_id'],inplace = True)
40
- val_df.drop(columns=['patient_id'],inplace = True)
41
-
42
- train = pd.concat([train_df,val_df],axis=0)
43
-
44
- return train
45
-
46
-
47
- # Inference function
48
- def classify(tabular_data, model):
49
- """
50
- Perform classification on tabular data using a PyCaret pre-trained model.
51
-
52
- Args:
53
- tabular_data (list or array-like): Input data points (e.g., a single row of features)
54
- model (object): Pre-trained classification model from PyCaret
55
-
56
- Returns:
57
- str: Classification result and probabilities
58
- """
59
- try:
60
- # Ensure tabular_data is a 2D list and extract the first row
61
- if isinstance(tabular_data, list) and isinstance(tabular_data[0], list):
62
- tabular_data = tabular_data[0] # Extract the first row
63
- else:
64
- raise ValueError("Input data is not in the expected 2D list format.")
65
-
66
- # Convert input data to a pandas DataFrame
67
- input_data = pd.DataFrame([tabular_data], columns= tabular_header)
68
- print(f"Input DataFrame:\n{input_data}")
69
-
70
- # Use PyCaret's predict_model to make predictions
71
- prediction = predict_model(model, data=input_data)
72
- print('OK')
73
- # Extract predicted class and probability
74
- predicted_class = prediction.loc[0, "prediction_label"]
75
- class_probability = prediction.loc[0, "prediction_score"]
76
-
77
- # Format the result
78
- result = f"Predicted Class: {predicted_class}, Probability: {class_probability:.2f}"
79
- return result
80
-
81
- except Exception as e:
82
- return f"An error occurred during classification: {str(e)}"
83
-
84
- args = parse_args(sys.argv[1:])
85
- # x_train, y_train, x_val, y_val, x_test, y_test = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
86
- train = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
87
- model = load_model(args.model_name_or_path)
88
- device = torch.device(args.device)
89
-
90
-
91
- # Gradio
92
- examples = [
93
- [
94
- [['1', '0', '0', '104', '24', '10.6', '171', '14.54', '236', '182', '12.33', '3.2', '72']],
95
- "PT_NO = 10001862, VISIBLE_STONE_CT = True, REAL_STONE = True",
96
- ],
97
- [
98
- [['0', '1','0','106','18','13.6', '388', '21.13', '196', '118', '1.87', '2.7', '58']],
99
- "PT_NO = 10007376, VISIBLE_STONE_CT = True, REAL_STONE = True",
100
- ],
101
- [
102
- [['1', '0','1','205','18','9.3', '103', '8.45', '440', '100', '4.21', '4.5', '63']],
103
- "PT_NO = 10040285, VISIBLE_STONE_CT = False, REAL_STONE = True",
104
- ],
105
- [
106
- [['0', '1','1','130','20','12.1', '192', '8.63', '47', '59', '0.02', '0.4', '57']],
107
- "PT_NO = 10005545, VISIBLE_STONE_CT = False, REAL_STONE = False",
108
- ],
109
- ]
110
-
111
- tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
112
-
113
- description = """
114
- GPU ๋ฆฌ์†Œ์Šค ์ œ์•ฝ์œผ๋กœ ์ธํ•ด, ์˜จ๋ผ์ธ ๋ฐ๋ชจ์—์„œ๋Š” NVIDIA RTX 3090 24GB๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. \n
115
-
116
- **Note**: ํ˜„์žฌ ์ €ํฌ ๋ชจ๋ธ์€ **์ด๋‹ด๊ด€๊ฒฐ์„์ฆ**์˜ ๋ถ„์„ ๋ฐ ์ง„๋‹จ์„ ์ค‘์‹ฌ์œผ๋กœ ์ตœ์ ํ™”๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ •ํ™•ํ•˜๊ณ  ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. \n
117
- ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฉฐ, ์•„๋ž˜์™€ ๊ฐ™์ด ๊ฐ๊ฐ **์ด์‚ฐํ˜•(discrete)** **์—ฐ์†ํ˜•(continuous)** ๋ฐ์ดํ„ฐ๋กœ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. \n
118
-
119
- - ์ด์‚ฐํ˜• ๋ณ€์ˆ˜:
120
- - DUCT_DILIATATION_8MM
121
- - DUCT_DILIATATION_10MM
122
- - PANCREATITIS
123
-
124
- - ์—ฐ์†ํ˜• ๋ณ€์ˆ˜:
125
- - FIRST_SBP (Systolic blood pressure)
126
- - FIRST_RR (Respiratory rate)
127
- - Hb (Hemoglobin)
128
- - PLT (Platelet)
129
- - WBC (White Blood Cell)
130
- - ALP (Alkaline Phosphatase)
131
- - ALT (Alanine Aminotransferase)
132
- - AST (Aspartate Aminotransferase)
133
- - CRP (C-Reactive Protein)
134
- - BILIRUBIN
135
- - AGE
136
-
137
- **์ค‘์š”**: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ์ปฌ๋Ÿผ์ด ๋ณ€๊ฒฝ(์ถ”๊ฐ€, ์‚ญ์ œ)๋  ๊ฒฝ์šฐ, ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. \n
138
- ๋”ฐ๋ผ์„œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ๊ตฌ์กฐ๋ฅผ ๋ณ€๊ฒฝํ•˜๊ธฐ ์ „์— ๋ชจ๋ธ์˜ ์žฌํ•™์Šต ๋˜๋Š” ์žฌ๊ฒ€์ฆ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. \n
139
- """
140
-
141
- title_markdown = ("""
142
- # ์ž„์ƒ ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ ๋จธ์‹ ๋Ÿฌ๋‹์„ ์ด์šฉํ•œ ์ด๋‹ด๊ด€์„ ์˜ˆ์ธก ๋ชจ๋ธ
143
- ## Development of a Common Bile Duct Stone Prediction Model Using Machine Learning Based on Clinical Data
144
- [๐Ÿ“–[Learn more about Common Bile Duct Stones (์ด๋‹ด๊ด€๊ฒฐ์„์ฆ)](https://namu.wiki/w/%EC%B4%9D%EB%8B%B4%EA%B4%80%EA%B2%B0%EC%84%9D%EC%A6%9D)]
145
- ### Copyright ยฉ 2024 Dongguk University (DGU) and Dongguk University Medical Center (DUMC). All rights reserved.
146
- """)
147
-
148
-
149
- # def explain_with_lime(tabular_data):
150
- # """
151
- # Apply LIME to explain predictions.
152
- # Args:
153
- # tabular_data (list): List of input data points (e.g., rows in a dataframe)
154
- # Returns:
155
- # str: HTML or image showing LIME explanation
156
- # """
157
- # input_data = np.array(tabular_data, dtype=float)
158
- # explainer = LimeTabularExplainer(
159
- # training_data=x_train.values, # Replace with your training data
160
- # feature_names=tabular_header,
161
- # class_names=['intermediate', 'High'], # Replace with actual class names
162
- # mode='classification'
163
- # )
164
-
165
- # explanation = explainer.explain_instance(
166
- # input_data[0], # Single instance to explain
167
- # model.predict_proba, # Probability prediction function
168
- # num_features=len(tabular_header)
169
- # )
170
-
171
- # # Plot LIME explanation
172
- # fig = explanation.as_pyplot_figure()
173
- # fig.set_size_inches(25, 8)
174
- # buf = io.BytesIO()
175
- # fig.savefig(buf, format='png')
176
- # buf.seek(0)
177
- # encoded_image = base64.b64encode(buf.read()).decode('utf-8')
178
- # buf.close()
179
- # plt.close(fig)
180
-
181
- # return f"<img src='data:image/png;base64,{encoded_image}'/>"
182
-
183
-
184
- tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
185
- tabular_dtype = ['number'] * len(tabular_header)
186
-
187
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
188
- gr.Markdown(title_markdown)
189
- gr.Markdown(description)
190
- with gr.Row():
191
- with gr.Column():
192
- tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=13)
193
- info = gr.Textbox(lines=1, label="Patient info", visible = False)
194
-
195
- with gr.Accordion("Parameters", open=False) as parameter_row:
196
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
197
- label="Temperature", )
198
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, interactive=True, label="Top P", )
199
-
200
- with gr.Row():
201
- # btn_c = gr.ClearButton([tabular_input])
202
- btn_c = gr.Button("Clear")
203
- btn = gr.Button("Run")
204
-
205
-
206
-
207
-
208
- result_output = gr.Textbox(lines=2, label="Classification Result")
209
- lime_output = gr.HTML(label="LIME Explanation")
210
- gr.Examples(examples=examples, inputs=[tabular_input, info])
211
- btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
212
- # btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button
213
-
214
- # Clear functionality: resets inputs and outputs
215
- def clear_fields():
216
- return None, None, [[None] * len(tabular_header)]
217
-
218
- btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])
219
-
220
-
221
- demo.queue()
222
- demo.launch(share=True)
223
-
224
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import io
4
+ import base64
5
+ import matplotlib.pyplot as plt
6
+ import sys
7
+ import bleach
8
+ import gradio as gr
9
+ import torch
10
+ import numpy as np
11
+ import pandas as pd
12
+ import pickle
13
+ from sklearn.preprocessing import StandardScaler
14
+ from lime.lime_tabular import LimeTabularExplainer
15
+ from pycaret.classification import *
16
+ import warnings
17
+ warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
18
+ from sklearn.preprocessing import MinMaxScaler
19
+ from sklearn.model_selection import train_test_split
20
+ from sklearn.utils import resample
21
+ from glob import glob
22
+ from imblearn.over_sampling import SMOTE
23
+
24
+
25
+
26
+ def load_data(data_dir : str,
27
+ excel_file : str,
28
+ mode : str = "train",
29
+ modality : str = 'mm',
30
+ phase : str = 'portal', # 'portal', 'pre-enhance', 'combine'
31
+ smote = bool,
32
+ ):
33
+
34
+
35
+ print("--------------Load RawData--------------")
36
+ df = pd.read_csv(os.path.join(data_dir, excel_file))
37
+
38
+ #Inclusion
39
+ print("--------------Inclusion--------------")
40
+ print('Total : ', len(df))
41
+
42
+ print("--------------fillNA--------------")
43
+ # data = data.dropna()
44
+ df.fillna(0.0,inplace=True)
45
+ print(df['REAL_STONE'].value_counts())
46
+
47
+ #Column rename
48
+ df.rename(columns={'ID': 'patient_id', 'REAL_STONE':'target'}, inplace=True)
49
+
50
+ # feature importance w/o VISIBLE_STONE_CT(n=11)
51
+ columns = ['patient_id','DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','Hb', 'PLT', 'WBC', 'ALP', 'ALT', 'AST', 'CRP', 'BILIRUBIN', 'AGE','target']
52
+
53
+ data = df[columns]
54
+ data['patient_id'] = data['patient_id'].astype(str)
55
+
56
+ image_list = sorted(glob(os.path.join(data_dir,"*.nii.gz")))
57
+
58
+ def get_patient_data(image_number):
59
+ row = data[data['patient_id'].astype(str).str.startswith(image_number)]
60
+ return row.iloc[0, 1:].tolist() if not row.empty else None
61
+
62
+ # feature importance w/o VISIBLE_STONE_CT(n=11)
63
+ data_dict = {key: [] for key in ['image_path', 'DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM', 'Hb', 'PLT', 'WBC', 'ALP', 'ALT', 'AST', 'CRP', 'BILIRUBIN', 'AGE','target']}
64
+
65
+
66
+ # Filter images based on the phase
67
+ if phase == 'portal':
68
+ # Filter the images for the 'portal' phase by checking for 'Portal' in the filename
69
+ image_list = [img for img in image_list if 'Portal' in os.path.basename(img)]
70
+ elif phase == 'pre-enhance':
71
+ # Filter the images for the 'pre-enhance' phase by checking for 'Pre_enhance' in the filename
72
+ image_list = [img for img in image_list if 'Pre_enhance' in os.path.basename(img)]
73
+ elif phase == 'combine':
74
+ # Include both 'portal' and 'pre-enhance' images for the 'combine' phase
75
+ portal_images = [img for img in image_list if 'Portal' in os.path.basename(img)]
76
+ pre_enhance_images = [img for img in image_list if 'Pre_enhance' in os.path.basename(img)]
77
+ image_list = portal_images + pre_enhance_images
78
+ else:
79
+ raise ValueError("Invalid phase. Choose from ['portal', 'pre-enhance', 'combine']")
80
+
81
+
82
+ for image_path in image_list:
83
+ image_number = os.path.basename(image_path).split('_')[0]
84
+ patient_data = get_patient_data(image_number)
85
+ if patient_data:
86
+ data_dict['image_path'].append(image_path)
87
+ keys_list = list(data_dict.keys())[1:]
88
+ for key, value in zip(keys_list, patient_data):
89
+ if key == 'image_path':
90
+ continue
91
+ data_dict[key].append(value)
92
+
93
+ if modality == 'image':
94
+ data_dict = {k: data_dict[k] for k in ['image_path', 'target']}
95
+
96
+ elif modality not in ['mm', 'tabular']:
97
+ raise AssertionError("Select Modality for Feature engineering!")
98
+
99
+ #Create a DataFrame from the dictionary
100
+ train_df = pd.DataFrame(data_dict)
101
+
102
+ #if only tabular use
103
+ if modality == 'tabular':
104
+ train_df = data
105
+
106
+ print("--------------Scaling--------------")
107
+ if modality in ['mm', 'tabular']:
108
+ columns_to_scale = ['Hb', 'PLT', 'WBC', 'ALP', 'ALT',
109
+ 'AST', 'CRP', 'BILIRUBIN', 'FIRST_SBP', 'FIRST_DBP', 'FIRST_HR', 'FIRST_RR',
110
+ 'FIRST_BT','AGE']
111
+
112
+ columns_to_scale_existing = [col for col in columns_to_scale if col in train_df.columns]
113
+
114
+ if columns_to_scale_existing:
115
+ scaler = MinMaxScaler()
116
+ train_df[columns_to_scale_existing] = scaler.fit_transform(train_df[columns_to_scale_existing])
117
+ else:
118
+ print("No columns to scale.")
119
+
120
+ if mode == 'train' or mode == 'test':
121
+ print("--------------Class balance--------------")
122
+ # undersampling
123
+ majority_class = train_df[train_df['target'] == 1.0]
124
+ minority_class = train_df[train_df['target'] == 0.0]
125
+
126
+ # Undersample the majority class to match the number of '1's in the minority class
127
+ undersampled_majority_class = resample(majority_class,
128
+ replace=False,
129
+ n_samples=len(minority_class),
130
+ random_state=42)
131
+
132
+ # Concatenate minority class and undersampled majority class
133
+ data = pd.concat([undersampled_majority_class, minority_class])
134
+
135
+ # print("--------------Class imbalance--------------")
136
+ if smote: # Apply SMOTE if the flag is set
137
+ data = train_df
138
+ print(data['target'].value_counts())
139
+ print("Applying SMOTE...")
140
+ smote = SMOTE(sampling_strategy='all', random_state=42)
141
+ X_data = data.drop(columns=['target'])
142
+ y_data = data['target']
143
+ X_data_res, y_data_res = smote.fit_resample(X_data, y_data)
144
+ data_resampled = pd.DataFrame(X_data_res, columns=X_data.columns)
145
+ data_resampled['target'] = y_data_res
146
+ data = data_resampled # Update train_data with resampled data
147
+ print(data['target'].value_counts())
148
+
149
+ train_data, test_data = train_test_split(data, test_size=0.3, stratify=data['target'], random_state=123)
150
+ valid_data, test_data = train_test_split(test_data, test_size=0.4, stratify=test_data['target'], random_state=123)
151
+
152
+ if mode == 'train':
153
+ print("Train set shape:", train_data.shape)
154
+ print("Validation set shape:", valid_data.shape)
155
+ return train_data, valid_data
156
+
157
+ elif mode == 'test':
158
+ print("Test set shape:", test_data.shape)
159
+ return test_data
160
+
161
+ elif mode == 'pretrain' or mode == 'eval':
162
+ pretrain_data, eval_data = train_test_split(train_df, test_size=0.1, random_state=123)
163
+ if mode == 'pretrain':
164
+ print("Pretrain set shape:", pretrain_data.shape)
165
+ return pretrain_data
166
+ elif mode == 'eval':
167
+ print("Validation set shape:", eval_data.shape)
168
+ return eval_data
169
+
170
+ else:
171
+ raise ValueError("Choose mode!")
172
+
173
+
174
+ def parse_args(args):
175
+ parser = argparse.ArgumentParser(description="M3D-LaMed chat")
176
+ parser.add_argument('--data_dir', type=str, default="/mnt/c/Users/user/Downloads/DUMC_project/DUMC_total")
177
+ parser.add_argument('--excel_file', type=str, default="dumc_1223_case3_duct_correct.csv")
178
+ parser.add_argument('--modality', type=str, default="tabular")
179
+ parser.add_argument('--phase', type=str, default="combine")
180
+ parser.add_argument('--smote', type=bool, default=True)
181
+ parser.add_argument('--model_name_or_path', type=str, default="logs/2025-01-13-18-16-test-tabular/ensemble_1", choices=[])
182
+ parser.add_argument('--top_p', type=float, default=None)
183
+ parser.add_argument('--temperature', type=float, default=1.0)
184
+ parser.add_argument('--device', type=str, default="cuda", choices=["cuda", "cpu"])
185
+
186
+ return parser.parse_args(args)
187
+
188
+
189
+ def load_data_and_prepare(data_dir, excel_file, modality, phase, smote):
190
+ # Load train, validation, and test data
191
+ train_df,val_df = load_data(data_dir, excel_file, 'train', modality, phase, smote)
192
+
193
+ train_df.drop(columns=['patient_id'],inplace = True)
194
+ val_df.drop(columns=['patient_id'],inplace = True)
195
+
196
+ train = pd.concat([train_df,val_df],axis=0)
197
+
198
+ return train
199
+
200
+
201
+ # Inference function
202
+ def classify(tabular_data, model):
203
+ """
204
+ Perform classification on tabular data using a PyCaret pre-trained model.
205
+
206
+ Args:
207
+ tabular_data (list or array-like): Input data points (e.g., a single row of features)
208
+ model (object): Pre-trained classification model from PyCaret
209
+
210
+ Returns:
211
+ str: Classification result and probabilities
212
+ """
213
+ try:
214
+ # Ensure tabular_data is a 2D list and extract the first row
215
+ if isinstance(tabular_data, list) and isinstance(tabular_data[0], list):
216
+ tabular_data = tabular_data[0] # Extract the first row
217
+ else:
218
+ raise ValueError("Input data is not in the expected 2D list format.")
219
+
220
+ # Convert input data to a pandas DataFrame
221
+ input_data = pd.DataFrame([tabular_data], columns= tabular_header)
222
+ print(f"Input DataFrame:\n{input_data}")
223
+
224
+ # Use PyCaret's predict_model to make predictions
225
+ prediction = predict_model(model, data=input_data)
226
+ print('OK')
227
+ # Extract predicted class and probability
228
+ predicted_class = prediction.loc[0, "prediction_label"]
229
+ class_probability = prediction.loc[0, "prediction_score"]
230
+
231
+ # Format the result
232
+ result = f"Predicted Class: {predicted_class}, Probability: {class_probability:.2f}"
233
+ return result
234
+
235
+ except Exception as e:
236
+ return f"An error occurred during classification: {str(e)}"
237
+
238
+ args = parse_args(sys.argv[1:])
239
+ # x_train, y_train, x_val, y_val, x_test, y_test = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
240
+ train = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
241
+ model = load_model(args.model_name_or_path)
242
+ device = torch.device(args.device)
243
+
244
+
245
+ # Gradio
246
+ examples = [
247
+ [
248
+ [['1', '0', '0', '104', '24', '10.6', '171', '14.54', '236', '182', '12.33', '3.2', '72']],
249
+ "PT_NO = 10001862, VISIBLE_STONE_CT = True, REAL_STONE = True",
250
+ ],
251
+ [
252
+ [['0', '1','0','106','18','13.6', '388', '21.13', '196', '118', '1.87', '2.7', '58']],
253
+ "PT_NO = 10007376, VISIBLE_STONE_CT = True, REAL_STONE = True",
254
+ ],
255
+ [
256
+ [['1', '0','1','205','18','9.3', '103', '8.45', '440', '100', '4.21', '4.5', '63']],
257
+ "PT_NO = 10040285, VISIBLE_STONE_CT = False, REAL_STONE = True",
258
+ ],
259
+ [
260
+ [['0', '1','1','130','20','12.1', '192', '8.63', '47', '59', '0.02', '0.4', '57']],
261
+ "PT_NO = 10005545, VISIBLE_STONE_CT = False, REAL_STONE = False",
262
+ ],
263
+ ]
264
+
265
+ tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
266
+
267
+ description = """
268
+ GPU ๋ฆฌ์†Œ์Šค ์ œ์•ฝ์œผ๋กœ ์ธํ•ด, ์˜จ๋ผ์ธ ๋ฐ๋ชจ์—์„œ๋Š” NVIDIA RTX 3090 24GB๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. \n
269
+
270
+ **Note**: ํ˜„์žฌ ์ €ํฌ ๋ชจ๋ธ์€ **์ด๋‹ด๊ด€๊ฒฐ์„์ฆ**์˜ ๋ถ„์„ ๋ฐ ์ง„๋‹จ์„ ์ค‘์‹ฌ์œผ๋กœ ์ตœ์ ํ™”๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ •ํ™•ํ•˜๊ณ  ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. \n
271
+ ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฉฐ, ์•„๋ž˜์™€ ๊ฐ™์ด ๊ฐ๊ฐ **์ด์‚ฐํ˜•(discrete)** **์—ฐ์†ํ˜•(continuous)** ๋ฐ์ดํ„ฐ๋กœ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. \n
272
+
273
+ - ์ด์‚ฐํ˜• ๋ณ€์ˆ˜:
274
+ - DUCT_DILIATATION_8MM
275
+ - DUCT_DILIATATION_10MM
276
+ - PANCREATITIS
277
+
278
+ - ์—ฐ์†ํ˜• ๋ณ€์ˆ˜:
279
+ - FIRST_SBP (Systolic blood pressure)
280
+ - FIRST_RR (Respiratory rate)
281
+ - Hb (Hemoglobin)
282
+ - PLT (Platelet)
283
+ - WBC (White Blood Cell)
284
+ - ALP (Alkaline Phosphatase)
285
+ - ALT (Alanine Aminotransferase)
286
+ - AST (Aspartate Aminotransferase)
287
+ - CRP (C-Reactive Protein)
288
+ - BILIRUBIN
289
+ - AGE
290
+
291
+ **์ค‘์š”**: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ์ปฌ๋Ÿผ์ด ๋ณ€๊ฒฝ(์ถ”๊ฐ€, ์‚ญ์ œ)๋  ๊ฒฝ์šฐ, ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. \n
292
+ ๋”ฐ๋ผ์„œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ๊ตฌ์กฐ๋ฅผ ๋ณ€๊ฒฝํ•˜๊ธฐ ์ „์— ๋ชจ๋ธ์˜ ์žฌํ•™์Šต ๋˜๋Š” ์žฌ๊ฒ€์ฆ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. \n
293
+ """
294
+
295
+ title_markdown = ("""
296
+ # ์ž„์ƒ ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ ๋จธ์‹ ๋Ÿฌ๋‹์„ ์ด์šฉํ•œ ์ด๋‹ด๊ด€์„ ์˜ˆ์ธก ๋ชจ๋ธ
297
+ ## Development of a Common Bile Duct Stone Prediction Model Using Machine Learning Based on Clinical Data
298
+ [๐Ÿ“–[Learn more about Common Bile Duct Stones (์ด๋‹ด๊ด€๊ฒฐ์„์ฆ)](https://namu.wiki/w/%EC%B4%9D%EB%8B%B4%EA%B4%80%EA%B2%B0%EC%84%9D%EC%A6%9D)]
299
+ ### Copyright ยฉ 2024 Dongguk University (DGU) and Dongguk University Medical Center (DUMC). All rights reserved.
300
+ """)
301
+
302
+
303
+ # def explain_with_lime(tabular_data):
304
+ # """
305
+ # Apply LIME to explain predictions.
306
+ # Args:
307
+ # tabular_data (list): List of input data points (e.g., rows in a dataframe)
308
+ # Returns:
309
+ # str: HTML or image showing LIME explanation
310
+ # """
311
+ # input_data = np.array(tabular_data, dtype=float)
312
+ # explainer = LimeTabularExplainer(
313
+ # training_data=x_train.values, # Replace with your training data
314
+ # feature_names=tabular_header,
315
+ # class_names=['intermediate', 'High'], # Replace with actual class names
316
+ # mode='classification'
317
+ # )
318
+
319
+ # explanation = explainer.explain_instance(
320
+ # input_data[0], # Single instance to explain
321
+ # model.predict_proba, # Probability prediction function
322
+ # num_features=len(tabular_header)
323
+ # )
324
+
325
+ # # Plot LIME explanation
326
+ # fig = explanation.as_pyplot_figure()
327
+ # fig.set_size_inches(25, 8)
328
+ # buf = io.BytesIO()
329
+ # fig.savefig(buf, format='png')
330
+ # buf.seek(0)
331
+ # encoded_image = base64.b64encode(buf.read()).decode('utf-8')
332
+ # buf.close()
333
+ # plt.close(fig)
334
+
335
+ # return f"<img src='data:image/png;base64,{encoded_image}'/>"
336
+
337
+
338
+ tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
339
+ tabular_dtype = ['number'] * len(tabular_header)
340
+
341
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
342
+ gr.Markdown(title_markdown)
343
+ gr.Markdown(description)
344
+ with gr.Row():
345
+ with gr.Column():
346
+ tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=13)
347
+ info = gr.Textbox(lines=1, label="Patient info", visible = False)
348
+
349
+ with gr.Accordion("Parameters", open=False) as parameter_row:
350
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
351
+ label="Temperature", )
352
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, interactive=True, label="Top P", )
353
+
354
+ with gr.Row():
355
+ # btn_c = gr.ClearButton([tabular_input])
356
+ btn_c = gr.Button("Clear")
357
+ btn = gr.Button("Run")
358
+
359
+
360
+
361
+
362
+ result_output = gr.Textbox(lines=2, label="Classification Result")
363
+ lime_output = gr.HTML(label="LIME Explanation")
364
+ gr.Examples(examples=examples, inputs=[tabular_input, info])
365
+ btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
366
+ # btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button
367
+
368
+ # Clear functionality: resets inputs and outputs
369
+ def clear_fields():
370
+ return None, None, [[None] * len(tabular_header)]
371
+
372
+ btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])
373
+
374
+
375
+ demo.queue()
376
+ demo.launch(share=True)
377
+
378
+