import os import streamlit as st import torch from transformers import AutoTokenizer from predict_utils import predict_injury # 🔹 Load tokenizer from Hugging Face @st.cache_resource def load_tokenizer(): tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") return tokenizer # 🔹 Load model from file @st.cache_resource def load_model(): model_path = "model/final_injury_model.pt" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not os.path.exists(model_path): st.error(f"Model file not found at: {model_path}") return None model = torch.load(model_path, map_location=device) model.eval() return model # 💡 Main app def main(): st.set_page_config(page_title="NBA Injury Type & Duration Classifier", page_icon="🏀") st.title("NBA Injury Type & Duration Classifier 🏀") model = load_model() tokenizer = load_tokenizer() if model is None or tokenizer is None: st.stop() st.markdown(""" Enter an injury description and player details to get predicted injury type and expected recovery duration. """) # 🔹 User Inputs text = st.text_area("Injury description", "player has a sprained ankle") prior_injuries = st.number_input("Number of Prior Injuries", min_value=0, value=1) injury_type_id = st.selectbox("General Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3}) position_id = st.selectbox("Player Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5}) # 🔹 Prediction button if st.button("Predict"): label_map_type = ["bone", "muscle", "joint", "ligament"] label_map_duration = ["short", "medium", "long"] try: type_label, type_conf, duration_label, duration_conf = predict_injury( model=model, tokenizer=tokenizer, text=text, prior_injuries=prior_injuries, injury_type_id=injury_type_id, position_id=position_id, label_map_type=label_map_type, label_map_duration=label_map_duration ) st.success(f"**Predicted Injury Type:** {type_label} ({type_conf:.1%} confidence)") st.success(f"**Expected Duration:** {duration_label} ({duration_conf:.1%} confidence)") except Exception as e: st.error(f"Prediction failed: {e}") if __name__ == "__main__": main()