Spaces:
Sleeping
Sleeping
test
Browse files
app.py
CHANGED
|
@@ -12,150 +12,12 @@ from lime.lime_tabular import LimeTabularExplainer
|
|
| 12 |
from pycaret.classification import *
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
|
| 15 |
-
from sklearn.preprocessing import MinMaxScaler
|
| 16 |
-
from sklearn.model_selection import train_test_split
|
| 17 |
-
from sklearn.utils import resample
|
| 18 |
-
from glob import glob
|
| 19 |
-
from imblearn.over_sampling import SMOTE
|
| 20 |
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
-
|
| 24 |
-
excel_file : str,
|
| 25 |
-
mode : str = "train",
|
| 26 |
-
modality : str = 'mm',
|
| 27 |
-
phase : str = 'portal', # 'portal', 'pre-enhance', 'combine'
|
| 28 |
-
smote = bool,
|
| 29 |
-
):
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
print("--------------Load RawData--------------")
|
| 33 |
-
df = pd.read_csv(os.path.join(data_dir, excel_file))
|
| 34 |
-
|
| 35 |
-
#Inclusion
|
| 36 |
-
print("--------------Inclusion--------------")
|
| 37 |
-
print('Total : ', len(df))
|
| 38 |
-
|
| 39 |
-
print("--------------fillNA--------------")
|
| 40 |
-
# data = data.dropna()
|
| 41 |
-
df.fillna(0.0,inplace=True)
|
| 42 |
-
print(df['REAL_STONE'].value_counts())
|
| 43 |
-
|
| 44 |
-
#Column rename
|
| 45 |
-
df.rename(columns={'ID': 'patient_id', 'REAL_STONE':'target'}, inplace=True)
|
| 46 |
-
|
| 47 |
-
# Final(n=11)
|
| 48 |
-
columns = ['patient_id','DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE','target']
|
| 49 |
-
|
| 50 |
-
data = df[columns]
|
| 51 |
-
data['patient_id'] = data['patient_id'].astype(str)
|
| 52 |
-
|
| 53 |
-
image_list = sorted(glob(os.path.join(data_dir,"*.nii.gz")))
|
| 54 |
-
|
| 55 |
-
def get_patient_data(image_number):
|
| 56 |
-
row = data[data['patient_id'].astype(str).str.startswith(image_number)]
|
| 57 |
-
return row.iloc[0, 1:].tolist() if not row.empty else None
|
| 58 |
-
|
| 59 |
-
# Final(n=11)
|
| 60 |
-
data_dict = {key: [] for key in ['image_path','DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE','target']}
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# Filter images based on the phase
|
| 65 |
-
if phase == 'portal':
|
| 66 |
-
# Filter the images for the 'portal' phase by checking for 'Portal' in the filename
|
| 67 |
-
image_list = [img for img in image_list if 'Portal' in os.path.basename(img)]
|
| 68 |
-
elif phase == 'pre-enhance':
|
| 69 |
-
# Filter the images for the 'pre-enhance' phase by checking for 'Pre_enhance' in the filename
|
| 70 |
-
image_list = [img for img in image_list if 'Pre_enhance' in os.path.basename(img)]
|
| 71 |
-
elif phase == 'combine':
|
| 72 |
-
# Include both 'portal' and 'pre-enhance' images for the 'combine' phase
|
| 73 |
-
portal_images = [img for img in image_list if 'Portal' in os.path.basename(img)]
|
| 74 |
-
pre_enhance_images = [img for img in image_list if 'Pre_enhance' in os.path.basename(img)]
|
| 75 |
-
image_list = portal_images + pre_enhance_images
|
| 76 |
-
else:
|
| 77 |
-
raise ValueError("Invalid phase. Choose from ['portal', 'pre-enhance', 'combine']")
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
for image_path in image_list:
|
| 81 |
-
image_number = os.path.basename(image_path).split('_')[0]
|
| 82 |
-
patient_data = get_patient_data(image_number)
|
| 83 |
-
if patient_data:
|
| 84 |
-
data_dict['image_path'].append(image_path)
|
| 85 |
-
keys_list = list(data_dict.keys())[1:]
|
| 86 |
-
for key, value in zip(keys_list, patient_data):
|
| 87 |
-
if key == 'image_path':
|
| 88 |
-
continue
|
| 89 |
-
data_dict[key].append(value)
|
| 90 |
-
|
| 91 |
-
if modality == 'image':
|
| 92 |
-
data_dict = {k: data_dict[k] for k in ['image_path', 'target']}
|
| 93 |
-
|
| 94 |
-
elif modality not in ['mm', 'tabular']:
|
| 95 |
-
raise AssertionError("Select Modality for Feature engineering!")
|
| 96 |
-
|
| 97 |
-
#Create a DataFrame from the dictionary
|
| 98 |
-
train_df = pd.DataFrame(data_dict)
|
| 99 |
-
|
| 100 |
-
#if only tabular use
|
| 101 |
-
if modality == 'tabular':
|
| 102 |
-
train_df = data
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
if mode == 'train' or mode == 'test':
|
| 106 |
-
print("--------------Class balance--------------")
|
| 107 |
-
# undersampling
|
| 108 |
-
majority_class = train_df[train_df['target'] == 1.0]
|
| 109 |
-
minority_class = train_df[train_df['target'] == 0.0]
|
| 110 |
-
|
| 111 |
-
# Undersample the majority class to match the number of '1's in the minority class
|
| 112 |
-
undersampled_majority_class = resample(majority_class,
|
| 113 |
-
replace=False,
|
| 114 |
-
n_samples=len(minority_class),
|
| 115 |
-
random_state=42)
|
| 116 |
-
|
| 117 |
-
# Concatenate minority class and undersampled majority class
|
| 118 |
-
data = pd.concat([undersampled_majority_class, minority_class])
|
| 119 |
-
|
| 120 |
-
# print("--------------Class imbalance--------------")
|
| 121 |
-
if smote: # Apply SMOTE if the flag is set
|
| 122 |
-
data = train_df
|
| 123 |
-
print(data['target'].value_counts())
|
| 124 |
-
print("Applying SMOTE...")
|
| 125 |
-
smote = SMOTE(sampling_strategy='all', random_state=42)
|
| 126 |
-
X_data = data.drop(columns=['target'])
|
| 127 |
-
y_data = data['target']
|
| 128 |
-
X_data_res, y_data_res = smote.fit_resample(X_data, y_data)
|
| 129 |
-
data_resampled = pd.DataFrame(X_data_res, columns=X_data.columns)
|
| 130 |
-
data_resampled['target'] = y_data_res
|
| 131 |
-
data = data_resampled # Update train_data with resampled data
|
| 132 |
-
print(data['target'].value_counts())
|
| 133 |
-
|
| 134 |
-
train_data, test_data = train_test_split(data, test_size=0.3, stratify=data['target'], random_state=123)
|
| 135 |
-
valid_data, test_data = train_test_split(test_data, test_size=0.4, stratify=test_data['target'], random_state=123)
|
| 136 |
-
|
| 137 |
-
if mode == 'train':
|
| 138 |
-
print("Train set shape:", train_data.shape)
|
| 139 |
-
print("Validation set shape:", valid_data.shape)
|
| 140 |
-
return train_data, valid_data
|
| 141 |
-
|
| 142 |
-
elif mode == 'test':
|
| 143 |
-
print("Test set shape:", test_data.shape)
|
| 144 |
-
return test_data
|
| 145 |
-
|
| 146 |
-
elif mode == 'pretrain' or mode == 'eval':
|
| 147 |
-
pretrain_data, eval_data = train_test_split(train_df, test_size=0.1, random_state=123)
|
| 148 |
-
if mode == 'pretrain':
|
| 149 |
-
print("Pretrain set shape:", pretrain_data.shape)
|
| 150 |
-
return pretrain_data
|
| 151 |
-
elif mode == 'eval':
|
| 152 |
-
print("Validation set shape:", eval_data.shape)
|
| 153 |
-
return eval_data
|
| 154 |
-
|
| 155 |
-
else:
|
| 156 |
-
raise ValueError("Choose mode!")
|
| 157 |
-
|
| 158 |
-
|
| 159 |
def parse_args(args):
|
| 160 |
parser = argparse.ArgumentParser(description="CBD Classification")
|
| 161 |
parser.add_argument('--data_dir', type=str, default="./")
|
|
@@ -231,107 +93,16 @@ def classify(tabular_data):
|
|
| 231 |
|
| 232 |
if __name__ == '__main__':
|
| 233 |
args = parse_args(sys.argv[1:])
|
| 234 |
-
train = load_data_and_prepare(args.data_dir, args.excel_file, args.
|
| 235 |
model = load_model(args.model_name_or_path)
|
| 236 |
device = torch.device(args.device)
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
[
|
| 242 |
-
[['1', '0', '0', '104', '24', '10.6', '171', '14.54', '236', '182', '12.33', '3.2', '72']],
|
| 243 |
-
"PT_NO = 10001862, VISIBLE_STONE_CT = True, REAL_STONE = True",
|
| 244 |
-
],
|
| 245 |
-
[
|
| 246 |
-
[['0', '1','0','106','18','13.6', '388', '21.13', '196', '118', '1.87', '2.7', '58']],
|
| 247 |
-
"PT_NO = 10007376, VISIBLE_STONE_CT = True, REAL_STONE = True",
|
| 248 |
-
],
|
| 249 |
-
[
|
| 250 |
-
[['1', '0','1','205','18','9.3', '103', '8.45', '440', '100', '4.21', '4.5', '63']],
|
| 251 |
-
"PT_NO = 10040285, VISIBLE_STONE_CT = False, REAL_STONE = True",
|
| 252 |
-
],
|
| 253 |
-
[
|
| 254 |
-
[['0', '1','1','130','20','12.1', '192', '8.63', '47', '59', '0.02', '0.4', '57']],
|
| 255 |
-
"PT_NO = 10005545, VISIBLE_STONE_CT = False, REAL_STONE = False",
|
| 256 |
-
],
|
| 257 |
-
]
|
| 258 |
-
|
| 259 |
-
tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
|
| 260 |
-
|
| 261 |
-
description = """
|
| 262 |
-
GPU 리소스 제약으로 인해, 온라인 데모에서는 NVIDIA RTX 3090 24GB를 사용하고 있습니다. \n
|
| 263 |
-
|
| 264 |
-
**Note**: 현재 저희 모델은 **총담관결석증**의 분석 및 진단을 중심으로 최적화되어 있으며, 정확하고 신뢰할 수 있는 결과를 제공합니다. \n
|
| 265 |
-
모델은 다음과 같은 입력 데이터를 처리하며, 아래와 같이 각각 **이산형(discrete)** **연속형(continuous)** 데이터로 처리됩니다. \n
|
| 266 |
-
|
| 267 |
-
- 이산형 변수:
|
| 268 |
-
- DUCT_DILIATATION_8MM
|
| 269 |
-
- DUCT_DILIATATION_10MM
|
| 270 |
-
- PANCREATITIS
|
| 271 |
-
|
| 272 |
-
- 연속형 변수:
|
| 273 |
-
- FIRST_SBP (Systolic blood pressure)
|
| 274 |
-
- FIRST_RR (Respiratory rate)
|
| 275 |
-
- Hb (Hemoglobin)
|
| 276 |
-
- PLT (Platelet)
|
| 277 |
-
- WBC (White Blood Cell)
|
| 278 |
-
- ALP (Alkaline Phosphatase)
|
| 279 |
-
- ALT (Alanine Aminotransferase)
|
| 280 |
-
- AST (Aspartate Aminotransferase)
|
| 281 |
-
- CRP (C-Reactive Protein)
|
| 282 |
-
- BILIRUBIN
|
| 283 |
-
- AGE
|
| 284 |
-
|
| 285 |
-
**중요**: 입력 데이터의 컬럼이 변경(추가, 삭제)될 경우, 모델의 예측 결과가 달라질 수 있습니다. \n
|
| 286 |
-
따라서 입력 데이터의 구조를 변경하기 전에 모델의 재학습 또는 재검증이 필요합니다. \n
|
| 287 |
-
"""
|
| 288 |
-
|
| 289 |
-
title_markdown = ("""
|
| 290 |
-
# 임상 데이터 기반 머신러닝을 이용한 총담관석 예측 모델
|
| 291 |
-
## Development of a Common Bile Duct Stone Prediction Model Using Machine Learning Based on Clinical Data
|
| 292 |
-
[📖[Learn more about Common Bile Duct Stones (총담관결석증)](https://namu.wiki/w/%EC%B4%9D%EB%8B%B4%EA%B4%80%EA%B2%B0%EC%84%9D%EC%A6%9D)]
|
| 293 |
-
### Copyright © 2024 Dongguk University (DGU) and Dongguk University Medical Center (DUMC). All rights reserved.
|
| 294 |
-
""")
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
# def explain_with_lime(tabular_data):
|
| 298 |
-
# """
|
| 299 |
-
# Apply LIME to explain predictions.
|
| 300 |
-
# Args:
|
| 301 |
-
# tabular_data (list): List of input data points (e.g., rows in a dataframe)
|
| 302 |
-
# Returns:
|
| 303 |
-
# str: HTML or image showing LIME explanation
|
| 304 |
-
# """
|
| 305 |
-
# input_data = np.array(tabular_data, dtype=float)
|
| 306 |
-
# explainer = LimeTabularExplainer(
|
| 307 |
-
# training_data=x_train.values, # Replace with your training data
|
| 308 |
-
# feature_names=tabular_header,
|
| 309 |
-
# class_names=['intermediate', 'High'], # Replace with actual class names
|
| 310 |
-
# mode='classification'
|
| 311 |
-
# )
|
| 312 |
-
|
| 313 |
-
# explanation = explainer.explain_instance(
|
| 314 |
-
# input_data[0], # Single instance to explain
|
| 315 |
-
# model.predict_proba, # Probability prediction function
|
| 316 |
-
# num_features=len(tabular_header)
|
| 317 |
-
# )
|
| 318 |
-
|
| 319 |
-
# # Plot LIME explanation
|
| 320 |
-
# fig = explanation.as_pyplot_figure()
|
| 321 |
-
# fig.set_size_inches(25, 8)
|
| 322 |
-
# buf = io.BytesIO()
|
| 323 |
-
# fig.savefig(buf, format='png')
|
| 324 |
-
# buf.seek(0)
|
| 325 |
-
# encoded_image = base64.b64encode(buf.read()).decode('utf-8')
|
| 326 |
-
# buf.close()
|
| 327 |
-
# plt.close(fig)
|
| 328 |
-
|
| 329 |
-
# return f"<img src='data:image/png;base64,{encoded_image}'/>"
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
tabular_header = ['DUCT_DILIATATION_8MM', 'DUCT_DILIATATION_10MM','PANCREATITIS','FIRST_SBP','FIRST_RR','Hb', 'PLT', 'WBC', 'ALP', 'AST', 'CRP', 'BILIRUBIN', 'AGE']
|
| 333 |
tabular_dtype = ['number'] * len(tabular_header)
|
| 334 |
|
|
|
|
| 335 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 336 |
gr.Markdown(title_markdown)
|
| 337 |
gr.Markdown(description)
|
|
|
|
| 12 |
from pycaret.classification import *
|
| 13 |
import warnings
|
| 14 |
warnings.filterwarnings("ignore", category=FutureWarning, module="torch.storage")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
+
from util import load_data
|
| 17 |
+
import view
|
| 18 |
|
| 19 |
|
| 20 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def parse_args(args):
|
| 22 |
parser = argparse.ArgumentParser(description="CBD Classification")
|
| 23 |
parser.add_argument('--data_dir', type=str, default="./")
|
|
|
|
| 93 |
|
| 94 |
if __name__ == '__main__':
|
| 95 |
args = parse_args(sys.argv[1:])
|
| 96 |
+
train = load_data_and_prepare(args.data_dir, args.excel_file, args.mode, args.scale, args.smote)
|
| 97 |
model = load_model(args.model_name_or_path)
|
| 98 |
device = torch.device(args.device)
|
| 99 |
+
examples = view.examples
|
| 100 |
+
description = view.description
|
| 101 |
+
title_markdown = view.title_markdown
|
| 102 |
+
tabular_header = view.tabular_header
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
tabular_dtype = ['number'] * len(tabular_header)
|
| 104 |
|
| 105 |
+
|
| 106 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 107 |
gr.Markdown(title_markdown)
|
| 108 |
gr.Markdown(description)
|