xDesCO / ML_model /model.py
tuannt
update_check_clinical_data
7650763
# from sklearn.ensemble import RandomForestClassifier
import numpy as np
import joblib
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from lime import lime_tabular
class MLModel():
def __init__(self):
self.model = joblib.load('./ML_model/weights/xgboost_convnet_best.pkl')
self.raw_clinical_data = pd.read_csv("./ML_model/data/clinical-2.csv")
self.clinical_data = self.raw_clinical_data.copy()
del self.clinical_data['ID']
del self.clinical_data['SIDE']
self.columns_to_scale = ['AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'KOOS PAIN SCORE']
self.scaler = MinMaxScaler()
self.clinical_data[self.columns_to_scale] = self.scaler.fit_transform(self.clinical_data[self.columns_to_scale])
self.numeric_defaults = self.raw_clinical_data[self.columns_to_scale].median(numeric_only=True).to_dict()
self.columns_to_convert = ['FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS']
self.mapping_dict = {}
self.categorical_code_maps = {}
for column in self.columns_to_convert:
self.clinical_data[column], unique_values = pd.factorize(self.clinical_data[column])
self.mapping_dict[column] = unique_values
self.categorical_code_maps[column] = {str(value): idx for idx, value in enumerate(unique_values)}
self.clinical_feature_columns = [
'AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI',
'FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA',
'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC',
'CREPITUS', 'KOOS PAIN SCORE'
]
self.selected_features = joblib.load("./ML_model/data/selected_features.pkl")
self.X_train = joblib.load("./ML_model/data/X_train.pkl")
self.selected_feature_index = [self.X_train.columns.get_loc(col) for col in self.selected_features]
self.explainer = lime_tabular.LimeTabularExplainer(
self.X_train[self.selected_features].to_numpy(),
feature_names=self.selected_features,
class_names=['0', '1', '2', '3', '4'],
mode='classification'
)
def get_clinical_data(self, filename):
row = self.clinical_data.loc[self.clinical_data['FILENAME'] == filename]
return row.to_numpy()# remove filename col
def has_clinical_data(self, filename):
return self.clinical_data['FILENAME'].eq(filename).any()
def get_categorical_options(self):
return {
column: [str(v) for v in self.mapping_dict[column]]
for column in self.columns_to_convert
}
def preprocess_clinical_input(self, clinical):
numeric_input = pd.DataFrame([{col: float(clinical[col]) for col in self.columns_to_scale}])
scaled_numeric = self.scaler.transform(numeric_input[self.columns_to_scale])[0]
scaled_numeric_map = {
col: scaled_numeric[idx]
for idx, col in enumerate(self.columns_to_scale)
}
categorical_map = {}
for column in self.columns_to_convert:
value = str(clinical[column])
if value not in self.categorical_code_maps[column]:
raise ValueError(f"Invalid value for {column}: {value}")
categorical_map[column] = self.categorical_code_maps[column][value]
features = []
for column in self.clinical_feature_columns:
if column in scaled_numeric_map:
features.append(float(scaled_numeric_map[column]))
else:
features.append(float(categorical_map[column]))
return np.array(features, dtype=float)
def _resolve_clinical_features(self, clinical=None, filename=None):
if clinical is None:
if not filename:
raise ValueError("Need clinical data or filename in OAI database")
clinical_row = self.get_clinical_data(filename)
if clinical_row.shape[0] == 0:
return None
return clinical_row[0, 1:].astype(float)
if isinstance(clinical, dict):
return self.preprocess_clinical_input(clinical)
clinical_array = np.asarray(clinical, dtype=float)
if clinical_array.ndim == 2:
clinical_array = clinical_array[0]
if clinical_array.shape[0] != len(self.clinical_feature_columns):
raise ValueError("Clinical input must have 14 features")
return clinical_array
def predict(self, overal_diagnosis, jsws, clinical=None, filename=None):
clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename)
if clinical_features is None:
raise ValueError("clinical data not found")
x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws)))
return self.model.predict(x = x[self.selected_feature_index].reshape(1, -1)) # unsqueeze(0)
def predict_explain(self, overal_diagnosis, jsws, clinical=None, filename=None):
clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename)
if clinical_features is None:
raise ValueError("clinical data not found")
x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws)))[self.selected_feature_index]
exp = self.explainer.explain_instance(x, self.model.predict_proba, num_features=len(self.selected_features), top_labels=1)
return self.model.predict(x.reshape(1, -1))[0], exp