Alfred Liu commited on
Commit
d19bd3e
·
1 Parent(s): b0800d3

Code release

Browse files
.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OS generated files
2
+ .DS_Store
3
+ .DS_Store?
4
+ ._*
5
+ .Spotlight-V100
6
+ .Trashes
7
+ ehthumbs.db
8
+ Thumbs.db
9
+
10
+ # Compiled source
11
+ build
12
+ debug
13
+ Debug
14
+ release
15
+ Release
16
+ x64
17
+ *.so
18
+ *.whl
19
+
20
+ # VS project files
21
+ *.sln
22
+ *.vcxproj
23
+ *.vcxproj.filters
24
+ *.vcxproj.user
25
+ *.rc
26
+ .vs
27
+
28
+ # Byte-compiled / optimized / DLL files
29
+ *__pycache__*
30
+ *.py[cod]
31
+ *$py.class
32
+
33
+ # Distribution / packaging
34
+ .Python
35
+ build
36
+ develop-eggs
37
+ dist
38
+ downloads
39
+
40
+ # IDE
41
+ .idea
42
+ .vscode
43
+ pyrightconfig.json
44
+
45
+ # Custom
46
+ data
47
+ outputs
48
+ prediction
49
+ submission
50
+ checkpoints
51
+ pretrain
52
+ *.png
53
+ *.jpg
README.md CHANGED
@@ -1,2 +1,132 @@
1
  # SparseBEV
