SuriRaja commited on
Commit
ec06154
·
verified ·
1 Parent(s): e348275

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DetrForObjectDetection, DetrImageProcessor
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as patches
7
+
8
+ # Load pre-trained model and processor
9
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
10
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
11
+
12
+ def detect_car(image: Image.Image) -> Image.Image:
13
+ # Preprocess the input image
14
+ inputs = processor(images=image, return_tensors="pt")
15
+
16
+ # Run the model to get predictions
17
+ outputs = model(**inputs)
18
+
19
+ # Postprocess the outputs to get bounding boxes and labels
20
+ target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
21
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
22
+
23
+ # Plotting the image with bounding boxes for objects
24
+ fig, ax = plt.subplots(1, figsize=(12, 8))
25
+ ax.imshow(image)
26
+
27
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
+ if score > 0.7: # Confidence threshold for detecting cars
29
+ xmin, ymin, xmax, ymax = box.detach().numpy()
30
+ width, height = xmax - xmin, ymax - ymin
31
+ rect = patches.Rectangle((xmin, ymin), width, height, linewidth=2, edgecolor='red', facecolor='none')
32
+ ax.add_patch(rect)
33
+ ax.text(xmin, ymin, f"{model.config.id2label[label.item()]}: {score:.2f}",
34
+ color='white', fontsize=12, bbox=dict(facecolor='red', alpha=0.5))
35
+
36
+ # Convert the plot to an image
37
+ plt.axis('off')
38
+ plt.tight_layout()
39
+
40
+ # Save the figure to a canvas and convert to image
41
+ fig.canvas.draw()
42
+ result_img = Image.frombytes('RGB', fig.canvas.get_width_height(), fig.canvas.tostring_rgb())
43
+ plt.close(fig)
44
+ return result_img
45
+
46
+ # Gradio interface to upload images and get object detection results
47
+ iface = gr.Interface(
48
+ fn=detect_car,
49
+ inputs=gr.Image(type="pil"),
50
+ outputs=gr.Image(type="pil"),
51
+ title="Car Detection with DETR",
52
+ description="Upload an image and the model will detect cars with bounding boxes. Only cars will be displayed."
53
+ )
54
+
55
+ if __name__ == "__main__":
56
+ iface.launch()