File size: 2,313 Bytes
3ea6165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn


class SparseBEVOnnxWrapper(nn.Module):
    """
    Thin wrapper around SparseBEV for ONNX export.

    Accepts pre-computed tensors instead of the img_metas dict so the graph
    boundary is clean.  Returns raw decoder logits without NMS or decoding so
    post-processing can stay in Python.

    Inputs (all float32):
        img        [B, T*N, 3, H, W]  — BGR images, will be normalised inside
        lidar2img  [B, T*N, 4, 4]     — LiDAR-to-image projection matrices
        time_diff  [B, T]             — seconds since the first frame (per frame,
                                        averaged across the N cameras)

    Outputs:
        cls_scores  [num_layers, B, Q, num_classes]
        bbox_preds  [num_layers, B, Q, 10]
    """

    def __init__(self, model, image_h=256, image_w=704, num_frames=8, num_cameras=6):
        super().__init__()
        self.model = model
        self.image_h = image_h
        self.image_w = image_w
        self.num_frames = num_frames
        self.num_cameras = num_cameras

        # Disable stochastic augmentations that are meaningless at inference
        self.model.use_grid_mask = False
        # Disable FP16 casting decorators
        self.model.fp16_enabled = False

    def forward(self, img, lidar2img, time_diff):
        B, TN, C, H, W = img.shape

        # Build a minimal img_metas.  Only the Python-constant fields are here;
        # the tensor fields (time_diff, lidar2img) are injected as real tensors
        # so the ONNX tracer includes them in the graph.
        img_shape = (self.image_h, self.image_w, C)
        img_metas = [{
            'img_shape': [img_shape] * TN,
            'ori_shape': [img_shape] * TN,
            'time_diff': time_diff,    # tensor — flows into the ONNX graph
            'lidar2img': lidar2img,    # tensor — flows into the ONNX graph
        }]

        # Backbone + FPN
        img_feats = self.model.extract_feat(img=img, img_metas=img_metas)

        # Detection head — returns raw predictions, no NMS
        outs = self.model.pts_bbox_head(img_feats, img_metas)

        cls_scores = outs['all_cls_scores']   # [num_layers, B, Q, num_classes]
        bbox_preds = outs['all_bbox_preds']   # [num_layers, B, Q, 10]

        return cls_scores, bbox_preds