Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +23 -17
src/streamlit_app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
|
|
|
| 3 |
from transformers import AutoTokenizer
|
| 4 |
-
from predict_utils import predict_injury # assumes predict_injury is defined
|
| 5 |
|
| 6 |
# โ
Load tokenizer
|
| 7 |
@st.cache_resource
|
|
@@ -11,32 +12,36 @@ def load_tokenizer():
|
|
| 11 |
# โ
Load model
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model():
|
| 14 |
-
model_path = "model/final_injury_model.pt" #
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
model = torch.load(model_path, map_location=device)
|
| 17 |
-
if not os.path.exists(model_path):
|
| 18 |
-
st.error(f"Model file not found at: {model_path}")
|
| 19 |
-
return None
|
| 20 |
model.eval()
|
| 21 |
-
|
| 22 |
-
|
| 23 |
return model
|
|
|
|
|
|
|
| 24 |
def main():
|
| 25 |
-
st.title("NBA Injury Type & Duration Classifier
|
| 26 |
|
| 27 |
model = load_model()
|
| 28 |
tokenizer = load_tokenizer()
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
#
|
| 33 |
-
text = st.text_input("Injury
|
| 34 |
-
prior_injuries = st.number_input("Prior Injuries", min_value=0, value=1)
|
| 35 |
-
injury_type_id = st.selectbox("Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3})
|
| 36 |
-
position_id = st.selectbox("Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5})
|
| 37 |
|
| 38 |
# prediction
|
| 39 |
-
if st.button("Predict"):
|
| 40 |
label_map_type = ['bone', 'muscle', 'joint', 'ligament']
|
| 41 |
label_map_duration = ['short', 'medium', 'long']
|
| 42 |
|
|
@@ -51,7 +56,8 @@ def main():
|
|
| 51 |
label_map_duration=label_map_duration
|
| 52 |
)
|
| 53 |
|
| 54 |
-
st.success(f"**Injury Type:** {type_label} ({type_conf:.1%})")
|
| 55 |
-
st.success(f"**
|
|
|
|
| 56 |
if __name__ == "__main__":
|
| 57 |
main()
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
+
import os
|
| 4 |
from transformers import AutoTokenizer
|
| 5 |
+
from predict_utils import predict_injury # assumes predict_injury is defined in this module
|
| 6 |
|
| 7 |
# โ
Load tokenizer
|
| 8 |
@st.cache_resource
|
|
|
|
| 12 |
# โ
Load model
|
| 13 |
@st.cache_resource
|
| 14 |
def load_model():
|
| 15 |
+
model_path = "model/final_injury_model.pt" # adjust if needed
|
| 16 |
+
if not os.path.exists(model_path):
|
| 17 |
+
st.error(f"Model file not found at: {model_path}")
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
model = torch.load(model_path, map_location=device)
|
|
|
|
|
|
|
|
|
|
| 22 |
model.eval()
|
|
|
|
|
|
|
| 23 |
return model
|
| 24 |
+
|
| 25 |
+
# โ
Main Streamlit app
|
| 26 |
def main():
|
| 27 |
+
st.title("๐ NBA Injury Type & Duration Classifier")
|
| 28 |
|
| 29 |
model = load_model()
|
| 30 |
tokenizer = load_tokenizer()
|
| 31 |
|
| 32 |
+
if model is None:
|
| 33 |
+
st.stop() # prevent rest of UI from running if model isn't loaded
|
| 34 |
+
|
| 35 |
+
st.markdown("Enter an injury report and structured data to get predictions.")
|
| 36 |
|
| 37 |
+
# user inputs
|
| 38 |
+
text = st.text_input("๐ Injury Description", "player has a sprained ankle")
|
| 39 |
+
prior_injuries = st.number_input("๐ Prior Injuries", min_value=0, value=1)
|
| 40 |
+
injury_type_id = st.selectbox("๐ฉป Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3})
|
| 41 |
+
position_id = st.selectbox("๐ Player Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5})
|
| 42 |
|
| 43 |
# prediction
|
| 44 |
+
if st.button("๐ฎ Predict"):
|
| 45 |
label_map_type = ['bone', 'muscle', 'joint', 'ligament']
|
| 46 |
label_map_duration = ['short', 'medium', 'long']
|
| 47 |
|
|
|
|
| 56 |
label_map_duration=label_map_duration
|
| 57 |
)
|
| 58 |
|
| 59 |
+
st.success(f"**Predicted Injury Type:** `{type_label}` ({type_conf:.1%} confidence)")
|
| 60 |
+
st.success(f"**Predicted Duration:** `{duration_label}` ({duration_conf:.1%} confidence)")
|
| 61 |
+
|
| 62 |
if __name__ == "__main__":
|
| 63 |
main()
|