import io from pathlib import Path import streamlit as st from fastai.vision.all import load_learner, PILImage # ✅ Correct absolute path for Hugging Face Spaces MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl") # Custom CSS for modern UI st.markdown(""" """, unsafe_allow_html=True) @st.cache_resource def load_model(): """Load and cache the FastAI learner. Returns None if model missing or incompatible.""" if not MODEL_PATH.exists(): st.error(f"❌ Model not found at {MODEL_PATH}") return None try: learner = load_learner(MODEL_PATH) return learner except Exception as e: st.error(f"⚠️ Error loading model:\n\n{e}") return None def predict(learner, img_bytes: bytes): """Make a prediction on uploaded image bytes.""" img = PILImage.create(io.BytesIO(img_bytes)) pred, pred_idx, probs = learner.predict(img) return pred, probs def main(): # Header st.title("🎮 Pokémon Gen 9 Classifier") st.markdown('
Upload a Pokémon image and discover which species it is!
', unsafe_allow_html=True) learner = load_model() if learner is None: st.warning( "⚠️ Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS." ) st.stop() # Example images section st.markdown("---") st.markdown("### 🖼️ Try with Example Images") # Define example images path (adjust this to your actual examples folder) examples_path = Path("examples") if examples_path.exists(): example_images = list(examples_path.glob("*.jpg")) + list(examples_path.glob("*.png")) + list(examples_path.glob("*.jpeg")) if example_images: # Display examples in a grid cols = st.columns(min(5, len(example_images))) for idx, img_path in enumerate(example_images[:5]): # Show max 5 examples with cols[idx]: st.image(str(img_path), use_container_width=True, caption=img_path.stem) if st.button(f"Use", key=f"example_{idx}"): # Store the selected example in session state st.session_state.example_image = img_path else: st.info("No example images found in the 'examples' folder.") else: st.info("💡 **Tip:** Create an 'examples' folder with sample Pokémon images to display them here!") st.markdown("---") # Create two columns for better layout col1, col2 = st.columns([1, 1]) # Check if example image was selected uploaded_file = None display_image = None if 'example_image' in st.session_state: example_path = st.session_state.example_image uploaded_file = example_path display_image = str(example_path) del st.session_state.example_image # Clear after use with col1: file_upload = st.file_uploader( "Choose a Pokémon image", type=["png", "jpg", "jpeg"], help="Upload a clear image of a Generation 9 Pokémon" ) # Prioritize file upload over example if file_upload is not None: uploaded_file = file_upload display_image = file_upload if display_image is not None: st.markdown('