lrschuman17 commited on
Commit
8764ccc
ยท
verified ยท
1 Parent(s): 30a0266

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +23 -17
src/streamlit_app.py CHANGED
@@ -1,7 +1,8 @@
1
  import streamlit as st
2
  import torch
 
3
  from transformers import AutoTokenizer
4
- from predict_utils import predict_injury # assumes predict_injury is defined here
5
 
6
  # โœ… Load tokenizer
7
  @st.cache_resource
@@ -11,32 +12,36 @@ def load_tokenizer():
11
  # โœ… Load model
12
  @st.cache_resource
13
  def load_model():
14
- model_path = "model/final_injury_model.pt" # or use a Hugging Face model ID if from hub
 
 
 
 
15
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
  model = torch.load(model_path, map_location=device)
17
- if not os.path.exists(model_path):
18
- st.error(f"Model file not found at: {model_path}")
19
- return None
20
  model.eval()
21
-
22
-
23
  return model
 
 
24
  def main():
25
- st.title("NBA Injury Type & Duration Classifier ๐Ÿ€")
26
 
27
  model = load_model()
28
  tokenizer = load_tokenizer()
29
 
30
- st.markdown("Enter injury note and details to get predictions")
 
 
 
31
 
32
- # form inputs
33
- text = st.text_input("Injury description", "player has a sprained ankle")
34
- prior_injuries = st.number_input("Prior Injuries", min_value=0, value=1)
35
- injury_type_id = st.selectbox("Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3})
36
- position_id = st.selectbox("Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5})
37
 
38
  # prediction
39
- if st.button("Predict"):
40
  label_map_type = ['bone', 'muscle', 'joint', 'ligament']
41
  label_map_duration = ['short', 'medium', 'long']
42
 
@@ -51,7 +56,8 @@ def main():
51
  label_map_duration=label_map_duration
52
  )
53
 
54
- st.success(f"**Injury Type:** {type_label} ({type_conf:.1%})")
55
- st.success(f"**Recovery Duration:** {duration_label} ({duration_conf:.1%})")
 
56
  if __name__ == "__main__":
57
  main()
 
1
  import streamlit as st
2
  import torch
3
+ import os
4
  from transformers import AutoTokenizer
5
+ from predict_utils import predict_injury # assumes predict_injury is defined in this module
6
 
7
  # โœ… Load tokenizer
8
  @st.cache_resource
 
12
  # โœ… Load model
13
  @st.cache_resource
14
  def load_model():
15
+ model_path = "model/final_injury_model.pt" # adjust if needed
16
+ if not os.path.exists(model_path):
17
+ st.error(f"Model file not found at: {model_path}")
18
+ return None
19
+
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model = torch.load(model_path, map_location=device)
 
 
 
22
  model.eval()
 
 
23
  return model
24
+
25
+ # โœ… Main Streamlit app
26
  def main():
27
+ st.title("๐Ÿ€ NBA Injury Type & Duration Classifier")
28
 
29
  model = load_model()
30
  tokenizer = load_tokenizer()
31
 
32
+ if model is None:
33
+ st.stop() # prevent rest of UI from running if model isn't loaded
34
+
35
+ st.markdown("Enter an injury report and structured data to get predictions.")
36
 
37
+ # user inputs
38
+ text = st.text_input("๐Ÿ“ Injury Description", "player has a sprained ankle")
39
+ prior_injuries = st.number_input("๐Ÿ” Prior Injuries", min_value=0, value=1)
40
+ injury_type_id = st.selectbox("๐Ÿฉป Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3})
41
+ position_id = st.selectbox("๐Ÿ€ Player Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5})
42
 
43
  # prediction
44
+ if st.button("๐Ÿ”ฎ Predict"):
45
  label_map_type = ['bone', 'muscle', 'joint', 'ligament']
46
  label_map_duration = ['short', 'medium', 'long']
47
 
 
56
  label_map_duration=label_map_duration
57
  )
58
 
59
+ st.success(f"**Predicted Injury Type:** `{type_label}` ({type_conf:.1%} confidence)")
60
+ st.success(f"**Predicted Duration:** `{duration_label}` ({duration_conf:.1%} confidence)")
61
+
62
  if __name__ == "__main__":
63
  main()