object / app.py
stevafernandes's picture
Update app.py
29a423a verified
import gradio as gr
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import keras_cv
import keras
# COCO class labels (80 classes)
COCO_CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign",
"parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
"cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
"surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
"knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
"couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush",
]
# Color palette for bounding boxes
COLORS = [
"#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7",
"#DDA0DD", "#98D8C8", "#F7DC6F", "#BB8FCE", "#85C1E9",
"#F8C471", "#82E0AA", "#F1948A", "#AED6F1", "#D7BDE2",
]
def load_model():
"""Load pretrained YOLOv8 model from KerasCV."""
model = keras_cv.models.YOLOV8Detector.from_preset(
"yolo_v8_m_pascalvoc",
bounding_box_format="xyxy",
)
return model
print("Loading model...")
model = load_model()
print("Model loaded!")
def detect_objects(image, confidence_threshold=0.5):
"""Run object detection on a single image."""
if image is None:
return None
orig_image = Image.fromarray(image)
orig_w, orig_h = orig_image.size
# Resize for model input
input_size = 640
resized = orig_image.resize((input_size, input_size))
img_array = np.array(resized, dtype="float32")
input_batch = np.expand_dims(img_array, axis=0)
# Run prediction
predictions = model.predict(input_batch)
boxes = predictions["boxes"][0]
classes = predictions["classes"][0]
confidence = predictions["confidence"][0]
# Convert to numpy if needed
if hasattr(boxes, "numpy"):
boxes = boxes.numpy()
classes = classes.numpy()
confidence = confidence.numpy()
# Draw results on original image
draw = ImageDraw.Draw(orig_image)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13)
except OSError:
font = ImageFont.load_default()
small_font = font
detections_found = 0
for i in range(len(boxes)):
score = float(confidence[i])
if score < confidence_threshold:
continue
cls_id = int(classes[i])
if cls_id < 0 or cls_id >= len(COCO_CLASSES):
label = f"class_{cls_id}"
else:
label = COCO_CLASSES[cls_id]
# Scale boxes from resized coords back to original image
x1 = float(boxes[i][0]) * orig_w / input_size
y1 = float(boxes[i][1]) * orig_h / input_size
x2 = float(boxes[i][2]) * orig_w / input_size
y2 = float(boxes[i][3]) * orig_h / input_size
color = COLORS[cls_id % len(COLORS)]
# Draw bounding box
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# Draw label background + text
text = f"{label} {score:.0%}"
bbox = draw.textbbox((x1, y1), text, font=font)
text_w = bbox[2] - bbox[0]
text_h = bbox[3] - bbox[1]
draw.rectangle([x1, y1 - text_h - 6, x1 + text_w + 8, y1], fill=color)
draw.text((x1 + 4, y1 - text_h - 4), text, fill="white", font=font)
detections_found += 1
status = f"Found {detections_found} object(s)" if detections_found else "No objects detected"
return orig_image, status
# Build the Gradio interface
with gr.Blocks(title="Keras Object Detection") as demo:
gr.Markdown("# Object Detection with KerasCV YOLOv8")
gr.Markdown("Upload an image to detect objects using a pretrained YOLOv8 model.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload Image", type="numpy")
threshold = gr.Slider(
minimum=0.1,
maximum=0.95,
value=0.5,
step=0.05,
label="Confidence Threshold",
)
run_btn = gr.Button("Detect Objects", variant="primary")
with gr.Column():
output_image = gr.Image(label="Detections")
status_text = gr.Textbox(label="Status", interactive=False)
run_btn.click(
fn=detect_objects,
inputs=[input_image, threshold],
outputs=[output_image, status_text],
)
demo.launch()