File size: 3,248 Bytes
71897be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Streamlit app for NBA Injury-Aware Performance Prediction
# Enhanced with attention-derived injury impact insights

import streamlit as st
import pandas as pd
import joblib
from sklearn.ensemble import RandomForestRegressor
import plotly.express as px
import plotly.graph_objects as go
from PIL import Image

# Set app config
st.set_page_config(
    page_title="NBA Injury-Aware Predictor",
    page_icon="🏀",
    layout="centered"
)

# Load model and data
@st.cache_resource
def load_rf_model():
    return joblib.load("rf_injury_change_model.pkl")

@st.cache_resource
def load_player_data():
    return pd.read_csv("player_data.csv")

# Mapping utilities
position_mapping = {"PG": 1.0, "SG": 2.0, "SF": 3.0, "PF": 4.0, "C": 5.0}

def encode_injury_type(type_str, known_types):
    return known_types.index(type_str) if type_str in known_types else -1

# App content
def main():
    st.title("🏀 Injury-Aware NBA Player Predictor")
    st.write("Predict performance changes post-injury using injury type, position, and context.")

    player_data = load_player_data()
    model = load_rf_model()

    injury_types = sorted(player_data["injury_type"].dropna().unique())
    player_list = sorted(player_data["player_name"].dropna().unique())

    player = st.selectbox("Select a Player", player_list)
    injury = st.selectbox("Hypothetical Injury", injury_types)

    player_row = player_data[player_data.player_name == player].iloc[0]
    position_numeric = position_mapping.get(player_row["position"], 0)

    # Sidebar for editable fields
    st.sidebar.subheader("Adjust Inputs")
    age = st.sidebar.slider("Age", int(player_row.age)-5, int(player_row.age)+5, int(player_row.age))
    height = st.sidebar.slider("Height (cm)", 160, 220, int(player_row.player_height))
    weight = st.sidebar.slider("Weight (kg)", 60, 140, int(player_row.player_weight))
    injury_occurrences = st.sidebar.slider("Prior Injuries", 0, 10, int(player_row.injury_occurrences or 1))

    avg_days_injured = player_data[player_data.injury_type == injury]["days_injured"].mean()
    days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(avg_days_injured or 30))

    # Prepare input vector
    encoded_type = encode_injury_type(injury, injury_types)
    input_data = pd.DataFrame([{
        "age": age,
        "player_height": height,
        "player_weight": weight,
        "position": position_numeric,
        "injury_type": encoded_type,
        "injury_occurrences": injury_occurrences,
        "days_injured": days_injured
    }])

    expected_features = model.feature_names_in_
    input_data = input_data.reindex(columns=expected_features, fill_value=0)

    if st.button("Predict 🔮"):
        preds = model.predict(input_data)
        labels = ["Change in PTS", "Change in REB", "Change in AST"]
        pred_df = pd.DataFrame(preds, columns=labels)

        st.subheader("Predicted Performance Changes")
        st.dataframe(pred_df.style.format("{:.2f}"))

        fig = go.Figure()
        for col in labels:
            fig.add_trace(go.Bar(x=[col], y=pred_df[col], name=col))

        fig.update_layout(title="Predicted Impact", template="plotly_dark")
        st.plotly_chart(fig)

if __name__ == "__main__":
    main()