Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastai.vision.all import * | |
| from pathlib import Path | |
| import numpy as np | |
| def load_model(): | |
| """Load the exported FastAI model""" | |
| try: | |
| model_path = Path('bears_model_clean.pkl') | |
| learn = load_learner(model_path) | |
| return learn | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| learn = load_model() | |
| def classify_bear(image): | |
| """ | |
| Detect bear species from uploaded image | |
| Args: | |
| image: PIL Image or numpy array | |
| Returns: | |
| dict: Prediction probabilities for each bear type | |
| """ | |
| if learn is None: | |
| return {"Error": "Model not loaded properly"} | |
| if image is None: | |
| return {"No Image": "Please upload an image"} | |
| try: | |
| # Make prediction | |
| pred, pred_idx, probs = learn.predict(image) | |
| # Get class names | |
| class_names = learn.dls.vocab | |
| # Create confidence dictionary | |
| confidences = {} | |
| for i, class_name in enumerate(class_names): | |
| confidences[class_name] = float(probs[i]) | |
| return confidences | |
| except Exception as e: | |
| return {"Error": f"Prediction failed: {str(e)}"} | |
| def get_bear_info(prediction_dict): | |
| """ | |
| Get information about the predicted bear type | |
| Args: | |
| prediction_dict: Dictionary with prediction confidences | |
| Returns: | |
| str: Information about the most likely bear type | |
| """ | |
| if "Error" in prediction_dict: | |
| return prediction_dict["Error"] | |
| if "No Image" in prediction_dict: | |
| return "Upload an image to learn about the bear species!" | |
| # Get the bear type with highest confidence | |
| top_prediction = max(prediction_dict.items(), key=lambda x: x[1]) | |
| bear_type = top_prediction[0] | |
| confidence = top_prediction[1] | |
| # Bear information dictionary | |
| bear_info = { | |
| "black": "π» **Black Bear**: The most common bear in North America. They're excellent climbers and swimmers, with a varied omnivorous diet.", | |
| "grizzly": "π» **Grizzly Bear**: A powerful subspecies of brown bear found in North America. Known for their distinctive shoulder hump and long claws.", | |
| "polar": "π»ββοΈ **Polar Bear**: The largest bear species, perfectly adapted to Arctic life. They're excellent swimmers and primarily hunt seals.", | |
| "panda": "πΌ **Giant Panda**: A beloved bear species native to China, famous for their black and white coloring and bamboo diet.", | |
| "teddy": "π§Έ **Teddy Bear**: A stuffed toy bear! Named after President Theodore Roosevelt, these cuddly companions have been beloved by children for over a century." | |
| } | |
| # Find matching bear info (case insensitive) | |
| info = "" | |
| for key, value in bear_info.items(): | |
| if key.lower() in bear_type.lower(): | |
| info = value | |
| break | |
| if not info: | |
| info = f"π» **{bear_type}**: A type of bear!" | |
| return f"{info}\n\n**Confidence**: {confidence:.1%}" | |
| def predict_and_explain(image): | |
| """ | |
| Main function that combines prediction and explanation | |
| Args: | |
| image: Input image | |
| Returns: | |
| tuple: (prediction_dict, explanation_text) | |
| """ | |
| predictions = classify_bear(image) | |
| explanation = get_bear_info(predictions) | |
| return predictions, explanation | |
| def handle_image_change(image): | |
| """ | |
| Handle image change events with proper None checking | |
| Args: | |
| image: Input image (can be None when cleared) | |
| Returns: | |
| tuple: (prediction_dict, explanation_text) | |
| """ | |
| if image is None: | |
| return {}, "Upload an image to learn about the bear species!" | |
| return predict_and_explain(image) | |
| def get_sample_images(): | |
| """ | |
| Get list of sample images if they exist | |
| Returns: | |
| list: List of image paths for examples | |
| """ | |
| sample_paths = [ | |
| "samples/black.jpg", | |
| "samples/grizzly.jpg", | |
| "samples/polar.jpg", | |
| "samples/panda.jpg", | |
| "samples/teddy.jpg" | |
| ] | |
| existing_samples = [] | |
| for path in sample_paths: | |
| if Path(path).exists(): | |
| existing_samples.append([path]) | |
| print(f"β Found sample image: {path}") | |
| else: | |
| print(f"β οΈ Sample image not found: {path}") | |
| return existing_samples | |
| def create_interface(): | |
| """Create and configure the Gradio interface""" | |
| css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .bear-title { | |
| text-align: center; | |
| color: #8B4513; | |
| font-size: 2.5em; | |
| margin-bottom: 20px; | |
| } | |
| .profile-links { | |
| text-align: center; | |
| margin: 20px 0; | |
| padding: 15px; | |
| background: linear-gradient(135deg, #FF6347 0%, #FFA500 100%); | |
| border-radius: 10px; | |
| color: white; | |
| } | |
| .profile-links a { | |
| color: #fff; | |
| text-decoration: none; | |
| margin: 0 15px; | |
| padding: 8px 16px; | |
| background: rgba(255,255,255,0.2); | |
| border-radius: 20px; | |
| transition: all 0.3s ease; | |
| display: inline-block; | |
| } | |
| .profile-links a:hover { | |
| background: rgba(255,255,255,0.3); | |
| transform: translateY(-2px); | |
| } | |
| .project-info { | |
| text-align: center; | |
| margin: 20px 0; | |
| padding: 15px; | |
| background: #f8f9fa; | |
| border-radius: 10px; | |
| border-left: 4px solid #8B4513; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="π» Bear Species Detector") as demo: | |
| gr.HTML(""" | |
| <div class="bear-title"> | |
| π» Bear Species Detector πΌ | |
| </div> | |
| <div class="profile-links"> | |
| <strong>π Explore More:</strong><br> | |
| <a href="https://www.kaggle.com/williamwillj" target="_blank">π Kaggle Profile</a> | |
| <a href="https://huggingface.co/williamj949" target="_blank">π€ HuggingFace Profile</a> | |
| <a href="https://www.kaggle.com/code/williamwillj/bear-detector/notebook" target="_blank">π Training Notebook</a> | |
| </div> | |
| <div class="project-info"> | |
| <p style="font-size: 1.2em; color: #666; margin-bottom: 10px;"> | |
| Upload an image of a bear and I'll tell you what species it is!<br> | |
| <em>Supports: Black Bear, Grizzly Bear, Polar Bear, Giant Panda, and even Teddy Bears! π§Έ</em> | |
| </p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Image input | |
| image_input = gr.Image( | |
| label="Upload Bear Image πΈ", | |
| type="pil", | |
| height=400 | |
| ) | |
| # Submit button | |
| submit_btn = gr.Button( | |
| "Detect Bear Type! π", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| # Get sample images | |
| sample_images = get_sample_images() | |
| # Only show examples if we have sample images | |
| if sample_images: | |
| gr.Examples( | |
| examples=sample_images, | |
| inputs=image_input, | |
| label="Try these examples:" | |
| ) | |
| else: | |
| gr.HTML(""" | |
| <p style="text-align: center; color: #888; font-style: italic;"> | |
| π‘ Add sample images to the 'samples/' folder to see examples here! | |
| </p> | |
| """) | |
| with gr.Column(): | |
| # Prediction output | |
| prediction_output = gr.Label( | |
| label="Prediction Confidence π", | |
| num_top_classes=5 | |
| ) | |
| # Bear information output | |
| info_output = gr.Markdown( | |
| label="Bear Information π", | |
| value="Upload an image to learn about the bear species!" | |
| ) | |
| # Connect the interface | |
| submit_btn.click( | |
| fn=predict_and_explain, | |
| inputs=image_input, | |
| outputs=[prediction_output, info_output] | |
| ) | |
| # Also trigger on image upload | |
| image_input.change( | |
| fn=handle_image_change, | |
| inputs=image_input, | |
| outputs=[prediction_output, info_output] | |
| ) | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 30px; padding: 20px; background: #f8f9fa; border-radius: 10px;"> | |
| <h3 style="color: #8B4513; margin-bottom: 15px;">π Project Highlights</h3> | |
| <div style="display: flex; justify-content: center; gap: 30px; flex-wrap: wrap; margin-bottom: 15px;"> | |
| <div> | |
| <strong>π― Accuracy:</strong> 96%+ on the data set | |
| </div> | |
| <div> | |
| <strong>π§ Tech Stack:</strong> FastAI + Gradio | |
| </div> | |
| </div> | |
| <p style="color: #666; font-size: 0.9em;"> | |
| Check out my | |
| <a href="https://www.kaggle.com/williamwillj" target="_blank" style="color: #8B4513;">Kaggle profile</a> | |
| and | |
| <a href="https://huggingface.co/williamj949" target="_blank" style="color: #8B4513;">Hugging Face profile</a> | |
| for more ML projects! | |
| </p> | |
| <p style="color: #888; margin-top: 15px;"> | |
| Built with β€οΈ using FastAI and Gradio | |
| </p> | |
| </div> | |
| """) | |
| return demo | |
| # Main execution | |
| if __name__ == "__main__": | |
| # Check if model is loaded | |
| if learn is None: | |
| print("β Error: Could not load the model. Please ensure 'bears_model_xx.pkl' is in the correct path.") | |
| print("π‘ Tip: Update the model_path in the load_model() function to point to your saved model.") | |
| else: | |
| print("β Model loaded successfully!") | |
| print(f"π Classes: {learn.dls.vocab}") | |
| demo = create_interface() | |
| demo.launch( | |
| share=True, # Set to True to create a public link | |
| server_name="0.0.0.0", # Allow access from any IP | |
| server_port=7860, # Default Gradio port | |
| show_error=True) | |