jordan0811 commited on
Commit
1a45478
·
verified ·
1 Parent(s): 9bd8577

Upload axmodel_inf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. axmodel_inf.py +257 -0
axmodel_inf.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DEIMv2: Real-Time Object Detection Meets DINOv3
3
+ Copyright (c) 2025 The DEIMv2 Authors. All Rights Reserved.
4
+ ---------------------------------------------------------------------------------
5
+ Modified from D-FINE (https://github.com/Peterande/D-FINE)
6
+ Copyright (c) 2024 The D-FINE Authors. All Rights Reserved.
7
+ """
8
+
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import axengine as ort
13
+ import torch
14
+ import torchvision
15
+ import torchvision.transforms as T
16
+ from PIL import Image, ImageDraw
17
+ from torch import nn
18
+ import torch.nn.functional as F
19
+
20
+
21
+ def mod(a, b):
22
+ out = a - a // b * b
23
+ return out
24
+
25
+
26
+ class PostProcessor(nn.Module):
27
+ __share__ = [
28
+ 'num_classes',
29
+ 'use_focal_loss',
30
+ 'num_top_queries',
31
+ 'remap_mscoco_category'
32
+ ]
33
+
34
+ def __init__(
35
+ self,
36
+ num_classes=80,
37
+ use_focal_loss=True,
38
+ num_top_queries=300,
39
+ remap_mscoco_category=False
40
+ ) -> None:
41
+ super().__init__()
42
+ self.use_focal_loss = use_focal_loss
43
+ self.num_top_queries = num_top_queries
44
+ self.num_classes = int(num_classes)
45
+ self.remap_mscoco_category = remap_mscoco_category
46
+ self.deploy_mode = False
47
+
48
+ def extra_repr(self) -> str:
49
+ return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}'
50
+
51
+ # def forward(self, outputs, orig_target_sizes):
52
+ def forward(self, outputs, orig_target_sizes: torch.Tensor):
53
+ logits, boxes = outputs['pred_logits'], outputs['pred_boxes']
54
+ # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
55
+
56
+ bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy')
57
+ bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1)
58
+
59
+ if self.use_focal_loss:
60
+ scores = F.sigmoid(logits)
61
+ scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1)
62
+ # labels = index % self.num_classes
63
+ labels = mod(index, self.num_classes)
64
+ index = index // self.num_classes
65
+ boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1]))
66
+
67
+ else:
68
+ scores = F.softmax(logits)[:, :, :-1]
69
+ scores, labels = scores.max(dim=-1)
70
+ if scores.shape[1] > self.num_top_queries:
71
+ scores, index = torch.topk(scores, self.num_top_queries, dim=-1)
72
+ labels = torch.gather(labels, dim=1, index=index)
73
+ boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1]))
74
+
75
+ if self.deploy_mode:
76
+ return labels, boxes, scores
77
+
78
+ if self.remap_mscoco_category:
79
+ from ..data.dataset import mscoco_label2category
80
+ labels = torch.tensor([mscoco_label2category[int(x.item())] for x in labels.flatten()])\
81
+ .to(boxes.device).reshape(labels.shape)
82
+
83
+ results = []
84
+ for lab, box, sco in zip(labels, boxes, scores):
85
+ result = dict(labels=lab, boxes=box, scores=sco)
86
+ results.append(result)
87
+
88
+ return results
89
+
90
+
91
+ def deploy(self, ):
92
+ self.eval()
93
+ self.deploy_mode = True
94
+ return self
95
+
96
+
97
+ def resize_with_aspect_ratio(image, size, interpolation=Image.BILINEAR):
98
+ """Resizes an image while maintaining aspect ratio and pads it."""
99
+ original_width, original_height = image.size
100
+ ratio = min(size / original_width, size / original_height)
101
+ new_width = int(original_width * ratio)
102
+ new_height = int(original_height * ratio)
103
+ image = image.resize((new_width, new_height), interpolation)
104
+
105
+ # Create a new image with the desired size and paste the resized image onto it
106
+ new_image = Image.new("RGB", (size, size))
107
+ new_image.paste(image, ((size - new_width) // 2, (size - new_height) // 2))
108
+ return new_image, ratio, (size - new_width) // 2, (size - new_height) // 2
109
+
110
+
111
+ def draw(images, labels, boxes, scores, ratios, paddings, thrh=0.4):
112
+ result_images = []
113
+ for i, im in enumerate(images):
114
+ draw = ImageDraw.Draw(im)
115
+ scr = scores[i]
116
+ lab = labels[i][scr > thrh]
117
+ box = boxes[i][scr > thrh]
118
+ scr = scr[scr > thrh]
119
+
120
+ ratio = ratios[i]
121
+ pad_w, pad_h = paddings[i]
122
+
123
+ for lbl, bb in zip(lab, box):
124
+ # Adjust bounding boxes according to the resizing and padding
125
+ bb = [
126
+ (bb[0] - pad_w) / ratio,
127
+ (bb[1] - pad_h) / ratio,
128
+ (bb[2] - pad_w) / ratio,
129
+ (bb[3] - pad_h) / ratio,
130
+ ]
131
+ draw.rectangle(bb, outline='red')
132
+ draw.text((bb[0], bb[1]), text=str(lbl), fill='blue')
133
+
134
+ result_images.append(im)
135
+ return result_images
136
+
137
+
138
+ def process_image(sess, im_pil, size=640, model_size='s'):
139
+ post_processor = PostProcessor().deploy()
140
+ # Resize image while preserving aspect ratio
141
+ resized_im_pil, ratio, pad_w, pad_h = resize_with_aspect_ratio(im_pil, size)
142
+ orig_size = torch.tensor([[resized_im_pil.size[1], resized_im_pil.size[0]]])
143
+
144
+ transforms = T.Compose([
145
+ T.ToTensor(),
146
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
147
+ if model_size not in ['atto', 'femto', 'pico', 'n']
148
+ else T.Lambda(lambda x: x)
149
+ ])
150
+ im_data = transforms(resized_im_pil).unsqueeze(0)
151
+
152
+ output = sess.run(
153
+ output_names=None,
154
+ input_feed={'images': im_data.numpy()}
155
+ )
156
+
157
+ output = {"pred_logits": torch.from_numpy(output[0]), "pred_boxes": torch.from_numpy(output[1])}
158
+ output = post_processor(output, orig_size)
159
+ labels, boxes, scores = output[0].numpy(), output[1].numpy(), output[2].numpy()
160
+
161
+ result_images = draw(
162
+ [im_pil], labels, boxes, scores,
163
+ [ratio], [(pad_w, pad_h)]
164
+ )
165
+ result_images[0].save('result.jpg')
166
+ print("Image processing complete. Result saved as 'result.jpg'.")
167
+
168
+
169
+ def process_video(sess, video_path, size=640, model_size='s'):
170
+ cap = cv2.VideoCapture(video_path)
171
+
172
+ # Get video properties
173
+ fps = cap.get(cv2.CAP_PROP_FPS)
174
+ orig_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
175
+ orig_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
176
+
177
+ # Define the codec and create VideoWriter object
178
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
179
+ out = cv2.VideoWriter('onnx_result.mp4', fourcc, fps, (orig_w, orig_h))
180
+
181
+ frame_count = 0
182
+ print("Processing video frames...")
183
+ while cap.isOpened():
184
+ ret, frame = cap.read()
185
+ if not ret:
186
+ break
187
+
188
+ # Convert frame to PIL image
189
+ frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
190
+
191
+ # Resize frame while preserving aspect ratio
192
+ resized_frame_pil, ratio, pad_w, pad_h = resize_with_aspect_ratio(frame_pil, size)
193
+ orig_size = torch.tensor([[resized_frame_pil.size[1], resized_frame_pil.size[0]]])
194
+
195
+ transforms = T.Compose([
196
+ T.ToTensor(),
197
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
198
+ if model_size not in ['atto', 'femto', 'pico', 'n']
199
+ else T.Lambda(lambda x: x)
200
+ ])
201
+ im_data = transforms(resized_frame_pil).unsqueeze(0)
202
+
203
+ output = sess.run(
204
+ output_names=None,
205
+ input_feed={'images': im_data.numpy(), "orig_target_sizes": orig_size.numpy()}
206
+ )
207
+
208
+ labels, boxes, scores = output
209
+
210
+ # Draw detections on the original frame
211
+ result_images = draw(
212
+ [frame_pil], labels, boxes, scores,
213
+ [ratio], [(pad_w, pad_h)]
214
+ )
215
+ frame_with_detections = result_images[0]
216
+
217
+ # Convert back to OpenCV image
218
+ frame = cv2.cvtColor(np.array(frame_with_detections), cv2.COLOR_RGB2BGR)
219
+
220
+ # Write the frame
221
+ out.write(frame)
222
+ frame_count += 1
223
+
224
+ if frame_count % 10 == 0:
225
+ print(f"Processed {frame_count} frames...")
226
+
227
+ cap.release()
228
+ out.release()
229
+ print("Video processing complete. Result saved as 'result.mp4'.")
230
+
231
+
232
+ def main(args):
233
+ """Main function."""
234
+ # Load the ONNX model
235
+ sess = ort.InferenceSession(args.axmodel)
236
+ size = sess.get_inputs()[0].shape[2]
237
+
238
+ input_path = args.input
239
+
240
+ try:
241
+ # Try to open the input as an image
242
+ im_pil = Image.open(input_path).convert('RGB')
243
+ process_image(sess, im_pil, size, args.model_size)
244
+ except IOError:
245
+ # Not an image, process as video
246
+ process_video(sess, input_path, size, args.model_size)
247
+
248
+
249
+ if __name__ == '__main__':
250
+ import argparse
251
+ parser = argparse.ArgumentParser()
252
+ parser.add_argument('--axmodel', type=str, default="compiled.axmodel", help='Path to the axmodel model file.')
253
+ parser.add_argument('--input', type=str, required=True, help='Path to the input image or video file.')
254
+ parser.add_argument('-ms', '--model-size', type=str, required=True, choices=['atto', 'femto', 'pico', 'n', 's', 'm', 'l', 'x'],
255
+ help='Model size')
256
+ args = parser.parse_args()
257
+ main(args)