import gradio as gr import numpy as np from PIL import Image, ImageDraw, ImageFont import keras_cv import keras # COCO class labels (80 classes) COCO_CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ] # Color palette for bounding boxes COLORS = [ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4", "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F", "#BB8FCE", "#85C1E9", "#F8C471", "#82E0AA", "#F1948A", "#AED6F1", "#D7BDE2", ] def load_model(): """Load pretrained YOLOv8 model from KerasCV.""" model = keras_cv.models.YOLOV8Detector.from_preset( "yolo_v8_m_pascalvoc", bounding_box_format="xyxy", ) return model print("Loading model...") model = load_model() print("Model loaded!") def detect_objects(image, confidence_threshold=0.5): """Run object detection on a single image.""" if image is None: return None orig_image = Image.fromarray(image) orig_w, orig_h = orig_image.size # Resize for model input input_size = 640 resized = orig_image.resize((input_size, input_size)) img_array = np.array(resized, dtype="float32") input_batch = np.expand_dims(img_array, axis=0) # Run prediction predictions = model.predict(input_batch) boxes = predictions["boxes"][0] classes = predictions["classes"][0] confidence = predictions["confidence"][0] # Convert to numpy if needed if hasattr(boxes, "numpy"): boxes = boxes.numpy() classes = classes.numpy() confidence = confidence.numpy() # Draw results on original image draw = ImageDraw.Draw(orig_image) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13) except OSError: font = ImageFont.load_default() small_font = font detections_found = 0 for i in range(len(boxes)): score = float(confidence[i]) if score < confidence_threshold: continue cls_id = int(classes[i]) if cls_id < 0 or cls_id >= len(COCO_CLASSES): label = f"class_{cls_id}" else: label = COCO_CLASSES[cls_id] # Scale boxes from resized coords back to original image x1 = float(boxes[i][0]) * orig_w / input_size y1 = float(boxes[i][1]) * orig_h / input_size x2 = float(boxes[i][2]) * orig_w / input_size y2 = float(boxes[i][3]) * orig_h / input_size color = COLORS[cls_id % len(COLORS)] # Draw bounding box draw.rectangle([x1, y1, x2, y2], outline=color, width=3) # Draw label background + text text = f"{label} {score:.0%}" bbox = draw.textbbox((x1, y1), text, font=font) text_w = bbox[2] - bbox[0] text_h = bbox[3] - bbox[1] draw.rectangle([x1, y1 - text_h - 6, x1 + text_w + 8, y1], fill=color) draw.text((x1 + 4, y1 - text_h - 4), text, fill="white", font=font) detections_found += 1 status = f"Found {detections_found} object(s)" if detections_found else "No objects detected" return orig_image, status # Build the Gradio interface with gr.Blocks(title="Keras Object Detection") as demo: gr.Markdown("# Object Detection with KerasCV YOLOv8") gr.Markdown("Upload an image to detect objects using a pretrained YOLOv8 model.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Upload Image", type="numpy") threshold = gr.Slider( minimum=0.1, maximum=0.95, value=0.5, step=0.05, label="Confidence Threshold", ) run_btn = gr.Button("Detect Objects", variant="primary") with gr.Column(): output_image = gr.Image(label="Detections") status_text = gr.Textbox(label="Status", interactive=False) run_btn.click( fn=detect_objects, inputs=[input_image, threshold], outputs=[output_image, status_text], ) demo.launch()