gopichandra commited on
Commit
d3a9708
·
verified ·
1 Parent(s): 373f17e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -11,25 +11,27 @@ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
11
  def predict_image(image: Image.Image):
12
  try:
13
  # Step 1: Use DETR for object detection
 
14
  inputs = processor(images=image, return_tensors="pt", padding=True)
15
-
 
16
  with torch.no_grad():
17
  outputs = model(**inputs)
18
-
19
- # Post-process the outputs for object detection
20
- target_sizes = torch.tensor([image.size[::-1]]) # Convert to (height, width)
21
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
22
 
23
- # Draw bounding boxes on the image
24
  draw = ImageDraw.Draw(image)
25
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
26
  box = [round(i, 2) for i in box.tolist()]
27
  draw.rectangle(box, outline="red", width=2)
28
  draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
29
 
30
- # Return the image with bounding boxes and description text
31
  return image, "Object detection complete."
32
-
33
  except Exception as e:
34
  return None, f"Error: {str(e)}"
35
 
 
11
  def predict_image(image: Image.Image):
12
  try:
13
  # Step 1: Use DETR for object detection
14
+ # Ensure padding is enabled when processing single images
15
  inputs = processor(images=image, return_tensors="pt", padding=True)
16
+
17
+ # Step 2: Run inference
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
+
21
+ # Step 3: Post-process the outputs for object detection
22
+ target_sizes = torch.tensor([image.size[::-1]]) # Convert image size to (height, width)
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
24
 
25
+ # Step 4: Draw bounding boxes on the image
26
  draw = ImageDraw.Draw(image)
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
  box = [round(i, 2) for i in box.tolist()]
29
  draw.rectangle(box, outline="red", width=2)
30
  draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
31
 
32
+ # Step 5: Return the processed image and description
33
  return image, "Object detection complete."
34
+
35
  except Exception as e:
36
  return None, f"Error: {str(e)}"
37