|
|
|
|
|
|
|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import joblib |
|
|
from sklearn.ensemble import RandomForestRegressor |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
st.set_page_config( |
|
|
page_title="NBA Injury-Aware Predictor", |
|
|
page_icon="๐", |
|
|
layout="centered" |
|
|
) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_rf_model(): |
|
|
return joblib.load("rf_injury_change_model.pkl") |
|
|
|
|
|
@st.cache_resource |
|
|
def load_player_data(): |
|
|
return pd.read_csv("player_data.csv") |
|
|
|
|
|
|
|
|
position_mapping = {"PG": 1.0, "SG": 2.0, "SF": 3.0, "PF": 4.0, "C": 5.0} |
|
|
|
|
|
def encode_injury_type(type_str, known_types): |
|
|
return known_types.index(type_str) if type_str in known_types else -1 |
|
|
|
|
|
|
|
|
def main(): |
|
|
st.title("๐ Injury-Aware NBA Player Predictor") |
|
|
st.write("Predict performance changes post-injury using injury type, position, and context.") |
|
|
|
|
|
player_data = load_player_data() |
|
|
model = load_rf_model() |
|
|
|
|
|
injury_types = sorted(player_data["injury_type"].dropna().unique()) |
|
|
player_list = sorted(player_data["player_name"].dropna().unique()) |
|
|
|
|
|
player = st.selectbox("Select a Player", player_list) |
|
|
injury = st.selectbox("Hypothetical Injury", injury_types) |
|
|
|
|
|
player_row = player_data[player_data.player_name == player].iloc[0] |
|
|
position_numeric = position_mapping.get(player_row["position"], 0) |
|
|
|
|
|
|
|
|
st.sidebar.subheader("Adjust Inputs") |
|
|
age = st.sidebar.slider("Age", int(player_row.age)-5, int(player_row.age)+5, int(player_row.age)) |
|
|
height = st.sidebar.slider("Height (cm)", 160, 220, int(player_row.player_height)) |
|
|
weight = st.sidebar.slider("Weight (kg)", 60, 140, int(player_row.player_weight)) |
|
|
injury_occurrences = st.sidebar.slider("Prior Injuries", 0, 10, int(player_row.injury_occurrences or 1)) |
|
|
|
|
|
avg_days_injured = player_data[player_data.injury_type == injury]["days_injured"].mean() |
|
|
days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(avg_days_injured or 30)) |
|
|
|
|
|
|
|
|
encoded_type = encode_injury_type(injury, injury_types) |
|
|
input_data = pd.DataFrame([{ |
|
|
"age": age, |
|
|
"player_height": height, |
|
|
"player_weight": weight, |
|
|
"position": position_numeric, |
|
|
"injury_type": encoded_type, |
|
|
"injury_occurrences": injury_occurrences, |
|
|
"days_injured": days_injured |
|
|
}]) |
|
|
|
|
|
expected_features = model.feature_names_in_ |
|
|
input_data = input_data.reindex(columns=expected_features, fill_value=0) |
|
|
|
|
|
if st.button("Predict ๐ฎ"): |
|
|
preds = model.predict(input_data) |
|
|
labels = ["Change in PTS", "Change in REB", "Change in AST"] |
|
|
pred_df = pd.DataFrame(preds, columns=labels) |
|
|
|
|
|
st.subheader("Predicted Performance Changes") |
|
|
st.dataframe(pred_df.style.format("{:.2f}")) |
|
|
|
|
|
fig = go.Figure() |
|
|
for col in labels: |
|
|
fig.add_trace(go.Bar(x=[col], y=pred_df[col], name=col)) |
|
|
|
|
|
fig.update_layout(title="Predicted Impact", template="plotly_dark") |
|
|
st.plotly_chart(fig) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|