sandbox338's picture
Update app.py
3bb0452 verified
import gradio as gr
from transformers import pipeline
import requests
from PIL import Image, ImageDraw
import io
# Load model from Hugging Face Hub
MODEL_ID = "sandbox338/wild-life-model"
# Function to detect wildlife in images
def detect_wildlife(image):
# Convert image to bytes for API call
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()
# API endpoint for your model
API_URL = f"https://huggingface.co/sandbox338/wild-life-model"
headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
# Make prediction request
response = requests.post(API_URL, headers=headers, data=img_byte_arr)
if response.status_code != 200:
return None, f"Error: {response.status_code} - {response.text}"
# Parse the response
result = response.json()
# Draw predictions on image
image_with_boxes = draw_predictions(image, result)
# Create text summary of results
summary = create_summary(result)
return image_with_boxes, summary
# Function to draw bounding boxes on image
def draw_predictions(image, predictions):
img = image.copy()
draw = ImageDraw.Draw(img)
colors = {
"animal": (255, 0, 0), # Red
"bird": (0, 255, 0), # Green
"mammal": (0, 0, 255), # Blue
"reptile": (255, 255, 0) # Yellow
}
for pred in predictions:
box = pred.get("box", {})
label = pred.get("label", "unknown")
score = pred.get("score", 0)
# Get color based on label (default to purple if not in colors dict)
color = colors.get(label.lower(), (255, 0, 255))
# Draw rectangle
xmin, ymin = box.get("xmin", 0), box.get("ymin", 0)
xmax, ymax = box.get("xmax", 0), box.get("ymax", 0)
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
# Draw label
draw.text((xmin, ymin-15), f"{label}: {score:.2f}", fill=color)
return img
# Function to create text summary of detections
def create_summary(predictions):
if not predictions:
return "No wildlife detected in this image."
summary = f"Detected {len(predictions)} wildlife instances:\n\n"
# Count occurrences of each species
species_count = {}
for pred in predictions:
label = pred.get("label", "unknown")
score = pred.get("score", 0)
if label not in species_count:
species_count[label] = 0
species_count[label] += 1
# Add to summary
for species, count in species_count.items():
summary += f"- {species}: {count} {'instance' if count == 1 else 'instances'}\n"
return summary
## inputs
examples = [
["https://www.google.com/url?sa=i&url=https%3A%2F%2Fwww.ifaw.org%2Fanimals%2Fantelopes&psig=AOvVaw1S3fshLDMiIZ1r9YOa8NIg&ust=1747246824703000&source=images&cd=vfe&opi=89978449&ved=0CBQQjRxqFwoTCMCKlOWHoY0DFQAAAAAdAAAAABAE"],
["example_images/wildlife2.jpg"],
["example_images/wildlife3.jpg"]
]
# Create Gradio interface
interface = gr.Interface(
fn=detect_wildlife,
inputs=gr.Image(type="pil"),
outputs=[
gr.Image(label="Detection Results"),
gr.Textbox(label="Detection Summary")
],
title="Wildlife Detection Model",
description="Upload an image to detect wildlife using our trained model",
examples=examples
)
if __name__ == "__main__":
interface.launch()