gopichandra commited on
Commit
d3e250c
·
verified ·
1 Parent(s): c1f9c76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -13,23 +13,26 @@ def predict_image(image: Image.Image):
13
  # Step 1: Preprocess the image with padding enabled
14
  inputs = processor(images=image, return_tensors="pt", padding=True)
15
 
16
- # Step 2: Run inference without gradients
 
 
 
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
 
20
- # Step 3: Post-process the outputs
21
  # Convert image size to (height, width)
22
  target_sizes = torch.tensor([image.size[::-1]])
23
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
24
 
25
- # Step 4: Draw bounding boxes on the image
26
  draw = ImageDraw.Draw(image)
27
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
28
  box = [round(i, 2) for i in box.tolist()]
29
  draw.rectangle(box, outline="red", width=2)
30
  draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
31
 
32
- # Step 5: Return the processed image with status
33
  return image, "Object detection complete."
34
 
35
  except Exception as e:
 
13
  # Step 1: Preprocess the image with padding enabled
14
  inputs = processor(images=image, return_tensors="pt", padding=True)
15
 
16
+ # Step 2: Ensure the model works on single-image batch by adding a batch dimension
17
+ inputs = {key: val.unsqueeze(0) if len(val.shape) == 3 else val for key, val in inputs.items()}
18
+
19
+ # Step 3: Run inference without gradients
20
  with torch.no_grad():
21
  outputs = model(**inputs)
22
 
23
+ # Step 4: Post-process the outputs
24
  # Convert image size to (height, width)
25
  target_sizes = torch.tensor([image.size[::-1]])
26
  results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]
27
 
28
+ # Step 5: Draw bounding boxes on the image
29
  draw = ImageDraw.Draw(image)
30
  for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
31
  box = [round(i, 2) for i in box.tolist()]
32
  draw.rectangle(box, outline="red", width=2)
33
  draw.text((box[0], box[1]), f"{model.config.id2label[label.item()]}: {round(score.item(), 3)}", fill="red")
34
 
35
+ # Step 6: Return the processed image with status
36
  return image, "Object detection complete."
37
 
38
  except Exception as e: