import torch from transformers import DetrForObjectDetection, DetrImageProcessor from PIL import Image import gradio as gr import matplotlib.pyplot as plt import matplotlib.patches as patches # Load pre-trained model and processor model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") def detect_car(image: Image.Image) -> Image.Image: # Preprocess the input image inputs = processor(images=image, return_tensors="pt") # Run the model to get predictions outputs = model(**inputs) # Postprocess the outputs to get bounding boxes and labels target_sizes = torch.tensor([image.size[::-1]]) # (height, width) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0] # Plotting the image with bounding boxes for objects fig, ax = plt.subplots(1, figsize=(12, 8)) ax.imshow(image) for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): if score > 0.7: # Confidence threshold for detecting cars xmin, ymin, xmax, ymax = box.detach().numpy() width, height = xmax - xmin, ymax - ymin rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='red', facecolor='none') ax.add_patch(rect) ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}", color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5)) # Convert the plot to an image plt.axis('off') plt.tight_layout() # Save the figure to a canvas and convert to image fig.canvas.draw() result_img = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb()) plt.close(fig) return result_img # Gradio interface to upload images and get object detection results iface = gr.Interface( fn=detect_car, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"), title="Car Detection with DETR", description="Upload an image and the model will detect cars with bounding boxes. Only cars will be displayed." ) if __name__ == "__main__": iface.launch()