fangmingguo commited on
Commit
224aed4
·
verified ·
1 Parent(s): eb94ef6

Upload inference_axmodel.py

Browse files
Files changed (1) hide show
  1. inference_axmodel.py +1002 -0
inference_axmodel.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ import os
5
+ import os.path as osp
6
+ import cv2
7
+ import numpy as np
8
+ import axengine as axe
9
+ from collections import defaultdict
10
+ from tqdm import tqdm
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description='BEVFormer AXEngine Inference from Extracted Data')
15
+ parser.add_argument('model', help='AXModel path')
16
+ parser.add_argument('config_json', help='JSON config file path')
17
+ parser.add_argument('data_dir', help='extracted data directory (extracted_data)')
18
+ parser.add_argument('--output-dir', default='./inference_results_extracted', help='output directory')
19
+ parser.add_argument('--score-thr', type=float, default=0.1, help='score threshold')
20
+ parser.add_argument('--fps', type=int, default=3, help='video fps')
21
+ parser.add_argument('--start-scene', type=int, default=0, help='start scene index')
22
+ parser.add_argument('--end-scene', type=int, default=None, help='end scene index (None for all)')
23
+ return parser.parse_args()
24
+
25
+
26
+ def load_axmodel(axmodel_path):
27
+ """Load AXModel"""
28
+ # 尝试使用 AxEngineExecutionProvider 而不是 AXCLRTExecutionProvider
29
+ providers = ['AxEngineExecutionProvider']
30
+ session = axe.InferenceSession(axmodel_path, providers=providers)
31
+ return session
32
+
33
+
34
+ def load_config_from_json(config_path):
35
+ """Load configuration from JSON file"""
36
+ with open(config_path, 'r') as f:
37
+ config = json.load(f)
38
+ return config
39
+
40
+
41
+ def preprocess_image(img_path, img_norm_cfg, target_size=(480, 800)):
42
+ """Preprocess image: load, resize, normalize
43
+
44
+ Args:
45
+ img_path: path to image file
46
+ img_norm_cfg: normalization config with 'mean', 'std', 'to_rgb'
47
+ target_size: (H, W) target size
48
+
49
+ Returns:
50
+ img: (C, H, W) normalized numpy array, float32
51
+ """
52
+ # Load image
53
+ img = cv2.imread(img_path)
54
+ if img is None:
55
+ raise ValueError(f"Cannot load image: {img_path}")
56
+
57
+ # Convert BGR to RGB if needed
58
+ if img_norm_cfg.get('to_rgb', True):
59
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
60
+
61
+ # Resize if needed
62
+ if img.shape[:2] != target_size:
63
+ img = cv2.resize(img, (target_size[1], target_size[0])) # (W, H)
64
+
65
+ # Convert to float and normalize
66
+ img = img.astype(np.float32)
67
+ mean = np.array(img_norm_cfg.get('mean', [123.675, 116.28, 103.53]), dtype=np.float32)
68
+ std = np.array(img_norm_cfg.get('std', [58.395, 57.12, 57.375]), dtype=np.float32)
69
+
70
+ img = (img - mean) / std
71
+ img = img.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
72
+
73
+ return img
74
+
75
+
76
+ def load_data(data_dir, scene_name, frame_idx):
77
+ """Load data
78
+
79
+ Args:
80
+ data_dir: data directory path
81
+ scene_name: scene name (scene token)
82
+ frame_idx: frame index (sample index)
83
+
84
+ Returns:
85
+ img: (1, N, C, H, W) numpy array
86
+ lidar2img: (1, N, 4, 4) numpy array
87
+ can_bus: (1, 18) numpy array
88
+ meta: dict with metadata
89
+ """
90
+ scene_dir = osp.join(data_dir, scene_name)
91
+
92
+ # Load meta
93
+ meta_path = osp.join(scene_dir, f'meta_{frame_idx:06d}.json')
94
+ with open(meta_path, 'r') as f:
95
+ meta = json.load(f)
96
+
97
+ # Get normalization config
98
+ img_norm_cfg = meta.get('img_norm_cfg', {
99
+ 'mean': [123.675, 116.28, 103.53],
100
+ 'std': [58.395, 57.12, 57.375],
101
+ 'to_rgb': True
102
+ })
103
+
104
+ # Get image shape
105
+ img_shape = meta.get('img_shape', [[480, 800, 3]] * 6)
106
+ target_size = (img_shape[0][0], img_shape[0][1]) # (H, W)
107
+
108
+ # Load images for all cameras
109
+ num_cams = meta.get('num_cams', 6)
110
+ imgs = []
111
+ for cam_idx in range(num_cams):
112
+ img_path = osp.join(scene_dir, f'cam_{cam_idx:02d}_{frame_idx:06d}.png')
113
+ img = preprocess_image(img_path, img_norm_cfg, target_size)
114
+ imgs.append(img)
115
+
116
+ # Stack images: (N, C, H, W) -> (1, N, C, H, W)
117
+ img = np.stack(imgs, axis=0) # (N, C, H, W)
118
+ img = img[np.newaxis, ...] # (1, N, C, H, W)
119
+
120
+ # Load lidar2img: (N, 4, 4) -> (1, N, 4, 4)
121
+ lidar2img = np.array(meta['lidar2img'], dtype=np.float32) # (N, 4, 4)
122
+ lidar2img = lidar2img[np.newaxis, ...] # (1, N, 4, 4)
123
+
124
+ # Load can_bus: (18,) -> (1, 18)
125
+ can_bus = np.array(meta['can_bus'], dtype=np.float32) # (18,)
126
+ can_bus = can_bus[np.newaxis, ...] # (1, 18)
127
+
128
+ return img, lidar2img, can_bus, meta
129
+
130
+ CLASS_COLORS = {
131
+ 0: (0, 255, 0), 1: (255, 255, 0), 2: (0, 0, 255), 3: (0, 165, 255),
132
+ 4: (255, 0, 255), 5: (0, 255, 255), 6: (128, 0, 128), 7: (255, 165, 0),
133
+ 8: (0, 0, 255), 9: (128, 128, 128),
134
+ }
135
+
136
+
137
+ def denormalize_bbox_np(normalized_bboxes, pc_range):
138
+ """Denormalize bbox using numpy only"""
139
+ # rotation
140
+ rot_sine = normalized_bboxes[..., 6:7]
141
+ rot_cosine = normalized_bboxes[..., 7:8]
142
+
143
+ rot = np.arctan2(rot_sine, rot_cosine)
144
+
145
+ # center in the bev
146
+ cx = normalized_bboxes[..., 0:1]
147
+ cy = normalized_bboxes[..., 1:2]
148
+ cz = normalized_bboxes[..., 4:5]
149
+
150
+ # size
151
+ w = normalized_bboxes[..., 2:3]
152
+ l = normalized_bboxes[..., 3:4]
153
+ h = normalized_bboxes[..., 5:6]
154
+
155
+ w = np.exp(w)
156
+ l = np.exp(l)
157
+ h = np.exp(h)
158
+
159
+ if normalized_bboxes.shape[-1] > 8:
160
+ # velocity
161
+ vx = normalized_bboxes[:, 8:9]
162
+ vy = normalized_bboxes[:, 9:10]
163
+ denormalized_bboxes = np.concatenate([cx, cy, cz, w, l, h, rot, vx, vy], axis=-1)
164
+ else:
165
+ denormalized_bboxes = np.concatenate([cx, cy, cz, w, l, h, rot], axis=-1)
166
+ return denormalized_bboxes
167
+
168
+ def decode_bboxes_custom_np(all_cls_scores, all_bbox_preds, pc_range, post_center_range, max_num=100, score_threshold=None, num_classes=10):
169
+ """Custom bbox decode function"""
170
+ # Use output from the last decoder layer
171
+ all_cls_scores = all_cls_scores[-1] # (bs, num_query, num_classes)
172
+ all_bbox_preds = all_bbox_preds[-1] # (bs, num_query, 10)
173
+
174
+ batch_size = all_cls_scores.shape[0]
175
+ predictions_list = []
176
+
177
+ for i in range(batch_size):
178
+ cls_scores = all_cls_scores[i] # (num_query, num_classes)
179
+ bbox_preds = all_bbox_preds[i] # (num_query, 10)
180
+
181
+ # Apply sigmoid
182
+ cls_scores = 1.0 / (1.0 + np.exp(-cls_scores))
183
+
184
+ # TopK selection
185
+ cls_scores_flat = cls_scores.reshape(-1)
186
+ topk_indices = np.argsort(cls_scores_flat)[::-1][:max_num]
187
+ scores = cls_scores_flat[topk_indices]
188
+ labels = topk_indices % num_classes
189
+ bbox_index = topk_indices // num_classes
190
+ bbox_preds = bbox_preds[bbox_index]
191
+
192
+ # Denormalize bbox
193
+ final_box_preds = denormalize_bbox_np(bbox_preds, pc_range) # (max_num, 9)
194
+ final_scores = scores
195
+ final_preds = labels
196
+
197
+ # Apply score threshold
198
+ if score_threshold is not None:
199
+ thresh_mask = final_scores > score_threshold
200
+ tmp_score = score_threshold
201
+ while thresh_mask.sum() == 0:
202
+ tmp_score *= 0.9
203
+ if tmp_score < 0.01:
204
+ thresh_mask = np.ones(len(final_scores), dtype=bool)
205
+ break
206
+ thresh_mask = final_scores >= tmp_score
207
+ else:
208
+ thresh_mask = np.ones(len(final_scores), dtype=bool)
209
+
210
+ # Apply post processing range filtering
211
+ if post_center_range is not None:
212
+ post_center_range_arr = np.array(post_center_range)
213
+ mask = (final_box_preds[..., :3] >= post_center_range_arr[:3]).all(1)
214
+ mask &= (final_box_preds[..., :3] <= post_center_range_arr[3:]).all(1)
215
+ mask &= thresh_mask
216
+
217
+ boxes3d = final_box_preds[mask]
218
+ scores = final_scores[mask]
219
+ labels = final_preds[mask]
220
+ else:
221
+ boxes3d = final_box_preds[thresh_mask]
222
+ scores = final_scores[thresh_mask]
223
+ labels = final_preds[thresh_mask]
224
+
225
+ predictions_list.append({
226
+ 'bboxes': boxes3d,
227
+ 'scores': scores,
228
+ 'labels': labels
229
+ })
230
+
231
+ return predictions_list
232
+
233
+
234
+ def get_bboxes_custom_np(preds_dicts, pc_range, post_center_range, max_num=100, score_threshold=None, num_classes=10):
235
+ """Custom get_bboxes function"""
236
+ # Decode bounding boxes
237
+ preds_list = decode_bboxes_custom_np(
238
+ preds_dicts['all_cls_scores'],
239
+ preds_dicts['all_bbox_preds'],
240
+ pc_range,
241
+ post_center_range,
242
+ max_num,
243
+ score_threshold,
244
+ num_classes
245
+ )
246
+
247
+ num_samples = len(preds_list)
248
+ ret_list = []
249
+
250
+ for i in range(num_samples):
251
+ preds = preds_list[i]
252
+ bboxes = preds['bboxes']
253
+
254
+ if len(bboxes) == 0:
255
+ ret_list.append((
256
+ np.zeros((0, 9), dtype=np.float32),
257
+ np.zeros((0,), dtype=np.float32),
258
+ np.zeros((0,), dtype=np.int64)
259
+ ))
260
+ continue
261
+
262
+ # Adjust z coordinate: convert center z to bottom center z
263
+ bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
264
+
265
+ # Shrink box dimensions: multiply w, l, h by 0.9 to fix oversized boxes
266
+ bboxes[:, 3:6] = bboxes[:, 3:6] * 0.9
267
+
268
+ scores = preds['scores']
269
+ labels = preds['labels']
270
+
271
+ ret_list.append((bboxes, scores, labels))
272
+
273
+ return ret_list
274
+
275
+
276
+ def format_bbox_result_np(bboxes, scores, labels):
277
+ return {
278
+ 'boxes_3d': bboxes,
279
+ 'scores_3d': scores,
280
+ 'labels_3d': labels
281
+ }
282
+
283
+
284
+ def rotation_3d_in_axis_np(points, angles, axis=2):
285
+ """Rotate points by angles according to axis"""
286
+ rot_sin = np.sin(angles)
287
+ rot_cos = np.cos(angles)
288
+ ones = np.ones_like(rot_cos)
289
+ zeros = np.zeros_like(rot_cos)
290
+
291
+ if axis == 2 or axis == -1:
292
+ # Rotate around z-axis
293
+ # Build rotation matrix: (N, 3, 3)
294
+ N = len(angles)
295
+ rot_mat = np.zeros((N, 3, 3), dtype=points.dtype)
296
+ rot_mat[:, 0, 0] = rot_cos
297
+ rot_mat[:, 0, 1] = -rot_sin
298
+ rot_mat[:, 0, 2] = zeros
299
+ rot_mat[:, 1, 0] = rot_sin
300
+ rot_mat[:, 1, 1] = rot_cos
301
+ rot_mat[:, 1, 2] = zeros
302
+ rot_mat[:, 2, 0] = zeros
303
+ rot_mat[:, 2, 1] = zeros
304
+ rot_mat[:, 2, 2] = ones
305
+
306
+ # Rotation: (N, M, 3) @ (N, 3, 3) -> (N, M, 3)
307
+ return np.einsum('aij,ajk->aik', points, rot_mat)
308
+ else:
309
+ raise ValueError(f'Only axis=2 (z-axis) is supported for LiDAR boxes')
310
+
311
+
312
+ def compute_bbox_corners_np(bboxes):
313
+ """Compute 8 corners of 3D bbox"""
314
+ if len(bboxes) == 0:
315
+ return np.zeros((0, 8, 3), dtype=np.float32)
316
+
317
+ dtype = bboxes.dtype
318
+
319
+ # Extract bbox parameters
320
+ centers = bboxes[:, :3] # (N, 3) [x, y, z] - the bottom center
321
+ w = bboxes[:, 3:4] # width (y direction)
322
+ l = bboxes[:, 4:5] # length (x direction)
323
+ h = bboxes[:, 5:6] # height (z direction)
324
+ dims = np.concatenate([l, w, h], axis=1) # (N, 3) [x_size, y_size, z_size] = [l, w, h]
325
+ yaws = bboxes[:, 6] # (N,) yaw angle
326
+
327
+ # Fix: offset yaw by -80 degrees
328
+ yaws = yaws - (np.pi / 2.0 - np.pi / 18.0)
329
+
330
+ # Generate corners
331
+ corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1).astype(dtype)
332
+
333
+ # Rearrange to [0, 1, 3, 2, 4, 5, 7, 6]
334
+ corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
335
+
336
+ # Use relative origin [0.5, 0.5, 0] (bottom center)
337
+ corners_norm = corners_norm - np.array([0.5, 0.5, 0], dtype=dtype)
338
+
339
+ # Scale corners: dims is [x_size, y_size, z_size]
340
+ corners = dims[:, np.newaxis, :] * corners_norm[np.newaxis, :, :] # (N, 8, 3)
341
+
342
+ # Rotate around z-axis
343
+ corners = rotation_3d_in_axis_np(corners, yaws, axis=2)
344
+
345
+ # Translate to center point
346
+ corners += centers[:, np.newaxis, :]
347
+
348
+ return corners
349
+
350
+
351
+ def draw_bbox3d_on_img_custom_np(bboxes, raw_img, lidar2img_rt, color=(0, 255, 0), thickness=2):
352
+ """Custom 3D bbox drawing"""
353
+ img = raw_img.copy()
354
+
355
+ if len(bboxes) == 0:
356
+ return img
357
+
358
+ if not isinstance(bboxes, np.ndarray):
359
+ bboxes = np.array(bboxes)
360
+ if not isinstance(lidar2img_rt, np.ndarray):
361
+ lidar2img_rt = np.array(lidar2img_rt)
362
+
363
+ lidar2img_rt = lidar2img_rt.reshape(4, 4)
364
+
365
+ # Compute corners
366
+ corners_3d = compute_bbox_corners_np(bboxes) # (N, 8, 3)
367
+
368
+ num_bbox = corners_3d.shape[0]
369
+
370
+ # Project to 2D
371
+ corners_3d_flat = corners_3d.reshape(-1, 3) # (N*8, 3)
372
+ ones = np.ones((corners_3d_flat.shape[0], 1), dtype=np.float32)
373
+ pts_4d = np.concatenate([corners_3d_flat, ones], axis=-1) # (N*8, 4)
374
+
375
+ # Project
376
+ pts_2d = pts_4d @ lidar2img_rt.T # (N*8, 4)
377
+
378
+ # Perspective division
379
+ pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5)
380
+ pts_2d[:, 0] /= pts_2d[:, 2]
381
+ pts_2d[:, 1] /= pts_2d[:, 2]
382
+
383
+ imgfov_pts_2d = pts_2d[:, :2].reshape(num_bbox, 8, 2)
384
+
385
+ line_indices = ((0, 1), (0, 3), (0, 4), (1, 2), (1, 5), (3, 2), (3, 7),
386
+ (4, 5), (4, 7), (2, 6), (5, 6), (6, 7))
387
+
388
+ for i in range(num_bbox):
389
+ corners = imgfov_pts_2d[i].astype(np.int32)
390
+ for start, end in line_indices:
391
+ pt1 = (int(corners[start, 0]), int(corners[start, 1]))
392
+ pt2 = (int(corners[end, 0]), int(corners[end, 1]))
393
+ # Check if points are within image range
394
+ h, w = img.shape[:2]
395
+ if (0 <= pt1[0] < w and 0 <= pt1[1] < h) or (0 <= pt2[0] < w and 0 <= pt2[1] < h):
396
+ cv2.line(img, pt1, pt2, color, thickness, cv2.LINE_AA)
397
+
398
+ return img.astype(np.uint8)
399
+
400
+
401
+ def post_process_outputs_np(all_cls_scores, all_bbox_preds, config, score_thr=0.1):
402
+ bbox_coder = config['model']['bbox_coder']
403
+ pc_range = bbox_coder['pc_range']
404
+ post_center_range = bbox_coder['post_center_range']
405
+ max_num = bbox_coder['max_num']
406
+ score_threshold = bbox_coder.get('score_threshold', None)
407
+ num_classes = bbox_coder['num_classes']
408
+
409
+ preds_dicts = {
410
+ 'all_cls_scores': all_cls_scores,
411
+ 'all_bbox_preds': all_bbox_preds
412
+ }
413
+
414
+ bbox_list = get_bboxes_custom_np(
415
+ preds_dicts, pc_range, post_center_range,
416
+ max_num, score_threshold, num_classes
417
+ )
418
+
419
+ results = []
420
+ for bboxes, scores, labels in bbox_list:
421
+ # Set class score thresholds
422
+ class_score_thrs = {
423
+ 0: 0.3, # Car
424
+ 1: 0.3, # Truck
425
+ 2: 0.3, # Construction vehicle
426
+ 3: 0.3, # Bus
427
+ 4: 0.3, # Trailer
428
+ 5: 0.3, # Barrier
429
+ 6: 0.3, # Motorcycle
430
+ 7: 0.3, # Bicycle
431
+ 8: 0.3, # Pedestrian
432
+ 9: 0.3, # Traffic cone
433
+ }
434
+ default_thr = score_thr
435
+
436
+ keep_indices = []
437
+ for i in range(len(scores)):
438
+ cls_id = int(labels[i])
439
+ thr = class_score_thrs.get(cls_id, default_thr)
440
+ if scores[i] > thr:
441
+ keep_indices.append(i)
442
+
443
+ if len(keep_indices) == 0:
444
+ results.append(format_bbox_result_np(
445
+ np.zeros((0, 9), dtype=np.float32),
446
+ np.zeros((0,), dtype=np.float32),
447
+ np.zeros((0,), dtype=np.int64)
448
+ ))
449
+ continue
450
+
451
+ keep_indices = np.array(keep_indices, dtype=np.int64)
452
+ bboxes = bboxes[keep_indices]
453
+ scores = scores[keep_indices]
454
+ labels = labels[keep_indices]
455
+
456
+ # Circle NMS
457
+ dist_thrs = {
458
+ 0: 2.0, 1: 3.0, 2: 2.5, 3: 4.0, 4: 3.0,
459
+ 5: 1.0, 6: 1.5, 7: 1.0, 8: 0.5, 9: 0.3,
460
+ }
461
+
462
+ if len(scores) > 0:
463
+ keep_nms = circle_nms_np(bboxes, scores, labels, dist_thrs)
464
+ if len(keep_nms) > 0:
465
+ bboxes = bboxes[keep_nms]
466
+ scores = scores[keep_nms]
467
+ labels = labels[keep_nms]
468
+ else:
469
+ results.append(format_bbox_result_np(
470
+ np.zeros((0, 9), dtype=np.float32),
471
+ np.zeros((0,), dtype=np.float32),
472
+ np.zeros((0,), dtype=np.int64)
473
+ ))
474
+ continue
475
+
476
+ results.append(format_bbox_result_np(bboxes, scores, labels))
477
+
478
+ return results
479
+
480
+
481
+ def circle_nms_np(bboxes, scores, labels, dist_thrs):
482
+ if len(bboxes) == 0:
483
+ return np.array([], dtype=np.int64)
484
+
485
+ keep = []
486
+ order = np.argsort(scores)[::-1]
487
+ bboxes = bboxes[order]
488
+ scores = scores[order]
489
+ labels = labels[order]
490
+
491
+ pts = bboxes[:, :2]
492
+ labels_np = labels
493
+
494
+ suppressed = np.zeros(len(bboxes), dtype=bool)
495
+
496
+ for i in range(len(bboxes)):
497
+ if suppressed[i]:
498
+ continue
499
+ keep.append(order[i])
500
+
501
+ curr_cls = int(labels_np[i])
502
+ radius = dist_thrs.get(curr_cls, 1.0)
503
+
504
+ if i + 1 < len(bboxes):
505
+ dists = np.linalg.norm(pts[i+1:] - pts[i], axis=1)
506
+ idx_to_suppress = np.where(
507
+ (dists < radius) & (labels_np[i+1:] == curr_cls)
508
+ )[0]
509
+ suppressed[i+1:][idx_to_suppress] = True
510
+
511
+ return np.array(keep, dtype=np.int64)
512
+
513
+
514
+ def denormalize_img_np(img_array, img_norm_cfg):
515
+ """Denormalize image array (C, H, W) to (H, W, C) BGR"""
516
+ mean = np.array(img_norm_cfg.get('mean', [123.675, 116.28, 103.53]))
517
+ std = np.array(img_norm_cfg.get('std', [58.395, 57.12, 57.375]))
518
+
519
+ # (C, H, W) RGB -> (H, W, C) RGB
520
+ if img_array.ndim == 3:
521
+ img = img_array.transpose(1, 2, 0)
522
+ else:
523
+ img = img_array
524
+ img = (img * std + mean)
525
+ img = np.clip(img, 0, 255).astype(np.uint8)
526
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
527
+ return img
528
+
529
+
530
+ def draw_bev_map(bboxes, labels, scores, pc_range, bev_size=(800, 800), score_thr=0.1):
531
+ """Draw BEV (Bird's Eye View) map with detections
532
+
533
+ Args:
534
+ bboxes: (N, 9) numpy array, format: [x, y, z, w, l, h, yaw, vx, vy]
535
+ labels: (N,) numpy array, class labels
536
+ scores: (N,) numpy array, detection scores
537
+ pc_range: [x_min, y_min, z_min, x_max, y_max, z_max]
538
+ bev_size: (width, height) of BEV image
539
+ score_thr: score threshold
540
+
541
+ Returns:
542
+ bev_img: (H, W, 3) numpy array, BEV visualization
543
+ """
544
+ bev_w, bev_h = bev_size # BEV image size
545
+ bev_img = np.ones((bev_h, bev_w, 3), dtype=np.uint8) * 255 # White background
546
+
547
+ # Draw grid
548
+ x_min, y_min, z_min, x_max, y_max, z_max = pc_range
549
+ x_range = x_max - x_min
550
+ y_range = y_max - y_min
551
+
552
+ # Draw grid lines
553
+ grid_color = (200, 200, 200) # Light gray grid lines
554
+ for i in range(-5, 6):
555
+ x = x_min + (i + 5) * x_range / 10
556
+ y = y_min + (i + 5) * y_range / 10
557
+ # Vertical lines (y direction in LiDAR -> x direction in image)
558
+ img_x = int((y - y_min) / y_range * bev_w)
559
+ if 0 <= img_x < bev_w:
560
+ cv2.line(bev_img, (img_x, 0), (img_x, bev_h), grid_color, 1)
561
+ # Horizontal lines (x direction in LiDAR -> y direction in image, flipped)
562
+ img_y = int((x_max - x) / x_range * bev_h)
563
+ if 0 <= img_y < bev_h:
564
+ cv2.line(bev_img, (0, img_y), (bev_w, img_y), grid_color, 1)
565
+
566
+ # Draw center lines (ego vehicle position) - darker on white background
567
+ center_x = int((0 - y_min) / y_range * bev_w)
568
+ center_y = int((x_max - 0) / x_range * bev_h)
569
+ cv2.line(bev_img, (center_x, 0), (center_x, bev_h), (150, 150, 150), 2)
570
+ cv2.line(bev_img, (0, center_y), (bev_w, center_y), (150, 150, 150), 2)
571
+
572
+
573
+ ego_length_px = 30 # pixels (representing ~4.5m, along x-axis rightward)
574
+ ego_width_px = 12 # pixels (representing ~1.8m, along y-axis downward)
575
+
576
+ ego_corners_local = np.array([
577
+ [ego_length_px//2, -ego_width_px//2], # front-top (head)
578
+ [ego_length_px//2, ego_width_px//2], # front-bottom
579
+ [-ego_length_px//2, ego_width_px//2], # back-bottom
580
+ [-ego_length_px//2, -ego_width_px//2], # back-top
581
+ ], dtype=np.float32)
582
+
583
+
584
+ rotation_angle_90 = np.pi / 2 # 90 degrees in radians
585
+ cos_rot_90 = np.cos(rotation_angle_90)
586
+ sin_rot_90 = np.sin(rotation_angle_90)
587
+ rot_mat_90 = np.array([[cos_rot_90, -sin_rot_90], [sin_rot_90, cos_rot_90]])
588
+
589
+ ego_corners_rotated_90 = ego_corners_local @ rot_mat_90.T
590
+
591
+ ego_corners_rotated = ego_corners_rotated_90 @ rot_mat_90.T
592
+
593
+ # Translate to image coordinates (center position)
594
+ ego_corners = []
595
+ for corner in ego_corners_rotated:
596
+ corner_img_x = int(center_x + corner[0])
597
+ corner_img_y = int(center_y + corner[1])
598
+ ego_corners.append([corner_img_x, corner_img_y])
599
+ ego_corners = np.array(ego_corners, dtype=np.int32)
600
+
601
+ # Draw filled rectangle
602
+ cv2.fillPoly(bev_img, [ego_corners], (0, 0, 255)) # Red filled
603
+ cv2.polylines(bev_img, [ego_corners], True, (0, 0, 0), 2) # Black outline
604
+
605
+
606
+ arrow_length = ego_length_px // 2
607
+ initial_direction = np.array([1.0, 0.0])
608
+ arrow_dir_rotated_90 = initial_direction @ rot_mat_90.T
609
+ arrow_dir_rotated = arrow_dir_rotated_90 @ rot_mat_90.T
610
+ arrow_end_x = int(center_x + arrow_length * arrow_dir_rotated[0])
611
+ arrow_end_y = int(center_y + arrow_length * arrow_dir_rotated[1])
612
+ cv2.arrowedLine(bev_img, (center_x, center_y), (arrow_end_x, arrow_end_y),
613
+ (0, 0, 0), 3, tipLength=0.3) # Black arrow
614
+
615
+ if len(bboxes) == 0:
616
+ return bev_img
617
+
618
+ if score_thr > 0:
619
+ mask = scores > score_thr
620
+ bboxes = bboxes[mask]
621
+ labels = labels[mask]
622
+ scores = scores[mask]
623
+
624
+ if len(bboxes) == 0:
625
+ return bev_img
626
+
627
+ default_color = (255, 255, 255)
628
+
629
+
630
+ for i in range(len(bboxes)):
631
+ box = bboxes[i]
632
+ label = int(labels[i])
633
+ score = float(scores[i])
634
+ color = CLASS_COLORS.get(label, default_color)
635
+
636
+ x, y, z = box[0], box[1], box[2] # center position
637
+ w, l, h = box[3], box[4], box[5] # width, length, height
638
+ yaw = box[6] # yaw angle
639
+
640
+ yaw = yaw - np.pi / 2.0 # Subtract 90 degrees (counterclockwise)
641
+
642
+ # Convert to image coordinates
643
+ # Note: In LiDAR coordinate, x is forward, y is left, z is up
644
+ # In BEV image (top-down view):
645
+ # - x (forward) -> image y (downward, flipped)
646
+ # - y (left) -> image x (rightward)
647
+ # So: img_x = (y - y_min) / y_range * bev_w
648
+ # img_y = (x_max - x) / x_range * bev_h (flip x to get top-down view)
649
+ img_x = int((y - y_min) / y_range * bev_w)
650
+ img_y = int((x_max - x) / x_range * bev_h) # Flip x for top-down view
651
+
652
+ # Skip if outside image
653
+ if not (0 <= img_x < bev_w and 0 <= img_y < bev_h):
654
+ continue
655
+
656
+ # Calculate box dimensions in image space
657
+ box_w_px = int(w / x_range * bev_w)
658
+ box_l_px = int(l / y_range * bev_h)
659
+
660
+ # Draw rotated rectangle
661
+ # Calculate 4 corners of the box in LiDAR coordinates
662
+ cos_yaw = np.cos(yaw)
663
+ sin_yaw = np.sin(yaw)
664
+
665
+ # Box corners relative to center (in LiDAR frame: x forward, y left)
666
+ corners_local = np.array([
667
+ [l/2, w/2], # front-right
668
+ [l/2, -w/2], # front-left
669
+ [-l/2, -w/2], # back-left
670
+ [-l/2, w/2] # back-right
671
+ ])
672
+
673
+ # Rotate corners
674
+ rot_mat = np.array([[cos_yaw, -sin_yaw], [sin_yaw, cos_yaw]])
675
+ corners_rotated = corners_local @ rot_mat.T
676
+
677
+ # Translate to world coordinates and convert to image space
678
+ corners_img = []
679
+ for corner in corners_rotated:
680
+ corner_x = x + corner[0] # x in LiDAR (forward)
681
+ corner_y = y + corner[1] # y in LiDAR (left)
682
+ corner_img_x = int((corner_y - y_min) / y_range * bev_w) # y -> img_x
683
+ corner_img_y = int((x_max - corner_x) / x_range * bev_h) # x -> img_y (flipped)
684
+ corners_img.append([corner_img_x, corner_img_y])
685
+
686
+ corners_img = np.array(corners_img, dtype=np.int32)
687
+
688
+ # Draw filled polygon (semi-transparent on white background)
689
+ overlay = bev_img.copy()
690
+ cv2.fillPoly(overlay, [corners_img], color)
691
+ cv2.addWeighted(overlay, 0.5, bev_img, 0.5, 0, bev_img)
692
+ # Draw outline (black on white background)
693
+ cv2.polylines(bev_img, [corners_img], True, (0, 0, 0), 2)
694
+
695
+ # Draw direction arrow (forward direction) - black on white
696
+ # In LiDAR: forward is +x, left is +y
697
+ # In BEV image: x -> img_y (flipped), y -> img_x
698
+ # So rotation: img_x += sin(yaw) * length, img_y -= cos(yaw) * length
699
+ arrow_length = max(box_l_px // 2, 10)
700
+ arrow_end_x = int(img_x + arrow_length * sin_yaw) # y component -> img_x
701
+ arrow_end_y = int(img_y - arrow_length * cos_yaw) # x component -> img_y (flipped)
702
+ cv2.arrowedLine(bev_img, (img_x, img_y), (arrow_end_x, arrow_end_y),
703
+ (0, 0, 0), 2, tipLength=0.3) # Black arrow
704
+
705
+ # Draw center point
706
+ cv2.circle(bev_img, (img_x, img_y), 3, (0, 0, 0), -1) # Black center point
707
+
708
+ # Rotate BEV map counterclockwise by 90 degrees (map only, not text)
709
+ center = (bev_w // 2, bev_h // 2)
710
+ rotation_matrix = cv2.getRotationMatrix2D(center, 90, 1.0) # 90 degrees counterclockwise
711
+ bev_img = cv2.warpAffine(bev_img, rotation_matrix, (bev_w, bev_h), borderValue=(255, 255, 255))
712
+
713
+ # Flip horizontally to fix mirror effect
714
+ bev_img = cv2.flip(bev_img, 1) # 1 for horizontal flip
715
+
716
+ text = 'BEV Map'
717
+ font = cv2.FONT_HERSHEY_SIMPLEX
718
+ font_scale = 1
719
+ thickness = 2
720
+ (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
721
+ text_x = bev_w - text_width - 10
722
+ text_y = text_height + 10
723
+ cv2.putText(bev_img, text, (text_x, text_y), font, font_scale, (0, 0, 0), thickness)
724
+
725
+ return bev_img
726
+
727
+
728
+ def visualize_results_np(img, result, lidar2img, img_norm_cfg, class_names, score_thr=0.3, pc_range=None):
729
+ num_cams = img.shape[1] if img.ndim == 5 else 1
730
+ raw_imgs = [denormalize_img_np(img[0, cam_idx], img_norm_cfg) for cam_idx in range(num_cams)]
731
+ boxes_3d = result.get('boxes_3d')
732
+ scores_3d = result.get('scores_3d')
733
+ labels_3d = result.get('labels_3d')
734
+ vis_imgs = []
735
+ boxes_3d_for_bev = labels_3d_for_bev = scores_3d_for_bev = None
736
+
737
+ if boxes_3d is not None and len(boxes_3d) > 0:
738
+ mask = (scores_3d > score_thr) if (score_thr > 0 and scores_3d is not None) else np.ones_like(scores_3d, dtype=bool)
739
+ if np.any(mask):
740
+ boxes_3d = boxes_3d[mask]
741
+ scores_3d = scores_3d[mask]
742
+ labels_3d = labels_3d[mask]
743
+ boxes_3d_for_bev = boxes_3d.copy()
744
+ labels_3d_for_bev = labels_3d.copy()
745
+ scores_3d_for_bev = scores_3d.copy()
746
+ for cam_idx, vis_img in enumerate(raw_imgs):
747
+ vis_img = vis_img.copy()
748
+ if lidar2img.shape[1] > cam_idx:
749
+ cam_lidar2img = lidar2img[0, cam_idx]
750
+ for box, label in zip(boxes_3d, labels_3d):
751
+ color = CLASS_COLORS.get(int(label), (255, 255, 255))
752
+ try:
753
+ vis_img = draw_bbox3d_on_img_custom_np(box[None], vis_img, cam_lidar2img, color=color, thickness=2)
754
+ except Exception:
755
+ pass
756
+ vis_imgs.append(vis_img)
757
+ else:
758
+ vis_imgs = raw_imgs
759
+
760
+ if pc_range is None:
761
+ pc_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
762
+
763
+ if boxes_3d_for_bev is not None and len(boxes_3d_for_bev) > 0:
764
+ bev_size = (vis_imgs[0].shape[0], vis_imgs[0].shape[0]) if vis_imgs else (800, 800)
765
+ bev_img = draw_bev_map(boxes_3d_for_bev, labels_3d_for_bev, scores_3d_for_bev, pc_range, bev_size=bev_size, score_thr=score_thr)
766
+ else:
767
+ bev_size = (vis_imgs[0].shape[0], vis_imgs[0].shape[0]) if vis_imgs else (800, 800)
768
+ bev_img = np.full((bev_size[1], bev_size[0], 3), 255, np.uint8)
769
+ cv2.putText(bev_img, 'BEV Map (No Detections)', (10, bev_size[1]//2), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
770
+
771
+ if len(vis_imgs) == 6:
772
+ target_height = max(img.shape[0] for img in vis_imgs)
773
+ resized_imgs = [img if img.shape[0] == target_height else cv2.resize(img, (int(img.shape[1] * target_height / img.shape[0]), target_height)) for img in vis_imgs]
774
+
775
+ reordered_imgs = [
776
+ resized_imgs[2], resized_imgs[0], resized_imgs[1],
777
+ cv2.flip(resized_imgs[4], 1), cv2.flip(resized_imgs[3], 1), cv2.flip(resized_imgs[5], 1)
778
+ ]
779
+ top_row = np.hstack(reordered_imgs[:3])
780
+ bottom_row = np.hstack(reordered_imgs[3:])
781
+ left_side = np.vstack([top_row, bottom_row])
782
+ bev_img = cv2.resize(bev_img, (int(bev_img.shape[1] * left_side.shape[0] / bev_img.shape[0]), left_side.shape[0]))
783
+ vis_img = np.hstack([left_side, bev_img])
784
+ elif len(vis_imgs) > 1:
785
+ target_height = max(img.shape[0] for img in vis_imgs)
786
+ resized_imgs = [img if img.shape[0] == target_height else cv2.resize(img, (int(img.shape[1] * target_height / img.shape[0]), target_height)) for img in vis_imgs]
787
+ if bev_img.shape[0] != target_height:
788
+ bev_img = cv2.resize(bev_img, (int(bev_img.shape[1] * target_height / bev_img.shape[0]), target_height))
789
+ vis_img = np.hstack([np.hstack(resized_imgs), bev_img])
790
+ else:
791
+ cam_img = vis_imgs[0] if vis_imgs else bev_img
792
+ if bev_img.shape[0] != cam_img.shape[0]:
793
+ bev_img = cv2.resize(bev_img, (int(bev_img.shape[1] * cam_img.shape[0] / bev_img.shape[0]), cam_img.shape[0]))
794
+ vis_img = np.hstack([cam_img, bev_img]) if vis_imgs else bev_img
795
+
796
+ return vis_img
797
+
798
+
799
+ def create_video_from_images(image_dir, output_video_path, fps=3):
800
+ import subprocess
801
+
802
+ image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
803
+ if len(image_files) == 0:
804
+ return
805
+
806
+ first_img = cv2.imread(osp.join(image_dir, image_files[0]))
807
+ if first_img is None:
808
+ return
809
+
810
+ height, width = first_img.shape[:2]
811
+
812
+ max_width, max_height = 1920, 1080
813
+ if width > max_width or height > max_height:
814
+ scale = min(max_width / width, max_height / height)
815
+ width, height = int(width * scale), int(height * scale)
816
+
817
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
818
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
819
+ if not video_writer.isOpened():
820
+ fourcc = cv2.VideoWriter_fourcc(*'XVID')
821
+ video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
822
+
823
+ for img_file in tqdm(image_files, desc=f"Creating video: {osp.basename(output_video_path)}"):
824
+ img_path = osp.join(image_dir, img_file)
825
+ img = cv2.imread(img_path)
826
+ if img is not None:
827
+ if img.shape[:2] != (height, width):
828
+ img = cv2.resize(img, (width, height))
829
+ video_writer.write(img)
830
+
831
+ video_writer.release()
832
+
833
+ def main():
834
+ args = parse_args()
835
+
836
+ # Load configuration from JSON
837
+ config = load_config_from_json(args.config_json)
838
+
839
+ # Create output directory
840
+ os.makedirs(args.output_dir, exist_ok=True)
841
+
842
+ # Load AXModel
843
+ ax_session = load_axmodel(args.model)
844
+
845
+ # Get model parameters from config
846
+ transformer_cfg = config['model']['transformer']
847
+ bev_h = transformer_cfg['bev_h']
848
+ bev_w = transformer_cfg['bev_w']
849
+ embed_dims = transformer_cfg['embed_dims']
850
+
851
+ # Load scene index
852
+ scene_index_path = osp.join(args.data_dir, 'scene_index.json')
853
+ with open(scene_index_path, 'r') as f:
854
+ scene_index_data = json.load(f)
855
+
856
+ scenes_dict = scene_index_data['scenes']
857
+ scene_names = list(scenes_dict.keys())
858
+
859
+ end_scene = args.end_scene if args.end_scene is not None else len(scene_names)
860
+ end_scene = min(end_scene, len(scene_names))
861
+
862
+ prev_frame_info = {
863
+ 'prev_bev': None,
864
+ 'scene_token': None,
865
+ 'prev_pos': np.zeros(3, dtype=np.float32),
866
+ 'prev_angle': 0.0,
867
+ }
868
+
869
+ scene_results = defaultdict(list)
870
+
871
+ # Process all scenes
872
+ for scene_idx in range(args.start_scene, end_scene):
873
+ scene_name = scene_names[scene_idx]
874
+ scene_info = scenes_dict[scene_name]
875
+ sample_indices = scene_info['samples']
876
+ num_frames = len(sample_indices)
877
+
878
+ print(f"Processing scene {scene_idx+1}/{len(scene_names)}: {scene_name} ({num_frames} frames)")
879
+
880
+ # Reset prev_bev for new scene
881
+ if scene_name != prev_frame_info['scene_token']:
882
+ prev_frame_info['prev_bev'] = None
883
+ prev_frame_info['prev_pos'] = np.zeros(3, dtype=np.float32)
884
+ prev_frame_info['prev_angle'] = 0.0
885
+
886
+ prev_frame_info['scene_token'] = scene_name
887
+
888
+ # Process all frames in this scene
889
+ for local_idx, frame_idx in enumerate(tqdm(sample_indices, desc=f"Scene {scene_name}")):
890
+ # Load data
891
+ img, lidar2img, can_bus, meta = load_data(args.data_dir, scene_name, frame_idx)
892
+
893
+ # Process can_bus (compute delta)
894
+ curr_can_bus_np = can_bus[0] # (18,)
895
+
896
+ tmp_pos = curr_can_bus_np[:3].copy()
897
+ tmp_angle = curr_can_bus_np[-1]
898
+
899
+ delta_can_bus_np = curr_can_bus_np.copy()
900
+
901
+ if prev_frame_info['prev_bev'] is not None and prev_frame_info['scene_token'] == scene_name:
902
+ delta_can_bus_np[:3] -= prev_frame_info['prev_pos']
903
+ delta_can_bus_np[-1] -= prev_frame_info['prev_angle']
904
+ else:
905
+ delta_can_bus_np[:3] = 0.0
906
+ delta_can_bus_np[-1] = 0.0
907
+
908
+ prev_frame_info['prev_pos'] = tmp_pos
909
+ prev_frame_info['prev_angle'] = tmp_angle
910
+
911
+ # Prepare prev_bev
912
+ prev_bev_input = next((inp for inp in ax_session.get_inputs() if inp.name == 'prev_bev'), None)
913
+ expected_shape = (bev_h * bev_w, 1, embed_dims)
914
+ if prev_bev_input is not None:
915
+ expected_shape = list(prev_bev_input.shape)
916
+ for i, dim in enumerate(expected_shape):
917
+ if isinstance(dim, str) or dim < 0:
918
+ expected_shape[i] = (bev_h * bev_w, 1, embed_dims)[i] if i < 3 else 1
919
+ expected_shape = tuple(expected_shape)
920
+
921
+ if prev_frame_info['prev_bev'] is None:
922
+ prev_bev = np.zeros(expected_shape, dtype=np.float32)
923
+ else:
924
+ prev_bev = prev_frame_info['prev_bev']
925
+ if prev_bev.shape != expected_shape and len(prev_bev.shape) == 3:
926
+ prev_bev = prev_bev.reshape(expected_shape)
927
+
928
+ # Prepare AXEngine inputs
929
+ img_np = img.astype(np.float32)
930
+ lidar2img_np = lidar2img.astype(np.float32)
931
+ can_bus_np = delta_can_bus_np.reshape(1, -1).astype(np.float32)
932
+
933
+ input_names = [inp.name for inp in ax_session.get_inputs()]
934
+ ax_inputs = {}
935
+ for name in input_names:
936
+ if name == 'img':
937
+ ax_inputs['img'] = img_np
938
+ elif name == 'can_bus':
939
+ ax_inputs['can_bus'] = can_bus_np
940
+ elif name == 'lidar2img':
941
+ ax_inputs['lidar2img'] = lidar2img_np
942
+ elif name == 'prev_bev':
943
+ ax_inputs['prev_bev'] = prev_bev
944
+
945
+ # Run inference
946
+ ax_outputs = ax_session.run(None, ax_inputs)
947
+ bev_embed, all_cls_scores, all_bbox_preds = ax_outputs
948
+
949
+ prev_frame_info['prev_bev'] = bev_embed
950
+
951
+ # Post-process
952
+ results = post_process_outputs_np(
953
+ all_cls_scores, all_bbox_preds, config, args.score_thr
954
+ )
955
+
956
+ # Visualize
957
+ img_norm_cfg = config['img_norm']
958
+ class_names = config['dataset']['class_names']
959
+ pc_range = config['model']['bbox_coder']['pc_range']
960
+ vis_img = visualize_results_np(
961
+ img, results[0], lidar2img, img_norm_cfg, class_names, args.score_thr, pc_range=pc_range
962
+ )
963
+
964
+ scene_results[scene_name].append({
965
+ 'frame_idx': local_idx,
966
+ 'result': results[0],
967
+ 'vis_img': vis_img,
968
+ 'meta': meta
969
+ })
970
+
971
+ # Save results
972
+ for scene_name, frames in tqdm(scene_results.items(), desc="Save scene results"):
973
+ scene_dir = osp.join(args.output_dir, scene_name)
974
+ os.makedirs(scene_dir, exist_ok=True)
975
+ images_dir = osp.join(scene_dir, 'images')
976
+ os.makedirs(images_dir, exist_ok=True)
977
+
978
+ for local_idx, frame_data in enumerate(frames):
979
+ vis_img = frame_data['vis_img']
980
+
981
+ if vis_img is None:
982
+ continue
983
+
984
+ if not isinstance(vis_img, np.ndarray):
985
+ vis_img = np.array(vis_img)
986
+
987
+ if vis_img.dtype != np.uint8:
988
+ vis_img = (vis_img * 255).astype(np.uint8) if vis_img.max() <= 1.0 else vis_img.astype(np.uint8)
989
+
990
+ if len(vis_img.shape) == 3 and vis_img.shape[0] in (1, 3):
991
+ vis_img = vis_img.transpose(1, 2, 0)
992
+
993
+ if vis_img.shape[0] > 0 and vis_img.shape[1] > 0:
994
+ cv2.imwrite(osp.join(images_dir, f'frame_{local_idx:06d}.png'), vis_img)
995
+
996
+ create_video_from_images(images_dir, osp.join(scene_dir, f'{scene_name}_result.mp4'), args.fps)
997
+ print(f"✓ Scene {scene_name}: {len(frames)} frames, video: {osp.join(scene_dir, f'{scene_name}_result.mp4')}")
998
+
999
+
1000
+ if __name__ == '__main__':
1001
+ main()
1002
+