2
- [ICCV 2023] SparseBEV: High-Performance Sparse 3D Object Detection \\ from Multi-Camera Videos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # SparseBEV
2
+
3
+ This is the official PyTorch implementation for paper [SparseBEV: High-Performance Sparse 3D Object Detection from Multi-Camera Videos](https://arxiv.org/abs/2308.09244). (ICCV 2023)
4
+
5
+ ## Model Zoo
6
+
7
+ | Backbone | Pretrain | Input Size | Epochs | Training Cost | NDS | FPS | Config | Weights |
8
+ |----------|----------|------------|--------|---------------|-----|-----|--------|---------|
9
+ | R50 | [nuImages](https://github.com/open-mmlab/mmdetection3d/blob/main/configs/nuimages/cascade-mask-rcnn_r50_fpn_coco-20e_20e_nuim.py) | 704x256 | 36 | 28h (8x2080Ti) | 55.8 | 23.5 | [config](configs/r50_nuimg_704x256_400q_36ep.py) | [weights](https://drive.google.com/file/d/1C_Vn3iiSnSW1Dw1r0DkjJMwvHC5Y3zTN/view?usp=sharing) |
10
+
11
+ * FPS is measured on a machine with AMD 5800X and RTX 3090.
12
+ * The noise is around 0.3 NDS.
13
+
14
+ ## Environment
15
+
16
+ Install PyTorch 2.0 + CUDA 11.8:
17
+
18
+ ```
19
+ conda create -n sparsebev python=3.8
20
+ conda activate sparsebev
21
+ conda install pytorch==2.0.0 torchvision==0.15.0 pytorch-cuda=11.8 -c pytorch -c nvidia
22
+ ```
23
+
24
+ or PyTorch 1.10.2 + CUDA 10.2 for older GPUs:
25
+
26
+ ```
27
+ conda create -n sparsebev python=3.8
28
+ conda activate sparsebev
29
+ conda install pytorch==1.10.2 torchvision==0.11.3 cudatoolkit=10.2 -c pytorch
30
+ ```
31
+
32
+ Install other dependencies:
33
+
34
+ ```
35
+ pip install openmim
36
+ mim install mmcv-full==1.6.0
37
+ mim install mmdet==2.28.2
38
+ mim install mmsegmentation==0.30.0
39
+ mim install mmdet3d==1.0.0rc6
40
+ pip install setuptools==59.5.0
41
+ pip install numpy==1.23.5
42
+ ```
43
+
44
+ Install turbojpeg and pillow-simd to speed up data loading (optional but important):
45
+
46
+ ```
47
+ sudo apt-get update
48
+ sudo apt-get install -y libturbojpeg
49
+ pip install pyturbojpeg
50
+ pip uninstall pillow
51
+ pip install pillow-simd==9.0.0.post1
52
+ ```
53
+
54
+ Compile CUDA extensions:
55
+
56
+ ```
57
+ cd models/csrc
58
+ python setup.py build_ext --inplace
59
+ ```
60
+
61
+ ## Prepare Dataset
62
+
63
+ 1. Download nuScenes from [https://www.nuscenes.org/nuscenes](https://www.nuscenes.org/nuscenes) and put it in `data/nuscenes`.
64
+ 2. Download the generated info file from [Google Drive](https://drive.google.com/file/d/1uyoUuSRIVScrm_CUpge6V_UzwDT61ODO/view?usp=sharing) and unzip it.
65
+ 3. Folder structure:
66
+
67
+ ```
68
+ data/nuscenes
69
+ ├── maps
70
+ ├── nuscenes_infos_test_sweep.pkl
71
+ ├── nuscenes_infos_train_mini_sweep.pkl
72
+ ├── nuscenes_infos_train_sweep.pkl
73
+ ├── nuscenes_infos_val_mini_sweep.pkl
74
+ ├── nuscenes_infos_val_sweep.pkl
75
+ ├── samples
76
+ ├── sweeps
77
+ ├── v1.0-mini
78
+ ├── v1.0-test
79
+ └── v1.0-trainval
80
+ ```
81
+
82
+ These `*.pkl` files can also be generated with our script: `gen_sweep_info.py`.
83
+
84
+ ## Training
85
+
86
+ Train SparseBEV with 8 GPUs:
87
+
88
+ ```
89
+ torchrun --nproc_per_node 8 train.py --config configs/r50_nuimg_704x256_400q_36ep.py
90
+ ```
91
+
92
+ Train SparseBEV with 4 GPUs (i.e the last four GPUs):
93
+
94
+ ```
95
+ export CUDA_VISIBLE_DEVICES=4,5,6,7
96
+ torchrun --nproc_per_node 4 train.py --config configs/r50_nuimg_704x256_400q_36ep.py
97
+ ```
98
+
99
+ The batch size for each GPU will be scaled automatically. So there is no need to modify the `batch_size` in config files.
100
+
101
+ ## Evaluation
102
+
103
+ Single-GPU evaluation:
104
+
105
+ ```
106
+ export CUDA_VISIBLE_DEVICES=0
107
+ python val.py --config configs/r50_nuimg_704x256_400q_36ep.py --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth
108
+ ```
109
+
110
+ Multi-GPU evaluation:
111
+
112
+ ```
113
+ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
114
+ torchrun --nproc_per_node 8 val.py --config configs/r50_nuimg_704x256_400q_36ep.py --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth
115
+ ```
116
+
117
+ ## Timing
118
+
119
+ FPS is measured with a single GPU:
120
+
121
+ ```
122
+ export CUDA_VISIBLE_DEVICES=0
123
+ python timing.py --config configs/r50_nuimg_704x256_400q_36ep.py --weights checkpoints/r50_nuimg_704x256_400q_36ep.pth
124
+ ```
125
+
126
+ ## Acknowledgements
127
+
128
+ Many thanks to these excellent open-source projects:
129
+
130
+ * 3D Detection: [DETR3D](https://github.com/WangYueFt/detr3d), [PETR](https://github.com/megvii-research/PETR), [BEVFormer](https://github.com/fundamentalvision/BEVFormer), [BEVDet](https://github.com/HuangJunJie2017/BEVDet), [StreamPETR](https://github.com/exiawsh/StreamPETR)
131
+ * 2D Detection: [AdaMixer](https://github.com/MCG-NJU/AdaMixer), [DN-DETR](https://github.com/IDEA-Research/DN-DETR)
132
+ * Codebase: [MMDetection3D](https://github.com/open-mmlab/mmdetection3d), [CamLiFlow](https://github.com/MCG-NJU/CamLiFlow)
configs/r101_nuimg_1408x512_900q_24ep.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['./r50_nuimg_704x256_400q_36ep.py']
2
+
3
+ # For nuScenes we usually do 10-class detection
4
+ class_names = [
5
+ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
6
+ 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
7
+ ]
8
+
9
+ # If point cloud range is changed, the models should also change their point
10
+ # cloud range accordingly
11
+ point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
12
+ voxel_size = [0.2, 0.2, 8]
13
+
14
+ img_backbone = dict(
15
+ type='ResNet',
16
+ depth=101,
17
+ with_cp=True,
18
+ )
19
+
20
+ img_neck = dict(
21
+ type='FPN',
22
+ in_channels=[256, 512, 1024, 2048],
23
+ out_channels=256,
24
+ num_outs=5,
25
+ )
26
+
27
+ model = dict(
28
+ img_backbone=img_backbone,
29
+ img_neck=img_neck,
30
+ pts_bbox_head=dict(
31
+ num_query=900,
32
+ transformer=dict(num_levels=5)),
33
+ )
34
+
35
+ ida_aug_conf = {
36
+ 'resize_lim': (0.38 * 2, 0.55 * 2),
37
+ 'final_dim': (512, 1408),
38
+ 'bot_pct_lim': (0.0, 0.0),
39
+ 'rot_lim': (0.0, 0.0),
40
+ 'H': 900, 'W': 1600,
41
+ 'rand_flip': True,
42
+ }
43
+
44
+ train_pipeline = [
45
+ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
46
+ dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=7),
47
+ dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
48
+ dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
49
+ dict(type='ObjectNameFilter', classes=class_names),
50
+ dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
51
+ dict(type='GlobalRotScaleTransImage', rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05]),
52
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
53
+ dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], meta_keys=(
54
+ 'filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp'))
55
+ ]
56
+
57
+ test_pipeline = [
58
+ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
59
+ dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=7, test_mode=True),
60
+ dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
61
+ dict(
62
+ type='MultiScaleFlipAug3D',
63
+ img_scale=(1600, 900),
64
+ pts_scale_ratio=1,
65
+ flip=False,
66
+ transforms=[
67
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
68
+ dict(type='Collect3D', keys=['img'], meta_keys=(
69
+ 'filename', 'box_type_3d', 'ori_shape', 'img_shape', 'pad_shape',
70
+ 'lidar2img', 'img_timestamp'))
71
+ ])
72
+ ]
73
+
74
+ data = dict(
75
+ workers_per_gpu=4,
76
+ train=dict(pipeline=train_pipeline),
77
+ val=dict(pipeline=test_pipeline),
78
+ test=dict(pipeline=test_pipeline)
79
+ )
80
+
81
+ optimizer = dict(
82
+ type='AdamW',
83
+ lr=2e-4,
84
+ paramwise_cfg=dict(custom_keys={
85
+ 'img_backbone': dict(lr_mult=0.2),
86
+ 'sampling_offset': dict(lr_mult=0.1),
87
+ }),
88
+ weight_decay=0.01
89
+ )
90
+
91
+ # load pretrained weights
92
+ load_from = 'pretrain/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth'
93
+ revise_keys = [('backbone', 'img_backbone')]
94
+
95
+ total_epochs = 24
96
+ eval_config = dict(interval=total_epochs)
configs/r50_in1k_704x256_900q_36ep.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = ['./r50_nuimg_704x256_400q_36ep.py']
2
+
3
+ img_backbone = dict(pretrained='torchvision://resnet50')
4
+
5
+ model = dict(
6
+ img_backbone=img_backbone,
7
+ pts_bbox_head=dict(num_query=900)
8
+ )
9
+
10
+ optimizer = dict(
11
+ paramwise_cfg=dict(custom_keys={
12
+ 'img_backbone': dict(lr_mult=0.4),
13
+ 'sampling_offset': dict(lr_mult=0.1),
14
+ })
15
+ )
16
+
17
+ load_from = None
18
+ revise_keys = None
19
+
20
+ total_epochs = 36
21
+ eval_config = dict(interval=total_epochs)
configs/r50_nuimg_704x256_400q_36ep.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'CustomNuScenesDataset'
2
+ dataset_root = 'data/nuscenes/'
3
+
4
+ input_modality = dict(
5
+ use_lidar=False,
6
+ use_camera=True,
7
+ use_radar=False,
8
+ use_map=False,
9
+ use_external=True
10
+ )
11
+
12
+ # For nuScenes we usually do 10-class detection
13
+ class_names = [
14
+ 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier',
15
+ 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'
16
+ ]
17
+
18
+ # If point cloud range is changed, the models should also change their point
19
+ # cloud range accordingly
20
+ point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]
21
+ voxel_size = [0.2, 0.2, 8]
22
+
23
+ # arch config
24
+ embed_dims = 256
25
+ num_layers = 6
26
+ num_query = 400
27
+ num_frames = 8
28
+ num_levels = 4
29
+ num_points = 4
30
+
31
+ img_backbone = dict(
32
+ type='ResNet',
33
+ depth=50,
34
+ num_stages=4,
35
+ out_indices=(0, 1, 2, 3),
36
+ frozen_stages=1,
37
+ norm_cfg=dict(type='BN2d', requires_grad=True),
38
+ norm_eval=True,
39
+ style='pytorch',
40
+ with_cp=True)
41
+ img_neck = dict(
42
+ type='FPN',
43
+ in_channels=[256, 512, 1024, 2048],
44
+ out_channels=embed_dims,
45
+ num_outs=num_levels)
46
+ img_norm_cfg = dict(
47
+ mean=[123.675, 116.280, 103.530],
48
+ std=[58.395, 57.120, 57.375],
49
+ to_rgb=True)
50
+
51
+ model = dict(
52
+ type='SparseBEV',
53
+ data_aug=dict(
54
+ img_color_aug=True, # Move some augmentations to GPU
55
+ img_norm_cfg=img_norm_cfg,
56
+ img_pad_cfg=dict(size_divisor=32)),
57
+ stop_prev_grad=False,
58
+ img_backbone=img_backbone,
59
+ img_neck=img_neck,
60
+ pts_bbox_head=dict(
61
+ type='SparseBEVHead',
62
+ num_classes=10,
63
+ in_channels=embed_dims,
64
+ num_query=num_query,
65
+ query_denoising=True,
66
+ query_denoising_groups=10,
67
+ code_size=10,
68
+ code_weights=[2.0, 2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
69
+ sync_cls_avg_factor=True,
70
+ transformer=dict(
71
+ type='SparseBEVTransformer',
72
+ embed_dims=embed_dims,
73
+ num_frames=num_frames,
74
+ num_points=num_points,
75
+ num_layers=num_layers,
76
+ num_levels=num_levels,
77
+ num_classes=10,
78
+ code_size=10,
79
+ pc_range=point_cloud_range),
80
+ bbox_coder=dict(
81
+ type='NMSFreeCoder',
82
+ post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
83
+ pc_range=point_cloud_range,
84
+ max_num=300,
85
+ voxel_size=voxel_size,
86
+ score_threshold=0.05,
87
+ num_classes=10),
88
+ positional_encoding=dict(
89
+ type='SinePositionalEncoding',
90
+ num_feats=embed_dims // 2,
91
+ normalize=True,
92
+ offset=-0.5),
93
+ loss_cls=dict(
94
+ type='FocalLoss',
95
+ use_sigmoid=True,
96
+ gamma=2.0,
97
+ alpha=0.25,
98
+ loss_weight=2.0),
99
+ loss_bbox=dict(type='L1Loss', loss_weight=0.25),
100
+ loss_iou=dict(type='GIoULoss', loss_weight=0.0)),
101
+ train_cfg=dict(pts=dict(
102
+ grid_size=[512, 512, 1],
103
+ voxel_size=voxel_size,
104
+ point_cloud_range=point_cloud_range,
105
+ out_size_factor=4,
106
+ assigner=dict(
107
+ type='HungarianAssigner3D',
108
+ cls_cost=dict(type='FocalLossCost', weight=2.0),
109
+ reg_cost=dict(type='BBox3DL1Cost', weight=0.25),
110
+ iou_cost=dict(type='IoUCost', weight=0.0),
111
+ )
112
+ ))
113
+ )
114
+
115
+ ida_aug_conf = {
116
+ 'resize_lim': (0.38, 0.55),
117
+ 'final_dim': (256, 704),
118
+ 'bot_pct_lim': (0.0, 0.0),
119
+ 'rot_lim': (0.0, 0.0),
120
+ 'H': 900, 'W': 1600,
121
+ 'rand_flip': True,
122
+ }
123
+
124
+ train_pipeline = [
125
+ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
126
+ dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=num_frames - 1),
127
+ dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=False),
128
+ dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
129
+ dict(type='ObjectNameFilter', classes=class_names),
130
+ dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=True),
131
+ dict(type='GlobalRotScaleTransImage', rot_range=[-0.3925, 0.3925], scale_ratio_range=[0.95, 1.05]),
132
+ dict(type='DefaultFormatBundle3D', class_names=class_names),
133
+ dict(type='Collect3D', keys=['gt_bboxes_3d', 'gt_labels_3d', 'img'], meta_keys=(
134
+ 'filename', 'ori_shape', 'img_shape', 'pad_shape', 'lidar2img', 'img_timestamp'))
135
+ ]
136
+
137
+ test_pipeline = [
138
+ dict(type='LoadMultiViewImageFromFiles', to_float32=False, color_type='color'),
139
+ dict(type='LoadMultiViewImageFromMultiSweeps', sweeps_num=num_frames - 1, test_mode=True),
140
+ dict(type='RandomTransformImage', ida_aug_conf=ida_aug_conf, training=False),
141
+ dict(
142
+ type='MultiScaleFlipAug3D',
143
+ img_scale=(1600, 900),
144
+ pts_scale_ratio=1,
145
+ flip=False,
146
+ transforms=[
147
+ dict(type='DefaultFormatBundle3D', class_names=class_names, with_label=False),
148
+ dict(type='Collect3D', keys=['img'], meta_keys=(
149
+ 'filename', 'box_type_3d', 'ori_shape', 'img_shape', 'pad_shape',
150
+ 'lidar2img', 'img_timestamp'))
151
+ ])
152
+ ]
153
+
154
+ data = dict(
155
+ workers_per_gpu=8,
156
+ train=dict(
157
+ type=dataset_type,
158
+ data_root=dataset_root,
159
+ ann_file=dataset_root + 'nuscenes_infos_train_sweep.pkl',
160
+ pipeline=train_pipeline,
161
+ classes=class_names,
162
+ modality=input_modality,
163
+ test_mode=False,
164
+ use_valid_flag=True,
165
+ box_type_3d='LiDAR'),
166
+ val=dict(
167
+ type=dataset_type,
168
+ data_root=dataset_root,
169
+ ann_file=dataset_root + 'nuscenes_infos_val_sweep.pkl',
170
+ pipeline=test_pipeline,
171
+ classes=class_names,
172
+ modality=input_modality,
173
+ test_mode=True,
174
+ box_type_3d='LiDAR'),
175
+ test=dict(
176
+ type=dataset_type,
177
+ data_root=dataset_root,
178
+ ann_file=dataset_root + 'nuscenes_custom_infos_test.pkl',
179
+ pipeline=test_pipeline,
180
+ classes=class_names,
181
+ modality=input_modality,
182
+ test_mode=True,
183
+ box_type_3d='LiDAR')
184
+ )
185
+
186
+ optimizer = dict(
187
+ type='AdamW',
188
+ lr=2e-4,
189
+ paramwise_cfg=dict(custom_keys={
190
+ 'img_backbone': dict(lr_mult=0.1),
191
+ 'sampling_offset': dict(lr_mult=0.1),
192
+ }),
193
+ weight_decay=0.01
194
+ )
195
+
196
+ optimizer_config = dict(
197
+ type='Fp16OptimizerHook',
198
+ loss_scale=512.0,
199
+ grad_clip=dict(max_norm=35, norm_type=2)
200
+ )
201
+
202
+ # learning policy
203
+ lr_config = dict(
204
+ policy='CosineAnnealing',
205
+ warmup='linear',
206
+ warmup_iters=500,
207
+ warmup_ratio=1.0 / 3,
208
+ min_lr_ratio=1e-3
209
+ )
210
+ total_epochs = 36
211
+ batch_size = 8
212
+
213
+ # load pretrained weights
214
+ load_from = 'pretrain/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth'
215
+ revise_keys = [('backbone', 'img_backbone')]
216
+
217
+ # resume the last training
218
+ resume_from = None
219
+
220
+ # checkpointing
221
+ checkpoint_config = dict(interval=1, max_keep_ckpts=1)
222
+
223
+ # logging
224
+ log_config = dict(
225
+ interval=1,
226
+ hooks=[
227
+ dict(type='MyTextLoggerHook', interval=1, reset_flag=True),
228
+ dict(type='MyTensorboardLoggerHook', interval=500, reset_flag=True)
229
+ ]
230
+ )
231
+
232
+ # evaluation
233
+ eval_config = dict(interval=total_epochs)
234
+
235
+ # other flags
236
+ debug = False
gen_sweep_info.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generate info files manually
2
+ import os
3
+ import mmcv
4
+ import tqdm
5
+ import pickle
6
+ import argparse
7
+ import numpy as np
8
+ from nuscenes import NuScenes
9
+ from pyquaternion import Quaternion
10
+
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--data-root', default='data/nuscenes')
14
+ parser.add_argument('--version', default='v1.0-trainval')
15
+ args = parser.parse_args()
16
+
17
+
18
+ def get_cam_info(nusc, sample_data):
19
+ pose_record = nusc.get('ego_pose', sample_data['ego_pose_token'])
20
+ cs_record = nusc.get('calibrated_sensor', sample_data['calibrated_sensor_token'])
21
+
22
+ sensor2ego_translation = cs_record['translation']
23
+ ego2global_translation = pose_record['translation']
24
+ sensor2ego_rotation = Quaternion(cs_record['rotation']).rotation_matrix
25
+ ego2global_rotation = Quaternion(pose_record['rotation']).rotation_matrix
26
+ cam_intrinsic = np.array(cs_record['camera_intrinsic'])
27
+
28
+ sensor2global_rotation = sensor2ego_rotation.T @ ego2global_rotation.T
29
+ sensor2global_translation = sensor2ego_translation @ ego2global_rotation.T + ego2global_translation
30
+
31
+ return {
32
+ 'data_path': os.path.join(args.data_root, sample_data['filename']),
33
+ 'sensor2global_rotation': sensor2global_rotation,
34
+ 'sensor2global_translation': sensor2global_translation,
35
+ 'cam_intrinsic': cam_intrinsic,
36
+ 'timestamp': sample_data['timestamp'],
37
+ }
38
+
39
+
40
+ def add_sweep_info(nusc, sample_infos):
41
+ for curr_id in tqdm.tqdm(range(len(sample_infos['infos']))):
42
+ sample = nusc.get('sample', sample_infos['infos'][curr_id]['token'])
43
+
44
+ cam_types = [
45
+ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT',
46
+ 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT'
47
+ ]
48
+
49
+ curr_cams = dict()
50
+ for cam in cam_types:
51
+ curr_cams[cam] = nusc.get('sample_data', sample['data'][cam])
52
+
53
+ for cam in cam_types:
54
+ sample_data = nusc.get('sample_data', sample['data'][cam])
55
+ sweep_cam = get_cam_info(nusc, sample_data)
56
+ sample_infos['infos'][curr_id]['cams'][cam].update(sweep_cam)
57
+
58
+ # remove unnecessary
59
+ for cam in cam_types:
60
+ del sample_infos['infos'][curr_id]['cams'][cam]['sample_data_token']
61
+ del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_translation']
62
+ del sample_infos['infos'][curr_id]['cams'][cam]['sensor2ego_rotation']
63
+ del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_translation']
64
+ del sample_infos['infos'][curr_id]['cams'][cam]['ego2global_rotation']
65
+
66
+ sweep_infos = []
67
+ if sample['prev'] != '': # add sweep frame between two key frame
68
+ for _ in range(5):
69
+ sweep_info = dict()
70
+ for cam in cam_types:
71
+ if curr_cams[cam]['prev'] == '':
72
+ sweep_info = sweep_infos[-1]
73
+ break
74
+ sample_data = nusc.get('sample_data', curr_cams[cam]['prev'])
75
+ sweep_cam = get_cam_info(nusc, sample_data)
76
+ curr_cams[cam] = sample_data
77
+ sweep_info[cam] = sweep_cam
78
+ sweep_infos.append(sweep_info)
79
+
80
+ sample_infos['infos'][curr_id]['sweeps'] = sweep_infos
81
+
82
+ return sample_infos
83
+
84
+
85
+ if __name__ == '__main__':
86
+ nusc = NuScenes(args.version, args.data_root)
87
+
88
+ if args.version == 'v1.0-trainval':
89
+ sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train.pkl'), 'rb'))
90
+ sample_infos = add_sweep_info(nusc, sample_infos)
91
+ mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_sweep.pkl'))
92
+
93
+ sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val.pkl'), 'rb'))
94
+ sample_infos = add_sweep_info(nusc, sample_infos)
95
+ mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_sweep.pkl'))
96
+
97
+ elif args.version == 'v1.0-test':
98
+ sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_test.pkl'), 'rb'))
99
+ sample_infos = add_sweep_info(nusc, sample_infos)
100
+ mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_test_sweep.pkl'))
101
+
102
+ elif args.version == 'v1.0-mini':
103
+ sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_train_mini.pkl'), 'rb'))
104
+ sample_infos = add_sweep_info(nusc, sample_infos)
105
+ mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_train_mini_sweep.pkl'))
106
+
107
+ sample_infos = pickle.load(open(os.path.join(args.data_root, 'nuscenes_infos_val_mini.pkl'), 'rb'))
108
+ sample_infos = add_sweep_info(nusc, sample_infos)
109
+ mmcv.dump(sample_infos, os.path.join(args.data_root, 'nuscenes_infos_val_mini_sweep.pkl'))
110
+
111
+ else:
112
+ raise ValueError
loaders/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .pipelines import __all__
2
+ from .nuscenes_dataset import CustomNuScenesDataset
3
+
4
+ __all__ = [
5
+ 'CustomNuScenesDataset'
6
+ ]
loaders/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from mmcv.parallel import collate
3
+ from mmcv.runner import get_dist_info
4
+ from torch.utils.data import DataLoader
5
+ from mmdet.datasets.builder import worker_init_fn
6
+ from mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
7
+
8
+
9
+ def build_dataloader(dataset,
10
+ samples_per_gpu,
11
+ workers_per_gpu,
12
+ num_gpus=1,
13
+ dist=True,
14
+ shuffle=True,
15
+ seed=None,
16
+ **kwargs):
17
+
18
+ rank, world_size = get_dist_info()
19
+ if dist:
20
+ # DistributedGroupSampler will definitely shuffle the data to satisfy
21
+ # that images on each GPU are in the same group
22
+ if shuffle:
23
+ sampler = DistributedGroupSampler(
24
+ dataset, samples_per_gpu, world_size, rank, seed=seed)
25
+ else:
26
+ sampler = DistributedSampler(
27
+ dataset, world_size, rank, shuffle=False, seed=seed)
28
+ batch_size = samples_per_gpu
29
+ num_workers = workers_per_gpu
30
+ else:
31
+ sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
32
+ batch_size = num_gpus * samples_per_gpu
33
+ num_workers = num_gpus * workers_per_gpu
34
+
35
+ init_fn = partial(
36
+ worker_init_fn, num_workers=num_workers, rank=rank,
37
+ seed=seed) if seed is not None else None
38
+
39
+ data_loader = DataLoader(
40
+ dataset,
41
+ batch_size=batch_size,
42
+ sampler=sampler,
43
+ num_workers=num_workers,
44
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
45
+ pin_memory=False,
46
+ worker_init_fn=init_fn,
47
+ **kwargs)
48
+
49
+ return data_loader
loaders/nuscenes_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from mmdet.datasets import DATASETS
4
+ from mmdet3d.datasets import NuScenesDataset
5
+ from pyquaternion import Quaternion
6
+
7
+
8
+ @DATASETS.register_module()
9
+ class CustomNuScenesDataset(NuScenesDataset):
10
+
11
+ def collect_sweeps(self, index, into_past=60, into_future=0):
12
+ all_sweeps_prev = []
13
+ curr_index = index
14
+ while len(all_sweeps_prev) < into_past:
15
+ curr_sweeps = self.data_infos[curr_index]['sweeps']
16
+ if len(curr_sweeps) == 0:
17
+ break
18
+ all_sweeps_prev.extend(curr_sweeps)
19
+ all_sweeps_prev.append(self.data_infos[curr_index - 1]['cams'])
20
+ curr_index = curr_index - 1
21
+
22
+ all_sweeps_next = []
23
+ curr_index = index + 1
24
+ while len(all_sweeps_next) < into_future:
25
+ if curr_index >= len(self.data_infos):
26
+ break
27
+ curr_sweeps = self.data_infos[curr_index]['sweeps']
28
+ all_sweeps_next.extend(curr_sweeps[::-1])
29
+ all_sweeps_next.append(self.data_infos[curr_index]['cams'])
30
+ curr_index = curr_index + 1
31
+
32
+ return all_sweeps_prev, all_sweeps_next
33
+
34
+ def get_data_info(self, index):
35
+ info = self.data_infos[index]
36
+ sweeps_prev, sweeps_next = self.collect_sweeps(index)
37
+
38
+ ego2global_translation = info['ego2global_translation']
39
+ ego2global_rotation = info['ego2global_rotation']
40
+ lidar2ego_translation = info['lidar2ego_translation']
41
+ lidar2ego_rotation = info['lidar2ego_rotation']
42
+ ego2global_rotation = Quaternion(ego2global_rotation).rotation_matrix
43
+ lidar2ego_rotation = Quaternion(lidar2ego_rotation).rotation_matrix
44
+
45
+ input_dict = dict(
46
+ sample_idx=info['token'],
47
+ sweeps={'prev': sweeps_prev, 'next': sweeps_next},
48
+ timestamp=info['timestamp'] / 1e6,
49
+ ego2global_translation=ego2global_translation,
50
+ ego2global_rotation=ego2global_rotation,
51
+ lidar2ego_translation=lidar2ego_translation,
52
+ lidar2ego_rotation=lidar2ego_rotation,
53
+ )
54
+
55
+ if self.modality['use_camera']:
56
+ img_paths = []
57
+ img_timestamps = []
58
+ lidar2img_rts = []
59
+
60
+ for _, cam_info in info['cams'].items():
61
+ img_paths.append(os.path.relpath(cam_info['data_path']))
62
+ img_timestamps.append(cam_info['timestamp'] / 1e6)
63
+
64
+ # obtain lidar to image transformation matrix
65
+ lidar2cam_r = np.linalg.inv(cam_info['sensor2lidar_rotation'])
66
+ lidar2cam_t = cam_info['sensor2lidar_translation'] @ lidar2cam_r.T
67
+
68
+ lidar2cam_rt = np.eye(4)
69
+ lidar2cam_rt[:3, :3] = lidar2cam_r.T
70
+ lidar2cam_rt[3, :3] = -lidar2cam_t
71
+
72
+ intrinsic = cam_info['cam_intrinsic']
73
+ viewpad = np.eye(4)
74
+ viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic
75
+ lidar2img_rt = (viewpad @ lidar2cam_rt.T)
76
+ lidar2img_rts.append(lidar2img_rt)
77
+
78
+ input_dict.update(dict(
79
+ img_filename=img_paths,
80
+ img_timestamp=img_timestamps,
81
+ lidar2img=lidar2img_rts,
82
+ ))
83
+
84
+ if not self.test_mode:
85
+ annos = self.get_ann_info(index)
86
+ input_dict['ann_info'] = annos
87
+
88
+ return input_dict
loaders/pipelines/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .loading import LoadMultiViewImageFromMultiSweeps
2
+ from .transforms import PadMultiViewImage, NormalizeMultiviewImage, PhotoMetricDistortionMultiViewImage
3
+
4
+ __all__ = [
5
+ 'LoadMultiViewImageFromMultiSweeps', 'PadMultiViewImage', 'NormalizeMultiviewImage',
6
+ 'PhotoMetricDistortionMultiViewImage'
7
+ ]
loaders/pipelines/loading.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import mmcv
3
+ import numpy as np
4
+ from mmdet.datasets.builder import PIPELINES
5
+ from numpy.linalg import inv
6
+ from mmcv.runner import get_dist_info
7
+
8
+
9
+ def compose_lidar2img(ego2global_translation_curr,
10
+ ego2global_rotation_curr,
11
+ lidar2ego_translation_curr,
12
+ lidar2ego_rotation_curr,
13
+ sensor2global_translation_past,
14
+ sensor2global_rotation_past,
15
+ cam_intrinsic_past):
16
+
17
+ R = sensor2global_rotation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)
18
+ T = sensor2global_translation_past @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T)
19
+ T -= ego2global_translation_curr @ (inv(ego2global_rotation_curr).T @ inv(lidar2ego_rotation_curr).T) + lidar2ego_translation_curr @ inv(lidar2ego_rotation_curr).T
20
+
21
+ lidar2cam_r = inv(R.T)
22
+ lidar2cam_t = T @ lidar2cam_r.T
23
+
24
+ lidar2cam_rt = np.eye(4)
25
+ lidar2cam_rt[:3, :3] = lidar2cam_r.T
26
+ lidar2cam_rt[3, :3] = -lidar2cam_t
27
+
28
+ viewpad = np.eye(4)
29
+ viewpad[:cam_intrinsic_past.shape[0], :cam_intrinsic_past.shape[1]] = cam_intrinsic_past
30
+ lidar2img = (viewpad @ lidar2cam_rt.T).astype(np.float32)
31
+
32
+ return lidar2img
33
+
34
+
35
+ @PIPELINES.register_module()
36
+ class LoadMultiViewImageFromMultiSweeps(object):
37
+ def __init__(self,
38
+ sweeps_num=5,
39
+ color_type='color',
40
+ test_mode=False):
41
+ self.sweeps_num = sweeps_num
42
+ self.color_type = color_type
43
+ self.test_mode = test_mode
44
+
45
+ self.train_interval = [4, 8]
46
+ self.test_interval = 6
47
+
48
+ try:
49
+ mmcv.use_backend('turbojpeg')
50
+ except ImportError:
51
+ mmcv.use_backend('cv2')
52
+
53
+ def load_offline(self, results):
54
+ cam_types = [
55
+ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
56
+ 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
57
+ ]
58
+
59
+ if len(results['sweeps']['prev']) == 0:
60
+ for _ in range(self.sweeps_num):
61
+ for j in range(len(cam_types)):
62
+ results['img'].append(results['img'][j])
63
+ results['img_timestamp'].append(results['img_timestamp'][j])
64
+ results['filename'].append(results['filename'][j])
65
+ results['lidar2img'].append(np.copy(results['lidar2img'][j]))
66
+ else:
67
+ if self.test_mode:
68
+ interval = self.test_interval
69
+ choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
70
+ elif len(results['sweeps']['prev']) <= self.sweeps_num:
71
+ pad_len = self.sweeps_num - len(results['sweeps']['prev'])
72
+ choices = list(range(len(results['sweeps']['prev']))) + [len(results['sweeps']['prev']) - 1] * pad_len
73
+ else:
74
+ max_interval = len(results['sweeps']['prev']) // self.sweeps_num
75
+ max_interval = min(max_interval, self.train_interval[1])
76
+ min_interval = min(max_interval, self.train_interval[0])
77
+ interval = np.random.randint(min_interval, max_interval + 1)
78
+ choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
79
+
80
+ for idx in sorted(list(choices)):
81
+ sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
82
+ sweep = results['sweeps']['prev'][sweep_idx]
83
+
84
+ if len(sweep.keys()) < len(cam_types):
85
+ sweep = results['sweeps']['prev'][sweep_idx - 1]
86
+
87
+ for sensor in cam_types:
88
+ results['img'].append(mmcv.imread(sweep[sensor]['data_path'], self.color_type))
89
+ results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
90
+ results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
91
+ results['lidar2img'].append(compose_lidar2img(
92
+ results['ego2global_translation'],
93
+ results['ego2global_rotation'],
94
+ results['lidar2ego_translation'],
95
+ results['lidar2ego_rotation'],
96
+ sweep[sensor]['sensor2global_translation'],
97
+ sweep[sensor]['sensor2global_rotation'],
98
+ sweep[sensor]['cam_intrinsic'],
99
+ ))
100
+
101
+ return results
102
+
103
+ def load_online(self, results):
104
+ # only used when measuring FPS
105
+ assert self.test_mode
106
+ assert self.test_interval == 6
107
+
108
+ cam_types = [
109
+ 'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_FRONT_LEFT',
110
+ 'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_BACK_RIGHT'
111
+ ]
112
+
113
+ if len(results['sweeps']['prev']) == 0:
114
+ for _ in range(self.sweeps_num):
115
+ for j in range(len(cam_types)):
116
+ results['img_timestamp'].append(results['img_timestamp'][j])
117
+ results['filename'].append(results['filename'][j])
118
+ results['lidar2img'].append(np.copy(results['lidar2img'][j]))
119
+ else:
120
+ interval = self.test_interval
121
+ choices = [(k + 1) * interval - 1 for k in range(self.sweeps_num)]
122
+
123
+ for idx in sorted(list(choices)):
124
+ sweep_idx = min(idx, len(results['sweeps']['prev']) - 1)
125
+ sweep = results['sweeps']['prev'][sweep_idx]
126
+
127
+ if len(sweep.keys()) < len(cam_types):
128
+ sweep = results['sweeps']['prev'][sweep_idx - 1]
129
+
130
+ for sensor in cam_types:
131
+ # skip loading history frames
132
+ results['img_timestamp'].append(sweep[sensor]['timestamp'] / 1e6)
133
+ results['filename'].append(os.path.relpath(sweep[sensor]['data_path']))
134
+ results['lidar2img'].append(compose_lidar2img(
135
+ results['ego2global_translation'],
136
+ results['ego2global_rotation'],
137
+ results['lidar2ego_translation'],
138
+ results['lidar2ego_rotation'],
139
+ sweep[sensor]['sensor2global_translation'],
140
+ sweep[sensor]['sensor2global_rotation'],
141
+ sweep[sensor]['cam_intrinsic'],
142
+ ))
143
+
144
+ return results
145
+
146
+ def __call__(self, results):
147
+ if self.sweeps_num == 0:
148
+ return results
149
+
150
+ world_size = get_dist_info()[1]
151
+ if world_size == 1 and self.test_mode:
152
+ return self.load_online(results)
153
+ else:
154
+ return self.load_offline(results)
loaders/pipelines/transforms.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from numpy import random
6
+ from mmdet.datasets.builder import PIPELINES
7
+
8
+
9
+ @PIPELINES.register_module()
10
+ class PadMultiViewImage(object):
11
+ """Pad the multi-view image.
12
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
13
+ minimum size that is divisible by some number.
14
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
15
+ Args:
16
+ size (tuple, optional): Fixed padding size.
17
+ size_divisor (int, optional): The divisor of padded size.
18
+ pad_val (float, optional): Padding value, 0 by default.
19
+ """
20
+
21
+ def __init__(self, size=None, size_divisor=None, pad_val=0):
22
+ self.size = size
23
+ self.size_divisor = size_divisor
24
+ self.pad_val = pad_val
25
+ # only one of size and size_divisor should be valid
26
+ assert size is not None or size_divisor is not None
27
+ assert size is None or size_divisor is None
28
+
29
+ def _pad_img(self, img):
30
+ if self.size_divisor is not None:
31
+ pad_h = int(np.ceil(img.shape[0] / self.size_divisor)) * self.size_divisor
32
+ pad_w = int(np.ceil(img.shape[1] / self.size_divisor)) * self.size_divisor
33
+ else:
34
+ pad_h, pad_w = self.size
35
+
36
+ pad_width = ((0, pad_h - img.shape[0]), (0, pad_w - img.shape[1]), (0, 0))
37
+ img = np.pad(img, pad_width, constant_values=self.pad_val)
38
+ return img
39
+
40
+ def _pad_imgs(self, results):
41
+ padded_img = [self._pad_img(img) for img in results['img']]
42
+
43
+ results['ori_shape'] = [img.shape for img in results['img']]
44
+ results['img'] = padded_img
45
+ results['img_shape'] = [img.shape for img in padded_img]
46
+ results['pad_shape'] = [img.shape for img in padded_img]
47
+ results['pad_fixed_size'] = self.size
48
+ results['pad_size_divisor'] = self.size_divisor
49
+
50
+ def __call__(self, results):
51
+ """Call function to pad images, masks, semantic segmentation maps.
52
+ Args:
53
+ results (dict): Result dict from loading pipeline.
54
+ Returns:
55
+ dict: Updated result dict.
56
+ """
57
+ self._pad_imgs(results)
58
+ return results
59
+
60
+ def __repr__(self):
61
+ repr_str = self.__class__.__name__
62
+ repr_str += f'(size={self.size}, '
63
+ repr_str += f'size_divisor={self.size_divisor}, '
64
+ repr_str += f'pad_val={self.pad_val})'
65
+ return repr_str
66
+
67
+
68
+ @PIPELINES.register_module()
69
+ class NormalizeMultiviewImage(object):
70
+ """Normalize the image.
71
+ Added key is "img_norm_cfg".
72
+ Args:
73
+ mean (sequence): Mean values of 3 channels.
74
+ std (sequence): Std values of 3 channels.
75
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
76
+ default is true.
77
+ """
78
+
79
+ def __init__(self, mean, std, to_rgb=True):
80
+ self.mean = np.array(mean, dtype=np.float32).reshape(-1)
81
+ self.std = 1 / np.array(std, dtype=np.float32).reshape(-1)
82
+ self.to_rgb = to_rgb
83
+
84
+ def __call__(self, results):
85
+ """Call function to normalize images.
86
+ Args:
87
+ results (dict): Result dict from loading pipeline.
88
+ Returns:
89
+ dict: Normalized results, 'img_norm_cfg' key is added into
90
+ result dict.
91
+ """
92
+ normalized_imgs = []
93
+
94
+ for img in results['img']:
95
+ img = img.astype(np.float32)
96
+ if self.to_rgb:
97
+ img = img[..., ::-1]
98
+ img = img - self.mean
99
+ img = img * self.std
100
+ normalized_imgs.append(img)
101
+
102
+ results['img'] = normalized_imgs
103
+ results['img_norm_cfg'] = dict(
104
+ mean=self.mean,
105
+ std=self.std,
106
+ to_rgb=self.to_rgb
107
+ )
108
+ return results
109
+
110
+ def __repr__(self):
111
+ repr_str = self.__class__.__name__
112
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
113
+ return repr_str
114
+
115
+
116
+ @PIPELINES.register_module()
117
+ class PhotoMetricDistortionMultiViewImage:
118
+ """Apply photometric distortion to image sequentially, every transformation
119
+ is applied with a probability of 0.5. The position of random contrast is in
120
+ second or second to last.
121
+ 1. random brightness
122
+ 2. random contrast (mode 0)
123
+ 3. convert color from BGR to HSV
124
+ 4. random saturation
125
+ 5. random hue
126
+ 6. convert color from HSV to BGR
127
+ 7. random contrast (mode 1)
128
+ 8. randomly swap channels
129
+ Args:
130
+ brightness_delta (int): delta of brightness.
131
+ contrast_range (tuple): range of contrast.
132
+ saturation_range (tuple): range of saturation.
133
+ hue_delta (int): delta of hue.
134
+ """
135
+
136
+ def __init__(self,
137
+ brightness_delta=32,
138
+ contrast_range=(0.5, 1.5),
139
+ saturation_range=(0.5, 1.5),
140
+ hue_delta=18):
141
+ self.brightness_delta = brightness_delta
142
+ self.contrast_lower, self.contrast_upper = contrast_range
143
+ self.saturation_lower, self.saturation_upper = saturation_range
144
+ self.hue_delta = hue_delta
145
+
146
+ def __call__(self, results):
147
+ """Call function to perform photometric distortion on images.
148
+ Args:
149
+ results (dict): Result dict from loading pipeline.
150
+ Returns:
151
+ dict: Result dict with images distorted.
152
+ """
153
+ imgs = results['img']
154
+ new_imgs = []
155
+ for img in imgs:
156
+ ori_dtype = img.dtype
157
+ img = img.astype(np.float32)
158
+
159
+ # random brightness
160
+ if random.randint(2):
161
+ delta = random.uniform(-self.brightness_delta,
162
+ self.brightness_delta)
163
+ img += delta
164
+
165
+ # mode == 0 --> do random contrast first
166
+ # mode == 1 --> do random contrast last
167
+ mode = random.randint(2)
168
+ if mode == 1:
169
+ if random.randint(2):
170
+ alpha = random.uniform(self.contrast_lower,
171
+ self.contrast_upper)
172
+ img *= alpha
173
+
174
+ # convert color from BGR to HSV
175
+ img = mmcv.bgr2hsv(img)
176
+
177
+ # random saturation
178
+ if random.randint(2):
179
+ img[..., 1] *= random.uniform(self.saturation_lower,
180
+ self.saturation_upper)
181
+
182
+ # random hue
183
+ if random.randint(2):
184
+ img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
185
+ img[..., 0][img[..., 0] > 360] -= 360
186
+ img[..., 0][img[..., 0] < 0] += 360
187
+
188
+ # convert color from HSV to BGR
189
+ img = mmcv.hsv2bgr(img)
190
+
191
+ # random contrast
192
+ if mode == 0:
193
+ if random.randint(2):
194
+ alpha = random.uniform(self.contrast_lower,
195
+ self.contrast_upper)
196
+ img *= alpha
197
+
198
+ # randomly swap channels
199
+ if random.randint(2):
200
+ img = img[..., random.permutation(3)]
201
+
202
+ new_imgs.append(img.astype(ori_dtype))
203
+
204
+ results['img'] = new_imgs
205
+ return results
206
+
207
+ def __repr__(self):
208
+ repr_str = self.__class__.__name__
209
+ repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
210
+ repr_str += 'contrast_range='
211
+ repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
212
+ repr_str += 'saturation_range='
213
+ repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
214
+ repr_str += f'hue_delta={self.hue_delta})'
215
+ return repr_str
216
+
217
+
218
+ @PIPELINES.register_module()
219
+ class RandomTransformImage(object):
220
+ def __init__(self, ida_aug_conf=None, training=True):
221
+ self.ida_aug_conf = ida_aug_conf
222
+ self.training = training
223
+
224
+ def __call__(self, results):
225
+ resize, resize_dims, crop, flip, rotate = self.sample_augmentation()
226
+
227
+ if len(results['lidar2img']) == len(results['img']):
228
+ for i in range(len(results['img'])):
229
+ img = Image.fromarray(np.uint8(results['img'][i]))
230
+
231
+ # resize, resize_dims, crop, flip, rotate = self._sample_augmentation()
232
+ img, ida_mat = self.img_transform(
233
+ img,
234
+ resize=resize,
235
+ resize_dims=resize_dims,
236
+ crop=crop,
237
+ flip=flip,
238
+ rotate=rotate,
239
+ )
240
+ results['img'][i] = np.array(img).astype(np.uint8)
241
+ results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]
242
+
243
+ elif len(results['img']) == 6:
244
+ for i in range(len(results['img'])):
245
+ img = Image.fromarray(np.uint8(results['img'][i]))
246
+
247
+ # resize, resize_dims, crop, flip, rotate = self._sample_augmentation()
248
+ img, ida_mat = self.img_transform(
249
+ img,
250
+ resize=resize,
251
+ resize_dims=resize_dims,
252
+ crop=crop,
253
+ flip=flip,
254
+ rotate=rotate,
255
+ )
256
+ results['img'][i] = np.array(img).astype(np.uint8)
257
+
258
+ for i in range(len(results['lidar2img'])):
259
+ results['lidar2img'][i] = ida_mat @ results['lidar2img'][i]
260
+
261
+ else:
262
+ raise ValueError()
263
+
264
+ results['ori_shape'] = [img.shape for img in results['img']]
265
+ results['img_shape'] = [img.shape for img in results['img']]
266
+ results['pad_shape'] = [img.shape for img in results['img']]
267
+
268
+ return results
269
+
270
+ def img_transform(self, img, resize, resize_dims, crop, flip, rotate):
271
+ """
272
+ https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L48
273
+ """
274
+ def get_rot(h):
275
+ return torch.Tensor([
276
+ [np.cos(h), np.sin(h)],
277
+ [-np.sin(h), np.cos(h)],
278
+ ])
279
+
280
+ ida_rot = torch.eye(2)
281
+ ida_tran = torch.zeros(2)
282
+
283
+ # adjust image
284
+ img = img.resize(resize_dims)
285
+ img = img.crop(crop)
286
+ if flip:
287
+ img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
288
+ img = img.rotate(rotate)
289
+
290
+ # post-homography transformation
291
+ ida_rot *= resize
292
+ ida_tran -= torch.Tensor(crop[:2])
293
+
294
+ if flip:
295
+ A = torch.Tensor([[-1, 0], [0, 1]])
296
+ b = torch.Tensor([crop[2] - crop[0], 0])
297
+ ida_rot = A.matmul(ida_rot)
298
+ ida_tran = A.matmul(ida_tran) + b
299
+
300
+ A = get_rot(rotate / 180 * np.pi)
301
+ b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
302
+ b = A.matmul(-b) + b
303
+
304
+ ida_rot = A.matmul(ida_rot)
305
+ ida_tran = A.matmul(ida_tran) + b
306
+
307
+ ida_mat = torch.eye(4)
308
+ ida_mat[:2, :2] = ida_rot
309
+ ida_mat[:2, 2] = ida_tran
310
+
311
+ return img, ida_mat.numpy()
312
+
313
+ def sample_augmentation(self):
314
+ """
315
+ https://github.com/Megvii-BaseDetection/BEVStereo/blob/master/dataset/nusc_mv_det_dataset.py#L247
316
+ """
317
+ H, W = self.ida_aug_conf['H'], self.ida_aug_conf['W']
318
+ fH, fW = self.ida_aug_conf['final_dim']
319
+
320
+ if self.training:
321
+ resize = np.random.uniform(*self.ida_aug_conf['resize_lim'])
322
+ resize_dims = (int(W * resize), int(H * resize))
323
+ newW, newH = resize_dims
324
+ crop_h = int((1 - np.random.uniform(*self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
325
+ crop_w = int(np.random.uniform(0, max(0, newW - fW)))
326
+ crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
327
+ flip = False
328
+ if self.ida_aug_conf['rand_flip'] and np.random.choice([0, 1]):
329
+ flip = True
330
+ rotate = np.random.uniform(*self.ida_aug_conf['rot_lim'])
331
+ else:
332
+ resize = max(fH / H, fW / W)
333
+ resize_dims = (int(W * resize), int(H * resize))
334
+ newW, newH = resize_dims
335
+ crop_h = int((1 - np.mean(self.ida_aug_conf['bot_pct_lim'])) * newH) - fH
336
+ crop_w = int(max(0, newW - fW) / 2)
337
+ crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
338
+ flip = False
339
+ rotate = 0
340
+
341
+ return resize, resize_dims, crop, flip, rotate
342
+
343
+
344
+ @PIPELINES.register_module()
345
+ class GlobalRotScaleTransImage(object):
346
+ def __init__(self,
347
+ rot_range=[-0.3925, 0.3925],
348
+ scale_ratio_range=[0.95, 1.05],
349
+ translation_std=[0, 0, 0]):
350
+ self.rot_range = rot_range
351
+ self.scale_ratio_range = scale_ratio_range
352
+ self.translation_std = translation_std
353
+
354
+ def __call__(self, results):
355
+ # random rotate
356
+ rot_angle = np.random.uniform(*self.rot_range)
357
+ self.rotate_z(results, rot_angle)
358
+ results["gt_bboxes_3d"].rotate(np.array(rot_angle))
359
+
360
+ # random scale
361
+ scale_ratio = np.random.uniform(*self.scale_ratio_range)
362
+ self.scale_xyz(results, scale_ratio)
363
+ results["gt_bboxes_3d"].scale(scale_ratio)
364
+
365
+ # TODO: support translation
366
+
367
+ return results
368
+
369
+ def rotate_z(self, results, rot_angle):
370
+ rot_cos = torch.cos(torch.tensor(rot_angle))
371
+ rot_sin = torch.sin(torch.tensor(rot_angle))
372
+
373
+ rot_mat = torch.tensor([
374
+ [rot_cos, -rot_sin, 0, 0],
375
+ [rot_sin, rot_cos, 0, 0],
376
+ [0, 0, 1, 0],
377
+ [0, 0, 0, 1],
378
+ ])
379
+ rot_mat_inv = torch.inverse(rot_mat)
380
+
381
+ for view in range(len(results['lidar2img'])):
382
+ results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ rot_mat_inv).numpy()
383
+
384
+ def scale_xyz(self, results, scale_ratio):
385
+ scale_mat = torch.tensor([
386
+ [scale_ratio, 0, 0, 0],
387
+ [0, scale_ratio, 0, 0],
388
+ [0, 0, scale_ratio, 0],
389
+ [0, 0, 0, 1],
390
+ ])
391
+ scale_mat_inv = torch.inverse(scale_mat)
392
+
393
+ for view in range(len(results['lidar2img'])):
394
+ results['lidar2img'][view] = (torch.tensor(results['lidar2img'][view]).float() @ scale_mat_inv).numpy()
models/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .backbones import __all__
2
+ from .bbox import __all__
3
+ from .sparsebev import SparseBEV
4
+ from .sparsebev_head import SparseBEVHead
5
+ from .sparsebev_transformer import SparseBEVTransformer
6
+
7
+ __all__ = [
8
+ 'SparseBEV', 'SparseBEVHead', 'SparseBEVTransformer'
9
+ ]
models/backbones/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vovnet import VoVNet
2
+
3
+ __all__ = ['VoVNet']
models/backbones/vovnet.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import warnings
5
+ import torch.utils.checkpoint as cp
6
+ from collections import OrderedDict
7
+ from mmcv.runner import BaseModule
8
+ from mmdet.models.builder import BACKBONES
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+
11
+
12
+ VoVNet19_slim_dw_eSE = {
13
+ 'stem': [64, 64, 64],
14
+ 'stage_conv_ch': [64, 80, 96, 112],
15
+ 'stage_out_ch': [112, 256, 384, 512],
16
+ "layer_per_block": 3,
17
+ "block_per_stage": [1, 1, 1, 1],
18
+ "eSE": True,
19
+ "dw": True
20
+ }
21
+
22
+ VoVNet19_dw_eSE = {
23
+ 'stem': [64, 64, 64],
24
+ "stage_conv_ch": [128, 160, 192, 224],
25
+ "stage_out_ch": [256, 512, 768, 1024],
26
+ "layer_per_block": 3,
27
+ "block_per_stage": [1, 1, 1, 1],
28
+ "eSE": True,
29
+ "dw": True
30
+ }
31
+
32
+ VoVNet19_slim_eSE = {
33
+ 'stem': [64, 64, 128],
34
+ 'stage_conv_ch': [64, 80, 96, 112],
35
+ 'stage_out_ch': [112, 256, 384, 512],
36
+ 'layer_per_block': 3,
37
+ 'block_per_stage': [1, 1, 1, 1],
38
+ 'eSE': True,
39
+ "dw": False
40
+ }
41
+
42
+ VoVNet19_eSE = {
43
+ 'stem': [64, 64, 128],
44
+ "stage_conv_ch": [128, 160, 192, 224],
45
+ "stage_out_ch": [256, 512, 768, 1024],
46
+ "layer_per_block": 3,
47
+ "block_per_stage": [1, 1, 1, 1],
48
+ "eSE": True,
49
+ "dw": False
50
+ }
51
+
52
+ VoVNet39_eSE = {
53
+ 'stem': [64, 64, 128],
54
+ "stage_conv_ch": [128, 160, 192, 224],
55
+ "stage_out_ch": [256, 512, 768, 1024],
56
+ "layer_per_block": 5,
57
+ "block_per_stage": [1, 1, 2, 2],
58
+ "eSE": True,
59
+ "dw": False
60
+ }
61
+
62
+ VoVNet57_eSE = {
63
+ 'stem': [64, 64, 128],
64
+ "stage_conv_ch": [128, 160, 192, 224],
65
+ "stage_out_ch": [256, 512, 768, 1024],
66
+ "layer_per_block": 5,
67
+ "block_per_stage": [1, 1, 4, 3],
68
+ "eSE": True,
69
+ "dw": False
70
+ }
71
+
72
+ VoVNet99_eSE = {
73
+ 'stem': [64, 64, 128],
74
+ "stage_conv_ch": [128, 160, 192, 224],
75
+ "stage_out_ch": [256, 512, 768, 1024],
76
+ "layer_per_block": 5,
77
+ "block_per_stage": [1, 3, 9, 3],
78
+ "eSE": True,
79
+ "dw": False
80
+ }
81
+
82
+ _STAGE_SPECS = {
83
+ "V-19-slim-dw-eSE": VoVNet19_slim_dw_eSE,
84
+ "V-19-dw-eSE": VoVNet19_dw_eSE,
85
+ "V-19-slim-eSE": VoVNet19_slim_eSE,
86
+ "V-19-eSE": VoVNet19_eSE,
87
+ "V-39-eSE": VoVNet39_eSE,
88
+ "V-57-eSE": VoVNet57_eSE,
89
+ "V-99-eSE": VoVNet99_eSE,
90
+ }
91
+
92
+
93
+ def dw_conv3x3(in_channels, out_channels, module_name, postfix, stride=1, kernel_size=3, padding=1):
94
+ """3x3 convolution with padding"""
95
+ return [
96
+ (
97
+ '{}_{}/dw_conv3x3'.format(module_name, postfix),
98
+ nn.Conv2d(
99
+ in_channels,
100
+ out_channels,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=padding,
104
+ groups=out_channels,
105
+ bias=False
106
+ )
107
+ ),
108
+ (
109
+ '{}_{}/pw_conv1x1'.format(module_name, postfix),
110
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, bias=False)
111
+ ),
112
+ ('{}_{}/pw_norm'.format(module_name, postfix), nn.BatchNorm2d(out_channels)),
113
+ ('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)),
114
+ ]
115
+
116
+
117
+ def conv3x3(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=3, padding=1):
118
+ """3x3 convolution with padding"""
119
+ return [
120
+ (
121
+ f"{module_name}_{postfix}/conv",
122
+ nn.Conv2d(
123
+ in_channels,
124
+ out_channels,
125
+ kernel_size=kernel_size,
126
+ stride=stride,
127
+ padding=padding,
128
+ groups=groups,
129
+ bias=False,
130
+ ),
131
+ ),
132
+ (f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
133
+ (f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
134
+ ]
135
+
136
+
137
+ def conv1x1(in_channels, out_channels, module_name, postfix, stride=1, groups=1, kernel_size=1, padding=0):
138
+ """1x1 convolution with padding"""
139
+ return [
140
+ (
141
+ f"{module_name}_{postfix}/conv",
142
+ nn.Conv2d(
143
+ in_channels,
144
+ out_channels,
145
+ kernel_size=kernel_size,
146
+ stride=stride,
147
+ padding=padding,
148
+ groups=groups,
149
+ bias=False,
150
+ ),
151
+ ),
152
+ (f"{module_name}_{postfix}/norm", nn.BatchNorm2d(out_channels)),
153
+ (f"{module_name}_{postfix}/relu", nn.ReLU(inplace=True)),
154
+ ]
155
+
156
+
157
+ class Hsigmoid(nn.Module):
158
+ def __init__(self, inplace=True):
159
+ super(Hsigmoid, self).__init__()
160
+ self.inplace = inplace
161
+
162
+ def forward(self, x):
163
+ return F.relu6(x + 3.0, inplace=self.inplace) / 6.0
164
+
165
+
166
+ class eSEModule(nn.Module):
167
+ def __init__(self, channel, reduction=4):
168
+ super(eSEModule, self).__init__()
169
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
170
+ self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)
171
+ self.hsigmoid = Hsigmoid()
172
+
173
+ def forward(self, x):
174
+ inputs = x
175
+ x = self.avg_pool(x)
176
+ x = self.fc(x)
177
+ x = self.hsigmoid(x)
178
+ return inputs * x
179
+
180
+
181
+ class _OSA_module(nn.Module):
182
+ def __init__(self, in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE=False, identity=False, depthwise=False, with_cp=False):
183
+ super(_OSA_module, self).__init__()
184
+ self.with_cp = with_cp
185
+
186
+ self.identity = identity
187
+ self.depthwise = depthwise
188
+ self.isReduced = False
189
+ self.layers = nn.ModuleList()
190
+ in_channel = in_ch
191
+
192
+ if self.depthwise and in_channel != stage_ch:
193
+ self.isReduced = True
194
+ self.conv_reduction = nn.Sequential(
195
+ OrderedDict(conv1x1(in_channel, stage_ch, "{}_reduction".format(module_name), "0"))
196
+ )
197
+
198
+ for i in range(layer_per_block):
199
+ if self.depthwise:
200
+ self.layers.append(nn.Sequential(OrderedDict(dw_conv3x3(stage_ch, stage_ch, module_name, i))))
201
+ else:
202
+ self.layers.append(nn.Sequential(OrderedDict(conv3x3(in_channel, stage_ch, module_name, i))))
203
+ in_channel = stage_ch
204
+
205
+ # feature aggregation
206
+ in_channel = in_ch + layer_per_block * stage_ch
207
+ self.concat = nn.Sequential(OrderedDict(conv1x1(in_channel, concat_ch, module_name, "concat")))
208
+
209
+ self.ese = eSEModule(concat_ch)
210
+
211
+ def _forward(self, x):
212
+ identity_feat = x
213
+
214
+ output = []
215
+ output.append(x)
216
+
217
+ if self.depthwise and self.isReduced:
218
+ x = self.conv_reduction(x)
219
+
220
+ for layer in self.layers:
221
+ x = layer(x)
222
+ output.append(x)
223
+
224
+ x = torch.cat(output, dim=1)
225
+ xt = self.concat(x)
226
+
227
+ xt = self.ese(xt)
228
+
229
+ if self.identity:
230
+ xt = xt + identity_feat
231
+
232
+ return xt
233
+
234
+ def forward(self, x):
235
+ if self.with_cp and self.training and x.requires_grad:
236
+ return cp.checkpoint(self._forward, x)
237
+ else:
238
+ return self._forward(x)
239
+
240
+
241
+ class _OSA_stage(nn.Sequential):
242
+ def __init__(self, in_ch, stage_ch, concat_ch, block_per_stage, layer_per_block, stage_num, SE=False, depthwise=False, with_cp=False):
243
+ super(_OSA_stage, self).__init__()
244
+ if not stage_num == 2:
245
+ self.add_module("Pooling", nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True))
246
+
247
+ if block_per_stage != 1:
248
+ SE = False
249
+
250
+ module_name = f"OSA{stage_num}_1"
251
+ self.add_module(
252
+ module_name, _OSA_module(in_ch, stage_ch, concat_ch, layer_per_block, module_name, SE, depthwise=depthwise, with_cp=with_cp)
253
+ )
254
+
255
+ for i in range(block_per_stage - 1):
256
+ if i != block_per_stage - 2: # last block
257
+ SE = False
258
+ module_name = f"OSA{stage_num}_{i + 2}"
259
+ self.add_module(
260
+ module_name,
261
+ _OSA_module(
262
+ concat_ch,
263
+ stage_ch,
264
+ concat_ch,
265
+ layer_per_block,
266
+ module_name,
267
+ SE,
268
+ identity=True,
269
+ depthwise=depthwise,
270
+ with_cp=with_cp
271
+ ),
272
+ )
273
+
274
+
275
+ @BACKBONES.register_module()
276
+ class VoVNet(BaseModule):
277
+ def __init__(self, spec_name, input_ch=3, out_features=None, frozen_stages=-1, norm_eval=True, with_cp=False, pretrained=None, init_cfg=None):
278
+ """
279
+ Args:
280
+ input_ch(int) : the number of input channel
281
+ out_features (list[str]): name of the layers whose outputs should
282
+ be returned in forward. Can be anything in "stem", "stage2" ...
283
+ """
284
+ super(VoVNet, self).__init__(init_cfg)
285
+ self.frozen_stages = frozen_stages
286
+ self.norm_eval = norm_eval
287
+
288
+ if isinstance(pretrained, str):
289
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
290
+ 'please use "init_cfg" instead')
291
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
292
+ stage_specs = _STAGE_SPECS[spec_name]
293
+
294
+ stem_ch = stage_specs["stem"]
295
+ config_stage_ch = stage_specs["stage_conv_ch"]
296
+ config_concat_ch = stage_specs["stage_out_ch"]
297
+ block_per_stage = stage_specs["block_per_stage"]
298
+ layer_per_block = stage_specs["layer_per_block"]
299
+ SE = stage_specs["eSE"]
300
+ depthwise = stage_specs["dw"]
301
+
302
+ self._out_features = out_features
303
+
304
+ # Stem module
305
+ conv_type = dw_conv3x3 if depthwise else conv3x3
306
+ stem = conv3x3(input_ch, stem_ch[0], "stem", "1", 2)
307
+ stem += conv_type(stem_ch[0], stem_ch[1], "stem", "2", 1)
308
+ stem += conv_type(stem_ch[1], stem_ch[2], "stem", "3", 2)
309
+ self.add_module("stem", nn.Sequential((OrderedDict(stem))))
310
+ current_stirde = 4
311
+ self._out_feature_strides = {"stem": current_stirde, "stage2": current_stirde}
312
+ self._out_feature_channels = {"stem": stem_ch[2]}
313
+
314
+ stem_out_ch = [stem_ch[2]]
315
+ in_ch_list = stem_out_ch + config_concat_ch[:-1]
316
+
317
+ # OSA stages
318
+ self.stage_names = []
319
+ for i in range(4): # num_stages
320
+ name = "stage%d" % (i + 2) # stage 2 ... stage 5
321
+ self.stage_names.append(name)
322
+ self.add_module(
323
+ name,
324
+ _OSA_stage(
325
+ in_ch_list[i],
326
+ config_stage_ch[i],
327
+ config_concat_ch[i],
328
+ block_per_stage[i],
329
+ layer_per_block,
330
+ i + 2,
331
+ SE,
332
+ depthwise,
333
+ with_cp=with_cp
334
+ ),
335
+ )
336
+
337
+ self._out_feature_channels[name] = config_concat_ch[i]
338
+ if not i == 0:
339
+ self._out_feature_strides[name] = current_stirde = int(current_stirde * 2)
340
+
341
+ # initialize weights
342
+ # self._initialize_weights()
343
+
344
+ def _initialize_weights(self):
345
+ for m in self.modules():
346
+ if isinstance(m, nn.Conv2d):
347
+ nn.init.kaiming_normal_(m.weight)
348
+
349
+ def forward(self, x):
350
+ outputs = {}
351
+ x = self.stem(x)
352
+ if "stem" in self._out_features:
353
+ outputs["stem"] = x
354
+ for name in self.stage_names:
355
+ x = getattr(self, name)(x)
356
+ if name in self._out_features:
357
+ outputs[name] = x
358
+
359
+ return outputs
360
+
361
+ def _freeze_stages(self):
362
+ if self.frozen_stages >= 0:
363
+ m = getattr(self, 'stem')
364
+ m.eval()
365
+ for param in m.parameters():
366
+ param.requires_grad = False
367
+
368
+ for i in range(1, self.frozen_stages + 1):
369
+ m = getattr(self, f'stage{i+1}')
370
+ m.eval()
371
+ for param in m.parameters():
372
+ param.requires_grad = False
373
+
374
+ def train(self, mode=True):
375
+ """Convert the model into training mode while keep normalization layer
376
+ freezed."""
377
+ super(VoVNet, self).train(mode)
378
+ self._freeze_stages()
379
+ if mode and self.norm_eval:
380
+ for m in self.modules():
381
+ # trick: eval have effect on BatchNorm only
382
+ if isinstance(m, _BatchNorm):
383
+ m.eval()
models/bbox/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .assigners import __all__
2
+ from .coders import __all__
3
+ from .match_costs import __all__
models/bbox/assigners/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .hungarian_assigner_3d import HungarianAssigner3D
2
+
3
+ __all__ = ['HungarianAssigner3D']
models/bbox/assigners/hungarian_assigner_3d.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mmdet.core.bbox.builder import BBOX_ASSIGNERS
4
+ from mmdet.core.bbox.assigners import AssignResult
5
+ from mmdet.core.bbox.assigners import BaseAssigner
6
+ from mmdet.core.bbox.match_costs import build_match_cost
7
+ from ..utils import normalize_bbox
8
+
9
+ try:
10
+ from scipy.optimize import linear_sum_assignment
11
+ except ImportError:
12
+ linear_sum_assignment = None
13
+
14
+
15
+ @BBOX_ASSIGNERS.register_module()
16
+ class HungarianAssigner3D(BaseAssigner):
17
+ def __init__(self,
18
+ cls_cost=dict(type='ClassificationCost', weight=1.),
19
+ reg_cost=dict(type='BBoxL1Cost', weight=1.0),
20
+ iou_cost=dict(type='IoUCost', weight=0.0),
21
+ pc_range=None):
22
+ self.cls_cost = build_match_cost(cls_cost)
23
+ self.reg_cost = build_match_cost(reg_cost)
24
+ self.iou_cost = build_match_cost(iou_cost)
25
+ self.pc_range = pc_range
26
+
27
+ def assign(self,
28
+ bbox_pred,
29
+ cls_pred,
30
+ gt_bboxes,
31
+ gt_labels,
32
+ gt_bboxes_ignore=None,
33
+ code_weights=None,
34
+ with_velo=False):
35
+ assert gt_bboxes_ignore is None, \
36
+ 'Only case when gt_bboxes_ignore is None is supported.'
37
+ num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
38
+
39
+ # 1. assign -1 by default
40
+ assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
41
+ -1,
42
+ dtype=torch.long)
43
+ assigned_labels = bbox_pred.new_full((num_bboxes, ),
44
+ -1,
45
+ dtype=torch.long)
46
+ if num_gts == 0 or num_bboxes == 0:
47
+ # No ground truth or boxes, return empty assignment
48
+ if num_gts == 0:
49
+ # No ground truth, assign all to background
50
+ assigned_gt_inds[:] = 0
51
+ return AssignResult(
52
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
53
+
54
+ # 2. compute the weighted costs
55
+ # classification and bboxcost.
56
+ cls_cost = self.cls_cost(cls_pred, gt_labels)
57
+ # regression L1 cost
58
+ normalized_gt_bboxes = normalize_bbox(gt_bboxes)
59
+
60
+ if code_weights is not None:
61
+ bbox_pred = bbox_pred * code_weights
62
+ normalized_gt_bboxes = normalized_gt_bboxes * code_weights
63
+
64
+ if with_velo:
65
+ reg_cost = self.reg_cost(bbox_pred, normalized_gt_bboxes)
66
+ else:
67
+ reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8])
68
+
69
+ # weighted sum of above two costs
70
+ cost = cls_cost + reg_cost
71
+
72
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
73
+ cost = cost.detach().cpu()
74
+ cost = torch.nan_to_num(cost, nan=100.0, posinf=100.0, neginf=-100.0)
75
+
76
+ if linear_sum_assignment is None:
77
+ raise ImportError('Please run "pip install scipy" '
78
+ 'to install scipy first.')
79
+
80
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
81
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
82
+ bbox_pred.device)
83
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
84
+ bbox_pred.device)
85
+
86
+ # 4. assign backgrounds and foregrounds
87
+ # assign all indices to backgrounds first
88
+ assigned_gt_inds[:] = 0
89
+ # assign foregrounds based on matching results
90
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
91
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
92
+ return AssignResult(
93
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
models/bbox/coders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .nms_free_coder import NMSFreeCoder
2
+
3
+ __all__ = ['NMSFreeCoder']
models/bbox/coders/nms_free_coder.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mmdet.core.bbox import BaseBBoxCoder
4
+ from mmdet.core.bbox.builder import BBOX_CODERS
5
+ from ..utils import denormalize_bbox
6
+
7
+
8
+ @BBOX_CODERS.register_module()
9
+ class NMSFreeCoder(BaseBBoxCoder):
10
+ """Bbox coder for NMS-free detector.
11
+ Args:
12
+ pc_range (list[float]): Range of point cloud.
13
+ post_center_range (list[float]): Limit of the center.
14
+ Default: None.
15
+ max_num (int): Max number to be kept. Default: 100.
16
+ score_threshold (float): Threshold to filter boxes based on score.
17
+ Default: None.
18
+ code_size (int): Code size of bboxes. Default: 9
19
+ """
20
+ def __init__(self,
21
+ pc_range,
22
+ voxel_size=None,
23
+ post_center_range=None,
24
+ max_num=100,
25
+ score_threshold=None,
26
+ num_classes=10):
27
+ self.pc_range = pc_range
28
+ self.voxel_size = voxel_size
29
+ self.post_center_range = post_center_range
30
+ self.max_num = max_num
31
+ self.score_threshold = score_threshold
32
+ self.num_classes = num_classes
33
+
34
+ def encode(self):
35
+ pass
36
+
37
+ def decode_single(self, cls_scores, bbox_preds):
38
+ """Decode bboxes.
39
+ Args:
40
+ cls_scores (Tensor): Outputs from the classification head, \
41
+ shape [num_query, cls_out_channels]. Note \
42
+ cls_out_channels should includes background.
43
+ bbox_preds (Tensor): Outputs from the regression \
44
+ head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
45
+ Shape [num_query, 9].
46
+ Returns:
47
+ list[dict]: Decoded boxes.
48
+ """
49
+ max_num = self.max_num
50
+
51
+ cls_scores = cls_scores.sigmoid()
52
+ scores, indexs = cls_scores.view(-1).topk(max_num)
53
+ labels = indexs % self.num_classes
54
+ bbox_index = torch.div(indexs, self.num_classes, rounding_mode='trunc')
55
+ bbox_preds = bbox_preds[bbox_index]
56
+
57
+ final_box_preds = denormalize_bbox(bbox_preds)
58
+ final_scores = scores
59
+ final_preds = labels
60
+
61
+ # use score threshold
62
+ if self.score_threshold is not None:
63
+ thresh_mask = final_scores > self.score_threshold
64
+
65
+ if self.post_center_range is not None:
66
+ limit = torch.tensor(self.post_center_range, device=scores.device)
67
+ mask = (final_box_preds[..., :3] >= limit[:3]).all(1)
68
+ mask &= (final_box_preds[..., :3] <= limit[3:]).all(1)
69
+
70
+ if self.score_threshold:
71
+ mask &= thresh_mask
72
+
73
+ boxes3d = final_box_preds[mask]
74
+ scores = final_scores[mask]
75
+ labels = final_preds[mask]
76
+ predictions_dict = {
77
+ 'bboxes': boxes3d,
78
+ 'scores': scores,
79
+ 'labels': labels
80
+ }
81
+
82
+ else:
83
+ raise NotImplementedError(
84
+ 'Need to reorganize output as a batch, only '
85
+ 'support post_center_range is not None for now!'
86
+ )
87
+
88
+ return predictions_dict
89
+
90
+ def decode(self, preds_dicts):
91
+ """Decode bboxes.
92
+ Args:
93
+ all_cls_scores (Tensor): Outputs from the classification head, \
94
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note \
95
+ cls_out_channels should includes background.
96
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression \
97
+ head with normalized coordinate format (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \
98
+ Shape [nb_dec, bs, num_query, 9].
99
+ Returns:
100
+ list[dict]: Decoded boxes.
101
+ """
102
+ all_cls_scores = preds_dicts['all_cls_scores'][-1]
103
+ all_bbox_preds = preds_dicts['all_bbox_preds'][-1]
104
+
105
+ batch_size = all_cls_scores.size()[0]
106
+ predictions_list = []
107
+ for i in range(batch_size):
108
+ predictions_list.append(self.decode_single(all_cls_scores[i], all_bbox_preds[i]))
109
+
110
+ return predictions_list
models/bbox/match_costs/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .match_cost import BBox3DL1Cost
2
+
3
+ __all__ = ['BBox3DL1Cost']
models/bbox/match_costs/match_cost.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mmdet.core.bbox.match_costs.builder import MATCH_COST
3
+
4
+
5
+ @MATCH_COST.register_module()
6
+ class BBox3DL1Cost(object):
7
+ """BBox3DL1Cost.
8
+ Args:
9
+ weight (int | float, optional): loss_weight
10
+ """
11
+
12
+ def __init__(self, weight=1.0):
13
+ self.weight = weight
14
+
15
+ def __call__(self, bbox_pred, gt_bboxes):
16
+ """
17
+ Args:
18
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
19
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
20
+ [num_query, 4].
21
+ gt_bboxes (Tensor): Ground truth boxes with normalized
22
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
23
+ Returns:
24
+ torch.Tensor: bbox_cost value with weight
25
+ """
26
+ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
27
+ return bbox_cost * self.weight
28
+
29
+
30
+ @MATCH_COST.register_module()
31
+ class BBoxBEVL1Cost(object):
32
+ def __init__(self, weight, pc_range):
33
+ self.weight = weight
34
+ self.pc_range = pc_range
35
+
36
+ def __call__(self, bboxes, gt_bboxes):
37
+ pc_start = bboxes.new(self.pc_range[0:2])
38
+ pc_range = bboxes.new(self.pc_range[3:5]) - bboxes.new(self.pc_range[0:2])
39
+ # normalize the box center to [0, 1]
40
+ normalized_bboxes_xy = (bboxes[:, :2] - pc_start) / pc_range
41
+ normalized_gt_bboxes_xy = (gt_bboxes[:, :2] - pc_start) / pc_range
42
+ reg_cost = torch.cdist(normalized_bboxes_xy, normalized_gt_bboxes_xy, p=1)
43
+ return reg_cost * self.weight
44
+
45
+
46
+ @MATCH_COST.register_module()
47
+ class IoU3DCost(object):
48
+ def __init__(self, weight):
49
+ self.weight = weight
50
+
51
+ def __call__(self, iou):
52
+ iou_cost = - iou
53
+ return iou_cost * self.weight
models/bbox/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def normalize_bbox(bboxes):
5
+ cx = bboxes[..., 0:1]
6
+ cy = bboxes[..., 1:2]
7
+ cz = bboxes[..., 2:3]
8
+ w = bboxes[..., 3:4].log()
9
+ l = bboxes[..., 4:5].log()
10
+ h = bboxes[..., 5:6].log()
11
+ rot = bboxes[..., 6:7]
12
+
13
+ if bboxes.size(-1) > 7:
14
+ vx = bboxes[..., 7:8]
15
+ vy = bboxes[..., 8:9]
16
+ out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos(), vx, vy], dim=-1)
17
+ else:
18
+ out = torch.cat([cx, cy, w, l, cz, h, rot.sin(), rot.cos()], dim=-1)
19
+
20
+ return out
21
+
22
+
23
+ def denormalize_bbox(normalized_bboxes):
24
+ rot_sin = normalized_bboxes[..., 6:7]
25
+ rot_cos = normalized_bboxes[..., 7:8]
26
+ rot = torch.atan2(rot_sin, rot_cos)
27
+
28
+ cx = normalized_bboxes[..., 0:1]
29
+ cy = normalized_bboxes[..., 1:2]
30
+ cz = normalized_bboxes[..., 4:5]
31
+
32
+ w = normalized_bboxes[..., 2:3].exp()
33
+ l = normalized_bboxes[..., 3:4].exp()
34
+ h = normalized_bboxes[..., 5:6].exp()
35
+
36
+ if normalized_bboxes.size(-1) > 8:
37
+ vx = normalized_bboxes[..., 8:9]
38
+ vy = normalized_bboxes[..., 9:10]
39
+ out = torch.cat([cx, cy, cz, w, l, h, rot, vx, vy], dim=-1)
40
+ else:
41
+ out = torch.cat([cx, cy, cz, w, l, h, rot], dim=-1)
42
+
43
+ return out
44
+
45
+
46
+ def encode_bbox(bboxes, pc_range=None):
47
+ xyz = bboxes[..., 0:3].clone()
48
+ wlh = bboxes[..., 3:6].log()
49
+ rot = bboxes[..., 6:7]
50
+
51
+ if pc_range is not None:
52
+ xyz[..., 0] = (xyz[..., 0] - pc_range[0]) / (pc_range[3] - pc_range[0])
53
+ xyz[..., 1] = (xyz[..., 1] - pc_range[1]) / (pc_range[4] - pc_range[1])
54
+ xyz[..., 2] = (xyz[..., 2] - pc_range[2]) / (pc_range[5] - pc_range[2])
55
+
56
+ if bboxes.shape[-1] > 7:
57
+ vel = bboxes[..., 7:9].clone()
58
+ return torch.cat([xyz, wlh, rot.sin(), rot.cos(), vel], dim=-1)
59
+ else:
60
+ return torch.cat([xyz, wlh, rot.sin(), rot.cos()], dim=-1)
61
+
62
+
63
+ def decode_bbox(bboxes, pc_range=None):
64
+ xyz = bboxes[..., 0:3].clone()
65
+ wlh = bboxes[..., 3:6].exp()
66
+ rot = torch.atan2(bboxes[..., 6:7], bboxes[..., 7:8])
67
+
68
+ if pc_range is not None:
69
+ xyz[..., 0] = xyz[..., 0] * (pc_range[3] - pc_range[0]) + pc_range[0]
70
+ xyz[..., 1] = xyz[..., 1] * (pc_range[4] - pc_range[1]) + pc_range[1]
71
+ xyz[..., 2] = xyz[..., 2] * (pc_range[5] - pc_range[2]) + pc_range[2]
72
+
73
+ if bboxes.shape[-1] > 8:
74
+ vel = bboxes[..., 8:10].clone()
75
+ return torch.cat([xyz, wlh, rot, vel], dim=-1)
76
+ else:
77
+ return torch.cat([xyz, wlh, rot], dim=-1)
models/checkpoint.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html#checkpoint
2
+
3
+ import torch
4
+ import warnings
5
+ import weakref
6
+ from typing import Any, Iterable, List, Tuple
7
+
8
+ __all__ = [
9
+ "checkpoint", "checkpoint_sequential", "CheckpointFunction",
10
+ "check_backward_validity", "detach_variable", "get_device_states",
11
+ "set_device_states",
12
+ ]
13
+
14
+ def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
15
+ if isinstance(inputs, tuple):
16
+ out = []
17
+ for inp in inputs:
18
+ if not isinstance(inp, torch.Tensor):
19
+ out.append(inp)
20
+ continue
21
+
22
+ x = inp.detach()
23
+ x.requires_grad = inp.requires_grad
24
+ out.append(x)
25
+ return tuple(out)
26
+ else:
27
+ raise RuntimeError(
28
+ "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)
29
+
30
+
31
+ def check_backward_validity(inputs: Iterable[Any]) -> None:
32
+ if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
33
+ warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
34
+
35
+
36
+ # We can't know if the run_fn will internally move some args to different devices,
37
+ # which would require logic to preserve rng states for those devices as well.
38
+ # We could paranoically stash and restore ALL the rng states for all visible devices,
39
+ # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
40
+ # the device of all Tensor args.
41
+ #
42
+ # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
43
+ def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
44
+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
45
+ # the conditionals short-circuit.
46
+ fwd_gpu_devices = list({arg.get_device() for arg in args
47
+ if isinstance(arg, torch.Tensor) and arg.is_cuda})
48
+
49
+ fwd_gpu_states = []
50
+ for device in fwd_gpu_devices:
51
+ with torch.cuda.device(device):
52
+ fwd_gpu_states.append(torch.cuda.get_rng_state())
53
+
54
+ return fwd_gpu_devices, fwd_gpu_states
55
+
56
+
57
+ def set_device_states(devices, states) -> None:
58
+ for device, state in zip(devices, states):
59
+ with torch.cuda.device(device):
60
+ torch.cuda.set_rng_state(state)
61
+
62
+ def _get_autocast_kwargs():
63
+ gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
64
+ "dtype": torch.get_autocast_gpu_dtype(),
65
+ "cache_enabled": torch.is_autocast_cache_enabled()}
66
+
67
+ cpu_autocast_kwargs = {"enabled": torch.is_autocast_cpu_enabled(),
68
+ "dtype": torch.get_autocast_cpu_dtype(),
69
+ "cache_enabled": torch.is_autocast_cache_enabled()}
70
+
71
+ return gpu_autocast_kwargs, cpu_autocast_kwargs
72
+
73
+ class CheckpointFunction(torch.autograd.Function):
74
+
75
+ @staticmethod
76
+ def forward(ctx, run_function, preserve_rng_state, *args):
77
+ check_backward_validity(args)
78
+ ctx.run_function = run_function
79
+ ctx.preserve_rng_state = preserve_rng_state
80
+ # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
81
+ ctx.gpu_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs()
82
+ if preserve_rng_state:
83
+ ctx.fwd_cpu_state = torch.get_rng_state()
84
+ # Don't eagerly initialize the cuda context by accident.
85
+ # (If the user intends that the context is initialized later, within their
86
+ # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
87
+ # we have no way to anticipate this will happen before we run the function.)
88
+ ctx.had_cuda_in_fwd = False
89
+ if torch.cuda._initialized:
90
+ ctx.had_cuda_in_fwd = True
91
+ ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
92
+
93
+ # Save non-tensor inputs in ctx, keep a placeholder None for tensors
94
+ # to be filled out during the backward.
95
+ ctx.inputs = []
96
+ ctx.tensor_indices = []
97
+ tensor_inputs = []
98
+ for i, arg in enumerate(args):
99
+ if torch.is_tensor(arg):
100
+ tensor_inputs.append(arg)
101
+ ctx.tensor_indices.append(i)
102
+ ctx.inputs.append(None)
103
+ else:
104
+ ctx.inputs.append(arg)
105
+
106
+ ctx.save_for_backward(*tensor_inputs)
107
+
108
+ with torch.no_grad():
109
+ outputs = run_function(*args)
110
+ return outputs
111
+
112
+ @staticmethod
113
+ def backward(ctx, *args):
114
+ if not torch.autograd._is_checkpoint_valid():
115
+ raise RuntimeError(
116
+ "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
117
+ " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
118
+ " argument.")
119
+ # Copy the list to avoid modifying original list.
120
+ inputs = list(ctx.inputs)
121
+ tensor_indices = ctx.tensor_indices
122
+ tensors = ctx.saved_tensors
123
+
124
+ # Fill in inputs with appropriate saved tensors.
125
+ for i, idx in enumerate(tensor_indices):
126
+ inputs[idx] = tensors[i]
127
+
128
+ # Stash the surrounding rng state, and mimic the state that was
129
+ # present at this time during forward. Restore the surrounding state
130
+ # when we're done.
131
+ rng_devices = []
132
+ if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
133
+ rng_devices = ctx.fwd_gpu_devices
134
+ with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
135
+ if ctx.preserve_rng_state:
136
+ torch.set_rng_state(ctx.fwd_cpu_state)
137
+ if ctx.had_cuda_in_fwd:
138
+ set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
139
+ detached_inputs = detach_variable(tuple(inputs))
140
+ with torch.enable_grad(), \
141
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
142
+ torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
143
+ outputs = ctx.run_function(*detached_inputs)
144
+
145
+ if isinstance(outputs, torch.Tensor):
146
+ outputs = (outputs,)
147
+
148
+ # run backward() with only tensor that requires grad
149
+ outputs_with_grad = []
150
+ args_with_grad = []
151
+ for i in range(len(outputs)):
152
+ if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
153
+ outputs_with_grad.append(outputs[i])
154
+ args_with_grad.append(args[i])
155
+ if len(outputs_with_grad) == 0:
156
+ raise RuntimeError(
157
+ "none of output has requires_grad=True,"
158
+ " this checkpoint() is not necessary")
159
+ torch.autograd.backward(outputs_with_grad, args_with_grad)
160
+ grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
161
+ for inp in detached_inputs)
162
+
163
+ return (None, None) + grads
164
+
165
+
166
+ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
167
+ r"""Checkpoint a model or part of the model
168
+
169
+ Checkpointing works by trading compute for memory. Rather than storing all
170
+ intermediate activations of the entire computation graph for computing
171
+ backward, the checkpointed part does **not** save intermediate activations,
172
+ and instead recomputes them in backward pass. It can be applied on any part
173
+ of a model.
174
+
175
+ Specifically, in the forward pass, :attr:`function` will run in
176
+ :func:`torch.no_grad` manner, i.e., not storing the intermediate
177
+ activations. Instead, the forward pass saves the inputs tuple and the
178
+ :attr:`function` parameter. In the backwards pass, the saved inputs and
179
+ :attr:`function` is retrieved, and the forward pass is computed on
180
+ :attr:`function` again, now tracking the intermediate activations, and then
181
+ the gradients are calculated using these activation values.
182
+
183
+ The output of :attr:`function` can contain non-Tensor values and gradient
184
+ recording is only performed for the Tensor values. Note that if the output
185
+ consists of nested structures (ex: custom objects, lists, dicts etc.)
186
+ consisting of Tensors, these Tensors nested in custom structures will not
187
+ be considered as part of autograd.
188
+
189
+
190
+ .. warning::
191
+ If :attr:`function` invocation during backward does anything different
192
+ than the one during forward, e.g., due to some global variable, the
193
+ checkpointed version won't be equivalent, and unfortunately it can't be
194
+ detected.
195
+
196
+ .. warning::
197
+ If ``use_reentrant=True`` is specified, then if the checkpointed segment
198
+ contains tensors detached from the computational graph by `detach()` or
199
+ `torch.no_grad()`, the backward pass will raise an error. This is
200
+ because `checkpoint` makes all the outputs require gradients which
201
+ causes issues when a tensor is defined to have no gradient in the model.
202
+ To circumvent this, detach the tensors outside of the `checkpoint`
203
+ function. Note that the checkpointed segment can contain tensors
204
+ detached from the computational graph if ``use_reentrant=False`` is
205
+ specified.
206
+
207
+ .. warning::
208
+ If ``use_reentrant=True`` is specified, at least one of the inputs needs
209
+ to have :code:`requires_grad=True` if grads are needed for model inputs,
210
+ otherwise the checkpointed part of the model won't have gradients. At
211
+ least one of the outputs needs to have :code:`requires_grad=True` as
212
+ well. Note that this does not apply if ``use_reentrant=False`` is
213
+ specified.
214
+
215
+ .. warning::
216
+ If ``use_reentrant=True`` is specified, checkpointing currently only
217
+ supports :func:`torch.autograd.backward` and only if its `inputs`
218
+ argument is not passed. :func:`torch.autograd.grad`
219
+ is not supported. If ``use_reentrant=False`` is specified, checkpointing
220
+ will work with :func:`torch.autograd.grad`.
221
+
222
+ Args:
223
+ function: describes what to run in the forward pass of the model or
224
+ part of the model. It should also know how to handle the inputs
225
+ passed as the tuple. For example, in LSTM, if user passes
226
+ ``(activation, hidden)``, :attr:`function` should correctly use the
227
+ first input as ``activation`` and the second input as ``hidden``
228
+ preserve_rng_state(bool, optional): Omit stashing and restoring
229
+ the RNG state during each checkpoint.
230
+ Default: ``True``
231
+ use_reentrant(bool, optional): Use checkpointing
232
+ implementation that requires re-entrant autograd.
233
+ If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
234
+ implementation that does not require re-entrant autograd. This
235
+ allows ``checkpoint`` to support additional functionality, such as
236
+ working as expected with ``torch.autograd.grad`` and support for
237
+ keyword arguments input into the checkpointed function. Note that future
238
+ versions of PyTorch will default to ``use_reentrant=False``.
239
+ Default: ``True``
240
+ args: tuple containing inputs to the :attr:`function`
241
+
242
+ Returns:
243
+ Output of running :attr:`function` on :attr:`*args`
244
+ """
245
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
246
+ preserve = kwargs.pop('preserve_rng_state', True)
247
+ if kwargs and use_reentrant:
248
+ raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
249
+
250
+ if use_reentrant:
251
+ return CheckpointFunction.apply(function, preserve, *args)
252
+ else:
253
+ return _checkpoint_without_reentrant(
254
+ function,
255
+ preserve,
256
+ *args,
257
+ **kwargs,
258
+ )
259
+
260
+
261
+ def checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs):
262
+ r"""A helper function for checkpointing sequential models.
263
+
264
+ Sequential models execute a list of modules/functions in order
265
+ (sequentially). Therefore, we can divide such a model in various segments
266
+ and checkpoint each segment. All segments except the last will run in
267
+ :func:`torch.no_grad` manner, i.e., not storing the intermediate
268
+ activations. The inputs of each checkpointed segment will be saved for
269
+ re-running the segment in the backward pass.
270
+
271
+ See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
272
+
273
+ .. warning::
274
+ Checkpointing currently only supports :func:`torch.autograd.backward`
275
+ and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
276
+ is not supported.
277
+
278
+ .. warning:
279
+ At least one of the inputs needs to have :code:`requires_grad=True` if
280
+ grads are needed for model inputs, otherwise the checkpointed part of the
281
+ model won't have gradients.
282
+
283
+ .. warning:
284
+ Since PyTorch 1.4, it allows only one Tensor as the input and
285
+ intermediate outputs, just like :class:`torch.nn.Sequential`.
286
+
287
+ Args:
288
+ functions: A :class:`torch.nn.Sequential` or the list of modules or
289
+ functions (comprising the model) to run sequentially.
290
+ segments: Number of chunks to create in the model
291
+ input: A Tensor that is input to :attr:`functions`
292
+ preserve_rng_state(bool, optional): Omit stashing and restoring
293
+ the RNG state during each checkpoint.
294
+ Default: ``True``
295
+ use_reentrant(bool, optional): Use checkpointing
296
+ implementation that requires re-entrant autograd.
297
+ If ``use_reentrant=False`` is specified, ``checkpoint`` will use an
298
+ implementation that does not require re-entrant autograd. This
299
+ allows ``checkpoint`` to support additional functionality, such as
300
+ working as expected with ``torch.autograd.grad`` and support for
301
+ keyword arguments input into the checkpointed function.
302
+ Default: ``True``
303
+
304
+ Returns:
305
+ Output of running :attr:`functions` sequentially on :attr:`*inputs`
306
+
307
+ Example:
308
+ >>> # xdoctest: +SKIP("stub")
309
+ >>> model = nn.Sequential(...)
310
+ >>> input_var = checkpoint_sequential(model, chunks, input_var)
311
+ """
312
+ # Hack for keyword-only parameter in a python 2.7-compliant way
313
+ preserve = kwargs.pop('preserve_rng_state', True)
314
+ if kwargs:
315
+ raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
316
+
317
+ def run_function(start, end, functions):
318
+ def forward(input):
319
+ for j in range(start, end + 1):
320
+ input = functions[j](input)
321
+ return input
322
+ return forward
323
+
324
+ if isinstance(functions, torch.nn.Sequential):
325
+ functions = list(functions.children())
326
+
327
+ segment_size = len(functions) // segments
328
+ # the last chunk has to be non-volatile
329
+ end = -1
330
+ for start in range(0, segment_size * (segments - 1), segment_size):
331
+ end = start + segment_size - 1
332
+ input = checkpoint(
333
+ run_function(start, end, functions),
334
+ input,
335
+ use_reentrant=use_reentrant,
336
+ preserve_rng_state=preserve
337
+ )
338
+ return run_function(end + 1, len(functions) - 1, functions)(input)
339
+
340
+
341
+ def _checkpoint_without_reentrant(function, preserve_rng_state=True, *args, **kwargs):
342
+ """Checkpointining without re-entrant autograd
343
+ Args:
344
+ function: describes what to run in the forward pass of the model or
345
+ part of the model. It should also know how to handle the inputs
346
+ passed as the tuple. For example, in LSTM, if user passes
347
+ ``(activation, hidden)``, :attr:`function` should correctly use the
348
+ first input as ``activation`` and the second input as ``hidden``
349
+ preserve_rng_state(bool, optional): Omit stashing and restoring
350
+ the RNG state during each checkpoint.
351
+ Default: ``True``
352
+ *args: Arguments to pass in to the given ``function``.
353
+ **kwargs: Keyword arguments to pass into the given ``function``.
354
+ """
355
+ # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
356
+ gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
357
+
358
+ if preserve_rng_state:
359
+ fwd_cpu_state = torch.get_rng_state()
360
+ # Don't eagerly initialize the cuda context by accident.
361
+ # (If the user intends that the context is initialized later, within their
362
+ # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
363
+ # we have no way to anticipate this will happen before we run the function.
364
+ # If they do so, we raise an error.)
365
+ had_cuda_in_fwd = False
366
+ if torch.cuda._initialized:
367
+ had_cuda_in_fwd = True
368
+ fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
369
+
370
+ # Custom class to be able to take weak references
371
+ class Holder():
372
+ pass
373
+ # The Holder object for each of the saved object is saved directly on the
374
+ # SavedVariable and is cleared when reset_data() is called on it. We MUST make
375
+ # sure that this is the only object having an owning reference to ensure that
376
+ # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
377
+ # data is cleared.
378
+ storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
379
+ weak_holder_list = []
380
+
381
+ def pack(x):
382
+ # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
383
+ # size, device, ...) to catch certain cases of undeterministic behavior of the forward
384
+ res = Holder()
385
+ weak_holder_list.append(weakref.ref(res))
386
+ return res
387
+
388
+
389
+ def unpack(x):
390
+ unpack_counter = 0
391
+ if len(storage) == 0:
392
+ def inner_pack(inner):
393
+ nonlocal unpack_counter
394
+ unpack_counter += 1
395
+ # If the holder went out of scope, the SavedVariable is dead and so
396
+ # the value will never be read from the storage. Skip filling it.
397
+ if weak_holder_list[unpack_counter - 1]() is None:
398
+ return
399
+ # Use detach here to ensure we don't keep the temporary autograd
400
+ # graph created during the second forward
401
+ storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
402
+ return
403
+
404
+ def inner_unpack(packed):
405
+ raise RuntimeError("You are calling backwards on a tensor that is never exposed. Please open an issue.")
406
+
407
+ # Stash the surrounding rng state, and mimic the state that was
408
+ # present at this time during forward. Restore the surrounding state
409
+ # when we're done.
410
+ rng_devices = []
411
+ if preserve_rng_state and had_cuda_in_fwd:
412
+ rng_devices = fwd_gpu_devices
413
+ with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
414
+ if preserve_rng_state:
415
+ torch.set_rng_state(fwd_cpu_state)
416
+ if had_cuda_in_fwd:
417
+ set_device_states(fwd_gpu_devices, fwd_gpu_states)
418
+
419
+ with torch.enable_grad(), \
420
+ torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
421
+ torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
422
+ torch.autograd.graph.saved_tensors_hooks(inner_pack, inner_unpack):
423
+ _unused = function(*args, **kwargs)
424
+
425
+ if x not in storage:
426
+ raise RuntimeError(
427
+ "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
428
+ " recomputation being triggered in between, this is not currently supported. Please"
429
+ " open an issue with details on your use case so that we can prioritize adding this."
430
+ )
431
+
432
+ return storage[x]
433
+
434
+ with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
435
+ output = function(*args, **kwargs)
436
+ if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
437
+ # Cuda was not initialized before running the forward, so we didn't
438
+ # stash the CUDA state.
439
+ raise RuntimeError(
440
+ "PyTorch's CUDA state was initialized in the forward pass "
441
+ "of a Checkpoint, which is not allowed. Please open an issue "
442
+ "if you need this feature.")
443
+
444
+ return output
models/csrc/__init__.py ADDED
File without changes
models/csrc/msmv_sampling/msmv_sampling.cpp ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "msmv_sampling.h"
2
+
3
+ #define MAX_POINT 32
4
+
5
+ void ms_deformable_im2col_cuda_c2345(
6
+ const float* feat_c2,
7
+ const float* feat_c3,
8
+ const float* feat_c4,
9
+ const float* feat_c5,
10
+ const int h_c2, const int w_c2,
11
+ const int h_c3, const int w_c3,
12
+ const int h_c4, const int w_c4,
13
+ const int h_c5, const int w_c5,
14
+ const float* data_sampling_loc,
15
+ const float* data_attn_weight,
16
+ const int batch_size,
17
+ const int channels,
18
+ const int num_views,
19
+ const int num_query,
20
+ const int num_point,
21
+ float* data_col
22
+ );
23
+
24
+ void ms_deformable_im2col_cuda_c23456(
25
+ const float* feat_c2,
26
+ const float* feat_c3,
27
+ const float* feat_c4,
28
+ const float* feat_c5,
29
+ const float* feat_c6,
30
+ const int h_c2, const int w_c2,
31
+ const int h_c3, const int w_c3,
32
+ const int h_c4, const int w_c4,
33
+ const int h_c5, const int w_c5,
34
+ const int h_c6, const int w_c6,
35
+ const float* data_sampling_loc,
36
+ const float* data_attn_weight,
37
+ const int batch_size,
38
+ const int channels,
39
+ const int num_views,
40
+ const int num_query,
41
+ const int num_point,
42
+ float* data_col
43
+ );
44
+
45
+ void ms_deformable_col2im_cuda_c2345(
46
+ const float* grad_col,
47
+ const float* feat_c2,
48
+ const float* feat_c3,
49
+ const float* feat_c4,
50
+ const float* feat_c5,
51
+ const int h_c2, const int w_c2,
52
+ const int h_c3, const int w_c3,
53
+ const int h_c4, const int w_c4,
54
+ const int h_c5, const int w_c5,
55
+ const float* data_sampling_loc,
56
+ const float* data_attn_weight,
57
+ const int batch_size,
58
+ const int channels,
59
+ const int num_views,
60
+ const int num_query,
61
+ const int num_point,
62
+ float* grad_value_c2,
63
+ float* grad_value_c3,
64
+ float* grad_value_c4,
65
+ float* grad_value_c5,
66
+ float* grad_sampling_loc,
67
+ float* grad_attn_weight
68
+ );
69
+
70
+ void ms_deformable_col2im_cuda_c23456(
71
+ const float *grad_col,
72
+ const float *feat_c2,
73
+ const float *feat_c3,
74
+ const float *feat_c4,
75
+ const float *feat_c5,
76
+ const float *feat_c6,
77
+ const int h_c2, const int w_c2,
78
+ const int h_c3, const int w_c3,
79
+ const int h_c4, const int w_c4,
80
+ const int h_c5, const int w_c5,
81
+ const int h_c6, const int w_c6,
82
+ const float *data_sampling_loc,
83
+ const float *data_attn_weight,
84
+ const int batch_size,
85
+ const int channels,
86
+ const int num_views,
87
+ const int num_query,
88
+ const int num_point,
89
+ float *grad_value_c2,
90
+ float *grad_value_c3,
91
+ float *grad_value_c4,
92
+ float *grad_value_c5,
93
+ float *grad_value_c6,
94
+ float *grad_sampling_loc,
95
+ float *grad_attn_weight
96
+ );
97
+
98
+ at::Tensor ms_deform_attn_cuda_c2345_forward(
99
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
100
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
101
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
102
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
103
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
104
+ const at::Tensor& attn_weight // [B, Q, P, 4]
105
+ ) {
106
+ AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
107
+ AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
108
+ AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
109
+ AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
110
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
111
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
112
+
113
+ AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
114
+ AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
115
+ AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
116
+ AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
117
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
118
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
119
+
120
+ const int batch_size = feat_c2.size(0);
121
+ const int num_views = feat_c2.size(1);
122
+ const int channels = feat_c2.size(4);
123
+ const int num_query = sampling_loc.size(1);
124
+ const int num_point = sampling_loc.size(2);
125
+ AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
126
+
127
+ const int h_c2 = feat_c2.size(2);
128
+ const int w_c2 = feat_c2.size(3);
129
+ const int h_c3 = feat_c3.size(2);
130
+ const int w_c3 = feat_c3.size(3);
131
+ const int h_c4 = feat_c4.size(2);
132
+ const int w_c4 = feat_c4.size(3);
133
+ const int h_c5 = feat_c5.size(2);
134
+ const int w_c5 = feat_c5.size(3);
135
+
136
+ auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());
137
+ ms_deformable_im2col_cuda_c2345(
138
+ feat_c2.data_ptr<float>(),
139
+ feat_c3.data_ptr<float>(),
140
+ feat_c4.data_ptr<float>(),
141
+ feat_c5.data_ptr<float>(),
142
+ h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
143
+ sampling_loc.data_ptr<float>(),
144
+ attn_weight.data_ptr<float>(),
145
+ batch_size, channels, num_views, num_query, num_point,
146
+ output.data_ptr<float>()
147
+ );
148
+
149
+ return output;
150
+ }
151
+
152
+ at::Tensor ms_deform_attn_cuda_c23456_forward(
153
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
154
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
155
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
156
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
157
+ const at::Tensor& feat_c6, // [B, N, H, W, C]
158
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
159
+ const at::Tensor& attn_weight // [B, Q, P, 4]
160
+ ) {
161
+ AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
162
+ AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
163
+ AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
164
+ AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
165
+ AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous");
166
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
167
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
168
+
169
+ AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
170
+ AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
171
+ AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
172
+ AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
173
+ AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor");
174
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
175
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
176
+
177
+ const int batch_size = feat_c2.size(0);
178
+ const int num_views = feat_c2.size(1);
179
+ const int channels = feat_c2.size(4);
180
+ const int num_query = sampling_loc.size(1);
181
+ const int num_point = sampling_loc.size(2);
182
+ AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
183
+
184
+ const int h_c2 = feat_c2.size(2);
185
+ const int w_c2 = feat_c2.size(3);
186
+ const int h_c3 = feat_c3.size(2);
187
+ const int w_c3 = feat_c3.size(3);
188
+ const int h_c4 = feat_c4.size(2);
189
+ const int w_c4 = feat_c4.size(3);
190
+ const int h_c5 = feat_c5.size(2);
191
+ const int w_c5 = feat_c5.size(3);
192
+ const int h_c6 = feat_c6.size(2);
193
+ const int w_c6 = feat_c6.size(3);
194
+
195
+ auto output = at::zeros({ batch_size, num_query, channels, num_point }, feat_c2.options());
196
+ ms_deformable_im2col_cuda_c23456(
197
+ feat_c2.data_ptr<float>(),
198
+ feat_c3.data_ptr<float>(),
199
+ feat_c4.data_ptr<float>(),
200
+ feat_c5.data_ptr<float>(),
201
+ feat_c6.data_ptr<float>(),
202
+ h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
203
+ sampling_loc.data_ptr<float>(),
204
+ attn_weight.data_ptr<float>(),
205
+ batch_size, channels, num_views, num_query, num_point,
206
+ output.data_ptr<float>()
207
+ );
208
+
209
+ return output;
210
+ }
211
+
212
+ std::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(
213
+ const at::Tensor& grad_output,
214
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
215
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
216
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
217
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
218
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
219
+ const at::Tensor& attn_weight // [B, Q, P, 4]
220
+ ) {
221
+ AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
222
+ AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
223
+ AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
224
+ AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
225
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
226
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
227
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
228
+
229
+ AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
230
+ AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
231
+ AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
232
+ AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
233
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
234
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
235
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
236
+
237
+ const int batch_size = feat_c2.size(0);
238
+ const int num_views = feat_c2.size(1);
239
+ const int channels = feat_c2.size(4);
240
+ const int num_query = sampling_loc.size(1);
241
+ const int num_point = sampling_loc.size(2);
242
+ AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
243
+
244
+ auto grad_value_c2 = at::zeros_like(feat_c2);
245
+ auto grad_value_c3 = at::zeros_like(feat_c3);
246
+ auto grad_value_c4 = at::zeros_like(feat_c4);
247
+ auto grad_value_c5 = at::zeros_like(feat_c5);
248
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
249
+ auto grad_attn_weight = at::zeros_like(attn_weight);
250
+
251
+ const int h_c2 = feat_c2.size(2);
252
+ const int w_c2 = feat_c2.size(3);
253
+ const int h_c3 = feat_c3.size(2);
254
+ const int w_c3 = feat_c3.size(3);
255
+ const int h_c4 = feat_c4.size(2);
256
+ const int w_c4 = feat_c4.size(3);
257
+ const int h_c5 = feat_c5.size(2);
258
+ const int w_c5 = feat_c5.size(3);
259
+
260
+ ms_deformable_col2im_cuda_c2345(
261
+ grad_output.data_ptr<float>(),
262
+ feat_c2.data_ptr<float>(),
263
+ feat_c3.data_ptr<float>(),
264
+ feat_c4.data_ptr<float>(),
265
+ feat_c5.data_ptr<float>(),
266
+ h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
267
+ sampling_loc.data_ptr<float>(),
268
+ attn_weight.data_ptr<float>(),
269
+ batch_size, channels, num_views, num_query, num_point,
270
+ grad_value_c2.data_ptr<float>(),
271
+ grad_value_c3.data_ptr<float>(),
272
+ grad_value_c4.data_ptr<float>(),
273
+ grad_value_c5.data_ptr<float>(),
274
+ grad_sampling_loc.data_ptr<float>(),
275
+ grad_attn_weight.data_ptr<float>()
276
+ );
277
+
278
+ return {
279
+ grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
280
+ };
281
+ }
282
+
283
+ std::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(
284
+ const at::Tensor& grad_output,
285
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
286
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
287
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
288
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
289
+ const at::Tensor& feat_c6, // [B, N, H, W, C]
290
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
291
+ const at::Tensor& attn_weight // [B, Q, P, 4]
292
+ ) {
293
+ AT_ASSERTM(feat_c2.is_contiguous(), "value tensor has to be contiguous");
294
+ AT_ASSERTM(feat_c3.is_contiguous(), "value tensor has to be contiguous");
295
+ AT_ASSERTM(feat_c4.is_contiguous(), "value tensor has to be contiguous");
296
+ AT_ASSERTM(feat_c5.is_contiguous(), "value tensor has to be contiguous");
297
+ AT_ASSERTM(feat_c6.is_contiguous(), "value tensor has to be contiguous");
298
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
299
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
300
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
301
+
302
+ AT_ASSERTM(feat_c2.is_cuda(), "value must be a CUDA tensor");
303
+ AT_ASSERTM(feat_c3.is_cuda(), "value must be a CUDA tensor");
304
+ AT_ASSERTM(feat_c4.is_cuda(), "value must be a CUDA tensor");
305
+ AT_ASSERTM(feat_c5.is_cuda(), "value must be a CUDA tensor");
306
+ AT_ASSERTM(feat_c6.is_cuda(), "value must be a CUDA tensor");
307
+ AT_ASSERTM(sampling_loc.is_cuda(), "sampling_loc must be a CUDA tensor");
308
+ AT_ASSERTM(attn_weight.is_cuda(), "attn_weight must be a CUDA tensor");
309
+ AT_ASSERTM(grad_output.is_cuda(), "grad_output must be a CUDA tensor");
310
+
311
+ const int batch_size = feat_c2.size(0);
312
+ const int num_views = feat_c2.size(1);
313
+ const int channels = feat_c2.size(4);
314
+ const int num_query = sampling_loc.size(1);
315
+ const int num_point = sampling_loc.size(2);
316
+ AT_ASSERTM(num_point <= MAX_POINT, "num_point exceed limits");
317
+
318
+ auto grad_value_c2 = at::zeros_like(feat_c2);
319
+ auto grad_value_c3 = at::zeros_like(feat_c3);
320
+ auto grad_value_c4 = at::zeros_like(feat_c4);
321
+ auto grad_value_c5 = at::zeros_like(feat_c5);
322
+ auto grad_value_c6 = at::zeros_like(feat_c6);
323
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
324
+ auto grad_attn_weight = at::zeros_like(attn_weight);
325
+
326
+ const int h_c2 = feat_c2.size(2);
327
+ const int w_c2 = feat_c2.size(3);
328
+ const int h_c3 = feat_c3.size(2);
329
+ const int w_c3 = feat_c3.size(3);
330
+ const int h_c4 = feat_c4.size(2);
331
+ const int w_c4 = feat_c4.size(3);
332
+ const int h_c5 = feat_c5.size(2);
333
+ const int w_c5 = feat_c5.size(3);
334
+ const int h_c6 = feat_c6.size(2);
335
+ const int w_c6 = feat_c6.size(3);
336
+
337
+ ms_deformable_col2im_cuda_c23456(
338
+ grad_output.data_ptr<float>(),
339
+ feat_c2.data_ptr<float>(),
340
+ feat_c3.data_ptr<float>(),
341
+ feat_c4.data_ptr<float>(),
342
+ feat_c5.data_ptr<float>(),
343
+ feat_c6.data_ptr<float>(),
344
+ h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
345
+ sampling_loc.data_ptr<float>(),
346
+ attn_weight.data_ptr<float>(),
347
+ batch_size, channels, num_views, num_query, num_point,
348
+ grad_value_c2.data_ptr<float>(),
349
+ grad_value_c3.data_ptr<float>(),
350
+ grad_value_c4.data_ptr<float>(),
351
+ grad_value_c5.data_ptr<float>(),
352
+ grad_value_c6.data_ptr<float>(),
353
+ grad_sampling_loc.data_ptr<float>(),
354
+ grad_attn_weight.data_ptr<float>()
355
+ );
356
+
357
+ return {
358
+ grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
359
+ };
360
+ }
361
+
362
+ #ifdef TORCH_EXTENSION_NAME
363
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
364
+ m.def("_ms_deform_attn_cuda_c2345_forward", &ms_deform_attn_cuda_c2345_forward, "pass");
365
+ m.def("_ms_deform_attn_cuda_c2345_backward", &ms_deform_attn_cuda_c2345_backward, "pass");
366
+ m.def("_ms_deform_attn_cuda_c23456_forward", &ms_deform_attn_cuda_c23456_forward, "pass");
367
+ m.def("_ms_deform_attn_cuda_c23456_backward", &ms_deform_attn_cuda_c23456_backward, "pass");
368
+ }
369
+ #endif
models/csrc/msmv_sampling/msmv_sampling.h ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/extension.h>
4
+
5
+ at::Tensor ms_deform_attn_cuda_c2345_forward(
6
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
7
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
8
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
9
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
10
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
11
+ const at::Tensor& attn_weight // [B, Q, P, 4]
12
+ );
13
+
14
+ std::vector<at::Tensor> ms_deform_attn_cuda_c2345_backward(
15
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
16
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
17
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
18
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
19
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
20
+ const at::Tensor& attn_weight, // [B, Q, P, 4]
21
+ const at::Tensor& grad_output
22
+ );
23
+
24
+ at::Tensor ms_deform_attn_cuda_c23456_forward(
25
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
26
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
27
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
28
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
29
+ const at::Tensor& feat_c6, // [B, N, H, W, C]
30
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
31
+ const at::Tensor& attn_weight // [B, Q, P, 4]
32
+ );
33
+
34
+ std::vector<at::Tensor> ms_deform_attn_cuda_c23456_backward(
35
+ const at::Tensor& grad_output,
36
+ const at::Tensor& feat_c2, // [B, N, H, W, C]
37
+ const at::Tensor& feat_c3, // [B, N, H, W, C]
38
+ const at::Tensor& feat_c4, // [B, N, H, W, C]
39
+ const at::Tensor& feat_c5, // [B, N, H, W, C]
40
+ const at::Tensor& feat_c6, // [B, N, H, W, C]
41
+ const at::Tensor& sampling_loc, // [B, Q, P, 3]
42
+ const at::Tensor& attn_weight // [B, Q, P, 4]
43
+ );
models/csrc/msmv_sampling/msmv_sampling_backward.cu ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Modified from Deformable DETR
3
+ */
4
+
5
+ #include <cstdio>
6
+ #include <iostream>
7
+ #include <algorithm>
8
+ #include <cstring>
9
+ #include <cuda_runtime.h>
10
+ #include <device_launch_parameters.h>
11
+ #include <torch/extension.h>
12
+ #include <ATen/ATen.h>
13
+ #include <ATen/cuda/CUDAContext.h>
14
+ #include <THC/THCAtomics.cuh>
15
+
16
+ #define CUDA_KERNEL_LOOP(i, n) \
17
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
18
+ i < (n); \
19
+ i += blockDim.x * gridDim.x)
20
+
21
+ #define CUDA_NUM_THREADS 512
22
+ #define MAX_POINT 32
23
+
24
+ inline int GET_BLOCKS(const int N, const int num_threads)
25
+ {
26
+ return (N + num_threads - 1) / num_threads;
27
+ }
28
+
29
+ __device__ void ms_deform_attn_col2im_bilinear(const float *&bottom_data,
30
+ const int &height, const int &width, const int &channels,
31
+ const float &h, const float &w, const int &c,
32
+ const float &top_grad,
33
+ const float &attn_weight,
34
+ const float *&grad_value,
35
+ float *&grad_sampling_loc,
36
+ float *&grad_attn_weight)
37
+ {
38
+ const int h_low = floor(h);
39
+ const int w_low = floor(w);
40
+ const int h_high = h_low + 1;
41
+ const int w_high = w_low + 1;
42
+
43
+ const float lh = h - h_low;
44
+ const float lw = w - w_low;
45
+ const float hh = 1 - lh, hw = 1 - lw;
46
+
47
+ const int w_stride = channels;
48
+ const int h_stride = width * w_stride;
49
+ const int h_low_ptr_offset = h_low * h_stride;
50
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
51
+ const int w_low_ptr_offset = w_low * w_stride;
52
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
53
+
54
+ const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
55
+ const float top_grad_value = top_grad * attn_weight;
56
+ float grad_h_weight = 0, grad_w_weight = 0;
57
+
58
+ float *grad_ptr;
59
+
60
+ float v1 = 0;
61
+ if (h_low >= 0 && w_low >= 0)
62
+ {
63
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
64
+ grad_ptr = const_cast<float *>(grad_value + ptr1);
65
+ v1 = bottom_data[ptr1];
66
+ grad_h_weight -= hw * v1;
67
+ grad_w_weight -= hh * v1;
68
+ atomicAdd(grad_ptr, w1 * top_grad_value);
69
+ }
70
+ float v2 = 0;
71
+ if (h_low >= 0 && w_high <= width - 1)
72
+ {
73
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
74
+ grad_ptr = const_cast<float *>(grad_value + ptr2);
75
+ v2 = bottom_data[ptr2];
76
+ grad_h_weight -= lw * v2;
77
+ grad_w_weight += hh * v2;
78
+ atomicAdd(grad_ptr, w2 * top_grad_value);
79
+ }
80
+ float v3 = 0;
81
+ if (h_high <= height - 1 && w_low >= 0)
82
+ {
83
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
84
+ grad_ptr = const_cast<float *>(grad_value + ptr3);
85
+ v3 = bottom_data[ptr3];
86
+ grad_h_weight += hw * v3;
87
+ grad_w_weight -= lh * v3;
88
+ atomicAdd(grad_ptr, w3 * top_grad_value);
89
+ }
90
+ float v4 = 0;
91
+ if (h_high <= height - 1 && w_high <= width - 1)
92
+ {
93
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
94
+ grad_ptr = const_cast<float *>(grad_value + ptr4);
95
+ v4 = bottom_data[ptr4];
96
+ grad_h_weight += lw * v4;
97
+ grad_w_weight += lh * v4;
98
+ atomicAdd(grad_ptr, w4 * top_grad_value);
99
+ }
100
+
101
+ const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
102
+ atomicAdd(grad_attn_weight, top_grad * val);
103
+ atomicAdd(grad_sampling_loc, (width - 1) * grad_w_weight * top_grad_value);
104
+ atomicAdd(grad_sampling_loc + 1, (height - 1) * grad_h_weight * top_grad_value);
105
+ }
106
+
107
+ // global_memory_way
108
+ __global__ void ms_deformable_col2im_gpu_kernel_gm_c2345(
109
+ const float *grad_col,
110
+ const float *feat_c2,
111
+ const float *feat_c3,
112
+ const float *feat_c4,
113
+ const float *feat_c5,
114
+ const int h_c2, const int w_c2,
115
+ const int h_c3, const int w_c3,
116
+ const int h_c4, const int w_c4,
117
+ const int h_c5, const int w_c5,
118
+ const float *data_sampling_loc,
119
+ const float *data_attn_weight,
120
+ const int batch_size,
121
+ const int channels,
122
+ const int num_views,
123
+ const int num_query,
124
+ const int num_point,
125
+ float *grad_value_c2,
126
+ float *grad_value_c3,
127
+ float *grad_value_c4,
128
+ float *grad_value_c5,
129
+ float *grad_sampling_loc,
130
+ float *grad_attn_weight)
131
+ {
132
+ CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)
133
+ { // n: bs x query x channels
134
+
135
+ int _temp = index;
136
+ const int p_col = _temp % num_point;
137
+ _temp /= num_point;
138
+ const int c_col = _temp % channels;
139
+ _temp /= channels;
140
+ const int sampling_index = _temp;
141
+ _temp /= num_query;
142
+ const int b_col = _temp;
143
+
144
+ const float top_grad = grad_col[index];
145
+
146
+ // Sampling location in range [0, 1]
147
+ int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
148
+ const float loc_w = data_sampling_loc[data_loc_ptr];
149
+ const float loc_h = data_sampling_loc[data_loc_ptr + 1];
150
+ const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
151
+
152
+ // Attn weights
153
+ int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
154
+
155
+ const float weight_c2 = data_attn_weight[data_weight_ptr];
156
+ const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
157
+ const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
158
+ const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
159
+
160
+ // const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
161
+ // const float w_im = loc_w * spatial_w - 0.5;
162
+
163
+ // C2 Feature
164
+ float h_im = loc_h * (h_c2 - 1); // align_corners = True
165
+ float w_im = loc_w * (w_c2 - 1);
166
+
167
+ float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;
168
+ float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;
169
+
170
+ if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)
171
+ {
172
+ const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
173
+ const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
174
+ ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,
175
+ top_grad, weight_c2,
176
+ grad_c2_ptr, grad_location_ptr, grad_weights_ptr);
177
+ }
178
+
179
+ grad_weights_ptr += 1;
180
+
181
+ // C3 Feature
182
+ h_im = loc_h * (h_c3 - 1); // align_corners = True
183
+ w_im = loc_w * (w_c3 - 1);
184
+
185
+ if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)
186
+ {
187
+ const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
188
+ const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
189
+ ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,
190
+ top_grad, weight_c3,
191
+ grad_c3_ptr, grad_location_ptr, grad_weights_ptr);
192
+ }
193
+
194
+ grad_weights_ptr += 1;
195
+
196
+ // C4 Feature
197
+ h_im = loc_h * (h_c4 - 1); // align_corners = True
198
+ w_im = loc_w * (w_c4 - 1);
199
+
200
+ if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)
201
+ {
202
+ const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
203
+ const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
204
+ ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,
205
+ top_grad, weight_c4,
206
+ grad_c4_ptr, grad_location_ptr, grad_weights_ptr);
207
+ }
208
+
209
+ grad_weights_ptr += 1;
210
+
211
+ // C5 Feature
212
+ h_im = loc_h * (h_c5 - 1); // align_corners = True
213
+ w_im = loc_w * (w_c5 - 1);
214
+
215
+ if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)
216
+ {
217
+ const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
218
+ const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
219
+ ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,
220
+ top_grad, weight_c5,
221
+ grad_c5_ptr, grad_location_ptr, grad_weights_ptr);
222
+ }
223
+ }
224
+ }
225
+
226
+ __global__ void ms_deformable_col2im_gpu_kernel_gm_c23456(
227
+ const float *grad_col,
228
+ const float *feat_c2,
229
+ const float *feat_c3,
230
+ const float *feat_c4,
231
+ const float *feat_c5,
232
+ const float *feat_c6,
233
+ const int h_c2, const int w_c2,
234
+ const int h_c3, const int w_c3,
235
+ const int h_c4, const int w_c4,
236
+ const int h_c5, const int w_c5,
237
+ const int h_c6, const int w_c6,
238
+ const float *data_sampling_loc,
239
+ const float *data_attn_weight,
240
+ const int batch_size,
241
+ const int channels,
242
+ const int num_views,
243
+ const int num_query,
244
+ const int num_point,
245
+ float *grad_value_c2,
246
+ float *grad_value_c3,
247
+ float *grad_value_c4,
248
+ float *grad_value_c5,
249
+ float *grad_value_c6,
250
+ float *grad_sampling_loc,
251
+ float *grad_attn_weight)
252
+ {
253
+ CUDA_KERNEL_LOOP(index, batch_size * num_query * channels * num_point)
254
+ { // n: bs x query x channels
255
+
256
+ int _temp = index;
257
+ const int p_col = _temp % num_point;
258
+ _temp /= num_point;
259
+ const int c_col = _temp % channels;
260
+ _temp /= channels;
261
+ const int sampling_index = _temp;
262
+ _temp /= num_query;
263
+ const int b_col = _temp;
264
+
265
+ const float top_grad = grad_col[index];
266
+
267
+ // Sampling location in range [0, 1]
268
+ int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
269
+ const float loc_w = data_sampling_loc[data_loc_ptr];
270
+ const float loc_h = data_sampling_loc[data_loc_ptr + 1];
271
+ const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
272
+
273
+ // Attn weights
274
+ int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
275
+
276
+ const float weight_c2 = data_attn_weight[data_weight_ptr];
277
+ const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
278
+ const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
279
+ const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
280
+ const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
281
+
282
+ // const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
283
+ // const float w_im = loc_w * spatial_w - 0.5;
284
+
285
+ // C2 Feature
286
+ float h_im = loc_h * (h_c2 - 1); // align_corners = True
287
+ float w_im = loc_w * (w_c2 - 1);
288
+
289
+ float *grad_location_ptr = grad_sampling_loc + data_loc_ptr;
290
+ float *grad_weights_ptr = grad_attn_weight + data_weight_ptr;
291
+
292
+ if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2)
293
+ {
294
+ const float *feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
295
+ const float *grad_c2_ptr = grad_value_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
296
+ ms_deform_attn_col2im_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col,
297
+ top_grad, weight_c2,
298
+ grad_c2_ptr, grad_location_ptr, grad_weights_ptr);
299
+ }
300
+
301
+ grad_weights_ptr += 1;
302
+
303
+ // C3 Feature
304
+ h_im = loc_h * (h_c3 - 1); // align_corners = True
305
+ w_im = loc_w * (w_c3 - 1);
306
+
307
+ if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3)
308
+ {
309
+ const float *feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
310
+ const float *grad_c3_ptr = grad_value_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
311
+ ms_deform_attn_col2im_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col,
312
+ top_grad, weight_c3,
313
+ grad_c3_ptr, grad_location_ptr, grad_weights_ptr);
314
+ }
315
+
316
+ grad_weights_ptr += 1;
317
+
318
+ // C4 Feature
319
+ h_im = loc_h * (h_c4 - 1); // align_corners = True
320
+ w_im = loc_w * (w_c4 - 1);
321
+
322
+ if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4)
323
+ {
324
+ const float *feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
325
+ const float *grad_c4_ptr = grad_value_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
326
+ ms_deform_attn_col2im_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col,
327
+ top_grad, weight_c4,
328
+ grad_c4_ptr, grad_location_ptr, grad_weights_ptr);
329
+ }
330
+
331
+ grad_weights_ptr += 1;
332
+
333
+ // C5 Feature
334
+ h_im = loc_h * (h_c5 - 1); // align_corners = True
335
+ w_im = loc_w * (w_c5 - 1);
336
+
337
+ if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5)
338
+ {
339
+ const float *feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
340
+ const float *grad_c5_ptr = grad_value_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
341
+ ms_deform_attn_col2im_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col,
342
+ top_grad, weight_c5,
343
+ grad_c5_ptr, grad_location_ptr, grad_weights_ptr);
344
+ }
345
+
346
+ grad_weights_ptr += 1;
347
+
348
+ // C6 Feature
349
+ h_im = loc_h * (h_c6 - 1); // align_corners = True
350
+ w_im = loc_w * (w_c6 - 1);
351
+
352
+ if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6)
353
+ {
354
+ const float *feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
355
+ const float *grad_c6_ptr = grad_value_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
356
+ ms_deform_attn_col2im_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col,
357
+ top_grad, weight_c6,
358
+ grad_c6_ptr, grad_location_ptr, grad_weights_ptr);
359
+ }
360
+ }
361
+ }
362
+
363
+ void ms_deformable_col2im_cuda_c2345(
364
+ const float *grad_col,
365
+ const float *feat_c2,
366
+ const float *feat_c3,
367
+ const float *feat_c4,
368
+ const float *feat_c5,
369
+ const int h_c2, const int w_c2,
370
+ const int h_c3, const int w_c3,
371
+ const int h_c4, const int w_c4,
372
+ const int h_c5, const int w_c5,
373
+ const float *data_sampling_loc,
374
+ const float *data_attn_weight,
375
+ const int batch_size,
376
+ const int channels,
377
+ const int num_views,
378
+ const int num_query,
379
+ const int num_point,
380
+ float *grad_value_c2,
381
+ float *grad_value_c3,
382
+ float *grad_value_c4,
383
+ float *grad_value_c5,
384
+ float *grad_sampling_loc,
385
+ float *grad_attn_weight)
386
+ {
387
+ const int num_kernels = batch_size * num_query * channels * num_point;
388
+ const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;
389
+
390
+ ms_deformable_col2im_gpu_kernel_gm_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>>(
391
+ grad_col, feat_c2, feat_c3, feat_c4, feat_c5,
392
+ h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
393
+ data_sampling_loc, data_attn_weight,
394
+ batch_size, channels, num_views, num_query, num_point,
395
+ grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5,
396
+ grad_sampling_loc, grad_attn_weight);
397
+
398
+ cudaError_t err = cudaGetLastError();
399
+ if (err != cudaSuccess)
400
+ {
401
+ printf("error in ms_deformable_col2im_cuda_c2345: %s\n", cudaGetErrorString(err));
402
+ }
403
+ }
404
+
405
+ void ms_deformable_col2im_cuda_c23456(
406
+ const float *grad_col,
407
+ const float *feat_c2,
408
+ const float *feat_c3,
409
+ const float *feat_c4,
410
+ const float *feat_c5,
411
+ const float *feat_c6,
412
+ const int h_c2, const int w_c2,
413
+ const int h_c3, const int w_c3,
414
+ const int h_c4, const int w_c4,
415
+ const int h_c5, const int w_c5,
416
+ const int h_c6, const int w_c6,
417
+ const float *data_sampling_loc,
418
+ const float *data_attn_weight,
419
+ const int batch_size,
420
+ const int channels,
421
+ const int num_views,
422
+ const int num_query,
423
+ const int num_point,
424
+ float *grad_value_c2,
425
+ float *grad_value_c3,
426
+ float *grad_value_c4,
427
+ float *grad_value_c5,
428
+ float *grad_value_c6,
429
+ float *grad_sampling_loc,
430
+ float *grad_attn_weight)
431
+ {
432
+ const int num_kernels = batch_size * num_query * channels * num_point;
433
+ const int num_threads = (channels * num_point > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels * num_point;
434
+
435
+ ms_deformable_col2im_gpu_kernel_gm_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>>(
436
+ grad_col, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
437
+ h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
438
+ data_sampling_loc, data_attn_weight,
439
+ batch_size, channels, num_views, num_query, num_point,
440
+ grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6,
441
+ grad_sampling_loc, grad_attn_weight);
442
+
443
+ cudaError_t err = cudaGetLastError();
444
+ if (err != cudaSuccess)
445
+ {
446
+ printf("error in ms_deformable_col2im_cuda_c23456: %s\n", cudaGetErrorString(err));
447
+ }
448
+ }
models/csrc/msmv_sampling/msmv_sampling_forward.cu ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*!
2
+ * Modified from Deformable DETR
3
+ */
4
+
5
+ #include <cstdio>
6
+ #include <algorithm>
7
+ #include <cstring>
8
+ #include <cuda_runtime.h>
9
+ #include <device_launch_parameters.h>
10
+ #include <torch/extension.h>
11
+ #include <ATen/ATen.h>
12
+ #include <ATen/cuda/CUDAContext.h>
13
+ #include <THC/THCAtomics.cuh>
14
+
15
+ #define CUDA_KERNEL_LOOP(i, n) \
16
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
17
+ i < (n); \
18
+ i += blockDim.x * gridDim.x)
19
+
20
+ #define CUDA_NUM_THREADS 512
21
+ #define MAX_POINT 32
22
+
23
+ inline int GET_BLOCKS(const int N, const int num_threads) {
24
+ return (N + num_threads - 1) / num_threads;
25
+ }
26
+
27
+ __device__ float ms_deform_attn_im2col_bilinear(
28
+ const float*& bottom_data,
29
+ const int& height, const int& width, const int& channels,
30
+ const float& h, const float& w, const int& c) {
31
+
32
+ const int h_low = floor(h);
33
+ const int w_low = floor(w);
34
+ const int h_high = h_low + 1;
35
+ const int w_high = w_low + 1;
36
+
37
+ const float lh = h - h_low;
38
+ const float lw = w - w_low;
39
+ const float hh = 1 - lh, hw = 1 - lw;
40
+
41
+ const int w_stride = channels;
42
+ const int h_stride = width * w_stride;
43
+ const int h_low_ptr_offset = h_low * h_stride;
44
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
45
+ const int w_low_ptr_offset = w_low * w_stride;
46
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
47
+
48
+ float v1 = 0;
49
+ if (h_low >= 0 && w_low >= 0) {
50
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + c;
51
+ v1 = bottom_data[ptr1];
52
+ }
53
+ float v2 = 0;
54
+ if (h_low >= 0 && w_high <= width - 1) {
55
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + c;
56
+ v2 = bottom_data[ptr2];
57
+ }
58
+ float v3 = 0;
59
+ if (h_high <= height - 1 && w_low >= 0) {
60
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + c;
61
+ v3 = bottom_data[ptr3];
62
+ }
63
+ float v4 = 0;
64
+ if (h_high <= height - 1 && w_high <= width - 1) {
65
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + c;
66
+ v4 = bottom_data[ptr4];
67
+ }
68
+
69
+ const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
70
+ const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
71
+
72
+ return val;
73
+ }
74
+
75
+ __global__ void ms_deformable_im2col_gpu_kernel_c2345(
76
+ const float* feat_c2,
77
+ const float* feat_c3,
78
+ const float* feat_c4,
79
+ const float* feat_c5,
80
+ const int h_c2, const int w_c2,
81
+ const int h_c3, const int w_c3,
82
+ const int h_c4, const int w_c4,
83
+ const int h_c5, const int w_c5,
84
+ const float* data_sampling_loc,
85
+ const float* data_attn_weight,
86
+ const int batch_size,
87
+ const int channels,
88
+ const int num_views,
89
+ const int num_query,
90
+ const int num_point,
91
+ float* data_col) {
92
+
93
+ float res[MAX_POINT];
94
+
95
+ CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
96
+ int _temp = index;
97
+ const int c_col = _temp % channels;
98
+ _temp /= channels;
99
+ const int sampling_index = _temp;
100
+ _temp /= num_query;
101
+ const int b_col = _temp;
102
+
103
+ for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
104
+
105
+ for (int p_col = 0; p_col < num_point; ++p_col) {
106
+ // Sampling location in range [0, 1]
107
+ int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
108
+ const float loc_w = data_sampling_loc[data_loc_ptr];
109
+ const float loc_h = data_sampling_loc[data_loc_ptr + 1];
110
+ const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
111
+
112
+ // Attn weights
113
+ int data_weight_ptr = sampling_index * num_point * 4 + p_col * 4;
114
+ const float weight_c2 = data_attn_weight[data_weight_ptr];
115
+ const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
116
+ const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
117
+ const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
118
+
119
+ //const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
120
+ //const float w_im = loc_w * spatial_w - 0.5;
121
+
122
+ // C2 Feature
123
+ float h_im = loc_h * (h_c2 - 1); // align_corners = True
124
+ float w_im = loc_w * (w_c2 - 1);
125
+
126
+ if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
127
+ const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
128
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
129
+ }
130
+
131
+ // C3 Feature
132
+ h_im = loc_h * (h_c3 - 1); // align_corners = True
133
+ w_im = loc_w * (w_c3 - 1);
134
+
135
+ if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
136
+ const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
137
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
138
+ }
139
+
140
+ // C4 Feature
141
+ h_im = loc_h * (h_c4 - 1); // align_corners = True
142
+ w_im = loc_w * (w_c4 - 1);
143
+
144
+ if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
145
+ const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
146
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
147
+ }
148
+
149
+ // C5 Feature
150
+ h_im = loc_h * (h_c5 - 1); // align_corners = True
151
+ w_im = loc_w * (w_c5 - 1);
152
+
153
+ if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
154
+ const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
155
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
156
+ }
157
+ }
158
+
159
+ for (int p_col = 0; p_col < num_point; ++p_col) {
160
+ float* data_col_ptr = data_col + index * num_point + p_col;
161
+ *data_col_ptr = res[p_col];
162
+ }
163
+ }
164
+ }
165
+
166
+ __global__ void ms_deformable_im2col_gpu_kernel_c23456(
167
+ const float* feat_c2,
168
+ const float* feat_c3,
169
+ const float* feat_c4,
170
+ const float* feat_c5,
171
+ const float* feat_c6,
172
+ const int h_c2, const int w_c2,
173
+ const int h_c3, const int w_c3,
174
+ const int h_c4, const int w_c4,
175
+ const int h_c5, const int w_c5,
176
+ const int h_c6, const int w_c6,
177
+ const float* data_sampling_loc,
178
+ const float* data_attn_weight,
179
+ const int batch_size,
180
+ const int channels,
181
+ const int num_views,
182
+ const int num_query,
183
+ const int num_point,
184
+ float* data_col) {
185
+
186
+ float res[MAX_POINT];
187
+
188
+ CUDA_KERNEL_LOOP(index, batch_size * num_query * channels) { // n: bs x query x channels
189
+ int _temp = index;
190
+ const int c_col = _temp % channels;
191
+ _temp /= channels;
192
+ const int sampling_index = _temp;
193
+ _temp /= num_query;
194
+ const int b_col = _temp;
195
+
196
+ for (int p_col = 0; p_col < num_point; ++p_col) { res[p_col] = 0; }
197
+
198
+ for (int p_col = 0; p_col < num_point; ++p_col) {
199
+ // Sampling location in range [0, 1]
200
+ int data_loc_ptr = sampling_index * num_point * 3 + p_col * 3;
201
+ const float loc_w = data_sampling_loc[data_loc_ptr];
202
+ const float loc_h = data_sampling_loc[data_loc_ptr + 1];
203
+ const int loc_v = round(data_sampling_loc[data_loc_ptr + 2] * (num_views - 1));
204
+
205
+ // Attn weights
206
+ int data_weight_ptr = sampling_index * num_point * 5 + p_col * 5;
207
+ const float weight_c2 = data_attn_weight[data_weight_ptr];
208
+ const float weight_c3 = data_attn_weight[data_weight_ptr + 1];
209
+ const float weight_c4 = data_attn_weight[data_weight_ptr + 2];
210
+ const float weight_c5 = data_attn_weight[data_weight_ptr + 3];
211
+ const float weight_c6 = data_attn_weight[data_weight_ptr + 4];
212
+
213
+ //const float h_im = loc_h * spatial_h - 0.5; // align_corners = False
214
+ //const float w_im = loc_w * spatial_w - 0.5;
215
+
216
+ // C2 Feature
217
+ float h_im = loc_h * (h_c2 - 1); // align_corners = True
218
+ float w_im = loc_w * (w_c2 - 1);
219
+
220
+ if (h_im > -1 && w_im > -1 && h_im < h_c2 && w_im < w_c2) {
221
+ const float* feat_c2_ptr = feat_c2 + b_col * num_views * h_c2 * w_c2 * channels + loc_v * h_c2 * w_c2 * channels;
222
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c2_ptr, h_c2, w_c2, channels, h_im, w_im, c_col) * weight_c2;
223
+ }
224
+
225
+ // C3 Feature
226
+ h_im = loc_h * (h_c3 - 1); // align_corners = True
227
+ w_im = loc_w * (w_c3 - 1);
228
+
229
+ if (h_im > -1 && w_im > -1 && h_im < h_c3 && w_im < w_c3) {
230
+ const float* feat_c3_ptr = feat_c3 + b_col * num_views * h_c3 * w_c3 * channels + loc_v * h_c3 * w_c3 * channels;
231
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c3_ptr, h_c3, w_c3, channels, h_im, w_im, c_col) * weight_c3;
232
+ }
233
+
234
+ // C4 Feature
235
+ h_im = loc_h * (h_c4 - 1); // align_corners = True
236
+ w_im = loc_w * (w_c4 - 1);
237
+
238
+ if (h_im > -1 && w_im > -1 && h_im < h_c4 && w_im < w_c4) {
239
+ const float* feat_c4_ptr = feat_c4 + b_col * num_views * h_c4 * w_c4 * channels + loc_v * h_c4 * w_c4 * channels;
240
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c4_ptr, h_c4, w_c4, channels, h_im, w_im, c_col) * weight_c4;
241
+ }
242
+
243
+ // C5 Feature
244
+ h_im = loc_h * (h_c5 - 1); // align_corners = True
245
+ w_im = loc_w * (w_c5 - 1);
246
+
247
+ if (h_im > -1 && w_im > -1 && h_im < h_c5 && w_im < w_c5) {
248
+ const float* feat_c5_ptr = feat_c5 + b_col * num_views * h_c5 * w_c5 * channels + loc_v * h_c5 * w_c5 * channels;
249
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c5_ptr, h_c5, w_c5, channels, h_im, w_im, c_col) * weight_c5;
250
+ }
251
+
252
+ // C6 Feature
253
+ h_im = loc_h * (h_c6 - 1); // align_corners = True
254
+ w_im = loc_w * (w_c6 - 1);
255
+
256
+ if (h_im > -1 && w_im > -1 && h_im < h_c6 && w_im < w_c6) {
257
+ const float* feat_c6_ptr = feat_c6 + b_col * num_views * h_c6 * w_c6 * channels + loc_v * h_c6 * w_c6 * channels;
258
+ res[p_col] += ms_deform_attn_im2col_bilinear(feat_c6_ptr, h_c6, w_c6, channels, h_im, w_im, c_col) * weight_c6;
259
+ }
260
+ }
261
+
262
+ for (int p_col = 0; p_col < num_point; ++p_col) {
263
+ float* data_col_ptr = data_col + index * num_point + p_col;
264
+ *data_col_ptr = res[p_col];
265
+ }
266
+ }
267
+ }
268
+
269
+ void ms_deformable_im2col_cuda_c2345(
270
+ const float* feat_c2,
271
+ const float* feat_c3,
272
+ const float* feat_c4,
273
+ const float* feat_c5,
274
+ const int h_c2, const int w_c2,
275
+ const int h_c3, const int w_c3,
276
+ const int h_c4, const int w_c4,
277
+ const int h_c5, const int w_c5,
278
+ const float* data_sampling_loc,
279
+ const float* data_attn_weight,
280
+ const int batch_size,
281
+ const int channels,
282
+ const int num_views,
283
+ const int num_query,
284
+ const int num_point,
285
+ float* data_col) {
286
+
287
+ const int num_kernels = batch_size * num_query * channels;
288
+ const int num_threads = CUDA_NUM_THREADS;
289
+
290
+ ms_deformable_im2col_gpu_kernel_c2345 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (
291
+ feat_c2, feat_c3, feat_c4, feat_c5, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5,
292
+ data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
293
+ );
294
+
295
+ cudaError_t err = cudaGetLastError();
296
+ if (err != cudaSuccess) {
297
+ printf("error in ms_deformable_im2col_cuda_c2345: %s\n", cudaGetErrorString(err));
298
+ }
299
+ }
300
+
301
+ void ms_deformable_im2col_cuda_c23456(
302
+ const float* feat_c2,
303
+ const float* feat_c3,
304
+ const float* feat_c4,
305
+ const float* feat_c5,
306
+ const float* feat_c6,
307
+ const int h_c2, const int w_c2,
308
+ const int h_c3, const int w_c3,
309
+ const int h_c4, const int w_c4,
310
+ const int h_c5, const int w_c5,
311
+ const int h_c6, const int w_c6,
312
+ const float* data_sampling_loc,
313
+ const float* data_attn_weight,
314
+ const int batch_size,
315
+ const int channels,
316
+ const int num_views,
317
+ const int num_query,
318
+ const int num_point,
319
+ float* data_col) {
320
+
321
+ const int num_kernels = batch_size * num_query * channels;
322
+ const int num_threads = CUDA_NUM_THREADS;
323
+
324
+ ms_deformable_im2col_gpu_kernel_c23456 <<<GET_BLOCKS(num_kernels, num_threads), num_threads>>> (
325
+ feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, h_c2, w_c2, h_c3, w_c3, h_c4, w_c4, h_c5, w_c5, h_c6, w_c6,
326
+ data_sampling_loc, data_attn_weight, batch_size, channels, num_views, num_query, num_point, data_col
327
+ );
328
+
329
+ cudaError_t err = cudaGetLastError();
330
+ if (err != cudaSuccess) {
331
+ printf("error in ms_deformable_im2col_cuda_c23456: %s\n", cudaGetErrorString(err));
332
+ }
333
+ }
models/csrc/setup.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup
2
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3
+
4
+
5
+ def get_ext_modules():
6
+ return [
7
+ CUDAExtension(
8
+ name='_msmv_sampling_cuda',
9
+ sources=[
10
+ 'msmv_sampling/msmv_sampling.cpp',
11
+ 'msmv_sampling/msmv_sampling_forward.cu',
12
+ 'msmv_sampling/msmv_sampling_backward.cu'
13
+ ],
14
+ include_dirs=['msmv_sampling']
15
+ )
16
+ ]
17
+
18
+
19
+ setup(
20
+ name='csrc',
21
+ ext_modules=get_ext_modules(),
22
+ cmdclass={'build_ext': BuildExtension}
23
+ )
24
+
models/csrc/wrapper.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c2345_forward, _ms_deform_attn_cuda_c2345_backward
4
+ from ._msmv_sampling_cuda import _ms_deform_attn_cuda_c23456_forward, _ms_deform_attn_cuda_c23456_backward
5
+
6
+
7
+ def msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights):
8
+ """
9
+ value: [B, N, H1W1 + H2W2..., C]
10
+ sampling_locations: [B, Q, P, 3]
11
+ scale_weights: [B, Q, P, 4]
12
+ """
13
+ assert scale_weights.shape[-1] == len(mlvl_feats)
14
+
15
+ B, _, _, _, C = mlvl_feats[0].shape
16
+ _, Q, P, _ = sampling_locations.shape
17
+
18
+ sampling_locations = sampling_locations * 2 - 1
19
+ sampling_locations = sampling_locations[:, :, :, None, :] # [B, Q, P, 1, 3]
20
+
21
+ final = torch.zeros([B, C, Q, P], device=mlvl_feats[0].device)
22
+
23
+ for lvl, feat in enumerate(mlvl_feats):
24
+ feat = feat.permute(0, 4, 1, 2, 3)
25
+ out = F.grid_sample(
26
+ feat, sampling_locations, mode='bilinear',
27
+ padding_mode='zeros', align_corners=True,
28
+ )[..., 0] # [B, C, Q, P]
29
+ out = out * scale_weights[..., lvl].reshape(B, 1, Q, P)
30
+ final += out
31
+
32
+ return final.permute(0, 2, 1, 3)
33
+
34
+
35
+ class MSMVSamplingC2345(torch.autograd.Function):
36
+ @staticmethod
37
+ def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights):
38
+ ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights)
39
+
40
+ assert callable(_ms_deform_attn_cuda_c2345_forward)
41
+ return _ms_deform_attn_cuda_c2345_forward(
42
+ feat_c2, feat_c3, feat_c4, feat_c5,
43
+ sampling_locations, scale_weights)
44
+
45
+ @staticmethod
46
+ def backward(ctx, grad_output):
47
+ feat_c2, feat_c3, feat_c4, feat_c5, sampling_locations, scale_weights = ctx.saved_tensors
48
+
49
+ assert callable(_ms_deform_attn_cuda_c2345_backward)
50
+ grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c2345_backward(grad_output.contiguous(),
51
+ feat_c2, feat_c3, feat_c4, feat_c5,
52
+ sampling_locations, scale_weights
53
+ )
54
+
55
+ return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_sampling_loc, grad_attn_weight
56
+
57
+
58
+ class MSMVSamplingC23456(torch.autograd.Function):
59
+ @staticmethod
60
+ def forward(ctx, feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights):
61
+ ctx.save_for_backward(feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights)
62
+
63
+ assert callable(_ms_deform_attn_cuda_c23456_forward)
64
+ return _ms_deform_attn_cuda_c23456_forward(
65
+ feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
66
+ sampling_locations, scale_weights)
67
+
68
+ @staticmethod
69
+ def backward(ctx, grad_output):
70
+ feat_c2, feat_c3, feat_c4, feat_c5, feat_c6, sampling_locations, scale_weights = ctx.saved_tensors
71
+
72
+ assert callable(_ms_deform_attn_cuda_c23456_backward)
73
+ grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight = _ms_deform_attn_cuda_c23456_backward(grad_output.contiguous(),
74
+ feat_c2, feat_c3, feat_c4, feat_c5, feat_c6,
75
+ sampling_locations, scale_weights
76
+ )
77
+
78
+ return grad_value_c2, grad_value_c3, grad_value_c4, grad_value_c5, grad_value_c6, grad_sampling_loc, grad_attn_weight
79
+
80
+
81
+ def msmv_sampling(mlvl_feats, sampling_locations, scale_weights):
82
+ if len(mlvl_feats) == 4:
83
+ return MSMVSamplingC2345.apply(*mlvl_feats, sampling_locations, scale_weights)
84
+ elif len(mlvl_feats) == 5:
85
+ return MSMVSamplingC23456.apply(*mlvl_feats, sampling_locations, scale_weights)
86
+ else:
87
+ return msmv_sampling_pytorch(mlvl_feats, sampling_locations, scale_weights)
models/sparsebev.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue
2
+ import torch
3
+ import numpy as np
4
+ from mmcv.runner import force_fp32, auto_fp16
5
+ from mmcv.runner import get_dist_info
6
+ from mmcv.runner.fp16_utils import cast_tensor_type
7
+ from mmdet.models import DETECTORS
8
+ from mmdet3d.core import bbox3d2result
9
+ from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector
10
+ from .utils import GridMask, pad_multiple, GpuPhotoMetricDistortion
11
+
12
+
13
+ @DETECTORS.register_module()
14
+ class SparseBEV(MVXTwoStageDetector):
15
+ def __init__(self,
16
+ data_aug=None,
17
+ stop_prev_grad=False,
18
+ pts_voxel_layer=None,
19
+ pts_voxel_encoder=None,
20
+ pts_middle_encoder=None,
21
+ pts_fusion_layer=None,
22
+ img_backbone=None,
23
+ pts_backbone=None,
24
+ img_neck=None,
25
+ pts_neck=None,
26
+ pts_bbox_head=None,
27
+ img_roi_head=None,
28
+ img_rpn_head=None,
29
+ train_cfg=None,
30
+ test_cfg=None,
31
+ pretrained=None):
32
+ super(SparseBEV, self).__init__(pts_voxel_layer, pts_voxel_encoder,
33
+ pts_middle_encoder, pts_fusion_layer,
34
+ img_backbone, pts_backbone, img_neck, pts_neck,
35
+ pts_bbox_head, img_roi_head, img_rpn_head,
36
+ train_cfg, test_cfg, pretrained)
37
+ self.data_aug = data_aug
38
+ self.stop_prev_grad = stop_prev_grad
39
+ self.color_aug = GpuPhotoMetricDistortion()
40
+ self.grid_mask = GridMask(ratio=0.5, prob=0.7)
41
+ self.use_grid_mask = True
42
+
43
+ self.memory = {}
44
+ self.queue = queue.Queue()
45
+
46
+ @auto_fp16(apply_to=('img'), out_fp32=True)
47
+ def extract_img_feat(self, img):
48
+ if self.use_grid_mask:
49
+ img = self.grid_mask(img)
50
+
51
+ img_feats = self.img_backbone(img)
52
+
53
+ if isinstance(img_feats, dict):
54
+ img_feats = list(img_feats.values())
55
+
56
+ if self.with_img_neck:
57
+ img_feats = self.img_neck(img_feats)
58
+
59
+ return img_feats
60
+
61
+ def extract_feat(self, img, img_metas):
62
+ if isinstance(img, list):
63
+ img = torch.stack(img, dim=0)
64
+
65
+ assert img.dim() == 5
66
+
67
+ B, N, C, H, W = img.size()
68
+ img = img.view(B * N, C, H, W)
69
+ img = img.float()
70
+
71
+ # move some augmentations to GPU
72
+ if self.data_aug is not None:
73
+ if 'img_color_aug' in self.data_aug and self.data_aug['img_color_aug'] and self.training:
74
+ img = self.color_aug(img)
75
+
76
+ if 'img_norm_cfg' in self.data_aug:
77
+ img_norm_cfg = self.data_aug['img_norm_cfg']
78
+
79
+ norm_mean = torch.tensor(img_norm_cfg['mean'], device=img.device)
80
+ norm_std = torch.tensor(img_norm_cfg['std'], device=img.device)
81
+
82
+ if img_norm_cfg['to_rgb']:
83
+ img = img[:, [2, 1, 0], :, :] # BGR to RGB
84
+
85
+ img = img - norm_mean.reshape(1, 3, 1, 1)
86
+ img = img / norm_std.reshape(1, 3, 1, 1)
87
+
88
+ for b in range(B):
89
+ img_shape = (img.shape[2], img.shape[3], img.shape[1])
90
+ img_metas[b]['img_shape'] = [img_shape for _ in range(N)]
91
+ img_metas[b]['ori_shape'] = [img_shape for _ in range(N)]
92
+
93
+ if 'img_pad_cfg' in self.data_aug:
94
+ img_pad_cfg = self.data_aug['img_pad_cfg']
95
+ img = pad_multiple(img, img_metas, size_divisor=img_pad_cfg['size_divisor'])
96
+
97
+ input_shape = img.shape[-2:]
98
+ # update real input shape of each single img
99
+ for img_meta in img_metas:
100
+ img_meta.update(input_shape=input_shape)
101
+
102
+ if self.training and self.stop_prev_grad:
103
+ H, W = input_shape
104
+ img = img.reshape(B, -1, 6, C, H, W)
105
+
106
+ img_grad = img[:, :1]
107
+ img_nograd = img[:, 1:]
108
+
109
+ all_img_feats = [self.extract_img_feat(img_grad.reshape(-1, C, H, W))]
110
+
111
+ with torch.no_grad():
112
+ self.eval()
113
+ for k in range(img_nograd.shape[1]):
114
+ all_img_feats.append(self.extract_img_feat(img_nograd[:, k].reshape(-1, C, H, W)))
115
+ self.train()
116
+
117
+ img_feats = []
118
+ for lvl in range(len(all_img_feats[0])):
119
+ C, H, W = all_img_feats[0][lvl].shape[1:]
120
+ img_feat = torch.cat([feat[lvl].reshape(B, -1, 6, C, H, W) for feat in all_img_feats], dim=1)
121
+ img_feat = img_feat.reshape(-1, C, H, W)
122
+ img_feats.append(img_feat)
123
+ else:
124
+ img_feats = self.extract_img_feat(img)
125
+
126
+ img_feats_reshaped = []
127
+ for img_feat in img_feats:
128
+ BN, C, H, W = img_feat.size()
129
+ img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W))
130
+
131
+ return img_feats_reshaped
132
+
133
+ def forward_pts_train(self,
134
+ pts_feats,
135
+ gt_bboxes_3d,
136
+ gt_labels_3d,
137
+ img_metas,
138
+ gt_bboxes_ignore=None):
139
+ """Forward function for point cloud branch.
140
+ Args:
141
+ pts_feats (list[torch.Tensor]): Features of point cloud branch
142
+ gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth
143
+ boxes for each sample.
144
+ gt_labels_3d (list[torch.Tensor]): Ground truth labels for
145
+ boxes of each sampole
146
+ img_metas (list[dict]): Meta information of samples.
147
+ gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
148
+ boxes to be ignored. Defaults to None.
149
+ Returns:
150
+ dict: Losses of each branch.
151
+ """
152
+ outs = self.pts_bbox_head(pts_feats, img_metas)
153
+ loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs]
154
+ losses = self.pts_bbox_head.loss(*loss_inputs)
155
+
156
+ return losses
157
+
158
+ @force_fp32(apply_to=('img', 'points'))
159
+ def forward(self, return_loss=True, **kwargs):
160
+ """Calls either forward_train or forward_test depending on whether
161
+ return_loss=True.
162
+ Note this setting will change the expected inputs. When
163
+ `return_loss=True`, img and img_metas are single-nested (i.e.
164
+ torch.Tensor and list[dict]), and when `resturn_loss=False`, img and
165
+ img_metas should be double nested (i.e. list[torch.Tensor],
166
+ list[list[dict]]), with the outer list indicating test time
167
+ augmentations.
168
+ """
169
+ if return_loss:
170
+ return self.forward_train(**kwargs)
171
+ else:
172
+ return self.forward_test(**kwargs)
173
+
174
+ def forward_train(self,
175
+ points=None,
176
+ img_metas=None,
177
+ gt_bboxes_3d=None,
178
+ gt_labels_3d=None,
179
+ gt_labels=None,
180
+ gt_bboxes=None,
181
+ img=None,
182
+ proposals=None,
183
+ gt_bboxes_ignore=None,
184
+ img_depth=None,
185
+ img_mask=None):
186
+ """Forward training function.
187
+ Args:
188
+ points (list[torch.Tensor], optional): Points of each sample.
189
+ Defaults to None.
190
+ img_metas (list[dict], optional): Meta information of each sample.
191
+ Defaults to None.
192
+ gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional):
193
+ Ground truth 3D boxes. Defaults to None.
194
+ gt_labels_3d (list[torch.Tensor], optional): Ground truth labels
195
+ of 3D boxes. Defaults to None.
196
+ gt_labels (list[torch.Tensor], optional): Ground truth labels
197
+ of 2D boxes in images. Defaults to None.
198
+ gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in
199
+ images. Defaults to None.
200
+ img (torch.Tensor optional): Images of each sample with shape
201
+ (N, C, H, W). Defaults to None.
202
+ proposals ([list[torch.Tensor], optional): Predicted proposals
203
+ used for training Fast RCNN. Defaults to None.
204
+ gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth
205
+ 2D boxes in images to be ignored. Defaults to None.
206
+ Returns:
207
+ dict: Losses of different branches.
208
+ """
209
+ img_feats = self.extract_feat(img, img_metas)
210
+
211
+ for i in range(len(img_metas)):
212
+ img_metas[i]['gt_bboxes_3d'] = gt_bboxes_3d[i]
213
+ img_metas[i]['gt_labels_3d'] = gt_labels_3d[i]
214
+
215
+ losses = self.forward_pts_train(img_feats, gt_bboxes_3d, gt_labels_3d, img_metas, gt_bboxes_ignore)
216
+
217
+ return losses
218
+
219
+ def forward_test(self, img_metas, img=None, **kwargs):
220
+ for var, name in [(img_metas, 'img_metas')]:
221
+ if not isinstance(var, list):
222
+ raise TypeError('{} must be a list, but got {}'.format(
223
+ name, type(var)))
224
+ img = [img] if img is None else img
225
+ return self.simple_test(img_metas[0], img[0], **kwargs)
226
+
227
+ def simple_test_pts(self, x, img_metas, rescale=False):
228
+ outs = self.pts_bbox_head(x, img_metas)
229
+ bbox_list = self.pts_bbox_head.get_bboxes(outs, img_metas[0], rescale=rescale)
230
+
231
+ bbox_results = [
232
+ bbox3d2result(bboxes, scores, labels)
233
+ for bboxes, scores, labels in bbox_list
234
+ ]
235
+
236
+ return bbox_results
237
+
238
+ def simple_test(self, img_metas, img=None, rescale=False):
239
+ world_size = get_dist_info()[1]
240
+ if world_size == 1: # online
241
+ return self.simple_test_online(img_metas, img, rescale)
242
+ elif world_size > 1: # offline
243
+ return self.simple_test_offline(img_metas, img, rescale)
244
+
245
+ def simple_test_offline(self, img_metas, img=None, rescale=False):
246
+ self.fp16_enabled = False
247
+ img_feats = self.extract_feat(img=img, img_metas=img_metas)
248
+
249
+ bbox_list = [dict() for _ in range(len(img_metas))]
250
+ bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
251
+ for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
252
+ result_dict['pts_bbox'] = pts_bbox
253
+
254
+ return bbox_list
255
+
256
+ def simple_test_online(self, img_metas, img=None, rescale=False):
257
+ self.fp16_enabled = False
258
+ assert len(img_metas) == 1 # batch_size = 1
259
+
260
+ B, N, C, H, W = img.shape
261
+ img = img.reshape(B, N//6, 6, C, H, W)
262
+
263
+ img_filenames = img_metas[0]['filename']
264
+ num_frames = len(img_filenames) // 6
265
+ # assert num_frames == img.shape[1]
266
+
267
+ img_shape = (H, W, C)
268
+ img_metas[0]['img_shape'] = [img_shape for _ in range(len(img_filenames))]
269
+ img_metas[0]['ori_shape'] = [img_shape for _ in range(len(img_filenames))]
270
+ img_metas[0]['pad_shape'] = [img_shape for _ in range(len(img_filenames))]
271
+
272
+ img_feats_large, img_metas_large = [], []
273
+
274
+ for i in range(num_frames):
275
+ img_indices = list(np.arange(i * 6, (i + 1) * 6))
276
+
277
+ img_curr_large = img[:, 0] # [B, 6, C, H, W]
278
+ img_metas_curr_large = [{}]
279
+
280
+ for k in img_metas[0].keys():
281
+ if isinstance(img_metas[0][k], list):
282
+ img_metas_curr_large[0][k] = [img_metas[0][k][i] for i in img_indices]
283
+
284
+ if img_filenames[img_indices[0]] in self.memory:
285
+ img_feats_curr_large = self.memory[img_filenames[img_indices[0]]]
286
+ else:
287
+ assert i == 0
288
+ img_feats_curr_large = self.extract_feat(img_curr_large, img_metas_curr_large)
289
+ self.memory[img_filenames[img_indices[0]]] = img_feats_curr_large
290
+ self.queue.put(img_filenames[img_indices[0]])
291
+
292
+ img_feats_large.append(img_feats_curr_large)
293
+ img_metas_large.append(img_metas_curr_large)
294
+
295
+ # reorganize
296
+ feat_levels = len(img_feats_large[0])
297
+ img_feats_large_reorganized = []
298
+ for j in range(feat_levels):
299
+ feat_l = torch.cat([img_feats_large[i][j] for i in range(len(img_feats_large))], dim=0)
300
+ feat_l = feat_l.flatten(0, 1)[None, ...]
301
+ img_feats_large_reorganized.append(feat_l)
302
+
303
+ img_metas_large_reorganized = img_metas_large[0]
304
+ for i in range(1, len(img_metas_large)):
305
+ for k, v in img_metas_large[i][0].items():
306
+ if isinstance(v, list):
307
+ img_metas_large_reorganized[0][k].extend(v)
308
+
309
+ img_feats = img_feats_large_reorganized
310
+ img_metas = img_metas_large_reorganized
311
+ img_feats = cast_tensor_type(img_feats, torch.half, torch.float32)
312
+
313
+ bbox_list = [dict() for _ in range(1)]
314
+ bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale)
315
+ for result_dict, pts_bbox in zip(bbox_list, bbox_pts):
316
+ result_dict['pts_bbox'] = pts_bbox
317
+
318
+ while self.queue.qsize() >= 8:
319
+ pop_key = self.queue.get()
320
+ self.memory.pop(pop_key)
321
+
322
+ return bbox_list
models/sparsebev_head.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from mmcv.runner import force_fp32
5
+ from mmdet.core import multi_apply, reduce_mean
6
+ from mmdet.models import HEADS
7
+ from mmdet.models.dense_heads import DETRHead
8
+ from mmdet3d.core.bbox.coders import build_bbox_coder
9
+ from mmdet3d.core.bbox.structures.lidar_box3d import LiDARInstance3DBoxes
10
+ from .bbox.utils import normalize_bbox, encode_bbox
11
+
12
+
13
+ @HEADS.register_module()
14
+ class SparseBEVHead(DETRHead):
15
+ def __init__(self,
16
+ *args,
17
+ num_classes,
18
+ in_channels,
19
+ query_denoising=True,
20
+ query_denoising_groups=10,
21
+ bbox_coder=None,
22
+ code_size=10,
23
+ code_weights=[1.0] * 10,
24
+ train_cfg=dict(),
25
+ test_cfg=dict(max_per_img=100),
26
+ **kwargs):
27
+ self.code_size = code_size
28
+ self.code_weights = code_weights
29
+ self.num_classes = num_classes
30
+ self.in_channels = in_channels
31
+ self.train_cfg = train_cfg
32
+ self.test_cfg = test_cfg
33
+ self.fp16_enabled = False
34
+ self.embed_dims = in_channels
35
+
36
+ super(SparseBEVHead, self).__init__(num_classes, in_channels, train_cfg=train_cfg, test_cfg=test_cfg, **kwargs)
37
+
38
+ self.code_weights = nn.Parameter(torch.tensor(self.code_weights), requires_grad=False)
39
+ self.bbox_coder = build_bbox_coder(bbox_coder)
40
+ self.pc_range = self.bbox_coder.pc_range
41
+
42
+ self.dn_enabled = query_denoising
43
+ self.dn_group_num = query_denoising_groups
44
+ self.dn_weight = 1.0
45
+ self.dn_bbox_noise_scale = 0.5
46
+ self.dn_label_noise_scale = 0.5
47
+
48
+ def _init_layers(self):
49
+ self.init_query_bbox = nn.Embedding(self.num_query, 10) # (x, y, z, w, l, h, sin, cos, vx, vy)
50
+ self.label_enc = nn.Embedding(self.num_classes + 1, self.embed_dims - 1) # DAB-DETR
51
+
52
+ nn.init.zeros_(self.init_query_bbox.weight[:, 2:3])
53
+ nn.init.zeros_(self.init_query_bbox.weight[:, 8:10])
54
+ nn.init.constant_(self.init_query_bbox.weight[:, 5:6], 1.5)
55
+
56
+ grid_size = int(math.sqrt(self.num_query))
57
+ assert grid_size * grid_size == self.num_query
58
+ x = y = torch.arange(grid_size)
59
+ xx, yy = torch.meshgrid(x, y, indexing='ij') # [0, grid_size - 1]
60
+ xy = torch.cat([xx[..., None], yy[..., None]], dim=-1)
61
+ xy = (xy + 0.5) / grid_size # [0.5, grid_size - 0.5] / grid_size ~= (0, 1)
62
+ with torch.no_grad():
63
+ self.init_query_bbox.weight[:, :2] = xy.reshape(-1, 2) # [Q, 2]
64
+
65
+ def init_weights(self):
66
+ self.transformer.init_weights()
67
+
68
+ def forward(self, mlvl_feats, img_metas):
69
+ query_bbox = self.init_query_bbox.weight.clone() # [Q, 10]
70
+ #query_bbox[..., :3] = query_bbox[..., :3].sigmoid()
71
+
72
+ B = mlvl_feats[0].shape[0]
73
+ query_bbox, query_feat, attn_mask, mask_dict = self.prepare_for_dn_input(B, query_bbox, self.label_enc, img_metas)
74
+
75
+ cls_scores, bbox_preds = self.transformer(
76
+ query_bbox,
77
+ query_feat,
78
+ mlvl_feats,
79
+ attn_mask=attn_mask,
80
+ img_metas=img_metas,
81
+ )
82
+
83
+ bbox_preds[..., 0] = bbox_preds[..., 0] * (self.pc_range[3] - self.pc_range[0]) + self.pc_range[0]
84
+ bbox_preds[..., 1] = bbox_preds[..., 1] * (self.pc_range[4] - self.pc_range[1]) + self.pc_range[1]
85
+ bbox_preds[..., 2] = bbox_preds[..., 2] * (self.pc_range[5] - self.pc_range[2]) + self.pc_range[2]
86
+
87
+ bbox_preds = torch.cat([
88
+ bbox_preds[..., 0:2],
89
+ bbox_preds[..., 3:5],
90
+ bbox_preds[..., 2:3],
91
+ bbox_preds[..., 5:10],
92
+ ], dim=-1) # [cx, cy, w, l, cz, h, sin, cos, vx, vy]
93
+
94
+ if mask_dict is not None and mask_dict['pad_size'] > 0:
95
+ output_known_cls_scores = cls_scores[:, :, :mask_dict['pad_size'], :]
96
+ output_known_bbox_preds = bbox_preds[:, :, :mask_dict['pad_size'], :]
97
+ output_cls_scores = cls_scores[:, :, mask_dict['pad_size']:, :]
98
+ output_bbox_preds = bbox_preds[:, :, mask_dict['pad_size']:, :]
99
+ mask_dict['output_known_lbs_bboxes'] = (output_known_cls_scores, output_known_bbox_preds)
100
+ outs = {
101
+ 'all_cls_scores': output_cls_scores,
102
+ 'all_bbox_preds': output_bbox_preds,
103
+ 'enc_cls_scores': None,
104
+ 'enc_bbox_preds': None,
105
+ 'dn_mask_dict': mask_dict,
106
+ }
107
+ else:
108
+ outs = {
109
+ 'all_cls_scores': cls_scores,
110
+ 'all_bbox_preds': bbox_preds,
111
+ 'enc_cls_scores': None,
112
+ 'enc_bbox_preds': None,
113
+ }
114
+
115
+ return outs
116
+
117
+ def prepare_for_dn_input(self, batch_size, init_query_bbox, label_enc, img_metas):
118
+ device = init_query_bbox.device
119
+ indicator0 = torch.zeros([self.num_query, 1], device=device)
120
+ init_query_feat = label_enc.weight[self.num_classes].repeat(self.num_query, 1)
121
+ init_query_feat = torch.cat([init_query_feat, indicator0], dim=1)
122
+
123
+ if self.training and self.dn_enabled:
124
+ targets = [{
125
+ 'bboxes': torch.cat([m['gt_bboxes_3d'].gravity_center,
126
+ m['gt_bboxes_3d'].tensor[:, 3:]], dim=1).cuda(),
127
+ 'labels': m['gt_labels_3d'].cuda().long()
128
+ } for m in img_metas]
129
+
130
+ known = [torch.ones_like(t['labels'], device=device) for t in targets]
131
+ known_num = [sum(k) for k in known]
132
+
133
+ # can be modified to selectively denosie some label or boxes; also known label prediction
134
+ unmask_bbox = unmask_label = torch.cat(known)
135
+ labels = torch.cat([t['labels'] for t in targets]).clone()
136
+ bboxes = torch.cat([t['bboxes'] for t in targets]).clone()
137
+ batch_idx = torch.cat([torch.full_like(t['labels'].long(), i) for i, t in enumerate(targets)])
138
+
139
+ known_indice = torch.nonzero(unmask_label + unmask_bbox)
140
+ known_indice = known_indice.view(-1)
141
+
142
+ # add noise
143
+ known_indice = known_indice.repeat(self.dn_group_num, 1).view(-1)
144
+ known_labels = labels.repeat(self.dn_group_num, 1).view(-1)
145
+ known_bid = batch_idx.repeat(self.dn_group_num, 1).view(-1)
146
+ known_bboxs = bboxes.repeat(self.dn_group_num, 1) # 9
147
+ known_labels_expand = known_labels.clone()
148
+ known_bbox_expand = known_bboxs.clone()
149
+
150
+ # noise on the box
151
+ if self.dn_bbox_noise_scale > 0:
152
+ wlh = known_bbox_expand[..., 3:6].clone()
153
+ rand_prob = torch.rand_like(known_bbox_expand) * 2 - 1.0
154
+ known_bbox_expand[..., 0:3] += torch.mul(rand_prob[..., 0:3], wlh / 2) * self.dn_bbox_noise_scale
155
+ # known_bbox_expand[..., 3:6] += torch.mul(rand_prob[..., 3:6], wlh) * self.dn_bbox_noise_scale
156
+ # known_bbox_expand[..., 6:7] += torch.mul(rand_prob[..., 6:7], 3.14159) * self.dn_bbox_noise_scale
157
+
158
+ known_bbox_expand = encode_bbox(known_bbox_expand, self.pc_range)
159
+ known_bbox_expand[..., 0:3].clamp_(min=0.0, max=1.0)
160
+ # nn.init.constant(known_bbox_expand[..., 8:10], 0.0)
161
+
162
+ # noise on the label
163
+ if self.dn_label_noise_scale > 0:
164
+ p = torch.rand_like(known_labels_expand.float())
165
+ chosen_indice = torch.nonzero(p < self.dn_label_noise_scale).view(-1) # usually half of bbox noise
166
+ new_label = torch.randint_like(chosen_indice, 0, self.num_classes) # randomly put a new one here
167
+ known_labels_expand.scatter_(0, chosen_indice, new_label)
168
+
169
+ known_feat_expand = label_enc(known_labels_expand)
170
+ indicator1 = torch.ones([known_feat_expand.shape[0], 1], device=device) # add dn part indicator
171
+ known_feat_expand = torch.cat([known_feat_expand, indicator1], dim=1)
172
+
173
+ # construct final query
174
+ dn_single_pad = int(max(known_num))
175
+ dn_pad_size = int(dn_single_pad * self.dn_group_num)
176
+ dn_query_bbox = torch.zeros([dn_pad_size, init_query_bbox.shape[-1]], device=device)
177
+ dn_query_feat = torch.zeros([dn_pad_size, self.embed_dims], device=device)
178
+ input_query_bbox = torch.cat([dn_query_bbox, init_query_bbox], dim=0).repeat(batch_size, 1, 1)
179
+ input_query_feat = torch.cat([dn_query_feat, init_query_feat], dim=0).repeat(batch_size, 1, 1)
180
+
181
+ if len(known_num):
182
+ map_known_indice = torch.cat([torch.tensor(range(num)) for num in known_num]) # [1,2, 1,2,3]
183
+ map_known_indice = torch.cat([map_known_indice + dn_single_pad * i for i in range(self.dn_group_num)]).long()
184
+
185
+ if len(known_bid):
186
+ input_query_bbox[known_bid.long(), map_known_indice] = known_bbox_expand
187
+ input_query_feat[(known_bid.long(), map_known_indice)] = known_feat_expand
188
+
189
+ total_size = dn_pad_size + self.num_query
190
+ attn_mask = torch.ones([total_size, total_size], device=device) < 0
191
+
192
+ # match query cannot see the reconstruct
193
+ attn_mask[dn_pad_size:, :dn_pad_size] = True
194
+ for i in range(self.dn_group_num):
195
+ if i == 0:
196
+ attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
197
+ if i == self.dn_group_num - 1:
198
+ attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
199
+ else:
200
+ attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), dn_single_pad * (i + 1):dn_pad_size] = True
201
+ attn_mask[dn_single_pad * i:dn_single_pad * (i + 1), :dn_single_pad * i] = True
202
+
203
+ mask_dict = {
204
+ 'known_indice': torch.as_tensor(known_indice).long(),
205
+ 'batch_idx': torch.as_tensor(batch_idx).long(),
206
+ 'map_known_indice': torch.as_tensor(map_known_indice).long(),
207
+ 'known_lbs_bboxes': (known_labels, known_bboxs),
208
+ 'pad_size': dn_pad_size
209
+ }
210
+ else:
211
+ input_query_bbox = init_query_bbox.repeat(batch_size, 1, 1)
212
+ input_query_feat = init_query_feat.repeat(batch_size, 1, 1)
213
+ attn_mask = None
214
+ mask_dict = None
215
+
216
+ return input_query_bbox, input_query_feat, attn_mask, mask_dict
217
+
218
+ def prepare_for_dn_loss(self, mask_dict):
219
+ cls_scores, bbox_preds = mask_dict['output_known_lbs_bboxes']
220
+ known_labels, known_bboxs = mask_dict['known_lbs_bboxes']
221
+ map_known_indice = mask_dict['map_known_indice'].long()
222
+ known_indice = mask_dict['known_indice'].long()
223
+ batch_idx = mask_dict['batch_idx'].long()
224
+ bid = batch_idx[known_indice]
225
+ num_tgt = known_indice.numel()
226
+
227
+ if len(cls_scores) > 0:
228
+ cls_scores = cls_scores.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
229
+ bbox_preds = bbox_preds.permute(1, 2, 0, 3)[(bid, map_known_indice)].permute(1, 0, 2)
230
+
231
+ return known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt
232
+
233
+ def dn_loss_single(self,
234
+ cls_scores,
235
+ bbox_preds,
236
+ known_bboxs,
237
+ known_labels,
238
+ num_total_pos=None):
239
+ # Compute the average number of gt boxes accross all gpus
240
+ num_total_pos = cls_scores.new_tensor([num_total_pos])
241
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1.0).item()
242
+
243
+ # cls loss
244
+ cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
245
+ bbox_weights = torch.ones_like(bbox_preds)
246
+ label_weights = torch.ones_like(known_labels)
247
+ loss_cls = self.loss_cls(
248
+ cls_scores,
249
+ known_labels.long(),
250
+ label_weights,
251
+ avg_factor=num_total_pos
252
+ )
253
+
254
+ # regression L1 loss
255
+ bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
256
+ normalized_bbox_targets = normalize_bbox(known_bboxs)
257
+ isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
258
+ bbox_weights = bbox_weights * self.code_weights
259
+ loss_bbox = self.loss_bbox(
260
+ bbox_preds[isnotnan, :10],
261
+ normalized_bbox_targets[isnotnan, :10],
262
+ bbox_weights[isnotnan, :10],
263
+ avg_factor=num_total_pos
264
+ )
265
+
266
+ loss_cls = self.dn_weight * torch.nan_to_num(loss_cls)
267
+ loss_bbox = self.dn_weight * torch.nan_to_num(loss_bbox)
268
+
269
+ return loss_cls, loss_bbox
270
+
271
+ @force_fp32(apply_to=('preds_dicts'))
272
+ def calc_dn_loss(self, loss_dict, preds_dicts, num_dec_layers):
273
+ known_labels, known_bboxs, cls_scores, bbox_preds, num_tgt = \
274
+ self.prepare_for_dn_loss(preds_dicts['dn_mask_dict'])
275
+
276
+ all_known_bboxs_list = [known_bboxs for _ in range(num_dec_layers)]
277
+ all_known_labels_list = [known_labels for _ in range(num_dec_layers)]
278
+ all_num_tgts_list = [num_tgt for _ in range(num_dec_layers)]
279
+
280
+ dn_losses_cls, dn_losses_bbox = multi_apply(
281
+ self.dn_loss_single, cls_scores, bbox_preds,
282
+ all_known_bboxs_list, all_known_labels_list, all_num_tgts_list)
283
+
284
+ loss_dict['loss_cls_dn'] = dn_losses_cls[-1]
285
+ loss_dict['loss_bbox_dn'] = dn_losses_bbox[-1]
286
+
287
+ num_dec_layer = 0
288
+ for loss_cls_i, loss_bbox_i in zip(dn_losses_cls[:-1], dn_losses_bbox[:-1]):
289
+ loss_dict[f'd{num_dec_layer}.loss_cls_dn'] = loss_cls_i
290
+ loss_dict[f'd{num_dec_layer}.loss_bbox_dn'] = loss_bbox_i
291
+ num_dec_layer += 1
292
+
293
+ return loss_dict
294
+
295
+ def _get_target_single(self,
296
+ cls_score,
297
+ bbox_pred,
298
+ gt_labels,
299
+ gt_bboxes,
300
+ gt_bboxes_ignore=None):
301
+ num_bboxes = bbox_pred.size(0)
302
+
303
+ # assigner and sampler
304
+ assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, gt_labels, gt_bboxes_ignore, self.code_weights, True)
305
+ sampling_result = self.sampler.sample(assign_result, bbox_pred, gt_bboxes)
306
+ pos_inds = sampling_result.pos_inds
307
+ neg_inds = sampling_result.neg_inds
308
+
309
+ # label targets
310
+ labels = gt_bboxes.new_full((num_bboxes, ), self.num_classes, dtype=torch.long)
311
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
312
+ label_weights = gt_bboxes.new_ones(num_bboxes)
313
+
314
+ # bbox targets
315
+ bbox_targets = torch.zeros_like(bbox_pred)[..., :9]
316
+ bbox_weights = torch.zeros_like(bbox_pred)
317
+ bbox_weights[pos_inds] = 1.0
318
+
319
+ # DETR
320
+ bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes
321
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, neg_inds)
322
+
323
+ def get_targets(self,
324
+ cls_scores_list,
325
+ bbox_preds_list,
326
+ gt_bboxes_list,
327
+ gt_labels_list,
328
+ gt_bboxes_ignore_list=None):
329
+ assert gt_bboxes_ignore_list is None, \
330
+ 'Only supports for gt_bboxes_ignore setting to None.'
331
+ num_imgs = len(cls_scores_list)
332
+ gt_bboxes_ignore_list = [gt_bboxes_ignore_list for _ in range(num_imgs)]
333
+
334
+ (labels_list, label_weights_list, bbox_targets_list,
335
+ bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
336
+ self._get_target_single, cls_scores_list, bbox_preds_list,
337
+ gt_labels_list, gt_bboxes_list, gt_bboxes_ignore_list)
338
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
339
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
340
+ return (labels_list, label_weights_list, bbox_targets_list,
341
+ bbox_weights_list, num_total_pos, num_total_neg)
342
+
343
+ def loss_single(self,
344
+ cls_scores,
345
+ bbox_preds,
346
+ gt_bboxes_list,
347
+ gt_labels_list,
348
+ gt_bboxes_ignore_list=None):
349
+ num_imgs = cls_scores.size(0)
350
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
351
+ bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
352
+ cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
353
+ gt_bboxes_list, gt_labels_list, gt_bboxes_ignore_list)
354
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
355
+ num_total_pos, num_total_neg) = cls_reg_targets
356
+
357
+ labels = torch.cat(labels_list, 0)
358
+ label_weights = torch.cat(label_weights_list, 0)
359
+ bbox_targets = torch.cat(bbox_targets_list, 0)
360
+ bbox_weights = torch.cat(bbox_weights_list, 0)
361
+
362
+ # classification loss
363
+ cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
364
+ # construct weighted avg_factor to match with the official DETR repo
365
+ cls_avg_factor = num_total_pos * 1.0 + \
366
+ num_total_neg * self.bg_cls_weight
367
+ if self.sync_cls_avg_factor:
368
+ cls_avg_factor = reduce_mean(
369
+ cls_scores.new_tensor([cls_avg_factor]))
370
+
371
+ cls_avg_factor = max(cls_avg_factor, 1)
372
+ loss_cls = self.loss_cls(
373
+ cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
374
+
375
+ # Compute the average number of gt boxes accross all gpus, for
376
+ # normalization purposes
377
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
378
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
379
+
380
+ # regression L1 loss
381
+ bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1))
382
+ normalized_bbox_targets = normalize_bbox(bbox_targets)
383
+ isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1)
384
+ bbox_weights = bbox_weights * self.code_weights
385
+
386
+ loss_bbox = self.loss_bbox(
387
+ bbox_preds[isnotnan, :10],
388
+ normalized_bbox_targets[isnotnan, :10],
389
+ bbox_weights[isnotnan, :10],
390
+ avg_factor=num_total_pos
391
+ )
392
+
393
+ loss_cls = torch.nan_to_num(loss_cls)
394
+ loss_bbox = torch.nan_to_num(loss_bbox)
395
+
396
+ return loss_cls, loss_bbox
397
+
398
+ @force_fp32(apply_to=('preds_dicts'))
399
+ def loss(self,
400
+ gt_bboxes_list,
401
+ gt_labels_list,
402
+ preds_dicts,
403
+ gt_bboxes_ignore=None):
404
+ assert gt_bboxes_ignore is None, \
405
+ f'{self.__class__.__name__} only supports ' \
406
+ f'for gt_bboxes_ignore setting to None.'
407
+
408
+ all_cls_scores = preds_dicts['all_cls_scores']
409
+ all_bbox_preds = preds_dicts['all_bbox_preds']
410
+ enc_cls_scores = preds_dicts['enc_cls_scores']
411
+ enc_bbox_preds = preds_dicts['enc_bbox_preds']
412
+
413
+ num_dec_layers = len(all_cls_scores)
414
+ device = gt_labels_list[0].device
415
+ gt_bboxes_list = [torch.cat(
416
+ (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
417
+ dim=1).to(device) for gt_bboxes in gt_bboxes_list]
418
+
419
+ all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
420
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
421
+ all_gt_bboxes_ignore_list = [gt_bboxes_ignore for _ in range(num_dec_layers)]
422
+
423
+ losses_cls, losses_bbox = multi_apply(
424
+ self.loss_single, all_cls_scores, all_bbox_preds,
425
+ all_gt_bboxes_list, all_gt_labels_list,
426
+ all_gt_bboxes_ignore_list)
427
+
428
+ loss_dict = dict()
429
+ # loss of proposal generated from encode feature map
430
+ if enc_cls_scores is not None:
431
+ binary_labels_list = [
432
+ torch.zeros_like(gt_labels_list[i])
433
+ for i in range(len(all_gt_labels_list))
434
+ ]
435
+ enc_loss_cls, enc_losses_bbox = \
436
+ self.loss_single(enc_cls_scores, enc_bbox_preds,
437
+ gt_bboxes_list, binary_labels_list, gt_bboxes_ignore)
438
+ loss_dict['enc_loss_cls'] = enc_loss_cls
439
+ loss_dict['enc_loss_bbox'] = enc_losses_bbox
440
+
441
+ if 'dn_mask_dict' in preds_dicts and preds_dicts['dn_mask_dict'] is not None:
442
+ loss_dict = self.calc_dn_loss(loss_dict, preds_dicts, num_dec_layers)
443
+
444
+ # loss from the last decoder layer
445
+ loss_dict['loss_cls'] = losses_cls[-1]
446
+ loss_dict['loss_bbox'] = losses_bbox[-1]
447
+
448
+ # loss from other decoder layers
449
+ num_dec_layer = 0
450
+ for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]):
451
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
452
+ loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
453
+ num_dec_layer += 1
454
+ return loss_dict
455
+
456
+ @force_fp32(apply_to=('preds_dicts'))
457
+ def get_bboxes(self, preds_dicts, img_metas, rescale=False):
458
+ preds_dicts = self.bbox_coder.decode(preds_dicts)
459
+ num_samples = len(preds_dicts)
460
+ ret_list = []
461
+ for i in range(num_samples):
462
+ preds = preds_dicts[i]
463
+ bboxes = preds['bboxes']
464
+ bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5
465
+ bboxes = LiDARInstance3DBoxes(bboxes, 9)
466
+ scores = preds['scores']
467
+ labels = preds['labels']
468
+ ret_list.append([bboxes, scores, labels])
469
+ return ret_list
models/sparsebev_sampling.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .bbox.utils import decode_bbox
4
+ from .utils import rotation_3d_in_axis, DUMP
5
+ from .csrc.wrapper import msmv_sampling, msmv_sampling_pytorch
6
+
7
+
8
+ def make_sample_points(query_bbox, offset, pc_range):
9
+ '''
10
+ query_bbox: [B, Q, 10]
11
+ offset: [B, Q, num_points, 4], normalized by stride
12
+ '''
13
+ query_bbox = decode_bbox(query_bbox, pc_range) # [B, Q, 9]
14
+
15
+ xyz = query_bbox[..., 0:3] # [B, Q, 3]
16
+ wlh = query_bbox[..., 3:6] # [B, Q, 3]
17
+ ang = query_bbox[..., 6:7] # [B, Q, 1]
18
+
19
+ delta_xyz = offset[..., 0:3] # [B, Q, P, 3]
20
+ delta_xyz = wlh[:, :, None, :] * delta_xyz # [B, Q, P, 3]
21
+ delta_xyz = rotation_3d_in_axis(delta_xyz, ang) # [B, Q, P, 3]
22
+ sample_xyz = xyz[:, :, None, :] + delta_xyz # [B, Q, P, 3]
23
+
24
+ return sample_xyz # [B, Q, P, 3]
25
+
26
+
27
+ def sampling_4d(sample_points, mlvl_feats, scale_weights, lidar2img, image_h, image_w, eps=1e-5):
28
+ B, Q, T, G, P, _ = sample_points.shape # [B, Q, T, G, P, 4]
29
+ N = 6
30
+
31
+ sample_points = sample_points.reshape(B, Q, T, G * P, 3)
32
+
33
+ # get the projection matrix
34
+ lidar2img = lidar2img[:, :, None, None, :, :] # [B, TN, 1, 1, 4, 4]
35
+ lidar2img = lidar2img.expand(B, T*N, Q, G * P, 4, 4)
36
+ lidar2img = lidar2img.reshape(B, T, N, Q, G*P, 4, 4)
37
+
38
+ # expand the points
39
+ ones = torch.ones_like(sample_points[..., :1])
40
+ sample_points = torch.cat([sample_points, ones], dim=-1) # [B, Q, GP, 4]
41
+ sample_points = sample_points[:, :, None, ..., None] # [B, Q, T, GP, 4]
42
+ sample_points = sample_points.expand(B, Q, N, T, G * P, 4, 1)
43
+ sample_points = sample_points.transpose(1, 3) # [B, T, N, Q, GP, 4, 1]
44
+
45
+ # project 3d sampling points to image
46
+ sample_points_cam = torch.matmul(lidar2img, sample_points).squeeze(-1) # [B, T, N, Q, GP, 4]
47
+
48
+ # homo coord -> pixel coord
49
+ homo = sample_points_cam[..., 2:3]
50
+ homo_nonzero = torch.maximum(homo, torch.zeros_like(homo) + eps)
51
+ sample_points_cam = sample_points_cam[..., 0:2] / homo_nonzero # [B, T, N, Q, GP, 2]
52
+
53
+ # normalize
54
+ sample_points_cam[..., 0] /= image_w
55
+ sample_points_cam[..., 1] /= image_h
56
+
57
+ # check if out of image
58
+ valid_mask = ((homo > eps) \
59
+ & (sample_points_cam[..., 1:2] > 0.0)
60
+ & (sample_points_cam[..., 1:2] < 1.0)
61
+ & (sample_points_cam[..., 0:1] > 0.0)
62
+ & (sample_points_cam[..., 0:1] < 1.0)
63
+ ).squeeze(-1).float() # [B, T, N, Q, GP]
64
+
65
+ if DUMP.enabled:
66
+ torch.save(torch.cat([sample_points_cam, homo_nonzero], dim=-1),
67
+ '{}/sample_points_cam_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
68
+ torch.save(valid_mask,
69
+ '{}/sample_points_cam_valid_mask_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
70
+
71
+ valid_mask = valid_mask.permute(0, 1, 3, 4, 2) # [B, T, Q, GP, N]
72
+ sample_points_cam = sample_points_cam.permute(0, 1, 3, 4, 2, 5) # [B, T, Q, GP, N, 2]
73
+
74
+ i_batch = torch.arange(B, dtype=torch.long, device=sample_points.device)
75
+ i_query = torch.arange(Q, dtype=torch.long, device=sample_points.device)
76
+ i_time = torch.arange(T, dtype=torch.long, device=sample_points.device)
77
+ i_point = torch.arange(G * P, dtype=torch.long, device=sample_points.device)
78
+ i_batch = i_batch.view(B, 1, 1, 1, 1).expand(B, T, Q, G * P, 1)
79
+ i_time = i_time.view(1, T, 1, 1, 1).expand(B, T, Q, G * P, 1)
80
+ i_query = i_query.view(1, 1, Q, 1, 1).expand(B, T, Q, G * P, 1)
81
+ i_point = i_point.view(1, 1, 1, G * P, 1).expand(B, T, Q, G * P, 1)
82
+ i_view = torch.argmax(valid_mask, dim=-1)[..., None] # [B, T, Q, GP, 1]
83
+
84
+ sample_points_cam = sample_points_cam[i_batch, i_time, i_query, i_point, i_view, :] # [B, Q, GP, 1, 2]
85
+ valid_mask = valid_mask[i_batch, i_time, i_query, i_point, i_view] # [B, Q, GP, 1]
86
+
87
+ sample_points_cam = torch.cat([sample_points_cam, i_view[..., None].float() / 5], dim=-1)
88
+ sample_points_cam = sample_points_cam.reshape(B, T, Q, G, P, 1, 3)
89
+ sample_points_cam = sample_points_cam.permute(0, 1, 3, 2, 4, 5, 6) # [B, T, G, Q, P, 1, 3]
90
+ sample_points_cam = sample_points_cam.reshape(B*T*G, Q, P, 3)
91
+
92
+ scale_weights = scale_weights.reshape(B, Q, G, T, P, -1)
93
+ scale_weights = scale_weights.permute(0, 2, 3, 1, 4, 5)
94
+ scale_weights = scale_weights.reshape(B*G*T, Q, P, -1)
95
+
96
+ final = msmv_sampling(mlvl_feats, sample_points_cam, scale_weights)
97
+ C = final.shape[2] # [BTG, Q, C, P]
98
+ final = final.reshape(B, T, G, Q, C, P)
99
+ final = final.permute(0, 3, 2, 1, 5, 4)
100
+ final = final.flatten(3, 4) # [B, Q, G, FP, C]
101
+
102
+ return final
models/sparsebev_transformer.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+ from mmcv.runner import BaseModule
6
+ from mmcv.cnn import bias_init_with_prob
7
+ from mmcv.cnn.bricks.transformer import MultiheadAttention, FFN
8
+ from mmdet.models.utils.builder import TRANSFORMER
9
+ from .bbox.utils import decode_bbox
10
+ from .utils import inverse_sigmoid, DUMP
11
+ from .sparsebev_sampling import sampling_4d, make_sample_points
12
+ from .checkpoint import checkpoint as cp
13
+
14
+
15
+ @TRANSFORMER.register_module()
16
+ class SparseBEVTransformer(BaseModule):
17
+ def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):
18
+ assert init_cfg is None, 'To prevent abnormal initialization ' \
19
+ 'behavior, init_cfg is not allowed to be set'
20
+ super(SparseBEVTransformer, self).__init__(init_cfg=init_cfg)
21
+
22
+ self.embed_dims = embed_dims
23
+ self.pc_range = pc_range
24
+
25
+ self.decoder = SparseBEVTransformerDecoder(embed_dims, num_frames, num_points, num_layers, num_levels, num_classes, code_size, pc_range=pc_range)
26
+
27
+ @torch.no_grad()
28
+ def init_weights(self):
29
+ self.decoder.init_weights()
30
+
31
+ def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
32
+ cls_scores, bbox_preds = self.decoder(query_bbox, query_feat, mlvl_feats, attn_mask, img_metas)
33
+
34
+ cls_scores = torch.nan_to_num(cls_scores)
35
+ bbox_preds = torch.nan_to_num(bbox_preds)
36
+
37
+ return cls_scores, bbox_preds
38
+
39
+
40
+ class SparseBEVTransformerDecoder(BaseModule):
41
+ def __init__(self, embed_dims, num_frames=8, num_points=4, num_layers=6, num_levels=4, num_classes=10, code_size=10, pc_range=[], init_cfg=None):
42
+ super(SparseBEVTransformerDecoder, self).__init__(init_cfg)
43
+ self.num_layers = num_layers
44
+ self.pc_range = pc_range
45
+
46
+ self.decoder_layer = SparseBEVTransformerDecoderLayer(
47
+ embed_dims, num_frames, num_points, num_levels, num_classes, code_size, pc_range=pc_range
48
+ )
49
+
50
+ @torch.no_grad()
51
+ def init_weights(self):
52
+ self.decoder_layer.init_weights()
53
+
54
+ def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
55
+ cls_scores, bbox_preds = [], []
56
+
57
+ timestamps = np.array([m['img_timestamp'] for m in img_metas], dtype=np.float64)
58
+ timestamps = np.reshape(timestamps, [query_bbox.shape[0], -1, 6])
59
+ time_diff = timestamps[:, :1, :] - timestamps
60
+ time_diff = np.mean(time_diff, axis=-1).astype(np.float32) # [B, F]
61
+ time_diff = torch.from_numpy(time_diff).to(query_bbox.device) # [B, F]
62
+ img_metas[0]['time_diff'] = time_diff
63
+
64
+ lidar2img = np.asarray([m['lidar2img'] for m in img_metas]).astype(np.float32)
65
+ lidar2img = torch.from_numpy(lidar2img).to(query_bbox.device) # [B, N, 4, 4]
66
+ img_metas[0]['lidar2img'] = lidar2img
67
+
68
+ for lvl, feat in enumerate(mlvl_feats):
69
+ B, TN, GC, H, W = feat.shape # [B, TN, GC, H, W]
70
+ N, T, G, C = 6, TN // 6, 4, GC // 4
71
+ feat = feat.reshape(B, T, N, G, C, H, W)
72
+ feat = feat.permute(0, 1, 3, 2, 5, 6, 4) # [B, T, G, N, H, W, C]
73
+ feat = feat.reshape(B*T*G, N, H, W, C) # [BTG, C, N, H, W]
74
+ mlvl_feats[lvl] = feat.contiguous()
75
+
76
+ for i in range(self.num_layers):
77
+ DUMP.stage_count = i
78
+
79
+ query_feat, cls_score, bbox_pred = self.decoder_layer(
80
+ query_bbox, query_feat, mlvl_feats, attn_mask, img_metas
81
+ )
82
+ query_bbox = bbox_pred.clone().detach()
83
+
84
+ cls_scores.append(cls_score)
85
+ bbox_preds.append(bbox_pred)
86
+
87
+ cls_scores = torch.stack(cls_scores)
88
+ bbox_preds = torch.stack(bbox_preds)
89
+
90
+ return cls_scores, bbox_preds
91
+
92
+
93
+ class SparseBEVTransformerDecoderLayer(BaseModule):
94
+ def __init__(self, embed_dims, num_frames=8, num_points=4, num_levels=4, num_classes=10, code_size=10, num_cls_fcs=2, num_reg_fcs=2, pc_range=[], init_cfg=None):
95
+ super(SparseBEVTransformerDecoderLayer, self).__init__(init_cfg)
96
+
97
+ self.embed_dims = embed_dims
98
+ self.num_classes = num_classes
99
+ self.code_size = code_size
100
+ self.pc_range = pc_range
101
+
102
+ self.position_encoder = nn.Sequential(
103
+ nn.Linear(3, self.embed_dims),
104
+ nn.LayerNorm(self.embed_dims),
105
+ nn.ReLU(inplace=True),
106
+ nn.Linear(self.embed_dims, self.embed_dims),
107
+ nn.LayerNorm(self.embed_dims),
108
+ nn.ReLU(inplace=True),
109
+ )
110
+
111
+ self.self_attn = SparseBEVSelfAttention(embed_dims, num_heads=8, dropout=0.1, pc_range=pc_range)
112
+ self.sampling = SparseBEVSampling(embed_dims, num_frames=num_frames, num_groups=4, num_points=num_points, num_levels=num_levels, pc_range=pc_range)
113
+ self.mixing = AdaptiveMixing(in_dim=embed_dims, in_points=num_points * num_frames, n_groups=4, out_points=128)
114
+ self.ffn = FFN(embed_dims, feedforward_channels=512, ffn_drop=0.1)
115
+
116
+ self.norm1 = nn.LayerNorm(embed_dims)
117
+ self.norm2 = nn.LayerNorm(embed_dims)
118
+ self.norm3 = nn.LayerNorm(embed_dims)
119
+
120
+ cls_branch = []
121
+ for _ in range(num_cls_fcs):
122
+ cls_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
123
+ cls_branch.append(nn.LayerNorm(self.embed_dims))
124
+ cls_branch.append(nn.ReLU(inplace=True))
125
+ cls_branch.append(nn.Linear(self.embed_dims, self.num_classes))
126
+ self.cls_branch = nn.Sequential(*cls_branch)
127
+
128
+ reg_branch = []
129
+ for _ in range(num_reg_fcs):
130
+ reg_branch.append(nn.Linear(self.embed_dims, self.embed_dims))
131
+ reg_branch.append(nn.ReLU(inplace=True))
132
+ reg_branch.append(nn.Linear(self.embed_dims, self.code_size))
133
+ self.reg_branch = nn.Sequential(*reg_branch)
134
+
135
+ @torch.no_grad()
136
+ def init_weights(self):
137
+ self.self_attn.init_weights()
138
+ self.sampling.init_weights()
139
+ self.mixing.init_weights()
140
+
141
+ bias_init = bias_init_with_prob(0.01)
142
+ nn.init.constant_(self.cls_branch[-1].bias, bias_init)
143
+
144
+ def refine_bbox(self, bbox_proposal, bbox_delta):
145
+ xyz = inverse_sigmoid(bbox_proposal[..., 0:3])
146
+ xyz_delta = bbox_delta[..., 0:3]
147
+ xyz_new = torch.sigmoid(xyz_delta + xyz)
148
+
149
+ return torch.cat([xyz_new, bbox_delta[..., 3:]], dim=-1)
150
+
151
+ def forward(self, query_bbox, query_feat, mlvl_feats, attn_mask, img_metas):
152
+ """
153
+ query_bbox: [B, Q, 10] [cx, cy, cz, w, h, d, rot.sin, rot.cos, vx, vy]
154
+ """
155
+ query_pos = self.position_encoder(query_bbox[..., :3])
156
+ query_feat = query_feat + query_pos
157
+
158
+ query_feat = self.norm1(self.self_attn(query_bbox, query_feat, attn_mask))
159
+ sampled_feat = self.sampling(query_bbox, query_feat, mlvl_feats, img_metas)
160
+ query_feat = self.norm2(self.mixing(sampled_feat, query_feat))
161
+ query_feat = self.norm3(self.ffn(query_feat))
162
+
163
+ cls_score = self.cls_branch(query_feat) # [B, Q, num_classes]
164
+ bbox_pred = self.reg_branch(query_feat) # [B, Q, code_size]
165
+ bbox_pred = self.refine_bbox(query_bbox, bbox_pred)
166
+
167
+ time_diff = img_metas[0]['time_diff'] # [B, F]
168
+ if time_diff.shape[1] > 1:
169
+ time_diff = time_diff.clone()
170
+ time_diff[time_diff < 1e-5] = 1.0
171
+ bbox_pred[..., 8:] = bbox_pred[..., 8:] / time_diff[:, 1:2, None]
172
+
173
+ if DUMP.enabled:
174
+ query_bbox_dec = decode_bbox(query_bbox, self.pc_range)
175
+ bbox_pred_dec = decode_bbox(bbox_pred, self.pc_range)
176
+ cls_score_sig = torch.sigmoid(cls_score)
177
+ torch.save(query_bbox_dec, '{}/query_bbox_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
178
+ torch.save(bbox_pred_dec, '{}/bbox_pred_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
179
+ torch.save(cls_score_sig, '{}/cls_score_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
180
+
181
+ return query_feat, cls_score, bbox_pred
182
+
183
+
184
+ class SparseBEVSelfAttention(BaseModule):
185
+ def __init__(self, embed_dims=256, num_heads=8, dropout=0.1, pc_range=[], init_cfg=None):
186
+ super().__init__(init_cfg)
187
+ self.pc_range = pc_range
188
+
189
+ self.attention = MultiheadAttention(embed_dims, num_heads, dropout, batch_first=True)
190
+ self.gen_tau = nn.Linear(embed_dims, num_heads)
191
+
192
+ @torch.no_grad()
193
+ def init_weights(self):
194
+ nn.init.zeros_(self.gen_tau.weight)
195
+ nn.init.uniform_(self.gen_tau.bias, 0.0, 2.0)
196
+
197
+ def inner_forward(self, query_bbox, query_feat, pre_attn_mask):
198
+ """
199
+ query_bbox: [B, Q, 10]
200
+ query_feat: [B, Q, C]
201
+ """
202
+ dist = self.calc_bbox_dists(query_bbox)
203
+ tau = self.gen_tau(query_feat) # [B, Q, 8]
204
+
205
+ if DUMP.enabled:
206
+ torch.save(tau, '{}/sasa_tau_stage{}.pth'.format(DUMP.out_dir, DUMP.stage_count))
207
+
208
+ tau = tau.permute(0, 2, 1) # [B, 8, Q]
209
+ attn_mask = dist[:, None, :, :] * tau[..., None] # [B, 8, Q, Q]
210
+ if pre_attn_mask is not None:
211
+ attn_mask[:, :, pre_attn_mask] = float('-inf')
212
+ attn_mask = attn_mask.flatten(0, 1) # [Bx8, Q, Q]
213
+ return self.attention(query_feat, attn_mask=attn_mask)
214
+
215
+ def forward(self, query_bbox, query_feat, pre_attn_mask):
216
+ if self.training and query_feat.requires_grad:
217
+ return cp(self.inner_forward, query_bbox, query_feat, pre_attn_mask, use_reentrant=False)
218
+ else:
219
+ return self.inner_forward(query_bbox, query_feat, pre_attn_mask)
220
+
221
+ @torch.no_grad()
222
+ def calc_bbox_dists(self, bboxes):
223
+ centers = decode_bbox(bboxes, self.pc_range)[..., :2] # [B, Q, 2]
224
+
225
+ dist = []
226
+ for b in range(centers.shape[0]):
227
+ dist_b = torch.norm(centers[b].reshape(-1, 1, 2) - centers[b].reshape(1, -1, 2), dim=-1)
228
+ dist.append(dist_b[None, ...])
229
+
230
+ dist = torch.cat(dist, dim=0) # [B, Q, Q]
231
+ dist = -dist
232
+
233
+ return dist
234
+
235
+
236
+ class SparseBEVSampling(BaseModule):
237
+ def __init__(self, embed_dims=256, num_frames=4, num_groups=4, num_points=8, num_levels=4, pc_range=[], init_cfg=None):
238
+ super().__init__(init_cfg)
239
+
240
+ self.num_frames = num_frames
241
+ self.num_points = num_points
242
+ self.num_groups = num_groups
243
+ self.num_levels = num_levels
244
+ self.pc_range = pc_range
245
+
246
+ self.sampling_offset = nn.Linear(embed_dims, num_groups * num_points * 3)
247
+ self.scale_weights = nn.Linear(embed_dims, num_groups * num_points * num_levels)
248
+
249
+ def init_weights(self):
250
+ bias = self.sampling_offset.bias.data.view(self.num_groups * self.num_points, 3)
251
+ nn.init.zeros_(self.sampling_offset.weight)
252
+ nn.init.uniform_(bias[:, 0:3], -0.5, 0.5)
253
+
254
+ def inner_forward(self, query_bbox, query_feat, mlvl_feats, img_metas):
255
+ '''
256
+ query_bbox: [B, Q, 10]
257
+ query_feat: [B, Q, C]
258
+ '''
259
+ B, Q = query_bbox.shape[:2]
260
+ image_h, image_w, _ = img_metas[0]['img_shape'][0]
261
+
262
+ # sampling offset of all frames
263
+ sampling_offset = self.sampling_offset(query_feat)
264
+ sampling_offset = sampling_offset.view(B, Q, self.num_groups * self.num_points, 3)
265
+ sampling_points = make_sample_points(query_bbox, sampling_offset, self.pc_range) # [B, Q, GP, 3]
266
+ sampling_points = sampling_points.reshape(B, Q, 1, self.num_groups, self.num_points, 3)
267
+ sampling_points = sampling_points.expand(B, Q, self.num_frames, self.num_groups, self.num_points, 3)
268
+
269
+ # warp sample points based on velocity
270
+ time_diff = img_metas[0]['time_diff'] # [B, F]
271
+ time_diff = time_diff[:, None, :, None] # [B, 1, F, 1]
272
+ vel = query_bbox[..., 8:].detach() # [B, Q, 2]
273
+ vel = vel[:, :, None, :] # [B, Q, 1, 2]
274
+ dist = vel * time_diff # [B, Q, F, 2]
275
+ dist = dist[:, :, :, None, None, :] # [B, Q, F, 1, 1, 2]
276
+ sampling_points = torch.cat([
277
+ sampling_points[..., 0:2] - dist,
278
+ sampling_points[..., 2:3]
279
+ ], dim=-1)
280
+
281
+ # scale weights
282
+ scale_weights = self.scale_weights(query_feat).view(B, Q, self.num_groups, 1, self.num_points, self.num_levels)
283
+ scale_weights = torch.softmax(scale_weights, dim=-1)
284
+ scale_weights = scale_weights.expand(B, Q, self.num_groups, self.num_frames, self.num_points, self.num_levels)
285
+
286
+ # sampling
287
+ sampled_feats = sampling_4d(
288
+ sampling_points,
289
+ mlvl_feats,
290
+ scale_weights,
291
+ img_metas[0]['lidar2img'],
292
+ image_h, image_w
293
+ ) # [B, Q, G, FP, C]
294
+
295
+ return sampled_feats
296
+
297
+ def forward(self, query_bbox, query_feat, mlvl_feats, img_metas):
298
+ if self.training and query_feat.requires_grad:
299
+ return cp(self.inner_forward, query_bbox, query_feat, mlvl_feats, img_metas, use_reentrant=False)
300
+ else:
301
+ return self.inner_forward(query_bbox, query_feat, mlvl_feats, img_metas)
302
+
303
+
304
+ class AdaptiveMixing(nn.Module):
305
+ def __init__(self, in_dim, in_points, n_groups=1, query_dim=None, out_dim=None, out_points=None):
306
+ super(AdaptiveMixing, self).__init__()
307
+
308
+ out_dim = out_dim if out_dim is not None else in_dim
309
+ out_points = out_points if out_points is not None else in_points
310
+ query_dim = query_dim if query_dim is not None else in_dim
311
+
312
+ self.query_dim = query_dim
313
+ self.in_dim = in_dim
314
+ self.in_points = in_points
315
+ self.n_groups = n_groups
316
+ self.out_dim = out_dim
317
+ self.out_points = out_points
318
+
319
+ self.eff_in_dim = in_dim // n_groups
320
+ self.eff_out_dim = out_dim // n_groups
321
+
322
+ self.m_parameters = self.eff_in_dim * self.eff_out_dim
323
+ self.s_parameters = self.in_points * self.out_points
324
+ self.total_parameters = self.m_parameters + self.s_parameters
325
+
326
+ self.parameter_generator = nn.Linear(self.query_dim, self.n_groups * self.total_parameters)
327
+ self.out_proj = nn.Linear(self.eff_out_dim * self.out_points * self.n_groups, self.query_dim)
328
+ self.act = nn.ReLU(inplace=True)
329
+
330
+ @torch.no_grad()
331
+ def init_weights(self):
332
+ nn.init.zeros_(self.parameter_generator.weight)
333
+
334
+ def inner_forward(self, x, query):
335
+ B, Q, G, P, C = x.shape
336
+ assert G == self.n_groups
337
+ assert P == self.in_points
338
+ assert C == self.eff_in_dim
339
+
340
+ '''generate mixing parameters'''
341
+ params = self.parameter_generator(query)
342
+ params = params.reshape(B*Q, G, -1)
343
+ out = x.reshape(B*Q, G, P, C)
344
+
345
+ M, S = params.split([self.m_parameters, self.s_parameters], 2)
346
+ M = M.reshape(B*Q, G, self.eff_in_dim, self.eff_out_dim)
347
+ S = S.reshape(B*Q, G, self.out_points, self.in_points)
348
+
349
+ '''adaptive channel mixing'''
350
+ out = torch.matmul(out, M)
351
+ out = F.layer_norm(out, [out.size(-2), out.size(-1)])
352
+ out = self.act(out)
353
+
354
+ '''adaptive point mixing'''
355
+ out = torch.matmul(S, out) # implicitly transpose and matmul
356
+ out = F.layer_norm(out, [out.size(-2), out.size(-1)])
357
+ out = self.act(out)
358
+
359
+ '''linear transfomation to query dim'''
360
+ out = out.reshape(B, Q, -1)
361
+ out = self.out_proj(out)
362
+ out = query + out
363
+
364
+ return out
365
+
366
+ def forward(self, x, query):
367
+ if self.training and x.requires_grad:
368
+ return cp(self.inner_forward, x, query, use_reentrant=False)
369
+ else:
370
+ return self.inner_forward(x, query)
models/utils.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from numpy import random
6
+
7
+
8
+ class GridMask(nn.Module):
9
+ def __init__(self, ratio=0.5, prob=0.7):
10
+ super(GridMask, self).__init__()
11
+ self.ratio = ratio
12
+ self.prob = prob
13
+
14
+ def forward(self, x):
15
+ if np.random.rand() > self.prob or not self.training:
16
+ return x
17
+
18
+ n, c, h, w = x.size()
19
+ x = x.view(-1, h, w)
20
+ hh = int(1.5 * h)
21
+ ww = int(1.5 * w)
22
+
23
+ d = np.random.randint(2, h)
24
+ l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
25
+ mask = np.ones((hh, ww), np.uint8)
26
+ st_h = np.random.randint(d)
27
+ st_w = np.random.randint(d)
28
+
29
+ for i in range(hh // d):
30
+ s = d*i + st_h
31
+ t = min(s + l, hh)
32
+ mask[s:t, :] = 0
33
+
34
+ for i in range(ww // d):
35
+ s = d*i + st_w
36
+ t = min(s + l, ww)
37
+ mask[:, s:t] = 0
38
+
39
+ mask = mask[(hh-h)//2:(hh-h)//2+h, (ww-w)//2:(ww-w)//2+w]
40
+ mask = torch.tensor(mask, dtype=x.dtype, device=x.device)
41
+ mask = 1 - mask
42
+ mask = mask.expand_as(x)
43
+ x = x * mask
44
+
45
+ return x.view(n, c, h, w)
46
+
47
+
48
+ def rotation_3d_in_axis(points, angles):
49
+ assert points.shape[-1] == 3
50
+ assert angles.shape[-1] == 1
51
+ angles = angles[..., 0]
52
+
53
+ n_points = points.shape[-2]
54
+ input_dims = angles.shape
55
+
56
+ if len(input_dims) > 1:
57
+ points = points.reshape(-1, n_points, 3)
58
+ angles = angles.reshape(-1)
59
+
60
+ rot_sin = torch.sin(angles)
61
+ rot_cos = torch.cos(angles)
62
+ ones = torch.ones_like(rot_cos)
63
+ zeros = torch.zeros_like(rot_cos)
64
+
65
+ rot_mat_T = torch.stack([
66
+ rot_cos, rot_sin, zeros,
67
+ -rot_sin, rot_cos, zeros,
68
+ zeros, zeros, ones,
69
+ ]).transpose(0, 1).reshape(-1, 3, 3)
70
+
71
+ points = torch.bmm(points, rot_mat_T)
72
+
73
+ if len(input_dims) > 1:
74
+ points = points.reshape(*input_dims, n_points, 3)
75
+
76
+ return points
77
+
78
+
79
+ def inverse_sigmoid(x, eps=1e-5):
80
+ """Inverse function of sigmoid.
81
+ Args:
82
+ x (Tensor): The tensor to do the
83
+ inverse.
84
+ eps (float): EPS avoid numerical
85
+ overflow. Defaults 1e-5.
86
+ Returns:
87
+ Tensor: The x has passed the inverse
88
+ function of sigmoid, has same
89
+ shape with input.
90
+ """
91
+ x = x.clamp(min=0, max=1)
92
+ x1 = x.clamp(min=eps)
93
+ x2 = (1 - x).clamp(min=eps)
94
+ return torch.log(x1 / x2)
95
+
96
+
97
+ def pad_multiple(inputs, img_metas, size_divisor=32):
98
+ _, _, img_h, img_w = inputs.shape
99
+
100
+ pad_h = 0 if img_h % size_divisor == 0 else size_divisor - (img_h % size_divisor)
101
+ pad_w = 0 if img_w % size_divisor == 0 else size_divisor - (img_w % size_divisor)
102
+
103
+ B = len(img_metas)
104
+ N = len(img_metas[0]['ori_shape'])
105
+
106
+ for b in range(B):
107
+ img_metas[b]['img_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]
108
+ img_metas[b]['pad_shape'] = [(img_h + pad_h, img_w + pad_w, 3) for _ in range(N)]
109
+
110
+ if pad_h == 0 and pad_w == 0:
111
+ return inputs
112
+ else:
113
+ return F.pad(inputs, [0, pad_w, 0, pad_h], value=0)
114
+
115
+
116
+ def rgb_to_hsv(image: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
117
+ r"""Convert an image from RGB to HSV.
118
+
119
+ .. image:: _static/img/rgb_to_hsv.png
120
+
121
+ The image data is assumed to be in the range of (0, 1).
122
+
123
+ Args:
124
+ image: RGB Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
125
+ eps: scalar to enforce numarical stability.
126
+
127
+ Returns:
128
+ HSV version of the image with shape of :math:`(*, 3, H, W)`.
129
+ The H channel values are in the range 0..2pi. S and V are in the range 0..1.
130
+
131
+ .. note::
132
+ See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
133
+ color_conversions.html>`__.
134
+
135
+ Example:
136
+ >>> input = torch.rand(2, 3, 4, 5)
137
+ >>> output = rgb_to_hsv(input) # 2x3x4x5
138
+ """
139
+ if not isinstance(image, torch.Tensor):
140
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
141
+
142
+ if len(image.shape) < 3 or image.shape[-3] != 3:
143
+ raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
144
+
145
+ image = image / 255.0
146
+
147
+ max_rgb, argmax_rgb = image.max(-3)
148
+ min_rgb, argmin_rgb = image.min(-3)
149
+ deltac = max_rgb - min_rgb
150
+
151
+ v = max_rgb
152
+ s = deltac / (max_rgb + eps)
153
+
154
+ deltac = torch.where(deltac == 0, torch.ones_like(deltac), deltac)
155
+ rc, gc, bc = torch.unbind((max_rgb.unsqueeze(-3) - image), dim=-3)
156
+
157
+ h1 = bc - gc
158
+ h2 = (rc - bc) + 2.0 * deltac
159
+ h3 = (gc - rc) + 4.0 * deltac
160
+
161
+ h = torch.stack((h1, h2, h3), dim=-3) / deltac.unsqueeze(-3)
162
+ h = torch.gather(h, dim=-3, index=argmax_rgb.unsqueeze(-3)).squeeze(-3)
163
+ h = (h / 6.0) % 1.0
164
+
165
+ h = h * 360.0
166
+ v = v * 255.0
167
+
168
+ return torch.stack((h, s, v), dim=-3)
169
+
170
+
171
+ def hsv_to_rgb(image: torch.Tensor) -> torch.Tensor:
172
+ r"""Convert an image from HSV to RGB.
173
+
174
+ The H channel values are assumed to be in the range 0..2pi. S and V are in the range 0..1.
175
+
176
+ Args:
177
+ image: HSV Image to be converted to HSV with shape of :math:`(*, 3, H, W)`.
178
+
179
+ Returns:
180
+ RGB version of the image with shape of :math:`(*, 3, H, W)`.
181
+
182
+ Example:
183
+ >>> input = torch.rand(2, 3, 4, 5)
184
+ >>> output = hsv_to_rgb(input) # 2x3x4x5
185
+ """
186
+ if not isinstance(image, torch.Tensor):
187
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
188
+
189
+ if len(image.shape) < 3 or image.shape[-3] != 3:
190
+ raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
191
+
192
+ h: torch.Tensor = image[..., 0, :, :] / 360.0
193
+ s: torch.Tensor = image[..., 1, :, :]
194
+ v: torch.Tensor = image[..., 2, :, :] / 255.0
195
+
196
+ hi: torch.Tensor = torch.floor(h * 6) % 6
197
+ f: torch.Tensor = ((h * 6) % 6) - hi
198
+ one: torch.Tensor = torch.tensor(1.0, device=image.device, dtype=image.dtype)
199
+ p: torch.Tensor = v * (one - s)
200
+ q: torch.Tensor = v * (one - f * s)
201
+ t: torch.Tensor = v * (one - (one - f) * s)
202
+
203
+ hi = hi.long()
204
+ indices: torch.Tensor = torch.stack([hi, hi + 6, hi + 12], dim=-3)
205
+ out = torch.stack((v, q, p, p, t, v, t, v, v, q, p, p, p, p, t, v, v, q), dim=-3)
206
+ out = torch.gather(out, -3, indices)
207
+ out = out * 255.0
208
+
209
+ return out
210
+
211
+
212
+ class GpuPhotoMetricDistortion:
213
+ """Apply photometric distortion to image sequentially, every transformation
214
+ is applied with a probability of 0.5. The position of random contrast is in
215
+ second or second to last.
216
+ 1. random brightness
217
+ 2. random contrast (mode 0)
218
+ 3. convert color from BGR to HSV
219
+ 4. random saturation
220
+ 5. random hue
221
+ 6. convert color from HSV to BGR
222
+ 7. random contrast (mode 1)
223
+ 8. randomly swap channels
224
+ Args:
225
+ brightness_delta (int): delta of brightness.
226
+ contrast_range (tuple): range of contrast.
227
+ saturation_range (tuple): range of saturation.
228
+ hue_delta (int): delta of hue.
229
+ """
230
+
231
+ def __init__(self,
232
+ brightness_delta=32,
233
+ contrast_range=(0.5, 1.5),
234
+ saturation_range=(0.5, 1.5),
235
+ hue_delta=18):
236
+ self.brightness_delta = brightness_delta
237
+ self.contrast_lower, self.contrast_upper = contrast_range
238
+ self.saturation_lower, self.saturation_upper = saturation_range
239
+ self.hue_delta = hue_delta
240
+
241
+ def __call__(self, imgs):
242
+ """Call function to perform photometric distortion on images.
243
+ Args:
244
+ results (dict): Result dict from loading pipeline.
245
+ Returns:
246
+ dict: Result dict with images distorted.
247
+ """
248
+ imgs = imgs[:, [2, 1, 0], :, :] # BGR to RGB
249
+
250
+ contrast_modes = []
251
+ for _ in range(imgs.shape[0]):
252
+ # mode == 0 --> do random contrast first
253
+ # mode == 1 --> do random contrast last
254
+ contrast_modes.append(random.randint(2))
255
+
256
+ for idx in range(imgs.shape[0]):
257
+ # random brightness
258
+ if random.randint(2):
259
+ delta = random.uniform(-self.brightness_delta, self.brightness_delta)
260
+ imgs[idx] += delta
261
+
262
+ if contrast_modes[idx] == 0:
263
+ if random.randint(2):
264
+ alpha = random.uniform(self.contrast_lower, self.contrast_upper)
265
+ imgs[idx] *= alpha
266
+
267
+ # convert color from BGR to HSV
268
+ imgs = rgb_to_hsv(imgs)
269
+
270
+ for idx in range(imgs.shape[0]):
271
+ # random saturation
272
+ if random.randint(2):
273
+ imgs[idx, 1] *= random.uniform(self.saturation_lower, self.saturation_upper)
274
+
275
+ # random hue
276
+ if random.randint(2):
277
+ imgs[idx, 0] += random.uniform(-self.hue_delta, self.hue_delta)
278
+
279
+ imgs[:, 0][imgs[:, 0] > 360] -= 360
280
+ imgs[:, 0][imgs[:, 0] < 0] += 360
281
+
282
+ # convert color from HSV to BGR
283
+ imgs = hsv_to_rgb(imgs)
284
+
285
+ for idx in range(imgs.shape[0]):
286
+ # random contrast
287
+ if contrast_modes[idx] == 1:
288
+ if random.randint(2):
289
+ alpha = random.uniform(self.contrast_lower, self.contrast_upper)
290
+ imgs[idx] *= alpha
291
+
292
+ # randomly swap channels
293
+ if random.randint(2):
294
+ imgs[idx] = imgs[idx, random.permutation(3)]
295
+
296
+ imgs = imgs[:, [2, 1, 0], :, :] # RGB to BGR
297
+
298
+ return imgs
299
+
300
+
301
+ class DumpConfig:
302
+ def __init__(self):
303
+ self.enabled = False
304
+ self.out_dir = 'outputs'
305
+ self.stage_count = 0
306
+ self.frame_count = 0
307
+
308
+
309
+ DUMP = DumpConfig()
timing.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import utils
3
+ import logging
4
+ import argparse
5
+ import importlib
6
+ import torch
7
+ import torch.distributed
8
+ import torch.backends.cudnn as cudnn
9
+ from mmcv import Config, DictAction
10
+ from mmcv.parallel import MMDataParallel
11
+ from mmcv.runner import load_checkpoint
12
+ from mmdet.apis import set_random_seed
13
+ from mmdet3d.datasets import build_dataset, build_dataloader
14
+ from mmdet3d.models import build_model
15
+
16
+
17
+ def main():
18
+ parser = argparse.ArgumentParser(description='Validate a detector')
19
+ parser.add_argument('--config', required=True)
20
+ parser.add_argument('--weights', required=True)
21
+ parser.add_argument('--num_warmup', default=10)
22
+ parser.add_argument('--samples', default=500)
23
+ parser.add_argument('--log-interval', default=50, help='interval of logging')
24
+ parser.add_argument('--override', nargs='+', action=DictAction)
25
+ args = parser.parse_args()
26
+
27
+ # parse configs
28
+ cfgs = Config.fromfile(args.config)
29
+ if args.override is not None:
30
+ cfgs.merge_from_dict(args.override)
31
+
32
+ # register custom module
33
+ importlib.import_module('models')
34
+ importlib.import_module('loaders')
35
+
36
+ # MMCV, please shut up
37
+ from mmcv.utils.logging import logger_initialized
38
+ logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
39
+ logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
40
+ utils.init_logging(None, cfgs.debug)
41
+
42
+ # you need GPUs
43
+ assert torch.cuda.is_available() and torch.cuda.device_count() == 1
44
+ logging.info('Using GPU: %s' % torch.cuda.get_device_name(0))
45
+ torch.cuda.set_device(0)
46
+
47
+ logging.info('Setting random seed: 0')
48
+ set_random_seed(0, deterministic=True)
49
+ cudnn.benchmark = True
50
+
51
+ logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
52
+ val_dataset = build_dataset(cfgs.data.val)
53
+ val_loader = build_dataloader(
54
+ val_dataset,
55
+ samples_per_gpu=1,
56
+ workers_per_gpu=cfgs.data.workers_per_gpu,
57
+ num_gpus=1,
58
+ dist=False,
59
+ shuffle=False,
60
+ seed=0,
61
+ )
62
+
63
+ logging.info('Creating model: %s' % cfgs.model.type)
64
+ model = build_model(cfgs.model)
65
+ model.cuda()
66
+
67
+ assert torch.cuda.device_count() == 1
68
+ model = MMDataParallel(model, [0])
69
+
70
+ logging.info('Loading checkpoint from %s' % args.weights)
71
+ load_checkpoint(
72
+ model, args.weights, map_location='cuda', strict=False,
73
+ logger=logging.Logger(__name__, logging.ERROR)
74
+ )
75
+ model.eval()
76
+
77
+ pure_inf_time = 0
78
+ with torch.no_grad():
79
+ for i, data in enumerate(val_loader):
80
+ torch.cuda.synchronize()
81
+ start_time = time.perf_counter()
82
+
83
+ model(return_loss=False, rescale=True, **data)
84
+
85
+ torch.cuda.synchronize()
86
+ elapsed = time.perf_counter() - start_time
87
+
88
+ if i >= args.num_warmup:
89
+ pure_inf_time += elapsed
90
+ if (i + 1) % args.log_interval == 0:
91
+ fps = (i + 1 - args.num_warmup) / pure_inf_time
92
+ print(f'Done sample [{i + 1:<3}/ {args.samples}], '
93
+ f'fps: {fps:.1f} sample / s')
94
+
95
+ if (i + 1) == args.samples:
96
+ break
97
+
98
+
99
+ if __name__ == '__main__':
100
+ main()
train.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import utils
3
+ import shutil
4
+ import logging
5
+ import argparse
6
+ import importlib
7
+ import torch
8
+ import torch.distributed as dist
9
+ from datetime import datetime
10
+ from mmcv import Config, DictAction
11
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
12
+ from mmcv.runner import EpochBasedRunner, build_optimizer, load_checkpoint
13
+ from mmdet.apis import set_random_seed
14
+ from mmdet.core import DistEvalHook, EvalHook
15
+ from mmdet3d.datasets import build_dataset
16
+ from mmdet3d.models import build_model
17
+ from loaders.builder import build_dataloader
18
+
19
+
20
+ def main():
21
+ parser = argparse.ArgumentParser(description='Train a detector')
22
+ parser.add_argument('--config', required=True)
23
+ parser.add_argument('--override', nargs='+', action=DictAction)
24
+ parser.add_argument('--local_rank', type=int, default=0)
25
+ parser.add_argument('--world_size', type=int, default=1)
26
+ args = parser.parse_args()
27
+
28
+ # parse configs
29
+ cfgs = Config.fromfile(args.config)
30
+ if args.override is not None:
31
+ cfgs.merge_from_dict(args.override)
32
+
33
+ # register custom module
34
+ importlib.import_module('models')
35
+ importlib.import_module('loaders')
36
+
37
+ # MMCV, please shut up
38
+ from mmcv.utils.logging import logger_initialized
39
+ logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
40
+ logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
41
+ logger_initialized['mmdet3d'] = logging.Logger(__name__, logging.WARNING)
42
+
43
+ # you need GPUs
44
+ assert torch.cuda.is_available()
45
+
46
+ # determine local_rank and world_size
47
+ if 'LOCAL_RANK' not in os.environ:
48
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
49
+
50
+ if 'WORLD_SIZE' not in os.environ:
51
+ os.environ['WORLD_SIZE'] = str(args.world_size)
52
+
53
+ local_rank = int(os.environ['LOCAL_RANK'])
54
+ world_size = int(os.environ['WORLD_SIZE'])
55
+
56
+ if local_rank == 0:
57
+ # resume or start a new run
58
+ if cfgs.resume_from is not None:
59
+ assert os.path.isfile(cfgs.resume_from)
60
+ work_dir = os.path.dirname(cfgs.resume_from)
61
+ else:
62
+ run_name = ''
63
+ if not cfgs.debug:
64
+ run_name = input('Name your run (leave blank for default): ')
65
+ if run_name == '':
66
+ run_name = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
67
+
68
+ work_dir = os.path.join('outputs', cfgs.model.type, run_name)
69
+ if os.path.exists(work_dir): # must be an empty dir
70
+ if input('Path "%s" already exists, overwrite it? [Y/n] ' % work_dir) == 'n':
71
+ print('Bye.')
72
+ exit(0)
73
+ shutil.rmtree(work_dir)
74
+
75
+ os.makedirs(work_dir, exist_ok=False)
76
+
77
+ # init logging, backup code
78
+ utils.init_logging(os.path.join(work_dir, 'train.log'), cfgs.debug)
79
+ utils.backup_code(work_dir)
80
+ logging.info('Logs will be saved to %s' % work_dir)
81
+
82
+ else:
83
+ # disable logging on other workers
84
+ logging.root.disabled = True
85
+ work_dir = '/tmp'
86
+
87
+ logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))
88
+ torch.cuda.set_device(local_rank)
89
+
90
+ if world_size > 1:
91
+ logging.info('Initializing DDP with %d GPUs...' % world_size)
92
+ dist.init_process_group('nccl', init_method='env://')
93
+
94
+ logging.info('Setting random seed: 0')
95
+ set_random_seed(0, deterministic=True)
96
+
97
+ logging.info('Loading training set from %s' % cfgs.dataset_root)
98
+ train_dataset = build_dataset(cfgs.data.train)
99
+ train_loader = build_dataloader(
100
+ train_dataset,
101
+ samples_per_gpu=cfgs.batch_size // world_size,
102
+ workers_per_gpu=cfgs.data.workers_per_gpu,
103
+ num_gpus=world_size,
104
+ dist=world_size > 1,
105
+ shuffle=True,
106
+ seed=0,
107
+ )
108
+
109
+ logging.info('Loading validation set from %s' % cfgs.dataset_root)
110
+ val_dataset = build_dataset(cfgs.data.val)
111
+ val_loader = build_dataloader(
112
+ val_dataset,
113
+ samples_per_gpu=1,
114
+ workers_per_gpu=cfgs.data.workers_per_gpu,
115
+ num_gpus=world_size,
116
+ dist=world_size > 1,
117
+ shuffle=False
118
+ )
119
+
120
+ logging.info('Creating model: %s' % cfgs.model.type)
121
+ model = build_model(cfgs.model)
122
+ model.init_weights()
123
+ model.cuda()
124
+ model.train()
125
+
126
+ n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
127
+ logging.info('Trainable parameters: %d (%.1fM)' % (n_params, n_params / 1e6))
128
+ logging.info('Batch size per GPU: %d' % (cfgs.batch_size // world_size))
129
+
130
+ if world_size > 1:
131
+ model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
132
+ else:
133
+ model = MMDataParallel(model, [0])
134
+
135
+ logging.info('Creating optimizer: %s' % cfgs.optimizer.type)
136
+ optimizer = build_optimizer(model, cfgs.optimizer)
137
+
138
+ runner = EpochBasedRunner(
139
+ model,
140
+ optimizer=optimizer,
141
+ work_dir=work_dir,
142
+ logger=logging.root,
143
+ max_epochs=cfgs.total_epochs,
144
+ meta=dict(),
145
+ )
146
+
147
+ runner.register_lr_hook(cfgs.lr_config)
148
+ runner.register_optimizer_hook(cfgs.optimizer_config)
149
+ runner.register_checkpoint_hook(cfgs.checkpoint_config)
150
+ runner.register_logger_hooks(cfgs.log_config)
151
+ runner.register_timer_hook(dict(type='IterTimerHook'))
152
+ runner.register_custom_hooks(dict(type='DistSamplerSeedHook'))
153
+
154
+ if cfgs.eval_config['interval'] > 0:
155
+ if world_size > 1:
156
+ runner.register_hook(DistEvalHook(val_loader, interval=cfgs.eval_config['interval'], gpu_collect=True))
157
+ else:
158
+ runner.register_hook(EvalHook(val_loader, interval=cfgs.eval_config['interval']))
159
+
160
+ if cfgs.resume_from is not None:
161
+ logging.info('Resuming from %s' % cfgs.resume_from)
162
+ runner.resume(cfgs.resume_from)
163
+
164
+ elif cfgs.load_from is not None:
165
+ logging.info('Loading checkpoint from %s' % cfgs.load_from)
166
+ if cfgs.revise_keys is not None:
167
+ load_checkpoint(
168
+ model, cfgs.load_from, map_location='cpu',
169
+ revise_keys=cfgs.revise_keys
170
+ )
171
+ else:
172
+ load_checkpoint(
173
+ model, cfgs.load_from, map_location='cpu',
174
+ )
175
+
176
+ runner.run([train_loader], [('train', 1)])
177
+
178
+
179
+ if __name__ == '__main__':
180
+ main()
utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import glob
4
+ import torch
5
+ import shutil
6
+ import logging
7
+ import datetime
8
+ from mmcv.runner.hooks import HOOKS
9
+ from mmcv.runner.hooks.logger import LoggerHook, TextLoggerHook
10
+ from mmcv.runner.dist_utils import master_only
11
+ from torch.utils.tensorboard import SummaryWriter
12
+
13
+
14
+ def init_logging(filename=None, debug=False):
15
+ logging.root = logging.RootLogger('DEBUG' if debug else 'INFO')
16
+ formatter = logging.Formatter('[%(asctime)s][%(levelname)s] - %(message)s')
17
+
18
+ stream_handler = logging.StreamHandler(sys.stdout)
19
+ stream_handler.setFormatter(formatter)
20
+ logging.root.addHandler(stream_handler)
21
+
22
+ if filename is not None:
23
+ file_handler = logging.FileHandler(filename)
24
+ file_handler.setFormatter(formatter)
25
+ logging.root.addHandler(file_handler)
26
+
27
+
28
+ def backup_code(work_dir, verbose=False):
29
+ base_dir = os.path.dirname(os.path.abspath(__file__))
30
+ for pattern in ['*.py', 'configs/*.py', 'models/*.py', 'loaders/*.py', 'loaders/pipelines/*.py']:
31
+ for file in glob.glob(pattern):
32
+ src = os.path.join(base_dir, file)
33
+ dst = os.path.join(work_dir, 'backup', os.path.dirname(file))
34
+
35
+ if verbose:
36
+ logging.info('Copying %s -> %s' % (os.path.relpath(src), os.path.relpath(dst)))
37
+
38
+ os.makedirs(dst, exist_ok=True)
39
+ shutil.copy2(src, dst)
40
+
41
+
42
+ @HOOKS.register_module()
43
+ class MyTextLoggerHook(TextLoggerHook):
44
+ def _log_info(self, log_dict, runner):
45
+ # print exp name for users to distinguish experiments
46
+ # at every ``interval_exp_name`` iterations and the end of each epoch
47
+ if runner.meta is not None and 'exp_name' in runner.meta:
48
+ if (self.every_n_iters(runner, self.interval_exp_name)) or (
49
+ self.by_epoch and self.end_of_epoch(runner)):
50
+ exp_info = f'Exp name: {runner.meta["exp_name"]}'
51
+ runner.logger.info(exp_info)
52
+
53
+ # by epoch: Epoch [4][100/1000]
54
+ # by iter: Iter [100/100000]
55
+ if self.by_epoch:
56
+ log_str = f'Epoch [{log_dict["epoch"]}/{runner.max_epochs}]' \
57
+ f'[{log_dict["iter"]}/{len(runner.data_loader)}] '
58
+ else:
59
+ log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}] '
60
+
61
+ log_str += 'loss: %.2f, ' % log_dict['loss']
62
+
63
+ if 'time' in log_dict.keys():
64
+ # MOD: skip the first iteration since it's not accurate
65
+ if runner.iter == self.start_iter:
66
+ time_sec_avg = log_dict['time']
67
+ else:
68
+ self.time_sec_tot += (log_dict['time'] * self.interval)
69
+ time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter)
70
+
71
+ eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
72
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
73
+ log_str += f'eta: {eta_str}, '
74
+ log_str += f'time: {log_dict["time"]:.2f}s, ' \
75
+ f'data: {log_dict["data_time"] * 1000:.0f}ms, '
76
+ # statistic memory
77
+ if torch.cuda.is_available():
78
+ log_str += f'mem: {log_dict["memory"]}M'
79
+
80
+ runner.logger.info(log_str)
81
+
82
+ def log(self, runner):
83
+ if 'eval_iter_num' in runner.log_buffer.output:
84
+ # this doesn't modify runner.iter and is regardless of by_epoch
85
+ cur_iter = runner.log_buffer.output.pop('eval_iter_num')
86
+ else:
87
+ cur_iter = self.get_iter(runner, inner_iter=True)
88
+
89
+ log_dict = {
90
+ 'mode': self.get_mode(runner),
91
+ 'epoch': self.get_epoch(runner),
92
+ 'iter': cur_iter
93
+ }
94
+
95
+ # only record lr of the first param group
96
+ cur_lr = runner.current_lr()
97
+ if isinstance(cur_lr, list):
98
+ log_dict['lr'] = cur_lr[0]
99
+ else:
100
+ assert isinstance(cur_lr, dict)
101
+ log_dict['lr'] = {}
102
+ for k, lr_ in cur_lr.items():
103
+ assert isinstance(lr_, list)
104
+ log_dict['lr'].update({k: lr_[0]})
105
+
106
+ if 'time' in runner.log_buffer.output:
107
+ # statistic memory
108
+ if torch.cuda.is_available():
109
+ log_dict['memory'] = self._get_max_memory(runner)
110
+
111
+ log_dict = dict(log_dict, **runner.log_buffer.output)
112
+
113
+ # MOD: disable writing to files
114
+ # self._dump_log(log_dict, runner)
115
+ self._log_info(log_dict, runner)
116
+
117
+ return log_dict
118
+
119
+ def after_train_epoch(self, runner):
120
+ if runner.log_buffer.ready:
121
+ metrics = self.get_loggable_tags(runner)
122
+ runner.logger.info('--- Evaluation Results ---')
123
+ runner.logger.info('mAP: %.4f' % metrics['val/pts_bbox_NuScenes/mAP'])
124
+ runner.logger.info('mATE: %.4f' % metrics['val/pts_bbox_NuScenes/mATE'])
125
+ runner.logger.info('mASE: %.4f' % metrics['val/pts_bbox_NuScenes/mASE'])
126
+ runner.logger.info('mAOE: %.4f' % metrics['val/pts_bbox_NuScenes/mAOE'])
127
+ runner.logger.info('mAVE: %.4f' % metrics['val/pts_bbox_NuScenes/mAVE'])
128
+ runner.logger.info('mAAE: %.4f' % metrics['val/pts_bbox_NuScenes/mAAE'])
129
+ runner.logger.info('NDS: %.4f' % metrics['val/pts_bbox_NuScenes/NDS'])
130
+
131
+
132
+ @HOOKS.register_module()
133
+ class MyTensorboardLoggerHook(LoggerHook):
134
+ def __init__(self, log_dir=None, interval=10, ignore_last=True, reset_flag=False, by_epoch=True):
135
+ super(MyTensorboardLoggerHook, self).__init__(
136
+ interval, ignore_last, reset_flag, by_epoch)
137
+ self.log_dir = log_dir
138
+
139
+ @master_only
140
+ def before_run(self, runner):
141
+ super(MyTensorboardLoggerHook, self).before_run(runner)
142
+ if self.log_dir is None:
143
+ self.log_dir = runner.work_dir
144
+ self.writer = SummaryWriter(self.log_dir)
145
+
146
+ @master_only
147
+ def log(self, runner):
148
+ tags = self.get_loggable_tags(runner)
149
+
150
+ for key, value in tags.items():
151
+ # MOD: merge into the 'train' group
152
+ if key == 'learning_rate':
153
+ key = 'train/learning_rate'
154
+
155
+ # MOD: skip momentum
156
+ ignore = False
157
+ if key == 'momentum':
158
+ ignore = True
159
+
160
+ # MOD: skip intermediate losses
161
+ for i in range(5):
162
+ if key[:13] == 'train/d%d.loss' % i:
163
+ ignore = True
164
+
165
+ if key[:3] == 'val':
166
+ metric_name = key[22:]
167
+ if metric_name in ['mAP', 'mATE', 'mASE', 'mAOE', 'mAVE', 'mAAE', 'NDS']:
168
+ key = 'val/' + metric_name
169
+ else:
170
+ ignore = True
171
+
172
+ if self.get_mode(runner) == 'train' and key[:5] != 'train':
173
+ ignore = True
174
+
175
+ if self.get_mode(runner) != 'train' and key[:3] != 'val':
176
+ ignore = True
177
+
178
+ if ignore:
179
+ continue
180
+
181
+ if key[:5] == 'train':
182
+ self.writer.add_scalar(key, value, self.get_iter(runner))
183
+ elif key[:3] == 'val':
184
+ self.writer.add_scalar(key, value, self.get_epoch(runner))
185
+
186
+ @master_only
187
+ def after_run(self, runner):
188
+ self.writer.close()
val.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import utils
3
+ import logging
4
+ import argparse
5
+ import importlib
6
+ import torch
7
+ import torch.distributed
8
+ import torch.distributed as dist
9
+ import torch.backends.cudnn as cudnn
10
+ from mmcv import Config
11
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
12
+ from mmcv.runner import load_checkpoint
13
+ from mmdet.apis import set_random_seed, multi_gpu_test, single_gpu_test
14
+ from mmdet3d.datasets import build_dataset, build_dataloader
15
+ from mmdet3d.models import build_model
16
+
17
+
18
+ def evaluate(dataset, results, epoch):
19
+ metrics = dataset.evaluate(results, jsonfile_prefix=None)
20
+
21
+ mAP = metrics['pts_bbox_NuScenes/mAP']
22
+ mATE = metrics['pts_bbox_NuScenes/mATE']
23
+ mASE = metrics['pts_bbox_NuScenes/mASE']
24
+ mAOE = metrics['pts_bbox_NuScenes/mAOE']
25
+ mAVE = metrics['pts_bbox_NuScenes/mAVE']
26
+ mAAE = metrics['pts_bbox_NuScenes/mAAE']
27
+ NDS = metrics['pts_bbox_NuScenes/NDS']
28
+
29
+ logging.info('--- Evaluation Results (Epoch %d) ---' % epoch)
30
+ logging.info('mAP: %.4f' % metrics['pts_bbox_NuScenes/mAP'])
31
+ logging.info('mATE: %.4f' % metrics['pts_bbox_NuScenes/mATE'])
32
+ logging.info('mASE: %.4f' % metrics['pts_bbox_NuScenes/mASE'])
33
+ logging.info('mAOE: %.4f' % metrics['pts_bbox_NuScenes/mAOE'])
34
+ logging.info('mAVE: %.4f' % metrics['pts_bbox_NuScenes/mAVE'])
35
+ logging.info('mAAE: %.4f' % metrics['pts_bbox_NuScenes/mAAE'])
36
+ logging.info('NDS: %.4f' % metrics['pts_bbox_NuScenes/NDS'])
37
+
38
+ return {
39
+ 'mAP': mAP,
40
+ 'mATE': mATE,
41
+ 'mASE': mASE,
42
+ 'mAOE': mAOE,
43
+ 'mAVE': mAVE,
44
+ 'mAAE': mAAE,
45
+ 'NDS': NDS,
46
+ }
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description='Validate a detector')
51
+ parser.add_argument('--config', required=True)
52
+ parser.add_argument('--weights', required=True)
53
+ parser.add_argument('--local_rank', type=int, default=0)
54
+ parser.add_argument('--world_size', type=int, default=1)
55
+ parser.add_argument('--batch_size', type=int, default=1)
56
+ args = parser.parse_args()
57
+
58
+ # parse configs
59
+ cfgs = Config.fromfile(args.config)
60
+
61
+ # register custom module
62
+ importlib.import_module('models')
63
+ importlib.import_module('loaders')
64
+
65
+ # MMCV, please shut up
66
+ from mmcv.utils.logging import logger_initialized
67
+ logger_initialized['root'] = logging.Logger(__name__, logging.WARNING)
68
+ logger_initialized['mmcv'] = logging.Logger(__name__, logging.WARNING)
69
+
70
+ # you need GPUs
71
+ assert torch.cuda.is_available()
72
+
73
+ # determine local_rank and world_size
74
+ if 'LOCAL_RANK' not in os.environ:
75
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
76
+
77
+ if 'WORLD_SIZE' not in os.environ:
78
+ os.environ['WORLD_SIZE'] = str(args.world_size)
79
+
80
+ local_rank = int(os.environ['LOCAL_RANK'])
81
+ world_size = int(os.environ['WORLD_SIZE'])
82
+
83
+ if local_rank == 0:
84
+ utils.init_logging(None, cfgs.debug)
85
+ else:
86
+ logging.root.disabled = True
87
+
88
+ logging.info('Using GPU: %s' % torch.cuda.get_device_name(local_rank))
89
+ torch.cuda.set_device(local_rank)
90
+
91
+ if world_size > 1:
92
+ logging.info('Initializing DDP with %d GPUs...' % world_size)
93
+ dist.init_process_group('nccl', init_method='env://')
94
+
95
+ logging.info('Setting random seed: 0')
96
+ set_random_seed(0, deterministic=True)
97
+ cudnn.benchmark = True
98
+
99
+ logging.info('Loading validation set from %s' % cfgs.data.val.data_root)
100
+ val_dataset = build_dataset(cfgs.data.val)
101
+ val_loader = build_dataloader(
102
+ val_dataset,
103
+ samples_per_gpu=args.batch_size,
104
+ workers_per_gpu=cfgs.data.workers_per_gpu,
105
+ num_gpus=world_size,
106
+ dist=world_size > 1,
107
+ shuffle=False,
108
+ seed=0,
109
+ )
110
+
111
+ logging.info('Creating model: %s' % cfgs.model.type)
112
+ model = build_model(cfgs.model)
113
+ model.cuda()
114
+
115
+ if world_size > 1:
116
+ model = MMDistributedDataParallel(model, [local_rank], broadcast_buffers=False)
117
+ else:
118
+ model = MMDataParallel(model, [0])
119
+
120
+ if os.path.isfile(args.weights):
121
+ logging.info('Loading checkpoint from %s' % args.weights)
122
+ load_checkpoint(
123
+ model, args.weights, map_location='cuda', strict=True,
124
+ logger=logging.Logger(__name__, logging.ERROR)
125
+ )
126
+
127
+ if world_size > 1:
128
+ results = multi_gpu_test(model, val_loader, gpu_collect=True)
129
+ else:
130
+ results = single_gpu_test(model, val_loader)
131
+
132
+ if local_rank == 0:
133
+ evaluate(val_dataset, results, -1)
134
+
135
+
136
+ if __name__ == '__main__':
137
+ main()