Spaces:
Sleeping
Sleeping
File size: 2,085 Bytes
aab7641 3208c00 aab7641 b160f6c 9b0cb23 b160f6c aab7641 b160f6c aab7641 08cb881 b160f6c 08cb881 aab7641 b160f6c aab7641 b160f6c aab7641 b160f6c aab7641 b160f6c aab7641 b160f6c aab7641 08cb881 b160f6c aab7641 b160f6c 08cb881 b160f6c 08cb881 b160f6c aab7641 3208c00 aab7641 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | import io
from pathlib import Path
import streamlit as st
from fastai.vision.all import load_learner, PILImage
# ✅ Correct absolute path for Hugging Face Spaces
MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
@st.cache_resource
def load_model():
"""Load and cache the FastAI learner. Returns None if model missing or incompatible."""
if not MODEL_PATH.exists():
st.error(f"❌ Model not found at {MODEL_PATH}")
return None
try:
learner = load_learner(MODEL_PATH)
return learner
except Exception as e:
st.error(f"⚠️ Error loading model:\n\n{e}")
return None
def predict(learner, img_bytes: bytes):
"""Make a prediction on uploaded image bytes."""
img = PILImage.create(io.BytesIO(img_bytes))
pred, pred_idx, probs = learner.predict(img)
return pred, probs
def main():
st.title("🎯 FastAI Image Classifier")
st.write("Upload an image and the model will predict its class.")
learner = load_model()
if learner is None:
st.warning(
"Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
)
st.stop()
uploaded_file = st.file_uploader("📤 Choose an image...", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
with st.spinner("Predicting..."):
try:
pred, probs = predict(learner, uploaded_file.read())
st.success(f"✅ Predicted: **{pred}**")
# Show top-5 predictions
vocab = learner.dls.vocab
probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
st.write("### Top Predictions:")
for label, p in probs_list[:5]:
st.write(f"- {label}: {p:.4f}")
except Exception as e:
st.error(f"Error during prediction: {e}")
if __name__ == "__main__":
main()
|