# from sklearn.ensemble import RandomForestClassifier import numpy as np import joblib import pandas as pd from sklearn.preprocessing import MinMaxScaler from lime import lime_tabular class MLModel(): def __init__(self): self.model = joblib.load('./ML_model/weights/xgboost_convnet_best.pkl') self.raw_clinical_data = pd.read_csv("./ML_model/data/clinical-2.csv") self.clinical_data = self.raw_clinical_data.copy() del self.clinical_data['ID'] del self.clinical_data['SIDE'] self.columns_to_scale = ['AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'KOOS PAIN SCORE'] self.scaler = MinMaxScaler() self.clinical_data[self.columns_to_scale] = self.scaler.fit_transform(self.clinical_data[self.columns_to_scale]) self.numeric_defaults = self.raw_clinical_data[self.columns_to_scale].median(numeric_only=True).to_dict() self.columns_to_convert = ['FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS'] self.mapping_dict = {} self.categorical_code_maps = {} for column in self.columns_to_convert: self.clinical_data[column], unique_values = pd.factorize(self.clinical_data[column]) self.mapping_dict[column] = unique_values self.categorical_code_maps[column] = {str(value): idx for idx, value in enumerate(unique_values)} self.clinical_feature_columns = [ 'AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS', 'KOOS PAIN SCORE' ] self.selected_features = joblib.load("./ML_model/data/selected_features.pkl") self.X_train = joblib.load("./ML_model/data/X_train.pkl") self.selected_feature_index = [self.X_train.columns.get_loc(col) for col in self.selected_features] self.explainer = lime_tabular.LimeTabularExplainer( self.X_train[self.selected_features].to_numpy(), feature_names=self.selected_features, class_names=['0', '1', '2', '3', '4'], mode='classification' ) def get_clinical_data(self, filename): row = self.clinical_data.loc[self.clinical_data['FILENAME'] == filename] return row.to_numpy()# remove filename col def has_clinical_data(self, filename): return self.clinical_data['FILENAME'].eq(filename).any() def get_categorical_options(self): return { column: [str(v) for v in self.mapping_dict[column]] for column in self.columns_to_convert } def preprocess_clinical_input(self, clinical): numeric_input = pd.DataFrame([{col: float(clinical[col]) for col in self.columns_to_scale}]) scaled_numeric = self.scaler.transform(numeric_input[self.columns_to_scale])[0] scaled_numeric_map = { col: scaled_numeric[idx] for idx, col in enumerate(self.columns_to_scale) } categorical_map = {} for column in self.columns_to_convert: value = str(clinical[column]) if value not in self.categorical_code_maps[column]: raise ValueError(f"Invalid value for {column}: {value}") categorical_map[column] = self.categorical_code_maps[column][value] features = [] for column in self.clinical_feature_columns: if column in scaled_numeric_map: features.append(float(scaled_numeric_map[column])) else: features.append(float(categorical_map[column])) return np.array(features, dtype=float) def _resolve_clinical_features(self, clinical=None, filename=None): if clinical is None: if not filename: raise ValueError("Need clinical data or filename in OAI database") clinical_row = self.get_clinical_data(filename) if clinical_row.shape[0] == 0: return None return clinical_row[0, 1:].astype(float) if isinstance(clinical, dict): return self.preprocess_clinical_input(clinical) clinical_array = np.asarray(clinical, dtype=float) if clinical_array.ndim == 2: clinical_array = clinical_array[0] if clinical_array.shape[0] != len(self.clinical_feature_columns): raise ValueError("Clinical input must have 14 features") return clinical_array def predict(self, overal_diagnosis, jsws, clinical=None, filename=None): clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename) if clinical_features is None: raise ValueError("clinical data not found") x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws))) return self.model.predict(x = x[self.selected_feature_index].reshape(1, -1)) # unsqueeze(0) def predict_explain(self, overal_diagnosis, jsws, clinical=None, filename=None): clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename) if clinical_features is None: raise ValueError("clinical data not found") x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws)))[self.selected_feature_index] exp = self.explainer.explain_instance(x, self.model.predict_proba, num_features=len(self.selected_features), top_labels=1) return self.model.predict(x.reshape(1, -1))[0], exp