Spaces:
Sleeping
Sleeping
| import io | |
| from pathlib import Path | |
| import streamlit as st | |
| from fastai.vision.all import load_learner, PILImage | |
| # ✅ Correct absolute path for Hugging Face Spaces | |
| MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl") | |
| # Custom CSS for modern UI | |
| st.markdown(""" | |
| <style> | |
| /* Main container styling */ | |
| .main { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 2rem; | |
| } | |
| /* Card-like containers */ | |
| .stApp { | |
| max-width: 1200px; | |
| margin: 0 auto; | |
| } | |
| /* Title styling */ | |
| h1 { | |
| color: white !important; | |
| text-align: center; | |
| font-size: 3rem !important; | |
| font-weight: 800 !important; | |
| margin-bottom: 0.5rem !important; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.2); | |
| } | |
| /* Subtitle styling */ | |
| .subtitle { | |
| text-align: center; | |
| color: rgba(255,255,255,0.9); | |
| font-size: 1.2rem; | |
| margin-bottom: 2rem; | |
| } | |
| /* File uploader styling */ | |
| .stFileUploader { | |
| background: white; | |
| border-radius: 15px; | |
| padding: 2rem; | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.2); | |
| } | |
| /* Prediction result card */ | |
| .prediction-card { | |
| background: white; | |
| border-radius: 15px; | |
| padding: 2rem; | |
| margin-top: 2rem; | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.2); | |
| } | |
| /* Success message styling */ | |
| .stSuccess { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white !important; | |
| border-radius: 10px; | |
| font-size: 1.5rem; | |
| font-weight: bold; | |
| text-align: center; | |
| padding: 1rem; | |
| } | |
| /* Progress bar */ | |
| .stProgress > div > div { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| } | |
| /* Buttons */ | |
| .stButton > button { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| color: white; | |
| border: none; | |
| border-radius: 10px; | |
| padding: 0.75rem 2rem; | |
| font-size: 1.1rem; | |
| font-weight: 600; | |
| transition: all 0.3s ease; | |
| } | |
| .stButton > button:hover { | |
| transform: translateY(-2px); | |
| box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); | |
| } | |
| /* Image container */ | |
| .uploaded-image { | |
| border-radius: 15px; | |
| overflow: hidden; | |
| box-shadow: 0 10px 30px rgba(0,0,0,0.2); | |
| margin: 2rem 0; | |
| } | |
| /* Probability bars */ | |
| .prob-bar { | |
| background: #f0f2f6; | |
| border-radius: 10px; | |
| height: 40px; | |
| margin: 0.5rem 0; | |
| overflow: hidden; | |
| position: relative; | |
| } | |
| .prob-fill { | |
| height: 100%; | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| display: flex; | |
| align-items: center; | |
| padding: 0 1rem; | |
| color: white; | |
| font-weight: 600; | |
| transition: width 0.5s ease; | |
| } | |
| .prob-label { | |
| position: absolute; | |
| left: 1rem; | |
| top: 50%; | |
| transform: translateY(-50%); | |
| font-weight: 600; | |
| z-index: 1; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def load_model(): | |
| """Load and cache the FastAI learner. Returns None if model missing or incompatible.""" | |
| if not MODEL_PATH.exists(): | |
| st.error(f"❌ Model not found at {MODEL_PATH}") | |
| return None | |
| try: | |
| learner = load_learner(MODEL_PATH) | |
| return learner | |
| except Exception as e: | |
| st.error(f"⚠️ Error loading model:\n\n{e}") | |
| return None | |
| def predict(learner, img_bytes: bytes): | |
| """Make a prediction on uploaded image bytes.""" | |
| img = PILImage.create(io.BytesIO(img_bytes)) | |
| pred, pred_idx, probs = learner.predict(img) | |
| return pred, probs | |
| def main(): | |
| # Header | |
| st.title("🎮 Pokémon Gen 9 Classifier") | |
| st.markdown('<p class="subtitle">Upload a Pokémon image and discover which species it is!</p>', unsafe_allow_html=True) | |
| learner = load_model() | |
| if learner is None: | |
| st.warning( | |
| "⚠️ Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS." | |
| ) | |
| st.stop() | |
| # Example images section | |
| st.markdown("---") | |
| st.markdown("### 🖼️ Try with Example Images") | |
| # Define example images path (adjust this to your actual examples folder) | |
| examples_path = Path("examples") | |
| if examples_path.exists(): | |
| example_images = list(examples_path.glob("*.jpg")) + list(examples_path.glob("*.png")) + list(examples_path.glob("*.jpeg")) | |
| if example_images: | |
| # Display examples in a grid | |
| cols = st.columns(min(5, len(example_images))) | |
| for idx, img_path in enumerate(example_images[:5]): # Show max 5 examples | |
| with cols[idx]: | |
| st.image(str(img_path), use_container_width=True, caption=img_path.stem) | |
| if st.button(f"Use", key=f"example_{idx}"): | |
| # Store the selected example in session state | |
| st.session_state.example_image = img_path | |
| else: | |
| st.info("No example images found in the 'examples' folder.") | |
| else: | |
| st.info("💡 **Tip:** Create an 'examples' folder with sample Pokémon images to display them here!") | |
| st.markdown("---") | |
| # Create two columns for better layout | |
| col1, col2 = st.columns([1, 1]) | |
| # Check if example image was selected | |
| uploaded_file = None | |
| display_image = None | |
| if 'example_image' in st.session_state: | |
| example_path = st.session_state.example_image | |
| uploaded_file = example_path | |
| display_image = str(example_path) | |
| del st.session_state.example_image # Clear after use | |
| with col1: | |
| file_upload = st.file_uploader( | |
| "Choose a Pokémon image", | |
| type=["png", "jpg", "jpeg"], | |
| help="Upload a clear image of a Generation 9 Pokémon" | |
| ) | |
| # Prioritize file upload over example | |
| if file_upload is not None: | |
| uploaded_file = file_upload | |
| display_image = file_upload | |
| if display_image is not None: | |
| st.markdown('<div class="uploaded-image">', unsafe_allow_html=True) | |
| st.image(display_image, use_container_width=True) | |
| st.markdown('</div>', unsafe_allow_html=True) | |
| with col2: | |
| if uploaded_file is not None: | |
| with st.spinner("🔍 Analyzing image..."): | |
| try: | |
| # Read image bytes based on type | |
| if isinstance(uploaded_file, Path): | |
| # Example image - read from file | |
| with open(uploaded_file, 'rb') as f: | |
| img_bytes = f.read() | |
| else: | |
| # Uploaded file | |
| img_bytes = uploaded_file.read() | |
| pred, probs = predict(learner, img_bytes) | |
| # Main prediction with emoji | |
| st.markdown("### 🎯 Prediction Result") | |
| st.success(f"✨ **{pred}**") | |
| # Confidence percentage | |
| max_prob = float(probs.max()) | |
| st.metric( | |
| "Confidence", | |
| f"{max_prob*100:.1f}%", | |
| delta=None | |
| ) | |
| # Top predictions with visual progress bars | |
| st.markdown("### 📊 Top 5 Predictions") | |
| vocab = learner.dls.vocab | |
| probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True) | |
| for i, (label, p) in enumerate(probs_list[:5]): | |
| prob_percent = float(p) * 100 | |
| # Custom progress bar with label | |
| st.markdown(f""" | |
| <div style="margin: 1rem 0;"> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 0.3rem;"> | |
| <span style="font-weight: 600; color: #1f2937;">{'🥇' if i==0 else '🥈' if i==1 else '🥉' if i==2 else '⭐'} {label}</span> | |
| <span style="font-weight: 600; color: #667eea;">{prob_percent:.1f}%</span> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| st.progress(float(p)) | |
| except Exception as e: | |
| st.error(f"❌ Error during prediction: {e}") | |
| else: | |
| # Placeholder when no image is uploaded | |
| st.info("👆 Upload an image to get started!") | |
| st.markdown(""" | |
| ### How to use: | |
| 1. 📤 Upload a Pokémon image (PNG, JPG, or JPEG) | |
| 2. ⏳ Wait for the AI to analyze it | |
| 3. 🎉 See the prediction and confidence scores! | |
| **Tip:** Use clear, well-lit images for best results! | |
| """) | |
| if __name__ == "__main__": | |
| main() |