ready2drop commited on
Commit
d3c46ef
·
verified ·
1 Parent(s): fbcd943
Files changed (1) hide show
  1. app.py +9 -238
app.py CHANGED
@@ -12,150 +12,12 @@ from lime.lime_tabular import LimeTabularExplainer
12
  from pycaret.classification import *
13
  import warnings
14
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
15
- from sklearn.preprocessing import MinMaxScaler
16
- from sklearn.model_selection import train_test_split
17
- from sklearn.utils import resample
18
- from glob import glob
19
- from imblearn.over_sampling import SMOTE
20
 
 
 
21
 
22
 
23
- def load_data(data_dir : str,
24
- excel_file : str,
25
- mode : str = "train",
26
- modality : str = 'mm',
27
- phase : str = 'portal', # 'portal', 'pre-enhance', 'combine'
28
- smote = bool,
29
- ):
30
-
31
-
32
- print("--------------Load RawData--------------")
33
- df = pd.read_csv(os.path.join(data_dir, excel_file))
34
-
35
- #Inclusion
36
- print("--------------Inclusion--------------")
37
- print('Total : ', len(df))
38
-
39
- print("--------------fillNA--------------")
40
- # data = data.dropna()
41
- df.fillna(0.0,inplace=True)
42
- print(df['REAL_STONE'].value_counts())
43
-
44
- #Column rename
45
- df.rename(columns={'ID': 'patient_id', 'REAL_STONE':'target'}, inplace=True)
46
-
47
- # Final(n=11)
48
- columns = ['patient_id','DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE','target']
49
-
50
- data = df[columns]
51
- data['patient_id'] = data['patient_id'].astype(str)
52
-
53
- image_list = sorted(glob(os.path.join(data_dir,"*.nii.gz")))
54
-
55
- def get_patient_data(image_number):
56
- row = data[data['patient_id'].astype(str).str.startswith(image_number)]
57
- return row.iloc[0, 1:].tolist() if not row.empty else None
58
-
59
- # Final(n=11)
60
- data_dict = {key: [] for key in ['image_path','DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE','target']}
61
-
62
-
63
-
64
- # Filter images based on the phase
65
- if phase == 'portal':
66
- # Filter the images for the 'portal' phase by checking for 'Portal' in the filename
67
- image_list = [img for img in image_list if 'Portal' in os.path.basename(img)]
68
- elif phase == 'pre-enhance':
69
- # Filter the images for the 'pre-enhance' phase by checking for 'Pre_enhance' in the filename
70
- image_list = [img for img in image_list if 'Pre_enhance' in os.path.basename(img)]
71
- elif phase == 'combine':
72
- # Include both 'portal' and 'pre-enhance' images for the 'combine' phase
73
- portal_images = [img for img in image_list if 'Portal' in os.path.basename(img)]
74
- pre_enhance_images = [img for img in image_list if 'Pre_enhance' in os.path.basename(img)]
75
- image_list = portal_images + pre_enhance_images
76
- else:
77
- raise ValueError("Invalid phase. Choose from ['portal', 'pre-enhance', 'combine']")
78
-
79
-
80
- for image_path in image_list:
81
- image_number = os.path.basename(image_path).split('_')[0]
82
- patient_data = get_patient_data(image_number)
83
- if patient_data:
84
- data_dict['image_path'].append(image_path)
85
- keys_list = list(data_dict.keys())[1:]
86
- for key, value in zip(keys_list, patient_data):
87
- if key == 'image_path':
88
- continue
89
- data_dict[key].append(value)
90
-
91
- if modality == 'image':
92
- data_dict = {k: data_dict[k] for k in ['image_path', 'target']}
93
-
94
- elif modality not in ['mm', 'tabular']:
95
- raise AssertionError("Select Modality for Feature engineering!")
96
-
97
- #Create a DataFrame from the dictionary
98
- train_df = pd.DataFrame(data_dict)
99
-
100
- #if only tabular use
101
- if modality == 'tabular':
102
- train_df = data
103
-
104
-
105
- if mode == 'train' or mode == 'test':
106
- print("--------------Class balance--------------")
107
- # undersampling
108
- majority_class = train_df[train_df['target'] == 1.0]
109
- minority_class = train_df[train_df['target'] == 0.0]
110
-
111
- # Undersample the majority class to match the number of '1's in the minority class
112
- undersampled_majority_class = resample(majority_class,
113
- replace=False,
114
- n_samples=len(minority_class),
115
- random_state=42)
116
-
117
- # Concatenate minority class and undersampled majority class
118
- data = pd.concat([undersampled_majority_class, minority_class])
119
-
120
- # print("--------------Class imbalance--------------")
121
- if smote: # Apply SMOTE if the flag is set
122
- data = train_df
123
- print(data['target'].value_counts())
124
- print("Applying SMOTE...")
125
- smote = SMOTE(sampling_strategy='all', random_state=42)
126
- X_data = data.drop(columns=['target'])
127
- y_data = data['target']
128
- X_data_res, y_data_res = smote.fit_resample(X_data, y_data)
129
- data_resampled = pd.DataFrame(X_data_res, columns=X_data.columns)
130
- data_resampled['target'] = y_data_res
131
- data = data_resampled # Update train_data with resampled data
132
- print(data['target'].value_counts())
133
-
134
- train_data, test_data = train_test_split(data, test_size=0.3, stratify=data['target'], random_state=123)
135
- valid_data, test_data = train_test_split(test_data, test_size=0.4, stratify=test_data['target'], random_state=123)
136
-
137
- if mode == 'train':
138
- print("Train set shape:", train_data.shape)
139
- print("Validation set shape:", valid_data.shape)
140
- return train_data, valid_data
141
-
142
- elif mode == 'test':
143
- print("Test set shape:", test_data.shape)
144
- return test_data
145
-
146
- elif mode == 'pretrain' or mode == 'eval':
147
- pretrain_data, eval_data = train_test_split(train_df, test_size=0.1, random_state=123)
148
- if mode == 'pretrain':
149
- print("Pretrain set shape:", pretrain_data.shape)
150
- return pretrain_data
151
- elif mode == 'eval':
152
- print("Validation set shape:", eval_data.shape)
153
- return eval_data
154
-
155
- else:
156
- raise ValueError("Choose mode!")
157
-
158
-
159
  def parse_args(args):
160
  parser = argparse.ArgumentParser(description="CBD Classification")
161
  parser.add_argument('--data_dir', type=str, default="./")
@@ -231,107 +93,16 @@ def classify(tabular_data):
231
 
232
  if __name__ == '__main__':
233
  args = parse_args(sys.argv[1:])
234
- train = load_data_and_prepare(args.data_dir, args.excel_file, args.modality, args.phase, args.smote)
235
  model = load_model(args.model_name_or_path)
236
  device = torch.device(args.device)
237
-
238
-
239
- # Gradio
240
- examples = [
241
- [
242
- [['1', '0', '0', '104', '24', '10.6', '171', '14.54', '236', '182', '12.33', '3.2', '72']],
243
- "PT_NO = 10001862, VISIBLE_STONE_CT = True, REAL_STONE = True",
244
- ],
245
- [
246
- [['0', '1','0','106','18','13.6', '388', '21.13', '196', '118', '1.87', '2.7', '58']],
247
- "PT_NO = 10007376, VISIBLE_STONE_CT = True, REAL_STONE = True",
248
- ],
249
- [
250
- [['1', '0','1','205','18','9.3', '103', '8.45', '440', '100', '4.21', '4.5', '63']],
251
- "PT_NO = 10040285, VISIBLE_STONE_CT = False, REAL_STONE = True",
252
- ],
253
- [
254
- [['0', '1','1','130','20','12.1', '192', '8.63', '47', '59', '0.02', '0.4', '57']],
255
- "PT_NO = 10005545, VISIBLE_STONE_CT = False, REAL_STONE = False",
256
- ],
257
- ]
258
-
259
- tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
260
-
261
- description = """
262
- GPU 리소스 제약으로 인해, 온라인 데모에서는 NVIDIA RTX 3090 24GB를 사용하고 있습니다. \n
263
-
264
- **Note**: 현재 저희 모델은 **총담관결석증**의 분석 및 진단을 중심으로 최적화되어 있으며, 정확하고 신뢰할 수 있는 결과를 제공합니다. \n
265
- 모델은 다음과 같은 입력 데이터를 처리하며, 아래와 같이 각각 **이산형(discrete)** **연속형(continuous)** 데이터로 처리됩니다. \n
266
-
267
- - 이산형 변수:
268
- - DUCT_DILIATATION_8MM
269
- - DUCT_DILIATATION_10MM
270
- - PANCREATITIS
271
-
272
- - 연속형 변수:
273
- - FIRST_SBP (Systolic blood pressure)
274
- - FIRST_RR (Respiratory rate)
275
- - Hb (Hemoglobin)
276
- - PLT (Platelet)
277
- - WBC (White Blood Cell)
278
- - ALP (Alkaline Phosphatase)
279
- - ALT (Alanine Aminotransferase)
280
- - AST (Aspartate Aminotransferase)
281
- - CRP (C-Reactive Protein)
282
- - BILIRUBIN
283
- - AGE
284
-
285
- **중요**: 입력 데이터의 컬럼이 변경(추가, 삭제)될 경우, 모델의 예측 결과가 달라질 수 있습니다. \n
286
- 따라서 입력 데이터의 구조를 변경하기 전에 모델의 재학습 또는 재검증이 필요합니다. \n
287
- """
288
-
289
- title_markdown = ("""
290
- # 임상 데이터 기반 머신러닝을 이용한 총담관석 예측 모델
291
- ## Development of a Common Bile Duct Stone Prediction Model Using Machine Learning Based on Clinical Data
292
- [📖[Learn more about Common Bile Duct Stones (총담관결석증)](https://namu.wiki/w/%EC%B4%9D%EB%8B%B4%EA%B4%80%EA%B2%B0%EC%84%9D%EC%A6%9D)]
293
- ### Copyright © 2024 Dongguk University (DGU) and Dongguk University Medical Center (DUMC). All rights reserved.
294
- """)
295
-
296
-
297
- # def explain_with_lime(tabular_data):
298
- # """
299
- # Apply LIME to explain predictions.
300
- # Args:
301
- # tabular_data (list): List of input data points (e.g., rows in a dataframe)
302
- # Returns:
303
- # str: HTML or image showing LIME explanation
304
- # """
305
- # input_data = np.array(tabular_data, dtype=float)
306
- # explainer = LimeTabularExplainer(
307
- # training_data=x_train.values, # Replace with your training data
308
- # feature_names=tabular_header,
309
- # class_names=['intermediate', 'High'], # Replace with actual class names
310
- # mode='classification'
311
- # )
312
-
313
- # explanation = explainer.explain_instance(
314
- # input_data[0], # Single instance to explain
315
- # model.predict_proba, # Probability prediction function
316
- # num_features=len(tabular_header)
317
- # )
318
-
319
- # # Plot LIME explanation
320
- # fig = explanation.as_pyplot_figure()
321
- # fig.set_size_inches(25, 8)
322
- # buf = io.BytesIO()
323
- # fig.savefig(buf, format='png')
324
- # buf.seek(0)
325
- # encoded_image = base64.b64encode(buf.read()).decode('utf-8')
326
- # buf.close()
327
- # plt.close(fig)
328
-
329
- # return f"<img src='data:image/png;base64,{encoded_image}'/>"
330
-
331
-
332
- tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
333
  tabular_dtype = ['number'] * len(tabular_header)
334
 
 
335
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
336
  gr.Markdown(title_markdown)
337
  gr.Markdown(description)
 
12
  from pycaret.classification import *
13
  import warnings
14
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
 
 
 
 
 
15
 
16
+ from util import load_data
17
+ import view
18
 
19
 
20
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def parse_args(args):
22
  parser = argparse.ArgumentParser(description="CBD Classification")
23
  parser.add_argument('--data_dir', type=str, default="./")
 
93
 
94
  if __name__ == '__main__':
95
  args = parse_args(sys.argv[1:])
96
+ train = load_data_and_prepare(args.data_dir, args.excel_file, args.mode, args.scale, args.smote)
97
  model = load_model(args.model_name_or_path)
98
  device = torch.device(args.device)
99
+ examples = view.examples
100
+ description = view.description
101
+ title_markdown = view.title_markdown
102
+ tabular_header = view.tabular_header
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  tabular_dtype = ['number'] * len(tabular_header)
104
 
105
+
106
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
107
  gr.Markdown(title_markdown)
108
  gr.Markdown(description)