bird-detector / app.py
jijinAI's picture
Update app.py
f0f2319 verified
import gradio as gr
import requests
from PIL import Image
from transformers import pipeline
import io
bird_list_md = """
*Note: This is a demonstration.*
| English Name | Māori Name |
| :--- | :--- |
| Antipodean albatross | Toroa |
| Auckland Island shag | Kawau o Motu Maha |
| Auckland Island teal | Tete kakariki |
| Australasian bittern | Matuku-hurepo |
| Australasian crested grebe | Pūteketeke |
| Black-fronted tern | Tarapirohe |
| Black noddy | |
| Black petrel | Taiko |
| Black robin | Karure |
| Black stilt | Kaki |
| Blue duck | Whio |
| Caspian tern | Taranui |
| Chatham Island oystercatcher | Torea tai |
| Chatham Island pigeon | Parea |
| Chatham Island shag | Papua |
| Chatham Island snipe | |
| Chatham Island taiko | Taiko |
| Chatham petrel | Ranguru |
| Eastern rockhopper penguin | Tawaki piki toka |
| Fairy tern | Tara iti |
| Forbes' parakeet | |
| Foveaux shag | Mapo |
| Great spotted kiwi | Roroa |
| Grey-headed mollymawk | Toroa |
| Grey duck | Parera |
| Hutton's shearwater | Kaikoura titi |
| Kakapo | |
| Kea | |
| Kermadec storm petrel | |
| Long-tailed cuckoo | Koekoea |
| Masked booby | |
| New Zealand king shag | Kawau pateketeke |
| New Zealand storm petrel | Takahikare-raro |
| Northern royal albatross | Toroa |
| Okarito brown kiwi | Rowi |
| Orange-fronted parakeet | Kakariki karaka |
| Pitt Island shag | Kawau o Rangihaute |
| Reef heron | Matuku moana |
| Rock wren | Piwauwau |
| Salvin's mollymawk | Toroa |
| Shore plover | Tuturuatu |
| South Island takahe | Takahe |
| Southern royal albatross | Toroa |
| Spotted shag | Kawau tikitiki |
| Stitchbird | Hihi |
| Subantarctic skua | Hakoakoa |
| Whenua Hou diving petrel | Kuaka Whenua Hou |
| White-bellied storm petrel | |
| White heron | Kotuku |
| White tern | |
| Yellow-eyed penguin | Hoiho |
"""
# Load the image classification model from Hugging Face just once
# This prevents reloading the model on every function call, which is much more efficient.
image_classifier = pipeline("image-classification", model="jijinAI/bird-detection")
def classify_image(image_file, image_url):
"""
This function takes an image from either a file upload or a URL,
classifies it, and returns the image and the classification results.
"""
# --- 1. Input Handling ---
# Prioritize the uploaded file if it exists.
if image_file is not None:
image = image_file
# Otherwise, try to use the URL.
elif image_url:
try:
response = requests.get(image_url, timeout=5)
response.raise_for_status() # Raise an exception for bad status codes
image = Image.open(io.BytesIO(response.content))
except requests.exceptions.RequestException as e:
raise gr.Error(f"Could not retrieve image from URL. Please check the link. Error: {e}")
except IOError:
raise gr.Error("The URL did not point to a valid image file.")
# If no input is provided, raise an error.
else:
raise gr.Error("Please upload an image or provide a URL.")
# --- 2. Classification ---
# Classify the image using the pre-loaded model.
results = image_classifier(image)
# --- 3. Format Output ---
# Convert the list of dictionaries from the model into a single
# dictionary of {label: score} for the Gradio Label component.
confidences = {item['label']: item['score'] for item in results}
# Return the processed image and the confidence dictionary.
return image, confidences
# --- 4. Define the Gradio Interface using Blocks for a custom layout ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🖼️ Image Classifier")
gr.Markdown("Upload an image from your computer or paste a URL to classify it.")
with gr.Row():
# Input Column
with gr.Column(scale=1):
with gr.Tab("Upload Image"):
input_image_file = gr.Image(type="pil", label="Upload Image File", height=300, width=300)
with gr.Tab("Image URL"):
input_image_url = gr.Textbox(label="Enter Image URL")
with gr.Accordion("Show Supported Bird Species", open=False):
gr.Markdown(bird_list_md)
submit_btn = gr.Button("Classify Image", variant="primary")
# Output Column
with gr.Column(scale=2):
output_image = gr.Image(label="Processed Image", height=300, width=300)
output_label = gr.Label(num_top_classes=5, label="Top 5 Labels")
# Define the click event for the button
submit_btn.click(
fn=classify_image,
inputs=[input_image_file, input_image_url],
outputs=[output_image, output_label]
)
# Add some examples for users to try
gr.Examples(
examples=[
[None, "https://www.nzbirdsonline.org.nz/assets/95597/1691022013-adult-20long-tailed-20cuckoo-20calling-20while-20perched-20within-20beech-20tree.jpg?auto=format&fit=crop&w=1200"]
],
inputs=[input_image_file, input_image_url],
outputs=[output_image, output_label],
fn=classify_image
)
# Launch the interface
if __name__ == "__main__":
demo.launch()