Spaces:
Build error
Build error
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +6 -1
src/streamlit_app.py
CHANGED
|
@@ -11,10 +11,15 @@ def load_tokenizer():
|
|
| 11 |
# β
Load model
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model():
|
| 14 |
-
model_path = "final_injury_model.pt" # or use a Hugging Face model ID if from hub
|
| 15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
model = torch.load(model_path, map_location=device)
|
|
|
|
|
|
|
|
|
|
| 17 |
model.eval()
|
|
|
|
|
|
|
| 18 |
return model
|
| 19 |
def main():
|
| 20 |
st.title("NBA Injury Type & Duration Classifier π")
|
|
|
|
| 11 |
# β
Load model
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model():
|
| 14 |
+
model_path = "model/final_injury_model.pt" # or use a Hugging Face model ID if from hub
|
| 15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
model = torch.load(model_path, map_location=device)
|
| 17 |
+
if not os.path.exists(model_path):
|
| 18 |
+
st.error(f"Model file not found at: {model_path}")
|
| 19 |
+
return None
|
| 20 |
model.eval()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
return model
|
| 24 |
def main():
|
| 25 |
st.title("NBA Injury Type & Duration Classifier π")
|