File size: 7,567 Bytes
3ea6165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
Export SparseBEV to ONNX for inference via ONNX Runtime CoreML EP.

Usage:
    python export_onnx.py \
        --config configs/r50_nuimg_704x256_400q_36ep.py \
        --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth \
        --out sparsebev.onnx

Then run with CoreML EP:
    import onnxruntime as ort, numpy as np
    sess = ort.InferenceSession('sparsebev.onnx',
                                providers=['CoreMLExecutionProvider',
                                           'CPUExecutionProvider'])
    outputs = sess.run(None, {'img': img_np, 'lidar2img': l2i_np, 'time_diff': td_np})
    cls_scores, bbox_preds = outputs  # raw logits, apply NMSFreeCoder.decode() separately

Input format (all float32 numpy arrays):
    img        [1, 48, 3, 256, 704]  BGR, pixel values in [0, 255]
    lidar2img  [1, 48, 4, 4]         LiDAR-to-image projection matrices
    time_diff  [1, 8]                seconds since frame-0, one value per frame
                                     (frame 0 = 0.0, frame k = timestamp[0] - timestamp[k])
"""

import argparse
import sys
from unittest.mock import MagicMock

# mmcv is installed without compiled C++ ops (no mmcv-full on macOS).
# SparseBEV doesn't use any mmcv ops at inference time, so stub out the
# missing extension module before anything else imports mmcv.ops.
sys.modules['mmcv._ext'] = MagicMock()

import torch
import numpy as np

# Register all custom mmdet3d modules by importing the local package
sys.path.insert(0, '.')
import models  # noqa: F401  triggers __init__.py which registers DETECTORS etc.

from mmcv import Config
from mmdet3d.models import build_detector
from mmcv.runner import load_checkpoint
from models.onnx_wrapper import SparseBEVOnnxWrapper


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config',   default='configs/r50_nuimg_704x256_400q_36ep.py')
    parser.add_argument('--weights',  default='checkpoints/r50_nuimg_704x256_400q_36ep.pth')
    parser.add_argument('--out-dir',  default='exports',
                        help='Directory to write the ONNX model into')
    parser.add_argument('--out',      default=None,
                        help='Override output filename (default: derived from config + opset)')
    parser.add_argument('--opset',   type=int, default=18,
                        help='ONNX opset version (18 recommended for torch 2.x)')
    parser.add_argument('--validate', action='store_true',
                        help='Run ORT inference and compare to PyTorch output')
    return parser.parse_args()


def build_dummy_inputs(num_frames=8, num_cameras=6, H=256, W=704):
    """Return (img, lidar2img, time_diff) dummy tensors for export / validation."""
    img       = torch.zeros(1, num_frames * num_cameras, 3, H, W)
    lidar2img = torch.eye(4).reshape(1, 1, 4, 4).expand(1, num_frames * num_cameras, 4, 4).contiguous()
    time_diff = torch.zeros(1, num_frames)
    return img, lidar2img, time_diff


def main():
    args = parse_args()

    # ------------------------------------------------------------------ #
    # Resolve output path
    # ------------------------------------------------------------------ #
    import os
    os.makedirs(args.out_dir, exist_ok=True)

    if args.out is None:
        # Derive a descriptive name from the config stem.
        # e.g. configs/r50_nuimg_704x256_400q_36ep.py
        #   -> sparsebev_r50_nuimg_704x256_400q_36ep_opset18.onnx
        config_stem = os.path.splitext(os.path.basename(args.config))[0]
        args.out = os.path.join(args.out_dir,
                                f'sparsebev_{config_stem}_opset{args.opset}.onnx')
    else:
        args.out = os.path.join(args.out_dir, os.path.basename(args.out))

    # ------------------------------------------------------------------ #
    # Load model
    # ------------------------------------------------------------------ #
    cfg = Config.fromfile(args.config)
    model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
    load_checkpoint(model, args.weights, map_location='cpu')
    model.eval()

    wrapper = SparseBEVOnnxWrapper(model).eval()

    # ------------------------------------------------------------------ #
    # Dummy inputs
    # ------------------------------------------------------------------ #
    img, lidar2img, time_diff = build_dummy_inputs()

    # ------------------------------------------------------------------ #
    # Reference PyTorch forward (for later numerical comparison)
    # ------------------------------------------------------------------ #
    with torch.no_grad():
        ref_cls, ref_bbox = wrapper(img, lidar2img, time_diff)
    print(f'PyTorch output shapes: cls={tuple(ref_cls.shape)}  bbox={tuple(ref_bbox.shape)}')

    # ------------------------------------------------------------------ #
    # ONNX export
    # ------------------------------------------------------------------ #
    print(f'Exporting to {args.out} (opset {args.opset}) …')
    torch.onnx.export(
        wrapper,
        (img, lidar2img, time_diff),
        args.out,
        opset_version=args.opset,
        input_names=['img', 'lidar2img', 'time_diff'],
        output_names=['cls_scores', 'bbox_preds'],
        do_constant_folding=True,
        verbose=False,
    )
    print('Export done.')

    # ------------------------------------------------------------------ #
    # ONNX model check
    # ------------------------------------------------------------------ #
    import onnx
    model_proto = onnx.load(args.out)
    onnx.checker.check_model(model_proto)
    print('ONNX checker passed.')

    # ------------------------------------------------------------------ #
    # Optional: validate ORT CPU output against PyTorch
    # ------------------------------------------------------------------ #
    if args.validate:
        import onnxruntime as ort

        print('Running ORT CPU validation …')
        sess = ort.InferenceSession(args.out, providers=['CPUExecutionProvider'])
        feeds = {
            'img':       img.numpy(),
            'lidar2img': lidar2img.numpy(),
            'time_diff': time_diff.numpy(),
        }
        ort_cls, ort_bbox = sess.run(None, feeds)

        cls_diff  = np.abs(ref_cls.numpy()  - ort_cls).max()
        bbox_diff = np.abs(ref_bbox.numpy() - ort_bbox).max()
        print(f'Max absolute diff — cls: {cls_diff:.6f}   bbox: {bbox_diff:.6f}')

        if cls_diff < 5e-2 and bbox_diff < 5e-2:
            print('Validation PASSED.')
        else:
            print('WARNING: diff is larger than expected — check for unsupported ops.')

        # ------------------------------------------------------------------ #
        # CoreML EP — must pass MLComputeUnits explicitly; without it ORT
        # discards the EP entirely on first partition error instead of falling
        # back per-node to the CPU provider.
        # ------------------------------------------------------------------ #
        print('\nRunning CoreML EP …')
        sess_cml = ort.InferenceSession(
            args.out,
            providers=[
                ('CoreMLExecutionProvider', {'MLComputeUnits': 'ALL'}),
                'CPUExecutionProvider',
            ],
        )
        cml_cls, cml_bbox = sess_cml.run(None, feeds)
        cml_cls_diff  = np.abs(ref_cls.numpy()  - cml_cls).max()
        cml_bbox_diff = np.abs(ref_bbox.numpy() - cml_bbox).max()
        print(f'CoreML EP max diff — cls: {cml_cls_diff:.6f}   bbox: {cml_bbox_diff:.6f}')


if __name__ == '__main__':
    main()