|
|
import tensorflow as tf |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
model = tf.saved_model.load("http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_fpnlite_320x320/saved_model") |
|
|
|
|
|
|
|
|
category_index = { |
|
|
1: {'id': 1, 'name': 'person'}, |
|
|
2: {'id': 2, 'name': 'bicycle'}, |
|
|
3: {'id': 3, 'name': 'car'}, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
def detect_objects(image): |
|
|
|
|
|
input_tensor = tf.convert_to_tensor(image) |
|
|
input_tensor = input_tensor[tf.newaxis,...] |
|
|
|
|
|
|
|
|
detections = model(input_tensor) |
|
|
|
|
|
|
|
|
num_detections = int(detections.pop('num_detections')) |
|
|
detections = {key: value[0, :num_detections].numpy() for key, value in detections.items()} |
|
|
detection_classes = detections['detection_classes'].astype(np.int64) |
|
|
detection_boxes = detections['detection_boxes'] |
|
|
detection_scores = detections['detection_scores'] |
|
|
|
|
|
|
|
|
for i in range(num_detections): |
|
|
if detection_scores[i] > 0.5: |
|
|
class_name = category_index.get(detection_classes[i], {'name': 'N/A'})['name'] |
|
|
box = detection_boxes[i] |
|
|
height, width, _ = image.shape |
|
|
ymin, xmin, ymax, xmax = box |
|
|
(startX, startY, endX, endY) = (int(xmin * width), int(ymin * height), int(xmax * width), int(ymax * height)) |
|
|
|
|
|
|
|
|
cv2.rectangle(image, (startX, startY), (endX, endY), (0, 255, 0), 2) |
|
|
label = f"{class_name}: {detection_scores[i]:.2f}" |
|
|
cv2.putText(image, label, (startX, startY - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def gradio_interface(image): |
|
|
if isinstance(image, str): |
|
|
image = cv2.imread(image) |
|
|
else: |
|
|
|
|
|
image = np.array(image) |
|
|
|
|
|
|
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
detected_image = detect_objects(image_rgb) |
|
|
|
|
|
|
|
|
detected_image_bgr = cv2.cvtColor(detected_image, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
return detected_image_bgr |
|
|
|
|
|
|
|
|
iface = gr.Interface(fn=gradio_interface, |
|
|
inputs=gr.inputs.Image(type="filepath"), |
|
|
outputs="image", |
|
|
title="Object Detection with Bounding Boxes", |
|
|
description="Upload an image or provide a file path to detect objects.") |
|
|
|
|
|
|
|
|
iface.launch() |
|
|
|