gopichandra commited on
Commit
c1f9c76
·
verified ·
1 Parent(s): d3a9708

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -10,26 +10,26 @@ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
10
  # Prediction function
11
  def predict_image(image: Image.Image):
12
  try:
13
- # Step 1: Use DETR for object detection
14
- # Ensure padding is enabled when processing single images
15
  inputs = processor(images=image, return_tensors="pt", padding=True)
16
-
17
- # Step 2: Run inference
18
  with torch.no_grad():
19
  outputs = model(**inputs)
20
 
21
- # Step 3: Post-process the outputs for object detection
22
- target_sizes = torch.tensor([image.size[::-1]]) # Convert image size to (height, width)
 
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
24
-
25
  # Step 4: Draw bounding boxes on the image
26
  draw = ImageDraw.Draw(image)
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
  box = [round(i, 2) for i in box.tolist()]
29
  draw.rectangle(box, outline="red", width=2)
30
  draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
31
-
32
- # Step 5: Return the processed image and description
33
  return image, "Object detection complete."
34
 
35
  except Exception as e:
 
10
  # Prediction function
11
  def predict_image(image: Image.Image):
12
  try:
13
+ # Step 1: Preprocess the image with padding enabled
 
14
  inputs = processor(images=image, return_tensors="pt", padding=True)
15
+
16
+ # Step 2: Run inference without gradients
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
 
20
+ # Step 3: Post-process the outputs
21
+ # Convert image size to (height, width)
22
+ target_sizes = torch.tensor([image.size[::-1]])
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
24
+
25
  # Step 4: Draw bounding boxes on the image
26
  draw = ImageDraw.Draw(image)
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
  box = [round(i, 2) for i in box.tolist()]
29
  draw.rectangle(box, outline="red", width=2)
30
  draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
31
+
32
+ # Step 5: Return the processed image with status
33
  return image, "Object detection complete."
34
 
35
  except Exception as e: