Upload inference_axmodel.py
Browse files- inference_axmodel.py +1002 -0
inference_axmodel.py
ADDED
|
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import os.path as osp
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import axengine as axe
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(description='BEVFormer AXEngine Inference from Extracted Data')
|
| 15 |
+
parser.add_argument('model', help='AXModel path')
|
| 16 |
+
parser.add_argument('config_json', help='JSON config file path')
|
| 17 |
+
parser.add_argument('data_dir', help='extracted data directory (extracted_data)')
|
| 18 |
+
parser.add_argument('--output-dir', default='./inference_results_extracted', help='output directory')
|
| 19 |
+
parser.add_argument('--score-thr', type=float, default=0.1, help='score threshold')
|
| 20 |
+
parser.add_argument('--fps', type=int, default=3, help='video fps')
|
| 21 |
+
parser.add_argument('--start-scene', type=int, default=0, help='start scene index')
|
| 22 |
+
parser.add_argument('--end-scene', type=int, default=None, help='end scene index (None for all)')
|
| 23 |
+
return parser.parse_args()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_axmodel(axmodel_path):
|
| 27 |
+
"""Load AXModel"""
|
| 28 |
+
# 尝试使用 AxEngineExecutionProvider 而不是 AXCLRTExecutionProvider
|
| 29 |
+
providers = ['AxEngineExecutionProvider']
|
| 30 |
+
session = axe.InferenceSession(axmodel_path, providers=providers)
|
| 31 |
+
return session
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_config_from_json(config_path):
|
| 35 |
+
"""Load configuration from JSON file"""
|
| 36 |
+
with open(config_path, 'r') as f:
|
| 37 |
+
config = json.load(f)
|
| 38 |
+
return config
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def preprocess_image(img_path, img_norm_cfg, target_size=(480, 800)):
|
| 42 |
+
"""Preprocess image: load, resize, normalize
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
img_path: path to image file
|
| 46 |
+
img_norm_cfg: normalization config with 'mean', 'std', 'to_rgb'
|
| 47 |
+
target_size: (H, W) target size
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
img: (C, H, W) normalized numpy array, float32
|
| 51 |
+
"""
|
| 52 |
+
# Load image
|
| 53 |
+
img = cv2.imread(img_path)
|
| 54 |
+
if img is None:
|
| 55 |
+
raise ValueError(f"Cannot load image: {img_path}")
|
| 56 |
+
|
| 57 |
+
# Convert BGR to RGB if needed
|
| 58 |
+
if img_norm_cfg.get('to_rgb', True):
|
| 59 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 60 |
+
|
| 61 |
+
# Resize if needed
|
| 62 |
+
if img.shape[:2] != target_size:
|
| 63 |
+
img = cv2.resize(img, (target_size[1], target_size[0])) # (W, H)
|
| 64 |
+
|
| 65 |
+
# Convert to float and normalize
|
| 66 |
+
img = img.astype(np.float32)
|
| 67 |
+
mean = np.array(img_norm_cfg.get('mean', [123.675, 116.28, 103.53]), dtype=np.float32)
|
| 68 |
+
std = np.array(img_norm_cfg.get('std', [58.395, 57.12, 57.375]), dtype=np.float32)
|
| 69 |
+
|
| 70 |
+
img = (img - mean) / std
|
| 71 |
+
img = img.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
|
| 72 |
+
|
| 73 |
+
return img
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def load_data(data_dir, scene_name, frame_idx):
|
| 77 |
+
"""Load data
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
data_dir: data directory path
|
| 81 |
+
scene_name: scene name (scene token)
|
| 82 |
+
frame_idx: frame index (sample index)
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
img: (1, N, C, H, W) numpy array
|
| 86 |
+
lidar2img: (1, N, 4, 4) numpy array
|
| 87 |
+
can_bus: (1, 18) numpy array
|
| 88 |
+
meta: dict with metadata
|
| 89 |
+
"""
|
| 90 |
+
scene_dir = osp.join(data_dir, scene_name)
|
| 91 |
+
|
| 92 |
+
# Load meta
|
| 93 |
+
meta_path = osp.join(scene_dir, f'meta_{frame_idx:06d}.json')
|
| 94 |
+
with open(meta_path, 'r') as f:
|
| 95 |
+
meta = json.load(f)
|
| 96 |
+
|
| 97 |
+
# Get normalization config
|
| 98 |
+
img_norm_cfg = meta.get('img_norm_cfg', {
|
| 99 |
+
'mean': [123.675, 116.28, 103.53],
|
| 100 |
+
'std': [58.395, 57.12, 57.375],
|
| 101 |
+
'to_rgb': True
|
| 102 |
+
})
|
| 103 |
+
|
| 104 |
+
# Get image shape
|
| 105 |
+
img_shape = meta.get('img_shape', [[480, 800, 3]] * 6)
|
| 106 |
+
target_size = (img_shape[0][0], img_shape[0][1]) # (H, W)
|
| 107 |
+
|
| 108 |
+
# Load images for all cameras
|
| 109 |
+
num_cams = meta.get('num_cams', 6)
|
| 110 |
+
imgs = []
|
| 111 |
+
for cam_idx in range(num_cams):
|
| 112 |
+
img_path = osp.join(scene_dir, f'cam_{cam_idx:02d}_{frame_idx:06d}.png')
|
| 113 |
+
img = preprocess_image(img_path, img_norm_cfg, target_size)
|
| 114 |
+
imgs.append(img)
|
| 115 |
+
|
| 116 |
+
# Stack images: (N, C, H, W) -> (1, N, C, H, W)
|
| 117 |
+
img = np.stack(imgs, axis=0) # (N, C, H, W)
|
| 118 |
+
img = img[np.newaxis, ...] # (1, N, C, H, W)
|
| 119 |
+
|
| 120 |
+
# Load lidar2img: (N, 4, 4) -> (1, N, 4, 4)
|
| 121 |
+
lidar2img = np.array(meta['lidar2img'], dtype=np.float32) # (N, 4, 4)
|
| 122 |
+
lidar2img = lidar2img[np.newaxis, ...] # (1, N, 4, 4)
|
| 123 |
+
|
| 124 |
+
# Load can_bus: (18,) -> (1, 18)
|
| 125 |
+
can_bus = np.array(meta['can_bus'], dtype=np.float32) # (18,)
|
| 126 |
+
can_bus = can_bus[np.newaxis, ...] # (1, 18)
|
| 127 |
+
|
| 128 |
+
return img, lidar2img, can_bus, meta
|
| 129 |
+
|
| 130 |
+
CLASS_COLORS = {
|
| 131 |
+
0: (0, 255, 0), 1: (255, 255, 0), 2: (0, 0, 255), 3: (0, 165, 255),
|
| 132 |
+
4: (255, 0, 255), 5: (0, 255, 255), 6: (128, 0, 128), 7: (255, 165, 0),
|
| 133 |
+
8: (0, 0, 255), 9: (128, 128, 128),
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def denormalize_bbox_np(normalized_bboxes, pc_range):
|
| 138 |
+
"""Denormalize bbox using numpy only"""
|
| 139 |
+
# rotation
|
| 140 |
+
rot_sine = normalized_bboxes[..., 6:7]
|
| 141 |
+
rot_cosine = normalized_bboxes[..., 7:8]
|
| 142 |
+
|
| 143 |
+
rot = np.arctan2(rot_sine, rot_cosine)
|
| 144 |
+
|
| 145 |
+
# center in the bev
|
| 146 |
+
cx = normalized_bboxes[..., 0:1]
|
| 147 |
+
cy = normalized_bboxes[..., 1:2]
|
| 148 |
+
cz = normalized_bboxes[..., 4:5]
|
| 149 |
+
|
| 150 |
+
# size
|
| 151 |
+
w = normalized_bboxes[..., 2:3]
|
| 152 |
+
l = normalized_bboxes[..., 3:4]
|
| 153 |
+
h = normalized_bboxes[..., 5:6]
|
| 154 |
+
|
| 155 |
+
w = np.exp(w)
|
| 156 |
+
l = np.exp(l)
|
| 157 |
+
h = np.exp(h)
|
| 158 |
+
|
| 159 |
+
if normalized_bboxes.shape[-1] > 8:
|
| 160 |
+
# velocity
|
| 161 |
+
vx = normalized_bboxes[:, 8:9]
|
| 162 |
+
vy = normalized_bboxes[:, 9:10]
|
| 163 |
+
denormalized_bboxes = np.concatenate([cx, cy, cz, w, l, h, rot, vx, vy], axis=-1)
|
| 164 |
+
else:
|
| 165 |
+
denormalized_bboxes = np.concatenate([cx, cy, cz, w, l, h, rot], axis=-1)
|
| 166 |
+
return denormalized_bboxes
|
| 167 |
+
|
| 168 |
+
def decode_bboxes_custom_np(all_cls_scores, all_bbox_preds, pc_range, post_center_range, max_num=100, score_threshold=None, num_classes=10):
|
| 169 |
+
"""Custom bbox decode function"""
|
| 170 |
+
# Use output from the last decoder layer
|
| 171 |
+
all_cls_scores = all_cls_scores[-1] # (bs, num_query, num_classes)
|
| 172 |
+
all_bbox_preds = all_bbox_preds[-1] # (bs, num_query, 10)
|
| 173 |
+
|
| 174 |
+
batch_size = all_cls_scores.shape[0]
|
| 175 |
+
predictions_list = []
|
| 176 |
+
|
| 177 |
+
for i in range(batch_size):
|
| 178 |
+
cls_scores = all_cls_scores[i] # (num_query, num_classes)
|
| 179 |
+
bbox_preds = all_bbox_preds[i] # (num_query, 10)
|
| 180 |
+
|
| 181 |
+
# Apply sigmoid
|
| 182 |
+
cls_scores = 1.0 / (1.0 + np.exp(-cls_scores))
|
| 183 |
+
|
| 184 |
+
# TopK selection
|
| 185 |
+
cls_scores_flat = cls_scores.reshape(-1)
|
| 186 |
+
topk_indices = np.argsort(cls_scores_flat)[::-1][:max_num]
|
| 187 |
+
scores = cls_scores_flat[topk_indices]
|
| 188 |
+
labels = topk_indices % num_classes
|
| 189 |
+
bbox_index = topk_indices // num_classes
|
| 190 |
+
bbox_preds = bbox_preds[bbox_index]
|
| 191 |
+
|
| 192 |
+
# Denormalize bbox
|
| 193 |
+
final_box_preds = denormalize_bbox_np(bbox_preds, pc_range) # (max_num, 9)
|
| 194 |
+
final_scores = scores
|
| 195 |
+
final_preds = labels
|
| 196 |
+
|
| 197 |
+
# Apply score threshold
|
| 198 |
+
if score_threshold is not None:
|
| 199 |
+
thresh_mask = final_scores > score_threshold
|
| 200 |
+
tmp_score = score_threshold
|
| 201 |
+
while thresh_mask.sum() == 0:
|
| 202 |
+
tmp_score *= 0.9
|
| 203 |
+
if tmp_score < 0.01:
|
| 204 |
+
thresh_mask = np.ones(len(final_scores), dtype=bool)
|
| 205 |
+
break
|
| 206 |
+
thresh_mask = final_scores >= tmp_score
|
| 207 |
+
else:
|
| 208 |
+
thresh_mask = np.ones(len(final_scores), dtype=bool)
|
| 209 |
+
|
| 210 |
+
# Apply post processing range filtering
|
| 211 |
+
if post_center_range is not None:
|
| 212 |
+
post_center_range_arr = np.array(post_center_range)
|
| 213 |
+
mask = (final_box_preds[..., :3] >= post_center_range_arr[:3]).all(1)
|
| 214 |
+
mask &= (final_box_preds[..., :3] <= post_center_range_arr[3:]).all(1)
|
| 215 |
+
mask &= thresh_mask
|
| 216 |
+
|
| 217 |
+
boxes3d = final_box_preds[mask]
|
| 218 |
+
scores = final_scores[mask]
|
| 219 |
+
labels = final_preds[mask]
|
| 220 |
+
else:
|
| 221 |
+
boxes3d = final_box_preds[thresh_mask]
|
| 222 |
+
scores = final_scores[thresh_mask]
|
| 223 |
+
labels = final_preds[thresh_mask]
|
| 224 |
+
|
| 225 |
+
predictions_list.append({
|
| 226 |
+
'bboxes': boxes3d,
|
| 227 |
+
'scores': scores,
|
| 228 |
+
'labels': labels
|
| 229 |
+
})
|
| 230 |
+
|
| 231 |
+
return predictions_list
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_bboxes_custom_np(preds_dicts, pc_range, post_center_range, max_num=100, score_threshold=None, num_classes=10):
|
| 235 |
+
"""Custom get_bboxes function"""
|
| 236 |
+
# Decode bounding boxes
|
| 237 |
+
preds_list = decode_bboxes_custom_np(
|
| 238 |
+
preds_dicts['all_cls_scores'],
|
| 239 |
+
preds_dicts['all_bbox_preds'],
|
| 240 |
+
pc_range,
|
| 241 |
+
post_center_range,
|
| 242 |
+
max_num,
|
| 243 |
+
score_threshold,
|
| 244 |
+
num_classes
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
num_samples = len(preds_list)
|
| 248 |
+
ret_list = []
|
| 249 |
+
|
| 250 |
+
for i in range(num_samples):
|
| 251 |
+
preds = preds_list[i]
|
| 252 |
+
bboxes = preds['bboxes']
|
| 253 |
+
|
| 254 |
+
if len(bboxes) == 0:
|
| 255 |
+
ret_list.append((
|
| 256 |
+
np.zeros((0, 9), dtype=np.float32),
|
| 257 |
+
np.zeros((0,), dtype=np.float32),
|
| 258 |
+
np.zeros((0,), dtype=np.int64)
|
| 259 |
+
))
|
| 260 |
+
continue
|
| 261 |
+
|
| 262 |
+
# Adjust z coordinate: convert center z to bottom center z
|
| 263 |
+
bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
|
| 264 |
+
|
| 265 |
+
# Shrink box dimensions: multiply w, l, h by 0.9 to fix oversized boxes
|
| 266 |
+
bboxes[:, 3:6] = bboxes[:, 3:6] * 0.9
|
| 267 |
+
|
| 268 |
+
scores = preds['scores']
|
| 269 |
+
labels = preds['labels']
|
| 270 |
+
|
| 271 |
+
ret_list.append((bboxes, scores, labels))
|
| 272 |
+
|
| 273 |
+
return ret_list
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def format_bbox_result_np(bboxes, scores, labels):
|
| 277 |
+
return {
|
| 278 |
+
'boxes_3d': bboxes,
|
| 279 |
+
'scores_3d': scores,
|
| 280 |
+
'labels_3d': labels
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def rotation_3d_in_axis_np(points, angles, axis=2):
|
| 285 |
+
"""Rotate points by angles according to axis"""
|
| 286 |
+
rot_sin = np.sin(angles)
|
| 287 |
+
rot_cos = np.cos(angles)
|
| 288 |
+
ones = np.ones_like(rot_cos)
|
| 289 |
+
zeros = np.zeros_like(rot_cos)
|
| 290 |
+
|
| 291 |
+
if axis == 2 or axis == -1:
|
| 292 |
+
# Rotate around z-axis
|
| 293 |
+
# Build rotation matrix: (N, 3, 3)
|
| 294 |
+
N = len(angles)
|
| 295 |
+
rot_mat = np.zeros((N, 3, 3), dtype=points.dtype)
|
| 296 |
+
rot_mat[:, 0, 0] = rot_cos
|
| 297 |
+
rot_mat[:, 0, 1] = -rot_sin
|
| 298 |
+
rot_mat[:, 0, 2] = zeros
|
| 299 |
+
rot_mat[:, 1, 0] = rot_sin
|
| 300 |
+
rot_mat[:, 1, 1] = rot_cos
|
| 301 |
+
rot_mat[:, 1, 2] = zeros
|
| 302 |
+
rot_mat[:, 2, 0] = zeros
|
| 303 |
+
rot_mat[:, 2, 1] = zeros
|
| 304 |
+
rot_mat[:, 2, 2] = ones
|
| 305 |
+
|
| 306 |
+
# Rotation: (N, M, 3) @ (N, 3, 3) -> (N, M, 3)
|
| 307 |
+
return np.einsum('aij,ajk->aik', points, rot_mat)
|
| 308 |
+
else:
|
| 309 |
+
raise ValueError(f'Only axis=2 (z-axis) is supported for LiDAR boxes')
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def compute_bbox_corners_np(bboxes):
|
| 313 |
+
"""Compute 8 corners of 3D bbox"""
|
| 314 |
+
if len(bboxes) == 0:
|
| 315 |
+
return np.zeros((0, 8, 3), dtype=np.float32)
|
| 316 |
+
|
| 317 |
+
dtype = bboxes.dtype
|
| 318 |
+
|
| 319 |
+
# Extract bbox parameters
|
| 320 |
+
centers = bboxes[:, :3] # (N, 3) [x, y, z] - the bottom center
|
| 321 |
+
w = bboxes[:, 3:4] # width (y direction)
|
| 322 |
+
l = bboxes[:, 4:5] # length (x direction)
|
| 323 |
+
h = bboxes[:, 5:6] # height (z direction)
|
| 324 |
+
dims = np.concatenate([l, w, h], axis=1) # (N, 3) [x_size, y_size, z_size] = [l, w, h]
|
| 325 |
+
yaws = bboxes[:, 6] # (N,) yaw angle
|
| 326 |
+
|
| 327 |
+
# Fix: offset yaw by -80 degrees
|
| 328 |
+
yaws = yaws - (np.pi / 2.0 - np.pi / 18.0)
|
| 329 |
+
|
| 330 |
+
# Generate corners
|
| 331 |
+
corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1).astype(dtype)
|
| 332 |
+
|
| 333 |
+
# Rearrange to [0, 1, 3, 2, 4, 5, 7, 6]
|
| 334 |
+
corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
|
| 335 |
+
|
| 336 |
+
# Use relative origin [0.5, 0.5, 0] (bottom center)
|
| 337 |
+
corners_norm = corners_norm - np.array([0.5, 0.5, 0], dtype=dtype)
|
| 338 |
+
|
| 339 |
+
# Scale corners: dims is [x_size, y_size, z_size]
|
| 340 |
+
corners = dims[:, np.newaxis, :] * corners_norm[np.newaxis, :, :] # (N, 8, 3)
|
| 341 |
+
|
| 342 |
+
# Rotate around z-axis
|
| 343 |
+
corners = rotation_3d_in_axis_np(corners, yaws, axis=2)
|
| 344 |
+
|
| 345 |
+
# Translate to center point
|
| 346 |
+
corners += centers[:, np.newaxis, :]
|
| 347 |
+
|
| 348 |
+
return corners
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def draw_bbox3d_on_img_custom_np(bboxes, raw_img, lidar2img_rt, color=(0, 255, 0), thickness=2):
|
| 352 |
+
"""Custom 3D bbox drawing"""
|
| 353 |
+
img = raw_img.copy()
|
| 354 |
+
|
| 355 |
+
if len(bboxes) == 0:
|
| 356 |
+
return img
|
| 357 |
+
|
| 358 |
+
if not isinstance(bboxes, np.ndarray):
|
| 359 |
+
bboxes = np.array(bboxes)
|
| 360 |
+
if not isinstance(lidar2img_rt, np.ndarray):
|
| 361 |
+
lidar2img_rt = np.array(lidar2img_rt)
|
| 362 |
+
|
| 363 |
+
lidar2img_rt = lidar2img_rt.reshape(4, 4)
|
| 364 |
+
|
| 365 |
+
# Compute corners
|
| 366 |
+
corners_3d = compute_bbox_corners_np(bboxes) # (N, 8, 3)
|
| 367 |
+
|
| 368 |
+
num_bbox = corners_3d.shape[0]
|
| 369 |
+
|
| 370 |
+
# Project to 2D
|
| 371 |
+
corners_3d_flat = corners_3d.reshape(-1, 3) # (N*8, 3)
|
| 372 |
+
ones = np.ones((corners_3d_flat.shape[0], 1), dtype=np.float32)
|
| 373 |
+
pts_4d = np.concatenate([corners_3d_flat, ones], axis=-1) # (N*8, 4)
|
| 374 |
+
|
| 375 |
+
# Project
|
| 376 |
+
pts_2d = pts_4d @ lidar2img_rt.T # (N*8, 4)
|
| 377 |
+
|
| 378 |
+
# Perspective division
|
| 379 |
+
pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5)
|
| 380 |
+
pts_2d[:, 0] /= pts_2d[:, 2]
|
| 381 |
+
pts_2d[:, 1] /= pts_2d[:, 2]
|
| 382 |
+
|
| 383 |
+
imgfov_pts_2d = pts_2d[:, :2].reshape(num_bbox, 8, 2)
|
| 384 |
+
|
| 385 |
+
line_indices = ((0, 1), (0, 3), (0, 4), (1, 2), (1, 5), (3, 2), (3, 7),
|
| 386 |
+
(4, 5), (4, 7), (2, 6), (5, 6), (6, 7))
|
| 387 |
+
|
| 388 |
+
for i in range(num_bbox):
|
| 389 |
+
corners = imgfov_pts_2d[i].astype(np.int32)
|
| 390 |
+
for start, end in line_indices:
|
| 391 |
+
pt1 = (int(corners[start, 0]), int(corners[start, 1]))
|
| 392 |
+
pt2 = (int(corners[end, 0]), int(corners[end, 1]))
|
| 393 |
+
# Check if points are within image range
|
| 394 |
+
h, w = img.shape[:2]
|
| 395 |
+
if (0 <= pt1[0] < w and 0 <= pt1[1] < h) or (0 <= pt2[0] < w and 0 <= pt2[1] < h):
|
| 396 |
+
cv2.line(img, pt1, pt2, color, thickness, cv2.LINE_AA)
|
| 397 |
+
|
| 398 |
+
return img.astype(np.uint8)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def post_process_outputs_np(all_cls_scores, all_bbox_preds, config, score_thr=0.1):
|
| 402 |
+
bbox_coder = config['model']['bbox_coder']
|
| 403 |
+
pc_range = bbox_coder['pc_range']
|
| 404 |
+
post_center_range = bbox_coder['post_center_range']
|
| 405 |
+
max_num = bbox_coder['max_num']
|
| 406 |
+
score_threshold = bbox_coder.get('score_threshold', None)
|
| 407 |
+
num_classes = bbox_coder['num_classes']
|
| 408 |
+
|
| 409 |
+
preds_dicts = {
|
| 410 |
+
'all_cls_scores': all_cls_scores,
|
| 411 |
+
'all_bbox_preds': all_bbox_preds
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
bbox_list = get_bboxes_custom_np(
|
| 415 |
+
preds_dicts, pc_range, post_center_range,
|
| 416 |
+
max_num, score_threshold, num_classes
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
results = []
|
| 420 |
+
for bboxes, scores, labels in bbox_list:
|
| 421 |
+
# Set class score thresholds
|
| 422 |
+
class_score_thrs = {
|
| 423 |
+
0: 0.3, # Car
|
| 424 |
+
1: 0.3, # Truck
|
| 425 |
+
2: 0.3, # Construction vehicle
|
| 426 |
+
3: 0.3, # Bus
|
| 427 |
+
4: 0.3, # Trailer
|
| 428 |
+
5: 0.3, # Barrier
|
| 429 |
+
6: 0.3, # Motorcycle
|
| 430 |
+
7: 0.3, # Bicycle
|
| 431 |
+
8: 0.3, # Pedestrian
|
| 432 |
+
9: 0.3, # Traffic cone
|
| 433 |
+
}
|
| 434 |
+
default_thr = score_thr
|
| 435 |
+
|
| 436 |
+
keep_indices = []
|
| 437 |
+
for i in range(len(scores)):
|
| 438 |
+
cls_id = int(labels[i])
|
| 439 |
+
thr = class_score_thrs.get(cls_id, default_thr)
|
| 440 |
+
if scores[i] > thr:
|
| 441 |
+
keep_indices.append(i)
|
| 442 |
+
|
| 443 |
+
if len(keep_indices) == 0:
|
| 444 |
+
results.append(format_bbox_result_np(
|
| 445 |
+
np.zeros((0, 9), dtype=np.float32),
|
| 446 |
+
np.zeros((0,), dtype=np.float32),
|
| 447 |
+
np.zeros((0,), dtype=np.int64)
|
| 448 |
+
))
|
| 449 |
+
continue
|
| 450 |
+
|
| 451 |
+
keep_indices = np.array(keep_indices, dtype=np.int64)
|
| 452 |
+
bboxes = bboxes[keep_indices]
|
| 453 |
+
scores = scores[keep_indices]
|
| 454 |
+
labels = labels[keep_indices]
|
| 455 |
+
|
| 456 |
+
# Circle NMS
|
| 457 |
+
dist_thrs = {
|
| 458 |
+
0: 2.0, 1: 3.0, 2: 2.5, 3: 4.0, 4: 3.0,
|
| 459 |
+
5: 1.0, 6: 1.5, 7: 1.0, 8: 0.5, 9: 0.3,
|
| 460 |
+
}
|
| 461 |
+
|
| 462 |
+
if len(scores) > 0:
|
| 463 |
+
keep_nms = circle_nms_np(bboxes, scores, labels, dist_thrs)
|
| 464 |
+
if len(keep_nms) > 0:
|
| 465 |
+
bboxes = bboxes[keep_nms]
|
| 466 |
+
scores = scores[keep_nms]
|
| 467 |
+
labels = labels[keep_nms]
|
| 468 |
+
else:
|
| 469 |
+
results.append(format_bbox_result_np(
|
| 470 |
+
np.zeros((0, 9), dtype=np.float32),
|
| 471 |
+
np.zeros((0,), dtype=np.float32),
|
| 472 |
+
np.zeros((0,), dtype=np.int64)
|
| 473 |
+
))
|
| 474 |
+
continue
|
| 475 |
+
|
| 476 |
+
results.append(format_bbox_result_np(bboxes, scores, labels))
|
| 477 |
+
|
| 478 |
+
return results
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def circle_nms_np(bboxes, scores, labels, dist_thrs):
|
| 482 |
+
if len(bboxes) == 0:
|
| 483 |
+
return np.array([], dtype=np.int64)
|
| 484 |
+
|
| 485 |
+
keep = []
|
| 486 |
+
order = np.argsort(scores)[::-1]
|
| 487 |
+
bboxes = bboxes[order]
|
| 488 |
+
scores = scores[order]
|
| 489 |
+
labels = labels[order]
|
| 490 |
+
|
| 491 |
+
pts = bboxes[:, :2]
|
| 492 |
+
labels_np = labels
|
| 493 |
+
|
| 494 |
+
suppressed = np.zeros(len(bboxes), dtype=bool)
|
| 495 |
+
|
| 496 |
+
for i in range(len(bboxes)):
|
| 497 |
+
if suppressed[i]:
|
| 498 |
+
continue
|
| 499 |
+
keep.append(order[i])
|
| 500 |
+
|
| 501 |
+
curr_cls = int(labels_np[i])
|
| 502 |
+
radius = dist_thrs.get(curr_cls, 1.0)
|
| 503 |
+
|
| 504 |
+
if i + 1 < len(bboxes):
|
| 505 |
+
dists = np.linalg.norm(pts[i+1:] - pts[i], axis=1)
|
| 506 |
+
idx_to_suppress = np.where(
|
| 507 |
+
(dists < radius) & (labels_np[i+1:] == curr_cls)
|
| 508 |
+
)[0]
|
| 509 |
+
suppressed[i+1:][idx_to_suppress] = True
|
| 510 |
+
|
| 511 |
+
return np.array(keep, dtype=np.int64)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def denormalize_img_np(img_array, img_norm_cfg):
|
| 515 |
+
"""Denormalize image array (C, H, W) to (H, W, C) BGR"""
|
| 516 |
+
mean = np.array(img_norm_cfg.get('mean', [123.675, 116.28, 103.53]))
|
| 517 |
+
std = np.array(img_norm_cfg.get('std', [58.395, 57.12, 57.375]))
|
| 518 |
+
|
| 519 |
+
# (C, H, W) RGB -> (H, W, C) RGB
|
| 520 |
+
if img_array.ndim == 3:
|
| 521 |
+
img = img_array.transpose(1, 2, 0)
|
| 522 |
+
else:
|
| 523 |
+
img = img_array
|
| 524 |
+
img = (img * std + mean)
|
| 525 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
| 526 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 527 |
+
return img
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def draw_bev_map(bboxes, labels, scores, pc_range, bev_size=(800, 800), score_thr=0.1):
|
| 531 |
+
"""Draw BEV (Bird's Eye View) map with detections
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
bboxes: (N, 9) numpy array, format: [x, y, z, w, l, h, yaw, vx, vy]
|
| 535 |
+
labels: (N,) numpy array, class labels
|
| 536 |
+
scores: (N,) numpy array, detection scores
|
| 537 |
+
pc_range: [x_min, y_min, z_min, x_max, y_max, z_max]
|
| 538 |
+
bev_size: (width, height) of BEV image
|
| 539 |
+
score_thr: score threshold
|
| 540 |
+
|
| 541 |
+
Returns:
|
| 542 |
+
bev_img: (H, W, 3) numpy array, BEV visualization
|
| 543 |
+
"""
|
| 544 |
+
bev_w, bev_h = bev_size # BEV image size
|
| 545 |
+
bev_img = np.ones((bev_h, bev_w, 3), dtype=np.uint8) * 255 # White background
|
| 546 |
+
|
| 547 |
+
# Draw grid
|
| 548 |
+
x_min, y_min, z_min, x_max, y_max, z_max = pc_range
|
| 549 |
+
x_range = x_max - x_min
|
| 550 |
+
y_range = y_max - y_min
|
| 551 |
+
|
| 552 |
+
# Draw grid lines
|
| 553 |
+
grid_color = (200, 200, 200) # Light gray grid lines
|
| 554 |
+
for i in range(-5, 6):
|
| 555 |
+
x = x_min + (i + 5) * x_range / 10
|
| 556 |
+
y = y_min + (i + 5) * y_range / 10
|
| 557 |
+
# Vertical lines (y direction in LiDAR -> x direction in image)
|
| 558 |
+
img_x = int((y - y_min) / y_range * bev_w)
|
| 559 |
+
if 0 <= img_x < bev_w:
|
| 560 |
+
cv2.line(bev_img, (img_x, 0), (img_x, bev_h), grid_color, 1)
|
| 561 |
+
# Horizontal lines (x direction in LiDAR -> y direction in image, flipped)
|
| 562 |
+
img_y = int((x_max - x) / x_range * bev_h)
|
| 563 |
+
if 0 <= img_y < bev_h:
|
| 564 |
+
cv2.line(bev_img, (0, img_y), (bev_w, img_y), grid_color, 1)
|
| 565 |
+
|
| 566 |
+
# Draw center lines (ego vehicle position) - darker on white background
|
| 567 |
+
center_x = int((0 - y_min) / y_range * bev_w)
|
| 568 |
+
center_y = int((x_max - 0) / x_range * bev_h)
|
| 569 |
+
cv2.line(bev_img, (center_x, 0), (center_x, bev_h), (150, 150, 150), 2)
|
| 570 |
+
cv2.line(bev_img, (0, center_y), (bev_w, center_y), (150, 150, 150), 2)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
ego_length_px = 30 # pixels (representing ~4.5m, along x-axis rightward)
|
| 574 |
+
ego_width_px = 12 # pixels (representing ~1.8m, along y-axis downward)
|
| 575 |
+
|
| 576 |
+
ego_corners_local = np.array([
|
| 577 |
+
[ego_length_px//2, -ego_width_px//2], # front-top (head)
|
| 578 |
+
[ego_length_px//2, ego_width_px//2], # front-bottom
|
| 579 |
+
[-ego_length_px//2, ego_width_px//2], # back-bottom
|
| 580 |
+
[-ego_length_px//2, -ego_width_px//2], # back-top
|
| 581 |
+
], dtype=np.float32)
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
rotation_angle_90 = np.pi / 2 # 90 degrees in radians
|
| 585 |
+
cos_rot_90 = np.cos(rotation_angle_90)
|
| 586 |
+
sin_rot_90 = np.sin(rotation_angle_90)
|
| 587 |
+
rot_mat_90 = np.array([[cos_rot_90, -sin_rot_90], [sin_rot_90, cos_rot_90]])
|
| 588 |
+
|
| 589 |
+
ego_corners_rotated_90 = ego_corners_local @ rot_mat_90.T
|
| 590 |
+
|
| 591 |
+
ego_corners_rotated = ego_corners_rotated_90 @ rot_mat_90.T
|
| 592 |
+
|
| 593 |
+
# Translate to image coordinates (center position)
|
| 594 |
+
ego_corners = []
|
| 595 |
+
for corner in ego_corners_rotated:
|
| 596 |
+
corner_img_x = int(center_x + corner[0])
|
| 597 |
+
corner_img_y = int(center_y + corner[1])
|
| 598 |
+
ego_corners.append([corner_img_x, corner_img_y])
|
| 599 |
+
ego_corners = np.array(ego_corners, dtype=np.int32)
|
| 600 |
+
|
| 601 |
+
# Draw filled rectangle
|
| 602 |
+
cv2.fillPoly(bev_img, [ego_corners], (0, 0, 255)) # Red filled
|
| 603 |
+
cv2.polylines(bev_img, [ego_corners], True, (0, 0, 0), 2) # Black outline
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
arrow_length = ego_length_px // 2
|
| 607 |
+
initial_direction = np.array([1.0, 0.0])
|
| 608 |
+
arrow_dir_rotated_90 = initial_direction @ rot_mat_90.T
|
| 609 |
+
arrow_dir_rotated = arrow_dir_rotated_90 @ rot_mat_90.T
|
| 610 |
+
arrow_end_x = int(center_x + arrow_length * arrow_dir_rotated[0])
|
| 611 |
+
arrow_end_y = int(center_y + arrow_length * arrow_dir_rotated[1])
|
| 612 |
+
cv2.arrowedLine(bev_img, (center_x, center_y), (arrow_end_x, arrow_end_y),
|
| 613 |
+
(0, 0, 0), 3, tipLength=0.3) # Black arrow
|
| 614 |
+
|
| 615 |
+
if len(bboxes) == 0:
|
| 616 |
+
return bev_img
|
| 617 |
+
|
| 618 |
+
if score_thr > 0:
|
| 619 |
+
mask = scores > score_thr
|
| 620 |
+
bboxes = bboxes[mask]
|
| 621 |
+
labels = labels[mask]
|
| 622 |
+
scores = scores[mask]
|
| 623 |
+
|
| 624 |
+
if len(bboxes) == 0:
|
| 625 |
+
return bev_img
|
| 626 |
+
|
| 627 |
+
default_color = (255, 255, 255)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
for i in range(len(bboxes)):
|
| 631 |
+
box = bboxes[i]
|
| 632 |
+
label = int(labels[i])
|
| 633 |
+
score = float(scores[i])
|
| 634 |
+
color = CLASS_COLORS.get(label, default_color)
|
| 635 |
+
|
| 636 |
+
x, y, z = box[0], box[1], box[2] # center position
|
| 637 |
+
w, l, h = box[3], box[4], box[5] # width, length, height
|
| 638 |
+
yaw = box[6] # yaw angle
|
| 639 |
+
|
| 640 |
+
yaw = yaw - np.pi / 2.0 # Subtract 90 degrees (counterclockwise)
|
| 641 |
+
|
| 642 |
+
# Convert to image coordinates
|
| 643 |
+
# Note: In LiDAR coordinate, x is forward, y is left, z is up
|
| 644 |
+
# In BEV image (top-down view):
|
| 645 |
+
# - x (forward) -> image y (downward, flipped)
|
| 646 |
+
# - y (left) -> image x (rightward)
|
| 647 |
+
# So: img_x = (y - y_min) / y_range * bev_w
|
| 648 |
+
# img_y = (x_max - x) / x_range * bev_h (flip x to get top-down view)
|
| 649 |
+
img_x = int((y - y_min) / y_range * bev_w)
|
| 650 |
+
img_y = int((x_max - x) / x_range * bev_h) # Flip x for top-down view
|
| 651 |
+
|
| 652 |
+
# Skip if outside image
|
| 653 |
+
if not (0 <= img_x < bev_w and 0 <= img_y < bev_h):
|
| 654 |
+
continue
|
| 655 |
+
|
| 656 |
+
# Calculate box dimensions in image space
|
| 657 |
+
box_w_px = int(w / x_range * bev_w)
|
| 658 |
+
box_l_px = int(l / y_range * bev_h)
|
| 659 |
+
|
| 660 |
+
# Draw rotated rectangle
|
| 661 |
+
# Calculate 4 corners of the box in LiDAR coordinates
|
| 662 |
+
cos_yaw = np.cos(yaw)
|
| 663 |
+
sin_yaw = np.sin(yaw)
|
| 664 |
+
|
| 665 |
+
# Box corners relative to center (in LiDAR frame: x forward, y left)
|
| 666 |
+
corners_local = np.array([
|
| 667 |
+
[l/2, w/2], # front-right
|
| 668 |
+
[l/2, -w/2], # front-left
|
| 669 |
+
[-l/2, -w/2], # back-left
|
| 670 |
+
[-l/2, w/2] # back-right
|
| 671 |
+
])
|
| 672 |
+
|
| 673 |
+
# Rotate corners
|
| 674 |
+
rot_mat = np.array([[cos_yaw, -sin_yaw], [sin_yaw, cos_yaw]])
|
| 675 |
+
corners_rotated = corners_local @ rot_mat.T
|
| 676 |
+
|
| 677 |
+
# Translate to world coordinates and convert to image space
|
| 678 |
+
corners_img = []
|
| 679 |
+
for corner in corners_rotated:
|
| 680 |
+
corner_x = x + corner[0] # x in LiDAR (forward)
|
| 681 |
+
corner_y = y + corner[1] # y in LiDAR (left)
|
| 682 |
+
corner_img_x = int((corner_y - y_min) / y_range * bev_w) # y -> img_x
|
| 683 |
+
corner_img_y = int((x_max - corner_x) / x_range * bev_h) # x -> img_y (flipped)
|
| 684 |
+
corners_img.append([corner_img_x, corner_img_y])
|
| 685 |
+
|
| 686 |
+
corners_img = np.array(corners_img, dtype=np.int32)
|
| 687 |
+
|
| 688 |
+
# Draw filled polygon (semi-transparent on white background)
|
| 689 |
+
overlay = bev_img.copy()
|
| 690 |
+
cv2.fillPoly(overlay, [corners_img], color)
|
| 691 |
+
cv2.addWeighted(overlay, 0.5, bev_img, 0.5, 0, bev_img)
|
| 692 |
+
# Draw outline (black on white background)
|
| 693 |
+
cv2.polylines(bev_img, [corners_img], True, (0, 0, 0), 2)
|
| 694 |
+
|
| 695 |
+
# Draw direction arrow (forward direction) - black on white
|
| 696 |
+
# In LiDAR: forward is +x, left is +y
|
| 697 |
+
# In BEV image: x -> img_y (flipped), y -> img_x
|
| 698 |
+
# So rotation: img_x += sin(yaw) * length, img_y -= cos(yaw) * length
|
| 699 |
+
arrow_length = max(box_l_px // 2, 10)
|
| 700 |
+
arrow_end_x = int(img_x + arrow_length * sin_yaw) # y component -> img_x
|
| 701 |
+
arrow_end_y = int(img_y - arrow_length * cos_yaw) # x component -> img_y (flipped)
|
| 702 |
+
cv2.arrowedLine(bev_img, (img_x, img_y), (arrow_end_x, arrow_end_y),
|
| 703 |
+
(0, 0, 0), 2, tipLength=0.3) # Black arrow
|
| 704 |
+
|
| 705 |
+
# Draw center point
|
| 706 |
+
cv2.circle(bev_img, (img_x, img_y), 3, (0, 0, 0), -1) # Black center point
|
| 707 |
+
|
| 708 |
+
# Rotate BEV map counterclockwise by 90 degrees (map only, not text)
|
| 709 |
+
center = (bev_w // 2, bev_h // 2)
|
| 710 |
+
rotation_matrix = cv2.getRotationMatrix2D(center, 90, 1.0) # 90 degrees counterclockwise
|
| 711 |
+
bev_img = cv2.warpAffine(bev_img, rotation_matrix, (bev_w, bev_h), borderValue=(255, 255, 255))
|
| 712 |
+
|
| 713 |
+
# Flip horizontally to fix mirror effect
|
| 714 |
+
bev_img = cv2.flip(bev_img, 1) # 1 for horizontal flip
|
| 715 |
+
|
| 716 |
+
text = 'BEV Map'
|
| 717 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 718 |
+
font_scale = 1
|
| 719 |
+
thickness = 2
|
| 720 |
+
(text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
|
| 721 |
+
text_x = bev_w - text_width - 10
|
| 722 |
+
text_y = text_height + 10
|
| 723 |
+
cv2.putText(bev_img, text, (text_x, text_y), font, font_scale, (0, 0, 0), thickness)
|
| 724 |
+
|
| 725 |
+
return bev_img
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
def visualize_results_np(img, result, lidar2img, img_norm_cfg, class_names, score_thr=0.3, pc_range=None):
|
| 729 |
+
num_cams = img.shape[1] if img.ndim == 5 else 1
|
| 730 |
+
raw_imgs = [denormalize_img_np(img[0, cam_idx], img_norm_cfg) for cam_idx in range(num_cams)]
|
| 731 |
+
boxes_3d = result.get('boxes_3d')
|
| 732 |
+
scores_3d = result.get('scores_3d')
|
| 733 |
+
labels_3d = result.get('labels_3d')
|
| 734 |
+
vis_imgs = []
|
| 735 |
+
boxes_3d_for_bev = labels_3d_for_bev = scores_3d_for_bev = None
|
| 736 |
+
|
| 737 |
+
if boxes_3d is not None and len(boxes_3d) > 0:
|
| 738 |
+
mask = (scores_3d > score_thr) if (score_thr > 0 and scores_3d is not None) else np.ones_like(scores_3d, dtype=bool)
|
| 739 |
+
if np.any(mask):
|
| 740 |
+
boxes_3d = boxes_3d[mask]
|
| 741 |
+
scores_3d = scores_3d[mask]
|
| 742 |
+
labels_3d = labels_3d[mask]
|
| 743 |
+
boxes_3d_for_bev = boxes_3d.copy()
|
| 744 |
+
labels_3d_for_bev = labels_3d.copy()
|
| 745 |
+
scores_3d_for_bev = scores_3d.copy()
|
| 746 |
+
for cam_idx, vis_img in enumerate(raw_imgs):
|
| 747 |
+
vis_img = vis_img.copy()
|
| 748 |
+
if lidar2img.shape[1] > cam_idx:
|
| 749 |
+
cam_lidar2img = lidar2img[0, cam_idx]
|
| 750 |
+
for box, label in zip(boxes_3d, labels_3d):
|
| 751 |
+
color = CLASS_COLORS.get(int(label), (255, 255, 255))
|
| 752 |
+
try:
|
| 753 |
+
vis_img = draw_bbox3d_on_img_custom_np(box[None], vis_img, cam_lidar2img, color=color, thickness=2)
|
| 754 |
+
except Exception:
|
| 755 |
+
pass
|
| 756 |
+
vis_imgs.append(vis_img)
|
| 757 |
+
else:
|
| 758 |
+
vis_imgs = raw_imgs
|
| 759 |
+
|
| 760 |
+
if pc_range is None:
|
| 761 |
+
pc_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
|
| 762 |
+
|
| 763 |
+
if boxes_3d_for_bev is not None and len(boxes_3d_for_bev) > 0:
|
| 764 |
+
bev_size = (vis_imgs[0].shape[0], vis_imgs[0].shape[0]) if vis_imgs else (800, 800)
|
| 765 |
+
bev_img = draw_bev_map(boxes_3d_for_bev, labels_3d_for_bev, scores_3d_for_bev, pc_range, bev_size=bev_size, score_thr=score_thr)
|
| 766 |
+
else:
|
| 767 |
+
bev_size = (vis_imgs[0].shape[0], vis_imgs[0].shape[0]) if vis_imgs else (800, 800)
|
| 768 |
+
bev_img = np.full((bev_size[1], bev_size[0], 3), 255, np.uint8)
|
| 769 |
+
cv2.putText(bev_img, 'BEV Map (No Detections)', (10, bev_size[1]//2), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2)
|
| 770 |
+
|
| 771 |
+
if len(vis_imgs) == 6:
|
| 772 |
+
target_height = max(img.shape[0] for img in vis_imgs)
|
| 773 |
+
resized_imgs = [img if img.shape[0] == target_height else cv2.resize(img, (int(img.shape[1] * target_height / img.shape[0]), target_height)) for img in vis_imgs]
|
| 774 |
+
|
| 775 |
+
reordered_imgs = [
|
| 776 |
+
resized_imgs[2], resized_imgs[0], resized_imgs[1],
|
| 777 |
+
cv2.flip(resized_imgs[4], 1), cv2.flip(resized_imgs[3], 1), cv2.flip(resized_imgs[5], 1)
|
| 778 |
+
]
|
| 779 |
+
top_row = np.hstack(reordered_imgs[:3])
|
| 780 |
+
bottom_row = np.hstack(reordered_imgs[3:])
|
| 781 |
+
left_side = np.vstack([top_row, bottom_row])
|
| 782 |
+
bev_img = cv2.resize(bev_img, (int(bev_img.shape[1] * left_side.shape[0] / bev_img.shape[0]), left_side.shape[0]))
|
| 783 |
+
vis_img = np.hstack([left_side, bev_img])
|
| 784 |
+
elif len(vis_imgs) > 1:
|
| 785 |
+
target_height = max(img.shape[0] for img in vis_imgs)
|
| 786 |
+
resized_imgs = [img if img.shape[0] == target_height else cv2.resize(img, (int(img.shape[1] * target_height / img.shape[0]), target_height)) for img in vis_imgs]
|
| 787 |
+
if bev_img.shape[0] != target_height:
|
| 788 |
+
bev_img = cv2.resize(bev_img, (int(bev_img.shape[1] * target_height / bev_img.shape[0]), target_height))
|
| 789 |
+
vis_img = np.hstack([np.hstack(resized_imgs), bev_img])
|
| 790 |
+
else:
|
| 791 |
+
cam_img = vis_imgs[0] if vis_imgs else bev_img
|
| 792 |
+
if bev_img.shape[0] != cam_img.shape[0]:
|
| 793 |
+
bev_img = cv2.resize(bev_img, (int(bev_img.shape[1] * cam_img.shape[0] / bev_img.shape[0]), cam_img.shape[0]))
|
| 794 |
+
vis_img = np.hstack([cam_img, bev_img]) if vis_imgs else bev_img
|
| 795 |
+
|
| 796 |
+
return vis_img
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
def create_video_from_images(image_dir, output_video_path, fps=3):
|
| 800 |
+
import subprocess
|
| 801 |
+
|
| 802 |
+
image_files = sorted([f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
|
| 803 |
+
if len(image_files) == 0:
|
| 804 |
+
return
|
| 805 |
+
|
| 806 |
+
first_img = cv2.imread(osp.join(image_dir, image_files[0]))
|
| 807 |
+
if first_img is None:
|
| 808 |
+
return
|
| 809 |
+
|
| 810 |
+
height, width = first_img.shape[:2]
|
| 811 |
+
|
| 812 |
+
max_width, max_height = 1920, 1080
|
| 813 |
+
if width > max_width or height > max_height:
|
| 814 |
+
scale = min(max_width / width, max_height / height)
|
| 815 |
+
width, height = int(width * scale), int(height * scale)
|
| 816 |
+
|
| 817 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 818 |
+
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 819 |
+
if not video_writer.isOpened():
|
| 820 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
| 821 |
+
video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
| 822 |
+
|
| 823 |
+
for img_file in tqdm(image_files, desc=f"Creating video: {osp.basename(output_video_path)}"):
|
| 824 |
+
img_path = osp.join(image_dir, img_file)
|
| 825 |
+
img = cv2.imread(img_path)
|
| 826 |
+
if img is not None:
|
| 827 |
+
if img.shape[:2] != (height, width):
|
| 828 |
+
img = cv2.resize(img, (width, height))
|
| 829 |
+
video_writer.write(img)
|
| 830 |
+
|
| 831 |
+
video_writer.release()
|
| 832 |
+
|
| 833 |
+
def main():
|
| 834 |
+
args = parse_args()
|
| 835 |
+
|
| 836 |
+
# Load configuration from JSON
|
| 837 |
+
config = load_config_from_json(args.config_json)
|
| 838 |
+
|
| 839 |
+
# Create output directory
|
| 840 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 841 |
+
|
| 842 |
+
# Load AXModel
|
| 843 |
+
ax_session = load_axmodel(args.model)
|
| 844 |
+
|
| 845 |
+
# Get model parameters from config
|
| 846 |
+
transformer_cfg = config['model']['transformer']
|
| 847 |
+
bev_h = transformer_cfg['bev_h']
|
| 848 |
+
bev_w = transformer_cfg['bev_w']
|
| 849 |
+
embed_dims = transformer_cfg['embed_dims']
|
| 850 |
+
|
| 851 |
+
# Load scene index
|
| 852 |
+
scene_index_path = osp.join(args.data_dir, 'scene_index.json')
|
| 853 |
+
with open(scene_index_path, 'r') as f:
|
| 854 |
+
scene_index_data = json.load(f)
|
| 855 |
+
|
| 856 |
+
scenes_dict = scene_index_data['scenes']
|
| 857 |
+
scene_names = list(scenes_dict.keys())
|
| 858 |
+
|
| 859 |
+
end_scene = args.end_scene if args.end_scene is not None else len(scene_names)
|
| 860 |
+
end_scene = min(end_scene, len(scene_names))
|
| 861 |
+
|
| 862 |
+
prev_frame_info = {
|
| 863 |
+
'prev_bev': None,
|
| 864 |
+
'scene_token': None,
|
| 865 |
+
'prev_pos': np.zeros(3, dtype=np.float32),
|
| 866 |
+
'prev_angle': 0.0,
|
| 867 |
+
}
|
| 868 |
+
|
| 869 |
+
scene_results = defaultdict(list)
|
| 870 |
+
|
| 871 |
+
# Process all scenes
|
| 872 |
+
for scene_idx in range(args.start_scene, end_scene):
|
| 873 |
+
scene_name = scene_names[scene_idx]
|
| 874 |
+
scene_info = scenes_dict[scene_name]
|
| 875 |
+
sample_indices = scene_info['samples']
|
| 876 |
+
num_frames = len(sample_indices)
|
| 877 |
+
|
| 878 |
+
print(f"Processing scene {scene_idx+1}/{len(scene_names)}: {scene_name} ({num_frames} frames)")
|
| 879 |
+
|
| 880 |
+
# Reset prev_bev for new scene
|
| 881 |
+
if scene_name != prev_frame_info['scene_token']:
|
| 882 |
+
prev_frame_info['prev_bev'] = None
|
| 883 |
+
prev_frame_info['prev_pos'] = np.zeros(3, dtype=np.float32)
|
| 884 |
+
prev_frame_info['prev_angle'] = 0.0
|
| 885 |
+
|
| 886 |
+
prev_frame_info['scene_token'] = scene_name
|
| 887 |
+
|
| 888 |
+
# Process all frames in this scene
|
| 889 |
+
for local_idx, frame_idx in enumerate(tqdm(sample_indices, desc=f"Scene {scene_name}")):
|
| 890 |
+
# Load data
|
| 891 |
+
img, lidar2img, can_bus, meta = load_data(args.data_dir, scene_name, frame_idx)
|
| 892 |
+
|
| 893 |
+
# Process can_bus (compute delta)
|
| 894 |
+
curr_can_bus_np = can_bus[0] # (18,)
|
| 895 |
+
|
| 896 |
+
tmp_pos = curr_can_bus_np[:3].copy()
|
| 897 |
+
tmp_angle = curr_can_bus_np[-1]
|
| 898 |
+
|
| 899 |
+
delta_can_bus_np = curr_can_bus_np.copy()
|
| 900 |
+
|
| 901 |
+
if prev_frame_info['prev_bev'] is not None and prev_frame_info['scene_token'] == scene_name:
|
| 902 |
+
delta_can_bus_np[:3] -= prev_frame_info['prev_pos']
|
| 903 |
+
delta_can_bus_np[-1] -= prev_frame_info['prev_angle']
|
| 904 |
+
else:
|
| 905 |
+
delta_can_bus_np[:3] = 0.0
|
| 906 |
+
delta_can_bus_np[-1] = 0.0
|
| 907 |
+
|
| 908 |
+
prev_frame_info['prev_pos'] = tmp_pos
|
| 909 |
+
prev_frame_info['prev_angle'] = tmp_angle
|
| 910 |
+
|
| 911 |
+
# Prepare prev_bev
|
| 912 |
+
prev_bev_input = next((inp for inp in ax_session.get_inputs() if inp.name == 'prev_bev'), None)
|
| 913 |
+
expected_shape = (bev_h * bev_w, 1, embed_dims)
|
| 914 |
+
if prev_bev_input is not None:
|
| 915 |
+
expected_shape = list(prev_bev_input.shape)
|
| 916 |
+
for i, dim in enumerate(expected_shape):
|
| 917 |
+
if isinstance(dim, str) or dim < 0:
|
| 918 |
+
expected_shape[i] = (bev_h * bev_w, 1, embed_dims)[i] if i < 3 else 1
|
| 919 |
+
expected_shape = tuple(expected_shape)
|
| 920 |
+
|
| 921 |
+
if prev_frame_info['prev_bev'] is None:
|
| 922 |
+
prev_bev = np.zeros(expected_shape, dtype=np.float32)
|
| 923 |
+
else:
|
| 924 |
+
prev_bev = prev_frame_info['prev_bev']
|
| 925 |
+
if prev_bev.shape != expected_shape and len(prev_bev.shape) == 3:
|
| 926 |
+
prev_bev = prev_bev.reshape(expected_shape)
|
| 927 |
+
|
| 928 |
+
# Prepare AXEngine inputs
|
| 929 |
+
img_np = img.astype(np.float32)
|
| 930 |
+
lidar2img_np = lidar2img.astype(np.float32)
|
| 931 |
+
can_bus_np = delta_can_bus_np.reshape(1, -1).astype(np.float32)
|
| 932 |
+
|
| 933 |
+
input_names = [inp.name for inp in ax_session.get_inputs()]
|
| 934 |
+
ax_inputs = {}
|
| 935 |
+
for name in input_names:
|
| 936 |
+
if name == 'img':
|
| 937 |
+
ax_inputs['img'] = img_np
|
| 938 |
+
elif name == 'can_bus':
|
| 939 |
+
ax_inputs['can_bus'] = can_bus_np
|
| 940 |
+
elif name == 'lidar2img':
|
| 941 |
+
ax_inputs['lidar2img'] = lidar2img_np
|
| 942 |
+
elif name == 'prev_bev':
|
| 943 |
+
ax_inputs['prev_bev'] = prev_bev
|
| 944 |
+
|
| 945 |
+
# Run inference
|
| 946 |
+
ax_outputs = ax_session.run(None, ax_inputs)
|
| 947 |
+
bev_embed, all_cls_scores, all_bbox_preds = ax_outputs
|
| 948 |
+
|
| 949 |
+
prev_frame_info['prev_bev'] = bev_embed
|
| 950 |
+
|
| 951 |
+
# Post-process
|
| 952 |
+
results = post_process_outputs_np(
|
| 953 |
+
all_cls_scores, all_bbox_preds, config, args.score_thr
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
# Visualize
|
| 957 |
+
img_norm_cfg = config['img_norm']
|
| 958 |
+
class_names = config['dataset']['class_names']
|
| 959 |
+
pc_range = config['model']['bbox_coder']['pc_range']
|
| 960 |
+
vis_img = visualize_results_np(
|
| 961 |
+
img, results[0], lidar2img, img_norm_cfg, class_names, args.score_thr, pc_range=pc_range
|
| 962 |
+
)
|
| 963 |
+
|
| 964 |
+
scene_results[scene_name].append({
|
| 965 |
+
'frame_idx': local_idx,
|
| 966 |
+
'result': results[0],
|
| 967 |
+
'vis_img': vis_img,
|
| 968 |
+
'meta': meta
|
| 969 |
+
})
|
| 970 |
+
|
| 971 |
+
# Save results
|
| 972 |
+
for scene_name, frames in tqdm(scene_results.items(), desc="Save scene results"):
|
| 973 |
+
scene_dir = osp.join(args.output_dir, scene_name)
|
| 974 |
+
os.makedirs(scene_dir, exist_ok=True)
|
| 975 |
+
images_dir = osp.join(scene_dir, 'images')
|
| 976 |
+
os.makedirs(images_dir, exist_ok=True)
|
| 977 |
+
|
| 978 |
+
for local_idx, frame_data in enumerate(frames):
|
| 979 |
+
vis_img = frame_data['vis_img']
|
| 980 |
+
|
| 981 |
+
if vis_img is None:
|
| 982 |
+
continue
|
| 983 |
+
|
| 984 |
+
if not isinstance(vis_img, np.ndarray):
|
| 985 |
+
vis_img = np.array(vis_img)
|
| 986 |
+
|
| 987 |
+
if vis_img.dtype != np.uint8:
|
| 988 |
+
vis_img = (vis_img * 255).astype(np.uint8) if vis_img.max() <= 1.0 else vis_img.astype(np.uint8)
|
| 989 |
+
|
| 990 |
+
if len(vis_img.shape) == 3 and vis_img.shape[0] in (1, 3):
|
| 991 |
+
vis_img = vis_img.transpose(1, 2, 0)
|
| 992 |
+
|
| 993 |
+
if vis_img.shape[0] > 0 and vis_img.shape[1] > 0:
|
| 994 |
+
cv2.imwrite(osp.join(images_dir, f'frame_{local_idx:06d}.png'), vis_img)
|
| 995 |
+
|
| 996 |
+
create_video_from_images(images_dir, osp.join(scene_dir, f'{scene_name}_result.mp4'), args.fps)
|
| 997 |
+
print(f"✓ Scene {scene_name}: {len(frames)} frames, video: {osp.join(scene_dir, f'{scene_name}_result.mp4')}")
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
if __name__ == '__main__':
|
| 1001 |
+
main()
|
| 1002 |
+
|