# Streamlit app for NBA Injury-Aware Performance Prediction # Enhanced with attention-derived injury impact insights import streamlit as st import pandas as pd import joblib from sklearn.ensemble import RandomForestRegressor import plotly.express as px import plotly.graph_objects as go from PIL import Image # Set app config st.set_page_config( page_title="NBA Injury-Aware Predictor", page_icon="🏀", layout="centered" ) # Load model and data @st.cache_resource def load_rf_model(): return joblib.load("rf_injury_change_model.pkl") @st.cache_resource def load_player_data(): return pd.read_csv("player_data.csv") # Mapping utilities position_mapping = {"PG": 1.0, "SG": 2.0, "SF": 3.0, "PF": 4.0, "C": 5.0} def encode_injury_type(type_str, known_types): return known_types.index(type_str) if type_str in known_types else -1 # App content def main(): st.title("🏀 Injury-Aware NBA Player Predictor") st.write("Predict performance changes post-injury using injury type, position, and context.") player_data = load_player_data() model = load_rf_model() injury_types = sorted(player_data["injury_type"].dropna().unique()) player_list = sorted(player_data["player_name"].dropna().unique()) player = st.selectbox("Select a Player", player_list) injury = st.selectbox("Hypothetical Injury", injury_types) player_row = player_data[player_data.player_name == player].iloc[0] position_numeric = position_mapping.get(player_row["position"], 0) # Sidebar for editable fields st.sidebar.subheader("Adjust Inputs") age = st.sidebar.slider("Age", int(player_row.age)-5, int(player_row.age)+5, int(player_row.age)) height = st.sidebar.slider("Height (cm)", 160, 220, int(player_row.player_height)) weight = st.sidebar.slider("Weight (kg)", 60, 140, int(player_row.player_weight)) injury_occurrences = st.sidebar.slider("Prior Injuries", 0, 10, int(player_row.injury_occurrences or 1)) avg_days_injured = player_data[player_data.injury_type == injury]["days_injured"].mean() days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(avg_days_injured or 30)) # Prepare input vector encoded_type = encode_injury_type(injury, injury_types) input_data = pd.DataFrame([{ "age": age, "player_height": height, "player_weight": weight, "position": position_numeric, "injury_type": encoded_type, "injury_occurrences": injury_occurrences, "days_injured": days_injured }]) expected_features = model.feature_names_in_ input_data = input_data.reindex(columns=expected_features, fill_value=0) if st.button("Predict 🔮"): preds = model.predict(input_data) labels = ["Change in PTS", "Change in REB", "Change in AST"] pred_df = pd.DataFrame(preds, columns=labels) st.subheader("Predicted Performance Changes") st.dataframe(pred_df.style.format("{:.2f}")) fig = go.Figure() for col in labels: fig.add_trace(go.Bar(x=[col], y=pred_df[col], name=col)) fig.update_layout(title="Predicted Impact", template="plotly_dark") st.plotly_chart(fig) if __name__ == "__main__": main()