File size: 17,084 Bytes
5dc5516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f51f04a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
import gradio as gr
import pickle
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import BaggingClassifier, RandomForestClassifier
from xgboost import XGBClassifier

# --- Global Dictionaries & Lists ---
LOADED_OBJECTS_LOGREG = {}
LOADED_OBJECTS_TREE = {}
MODEL_CACHE = {}
logreg_ui_feature_order = []
tree_ui_feature_order = []
logreg_inputs_ui_components = []
tree_inputs_ui_components = []

# --- Default Sample Input Data ---
# These should contain values for *all* features in their respective UI feature orders
# For features that are often imputed, you can put a typical non-NaN value or np.nan if you want the default to show "blank then imputed"
# (Gradio handles np.nan for gr.Number by showing it as blank)

# Sample for Logistic Regression UI features
# Keys should match elements in `logreg_ui_feature_order`
DEFAULT_LOGREG_INPUTS = {
    'AGE': 28,
    'AVG_SPEED': 4.1, # Imputable
    'DIST_MILES': 2.6, # Imputable
    'DRIVE_FGA': 3.0,
    'GP': 70,
    'PACE': 99.5,
    'PAINT_TOUCHES': 5.5,
    'PLAYER_HEIGHT_INCHES': 78, # Imputable
    'PLAYER_WEIGHT': 215, # Imputable
    'POSS': 7100,
    'PULL_UP_FG3A': 1.8 # Imputable
    # Add any other features that end up in logreg_ui_feature_order
}

# Sample for Tree Models UI features (this will be longer)
# Keys should match elements in `tree_ui_feature_order`
DEFAULT_TREE_INPUTS = {
    'AGE': 26,
    'PLAYER_HEIGHT_INCHES': 79, # Imputable
    'PLAYER_WEIGHT': 220, # Imputable
    'GP': 65,
    'MIN': 32.0,
    'USG_PCT': 0.22,
    'PACE': 100.0,
    'POSS': 7200,
    'FGA_PG': 16.0,
    'DRIVES': 11.0,
    'DRIVE_FGA': 5.5,
    'DRIVE_PASSES': 5.0,
    'DIST_MILES': 2.7, # Imputable
    'AVG_SPEED': 4.2,  # Imputable
    'PULL_UP_FGA': 5.5,
    'PULL_UP_FG3A': 2.2, # Imputable
    'TOUCHES': 65.0,
    'FRONT_CT_TOUCHES': 33.0,
    'AVG_SEC_PER_TOUCH': 2.8,
    'AVG_DRIB_PER_TOUCH': 2.1,
    'ELBOW_TOUCHES': 2.3,
    'POST_TOUCHES': 3.5,
    'PAINT_TOUCHES': 6.0
    # Add all other features that are in tree_ui_feature_order
}


# --- Loading Functions (Modified to define UI feature orders earlier) ---
def load_logreg_dependencies():
    global logreg_ui_feature_order
    if not LOADED_OBJECTS_LOGREG:
        print("Loading Logistic Regression dependencies...")
        try:
            with open('simple_imputer_logreg.pkl', 'rb') as f:
                LOADED_OBJECTS_LOGREG['imputer'] = pickle.load(f)
            with open('cols_to_impute_logreg.pkl', 'rb') as f:
                LOADED_OBJECTS_LOGREG['cols_to_impute'] = pickle.load(f)
            with open('stepwise_features_logreg.pkl', 'rb') as f:
                LOADED_OBJECTS_LOGREG['features'] = pickle.load(f)
            with open('logisticregression.pkl', 'rb') as f:
                MODEL_CACHE['logistic_regression'] = pickle.load(f)

            _model_feats = LOADED_OBJECTS_LOGREG.get('features', [])
            _impute_cols = LOADED_OBJECTS_LOGREG.get('cols_to_impute', [])
            # All unique features needed for the UI, then sort for consistent order
            logreg_ui_feature_order = sorted(list(set(_model_feats + _impute_cols)))
            print("Logistic Regression dependencies loaded successfully.")
            print(f"LogReg UI features order set: {logreg_ui_feature_order}")
            return True
        except FileNotFoundError as e: # ... (rest of error handling)
            print(f"Error loading LogReg file: {e.filename}")
            gr.Warning(f"Required file for Logistic Regression not found: {e.filename}. This model will not work.")
            return False
        except Exception as e:
            print(f"Error loading LogReg dependencies: {e}")
            gr.Warning(f"Error loading dependencies for Logistic Regression: {e}")
            return False
    return True

