ready2drop commited on
Commit
bb9eeda
ยท
verified ยท
1 Parent(s): ee42bae
Files changed (1) hide show
  1. app.py +140 -158
app.py CHANGED
@@ -1,16 +1,13 @@
1
  import argparse
2
  import os
3
- import io
4
- import base64
5
  import matplotlib.pyplot as plt
6
  import sys
7
- import bleach
8
  import gradio as gr
9
  import torch
10
- import numpy as np
11
  import pandas as pd
12
- import pickle
13
- from sklearn.preprocessing import StandardScaler
 
14
  from lime.lime_tabular import LimeTabularExplainer
15
  from pycaret.classification import *
16
  import warnings
@@ -102,20 +99,6 @@ def load_data(data_dir : str,
102
  #if only tabular use
103
  if modality == 'tabular':
104
  train_df = data
105
-
106
- print("--------------Scaling--------------")
107
- if modality in ['mm', 'tabular']:
108
- columns_to_scale = ['Hb', 'PLT', 'WBC', 'ALP', 'ALT',
109
- 'AST', 'CRP', 'BILIRUBIN', 'FIRST_SBP', 'FIRST_DBP', 'FIRST_HR', 'FIRST_RR',
110
- 'FIRST_BT','AGE']
111
-
112
- columns_to_scale_existing = [col for col in columns_to_scale if col in train_df.columns]
113
-
114
- if columns_to_scale_existing:
115
- scaler = MinMaxScaler()
116
- train_df[columns_to_scale_existing] = scaler.fit_transform(train_df[columns_to_scale_existing])
117
- else:
118
- print("No columns to scale.")
119
 
120
  if mode == 'train' or mode == 'test':
121
  print("--------------Class balance--------------")
@@ -219,11 +202,11 @@ def classify(tabular_data):
219
 
220
  # Convert input data to a pandas DataFrame
221
  input_data = pd.DataFrame([tabular_data], columns= tabular_header)
222
- print(f"Input DataFrame:\n{input_data}")
223
-
224
  # Use PyCaret's predict_model to make predictions
225
  prediction = predict_model(model, data=input_data)
226
- print('OK')
227
  # Extract predicted class and probability
228
  predicted_class = prediction.loc[0, "prediction_label"]
229
  class_probability = prediction.loc[0, "prediction_score"]
@@ -235,144 +218,143 @@ def classify(tabular_data):
235
  except Exception as e:
236
  return f"An error occurred during classification: {str(e)}"
237
 
238
- args = parse_args(sys.argv[1:])
239
- # x_train, y_train, x_val, y_val, x_test, y_test = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
240
- train = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
241
- model = load_model(args.model_name_or_path)
242
- device = torch.device(args.device)
243
-
244
-
245
- # Gradio
246
- examples = [
247
- [
248
- [['1', '0', '0', '104', '24', '10.6', '171', '14.54', '236', '182', '12.33', '3.2', '72']],
249
- "PT_NO = 10001862, VISIBLE_STONE_CT = True, REAL_STONE = True",
250
- ],
251
- [
252
- [['0', '1','0','106','18','13.6', '388', '21.13', '196', '118', '1.87', '2.7', '58']],
253
- "PT_NO = 10007376, VISIBLE_STONE_CT = True, REAL_STONE = True",
254
- ],
255
- [
256
- [['1', '0','1','205','18','9.3', '103', '8.45', '440', '100', '4.21', '4.5', '63']],
257
- "PT_NO = 10040285, VISIBLE_STONE_CT = False, REAL_STONE = True",
258
- ],
259
- [
260
- [['0', '1','1','130','20','12.1', '192', '8.63', '47', '59', '0.02', '0.4', '57']],
261
- "PT_NO = 10005545, VISIBLE_STONE_CT = False, REAL_STONE = False",
262
- ],
263
- ]
264
-
265
- tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
266
-
267
- description = """
268
- GPU ๋ฆฌ์†Œ์Šค ์ œ์•ฝ์œผ๋กœ ์ธํ•ด, ์˜จ๋ผ์ธ ๋ฐ๋ชจ์—์„œ๋Š” NVIDIA RTX 3090 24GB๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. \n
269
-
270
- **Note**: ํ˜„์žฌ ์ €ํฌ ๋ชจ๋ธ์€ **์ด๋‹ด๊ด€๊ฒฐ์„์ฆ**์˜ ๋ถ„์„ ๋ฐ ์ง„๋‹จ์„ ์ค‘์‹ฌ์œผ๋กœ ์ตœ์ ํ™”๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ •ํ™•ํ•˜๊ณ  ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. \n
271
- ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฉฐ, ์•„๋ž˜์™€ ๊ฐ™์ด ๊ฐ๊ฐ **์ด์‚ฐํ˜•(discrete)** **์—ฐ์†ํ˜•(continuous)** ๋ฐ์ดํ„ฐ๋กœ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. \n
272
-
273
- - ์ด์‚ฐํ˜• ๋ณ€์ˆ˜:
274
- - DUCT_DILIATATION_8MM
275
- - DUCT_DILIATATION_10MM
276
- - PANCREATITIS
277
-
278
- - ์—ฐ์†ํ˜• ๋ณ€์ˆ˜:
279
- - FIRST_SBP (Systolic blood pressure)
280
- - FIRST_RR (Respiratory rate)
281
- - Hb (Hemoglobin)
282
- - PLT (Platelet)
283
- - WBC (White Blood Cell)
284
- - ALP (Alkaline Phosphatase)
285
- - ALT (Alanine Aminotransferase)
286
- - AST (Aspartate Aminotransferase)
287
- - CRP (C-Reactive Protein)
288
- - BILIRUBIN
289
- - AGE
290
-
291
- **์ค‘์š”**: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ์ปฌ๋Ÿผ์ด ๋ณ€๊ฒฝ(์ถ”๊ฐ€, ์‚ญ์ œ)๋  ๊ฒฝ์šฐ, ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. \n
292
- ๋”ฐ๋ผ์„œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ๊ตฌ์กฐ๋ฅผ ๋ณ€๊ฒฝํ•˜๊ธฐ ์ „์— ๋ชจ๋ธ์˜ ์žฌํ•™์Šต ๋˜๋Š” ์žฌ๊ฒ€์ฆ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. \n
293
- """
294
-
295
- title_markdown = ("""
296
- # ์ž„์ƒ ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ ๋จธ์‹ ๋Ÿฌ๋‹์„ ์ด์šฉํ•œ ์ด๋‹ด๊ด€์„ ์˜ˆ์ธก ๋ชจ๋ธ
297
- ## Development of a Common Bile Duct Stone Prediction Model Using Machine Learning Based on Clinical Data
298
- [๐Ÿ“–[Learn more about Common Bile Duct Stones (์ด๋‹ด๊ด€๊ฒฐ์„์ฆ)](https://namu.wiki/w/%EC%B4%9D%EB%8B%B4%EA%B4%80%EA%B2%B0%EC%84%9D%EC%A6%9D)]
299
- ### Copyright ยฉ 2024 Dongguk University (DGU) and Dongguk University Medical Center (DUMC). All rights reserved.
300
- """)
301
-
302
-
303
- # def explain_with_lime(tabular_data):
304
- # """
305
- # Apply LIME to explain predictions.
306
- # Args:
307
- # tabular_data (list): List of input data points (e.g., rows in a dataframe)
308
- # Returns:
309
- # str: HTML or image showing LIME explanation
310
- # """
311
- # input_data = np.array(tabular_data, dtype=float)
312
- # explainer = LimeTabularExplainer(
313
- # training_data=x_train.values, # Replace with your training data
314
- # feature_names=tabular_header,
315
- # class_names=['intermediate', 'High'], # Replace with actual class names
316
- # mode='classification'
317
- # )
318
-
319
- # explanation = explainer.explain_instance(
320
- # input_data[0], # Single instance to explain
321
- # model.predict_proba, # Probability prediction function
322
- # num_features=len(tabular_header)
323
- # )
324
-
325
- # # Plot LIME explanation
326
- # fig = explanation.as_pyplot_figure()
327
- # fig.set_size_inches(25, 8)
328
- # buf = io.BytesIO()
329
- # fig.savefig(buf, format='png')
330
- # buf.seek(0)
331
- # encoded_image = base64.b64encode(buf.read()).decode('utf-8')
332
- # buf.close()
333
- # plt.close(fig)
334
-
335
- # return f"<img src='data:image/png;base64,{encoded_image}'/>"
336
-
337
-
338
- tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
339
- tabular_dtype = ['number'] * len(tabular_header)
340
-
341
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
342
- gr.Markdown(title_markdown)
343
- gr.Markdown(description)
344
- with gr.Row():
345
- with gr.Column():
346
- tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=13)
347
- info = gr.Textbox(lines=1, label="Patient info", visible = False)
348
-
349
- with gr.Accordion("Parameters", open=False) as parameter_row:
350
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
351
- label="Temperature", )
352
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, interactive=True, label="Top P", )
353
-
354
- with gr.Row():
355
- # btn_c = gr.ClearButton([tabular_input])
356
- btn_c = gr.Button("Clear")
357
- btn = gr.Button("Run")
358
-
359
-
360
 
