File size: 5,588 Bytes
036e7c4
 
 
 
 
 
 
 
 
 
 
7650763
 
036e7c4
 
 
 
 
 
7650763
036e7c4
 
 
7650763
036e7c4
 
 
7650763
 
 
 
 
 
 
 
036e7c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7650763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
036e7c4
7650763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
036e7c4
7650763
 
 
 
 
 
 
 
036e7c4
 
 
 
7650763
 
 
036e7c4
7650763
036e7c4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# from sklearn.ensemble import RandomForestClassifier
import numpy as np
import joblib
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from lime import lime_tabular

class MLModel():
    def __init__(self):
        self.model = joblib.load('./ML_model/weights/xgboost_convnet_best.pkl')

        self.raw_clinical_data = pd.read_csv("./ML_model/data/clinical-2.csv")
        self.clinical_data = self.raw_clinical_data.copy()
        del self.clinical_data['ID']
        del self.clinical_data['SIDE']

        self.columns_to_scale = ['AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI', 'KOOS PAIN SCORE']
        self.scaler = MinMaxScaler()
        self.clinical_data[self.columns_to_scale] = self.scaler.fit_transform(self.clinical_data[self.columns_to_scale])
        self.numeric_defaults = self.raw_clinical_data[self.columns_to_scale].median(numeric_only=True).to_dict()

        self.columns_to_convert = ['FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA', 'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC', 'CREPITUS']
        self.mapping_dict = {}
        self.categorical_code_maps = {}
        for column in self.columns_to_convert:
            self.clinical_data[column], unique_values = pd.factorize(self.clinical_data[column])
            self.mapping_dict[column] = unique_values
            self.categorical_code_maps[column] = {str(value): idx for idx, value in enumerate(unique_values)}

        self.clinical_feature_columns = [
            'AGE', 'HEIGHT', 'WEIGHT', 'MAX WEIGHT', 'BMI',
            'FREQUENT PAIN', 'SURGERY', 'RISK', 'SXKOA',
            'SWELLING', 'BENDING FULLY', 'SYMPTOMATIC',
            'CREPITUS', 'KOOS PAIN SCORE'
        ]

        self.selected_features = joblib.load("./ML_model/data/selected_features.pkl")
        self.X_train = joblib.load("./ML_model/data/X_train.pkl")
        self.selected_feature_index = [self.X_train.columns.get_loc(col) for col in self.selected_features]

        self.explainer = lime_tabular.LimeTabularExplainer(
            self.X_train[self.selected_features].to_numpy(),
            feature_names=self.selected_features,
            class_names=['0', '1', '2', '3', '4'],
            mode='classification'
        )
                
    def get_clinical_data(self, filename):
        row = self.clinical_data.loc[self.clinical_data['FILENAME'] == filename]
        return row.to_numpy()# remove filename col

    def has_clinical_data(self, filename):
        return self.clinical_data['FILENAME'].eq(filename).any()

    def get_categorical_options(self):
        return {
            column: [str(v) for v in self.mapping_dict[column]]
            for column in self.columns_to_convert
        }

    def preprocess_clinical_input(self, clinical):
        numeric_input = pd.DataFrame([{col: float(clinical[col]) for col in self.columns_to_scale}])
        scaled_numeric = self.scaler.transform(numeric_input[self.columns_to_scale])[0]
        scaled_numeric_map = {
            col: scaled_numeric[idx]
            for idx, col in enumerate(self.columns_to_scale)
        }

        categorical_map = {}
        for column in self.columns_to_convert:
            value = str(clinical[column])
            if value not in self.categorical_code_maps[column]:
                raise ValueError(f"Invalid value for {column}: {value}")
            categorical_map[column] = self.categorical_code_maps[column][value]

        features = []
        for column in self.clinical_feature_columns:
            if column in scaled_numeric_map:
                features.append(float(scaled_numeric_map[column]))
            else:
                features.append(float(categorical_map[column]))

        return np.array(features, dtype=float)

    def _resolve_clinical_features(self, clinical=None, filename=None):
        if clinical is None:
            if not filename:
                raise ValueError("Need clinical data or filename in OAI database")
            clinical_row = self.get_clinical_data(filename)
            if clinical_row.shape[0] == 0:
                return None
            return clinical_row[0, 1:].astype(float)

        if isinstance(clinical, dict):
            return self.preprocess_clinical_input(clinical)

        clinical_array = np.asarray(clinical, dtype=float)
        if clinical_array.ndim == 2:
            clinical_array = clinical_array[0]

        if clinical_array.shape[0] != len(self.clinical_feature_columns):
            raise ValueError("Clinical input must have 14 features")

        return clinical_array

    def predict(self, overal_diagnosis, jsws, clinical=None, filename=None):
        clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename)
        if clinical_features is None:
            raise ValueError("clinical data not found")

        x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws)))

        return self.model.predict(x = x[self.selected_feature_index].reshape(1, -1)) # unsqueeze(0) 
    
    def predict_explain(self, overal_diagnosis, jsws, clinical=None, filename=None):
        clinical_features = self._resolve_clinical_features(clinical=clinical, filename=filename)
        if clinical_features is None:
            raise ValueError("clinical data not found")
        
        x = np.concatenate((overal_diagnosis, clinical_features, np.array(jsws)))[self.selected_feature_index]

        exp = self.explainer.explain_instance(x, self.model.predict_proba, num_features=len(self.selected_features), top_labels=1)
        return self.model.predict(x.reshape(1, -1))[0], exp