qqc1989 commited on
Commit
a0abbe9
·
verified ·
1 Parent(s): 2b42e3c

Upload yolo11_axera.py

Browse files
Files changed (1) hide show
  1. yolo11_axera.py +216 -0
yolo11_axera.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import axengine as axe
2
+ import numpy as np
3
+ import cv2
4
+ import argparse
5
+ from dataclasses import dataclass
6
+
7
+ # COCO Class Names
8
+ COCO_CLASSES = [
9
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
10
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
11
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
12
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
13
+ 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
14
+ 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
15
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard',
16
+ 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
17
+ 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
18
+ 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
19
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
20
+ 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
21
+ 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
22
+ 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
23
+ 'teddy bear', 'hair drier', 'toothbrush'
24
+ ]
25
+
26
+ @dataclass
27
+ class Object:
28
+ bbox: list # [x0, y0, width, height]
29
+ label: int
30
+ prob: float
31
+
32
+ def sigmoid(x):
33
+ return 1 / (1 + np.exp(-x))
34
+
35
+ def softmax(x, axis=-1):
36
+ x = x - np.max(x, axis=axis, keepdims=True)
37
+ e_x = np.exp(x)
38
+ return e_x / np.sum(e_x, axis=axis, keepdims=True)
39
+
40
+ def decode_distributions(feat, reg_max=16):
41
+ prob = softmax(feat, axis=-1)
42
+ dis = np.sum(prob * np.arange(reg_max), axis=-1)
43
+ return dis
44
+
45
+ def preprocess(image_path, input_size):
46
+ image = cv2.imread(image_path)
47
+ if image is None:
48
+ raise FileNotFoundError(f"Unable to read image file: {image_path}")
49
+ original_shape = image.shape[:2]
50
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
51
+ resized_image = cv2.resize(image, input_size)
52
+ input_tensor = np.expand_dims(resized_image, axis=0).astype(np.uint8)
53
+ return input_tensor, original_shape, image
54
+
55
+ def postprocess(outputs, original_shape, input_size, confidence_threshold, nms_threshold, reg_max=16):
56
+ heads = [
57
+ {'output': outputs[0], 'grid_size': input_size[0] // 8, 'stride': 8},
58
+ {'output': outputs[1], 'grid_size': input_size[0] // 16, 'stride': 16},
59
+ {'output': outputs[2], 'grid_size': input_size[0] // 32, 'stride': 32}
60
+ ]
61
+ detections = []
62
+ num_classes = 80
63
+ bbox_channels = 4 * reg_max
64
+ class_channels = num_classes
65
+
66
+ for head in heads:
67
+ output = head['output']
68
+ batch_size, grid_h, grid_w, channels = output.shape
69
+ stride = head['stride']
70
+
71
+ bbox_part = output[:, :, :, :bbox_channels]
72
+ class_part = output[:, :, :, bbox_channels:]
73
+
74
+ bbox_part = bbox_part.reshape(batch_size, grid_h, grid_w, 4, reg_max)
75
+ bbox_part = bbox_part.reshape(grid_h * grid_w, 4, reg_max)
76
+ class_part = class_part.reshape(batch_size, grid_h * grid_w, class_channels)
77
+
78
+ for b in range(batch_size):
79
+ for i in range(grid_h * grid_w):
80
+ h = i // grid_w
81
+ w = i % grid_w
82
+ class_scores = class_part[b, i, :]
83
+ class_id = np.argmax(class_scores)
84
+ class_score = class_scores[class_id]
85
+ box_prob = sigmoid(class_score)
86
+ if box_prob < confidence_threshold:
87
+ continue
88
+ bbox = bbox_part[i, :, :]
89
+ dis_left = decode_distributions(bbox[0, :], reg_max)
90
+ dis_top = decode_distributions(bbox[1, :], reg_max)
91
+ dis_right = decode_distributions(bbox[2, :], reg_max)
92
+ dis_bottom = decode_distributions(bbox[3, :], reg_max)
93
+ pb_cx = (w + 0.5) * stride
94
+ pb_cy = (h + 0.5) * stride
95
+ x0 = pb_cx - dis_left * stride
96
+ y0 = pb_cy - dis_top * stride
97
+ x1 = pb_cx + dis_right * stride
98
+ y1 = pb_cy + dis_bottom * stride
99
+ scale_x = original_shape[1] / input_size[0]
100
+ scale_y = original_shape[0] / input_size[1]
101
+ x0 = np.clip(x0 * scale_x, 0, original_shape[1] - 1)
102
+ y0 = np.clip(y0 * scale_y, 0, original_shape[0] - 1)
103
+ x1 = np.clip(x1 * scale_x, 0, original_shape[1] - 1)
104
+ y1 = np.clip(y1 * scale_y, 0, original_shape[0] - 1)
105
+ width = x1 - x0
106
+ height = y1 - y0
107
+ detections.append(Object(
108
+ bbox=[float(x0), float(y0), float(width), float(height)],
109
+ label=int(class_id),
110
+ prob=float(box_prob)
111
+ ))
112
+
113
+ if len(detections) == 0:
114
+ return []
115
+ boxes = np.array([d.bbox for d in detections])
116
+ scores = np.array([d.prob for d in detections])
117
+ class_ids = np.array([d.label for d in detections])
118
+
119
+ final_detections = []
120
+ unique_classes = np.unique(class_ids)
121
+ for cls in unique_classes:
122
+ idxs = np.where(class_ids == cls)[0]
123
+ cls_boxes = boxes[idxs]
124
+ cls_scores = scores[idxs]
125
+ x1_cls = cls_boxes[:, 0]
126
+ y1_cls = cls_boxes[:, 1]
127
+ x2_cls = cls_boxes[:, 0] + cls_boxes[:, 2]
128
+ y2_cls = cls_boxes[:, 1] + cls_boxes[:, 3]
129
+ areas = (x2_cls - x1_cls) * (y2_cls - y1_cls)
130
+ order = cls_scores.argsort()[::-1]
131
+ keep = []
132
+ while order.size > 0:
133
+ i = order[0]
134
+ keep.append(i)
135
+ if order.size == 1:
136
+ break
137
+ xx1 = np.maximum(x1_cls[i], x1_cls[order[1:]])
138
+ yy1 = np.maximum(y1_cls[i], y1_cls[order[1:]])
139
+ xx2 = np.minimum(x2_cls[i], x2_cls[order[1:]])
140
+ yy2 = np.minimum(y2_cls[i], y2_cls[order[1:]])
141
+ w = np.maximum(0, xx2 - xx1)
142
+ h = np.maximum(0, yy2 - yy1)
143
+ intersection = w * h
144
+ iou = intersection / (areas[i] + areas[order[1:]] - intersection)
145
+ inds = np.where(iou <= nms_threshold)[0]
146
+ order = order[inds + 1]
147
+ for idx in keep:
148
+ final_detections.append(Object(
149
+ bbox=cls_boxes[idx].tolist(),
150
+ label=int(cls),
151
+ prob=float(cls_scores[idx])
152
+ ))
153
+ return final_detections
154
+
155
+ def main():
156
+ parser = argparse.ArgumentParser(description="YOLO11 AXEngine Inference")
157
+ parser.add_argument('--model', type=str, default='yolo11x.axmodel', help='Model path')
158
+ parser.add_argument('--image', type=str, default='dog.jpg', help='Image path')
159
+ parser.add_argument('--conf', type=float, default=0.45, help='Confidence threshold')
160
+ parser.add_argument('--nms', type=float, default=0.45, help='NMS threshold')
161
+ parser.add_argument('--size', type=int, nargs=2, default=[640, 640], help='Input size W H')
162
+ parser.add_argument('--regmax', type=int, default=16, help='DFL reg_max value')
163
+ args = parser.parse_args()
164
+
165
+ try:
166
+ input_tensor, original_shape, original_image = preprocess(args.image, tuple(args.size))
167
+ except FileNotFoundError as e:
168
+ print(e)
169
+ return
170
+
171
+ try:
172
+ session = axe.InferenceSession(args.model)
173
+ except Exception as e:
174
+ print(f"Error loading model: {e}")
175
+ return
176
+
177
+ input_name = session.get_inputs()[0].name
178
+ output_names = [output.name for output in session.get_outputs()]
179
+
180
+ try:
181
+ outputs = session.run(output_names, {input_name: input_tensor})
182
+ except Exception as e:
183
+ print(f"Error during inference: {e}")
184
+ return
185
+
186
+ try:
187
+ detections = postprocess(
188
+ outputs,
189
+ original_shape,
190
+ tuple(args.size),
191
+ args.conf,
192
+ args.nms,
193
+ reg_max=args.regmax
194
+ )
195
+ except Exception as e:
196
+ print(f"Error during post-processing: {e}")
197
+ return
198
+
199
+ for det in detections:
200
+ bbox = det.bbox
201
+ score = det.prob
202
+ class_id = det.label
203
+ if class_id >= len(COCO_CLASSES):
204
+ label = f"cls{class_id}: {score:.2f}"
205
+ else:
206
+ label = f"{COCO_CLASSES[class_id]}: {score:.2f}"
207
+ x, y, w, h = map(int, bbox)
208
+ cv2.rectangle(original_image, (x, y), (x + w, y + h), (0, 255, 0), 2)
209
+ cv2.putText(original_image, label, (x, y - 10),
210
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
211
+
212
+ cv2.imwrite('detections.png', cv2.cvtColor(original_image, cv2.COLOR_RGB2BGR))
213
+ print("结果已保存到 detections.png")
214
+
215
+ if __name__ == '__main__':
216
+ main()