def load_tree_dependencies():
    global tree_ui_feature_order
    if not LOADED_OBJECTS_TREE:
        print("Loading Tree-based model dependencies...")
        try:
            with open('simple_imputer_tree.pkl', 'rb') as f:
                LOADED_OBJECTS_TREE['imputer'] = pickle.load(f)
            with open('cols_to_impute_tree.pkl', 'rb') as f:
                LOADED_OBJECTS_TREE['cols_to_impute'] = pickle.load(f)
            with open('tree_model_feature_columns.pkl', 'rb') as f:
                LOADED_OBJECTS_TREE['features'] = pickle.load(f)
                tree_ui_feature_order = LOADED_OBJECTS_TREE['features']

            tree_model_files = { # ... (model loading)
                "Optimized Decision Tree": "optimized_decision_tree_model.pkl",
                "Bagging Classifier": "bagging_classifier_model.pkl",
                "XGBoost Classifier": "xgboost_classifier_model.pkl",
                "Random Forest Classifier": "random_forest_classifier_model.pkl"
            }
            for name, path in tree_model_files.items():
                with open(path, 'rb') as f:
                    MODEL_CACHE[name.lower().replace(" ", "_")] = pickle.load(f)
            print("Tree-based model dependencies loaded successfully.")
            print(f"Tree UI features order set: {tree_ui_feature_order}")
            return True
        except FileNotFoundError as e: # ... (rest of error handling)
            print(f"Error loading Tree model file: {e.filename}")
            gr.Warning(f"Required file for Tree models not found: {e.filename}. Some tree models may not work.")
            return False
        except Exception as e:
            print(f"Error loading Tree dependencies: {e}")
            gr.Warning(f"Error loading dependencies for Tree models: {e}")
            return False
    return True

LOGREG_READY = load_logreg_dependencies()
TREE_READY = load_tree_dependencies()

