DatSplit commited on
Commit
cbda6dd
·
verified ·
1 Parent(s): 11defca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -3
app.py CHANGED
@@ -9,8 +9,76 @@ import torch
9
  from PIL import Image, ImageColor
10
  from torchvision.utils import draw_bounding_boxes
11
  import rfdetr.datasets.transforms as T
 
12
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def process_categories() -> tuple:
15
  with open("categories.json") as fp:
16
  categories = json.load(fp)
@@ -79,10 +147,33 @@ def inference(image_path, model_name, bbox_threshold):
79
  )
80
 
81
  ort_inputs = {ort_session.get_inputs()[0].name: tensor_img.cpu().numpy()}
82
- ort_outs = ort_session.run(None, ort_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- boxes, labels, scores = ort_outs
85
- return draw_predictions(boxes, labels, scores, torch.from_numpy(np.array(img)), score_threshold=bbox_threshold)
86
 
87
 
88
 
 
9
  from PIL import Image, ImageColor
10
  from torchvision.utils import draw_bounding_boxes
11
  import rfdetr.datasets.transforms as T
12
+ from torchvision.ops import box_convert
13
 
14
+ def _box_yxyx_to_xyxy(boxes: torch.Tensor) -> torch.Tensor:
15
+ """Convert bounding boxes from (y1, x1, y2, x2) format to (x1, y1, x2, y2) format.
16
 
17
+ Args:
18
+ boxes (torch.Tensor): A tensor of bounding boxes in the (y1, x1, y2, x2) format.
19
+
20
+ Returns:
21
+ torch.Tensor: A tensor of bounding boxes in the (x1, y1, x2, y2) format.
22
+ """
23
+ y1, x1, y2, x2 = boxes.unbind(-1)
24
+ boxes = torch.stack((x1, y1, x2, y2), dim=-1)
25
+ return boxes
26
+
27
+
28
+ def _box_xyxy_to_yxyx(boxes: torch.Tensor) -> torch.Tensor:
29
+ """Convert bounding boxes from (x1, y1, x2, y2) format to (y1, x1, y2, x2) format.
30
+
31
+ Args:
32
+ boxes (torch.Tensor): A tensor of bounding boxes in the (x1, y1, x2, y2) format.
33
+
34
+ Returns:
35
+ torch.Tensor: A tensor of bounding boxes in the (y1, x1, y2, x2) format.
36
+ """
37
+ x1, y1, x2, y2 = boxes.unbind(-1)
38
+ boxes = torch.stack((y1, x1, y2, x2), dim=-1)
39
+ return boxes
40
+
41
+
42
+ # Adapted from: https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py#L168
43
+ def extended_box_convert(
44
+ boxes: torch.Tensor, in_fmt: str, out_fmt: str
45
+ ) -> torch.Tensor:
46
+ """
47
+ Converts boxes from given in_fmt to out_fmt.
48
+
49
+ Supported in_fmt and out_fmt are:
50
+ - 'xyxy': boxes are represented via corners, x1, y1 being top left and x2, y2 being bottom right. This is the format that torchvision utilities expect.
51
+ - 'xywh' : boxes are represented via corner, width and height, x1, y2 being top left, w, h being width and height.
52
+ - 'cxcywh' : boxes are represented via centre, width and height, cx, cy being center of box, w, h being width and height.
53
+ - 'yxyx': boxes are represented via corners, y1, x1 being top left and y2, x2 being bottom right. This is the format that `amrcnn` model outputs.
54
+
55
+ Args:
56
+ boxes (Tensor[N, 4]): boxes which will be converted.
57
+ in_fmt (str): Input format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx'].
58
+ out_fmt (str): Output format of given boxes. Supported formats are ['xyxy', 'xywh', 'cxcywh', 'yxyx'].
59
+
60
+ Returns:
61
+ Tensor[N, 4]: Boxes into converted format.
62
+ """
63
+
64
+ if in_fmt == "yxyx":
65
+ # Convert to xyxy and assign in_fmt accordingly
66
+ boxes = _box_yxyx_to_xyxy(boxes)
67
+ in_fmt = "xyxy"
68
+
69
+ if out_fmt == "yxyx":
70
+ # Convert to xyxy if not already in that format
71
+ if in_fmt != "xyxy":
72
+ boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt="xyxy")
73
+ # Convert to yxyx
74
+ boxes = _box_xyxy_to_yxyx(boxes)
75
+ else:
76
+ # Use torchvision's box_convert for other conversions
77
+ boxes = box_convert(boxes, in_fmt=in_fmt, out_fmt=out_fmt)
78
+
79
+ return boxes
80
+
81
+
82
  def process_categories() -> tuple:
83
  with open("categories.json") as fp:
84
  categories = json.load(fp)
 
147
  )
148
 
149
  ort_inputs = {ort_session.get_inputs()[0].name: tensor_img.cpu().numpy()}
150
+ pred_boxes, logits = ort_session.run(['dets', 'labels'], ort_inputs)
151
+
152
+ scores = torch.sigmoid(torch.from_numpy(logits))
153
+ max_scores, pred_labels = scores.max(-1)
154
+ mask = max_scores > bbox_threshold
155
+
156
+ pred_boxes = torch.from_numpy(pred_boxes[0])
157
+ image_w, image_h = img.size
158
+
159
+ pred_boxes_abs = pred_boxes.clone()
160
+ pred_boxes_abs[:, 0] *= image_w
161
+ pred_boxes_abs[:, 1] *= image_h
162
+ pred_boxes_abs[:, 2] *= image_w
163
+ pred_boxes_abs[:, 3] *= image_h
164
+
165
+ mask = mask.squeeze(0)
166
+
167
+ filtered_boxes = extended_box_convert(
168
+ pred_boxes_abs[mask], in_fmt="cxcywh", out_fmt="xyxy"
169
+ )
170
+ filtered_scores = max_scores.squeeze(0)[mask]
171
+ filtered_labels = pred_labels.squeeze(0)[mask]
172
+
173
+ img_tensor = torch.from_numpy(np.array(img)).permute(2, 0, 1)
174
+
175
+ return draw_predictions(filtered_boxes, filtered_labels, filtered_scores, img_tensor, score_threshold=bbox_threshold)
176
 
 
 
177
 
178
 
179