import io from pathlib import Path import streamlit as st from fastai.vision.all import load_learner, PILImage # ✅ Correct absolute path for Hugging Face Spaces MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl") # Custom CSS for modern UI st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): """Load and cache the FastAI learner. Returns None if model missing or incompatible.""" if not MODEL_PATH.exists(): st.error(f"❌ Model not found at {MODEL_PATH}") return None try: learner = load_learner(MODEL_PATH) return learner except Exception as e: st.error(f"⚠️ Error loading model:\n\n{e}") return None def predict(learner, img_bytes: bytes): """Make a prediction on uploaded image bytes.""" img = PILImage.create(io.BytesIO(img_bytes)) pred, pred_idx, probs = learner.predict(img) return pred, probs def main(): # Header st.title("🎮 Pokémon Gen 9 Classifier") st.markdown('

Upload a Pokémon image and discover which species it is!

', unsafe_allow_html=True) learner = load_model() if learner is None: st.warning( "⚠️ Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS." ) st.stop() # Example images section st.markdown("---") st.markdown("### 🖼️ Try with Example Images") # Define example images path (adjust this to your actual examples folder) examples_path = Path("examples") if examples_path.exists(): example_images = list(examples_path.glob("*.jpg")) + list(examples_path.glob("*.png")) + list(examples_path.glob("*.jpeg")) if example_images: # Display examples in a grid cols = st.columns(min(5, len(example_images))) for idx, img_path in enumerate(example_images[:5]): # Show max 5 examples with cols[idx]: st.image(str(img_path), use_container_width=True, caption=img_path.stem) if st.button(f"Use", key=f"example_{idx}"): # Store the selected example in session state st.session_state.example_image = img_path else: st.info("No example images found in the 'examples' folder.") else: st.info("💡 **Tip:** Create an 'examples' folder with sample Pokémon images to display them here!") st.markdown("---") # Create two columns for better layout col1, col2 = st.columns([1, 1]) # Check if example image was selected uploaded_file = None display_image = None if 'example_image' in st.session_state: example_path = st.session_state.example_image uploaded_file = example_path display_image = str(example_path) del st.session_state.example_image # Clear after use with col1: file_upload = st.file_uploader( "Choose a Pokémon image", type=["png", "jpg", "jpeg"], help="Upload a clear image of a Generation 9 Pokémon" ) # Prioritize file upload over example if file_upload is not None: uploaded_file = file_upload display_image = file_upload if display_image is not None: st.markdown('
', unsafe_allow_html=True) st.image(display_image, use_container_width=True) st.markdown('
', unsafe_allow_html=True) with col2: if uploaded_file is not None: with st.spinner("🔍 Analyzing image..."): try: # Read image bytes based on type if isinstance(uploaded_file, Path): # Example image - read from file with open(uploaded_file, 'rb') as f: img_bytes = f.read() else: # Uploaded file img_bytes = uploaded_file.read() pred, probs = predict(learner, img_bytes) # Main prediction with emoji st.markdown("### 🎯 Prediction Result") st.success(f"✨ **{pred}**") # Confidence percentage max_prob = float(probs.max()) st.metric( "Confidence", f"{max_prob*100:.1f}%", delta=None ) # Top predictions with visual progress bars st.markdown("### 📊 Top 5 Predictions") vocab = learner.dls.vocab probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True) for i, (label, p) in enumerate(probs_list[:5]): prob_percent = float(p) * 100 # Custom progress bar with label st.markdown(f"""
{'🥇' if i==0 else '🥈' if i==1 else '🥉' if i==2 else '⭐'} {label} {prob_percent:.1f}%
""", unsafe_allow_html=True) st.progress(float(p)) except Exception as e: st.error(f"❌ Error during prediction: {e}") else: # Placeholder when no image is uploaded st.info("👆 Upload an image to get started!") st.markdown(""" ### How to use: 1. 📤 Upload a Pokémon image (PNG, JPG, or JPEG) 2. ⏳ Wait for the AI to analyze it 3. 🎉 See the prediction and confidence scores! **Tip:** Use clear, well-lit images for best results! """) if __name__ == "__main__": main()