lrschuman17 commited on
Commit
93984c5
ยท
verified ยท
1 Parent(s): 8764ccc

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +47 -38
src/streamlit_app.py CHANGED
@@ -1,63 +1,72 @@
 
1
  import streamlit as st
2
  import torch
3
- import os
4
  from transformers import AutoTokenizer
5
- from predict_utils import predict_injury # assumes predict_injury is defined in this module
6
 
7
- # โœ… Load tokenizer
8
  @st.cache_resource
9
  def load_tokenizer():
10
- return AutoTokenizer.from_pretrained("distilbert-base-uncased")
 
11
 
12
- # โœ… Load model
13
  @st.cache_resource
14
  def load_model():
15
- model_path = "model/final_injury_model.pt" # adjust if needed
 
 
16
  if not os.path.exists(model_path):
17
  st.error(f"Model file not found at: {model_path}")
18
  return None
19
 
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model = torch.load(model_path, map_location=device)
22
  model.eval()
23
  return model
24
 
25
- # โœ… Main Streamlit app
26
  def main():
27
- st.title("๐Ÿ€ NBA Injury Type & Duration Classifier")
 
28
 
29
  model = load_model()
30
  tokenizer = load_tokenizer()
31
 
32
- if model is None:
33
- st.stop() # prevent rest of UI from running if model isn't loaded
34
-
35
- st.markdown("Enter an injury report and structured data to get predictions.")
36
-
37
- # user inputs
38
- text = st.text_input("๐Ÿ“ Injury Description", "player has a sprained ankle")
39
- prior_injuries = st.number_input("๐Ÿ” Prior Injuries", min_value=0, value=1)
40
- injury_type_id = st.selectbox("๐Ÿฉป Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3})
41
- position_id = st.selectbox("๐Ÿ€ Player Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5})
42
-
43
- # prediction
44
- if st.button("๐Ÿ”ฎ Predict"):
45
- label_map_type = ['bone', 'muscle', 'joint', 'ligament']
46
- label_map_duration = ['short', 'medium', 'long']
47
-
48
- type_label, type_conf, duration_label, duration_conf = predict_injury(
49
- model=model,
50
- tokenizer=tokenizer,
51
- text=text,
52
- prior_injuries=prior_injuries,
53
- injury_type_id=injury_type_id,
54
- position_id=position_id,
55
- label_map_type=label_map_type,
56
- label_map_duration=label_map_duration
57
- )
58
-
59
- st.success(f"**Predicted Injury Type:** `{type_label}` ({type_conf:.1%} confidence)")
60
- st.success(f"**Predicted Duration:** `{duration_label}` ({duration_conf:.1%} confidence)")
 
 
 
 
 
 
61
 
62
  if __name__ == "__main__":
63
  main()
 
1
+ import os
2
  import streamlit as st
3
  import torch
 
4
  from transformers import AutoTokenizer
5
+ from predict_utils import predict_injury
6
 
7
+ # ๐Ÿ”น Load tokenizer from Hugging Face
8
  @st.cache_resource
9
  def load_tokenizer():
10
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
11
+ return tokenizer
12
 
13
+ # ๐Ÿ”น Load model from file
14
  @st.cache_resource
15
  def load_model():
16
+ model_path = "model/final_injury_model.pt"
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
  if not os.path.exists(model_path):
20
  st.error(f"Model file not found at: {model_path}")
21
  return None
22
 
 
23
  model = torch.load(model_path, map_location=device)
24
  model.eval()
25
  return model
26
 
27
+ # ๐Ÿ’ก Main app
28
  def main():
29
+ st.set_page_config(page_title="NBA Injury Type & Duration Classifier", page_icon="๐Ÿ€")
30
+ st.title("NBA Injury Type & Duration Classifier ๐Ÿ€")
31
 
32
  model = load_model()
33
  tokenizer = load_tokenizer()
34
 
35
+ if model is None or tokenizer is None:
36
+ st.stop()
37
+
38
+ st.markdown("""
39
+ Enter an injury description and player details to get predicted injury type and expected recovery duration.
40
+ """)
41
+
42
+ # ๐Ÿ”น User Inputs
43
+ text = st.text_area("Injury description", "player has a sprained ankle")
44
+ prior_injuries = st.number_input("Number of Prior Injuries", min_value=0, value=1)
45
+ injury_type_id = st.selectbox("General Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3})
46
+ position_id = st.selectbox("Player Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5})
47
+
48
+ # ๐Ÿ”น Prediction button
49
+ if st.button("Predict"):
50
+ label_map_type = ["bone", "muscle", "joint", "ligament"]
51
+ label_map_duration = ["short", "medium", "long"]
52
+
53
+ try:
54
+ type_label, type_conf, duration_label, duration_conf = predict_injury(
55
+ model=model,
56
+ tokenizer=tokenizer,
57
+ text=text,
58
+ prior_injuries=prior_injuries,
59
+ injury_type_id=injury_type_id,
60
+ position_id=position_id,
61
+ label_map_type=label_map_type,
62
+ label_map_duration=label_map_duration
63
+ )
64
+
65
+ st.success(f"**Predicted Injury Type:** {type_label} ({type_conf:.1%} confidence)")
66
+ st.success(f"**Expected Duration:** {duration_label} ({duration_conf:.1%} confidence)")
67
+
68
+ except Exception as e:
69
+ st.error(f"Prediction failed: {e}")
70
 
71
  if __name__ == "__main__":
72
  main()