shuraimi's picture
Update app.py
9e0fb8d verified
import io
from pathlib import Path
import streamlit as st
from fastai.vision.all import load_learner, PILImage
# ✅ Correct absolute path for Hugging Face Spaces
MODEL_PATH = Path("models/pokemon_gen9_classifier_resnet101_after_cleaning.pkl")
# Custom CSS for modern UI
st.markdown("""
<style>
/* Main container styling */
.main {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 2rem;
}
/* Card-like containers */
.stApp {
max-width: 1200px;
margin: 0 auto;
}
/* Title styling */
h1 {
color: white !important;
text-align: center;
font-size: 3rem !important;
font-weight: 800 !important;
margin-bottom: 0.5rem !important;
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
}
/* Subtitle styling */
.subtitle {
text-align: center;
color: rgba(255,255,255,0.9);
font-size: 1.2rem;
margin-bottom: 2rem;
}
/* File uploader styling */
.stFileUploader {
background: white;
border-radius: 15px;
padding: 2rem;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
/* Prediction result card */
.prediction-card {
background: white;
border-radius: 15px;
padding: 2rem;
margin-top: 2rem;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
/* Success message styling */
.stSuccess {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white !important;
border-radius: 10px;
font-size: 1.5rem;
font-weight: bold;
text-align: center;
padding: 1rem;
}
/* Progress bar */
.stProgress > div > div {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
}
/* Buttons */
.stButton > button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 10px;
padding: 0.75rem 2rem;
font-size: 1.1rem;
font-weight: 600;
transition: all 0.3s ease;
}
.stButton > button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
/* Image container */
.uploaded-image {
border-radius: 15px;
overflow: hidden;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
margin: 2rem 0;
}
/* Probability bars */
.prob-bar {
background: #f0f2f6;
border-radius: 10px;
height: 40px;
margin: 0.5rem 0;
overflow: hidden;
position: relative;
}
.prob-fill {
height: 100%;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
display: flex;
align-items: center;
padding: 0 1rem;
color: white;
font-weight: 600;
transition: width 0.5s ease;
}
.prob-label {
position: absolute;
left: 1rem;
top: 50%;
transform: translateY(-50%);
font-weight: 600;
z-index: 1;
}
</style>
""", unsafe_allow_html=True)
@st.cache_resource
def load_model():
"""Load and cache the FastAI learner. Returns None if model missing or incompatible."""
if not MODEL_PATH.exists():
st.error(f"❌ Model not found at {MODEL_PATH}")
return None
try:
learner = load_learner(MODEL_PATH)
return learner
except Exception as e:
st.error(f"⚠️ Error loading model:\n\n{e}")
return None
def predict(learner, img_bytes: bytes):
"""Make a prediction on uploaded image bytes."""
img = PILImage.create(io.BytesIO(img_bytes))
pred, pred_idx, probs = learner.predict(img)
return pred, probs
def main():
# Header
st.title("🎮 Pokémon Gen 9 Classifier")
st.markdown('<p class="subtitle">Upload a Pokémon image and discover which species it is!</p>', unsafe_allow_html=True)
learner = load_model()
if learner is None:
st.warning(
"⚠️ Model not loaded. Please ensure the `.pkl` file is correctly placed under `models/` and committed with Git LFS."
)
st.stop()
# Example images section
st.markdown("---")
st.markdown("### 🖼️ Try with Example Images")
# Define example images path (adjust this to your actual examples folder)
examples_path = Path("examples")
if examples_path.exists():
example_images = list(examples_path.glob("*.jpg")) + list(examples_path.glob("*.png")) + list(examples_path.glob("*.jpeg"))
if example_images:
# Display examples in a grid
cols = st.columns(min(5, len(example_images)))
for idx, img_path in enumerate(example_images[:5]): # Show max 5 examples
with cols[idx]:
st.image(str(img_path), use_container_width=True, caption=img_path.stem)
if st.button(f"Use", key=f"example_{idx}"):
# Store the selected example in session state
st.session_state.example_image = img_path
else:
st.info("No example images found in the 'examples' folder.")
else:
st.info("💡 **Tip:** Create an 'examples' folder with sample Pokémon images to display them here!")
st.markdown("---")
# Create two columns for better layout
col1, col2 = st.columns([1, 1])
# Check if example image was selected
uploaded_file = None
display_image = None
if 'example_image' in st.session_state:
example_path = st.session_state.example_image
uploaded_file = example_path
display_image = str(example_path)
del st.session_state.example_image # Clear after use
with col1:
file_upload = st.file_uploader(
"Choose a Pokémon image",
type=["png", "jpg", "jpeg"],
help="Upload a clear image of a Generation 9 Pokémon"
)
# Prioritize file upload over example
if file_upload is not None:
uploaded_file = file_upload
display_image = file_upload
if display_image is not None:
st.markdown('<div class="uploaded-image">', unsafe_allow_html=True)
st.image(display_image, use_container_width=True)
st.markdown('</div>', unsafe_allow_html=True)
with col2:
if uploaded_file is not None:
with st.spinner("🔍 Analyzing image..."):
try:
# Read image bytes based on type
if isinstance(uploaded_file, Path):
# Example image - read from file
with open(uploaded_file, 'rb') as f:
img_bytes = f.read()
else:
# Uploaded file
img_bytes = uploaded_file.read()
pred, probs = predict(learner, img_bytes)
# Main prediction with emoji
st.markdown("### 🎯 Prediction Result")
st.success(f"✨ **{pred}**")
# Confidence percentage
max_prob = float(probs.max())
st.metric(
"Confidence",
f"{max_prob*100:.1f}%",
delta=None
)
# Top predictions with visual progress bars
st.markdown("### 📊 Top 5 Predictions")
vocab = learner.dls.vocab
probs_list = sorted(zip(vocab, probs), key=lambda x: x[1], reverse=True)
for i, (label, p) in enumerate(probs_list[:5]):
prob_percent = float(p) * 100
# Custom progress bar with label
st.markdown(f"""
<div style="margin: 1rem 0;">
<div style="display: flex; justify-content: space-between; margin-bottom: 0.3rem;">
<span style="font-weight: 600; color: #1f2937;">{'🥇' if i==0 else '🥈' if i==1 else '🥉' if i==2 else '⭐'} {label}</span>
<span style="font-weight: 600; color: #667eea;">{prob_percent:.1f}%</span>
</div>
</div>
""", unsafe_allow_html=True)
st.progress(float(p))
except Exception as e:
st.error(f"❌ Error during prediction: {e}")
else:
# Placeholder when no image is uploaded
st.info("👆 Upload an image to get started!")
st.markdown("""
### How to use:
1. 📤 Upload a Pokémon image (PNG, JPG, or JPEG)
2. ⏳ Wait for the AI to analyze it
3. 🎉 See the prediction and confidence scores!
**Tip:** Use clear, well-lit images for best results!
""")
if __name__ == "__main__":
main()