# --- Prediction Function (no changes needed here from the previous version for default values) ---
def predict_injury(model_choice, *all_ui_values):
    # ... (Prediction function remains the same as the previous version)
    print(f"Model choice: {model_choice}")
    # all_ui_values contains values from logreg_inputs_ui_components then tree_inputs_ui_components

    raw_input_data = {} # This will hold {feature_name: value} for the chosen model's UI features

    if model_choice == "Logistic Regression":
        if not LOGREG_READY or 'logistic_regression' not in MODEL_CACHE:
            return "Logistic Regression model not loaded. Check logs.", "", ""

        num_logreg_ui_fields = len(logreg_ui_feature_order)
        current_args = all_ui_values[:num_logreg_ui_fields]
        current_ui_feature_order_for_dict_zip = logreg_ui_feature_order # Order of features as in UI

        imputer = LOADED_OBJECTS_LOGREG['imputer']
        cols_to_impute_for_model = LOADED_OBJECTS_LOGREG['cols_to_impute']
        final_model_features = LOADED_OBJECTS_LOGREG['features'] # Stepwise features
        model = MODEL_CACHE['logistic_regression']

        raw_input_data = dict(zip(current_ui_feature_order_for_dict_zip, current_args))

    elif model_choice in ["Optimized Decision Tree", "Bagging Classifier", "XGBoost Classifier", "Random Forest Classifier"]:
        if not TREE_READY or model_choice.lower().replace(" ", "_") not in MODEL_CACHE:
            return f"{model_choice} model not loaded. Check logs.", "", ""

        num_logreg_ui_fields = len(logreg_ui_feature_order)
        # num_tree_ui_fields = len(tree_ui_feature_order) # Not needed directly for slicing here
        current_args = all_ui_values[num_logreg_ui_fields:] # Get the rest of args for tree models
        current_ui_feature_order_for_dict_zip = tree_ui_feature_order # Order of features as in UI

        imputer = LOADED_OBJECTS_TREE['imputer']
        cols_to_impute_for_model = LOADED_OBJECTS_TREE['cols_to_impute']
        final_model_features = LOADED_OBJECTS_TREE['features'] # Full feature list
        model = MODEL_CACHE[model_choice.lower().replace(" ", "_")]

        raw_input_data = dict(zip(current_ui_feature_order_for_dict_zip, current_args))
    else:
        return "Invalid model choice.", "", ""

    print(f"Raw input data mapped for model: {raw_input_data}")

    processed_values = {}
    for key, value in raw_input_data.items():
        if isinstance(value, str) and value.strip() == "":
            processed_values[key] = np.nan
        elif value is None: # Gradio number input might return None if empty and no default
            processed_values[key] = np.nan
        else:
            try:
                processed_values[key] = float(value)
            except (ValueError, TypeError):
                if key not in cols_to_impute_for_model:
                    return f"Invalid input for '{key}'. Please provide a number.", f"Error with input '{key}'", ""
                processed_values[key] = np.nan # Ensure it's np.nan for imputer

    df = pd.DataFrame([processed_values])
    # Ensure df columns match the order expected by processing steps (usually current_ui_feature_order)
    # This ensures that when we select `cols_to_impute_for_model` or `final_model_features`, they exist.
    # We need to ensure all keys from current_ui_feature_order_for_dict_zip are columns in df
    df = df.reindex(columns=current_ui_feature_order_for_dict_zip)


    print("DataFrame before imputation:\n", df.head())

    # Imputation
    try:
        imputable_cols_in_df = [col for col in cols_to_impute_for_model if col in df.columns]
        if imputable_cols_in_df:
            # Create a copy for transform to avoid SettingWithCopyWarning if df[imputable_cols_in_df] is a view
            df_subset_to_impute = df[imputable_cols_in_df].copy()
            transformed_subset = imputer.transform(df_subset_to_impute)
            df[imputable_cols_in_df] = transformed_subset

        print("DataFrame after imputation:\n", df.head())
    except Exception as e:
        print(f"Error during imputation: {e}")
        return f"Error during data imputation: {e}", str(e), ""

    # Feature selection and ordering for the *final model input*
    try:
        final_df = df[final_model_features]
        print("Final DataFrame for model prediction:\n", final_df.head())
    except KeyError as e:
        print(f"Error: Missing feature column for the model's final input: {e}")
        return f"Internal error: Missing expected feature {e}. Check model feature lists.", f"KeyError: {e}", ""
    except Exception as e:
        print(f"Error during final feature selection: {e}")
        return f"Error during data preparation: {e}", str(e), ""

    # Prediction
    try:
        pred_proba = model.predict_proba(final_df)[0]
        prediction = model.predict(final_df)[0]
        proba_injured = pred_proba[1]
        proba_not_injured = pred_proba[0]
        if prediction == 1:
            result_text = "Prediction: Likely Injured"
            probability_text = f"Probability of Injury: {proba_injured:.2%}\nProbability of No Injury: {proba_not_injured:.2%}"
        else:
            result_text = "Prediction: Likely Not Injured"
            probability_text = f"Probability of No Injury: {proba_not_injured:.2%}\nProbability of Injury: {proba_injured:.2%}"
        # Added a third output for more detailed info/status if needed
        return result_text, probability_text, "Prediction successful."
    except Exception as e:
        print(f"Error during prediction: {e}")
        return f"Error during prediction: {e}", str(e), ""


# --- Create Gradio UI components with default values ---
if LOGREG_READY and logreg_ui_feature_order: # Check if order is populated
    print(f"Generating UI components for LogReg with defaults. UI Features: {logreg_ui_feature_order}")
    for feature in logreg_ui_feature_order:
        label = feature.replace("_", " ").title()
        default_value = DEFAULT_LOGREG_INPUTS.get(feature, None) # Get default, or None
        if feature in LOADED_OBJECTS_LOGREG.get('cols_to_impute', []):
            label += " (can be blank for mean imputation)"
        logreg_inputs_ui_components.append(
            gr.Number(label=label, value=default_value, elem_id=f"logreg_{feature}")
        )
else:
    if not LOGREG_READY: print("Logistic Regression dependencies not loaded.")
    if not logreg_ui_feature_order: print("LogReg UI feature order not set.")

