chinmaygarde commited on
Commit
3ea6165
·
unverified ·
1 Parent(s): 6d2df77

Attempt an export to ONNX.

Browse files
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.jpg filter=lfs diff=lfs merge=lfs -text
2
  *.jpeg filter=lfs diff=lfs merge=lfs -text
3
  *.png filter=lfs diff=lfs merge=lfs -text
 
1
+ exports/** filter=lfs diff=lfs merge=lfs -text
2
  *.jpg filter=lfs diff=lfs merge=lfs -text
3
  *.jpeg filter=lfs diff=lfs merge=lfs -text
4
  *.png filter=lfs diff=lfs merge=lfs -text
.gitignore CHANGED
@@ -51,3 +51,4 @@ checkpoints
51
  pretrain
52
  *.png
53
  *.jpg
 
 
51
  pretrain
52
  *.png
53
  *.jpg
54
+ .claude/settings.local.json
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md CHANGED
@@ -169,6 +169,72 @@ Visualize the sampling points (like Fig. 6 in the paper):
169
  python viz_sample_points.py --config configs/r50_nuimg_704x256.py --weights checkpoints/r50_nuimg_704x256.pth
170
  ```
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ## Acknowledgements
173
 
174
  Many thanks to these excellent open-source projects:
 
169
  python viz_sample_points.py --config configs/r50_nuimg_704x256.py --weights checkpoints/r50_nuimg_704x256.pth
170
  ```
171
 
172
+ ## Changes from upstream
173
+
174
+ This fork adds ONNX export support targeting [ONNX Runtime's CoreML Execution Provider](https://onnxruntime.ai/docs/execution-providers/CoreML-ExecutionProvider.html) for inference on Apple Silicon (Mac Studio).
175
+
176
+ ### Dependency management
177
+
178
+ - `pyproject.toml` / `uv.lock` — project dependencies managed with [uv](https://docs.astral.sh/uv/)
179
+ - `justfile` — task runner for common operations
180
+
181
+ ### ONNX export
182
+
183
+ Three code changes were required to make the model traceable with `torch.onnx.export`:
184
+
185
+ **`models/sparsebev_sampling.py`** — `sampling_4d()`
186
+ - Replaced 6-dimensional advanced tensor indexing (not supported by the ONNX tracer) with `torch.gather` for best-view selection
187
+
188
+ **`models/csrc/wrapper.py`** — new `msmv_sampling_onnx()`
189
+ - Added an ONNX-compatible sampling path that uses 4D `F.grid_sample` (ONNX opset 16+) and `torch.gather` for view selection, replacing the original 5D volumetric `grid_sample` which is not in the ONNX spec
190
+ - The existing CUDA kernel path (`msmv_sampling` / `msmv_sampling_pytorch`) is preserved and used when CUDA is available
191
+
192
+ **`models/sparsebev_transformer.py`**
193
+ - `SparseBEVTransformerDecoder.forward()`: added a fast path that accepts pre-computed `time_diff` and `lidar2img` tensors directly, bypassing the NumPy preprocessing that is not traceable
194
+ - `SparseBEVTransformerDecoderLayer.forward()`: replaced a masked in-place assignment (`tensor[mask] = value`) with `torch.where`, which is ONNX-compatible
195
+ - `SparseBEVSelfAttention.calc_bbox_dists()`: replaced a Python loop over the batch dimension with a vectorised `torch.norm` using broadcasting
196
+
197
+ ### New files
198
+
199
+ | File | Purpose |
200
+ |------|---------|
201
+ | `export_onnx.py` | Exports the model to ONNX, runs ORT CPU + CoreML EP validation |
202
+ | `models/onnx_wrapper.py` | Thin `nn.Module` wrapper that accepts pre-computed tensors instead of `img_metas` dicts |
203
+ | `justfile` | `just onnx_export` / `just onnx_export_validate` |
204
+ | `exports/` | ONNX model files tracked via Git LFS |
205
+
206
+ ### Running the export
207
+
208
+ ```bash
209
+ just onnx_export
210
+ # or with validation against PyTorch and CoreML EP:
211
+ just onnx_export_validate
212
+ ```
213
+
214
+ Exported models land in `exports/` as `sparsebev_{config}_opset{N}.onnx` (+ `.onnx.data` for weights).
215
+
216
+ **Inference with ONNX Runtime:**
217
+
218
+ ```python
219
+ import onnxruntime as ort
220
+ sess = ort.InferenceSession(
221
+ 'exports/sparsebev_r50_nuimg_704x256_400q_36ep_opset18.onnx',
222
+ providers=[('CoreMLExecutionProvider', {'MLComputeUnits': 'ALL'}),
223
+ 'CPUExecutionProvider'],
224
+ )
225
+ cls_scores, bbox_preds = sess.run(None, {
226
+ 'img': img_np, # [1, 48, 3, 256, 704] float32 BGR
227
+ 'lidar2img': lidar2img_np, # [1, 48, 4, 4] float32
228
+ 'time_diff': time_diff_np, # [1, 8] float32, seconds since frame 0
229
+ })
230
+ # cls_scores: [6, 1, 400, 10] raw logits per decoder layer
231
+ # bbox_preds: [6, 1, 400, 10] raw box params — decode with NMSFreeCoder
232
+ ```
233
+
234
+ The `MLComputeUnits` option must be passed explicitly; without it ONNX Runtime discards the CoreML EP on the first unsupported partition instead of falling back per-node.
235
+
236
+ ---
237
+
238
  ## Acknowledgements
239
 
240
  Many thanks to these excellent open-source projects:
export_onnx.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Export SparseBEV to ONNX for inference via ONNX Runtime CoreML EP.
3
+
4
+ Usage:
5
+ python export_onnx.py \
6
+ --config configs/r50_nuimg_704x256_400q_36ep.py \
7
+ --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth \
8
+ --out sparsebev.onnx
9
+
10
+ Then run with CoreML EP:
11
+ import onnxruntime as ort, numpy as np
12
+ sess = ort.InferenceSession('sparsebev.onnx',
13
+ providers=['CoreMLExecutionProvider',
14
+ 'CPUExecutionProvider'])
15
+ outputs = sess.run(None, {'img': img_np, 'lidar2img': l2i_np, 'time_diff': td_np})
16
+ cls_scores, bbox_preds = outputs # raw logits, apply NMSFreeCoder.decode() separately
17
+
18
+ Input format (all float32 numpy arrays):
19
+ img [1, 48, 3, 256, 704] BGR, pixel values in [0, 255]
20
+ lidar2img [1, 48, 4, 4] LiDAR-to-image projection matrices
21
+ time_diff [1, 8] seconds since frame-0, one value per frame
22
+ (frame 0 = 0.0, frame k = timestamp[0] - timestamp[k])
23
+ """
24
+
25
+ import argparse
26
+ import sys
27
+ from unittest.mock import MagicMock
28
+
29
+ # mmcv is installed without compiled C++ ops (no mmcv-full on macOS).
30
+ # SparseBEV doesn't use any mmcv ops at inference time, so stub out the
31
+ # missing extension module before anything else imports mmcv.ops.
32
+ sys.modules['mmcv._ext'] = MagicMock()
33
+
34
+ import torch
35
+ import numpy as np
36
+
37
+ # Register all custom mmdet3d modules by importing the local package
38
+ sys.path.insert(0, '.')
39
+ import models # noqa: F401 triggers __init__.py which registers DETECTORS etc.
40
+
41
+ from mmcv import Config
42
+ from mmdet3d.models import build_detector
43
+ from mmcv.runner import load_checkpoint
44
+ from models.onnx_wrapper import SparseBEVOnnxWrapper
45
+
46
+
47
+ def parse_args():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument('--config', default='configs/r50_nuimg_704x256_400q_36ep.py')
50
+ parser.add_argument('--weights', default='checkpoints/r50_nuimg_704x256_400q_36ep.pth')
51
+ parser.add_argument('--out-dir', default='exports',
52
+ help='Directory to write the ONNX model into')
53
+ parser.add_argument('--out', default=None,
54
+ help='Override output filename (default: derived from config + opset)')
55
+ parser.add_argument('--opset', type=int, default=18,
56
+ help='ONNX opset version (18 recommended for torch 2.x)')
57
+ parser.add_argument('--validate', action='store_true',
58
+ help='Run ORT inference and compare to PyTorch output')
59
+ return parser.parse_args()
60
+
61
+
62
+ def build_dummy_inputs(num_frames=8, num_cameras=6, H=256, W=704):
63
+ """Return (img, lidar2img, time_diff) dummy tensors for export / validation."""
64
+ img = torch.zeros(1, num_frames * num_cameras, 3, H, W)
65
+ lidar2img = torch.eye(4).reshape(1, 1, 4, 4).expand(1, num_frames * num_cameras, 4, 4).contiguous()
66
+ time_diff = torch.zeros(1, num_frames)
67
+ return img, lidar2img, time_diff
68
+
69
+
70
+ def main():
71
+ args = parse_args()
72
+
73
+ # ------------------------------------------------------------------ #
74
+ # Resolve output path
75
+ # ------------------------------------------------------------------ #
76
+ import os
77
+ os.makedirs(args.out_dir, exist_ok=True)
78
+
79
+ if args.out is None:
80
+ # Derive a descriptive name from the config stem.
81
+ # e.g. configs/r50_nuimg_704x256_400q_36ep.py
82
+ # -> sparsebev_r50_nuimg_704x256_400q_36ep_opset18.onnx
83
+ config_stem = os.path.splitext(os.path.basename(args.config))[0]
84
+ args.out = os.path.join(args.out_dir,
85
+ f'sparsebev_{config_stem}_opset{args.opset}.onnx')
86
+ else:
87
+ args.out = os.path.join(args.out_dir, os.path.basename(args.out))
88
+
89
+ # ------------------------------------------------------------------ #
90
+ # Load model
91
+ # ------------------------------------------------------------------ #
92
+ cfg = Config.fromfile(args.config)
93
+ model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
94
+ load_checkpoint(model, args.weights, map_location='cpu')
95
+ model.eval()
96
+
97
+ wrapper = SparseBEVOnnxWrapper(model).eval()
98
+
99
+ # ------------------------------------------------------------------ #
100
+ # Dummy inputs
101
+ # ------------------------------------------------------------------ #
102
+ img, lidar2img, time_diff = build_dummy_inputs()
103
+
104
+ # ------------------------------------------------------------------ #
105
+ # Reference PyTorch forward (for later numerical comparison)
106
+ # ------------------------------------------------------------------ #
107
+ with torch.no_grad():
108
+ ref_cls, ref_bbox = wrapper(img, lidar2img, time_diff)
109
+ print(f'PyTorch output shapes: cls={tuple(ref_cls.shape)} bbox={tuple(ref_bbox.shape)}')
110
+
111
+ # ------------------------------------------------------------------ #
112
+ # ONNX export
113
+ # ------------------------------------------------------------------ #
114
+ print(f'Exporting to {args.out} (opset {args.opset}) …')
115
+ torch.onnx.export(
116
+ wrapper,
117
+ (img, lidar2img, time_diff),
118
+ args.out,
119
+ opset_version=args.opset,
120
+ input_names=['img', 'lidar2img', 'time_diff'],
121
+ output_names=['cls_scores', 'bbox_preds'],
122
+ do_constant_folding=True,
123
+ verbose=False,
124
+ )
125
+ print('Export done.')
126
+
127
+ # ------------------------------------------------------------------ #
128
+ # ONNX model check
129
+ # ------------------------------------------------------------------ #
130
+ import onnx
131
+ model_proto = onnx.load(args.out)
132
+ onnx.checker.check_model(model_proto)
133
+ print('ONNX checker passed.')
134
+
135
+ # ------------------------------------------------------------------ #
136
+ # Optional: validate ORT CPU output against PyTorch
137
+ # ------------------------------------------------------------------ #
138
+ if args.validate:
139
+ import onnxruntime as ort
140
+
141
+ print('Running ORT CPU validation …')
142
+ sess = ort.InferenceSession(args.out, providers=['CPUExecutionProvider'])
143
+ feeds = {
144
+ 'img': img.numpy(),
145
+ 'lidar2img': lidar2img.numpy(),
146
+ 'time_diff': time_diff.numpy(),
147
+ }
148
+ ort_cls, ort_bbox = sess.run(None, feeds)
149
+
150
+ cls_diff = np.abs(ref_cls.numpy() - ort_cls).max()
151
+ bbox_diff = np.abs(ref_bbox.numpy() - ort_bbox).max()
152
+ print(f'Max absolute diff — cls: {cls_diff:.6f} bbox: {bbox_diff:.6f}')
153
+
154
+ if cls_diff < 5e-2 and bbox_diff < 5e-2:
155
+ print('Validation PASSED.')
156
+ else:
157
+ print('WARNING: diff is larger than expected — check for unsupported ops.')
158
+
159
+ # ------------------------------------------------------------------ #
160
+ # CoreML EP — must pass MLComputeUnits explicitly; without it ORT
161
+ # discards the EP entirely on first partition error instead of falling
162
+ # back per-node to the CPU provider.
163
+ # ------------------------------------------------------------------ #
164
+ print('\nRunning CoreML EP …')
165
+ sess_cml = ort.InferenceSession(
166
+ args.out,
167
+ providers=[
168
+ ('CoreMLExecutionProvider', {'MLComputeUnits': 'ALL'}),
169
+ 'CPUExecutionProvider',
170
+ ],
171
+ )
172
+ cml_cls, cml_bbox = sess_cml.run(None, feeds)
173
+ cml_cls_diff = np.abs(ref_cls.numpy() - cml_cls).max()
174
+ cml_bbox_diff = np.abs(ref_bbox.numpy() - cml_bbox).max()
175
+ print(f'CoreML EP max diff — cls: {cml_cls_diff:.6f} bbox: {cml_bbox_diff:.6f}')
176
+
177
+
178
+ if __name__ == '__main__':
179
+ main()
justfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python := "uv run python"
2
+
3
+ config := "configs/r50_nuimg_704x256_400q_36ep.py"
4
+ weights := "checkpoints/r50_nuimg_704x256_400q_36ep.pth"
5
+ out_dir := "exports"
6
+
7
+ # Export the model to ONNX (output goes to exports/ with a descriptive name)
8
+ onnx_export config=config weights=weights out_dir=out_dir:
9
+ {{ python }} export_onnx.py \
10
+ --config {{ config }} \
11
+ --weights {{ weights }} \
12
+ --out-dir {{ out_dir }}
13
+
14
+ # Export and validate against PyTorch + CoreML EP
15
+ onnx_export_validate config=config weights=weights out_dir=out_dir:
16
+ {{ python }} export_onnx.py \
17
+ --config {{ config }} \
18
+ --weights {{ weights }} \
19
+ --out-dir {{ out_dir }} \
20
+ --validate
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from sparsebev!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
models/csrc/wrapper.py CHANGED
@@ -91,3 +91,57 @@ def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
91
  return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
92
  else:
93
  return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
92
  else:
93
  return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
94
+
95
+
96
+ def msmv_sampling_onnx(mlvl_feats, uv, view_idx, scale_weights):
97
+ """
98
+ ONNX-compatible multi-scale multi-view sampling using 4D F.grid_sample.
99
+
100
+ Replaces the 5D volumetric grid_sample used in msmv_sampling_pytorch with
101
+ separate per-view 4D grid_samples followed by a torch.gather for view
102
+ selection. All ops are in ONNX opset 16.
103
+
104
+ Args:
105
+ mlvl_feats: list of [BTG, C, N, H, W] channel-first feature maps
106
+ uv: [BTG, Q, P, 2] normalised (u, v) in [0, 1]
107
+ view_idx: [BTG, Q, P] integer camera-view indices
108
+ scale_weights:[BTG, Q, P, L] softmax weights over pyramid levels
109
+ Returns:
110
+ [BTG, Q, C, P]
111
+ """
112
+ BTG, C, N, _, _ = mlvl_feats[0].shape
113
+ _, Q, P, _ = uv.shape
114
+
115
+ # Convert UV from [0, 1] to [-1, 1] for F.grid_sample
116
+ uv_gs = uv * 2.0 - 1.0 # [BTG, Q, P, 2]
117
+
118
+ # Tile UV for all N views: [BTG*N, Q, P, 2]
119
+ # Use expand+contiguous+reshape (maps to ONNX Expand, better CoreML EP support
120
+ # than repeat_interleave which maps to ONNX Tile and can trip up CoreML)
121
+ uv_gs = uv_gs.unsqueeze(1).expand(BTG, N, Q, P, 2).contiguous().reshape(BTG * N, Q, P, 2)
122
+
123
+ # Pre-expand view_idx for gathering along the N dim: [BTG, C, 1, Q, P]
124
+ view_idx_g = view_idx[:, None, None, :, :].expand(BTG, C, 1, Q, P)
125
+
126
+ final = torch.zeros(BTG, C, Q, P, device=mlvl_feats[0].device, dtype=mlvl_feats[0].dtype)
127
+
128
+ for lvl, feat in enumerate(mlvl_feats):
129
+ _, _, _, H_lvl, W_lvl = feat.shape
130
+
131
+ # [BTG, C, N, H, W] -> [BTG, N, C, H, W] -> [BTG*N, C, H, W]
132
+ feat_4d = feat.permute(0, 2, 1, 3, 4).reshape(BTG * N, C, H_lvl, W_lvl)
133
+
134
+ # 4D grid_sample: [BTG*N, C, Q, P]
135
+ sampled = F.grid_sample(feat_4d, uv_gs, mode='bilinear', padding_mode='zeros', align_corners=True)
136
+
137
+ # [BTG*N, C, Q, P] -> [BTG, N, C, Q, P] -> [BTG, C, N, Q, P]
138
+ sampled = sampled.reshape(BTG, N, C, Q, P).permute(0, 2, 1, 3, 4)
139
+
140
+ # Gather the selected camera view: [BTG, C, 1, Q, P] -> [BTG, C, Q, P]
141
+ sampled = torch.gather(sampled, 2, view_idx_g).squeeze(2)
142
+
143
+ # Accumulate with per-level scale weight
144
+ w = scale_weights[..., lvl].reshape(BTG, 1, Q, P)
145
+ final = final + sampled * w
146
+
147
+ return final.permute(0, 2, 1, 3) # [BTG, Q, C, P]
models/onnx_wrapper.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SparseBEVOnnxWrapper(nn.Module):
6
+ """
7
+ Thin wrapper around SparseBEV for ONNX export.
8
+
9
+ Accepts pre-computed tensors instead of the img_metas dict so the graph
10
+ boundary is clean. Returns raw decoder logits without NMS or decoding so
11
+ post-processing can stay in Python.
12
+
13
+ Inputs (all float32):
14
+ img [B, T*N, 3, H, W] — BGR images, will be normalised inside
15
+ lidar2img [B, T*N, 4, 4] — LiDAR-to-image projection matrices
16
+ time_diff [B, T] — seconds since the first frame (per frame,
17
+ averaged across the N cameras)
18
+
19
+ Outputs:
20
+ cls_scores [num_layers, B, Q, num_classes]
21
+ bbox_preds [num_layers, B, Q, 10]
22
+ """
23
+
24
+ def __init__(self, model, image_h=256, image_w=704, num_frames=8, num_cameras=6):
25
+ super().__init__()
26
+ self.model = model
27
+ self.image_h = image_h
28
+ self.image_w = image_w
29
+ self.num_frames = num_frames
30
+ self.num_cameras = num_cameras
31
+
32
+ # Disable stochastic augmentations that are meaningless at inference
33
+ self.model.use_grid_mask = False
34
+ # Disable FP16 casting decorators
35
+ self.model.fp16_enabled = False
36
+
37
+ def forward(self, img, lidar2img, time_diff):
38
+ B, TN, C, H, W = img.shape
39
+
40
+ # Build a minimal img_metas. Only the Python-constant fields are here;
41
+ # the tensor fields (time_diff, lidar2img) are injected as real tensors
42
+ # so the ONNX tracer includes them in the graph.
43
+ img_shape = (self.image_h, self.image_w, C)
44
+ img_metas = [{
45
+ 'img_shape': [img_shape] * TN,
46
+ 'ori_shape': [img_shape] * TN,
47
+ 'time_diff': time_diff, # tensor — flows into the ONNX graph
48
+ 'lidar2img': lidar2img, # tensor — flows into the ONNX graph
49
+ }]
50
+
51
+ # Backbone + FPN
52
+ img_feats = self.model.extract_feat(img=img, img_metas=img_metas)
53
+
54
+ # Detection head — returns raw predictions, no NMS
55
+ outs = self.model.pts_bbox_head(img_feats, img_metas)
56
+
57
+ cls_scores = outs['all_cls_scores'] # [num_layers, B, Q, num_classes]
58
+ bbox_preds = outs['all_bbox_preds'] # [num_layers, B, Q, 10]
59
+
60
+ return cls_scores, bbox_preds
models/sparsebev_sampling.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import torch.nn.functional as F
3
  from .bbox.utils import decode_bbox
4
  from .utils import rotation_3d_in_axis, DUMP
5
- from .csrc.wrapper import msmv_sampling, msmv_sampling_pytorch
6
 
7
 
8
  def make_sample_points(query_bbox, offset, pc_range):
@@ -88,38 +88,55 @@ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, im
88
  valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
89
  sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
90
 
91
- # prepare batched indexing
92
- i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
93
- i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
94
- i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
95
- i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
96
- i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
97
- i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
98
- i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
99
- i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
100
-
101
  # we only keep at most one valid sampling point, see https://zhuanlan.zhihu.com/p/654821380
102
- i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
103
-
104
- # index the only one sampling point and its valid flag
105
- sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
106
- valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
107
-
108
- # treat the view index as a new axis for grid_sample and normalize the view index to [0, 1]
109
- sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / (N - 1)], dim=-1)
110
-
111
- # reorganize the tensor to stack T and G to the batch dim for better parallelism
112
- sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
113
- sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
114
- sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
115
-
116
- # reorganize the tensor to stack T and G to the batch dim for better parallelism
117
- scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
118
- scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
119
- scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
120
-
121
- # multi-scale multi-view grid sample
122
- final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # reorganize the sampled features
125
  C = final.shape[2] # [BTG, Q, C, P]
 
2
  import torch.nn.functional as F
3
  from .bbox.utils import decode_bbox
4
  from .utils import rotation_3d_in_axis, DUMP
5
+ from .csrc.wrapper import msmv_sampling, msmv_sampling_pytorch, msmv_sampling_onnx, MSMV_CUDA
6
 
7
 
8
  def make_sample_points(query_bbox, offset, pc_range):
 
88
  valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
89
  sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
90
 
 
 
 
 
 
 
 
 
 
 
91
  # we only keep at most one valid sampling point, see https://zhuanlan.zhihu.com/p/654821380
92
+ i_view = torch.argmax(valid_mask, dim=-1, keepdim=True) # [B, T, Q, GP, 1]
93
+
94
+ if MSMV_CUDA:
95
+ # Original fancy-indexing path (used with CUDA kernel on Linux/Windows)
96
+ i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
97
+ i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
98
+ i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
99
+ i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
100
+ i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
101
+ i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
102
+ i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
103
+ i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
104
+
105
+ sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :]
106
+ valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view]
107
+
108
+ # treat the view index as a new axis for grid_sample, normalise to [0, 1]
109
+ sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / (N - 1)], dim=-1)
110
+
111
+ sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
112
+ sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6)
113
+ sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
114
+
115
+ scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
116
+ scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
117
+ scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
118
+
119
+ final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
120
+ else:
121
+ # ONNX-compatible path: torch.gather + 4D grid_sample (no custom CUDA ops)
122
+ # Select best-view UV coords via gather [B, T, Q, GP, 1, 2]
123
+ i_view_uv = i_view.unsqueeze(-1).expand(B, T, Q, G * P, 1, 2)
124
+ sample_points_cam = torch.gather(sample_points_cam, 4, i_view_uv).squeeze(4) # [B, T, Q, GP, 2]
125
+
126
+ # Reorganize UV to [B*T*G, Q, P, 2]
127
+ sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 2)
128
+ sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5) # [B, T, G, Q, P, 2]
129
+ sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 2)
130
+
131
+ # Reorganize view_idx to [B*T*G, Q, P]
132
+ i_view = i_view.squeeze(4).reshape(B, T, Q, G, P)
133
+ i_view = i_view.permute(0, 1, 3, 2, 4).reshape(B*T*G, Q, P)
134
+
135
+ scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
136
+ scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
137
+ scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
138
+
139
+ final = msmv_sampling_onnx(mlvl_feats, sample_points_cam, i_view, scale_weights)
140
 
141
  # reorganize the sampled features
142
  C = final.shape[2] # [BTG, Q, C, P]
models/sparsebev_transformer.py CHANGED
@@ -56,18 +56,23 @@ class SparseBEVTransformerDecoder(BaseModule):
56
  def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
57
  cls_scores, bbox_preds = [], []
58
 
59
- # calculate time difference according to timestamps
60
- timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
61
- timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
62
- time_diff = timestamps[:, :1, :] - timestamps
63
- time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F]
64
- time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
65
- img_metas[0]['time_diff'] = time_diff
66
-
67
- # organize projections matrix and copy to CUDA
68
- lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
69
- lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
70
- img_metas[0]['lidar2img'] = lidar2img
 
 
 
 
 
71
 
72
  # group image features in advance for sampling, see `sampling_4d` for more details
73
  for lvl, feat in enumerate(mlvl_feats):
@@ -178,9 +183,11 @@ class SparseBEVTransformerDecoderLayer(BaseModule):
178
  # calculate absolute velocity according to time difference
179
  time_diff = img_metas[0]['time_diff'] # [B, F]
180
  if time_diff.shape[1] > 1:
181
- time_diff = time_diff.clone()
182
- time_diff[time_diff < 1e-5] = 1.0
183
- bbox_pred[..., 8:] = bbox_pred[..., 8:] / time_diff[:, 1:2, None]
 
 
184
 
185
  if DUMP.enabled:
186
  query_bbox_dec = decode_bbox(query_bbox, self.pc_range)
@@ -236,16 +243,8 @@ class SparseBEVSelfAttention(BaseModule):
236
  @torch.no_grad()
237
  def calc_bbox_dists(self, bboxes):
238
  centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2]
239
-
240
- dist = []
241
- for b in range(centers.shape[0]):
242
- dist_b = torch.norm(centers[b].reshape(-1, 1, 2) - centers[b].reshape(1, -1, 2), dim=-1)
243
- dist.append(dist_b[None, ...])
244
-
245
- dist = torch.cat(dist, dim=0) # [B, Q, Q]
246
- dist = -dist
247
-
248
- return dist
249
 
250
 
251
  class SparseBEVSampling(BaseModule):
 
56
  def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
57
  cls_scores, bbox_preds = [], []
58
 
59
+ if isinstance(img_metas[0].get('time_diff'), torch.Tensor):
60
+ # ONNX export path: tensors pre-computed and injected by the wrapper
61
+ pass # time_diff and lidar2img already set in img_metas[0]
62
+ else:
63
+ # Standard path: extract from img_metas using numpy
64
+ # calculate time difference according to timestamps
65
+ timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
66
+ timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
67
+ time_diff = timestamps[:, :1, :] - timestamps
68
+ time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F]
69
+ time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
70
+ img_metas[0]['time_diff'] = time_diff
71
+
72
+ # organize projections matrix and copy to CUDA
73
+ lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
74
+ lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
75
+ img_metas[0]['lidar2img'] = lidar2img
76
 
77
  # group image features in advance for sampling, see `sampling_4d` for more details
78
  for lvl, feat in enumerate(mlvl_feats):
 
183
  # calculate absolute velocity according to time difference
184
  time_diff = img_metas[0]['time_diff'] # [B, F]
185
  if time_diff.shape[1] > 1:
186
+ time_diff = torch.where(time_diff < 1e-5, torch.ones_like(time_diff), time_diff)
187
+ bbox_pred = torch.cat([
188
+ bbox_pred[..., :8],
189
+ bbox_pred[..., 8:] / time_diff[:, 1:2, None],
190
+ ], dim=-1)
191
 
192
  if DUMP.enabled:
193
  query_bbox_dec = decode_bbox(query_bbox, self.pc_range)
 
243
  @torch.no_grad()
244
  def calc_bbox_dists(self, bboxes):
245
  centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2]
246
+ dist = torch.norm(centers.unsqueeze(2) - centers.unsqueeze(1), dim=-1) # [B, Q, Q]
247
+ return -dist
 
 
 
 
 
 
 
 
248
 
249
 
250
  class SparseBEVSampling(BaseModule):
pyproject.toml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "sparsebev"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "numpy>=2.4.3",
9
+ "onnx>=1.20.1",
10
+ "onnxruntime>=1.16",
11
+ "setuptools>=40,<72", # <72 required: mmcv setup.py uses pkg_resources removed in >=72
12
+ "torch>=2.10.0",
13
+ "torchvision>=0.25.0",
14
+ # mmdet ecosystem — old packages with stale pins, needs --no-build-isolation
15
+ "mmdet==2.28.2",
16
+ "mmsegmentation==0.30.0",
17
+ "mmdet3d==1.0.0rc6",
18
+ "mmcv==1.7.0",
19
+ "fvcore>=0.1.5.post20221221",
20
+ "einops>=0.8.2",
21
+ "onnxscript>=0.6.2",
22
+ ]
23
+
24
+ [tool.uv]
25
+ # Build mmcv/mmdet without isolation so they see the pinned setuptools<72
26
+ # (they import pkg_resources in setup.py which was removed in setuptools>=72)
27
+ no-build-isolation-package = ["mmcv", "mmdet", "mmdet3d", "mmsegmentation"]
28
+
29
+ # mmdet3d==1.0.0rc6 has stale pins that conflict with Python 3.12 and modern torch.
30
+ # Override to compatible modern versions.
31
+ override-dependencies = [
32
+ "networkx>=2.5.1",
33
+ # mmdet3d pins numba==0.53.0 -> llvmlite==0.36.0 which only supports Python<3.10
34
+ "numba>=0.60.0",
35
+ "llvmlite>=0.43.0",
36
+ # setuptools>=72 removed pkg_resources as a top-level module; mmcv setup.py needs it
37
+ "setuptools<72",
38
+ ]
39
+
40
+ [tool.uv.extra-build-dependencies]
41
+ # mmdet3d/mmdet need torch at build time (they import it in setup.py)
42
+ mmdet3d = ["torch"]
43
+ mmdet = ["torch"]
44
+ mmcv = ["setuptools"]
uv.lock ADDED
The diff for this file is too large to render. See raw diff