import gradio as gr from fastai.vision.all import * from pathlib import Path import numpy as np def load_model(): """Load the exported FastAI model""" try: model_path = Path('bears_model_clean.pkl') learn = load_learner(model_path) return learn except Exception as e: print(f"Error loading model: {e}") return None learn = load_model() def classify_bear(image): """ Detect bear species from uploaded image Args: image: PIL Image or numpy array Returns: dict: Prediction probabilities for each bear type """ if learn is None: return {"Error": "Model not loaded properly"} if image is None: return {"No Image": "Please upload an image"} try: # Make prediction pred, pred_idx, probs = learn.predict(image) # Get class names class_names = learn.dls.vocab # Create confidence dictionary confidences = {} for i, class_name in enumerate(class_names): confidences[class_name] = float(probs[i]) return confidences except Exception as e: return {"Error": f"Prediction failed: {str(e)}"} def get_bear_info(prediction_dict): """ Get information about the predicted bear type Args: prediction_dict: Dictionary with prediction confidences Returns: str: Information about the most likely bear type """ if "Error" in prediction_dict: return prediction_dict["Error"] if "No Image" in prediction_dict: return "Upload an image to learn about the bear species!" # Get the bear type with highest confidence top_prediction = max(prediction_dict.items(), key=lambda x: x[1]) bear_type = top_prediction[0] confidence = top_prediction[1] # Bear information dictionary bear_info = { "black": "🐻 **Black Bear**: The most common bear in North America. They're excellent climbers and swimmers, with a varied omnivorous diet.", "grizzly": "🐻 **Grizzly Bear**: A powerful subspecies of brown bear found in North America. Known for their distinctive shoulder hump and long claws.", "polar": "🐻❄️ **Polar Bear**: The largest bear species, perfectly adapted to Arctic life. They're excellent swimmers and primarily hunt seals.", "panda": "🐼 **Giant Panda**: A beloved bear species native to China, famous for their black and white coloring and bamboo diet.", "teddy": "🧸 **Teddy Bear**: A stuffed toy bear! Named after President Theodore Roosevelt, these cuddly companions have been beloved by children for over a century." } # Find matching bear info (case insensitive) info = "" for key, value in bear_info.items(): if key.lower() in bear_type.lower(): info = value break if not info: info = f"🐻 **{bear_type}**: A type of bear!" return f"{info}\n\n**Confidence**: {confidence:.1%}" def predict_and_explain(image): """ Main function that combines prediction and explanation Args: image: Input image Returns: tuple: (prediction_dict, explanation_text) """ predictions = classify_bear(image) explanation = get_bear_info(predictions) return predictions, explanation def handle_image_change(image): """ Handle image change events with proper None checking Args: image: Input image (can be None when cleared) Returns: tuple: (prediction_dict, explanation_text) """ if image is None: return {}, "Upload an image to learn about the bear species!" return predict_and_explain(image) def get_sample_images(): """ Get list of sample images if they exist Returns: list: List of image paths for examples """ sample_paths = [ "samples/black.jpg", "samples/grizzly.jpg", "samples/polar.jpg", "samples/panda.jpg", "samples/teddy.jpg" ] existing_samples = [] for path in sample_paths: if Path(path).exists(): existing_samples.append([path]) print(f"✅ Found sample image: {path}") else: print(f"⚠️ Sample image not found: {path}") return existing_samples def create_interface(): """Create and configure the Gradio interface""" css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .bear-title { text-align: center; color: #8B4513; font-size: 2.5em; margin-bottom: 20px; } .profile-links { text-align: center; margin: 20px 0; padding: 15px; background: linear-gradient(135deg, #FF6347 0%, #FFA500 100%); border-radius: 10px; color: white; } .profile-links a { color: #fff; text-decoration: none; margin: 0 15px; padding: 8px 16px; background: rgba(255,255,255,0.2); border-radius: 20px; transition: all 0.3s ease; display: inline-block; } .profile-links a:hover { background: rgba(255,255,255,0.3); transform: translateY(-2px); } .project-info { text-align: center; margin: 20px 0; padding: 15px; background: #f8f9fa; border-radius: 10px; border-left: 4px solid #8B4513; } """ with gr.Blocks(css=css, title="🐻 Bear Species Detector") as demo: gr.HTML("""
Upload an image of a bear and I'll tell you what species it is!
Supports: Black Bear, Grizzly Bear, Polar Bear, Giant Panda, and even Teddy Bears! 🧸
💡 Add sample images to the 'samples/' folder to see examples here!
""") with gr.Column(): # Prediction output prediction_output = gr.Label( label="Prediction Confidence 📊", num_top_classes=5 ) # Bear information output info_output = gr.Markdown( label="Bear Information 📖", value="Upload an image to learn about the bear species!" ) # Connect the interface submit_btn.click( fn=predict_and_explain, inputs=image_input, outputs=[prediction_output, info_output] ) # Also trigger on image upload image_input.change( fn=handle_image_change, inputs=image_input, outputs=[prediction_output, info_output] ) gr.HTML("""Check out my Kaggle profile and Hugging Face profile for more ML projects!
Built with ❤️ using FastAI and Gradio