import io from pathlib import Path import streamlit as st from fastai.vision.all import load_learner, PILImage # ✅ Correct absolute path for Hugging Face Spaces MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl") @st.cache_resource def load_model(): """Load and cache the FastAI learner. Returns None if model missing or incompatible.""" if not MODEL_PATH.exists(): st.error(f"❌ Model not found at {MODEL_PATH}") return None try: learner = load_learner(MODEL_PATH) return learner except Exception as e: st.error(f"⚠️ Error loading model:\n\n{e}") return None def predict(learner, img_bytes: bytes): """Make a prediction on uploaded image bytes.""" img = PILImage.create(io.BytesIO(img_bytes)) pred, pred_idx, probs = learner.predict(img) return pred, probs def main(): st.title("🎯 FastAI Image Classifier") st.write("Upload an image and the model will predict its class.") learner = load_model() if learner is None: st.warning( "Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS." ) st.stop() uploaded_file = st.file_uploader("📤 Choose an image...", type=["png", "jpg", "jpeg"]) if uploaded_file is not None: st.image(uploaded_file, caption="Uploaded Image", use_column_width=True) with st.spinner("Predicting..."): try: pred, probs = predict(learner, uploaded_file.read()) st.success(f"✅ Predicted: **{pred}**") # Show top-5 predictions vocab = learner.dls.vocab probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True) st.write("### Top Predictions:") for label, p in probs_list[:5]: st.write(f"- {label}: {p:.4f}") except Exception as e: st.error(f"Error during prediction: {e}") if __name__ == "__main__": main()