bear-detector / app.py
williamj949's picture
update footer
8fbace9
import gradio as gr
from fastai.vision.all import *
from pathlib import Path
import numpy as np
def load_model():
"""Load the exported FastAI model"""
try:
model_path = Path('bears_model_clean.pkl')
learn = load_learner(model_path)
return learn
except Exception as e:
print(f"Error loading model: {e}")
return None
learn = load_model()
def classify_bear(image):
"""
Detect bear species from uploaded image
Args:
image: PIL Image or numpy array
Returns:
dict: Prediction probabilities for each bear type
"""
if learn is None:
return {"Error": "Model not loaded properly"}
if image is None:
return {"No Image": "Please upload an image"}
try:
# Make prediction
pred, pred_idx, probs = learn.predict(image)
# Get class names
class_names = learn.dls.vocab
# Create confidence dictionary
confidences = {}
for i, class_name in enumerate(class_names):
confidences[class_name] = float(probs[i])
return confidences
except Exception as e:
return {"Error": f"Prediction failed: {str(e)}"}
def get_bear_info(prediction_dict):
"""
Get information about the predicted bear type
Args:
prediction_dict: Dictionary with prediction confidences
Returns:
str: Information about the most likely bear type
"""
if "Error" in prediction_dict:
return prediction_dict["Error"]
if "No Image" in prediction_dict:
return "Upload an image to learn about the bear species!"
# Get the bear type with highest confidence
top_prediction = max(prediction_dict.items(), key=lambda x: x[1])
bear_type = top_prediction[0]
confidence = top_prediction[1]
# Bear information dictionary
bear_info = {
"black": "🐻 **Black Bear**: The most common bear in North America. They're excellent climbers and swimmers, with a varied omnivorous diet.",
"grizzly": "🐻 **Grizzly Bear**: A powerful subspecies of brown bear found in North America. Known for their distinctive shoulder hump and long claws.",
"polar": "πŸ»β€β„οΈ **Polar Bear**: The largest bear species, perfectly adapted to Arctic life. They're excellent swimmers and primarily hunt seals.",
"panda": "🐼 **Giant Panda**: A beloved bear species native to China, famous for their black and white coloring and bamboo diet.",
"teddy": "🧸 **Teddy Bear**: A stuffed toy bear! Named after President Theodore Roosevelt, these cuddly companions have been beloved by children for over a century."
}
# Find matching bear info (case insensitive)
info = ""
for key, value in bear_info.items():
if key.lower() in bear_type.lower():
info = value
break
if not info:
info = f"🐻 **{bear_type}**: A type of bear!"
return f"{info}\n\n**Confidence**: {confidence:.1%}"
def predict_and_explain(image):
"""
Main function that combines prediction and explanation
Args:
image: Input image
Returns:
tuple: (prediction_dict, explanation_text)
"""
predictions = classify_bear(image)
explanation = get_bear_info(predictions)
return predictions, explanation
def handle_image_change(image):
"""
Handle image change events with proper None checking
Args:
image: Input image (can be None when cleared)
Returns:
tuple: (prediction_dict, explanation_text)
"""
if image is None:
return {}, "Upload an image to learn about the bear species!"
return predict_and_explain(image)
def get_sample_images():
"""
Get list of sample images if they exist
Returns:
list: List of image paths for examples
"""
sample_paths = [
"samples/black.jpg",
"samples/grizzly.jpg",
"samples/polar.jpg",
"samples/panda.jpg",
"samples/teddy.jpg"
]
existing_samples = []
for path in sample_paths:
if Path(path).exists():
existing_samples.append([path])
print(f"βœ… Found sample image: {path}")
else:
print(f"⚠️ Sample image not found: {path}")
return existing_samples
def create_interface():
"""Create and configure the Gradio interface"""
css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.bear-title {
text-align: center;
color: #8B4513;
font-size: 2.5em;
margin-bottom: 20px;
}
.profile-links {
text-align: center;
margin: 20px 0;
padding: 15px;
background: linear-gradient(135deg, #FF6347 0%, #FFA500 100%);
border-radius: 10px;
color: white;
}
.profile-links a {
color: #fff;
text-decoration: none;
margin: 0 15px;
padding: 8px 16px;
background: rgba(255,255,255,0.2);
border-radius: 20px;
transition: all 0.3s ease;
display: inline-block;
}
.profile-links a:hover {
background: rgba(255,255,255,0.3);
transform: translateY(-2px);
}
.project-info {
text-align: center;
margin: 20px 0;
padding: 15px;
background: #f8f9fa;
border-radius: 10px;
border-left: 4px solid #8B4513;
}
"""
with gr.Blocks(css=css, title="🐻 Bear Species Detector") as demo:
gr.HTML("""
<div class="bear-title">
🐻 Bear Species Detector 🐼
</div>
<div class="profile-links">
<strong>πŸ”— Explore More:</strong><br>
<a href="https://www.kaggle.com/williamwillj" target="_blank">πŸ“Š Kaggle Profile</a>
<a href="https://huggingface.co/williamj949" target="_blank">πŸ€— HuggingFace Profile</a>
<a href="https://www.kaggle.com/code/williamwillj/bear-detector/notebook" target="_blank">πŸ“‹ Training Notebook</a>
</div>
<div class="project-info">
<p style="font-size: 1.2em; color: #666; margin-bottom: 10px;">
Upload an image of a bear and I'll tell you what species it is!<br>
<em>Supports: Black Bear, Grizzly Bear, Polar Bear, Giant Panda, and even Teddy Bears! 🧸</em>
</p>
</div>
""")
with gr.Row():
with gr.Column():
# Image input
image_input = gr.Image(
label="Upload Bear Image πŸ“Έ",
type="pil",
height=400
)
# Submit button
submit_btn = gr.Button(
"Detect Bear Type! πŸ”",
variant="primary",
size="lg"
)
# Get sample images
sample_images = get_sample_images()
# Only show examples if we have sample images
if sample_images:
gr.Examples(
examples=sample_images,
inputs=image_input,
label="Try these examples:"
)
else:
gr.HTML("""
<p style="text-align: center; color: #888; font-style: italic;">
πŸ’‘ Add sample images to the 'samples/' folder to see examples here!
</p>
""")
with gr.Column():
# Prediction output
prediction_output = gr.Label(
label="Prediction Confidence πŸ“Š",
num_top_classes=5
)
# Bear information output
info_output = gr.Markdown(
label="Bear Information πŸ“–",
value="Upload an image to learn about the bear species!"
)
# Connect the interface
submit_btn.click(
fn=predict_and_explain,
inputs=image_input,
outputs=[prediction_output, info_output]
)
# Also trigger on image upload
image_input.change(
fn=handle_image_change,
inputs=image_input,
outputs=[prediction_output, info_output]
)
gr.HTML("""
<div style="text-align: center; margin-top: 30px; padding: 20px; background: #f8f9fa; border-radius: 10px;">
<h3 style="color: #8B4513; margin-bottom: 15px;">πŸš€ Project Highlights</h3>
<div style="display: flex; justify-content: center; gap: 30px; flex-wrap: wrap; margin-bottom: 15px;">
<div>
<strong>🎯 Accuracy:</strong> 96%+ on the data set
</div>
<div>
<strong>πŸ”§ Tech Stack:</strong> FastAI + Gradio
</div>
</div>
<p style="color: #666; font-size: 0.9em;">
Check out my
<a href="https://www.kaggle.com/williamwillj" target="_blank" style="color: #8B4513;">Kaggle profile</a>
and
<a href="https://huggingface.co/williamj949" target="_blank" style="color: #8B4513;">Hugging Face profile</a>
for more ML projects!
</p>
<p style="color: #888; margin-top: 15px;">
Built with ❀️ using FastAI and Gradio
</p>
</div>
""")
return demo
# Main execution
if __name__ == "__main__":
# Check if model is loaded
if learn is None:
print("❌ Error: Could not load the model. Please ensure 'bears_model_xx.pkl' is in the correct path.")
print("πŸ’‘ Tip: Update the model_path in the load_model() function to point to your saved model.")
else:
print("βœ… Model loaded successfully!")
print(f"πŸ“‹ Classes: {learn.dls.vocab}")
demo = create_interface()
demo.launch(
share=True, # Set to True to create a public link
server_name="0.0.0.0", # Allow access from any IP
server_port=7860, # Default Gradio port
show_error=True)