lrschuman17 commited on
Commit
71897be
·
verified ·
1 Parent(s): 423a239

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Streamlit app for NBA Injury-Aware Performance Prediction
2
+ # Enhanced with attention-derived injury impact insights
3
+
4
+ import streamlit as st
5
+ import pandas as pd
6
+ import joblib
7
+ from sklearn.ensemble import RandomForestRegressor
8
+ import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ from PIL import Image
11
+
12
+ # Set app config
13
+ st.set_page_config(
14
+ page_title="NBA Injury-Aware Predictor",
15
+ page_icon="🏀",
16
+ layout="centered"
17
+ )
18
+
19
+ # Load model and data
20
+ @st.cache_resource
21
+ def load_rf_model():
22
+ return joblib.load("rf_injury_change_model.pkl")
23
+
24
+ @st.cache_resource
25
+ def load_player_data():
26
+ return pd.read_csv("player_data.csv")
27
+
28
+ # Mapping utilities
29
+ position_mapping = {"PG": 1.0, "SG": 2.0, "SF": 3.0, "PF": 4.0, "C": 5.0}
30
+
31
+ def encode_injury_type(type_str, known_types):
32
+ return known_types.index(type_str) if type_str in known_types else -1
33
+
34
+ # App content
35
+ def main():
36
+ st.title("🏀 Injury-Aware NBA Player Predictor")
37
+ st.write("Predict performance changes post-injury using injury type, position, and context.")
38
+
39
+ player_data = load_player_data()
40
+ model = load_rf_model()
41
+
42
+ injury_types = sorted(player_data["injury_type"].dropna().unique())
43
+ player_list = sorted(player_data["player_name"].dropna().unique())
44
+
45
+ player = st.selectbox("Select a Player", player_list)
46
+ injury = st.selectbox("Hypothetical Injury", injury_types)
47
+
48
+ player_row = player_data[player_data.player_name == player].iloc[0]
49
+ position_numeric = position_mapping.get(player_row["position"], 0)
50
+
51
+ # Sidebar for editable fields
52
+ st.sidebar.subheader("Adjust Inputs")
53
+ age = st.sidebar.slider("Age", int(player_row.age)-5, int(player_row.age)+5, int(player_row.age))
54
+ height = st.sidebar.slider("Height (cm)", 160, 220, int(player_row.player_height))
55
+ weight = st.sidebar.slider("Weight (kg)", 60, 140, int(player_row.player_weight))
56
+ injury_occurrences = st.sidebar.slider("Prior Injuries", 0, 10, int(player_row.injury_occurrences or 1))
57
+
58
+ avg_days_injured = player_data[player_data.injury_type == injury]["days_injured"].mean()
59
+ days_injured = st.sidebar.slider("Estimated Days Injured", 0, 365, int(avg_days_injured or 30))
60
+
61
+ # Prepare input vector
62
+ encoded_type = encode_injury_type(injury, injury_types)
63
+ input_data = pd.DataFrame([{
64
+ "age": age,
65
+ "player_height": height,
66
+ "player_weight": weight,
67
+ "position": position_numeric,
68
+ "injury_type": encoded_type,
69
+ "injury_occurrences": injury_occurrences,
70
+ "days_injured": days_injured
71
+ }])
72
+
73
+ expected_features = model.feature_names_in_
74
+ input_data = input_data.reindex(columns=expected_features, fill_value=0)
75
+
76
+ if st.button("Predict 🔮"):
77
+ preds = model.predict(input_data)
78
+ labels = ["Change in PTS", "Change in REB", "Change in AST"]
79
+ pred_df = pd.DataFrame(preds, columns=labels)
80
+
81
+ st.subheader("Predicted Performance Changes")
82
+ st.dataframe(pred_df.style.format("{:.2f}"))
83
+
84
+ fig = go.Figure()
85
+ for col in labels:
86
+ fig.add_trace(go.Bar(x=[col], y=pred_df[col], name=col))
87
+
88
+ fig.update_layout(title="Predicted Impact", template="plotly_dark")
89
+ st.plotly_chart(fig)
90
+
91
+ if __name__ == "__main__":
92
+ main()