InjuryDetection / app.py
lrschuman17's picture
Create app.py
71897be verified
# Streamlit app for NBA Injury-Aware Performance Prediction
# Enhanced with attention-derived injury impact insights
import streamlit as st
import pandas as pd
import joblib
from sklearn.ensemble import RandomForestRegressor
import plotly.express as px
import plotly.graph_objects as go
from PIL import Image
# Set app config
st.set_page_config(
page_title="NBA Injury-Aware Predictor",
page_icon="๐Ÿ€",
layout="centered"
)
# Load model and data
@st.cache_resource
def load_rf_model():
return joblib.load("rf_injury_change_model.pkl")
@st.cache_resource
def load_player_data():
return pd.read_csv("player_data.csv")
# Mapping utilities
position_mapping = {"PG": 1.0, "SG": 2.0, "SF": 3.0, "PF": 4.0, "C": 5.0}
def encode_injury_type(type_str, known_types):
return known_types.index(type_str) if type_str in known_types else -1
# App content
def main():
st.title("๐Ÿ€ Injury-Aware NBA Player Predictor")
st.write("Predict performance changes post-injury using injury type, position, and context.")
player_data = load_player_data()
model = load_rf_model()
injury_types = sorted(player_data["injury_type"].dropna().unique())
player_list = sorted(player_data["player_name"].dropna().unique())
player = st.selectbox("Select a Player", player_list)
injury = st.selectbox("Hypothetical Injury", injury_types)
player_row = player_data[player_data.player_name == player].iloc[0]
position_numeric = position_mapping.get(player_row["position"], 0)
# Sidebar for editable fields
st.sidebar.subheader("Adjust Inputs")
age = st.sidebar.slider("Age", int(player_row.age)-5, int(player_row.age)+5, int(player_row.age))
height = st.sidebar.slider("Height (cm)", 160, 220, int(player_row.player_height))
weight = st.sidebar.slider("Weight (kg)", 60, 140, int(player_row.player_weight))
injury_occurrences = st.sidebar.slider("Prior Injuries", 0, 10, int(player_row.injury_occurrences or 1))
avg_days_injured = player_data[player_data.injury_type == injury]["days_injured"].mean()
days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(avg_days_injured or 30))
# Prepare input vector
encoded_type = encode_injury_type(injury, injury_types)
input_data = pd.DataFrame([{
"age": age,
"player_height": height,
"player_weight": weight,
"position": position_numeric,
"injury_type": encoded_type,
"injury_occurrences": injury_occurrences,
"days_injured": days_injured
}])
expected_features = model.feature_names_in_
input_data = input_data.reindex(columns=expected_features, fill_value=0)
if st.button("Predict ๐Ÿ”ฎ"):
preds = model.predict(input_data)
labels = ["Change in PTS", "Change in REB", "Change in AST"]
pred_df = pd.DataFrame(preds, columns=labels)
st.subheader("Predicted Performance Changes")
st.dataframe(pred_df.style.format("{:.2f}"))
fig = go.Figure()
for col in labels:
fig.add_trace(go.Bar(x=[col], y=pred_df[col], name=col))
fig.update_layout(title="Predicted Impact", template="plotly_dark")
st.plotly_chart(fig)
if __name__ == "__main__":
main()