| |
| import numpy as np |
| import joblib |
| import pandas as pd |
| from sklearn.preprocessing import MinMaxScaler |
| from lime import lime_tabular |
|
|
| class MLModel(): |
| def __init__(self): |
| self.model = joblib.load('./ML_model/weights/xgboost_convnet_best.pkl') |
|
|
| self.raw_clinical_data = pd.read_csv("./ML_model/data/clinical-2.csv") |
| self.clinical_data = self.raw_clinical_data.copy() |
| del self.clinical_data['ID'] |
| del self.clinical_data['SIDE'] |
|
|
| self.columns_to_scale = ['AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'KOOS PAIN SCORE'] |
| self.scaler = MinMaxScaler() |
| self.clinical_data[self.columns_to_scale] = self.scaler.fit_transform(self.clinical_data[self.columns_to_scale]) |
| self.numeric_defaults = self.raw_clinical_data[self.columns_to_scale].median(numeric_only=True).to_dict() |
|
|
| self.columns_to_convert = ['FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS'] |
| self.mapping_dict = {} |
| self.categorical_code_maps = {} |
| for column in self.columns_to_convert: |
| self.clinical_data[column], unique_values = pd.factorize(self.clinical_data[column]) |
| self.mapping_dict[column] = unique_values |
| self.categorical_code_maps[column] = {str(value): idx for idx, value in enumerate(unique_values)} |
|
|
| self.clinical_feature_columns = [ |
| 'AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', |
| 'FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', |
| 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', |
| 'CREPITUS', 'KOOS PAIN SCORE' |
| ] |
|
|
| self.selected_features = joblib.load("./ML_model/data/selected_features.pkl") |
| self.X_train = joblib.load("./ML_model/data/X_train.pkl") |
| self.selected_feature_index = [self.X_train.columns.get_loc(col) for col in self.selected_features] |
|
|
| self.explainer = lime_tabular.LimeTabularExplainer( |
| self.X_train[self.selected_features].to_numpy(), |
| feature_names=self.selected_features, |
| class_names=['0', '1', '2', '3', '4'], |
| mode='classification' |
| ) |
| |
| def get_clinical_data(self, filename): |
| row = self.clinical_data.loc[self.clinical_data['FILENAME'] == filename] |
| return row.to_numpy() |
|
|
| def has_clinical_data(self, filename): |
| return self.clinical_data['FILENAME'].eq(filename).any() |
|
|
| def get_categorical_options(self): |
| return { |
| column: [str(v) for v in self.mapping_dict[column]] |
| for column in self.columns_to_convert |
| } |
|
|
| def preprocess_clinical_input(self, clinical): |
| numeric_input = pd.DataFrame([{col: float(clinical[col]) for col in self.columns_to_scale}]) |
| scaled_numeric = self.scaler.transform(numeric_input[self.columns_to_scale])[0] |
| scaled_numeric_map = { |
| col: scaled_numeric[idx] |
| for idx, col in enumerate(self.columns_to_scale) |
| } |
|
|
| categorical_map = {} |
| for column in self.columns_to_convert: |
| value = str(clinical[column]) |
| if value not in self.categorical_code_maps[column]: |
| raise ValueError(f"Invalid value for {column}: {value}") |
| categorical_map[column] = self.categorical_code_maps[column][value] |
|
|
| features = [] |
| for column in self.clinical_feature_columns: |
| if column in scaled_numeric_map: |
| features.append(float(scaled_numeric_map[column])) |
| else: |
| features.append(float(categorical_map[column])) |
|
|
| return np.array(features, dtype=float) |
|
|
| def _resolve_clinical_features(self, clinical=None, filename=None): |
| if clinical is None: |
| if not filename: |
| raise ValueError("Need clinical data or filename in OAI database") |
| clinical_row = self.get_clinical_data(filename) |
| if clinical_row.shape[0] == 0: |
| return None |
| return clinical_row[0, 1:].astype(float) |
|
|
| if isinstance(clinical, dict): |
| return self.preprocess_clinical_input(clinical) |
|
|
| clinical_array = np.asarray(clinical, dtype=float) |
| if clinical_array.ndim == 2: |
| clinical_array = clinical_array[0] |
|
|
| if clinical_array.shape[0] != len(self.clinical_feature_columns): |
| raise ValueError("Clinical input must have 14 features") |
|
|
| return clinical_array |
|
|
| def predict(self, overal_diagnosis, jsws, clinical=None, filename=None): |
| clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename) |
| if clinical_features is None: |
| raise ValueError("clinical data not found") |
|
|
| x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws))) |
|
|
| return self.model.predict(x = x[self.selected_feature_index].reshape(1, -1)) |
| |
| def predict_explain(self, overal_diagnosis, jsws, clinical=None, filename=None): |
| clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename) |
| if clinical_features is None: |
| raise ValueError("clinical data not found") |
| |
| x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws)))[self.selected_feature_index] |
|
|
| exp = self.explainer.explain_instance(x, self.model.predict_proba, num_features=len(self.selected_features), top_labels=1) |
| return self.model.predict(x.reshape(1, -1))[0], exp |
|
|