shuraimi's picture
Update src/app.py
9b0cb23 verified
import io
from pathlib import Path
import streamlit as st
from fastai.vision.all import load_learner, PILImage
# βœ… Correct absolute path for Hugging Face Spaces
MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
@st.cache_resource
def load_model():
"""Load and cache the FastAI learner. Returns None if model missing or incompatible."""
if not MODEL_PATH.exists():
st.error(f"❌ Model not found at {MODEL_PATH}")
return None
try:
learner = load_learner(MODEL_PATH)
return learner
except Exception as e:
st.error(f"⚠️ Error loading model:\n\n{e}")
return None
def predict(learner, img_bytes: bytes):
"""Make a prediction on uploaded image bytes."""
img = PILImage.create(io.BytesIO(img_bytes))
pred, pred_idx, probs = learner.predict(img)
return pred, probs
def main():
st.title("🎯 FastAI Image Classifier")
st.write("Upload an image and the model will predict its class.")
learner = load_model()
if learner is None:
st.warning(
"Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
)
st.stop()
uploaded_file = st.file_uploader("πŸ“€ Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
with st.spinner("Predicting..."):
try:
pred, probs = predict(learner, uploaded_file.read())
st.success(f"βœ… Predicted: **{pred}**")
# Show top-5 predictions
vocab = learner.dls.vocab
probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
st.write("### Top Predictions:")
for label, p in probs_list[:5]:
st.write(f"- {label}: {p:.4f}")
except Exception as e:
st.error(f"Error during prediction: {e}")
if __name__ == "__main__":
main()