361
-
362
- result_output = gr.Textbox(lines=2, label="Classification Result")
363
- lime_output = gr.HTML(label="LIME Explanation")
364
- gr.Examples(examples=examples, inputs=[tabular_input, info])
365
- btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
366
- # btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button
367
-
368
- # Clear functionality: resets inputs and outputs
369
- def clear_fields():
370
- return None, None, [[None] * len(tabular_header)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])
 
 
 
 
 
 
 
 
 
373
 
 
374
 
375
- demo.queue()
376
- demo.launch(share=True)
377
 
 
 
378
 
 
1
  import argparse
2
  import os
 
 
3
  import matplotlib.pyplot as plt
4
  import sys
 
5
  import gradio as gr
6
  import torch
 
7
  import pandas as pd
8
+ import numpy as np
9
+ import io
10
+ import base64
11
  from lime.lime_tabular import LimeTabularExplainer
12
  from pycaret.classification import *
13
  import warnings
 
99
  #if only tabular use
100
  if modality == 'tabular':
101
  train_df = data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  if mode == 'train' or mode == 'test':
104
  print("--------------Class balance--------------")
 
202
 
203
  # Convert input data to a pandas DataFrame
204
  input_data = pd.DataFrame([tabular_data], columns= tabular_header)
205
+ print(f"Original Input DataFrame:\n{input_data}")
206
+
207
  # Use PyCaret's predict_model to make predictions