if TREE_READY and tree_ui_feature_order: # Check if order is populated
    print(f"Generating UI components for Tree Models with defaults. UI Features: {tree_ui_feature_order}")
    for feature in tree_ui_feature_order:
        label = feature.replace("_", " ").title()
        default_value = DEFAULT_TREE_INPUTS.get(feature, None)
        if feature in LOADED_OBJECTS_TREE.get('cols_to_impute', []):
            label += " (can be blank for mean imputation)"
        tree_inputs_ui_components.append(
            gr.Number(label=label, value=default_value, elem_id=f"tree_{feature}")
        )
else:
    if not TREE_READY: print("Tree model dependencies not loaded.")
    if not tree_ui_feature_order: print("Tree UI feature order not set.")


# --- Create Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
    gr.Markdown("<h1>NBA Player Injury Predictor</h1>", elem_id="main_title") # Changed to h1
    gr.Markdown("Select a model and enter player statistics. Fields are pre-filled with sample data. Adjust as needed. "
                "Blank fields for certain stats will be mean-imputed by the model.", elem_id="sub_title")

    model_choices = []
    if LOGREG_READY and logreg_inputs_ui_components: model_choices.append("Logistic Regression") # Check if components were created
    if TREE_READY and tree_inputs_ui_components: model_choices.extend(["Optimized Decision Tree", "Bagging Classifier", "XGBoost Classifier", "Random Forest Classifier"])

    if not model_choices:
        gr.Markdown("## CRITICAL ERROR: No models could be loaded or UI components generated. Check console logs.")
    else:
        model_selector = gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0], elem_id="model_selector_dropdown")

        # Output components
        output_prediction = gr.Label(label="Prediction Result", elem_id="prediction_label")
        output_probability = gr.Textbox(label="Probabilities", lines=2, interactive=False, elem_id="probability_textbox") # Reduced lines
        # Optional status/debug message output
        output_status = gr.Textbox(label="Status", lines=1, interactive=False, elem_id="status_textbox", visible=False) # Initially hidden


        # Determine initial visibility based on the first choice
        initial_model_is_logreg = model_choices[0] == "Logistic Regression"

        with gr.Column(visible=initial_model_is_logreg, elem_id="logreg_inputs_column") as logreg_inputs_group:
            gr.Markdown("<h3>Logistic Regression Inputs</h3>", elem_id="logreg_header") # Changed to h3
            if LOGREG_READY and logreg_inputs_ui_components:
                with gr.Accordion("Show/Hide Logistic Regression Stats", open=True): # Added Accordion
                    for component in logreg_inputs_ui_components:
                        component.render()
            else:
                gr.Markdown("*(Logistic Regression model failed to load or has no inputs defined)*")

        with gr.Column(visible=not initial_model_is_logreg, elem_id="tree_inputs_column") as tree_inputs_group:
            gr.Markdown("<h3>Tree-Based Model Inputs</h3>", elem_id="tree_header") # Changed to h3
            if TREE_READY and tree_inputs_ui_components:
                with gr.Accordion("Show/Hide Tree Model Stats", open=True): # Added Accordion
                    for component in tree_inputs_ui_components:
                        component.render()
            else:
                gr.Markdown("*(Tree-based models failed to load or have no inputs defined)*")

        predict_button = gr.Button("Predict Injury", variant="primary", elem_id="predict_button_main") # Added variant

        def update_input_visibility(selected_model_value):
            is_logreg = selected_model_value == "Logistic Regression"
            return { # Return dict for gr.update
                logreg_inputs_group: gr.update(visible=is_logreg),
                tree_inputs_group: gr.update(visible=not is_logreg)
            }

        model_selector.change(
            fn=update_input_visibility,
            inputs=model_selector,
            outputs=[logreg_inputs_group, tree_inputs_group]
        )

        all_ui_input_components_ordered = logreg_inputs_ui_components + tree_inputs_ui_components

        predict_button.click(
            fn=predict_injury,
            inputs=[model_selector] + all_ui_input_components_ordered,
            # predict_injury now returns 3 values, last one is status
            outputs=[output_prediction, output_probability, output_status]
        )

if __name__ == "__main__":
    if not model_choices:
        print("No models available. Gradio app may not function correctly.")
    app.launch() # show_error for better debugging