Objectdetection / app.py
SuriRaja's picture
Create app.py
ec06154 verified
import torch
from transformers import DetrForObjectDetection, DetrImageProcessor
from PIL import Image
import gradio as gr
import matplotlib.pyplot as plt
import matplotlib.patches as patches
# Load pre-trained model and processor
model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
def detect_car(image: Image.Image) -> Image.Image:
# Preprocess the input image
inputs = processor(images=image, return_tensors="pt")
# Run the model to get predictions
outputs = model(**inputs)
# Postprocess the outputs to get bounding boxes and labels
target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
# Plotting the image with bounding boxes for objects
fig, ax = plt.subplots(1, figsize=(12, 8))
ax.imshow(image)
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
if score > 0.7: # Confidence threshold for detecting cars
xmin, ymin, xmax, ymax = box.detach().numpy()
width, height = xmax - xmin, ymax - ymin
rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='red', facecolor='none')
ax.add_patch(rect)
ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}",
color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
# Convert the plot to an image
plt.axis('off')
plt.tight_layout()
# Save the figure to a canvas and convert to image
fig.canvas.draw()
result_img = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
plt.close(fig)
return result_img
# Gradio interface to upload images and get object detection results
iface = gr.Interface(
fn=detect_car,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Car Detection with DETR",
description="Upload an image and the model will detect cars with bounding boxes. Only cars will be displayed."
)
if __name__ == "__main__":
iface.launch()