208
  prediction = predict_model(model, data=input_data)
209
+
210
  # Extract predicted class and probability
211
  predicted_class = prediction.loc[0, "prediction_label"]
212
  class_probability = prediction.loc[0, "prediction_score"]
 
218
  except Exception as e:
219
  return f"An error occurred during classification: {str(e)}"
220
 
221
+ if __name__ == '__main__':
222
+ args = parse_args(sys.argv[1:])
223
+ train = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
224
+ model = load_model(args.model_name_or_path)
225
+ device = torch.device(args.device)
226
+
227
+
228
+ # Gradio
229
+ examples = [
230
+ [
231
+ [['1', '0', '0', '104', '24', '10.6', '171', '14.54', '236', '182', '12.33', '3.2', '72']],
232
+ "PT_NO = 10001862, VISIBLE_STONE_CT = True, REAL_STONE = True",
233
+ ],
234
+ [
235
+ [['0', '1','0','106','18','13.6', '388', '21.13', '196', '118', '1.87', '2.7', '58']],
236
+ "PT_NO = 10007376, VISIBLE_STONE_CT = True, REAL_STONE = True",
237
+ ],
238
+ [
239
+ [['1', '0','1','205','18','9.3', '103', '8.45', '440', '100', '4.21', '4.5', '63']],
240
+ "PT_NO = 10040285, VISIBLE_STONE_CT = False, REAL_STONE = True",
241
+ ],
242
+ [
243
+ [['0', '1','1','130','20','12.1', '192', '8.63', '47', '59', '0.02', '0.4', '57']],
244
+ "PT_NO = 10005545, VISIBLE_STONE_CT = False, REAL_STONE = False",
245
+ ],
246
+ ]
247
+
248
+ tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
249
+
250
+ description = """
251
+ GPU ๋ฆฌ์†Œ์Šค ์ œ์•ฝ์œผ๋กœ ์ธํ•ด, ์˜จ๋ผ์ธ ๋ฐ๋ชจ์—์„œ๋Š” NVIDIA RTX 3090 24GB๋ฅผ ์‚ฌ์šฉํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. \n
252
+
253
+ **Note**: ํ˜„์žฌ ์ €ํฌ ๋ชจ๋ธ์€ **์ด๋‹ด๊ด€๊ฒฐ์„์ฆ**์˜ ๋ถ„์„ ๋ฐ ์ง„๋‹จ์„ ์ค‘์‹ฌ์œผ๋กœ ์ตœ์ ํ™”๋˜์–ด ์žˆ์œผ๋ฉฐ, ์ •ํ™•ํ•˜๊ณ  ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. \n
254
+ ๋ชจ๋ธ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ฉฐ, ์•„๋ž˜์™€ ๊ฐ™์ด ๊ฐ๊ฐ **์ด์‚ฐํ˜•(discrete)** **์—ฐ์†ํ˜•(continuous)** ๋ฐ์ดํ„ฐ๋กœ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. \n
255
+
256
+ - ์ด์‚ฐํ˜• ๋ณ€์ˆ˜:
257
+ - DUCT_DILIATATION_8MM
258
+ - DUCT_DILIATATION_10MM
259
+ - PANCREATITIS
260
+
261
+ - ์—ฐ์†ํ˜• ๋ณ€์ˆ˜:
262
+ - FIRST_SBP (Systolic blood pressure)
263
+ - FIRST_RR (Respiratory rate)
264
+ - Hb (Hemoglobin)
265
+ - PLT (Platelet)
266
+ - WBC (White Blood Cell)
267
+ - ALP (Alkaline Phosphatase)
268
+ - ALT (Alanine Aminotransferase)
269
+ - AST (Aspartate Aminotransferase)
270
+ - CRP (C-Reactive Protein)
271
+ - BILIRUBIN
272
+ - AGE
273
+
274
+ **์ค‘์š”**: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ์ปฌ๋Ÿผ์ด ๋ณ€๊ฒฝ(์ถ”๊ฐ€, ์‚ญ์ œ)๋  ๊ฒฝ์šฐ, ๋ชจ๋ธ์˜ ์˜ˆ์ธก ๊ฒฐ๊ณผ๊ฐ€ ๋‹ฌ๋ผ์งˆ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. \n
275
+ ๋”ฐ๋ผ์„œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ ๊ตฌ์กฐ๋ฅผ ๋ณ€๊ฒฝํ•˜๊ธฐ ์ „์— ๋ชจ๋ธ์˜ ์žฌํ•™์Šต ๋˜๋Š” ์žฌ๊ฒ€์ฆ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. \n
276
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
+ title_markdown = ("""
279
+ # ์ž„์ƒ ๋ฐ์ดํ„ฐ ๊ธฐ๋ฐ˜ ๋จธ์‹ ๋Ÿฌ๋‹์„ ์ด์šฉํ•œ ์ด๋‹ด๊ด€์„ ์˜ˆ์ธก ๋ชจ๋ธ
280
+ ## Development of a Common Bile Duct Stone Prediction Model Using Machine Learning Based on Clinical Data
281
+ [๐Ÿ“–[Learn more about Common Bile Duct Stones (์ด๋‹ด๊ด€๊ฒฐ์„์ฆ)](https://namu.wiki/w/%EC%B4%9D%EB%8B%B4%EA%B4%80%EA%B2%B0%EC%84%9D%EC%A6%9D)]
282
+ ### Copyright ยฉ 2024 Dongguk University (DGU) and Dongguk University Medical Center (DUMC). All rights reserved.
283
+ """)
284
+
285
+
286
+ # def explain_with_lime(tabular_data):
287
+ # """
288
+ # Apply LIME to explain predictions.
289
+ # Args:
290
+ # tabular_data (list): List of input data points (e.g., rows in a dataframe)
291
+ # Returns:
292
+ # str: HTML or image showing LIME explanation
293
+ # """
294
+ # input_data = np.array(tabular_data, dtype=float)
295
+ # explainer = LimeTabularExplainer(
296
+ # training_data=x_train.values, # Replace with your training data
297
+ # feature_names=tabular_header,
298
+ # class_names=['intermediate', 'High'], # Replace with actual class names
299
+ # mode='classification'
300
+ # )
301
+
302
+ # explanation = explainer.explain_instance(
303
+ # input_data[0], # Single instance to explain
304
+ # model.predict_proba, # Probability prediction function
305
+ # num_features=len(tabular_header)
306
+ # )
307
+
308
+ # # Plot LIME explanation
309
+ # fig = explanation.as_pyplot_figure()
310
+ # fig.set_size_inches(25, 8)
311
+ # buf = io.BytesIO()
312
+ # fig.savefig(buf, format='png')
313
+ # buf.seek(0)
314
+ # encoded_image = base64.b64encode(buf.read()).decode('utf-8')
315
+ # buf.close()
316
+ # plt.close(fig)
317
+
318
+ # return f"<img src='data:image/png;base64,{encoded_image}'/>"
319
+
320
+
321
+ tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
322
+ tabular_dtype = ['number'] * len(tabular_header)
323
+
324
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
325
+ gr.Markdown(title_markdown)
326
+ gr.Markdown(description)
327
+ with gr.Row():
328
+ with gr.Column():
329
+ tabular_input = gr.Dataframe(headers= tabular_header, datatype= tabular_dtype, label="Tabular Input", type="array", interactive=True, row_count=1, col_count=13)
330
+ info = gr.Textbox(lines=1, label="Patient info", visible = False)
331
+
332
+ with gr.Accordion("Parameters", open=False) as parameter_row:
333
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True,
334
+ label="Temperature", )
335
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.4, step=0.1, interactive=True, label="Top P", )
336
+
337
+ with gr.Row():
338
+ # btn_c = gr.ClearButton([tabular_input])
339
+ btn_c = gr.Button("Clear")
340
+ btn = gr.Button("Run")
341
+
342
+
343
 
344
+
345
+ result_output = gr.Textbox(lines=2, label="Classification Result")
346
+ lime_output = gr.HTML(label="LIME Explanation")
347
+ gr.Examples(examples=examples, inputs=[tabular_input, info])
348
+ btn.click(fn=classify, inputs=tabular_input, outputs=result_output)
349
+ # btn.click(fn=explain_with_lime, inputs=tabular_input, outputs=lime_output) # Add LIME button
350
+
351
+ # Clear functionality: resets inputs and outputs
352
+ def clear_fields():
353
+ return None, None, [[None] * len(tabular_header)]
354
 
355
+ btn_c.click(fn=clear_fields, inputs=[], outputs=[result_output, lime_output, tabular_input])
356
 
 
 
357
 
358
+ demo.queue()
359
+ demo.launch(share=True)
360