NBA / app.py
triflix's picture
Update app.py
f51f04a verified
import gradio as gr
import pickle
import pandas as pd
import numpy as np
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import BaggingClassifier, RandomForestClassifier
from xgboost import XGBClassifier
# --- Global Dictionaries & Lists ---
LOADED_OBJECTS_LOGREG = {}
LOADED_OBJECTS_TREE = {}
MODEL_CACHE = {}
logreg_ui_feature_order = []
tree_ui_feature_order = []
logreg_inputs_ui_components = []
tree_inputs_ui_components = []
# --- Default Sample Input Data ---
# These should contain values for *all* features in their respective UI feature orders
# For features that are often imputed, you can put a typical non-NaN value or np.nan if you want the default to show "blank then imputed"
# (Gradio handles np.nan for gr.Number by showing it as blank)
# Sample for Logistic Regression UI features
# Keys should match elements in `logreg_ui_feature_order`
DEFAULT_LOGREG_INPUTS = {
'AGE': 28,
'AVG_SPEED': 4.1, # Imputable
'DIST_MILES': 2.6, # Imputable
'DRIVE_FGA': 3.0,
'GP': 70,
'PACE': 99.5,
'PAINT_TOUCHES': 5.5,
'PLAYER_HEIGHT_INCHES': 78, # Imputable
'PLAYER_WEIGHT': 215, # Imputable
'POSS': 7100,
'PULL_UP_FG3A': 1.8 # Imputable
# Add any other features that end up in logreg_ui_feature_order
}
# Sample for Tree Models UI features (this will be longer)
# Keys should match elements in `tree_ui_feature_order`
DEFAULT_TREE_INPUTS = {
'AGE': 26,
'PLAYER_HEIGHT_INCHES': 79, # Imputable
'PLAYER_WEIGHT': 220, # Imputable
'GP': 65,
'MIN': 32.0,
'USG_PCT': 0.22,
'PACE': 100.0,
'POSS': 7200,
'FGA_PG': 16.0,
'DRIVES': 11.0,
'DRIVE_FGA': 5.5,
'DRIVE_PASSES': 5.0,
'DIST_MILES': 2.7, # Imputable
'AVG_SPEED': 4.2, # Imputable
'PULL_UP_FGA': 5.5,
'PULL_UP_FG3A': 2.2, # Imputable
'TOUCHES': 65.0,
'FRONT_CT_TOUCHES': 33.0,
'AVG_SEC_PER_TOUCH': 2.8,
'AVG_DRIB_PER_TOUCH': 2.1,
'ELBOW_TOUCHES': 2.3,
'POST_TOUCHES': 3.5,
'PAINT_TOUCHES': 6.0
# Add all other features that are in tree_ui_feature_order
}
# --- Loading Functions (Modified to define UI feature orders earlier) ---
def load_logreg_dependencies():
global logreg_ui_feature_order
if not LOADED_OBJECTS_LOGREG:
print("Loading Logistic Regression dependencies...")
try:
with open('simple_imputer_logreg.pkl', 'rb') as f:
LOADED_OBJECTS_LOGREG['imputer'] = pickle.load(f)
with open('cols_to_impute_logreg.pkl', 'rb') as f:
LOADED_OBJECTS_LOGREG['cols_to_impute'] = pickle.load(f)
with open('stepwise_features_logreg.pkl', 'rb') as f:
LOADED_OBJECTS_LOGREG['features'] = pickle.load(f)
with open('logisticregression.pkl', 'rb') as f:
MODEL_CACHE['logistic_regression'] = pickle.load(f)
_model_feats = LOADED_OBJECTS_LOGREG.get('features', [])
_impute_cols = LOADED_OBJECTS_LOGREG.get('cols_to_impute', [])
# All unique features needed for the UI, then sort for consistent order
logreg_ui_feature_order = sorted(list(set(_model_feats + _impute_cols)))
print("Logistic Regression dependencies loaded successfully.")
print(f"LogReg UI features order set: {logreg_ui_feature_order}")
return True
except FileNotFoundError as e: # ... (rest of error handling)
print(f"Error loading LogReg file: {e.filename}")
gr.Warning(f"Required file for Logistic Regression not found: {e.filename}. This model will not work.")
return False
except Exception as e:
print(f"Error loading LogReg dependencies: {e}")
gr.Warning(f"Error loading dependencies for Logistic Regression: {e}")
return False
return True
def load_tree_dependencies():
global tree_ui_feature_order
if not LOADED_OBJECTS_TREE:
print("Loading Tree-based model dependencies...")
try:
with open('simple_imputer_tree.pkl', 'rb') as f:
LOADED_OBJECTS_TREE['imputer'] = pickle.load(f)
with open('cols_to_impute_tree.pkl', 'rb') as f:
LOADED_OBJECTS_TREE['cols_to_impute'] = pickle.load(f)
with open('tree_model_feature_columns.pkl', 'rb') as f:
LOADED_OBJECTS_TREE['features'] = pickle.load(f)
tree_ui_feature_order = LOADED_OBJECTS_TREE['features']
tree_model_files = { # ... (model loading)
"Optimized Decision Tree": "optimized_decision_tree_model.pkl",
"Bagging Classifier": "bagging_classifier_model.pkl",
"XGBoost Classifier": "xgboost_classifier_model.pkl",
"Random Forest Classifier": "random_forest_classifier_model.pkl"
}
for name, path in tree_model_files.items():
with open(path, 'rb') as f:
MODEL_CACHE[name.lower().replace(" ", "_")] = pickle.load(f)
print("Tree-based model dependencies loaded successfully.")
print(f"Tree UI features order set: {tree_ui_feature_order}")
return True
except FileNotFoundError as e: # ... (rest of error handling)
print(f"Error loading Tree model file: {e.filename}")
gr.Warning(f"Required file for Tree models not found: {e.filename}. Some tree models may not work.")
return False
except Exception as e:
print(f"Error loading Tree dependencies: {e}")
gr.Warning(f"Error loading dependencies for Tree models: {e}")
return False
return True
LOGREG_READY = load_logreg_dependencies()
TREE_READY = load_tree_dependencies()
# --- Prediction Function (no changes needed here from the previous version for default values) ---
def predict_injury(model_choice, *all_ui_values):
# ... (Prediction function remains the same as the previous version)
print(f"Model choice: {model_choice}")
# all_ui_values contains values from logreg_inputs_ui_components then tree_inputs_ui_components
raw_input_data = {} # This will hold {feature_name: value} for the chosen model's UI features
if model_choice == "Logistic Regression":
if not LOGREG_READY or 'logistic_regression' not in MODEL_CACHE:
return "Logistic Regression model not loaded. Check logs.", "", ""
num_logreg_ui_fields = len(logreg_ui_feature_order)
current_args = all_ui_values[:num_logreg_ui_fields]
current_ui_feature_order_for_dict_zip = logreg_ui_feature_order # Order of features as in UI
imputer = LOADED_OBJECTS_LOGREG['imputer']
cols_to_impute_for_model = LOADED_OBJECTS_LOGREG['cols_to_impute']
final_model_features = LOADED_OBJECTS_LOGREG['features'] # Stepwise features
model = MODEL_CACHE['logistic_regression']
raw_input_data = dict(zip(current_ui_feature_order_for_dict_zip, current_args))
elif model_choice in ["Optimized Decision Tree", "Bagging Classifier", "XGBoost Classifier", "Random Forest Classifier"]:
if not TREE_READY or model_choice.lower().replace(" ", "_") not in MODEL_CACHE:
return f"{model_choice} model not loaded. Check logs.", "", ""
num_logreg_ui_fields = len(logreg_ui_feature_order)
# num_tree_ui_fields = len(tree_ui_feature_order) # Not needed directly for slicing here
current_args = all_ui_values[num_logreg_ui_fields:] # Get the rest of args for tree models
current_ui_feature_order_for_dict_zip = tree_ui_feature_order # Order of features as in UI
imputer = LOADED_OBJECTS_TREE['imputer']
cols_to_impute_for_model = LOADED_OBJECTS_TREE['cols_to_impute']
final_model_features = LOADED_OBJECTS_TREE['features'] # Full feature list
model = MODEL_CACHE[model_choice.lower().replace(" ", "_")]
raw_input_data = dict(zip(current_ui_feature_order_for_dict_zip, current_args))
else:
return "Invalid model choice.", "", ""
print(f"Raw input data mapped for model: {raw_input_data}")
processed_values = {}
for key, value in raw_input_data.items():
if isinstance(value, str) and value.strip() == "":
processed_values[key] = np.nan
elif value is None: # Gradio number input might return None if empty and no default
processed_values[key] = np.nan
else:
try:
processed_values[key] = float(value)
except (ValueError, TypeError):
if key not in cols_to_impute_for_model:
return f"Invalid input for '{key}'. Please provide a number.", f"Error with input '{key}'", ""
processed_values[key] = np.nan # Ensure it's np.nan for imputer
df = pd.DataFrame([processed_values])
# Ensure df columns match the order expected by processing steps (usually current_ui_feature_order)
# This ensures that when we select `cols_to_impute_for_model` or `final_model_features`, they exist.
# We need to ensure all keys from current_ui_feature_order_for_dict_zip are columns in df
df = df.reindex(columns=current_ui_feature_order_for_dict_zip)
print("DataFrame before imputation:\n", df.head())
# Imputation
try:
imputable_cols_in_df = [col for col in cols_to_impute_for_model if col in df.columns]
if imputable_cols_in_df:
# Create a copy for transform to avoid SettingWithCopyWarning if df[imputable_cols_in_df] is a view
df_subset_to_impute = df[imputable_cols_in_df].copy()
transformed_subset = imputer.transform(df_subset_to_impute)
df[imputable_cols_in_df] = transformed_subset
print("DataFrame after imputation:\n", df.head())
except Exception as e:
print(f"Error during imputation: {e}")
return f"Error during data imputation: {e}", str(e), ""
# Feature selection and ordering for the *final model input*
try:
final_df = df[final_model_features]
print("Final DataFrame for model prediction:\n", final_df.head())
except KeyError as e:
print(f"Error: Missing feature column for the model's final input: {e}")
return f"Internal error: Missing expected feature {e}. Check model feature lists.", f"KeyError: {e}", ""
except Exception as e:
print(f"Error during final feature selection: {e}")
return f"Error during data preparation: {e}", str(e), ""
# Prediction
try:
pred_proba = model.predict_proba(final_df)[0]
prediction = model.predict(final_df)[0]
proba_injured = pred_proba[1]
proba_not_injured = pred_proba[0]
if prediction == 1:
result_text = "Prediction: Likely Injured"
probability_text = f"Probability of Injury: {proba_injured:.2%}\nProbability of No Injury: {proba_not_injured:.2%}"
else:
result_text = "Prediction: Likely Not Injured"
probability_text = f"Probability of No Injury: {proba_not_injured:.2%}\nProbability of Injury: {proba_injured:.2%}"
# Added a third output for more detailed info/status if needed
return result_text, probability_text, "Prediction successful."
except Exception as e:
print(f"Error during prediction: {e}")
return f"Error during prediction: {e}", str(e), ""
# --- Create Gradio UI components with default values ---
if LOGREG_READY and logreg_ui_feature_order: # Check if order is populated
print(f"Generating UI components for LogReg with defaults. UI Features: {logreg_ui_feature_order}")
for feature in logreg_ui_feature_order:
label = feature.replace("_", " ").title()
default_value = DEFAULT_LOGREG_INPUTS.get(feature, None) # Get default, or None
if feature in LOADED_OBJECTS_LOGREG.get('cols_to_impute', []):
label += " (can be blank for mean imputation)"
logreg_inputs_ui_components.append(
gr.Number(label=label, value=default_value, elem_id=f"logreg_{feature}")
)
else:
if not LOGREG_READY: print("Logistic Regression dependencies not loaded.")
if not logreg_ui_feature_order: print("LogReg UI feature order not set.")
if TREE_READY and tree_ui_feature_order: # Check if order is populated
print(f"Generating UI components for Tree Models with defaults. UI Features: {tree_ui_feature_order}")
for feature in tree_ui_feature_order:
label = feature.replace("_", " ").title()
default_value = DEFAULT_TREE_INPUTS.get(feature, None)
if feature in LOADED_OBJECTS_TREE.get('cols_to_impute', []):
label += " (can be blank for mean imputation)"
tree_inputs_ui_components.append(
gr.Number(label=label, value=default_value, elem_id=f"tree_{feature}")
)
else:
if not TREE_READY: print("Tree model dependencies not loaded.")
if not tree_ui_feature_order: print("Tree UI feature order not set.")
# --- Create Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("<h1>NBA Player Injury Predictor</h1>", elem_id="main_title") # Changed to h1
gr.Markdown("Select a model and enter player statistics. Fields are pre-filled with sample data. Adjust as needed. "
"Blank fields for certain stats will be mean-imputed by the model.", elem_id="sub_title")
model_choices = []
if LOGREG_READY and logreg_inputs_ui_components: model_choices.append("Logistic Regression") # Check if components were created
if TREE_READY and tree_inputs_ui_components: model_choices.extend(["Optimized Decision Tree", "Bagging Classifier", "XGBoost Classifier", "Random Forest Classifier"])
if not model_choices:
gr.Markdown("## CRITICAL ERROR: No models could be loaded or UI components generated. Check console logs.")
else:
model_selector = gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0], elem_id="model_selector_dropdown")
# Output components
output_prediction = gr.Label(label="Prediction Result", elem_id="prediction_label")
output_probability = gr.Textbox(label="Probabilities", lines=2, interactive=False, elem_id="probability_textbox") # Reduced lines
# Optional status/debug message output
output_status = gr.Textbox(label="Status", lines=1, interactive=False, elem_id="status_textbox", visible=False) # Initially hidden
# Determine initial visibility based on the first choice
initial_model_is_logreg = model_choices[0] == "Logistic Regression"
with gr.Column(visible=initial_model_is_logreg, elem_id="logreg_inputs_column") as logreg_inputs_group:
gr.Markdown("<h3>Logistic Regression Inputs</h3>", elem_id="logreg_header") # Changed to h3
if LOGREG_READY and logreg_inputs_ui_components:
with gr.Accordion("Show/Hide Logistic Regression Stats", open=True): # Added Accordion
for component in logreg_inputs_ui_components:
component.render()
else:
gr.Markdown("*(Logistic Regression model failed to load or has no inputs defined)*")
with gr.Column(visible=not initial_model_is_logreg, elem_id="tree_inputs_column") as tree_inputs_group:
gr.Markdown("<h3>Tree-Based Model Inputs</h3>", elem_id="tree_header") # Changed to h3
if TREE_READY and tree_inputs_ui_components:
with gr.Accordion("Show/Hide Tree Model Stats", open=True): # Added Accordion
for component in tree_inputs_ui_components:
component.render()
else:
gr.Markdown("*(Tree-based models failed to load or have no inputs defined)*")
predict_button = gr.Button("Predict Injury", variant="primary", elem_id="predict_button_main") # Added variant
def update_input_visibility(selected_model_value):
is_logreg = selected_model_value == "Logistic Regression"
return { # Return dict for gr.update
logreg_inputs_group: gr.update(visible=is_logreg),
tree_inputs_group: gr.update(visible=not is_logreg)
}
model_selector.change(
fn=update_input_visibility,
inputs=model_selector,
outputs=[logreg_inputs_group, tree_inputs_group]
)
all_ui_input_components_ordered = logreg_inputs_ui_components + tree_inputs_ui_components
predict_button.click(
fn=predict_injury,
inputs=[model_selector] + all_ui_input_components_ordered,
# predict_injury now returns 3 values, last one is status
outputs=[output_prediction, output_probability, output_status]
)
if __name__ == "__main__":
if not model_choices:
print("No models available. Gradio app may not function correctly.")
app.launch() # show_error for better debugging