Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import torch | |
| from transformers import AutoTokenizer | |
| from predict_utils import predict_injury | |
| # πΉ Load tokenizer from Hugging Face | |
| def load_tokenizer(): | |
| tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") | |
| return tokenizer | |
| # πΉ Load model from file | |
| def load_model(): | |
| model_path = "model/final_injury_model.pt" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if not os.path.exists(model_path): | |
| st.error(f"Model file not found at: {model_path}") | |
| return None | |
| model = torch.load(model_path, map_location=device) | |
| model.eval() | |
| return model | |
| # π‘ Main app | |
| def main(): | |
| st.set_page_config(page_title="NBA Injury Type & Duration Classifier", page_icon="π") | |
| st.title("NBA Injury Type & Duration Classifier π") | |
| model = load_model() | |
| tokenizer = load_tokenizer() | |
| if model is None or tokenizer is None: | |
| st.stop() | |
| st.markdown(""" | |
| Enter an injury description and player details to get predicted injury type and expected recovery duration. | |
| """) | |
| # πΉ User Inputs | |
| text = st.text_area("Injury description", "player has a sprained ankle") | |
| prior_injuries = st.number_input("Number of Prior Injuries", min_value=0, value=1) | |
| injury_type_id = st.selectbox("General Injury Type", {"bone": 0, "muscle": 1, "joint": 2, "ligament": 3}) | |
| position_id = st.selectbox("Player Position", {"PG": 1, "SG": 2, "SF": 3, "PF": 4, "C": 5}) | |
| # πΉ Prediction button | |
| if st.button("Predict"): | |
| label_map_type = ["bone", "muscle", "joint", "ligament"] | |
| label_map_duration = ["short", "medium", "long"] | |
| try: | |
| type_label, type_conf, duration_label, duration_conf = predict_injury( | |
| model=model, | |
| tokenizer=tokenizer, | |
| text=text, | |
| prior_injuries=prior_injuries, | |
| injury_type_id=injury_type_id, | |
| position_id=position_id, | |
| label_map_type=label_map_type, | |
| label_map_duration=label_map_duration | |
| ) | |
| st.success(f"**Predicted Injury Type:** {type_label} ({type_conf:.1%} confidence)") | |
| st.success(f"**Expected Duration:** {duration_label} ({duration_conf:.1%} confidence)") | |
| except Exception as e: | |
| st.error(f"Prediction failed: {e}") | |
| if __name__ == "__main__": | |
| main() | |